Coverage for src/shephex/study/study.py: 87%
78 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-06-20 14:13 +0200
« prev ^ index » next coverage.py v7.6.1, created at 2025-06-20 14:13 +0200
1from pathlib import Path
2from typing import List, Optional, Self, Tuple, Union
4import dill
5from rich.progress import track
7from shephex.experiment.experiment import Experiment
8from shephex.experiment.options import Options
9from shephex.study.table import LittleTable as Table
12class Study:
13 """
14 A study is a series of related experiments. The role of the study object is
15 to manage experiments and provide a common interface for interacting with
16 them.
17 """
19 def __init__(self, path: Union[Path, str], refresh: bool = True, avoid_duplicates: bool = True) -> None:
20 self.path = Path(path)
21 self.avoid_duplicates = avoid_duplicates
22 if refresh:
23 self.refresh(clear_table=True)
25 def add_experiment(
26 self, experiment: Experiment, verbose: bool = True, check_contain: bool = True
27 ) -> bool:
28 """
29 Add an experiment to the study.
30 """
31 if check_contain:
32 contained = self.contains_experiment(experiment)
33 elif not check_contain:
34 contained = False
36 if not contained:
37 self.table.add_row(experiment.to_dict(), add_columns=True)
38 if not experiment.directory.exists():
39 experiment.dump()
41 def update_experiment(self, experiment: Experiment) -> None:
42 """
43 Update the experiment in the study.
44 """
45 self.table.update_row(experiment.to_dict())
47 def contains_experiment(self, experiment: Experiment) -> bool:
48 """
49 Check if the experiment is already in the study.
51 Returns
52 --------
53 contains: bool
54 True if the experiment is in the study, False otherwise.
55 """
56 if not self.avoid_duplicates:
57 return False
59 contains = self.table.contains_row(experiment.to_dict())
60 return contains
62 def discover_experiments(self) -> List[Path]:
63 return self.path.glob(f'*-{Experiment.extension}')
65 def refresh(self, clear_table: bool = False, progress_bar: bool = False) -> None:
66 # Check if the study directory exists
67 if not self.path.exists():
68 self.path.mkdir(parents=True) # pragma: no cover
70 if clear_table:
71 self.table = Table()
73 # Get the list of experiments
74 experiments_paths = list(self.discover_experiments())
76 # Add the experiments to the table
77 for experiment_path in track(experiments_paths, description="Loading experiments", disable=not progress_bar, show_speed=True):
78 experiment = Experiment.load(experiment_path, load_procedure=False)
79 if not self.contains_experiment(experiment):
80 self.add_experiment(experiment, verbose=True, check_contain=True)
81 else:
82 self.update_experiment(experiment)
84 def report(self) -> None:
85 from shephex.study.renderer import StudyRenderer
86 StudyRenderer().render_study(self)
88 def get_experiments(
89 self,
90 status: str = None,
91 load_procedure: bool = True,
92 loaded_experiments: Optional[List[Experiment]] = None,
93 ) -> List[Experiment]:
94 """
95 Get the experiments in the study.
97 Todo: This is probably quite inefficient, so should be improved.
98 """
99 experiments = []
101 if status != 'all' or status is None:
102 identifiers = self.table.where(status=status)
103 else:
104 identifiers = [row.identifier for row in self.table.table]
106 if loaded_experiments is not None:
107 loaded_ids = [experiment.identifier for experiment in loaded_experiments]
108 else:
109 loaded_ids = []
111 for identifier in identifiers:
112 if identifier in loaded_ids:
113 experiment = loaded_experiments[loaded_ids.index(identifier)]
114 else:
115 path = self.path / f'{identifier}-{Experiment.extension}'
116 experiment = Experiment.load(path, load_procedure=load_procedure)
118 experiments.append(experiment)
119 return experiments
121 def where(self, load_shephex_options: bool = True, *args, **kwargs) -> Union[List[str], Tuple[List[str], List[Options]]]:
122 identifiers = self.table.where(*args, **kwargs)
123 if load_shephex_options:
124 paths = [Path(self.path) / f"{id_}-{Experiment.extension}" for id_ in identifiers]
125 options = [Options.load(path / "shephex") for path in paths]
126 return identifiers, options
128 return identifiers
130 def dump(self, path: str | Path | None = None) -> None:
131 if path is None:
132 path = self.path / "study.pckl"
134 with open(path, "wb") as f:
135 dill.dump(self, f)
137 @classmethod
138 def load(cls, path: str | Path | None = None) -> Self:
139 if path is None:
140 path = Path.cwd() / "study.pckl"
142 with open(path, "rb") as f:
143 study = dill.load(f)
144 return study