Coverage for src/diffusionlab/dynamics.py: 100%

29 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from dataclasses import dataclass, field 

2from typing import Callable 

3 

4import jax 

5from jax import Array, numpy as jnp 

6 

7 

8@dataclass(frozen=True) 

9class DiffusionProcess: 

10 """ 

11 Base class for implementing various diffusion processes. 

12 

13 A diffusion process defines how data evolves over time when noise is added according to 

14 specific dynamics operating on scalar time inputs. This class provides a framework to 

15 implement diffusion processes based on a schedule defined by ``α(t)`` and ``σ(t)``. 

16 

17 The diffusion is parameterized by two scalar functions of scalar time ``t``: 

18 

19 - ``α(t)``: Controls how much of the original signal is preserved at time ``t``. 

20 - ``σ(t)``: Controls how much noise is added at time ``t``. 

21 

22 The forward process for a single data point ``x_0`` is defined as: 

23 

24 ``x_t = α(t) * x_0 + σ(t) * ε`` 

25 

26 where: 

27 

28 - ``x_0`` is the original data (``Array[*data_dims]``) 

29 - ``x_t`` is the noised data at time ``t`` (``Array[*data_dims]``) 

30 - ``ε`` is random noise sampled from a standard Gaussian distribution (``Array[*data_dims]``) 

31 - ``t`` is the scalar diffusion time parameter (``Array[]``) 

32 

33 Attributes: 

34 alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. 

35 sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. 

36 alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. 

37 sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. 

38 """ 

39 

40 alpha: Callable[[Array], Array] 

41 sigma: Callable[[Array], Array] 

42 alpha_prime: Callable[[Array], Array] = field(init=False) 

43 sigma_prime: Callable[[Array], Array] = field(init=False) 

44 

45 def __post_init__(self): 

46 object.__setattr__(self, "alpha_prime", jax.grad(self.alpha)) 

47 object.__setattr__(self, "sigma_prime", jax.grad(self.sigma)) 

48 

49 def forward(self, x: Array, t: Array, eps: Array) -> Array: 

50 """ 

51 Applies the forward diffusion process to a data tensor ``x`` at time ``t`` using noise ``ε``. 

52 

53 Computes ``x_t = α(t) * x + σ(t) * ε``. 

54 

55 Args: 

56 x (``Array[*data_dims]``): The input data tensor ``x_0``. 

57 t (``Array[]``): The scalar time parameter ``t``. 

58 eps (``Array[*data_dims]``): The Gaussian noise tensor ``ε``, matching the shape of ``x``. 

59 

60 Returns: 

61 ``Array[*data_dims]``: The noised data tensor ``x_t`` at time ``t``. 

62 """ 

63 alpha_t = self.alpha(t) 

64 sigma_t = self.sigma(t) 

65 return alpha_t * x + sigma_t * eps 

66 

67 

68@dataclass(frozen=True) 

69class VarianceExplodingProcess(DiffusionProcess): 

70 """ 

71 Implements a Variance Exploding (VE) diffusion process. 

72 

73 In this process, the signal component is constant (``α(t) = 1``), while the noise component 

74 increases over time according to the provided ``σ(t)`` function. The variance of the 

75 noised data ``x_t`` explodes as ``t`` increases. 

76 

77 Forward process: 

78 

79 ``x_t = x_0 + σ(t) * ε``. 

80 

81 This process uses: 

82 

83 - ``α(t) = 1`` 

84 - ``σ(t) =`` Provided by the user 

85 

86 Attributes: 

87 alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to 1. 

88 sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Provided by the user. 

89 alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to 0. 

90 sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. 

91 """ 

92 

93 def __init__(self, sigma: Callable[[Array], Array]): 

94 """ 

95 Initialize a Variance Exploding diffusion process. 

96 

97 Args: 

98 sigma (``Callable[[Array], Array]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. 

99 """ 

100 super().__init__(alpha=lambda t: jnp.ones_like(t), sigma=sigma) 

101 

102 

103@dataclass(frozen=True) 

104class VariancePreservingProcess(DiffusionProcess): 

105 """ 

106 Implements a Variance Preserving (VP) diffusion process, often used in DDPMs. 

107 

108 This process maintains the variance of the noised data ``x_t`` close to 1 (assuming ``x_0`` 

109 and ``ε`` have unit variance) throughout the diffusion by scaling the signal and noise 

110 components appropriately. 

111 

112 Uses the following scalar dynamics: 

113 

114 - ``α(t) = sqrt(1 - t²)`` 

115 - ``σ(t) = t`` 

116 

117 Forward process: 

118 

119 ``x_t = sqrt(1 - t²) * x_0 + t * ε``. 

120 

121 Attributes: 

122 alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to ``sqrt(1 - t²)``. 

123 sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Set to ``t``. 

124 alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to ``-t / sqrt(1 - t²)``. 

125 sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. Set to ``1``. 

126 """ 

127 

128 def __init__(self): 

129 """ 

130 Initialize a Variance Preserving process with predefined scalar dynamics. 

131 """ 

132 super().__init__( 

133 alpha=lambda t: jnp.sqrt(jnp.ones_like(t) - t**2), sigma=lambda t: t 

134 ) 

135 

136 

137@dataclass(frozen=True) 

138class FlowMatchingProcess(DiffusionProcess): 

139 """ 

140 Implements a diffusion process based on Flow Matching principles. 

141 

142 This process defines dynamics that linearly interpolate between the data distribution 

143 at ``t=0`` and a noise distribution (standard Gaussian) at ``t=1``. 

144 

145 Uses the following scalar dynamics: 

146 

147 - ``α(t) = 1 - t`` 

148 - ``σ(t) = t`` 

149 

150 Forward process: 

151 

152 ``x_t = (1 - t) * x_0 + t * ε``. 

153 

154 Attributes: 

155 alpha (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar signal coefficient ``α(t)``. Set to ``1 - t``. 

156 sigma (``Callable[[Array[]], Array[]]``): Function mapping scalar time ``t`` -> scalar noise coefficient ``σ(t)``. Set to ``t``. 

157 alpha_prime (``Callable[[Array[]], Array[]]``): Derivative of ``α`` w.r.t. scalar time ``t``. Set to ``-1``. 

158 sigma_prime (``Callable[[Array[]], Array[]]``): Derivative of ``σ`` w.r.t. scalar time ``t``. Set to ``1``. 

159 """ 

160 

161 def __init__(self): 

162 """ 

163 Initialize a Flow Matching process with predefined linear interpolation dynamics. 

164 """ 

165 super().__init__(alpha=lambda t: jnp.ones_like(t) - t, sigma=lambda t: t)