Coverage for src/diffusionlab/distributions/empirical.py: 100%

85 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from dataclasses import dataclass 

2from typing import Iterable, Tuple, cast 

3 

4import jax 

5from jax import Array, numpy as jnp 

6 

7from diffusionlab.dynamics import DiffusionProcess 

8from diffusionlab.distributions.base import Distribution 

9from diffusionlab.vector_fields import VectorFieldType, convert_vector_field_type 

10 

11 

12@dataclass(frozen=True) 

13class EmpiricalDistribution(Distribution): 

14 """ 

15 An empirical distribution, i.e., the uniform distribution over a dataset. 

16 The probability measure is defined as: 

17 

18 ``μ(A) = (1/N) * sum_{i=1}^{num_samples} delta(x_i in A)`` 

19 

20 where ``x_i`` is the ith data point in the dataset, and ``N`` is the number of data points. 

21 

22 This class provides methods for sampling from the empirical distribution and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process. 

23 

24 Attributes: 

25 dist_params (``Dict[str, Array]``): Dictionary containing distribution parameters (currently unused). 

26 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters. It may contain the following keys: 

27 

28 - ``labeled_data`` (``Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]]``): An iterable of data whose elements (samples) are tuples of (data batch, label batch). The label batch can be ``None`` if the data is unlabelled. 

29 """ 

30 

31 def __init__( 

32 self, labeled_data: Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]] 

33 ): 

34 super().__init__( 

35 dist_params={}, 

36 dist_hparams={"labeled_data": labeled_data}, 

37 ) 

38 

39 def sample( 

40 self, key: Array, num_samples: int 

41 ) -> Tuple[Array, Array] | Tuple[Array, None]: 

42 """ 

43 Sample from the empirical distribution using reservoir sampling. 

44 Assumes all batches in ``labeled_data`` are consistent: either all have labels (``Array``) 

45 or none have labels (``None``). 

46 

47 Args: 

48 key (``Array``): The JAX PRNG key to use for sampling. 

49 num_samples (``int``): The number of samples to draw. 

50 

51 Returns: 

52 ``Tuple[Array[num_samples, *data_dims], Array[num_samples, *label_dims]] | Tuple[Array[num_samples, *data_dims], None]``: A tuple ``(samples, labels)`` containing the samples and corresponding labels (stacked into an ``Array``), or ``(samples, None)`` if the data is unlabelled. 

53 """ 

54 data_iterator = iter(self.dist_hparams["labeled_data"]) # Get an iterator 

55 

56 # Initialize reservoir 

57 reservoir_samples = [] 

58 reservoir_labels = [] # Will store labels if present, otherwise remains empty 

59 items_seen = 0 

60 is_labeled = None # Determine based on first batch 

61 

62 for X_batch, y_batch in data_iterator: 

63 # Determine if data is labeled based on the first batch encountered 

64 if is_labeled is None: 

65 is_labeled = y_batch is not None 

66 if is_labeled: 

67 # Basic validation for the first labeled batch 

68 if ( 

69 not isinstance(y_batch, jnp.ndarray) 

70 or y_batch.shape[0] != X_batch.shape[0] 

71 ): 

72 raise ValueError( 

73 f"First labeled batch has inconsistent shape. X shape: {X_batch.shape}, Y shape: {getattr(y_batch, 'shape', 'N/A')}" 

74 ) 

75 # else: y_batch is None, is_labeled remains False 

76 

77 current_batch_size = X_batch.shape[0] 

78 

79 # Reservoir sampling 

80 for i in range(current_batch_size): 

81 x = X_batch[i] 

82 y = y_batch[i] if is_labeled else None 

83 

84 if items_seen < num_samples: 

85 reservoir_samples.append(x) 

86 if is_labeled: 

87 reservoir_labels.append(y) 

88 else: 

89 key, subkey = jax.random.split(key) 

90 j = jax.random.randint( 

91 subkey, shape=(), minval=0, maxval=items_seen + 1 

92 ) 

93 if j < num_samples: 

94 reservoir_samples[j] = x 

95 if is_labeled: 

96 reservoir_labels[j] = y 

97 

98 items_seen += 1 

99 

100 # Final checks and return 

101 if items_seen < num_samples: 

102 raise ValueError( 

103 f"Requested {num_samples} samples, but only {items_seen} items are available in the dataset." 

104 ) 

105 

106 # Stack samples into a single array 

107 stacked_samples = jnp.stack(reservoir_samples) 

108 

109 # Stack labels if data was labeled, otherwise return None 

110 stacked_labels = None 

111 if is_labeled: 

112 stacked_labels = jnp.stack(reservoir_labels) 

113 return stacked_samples, stacked_labels 

114 else: 

115 return stacked_samples, None 

116 

117 def score( 

118 self, 

119 x_t: Array, 

120 t: Array, 

121 diffusion_process: DiffusionProcess, 

122 ) -> Array: 

123 """ 

124 Computes the score function (``∇_x log p_t(x)``) of the empirical distribution at time ``t``, 

125 given the noisy state ``x_t`` and the diffusion process. 

126 

127 Args: 

128 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. 

129 t (``Array[]``): The time tensor. 

130 diffusion_process (``DiffusionProcess``): The diffusion process. 

131 

132 Returns: 

133 ``Array[*data_dims]``: The score of the empirical distribution at ``(x_t, t)``. 

134 """ 

135 x0_x_t = self.x0(x_t, t, diffusion_process) 

136 alpha_t = diffusion_process.alpha(t) 

137 sigma_t = diffusion_process.sigma(t) 

138 alpha_prime_t = diffusion_process.alpha_prime(t) 

139 sigma_prime_t = diffusion_process.sigma_prime(t) 

140 score_x_t = convert_vector_field_type( 

141 x_t, 

142 x0_x_t, 

143 alpha_t, 

144 sigma_t, 

145 alpha_prime_t, 

146 sigma_prime_t, 

147 VectorFieldType.X0, 

148 VectorFieldType.SCORE, 

149 ) 

150 return score_x_t 

151 

152 def x0( 

153 self, 

154 x_t: Array, 

155 t: Array, 

156 diffusion_process: DiffusionProcess, 

157 ) -> Array: 

158 """ 

159 Computes the denoiser ``E[x_0 | x_t]`` for an empirical distribution w.r.t. a given diffusion process. 

160 

161 This method computes the denoiser by performing a weighted average of the 

162 dataset samples, where the weights are determined by the likelihood of ``x_t`` 

163 given each sample. 

164 

165 Arguments: 

166 x_t (``Array[*data_dims]``): The input tensor. 

167 t (``Array[]``): The time tensor. 

168 diffusion_process (``DiffusionProcess``): The diffusion process. 

169 

170 Returns: 

171 ``Array[*data_dims]``: The prediction of ``x_0``. 

172 """ 

173 data = self.dist_hparams["labeled_data"] 

174 

175 alpha_t = diffusion_process.alpha(t) 

176 sigma_t = diffusion_process.sigma(t) 

177 

178 softmax_denom = jnp.zeros_like(t) 

179 x0_hat = jnp.zeros_like(x_t) 

180 for X_batch, y_batch in data: 

181 squared_dists = jax.vmap(lambda x: jnp.sum((x_t - alpha_t * x) ** 2))( 

182 X_batch 

183 ) 

184 exp_negative_dists = jnp.exp(-squared_dists / (2 * sigma_t**2)) 

185 softmax_denom += jnp.sum(exp_negative_dists) 

186 x0_hat += jnp.sum( 

187 jax.vmap(lambda xi, ei: xi * ei)( 

188 X_batch, exp_negative_dists 

189 ), 

190 axis=0, 

191 ) 

192 

193 eps = cast(float, jnp.finfo(softmax_denom.dtype).eps) 

194 softmax_denom = jnp.maximum(softmax_denom, eps) 

195 x0_hat = x0_hat / softmax_denom 

196 return x0_hat 

197 

198 def eps( 

199 self, 

200 x_t: Array, 

201 t: Array, 

202 diffusion_process: DiffusionProcess, 

203 ) -> Array: 

204 """ 

205 Computes the noise field ``eps(x_t, t)`` for an empirical distribution w.r.t. a given diffusion process. 

206 

207 Args: 

208 x_t (``Array[*data_dims]``): The input tensor. 

209 t (``Array[]``): The time tensor. 

210 diffusion_process (``DiffusionProcess``): The diffusion process. 

211 

212 Returns: 

213 ``Array[*data_dims]``: The noise field at ``(x_t, t)``. 

214 """ 

215 x0_x_t = self.x0(x_t, t, diffusion_process) 

216 alpha_t = diffusion_process.alpha(t) 

217 sigma_t = diffusion_process.sigma(t) 

218 alpha_prime_t = diffusion_process.alpha_prime(t) 

219 sigma_prime_t = diffusion_process.sigma_prime(t) 

220 eps_x_t = convert_vector_field_type( 

221 x_t, 

222 x0_x_t, 

223 alpha_t, 

224 sigma_t, 

225 alpha_prime_t, 

226 sigma_prime_t, 

227 VectorFieldType.X0, 

228 VectorFieldType.EPS, 

229 ) 

230 return eps_x_t 

231 

232 def v( 

233 self, 

234 x_t: Array, 

235 t: Array, 

236 diffusion_process: DiffusionProcess, 

237 ) -> Array: 

238 """ 

239 Computes the velocity field ``v(x_t, t)`` for an empirical distribution w.r.t. a given diffusion process. 

240 

241 Args: 

242 x_t (``Array[*data_dims]``): The input tensor. 

243 t (``Array[]``): The time tensor. 

244 diffusion_process (``DiffusionProcess``): The diffusion process. 

245 

246 Returns: 

247 ``Array[*data_dims]``: The velocity field at ``(x_t, t)``. 

248 """ 

249 x0_x_t = self.x0(x_t, t, diffusion_process) 

250 alpha_t = diffusion_process.alpha(t) 

251 sigma_t = diffusion_process.sigma(t) 

252 alpha_prime_t = diffusion_process.alpha_prime(t) 

253 sigma_prime_t = diffusion_process.sigma_prime(t) 

254 v_x_t = convert_vector_field_type( 

255 x_t, 

256 x0_x_t, 

257 alpha_t, 

258 sigma_t, 

259 alpha_prime_t, 

260 sigma_prime_t, 

261 VectorFieldType.X0, 

262 VectorFieldType.V, 

263 ) 

264 return v_x_t