Coverage for src/diffusionlab/distributions/empirical.py: 100%
85 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from dataclasses import dataclass
2from typing import Iterable, Tuple, cast
4import jax
5from jax import Array, numpy as jnp
7from diffusionlab.dynamics import DiffusionProcess
8from diffusionlab.distributions.base import Distribution
9from diffusionlab.vector_fields import VectorFieldType, convert_vector_field_type
12@dataclass(frozen=True)
13class EmpiricalDistribution(Distribution):
14 """
15 An empirical distribution, i.e., the uniform distribution over a dataset.
16 The probability measure is defined as:
18 ``μ(A) = (1/N) * sum_{i=1}^{num_samples} delta(x_i in A)``
20 where ``x_i`` is the ith data point in the dataset, and ``N`` is the number of data points.
22 This class provides methods for sampling from the empirical distribution and computing various vector fields (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) related to the distribution under a given diffusion process.
24 Attributes:
25 dist_params (``Dict[str, Array]``): Dictionary containing distribution parameters (currently unused).
26 dist_hparams (``Dict[str, Any]``): Dictionary for storing hyperparameters. It may contain the following keys:
28 - ``labeled_data`` (``Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]]``): An iterable of data whose elements (samples) are tuples of (data batch, label batch). The label batch can be ``None`` if the data is unlabelled.
29 """
31 def __init__(
32 self, labeled_data: Iterable[Tuple[Array, Array]] | Iterable[Tuple[Array, None]]
33 ):
34 super().__init__(
35 dist_params={},
36 dist_hparams={"labeled_data": labeled_data},
37 )
39 def sample(
40 self, key: Array, num_samples: int
41 ) -> Tuple[Array, Array] | Tuple[Array, None]:
42 """
43 Sample from the empirical distribution using reservoir sampling.
44 Assumes all batches in ``labeled_data`` are consistent: either all have labels (``Array``)
45 or none have labels (``None``).
47 Args:
48 key (``Array``): The JAX PRNG key to use for sampling.
49 num_samples (``int``): The number of samples to draw.
51 Returns:
52 ``Tuple[Array[num_samples, *data_dims], Array[num_samples, *label_dims]] | Tuple[Array[num_samples, *data_dims], None]``: A tuple ``(samples, labels)`` containing the samples and corresponding labels (stacked into an ``Array``), or ``(samples, None)`` if the data is unlabelled.
53 """
54 data_iterator = iter(self.dist_hparams["labeled_data"]) # Get an iterator
56 # Initialize reservoir
57 reservoir_samples = []
58 reservoir_labels = [] # Will store labels if present, otherwise remains empty
59 items_seen = 0
60 is_labeled = None # Determine based on first batch
62 for X_batch, y_batch in data_iterator:
63 # Determine if data is labeled based on the first batch encountered
64 if is_labeled is None:
65 is_labeled = y_batch is not None
66 if is_labeled:
67 # Basic validation for the first labeled batch
68 if (
69 not isinstance(y_batch, jnp.ndarray)
70 or y_batch.shape[0] != X_batch.shape[0]
71 ):
72 raise ValueError(
73 f"First labeled batch has inconsistent shape. X shape: {X_batch.shape}, Y shape: {getattr(y_batch, 'shape', 'N/A')}"
74 )
75 # else: y_batch is None, is_labeled remains False
77 current_batch_size = X_batch.shape[0]
79 # Reservoir sampling
80 for i in range(current_batch_size):
81 x = X_batch[i]
82 y = y_batch[i] if is_labeled else None
84 if items_seen < num_samples:
85 reservoir_samples.append(x)
86 if is_labeled:
87 reservoir_labels.append(y)
88 else:
89 key, subkey = jax.random.split(key)
90 j = jax.random.randint(
91 subkey, shape=(), minval=0, maxval=items_seen + 1
92 )
93 if j < num_samples:
94 reservoir_samples[j] = x
95 if is_labeled:
96 reservoir_labels[j] = y
98 items_seen += 1
100 # Final checks and return
101 if items_seen < num_samples:
102 raise ValueError(
103 f"Requested {num_samples} samples, but only {items_seen} items are available in the dataset."
104 )
106 # Stack samples into a single array
107 stacked_samples = jnp.stack(reservoir_samples)
109 # Stack labels if data was labeled, otherwise return None
110 stacked_labels = None
111 if is_labeled:
112 stacked_labels = jnp.stack(reservoir_labels)
113 return stacked_samples, stacked_labels
114 else:
115 return stacked_samples, None
117 def score(
118 self,
119 x_t: Array,
120 t: Array,
121 diffusion_process: DiffusionProcess,
122 ) -> Array:
123 """
124 Computes the score function (``∇_x log p_t(x)``) of the empirical distribution at time ``t``,
125 given the noisy state ``x_t`` and the diffusion process.
127 Args:
128 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``.
129 t (``Array[]``): The time tensor.
130 diffusion_process (``DiffusionProcess``): The diffusion process.
132 Returns:
133 ``Array[*data_dims]``: The score of the empirical distribution at ``(x_t, t)``.
134 """
135 x0_x_t = self.x0(x_t, t, diffusion_process)
136 alpha_t = diffusion_process.alpha(t)
137 sigma_t = diffusion_process.sigma(t)
138 alpha_prime_t = diffusion_process.alpha_prime(t)
139 sigma_prime_t = diffusion_process.sigma_prime(t)
140 score_x_t = convert_vector_field_type(
141 x_t,
142 x0_x_t,
143 alpha_t,
144 sigma_t,
145 alpha_prime_t,
146 sigma_prime_t,
147 VectorFieldType.X0,
148 VectorFieldType.SCORE,
149 )
150 return score_x_t
152 def x0(
153 self,
154 x_t: Array,
155 t: Array,
156 diffusion_process: DiffusionProcess,
157 ) -> Array:
158 """
159 Computes the denoiser ``E[x_0 | x_t]`` for an empirical distribution w.r.t. a given diffusion process.
161 This method computes the denoiser by performing a weighted average of the
162 dataset samples, where the weights are determined by the likelihood of ``x_t``
163 given each sample.
165 Arguments:
166 x_t (``Array[*data_dims]``): The input tensor.
167 t (``Array[]``): The time tensor.
168 diffusion_process (``DiffusionProcess``): The diffusion process.
170 Returns:
171 ``Array[*data_dims]``: The prediction of ``x_0``.
172 """
173 data = self.dist_hparams["labeled_data"]
175 alpha_t = diffusion_process.alpha(t)
176 sigma_t = diffusion_process.sigma(t)
178 softmax_denom = jnp.zeros_like(t)
179 x0_hat = jnp.zeros_like(x_t)
180 for X_batch, y_batch in data:
181 squared_dists = jax.vmap(lambda x: jnp.sum((x_t - alpha_t * x) ** 2))(
182 X_batch
183 )
184 exp_negative_dists = jnp.exp(-squared_dists / (2 * sigma_t**2))
185 softmax_denom += jnp.sum(exp_negative_dists)
186 x0_hat += jnp.sum(
187 jax.vmap(lambda xi, ei: xi * ei)(
188 X_batch, exp_negative_dists
189 ),
190 axis=0,
191 )
193 eps = cast(float, jnp.finfo(softmax_denom.dtype).eps)
194 softmax_denom = jnp.maximum(softmax_denom, eps)
195 x0_hat = x0_hat / softmax_denom
196 return x0_hat
198 def eps(
199 self,
200 x_t: Array,
201 t: Array,
202 diffusion_process: DiffusionProcess,
203 ) -> Array:
204 """
205 Computes the noise field ``eps(x_t, t)`` for an empirical distribution w.r.t. a given diffusion process.
207 Args:
208 x_t (``Array[*data_dims]``): The input tensor.
209 t (``Array[]``): The time tensor.
210 diffusion_process (``DiffusionProcess``): The diffusion process.
212 Returns:
213 ``Array[*data_dims]``: The noise field at ``(x_t, t)``.
214 """
215 x0_x_t = self.x0(x_t, t, diffusion_process)
216 alpha_t = diffusion_process.alpha(t)
217 sigma_t = diffusion_process.sigma(t)
218 alpha_prime_t = diffusion_process.alpha_prime(t)
219 sigma_prime_t = diffusion_process.sigma_prime(t)
220 eps_x_t = convert_vector_field_type(
221 x_t,
222 x0_x_t,
223 alpha_t,
224 sigma_t,
225 alpha_prime_t,
226 sigma_prime_t,
227 VectorFieldType.X0,
228 VectorFieldType.EPS,
229 )
230 return eps_x_t
232 def v(
233 self,
234 x_t: Array,
235 t: Array,
236 diffusion_process: DiffusionProcess,
237 ) -> Array:
238 """
239 Computes the velocity field ``v(x_t, t)`` for an empirical distribution w.r.t. a given diffusion process.
241 Args:
242 x_t (``Array[*data_dims]``): The input tensor.
243 t (``Array[]``): The time tensor.
244 diffusion_process (``DiffusionProcess``): The diffusion process.
246 Returns:
247 ``Array[*data_dims]``: The velocity field at ``(x_t, t)``.
248 """
249 x0_x_t = self.x0(x_t, t, diffusion_process)
250 alpha_t = diffusion_process.alpha(t)
251 sigma_t = diffusion_process.sigma(t)
252 alpha_prime_t = diffusion_process.alpha_prime(t)
253 sigma_prime_t = diffusion_process.sigma_prime(t)
254 v_x_t = convert_vector_field_type(
255 x_t,
256 x0_x_t,
257 alpha_t,
258 sigma_t,
259 alpha_prime_t,
260 sigma_prime_t,
261 VectorFieldType.X0,
262 VectorFieldType.V,
263 )
264 return v_x_t