Package pytorch_nd_semiconv
This package offers a generalisation of linear convolutions and max-poolings using semifields and for an arbitrary number of spatial dimensions, with many common choices for semifields being available.
Additionally, this package provides example implementations of 2D isotropic and
anisotropic quadratic kernels for use in e.g. dilation of 2-dimensional images.
These implementations are efficient when run under torch.compile
.
As an example of anisotropic quadratic dilation, consider replacing:
pooling = torch.nn.MaxPool2d(3, stride=2, padding=1)
with:
import pytorch_nd_semiconv as semiconv
dilation = semiconv.GenericConv(
semiconv.QuadraticKernelSpectral2D(
in_channels=5, out_channels=5, kernel_size=3
),
semiconv.SelectSemifield.tropical_max().lazy_fixed(),
padding="same",
stride=2,
groups=5,
)
See also the GitHub page and the PyPi page
Classes
class GenericConv (kernel,
conv,
stride=1,
padding=0,
dilation=1,
groups=1,
group_broadcasting=False,
kind='conv')-
Expand source code Browse git
class GenericConv(nn.Module): """ A generic convolution Module using a kernel and a convolution Module Parameters ------- kernel : nn.Module A module that produces a convolutional kernel from its `forward` method. Must not take arguments. See e.g. `QuadraticKernelSpectral2D` or `LearnedKernel2D`. conv : nn.Module A module that can take `image, kernel` as positional arguments, as well as `dilation`, `padding`, `stride` and `groups` as keyword arguments, optionally supporting `group_broadcasting` and `kind`. See e.g. `BroadcastSemifield.dynamic` or `SelectSemifield.lazy_fixed`. stride : int, (int, ...) = 1 The stride passed to `conv`, either for all spatial dimensions or for each separately. padding : int, (int, ...), ((int, int), ...), "valid", "same" = 0 The padding passed to `conv`. Depending on the type of `padding`: - `P` indicates padding at the start and end of all spatial axes with `P`. - `(P0, ...)` indicates padding at the start and end of the first spatial axis with `P0`, and similarly for all other spatial axes. - `((PBeg0, PEnd0), ...)` indicates padding the start of the first spatial axis with `PBeg0` and the end with `PEnd0`, similarly for all other spatial axes. - `"valid"` indicates to only perform the convolution with valid values of the image, i.e. no padding. - `"same"` indicates to pad the input such that a stride-1 convolution would produce an output of the same spatial size. Convolutions with higher stride will use the same padding scheme, but result in outputs of reduced size. dilation : int, (int, ...) = 1 The dilation passed to `conv`, either for all spatial dimensions or for each separately. groups : int = 1 The number of convolutional groups for this convolution. group_broadcasting : bool = False Whether to take the input kernels as a single output group, and broadcast across all input groups. `group_broadcasting` has no effect when `groups=1` kind : literal "conv" or "corr" Represents whether the kernel should be mirrored during the convolution (`"conv"`) or not (`"corr"`). Examples ------- >>> import pytorch_nd_semiconv as semiconv >>> dilation = semiconv.GenericConv( ... semiconv.QuadraticKernelSpectral2D(5, 5, 3), ... semiconv.SelectSemifield.tropical_max().lazy_fixed(), ... padding="same", ... stride=2, ... groups=5, ... ) >>> root = semiconv.GenericConv( ... semiconv.QuadraticKernelIso2D(5, 10, 3), ... semiconv.BroadcastSemifield.root(3.0).dynamic(), ... padding="same", ... ) """ def __init__( self, kernel: nn.Module, conv: nn.Module, stride: int | tuple[int, ...] = 1, padding: ( int | tuple[int, ...] | tuple[tuple[int, int], ...] | Literal["valid", "same"] ) = 0, dilation: int | tuple[int, ...] = 1, groups: int = 1, group_broadcasting: bool = False, kind: Literal["conv", "corr"] = "conv", ): super().__init__() self.padding = padding self.stride = stride self.dilation = dilation self.kernel = kernel self.conv = conv self.groups = groups self.group_broadcasting = group_broadcasting self.kind = kind # Since these are custom arguments, we only want to pass them if they differ # from the default values (otherwise, they may be unexpected) self.kwargs = {} if self.group_broadcasting: self.kwargs["group_broadcasting"] = True if self.kind == "corr": self.kwargs["kind"] = "corr" def forward(self, img): """ Run a forward step with this convolution. Parameters ---------- img : Tensor (B, C, *Spatial) The input images as a 2+N tensor, of shape (Batch, Channels, ...Spatial) For example: a 2D image would typically be (Batch, Channels, Height, Width) Returns ------- out_img : Tensor (B, C', *Spatial') The output images as a 2+N tensor, with the same batch shape but possibly adjusted other dimensions. """ return self.conv( img, self.kernel(), dilation=self.dilation, padding=self.padding, stride=self.stride, groups=self.groups, **self.kwargs, ) def extra_repr(self) -> str: res = [] if self.padding: res.append(f"padding={self.padding}") if self.stride != 1: res.append(f"stride={self.stride}") if self.dilation != 1: res.append(f"dilation={self.dilation}") if self.groups != 1: res.append(f"groups={self.groups}") if self.group_broadcasting: res.append("group_broadcasting=True") if self.kind == "corr": res.append("kind=corr") return ", ".join(res)
A generic convolution Module using a kernel and a convolution Module
Parameters
kernel
:nn.Module
-
A module that produces a convolutional kernel from its
forward
method. Must not take arguments.See e.g.
QuadraticKernelSpectral2D
orLearnedKernel2D
. conv
:nn.Module
-
A module that can take
image, kernel
as positional arguments, as well asdilation
,padding
,stride
andgroups
as keyword arguments, optionally supportinggroup_broadcasting
andkind
.See e.g.
BroadcastSemifield.dynamic()
orSelectSemifield.lazy_fixed()
. stride
:int, (int, ...) = 1
- The stride passed to
conv
, either for all spatial dimensions or for each separately. padding
:int, (int, ...), ((int, int), ...), "valid", "same" = 0
-
The padding passed to
conv
. Depending on the type ofpadding
:P
indicates padding at the start and end of all spatial axes withP
.(P0, …)
indicates padding at the start and end of the first spatial axis withP0
, and similarly for all other spatial axes.((PBeg0, PEnd0), …)
indicates padding the start of the first spatial axis withPBeg0
and the end withPEnd0
, similarly for all other spatial axes."valid"
indicates to only perform the convolution with valid values of the image, i.e. no padding."same"
indicates to pad the input such that a stride-1 convolution would produce an output of the same spatial size. Convolutions with higher stride will use the same padding scheme, but result in outputs of reduced size.
dilation
:int, (int, ...) = 1
- The dilation passed to
conv
, either for all spatial dimensions or for each separately. groups
:int = 1
- The number of convolutional groups for this convolution.
group_broadcasting
:bool = False
- Whether to take the input kernels as a single output group, and broadcast
across all input groups.
group_broadcasting
has no effect whengroups=1
kind
:literal "conv"
or"corr"
- Represents whether the kernel should be mirrored during the convolution
(
"conv"
) or not ("corr"
).
Examples
>>> import pytorch_nd_semiconv as semiconv >>> dilation = semiconv.GenericConv( ... semiconv.QuadraticKernelSpectral2D(5, 5, 3), ... semiconv.SelectSemifield.tropical_max().lazy_fixed(), ... padding="same", ... stride=2, ... groups=5, ... ) >>> root = semiconv.GenericConv( ... semiconv.QuadraticKernelIso2D(5, 10, 3), ... semiconv.BroadcastSemifield.root(3.0).dynamic(), ... padding="same", ... )
Methods
def forward(self, img)
-
Expand source code Browse git
def forward(self, img): """ Run a forward step with this convolution. Parameters ---------- img : Tensor (B, C, *Spatial) The input images as a 2+N tensor, of shape (Batch, Channels, ...Spatial) For example: a 2D image would typically be (Batch, Channels, Height, Width) Returns ------- out_img : Tensor (B, C', *Spatial') The output images as a 2+N tensor, with the same batch shape but possibly adjusted other dimensions. """ return self.conv( img, self.kernel(), dilation=self.dilation, padding=self.padding, stride=self.stride, groups=self.groups, **self.kwargs, )
Run a forward step with this convolution.
Parameters
img
:Tensor (B, C, *Spatial)
-
The input images as a 2+N tensor, of shape (Batch, Channels, …Spatial)
For example: a 2D image would typically be (Batch, Channels, Height, Width)
Returns
out_img
:Tensor (B, C', *Spatial')
- The output images as a 2+N tensor, with the same batch shape but possibly adjusted other dimensions.
class BroadcastSemifield (add_reduce, multiply, zero, add_reduce_channels=None)
-
Expand source code Browse git
class BroadcastSemifield(typing.NamedTuple): r""" A semifield definition using PyTorch broadcasting operators Using a technique similar to nn.Unfold, we can create a view of the input array and apply broadcasting functions along kernel axes to perform a semifield convolution. All functions must take PyTorch Tensors, and should have a backwards implementation. This function does not use JIT components, and therefore has no compilation time (and can be run on non-CUDA devices as well). Parameters ------- add_reduce : (Tensor, tuple of ints) -> Tensor To characterise semifield summation \(\bigoplus\), this function takes a single tensor with several axes, and performs reduction with \(\oplus\) along the axes indicated in the second argument. Example: ``lambda arr, dims: torch.sum(arr, dim=dims)`` multiply : (Tensor, Tensor) -> Tensor To characterise semifield multiplication \(\otimes\), this function takes two tensors and performs a broadcasting, element-wise version of \(\otimes\). Example: ``lambda img, krn: img * krn`` zero : float The absorbing semifield zero. Other Parameters ------- add_reduce_channels : (Tensor, int) -> Tensor, optional An alternate reduction function (similar to `add_reduce`) that is applied along specifically the channel dimension. This alternate function could be e.g. addition, in a modified version of \(T_+\) (see `channels_add` parameter of `BroadcastSemifield.tropical_max`). Examples ------- \(T_+\) convolution: >>> dilation = BroadcastSemifield.tropical_max().dynamic() \(L_{-3}\) convolution: >>> log = BroadcastSemifield.log(-3.0).dynamic() For examples of how to construct a `BroadcastSemifield` manually, see the source. """ # (multiplied, dims) -> `multipled` reduced with (+) along every dim in `dims` add_reduce: Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor] # (img, krn) -> `img` (x) `krn` multiply: Callable[[torch.Tensor, torch.Tensor], torch.Tensor] # forall a, b: `zero` (x) a (+) b == b zero: float # Similar to add_reduce, but only used for channel axis (so takes one dimension) add_reduce_channels: Callable[[torch.Tensor, int], torch.Tensor] = None @classmethod def tropical_max(cls, channels_add: bool = False, spread_gradient: bool = False): r""" Construct a \(T_+\) `BroadcastSemifield`. The tropical max semifield / semiring is defined as: \[(\mathbb{R}\cup \{-\infty\}, \max, +)\] Parameters ---------- channels_add : bool = False Whether to use standard addition \(+\) instead of the semifield addition \(\max\) along specifically the channel axis. spread_gradient : bool = False Whether to, in cases of multiple equal maxima, spread the gradient equally amongst all maxima. """ return cls( add_reduce=(lambda multiplied, dim: torch.amax(multiplied, dim=dim)) if spread_gradient else ( _repeated_dim( lambda multiplied, dim: torch.max(multiplied, dim=dim).values ) ), multiply=lambda img, krn: img + krn, zero=-torch.inf, add_reduce_channels=( (lambda multiplied, dim: torch.sum(multiplied, dim=dim)) if channels_add else None ), ) @classmethod def tropical_min_negated( cls, channels_add: bool = False, spread_gradient: bool = False ): r""" Construct a `BroadcastSemifield` similar to \(T_-\), where the kernel is negated The usual tropical min semifield / semiring is defined as: \[(\mathbb{R}\cup \{\infty\}, \min, +)\] This version is slightly modified: while performing erosion using \(T_-\) requires first negating the kernel, this modified semifield has \(-\) instead of \(+\) as the semifield multiplication. As such, the resulting convolution will work with non-negated kernels as inputs, making the interface more similar to the dilation in \(T_+\). Parameters ---------- channels_add : bool = False Whether to use standard addition \(+\) instead of the semifield addition \(\min\) along specifically the channel axis. spread_gradient : bool = False Whether to, in cases of multiple equal minima, spread the gradient equally amongst all minima. """ return cls( add_reduce=(lambda multiplied, dim: torch.amin(multiplied, dim=dim)) if spread_gradient else ( _repeated_dim( lambda multiplied, dim: torch.min(multiplied, dim=dim).values ) ), multiply=lambda img, krn: img - krn, zero=torch.inf, add_reduce_channels=( (lambda multiplied, dim: torch.sum(multiplied, dim=dim)) if channels_add else None ), ) @classmethod def linear(cls): r""" Construct a linear `BroadcastSemifield`. The linear field is defined as: \[(\mathbb{R}, +, \times)\] Mainly for comparison purposes: the linear convolutions offered by PyTorch use CUDNN, which is far better optimised for CUDA devices. """ return cls( add_reduce=(lambda multiplied, dim: torch.sum(multiplied, dim=dim)), multiply=lambda img, krn: img * krn, zero=0, ) @classmethod def root(cls, p: float): r""" Construct a \(R_p\) `BroadcastSemifield`. The root semifields are defined as: \[(\mathbb{R}_+, \oplus_p, \times) \textrm{ for all } p\ne0 \textrm{ where } a\oplus_p b= \sqrt[p]{a^p+b^p} \] with the semifield zero being \(0\) and the semifield one being \(1\). Parameters ---------- p : int The power to use in \(\oplus_p\). May not be zero. """ assert p != 0, f"Invalid value: {p=}" return cls( add_reduce=( lambda multiplied, dim: multiplied.pow(p).sum(dim=dim).pow(1 / p) ), multiply=lambda img, krn: img * krn, zero=float(torch.finfo(torch.float32).eps), ) @classmethod def log(cls, mu: float): r""" Construct a \(L_+\mu\) or \(L_-\mu\) `BroadcastSemifield`. The log semifields are defined as: \[(\mathbb{R}\cup \{\pm\infty\}, \oplus_\mu, +) \textrm{ for all } \mu\ne0 \textrm{ where } a\oplus_\mu b= \frac{1}{\mu}\ln(e^{\mu a}+e^{\mu b}) \] with the semifield zero being \(-\infty\) for \(\mu>0\) and \(\infty\) otherwise, and the semifield one being \(0\). Parameters ---------- mu : int The base to use in \(\oplus_mu\). May not be zero. """ assert mu != 0, f"Invalid value: {mu=}" return cls( add_reduce=( lambda multiplied, dim: torch.logsumexp(multiplied * mu, dim=dim) / mu ), multiply=lambda img, krn: img + krn, zero=-torch.inf if mu > 0 else torch.inf, ) def dynamic(self, unfold_copy: bool = False) -> torch.nn.Module: """ Create a convolution Module based on this `BroadcastSemifield`. This method is named `dynamic`, because the Module it creates will dynamically adjust itself based on new input types, unlike e.g. `SelectSemifield.lazy_fixed` Parameters ---------- unfold_copy : bool = False Whether to use `nn.functional.unfold` during computation, which results in a copy of the data. This is only supported for 2D convolutions 1D or 3+D convolutions cannot use `nn.functional.unfold`. Mainly for comparison purposes: in tests, it always results in slowdown. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv` """ return BroadcastConv(self, unfold_copy)
A semifield definition using PyTorch broadcasting operators
Using a technique similar to nn.Unfold, we can create a view of the input array and apply broadcasting functions along kernel axes to perform a semifield convolution. All functions must take PyTorch Tensors, and should have a backwards implementation.
This function does not use JIT components, and therefore has no compilation time (and can be run on non-CUDA devices as well).
Parameters
add_reduce
:(Tensor, tuple
ofints) -> Tensor
-
To characterise semifield summation \bigoplus, this function takes a single tensor with several axes, and performs reduction with \oplus along the axes indicated in the second argument.
Example:
lambda arr, dims: torch.sum(arr, dim=dims)
multiply
:(Tensor, Tensor) -> Tensor
-
To characterise semifield multiplication \otimes, this function takes two tensors and performs a broadcasting, element-wise version of \otimes.
Example:
lambda img, krn: img * krn
zero
:float
- The absorbing semifield zero.
Other Parameters
add_reduce_channels
:(Tensor, int) -> Tensor
, optional- An alternate reduction function (similar to
add_reduce
) that is applied along specifically the channel dimension. This alternate function could be e.g. addition, in a modified version of T_+ (seechannels_add
parameter ofBroadcastSemifield.tropical_max()
).
Examples
T_+ convolution:
>>> dilation = BroadcastSemifield.tropical_max().dynamic()
L_{-3} convolution:
>>> log = BroadcastSemifield.log(-3.0).dynamic()
For examples of how to construct a
BroadcastSemifield
manually, see the source.Static methods
def tropical_max(channels_add=False, spread_gradient=False)
-
Construct a T_+
BroadcastSemifield
.The tropical max semifield / semiring is defined as: (\mathbb{R}\cup \{-\infty\}, \max, +)
Parameters
channels_add
:bool = False
- Whether to use standard addition + instead of the semifield addition \max along specifically the channel axis.
spread_gradient
:bool = False
- Whether to, in cases of multiple equal maxima, spread the gradient equally amongst all maxima.
def tropical_min_negated(channels_add=False, spread_gradient=False)
-
Construct a
BroadcastSemifield
similar to T_-, where the kernel is negatedThe usual tropical min semifield / semiring is defined as: (\mathbb{R}\cup \{\infty\}, \min, +)
This version is slightly modified: while performing erosion using T_- requires first negating the kernel, this modified semifield has - instead of + as the semifield multiplication. As such, the resulting convolution will work with non-negated kernels as inputs, making the interface more similar to the dilation in T_+.
Parameters
channels_add
:bool = False
- Whether to use standard addition + instead of the semifield addition \min along specifically the channel axis.
spread_gradient
:bool = False
- Whether to, in cases of multiple equal minima, spread the gradient equally amongst all minima.
def linear()
-
Construct a linear
BroadcastSemifield
.The linear field is defined as: (\mathbb{R}, +, \times)
Mainly for comparison purposes: the linear convolutions offered by PyTorch use CUDNN, which is far better optimised for CUDA devices.
def root(p)
-
Construct a R_p
BroadcastSemifield
.The root semifields are defined as: (\mathbb{R}_+, \oplus_p, \times) \textrm{ for all } p\ne0 \textrm{ where } a\oplus_p b= \sqrt[p]{a^p+b^p} with the semifield zero being 0 and the semifield one being 1.
Parameters
p
:int
- The power to use in \oplus_p. May not be zero.
def log(mu)
-
Construct a L_+\mu or L_-\mu
BroadcastSemifield
.The log semifields are defined as: (\mathbb{R}\cup \{\pm\infty\}, \oplus_\mu, +) \textrm{ for all } \mu\ne0 \textrm{ where } a\oplus_\mu b= \frac{1}{\mu}\ln(e^{\mu a}+e^{\mu b}) with the semifield zero being -\infty for \mu>0 and \infty otherwise, and the semifield one being 0.
Parameters
mu
:int
- The base to use in \oplus_mu. May not be zero.
Methods
def dynamic(self, unfold_copy=False)
-
Expand source code Browse git
def dynamic(self, unfold_copy: bool = False) -> torch.nn.Module: """ Create a convolution Module based on this `BroadcastSemifield`. This method is named `dynamic`, because the Module it creates will dynamically adjust itself based on new input types, unlike e.g. `SelectSemifield.lazy_fixed` Parameters ---------- unfold_copy : bool = False Whether to use `nn.functional.unfold` during computation, which results in a copy of the data. This is only supported for 2D convolutions 1D or 3+D convolutions cannot use `nn.functional.unfold`. Mainly for comparison purposes: in tests, it always results in slowdown. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv` """ return BroadcastConv(self, unfold_copy)
Create a convolution Module based on this
BroadcastSemifield
.This method is named
dynamic
, because the Module it creates will dynamically adjust itself based on new input types, unlike e.g.SelectSemifield.lazy_fixed()
Parameters
unfold_copy
:bool = False
-
Whether to use
nn.functional.unfold
during computation, which results in a copy of the data. This is only supported for 2D convolutions 1D or 3+D convolutions cannot usenn.functional.unfold
.Mainly for comparison purposes: in tests, it always results in slowdown.
Returns
conv
:nn.Module
- A convolution module, suitable for use in
GenericConv
class SelectSemifield (add_select, times, d_times_d_img, d_times_d_kernel, zero, cache_name=None)
-
Expand source code Browse git
class SelectSemifield(NamedTuple): r""" A semifield definition where semifield addition selects a single value For such semifields, the backwards pass can be done very efficiently by memoizing the output provenance (index of the chosen value). The resulting module is compiled and works only on CUDA devices. Parameters ------- add_select : (float, float) -> bool Given two values, return whether we should pick the second value (`True`), or instead keep the first (`False`). If there is no meaningful difference between the two values, `False` should be preferred. times : (float, float) -> float Given an image and a kernel value, perform scalar semifield multiplication \(\otimes\). d_times_d_img : (float, float) -> float Given the two arguments to `times`, compute the derivative to the first: \[\frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{img}}\] d_times_d_kernel : (float, float) -> float Given the two arguments to `times`, compute the derivative to the second: \[\frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{kernel}}\] zero : float The semifield zero. cache_name : str, optional Identifier for this semifield, allows for extension compilations to be cached. Instances of `SelectSemifield` that are meaningfully different should not have the same `cache_name`, as this may lead to the wrong compilation being used. Examples ------- \(T_+\) convolution that will recompile for new inputs: >>> dilation = SelectSemifield.tropical_max().dynamic() \(T_-\) convolution that will compile only once: >>> erosion = SelectSemifield.tropical_min_negated().lazy_fixed() For examples of how to construct a `SelectSemifield` manually, see the source code. """ add_select: Callable[[float, float], bool] # Return True if we should pick right times: Callable[[float, float], float] # (img_val, krn_val) -> multiplied_val d_times_d_img: Callable[[float, float], float] d_times_d_kernel: Callable[[float, float], float] zero: float cache_name: str = None # Cache identifier: distinct for different operators @classmethod def tropical_max(cls) -> Self: r""" Construct a \(T_+\) `SelectSemifield`. The tropical max semifield / semiring is defined as: \[(\mathbb{R}\cup \{-\infty\}, \max, +)\] """ return cls( add_select=lambda left, right: left < right, times=lambda img_val, kernel_val: img_val + kernel_val, d_times_d_img=lambda _i, _k: 1.0, d_times_d_kernel=lambda _i, _k: 1.0, zero=-math.inf, cache_name="_tropical_max", ) @classmethod def tropical_min_negated(cls) -> Self: r""" Construct a `SelectSemifield` similar to \(T_-\), where the kernel is negated. The usual tropical min semifield / semiring is defined as: \[(\mathbb{R}\cup \{\infty\}, \min, +)\] This version is slightly modified: while performing erosion using \(T_-\) requires first negating the kernel, this modified semifield has \(-\) instead of \(+\) as the semifield multiplication. As such, the resulting convolution will work with non-negated kernels as inputs, making the interface more similar to the dilation in \(T_+\). """ return cls( add_select=lambda left, right: left > right, times=lambda img_val, kernel_val: img_val - kernel_val, d_times_d_img=lambda _i, _k: 1.0, d_times_d_kernel=lambda _i, _k: -1.0, zero=math.inf, cache_name="_tropical_min", ) # The torch compiler doesn't understand the Numba compiler @torch.compiler.disable @lru_cache # noqa: B019 def _compile( self, meta: ConvMeta, compile_options: Mapping[str, Any], ) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]: impl = compile_options.get("impl", "glb") if impl not in ("glb",): raise ValueError(f"Unknown {impl=}") cmp_semi = CompiledSelectSemifield.compile(self) impls = { "glb": compile_forwards, } forwards = impls[impl]( semifield=cmp_semi, meta=meta, thread_block_size=compile_options.get("thread_block_size"), debug=compile_options.get("debug", False), cache_name="_temporary" if self.cache_name is None else self.cache_name, to_extension=compile_options.get("to_extension", False), ) backwards, backwards_setup = compile_backwards( semifield=cmp_semi, meta=meta, thread_block_size=compile_options.get("thread_block_size"), debug=compile_options.get("debug", False), cache_name="_temporary" if self.cache_name is None else self.cache_name, to_extension=compile_options.get("to_extension", False), kernel_inflation=compile_options.get("kernel_inflation", 16), ) forwards.register_autograd(backwards, setup_context=backwards_setup) return forwards def dynamic( self, thread_block_size: int = None, to_extension: bool = False, debug: bool = False, kernel_inflation: int = 16, ) -> torch.nn.Module: """ Create a *recompiling* convolution Module based on this `SelectSemifield`. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv`. Note that the compilation process is not traceable, and recompilations **may cause errors when using `torch.compile`** for backends other than CUDA Graphs Other Parameters ---------- thread_block_size : int = 128 The number of threads per CUDA block. to_extension : bool = False Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs. For neural networks, it is best to keep `to_extension` as False and use CUDA Graphs via `torch.compile(model, mode="reduce-overhead", fullgraph=True)` to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use `to_extension` and minimise call overhead. debug : bool = False Whether to print additional debugging and compilation information. kernel_inflation : int = 16 The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency. """ return CompiledConv( self, { "thread_block_size": thread_block_size, "debug": debug, "to_extension": to_extension, "kernel_inflation": kernel_inflation, }, ) def lazy_fixed( self, thread_block_size: int = None, to_extension: bool = False, debug: bool = False, kernel_inflation: int = 16, ) -> torch.nn.Module: """ Create a *once-compiling* convolution Module based on this `SelectSemifield`. In general, `SelectSemifield.dynamic` should be preferred for testing and also for training if the model can be traced by CUDA Graphs. If CUDA Graphs cannot capture the model code due to dynamic elements, then using `SelectSemifield.lazy_fixed` with `to_extension=True` will minimise overhead. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv`. Note that compilation will be based on the first inputs seen, after which the operation will be fixed: **only batch size may be changed afterwards**. The module is, however, traceable by e.g. `torch.compile` on all backends. Other Parameters ---------- thread_block_size : int = 128 The number of threads per CUDA block. to_extension : bool = False Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs. For neural networks, it is best to keep `to_extension` as False and use CUDA Graphs via `torch.compile(model, mode="reduce-overhead", fullgraph=True)` to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use `to_extension` and minimise call overhead. debug : bool = False Whether to print additional debugging and compilation information. kernel_inflation : int = 16 The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency. """ return CompiledConvFixedLazy( self, { "thread_block_size": thread_block_size, "debug": debug, "to_extension": to_extension, "kernel_inflation": kernel_inflation, }, ) def __hash__(self): if self.cache_name is not None: return hash(self.cache_name) return hash( ( self.add_select, self.times, self.d_times_d_img, self.d_times_d_kernel, self.zero, ) ) def __eq__(self, other): if not isinstance(other, SelectSemifield): return False if self.cache_name is not None: return self.cache_name == other.cache_name return self is other @staticmethod def _get_result(res: tuple[torch.Tensor, torch.Tensor]): return res[0]
A semifield definition where semifield addition selects a single value
For such semifields, the backwards pass can be done very efficiently by memoizing the output provenance (index of the chosen value). The resulting module is compiled and works only on CUDA devices.
Parameters
add_select
:(float, float) -> bool
- Given two values, return whether we should pick the second value (
True
), or instead keep the first (False
). If there is no meaningful difference between the two values,False
should be preferred. times
:(float, float) -> float
- Given an image and a kernel value, perform scalar semifield multiplication \otimes.
d_times_d_img
:(float, float) -> float
- Given the two arguments to
times
, compute the derivative to the first: \frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{img}} d_times_d_kernel
:(float, float) -> float
- Given the two arguments to
times
, compute the derivative to the second: \frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{kernel}} zero
:float
- The semifield zero.
cache_name
:str
, optional-
Identifier for this semifield, allows for extension compilations to be cached.
Instances of
SelectSemifield
that are meaningfully different should not have the samecache_name
, as this may lead to the wrong compilation being used.
Examples
T_+ convolution that will recompile for new inputs:
>>> dilation = SelectSemifield.tropical_max().dynamic()
T_- convolution that will compile only once:
>>> erosion = SelectSemifield.tropical_min_negated().lazy_fixed()
For examples of how to construct a
SelectSemifield
manually, see the source code.Static methods
def tropical_max()
-
Construct a T_+
SelectSemifield
.The tropical max semifield / semiring is defined as: (\mathbb{R}\cup \{-\infty\}, \max, +)
def tropical_min_negated()
-
Construct a
SelectSemifield
similar to T_-, where the kernel is negated.The usual tropical min semifield / semiring is defined as: (\mathbb{R}\cup \{\infty\}, \min, +)
This version is slightly modified: while performing erosion using T_- requires first negating the kernel, this modified semifield has - instead of + as the semifield multiplication. As such, the resulting convolution will work with non-negated kernels as inputs, making the interface more similar to the dilation in T_+.
Methods
def dynamic(self, thread_block_size=None, to_extension=False, debug=False, kernel_inflation=16)
-
Expand source code Browse git
def dynamic( self, thread_block_size: int = None, to_extension: bool = False, debug: bool = False, kernel_inflation: int = 16, ) -> torch.nn.Module: """ Create a *recompiling* convolution Module based on this `SelectSemifield`. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv`. Note that the compilation process is not traceable, and recompilations **may cause errors when using `torch.compile`** for backends other than CUDA Graphs Other Parameters ---------- thread_block_size : int = 128 The number of threads per CUDA block. to_extension : bool = False Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs. For neural networks, it is best to keep `to_extension` as False and use CUDA Graphs via `torch.compile(model, mode="reduce-overhead", fullgraph=True)` to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use `to_extension` and minimise call overhead. debug : bool = False Whether to print additional debugging and compilation information. kernel_inflation : int = 16 The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency. """ return CompiledConv( self, { "thread_block_size": thread_block_size, "debug": debug, "to_extension": to_extension, "kernel_inflation": kernel_inflation, }, )
Create a recompiling convolution Module based on this
SelectSemifield
.Returns
conv
:nn.Module
- A convolution module, suitable for use in
GenericConv
. Note that the compilation process is not traceable, and recompilations may cause errors when usingtorch.compile
for backends other than CUDA Graphs
Other Parameters
thread_block_size
:int = 128
- The number of threads per CUDA block.
to_extension
:bool = False
-
Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs.
For neural networks, it is best to keep
to_extension
as False and use CUDA Graphs viatorch.compile(model, mode="reduce-overhead", fullgraph=True)
to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to useto_extension
and minimise call overhead. debug
:bool = False
- Whether to print additional debugging and compilation information.
kernel_inflation
:int = 16
- The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency.
def lazy_fixed(self, thread_block_size=None, to_extension=False, debug=False, kernel_inflation=16)
-
Expand source code Browse git
def lazy_fixed( self, thread_block_size: int = None, to_extension: bool = False, debug: bool = False, kernel_inflation: int = 16, ) -> torch.nn.Module: """ Create a *once-compiling* convolution Module based on this `SelectSemifield`. In general, `SelectSemifield.dynamic` should be preferred for testing and also for training if the model can be traced by CUDA Graphs. If CUDA Graphs cannot capture the model code due to dynamic elements, then using `SelectSemifield.lazy_fixed` with `to_extension=True` will minimise overhead. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv`. Note that compilation will be based on the first inputs seen, after which the operation will be fixed: **only batch size may be changed afterwards**. The module is, however, traceable by e.g. `torch.compile` on all backends. Other Parameters ---------- thread_block_size : int = 128 The number of threads per CUDA block. to_extension : bool = False Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs. For neural networks, it is best to keep `to_extension` as False and use CUDA Graphs via `torch.compile(model, mode="reduce-overhead", fullgraph=True)` to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use `to_extension` and minimise call overhead. debug : bool = False Whether to print additional debugging and compilation information. kernel_inflation : int = 16 The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency. """ return CompiledConvFixedLazy( self, { "thread_block_size": thread_block_size, "debug": debug, "to_extension": to_extension, "kernel_inflation": kernel_inflation, }, )
Create a once-compiling convolution Module based on this
SelectSemifield
.In general,
SelectSemifield.dynamic()
should be preferred for testing and also for training if the model can be traced by CUDA Graphs. If CUDA Graphs cannot capture the model code due to dynamic elements, then usingSelectSemifield.lazy_fixed()
withto_extension=True
will minimise overhead.Returns
conv
:nn.Module
- A convolution module, suitable for use in
GenericConv
. Note that compilation will be based on the first inputs seen, after which the operation will be fixed: only batch size may be changed afterwards. The module is, however, traceable by e.g.torch.compile
on all backends.
Other Parameters
thread_block_size
:int = 128
- The number of threads per CUDA block.
to_extension
:bool = False
-
Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs.
For neural networks, it is best to keep
to_extension
as False and use CUDA Graphs viatorch.compile(model, mode="reduce-overhead", fullgraph=True)
to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to useto_extension
and minimise call overhead. debug
:bool = False
- Whether to print additional debugging and compilation information.
kernel_inflation
:int = 16
- The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency.
class SubtractSemifield (add,
times,
d_times_d_img,
d_times_d_kernel,
subtract,
d_add_d_right,
zero,
cache_name=None,
post_sum=None,
undo_post_sum=None,
d_post_d_acc=None)-
Expand source code Browse git
class SubtractSemifield(NamedTuple): r""" A semifield definition where semifield addition has an inverse For such semifields, the backwards pass can be done by 'subtracting' every value from the result to get the arguments for the additive derivative. The resulting module is compiled and works only on CUDA devices. Note that, while this implementation is more memory-efficient than `BroadcastSemifield`, it is typically slower in execution speed. If memory usage is not a concern but training speed is, then `BroadastSemifield` should therefore be preferred. Parameters ------- add : (float, float) -> float Given an accumulator and a multiplied value, perform scalar semifield addition \(\oplus\). times : (float, float) -> float Given an image and a kernel value, perform scalar semifield multiplication \(\otimes\). d_times_d_img : (float, float) -> float Given the two arguments to `times`, compute the derivative to the first: \[\frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{img}}\] d_times_d_kernel : (float, float) -> float Given the two arguments to `times`, compute the derivative to the second: \[\frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{kernel}}\] subtract : (float, float) -> float Given the final accumulator value `res` and a multiplied value `val`, use the inverse of `add` and return an `acc` such that `add(acc, val) == res`. In other words: perform semifield subtraction. d_add_d_right : (float, float) -> float Given the two arguments to `add`, compute the derivative to the second: \[\frac{\delta (\textrm{acc} \oplus \textrm{val}) }{\delta\textrm{val}}\] zero : float The semifield zero. cache_name : str, optional Identifier for this semifield, allows for extension compilations to be cached. Instances of `SubtractSemifield` that are meaningfully different should not have the same `cache_name`, as this may lead to the wrong compilation being used. Other Parameters ------- post_sum : (float) -> float, optional Some semifield additions are fairly complex and computationally expensive, but can be reinterpreted as a repeated simpler operation, followed by a scalar transformation of the final accumulator value. `post_sum` is then this scalar transformation, taking the final accumulator value `res` and transforming it into a value `out`. Taking the root semifield \(R_3\) as an example, we can see that if we use - `times` as \(a \otimes_3 b = (a \times b)^3 \) - `add` as \(+\) - `post_sum` as \(\textrm{out} = \sqrt[3]{\textrm{res}} \) then we can perform the reduction in terms of simple scalar addition, instead of having to take the power and root at every step. Using such a transfrom does, however, require defining two other operators, namely the inverse and the derivative. When these are given, `subtract` and `d_add_d_right` will be given untransformed arguments: in the root semifield example, that would mean that the arguments to `subtract` and `d_add_d_right` are not yet taken to the `p`'th root. undo_post_sum : (float) -> float, optional The inverse of `post_sum`, required if `post_sum` is given. d_post_d_acc : (float) -> float, optional The derivative of `post_sum` to its argument, required if `post_sum` is given: \[\frac{\delta \textrm{post_sum}(\textrm{res}) }{\delta\textrm{res}}\] Examples ------- Linear convolution that will recompile for new parameters: >>> linear = SubtractSemifield.linear().dynamic() \(R_3\) convolution that will compile only once: >>> root = SubtractSemifield.root(3.0).lazy_fixed() For examples of how to construct a `SubtractSemifield`, see the source code. """ add: Callable[[float, float], float] # (acc, val) -> acc (+) val times: Callable[[float, float], float] # (img_val, krn_val) -> multiplied_val d_times_d_img: Callable[[float, float], float] d_times_d_kernel: Callable[[float, float], float] # (res, val) -> res-val, such that val (+) (res - val) == res subtract: Callable[[float, float], float] # d(acc (+) val) / dval d_add_d_right: Callable[[float, float], float] zero: float cache_name: str = None # Cache identifier: distinct for different operators post_sum: Callable[[float], float] = None # (final_acc) -> res undo_post_sum: Callable[[float], float] = None # (res) -> final_acc d_post_d_acc: Callable[[float], float] = None # (final_acc) -> dacc @classmethod def linear(cls) -> Self: r""" Construct a linear `SubtractSemifield` The linear field is defined as: \[(\mathbb{R}, +, \times)\] Mainly for comparison purposes: the linear convolutions offered by PyTorch use CUDNN, which is far better optimised for CUDA devices. """ return cls( add=lambda acc, val: acc + val, times=lambda img_val, kernel_val: img_val * kernel_val, d_times_d_img=lambda _i, kernel_val: kernel_val, d_times_d_kernel=lambda img_val, _k: img_val, subtract=lambda res, val: res - val, d_add_d_right=lambda _a, _v: 1, zero=0, cache_name="_linear", ) @classmethod def root(cls, p: float) -> Self: r""" Construct a \(R_p\) `SubtractSemifield`. The root semifields are defined as: \[(\mathbb{R}_+, \oplus_p, \times) \textrm{ for all } p\ne0 \textrm{ where } a\oplus_p b= \sqrt[p]{a^p+b^p} \] with the semifield zero being \(0\) and the semifield one being \(1\). Parameters ---------- p : int The power to use in \(\oplus_p\). May not be zero. """ assert p != 0, f"Invalid value: {p=}" return cls( times=lambda img_val, kernel_val: (img_val * kernel_val) ** p, add=lambda acc, val: (acc + val), post_sum=lambda acc: acc ** (1 / p), zero=0, cache_name=f"_root_{cls._number_to_cache(p)}", undo_post_sum=lambda res: res**p, subtract=lambda acc, val: acc - val, d_times_d_img=lambda a, b: ((a * b) ** p) * p / a, d_times_d_kernel=lambda a, b: ((a * b) ** p) * p / b, d_add_d_right=lambda _a, _b: 1, d_post_d_acc=lambda acc: (1 / p) * acc ** (1 / p - 1), ) @classmethod def log(cls, mu: float) -> Self: r""" Construct a \(L_+\mu\) or \(L_-\mu\) `SubtractSemifield`. The log semifields are defined as: \[(\mathbb{R}\cup \{\pm\infty\}, \oplus_\mu, +) \textrm{ for all } \mu\ne0 \textrm{ where } a\oplus_\mu b= \frac{1}{\mu}\ln(e^{\mu a}+e^{\mu b}) \] with the semifield zero being \(-\infty\) for \(\mu>0\) and \(\infty\) otherwise, and the semifield one being \(0\). Parameters ---------- mu : int The base to use in \(\oplus_\mu\). May not be zero. """ assert mu != 0, f"Invalid value: {mu=}" return cls( times=lambda img_val, kernel_val: math.exp((img_val + kernel_val) * mu), add=lambda acc, val: (acc + val), post_sum=lambda acc: math.log(acc) / mu, zero=0, cache_name=f"_log_{cls._number_to_cache(mu)}", d_times_d_img=lambda a, b: mu * math.exp((a + b) * mu), d_times_d_kernel=lambda a, b: mu * math.exp((a + b) * mu), undo_post_sum=lambda res: math.exp(res * mu), subtract=lambda acc, val: acc - val, d_add_d_right=lambda _a, _v: 1, d_post_d_acc=lambda acc: 1 / (mu * acc), ) # The torch compiler doesn't understand the Numba compiler @torch.compiler.disable @lru_cache # noqa: B019 def _compile( self, meta: ConvMeta, compile_options: Mapping[str, Any], ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]: impl = compile_options.get("impl", "glb") if impl not in ("glb",): raise ValueError(f"Unknown {impl=}") cmp_semi = CompiledSubtractSemifield.compile(self) impls = {"glb": compile_forwards} forwards = impls[impl]( semifield=cmp_semi, meta=meta, thread_block_size=compile_options.get("thread_block_size"), debug=compile_options.get("debug", False), cache_name="_temporary" if self.cache_name is None else self.cache_name, to_extension=compile_options.get("to_extension", False), ) backwards, backwards_setup = compile_backwards( semifield=cmp_semi, meta=meta, thread_block_size=compile_options.get("thread_block_size"), debug=compile_options.get("debug", False), cache_name="_temporary" if self.cache_name is None else self.cache_name, to_extension=compile_options.get("to_extension", False), ) forwards.register_autograd(backwards, setup_context=backwards_setup) return forwards def dynamic( self, thread_block_size: int = 256, to_extension: bool = False, debug: bool = False, kernel_inflation: int = 16, ) -> torch.nn.Module: """ Create a *recompiling* convolution Module based on this `SubtractSemifield`. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv`. Note that the compilation process is not traceable, and recompilations **may cause errors when using `torch.compile`** for backends other than CUDA Graphs Other Parameters ---------- thread_block_size : int = 128 The number of threads per CUDA block. to_extension : bool = False Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs. For neural networks, it is best to keep `to_extension` as False and use CUDA Graphs via `torch.compile(model, mode="reduce-overhead", fullgraph=True)` to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use `to_extension` and minimise call overhead. debug : bool = False Whether to print additional debugging and compilation information. kernel_inflation : int = 16 The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency. """ return CompiledConv( self, { "thread_block_size": thread_block_size, "debug": debug, "to_extension": to_extension, "kernel_inflation": kernel_inflation, }, ) def lazy_fixed( self, thread_block_size: int = 256, to_extension: bool = False, debug: bool = False, kernel_inflation: int = 16, ) -> torch.nn.Module: """ Create a *once-compiling* convolution Module based on this `SubtractSemifield`. In general, `SubtractSemifield.dynamic` should be preferred for testing and also for training if the model can be traced by CUDA Graphs. If CUDA Graphs cannot capture the model code due to dynamic elements, then using `SubtractSemifield.lazy_fixed` with `to_extension=True` will minimise overhead. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv`. Note that compilation will be based on the first inputs seen, after which the operation will be fixed: **only batch size may be changed afterwards**. The module is, however, traceable by e.g. `torch.compile`. Other Parameters ---------- thread_block_size : int = 256 The number of threads per CUDA block to_extension : bool = False Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs. For neural networks, it is best to keep `to_extension` as False and use CUDA Graphs via `torch.compile(model, mode="reduce-overhead", fullgraph=True)` to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use `to_extension` and minimise call overhead. debug : bool = False Whether to print additional debugging and compilation information. kernel_inflation : int = 16 The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency. """ return CompiledConvFixedLazy( self, { "thread_block_size": thread_block_size, "debug": debug, "to_extension": to_extension, "kernel_inflation": kernel_inflation, }, ) def __hash__(self): if self.cache_name is not None: return hash(self.cache_name) return hash( ( self.add, self.times, self.d_times_d_img, self.d_times_d_kernel, self.subtract, self.d_add_d_right, self.zero, ) ) def __eq__(self, other): if not isinstance(other, SubtractSemifield): return False if self.cache_name is not None: return self.cache_name == other.cache_name return self is other @staticmethod def _get_result(res: torch.Tensor): return res @staticmethod def _number_to_cache(n: float): return str(n).replace(".", "_").replace("-", "_minus_")
A semifield definition where semifield addition has an inverse
For such semifields, the backwards pass can be done by 'subtracting' every value from the result to get the arguments for the additive derivative. The resulting module is compiled and works only on CUDA devices.
Note that, while this implementation is more memory-efficient than
BroadcastSemifield
, it is typically slower in execution speed. If memory usage is not a concern but training speed is, thenBroadastSemifield
should therefore be preferred.Parameters
add
:(float, float) -> float
- Given an accumulator and a multiplied value, perform scalar semifield addition \oplus.
times
:(float, float) -> float
- Given an image and a kernel value, perform scalar semifield multiplication \otimes.
d_times_d_img
:(float, float) -> float
- Given the two arguments to
times
, compute the derivative to the first: \frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{img}} d_times_d_kernel
:(float, float) -> float
- Given the two arguments to
times
, compute the derivative to the second: \frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{kernel}} subtract
:(float, float) -> float
-
Given the final accumulator value
res
and a multiplied valueval
, use the inverse ofadd
andreturn an
acc
such thatadd(acc, val) == res
. In other words: perform semifield subtraction. d_add_d_right
:(float, float) -> float
- Given the two arguments to
add
, compute the derivative to the second: \frac{\delta (\textrm{acc} \oplus \textrm{val}) }{\delta\textrm{val}} zero
:float
- The semifield zero.
cache_name
:str
, optional-
Identifier for this semifield, allows for extension compilations to be cached.
Instances of
SubtractSemifield
that are meaningfully different should not have the samecache_name
, as this may lead to the wrong compilation being used.
Other Parameters
post_sum
:(float) -> float
, optional-
Some semifield additions are fairly complex and computationally expensive, but can be reinterpreted as a repeated simpler operation, followed by a scalar transformation of the final accumulator value.
post_sum
is then this scalar transformation, taking the final accumulator valueres
and transforming it into a valueout
.Taking the root semifield R_3 as an example, we can see that if we use
times
as a \otimes_3 b = (a \times b)^3add
as +post_sum
as \textrm{out} = \sqrt[3]{\textrm{res}}
then we can perform the reduction in terms of simple scalar addition, instead of having to take the power and root at every step.
Using such a transfrom does, however, require defining two other operators, namely the inverse and the derivative. When these are given,
subtract
andd_add_d_right
will be given untransformed arguments: in the root semifield example, that would mean that the arguments tosubtract
andd_add_d_right
are not yet taken to thep
'th root. undo_post_sum
:(float) -> float
, optional- The inverse of
post_sum
, required ifpost_sum
is given. d_post_d_acc
:(float) -> float
, optional- The derivative of
post_sum
to its argument, required ifpost_sum
is given: \frac{\delta \textrm{post_sum}(\textrm{res}) }{\delta\textrm{res}}
Examples
Linear convolution that will recompile for new parameters:
>>> linear = SubtractSemifield.linear().dynamic()
R_3 convolution that will compile only once:
>>> root = SubtractSemifield.root(3.0).lazy_fixed()
For examples of how to construct a
SubtractSemifield
, see the source code.Static methods
def linear()
-
Construct a linear
SubtractSemifield
The linear field is defined as: (\mathbb{R}, +, \times)
Mainly for comparison purposes: the linear convolutions offered by PyTorch use CUDNN, which is far better optimised for CUDA devices.
def root(p)
-
Construct a R_p
SubtractSemifield
.The root semifields are defined as: (\mathbb{R}_+, \oplus_p, \times) \textrm{ for all } p\ne0 \textrm{ where } a\oplus_p b= \sqrt[p]{a^p+b^p} with the semifield zero being 0 and the semifield one being 1.
Parameters
p
:int
- The power to use in \oplus_p. May not be zero.
def log(mu)
-
Construct a L_+\mu or L_-\mu
SubtractSemifield
.The log semifields are defined as: (\mathbb{R}\cup \{\pm\infty\}, \oplus_\mu, +) \textrm{ for all } \mu\ne0 \textrm{ where } a\oplus_\mu b= \frac{1}{\mu}\ln(e^{\mu a}+e^{\mu b}) with the semifield zero being -\infty for \mu>0 and \infty otherwise, and the semifield one being 0.
Parameters
mu
:int
- The base to use in \oplus_\mu. May not be zero.
Methods
def dynamic(self, thread_block_size=256, to_extension=False, debug=False, kernel_inflation=16)
-
Expand source code Browse git
def dynamic( self, thread_block_size: int = 256, to_extension: bool = False, debug: bool = False, kernel_inflation: int = 16, ) -> torch.nn.Module: """ Create a *recompiling* convolution Module based on this `SubtractSemifield`. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv`. Note that the compilation process is not traceable, and recompilations **may cause errors when using `torch.compile`** for backends other than CUDA Graphs Other Parameters ---------- thread_block_size : int = 128 The number of threads per CUDA block. to_extension : bool = False Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs. For neural networks, it is best to keep `to_extension` as False and use CUDA Graphs via `torch.compile(model, mode="reduce-overhead", fullgraph=True)` to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use `to_extension` and minimise call overhead. debug : bool = False Whether to print additional debugging and compilation information. kernel_inflation : int = 16 The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency. """ return CompiledConv( self, { "thread_block_size": thread_block_size, "debug": debug, "to_extension": to_extension, "kernel_inflation": kernel_inflation, }, )
Create a recompiling convolution Module based on this
SubtractSemifield
.Returns
conv
:nn.Module
- A convolution module, suitable for use in
GenericConv
. Note that the compilation process is not traceable, and recompilations may cause errors when usingtorch.compile
for backends other than CUDA Graphs
Other Parameters
thread_block_size
:int = 128
- The number of threads per CUDA block.
to_extension
:bool = False
-
Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs.
For neural networks, it is best to keep
to_extension
as False and use CUDA Graphs viatorch.compile(model, mode="reduce-overhead", fullgraph=True)
to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to useto_extension
and minimise call overhead. debug
:bool = False
- Whether to print additional debugging and compilation information.
kernel_inflation
:int = 16
- The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency.
def lazy_fixed(self, thread_block_size=256, to_extension=False, debug=False, kernel_inflation=16)
-
Expand source code Browse git
def lazy_fixed( self, thread_block_size: int = 256, to_extension: bool = False, debug: bool = False, kernel_inflation: int = 16, ) -> torch.nn.Module: """ Create a *once-compiling* convolution Module based on this `SubtractSemifield`. In general, `SubtractSemifield.dynamic` should be preferred for testing and also for training if the model can be traced by CUDA Graphs. If CUDA Graphs cannot capture the model code due to dynamic elements, then using `SubtractSemifield.lazy_fixed` with `to_extension=True` will minimise overhead. Returns ------- conv : nn.Module A convolution module, suitable for use in `GenericConv`. Note that compilation will be based on the first inputs seen, after which the operation will be fixed: **only batch size may be changed afterwards**. The module is, however, traceable by e.g. `torch.compile`. Other Parameters ---------- thread_block_size : int = 256 The number of threads per CUDA block to_extension : bool = False Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs. For neural networks, it is best to keep `to_extension` as False and use CUDA Graphs via `torch.compile(model, mode="reduce-overhead", fullgraph=True)` to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use `to_extension` and minimise call overhead. debug : bool = False Whether to print additional debugging and compilation information. kernel_inflation : int = 16 The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency. """ return CompiledConvFixedLazy( self, { "thread_block_size": thread_block_size, "debug": debug, "to_extension": to_extension, "kernel_inflation": kernel_inflation, }, )
Create a once-compiling convolution Module based on this
SubtractSemifield
.In general,
SubtractSemifield.dynamic()
should be preferred for testing and also for training if the model can be traced by CUDA Graphs. If CUDA Graphs cannot capture the model code due to dynamic elements, then usingSubtractSemifield.lazy_fixed()
withto_extension=True
will minimise overhead.Returns
conv
:nn.Module
- A convolution module, suitable for use in
GenericConv
. Note that compilation will be based on the first inputs seen, after which the operation will be fixed: only batch size may be changed afterwards. The module is, however, traceable by e.g.torch.compile
.
Other Parameters
thread_block_size
:int = 256
- The number of threads per CUDA block
to_extension
:bool = False
-
Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs.
For neural networks, it is best to keep
to_extension
as False and use CUDA Graphs viatorch.compile(model, mode="reduce-overhead", fullgraph=True)
to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to useto_extension
and minimise call overhead. debug
:bool = False
- Whether to print additional debugging and compilation information.
kernel_inflation
:int = 16
- The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency.
class QuadraticKernelSpectral2D (in_channels, out_channels, kernel_size, init=None)
-
Expand source code Browse git
class QuadraticKernelSpectral2D(nn.Module): r""" A kernel that evaluates \(x^T S^{-1} x\), with skew parameterised as an angle \(\theta\) This module takes no arguments in `forward` and produces a `Tensor` of `OIHW`, making this Module suitable as a kernel for `GenericConv`. Parameters ------- in_channels : int The number of input channels: the `I` in `OIHW`. out_channels : int The number of output channels: the `O` in `OIHW`. kernel_size : int The height `H` and width `W` of the kernel (rectangular kernels are not supported). init : dict, optional The initialisation stratergy for the underlying covariance matrices. If provided, the dictionary must have keys: `"var"` for the variances, which can take values: - `float` to indicate a constant initialisation - `"normal"` to indicate values normally distributed around 2.0 - `"uniform"` to indicate uniform-random initialisation - `"uniform-iso"` to indicate isotropic uniform-random initialisation - `"ss-iso"` to indicate scale-space isotropic initialisation - `"skewed"` to indicate uniform-random initialisation with the second primary axis having a significantly higher variance (**default**) and `"theta"` for the rotations, which can take values: - `"uniform"` to indicate uniform-random initialisation - `"spin"` to indicate shuffled but evenly spaced angles (**default**) Examples ------- >>> kernel = QuadraticKernelSpectral2D(5, 6, 3, {"var": 3.0, "theta": "spin"}) >>> tuple(kernel().shape) (6, 5, 3, 3) """ _pos_grid: torch.Tensor def __init__( self, in_channels: int, out_channels: int, kernel_size: int, init: dict[str, str | float] | None = None, ): super().__init__() self.covs = CovSpectral2D(in_channels, out_channels, init, kernel_size) self.kernel_size = kernel_size self.out_channels = out_channels self.in_channels = in_channels self.register_buffer( "_pos_grid", make_pos_grid(kernel_size).reshape(kernel_size * kernel_size, 2), ) def forward(self): dists = torch.einsum( "kx,oixX,kX->oik", self._pos_grid, self.covs.inverse_cov(), self._pos_grid ).view( ( self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, ) ) return -dists def extra_repr(self): kernel_size = self.kernel_size return f"{self.in_channels}, {self.out_channels}, {kernel_size=}" def inspect_parameters(self) -> tuple[torch.Tensor, torch.Tensor]: """ Inspect the parameters of the underlying covariance matrices. Returns ------- log_std : Tensor of (O, I, 2) The logathirms of the standard deviations in both axes for all kernels theta : Tensor of (O, I) The counter-clockwise angles between the first axis and the X-axis for all kernels """ return self.covs.log_std, self.covs.theta @torch.no_grad() def plot(self): """Provide a simple visualisation of some kernels. Requires `seaborn`.""" plot_kernels(self.forward())
A kernel that evaluates x^T S^{-1} x, with skew parameterised as an angle \theta
This module takes no arguments in
forward
and produces aTensor
ofOIHW
, making this Module suitable as a kernel forGenericConv
.Parameters
in_channels
:int
- The number of input channels: the
I
inOIHW
. out_channels
:int
- The number of output channels: the
O
inOIHW
. kernel_size
:int
- The height
H
and widthW
of the kernel (rectangular kernels are not supported). init
:dict
, optional-
The initialisation stratergy for the underlying covariance matrices. If provided, the dictionary must have keys:
"var"
for the variances, which can take values:float
to indicate a constant initialisation"normal"
to indicate values normally distributed around 2.0"uniform"
to indicate uniform-random initialisation"uniform-iso"
to indicate isotropic uniform-random initialisation"ss-iso"
to indicate scale-space isotropic initialisation"skewed"
to indicate uniform-random initialisation with the second primary axis having a significantly higher variance (default)
and
"theta"
for the rotations, which can take values:"uniform"
to indicate uniform-random initialisation"spin"
to indicate shuffled but evenly spaced angles (default)
Examples
>>> kernel = QuadraticKernelSpectral2D(5, 6, 3, {"var": 3.0, "theta": "spin"}) >>> tuple(kernel().shape) (6, 5, 3, 3)
Methods
def inspect_parameters(self)
-
Expand source code Browse git
def inspect_parameters(self) -> tuple[torch.Tensor, torch.Tensor]: """ Inspect the parameters of the underlying covariance matrices. Returns ------- log_std : Tensor of (O, I, 2) The logathirms of the standard deviations in both axes for all kernels theta : Tensor of (O, I) The counter-clockwise angles between the first axis and the X-axis for all kernels """ return self.covs.log_std, self.covs.theta
Inspect the parameters of the underlying covariance matrices.
Returns
log_std
:Tensor
of(O, I, 2)
- The logathirms of the standard deviations in both axes for all kernels
theta
:Tensor
of(O, I)
- The counter-clockwise angles between the first axis and the X-axis for all kernels
def plot(self)
-
Expand source code Browse git
@torch.no_grad() def plot(self): """Provide a simple visualisation of some kernels. Requires `seaborn`.""" plot_kernels(self.forward())
Provide a simple visualisation of some kernels. Requires
seaborn
.
class QuadraticKernelCholesky2D (in_channels, out_channels, kernel_size, init=None)
-
Expand source code Browse git
class QuadraticKernelCholesky2D(nn.Module): r""" A kernel that evaluates \(x^T S^{-1} x\), with skew parameterised as Pearson's R This module takes no arguments in `forward` and produces a `Tensor` of `OIHW`, making this Module suitable as a kernel for `GenericConv`. Parameters ------- in_channels : int The number of input channels: the `I` in `OIHW`. out_channels : int The number of output channels: the `O` in `OIHW`. kernel_size : int The height `H` and width `W` of the kernel (rectangular kernels are not supported). init : dict, optional The initialisation stratergy for the underlying covariance matrices. If provided, the dictionary must have the key `"var"`, which can take values: - `float` to indicate a constant initialisation - `"normal"` to indicate values normally distributed around 2.0 - `"uniform"` to indicate uniform-random initialisation - `"uniform-iso"` to indicate isotropic uniform-random initialisation - `"ss-iso"` to indicate scale-space isotropic initialisation - `"skewed"` to indicate uniform-random initialisation with the second primary axis having a significantly higher variance (**default**) The skew parameter is always initialised using a clipped normal distribution centred around 0. Examples ------- >>> kernel = QuadraticKernelCholesky2D(5, 6, 3, {"var": 3.0}) >>> tuple(kernel().shape) (6, 5, 3, 3) """ _pos_grid: torch.Tensor def __init__( self, in_channels: int, out_channels: int, kernel_size: int, init: dict[str, str | float] | None = None, ): super().__init__() self.covs = CovCholesky2D(in_channels, out_channels, init, kernel_size) self.kernel_size = kernel_size self.out_channels = out_channels self.in_channels = in_channels self.register_buffer("_pos_grid", make_pos_grid(kernel_size, grid_at_end=True)) def forward(self): # [o, i, 2, k*k] bs = torch.linalg.solve_triangular( self.covs.cholesky(), self._pos_grid, upper=False ) dists = ( bs.pow(2) .sum(-2) .view( ( self.out_channels, self.in_channels, self.kernel_size, self.kernel_size, ) ) ) return -dists def extra_repr(self): kernel_size = self.kernel_size return f"{self.in_channels}, {self.out_channels}, {kernel_size=}" def inspect_parameters(self) -> tuple[torch.Tensor, torch.Tensor]: """ Inspect the parameters of the underlying covariance matrices. Returns ------- log_std : Tensor of (O, I, 2) The logathirms of the standard deviations in both axes for all kernels corr : Tensor of (O, I) The skews, as values for Person's R, for all kernels """ return self.covs.log_std.moveaxis(0, 2), self.covs.corr @torch.no_grad() def plot(self): """Provide a simple visualisation of some kernels. Requires `seaborn`.""" plot_kernels(self.forward())
A kernel that evaluates x^T S^{-1} x, with skew parameterised as Pearson's R
This module takes no arguments in
forward
and produces aTensor
ofOIHW
, making this Module suitable as a kernel forGenericConv
.Parameters
in_channels
:int
- The number of input channels: the
I
inOIHW
. out_channels
:int
- The number of output channels: the
O
inOIHW
. kernel_size
:int
- The height
H
and widthW
of the kernel (rectangular kernels are not supported). init
:dict
, optional-
The initialisation stratergy for the underlying covariance matrices. If provided, the dictionary must have the key
"var"
, which can take values:float
to indicate a constant initialisation"normal"
to indicate values normally distributed around 2.0"uniform"
to indicate uniform-random initialisation"uniform-iso"
to indicate isotropic uniform-random initialisation"ss-iso"
to indicate scale-space isotropic initialisation"skewed"
to indicate uniform-random initialisation with the second primary axis having a significantly higher variance (default)
The skew parameter is always initialised using a clipped normal distribution centred around 0.
Examples
>>> kernel = QuadraticKernelCholesky2D(5, 6, 3, {"var": 3.0}) >>> tuple(kernel().shape) (6, 5, 3, 3)
Methods
def inspect_parameters(self)
-
Expand source code Browse git
def inspect_parameters(self) -> tuple[torch.Tensor, torch.Tensor]: """ Inspect the parameters of the underlying covariance matrices. Returns ------- log_std : Tensor of (O, I, 2) The logathirms of the standard deviations in both axes for all kernels corr : Tensor of (O, I) The skews, as values for Person's R, for all kernels """ return self.covs.log_std.moveaxis(0, 2), self.covs.corr
Inspect the parameters of the underlying covariance matrices.
Returns
log_std
:Tensor
of(O, I, 2)
- The logathirms of the standard deviations in both axes for all kernels
corr
:Tensor
of(O, I)
- The skews, as values for Person's R, for all kernels
def plot(self)
-
Expand source code Browse git
@torch.no_grad() def plot(self): """Provide a simple visualisation of some kernels. Requires `seaborn`.""" plot_kernels(self.forward())
Provide a simple visualisation of some kernels. Requires
seaborn
.
class QuadraticKernelIso2D (in_channels, out_channels, kernel_size, init=None)
-
Expand source code Browse git
class QuadraticKernelIso2D(nn.Module): r""" A kernel that evaluates \(x^T sI x\), representing an isotropic quadratic This module takes no arguments in `forward` and produces a `Tensor` of `OIHW`, making this Module suitable as a kernel for `GenericConv`. Parameters ------- in_channels : int The number of input channels: the `I` in `OIHW`. out_channels : int The number of output channels: the `O` in `OIHW`. kernel_size : int The height `H` and width `W` of the kernel (rectangular kernels are not supported). init : dict, optional The initialisation stratergy for the variances / scale parameters. If provided, the dictionary must have the key `"var"`, which can take values: - `float` to indicate a constant initialisation - `"normal"` to indicate values normally distributed around 2.0 - `"uniform"` to indicate uniform-random initialisation - `"ss"` to indicate scale-space initialisation (**default**) Attributes ------- log_std : Tensor of (O, I) The logathirms of the standard deviations for all kernels Examples ------- >>> kernel = QuadraticKernelIso2D(5, 6, 3, {"var": 3.0}) >>> tuple(kernel().shape) (6, 5, 3, 3) """ log_std: torch.Tensor _pos_grid: torch.Tensor def __init__( self, in_channels: int, out_channels: int, kernel_size: int, init: dict = None, ): super().__init__() init: dict[str, str | float] = init or {"var": "ss"} variances = torch.empty((out_channels, in_channels)) if isinstance(init["var"], float): nn.init.constant_(variances, init["var"]) elif init["var"] == "uniform": nn.init.uniform_(variances, 1, 16) elif init["var"] == "ss": spaced_vars = torch.linspace( 1, 2 * (kernel_size // 2) ** 2, steps=out_channels * in_channels, ) permutation = torch.randperm(spaced_vars.shape[0]) variances[:] = spaced_vars[permutation].reshape(out_channels, in_channels) elif init["var"] == "log-ss": spaced_vars = torch.logspace( math.log10(1), math.log10(2 * (kernel_size // 2) ** 2), steps=out_channels * in_channels, ) permutation = torch.randperm(spaced_vars.shape[0]) variances[:] = spaced_vars[permutation].reshape(out_channels, in_channels) elif init["var"] == "normal": nn.init.trunc_normal_(variances, mean=8.0, a=1.0, b=16.0) else: raise ValueError(f"Invalid {init['var']=}") self.log_std = nn.Parameter(variances.log().mul(0.5)) self.kernel_size = kernel_size self.out_channels = out_channels self.in_channels = in_channels self.register_buffer("_pos_grid", make_pos_grid(kernel_size, grid_at_end=True)) def forward(self): dists = ( self._pos_grid.pow(2).sum(-2) / self.log_std.mul(2).exp().unsqueeze(2) ).reshape( self.out_channels, self.in_channels, self.kernel_size, self.kernel_size ) return -dists def extra_repr(self): kernel_size = self.kernel_size return f"{self.in_channels}, {self.out_channels}, {kernel_size=}" @torch.no_grad() def plot(self): """Provide a simple visualisation of some kernels. Requires `seaborn`.""" plot_kernels(self.forward())
A kernel that evaluates x^T sI x, representing an isotropic quadratic
This module takes no arguments in
forward
and produces aTensor
ofOIHW
, making this Module suitable as a kernel forGenericConv
.Parameters
in_channels
:int
- The number of input channels: the
I
inOIHW
. out_channels
:int
- The number of output channels: the
O
inOIHW
. kernel_size
:int
- The height
H
and widthW
of the kernel (rectangular kernels are not supported). init
:dict
, optional-
The initialisation stratergy for the variances / scale parameters. If provided, the dictionary must have the key
"var"
, which can take values:float
to indicate a constant initialisation"normal"
to indicate values normally distributed around 2.0"uniform"
to indicate uniform-random initialisation"ss"
to indicate scale-space initialisation (default)
Attributes
log_std
:Tensor
of(O, I)
- The logathirms of the standard deviations for all kernels
Examples
>>> kernel = QuadraticKernelIso2D(5, 6, 3, {"var": 3.0}) >>> tuple(kernel().shape) (6, 5, 3, 3)
Methods
def plot(self)
-
Expand source code Browse git
@torch.no_grad() def plot(self): """Provide a simple visualisation of some kernels. Requires `seaborn`.""" plot_kernels(self.forward())
Provide a simple visualisation of some kernels. Requires
seaborn
.
class GenericClosing (kernel,
conv_dilation,
conv_erosion,
stride=1,
padding=0,
dilation=1,
groups=1,
group_broadcasting=False)-
Expand source code Browse git
class GenericClosing(nn.Module): """ A generic Module for implement a morphological closing with a dilation and erosion. `kind` is fixed to `"conv"` for dilation and `"corr"` for erosion, to simplify the implementation of the common morphological closing. Parameters ------- kernel : nn.Module A module that produces a convolutional kernel from its `forward` method. conv_dilation : nn.Module A module representing the adjoint dilation that can take `image, kernel` as positional arguments, as well as `dilation`, `padding`, `stride`, `groups` **and `kind`** as keyword arguments, optionally supporting `group_broadcasting` and `kind`. conv_erosion : nn.Module A module representing the adjoint erosion that can take `image, kernel` as positional arguments, as well as `dilation`, `padding`, `stride`, `groups` **and `kind`** as keyword arguments, optionally supporting `group_broadcasting`. stride : int, (int, ...) = 1 The stride passed to `conv`, either for all spatial dimensions or for each separately. padding : int, (int, ...), ((int, int), ...), "valid", "same" = 0 The padding passed to `conv`. Depending on the type of `padding`: - `P` indicates padding at the start and end of all spatial axes with `P`. - `(P0, ...)` indicates padding at the start and end of the first spatial axis with `P0`, and similarly for all other spatial axes. - `((PBeg0, PEnd0), ...)` indicates padding the start of the first spatial axis with `PBeg0` and the end with `PEnd0`, similarly for all other spatial axes. - `"valid"` indicates to only perform the convolution with valid values of the image, i.e. no padding. - `"same"` indicates to pad the input such that a stride-1 convolution would produce an output of the same spatial size. Convolutions with higher stride will use the same padding scheme, but result in outputs of reduced size. dilation : int, (int, ...) = 1 The dilation passed to `conv`, either for all spatial dimensions or for each separately. groups : int = 1 The number of convolutional groups for this convolution. group_broadcasting : bool = False Whether to take the input kernels as a single output group, and broadcast across all input groups. `group_broadcasting` has no effect when `groups=1` Examples ------- >>> import pytorch_nd_semiconv as semiconv >>> common_closing = semiconv.GenericClosing( ... semiconv.QuadraticKernelCholesky2D(5, 5, 3), ... semiconv.SelectSemifield.tropical_max().lazy_fixed(), ... semiconv.SelectSemifield.tropical_min_negated().lazy_fixed(), ... padding="same", ... groups=5, ... ) """ def __init__( self, kernel: nn.Module, conv_dilation: nn.Module, conv_erosion: nn.Module, stride: int | tuple[int, ...] = 1, padding: ( int | tuple[int, ...] | tuple[tuple[int, int], ...] | Literal["valid", "same"] ) = 0, dilation: int | tuple[int, ...] = 1, groups: int = 1, group_broadcasting: bool = False, ): super().__init__() self.padding = padding self.stride = stride self.dilation = dilation self.kernel = kernel self.conv_dilation = conv_dilation self.conv_erosion = conv_erosion self.groups = groups self.group_broadcasting = group_broadcasting self.kind = "closing" # Since these are custom arguments, we only want to pass them if they differ # from the default values (otherwise, they may be unexpected) self.kwargs = {} if self.group_broadcasting: self.kwargs["group_broadcasting"] = True def forward(self, x): kernel = self.kernel() dilated = self.conv_dilation( x, kernel, dilation=self.dilation, padding=self.padding, stride=self.stride, groups=self.groups, kind="conv", **self.kwargs, ) closed = self.conv_erosion( dilated, kernel, dilation=self.dilation, padding=self.padding, stride=self.stride, groups=self.groups, kind="corr", **self.kwargs, ) return closed extra_repr = GenericConv.extra_repr
A generic Module for implement a morphological closing with a dilation and erosion.
kind
is fixed to"conv"
for dilation and"corr"
for erosion, to simplify the implementation of the common morphological closing.Parameters
kernel
:nn.Module
- A module that produces a convolutional kernel from its
forward
method. conv_dilation
:nn.Module
- A module representing the adjoint dilation that can take
image, kernel
as positional arguments, as well asdilation
,padding
,stride
,groups
andkind
as keyword arguments, optionally supportinggroup_broadcasting
andkind
. conv_erosion
:nn.Module
- A module representing the adjoint erosion that can take
image, kernel
as positional arguments, as well asdilation
,padding
,stride
,groups
andkind
as keyword arguments, optionally supportinggroup_broadcasting
. stride
:int, (int, ...) = 1
- The stride passed to
conv
, either for all spatial dimensions or for each separately. padding
:int, (int, ...), ((int, int), ...), "valid", "same" = 0
-
The padding passed to
conv
. Depending on the type ofpadding
:P
indicates padding at the start and end of all spatial axes withP
.(P0, …)
indicates padding at the start and end of the first spatial axis withP0
, and similarly for all other spatial axes.((PBeg0, PEnd0), …)
indicates padding the start of the first spatial axis withPBeg0
and the end withPEnd0
, similarly for all other spatial axes."valid"
indicates to only perform the convolution with valid values of the image, i.e. no padding."same"
indicates to pad the input such that a stride-1 convolution would produce an output of the same spatial size. Convolutions with higher stride will use the same padding scheme, but result in outputs of reduced size.
dilation
:int, (int, ...) = 1
- The dilation passed to
conv
, either for all spatial dimensions or for each separately. groups
:int = 1
- The number of convolutional groups for this convolution.
group_broadcasting
:bool = False
- Whether to take the input kernels as a single output group, and broadcast
across all input groups.
group_broadcasting
has no effect whengroups=1
Examples
>>> import pytorch_nd_semiconv as semiconv >>> common_closing = semiconv.GenericClosing( ... semiconv.QuadraticKernelCholesky2D(5, 5, 3), ... semiconv.SelectSemifield.tropical_max().lazy_fixed(), ... semiconv.SelectSemifield.tropical_min_negated().lazy_fixed(), ... padding="same", ... groups=5, ... )
class LearnedKernel2D (in_channels, out_channels, kernel_size)
-
Expand source code Browse git
class LearnedKernel2D(nn.Module): """ A utility that provides a fully learnable kernel compatible with `GenericConv` Parameters ------- in_channels : int The number of input channels: the `I` in `OIHW`. out_channels : int The number of output channels: the `O` in `OIHW`. kernel_size : int The height `H` and width `W` of the kernel (rectangular kernels not supported). """ def __init__(self, in_channels: int, out_channels: int, kernel_size: int): super().__init__() self.kernel = nn.Parameter( torch.empty(out_channels, in_channels, kernel_size, kernel_size) ) nn.init.normal_(self.kernel) def forward(self): return self.kernel
A utility that provides a fully learnable kernel compatible with
GenericConv
Parameters
in_channels
:int
- The number of input channels: the
I
inOIHW
. out_channels
:int
- The number of output channels: the
O
inOIHW
. kernel_size
:int
- The height
H
and widthW
of the kernel (rectangular kernels not supported).
class TorchLinearConv2D (*args, **kwargs)
-
Expand source code Browse git
class TorchLinearConv2D(nn.Module): """ A utility that provides PyTorch Conv2D in a form compatible with `GenericConv`. """ @staticmethod def forward( img: torch.Tensor, kernel: torch.Tensor, stride: int | tuple[int, int] = 1, padding: ( int | tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] | Literal["valid", "same"] ) = 0, dilation: int | tuple[int, int] = 1, groups: int = 1, group_broadcasting: bool = False, kind: Literal["conv", "corr"] = "conv", ): if group_broadcasting: if kernel.shape[0] != 1: raise ValueError("Torch conv2d cannot broadcast groups with grp_o > 1") kernel = kernel.broadcast_to( (groups, kernel.shape[1], kernel.shape[2], kernel.shape[3]) ) if kind == "conv": kernel = kernel.flip((2, 3)) dilation = _as_tup_n(dilation, 2) (pad_y_beg, pad_y_end), (pad_x_beg, pad_x_end) = get_padding( padding, 2, dilation, kernel.shape[2:] ) if pad_y_beg != pad_y_end or pad_x_beg != pad_x_end: padded = torch.constant_pad_nd( img, # Yes, the padding really is in this order. (pad_x_beg, pad_x_end, pad_y_beg, pad_y_end), ) return torch.nn.functional.conv2d( padded, kernel, stride=stride, dilation=dilation, groups=groups ) return torch.nn.functional.conv2d( img, kernel, stride=stride, dilation=dilation, groups=groups, padding=(pad_y_beg, pad_x_beg), )
A utility that provides PyTorch Conv2D in a form compatible with
GenericConv
. class TorchMaxPool2D (kernel_size, stride=None, padding=0, dilation=1)
-
Expand source code Browse git
class TorchMaxPool2D(nn.Module): """ A utility that provides torch.nn.MaxPool2d with padding like `GenericConv`. """ def __init__( self, kernel_size: int | tuple[int, int], stride: int | tuple[int, int] = None, padding: ( int | tuple[int, int] | tuple[tuple[int, int], tuple[int, int]] | Literal["valid", "same"] ) = 0, dilation: int | tuple[int, int] = 1, ): super().__init__() self.kernel_size = kernel_size self.stride = kernel_size if stride is None else stride self.padding = padding self.dilation = dilation def forward( self, img: torch.Tensor, ): dilation = _as_tup_n(self.dilation, 2) krn_spatial = _as_tup_n(self.kernel_size, 2) (pad_y_beg, pad_y_end), (pad_x_beg, pad_x_end) = get_padding( self.padding, 2, dilation, krn_spatial ) if pad_y_beg == pad_y_end and pad_x_beg == pad_x_end: use_padding = (pad_y_beg, pad_x_beg) else: img = torch.constant_pad_nd( img, # Yes, the padding really is in this order. (pad_x_beg, pad_x_end, pad_y_beg, pad_y_end), ) use_padding = 0 return torch.nn.functional.max_pool2d( input=img, kernel_size=krn_spatial, stride=self.stride, padding=use_padding, dilation=dilation, ceil_mode=False, return_indices=False, )
A utility that provides torch.nn.MaxPool2d with padding like
GenericConv
.