Coverage for src/shephex/executor/slurm/slurm_executor.py: 74%

126 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-06-20 14:13 +0200

1import inspect 

2import json 

3from pathlib import Path 

4from typing import List, Literal, Optional, Union 

5 

6from shephex.executor.executor import Executor 

7from shephex.executor.slurm import ( 

8 SlurmBody, 

9 SlurmHeader, 

10 SlurmProfileManager, 

11 SlurmScript, 

12) 

13from shephex.experiment import FutureResult 

14from shephex.experiment.experiment import Experiment 

15 

16 

17class SlurmSafetyError(Exception): 

18 pass 

19 

20class SlurmExecutor(Executor): 

21 """ 

22 Shephex SLURM executor for executing experiments on a SLURM cluster. 

23 """ 

24 def __init__( 

25 self, directory: Union[str, Path] = None, 

26 scratch: bool = False, 

27 ulimit: Union[int, Literal['default']] = 8000, 

28 move_output_file: bool = True, 

29 safety_check: bool = True, 

30 array_limit: int | None = None, 

31 **kwargs 

32 ) -> None: 

33 """ 

34 shephex SLURM executor. 

35 

36 Parameters 

37 ---------- 

38 directory : Union[str, Path], optional 

39 Directory where the SLURM script and output files will be stored, 

40 defaults to /slurm. 

41 scratch : bool, optional 

42 If True, the executor will use the /scratch directory for the 

43 execution of the experiments. Defaults to False. When true 

44 files will automatically be copied back to the original directory 

45 once the job is finished. 

46 **kwargs 

47 Additional keyword arguments to be passed to the SlurmHeader, 

48 these are the SLURM parameters for the job. Supports all the  

49 arguments for sbatch, see https://slurm.schedmd.com/sbatch.html. 

50 """ 

51 if safety_check: 

52 self.safety_check(frame_index=2) 

53 

54 self.header = SlurmHeader() 

55 for key, value in kwargs.items(): 

56 self.header.add(key, value) 

57 

58 if directory is None: 

59 directory = 'slurm' 

60 self.directory = Path(directory) 

61 

62 # Containers for commands to be executed before and after the main execution 

63 self._commands_pre_execution = [] 

64 self._commands_post_execution = [] 

65 

66 # Special options 

67 self.ulimit = ulimit 

68 self.move_output_file = move_output_file 

69 self.scratch = scratch 

70 self.array_limit = array_limit 

71 

72 # To kepe track of the special options for saving the config 

73 self.special_options = { 

74 'scratch': scratch, 

75 'ulimit': ulimit, 

76 'move_output_file': move_output_file, 

77 'array_limit': array_limit, 

78 } 

79 

80 

81 @classmethod 

82 def from_config(cls, path: Path, safety_check: bool = True, **kwargs) -> 'SlurmExecutor': 

83 

84 if safety_check: 

85 cls.safety_check(frame_index=1) 

86 

87 if not isinstance(path, Path): 

88 path = Path(path) 

89 

90 assert path.exists(), f'File {path} does not exist' 

91 assert path.suffix == '.json', f'File {path} is not a json file' 

92 

93 with open(path) as f: 

94 config = json.load(f) 

95 config.update(kwargs) 

96 

97 pre_commands = config.pop('commands_pre_execution', list()) 

98 post_commands = config.pop('commands_post_execution', list()) 

99 

100 instance = cls(**config, safety_check=False) 

101 instance._commands_pre_execution = pre_commands 

102 instance._commands_post_execution = post_commands 

103 return instance 

104 

105 def to_config(self, path: Path | str) -> None: 

106 if not isinstance(path, Path): 

107 path = Path(path) 

108 config = self.header.to_dict() 

109 config.update(self.special_options) 

110 

111 config['commands_pre_execution'] = self._commands_pre_execution 

112 config['commands_post_execution'] = self._commands_post_execution 

113 

114 with open(path, 'w') as f: 

115 json.dump(config, f, indent=4) 

116 

117 @classmethod 

118 def from_profile(cls, name: str, safety_check: bool = True, **kwargs) -> 'SlurmExecutor': 

119 """ 

120 Create a new SlurmExecutor from a profile. 

121 

122 Parameters 

123 ---------- 

124 name : str 

125 Name of the profile. 

126 safety_check : bool, optional 

127 If True, a safety check will be performed to ensure that the executor 

128 is not instantiated on a script that is not the main script. Defaults to True. 

129 **kwargs 

130 Additional keyword arguments to be passed to the SlurmExecutor. 

131 """ 

132 if safety_check: 

133 cls.safety_check(frame_index=2) 

134 kwargs.pop("safety_check", None) 

135 

136 spm = SlurmProfileManager() 

137 profile = spm.get_profile_path(name) 

138 return cls.from_config(profile, **kwargs, safety_check=False) 

139 

140 def _single_execute(self) -> None: 

141 raise NotImplementedError('Single execution is not supported for SLURM Executor, everything is executed with _sequence execute.') 

142 

143 def _sequence_execute( 

144 self, 

145 experiments: List[Experiment], 

146 dry: bool = False, 

147 execution_directory: Union[Path, str] = None, 

148 ) -> List[FutureResult]: 

149 """ 

150 Execute a sequence of experiments as an array job. 

151 

152 Parameters 

153 ---------- 

154 experiments : List[Experiment] 

155 List of experiments to be executed. 

156 dry : bool, optional 

157 If True, the script will be printed instead of executed. 

158 execution_directory : Union[Path, str], optional 

159 Directory where the experiments will be executed. 

160 

161 Returns 

162 ------- 

163 List[FutureResult] 

164 List of FutureResult objects. 

165 """ 

166 

167 if len(experiments) == 0: 

168 return [] 

169 

170 # Dump config: 

171 self.directory.mkdir(parents=True, exist_ok=True) 

172 index = len(list(self.directory.glob('config*.json'))) 

173 path = self.directory / f'config_{index}.json' 

174 self.to_config(path) 

175 

176 header = self.header.copy() 

177 if self.array_limit is not None: 

178 array_limit = min(self.array_limit, len(experiments)) 

179 else: 

180 array_limit = len(experiments) 

181 

182 header.add('array', f'0-{len(experiments)-1}%{array_limit}') 

183 

184 body = self._make_slurm_body(experiments) 

185 

186 count = len(list(self.directory.glob('submit*.sh'))) 

187 script = SlurmScript(header, body, directory=self.directory, name=f'submit_{count}.sh') 

188 if dry: 

189 print(script) 

190 return [FutureResult() for _ in experiments] 

191 

192 script.write() 

193 

194 job_id = script.submit() 

195 for experiment in experiments: 

196 experiment.update_status('submitted') 

197 

198 return [FutureResult(info={'job_id': job_id}) for _ in experiments] 

199 

200 def _bash_array_str(self, strings: List[str]) -> str: 

201 """ 

202 Convert a list of strings into a nicely formatted bash array of string. 

203 

204 Parameters 

205 ---------- 

206 strings : List[str] 

207 List of strings to be converted. 

208 

209 Returns 

210 ------- 

211 str 

212 A python string representing a bash array of strings. 

213 """ 

214 bash_str = ' \n\t'.join(strings) 

215 return f'(\n\t{bash_str}\n)' 

216 

217 def _body_add(self, command: str, when: Optional[Literal['pre', 'post']] = None) -> None: 

218 """ 

219 Add a command to the body of the SLURM script. 

220 

221 Parameters 

222 ---------- 

223 command : str 

224 Command to be added to the body. 

225 """ 

226 if when is None: 

227 when = 'pre' 

228 

229 if when == 'pre': 

230 self._commands_pre_execution.append(command) 

231 

232 elif when == 'post': 

233 self._commands_post_execution.append(command) 

234 

235 def _make_slurm_body(self, experiments: List[Experiment]) -> SlurmBody: 

236 """ 

237 Make a new SlurmBody object. 

238 

239 Returns 

240 ------- 

241 SlurmBody 

242 A new SlurmBody object. 

243 """ 

244 

245 identifiers = [str(experiment.identifier) for experiment in experiments] 

246 directories = [str(experiment.directory.resolve()) for experiment in experiments] 

247 

248 body = SlurmBody() 

249 

250 body.add(f'directories={self._bash_array_str(directories)}') 

251 body.add(f'identifiers={self._bash_array_str(identifiers)}') 

252 

253 if self.move_output_file: 

254 self._body_add(r"mv slurm-${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}.out ${directories[$SLURM_ARRAY_TASK_ID]}", when='pre') 

255 if self.ulimit != 'default': 

256 self._body_add(f'ulimit -Su {self.ulimit}', when='pre') 

257 

258 for command in self._commands_pre_execution: 

259 body.add(command) 

260 

261 # Slurm info command: 

262 command = r'hex slurm add-info -d ${directories[$SLURM_ARRAY_TASK_ID]} -j "${SLURM_ARRAY_JOB_ID}_${SLURM_ARRAY_TASK_ID}"' 

263 body.add(command) 

264 

265 # Execution command 

266 command = r'hex execute ${directories[$SLURM_ARRAY_TASK_ID]}' 

267 

268 if self.scratch: 

269 command += ' -e /scratch/$SLURM_JOB_ID' 

270 

271 body.add(command) 

272 

273 for command in self._commands_post_execution: 

274 body.add(command) 

275 

276 if self.scratch: 

277 body.add( 

278 r'cp -r /scratch/$SLURM_JOB_ID/* ${directories[$SLURM_ARRAY_TASK_ID]}' 

279 ) 

280 

281 return body 

282 

283 @staticmethod 

284 def safety_check(frame_index: int = 2) -> None: 

285 """ 

286 Check if the executor is being called from the main script. 

287 

288 Parameters 

289 ---------- 

290 frame_index : int, optional 

291 Index of the frame to be checked. Defaults to 2. 

292  

293 Raises 

294 ------ 

295 SlurmSafetyError 

296 If the executor is not being called from the main script. 

297 

298 Frame index depends on which creation method is used: 

299 - from_profile: 2 

300 - from_config: 1 

301 - __init__: 0 

302 """ 

303 

304 caller_frames = inspect.stack() 

305 caller_frame = caller_frames[frame_index] 

306 caller_module = inspect.getmodule(caller_frame[0]) 

307 

308 if caller_module and caller_module.__name__ != "__main__" or caller_module is None: 

309 raise SlurmSafetyError("""SlurmExecutor should only be called from the main script. 

310 If the you really want, you can disable this check. This error may be caused by not having 

311 a 'if __name__ == "__main__":' block in the main script.""") 

312 

313