Coverage for src/diffusionlab/distributions/gmm/low_rank_gmm.py: 100%

45 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from typing import Tuple, cast 

2from dataclasses import dataclass 

3from jax import numpy as jnp, Array 

4import jax 

5from diffusionlab.distributions.base import Distribution 

6from diffusionlab.distributions.gmm.gmm import GMM 

7from diffusionlab.distributions.gmm.utils import ( 

8 _logdeth, 

9 _lstsq, 

10 create_gmm_vector_field_fns, 

11) 

12from diffusionlab.dynamics import DiffusionProcess 

13 

14 

15@dataclass(frozen=True) 

16class LowRankGMM(Distribution): 

17 """ 

18 Implements a low-rank Gaussian Mixture Model (GMM) distribution. 

19 

20 The probability measure is given by: 

21 

22 ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], cov_factors[i] @ cov_factors[i].T)`` 

23 

24 This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. 

25 

26 Attributes: 

27 dist_params (``Dict[str, Array]``): Dictionary containing the core low-rank GMM parameters. 

28 

29 - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. 

30 - ``cov_factors`` (``Array[num_components, data_dim, rank]``): The low-rank covariance matrix factors of the GMM components. 

31 - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. 

32 

33 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). 

34 """ 

35 

36 def __init__(self, means: Array, cov_factors: Array, priors: Array): 

37 """ 

38 Initializes the low-rank GMM distribution. 

39 

40 Args: 

41 means (``Array[num_components, data_dim]``): Means for each Gaussian component. 

42 cov_factors (``Array[num_components, data_dim, rank]``): Low-rank covariance matrices for each Gaussian component. 

43 priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1. 

44 """ 

45 eps = cast(float, jnp.finfo(cov_factors.dtype).eps) 

46 assert means.ndim == 2 

47 num_components, data_dim, rank = cov_factors.shape 

48 assert means.shape == (num_components, data_dim) 

49 assert priors.shape == (num_components,) 

50 assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps) 

51 

52 super().__init__( 

53 dist_params={ 

54 "means": means, 

55 "cov_factors": cov_factors, 

56 "priors": priors, 

57 }, 

58 dist_hparams={}, 

59 ) 

60 

61 def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]: 

62 """ 

63 Draws samples from the low-rank GMM distribution. 

64 

65 Args: 

66 key (``Array``): JAX PRNG key for random sampling. 

67 num_samples (``int``): The total number of samples to generate. 

68 

69 Returns: 

70 ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. 

71 """ 

72 covs = jax.vmap( 

73 lambda low_rank_cov_factor: low_rank_cov_factor @ low_rank_cov_factor.T 

74 )(self.dist_params["cov_factors"]) 

75 return GMM(self.dist_params["means"], covs, self.dist_params["priors"]).sample( 

76 key, num_samples 

77 ) 

78 

79 def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

80 """ 

81 Computes the score vector field ``∇_x log p_t(x_t)`` for the low-rank GMM distribution. 

82 

83 This is calculated with respect to the perturbed distribution ``p_t`` induced by the 

84 ``diffusion_process`` at time ``t``. 

85 

86 Args: 

87 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

88 t (``Array[]``): The time step (scalar). 

89 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

90 

91 Returns: 

92 ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. 

93 """ 

94 return low_rank_gmm_score( 

95 x_t, 

96 t, 

97 diffusion_process, 

98 self.dist_params["means"], 

99 self.dist_params["cov_factors"], 

100 self.dist_params["priors"], 

101 ) 

102 

103 def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

104 """ 

105 Computes the denoised prediction x0 = E[x_0 | x_t] for the low-rank GMM distribution. 

106 

107 This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` 

108 at time ``t`` under the ``diffusion_process``. 

109 

110 Args: 

111 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

112 t (``Array[]``): The time step (scalar). 

113 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

114 

115 Returns: 

116 ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. 

117 """ 

118 return low_rank_gmm_x0( 

119 x_t, 

120 t, 

121 diffusion_process, 

122 self.dist_params["means"], 

123 self.dist_params["cov_factors"], 

124 self.dist_params["priors"], 

125 ) 

126 

127 def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

128 """ 

129 Computes the noise prediction ε for the low-rank GMM distribution. 

130 

131 This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` 

132 at time ``t`` under the ``diffusion_process``. 

133 

134 Args: 

135 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

136 t (``Array[]``): The time step (scalar). 

137 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

138 

139 Returns: 

140 ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. 

141 """ 

142 return low_rank_gmm_eps( 

143 x_t, 

144 t, 

145 diffusion_process, 

146 self.dist_params["means"], 

147 self.dist_params["cov_factors"], 

148 self.dist_params["priors"], 

149 ) 

150 

151 def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

152 """ 

153 Computes the velocity vector field ``v`` for the low-rank GMM distribution. 

154 

155 This is the conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. 

156 

157 Args: 

158 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

159 t (``Array[]``): The time step (scalar). 

160 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

161 

162 Returns: 

163 ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. 

164 """ 

165 return low_rank_gmm_v( 

166 x_t, 

167 t, 

168 diffusion_process, 

169 self.dist_params["means"], 

170 self.dist_params["cov_factors"], 

171 self.dist_params["priors"], 

172 ) 

173 

174 

175def low_rank_gmm_x0( 

176 x_t: Array, 

177 t: Array, 

178 diffusion_process: DiffusionProcess, 

179 means: Array, 

180 cov_factors: Array, 

181 priors: Array, 

182) -> Array: 

183 """ 

184 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a low-rank GMM. 

185 

186 This implements the closed-form solution for the conditional expectation 

187 ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the low-rank GMM distribution 

188 defined by ``means``, ``cov_factors``, and ``priors``. 

189 

190 Args: 

191 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

192 t (``Array[]``): The time step (scalar). 

193 diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. 

194 means (``Array[num_components, data_dim]``): Means of the GMM components. 

195 cov_factors (``Array[num_components, data_dim, rank]``): Low-rank covariance matrices of the GMM components. 

196 priors (``Array[num_components]``): Mixture weights of the GMM components. 

197 

198 Returns: 

199 ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. 

200 """ 

201 num_components, data_dim, rank = cov_factors.shape 

202 alpha_t = diffusion_process.alpha(t) 

203 sigma_t = diffusion_process.sigma(t) 

204 

205 means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim) 

206 inner_covs = jax.vmap(lambda cov_factor: cov_factor.T @ cov_factor)( 

207 cov_factors 

208 ) # (num_components, rank, rank) 

209 

210 xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)( 

211 means_t 

212 ) # (num_components, data_dim) 

213 covs_t_inverse_xbars_t = jax.vmap( 

214 lambda cov_factor, inner_cov, xbar_t: (1 / sigma_t**2) 

215 * ( 

216 xbar_t 

217 - cov_factor 

218 @ _lstsq( 

219 inner_cov + (sigma_t / alpha_t) ** 2 * jnp.eye(rank), 

220 cov_factor.T @ xbar_t, 

221 ) 

222 ) 

223 )(cov_factors, inner_covs, xbars_t) # (num_components, data_dim) 

224 

225 logdets_covs_t = jax.vmap( 

226 lambda inner_cov: _logdeth((alpha_t / sigma_t) ** 2 * inner_cov + jnp.eye(rank)) 

227 )(inner_covs) + 2 * data_dim * jnp.log(sigma_t) # (num_components,) 

228 

229 log_likelihoods_unnormalized = jax.vmap( 

230 lambda xbar_t, covs_t_inverse_xbar_t, logdet_covs_t: -(1 / 2) 

231 * (jnp.sum(xbar_t * covs_t_inverse_xbar_t) + logdet_covs_t) 

232 )(xbars_t, covs_t_inverse_xbars_t, logdets_covs_t) # (num_components,) 

233 

234 log_posterior_unnormalized = ( 

235 jnp.log(priors) + log_likelihoods_unnormalized 

236 ) # (num_components,) 

237 

238 posterior_probs = jax.nn.softmax( 

239 log_posterior_unnormalized, axis=0 

240 ) # (num_components,) 

241 

242 posterior_means = jax.vmap( 

243 lambda mean, cov_factor, covs_t_inverse_xbar_t: mean 

244 + alpha_t * cov_factor @ (cov_factor.T @ covs_t_inverse_xbar_t) 

245 )(means, cov_factors, covs_t_inverse_xbars_t) # (num_components, data_dim) 

246 

247 x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,) 

248 

249 return x0_pred 

250 

251 

252# Generate eps, score, v functions from low_rank_gmm_x0 

253low_rank_gmm_eps, low_rank_gmm_score, low_rank_gmm_v = create_gmm_vector_field_fns( 

254 low_rank_gmm_x0 

255)