Coverage for src/diffusionlab/distributions/base.py: 100%

32 statements  

« prev     ^ index     » next       coverage.py v7.6.12, created at 2025-04-19 14:17 -0700

1from dataclasses import dataclass 

2from typing import Any, Callable, Dict, Tuple 

3 

4from jax import Array 

5 

6from diffusionlab.dynamics import DiffusionProcess 

7from diffusionlab.vector_fields import VectorFieldType 

8 

9 

10@dataclass(frozen=True) 

11class Distribution: 

12 """ 

13 Base class for all distributions. 

14 

15 This class should be subclassed by other distributions when you want to use ground truth 

16 scores, denoisers, noise predictors, or velocity estimators. 

17 

18 Each distribution implementation provides functions to sample from it and compute various vector fields 

19 related to a diffusion process, such as denoising (``x0``), noise prediction (``eps``), 

20 velocity estimation (``v``), and score estimation (``score``). 

21 

22 Attributes: 

23 dist_params (``Dict[str, Array]``): Dictionary containing distribution parameters as JAX arrays. 

24 Shapes depend on the specific distribution. 

25 dist_hparams (``Dict[str, Any]``): Dictionary containing distribution hyperparameters (non-array values). 

26 """ 

27 

28 dist_params: Dict[str, Array] 

29 dist_hparams: Dict[str, Any] 

30 

31 def sample( 

32 self, 

33 key: Array, 

34 num_samples: int, 

35 ) -> Tuple[Array, Any]: 

36 """ 

37 Sample from the distribution. 

38 

39 Args: 

40 key (``Array``): The JAX PRNG key to use for sampling. 

41 num_samples (``int``): The number of samples to draw. 

42 

43 Returns: 

44 ``Tuple[Array[num_samples, *data_dims], Any]``: A tuple containing the samples and any additional information. 

45 """ 

46 raise NotImplementedError 

47 

48 def get_vector_field( 

49 self, vector_field_type: VectorFieldType 

50 ) -> Callable[[Array, Array, DiffusionProcess], Array]: 

51 """ 

52 Get the vector field function of a given type associated with this distribution. 

53 

54 Args: 

55 vector_field_type (``VectorFieldType``): The type of vector field to retrieve (e.g., ``VectorFieldType.SCORE``, ``VectorFieldType.X0``, ``VectorFieldType.EPS``, ``VectorFieldType.V``). 

56 

57 Returns: 

58 ``Callable[[Array[*data_dims], Array[], DiffusionProcess], Array[*data_dims]]``: 

59 The requested vector field function. It takes the current state ``x_t`` (``Array[*data_dims]``), 

60 time ``t`` (``Array[]``), and the ``diffusion_process`` as input and returns the 

61 corresponding vector field value (``Array[*data_dims]``). 

62 """ 

63 match vector_field_type: 

64 case VectorFieldType.X0: 

65 vector_field = self.x0 

66 case VectorFieldType.EPS: 

67 vector_field = self.eps 

68 case VectorFieldType.V: 

69 vector_field = self.v 

70 case VectorFieldType.SCORE: 

71 vector_field = self.score 

72 case _: 

73 raise ValueError( 

74 f"Vector field type {vector_field_type} is not supported." 

75 ) 

76 return vector_field 

77 

78 def score( 

79 self, 

80 x_t: Array, 

81 t: Array, 

82 diffusion_process: DiffusionProcess, 

83 ) -> Array: 

84 """ 

85 Compute the score function (``∇_x log p_t(x)``) of the distribution at time ``t``, 

86 given the noisy state ``x_t`` and the ``diffusion_process``. 

87 

88 Args: 

89 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. 

90 t (``Array[]``): The time step. 

91 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

92 

93 Returns: 

94 ``Array[*data_dims]``: The score of the distribution at ``(x_t, t)``. 

95 """ 

96 raise NotImplementedError 

97 

98 def x0( 

99 self, 

100 x_t: Array, 

101 t: Array, 

102 diffusion_process: DiffusionProcess, 

103 ) -> Array: 

104 """ 

105 Predict the initial state ``x0`` (denoised sample) from the noisy state ``x_t`` at time ``t``, 

106 given the ``diffusion_process``. 

107 

108 Args: 

109 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. 

110 t (``Array[]``): The time step. 

111 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

112 

113 Returns: 

114 ``Array[*data_dims]``: The predicted initial state ``x0``. 

115 """ 

116 raise NotImplementedError 

117 

118 def eps( 

119 self, 

120 x_t: Array, 

121 t: Array, 

122 diffusion_process: DiffusionProcess, 

123 ) -> Array: 

124 """ 

125 Predict the noise component ``ε`` corresponding to the noisy state ``x_t`` at time ``t``, 

126 given the ``diffusion_process``. 

127 

128 Args: 

129 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. 

130 t (``Array[]``): The time step. 

131 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

132 

133 Returns: 

134 ``Array[*data_dims]``: The predicted noise ``ε``. 

135 """ 

136 raise NotImplementedError 

137 

138 def v( 

139 self, 

140 x_t: Array, 

141 t: Array, 

142 diffusion_process: DiffusionProcess, 

143 ) -> Array: 

144 """ 

145 Compute the velocity field ``v(x_t, t)`` corresponding to the noisy state ``x_t`` at time ``t``, 

146 given the ``diffusion_process``. 

147 

148 Args: 

149 x_t (``Array[*data_dims]``): The noisy state tensor at time ``t``. 

150 t (``Array[]``): The time step. 

151 diffusion_process (``DiffusionProcess``): The diffusion process definition. 

152 

153 Returns: 

154 ``Array[*data_dims]``: The computed velocity field ``v``. 

155 """ 

156 raise NotImplementedError