Coverage for src/diffusionlab/distributions/gmm/iso_gmm.py: 100%
47 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from dataclasses import dataclass
2from typing import Tuple, cast
3from jax import numpy as jnp, Array
4import jax
5from diffusionlab.distributions.base import Distribution
6from diffusionlab.distributions.gmm.gmm import GMM
7from diffusionlab.distributions.gmm.utils import create_gmm_vector_field_fns
8from diffusionlab.dynamics import DiffusionProcess
11@dataclass(frozen=True)
12class IsoGMM(Distribution):
13 """
14 Implements an isotropic Gaussian Mixture Model (GMM) distribution.
16 The probability measure is given by:
18 ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variances[i] * I)``
20 This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process.
22 Attributes:
23 dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters.
25 - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components.
26 - ``variances`` (``Array[num_components]``): The variances of the GMM components.
27 - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components.
29 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused).
30 """
32 def __init__(self, means: Array, variances: Array, priors: Array):
33 """
34 Initializes the isotropic GMM distribution.
36 Args:
37 means (``Array[num_components, data_dim]``): Means for each Gaussian component.
38 variances (``Array[num_components]``): Variance for each Gaussian component.
39 priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1.
40 """
41 eps = cast(float, jnp.finfo(variances.dtype).eps)
42 assert means.ndim == 2
43 num_components, data_dim = means.shape
44 assert variances.shape == (num_components,)
45 assert priors.shape == (num_components,)
46 assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps)
47 assert jnp.all(variances >= -eps)
49 super().__init__(
50 dist_params={
51 "means": means,
52 "variances": variances,
53 "priors": priors,
54 },
55 dist_hparams={},
56 )
58 def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]:
59 """
60 Draws samples from the isotropic GMM distribution.
62 Args:
63 key (``Array``): JAX PRNG key for random sampling.
64 num_samples (``int``): The total number of samples to generate.
66 Returns:
67 ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn.
68 """
69 data_dim = self.dist_params["means"].shape[1]
70 covs = jax.vmap(lambda variance: variance * jnp.eye(data_dim))(
71 self.dist_params["variances"]
72 )
73 base_gmm = GMM(self.dist_params["means"], covs, self.dist_params["priors"])
74 return base_gmm.sample(key, num_samples)
76 def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
77 """
78 Computes the score vector field ``∇_x log p_t(x_t)`` for the isotropic GMM distribution.
80 This is calculated with respect to the perturbed distribution ``p_t`` induced by the
81 ``diffusion_process`` at time ``t``.
83 Args:
84 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
85 t (``Array[]``): The time step (scalar).
86 diffusion_process (``DiffusionProcess``): The diffusion process definition.
88 Returns:
89 ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``.
90 """
91 return iso_gmm_score(
92 x_t,
93 t,
94 diffusion_process,
95 self.dist_params["means"],
96 self.dist_params["variances"],
97 self.dist_params["priors"],
98 )
100 def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
101 """
102 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the isotropic GMM distribution.
104 This represents the expected original sample ``x_0`` given the noisy observation ``x_t``
105 at time ``t`` under the ``diffusion_process``.
107 Args:
108 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
109 t (``Array[]``): The time step (scalar).
110 diffusion_process (``DiffusionProcess``): The diffusion process definition.
112 Returns:
113 ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``.
114 """
115 return iso_gmm_x0(
116 x_t,
117 t,
118 diffusion_process,
119 self.dist_params["means"],
120 self.dist_params["variances"],
121 self.dist_params["priors"],
122 )
124 def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
125 """
126 Computes the noise prediction ``ε`` for the isotropic GMM distribution.
128 This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t``
129 at time ``t`` under the ``diffusion_process``.
131 Args:
132 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
133 t (``Array[]``): The time step (scalar).
134 diffusion_process (``DiffusionProcess``): The diffusion process definition.
136 Returns:
137 ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``.
138 """
139 return iso_gmm_eps(
140 x_t,
141 t,
142 diffusion_process,
143 self.dist_params["means"],
144 self.dist_params["variances"],
145 self.dist_params["priors"],
146 )
148 def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
149 """
150 Computes the velocity vector field ``v`` for the isotropic GMM distribution.
152 This is conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``.
154 Args:
155 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
156 t (``Array[]``): The time step (scalar).
157 diffusion_process (``DiffusionProcess``): The diffusion process definition.
159 Returns:
160 ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``.
161 """
162 return iso_gmm_v(
163 x_t,
164 t,
165 diffusion_process,
166 self.dist_params["means"],
167 self.dist_params["variances"],
168 self.dist_params["priors"],
169 )
172def iso_gmm_x0(
173 x_t: Array,
174 t: Array,
175 diffusion_process: DiffusionProcess,
176 means: Array,
177 variances: Array,
178 priors: Array,
179) -> Array:
180 """
181 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a GMM.
183 This implements the closed-form solution for the conditional expectation
184 ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution
185 defined by ``means``, ``covs``, and ``priors``.
187 Args:
188 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
189 t (``Array[]``): The time step (scalar).
190 diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``.
191 means (``Array[num_components, data_dim]``): Means of the GMM components.
192 variances (``Array[num_components]``): Covariances of the GMM components.
193 priors (``Array[num_components]``): Mixture weights of the GMM components.
195 Returns:
196 ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``.
197 """
198 num_components, data_dim = means.shape
199 alpha_t = diffusion_process.alpha(t)
200 sigma_t = diffusion_process.sigma(t)
202 means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim)
203 variances_t = jax.vmap(lambda variance: alpha_t**2 * variance + sigma_t**2)(
204 variances
205 ) # (num_components,)
207 xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)(
208 means_t
209 ) # (num_components, data_dim)
210 variances_t_inv_xbars_t = jax.vmap(lambda variance_t, xbar_t: xbar_t / variance_t)(
211 variances_t, xbars_t
212 ) # (num_components, data_dim)
214 log_likelihoods_unnormalized = jax.vmap(
215 lambda xbar_t, variance_t, variance_t_inv_xbar_t: -0.5
216 * (jnp.sum(xbar_t * variance_t_inv_xbar_t) + data_dim * jnp.log(variance_t))
217 )(xbars_t, variances_t, variances_t_inv_xbars_t) # (num_components,)
218 log_posterior_unnormalized = (
219 jnp.log(priors) + log_likelihoods_unnormalized
220 ) # (num_components,)
221 posterior_probs = jax.nn.softmax(
222 log_posterior_unnormalized, axis=0
223 ) # (num_components,) sum to 1
225 posterior_means = jax.vmap(
226 lambda mean, variance, variance_t_inv_xbar_t: mean
227 + alpha_t * variance * variance_t_inv_xbar_t
228 )(means, variances, variances_t_inv_xbars_t) # (num_components, data_dim)
230 x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,)
232 return x0_pred
235# Generate eps, score, v functions from iso_gmm_x0
236iso_gmm_eps, iso_gmm_score, iso_gmm_v = create_gmm_vector_field_fns(iso_gmm_x0)