Coverage for src/shephex/executor/slurm/slurm_executor.py: 74%
126 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-06-20 14:13 +0200
« prev ^ index » next coverage.py v7.6.1, created at 2025-06-20 14:13 +0200
1import inspect
2import json
3from pathlib import Path
4from typing import List, Literal, Optional, Union
6from shephex.executor.executor import Executor
7from shephex.executor.slurm import (
8 SlurmBody,
9 SlurmHeader,
10 SlurmProfileManager,
11 SlurmScript,
12)
13from shephex.experiment import FutureResult
14from shephex.experiment.experiment import Experiment
17class SlurmSafetyError(Exception):
18 pass
20class SlurmExecutor(Executor):
21 """
22 Shephex SLURM executor for executing experiments on a SLURM cluster.
23 """
24 def __init__(
25 self, directory: Union[str, Path] = None,
26 scratch: bool = False,
27 ulimit: Union[int, Literal['default']] = 8000,
28 move_output_file: bool = True,
29 safety_check: bool = True,
30 array_limit: int | None = None,
31 **kwargs
32 ) -> None:
33 """
34 shephex SLURM executor.
36 Parameters
37 ----------
38 directory : Union[str, Path], optional
39 Directory where the SLURM script and output files will be stored,
40 defaults to /slurm.
41 scratch : bool, optional
42 If True, the executor will use the /scratch directory for the
43 execution of the experiments. Defaults to False. When true
44 files will automatically be copied back to the original directory
45 once the job is finished.
46 **kwargs
47 Additional keyword arguments to be passed to the SlurmHeader,
48 these are the SLURM parameters for the job. Supports all the
49 arguments for sbatch, see https://slurm.schedmd.com/sbatch.html.
50 """
51 if safety_check:
52 self.safety_check(frame_index=2)
54 self.header = SlurmHeader()
55 for key, value in kwargs.items():
56 self.header.add(key, value)
58 if directory is None:
59 directory = 'slurm'
60 self.directory = Path(directory)
62 # Containers for commands to be executed before and after the main execution
63 self._commands_pre_execution = []
64 self._commands_post_execution = []
66 # Special options
67 self.ulimit = ulimit
68 self.move_output_file = move_output_file
69 self.scratch = scratch
70 self.array_limit = array_limit
72 # To kepe track of the special options for saving the config
73 self.special_options = {
74 'scratch': scratch,
75 'ulimit': ulimit,
76 'move_output_file': move_output_file,
77 'array_limit': array_limit,
78 }
81 @classmethod
82 def from_config(cls, path: Path, safety_check: bool = True, **kwargs) -> 'SlurmExecutor':
84 if safety_check:
85 cls.safety_check(frame_index=1)
87 if not isinstance(path, Path):
88 path = Path(path)
90 assert path.exists(), f'File {path} does not exist'
91 assert path.suffix == '.json', f'File {path} is not a json file'
93 with open(path) as f:
94 config = json.load(f)
95 config.update(kwargs)
97 pre_commands = config.pop('commands_pre_execution', list())
98 post_commands = config.pop('commands_post_execution', list())
100 instance = cls(**config, safety_check=False)
101 instance._commands_pre_execution = pre_commands
102 instance._commands_post_execution = post_commands
103 return instance
105 def to_config(self, path: Path | str) -> None:
106 if not isinstance(path, Path):
107 path = Path(path)
108 config = self.header.to_dict()
109 config.update(self.special_options)
111 config['commands_pre_execution'] = self._commands_pre_execution
112 config['commands_post_execution'] = self._commands_post_execution
114 with open(path, 'w') as f:
115 json.dump(config, f, indent=4)
117 @classmethod
118 def from_profile(cls, name: str, safety_check: bool = True, **kwargs) -> 'SlurmExecutor':
119 """
120 Create a new SlurmExecutor from a profile.
122 Parameters
123 ----------
124 name : str
125 Name of the profile.
126 safety_check : bool, optional
127 If True, a safety check will be performed to ensure that the executor
128 is not instantiated on a script that is not the main script. Defaults to True.
129 **kwargs
130 Additional keyword arguments to be passed to the SlurmExecutor.
131 """
132 if safety_check:
133 cls.safety_check(frame_index=2)
134 kwargs.pop("safety_check", None)
136 spm = SlurmProfileManager()
137 profile = spm.get_profile_path(name)
138 return cls.from_config(profile, **kwargs, safety_check=False)
140 def _single_execute(self) -> None:
141 raise NotImplementedError('Single execution is not supported for SLURM Executor, everything is executed with _sequence execute.')
143 def _sequence_execute(
144 self,
145 experiments: List[Experiment],
146 dry: bool = False,
147 execution_directory: Union[Path, str] = None,
148 ) -> List[FutureResult]:
149 """
150 Execute a sequence of experiments as an array job.
152 Parameters
153 ----------
154 experiments : List[Experiment]
155 List of experiments to be executed.
156 dry : bool, optional
157 If True, the script will be printed instead of executed.
158 execution_directory : Union[Path, str], optional
159 Directory where the experiments will be executed.
161 Returns
162 -------
163 List[FutureResult]
164 List of FutureResult objects.
165 """
167 if len(experiments) == 0:
168 return []
170 # Dump config:
171 self.directory.mkdir(parents=True, exist_ok=True)
172 index = len(list(self.directory.glob('config*.json')))
173 path = self.directory / f'config_{index}.json'
174 self.to_config(path)
176 header = self.header.copy()
177 if self.array_limit is not None:
178 array_limit = min(self.array_limit, len(experiments))
179 else:
180 array_limit = len(experiments)
182 header.add('array', f'0-{len(experiments)-1}%{array_limit}')
184 body = self._make_slurm_body(experiments)
186 count = len(list(self.directory.glob('submit*.sh')))
187 script = SlurmScript(header, body, directory=self.directory, name=f'submit_{count}.sh')
188 if dry:
189 print(script)
190 return [FutureResult() for _ in experiments]
192 script.write()
194 job_id = script.submit()
195 for experiment in experiments:
196 experiment.update_status('submitted')
198 return [FutureResult(info={'job_id': job_id}) for _ in experiments]
200 def _bash_array_str(self, strings: List[str]) -> str:
201 """
202 Convert a list of strings into a nicely formatted bash array of string.
204 Parameters
205 ----------
206 strings : List[str]
207 List of strings to be converted.
209 Returns
210 -------
211 str
212 A python string representing a bash array of strings.
213 """
214 bash_str = ' \n\t'.join(strings)
215 return f'(\n\t{bash_str}\n)'
217 def _body_add(self, command: str, when: Optional[Literal['pre', 'post']] = None) -> None:
218 """
219 Add a command to the body of the SLURM script.
221 Parameters
222 ----------
223 command : str
224 Command to be added to the body.
225 """
226 if when is None:
227 when = 'pre'
229 if when == 'pre':
230 self._commands_pre_execution.append(command)
232 elif when == 'post':
233 self._commands_post_execution.append(command)
235 def _make_slurm_body(self, experiments: List[Experiment]) -> SlurmBody:
236 """
237 Make a new SlurmBody object.
239 Returns
240 -------
241 SlurmBody
242 A new SlurmBody object.
243 """
245 identifiers = [str(experiment.identifier) for experiment in experiments]
246 directories = [str(experiment.directory.resolve()) for experiment in experiments]
248 body = SlurmBody()
250 body.add(f'directories={self._bash_array_str(directories)}')
251 body.add(f'identifiers={self._bash_array_str(identifiers)}')
253 if self.move_output_file:
254 self._body_add(r"mv slurm-${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}.out ${directories[$SLURM_ARRAY_TASK_ID]}", when='pre')
255 if self.ulimit != 'default':
256 self._body_add(f'ulimit -Su {self.ulimit}', when='pre')
258 for command in self._commands_pre_execution:
259 body.add(command)
261 # Slurm info command:
262 command = r'hex slurm add-info -d ${directories[$SLURM_ARRAY_TASK_ID]} -j "${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}"'
263 body.add(command)
265 # Execution command
266 command = r'hex execute ${directories[$SLURM_ARRAY_TASK_ID]}'
268 if self.scratch:
269 command += ' -e /scratch/$SLURM_JOB_ID'
271 body.add(command)
273 for command in self._commands_post_execution:
274 body.add(command)
276 if self.scratch:
277 body.add(
278 r'cp -r /scratch/$SLURM_JOB_ID/* ${directories[$SLURM_ARRAY_TASK_ID]}'
279 )
281 return body
283 @staticmethod
284 def safety_check(frame_index: int = 2) -> None:
285 """
286 Check if the executor is being called from the main script.
288 Parameters
289 ----------
290 frame_index : int, optional
291 Index of the frame to be checked. Defaults to 2.
293 Raises
294 ------
295 SlurmSafetyError
296 If the executor is not being called from the main script.
298 Frame index depends on which creation method is used:
299 - from_profile: 2
300 - from_config: 1
301 - __init__: 0
302 """
304 caller_frames = inspect.stack()
305 caller_frame = caller_frames[frame_index]
306 caller_module = inspect.getmodule(caller_frame[0])
308 if caller_module and caller_module.__name__ != "__main__" or caller_module is None:
309 raise SlurmSafetyError("""SlurmExecutor should only be called from the main script.
310 If the you really want, you can disable this check. This error may be caused by not having
311 a 'if __name__ == "__main__":' block in the main script.""")