Package pytorch_nd_semiconv

This package offers a generalisation of linear convolutions and max-poolings using semifields and for an arbitrary number of spatial dimensions, with many common choices for semifields being available.

Additionally, this package provides example implementations of 2D isotropic and anisotropic quadratic kernels for use in e.g. dilation of 2-dimensional images. These implementations are efficient when run under torch.compile.

As an example of anisotropic quadratic dilation, consider replacing:

pooling = torch.nn.MaxPool2d(3, stride=2, padding=1)

with:

import pytorch_nd_semiconv as semiconv

dilation = semiconv.GenericConv(
    semiconv.QuadraticKernelSpectral2D(
        in_channels=5, out_channels=5, kernel_size=3
    ),
    semiconv.SelectSemifield.tropical_max().lazy_fixed(),
    padding="same",
    stride=2,
    groups=5,
)

See also the GitHub page

Classes

class GenericConv (kernel,
conv,
stride=1,
padding=0,
dilation=1,
groups=1,
group_broadcasting=False,
kind='conv')
Expand source code Browse git
class GenericConv(nn.Module):
    """
    A generic convolution Module using a kernel and a convolution Module

    Parameters
    -------
    kernel : nn.Module
        A module that produces a convolutional kernel from its `forward` method.
        Must not take arguments.

        See e.g. `QuadraticKernelSpectral2D` or `LearnedKernel2D`.
    conv : nn.Module
        A module that can take `image, kernel` as positional arguments, as well as
        `dilation`, `padding`, `stride` and `groups` as keyword arguments, optionally
        supporting `group_broadcasting` and `kind`.

        See e.g. `BroadcastSemifield.dynamic` or `SelectSemifield.lazy_fixed`.
    stride : int, (int, ...) = 1
        The stride passed to `conv`, either for all spatial dimensions or for each
        separately.
    padding : int, (int, ...), ((int, int), ...), "valid", "same" = 0
        The padding passed to `conv`.
        Depending on the type of `padding`:

        - `P` indicates padding at the start and end of all spatial axes with `P`.
        - `(P0, ...)` indicates padding at the start and end of the first spatial axis
          with `P0`, and similarly for all other spatial axes.
        - `((PBeg0, PEnd0), ...)` indicates padding the start of the first spatial axis
           with `PBeg0` and the end with `PEnd0`, similarly for all other spatial axes.
        - `"valid"` indicates to only perform the convolution with valid values of the
          image, i.e. no padding.
        - `"same"` indicates to pad the input such that a stride-1 convolution would
          produce an output of the same spatial size.
          Convolutions with higher stride will use the same padding scheme, but result
          in outputs of reduced size.
    dilation : int, (int, ...) = 1
        The dilation passed to `conv`, either for all spatial dimensions or for each
        separately.
    groups : int = 1
        The number of convolutional groups for this convolution.
    group_broadcasting : bool = False
        Whether to take the input kernels as a single output group, and broadcast
        across all input groups.
        `group_broadcasting` has no effect when `groups=1`
    kind : literal "conv" or "corr"
        Represents whether the kernel should be mirrored during the convolution
        (`"conv"`) or not (`"corr"`).

    Examples
    -------
    >>> import pytorch_nd_semiconv as semiconv
    >>> dilation = semiconv.GenericConv(
    ...     semiconv.QuadraticKernelSpectral2D(5, 5, 3),
    ...     semiconv.SelectSemifield.tropical_max().lazy_fixed(),
    ...     padding="same",
    ...     stride=2,
    ...     groups=5,
    ... )
    >>> root = semiconv.GenericConv(
    ...     semiconv.QuadraticKernelIso2D(5, 10, 3),
    ...     semiconv.BroadcastSemifield.root(3.0).dynamic(),
    ...     padding="same",
    ... )
    """

    def __init__(
        self,
        kernel: nn.Module,
        conv: nn.Module,
        stride: int | tuple[int, ...] = 1,
        padding: (
            int
            | tuple[int, ...]
            | tuple[tuple[int, int], ...]
            | Literal["valid", "same"]
        ) = 0,
        dilation: int | tuple[int, ...] = 1,
        groups: int = 1,
        group_broadcasting: bool = False,
        kind: Literal["conv", "corr"] = "conv",
    ):
        super().__init__()
        self.padding = padding
        self.stride = stride
        self.dilation = dilation
        self.kernel = kernel
        self.conv = conv
        self.groups = groups
        self.group_broadcasting = group_broadcasting
        self.kind = kind

        # Since these are custom arguments, we only want to pass them if they differ
        # from the default values (otherwise, they may be unexpected)
        self.kwargs = {}
        if self.group_broadcasting:
            self.kwargs["group_broadcasting"] = True
        if self.kind == "corr":
            self.kwargs["kind"] = "corr"

    def forward(self, img):
        """
        Run a forward step with this convolution.

        Parameters
        ----------
        img : Tensor (B, C, *Spatial)
            The input images as a 2+N tensor, of shape (Batch, Channels, ...Spatial)

            For example: a 2D image would typically be (Batch, Channels, Height, Width)

        Returns
        -------
        out_img : Tensor (B, C', *Spatial')
            The output images as a 2+N tensor, with the same batch shape but possibly
            adjusted other dimensions.
        """
        return self.conv(
            img,
            self.kernel(),
            dilation=self.dilation,
            padding=self.padding,
            stride=self.stride,
            groups=self.groups,
            **self.kwargs,
        )

    def extra_repr(self) -> str:
        res = []
        if self.padding:
            res.append(f"padding={self.padding}")
        if self.stride != 1:
            res.append(f"stride={self.stride}")
        if self.dilation != 1:
            res.append(f"dilation={self.dilation}")
        if self.groups != 1:
            res.append(f"groups={self.groups}")
        if self.group_broadcasting:
            res.append("group_broadcasting=True")
        if self.kind == "corr":
            res.append("kind=corr")

        return ", ".join(res)

A generic convolution Module using a kernel and a convolution Module

Parameters

kernel : nn.Module

A module that produces a convolutional kernel from its forward method. Must not take arguments.

See e.g. QuadraticKernelSpectral2D or LearnedKernel2D.

conv : nn.Module

A module that can take image, kernel as positional arguments, as well as dilation, padding, stride and groups as keyword arguments, optionally supporting group_broadcasting and kind.

See e.g. BroadcastSemifield.dynamic() or SelectSemifield.lazy_fixed().

stride : int, (int, ...) = 1
The stride passed to conv, either for all spatial dimensions or for each separately.
padding : int, (int, ...), ((int, int), ...), "valid", "same" = 0

The padding passed to conv. Depending on the type of padding:

  • P indicates padding at the start and end of all spatial axes with P.
  • (P0, …) indicates padding at the start and end of the first spatial axis with P0, and similarly for all other spatial axes.
  • ((PBeg0, PEnd0), …) indicates padding the start of the first spatial axis with PBeg0 and the end with PEnd0, similarly for all other spatial axes.
  • "valid" indicates to only perform the convolution with valid values of the image, i.e. no padding.
  • "same" indicates to pad the input such that a stride-1 convolution would produce an output of the same spatial size. Convolutions with higher stride will use the same padding scheme, but result in outputs of reduced size.
dilation : int, (int, ...) = 1
The dilation passed to conv, either for all spatial dimensions or for each separately.
groups : int = 1
The number of convolutional groups for this convolution.
group_broadcasting : bool = False
Whether to take the input kernels as a single output group, and broadcast across all input groups. group_broadcasting has no effect when groups=1
kind : literal "conv" or "corr"
Represents whether the kernel should be mirrored during the convolution ("conv") or not ("corr").

Examples

>>> import pytorch_nd_semiconv as semiconv
>>> dilation = semiconv.GenericConv(
...     semiconv.QuadraticKernelSpectral2D(5, 5, 3),
...     semiconv.SelectSemifield.tropical_max().lazy_fixed(),
...     padding="same",
...     stride=2,
...     groups=5,
... )
>>> root = semiconv.GenericConv(
...     semiconv.QuadraticKernelIso2D(5, 10, 3),
...     semiconv.BroadcastSemifield.root(3.0).dynamic(),
...     padding="same",
... )

Methods

def forward(self, img)
Expand source code Browse git
def forward(self, img):
    """
    Run a forward step with this convolution.

    Parameters
    ----------
    img : Tensor (B, C, *Spatial)
        The input images as a 2+N tensor, of shape (Batch, Channels, ...Spatial)

        For example: a 2D image would typically be (Batch, Channels, Height, Width)

    Returns
    -------
    out_img : Tensor (B, C', *Spatial')
        The output images as a 2+N tensor, with the same batch shape but possibly
        adjusted other dimensions.
    """
    return self.conv(
        img,
        self.kernel(),
        dilation=self.dilation,
        padding=self.padding,
        stride=self.stride,
        groups=self.groups,
        **self.kwargs,
    )

Run a forward step with this convolution.

Parameters

img : Tensor (B, C, *Spatial)

The input images as a 2+N tensor, of shape (Batch, Channels, …Spatial)

For example: a 2D image would typically be (Batch, Channels, Height, Width)

Returns

out_img : Tensor (B, C', *Spatial')
The output images as a 2+N tensor, with the same batch shape but possibly adjusted other dimensions.
class BroadcastSemifield (add_reduce, multiply, zero, add_reduce_channels=None)
Expand source code Browse git
class BroadcastSemifield(typing.NamedTuple):
    r"""
    A semifield definition using PyTorch broadcasting operators

    Using a technique similar to nn.Unfold, we can create a view of the input array and
    apply broadcasting functions along kernel axes to perform a semifield convolution.
    All functions must take PyTorch Tensors, and should have a backwards implementation.

    This function does not use JIT components, and therefore has no compilation time
    (and can be run on non-CUDA devices as well).

    Parameters
    -------
    add_reduce : (Tensor, tuple of ints) -> Tensor
        To characterise semifield summation \(\bigoplus\), this function takes a single
        tensor with several axes, and performs reduction with \(\oplus\) along the axes
        indicated in the second argument.

        Example: ``lambda arr, dims: torch.sum(arr, dim=dims)``
    multiply : (Tensor, Tensor) -> Tensor
        To characterise semifield multiplication \(\otimes\), this function takes two
        tensors and performs a broadcasting, element-wise version of \(\otimes\).

        Example: ``lambda img, krn: img * krn``
    zero : float
        The absorbing semifield zero.

    Other Parameters
    -------
    add_reduce_channels : (Tensor, int) -> Tensor, optional
        An alternate reduction function (similar to `add_reduce`) that is applied along
        specifically the channel dimension.
        This alternate function could be e.g. addition, in a modified version of \(T_+\)
        (see `channels_add` parameter of `BroadcastSemifield.tropical_max`).

    Examples
    -------
    \(T_+\) convolution:

    >>> dilation = BroadcastSemifield.tropical_max().dynamic()

    \(L_{-3}\) convolution:

    >>> log = BroadcastSemifield.log(-3.0).dynamic()

    For examples of how to construct a `BroadcastSemifield` manually, see the source.
    """

    # (multiplied, dims) -> `multipled` reduced with (+) along every dim in `dims`
    add_reduce: Callable[[torch.Tensor, tuple[int, ...]], torch.Tensor]
    # (img, krn) -> `img` (x) `krn`
    multiply: Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
    # forall a, b: `zero` (x) a  (+) b  ==  b
    zero: float
    # Similar to add_reduce, but only used for channel axis (so takes one dimension)
    add_reduce_channels: Callable[[torch.Tensor, int], torch.Tensor] = None

    @classmethod
    def tropical_max(cls, channels_add: bool = False, spread_gradient: bool = False):
        r"""
        Construct a \(T_+\) `BroadcastSemifield`.

        The tropical max semifield / semiring is defined as:
        \[(\mathbb{R}\cup \{-\infty\}, \max, +)\]

        Parameters
        ----------
        channels_add : bool = False
            Whether to use standard addition \(+\) instead of the semifield addition
            \(\max\) along specifically the channel axis.
        spread_gradient : bool = False
            Whether to, in cases of multiple equal maxima, spread the gradient equally
            amongst all maxima.
        """
        return cls(
            add_reduce=(lambda multiplied, dim: torch.amax(multiplied, dim=dim))
            if spread_gradient
            else (
                _repeated_dim(
                    lambda multiplied, dim: torch.max(multiplied, dim=dim).values
                )
            ),
            multiply=lambda img, krn: img + krn,
            zero=-torch.inf,
            add_reduce_channels=(
                (lambda multiplied, dim: torch.sum(multiplied, dim=dim))
                if channels_add
                else None
            ),
        )

    @classmethod
    def tropical_min_negated(
        cls, channels_add: bool = False, spread_gradient: bool = False
    ):
        r"""
        Construct a `BroadcastSemifield` similar to \(T_-\), where the kernel is negated

        The usual tropical min semifield / semiring is defined as:
        \[(\mathbb{R}\cup \{\infty\}, \min, +)\]

        This version is slightly modified:
        while performing erosion using \(T_-\) requires first negating the kernel, this
        modified semifield has \(-\) instead of \(+\) as the semifield multiplication.
        As such, the resulting convolution will work with non-negated kernels as inputs,
        making the interface more similar to the dilation in \(T_+\).

        Parameters
        ----------
        channels_add : bool = False
            Whether to use standard addition \(+\) instead of the semifield addition
            \(\min\) along specifically the channel axis.
        spread_gradient : bool = False
            Whether to, in cases of multiple equal minima, spread the gradient equally
            amongst all minima.
        """
        return cls(
            add_reduce=(lambda multiplied, dim: torch.amin(multiplied, dim=dim))
            if spread_gradient
            else (
                _repeated_dim(
                    lambda multiplied, dim: torch.min(multiplied, dim=dim).values
                )
            ),
            multiply=lambda img, krn: img - krn,
            zero=torch.inf,
            add_reduce_channels=(
                (lambda multiplied, dim: torch.sum(multiplied, dim=dim))
                if channels_add
                else None
            ),
        )

    @classmethod
    def linear(cls):
        r"""
        Construct a linear `BroadcastSemifield`.

        The linear field is defined as:
        \[(\mathbb{R}, +, \times)\]

        Mainly for comparison purposes: the linear convolutions offered by PyTorch
        use CUDNN, which is far better optimised for CUDA devices.
        """
        return cls(
            add_reduce=(lambda multiplied, dim: torch.sum(multiplied, dim=dim)),
            multiply=lambda img, krn: img * krn,
            zero=0,
        )

    @classmethod
    def root(cls, p: float):
        r"""
        Construct a \(R_p\) `BroadcastSemifield`.

        The root semifields are defined as:
        \[(\mathbb{R}_+, \oplus_p, \times) \textrm{ for all } p\ne0 \textrm{ where }
        a\oplus_p b= \sqrt[p]{a^p+b^p} \]
        with the semifield zero being \(0\) and the semifield one being \(1\).

        Parameters
        ----------
        p : int
            The power to use in \(\oplus_p\).
            May not be zero.
        """
        assert p != 0, f"Invalid value: {p=}"
        return cls(
            add_reduce=(
                lambda multiplied, dim: multiplied.pow(p).sum(dim=dim).pow(1 / p)
            ),
            multiply=lambda img, krn: img * krn,
            zero=float(torch.finfo(torch.float32).eps),
        )

    @classmethod
    def log(cls, mu: float):
        r"""
        Construct a \(L_+\mu\) or \(L_-\mu\) `BroadcastSemifield`.

        The log semifields are defined as:
        \[(\mathbb{R}\cup \{\pm\infty\}, \oplus_\mu, +) \textrm{ for all } \mu\ne0
        \textrm{ where }
        a\oplus_\mu b= \frac{1}{\mu}\ln(e^{\mu a}+e^{\mu b}) \]
        with the semifield zero being \(-\infty\) for \(\mu>0\) and \(\infty\)
        otherwise, and the semifield one being \(0\).

        Parameters
        ----------
        mu : int
            The base to use in \(\oplus_mu\).
            May not be zero.
        """
        assert mu != 0, f"Invalid value: {mu=}"
        return cls(
            add_reduce=(
                lambda multiplied, dim: torch.logsumexp(multiplied * mu, dim=dim) / mu
            ),
            multiply=lambda img, krn: img + krn,
            zero=-torch.inf if mu > 0 else torch.inf,
        )

    def dynamic(self, unfold_copy: bool = False) -> torch.nn.Module:
        """
        Create a convolution Module based on this `BroadcastSemifield`.

        This method is named `dynamic`, because the Module it creates will dynamically
        adjust itself based on new input types, unlike e.g. `SelectSemifield.lazy_fixed`

        Parameters
        ----------
        unfold_copy : bool = False
            Whether to use `nn.functional.unfold` during computation, which results in
            a copy of the data.
            This is only supported for 2D convolutions 1D or 3+D convolutions cannot use
            `nn.functional.unfold`.

            Mainly for comparison purposes: in tests, it always results in slowdown.

        Returns
        -------
        conv : nn.Module
            A convolution module, suitable for use in `GenericConv`
        """
        return BroadcastConv(self, unfold_copy)

A semifield definition using PyTorch broadcasting operators

Using a technique similar to nn.Unfold, we can create a view of the input array and apply broadcasting functions along kernel axes to perform a semifield convolution. All functions must take PyTorch Tensors, and should have a backwards implementation.

This function does not use JIT components, and therefore has no compilation time (and can be run on non-CUDA devices as well).

Parameters

add_reduce : (Tensor, tuple of ints) -> Tensor

To characterise semifield summation \bigoplus, this function takes a single tensor with several axes, and performs reduction with \oplus along the axes indicated in the second argument.

Example: lambda arr, dims: torch.sum(arr, dim=dims)

multiply : (Tensor, Tensor) -> Tensor

To characterise semifield multiplication \otimes, this function takes two tensors and performs a broadcasting, element-wise version of \otimes.

Example: lambda img, krn: img * krn

zero : float
The absorbing semifield zero.

Other Parameters

add_reduce_channels : (Tensor, int) -> Tensor, optional
An alternate reduction function (similar to add_reduce) that is applied along specifically the channel dimension. This alternate function could be e.g. addition, in a modified version of T_+ (see channels_add parameter of BroadcastSemifield.tropical_max()).

Examples

T_+ convolution:

>>> dilation = BroadcastSemifield.tropical_max().dynamic()

L_{-3} convolution:

>>> log = BroadcastSemifield.log(-3.0).dynamic()

For examples of how to construct a BroadcastSemifield manually, see the source.

Static methods

def tropical_max(channels_add=False, spread_gradient=False)

Construct a T_+ BroadcastSemifield.

The tropical max semifield / semiring is defined as: (\mathbb{R}\cup \{-\infty\}, \max, +)

Parameters

channels_add : bool = False
Whether to use standard addition + instead of the semifield addition \max along specifically the channel axis.
spread_gradient : bool = False
Whether to, in cases of multiple equal maxima, spread the gradient equally amongst all maxima.
def tropical_min_negated(channels_add=False, spread_gradient=False)

Construct a BroadcastSemifield similar to T_-, where the kernel is negated

The usual tropical min semifield / semiring is defined as: (\mathbb{R}\cup \{\infty\}, \min, +)

This version is slightly modified: while performing erosion using T_- requires first negating the kernel, this modified semifield has - instead of + as the semifield multiplication. As such, the resulting convolution will work with non-negated kernels as inputs, making the interface more similar to the dilation in T_+.

Parameters

channels_add : bool = False
Whether to use standard addition + instead of the semifield addition \min along specifically the channel axis.
spread_gradient : bool = False
Whether to, in cases of multiple equal minima, spread the gradient equally amongst all minima.
def linear()

Construct a linear BroadcastSemifield.

The linear field is defined as: (\mathbb{R}, +, \times)

Mainly for comparison purposes: the linear convolutions offered by PyTorch use CUDNN, which is far better optimised for CUDA devices.

def root(p)

Construct a R_p BroadcastSemifield.

The root semifields are defined as: (\mathbb{R}_+, \oplus_p, \times) \textrm{ for all } p\ne0 \textrm{ where } a\oplus_p b= \sqrt[p]{a^p+b^p} with the semifield zero being 0 and the semifield one being 1.

Parameters

p : int
The power to use in \oplus_p. May not be zero.
def log(mu)

Construct a L_+\mu or L_-\mu BroadcastSemifield.

The log semifields are defined as: (\mathbb{R}\cup \{\pm\infty\}, \oplus_\mu, +) \textrm{ for all } \mu\ne0 \textrm{ where } a\oplus_\mu b= \frac{1}{\mu}\ln(e^{\mu a}+e^{\mu b}) with the semifield zero being -\infty for \mu>0 and \infty otherwise, and the semifield one being 0.

Parameters

mu : int
The base to use in \oplus_mu. May not be zero.

Methods

def dynamic(self, unfold_copy=False)
Expand source code Browse git
def dynamic(self, unfold_copy: bool = False) -> torch.nn.Module:
    """
    Create a convolution Module based on this `BroadcastSemifield`.

    This method is named `dynamic`, because the Module it creates will dynamically
    adjust itself based on new input types, unlike e.g. `SelectSemifield.lazy_fixed`

    Parameters
    ----------
    unfold_copy : bool = False
        Whether to use `nn.functional.unfold` during computation, which results in
        a copy of the data.
        This is only supported for 2D convolutions 1D or 3+D convolutions cannot use
        `nn.functional.unfold`.

        Mainly for comparison purposes: in tests, it always results in slowdown.

    Returns
    -------
    conv : nn.Module
        A convolution module, suitable for use in `GenericConv`
    """
    return BroadcastConv(self, unfold_copy)

Create a convolution Module based on this BroadcastSemifield.

This method is named dynamic, because the Module it creates will dynamically adjust itself based on new input types, unlike e.g. SelectSemifield.lazy_fixed()

Parameters

unfold_copy : bool = False

Whether to use nn.functional.unfold during computation, which results in a copy of the data. This is only supported for 2D convolutions 1D or 3+D convolutions cannot use nn.functional.unfold.

Mainly for comparison purposes: in tests, it always results in slowdown.

Returns

conv : nn.Module
A convolution module, suitable for use in GenericConv
class SelectSemifield (add_select, times, d_times_d_img, d_times_d_kernel, zero, cache_name=None)
Expand source code Browse git
class SelectSemifield(NamedTuple):
    r"""
    A semifield definition where semifield addition selects a single value

    For such semifields, the backwards pass can be done very efficiently by memoizing
    the output provenance (index of the chosen value).
    The resulting module is compiled and works only on CUDA devices.

    Parameters
    -------
    add_select : (float, float) -> bool
        Given two values, return whether we should pick the second value (`True`), or
        instead keep the first (`False`).
        If there is no meaningful difference between the two values, `False` should be
        preferred.
    times : (float, float) -> float
        Given an image and a kernel value, perform scalar semifield multiplication
        \(\otimes\).
    d_times_d_img : (float, float) -> float
        Given the two arguments to `times`, compute the derivative to the first:
        \[\frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{img}}\]
    d_times_d_kernel : (float, float) -> float
        Given the two arguments to `times`, compute the derivative to the second:
        \[\frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{kernel}}\]
    zero : float
        The semifield zero.
    cache_name : str, optional
        Identifier for this semifield, allows for extension compilations to be cached.

        Instances of `SelectSemifield` that are meaningfully different should not have
        the same `cache_name`, as this may lead to the wrong compilation being used.

    Examples
    -------
    \(T_+\) convolution that will recompile for new inputs:

    >>> dilation = SelectSemifield.tropical_max().dynamic()

    \(T_-\) convolution that will compile only once:

    >>> erosion = SelectSemifield.tropical_min_negated().lazy_fixed()

    For examples of how to construct a `SelectSemifield` manually, see the source code.
    """

    add_select: Callable[[float, float], bool]  # Return True if we should pick right
    times: Callable[[float, float], float]  # (img_val, krn_val) -> multiplied_val
    d_times_d_img: Callable[[float, float], float]
    d_times_d_kernel: Callable[[float, float], float]
    zero: float
    cache_name: str = None  # Cache identifier: distinct for different operators

    @classmethod
    def tropical_max(cls) -> Self:
        r"""
        Construct a \(T_+\) `SelectSemifield`.

        The tropical max semifield / semiring is defined as:
        \[(\mathbb{R}\cup \{-\infty\}, \max, +)\]
        """
        return cls(
            add_select=lambda left, right: left < right,
            times=lambda img_val, kernel_val: img_val + kernel_val,
            d_times_d_img=lambda _i, _k: 1.0,
            d_times_d_kernel=lambda _i, _k: 1.0,
            zero=-math.inf,
            cache_name="_tropical_max",
        )

    @classmethod
    def tropical_min_negated(cls) -> Self:
        r"""
        Construct a `SelectSemifield` similar to \(T_-\), where the kernel is negated.

        The usual tropical min semifield / semiring is defined as:
        \[(\mathbb{R}\cup \{\infty\}, \min, +)\]

        This version is slightly modified:
        while performing erosion using \(T_-\) requires first negating the kernel, this
        modified semifield has \(-\) instead of \(+\) as the semifield multiplication.
        As such, the resulting convolution will work with non-negated kernels as inputs,
        making the interface more similar to the dilation in \(T_+\).
        """
        return cls(
            add_select=lambda left, right: left > right,
            times=lambda img_val, kernel_val: img_val - kernel_val,
            d_times_d_img=lambda _i, _k: 1.0,
            d_times_d_kernel=lambda _i, _k: -1.0,
            zero=math.inf,
            cache_name="_tropical_min",
        )

    # The torch compiler doesn't understand the Numba compiler
    @torch.compiler.disable
    @lru_cache  # noqa: B019
    def _compile(
        self,
        meta: ConvMeta,
        compile_options: Mapping[str, Any],
    ) -> Callable[[torch.Tensor, torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
        impl = compile_options.get("impl", "glb")
        if impl not in ("glb",):
            raise ValueError(f"Unknown {impl=}")

        cmp_semi = CompiledSelectSemifield.compile(self)

        impls = {
            "glb": compile_forwards,
        }
        forwards = impls[impl](
            semifield=cmp_semi,
            meta=meta,
            thread_block_size=compile_options.get("thread_block_size"),
            debug=compile_options.get("debug", False),
            cache_name="_temporary" if self.cache_name is None else self.cache_name,
            to_extension=compile_options.get("to_extension", False),
        )
        backwards, backwards_setup = compile_backwards(
            semifield=cmp_semi,
            meta=meta,
            thread_block_size=compile_options.get("thread_block_size"),
            debug=compile_options.get("debug", False),
            cache_name="_temporary" if self.cache_name is None else self.cache_name,
            to_extension=compile_options.get("to_extension", False),
            kernel_inflation=compile_options.get("kernel_inflation", 16),
        )
        forwards.register_autograd(backwards, setup_context=backwards_setup)

        return forwards

    def dynamic(
        self,
        thread_block_size: int = None,
        to_extension: bool = False,
        debug: bool = False,
        kernel_inflation: int = 16,
    ) -> torch.nn.Module:
        """
        Create a *recompiling* convolution Module based on this `SelectSemifield`.

        Returns
        -------
        conv : nn.Module
            A convolution module, suitable for use in `GenericConv`.
            Note that the compilation process is not traceable, and recompilations
            **may cause errors when using `torch.compile`** for backends other than
            CUDA Graphs

        Other Parameters
        ----------
        thread_block_size : int = 128
            The number of threads per CUDA block.
        to_extension : bool = False
            Whether the resulting module should compile to a PyTorch extension.
            Doing so increases compilation times, but reduces per-call overhead
            when not using CUDA-Graphs.

            For neural networks, it is best to keep `to_extension` as False and use
            CUDA Graphs via `torch.compile(model, mode="reduce-overhead",
            fullgraph=True)` to eliminate the wrapper code.
            If this is not possible (due to highly dynamic code or irregular shapes),
            then the next best option would be to use `to_extension`
            and minimise call overhead.
        debug : bool = False
            Whether to print additional debugging and compilation information.
        kernel_inflation : int = 16
            The factor to inflate the kernel gradient with, to better distribute
            atomic operations.
            A larger factor can improve performance when the number of output pixels
            per kernel value is high, but only up to a point, and at the cost of memory
            efficiency.
        """
        return CompiledConv(
            self,
            {
                "thread_block_size": thread_block_size,
                "debug": debug,
                "to_extension": to_extension,
                "kernel_inflation": kernel_inflation,
            },
        )

    def lazy_fixed(
        self,
        thread_block_size: int = None,
        to_extension: bool = False,
        debug: bool = False,
        kernel_inflation: int = 16,
    ) -> torch.nn.Module:
        """
        Create a *once-compiling* convolution Module based on this `SelectSemifield`.

        In general, `SelectSemifield.dynamic` should be preferred for testing and also
        for training if the model can be traced by CUDA Graphs.
        If CUDA Graphs cannot capture the model code due to dynamic elements, then using
        `SelectSemifield.lazy_fixed` with `to_extension=True` will minimise overhead.

        Returns
        -------
        conv : nn.Module
            A convolution module, suitable for use in `GenericConv`.
            Note that compilation will be based on the first inputs seen, after which
            the operation will be fixed: **only batch size may be changed afterwards**.
            The module is, however, traceable by e.g. `torch.compile` on all backends.

        Other Parameters
        ----------
        thread_block_size : int = 128
            The number of threads per CUDA block.
        to_extension : bool = False
            Whether the resulting module should compile to a PyTorch extension.
            Doing so increases compilation times, but reduces per-call overhead
            when not using CUDA-Graphs.

            For neural networks, it is best to keep `to_extension` as False and use
            CUDA Graphs via `torch.compile(model, mode="reduce-overhead",
            fullgraph=True)` to eliminate the wrapper code.
            If this is not possible (due to highly dynamic code or irregular shapes),
            then the next best option would be to use `to_extension`
            and minimise call overhead.
        debug : bool = False
            Whether to print additional debugging and compilation information.
        kernel_inflation : int = 16
            The factor to inflate the kernel gradient with, to better distribute
            atomic operations.
            A larger factor can improve performance when the number of output pixels
            per kernel value is high, but only up to a point, and at the cost of memory
            efficiency.
        """
        return CompiledConvFixedLazy(
            self,
            {
                "thread_block_size": thread_block_size,
                "debug": debug,
                "to_extension": to_extension,
                "kernel_inflation": kernel_inflation,
            },
        )

    def __hash__(self):
        if self.cache_name is not None:
            return hash(self.cache_name)

        return hash(
            (
                self.add_select,
                self.times,
                self.d_times_d_img,
                self.d_times_d_kernel,
                self.zero,
            )
        )

    def __eq__(self, other):
        if not isinstance(other, SelectSemifield):
            return False
        if self.cache_name is not None:
            return self.cache_name == other.cache_name

        return self is other

    @staticmethod
    def _get_result(res: tuple[torch.Tensor, torch.Tensor]):
        return res[0]

A semifield definition where semifield addition selects a single value

For such semifields, the backwards pass can be done very efficiently by memoizing the output provenance (index of the chosen value). The resulting module is compiled and works only on CUDA devices.

Parameters

add_select : (float, float) -> bool
Given two values, return whether we should pick the second value (True), or instead keep the first (False). If there is no meaningful difference between the two values, False should be preferred.
times : (float, float) -> float
Given an image and a kernel value, perform scalar semifield multiplication \otimes.
d_times_d_img : (float, float) -> float
Given the two arguments to times, compute the derivative to the first: \frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{img}}
d_times_d_kernel : (float, float) -> float
Given the two arguments to times, compute the derivative to the second: \frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{kernel}}
zero : float
The semifield zero.
cache_name : str, optional

Identifier for this semifield, allows for extension compilations to be cached.

Instances of SelectSemifield that are meaningfully different should not have the same cache_name, as this may lead to the wrong compilation being used.

Examples

T_+ convolution that will recompile for new inputs:

>>> dilation = SelectSemifield.tropical_max().dynamic()

T_- convolution that will compile only once:

>>> erosion = SelectSemifield.tropical_min_negated().lazy_fixed()

For examples of how to construct a SelectSemifield manually, see the source code.

Static methods

def tropical_max()

Construct a T_+ SelectSemifield.

The tropical max semifield / semiring is defined as: (\mathbb{R}\cup \{-\infty\}, \max, +)

def tropical_min_negated()

Construct a SelectSemifield similar to T_-, where the kernel is negated.

The usual tropical min semifield / semiring is defined as: (\mathbb{R}\cup \{\infty\}, \min, +)

This version is slightly modified: while performing erosion using T_- requires first negating the kernel, this modified semifield has - instead of + as the semifield multiplication. As such, the resulting convolution will work with non-negated kernels as inputs, making the interface more similar to the dilation in T_+.

Methods

def dynamic(self, thread_block_size=None, to_extension=False, debug=False, kernel_inflation=16)
Expand source code Browse git
def dynamic(
    self,
    thread_block_size: int = None,
    to_extension: bool = False,
    debug: bool = False,
    kernel_inflation: int = 16,
) -> torch.nn.Module:
    """
    Create a *recompiling* convolution Module based on this `SelectSemifield`.

    Returns
    -------
    conv : nn.Module
        A convolution module, suitable for use in `GenericConv`.
        Note that the compilation process is not traceable, and recompilations
        **may cause errors when using `torch.compile`** for backends other than
        CUDA Graphs

    Other Parameters
    ----------
    thread_block_size : int = 128
        The number of threads per CUDA block.
    to_extension : bool = False
        Whether the resulting module should compile to a PyTorch extension.
        Doing so increases compilation times, but reduces per-call overhead
        when not using CUDA-Graphs.

        For neural networks, it is best to keep `to_extension` as False and use
        CUDA Graphs via `torch.compile(model, mode="reduce-overhead",
        fullgraph=True)` to eliminate the wrapper code.
        If this is not possible (due to highly dynamic code or irregular shapes),
        then the next best option would be to use `to_extension`
        and minimise call overhead.
    debug : bool = False
        Whether to print additional debugging and compilation information.
    kernel_inflation : int = 16
        The factor to inflate the kernel gradient with, to better distribute
        atomic operations.
        A larger factor can improve performance when the number of output pixels
        per kernel value is high, but only up to a point, and at the cost of memory
        efficiency.
    """
    return CompiledConv(
        self,
        {
            "thread_block_size": thread_block_size,
            "debug": debug,
            "to_extension": to_extension,
            "kernel_inflation": kernel_inflation,
        },
    )

Create a recompiling convolution Module based on this SelectSemifield.

Returns

conv : nn.Module
A convolution module, suitable for use in GenericConv. Note that the compilation process is not traceable, and recompilations may cause errors when using torch.compile for backends other than CUDA Graphs

Other Parameters

thread_block_size : int = 128
The number of threads per CUDA block.
to_extension : bool = False

Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs.

For neural networks, it is best to keep to_extension as False and use CUDA Graphs via torch.compile(model, mode="reduce-overhead", fullgraph=True) to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use to_extension and minimise call overhead.

debug : bool = False
Whether to print additional debugging and compilation information.
kernel_inflation : int = 16
The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency.
def lazy_fixed(self, thread_block_size=None, to_extension=False, debug=False, kernel_inflation=16)
Expand source code Browse git
def lazy_fixed(
    self,
    thread_block_size: int = None,
    to_extension: bool = False,
    debug: bool = False,
    kernel_inflation: int = 16,
) -> torch.nn.Module:
    """
    Create a *once-compiling* convolution Module based on this `SelectSemifield`.

    In general, `SelectSemifield.dynamic` should be preferred for testing and also
    for training if the model can be traced by CUDA Graphs.
    If CUDA Graphs cannot capture the model code due to dynamic elements, then using
    `SelectSemifield.lazy_fixed` with `to_extension=True` will minimise overhead.

    Returns
    -------
    conv : nn.Module
        A convolution module, suitable for use in `GenericConv`.
        Note that compilation will be based on the first inputs seen, after which
        the operation will be fixed: **only batch size may be changed afterwards**.
        The module is, however, traceable by e.g. `torch.compile` on all backends.

    Other Parameters
    ----------
    thread_block_size : int = 128
        The number of threads per CUDA block.
    to_extension : bool = False
        Whether the resulting module should compile to a PyTorch extension.
        Doing so increases compilation times, but reduces per-call overhead
        when not using CUDA-Graphs.

        For neural networks, it is best to keep `to_extension` as False and use
        CUDA Graphs via `torch.compile(model, mode="reduce-overhead",
        fullgraph=True)` to eliminate the wrapper code.
        If this is not possible (due to highly dynamic code or irregular shapes),
        then the next best option would be to use `to_extension`
        and minimise call overhead.
    debug : bool = False
        Whether to print additional debugging and compilation information.
    kernel_inflation : int = 16
        The factor to inflate the kernel gradient with, to better distribute
        atomic operations.
        A larger factor can improve performance when the number of output pixels
        per kernel value is high, but only up to a point, and at the cost of memory
        efficiency.
    """
    return CompiledConvFixedLazy(
        self,
        {
            "thread_block_size": thread_block_size,
            "debug": debug,
            "to_extension": to_extension,
            "kernel_inflation": kernel_inflation,
        },
    )

Create a once-compiling convolution Module based on this SelectSemifield.

In general, SelectSemifield.dynamic() should be preferred for testing and also for training if the model can be traced by CUDA Graphs. If CUDA Graphs cannot capture the model code due to dynamic elements, then using SelectSemifield.lazy_fixed() with to_extension=True will minimise overhead.

Returns

conv : nn.Module
A convolution module, suitable for use in GenericConv. Note that compilation will be based on the first inputs seen, after which the operation will be fixed: only batch size may be changed afterwards. The module is, however, traceable by e.g. torch.compile on all backends.

Other Parameters

thread_block_size : int = 128
The number of threads per CUDA block.
to_extension : bool = False

Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs.

For neural networks, it is best to keep to_extension as False and use CUDA Graphs via torch.compile(model, mode="reduce-overhead", fullgraph=True) to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use to_extension and minimise call overhead.

debug : bool = False
Whether to print additional debugging and compilation information.
kernel_inflation : int = 16
The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency.
class SubtractSemifield (add,
times,
d_times_d_img,
d_times_d_kernel,
subtract,
d_add_d_right,
zero,
cache_name=None,
post_sum=None,
undo_post_sum=None,
d_post_d_acc=None)
Expand source code Browse git
class SubtractSemifield(NamedTuple):
    r"""
    A semifield definition where semifield addition has an inverse

    For such semifields, the backwards pass can be done by 'subtracting' every value
    from the result to get the arguments for the additive derivative.
    The resulting module is compiled and works only on CUDA devices.

    Note that, while this implementation is more memory-efficient than
    `BroadcastSemifield`, it is typically slower in execution speed.
    If memory usage is not a concern but training speed is, then `BroadastSemifield`
    should therefore be preferred.

    Parameters
    -------
    add : (float, float) -> float
        Given an accumulator and a multiplied value, perform scalar semifield addition
        \(\oplus\).
    times : (float, float) -> float
        Given an image and a kernel value, perform scalar semifield multiplication
        \(\otimes\).
    d_times_d_img : (float, float) -> float
        Given the two arguments to `times`, compute the derivative to the first:
        \[\frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{img}}\]
    d_times_d_kernel : (float, float) -> float
        Given the two arguments to `times`, compute the derivative to the second:
        \[\frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{kernel}}\]
    subtract : (float, float) -> float
        Given the final accumulator value `res` and a multiplied value `val`,
        use the inverse of `add` and

        return an `acc` such that `add(acc, val) == res`.
        In other words: perform semifield subtraction.
    d_add_d_right : (float, float) -> float
        Given the two arguments to `add`, compute the derivative to the second:
        \[\frac{\delta (\textrm{acc} \oplus \textrm{val}) }{\delta\textrm{val}}\]
    zero : float
        The semifield zero.
    cache_name : str, optional
        Identifier for this semifield, allows for extension compilations to be cached.

        Instances of `SubtractSemifield` that are meaningfully different should not have
        the same `cache_name`, as this may lead to the wrong compilation being used.

    Other Parameters
    -------
    post_sum : (float) -> float, optional
        Some semifield additions are fairly complex and computationally expensive, but
        can be reinterpreted as a repeated simpler operation, followed by a scalar
        transformation of the final accumulator value.
        `post_sum` is then this scalar transformation, taking the final accumulator
        value `res` and transforming it into a value `out`.

        Taking the root semifield \(R_3\) as an example, we can see that if we use

        - `times` as \(a \otimes_3 b = (a \times b)^3 \)
        - `add` as \(+\)
        - `post_sum` as \(\textrm{out} = \sqrt[3]{\textrm{res}} \)

        then we can perform the reduction in terms of simple scalar addition, instead
        of having to take the power and root at every step.

        Using such a transfrom does, however, require defining two other operators,
        namely the inverse and the derivative.
        When these are given, `subtract` and `d_add_d_right` will be given untransformed
        arguments: in the root semifield example, that would mean that the arguments
        to `subtract` and `d_add_d_right` are not yet taken to the `p`'th root.
    undo_post_sum : (float) -> float, optional
        The inverse of `post_sum`, required if `post_sum` is given.
    d_post_d_acc : (float) -> float, optional
        The derivative of `post_sum` to its argument, required if `post_sum` is given:
        \[\frac{\delta \textrm{post_sum}(\textrm{res}) }{\delta\textrm{res}}\]

    Examples
    -------
    Linear convolution that will recompile for new parameters:

    >>> linear = SubtractSemifield.linear().dynamic()

    \(R_3\) convolution that will compile only once:

    >>> root = SubtractSemifield.root(3.0).lazy_fixed()

    For examples of how to construct a `SubtractSemifield`, see the source code.
    """

    add: Callable[[float, float], float]  # (acc, val) -> acc (+) val
    times: Callable[[float, float], float]  # (img_val, krn_val) -> multiplied_val
    d_times_d_img: Callable[[float, float], float]
    d_times_d_kernel: Callable[[float, float], float]
    # (res, val) -> res-val, such that val (+) (res - val) == res
    subtract: Callable[[float, float], float]
    # d(acc (+) val) / dval
    d_add_d_right: Callable[[float, float], float]
    zero: float
    cache_name: str = None  # Cache identifier: distinct for different operators

    post_sum: Callable[[float], float] = None  # (final_acc) -> res
    undo_post_sum: Callable[[float], float] = None  # (res) -> final_acc
    d_post_d_acc: Callable[[float], float] = None  # (final_acc) -> dacc

    @classmethod
    def linear(cls) -> Self:
        r"""
        Construct a linear `SubtractSemifield`

        The linear field is defined as:
        \[(\mathbb{R}, +, \times)\]

        Mainly for comparison purposes: the linear convolutions offered by PyTorch
        use CUDNN, which is far better optimised for CUDA devices.
        """
        return cls(
            add=lambda acc, val: acc + val,
            times=lambda img_val, kernel_val: img_val * kernel_val,
            d_times_d_img=lambda _i, kernel_val: kernel_val,
            d_times_d_kernel=lambda img_val, _k: img_val,
            subtract=lambda res, val: res - val,
            d_add_d_right=lambda _a, _v: 1,
            zero=0,
            cache_name="_linear",
        )

    @classmethod
    def root(cls, p: float) -> Self:
        r"""
        Construct a \(R_p\) `SubtractSemifield`.

        The root semifields are defined as:
        \[(\mathbb{R}_+, \oplus_p, \times) \textrm{ for all } p\ne0 \textrm{ where }
        a\oplus_p b= \sqrt[p]{a^p+b^p} \]
        with the semifield zero being \(0\) and the semifield one being \(1\).

        Parameters
        ----------
        p : int
            The power to use in \(\oplus_p\).
            May not be zero.
        """
        assert p != 0, f"Invalid value: {p=}"
        return cls(
            times=lambda img_val, kernel_val: (img_val * kernel_val) ** p,
            add=lambda acc, val: (acc + val),
            post_sum=lambda acc: acc ** (1 / p),
            zero=0,
            cache_name=f"_root_{cls._number_to_cache(p)}",
            undo_post_sum=lambda res: res**p,
            subtract=lambda acc, val: acc - val,
            d_times_d_img=lambda a, b: ((a * b) ** p) * p / a,
            d_times_d_kernel=lambda a, b: ((a * b) ** p) * p / b,
            d_add_d_right=lambda _a, _b: 1,
            d_post_d_acc=lambda acc: (1 / p) * acc ** (1 / p - 1),
        )

    @classmethod
    def log(cls, mu: float) -> Self:
        r"""
        Construct a \(L_+\mu\) or \(L_-\mu\) `SubtractSemifield`.

        The log semifields are defined as:
        \[(\mathbb{R}\cup \{\pm\infty\}, \oplus_\mu, +) \textrm{ for all } \mu\ne0
        \textrm{ where }
        a\oplus_\mu b= \frac{1}{\mu}\ln(e^{\mu a}+e^{\mu b}) \]
        with the semifield zero being \(-\infty\) for \(\mu>0\) and \(\infty\)
        otherwise, and the semifield one being \(0\).

        Parameters
        ----------
        mu : int
            The base to use in \(\oplus_\mu\).
            May not be zero.
        """
        assert mu != 0, f"Invalid value: {mu=}"
        return cls(
            times=lambda img_val, kernel_val: math.exp((img_val + kernel_val) * mu),
            add=lambda acc, val: (acc + val),
            post_sum=lambda acc: math.log(acc) / mu,
            zero=0,
            cache_name=f"_log_{cls._number_to_cache(mu)}",
            d_times_d_img=lambda a, b: mu * math.exp((a + b) * mu),
            d_times_d_kernel=lambda a, b: mu * math.exp((a + b) * mu),
            undo_post_sum=lambda res: math.exp(res * mu),
            subtract=lambda acc, val: acc - val,
            d_add_d_right=lambda _a, _v: 1,
            d_post_d_acc=lambda acc: 1 / (mu * acc),
        )

    # The torch compiler doesn't understand the Numba compiler
    @torch.compiler.disable
    @lru_cache  # noqa: B019
    def _compile(
        self,
        meta: ConvMeta,
        compile_options: Mapping[str, Any],
    ) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
        impl = compile_options.get("impl", "glb")
        if impl not in ("glb",):
            raise ValueError(f"Unknown {impl=}")

        cmp_semi = CompiledSubtractSemifield.compile(self)
        impls = {"glb": compile_forwards}

        forwards = impls[impl](
            semifield=cmp_semi,
            meta=meta,
            thread_block_size=compile_options.get("thread_block_size"),
            debug=compile_options.get("debug", False),
            cache_name="_temporary" if self.cache_name is None else self.cache_name,
            to_extension=compile_options.get("to_extension", False),
        )
        backwards, backwards_setup = compile_backwards(
            semifield=cmp_semi,
            meta=meta,
            thread_block_size=compile_options.get("thread_block_size"),
            debug=compile_options.get("debug", False),
            cache_name="_temporary" if self.cache_name is None else self.cache_name,
            to_extension=compile_options.get("to_extension", False),
        )
        forwards.register_autograd(backwards, setup_context=backwards_setup)

        return forwards

    def dynamic(
        self,
        thread_block_size: int = 256,
        to_extension: bool = False,
        debug: bool = False,
        kernel_inflation: int = 16,
    ) -> torch.nn.Module:
        """
        Create a *recompiling* convolution Module based on this `SubtractSemifield`.

        Returns
        -------
        conv : nn.Module
            A convolution module, suitable for use in `GenericConv`.
            Note that the compilation process is not traceable, and recompilations
            **may cause errors when using `torch.compile`** for backends other than
            CUDA Graphs

        Other Parameters
        ----------
        thread_block_size : int = 128
            The number of threads per CUDA block.
        to_extension : bool = False
            Whether the resulting module should compile to a PyTorch extension.
            Doing so increases compilation times, but reduces per-call overhead
            when not using CUDA-Graphs.

            For neural networks, it is best to keep `to_extension` as False and use
            CUDA Graphs via `torch.compile(model, mode="reduce-overhead",
            fullgraph=True)` to eliminate the wrapper code.
            If this is not possible (due to highly dynamic code or irregular shapes),
            then the next best option would be to use `to_extension`
            and minimise call overhead.
        debug : bool = False
            Whether to print additional debugging and compilation information.
        kernel_inflation : int = 16
            The factor to inflate the kernel gradient with, to better distribute
            atomic operations.
            A larger factor can improve performance when the number of output pixels
            per kernel value is high, but only up to a point, and at the cost of memory
            efficiency.
        """
        return CompiledConv(
            self,
            {
                "thread_block_size": thread_block_size,
                "debug": debug,
                "to_extension": to_extension,
                "kernel_inflation": kernel_inflation,
            },
        )

    def lazy_fixed(
        self,
        thread_block_size: int = 256,
        to_extension: bool = False,
        debug: bool = False,
        kernel_inflation: int = 16,
    ) -> torch.nn.Module:
        """
        Create a *once-compiling* convolution Module based on this `SubtractSemifield`.

        In general, `SubtractSemifield.dynamic` should be preferred for testing and also
        for training if the model can be traced by CUDA Graphs.
        If CUDA Graphs cannot capture the model code due to dynamic elements, then using
        `SubtractSemifield.lazy_fixed` with `to_extension=True` will minimise overhead.

        Returns
        -------
        conv : nn.Module
            A convolution module, suitable for use in `GenericConv`.
            Note that compilation will be based on the first inputs seen, after which
            the operation will be fixed: **only batch size may be changed afterwards**.
            The module is, however, traceable by e.g. `torch.compile`.

        Other Parameters
        ----------
        thread_block_size : int = 256
            The number of threads per CUDA block
        to_extension : bool = False
            Whether the resulting module should compile to a PyTorch extension.
            Doing so increases compilation times, but reduces per-call overhead
            when not using CUDA-Graphs.

            For neural networks, it is best to keep `to_extension` as False and use
            CUDA Graphs via `torch.compile(model, mode="reduce-overhead",
            fullgraph=True)` to eliminate the wrapper code.
            If this is not possible (due to highly dynamic code or irregular shapes),
            then the next best option would be to use `to_extension`
            and minimise call overhead.
        debug : bool = False
            Whether to print additional debugging and compilation information.
        kernel_inflation : int = 16
            The factor to inflate the kernel gradient with, to better distribute
            atomic operations.
            A larger factor can improve performance when the number of output pixels
            per kernel value is high, but only up to a point, and at the cost of memory
            efficiency.
        """
        return CompiledConvFixedLazy(
            self,
            {
                "thread_block_size": thread_block_size,
                "debug": debug,
                "to_extension": to_extension,
                "kernel_inflation": kernel_inflation,
            },
        )

    def __hash__(self):
        if self.cache_name is not None:
            return hash(self.cache_name)

        return hash(
            (
                self.add,
                self.times,
                self.d_times_d_img,
                self.d_times_d_kernel,
                self.subtract,
                self.d_add_d_right,
                self.zero,
            )
        )

    def __eq__(self, other):
        if not isinstance(other, SubtractSemifield):
            return False
        if self.cache_name is not None:
            return self.cache_name == other.cache_name

        return self is other

    @staticmethod
    def _get_result(res: torch.Tensor):
        return res

    @staticmethod
    def _number_to_cache(n: float):
        return str(n).replace(".", "_").replace("-", "_minus_")

A semifield definition where semifield addition has an inverse

For such semifields, the backwards pass can be done by 'subtracting' every value from the result to get the arguments for the additive derivative. The resulting module is compiled and works only on CUDA devices.

Note that, while this implementation is more memory-efficient than BroadcastSemifield, it is typically slower in execution speed. If memory usage is not a concern but training speed is, then BroadastSemifield should therefore be preferred.

Parameters

add : (float, float) -> float
Given an accumulator and a multiplied value, perform scalar semifield addition \oplus.
times : (float, float) -> float
Given an image and a kernel value, perform scalar semifield multiplication \otimes.
d_times_d_img : (float, float) -> float
Given the two arguments to times, compute the derivative to the first: \frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{img}}
d_times_d_kernel : (float, float) -> float
Given the two arguments to times, compute the derivative to the second: \frac{\delta (\textrm{img} \otimes \textrm{kernel}) }{\delta\textrm{kernel}}
subtract : (float, float) -> float

Given the final accumulator value res and a multiplied value val, use the inverse of add and

return an acc such that add(acc, val) == res. In other words: perform semifield subtraction.

d_add_d_right : (float, float) -> float
Given the two arguments to add, compute the derivative to the second: \frac{\delta (\textrm{acc} \oplus \textrm{val}) }{\delta\textrm{val}}
zero : float
The semifield zero.
cache_name : str, optional

Identifier for this semifield, allows for extension compilations to be cached.

Instances of SubtractSemifield that are meaningfully different should not have the same cache_name, as this may lead to the wrong compilation being used.

Other Parameters

post_sum : (float) -> float, optional

Some semifield additions are fairly complex and computationally expensive, but can be reinterpreted as a repeated simpler operation, followed by a scalar transformation of the final accumulator value. post_sum is then this scalar transformation, taking the final accumulator value res and transforming it into a value out.

Taking the root semifield R_3 as an example, we can see that if we use

  • times as a \otimes_3 b = (a \times b)^3
  • add as +
  • post_sum as \textrm{out} = \sqrt[3]{\textrm{res}}

then we can perform the reduction in terms of simple scalar addition, instead of having to take the power and root at every step.

Using such a transfrom does, however, require defining two other operators, namely the inverse and the derivative. When these are given, subtract and d_add_d_right will be given untransformed arguments: in the root semifield example, that would mean that the arguments to subtract and d_add_d_right are not yet taken to the p'th root.

undo_post_sum : (float) -> float, optional
The inverse of post_sum, required if post_sum is given.
d_post_d_acc : (float) -> float, optional
The derivative of post_sum to its argument, required if post_sum is given: \frac{\delta \textrm{post_sum}(\textrm{res}) }{\delta\textrm{res}}

Examples

Linear convolution that will recompile for new parameters:

>>> linear = SubtractSemifield.linear().dynamic()

R_3 convolution that will compile only once:

>>> root = SubtractSemifield.root(3.0).lazy_fixed()

For examples of how to construct a SubtractSemifield, see the source code.

Static methods

def linear()

Construct a linear SubtractSemifield

The linear field is defined as: (\mathbb{R}, +, \times)

Mainly for comparison purposes: the linear convolutions offered by PyTorch use CUDNN, which is far better optimised for CUDA devices.

def root(p)

Construct a R_p SubtractSemifield.

The root semifields are defined as: (\mathbb{R}_+, \oplus_p, \times) \textrm{ for all } p\ne0 \textrm{ where } a\oplus_p b= \sqrt[p]{a^p+b^p} with the semifield zero being 0 and the semifield one being 1.

Parameters

p : int
The power to use in \oplus_p. May not be zero.
def log(mu)

Construct a L_+\mu or L_-\mu SubtractSemifield.

The log semifields are defined as: (\mathbb{R}\cup \{\pm\infty\}, \oplus_\mu, +) \textrm{ for all } \mu\ne0 \textrm{ where } a\oplus_\mu b= \frac{1}{\mu}\ln(e^{\mu a}+e^{\mu b}) with the semifield zero being -\infty for \mu>0 and \infty otherwise, and the semifield one being 0.

Parameters

mu : int
The base to use in \oplus_\mu. May not be zero.

Methods

def dynamic(self, thread_block_size=256, to_extension=False, debug=False, kernel_inflation=16)
Expand source code Browse git
def dynamic(
    self,
    thread_block_size: int = 256,
    to_extension: bool = False,
    debug: bool = False,
    kernel_inflation: int = 16,
) -> torch.nn.Module:
    """
    Create a *recompiling* convolution Module based on this `SubtractSemifield`.

    Returns
    -------
    conv : nn.Module
        A convolution module, suitable for use in `GenericConv`.
        Note that the compilation process is not traceable, and recompilations
        **may cause errors when using `torch.compile`** for backends other than
        CUDA Graphs

    Other Parameters
    ----------
    thread_block_size : int = 128
        The number of threads per CUDA block.
    to_extension : bool = False
        Whether the resulting module should compile to a PyTorch extension.
        Doing so increases compilation times, but reduces per-call overhead
        when not using CUDA-Graphs.

        For neural networks, it is best to keep `to_extension` as False and use
        CUDA Graphs via `torch.compile(model, mode="reduce-overhead",
        fullgraph=True)` to eliminate the wrapper code.
        If this is not possible (due to highly dynamic code or irregular shapes),
        then the next best option would be to use `to_extension`
        and minimise call overhead.
    debug : bool = False
        Whether to print additional debugging and compilation information.
    kernel_inflation : int = 16
        The factor to inflate the kernel gradient with, to better distribute
        atomic operations.
        A larger factor can improve performance when the number of output pixels
        per kernel value is high, but only up to a point, and at the cost of memory
        efficiency.
    """
    return CompiledConv(
        self,
        {
            "thread_block_size": thread_block_size,
            "debug": debug,
            "to_extension": to_extension,
            "kernel_inflation": kernel_inflation,
        },
    )

Create a recompiling convolution Module based on this SubtractSemifield.

Returns

conv : nn.Module
A convolution module, suitable for use in GenericConv. Note that the compilation process is not traceable, and recompilations may cause errors when using torch.compile for backends other than CUDA Graphs

Other Parameters

thread_block_size : int = 128
The number of threads per CUDA block.
to_extension : bool = False

Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs.

For neural networks, it is best to keep to_extension as False and use CUDA Graphs via torch.compile(model, mode="reduce-overhead", fullgraph=True) to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use to_extension and minimise call overhead.

debug : bool = False
Whether to print additional debugging and compilation information.
kernel_inflation : int = 16
The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency.
def lazy_fixed(self, thread_block_size=256, to_extension=False, debug=False, kernel_inflation=16)
Expand source code Browse git
def lazy_fixed(
    self,
    thread_block_size: int = 256,
    to_extension: bool = False,
    debug: bool = False,
    kernel_inflation: int = 16,
) -> torch.nn.Module:
    """
    Create a *once-compiling* convolution Module based on this `SubtractSemifield`.

    In general, `SubtractSemifield.dynamic` should be preferred for testing and also
    for training if the model can be traced by CUDA Graphs.
    If CUDA Graphs cannot capture the model code due to dynamic elements, then using
    `SubtractSemifield.lazy_fixed` with `to_extension=True` will minimise overhead.

    Returns
    -------
    conv : nn.Module
        A convolution module, suitable for use in `GenericConv`.
        Note that compilation will be based on the first inputs seen, after which
        the operation will be fixed: **only batch size may be changed afterwards**.
        The module is, however, traceable by e.g. `torch.compile`.

    Other Parameters
    ----------
    thread_block_size : int = 256
        The number of threads per CUDA block
    to_extension : bool = False
        Whether the resulting module should compile to a PyTorch extension.
        Doing so increases compilation times, but reduces per-call overhead
        when not using CUDA-Graphs.

        For neural networks, it is best to keep `to_extension` as False and use
        CUDA Graphs via `torch.compile(model, mode="reduce-overhead",
        fullgraph=True)` to eliminate the wrapper code.
        If this is not possible (due to highly dynamic code or irregular shapes),
        then the next best option would be to use `to_extension`
        and minimise call overhead.
    debug : bool = False
        Whether to print additional debugging and compilation information.
    kernel_inflation : int = 16
        The factor to inflate the kernel gradient with, to better distribute
        atomic operations.
        A larger factor can improve performance when the number of output pixels
        per kernel value is high, but only up to a point, and at the cost of memory
        efficiency.
    """
    return CompiledConvFixedLazy(
        self,
        {
            "thread_block_size": thread_block_size,
            "debug": debug,
            "to_extension": to_extension,
            "kernel_inflation": kernel_inflation,
        },
    )

Create a once-compiling convolution Module based on this SubtractSemifield.

In general, SubtractSemifield.dynamic() should be preferred for testing and also for training if the model can be traced by CUDA Graphs. If CUDA Graphs cannot capture the model code due to dynamic elements, then using SubtractSemifield.lazy_fixed() with to_extension=True will minimise overhead.

Returns

conv : nn.Module
A convolution module, suitable for use in GenericConv. Note that compilation will be based on the first inputs seen, after which the operation will be fixed: only batch size may be changed afterwards. The module is, however, traceable by e.g. torch.compile.

Other Parameters

thread_block_size : int = 256
The number of threads per CUDA block
to_extension : bool = False

Whether the resulting module should compile to a PyTorch extension. Doing so increases compilation times, but reduces per-call overhead when not using CUDA-Graphs.

For neural networks, it is best to keep to_extension as False and use CUDA Graphs via torch.compile(model, mode="reduce-overhead", fullgraph=True) to eliminate the wrapper code. If this is not possible (due to highly dynamic code or irregular shapes), then the next best option would be to use to_extension and minimise call overhead.

debug : bool = False
Whether to print additional debugging and compilation information.
kernel_inflation : int = 16
The factor to inflate the kernel gradient with, to better distribute atomic operations. A larger factor can improve performance when the number of output pixels per kernel value is high, but only up to a point, and at the cost of memory efficiency.
class QuadraticKernelSpectral2D (in_channels, out_channels, kernel_size, init=None)
Expand source code Browse git
class QuadraticKernelSpectral2D(nn.Module):
    r"""
    A kernel that evaluates \(x^T S^{-1} x\), with skew parameterised as an angle \(\theta\)

    This module takes no arguments in `forward` and produces a
    `Tensor` of `OIHW`, making this Module suitable as a kernel for `GenericConv`.

    Parameters
    -------
    in_channels : int
        The number of input channels: the `I` in `OIHW`.
    out_channels : int
        The number of output channels: the `O` in `OIHW`.
    kernel_size : int
        The height `H` and width `W` of the kernel (rectangular kernels are not supported).
    init : dict, optional
        The initialisation stratergy for the underlying covariance matrices.
        If provided, the dictionary must have keys:

        `"var"` for the variances, which can take values:

        - `float` to indicate a constant initialisation
        - `"normal"` to indicate values normally distributed around 2.0
        - `"uniform"` to indicate uniform-random initialisation
        - `"uniform-iso"` to indicate isotropic uniform-random initialisation
        - `"ss-iso"` to indicate scale-space isotropic initialisation
        - `"skewed"` to indicate uniform-random initialisation with the second primary
          axis having a significantly higher variance (**default**)

        and `"theta"` for the rotations, which can take values:

        - `"uniform"` to indicate uniform-random initialisation
        - `"spin"` to indicate shuffled but evenly spaced angles (**default**)

    Examples
    -------

    >>> kernel = QuadraticKernelSpectral2D(5, 6, 3, {"var": 3.0, "theta": "spin"})
    >>> tuple(kernel().shape)
    (6, 5, 3, 3)
    """

    _pos_grid: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        init: dict[str, str | float] | None = None,
    ):
        super().__init__()
        self.covs = CovSpectral2D(in_channels, out_channels, init, kernel_size)
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.in_channels = in_channels
        self.register_buffer(
            "_pos_grid",
            make_pos_grid(kernel_size).reshape(kernel_size * kernel_size, 2),
        )

    def forward(self):
        dists = torch.einsum(
            "kx,oixX,kX->oik", self._pos_grid, self.covs.inverse_cov(), self._pos_grid
        ).view(
            (
                self.out_channels,
                self.in_channels,
                self.kernel_size,
                self.kernel_size,
            )
        )
        return -dists

    def extra_repr(self):
        kernel_size = self.kernel_size
        return f"{self.in_channels}, {self.out_channels}, {kernel_size=}"

    def inspect_parameters(self) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Inspect the parameters of the underlying covariance matrices.

        Returns
        -------
        log_std : Tensor of (O, I, 2)
            The logathirms of the standard deviations in both axes for all kernels
        theta : Tensor of (O, I)
            The counter-clockwise angles between the first axis and the X-axis for all kernels
        """
        return self.covs.log_std, self.covs.theta

    @torch.no_grad()
    def plot(self):
        """Provide a simple visualisation of some kernels. Requires `seaborn`."""
        plot_kernels(self.forward())

A kernel that evaluates x^T S^{-1} x, with skew parameterised as an angle \theta

This module takes no arguments in forward and produces a Tensor of OIHW, making this Module suitable as a kernel for GenericConv.

Parameters

in_channels : int
The number of input channels: the I in OIHW.
out_channels : int
The number of output channels: the O in OIHW.
kernel_size : int
The height H and width W of the kernel (rectangular kernels are not supported).
init : dict, optional

The initialisation stratergy for the underlying covariance matrices. If provided, the dictionary must have keys:

"var" for the variances, which can take values:

  • float to indicate a constant initialisation
  • "normal" to indicate values normally distributed around 2.0
  • "uniform" to indicate uniform-random initialisation
  • "uniform-iso" to indicate isotropic uniform-random initialisation
  • "ss-iso" to indicate scale-space isotropic initialisation
  • "skewed" to indicate uniform-random initialisation with the second primary axis having a significantly higher variance (default)

and "theta" for the rotations, which can take values:

  • "uniform" to indicate uniform-random initialisation
  • "spin" to indicate shuffled but evenly spaced angles (default)

Examples

>>> kernel = QuadraticKernelSpectral2D(5, 6, 3, {"var": 3.0, "theta": "spin"})
>>> tuple(kernel().shape)
(6, 5, 3, 3)

Methods

def inspect_parameters(self)
Expand source code Browse git
def inspect_parameters(self) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Inspect the parameters of the underlying covariance matrices.

    Returns
    -------
    log_std : Tensor of (O, I, 2)
        The logathirms of the standard deviations in both axes for all kernels
    theta : Tensor of (O, I)
        The counter-clockwise angles between the first axis and the X-axis for all kernels
    """
    return self.covs.log_std, self.covs.theta

Inspect the parameters of the underlying covariance matrices.

Returns

log_std : Tensor of (O, I, 2)
The logathirms of the standard deviations in both axes for all kernels
theta : Tensor of (O, I)
The counter-clockwise angles between the first axis and the X-axis for all kernels
def plot(self)
Expand source code Browse git
@torch.no_grad()
def plot(self):
    """Provide a simple visualisation of some kernels. Requires `seaborn`."""
    plot_kernels(self.forward())

Provide a simple visualisation of some kernels. Requires seaborn.

class QuadraticKernelCholesky2D (in_channels, out_channels, kernel_size, init=None)
Expand source code Browse git
class QuadraticKernelCholesky2D(nn.Module):
    r"""
    A kernel that evaluates \(x^T S^{-1} x\), with skew parameterised as Pearson's R

    This module takes no arguments in `forward` and produces a
    `Tensor` of `OIHW`, making this Module suitable as a kernel for `GenericConv`.

    Parameters
    -------
    in_channels : int
        The number of input channels: the `I` in `OIHW`.
    out_channels : int
        The number of output channels: the `O` in `OIHW`.
    kernel_size : int
        The height `H` and width `W` of the kernel (rectangular kernels are not supported).
    init : dict, optional
        The initialisation stratergy for the underlying covariance matrices.
        If provided, the dictionary must have the key `"var"`, which can take values:

        - `float` to indicate a constant initialisation
        - `"normal"` to indicate values normally distributed around 2.0
        - `"uniform"` to indicate uniform-random initialisation
        - `"uniform-iso"` to indicate isotropic uniform-random initialisation
        - `"ss-iso"` to indicate scale-space isotropic initialisation
        - `"skewed"` to indicate uniform-random initialisation with the second primary
          axis having a significantly higher variance (**default**)

        The skew parameter is always initialised using a clipped normal distribution
        centred around 0.

    Examples
    -------

    >>> kernel = QuadraticKernelCholesky2D(5, 6, 3, {"var": 3.0})
    >>> tuple(kernel().shape)
    (6, 5, 3, 3)
    """

    _pos_grid: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        init: dict[str, str | float] | None = None,
    ):
        super().__init__()
        self.covs = CovCholesky2D(in_channels, out_channels, init, kernel_size)
        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.in_channels = in_channels
        self.register_buffer("_pos_grid", make_pos_grid(kernel_size, grid_at_end=True))

    def forward(self):
        # [o, i, 2, k*k]
        bs = torch.linalg.solve_triangular(
            self.covs.cholesky(), self._pos_grid, upper=False
        )
        dists = (
            bs.pow(2)
            .sum(-2)
            .view(
                (
                    self.out_channels,
                    self.in_channels,
                    self.kernel_size,
                    self.kernel_size,
                )
            )
        )
        return -dists

    def extra_repr(self):
        kernel_size = self.kernel_size
        return f"{self.in_channels}, {self.out_channels}, {kernel_size=}"

    def inspect_parameters(self) -> tuple[torch.Tensor, torch.Tensor]:
        """
        Inspect the parameters of the underlying covariance matrices.

        Returns
        -------
        log_std : Tensor of (O, I, 2)
            The logathirms of the standard deviations in both axes for all kernels
        corr : Tensor of (O, I)
            The skews, as values for Person's R, for all kernels
        """
        return self.covs.log_std.moveaxis(0, 2), self.covs.corr

    @torch.no_grad()
    def plot(self):
        """Provide a simple visualisation of some kernels. Requires `seaborn`."""
        plot_kernels(self.forward())

A kernel that evaluates x^T S^{-1} x, with skew parameterised as Pearson's R

This module takes no arguments in forward and produces a Tensor of OIHW, making this Module suitable as a kernel for GenericConv.

Parameters

in_channels : int
The number of input channels: the I in OIHW.
out_channels : int
The number of output channels: the O in OIHW.
kernel_size : int
The height H and width W of the kernel (rectangular kernels are not supported).
init : dict, optional

The initialisation stratergy for the underlying covariance matrices. If provided, the dictionary must have the key "var", which can take values:

  • float to indicate a constant initialisation
  • "normal" to indicate values normally distributed around 2.0
  • "uniform" to indicate uniform-random initialisation
  • "uniform-iso" to indicate isotropic uniform-random initialisation
  • "ss-iso" to indicate scale-space isotropic initialisation
  • "skewed" to indicate uniform-random initialisation with the second primary axis having a significantly higher variance (default)

The skew parameter is always initialised using a clipped normal distribution centred around 0.

Examples

>>> kernel = QuadraticKernelCholesky2D(5, 6, 3, {"var": 3.0})
>>> tuple(kernel().shape)
(6, 5, 3, 3)

Methods

def inspect_parameters(self)
Expand source code Browse git
def inspect_parameters(self) -> tuple[torch.Tensor, torch.Tensor]:
    """
    Inspect the parameters of the underlying covariance matrices.

    Returns
    -------
    log_std : Tensor of (O, I, 2)
        The logathirms of the standard deviations in both axes for all kernels
    corr : Tensor of (O, I)
        The skews, as values for Person's R, for all kernels
    """
    return self.covs.log_std.moveaxis(0, 2), self.covs.corr

Inspect the parameters of the underlying covariance matrices.

Returns

log_std : Tensor of (O, I, 2)
The logathirms of the standard deviations in both axes for all kernels
corr : Tensor of (O, I)
The skews, as values for Person's R, for all kernels
def plot(self)
Expand source code Browse git
@torch.no_grad()
def plot(self):
    """Provide a simple visualisation of some kernels. Requires `seaborn`."""
    plot_kernels(self.forward())

Provide a simple visualisation of some kernels. Requires seaborn.

class QuadraticKernelIso2D (in_channels, out_channels, kernel_size, init=None)
Expand source code Browse git
class QuadraticKernelIso2D(nn.Module):
    r"""
    A kernel that evaluates \(x^T sI x\), representing an isotropic quadratic

    This module takes no arguments in `forward` and produces a
    `Tensor` of `OIHW`, making this Module suitable as a kernel for `GenericConv`.

    Parameters
    -------
    in_channels : int
        The number of input channels: the `I` in `OIHW`.
    out_channels : int
        The number of output channels: the `O` in `OIHW`.
    kernel_size : int
        The height `H` and width `W` of the kernel (rectangular kernels are not supported).
    init : dict, optional
        The initialisation stratergy for the variances / scale parameters.
        If provided, the dictionary must have the key `"var"`, which can take values:

        - `float` to indicate a constant initialisation
        - `"normal"` to indicate values normally distributed around 2.0
        - `"uniform"` to indicate uniform-random initialisation
        - `"ss"` to indicate scale-space initialisation (**default**)

    Attributes
    -------
    log_std : Tensor of (O, I)
        The logathirms of the standard deviations for all kernels

    Examples
    -------

    >>> kernel = QuadraticKernelIso2D(5, 6, 3, {"var": 3.0})
    >>> tuple(kernel().shape)
    (6, 5, 3, 3)
    """

    log_std: torch.Tensor
    _pos_grid: torch.Tensor

    def __init__(
        self,
        in_channels: int,
        out_channels: int,
        kernel_size: int,
        init: dict = None,
    ):
        super().__init__()
        init: dict[str, str | float] = init or {"var": "ss"}

        variances = torch.empty((out_channels, in_channels))
        if isinstance(init["var"], float):
            nn.init.constant_(variances, init["var"])
        elif init["var"] == "uniform":
            nn.init.uniform_(variances, 1, 16)
        elif init["var"] == "ss":
            spaced_vars = torch.linspace(
                1,
                2 * (kernel_size // 2) ** 2,
                steps=out_channels * in_channels,
            )
            permutation = torch.randperm(spaced_vars.shape[0])
            variances[:] = spaced_vars[permutation].reshape(out_channels, in_channels)
        elif init["var"] == "log-ss":
            spaced_vars = torch.logspace(
                math.log10(1),
                math.log10(2 * (kernel_size // 2) ** 2),
                steps=out_channels * in_channels,
            )
            permutation = torch.randperm(spaced_vars.shape[0])
            variances[:] = spaced_vars[permutation].reshape(out_channels, in_channels)
        elif init["var"] == "normal":
            nn.init.trunc_normal_(variances, mean=8.0, a=1.0, b=16.0)
        else:
            raise ValueError(f"Invalid {init['var']=}")

        self.log_std = nn.Parameter(variances.log().mul(0.5))

        self.kernel_size = kernel_size
        self.out_channels = out_channels
        self.in_channels = in_channels
        self.register_buffer("_pos_grid", make_pos_grid(kernel_size, grid_at_end=True))

    def forward(self):
        dists = (
            self._pos_grid.pow(2).sum(-2) / self.log_std.mul(2).exp().unsqueeze(2)
        ).reshape(
            self.out_channels, self.in_channels, self.kernel_size, self.kernel_size
        )
        return -dists

    def extra_repr(self):
        kernel_size = self.kernel_size
        return f"{self.in_channels}, {self.out_channels}, {kernel_size=}"

    @torch.no_grad()
    def plot(self):
        """Provide a simple visualisation of some kernels. Requires `seaborn`."""
        plot_kernels(self.forward())

A kernel that evaluates x^T sI x, representing an isotropic quadratic

This module takes no arguments in forward and produces a Tensor of OIHW, making this Module suitable as a kernel for GenericConv.

Parameters

in_channels : int
The number of input channels: the I in OIHW.
out_channels : int
The number of output channels: the O in OIHW.
kernel_size : int
The height H and width W of the kernel (rectangular kernels are not supported).
init : dict, optional

The initialisation stratergy for the variances / scale parameters. If provided, the dictionary must have the key "var", which can take values:

  • float to indicate a constant initialisation
  • "normal" to indicate values normally distributed around 2.0
  • "uniform" to indicate uniform-random initialisation
  • "ss" to indicate scale-space initialisation (default)

Attributes

log_std : Tensor of (O, I)
The logathirms of the standard deviations for all kernels

Examples

>>> kernel = QuadraticKernelIso2D(5, 6, 3, {"var": 3.0})
>>> tuple(kernel().shape)
(6, 5, 3, 3)

Methods

def plot(self)
Expand source code Browse git
@torch.no_grad()
def plot(self):
    """Provide a simple visualisation of some kernels. Requires `seaborn`."""
    plot_kernels(self.forward())

Provide a simple visualisation of some kernels. Requires seaborn.

class GenericClosing (kernel,
conv_dilation,
conv_erosion,
stride=1,
padding=0,
dilation=1,
groups=1,
group_broadcasting=False)
Expand source code Browse git
class GenericClosing(nn.Module):
    """
    A generic Module for implement a morphological closing with a dilation and erosion.

    `kind` is fixed to `"conv"` for dilation and `"corr"` for erosion, to simplify the
    implementation of the common morphological closing.

    Parameters
    -------
    kernel : nn.Module
        A module that produces a convolutional kernel from its `forward` method.
    conv_dilation : nn.Module
        A module representing the adjoint dilation that can take `image, kernel`
         as positional arguments, as well as
        `dilation`, `padding`, `stride`, `groups` **and `kind`** as keyword arguments,
        optionally supporting `group_broadcasting` and `kind`.
    conv_erosion : nn.Module
        A module representing the adjoint erosion that can take `image, kernel`
         as positional arguments, as well as
        `dilation`, `padding`, `stride`, `groups` **and `kind`** as keyword arguments,
        optionally supporting `group_broadcasting`.
    stride : int, (int, ...) = 1
        The stride passed to `conv`, either for all spatial dimensions or for each
        separately.
    padding : int, (int, ...), ((int, int), ...), "valid", "same" = 0
        The padding passed to `conv`.
        Depending on the type of `padding`:

        - `P` indicates padding at the start and end of all spatial axes with `P`.
        - `(P0, ...)` indicates padding at the start and end of the first spatial axis
          with `P0`, and similarly for all other spatial axes.
        - `((PBeg0, PEnd0), ...)` indicates padding the start of the first spatial axis
           with `PBeg0` and the end with `PEnd0`, similarly for all other spatial axes.
        - `"valid"` indicates to only perform the convolution with valid values of the
          image, i.e. no padding.
        - `"same"` indicates to pad the input such that a stride-1 convolution would
          produce an output of the same spatial size.
          Convolutions with higher stride will use the same padding scheme, but result
          in outputs of reduced size.
    dilation : int, (int, ...) = 1
        The dilation passed to `conv`, either for all spatial dimensions or for each
        separately.
    groups : int = 1
        The number of convolutional groups for this convolution.
    group_broadcasting : bool = False
        Whether to take the input kernels as a single output group, and broadcast
        across all input groups.
        `group_broadcasting` has no effect when `groups=1`

    Examples
    -------

    >>> import pytorch_nd_semiconv as semiconv
    >>> common_closing = semiconv.GenericClosing(
    ...     semiconv.QuadraticKernelCholesky2D(5, 5, 3),
    ...     semiconv.SelectSemifield.tropical_max().lazy_fixed(),
    ...     semiconv.SelectSemifield.tropical_min_negated().lazy_fixed(),
    ...     padding="same",
    ...     groups=5,
    ... )
    """

    def __init__(
        self,
        kernel: nn.Module,
        conv_dilation: nn.Module,
        conv_erosion: nn.Module,
        stride: int | tuple[int, ...] = 1,
        padding: (
            int
            | tuple[int, ...]
            | tuple[tuple[int, int], ...]
            | Literal["valid", "same"]
        ) = 0,
        dilation: int | tuple[int, ...] = 1,
        groups: int = 1,
        group_broadcasting: bool = False,
    ):
        super().__init__()
        self.padding = padding
        self.stride = stride
        self.dilation = dilation
        self.kernel = kernel
        self.conv_dilation = conv_dilation
        self.conv_erosion = conv_erosion
        self.groups = groups
        self.group_broadcasting = group_broadcasting
        self.kind = "closing"

        # Since these are custom arguments, we only want to pass them if they differ
        # from the default values (otherwise, they may be unexpected)
        self.kwargs = {}
        if self.group_broadcasting:
            self.kwargs["group_broadcasting"] = True

    def forward(self, x):
        kernel = self.kernel()
        dilated = self.conv_dilation(
            x,
            kernel,
            dilation=self.dilation,
            padding=self.padding,
            stride=self.stride,
            groups=self.groups,
            kind="conv",
            **self.kwargs,
        )
        closed = self.conv_erosion(
            dilated,
            kernel,
            dilation=self.dilation,
            padding=self.padding,
            stride=self.stride,
            groups=self.groups,
            kind="corr",
            **self.kwargs,
        )
        return closed

    extra_repr = GenericConv.extra_repr

A generic Module for implement a morphological closing with a dilation and erosion.

kind is fixed to "conv" for dilation and "corr" for erosion, to simplify the implementation of the common morphological closing.

Parameters

kernel : nn.Module
A module that produces a convolutional kernel from its forward method.
conv_dilation : nn.Module
A module representing the adjoint dilation that can take image, kernel as positional arguments, as well as dilation, padding, stride, groups and kind as keyword arguments, optionally supporting group_broadcasting and kind.
conv_erosion : nn.Module
A module representing the adjoint erosion that can take image, kernel as positional arguments, as well as dilation, padding, stride, groups and kind as keyword arguments, optionally supporting group_broadcasting.
stride : int, (int, ...) = 1
The stride passed to conv, either for all spatial dimensions or for each separately.
padding : int, (int, ...), ((int, int), ...), "valid", "same" = 0

The padding passed to conv. Depending on the type of padding:

  • P indicates padding at the start and end of all spatial axes with P.
  • (P0, …) indicates padding at the start and end of the first spatial axis with P0, and similarly for all other spatial axes.
  • ((PBeg0, PEnd0), …) indicates padding the start of the first spatial axis with PBeg0 and the end with PEnd0, similarly for all other spatial axes.
  • "valid" indicates to only perform the convolution with valid values of the image, i.e. no padding.
  • "same" indicates to pad the input such that a stride-1 convolution would produce an output of the same spatial size. Convolutions with higher stride will use the same padding scheme, but result in outputs of reduced size.
dilation : int, (int, ...) = 1
The dilation passed to conv, either for all spatial dimensions or for each separately.
groups : int = 1
The number of convolutional groups for this convolution.
group_broadcasting : bool = False
Whether to take the input kernels as a single output group, and broadcast across all input groups. group_broadcasting has no effect when groups=1

Examples

>>> import pytorch_nd_semiconv as semiconv
>>> common_closing = semiconv.GenericClosing(
...     semiconv.QuadraticKernelCholesky2D(5, 5, 3),
...     semiconv.SelectSemifield.tropical_max().lazy_fixed(),
...     semiconv.SelectSemifield.tropical_min_negated().lazy_fixed(),
...     padding="same",
...     groups=5,
... )
class LearnedKernel2D (in_channels, out_channels, kernel_size)
Expand source code Browse git
class LearnedKernel2D(nn.Module):
    """
    A utility that provides a fully learnable kernel compatible with `GenericConv`

    Parameters
    -------
    in_channels : int
        The number of input channels: the `I` in `OIHW`.
    out_channels : int
        The number of output channels: the `O` in `OIHW`.
    kernel_size : int
        The height `H` and width `W` of the kernel (rectangular kernels not supported).
    """

    def __init__(self, in_channels: int, out_channels: int, kernel_size: int):
        super().__init__()
        self.kernel = nn.Parameter(
            torch.empty(out_channels, in_channels, kernel_size, kernel_size)
        )
        nn.init.normal_(self.kernel)

    def forward(self):
        return self.kernel

A utility that provides a fully learnable kernel compatible with GenericConv

Parameters

in_channels : int
The number of input channels: the I in OIHW.
out_channels : int
The number of output channels: the O in OIHW.
kernel_size : int
The height H and width W of the kernel (rectangular kernels not supported).
class TorchLinearConv2D (*args, **kwargs)
Expand source code Browse git
class TorchLinearConv2D(nn.Module):
    """
    A utility that provides PyTorch Conv2D in a form compatible with `GenericConv`.
    """

    @staticmethod
    def forward(
        img: torch.Tensor,
        kernel: torch.Tensor,
        stride: int | tuple[int, int] = 1,
        padding: (
            int
            | tuple[int, int]
            | tuple[tuple[int, int], tuple[int, int]]
            | Literal["valid", "same"]
        ) = 0,
        dilation: int | tuple[int, int] = 1,
        groups: int = 1,
        group_broadcasting: bool = False,
        kind: Literal["conv", "corr"] = "conv",
    ):
        if group_broadcasting:
            if kernel.shape[0] != 1:
                raise ValueError("Torch conv2d cannot broadcast groups with grp_o > 1")

            kernel = kernel.broadcast_to(
                (groups, kernel.shape[1], kernel.shape[2], kernel.shape[3])
            )
        if kind == "conv":
            kernel = kernel.flip((2, 3))

        dilation = _as_tup_n(dilation, 2)
        (pad_y_beg, pad_y_end), (pad_x_beg, pad_x_end) = get_padding(
            padding, 2, dilation, kernel.shape[2:]
        )

        if pad_y_beg != pad_y_end or pad_x_beg != pad_x_end:
            padded = torch.constant_pad_nd(
                img,
                # Yes, the padding really is in this order.
                (pad_x_beg, pad_x_end, pad_y_beg, pad_y_end),
            )
            return torch.nn.functional.conv2d(
                padded, kernel, stride=stride, dilation=dilation, groups=groups
            )

        return torch.nn.functional.conv2d(
            img,
            kernel,
            stride=stride,
            dilation=dilation,
            groups=groups,
            padding=(pad_y_beg, pad_x_beg),
        )

A utility that provides PyTorch Conv2D in a form compatible with GenericConv.

class TorchMaxPool2D (kernel_size, stride=None, padding=0, dilation=1)
Expand source code Browse git
class TorchMaxPool2D(nn.Module):
    """
    A utility that provides torch.nn.MaxPool2d with padding like `GenericConv`.
    """

    def __init__(
        self,
        kernel_size: int | tuple[int, int],
        stride: int | tuple[int, int] = None,
        padding: (
            int
            | tuple[int, int]
            | tuple[tuple[int, int], tuple[int, int]]
            | Literal["valid", "same"]
        ) = 0,
        dilation: int | tuple[int, int] = 1,
    ):
        super().__init__()
        self.kernel_size = kernel_size
        self.stride = kernel_size if stride is None else stride
        self.padding = padding
        self.dilation = dilation

    def forward(
        self,
        img: torch.Tensor,
    ):
        dilation = _as_tup_n(self.dilation, 2)
        krn_spatial = _as_tup_n(self.kernel_size, 2)
        (pad_y_beg, pad_y_end), (pad_x_beg, pad_x_end) = get_padding(
            self.padding, 2, dilation, krn_spatial
        )

        if pad_y_beg == pad_y_end and pad_x_beg == pad_x_end:
            use_padding = (pad_y_beg, pad_x_beg)
        else:
            img = torch.constant_pad_nd(
                img,
                # Yes, the padding really is in this order.
                (pad_x_beg, pad_x_end, pad_y_beg, pad_y_end),
            )
            use_padding = 0

        return torch.nn.functional.max_pool2d(
            input=img,
            kernel_size=krn_spatial,
            stride=self.stride,
            padding=use_padding,
            dilation=dilation,
            ceil_mode=False,
            return_indices=False,
        )

A utility that provides torch.nn.MaxPool2d with padding like GenericConv.