owned this note
owned this note
Published
Linked with GitHub
# Building and distributing a (non-trivial) pytorch extension -- The story so far
## 'Official' ways of building pytorch extensions
* [JIT loading mechanism](https://pytorch.org/tutorials/advanced/cpp_extension.html#jit-compiling-extensions)
```python
from torch.utils.cpp_extension import load
lltm_cpp = load(name="lltm_cpp", sources=["lltm.cpp"])
```
- :heavy_plus_sign: ensures correct pytorch build flags
:heavy_plus_sign: can distribute pure python code
:heavy_minus_sign: requires build tools on the end users' side
:heavy_minus_sign: need to handle C++ side dependencies manually
* [setuptools-based building](https://pytorch.org/tutorials/advanced/cpp_extension.html#building-with-setuptools)
```python
from setuptools import setup, Extension
from torch.utils import cpp_extension
setup(name='lltm_cpp',
ext_modules=[cpp_extension.CppExtension('lltm_cpp', ['lltm.cpp'])],
cmdclass={'build_ext': cpp_extension.BuildExtension})
```
- :heavy_plus_sign: ensures correct pytorch build flags
:heavy_plus_sign: can be built ahead of time
:heavy_minus_sign: need to handle C++ side dependencies manually
## Binding Code
Similar to [pybind11](https://github.com/pybind/pybind11) :gem:
```cpp
#include <torch/extension.h>
TORCH_LIBRARY(torch_sparse_ops, m) {
m.def("ffi_forward(Tensor features, Tensor weights, Tensor locations, Tensor? bias) -> Tensor");
}
TORCH_LIBRARY_IMPL(torch_sparse_ops, CPU, m) {
m.impl("ffi_forward", &SparseForwardOp<float, xck::ForwardImplementations::CPU, false>::dispatch);
}
TORCH_LIBRARY_IMPL(torch_sparse_ops, CUDA, m) {
m.impl("ffi_forward", &SparseForwardOp<float, xck::ForwardImplementations::GPU_Unit_FanBatch_Vector, false>::dispatch);
}
```
## [Scikit-build](https://scikit-build.readthedocs.io/en/latest/)
Python packaging for CMake-based projects. :gem:?
Setup file with some customizations (`setup.py`):
```python
from skbuild import setup
pkgdir_mapping = {'torch_sparse': 'src/python'}
def should_exclude(name: str):
return name.endswith('.a') or name.endswith('.cmake') or name.endswith(".hpp") or "include/experimental" in name
def exclude_files(cmake_manifest):
return list(filter(lambda name: not should_exclude(name), cmake_manifest))
setup(
# ... regular setup args
cmake_args=["-DCMAKE_BUILD_TYPE=RelWithDebInfo"],
packages=pkgdir_mapping.keys(),
package_dir=pkgdir_mapping,
cmake_process_manifest_hook=exclude_files,
cmake_languages=("C", "CXX", "CUDA")
)
```
Create and make available the extension library:
```cmake
add_library(torch_sparse_ops SHARED ...)
# ...
install(TARGETS torch_sparse_ops LIBRARY DESTINATION src/python)
```
Modern way of specifying package metadata.
`pyproject.toml`:
```
[build-system]
requires = [
"setuptools>=42",
"scikit-build>=0.13",
"cmake>=3.20",
"numpy",
"pybind11",
"ninja",
"torch==2.0.1",
]
build-backend = "setuptools.build_meta"
```
Notable: Does not require user(even when building from source) to have `cmake` or `ninja` installed; will go to temporariy building virtualenv.
- :heavy_plus_sign: cmake handles C++ dependencies
:heavy_plus_sign: can be built ahead of time
:heavy_plus_sign: can build just the C++ part
:heavy_minus_sign: need to handle pytorch compile flags
:page_facing_up::scissors: Build extension may depend on exact (major) version of pytorch
:page_facing_up::scissors: Requires CUDA development libraries to be present
([pip packages](https://pypi.org/project/nvidia-cuda-runtime-cu11/):gem: exist for _runtime_ library)
## Adding pytorch to a CMake project
* Query pytorch for cmake package path:
```cmake
find_package(Python3 COMPONENTS Interpreter REQUIRED)
execute_process(
COMMAND "${Python3_EXECUTABLE}" "-c" "import torch;print(torch.utils.cmake_prefix_path)"
OUTPUT_VARIABLE PT_CMAKE_PREFIX
COMMAND_ECHO STDOUT
OUTPUT_STRIP_TRAILING_WHITESPACE
COMMAND_ERROR_IS_FATAL ANY
)
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH};${PT_CMAKE_PREFIX})
find_package(Torch REQUIRED CONFIG)
```
and then `target_link_libraries(torch_sparse_ops PUBLIC torch)`
* :page_facing_up::scissors: skbuild caches build directory, but the virtualenv is re-created every build -> paths break
```cmake
if(SKBUILD)
message(STATUS "Building using SKBUILD: Resetting torch directories")
unset(C10_CUDA_LIBRARY CACHE)
unset(TORCH_LIBRARY CACHE)
endif()
```
* :page_facing_up::scissors: Sufficient for C++ based projects (?), but does not _link_ (in the cmake sense) required python parts
```cmake
find_package(Python3 REQUIRED Development)
target_link_libraries(torch_sparse_ops PUBLIC Python3::Python)
```
* :page_facing_up::scissors: Cannot load because of missing symbols. Links `libtorch.so`, `libtorch_cpu.so`, `libtorch_gpu.so`, but not `libtorch_python.so`
```cmake
# this shared library isn't linked with the default `torch` target,
# but it is required for
# _ZN8pybind116detail11type_casterIN2at6TensorEvE4loadENS_6handleEb
cmake_path(REPLACE_FILENAME TORCH_LIBRARY libtorch_python.so OUTPUT_VARIABLE LIBTORCH_PYTHON)
target_link_libraries(torch INTERFACE ${LIBTORCH_PYTHON})
```
Can use the full power of CMake for the build process:
```cmake
cmake_minimum_required(VERSION 3.20)
# we need 3.20 for `CUDAARCHS` environment variable
project(XMC-Kernels-PyTorch CXX CUDA)
include(cmake/pytorch.cmake)
find_package(Python3 REQUIRED Development)
find_package(OpenMP)
set(CMAKE_CXX_STANDARD 17)
set(KERNELS_REPO_URL "https://version.aalto.fi/gitlab/xmc/xmc-kernels.git"
CACHE STRING "URL to the git repository containing the XMC kernels" )
include(FetchContent)
FetchContent_Declare(
xmc-kernels
GIT_TAG master
GIT_REPOSITORY ${KERNELS_REPO_URL}
GIT_SHALLOW TRUE
)
set(XCK_BUILD_TESTS OFF)
FetchContent_MakeAvailable(xmc-kernels)
add_library(torch_sparse_ops SHARED src/cc/sparse.cpp src/cc/aten_device.cpp)
target_link_libraries(torch_sparse_ops PUBLIC torch
implement-all-kernels
Python3::Python
OpenMP::OpenMP_CXX)
target_compile_definitions(torch_sparse_ops PUBLIC
gsl_CONFIG_CONTRACT_VIOLATION_THROWS
gsl_CONFIG_DEVICE_CONTRACT_CHECKING_OFF)
if(OpenMP_CXX_FOUND)
target_link_libraries(torch_sparse_ops PUBLIC OpenMP::OpenMP_CXX)
target_compile_definitions(torch_sparse_ops PUBLIC ATEN_THREADING=OMP)
endif()
install(TARGETS torch_sparse_ops LIBRARY DESTINATION src/python)
```
* :gem: [FetchContent](https://cmake.org/cmake/help/latest/module/FetchContent.html)
* :gem:? [pytorch.cmake](https://version.aalto.fi/gitlab/AaltoRSE/xmc-sparse-pytorch/-/blob/master/cmake/pytorch.cmake)
## Building and running locally
* cmake based build process:
`cmake -S . -B build && cmake --build build`
* skbuild/python
`pip install .`
* wheel
`pip wheel --no-deps .`
Note: `pip install .` takes longer than `cmake`, as it first sets up a _new_ virtualenv and install dependencies.
## Making things portable(ish)
:page_facing_up::scissors: Chances are, if you try to use the `wheel` file on a different machine, there are problems
### CUDA compute capability
* Default: Compile for local GPUs
* Library loads and works on different computer until you try to actually call a GPU kernel
* => explicitly tell cmake to build for more architectures: [CUDAARCHS](https://cmake.org/cmake/help/latest/envvar/CUDAARCHS.html)
:page_facing_up::scissors: `find_package(torch)` somehow resets cuda architecture options :interrobang:
:wrench::
```cmake
# cache CUDA_ARCHITECTURES, which seems to be reset by Torch
set(TMP_STORE_CUDA_ARCHITECTURES "${CMAKE_CUDA_ARCHITECTURES}")
set(CMAKE_PREFIX_PATH ${CMAKE_PREFIX_PATH};${PT_CMAKE_PREFIX})
find_package(Torch REQUIRED CONFIG)
set(CMAKE_CUDA_ARCHITECTURES ${TMP_STORE_CUDA_ARCHITECTURES})
```
### glibc and manylinux
Building on recent linux (e.g. aalto desktop) links recent c standard libraries with symbols not available on older systems (triton; CSC clusters)
[PEP600](https://peps.python.org/pep-0600/) defines `manylinux` standards
to ensure basic portability: Maximum allowed versions of some foundational libraries.
* :gem: [auditwheel](https://github.com/pypa/auditwheel) for detecting required symbol versions and patching wheel to include all non-standard `.so`s
* First attempt: centos7 (supposedly corresponding to `manylinux2014`) image from docker [nvidia/cuda](https://hub.docker.com/r/nvidia/cuda/)
Does not work: `patchelf` version that comes with Centos7 too old for `auditwheel` :interrobang:
* manylinux [docker images](https://github.com/pypa/manylinux) :gem:? do not come with CUDA already set-up
* Finally: pytorch's own build/ci docker images are publically available, e.g. `pytorch/manylinux-builder:cuda11.7`
### Building in `pytorch/manylinux-builder`
Finding the correct python versions:
`/opt/python/${PYTHON}-${PYTHON}`
* :page_facing_up::scissors: Finding the correct CUDA version -- by default, the cuda 11.7 imgage uses Cuda 11.2
* CMake can figure things out: `export CUDACXX=/usr/local/cuda-11.7/bin/nvcc`
* :page_facing_up::scissors: pre-installed `auditwheel` does not work with python310/311
* => Install it *after* setting up the python version
```bash
export PIP=${PYPATH}/bin/pip
${PIP} install auditwheel
export AUDITWHEEL="${PYPATH}/bin/python -m auditwheel"
```
* :page_facing_up::scissors: Don't include the entire (wrong!) CUDA toolkit in the wheel
```bash
${AUDITWHEEL} repair --exclude libcudart.so.10.2 --exclude libtorch.so --exclude libtorch_cpu.so --exclude libtorch_cuda.so --exclude libc10.so --plat manylinux2014_x86_64 *.whl
```
Full CI script:
```yaml
- export CUDAARCHS="${CUDAARCHS}"
- export CUDACXX=/usr/local/cuda-11.7/bin/nvcc
- export PYPATH=/opt/python/${PYTHON}-${PYTHON}
- export PIP=${PYPATH}/bin/pip
- export AUDITWHEEL="${PYPATH}/bin/python -m auditwheel"
- export XCK_PT_DEV_SUFFIX=${CI_PIPELINE_IID}
# print versions so we have them logged
- g++ --version
- ${CUDACXX} --version
- ${PIP} --version
# the system-installed auditwheel version doesn't work for py310/py311, as it wants to include libpython into the wheel
# and cannot find it in the search paths. With this fresh install, libpython does not get included in the wheel.
- ${PIP} install auditwheel
# build the wheel
- ${PIP} wheel --progress-bar off --no-deps .
# make it manylinux
- ${AUDITWHEEL} show *.whl
# auditwheel picks up the system level cuda 10.2 here. Since we include it from the wheel anyway, I don't think it matters,
# but this might be a brittle assumption
- ${AUDITWHEEL} repair --exclude libcudart.so.10.2 --exclude libtorch.so --exclude libtorch_cpu.so --exclude libtorch_cuda.so --exclude libc10.so --plat manylinux2014_x86_64 *.whl
```
### Open problems:
:page_facing_up::scissors: Version matching: cuda versions; pytorch versions
:page_facing_up::scissors: Pip package versioning? current:
set a `dev` version based on the pipeline id
```bash
export XCK_PT_DEV_SUFFIX=${CI_PIPELINE_IID}
```
and dynamically adjust the version string in setup.py
```python
version = "0.0.4"
# potentially add a development version suffix
if "XCK_PT_DEV_SUFFIX" in os.environ:
version = version + ".dev" + os.environ["XCK_PT_DEV_SUFFIX"]
```