Coverage for src/diffusionlab/samplers.py: 100%

160 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from typing import Callable, Tuple 

2 

3from dataclasses import dataclass, field 

4import jax 

5from jax import Array, numpy as jnp 

6 

7from diffusionlab.dynamics import DiffusionProcess 

8from diffusionlab.vector_fields import ( 

9 VectorFieldType, 

10 convert_vector_field_type, 

11) 

12 

13 

14@dataclass 

15class Sampler: 

16 """ 

17 Base class for sampling from diffusion models using various vector field types. 

18 

19 A Sampler combines a diffusion process, a vector field prediction function, and a scheduler 

20 to generate samples from a trained diffusion model using the reverse process (denoising/sampling). 

21 

22 The sampler supports different vector field types (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) and can perform both stochastic and deterministic sampling based on the subclass implementation and the `use_stochastic_sampler`` flag. 

23 

24 Attributes: 

25 diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics. 

26 vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field. 

27 Takes the current state ``x_t`` and time ``t`` as input. 

28 vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``. 

29 use_stochastic_sampler (``bool``): Whether to use a stochastic or deterministic reverse process. 

30 sample_step (``Callable[[int, Array, Array, Array], Array]``): The specific function used to perform one sampling step. 

31 Takes step index ``idx``, current state ``x_t``, noise array ``zs``, and time schedule ``ts`` as input. 

32 Set during initialization based on the sampler type and ``use_stochastic_sampler``. 

33 """ 

34 

35 diffusion_process: DiffusionProcess 

36 vector_field: Callable[[Array, Array], Array] 

37 vector_field_type: VectorFieldType 

38 use_stochastic_sampler: bool 

39 sample_step: Callable[[int, Array, Array, Array], Array] = field(init=False) 

40 

41 def __post_init__(self): 

42 self.sample_step = self.get_sample_step_function() 

43 

44 def sample(self, x_init: Array, zs: Array, ts: Array) -> Array: 

45 """ 

46 Sample from the model using the reverse diffusion process. 

47 

48 This method generates a final sample by iteratively applying the ``sample_step`` function, 

49 starting from an initial state ``x_init`` and using the provided noise ``zs`` and time schedule ``ts``. 

50 

51 Args: 

52 x_init (``Array[*data_dims]``): The initial noisy tensor from which to initialize sampling (typically sampled from the prior distribution at ``ts[0]``). 

53 zs (``Array[num_steps, *data_dims]``): The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers. 

54 ts (``Array[num_steps+1]``): The time schedule for sampling. A sorted decreasing array of times from ``t_max`` to ``t_min``. 

55 

56 Returns: 

57 ``Array[*data_dims]``: The generated sample at the final time ``ts[-1]``. 

58 """ 

59 

60 def scan_fn(x, idx): 

61 next_x = self.sample_step(idx, x, zs, ts) 

62 return next_x, None 

63 

64 final_x, _ = jax.lax.scan(scan_fn, x_init, jnp.arange(zs.shape[0])) 

65 

66 return final_x 

67 

68 def sample_trajectory(self, x_init: Array, zs: Array, ts: Array) -> Array: 

69 """ 

70 Sample a trajectory from the model using the reverse diffusion process. 

71 

72 This method generates the entire trajectory of intermediate samples by iteratively 

73 applying the ``sample_step`` function. 

74 

75 Args: 

76 x_init (``Array[*data_dims]``): The initial noisy tensor from which to start sampling (at time ``ts[0]``). 

77 zs (``Array[num_steps, *data_dims]``): The noise tensors used at each step for stochastic sampling. Unused for deterministic samplers. 

78 ts (``Array[num_steps+1]``): The time schedule for sampling. A sorted decreasing array of times from ``t_max`` to ``t_min``. 

79 

80 Returns: 

81 ``Array[num_steps+1, *data_dims]``: The complete generated trajectory including the initial state ``x_init``. 

82 """ 

83 

84 def scan_fn(x, idx): 

85 next_x = self.sample_step(idx, x, zs, ts) 

86 return next_x, next_x 

87 

88 _, xs = jax.lax.scan(scan_fn, x_init, jnp.arange(zs.shape[0])) 

89 

90 xs = jnp.concatenate([x_init[None, ...], xs], axis=0) 

91 return xs 

92 

93 def get_sample_step_function(self) -> Callable[[int, Array, Array, Array], Array]: 

94 """ 

95 Abstract method to get the appropriate sampling step function. 

96 

97 Subclasses must implement this method to return the specific function used 

98 for performing one step of the reverse process, based on the sampler's 

99 implementation details (e.g., integrator type) and the ``use_stochastic_sampler`` flag. 

100 

101 Returns: 

102 ``Callable[[int, Array, Array, Array], Array]``: The sampling step function, which has signature: 

103 

104 ``(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]`` 

105 """ 

106 raise NotImplementedError 

107 

108 

109@dataclass 

110class EulerMaruyamaSampler(Sampler): 

111 """ 

112 Class for sampling from diffusion models using the first-order Euler-Maruyama integrator 

113 for the reverse process SDE/ODE. 

114 

115 This sampler implements the step function based on the Euler-Maruyama discretization 

116 of the reverse SDE (if ``use_stochastic_sampler`` is True) or the corresponding 

117 probability flow ODE (if ``use_stochastic_sampler`` is False). It supports all 

118 vector field types (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``). 

119 

120 

121 Attributes: 

122 diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics. 

123 vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field. 

124 Takes the current state ``x_t`` and time ``t`` as input. 

125 vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``. 

126 use_stochastic_sampler (``bool``): Whether to use a stochastic or deterministic reverse process. 

127 sample_step (``Callable[[int, Array, Array, Array], Array]``): The specific function used to perform one sampling step. 

128 Takes step index ``idx``, current state ``x_t``, noise array ``zs``, and time schedule ``ts`` as input. 

129 Set during initialization based on the sampler type and ``use_stochastic_sampler``. 

130 """ 

131 

132 def get_sample_step_function(self) -> Callable[[int, Array, Array, Array], Array]: 

133 """ 

134 Get the appropriate Euler-Maruyama sampling step function based on the 

135 vector field type and stochasticity. 

136 

137 Returns: 

138 Callable[[int, Array, Array, Array], Array]: The specific Euler-Maruyama step function to use. 

139 

140 Signature: ``(idx: int, x_t: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]`` 

141 """ 

142 match (self.vector_field_type, self.use_stochastic_sampler): 

143 case (VectorFieldType.SCORE, False): 

144 return self._sample_step_score_deterministic 

145 case (VectorFieldType.SCORE, True): 

146 return self._sample_step_score_stochastic 

147 case (VectorFieldType.X0, False): 

148 return self._sample_step_x0_deterministic 

149 case (VectorFieldType.X0, True): 

150 return self._sample_step_x0_stochastic 

151 case (VectorFieldType.EPS, False): 

152 return self._sample_step_eps_deterministic 

153 case (VectorFieldType.EPS, True): 

154 return self._sample_step_eps_stochastic 

155 case (VectorFieldType.V, False): 

156 return self._sample_step_v_deterministic 

157 case (VectorFieldType.V, True): 

158 return self._sample_step_v_stochastic 

159 case _: 

160 raise ValueError( 

161 f"Unsupported vector field type: {self.vector_field_type} and stochasticity: {self.use_stochastic_sampler}" 

162 ) 

163 

164 def _get_step_quantities( 

165 self, 

166 idx: int, 

167 zs: Array, 

168 ts: Array, 

169 ) -> Tuple[ 

170 Array, Array, Array, Array, Array, Array, Array, Array, Array, Array, Array 

171 ]: 

172 """ 

173 Calculate common quantities used in Euler-Maruyama sampling steps based on the diffusion process. 

174 

175 Args: 

176 idx (``int``): Current step index (corresponds to time ``ts[idx]``). 

177 zs (``Array[num_steps, *data_dims]``): Noise tensors for stochastic sampling. Only ``zs[idx]`` is used if needed. 

178 ts (``Array[num_steps+1]``): Time schedule for sampling. Used to get ``ts[idx]`` and ``ts[idx+1]``. 

179 

180 Returns: 

181 ``Tuple[Array[], Array[], Array[], Array[*data_dims], Array[], Array[], Array[], Array[], Array[], Array[], Array[]]``: A tuple containing 

182 

183 - t (``Array[]``): Current time ``ts[idx]``. 

184 - t1 (``Array[]``): Next time ``ts[idx+1]``. 

185 - dt (``Array[]``): Time difference ``(t1 - t)``, should be negative. 

186 - dwt (``Array[*data_dims]``): Scaled noise increment ``sqrt(-dt) * zs[idx]`` for the stochastic step. 

187 - alpha_t (``Array[]``): ``α`` at current time ``t``. 

188 - sigma_t (``Array[]``): ``σ`` at current time ``t``. 

189 - alpha_prime_t (``Array[]``): Derivative of ``α`` at current time ``t``. 

190 - sigma_prime_t (``Array[]``): Derivative of ``σ`` at current time ``t``. 

191 - alpha_ratio_t (``Array[]``): ``alpha_prime_t / alpha_t``. 

192 - sigma_ratio_t (``Array[]``): ``sigma_prime_t / sigma_t``. 

193 - diff_ratio_t (``Array[]``): ``sigma_ratio_t - alpha_ratio_t``. 

194 """ 

195 t = ts[idx] 

196 t1 = ts[idx + 1] 

197 dt = t1 - t 

198 dw_t = zs[idx] * jnp.sqrt(-dt) # dt is negative 

199 

200 alpha_t = self.diffusion_process.alpha(t) 

201 sigma_t = self.diffusion_process.sigma(t) 

202 alpha_prime_t = self.diffusion_process.alpha_prime(t) 

203 sigma_prime_t = self.diffusion_process.sigma_prime(t) 

204 alpha_ratio_t = alpha_prime_t / alpha_t 

205 sigma_ratio_t = sigma_prime_t / sigma_t 

206 diff_ratio_t = sigma_ratio_t - alpha_ratio_t 

207 

208 return ( 

209 t, 

210 t1, 

211 dt, 

212 dw_t, 

213 alpha_t, 

214 sigma_t, 

215 alpha_prime_t, 

216 sigma_prime_t, 

217 alpha_ratio_t, 

218 sigma_ratio_t, 

219 diff_ratio_t, 

220 ) 

221 

222 def _sample_step_score_deterministic( 

223 self, idx: int, x_t: Array, zs: Array, ts: Array 

224 ) -> Array: 

225 """ 

226 Perform one deterministic Euler step using the score vector field (i.e., ``VectorFieldType.SCORE``). 

227 Corresponds to the probability flow ODE associated with the score SDE. 

228 

229 Args: 

230 idx (``int``): Current step index. 

231 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``. 

232 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused). 

233 ts (``Array[num_steps+1]``): Time schedule. 

234 

235 Returns: 

236 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``. 

237 """ 

238 ( 

239 t, 

240 t1, 

241 dt, 

242 dw_t, 

243 alpha_t, 

244 sigma_t, 

245 alpha_prime_t, 

246 sigma_prime_t, 

247 alpha_ratio_t, 

248 sigma_ratio_t, 

249 diff_ratio_t, 

250 ) = self._get_step_quantities(idx, zs, ts) 

251 score_x_t = self.vector_field(x_t, t) 

252 drift_t = alpha_ratio_t * x_t - (sigma_t**2) * diff_ratio_t * score_x_t 

253 x_t1 = x_t + drift_t * dt 

254 return x_t1 

255 

256 def _sample_step_score_stochastic( 

257 self, idx: int, x_t: Array, zs: Array, ts: Array 

258 ) -> Array: 

259 """ 

260 Perform one stochastic Euler-Maruyama step using the score vector field (i.e., ``VectorFieldType.SCORE``). 

261 Corresponds to discretizing the reverse SDE derived using the score field. 

262 

263 Args: 

264 idx (``int``): Current step index. 

265 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``. 

266 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``. 

267 ts (``Array[num_steps+1]``): Time schedule. 

268 

269 Returns: 

270 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``. 

271 """ 

272 ( 

273 t, 

274 t1, 

275 dt, 

276 dw_t, 

277 alpha_t, 

278 sigma_t, 

279 alpha_prime_t, 

280 sigma_prime_t, 

281 alpha_ratio_t, 

282 sigma_ratio_t, 

283 diff_ratio_t, 

284 ) = self._get_step_quantities(idx, zs, ts) 

285 score_x_t = self.vector_field(x_t, t) 

286 drift_t = alpha_ratio_t * x_t - 2 * (sigma_t**2) * diff_ratio_t * score_x_t 

287 diffusion_t = jnp.sqrt(2 * diff_ratio_t) * sigma_t 

288 x_t1 = x_t + drift_t * dt + diffusion_t * dw_t 

289 return x_t1 

290 

291 def _sample_step_x0_deterministic( 

292 self, idx: int, x_t: Array, zs: Array, ts: Array 

293 ) -> Array: 

294 """ 

295 Perform one deterministic Euler step using the ``x_0`` vector field (i.e., ``VectorFieldType.X0``). 

296 Corresponds to the probability flow ODE associated with the ``x_0`` SDE. 

297 

298 Args: 

299 idx (``int``): Current step index. 

300 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``. 

301 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused). 

302 ts (``Array[num_steps+1]``): Time schedule. 

303 

304 Returns: 

305 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``. 

306 """ 

307 ( 

308 t, 

309 t1, 

310 dt, 

311 dw_t, 

312 alpha_t, 

313 sigma_t, 

314 alpha_prime_t, 

315 sigma_prime_t, 

316 alpha_ratio_t, 

317 sigma_ratio_t, 

318 diff_ratio_t, 

319 ) = self._get_step_quantities(idx, zs, ts) 

320 x0_x_t = self.vector_field(x_t, t) 

321 drift_t = sigma_ratio_t * x_t - alpha_t * diff_ratio_t * x0_x_t 

322 x_t1 = x_t + drift_t * dt 

323 return x_t1 

324 

325 def _sample_step_x0_stochastic( 

326 self, idx: int, x_t: Array, zs: Array, ts: Array 

327 ) -> Array: 

328 """ 

329 Perform one stochastic Euler-Maruyama step using the ``x_0`` vector field (i.e., ``VectorFieldType.X0``). 

330 Corresponds to discretizing the reverse SDE derived using the ``x_0`` field. 

331 

332 Args: 

333 idx (``int``): Current step index. 

334 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``. 

335 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``. 

336 ts (``Array[num_steps+1]``): Time schedule. 

337 

338 Returns: 

339 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``. 

340 """ 

341 ( 

342 t, 

343 t1, 

344 dt, 

345 dw_t, 

346 alpha_t, 

347 sigma_t, 

348 alpha_prime_t, 

349 sigma_prime_t, 

350 alpha_ratio_t, 

351 sigma_ratio_t, 

352 diff_ratio_t, 

353 ) = self._get_step_quantities(idx, zs, ts) 

354 x0_x_t = self.vector_field(x_t, t) 

355 drift_t = ( 

356 alpha_ratio_t + 2 * diff_ratio_t 

357 ) * x_t - 2 * alpha_t * diff_ratio_t * x0_x_t 

358 diffusion_t = jnp.sqrt(2 * diff_ratio_t) * sigma_t 

359 x_t1 = x_t + drift_t * dt + diffusion_t * dw_t 

360 return x_t1 

361 

362 def _sample_step_eps_deterministic( 

363 self, idx: int, x_t: Array, zs: Array, ts: Array 

364 ) -> Array: 

365 """ 

366 Perform one deterministic Euler step using the ε vector field (i.e., ``VectorFieldType.EPS``). 

367 Corresponds to the probability flow ODE associated with the ε SDE. 

368 

369 Args: 

370 idx (``int``): Current step index. 

371 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``. 

372 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused). 

373 ts (``Array[num_steps+1]``): Time schedule. 

374 

375 Returns: 

376 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``. 

377 """ 

378 ( 

379 t, 

380 t1, 

381 dt, 

382 dw_t, 

383 alpha_t, 

384 sigma_t, 

385 alpha_prime_t, 

386 sigma_prime_t, 

387 alpha_ratio_t, 

388 sigma_ratio_t, 

389 diff_ratio_t, 

390 ) = self._get_step_quantities(idx, zs, ts) 

391 eps_x_t = self.vector_field(x_t, t) 

392 drift_t = alpha_ratio_t * x_t + sigma_t * diff_ratio_t * eps_x_t 

393 x_t1 = x_t + drift_t * dt 

394 return x_t1 

395 

396 def _sample_step_eps_stochastic( 

397 self, idx: int, x_t: Array, zs: Array, ts: Array 

398 ) -> Array: 

399 """ 

400 Perform one stochastic Euler-Maruyama step using the ε vector field (i.e., ``VectorFieldType.EPS``). 

401 Corresponds to discretizing the reverse SDE derived using the ε field. 

402 

403 Args: 

404 idx (int): Current step index. 

405 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``. 

406 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``. 

407 ts (``Array[num_steps+1]``): Time schedule. 

408 

409 Returns: 

410 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``. 

411 """ 

412 ( 

413 t, 

414 t1, 

415 dt, 

416 dw_t, 

417 alpha_t, 

418 sigma_t, 

419 alpha_prime_t, 

420 sigma_prime_t, 

421 alpha_ratio_t, 

422 sigma_ratio_t, 

423 diff_ratio_t, 

424 ) = self._get_step_quantities(idx, zs, ts) 

425 eps_x_t = self.vector_field(x_t, t) 

426 drift_t = alpha_ratio_t * x_t + 2 * sigma_t * diff_ratio_t * eps_x_t 

427 diffusion_t = jnp.sqrt(2 * diff_ratio_t) * sigma_t 

428 x_t1 = x_t + drift_t * dt + diffusion_t * dw_t 

429 return x_t1 

430 

431 def _sample_step_v_deterministic( 

432 self, idx: int, x_t: Array, zs: Array, ts: Array 

433 ) -> Array: 

434 """ 

435 Perform one deterministic Euler step using the velocity vector field (i.e., ``VectorFieldType.V``). 

436 Corresponds to the probability flow ODE associated with the velocity SDE. 

437 

438 Args: 

439 idx (``int``): Current step index. 

440 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``. 

441 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused). 

442 ts (``Array[num_steps+1]``): Time schedule. 

443 

444 Returns: 

445 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``. 

446 """ 

447 ( 

448 t, 

449 t1, 

450 dt, 

451 dw_t, 

452 alpha_t, 

453 sigma_t, 

454 alpha_prime_t, 

455 sigma_prime_t, 

456 alpha_ratio_t, 

457 sigma_ratio_t, 

458 diff_ratio_t, 

459 ) = self._get_step_quantities(idx, zs, ts) 

460 v_x_t = self.vector_field(x_t, t) 

461 drift_t = v_x_t 

462 x_t1 = x_t + drift_t * dt 

463 return x_t1 

464 

465 def _sample_step_v_stochastic( 

466 self, idx: int, x_t: Array, zs: Array, ts: Array 

467 ) -> Array: 

468 """ 

469 Perform one stochastic Euler-Maruyama step using the velocity vector field (i.e., ``VectorFieldType.V``). 

470 Corresponds to discretizing the reverse SDE derived using the velocity field. 

471 

472 Args: 

473 idx (``int``): Current step index. 

474 x_t (``Array[*data_dims]``): Current state tensor at time ``ts[idx]``. 

475 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``. 

476 ts (``Array[num_steps+1]``): Time schedule. 

477 

478 Returns: 

479 ``Array[*data_dims]``: Next state tensor at time ``ts[idx+1]``. 

480 """ 

481 ( 

482 t, 

483 t1, 

484 dt, 

485 dw_t, 

486 alpha_t, 

487 sigma_t, 

488 alpha_prime_t, 

489 sigma_prime_t, 

490 alpha_ratio_t, 

491 sigma_ratio_t, 

492 diff_ratio_t, 

493 ) = self._get_step_quantities(idx, zs, ts) 

494 v_x_t = self.vector_field(x_t, t) 

495 drift_t = -alpha_ratio_t * x_t + 2 * v_x_t 

496 diffusion_t = jnp.sqrt(2 * diff_ratio_t) * sigma_t 

497 x_t1 = x_t + drift_t * dt + diffusion_t * dw_t 

498 return x_t1 

499 

500 

501@dataclass 

502class DDMSampler(Sampler): 

503 """ 

504 Class for sampling from diffusion models using the Denoising Diffusion Probabilistic Models (DDPM) 

505 or Denoising Diffusion Implicit Models (DDIM) sampling strategy. 

506 

507 This sampler first converts any given vector field type (``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``) provided by ``vector_field`` into an equivalent x0 prediction using the ``convert_vector_field_type`` utility. 

508 Then, it applies the DDPM (if ``use_stochastic_sampler`` is ``True``) or DDIM (if ``use_stochastic_sampler`` is ``False``) update rule based on this x0 prediction. 

509 

510 Attributes: 

511 diffusion_process (``DiffusionProcess``): The diffusion process defining the forward dynamics. 

512 vector_field (``Callable[[Array[*data_dims], Array[]], Array[*data_dims]]``): The function predicting the vector field. 

513 vector_field_type (``VectorFieldType``): The type of the vector field predicted by ``vector_field``. 

514 use_stochastic_sampler (``bool``): If ``True``, uses DDPM (stochastic); otherwise, uses DDIM (deterministic). 

515 sample_step (``Callable[[int, Array, Array, Array], Array]``): The DDPM or DDIM step function. 

516 """ 

517 

518 def get_sample_step_function(self) -> Callable[[int, Array, Array, Array], Array]: 

519 """ 

520 Get the appropriate DDPM/DDIM sampling step function based on stochasticity. 

521 

522 Returns: 

523 ``Callable[[int, Array, Array, Array], Array]``: The DDPM (stochastic) or DDIM (deterministic) step function, which has signature: 

524 

525 ``(idx: int, x: Array[*data_dims], zs: Array[num_steps, *data_dims], ts: Array[num_steps+1]) -> Array[*data_dims]`` 

526 """ 

527 if self.use_stochastic_sampler: 

528 return self._sample_step_stochastic 

529 else: 

530 return self._sample_step_deterministic 

531 

532 def _get_x0_prediction(self, x_t: Array, t: Array) -> Array: 

533 """ 

534 Predict the initial state x_0 from the current noisy state x_t at time t. 

535 

536 This uses the provided ``vector_field`` function and its ``vector_field_type`` 

537 to compute the prediction, converting it to an X0 prediction if necessary. 

538 

539 Args: 

540 x_t (``Array[*data_dims]``): The current state tensor. 

541 t (``Array[]``): The current time. 

542 

543 Returns: 

544 ``Array[*data_dims]``: The predicted initial state x_0. 

545 """ 

546 alpha_t = self.diffusion_process.alpha(t) 

547 sigma_t = self.diffusion_process.sigma(t) 

548 alpha_prime_t = self.diffusion_process.alpha_prime(t) 

549 sigma_prime_t = self.diffusion_process.sigma_prime(t) 

550 f_x_t = self.vector_field(x_t, t) 

551 x0_x_t = convert_vector_field_type( 

552 x_t, 

553 f_x_t, 

554 alpha_t, 

555 sigma_t, 

556 alpha_prime_t, 

557 sigma_prime_t, 

558 self.vector_field_type, 

559 VectorFieldType.X0, 

560 ) 

561 return x0_x_t 

562 

563 def _sample_step_deterministic( 

564 self, idx: int, x_t: Array, zs: Array, ts: Array 

565 ) -> Array: 

566 """ 

567 Perform one deterministic DDIM sampling step. 

568 

569 This involves predicting x0 from the current state ``(x_t, t)`` and then applying 

570 the DDIM update rule to get the state at the next timestep ``t1``. 

571 

572 Args: 

573 idx (``int``): The current step index (corresponds to time ``ts[idx]``). 

574 x_t (``Array[*data_dims]``): The current state tensor at time ``ts[idx]``. 

575 zs (``Array[num_steps, *data_dims]``): Noise tensors (unused in DDIM). 

576 ts (``Array[num_steps+1]``): The time schedule for sampling. 

577 

578 Returns: 

579 ``Array[*data_dims]``: The next state tensor at time ``ts[idx+1]`` after applying the DDIM update. 

580 """ 

581 t = ts[idx] 

582 x0_x_t = self._get_x0_prediction(x_t, t) 

583 

584 t1 = ts[idx + 1] 

585 alpha_t = self.diffusion_process.alpha(t) 

586 sigma_t = self.diffusion_process.sigma(t) 

587 alpha_t1 = self.diffusion_process.alpha(t1) 

588 sigma_t1 = self.diffusion_process.sigma(t1) 

589 

590 r01 = sigma_t1 / sigma_t 

591 r11 = (alpha_t / alpha_t1) * r01 

592 

593 mean = r01 * x_t + alpha_t1 * (1 - r11) * x0_x_t 

594 x_t1 = mean 

595 return x_t1 

596 

597 def _sample_step_stochastic( 

598 self, idx: int, x_t: Array, zs: Array, ts: Array 

599 ) -> Array: 

600 """ 

601 Perform one stochastic DDPM sampling step. 

602 

603 This involves predicting x0 from the current state (x, t), and then applying 

604 the DDPM update rule, which corresponds to sampling from the conditional 

605 distribution p(x_{t-1}|x_t, x_0), adding noise scaled by sigma_t. 

606 

607 Args: 

608 idx (``int``): The current step index (corresponds to time ``ts[idx]``). 

609 x_t (``Array[*data_dims]``): The current state tensor at time ``ts[idx]``. 

610 zs (``Array[num_steps, *data_dims]``): Noise tensors. Uses ``zs[idx]``. 

611 ts (``Array[num_steps+1]``): The time schedule for sampling. 

612 

613 Returns: 

614 ``Array[*data_dims]``: The next state tensor at time ``ts[idx+1]`` after applying the DDPM update. 

615 """ 

616 t = ts[idx] 

617 x0_x_t = self._get_x0_prediction(x_t, t) 

618 z_t = zs[idx] 

619 

620 t1 = ts[idx + 1] 

621 alpha_t = self.diffusion_process.alpha(t) 

622 sigma_t = self.diffusion_process.sigma(t) 

623 alpha_t1 = self.diffusion_process.alpha(t1) 

624 sigma_t1 = self.diffusion_process.sigma(t1) 

625 

626 r11 = (alpha_t / alpha_t1) * (sigma_t1 / sigma_t) 

627 r12 = r11 * (sigma_t1 / sigma_t) 

628 r22 = (alpha_t / alpha_t1) * r12 

629 

630 mean = r12 * x_t + alpha_t1 * (1 - r22) * x0_x_t 

631 std = sigma_t1 * (1 - (r11**2)) ** (1 / 2) 

632 x_t1 = mean + std * z_t 

633 return x_t1