Coverage for src/diffusionlab/distributions/gmm/iso_gmm.py: 100%

47 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from dataclasses import dataclass 

2from typing import Tuple, cast 

3from jax import numpy as jnp, Array 

4import jax 

5from diffusionlab.distributions.base import Distribution 

6from diffusionlab.distributions.gmm.gmm import GMM 

7from diffusionlab.distributions.gmm.utils import create_gmm_vector_field_fns 

8from diffusionlab.dynamics import DiffusionProcess 

9 

10 

11@dataclass(frozen=True) 

12class IsoGMM(Distribution): 

13 """ 

14 Implements an isotropic Gaussian Mixture Model (GMM) distribution. 

15 

16 The probability measure is given by: 

17 

18 ``μ(A) = sum_{i=1}^{num_components} priors[i] * N(A; means[i], variances[i] * I)`` 

19 

20 This class provides methods for sampling from the GMM and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. 

21 

22 Attributes: 

23 dist_params (``Dict[str, Array]``): Dictionary containing the core GMM parameters. 

24 

25 - ``means`` (``Array[num_components, data_dim]``): The means of the GMM components. 

26 - ``variances`` (``Array[num_components]``): The variances of the GMM components. 

27 - ``priors`` (``Array[num_components]``): The prior probabilities (mixture weights) of the GMM components. 

28 

29 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters (currently unused). 

30 """ 

31 

32 def __init__(self, means: Array, variances: Array, priors: Array): 

33 """ 

34 Initializes the isotropic GMM distribution. 

35 

36 Args: 

37 means (``Array[num_components, data_dim]``): Means for each Gaussian component. 

38 variances (``Array[num_components]``): Variance for each Gaussian component. 

39 priors (``Array[num_components]``): Mixture weights for each component. Must sum to 1. 

40 """ 

41 eps = cast(float, jnp.finfo(variances.dtype).eps) 

42 assert means.ndim == 2 

43 num_components, data_dim = means.shape 

44 assert variances.shape == (num_components,) 

45 assert priors.shape == (num_components,) 

46 assert jnp.isclose(jnp.sum(priors), 1.0, atol=eps) 

47 assert jnp.all(variances >= -eps) 

48 

49 super().__init__( 

50 dist_params={ 

51 "means": means, 

52 "variances": variances, 

53 "priors": priors, 

54 }, 

55 dist_hparams={}, 

56 ) 

57 

58 def sample(self, key: Array, num_samples: int) -> Tuple[Array, Array]: 

59 """ 

60 Draws samples from the isotropic GMM distribution. 

61 

62 Args: 

63 key (``Array``): JAX PRNG key for random sampling. 

64 num_samples (``int``): The total number of samples to generate. 

65 

66 Returns: 

67 ``Tuple[Array[num_samples, data_dim], Array[num_samples]]``: A tuple ``(samples, component_indices)`` containing the drawn samples and the index of the GMM component from which each sample was drawn. 

68 """ 

69 data_dim = self.dist_params["means"].shape[1] 

70 covs = jax.vmap(lambda variance: variance * jnp.eye(data_dim))( 

71 self.dist_params["variances"] 

72 ) 

73 base_gmm = GMM(self.dist_params["means"], covs, self.dist_params["priors"]) 

74 return base_gmm.sample(key, num_samples) 

75 

76 def score(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

77 """ 

78 Computes the score vector field ``∇_x log p_t(x_t)`` for the isotropic GMM distribution. 

79 

80 This is calculated with respect to the perturbed distribution ``p_t`` induced by the 

81 ``diffusion_process`` at time ``t``. 

82 

83 Args: 

84 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

85 t (``Array[]``): The time step (scalar). 

86 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

87 

88 Returns: 

89 ``Array[data_dim]``: The score vector field evaluated at ``x_t`` and ``t``. 

90 """ 

91 return iso_gmm_score( 

92 x_t, 

93 t, 

94 diffusion_process, 

95 self.dist_params["means"], 

96 self.dist_params["variances"], 

97 self.dist_params["priors"], 

98 ) 

99 

100 def x0(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

101 """ 

102 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for the isotropic GMM distribution. 

103 

104 This represents the expected original sample ``x_0`` given the noisy observation ``x_t`` 

105 at time ``t`` under the ``diffusion_process``. 

106 

107 Args: 

108 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

109 t (``Array[]``): The time step (scalar). 

110 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

111 

112 Returns: 

113 ``Array[data_dim]``: The denoised prediction vector field ``x0`` evaluated at ``x_t`` and ``t``. 

114 """ 

115 return iso_gmm_x0( 

116 x_t, 

117 t, 

118 diffusion_process, 

119 self.dist_params["means"], 

120 self.dist_params["variances"], 

121 self.dist_params["priors"], 

122 ) 

123 

124 def eps(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

125 """ 

126 Computes the noise prediction ``ε`` for the isotropic GMM distribution. 

127 

128 This predicts the noise that was added to the original sample ``x_0`` to obtain ``x_t`` 

129 at time ``t`` under the ``diffusion_process``. 

130 

131 Args: 

132 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

133 t (``Array[]``): The time step (scalar). 

134 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

135 

136 Returns: 

137 ``Array[data_dim]``: The noise prediction vector field ``ε`` evaluated at ``x_t`` and ``t``. 

138 """ 

139 return iso_gmm_eps( 

140 x_t, 

141 t, 

142 diffusion_process, 

143 self.dist_params["means"], 

144 self.dist_params["variances"], 

145 self.dist_params["priors"], 

146 ) 

147 

148 def v(self, x_t: Array, t: Array, diffusion_process: DiffusionProcess) -> Array: 

149 """ 

150 Computes the velocity vector field ``v`` for the isotropic GMM distribution. 

151 

152 This is conditional velocity ``E[dx_t/dt | x_t]`` under the ``diffusion_process``. 

153 

154 Args: 

155 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

156 t (``Array[]``): The time step (scalar). 

157 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

158 

159 Returns: 

160 ``Array[data_dim]``: The velocity vector field ``v`` evaluated at ``x_t`` and ``t``. 

161 """ 

162 return iso_gmm_v( 

163 x_t, 

164 t, 

165 diffusion_process, 

166 self.dist_params["means"], 

167 self.dist_params["variances"], 

168 self.dist_params["priors"], 

169 ) 

170 

171 

172def iso_gmm_x0( 

173 x_t: Array, 

174 t: Array, 

175 diffusion_process: DiffusionProcess, 

176 means: Array, 

177 variances: Array, 

178 priors: Array, 

179) -> Array: 

180 """ 

181 Computes the denoised prediction ``x0 = E[x_0 | x_t]`` for a GMM. 

182 

183 This implements the closed-form solution for the conditional expectation 

184 ``E[x_0 | x_t]`` where ``x_t ~ N(α_t x_0, σ_t^2 I)`` and ``x_0`` follows the GMM distribution 

185 defined by ``means``, ``covs``, and ``priors``. 

186 

187 Args: 

188 x_t (``Array[data_dim]``): The noisy state tensor at time ``t``. 

189 t (``Array[]``): The time step (scalar). 

190 diffusion_process (``DiffusionProcess``): Provides ``α(t)`` and ``σ(t)``. 

191 means (``Array[num_components, data_dim]``): Means of the GMM components. 

192 variances (``Array[num_components]``): Covariances of the GMM components. 

193 priors (``Array[num_components]``): Mixture weights of the GMM components. 

194 

195 Returns: 

196 ``Array[data_dim]``: The denoised prediction ``x0`` evaluated at ``x_t`` and ``t``. 

197 """ 

198 num_components, data_dim = means.shape 

199 alpha_t = diffusion_process.alpha(t) 

200 sigma_t = diffusion_process.sigma(t) 

201 

202 means_t = jax.vmap(lambda mean: alpha_t * mean)(means) # (num_components, data_dim) 

203 variances_t = jax.vmap(lambda variance: alpha_t**2 * variance + sigma_t**2)( 

204 variances 

205 ) # (num_components,) 

206 

207 xbars_t = jax.vmap(lambda mean_t: x_t - mean_t)( 

208 means_t 

209 ) # (num_components, data_dim) 

210 variances_t_inv_xbars_t = jax.vmap(lambda variance_t, xbar_t: xbar_t / variance_t)( 

211 variances_t, xbars_t 

212 ) # (num_components, data_dim) 

213 

214 log_likelihoods_unnormalized = jax.vmap( 

215 lambda xbar_t, variance_t, variance_t_inv_xbar_t: -0.5 

216 * (jnp.sum(xbar_t * variance_t_inv_xbar_t) + data_dim * jnp.log(variance_t)) 

217 )(xbars_t, variances_t, variances_t_inv_xbars_t) # (num_components,) 

218 log_posterior_unnormalized = ( 

219 jnp.log(priors) + log_likelihoods_unnormalized 

220 ) # (num_components,) 

221 posterior_probs = jax.nn.softmax( 

222 log_posterior_unnormalized, axis=0 

223 ) # (num_components,) sum to 1 

224 

225 posterior_means = jax.vmap( 

226 lambda mean, variance, variance_t_inv_xbar_t: mean 

227 + alpha_t * variance * variance_t_inv_xbar_t 

228 )(means, variances, variances_t_inv_xbars_t) # (num_components, data_dim) 

229 

230 x0_pred = jnp.sum(posterior_probs[:, None] * posterior_means, axis=0) # (data_dim,) 

231 

232 return x0_pred 

233 

234 

235# Generate eps, score, v functions from iso_gmm_x0 

236iso_gmm_eps, iso_gmm_score, iso_gmm_v = create_gmm_vector_field_fns(iso_gmm_x0)