1import os
2import inspect
3import numpy as np
4import pickle as pkl
5import cloudpickle as cpkl
6from typing import Optional, Literal, IO, Tuple
7
8from ._meta_class import SaveAfterInitMetaClass
9
[docs]
10class SaveableObject(metaclass=SaveAfterInitMetaClass):
11 """A utility class for saving objects to pickles and checkpointing.
12 """
[docs]
13 def __init__(self, path: Optional[str] = None):
14 """Initialises an instance of the :class:`SaveableObject`.
15 If a path is specified the saveable object is automatically saved after
16 initialisation.
17
18 Parameters
19 ----------
20 path : str, optional
21 The file :attr:`path` to save the object to. If ``None`` then the
22 object is not saved. By default ``None``.
23
24 Notes
25 -----
26 If no file extension is provided for `path` then the class name and the
27 ``.pkl`` extension are appended to the file name.
28 """
29 self.path = path
30 @property
31 def path(self) -> Optional[str]:
32 """The current file path of the object.
33
34 Notes
35 -----
36 On setting the value, if no file extension is provided then the class
37 name and the ``.pkl`` extension are appended to the file name.
38 """
39 return self._path
40 @path.setter
41 def path(self, value: Optional[str]):
42 """Set the current file path of the object.
43
44 Parameters
45 ----------
46 value : str, optional
47 The new file path of the object. By default ``None``.
48
49 Notes
50 -----
51 If no file extension is provided then the class name and the ``.pkl``
52 extension are appended to the file name.
53 """
54 self._path = self._updatepathroot(value)
55 @classmethod
56 def _get_name(cls, path: Optional[str]) -> Optional[str]:
57 """Returns the file name of the specified path (without the file
58 extension).
59
60 Parameters
61 ----------
62 path : str, optional
63 The path to obtain the file name for.
64
65 Returns
66 -------
67 str, optional
68 The file name (without the file extension).
69 """
70 if path is None:
71 return None
72 return os.path.split(os.path.splitext(cls._updatepathroot(path))[0])[-1]
73 @property
74 def name(self) -> Optional[str]:
75 """The file name of the object (without the file extension). Note that
76 `name` is read only.
77 """
78 return self._get_name(self._path)
79
80 def _save(self,
81 path: str,
82 write_mode: Literal["w", "wb", "a", "ab", "x", "xb"] = "wb"):
83 """Saves the object to `path` using `write_mode`.
84
85 Parameters
86 ----------
87 path : str
88 The path to save the object to.
89 write_mode : Literal["w", "wb", "a", "ab", "x", "xb"], optional
90 The mode with which to open the file to write to. These are the same
91 as `mode` for ``open``. By default ``"wb"``.
92 """
93 if not os.path.exists(path):
94 dirname = os.path.dirname(path)
95 if dirname != '':
96 os.makedirs(dirname, exist_ok=True)
97 with open(path, write_mode) as file:
98 cpkl.dump(self, file, pkl.HIGHEST_PROTOCOL)
99 def _getpath(self, path: Optional[str]) -> str:
100 """Returns the specified or saved :attr:`path`.
101
102 Parameters
103 ----------
104 path : str, optional
105 Specified path.
106
107 Returns
108 -------
109 str
110 Returns the specified or saved :attr:`path`.
111
112 Raises
113 ------
114 ValueError
115 No save path provided. Raised if no path is saved or specified.
116 """
117 path = path if path is not None or not hasattr(self, 'path') else self.path
118 if path is None:
119 raise ValueError("No save path provided.")
120 path = self._updatepathroot(path)
121 return path
122 @classmethod
123 def _updatepathroot(cls, path: Optional[str]) -> Optional[str]:
124 """If no file extension is provided then the class name and the ``.pkl``
125 extension are appended to the file name.
126
127 Parameters
128 ----------
129 path : str, optional
130 The file path.
131
132 Returns
133 -------
134 str, optional
135 The modified file path.
136 """
137 if path is None:
138 return None
139 split = os.path.splitext(path)
140 if len(split[1]) == 0:
141 file_name = os.path.split(split[0])[-1]
142 prefix = "_" if len(file_name) != 0 and file_name[-1] != "_" else ""
143 path += prefix + cls.__name__ + ".pkl"
144 return path
[docs]
145 def save(self, path: Optional[str] = None):
146 """Pickles the current instance.
147
148 Parameters
149 ----------
150 path : str, optional
151 The path to pickle the instance to. If ``None`` is specified
152 then the attribute :attr:`path` is used instead.
153 By default ``None``.
154
155 Raises
156 ------
157 ValueError
158 Raised if no path specified either by the parameter `path` or the
159 attribute :attr:`path`.
160
161 Notes
162 -----
163 If no file extension is provided then the class name and the ``.pkl``
164 extension are appended to the file name.
165 """
166 self.path = self._getpath(path)
167 self._save(self.path)
[docs]
168 def update_save(self, path: Optional[str] = None) -> bool:
169 """Pickles the current instance and retains the saved arguments if
170 they exist.
171
172 Parameters
173 ----------
174 path : str, optional
175 The path to pickle the instance to. If ``None`` is specified
176 then the attribute :attr:`path` is used instead. By default
177 ``None``.
178
179 Returns
180 -------
181 bool
182 ``True`` if there was an argument pickle to retain. ``False`` if
183 there was not an argument pickle to retain.
184
185 Raises
186 ------
187 ValueError
188 Raised if no path specified.
189
190 Notes
191 -----
192 If no file extension is provided then the class name and the ``.pkl``
193 extension are appended to the file name.
194 """
195 self.path = self._getpath(path)
196 file = open(self.path, "rb")
197 # Throw away the prior save:
198 try:
199 type(self)._load(file)
200 except:
201 pass
202 # Retain the parameters:
203 try:
204 params = pkl.load(file)
205 except EOFError:
206 # Close the file before writing to it
207 file.close()
208 self._save(self.path)
209 return False
210 else:
211 # Close the file before writing to it
212 file.close()
213 self._save(self.path)
214 SaveableObject._save(params, self.path, write_mode="ab")
215 return True
216
217 @classmethod
218 def _load(cls,
219 file: IO,
220 new_path: Optional[str] = None,
221 strict_typing: bool = True
222 ) -> "SaveableObject":
223 """Loads an instance from the `file`.
224
225 Parameters
226 ----------
227 file : IO
228 The file to load the instance from.
229 new_path : str, optional
230 The path to replace the previous path with. If ``None`` the `path`
231 is not replaced. By default ``None``.
232 strict_typing : bool, optional
233 If ``True`` then the loaded instance must be an instance of `cls`.
234 By default ``True``.
235
236 Returns
237 -------
238 `cls`
239 The loaded instance.
240
241 Raises
242 ------
243 TypeError
244 If `strict_typing` and the loaded instance is not an instance of
245 `cls`.
246
247 Notes
248 -----
249 ``strict_typing=True`` acts as a safety guard. Setting
250 ``strict_typing=False`` may increase the probability of unexpected or
251 uncaught errors.
252 """
253 instance = pkl.load(file)
254 if strict_typing and not isinstance(instance, cls):
255 raise TypeError(f"The loaded instance is not an instance of {cls}.")
256 if new_path is not None:
257 instance.path = new_path
258 return instance
[docs]
259 @classmethod
260 def load(cls,
261 path: str,
262 new_path: Optional[str] = None,
263 strict_typing: bool = True
264 ) -> "SaveableObject":
265 """Loads a pickled instance.
266
267 Parameters
268 ----------
269 path : str
270 The path of the pickle.
271 new_path : str, optional
272 The path to replace the previous path with. If ``None`` the `path`
273 is not replaced. By default ``None``.
274 strict_typing : bool, optional
275 If ``True`` then the loaded instance must be an instance of `cls`.
276 By default ``True``.
277
278 Returns
279 -------
280 SaveableObject
281 The loaded instance.
282
283 Raises
284 ------
285 TypeError
286 If `strict_typing` and the loaded instance is not an instance of
287 `cls`.
288
289 Notes
290 -----
291 ``strict_typing=True`` acts as a safety guard. Setting
292 ``strict_typing=False`` may increase the probability of unexpected or
293 uncaught errors.
294 """
295 path = cls._updatepathroot(path)
296 with open(path, "rb") as file:
297 return cls._load(file, new_path, strict_typing)
[docs]
298 @classmethod
299 def tryload(cls,
300 path: Optional[str],
301 new_path: Optional[str] = None,
302 strict_typing: bool = True
303 ) -> "SaveableObject" | Literal[False]:
304 """Attempts to :meth:`load` from the specified `path`. If the loading
305 fails then ``False`` is returned.
306
307 Parameters
308 ----------
309 path : str, optional
310 The path of the pickle. If ``None`` then ``False`` is returned.
311 new_path : str, optional
312 The path to replace the previous path with. If ``None`` the `path`
313 is not replaced. By default ``None``.
314 strict_typing : bool, optional
315 If ``True`` then the loaded instance must be an instance of `cls`.
316 By default ``True``.
317
318 Returns
319 -------
320 SaveableObject | Literal[False]
321 If succeeded the loaded instance, else False.
322
323 Notes
324 -----
325 ``strict_typing=True`` acts as a safety guard. Setting
326 ``strict_typing=False`` may increase the probability of unexpected or
327 uncaught errors.
328 """
329 try:
330 return cls.load(path, new_path, strict_typing)
331 except (FileNotFoundError, TypeError):
332 return False
[docs]
333 @classmethod
334 def loadif(cls, *args, **kwargs) -> Tuple["SaveableObject", bool]:
335 """Attempts to load from a specified `path`. If the loading fails or no
336 `path` is specified then a new instance of the object is generated with
337 the specified `*args` and `**kwargs`.
338
339 Parameters
340 ----------
341 *args
342 The arguments to pass to the initialisation on a failed
343 :meth:`load`.
344 path : str, optional
345 The path of the pickle, by default the parameter is not specified.
346 **kwargs
347 The keyword arguments to pass to the initialisation on a failed
348 :meth:`load`.
349
350 Returns
351 -------
352 (SaveableObject, bool)
353 The loaded or initialised instance followed by ``True`` if the
354 instance was loaded and ``False`` if the instance was initialised.
355 """
356 bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs)
357 try:
358 path = bound_args.arguments["path"]
359 except KeyError:
360 path = None
361 instance = cls.tryload(path)
362 if instance:
363 return instance, True
364 return cls(*args, **kwargs), False
[docs]
365 @classmethod
366 def loadifparams(cls,
367 *args,
368 dependencies: dict = {},
369 **kwargs
370 ) -> Tuple["SaveableObject", bool]:
371 """Attempts to :meth:`load` from a specified `path`. If the loading
372 fails or no `path` is specified or the parameters do not match the saved
373 parameters then a new instance of the object is generated with the
374 specified `*args` and `**kwargs`.
375
376 Parameters
377 ----------
378 *args
379 The arguments to pass to the initialisation on a failed
380 :meth:`load`.
381 path : str, optional
382 The path of the pickle, by default the parameter is not specified.
383 dependencies : dict, optional, must be specified as a keyword argument
384 A dictionary of additional dependencies to check.
385 **kwargs
386 The keyword arguments to pass to the initialisation on a failed
387 :meth:`load`.
388
389 Returns
390 -------
391 (SaveableObject, bool)
392 The loaded or initialised instance followed by ``True`` if the
393 instance was loaded and ``False`` if the instance was initialised.
394 """
395 bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs)
396 try:
397 path = bound_args.arguments["path"]
398 except KeyError:
399 path = None
400 path = cls._updatepathroot(path)
401 duplicates = []
402 for key in dependencies.keys():
403 if key in bound_args.arguments.keys():
404 duplicates.append(key)
405 if len(duplicates) != 0:
406 raise TypeError(f"The dependencies {duplicates} are also arguments. They must have different names.")
407 arguments = {**bound_args.arguments, **dependencies}
408 try:
409 with open(path, "rb") as file:
410 instance = cls._load(file)
411 params = pkl.load(file)
412 for key, value in arguments.items():
413 comparison = params[key] != value
414 if isinstance(comparison, bool):
415 if comparison:
416 raise ValueError
417 else:
418 if not np.array_equal(params[key], value):
419 raise ValueError
420 return instance, True
421 except (FileNotFoundError, EOFError, ValueError, TypeError, KeyError):
422 instance = cls(*args, **kwargs)
423 if path is not None:
424 SaveableObject._save(arguments, path, write_mode="ab")
425 return instance, False