Package pytorch_numba_extension_jit

This package is aimed at simplifying the usage of Numba-CUDA kernels within projects using the PyTorch deep learning framework.

By annotating a function written in the style of a Numba-CUDA kernel with type hints from this package, jit() can generate PyTorch Custom Operator bindings that allow the kernel to be used within a traced (e.g. torch.compile) environment. Furthermore, by setting to_extension=True, the kernel can also be transformed into PTX, and C++ code can be generated to invoke the kernel with minimal overhead.

As a toy example, consider the task of creating a copy of a 1D array:

>>> import pytorch_numba_extension_jit as pnex
>>> @pnex.jit(n_threads="a.numel()")
... def copy(
...     a: pnex.In(dtype="f32", shape=(None,)),
...     result: pnex.Out(dtype="f32", shape="a"),
... ):
...     x = cuda.grid(1)
...     if x < a.shape[0]:
...         result[x] = a[x]
>>> A = torch.arange(5, dtype=torch.float32, device="cuda")
>>> copy(A)
tensor([0., 1., 2., 3., 4.], device='cuda:0')

For more examples of usage, see jit() and the examples directory of the project.

See also the GitHub page and the PyPi page

Note: Correct CUDA toolkit versions

When this package is installed via Pip, a version of nvidia-cuda-nvcc and nvidia-cuda-runtime will likely be installed. However, depending on the weather outside, this version may not be correct. As such, if you experience issues during compilation (especially if you see the error cuModuleLoadData(&cuModule, ptx) failed with error CUDA_ERROR_UNSUPPORTED_PTX_VERSION), then it may be worth verifying your installation. This can be done by running nvidia-smi to find your CUDA version, and then pip list to find the currently installed versions of the relevant NVIDIA libraries. These libraries should begin with your CUDA version, e.g. for CUDA 12.8 the expected output might look like:

$ pip list | grep nvidia-cuda-
nvidia-cuda-cupti-cu12    12.8.57
nvidia-cuda-nvcc-cu12     12.8.61
nvidia-cuda-nvrtc-cu12    12.8.61
nvidia-cuda-runtime-cu12  12.8.57

Entrypoint

def jit(*,
n_threads: str | tuple[str, str] | tuple[str, str, str],
to_extension: bool = False,
cache_id: str = None,
verbose: bool = False,
threads_per_block: int | tuple[int, int] | tuple[int, int, int] = None,
max_registers: int = None) ‑> Callable[[Callable[..., None]], torch._library.custom_ops.CustomOpDef]

Compile a Python function in the form of a Numba-CUDA kernel to a PyTorch operator

All parameters must be annotated with one of the argument types exported by this module, and the resulting operator will take In/InMut/Scalar parameters as arguments, while returning Out parameters.

The keyword-only argument n_threads must be specified to indicate with how many threads the resulting kernel should be launched. The dimensionality of n_threads indicates the dimensionality of the launched kernel, while threads_per_block controls the size of each block.

With to_extension=True, this function will also compile the PTX generated by Numba to a PyTorch native C++ extension, thereby reducing the overhead per call. If the resulting compilation times (first several seconds, then cached) are not acceptable, this additional compilation step can be skipped with to_extension=False.

Parameters

n_threads : str, tuple[str, str], tuple[str, str, str]

Expression(s) that evaluate to the total number of threads that the kernel should be launched with. Thread axes are filled in the order X, Y, Z: as such, passing only a single string n_threads is equivalent to passing (n_threads, 1, 1), with only the X thread-dimension being non-unit.

In practice, this number is then divided by threads_per_block and rounded up to get the number of blocks for a single kernel invocation (blocks per grid).

to_extension : bool = False

Whether the function should be compiled to a PyTorch C++ extension or instead be left as a wrapped Numba-CUDA kernel. The signature of the returned function is identical in both cases, but compiling an extension can take 5+ seconds, while not compiling an extension incurs a small runtime overhead on every call.

For neural networks, it is best to keep to_extension as False and use CUDA Graphs via torch.compile(model, mode="reduce-overhead", fullgraph=True) to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use to_extension and minimise call overhead.

cache_id : str, optional

The name to save the compiled extension under: clashing cache_ids will result in recompilations (clashing functions will evict each-other from the cache), but not miscompilations (the results will be correct).

Only used when to_extension=True

Returns

decorator : (kernel) -> torch.library.CustomOpDef

The resulting decorator will transform a Python function (if properly annotated, and the function is a valid Numba-CUDA kernel) into a CustomOpDef, where the signature is such that all parameters annotated with In, InMut or Scalar must be provided as arguments, and all Out parameters are returned.

All parameters must be annotated with one of In, InMut, Out, Scalar or Unused

Other Parameters

verbose : bool = False
Whether to print additional information about the compilation process. Compilation errors are always printed.
threads_per_block : int, tuple[int, int], tuple[int, int, int] = None

The number of threads within a thread block across the various dimensions.

Depending on the dimensionality of n_threads, this defaults to one of:

  • For 1 dimension: 256
  • For 2 dimensions: (16, 16)
  • For 3 dimensions: (8, 8, 4)
max_registers : int, optional
Specify the maximum number of registers to be used by the kernel, with excess spilling over to local memory. Typically, the compiler is quite good at guessing the number of registers it should use, but limiting this to hit occupancy targets may help in some cases. This option is only available with to_extension=False, due to the structure of the Numba-CUDA API.

Examples

This is an example implementation of the mymuladd function from the PyTorch Custom C++ and CUDA Operators documentation, where we take 2D inputs instead of flattening. A variety of methods for specifying dtype and shape are used in this example, but sticking to one convention may be better for readability.

>>> import pytorch_numba_extension_jit as pnex
>>> # Can be invoked as mymuladd_2d(A, B, C) to return RESULT
... @pnex.jit(n_threads="result.numel()")
... def mymuladd_2d(
...     a: pnex.In(torch.float32, (None, None)),
...     b: pnex.In("f32", ("a.size(0)", "a.size(1)")),
...     c: float,  # : pnex.Scalar(float)
...     result: pnex.Out("float32", "a"),
... ):
...     idx = cuda.grid(1)
...     y, x = divmod(idx, result.shape[0])
...     if y < result.shape[0]:
...         result[y, x] = a[y, x] * b[y, x] + c

Here, we can see an alternate version that uses multidimensional blocks to achieve the same task, while compiling the result to a C++ operator using to_extension. Note that the n_threads argument is given sizes in the X, Y, Z order (consistent with C++ CUDA kernels), and that numba.cuda.grid also returns indices in this order, even if we might later use indices in e.g. y, x order.

>>> @pnex.jit(n_threads=("result.size(1)", "result.size(0)"), to_extension=True)
... def mymuladd_grid(
...     a: pnex.In("f32", (None, None)),
...     b: pnex.In("f32", ("a.size(0)", "a.size(1)")),
...     c: float,
...     result: pnex.Out("f32", "a"),
... ):
...     # always use this order for names to be consistent with CUDA terminology:
...     x, y = cuda.grid(2)
...
...     if y < result.shape[0] and x < result.shape[1]:
...         result[y, x] = a[y, x] * b[y, x] + c

Notes

This function relies heavily on internals and undocumented behaviour of the Numba-CUDA PTX compiler. However, these internals have not changed in over 3 years, so it is reasonable to assume they will remain similar in future versions as well. Versions 0.9.0 and 0.10.0 of Numba-CUDA have been verified to work as expected.

Additionally, it should be noted that storing the function to be compiled for compilation in a different stack frame may cause issues if some annotations use local variables and the module is using from __future__ import annotations. This is because annotations are not considered part of the function proper, so they are not closed over during the construction of a function (no cell is created). Using jit() directly with the decorator syntax @pnex.jit(n_threads=...) has no such problems, or one can selectively disable annotations for the file where the function to be compiled is defined.

See Also

numba.cuda.compile_for_current_device
used to compile the Python function into PTX: all functions must therefore also be valid numba.cuda kernels.
numba.cuda.jit
used instead to allow to_extension=False
torch.utils.cpp_extension.load_inline
used to compile the PyTorch C++ extension

Argument types

class In (dtype: torch.dtype | numpy.dtype | str,
shape: tuple[int | str | None, ...] | str)

A type annotation for immutable input tensor parameters in a jit() function

An input tensor is part of the argument list in the final operator, meaning it must be provided by the caller. This variant is immutable, meaning the kernel must not modify the tensor.

To use this annotation, use the syntax param: In(dtype, shape).

Parameters

dtype : torch.dtype, np.dtype, str

The data type of the input tensor.

Some equivalent examples: torch.float32, float, "float32" or "f32"

shape : str, tuple of (int or str or None)

The shape of the input tensor.

If shape is a string, it must be the name of a previously defined tensor parameter, and the shape of this parameter must be equal to the shape of the parameter named by shape.

If shape is a tuple, every element in the tuple corresponds with one axis in the input tensor. For every such element:

  • int constrains the axis to be exactly of the given dimension.
  • str represents an expression that evaluates to an integer, and constrains the axis to be equal to the result of the expression. If the name of a tensor parameter is provided, this is equivalent to param_name.shape[nth_dim] where nth_dim is the index of the current axis.
  • None does not constrain the size of the axis.
class InMut (dtype: torch.dtype | numpy.dtype | str,
shape: tuple[int | str | None, ...] | str)

A type annotation for mutable input tensor parameters in a jit() function

An input tensor is part of the argument list in the final operator, meaning it must be provided by the caller. This variant is mutable, meaning the kernel may modify the tensor.

To use this annotation, use the syntax param: InMut(dtype, shape).

For information on the parameters, see In.

class Out (dtype: torch.dtype | numpy.dtype | str,
shape: tuple[int | str, ...] | str,
init: float = None)

A type annotation for output tensor parameters in a jit() function

An output tensor is not part of the argument list in the final operator, meaning the caller must not attempt to provide it. Instead, parameters marked as Out are created by the wrapper code before being passed to the kernel, and are returned to the caller afterwards as return values from the final operator. Since parameters marked Out are returned, they can receive a gradient and can work with the PyTorch autograd system.

To use this annotation, use the syntax param: Out(dtype, shape[, init=init]).

Parameters

dtype : torch.dtype, np.dtype, str

The data type of the output tensor.

Some equivalent examples: torch.float32, float, "float32" or "f32"

shape : str, tuple of (int or str)

The shape of the output tensor.

If shape is a string, it must be the name of a previously defined tensor parameter, and this tensor will be constructed to have the same shape as the parameter named by shape

If shape is a tuple, every element in the tuple corresponds with one axis in the output tensor. For every such element:

  • int sets the size to be exactly the provided value.
  • str represents an expression that evaluates to an integer, and sets the size of the axis to be equal to the result of the expression. If the name of a tensor parameter is provided, this is equivalent to param_name.shape[nth_dim] where nth_dim is the index of the current axis.
init : float or int, optional

The initial value used to fill the output tensor with. If not provided, the output tensor will contain uninitialised memory (in the style of torch.empty).

Example: gradient tensors for the backward pass should be initialised with 0.

class Scalar (dtype: torch.dtype | numpy.dtype | str)

A type annotation for scalar input parameters in a jit() function

A scalar input is part of the argument list in the final operator, meaning the caller must provide it. It is not returned: for scalar outputs, use Out(dtype, (1,)) instead.

To use this annotation, use the syntax param: Scalar(dtype), or the shorthand param: dtype.

Parameters

dtype : torch.dtype, np.dtype, str

The data type of the scalar.

Some equivalent examples: torch.float32, float, "float32" or "f32"

class Unused

A type annotation for ignored parameters in a jit() function

This is a utility class for marking certain parameters to be skipped during compilation. An example of this would be a kernel which can optionally return an additional output (such as provenance indices for a maximum operation), allowing this output to be skipped programmatically.

Note that all array accesses of a parameter marked Unused must be statically determined to be dead code (e.g. if False), as compilation will otherwise fail.

To use this annotation, use e.g. param: Out(...) if condition else Unused