Coverage for src/diffusionlab/distributions/gmm/low_rank_gmm.py: 100%
45 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from typing import Tuple, cast
2from dataclasses import dataclass
3from jax import numpy as jnp, Array
4import jax
5from diffusionlab.distributions.base import Distribution
6from diffusionlab.distributions.gmm.gmm import GMM
7from diffusionlab.distributions.gmm.utils import (
8 _logdeth,
9 _lstsq,
10 create_gmm_vector_field_fns,
11)
12from diffusionlab.dynamics import DiffusionProcess
15@dataclass(frozen=True)
16class LowRankGMM(Distribution):
17 """
18 Implements a low-rank Gaussian Mixture Model (GMM) distribution.
20 The probability measure is given by:
22 ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], cov_factors[i] @ cov_factors[i].T)``
24 This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process.
26 Attributes:
27 dist_params (``Dict[str, Array]``): Dictionary containing the core low-rank GMM parameters.
29 - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components.
30 - ``cov_factors`` (``Array[num_components, data_dim, rank]``): The low-rank covariance matrix factors of the GMM components.
31 - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components.
33 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused).
34 """
36 def __init__(self, means: Array, cov_factors: Array, priors: Array):
37 """
38 Initializes the low-rank GMM distribution.
40 Args:
41 means (``Array[num_components, data_dim]``): Means for each Gaussian component.
42 cov_factors (``Array[num_components, data_dim, rank]``): Low-rank covariance matrices for each Gaussian component.
43 priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1.
44 """
45 eps = cast(float, jnp.finfo(cov_factors.dtype).eps)
46 assert means.ndim == 2
47 num_components, data_dim, rank = cov_factors.shape
48 assert means.shape == (num_components, data_dim)
49 assert priors.shape == (num_components,)
50 assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps)
52 super().__init__(
53 dist_params={
54 "means": means,
55 "cov_factors": cov_factors,
56 "priors": priors,
57 },
58 dist_hparams={},
59 )
61 def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]:
62 """
63 Draws samples from the low-rank GMM distribution.
65 Args:
66 key (``Array``): JAX PRNG key for random sampling.
67 num_samples (``int``): The total number of samples to generate.
69 Returns:
70 ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn.
71 """
72 covs = jax.vmap(
73 lambda low_rank_cov_factor: low_rank_cov_factor @ low_rank_cov_factor.T
74 )(self.dist_params["cov_factors"])
75 return GMM(self.dist_params["means"], covs, self.dist_params["priors"]).sample(
76 key, num_samples
77 )
79 def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
80 """
81 Computes the score vector field ``∇_x log p_t(x_t)`` for the low-rank GMM distribution.
83 This is calculated with respect to the perturbed distribution ``p_t`` induced by the
84 ``diffusion_process`` at time ``t``.
86 Args:
87 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
88 t (``Array[]``): The time step (scalar).
89 diffusion_process (``DiffusionProcess``): The diffusion process definition.
91 Returns:
92 ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``.
93 """
94 return low_rank_gmm_score(
95 x_t,
96 t,
97 diffusion_process,
98 self.dist_params["means"],
99 self.dist_params["cov_factors"],
100 self.dist_params["priors"],
101 )
103 def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
104 """
105 Computes the denoised prediction x0 = E[x_0 | x_t] for the low-rank GMM distribution.
107 This represents the expected original sample ``x_0`` given the noisy observation ``x_t``
108 at time ``t`` under the ``diffusion_process``.
110 Args:
111 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
112 t (``Array[]``): The time step (scalar).
113 diffusion_process (``DiffusionProcess``): The diffusion process definition.
115 Returns:
116 ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``.
117 """
118 return low_rank_gmm_x0(
119 x_t,
120 t,
121 diffusion_process,
122 self.dist_params["means"],
123 self.dist_params["cov_factors"],
124 self.dist_params["priors"],
125 )
127 def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
128 """
129 Computes the noise prediction ε for the low-rank GMM distribution.
131 This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t``
132 at time ``t`` under the ``diffusion_process``.
134 Args:
135 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
136 t (``Array[]``): The time step (scalar).
137 diffusion_process (``DiffusionProcess``): The diffusion process definition.
139 Returns:
140 ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``.
141 """
142 return low_rank_gmm_eps(
143 x_t,
144 t,
145 diffusion_process,
146 self.dist_params["means"],
147 self.dist_params["cov_factors"],
148 self.dist_params["priors"],
149 )
151 def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array:
152 """
153 Computes the velocity vector field ``v`` for the low-rank GMM distribution.
155 This is the conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``.
157 Args:
158 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
159 t (``Array[]``): The time step (scalar).
160 diffusion_process (``DiffusionProcess``): The diffusion process definition.
162 Returns:
163 ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``.
164 """
165 return low_rank_gmm_v(
166 x_t,
167 t,
168 diffusion_process,
169 self.dist_params["means"],
170 self.dist_params["cov_factors"],
171 self.dist_params["priors"],
172 )
175def low_rank_gmm_x0(
176 x_t: Array,
177 t: Array,
178 diffusion_process: DiffusionProcess,
179 means: Array,
180 cov_factors: Array,
181 priors: Array,
182) -> Array:
183 """
184 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a low-rank GMM.
186 This implements the closed-form solution for the conditional expectation
187 ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the low-rank GMM distribution
188 defined by ``means``, ``cov_factors``, and ``priors``.
190 Args:
191 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``.
192 t (``Array[]``): The time step (scalar).
193 diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``.
194 means (``Array[num_components, data_dim]``): Means of the GMM components.
195 cov_factors (``Array[num_components, data_dim, rank]``): Low-rank covariance matrices of the GMM components.
196 priors (``Array[num_components]``): Mixture weights of the GMM components.
198 Returns:
199 ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``.
200 """
201 num_components, data_dim, rank = cov_factors.shape
202 alpha_t = diffusion_process.alpha(t)
203 sigma_t = diffusion_process.sigma(t)
205 means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim)
206 inner_covs = jax.vmap(lambda cov_factor: cov_factor.T @ cov_factor)(
207 cov_factors
208 ) # (num_components, rank, rank)
210 xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)(
211 means_t
212 ) # (num_components, data_dim)
213 covs_t_inverse_xbars_t = jax.vmap(
214 lambda cov_factor, inner_cov, xbar_t: (1 / sigma_t**2)
215 * (
216 xbar_t
217 - cov_factor
218 @ _lstsq(
219 inner_cov + (sigma_t / alpha_t) ** 2 * jnp.eye(rank),
220 cov_factor.T @ xbar_t,
221 )
222 )
223 )(cov_factors, inner_covs, xbars_t) # (num_components, data_dim)
225 logdets_covs_t = jax.vmap(
226 lambda inner_cov: _logdeth((alpha_t / sigma_t) ** 2 * inner_cov + jnp.eye(rank))
227 )(inner_covs) + 2 * data_dim * jnp.log(sigma_t) # (num_components,)
229 log_likelihoods_unnormalized = jax.vmap(
230 lambda xbar_t, covs_t_inverse_xbar_t, logdet_covs_t: -(1 / 2)
231 * (jnp.sum(xbar_t * covs_t_inverse_xbar_t) + logdet_covs_t)
232 )(xbars_t, covs_t_inverse_xbars_t, logdets_covs_t) # (num_components,)
234 log_posterior_unnormalized = (
235 jnp.log(priors) + log_likelihoods_unnormalized
236 ) # (num_components,)
238 posterior_probs = jax.nn.softmax(
239 log_posterior_unnormalized, axis=0
240 ) # (num_components,)
242 posterior_means = jax.vmap(
243 lambda mean, cov_factor, covs_t_inverse_xbar_t: mean
244 + alpha_t * cov_factor @ (cov_factor.T @ covs_t_inverse_xbar_t)
245 )(means, cov_factors, covs_t_inverse_xbars_t) # (num_components, data_dim)
247 x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,)
249 return x0_pred
252# Generate eps, score, v functions from low_rank_gmm_x0
253low_rank_gmm_eps, low_rank_gmm_score, low_rank_gmm_v = create_gmm_vector_field_fns(
254 low_rank_gmm_x0
255)