Coverage for src/shephex/study/study.py: 87%

78 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-06-20 14:13 +0200

1from pathlib import Path 

2from typing import List, Optional, Self, Tuple, Union 

3 

4import dill 

5from rich.progress import track 

6 

7from shephex.experiment.experiment import Experiment 

8from shephex.experiment.options import Options 

9from shephex.study.table import LittleTable as Table 

10 

11 

12class Study: 

13 """ 

14 A study is a series of related experiments. The role of the study object is 

15 to manage experiments and provide a common interface for interacting with 

16 them. 

17 """ 

18 

19 def __init__(self, path: Union[Path, str], refresh: bool = True, avoid_duplicates: bool = True) -> None: 

20 self.path = Path(path) 

21 self.avoid_duplicates = avoid_duplicates 

22 if refresh: 

23 self.refresh(clear_table=True) 

24 

25 def add_experiment( 

26 self, experiment: Experiment, verbose: bool = True, check_contain: bool = True 

27 ) -> bool: 

28 """ 

29 Add an experiment to the study. 

30 """ 

31 if check_contain: 

32 contained = self.contains_experiment(experiment) 

33 elif not check_contain: 

34 contained = False 

35 

36 if not contained: 

37 self.table.add_row(experiment.to_dict(), add_columns=True) 

38 if not experiment.directory.exists(): 

39 experiment.dump() 

40 

41 def update_experiment(self, experiment: Experiment) -> None: 

42 """ 

43 Update the experiment in the study. 

44 """ 

45 self.table.update_row(experiment.to_dict()) 

46 

47 def contains_experiment(self, experiment: Experiment) -> bool: 

48 """ 

49 Check if the experiment is already in the study. 

50 

51 Returns 

52 -------- 

53 contains: bool 

54 True if the experiment is in the study, False otherwise. 

55 """ 

56 if not self.avoid_duplicates: 

57 return False 

58 

59 contains = self.table.contains_row(experiment.to_dict()) 

60 return contains 

61 

62 def discover_experiments(self) -> List[Path]: 

63 return self.path.glob(f'*-{Experiment.extension}') 

64 

65 def refresh(self, clear_table: bool = False, progress_bar: bool = False) -> None: 

66 # Check if the study directory exists 

67 if not self.path.exists(): 

68 self.path.mkdir(parents=True) # pragma: no cover 

69 

70 if clear_table: 

71 self.table = Table() 

72 

73 # Get the list of experiments 

74 experiments_paths = list(self.discover_experiments()) 

75 

76 # Add the experiments to the table 

77 for experiment_path in track(experiments_paths, description="Loading experiments", disable=not progress_bar, show_speed=True): 

78 experiment = Experiment.load(experiment_path, load_procedure=False) 

79 if not self.contains_experiment(experiment): 

80 self.add_experiment(experiment, verbose=True, check_contain=True) 

81 else: 

82 self.update_experiment(experiment) 

83 

84 def report(self) -> None: 

85 from shephex.study.renderer import StudyRenderer 

86 StudyRenderer().render_study(self) 

87 

88 def get_experiments( 

89 self, 

90 status: str = None, 

91 load_procedure: bool = True, 

92 loaded_experiments: Optional[List[Experiment]] = None, 

93 ) -> List[Experiment]: 

94 """ 

95 Get the experiments in the study. 

96 

97 Todo: This is probably quite inefficient, so should be improved. 

98 """ 

99 experiments = [] 

100 

101 if status != 'all' or status is None: 

102 identifiers = self.table.where(status=status) 

103 else: 

104 identifiers = [row.identifier for row in self.table.table] 

105 

106 if loaded_experiments is not None: 

107 loaded_ids = [experiment.identifier for experiment in loaded_experiments] 

108 else: 

109 loaded_ids = [] 

110 

111 for identifier in identifiers: 

112 if identifier in loaded_ids: 

113 experiment = loaded_experiments[loaded_ids.index(identifier)] 

114 else: 

115 path = self.path / f'{identifier}-{Experiment.extension}' 

116 experiment = Experiment.load(path, load_procedure=load_procedure) 

117 

118 experiments.append(experiment) 

119 return experiments 

120 

121 def where(self, load_shephex_options: bool = True, *args, **kwargs) -> Union[List[str], Tuple[List[str], List[Options]]]: 

122 identifiers = self.table.where(*args, **kwargs) 

123 if load_shephex_options: 

124 paths = [Path(self.path) / f"{id_}-{Experiment.extension}" for id_ in identifiers] 

125 options = [Options.load(path / "shephex") for path in paths] 

126 return identifiers, options 

127 

128 return identifiers 

129 

130 def dump(self, path: str | Path | None = None) -> None: 

131 if path is None: 

132 path = self.path / "study.pckl" 

133 

134 with open(path, "wb") as f: 

135 dill.dump(self, f) 

136 

137 @classmethod 

138 def load(cls, path: str | Path | None = None) -> Self: 

139 if path is None: 

140 path = Path.cwd() / "study.pckl" 

141 

142 with open(path, "rb") as f: 

143 study = dill.load(f) 

144 return study 

145 

146