Coverage for src/diffusionlab/dynamics.py: 100%
29 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from dataclasses import dataclass, field
2from typing import Callable
4import jax
5from jax import Array, numpy as jnp
8@dataclass(frozen=True)
9class DiffusionProcess:
10 """
11 Base class for implementing various diffusion processes.
13 A diffusion process defines how data evolves over time when noise is added according to
14 specific dynamics operating on scalar time inputs. This class provides a framework to
15 implement diffusion processes based on a schedule defined by ``α(t)`` and ``σ(t)``.
17 The diffusion is parameterized by two scalar functions of scalar time ``t``:
19 - ``α(t)``: Controls how much of the original signal is preserved at time ``t``.
20 - ``σ(t)``: Controls how much noise is added at time ``t``.
22 The forward process for a single data point ``x_0`` is defined as:
24 ``x_t = α(t) * x_0 + σ(t) * ε``
26 where:
28 - ``x_0`` is the original data (``Array[*data_dims]``)
29 - ``x_t`` is the noised data at time ``t`` (``Array[*data_dims]``)
30 - ``ε`` is random noise sampled from a standard Gaussian distribution (``Array[*data_dims]``)
31 - ``t`` is the scalar diffusion time parameter (``Array[]``)
33 Attributes:
34 alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``.
35 sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``.
36 alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``.
37 sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``.
38 """
40 alpha: Callable[[Array], Array]
41 sigma: Callable[[Array], Array]
42 alpha_prime: Callable[[Array], Array] = field(init=False)
43 sigma_prime: Callable[[Array], Array] = field(init=False)
45 def __post_init__(self):
46 object.__setattr__(self, "alpha_prime", jax.grad(self.alpha))
47 object.__setattr__(self, "sigma_prime", jax.grad(self.sigma))
49 def forward(self, x: Array, t: Array, eps: Array) -> Array:
50 """
51 Applies the forward diffusion process to a data tensor ``x`` at time ``t`` using noise ``ε``.
53 Computes ``x_t = α(t) * x + σ(t) * ε``.
55 Args:
56 x (``Array[*data_dims]``): The input data tensor ``x_0``.
57 t (``Array[]``): The scalar time parameter ``t``.
58 eps (``Array[*data_dims]``): The Gaussian noise tensor ``ε``, matching the shape of ``x``.
60 Returns:
61 ``Array[*data_dims]``: The noised data tensor ``x_t`` at time ``t``.
62 """
63 alpha_t = self.alpha(t)
64 sigma_t = self.sigma(t)
65 return alpha_t * x + sigma_t * eps
68@dataclass(frozen=True)
69class VarianceExplodingProcess(DiffusionProcess):
70 """
71 Implements a Variance Exploding (VE) diffusion process.
73 In this process, the signal component is constant (``α(t) = 1``), while the noise component
74 increases over time according to the provided ``σ(t)`` function. The variance of the
75 noised data ``x_t`` explodes as ``t`` increases.
77 Forward process:
79 ``x_t = x_0 + σ(t) * ε``.
81 This process uses:
83 - ``α(t) = 1``
84 - ``σ(t) =`` Provided by the user
86 Attributes:
87 alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to 1.
88 sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Provided by the user.
89 alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to 0.
90 sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``.
91 """
93 def __init__(self, sigma: Callable[[Array], Array]):
94 """
95 Initialize a Variance Exploding diffusion process.
97 Args:
98 sigma (``Callable[[Array], Array]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``.
99 """
100 super().__init__(alpha=lambda t: jnp.ones_like(t), sigma=sigma)
103@dataclass(frozen=True)
104class VariancePreservingProcess(DiffusionProcess):
105 """
106 Implements a Variance Preserving (VP) diffusion process, often used in DDPMs.
108 This process maintains the variance of the noised data ``x_t`` close to 1 (assuming ``x_0``
109 and ``ε`` have unit variance) throughout the diffusion by scaling the signal and noise
110 components appropriately.
112 Uses the following scalar dynamics:
114 - ``α(t) = sqrt(1 - t²)``
115 - ``σ(t) = t``
117 Forward process:
119 ``x_t = sqrt(1 - t²) * x_0 + t * ε``.
121 Attributes:
122 alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to ``sqrt(1 - t²)``.
123 sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Set to ``t``.
124 alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to ``-t / sqrt(1 - t²)``.
125 sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. Set to ``1``.
126 """
128 def __init__(self):
129 """
130 Initialize a Variance Preserving process with predefined scalar dynamics.
131 """
132 super().__init__(
133 alpha=lambda t: jnp.sqrt(jnp.ones_like(t) - t**2), sigma=lambda t: t
134 )
137@dataclass(frozen=True)
138class FlowMatchingProcess(DiffusionProcess):
139 """
140 Implements a diffusion process based on Flow Matching principles.
142 This process defines dynamics that linearly interpolate between the data distribution
143 at ``t=0`` and a noise distribution (standard Gaussian) at ``t=1``.
145 Uses the following scalar dynamics:
147 - ``α(t) = 1 - t``
148 - ``σ(t) = t``
150 Forward process:
152 ``x_t = (1 - t) * x_0 + t * ε``.
154 Attributes:
155 alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to ``1 - t``.
156 sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Set to ``t``.
157 alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to ``-1``.
158 sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. Set to ``1``.
159 """
161 def __init__(self):
162 """
163 Initialize a Flow Matching process with predefined linear interpolation dynamics.
164 """
165 super().__init__(alpha=lambda t: jnp.ones_like(t) - t, sigma=lambda t: t)