Coverage for src/diffusionlab/samplers.py: 100%
160 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
« prev ^ index » next coverage.py v7.6.12, created at 2025-04-19 14:17 -0700
1from typing import Callable, Tuple
3from dataclasses import dataclass, field
4import jax
5from jax import Array, numpy as jnp
7from diffusionlab.dynamics import DiffusionProcess
8from diffusionlab.vector_fields import (
9 VectorFieldType,
10 convert_vector_field_type,
11)
14@dataclass
15class Sampler:
16 """
17 Base class for sampling from diffusion models using various vector field types.
19 A Sampler combines a diffusion process, a vector field prediction function, and a scheduler
20 to generate samples from a trained diffusion model using the reverse process (denoising/sampling).
22 The sampler supports different vector field types (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) and can perform both stochastic and deterministic sampling based on the subclass implementation and the `use_stochastic_sampler`` flag.
24 Attributes:
25 diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics.
26 vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field.
27 Takes the current state ``x_t`` and time ``t`` as input.
28 vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``.
29 use_stochastic_sampler (``bool``): Whether to use a stochastic or deterministic reverse process.
30 sample_step (``Callable[[int, Array, Array, Array], Array]``): The specific function used to perform one sampling step.
31 Takes step index ``idx``, current state ``x_t``, noise array ``zs``, and time schedule ``ts`` as input.
32 Set during initialization based on the sampler type and ``use_stochastic_sampler``.
33 """
35 diffusion_process: DiffusionProcess
36 vector_field: Callable[[Array, Array], Array]
37 vector_field_type: VectorFieldType
38 use_stochastic_sampler: bool
39 sample_step: Callable[[int, Array, Array, Array], Array] = field(init=False)
41 def __post_init__(self):
42 self.sample_step = self.get_sample_step_function()
44 def sample(self, x_init: Array, zs: Array, ts: Array) -> Array:
45 """
46 Sample from the model using the reverse diffusion process.
48 This method generates a final sample by iteratively applying the ``sample_step`` function,
49 starting from an initial state ``x_init`` and using the provided noise ``zs`` and time schedule ``ts``.
51 Args:
52 x_init (``Array[*data_dims]``): The initial noisy tensor from which to initialize sampling (typically sampled from the prior distribution at ``ts[0]``).
53 zs (``Array[num_steps, *data_dims]``): The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers.
54 ts (``Array[num_steps+1]``): The time schedule for sampling. A sorted decreasing array of times from ``t_max`` to ``t_min``.
56 Returns:
57 ``Array[*data_dims]``: The generated sample at the final time ``ts[-1]``.
58 """
60 def scan_fn(x, idx):
61 next_x = self.sample_step(idx, x, zs, ts)
62 return next_x, None
64 final_x, _ = jax.lax.scan(scan_fn, x_init, jnp.arange(zs.shape[0]))
66 return final_x
68 def sample_trajectory(self, x_init: Array, zs: Array, ts: Array) -> Array:
69 """
70 Sample a trajectory from the model using the reverse diffusion process.
72 This method generates the entire trajectory of intermediate samples by iteratively
73 applying the ``sample_step`` function.
75 Args:
76 x_init (``Array[*data_dims]``): The initial noisy tensor from which to start sampling (at time ``ts[0]``).
77 zs (``Array[num_steps, *data_dims]``): The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers.
78 ts (``Array[num_steps+1]``): The time schedule for sampling. A sorted decreasing array of times from ``t_max`` to ``t_min``.
80 Returns:
81 ``Array[num_steps+1, *data_dims]``: The complete generated trajectory including the initial state ``x_init``.
82 """
84 def scan_fn(x, idx):
85 next_x = self.sample_step(idx, x, zs, ts)
86 return next_x, next_x
88 _, xs = jax.lax.scan(scan_fn, x_init, jnp.arange(zs.shape[0]))
90 xs = jnp.concatenate([x_init[None, ...], xs], axis=0)
91 return xs
93 def get_sample_step_function(self) -> Callable[[int, Array, Array, Array], Array]:
94 """
95 Abstract method to get the appropriate sampling step function.
97 Subclasses must implement this method to return the specific function used
98 for performing one step of the reverse process, based on the sampler's
99 implementation details (e.g., integrator type) and the ``use_stochastic_sampler`` flag.
101 Returns:
102 ``Callable[[int, Array, Array, Array], Array]``: The sampling step function, which has signature:
104 ``(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]``
105 """
106 raise NotImplementedError
109@dataclass
110class EulerMaruyamaSampler(Sampler):
111 """
112 Class for sampling from diffusion models using the first-order Euler-Maruyama integrator
113 for the reverse process SDE/ODE.
115 This sampler implements the step function based on the Euler-Maruyama discretization
116 of the reverse SDE (if ``use_stochastic_sampler`` is True) or the corresponding
117 probability flow ODE (if ``use_stochastic_sampler`` is False). It supports all
118 vector field types (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``).
121 Attributes:
122 diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics.
123 vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field.
124 Takes the current state ``x_t`` and time ``t`` as input.
125 vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``.
126 use_stochastic_sampler (``bool``): Whether to use a stochastic or deterministic reverse process.
127 sample_step (``Callable[[int, Array, Array, Array], Array]``): The specific function used to perform one sampling step.
128 Takes step index ``idx``, current state ``x_t``, noise array ``zs``, and time schedule ``ts`` as input.
129 Set during initialization based on the sampler type and ``use_stochastic_sampler``.
130 """
132 def get_sample_step_function(self) -> Callable[[int, Array, Array, Array], Array]:
133 """
134 Get the appropriate Euler-Maruyama sampling step function based on the
135 vector field type and stochasticity.
137 Returns:
138 Callable[[int, Array, Array, Array], Array]: The specific Euler-Maruyama step function to use.
140 Signature: ``(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]``
141 """
142 match (self.vector_field_type, self.use_stochastic_sampler):
143 case (VectorFieldType.SCORE, False):
144 return self._sample_step_score_deterministic
145 case (VectorFieldType.SCORE, True):
146 return self._sample_step_score_stochastic
147 case (VectorFieldType.X0, False):
148 return self._sample_step_x0_deterministic
149 case (VectorFieldType.X0, True):
150 return self._sample_step_x0_stochastic
151 case (VectorFieldType.EPS, False):
152 return self._sample_step_eps_deterministic
153 case (VectorFieldType.EPS, True):
154 return self._sample_step_eps_stochastic
155 case (VectorFieldType.V, False):
156 return self._sample_step_v_deterministic
157 case (VectorFieldType.V, True):
158 return self._sample_step_v_stochastic
159 case _:
160 raise ValueError(
161 f"Unsupported vector field type: {self.vector_field_type} and stochasticity: {self.use_stochastic_sampler}"
162 )
164 def _get_step_quantities(
165 self,
166 idx: int,
167 zs: Array,
168 ts: Array,
169 ) -> Tuple[
170 Array, Array, Array, Array, Array, Array, Array, Array, Array, Array, Array
171 ]:
172 """
173 Calculate common quantities used in Euler-Maruyama sampling steps based on the diffusion process.
175 Args:
176 idx (``int``): Current step index (corresponds to time ``ts[idx]``).
177 zs (``Array[num_steps, *data_dims]``): Noise tensors for stochastic sampling. Only ``zs[idx]`` is used if needed.
178 ts (``Array[num_steps+1]``): Time schedule for sampling. Used to get ``ts[idx]`` and ``ts[idx+1]``.
180 Returns:
181 ``Tuple[Array[], Array[], Array[], Array[*data_dims], Array[], Array[], Array[], Array[], Array[], Array[], Array[]]``: A tuple containing
183 - t (``Array[]``): Current time ``ts[idx]``.
184 - t1 (``Array[]``): Next time ``ts[idx+1]``.
185 - dt (``Array[]``): Time difference ``(t1 - t)``, should be negative.
186 - dwt (``Array[*data_dims]``): Scaled noise increment ``sqrt(-dt) * zs[idx]`` for the stochastic step.
187 - alpha_t (``Array[]``): ``α`` at current time ``t``.
188 - sigma_t (``Array[]``): ``σ`` at current time ``t``.
189 - alpha_prime_t (``Array[]``): Derivative of ``α`` at current time ``t``.
190 - sigma_prime_t (``Array[]``): Derivative of ``σ`` at current time ``t``.
191 - alpha_ratio_t (``Array[]``): ``alpha_prime_t / alpha_t``.
192 - sigma_ratio_t (``Array[]``): ``sigma_prime_t / sigma_t``.
193 - diff_ratio_t (``Array[]``): ``sigma_ratio_t - alpha_ratio_t``.
194 """
195 t = ts[idx]
196 t1 = ts[idx + 1]
197 dt = t1 - t
198 dw_t = zs[idx] * jnp.sqrt(-dt) # dt is negative
200 alpha_t = self.diffusion_process.alpha(t)
201 sigma_t = self.diffusion_process.sigma(t)
202 alpha_prime_t = self.diffusion_process.alpha_prime(t)
203 sigma_prime_t = self.diffusion_process.sigma_prime(t)
204 alpha_ratio_t = alpha_prime_t / alpha_t
205 sigma_ratio_t = sigma_prime_t / sigma_t
206 diff_ratio_t = sigma_ratio_t - alpha_ratio_t
208 return (
209 t,
210 t1,
211 dt,
212 dw_t,
213 alpha_t,
214 sigma_t,
215 alpha_prime_t,
216 sigma_prime_t,
217 alpha_ratio_t,
218 sigma_ratio_t,
219 diff_ratio_t,
220 )
222 def _sample_step_score_deterministic(
223 self, idx: int, x_t: Array, zs: Array, ts: Array
224 ) -> Array:
225 """
226 Perform one deterministic Euler step using the score vector field (i.e., ``VectorFieldType.SCORE``).
227 Corresponds to the probability flow ODE associated with the score SDE.
229 Args:
230 idx (``int``): Current step index.
231 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``.
232 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused).
233 ts (``Array[num_steps+1]``): Time schedule.
235 Returns:
236 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``.
237 """
238 (
239 t,
240 t1,
241 dt,
242 dw_t,
243 alpha_t,
244 sigma_t,
245 alpha_prime_t,
246 sigma_prime_t,
247 alpha_ratio_t,
248 sigma_ratio_t,
249 diff_ratio_t,
250 ) = self._get_step_quantities(idx, zs, ts)
251 score_x_t = self.vector_field(x_t, t)
252 drift_t = alpha_ratio_t * x_t - (sigma_t**2) * diff_ratio_t * score_x_t
253 x_t1 = x_t + drift_t * dt
254 return x_t1
256 def _sample_step_score_stochastic(
257 self, idx: int, x_t: Array, zs: Array, ts: Array
258 ) -> Array:
259 """
260 Perform one stochastic Euler-Maruyama step using the score vector field (i.e., ``VectorFieldType.SCORE``).
261 Corresponds to discretizing the reverse SDE derived using the score field.
263 Args:
264 idx (``int``): Current step index.
265 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``.
266 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``.
267 ts (``Array[num_steps+1]``): Time schedule.
269 Returns:
270 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``.
271 """
272 (
273 t,
274 t1,
275 dt,
276 dw_t,
277 alpha_t,
278 sigma_t,
279 alpha_prime_t,
280 sigma_prime_t,
281 alpha_ratio_t,
282 sigma_ratio_t,
283 diff_ratio_t,
284 ) = self._get_step_quantities(idx, zs, ts)
285 score_x_t = self.vector_field(x_t, t)
286 drift_t = alpha_ratio_t * x_t - 2 * (sigma_t**2) * diff_ratio_t * score_x_t
287 diffusion_t = jnp.sqrt(2 * diff_ratio_t) * sigma_t
288 x_t1 = x_t + drift_t * dt + diffusion_t * dw_t
289 return x_t1
291 def _sample_step_x0_deterministic(
292 self, idx: int, x_t: Array, zs: Array, ts: Array
293 ) -> Array:
294 """
295 Perform one deterministic Euler step using the ``x_0`` vector field (i.e., ``VectorFieldType.X0``).
296 Corresponds to the probability flow ODE associated with the ``x_0`` SDE.
298 Args:
299 idx (``int``): Current step index.
300 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``.
301 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused).
302 ts (``Array[num_steps+1]``): Time schedule.
304 Returns:
305 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``.
306 """
307 (
308 t,
309 t1,
310 dt,
311 dw_t,
312 alpha_t,
313 sigma_t,
314 alpha_prime_t,
315 sigma_prime_t,
316 alpha_ratio_t,
317 sigma_ratio_t,
318 diff_ratio_t,
319 ) = self._get_step_quantities(idx, zs, ts)
320 x0_x_t = self.vector_field(x_t, t)
321 drift_t = sigma_ratio_t * x_t - alpha_t * diff_ratio_t * x0_x_t
322 x_t1 = x_t + drift_t * dt
323 return x_t1
325 def _sample_step_x0_stochastic(
326 self, idx: int, x_t: Array, zs: Array, ts: Array
327 ) -> Array:
328 """
329 Perform one stochastic Euler-Maruyama step using the ``x_0`` vector field (i.e., ``VectorFieldType.X0``).
330 Corresponds to discretizing the reverse SDE derived using the ``x_0`` field.
332 Args:
333 idx (``int``): Current step index.
334 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``.
335 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``.
336 ts (``Array[num_steps+1]``): Time schedule.
338 Returns:
339 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``.
340 """
341 (
342 t,
343 t1,
344 dt,
345 dw_t,
346 alpha_t,
347 sigma_t,
348 alpha_prime_t,
349 sigma_prime_t,
350 alpha_ratio_t,
351 sigma_ratio_t,
352 diff_ratio_t,
353 ) = self._get_step_quantities(idx, zs, ts)
354 x0_x_t = self.vector_field(x_t, t)
355 drift_t = (
356 alpha_ratio_t + 2 * diff_ratio_t
357 ) * x_t - 2 * alpha_t * diff_ratio_t * x0_x_t
358 diffusion_t = jnp.sqrt(2 * diff_ratio_t) * sigma_t
359 x_t1 = x_t + drift_t * dt + diffusion_t * dw_t
360 return x_t1
362 def _sample_step_eps_deterministic(
363 self, idx: int, x_t: Array, zs: Array, ts: Array
364 ) -> Array:
365 """
366 Perform one deterministic Euler step using the ε vector field (i.e., ``VectorFieldType.EPS``).
367 Corresponds to the probability flow ODE associated with the ε SDE.
369 Args:
370 idx (``int``): Current step index.
371 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``.
372 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused).
373 ts (``Array[num_steps+1]``): Time schedule.
375 Returns:
376 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``.
377 """
378 (
379 t,
380 t1,
381 dt,
382 dw_t,
383 alpha_t,
384 sigma_t,
385 alpha_prime_t,
386 sigma_prime_t,
387 alpha_ratio_t,
388 sigma_ratio_t,
389 diff_ratio_t,
390 ) = self._get_step_quantities(idx, zs, ts)
391 eps_x_t = self.vector_field(x_t, t)
392 drift_t = alpha_ratio_t * x_t + sigma_t * diff_ratio_t * eps_x_t
393 x_t1 = x_t + drift_t * dt
394 return x_t1
396 def _sample_step_eps_stochastic(
397 self, idx: int, x_t: Array, zs: Array, ts: Array
398 ) -> Array:
399 """
400 Perform one stochastic Euler-Maruyama step using the ε vector field (i.e., ``VectorFieldType.EPS``).
401 Corresponds to discretizing the reverse SDE derived using the ε field.
403 Args:
404 idx (int): Current step index.
405 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``.
406 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``.
407 ts (``Array[num_steps+1]``): Time schedule.
409 Returns:
410 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``.
411 """
412 (
413 t,
414 t1,
415 dt,
416 dw_t,
417 alpha_t,
418 sigma_t,
419 alpha_prime_t,
420 sigma_prime_t,
421 alpha_ratio_t,
422 sigma_ratio_t,
423 diff_ratio_t,
424 ) = self._get_step_quantities(idx, zs, ts)
425 eps_x_t = self.vector_field(x_t, t)
426 drift_t = alpha_ratio_t * x_t + 2 * sigma_t * diff_ratio_t * eps_x_t
427 diffusion_t = jnp.sqrt(2 * diff_ratio_t) * sigma_t
428 x_t1 = x_t + drift_t * dt + diffusion_t * dw_t
429 return x_t1
431 def _sample_step_v_deterministic(
432 self, idx: int, x_t: Array, zs: Array, ts: Array
433 ) -> Array:
434 """
435 Perform one deterministic Euler step using the velocity vector field (i.e., ``VectorFieldType.V``).
436 Corresponds to the probability flow ODE associated with the velocity SDE.
438 Args:
439 idx (``int``): Current step index.
440 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``.
441 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused).
442 ts (``Array[num_steps+1]``): Time schedule.
444 Returns:
445 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``.
446 """
447 (
448 t,
449 t1,
450 dt,
451 dw_t,
452 alpha_t,
453 sigma_t,
454 alpha_prime_t,
455 sigma_prime_t,
456 alpha_ratio_t,
457 sigma_ratio_t,
458 diff_ratio_t,
459 ) = self._get_step_quantities(idx, zs, ts)
460 v_x_t = self.vector_field(x_t, t)
461 drift_t = v_x_t
462 x_t1 = x_t + drift_t * dt
463 return x_t1
465 def _sample_step_v_stochastic(
466 self, idx: int, x_t: Array, zs: Array, ts: Array
467 ) -> Array:
468 """
469 Perform one stochastic Euler-Maruyama step using the velocity vector field (i.e., ``VectorFieldType.V``).
470 Corresponds to discretizing the reverse SDE derived using the velocity field.
472 Args:
473 idx (``int``): Current step index.
474 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``.
475 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``.
476 ts (``Array[num_steps+1]``): Time schedule.
478 Returns:
479 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``.
480 """
481 (
482 t,
483 t1,
484 dt,
485 dw_t,
486 alpha_t,
487 sigma_t,
488 alpha_prime_t,
489 sigma_prime_t,
490 alpha_ratio_t,
491 sigma_ratio_t,
492 diff_ratio_t,
493 ) = self._get_step_quantities(idx, zs, ts)
494 v_x_t = self.vector_field(x_t, t)
495 drift_t = -alpha_ratio_t * x_t + 2 * v_x_t
496 diffusion_t = jnp.sqrt(2 * diff_ratio_t) * sigma_t
497 x_t1 = x_t + drift_t * dt + diffusion_t * dw_t
498 return x_t1
501@dataclass
502class DDMSampler(Sampler):
503 """
504 Class for sampling from diffusion models using the Denoising Diffusion Probabilistic Models (DDPM)
505 or Denoising Diffusion Implicit Models (DDIM) sampling strategy.
507 This sampler first converts any given vector field type (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) provided by ``vector_field`` into an equivalent x0 prediction using the ``convert_vector_field_type`` utility.
508 Then, it applies the DDPM (if ``use_stochastic_sampler`` is ``True``) or DDIM (if ``use_stochastic_sampler`` is ``False``) update rule based on this x0 prediction.
510 Attributes:
511 diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics.
512 vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field.
513 vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``.
514 use_stochastic_sampler (``bool``): If ``True``, uses DDPM (stochastic); otherwise, uses DDIM (deterministic).
515 sample_step (``Callable[[int, Array, Array, Array], Array]``): The DDPM or DDIM step function.
516 """
518 def get_sample_step_function(self) -> Callable[[int, Array, Array, Array], Array]:
519 """
520 Get the appropriate DDPM/DDIM sampling step function based on stochasticity.
522 Returns:
523 ``Callable[[int, Array, Array, Array], Array]``: The DDPM (stochastic) or DDIM (deterministic) step function, which has signature:
525 ``(idx: int, x: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]``
526 """
527 if self.use_stochastic_sampler:
528 return self._sample_step_stochastic
529 else:
530 return self._sample_step_deterministic
532 def _get_x0_prediction(self, x_t: Array, t: Array) -> Array:
533 """
534 Predict the initial state x_0 from the current noisy state x_t at time t.
536 This uses the provided ``vector_field`` function and its ``vector_field_type``
537 to compute the prediction, converting it to an X0 prediction if necessary.
539 Args:
540 x_t (``Array[*data_dims]``): The current state tensor.
541 t (``Array[]``): The current time.
543 Returns:
544 ``Array[*data_dims]``: The predicted initial state x_0.
545 """
546 alpha_t = self.diffusion_process.alpha(t)
547 sigma_t = self.diffusion_process.sigma(t)
548 alpha_prime_t = self.diffusion_process.alpha_prime(t)
549 sigma_prime_t = self.diffusion_process.sigma_prime(t)
550 f_x_t = self.vector_field(x_t, t)
551 x0_x_t = convert_vector_field_type(
552 x_t,
553 f_x_t,
554 alpha_t,
555 sigma_t,
556 alpha_prime_t,
557 sigma_prime_t,
558 self.vector_field_type,
559 VectorFieldType.X0,
560 )
561 return x0_x_t
563 def _sample_step_deterministic(
564 self, idx: int, x_t: Array, zs: Array, ts: Array
565 ) -> Array:
566 """
567 Perform one deterministic DDIM sampling step.
569 This involves predicting x0 from the current state ``(x_t, t)`` and then applying
570 the DDIM update rule to get the state at the next timestep ``t1``.
572 Args:
573 idx (``int``): The current step index (corresponds to time ``ts[idx]``).
574 x_t (``Array[*data_dims]``): The current state tensor at time ``ts[idx]``.
575 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused in DDIM).
576 ts (``Array[num_steps+1]``): The time schedule for sampling.
578 Returns:
579 ``Array[*data_dims]``: The next state tensor at time ``ts[idx+1]`` after applying the DDIM update.
580 """
581 t = ts[idx]
582 x0_x_t = self._get_x0_prediction(x_t, t)
584 t1 = ts[idx + 1]
585 alpha_t = self.diffusion_process.alpha(t)
586 sigma_t = self.diffusion_process.sigma(t)
587 alpha_t1 = self.diffusion_process.alpha(t1)
588 sigma_t1 = self.diffusion_process.sigma(t1)
590 r01 = sigma_t1 / sigma_t
591 r11 = (alpha_t / alpha_t1) * r01
593 mean = r01 * x_t + alpha_t1 * (1 - r11) * x0_x_t
594 x_t1 = mean
595 return x_t1
597 def _sample_step_stochastic(
598 self, idx: int, x_t: Array, zs: Array, ts: Array
599 ) -> Array:
600 """
601 Perform one stochastic DDPM sampling step.
603 This involves predicting x0 from the current state (x, t), and then applying
604 the DDPM update rule, which corresponds to sampling from the conditional
605 distribution p(x_{t-1}|x_t, x_0), adding noise scaled by sigma_t.
607 Args:
608 idx (``int``): The current step index (corresponds to time ``ts[idx]``).
609 x_t (``Array[*data_dims]``): The current state tensor at time ``ts[idx]``.
610 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``.
611 ts (``Array[num_steps+1]``): The time schedule for sampling.
613 Returns:
614 ``Array[*data_dims]``: The next state tensor at time ``ts[idx+1]`` after applying the DDPM update.
615 """
616 t = ts[idx]
617 x0_x_t = self._get_x0_prediction(x_t, t)
618 z_t = zs[idx]
620 t1 = ts[idx + 1]
621 alpha_t = self.diffusion_process.alpha(t)
622 sigma_t = self.diffusion_process.sigma(t)
623 alpha_t1 = self.diffusion_process.alpha(t1)
624 sigma_t1 = self.diffusion_process.sigma(t1)
626 r11 = (alpha_t / alpha_t1) * (sigma_t1 / sigma_t)
627 r12 = r11 * (sigma_t1 / sigma_t)
628 r22 = (alpha_t / alpha_t1) * r12
630 mean = r12 * x_t + alpha_t1 * (1 - r22) * x0_x_t
631 std = sigma_t1 * (1 - (r11**2)) ** (1 / 2)
632 x_t1 = mean + std * z_t
633 return x_t1