1import os
2import inspect
3import numpy as np
4import pickle as pkl
5import cloudpickle as cpkl
6from typing import Optional, Literal, IO, Tuple
7
8from ._meta_class import SaveAfterInitMetaClass
9
[docs]
10class SaveableObject(metaclass=SaveAfterInitMetaClass):
11 """A utility class for saving objects to pickles and checkpointing.
12 """
[docs]
13 def __init__(self, path: Optional[str] = None):
14 """Initialises an instance of the :class:`SaveableObject`.
15 If a path is specified the saveable object is automatically saved after
16 initialisation.
17
18 Parameters
19 ----------
20 path : str, optional
21 The file :attr:`path` to save the object to. If ``None`` then the
22 object is not saved. By default ``None``.
23
24 Notes
25 -----
26 If no file extension is provided for `path` then the class name and the
27 ``.pkl`` extension are appended to the file name.
28 """
29 self.path = path
30 @property
31 def path(self) -> Optional[str]:
32 """The current file path of the object.
33
34 Notes
35 -----
36 On setting the value, if no file extension is provided then the class
37 name and the ``.pkl`` extension are appended to the file name.
38 """
39 return self._path
40 @path.setter
41 def path(self, value: Optional[str]):
42 """Set the current file path of the object.
43
44 Parameters
45 ----------
46 value : str, optional
47 The new file path of the object. By default ``None``.
48
49 Notes
50 -----
51 If no file extension is provided then the class name and the ``.pkl``
52 extension are appended to the file name.
53 """
54 self._path = self._updatepathroot(value)
55 @classmethod
56 def _get_name(cls, path: Optional[str]) -> Optional[str]:
57 """Returns the file name of the specified path (without the file
58 extension).
59
60 Parameters
61 ----------
62 path : str, optional
63 The path to obtain the file name for.
64
65 Returns
66 -------
67 str, optional
68 The file name (without the file extension).
69 """
70 if path is None:
71 return None
72 return os.path.split(os.path.splitext(cls._updatepathroot(path))[0])[-1]
73 @property
74 def name(self) -> Optional[str]:
75 """The file name of the object (without the file extension). Note that
76 `name` is read only.
77 """
78 return self._get_name(self._path)
79
80 def _save(self, path: str, write_mode: Literal["w", "wb", "a", "ab", "x", "xb"] = "wb"):
81 """Saves the object to `path` using `write_mode`.
82
83 Parameters
84 ----------
85 path : str
86 The path to save the object to.
87 write_mode : Literal["w", "wb", "a", "ab", "x", "xb"], optional
88 The mode with which to open the file to write to. These are the same
89 as `mode` for ``open``. By default ``"wb"``.
90 """
91 if not os.path.exists(path):
92 dirname = os.path.dirname(path)
93 if dirname != '':
94 os.makedirs(dirname, exist_ok=True)
95 with open(path, write_mode) as file:
96 cpkl.dump(self, file, pkl.HIGHEST_PROTOCOL)
97 def _getpath(self, path: Optional[str]) -> str:
98 """Returns the specified or saved :attr:`path`.
99
100 Parameters
101 ----------
102 path : str, optional
103 Specified path.
104
105 Returns
106 -------
107 str
108 Returns the specified or saved :attr:`path`.
109
110 Raises
111 ------
112 ValueError
113 No save path provided. Raised if no path is saved or specified.
114 """
115 path = path if path is not None or not hasattr(self, 'path') else self.path
116 if path is None:
117 raise ValueError("No save path provided.")
118 path = self._updatepathroot(path)
119 return path
120 @classmethod
121 def _updatepathroot(cls, path: Optional[str]) -> Optional[str]:
122 """If no file extension is provided then the class name and the ``.pkl``
123 extension are appended to the file name.
124
125 Parameters
126 ----------
127 path : str, optional
128 The file path.
129
130 Returns
131 -------
132 str, optional
133 The modified file path.
134 """
135 if path is None:
136 return None
137 split = os.path.splitext(path)
138 if len(split[1]) == 0:
139 file_name = os.path.split(split[0])[-1]
140 prefix = "_" if len(file_name) != 0 and file_name[-1] != "_" else ""
141 path += prefix + cls.__name__ + ".pkl"
142 return path
[docs]
143 def save(self, path: Optional[str] = None):
144 """Pickles the current instance.
145
146 Parameters
147 ----------
148 path : str, optional
149 The path to pickle the instance to. If ``None`` is specified
150 then the attribute :attr:`path` is used instead.
151 By default ``None``.
152
153 Raises
154 ------
155 ValueError
156 Raised if no path specified either by the parameter `path` or the
157 attribute :attr:`path`.
158
159 Notes
160 -----
161 If no file extension is provided then the class name and the ``.pkl``
162 extension are appended to the file name.
163 """
164 self.path = self._getpath(path)
165 self._save(self.path)
[docs]
166 def update_save(self, path: Optional[str] = None) -> bool:
167 """Pickles the current instance and retains the saved arguments if
168 they exist.
169
170 Parameters
171 ----------
172 path : str, optional
173 The path to pickle the instance to. If ``None`` is specified
174 then the attribute :attr:`path` is used instead. By default
175 ``None``.
176
177 Returns
178 -------
179 bool
180 ``True`` if there was an argument pickle to retain. ``False`` if
181 there was not an argument pickle to retain.
182
183 Raises
184 ------
185 ValueError
186 Raised if no path specified.
187
188 Notes
189 -----
190 If no file extension is provided then the class name and the ``.pkl``
191 extension are appended to the file name.
192 """
193 self.path = self._getpath(path)
194 file = open(self.path, "rb")
195 # Throw away the prior save:
196 try:
197 type(self)._load(file)
198 except:
199 pass
200 # Retain the parameters:
201 try:
202 params = pkl.load(file)
203 except EOFError:
204 # Close the file before writing to it
205 file.close()
206 self._save(self.path)
207 return False
208 else:
209 # Close the file before writing to it
210 file.close()
211 self._save(self.path)
212 SaveableObject._save(params, self.path, write_mode="ab")
213 return True
214
215 @classmethod
216 def _load(cls, file: IO, new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject":
217 """Loads an instance from the `file`.
218
219 Parameters
220 ----------
221 file : IO
222 The file to load the instance from.
223 new_path : str, optional
224 The path to replace the previous path with. If ``None`` the `path`
225 is not replaced. By default ``None``.
226 strict_typing : bool, optional
227 If ``True`` then the loaded instance must be an instance of `cls`.
228 By default ``True``.
229
230 Returns
231 -------
232 `cls`
233 The loaded instance.
234
235 Raises
236 ------
237 TypeError
238 If `strict_typing` and the loaded instance is not an instance of
239 `cls`.
240
241 Notes
242 -----
243 ``strict_typing=True`` acts as a safety guard. Setting
244 ``strict_typing=False`` may increase the probability of unexpected or
245 uncaught errors.
246 """
247 instance = pkl.load(file)
248 if strict_typing and not isinstance(instance, cls):
249 raise TypeError(f"The loaded instance is not an instance of {cls}.")
250 if new_path is not None:
251 instance.path = new_path
252 return instance
[docs]
253 @classmethod
254 def load(cls, path: str, new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject":
255 """Loads a pickled instance.
256
257 Parameters
258 ----------
259 path : str
260 The path of the pickle.
261 new_path : str, optional
262 The path to replace the previous path with. If ``None`` the `path`
263 is not replaced. By default ``None``.
264 strict_typing : bool, optional
265 If ``True`` then the loaded instance must be an instance of `cls`.
266 By default ``True``.
267
268 Returns
269 -------
270 SaveableObject
271 The loaded instance.
272
273 Raises
274 ------
275 TypeError
276 If `strict_typing` and the loaded instance is not an instance of
277 `cls`.
278
279 Notes
280 -----
281 ``strict_typing=True`` acts as a safety guard. Setting
282 ``strict_typing=False`` may increase the probability of unexpected or
283 uncaught errors.
284 """
285 path = cls._updatepathroot(path)
286 with open(path, "rb") as file:
287 return cls._load(file, new_path, strict_typing)
[docs]
288 @classmethod
289 def tryload(cls, path: Optional[str], new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject" | Literal[False]:
290 """Attempts to :meth:`load` from the specified `path`. If the loading
291 fails then ``False`` is returned.
292
293 Parameters
294 ----------
295 path : str, optional
296 The path of the pickle. If ``None`` then ``False`` is returned.
297 new_path : str, optional
298 The path to replace the previous path with. If ``None`` the `path`
299 is not replaced. By default ``None``.
300 strict_typing : bool, optional
301 If ``True`` then the loaded instance must be an instance of `cls`.
302 By default ``True``.
303
304 Returns
305 -------
306 SaveableObject | Literal[False]
307 If succeeded the loaded instance, else False.
308
309 Notes
310 -----
311 ``strict_typing=True`` acts as a safety guard. Setting
312 ``strict_typing=False`` may increase the probability of unexpected or
313 uncaught errors.
314 """
315 try:
316 return cls.load(path, new_path, strict_typing)
317 except (FileNotFoundError, TypeError):
318 return False
[docs]
319 @classmethod
320 def loadif(cls, *args, **kwargs) -> Tuple["SaveableObject", bool]:
321 """Attempts to load from a specified `path`. If the loading fails or no
322 `path` is specified then a new instance of the object is generated with
323 the specified `*args` and `**kwargs`.
324
325 Parameters
326 ----------
327 *args
328 The arguments to pass to the initialisation on a failed
329 :meth:`load`.
330 path : str, optional
331 The path of the pickle, by default the parameter is not specified.
332 **kwargs
333 The keyword arguments to pass to the initialisation on a failed
334 :meth:`load`.
335
336 Returns
337 -------
338 (SaveableObject, bool)
339 The loaded or initialised instance followed by ``True`` if the
340 instance was loaded and ``False`` if the instance was initialised.
341 """
342 bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs)
343 try:
344 path = bound_args.arguments["path"]
345 except KeyError:
346 path = None
347 instance = cls.tryload(path)
348 if instance:
349 return instance, True
350 return cls(*args, **kwargs), False
[docs]
351 @classmethod
352 def loadifparams(cls, *args, dependencies: dict = {}, **kwargs) -> Tuple["SaveableObject", bool]:
353 """Attempts to :meth:`load` from a specified `path`. If the loading
354 fails or no `path` is specified or the parameters do not match the saved
355 parameters then a new instance of the object is generated with the
356 specified `*args` and `**kwargs`.
357
358 Parameters
359 ----------
360 *args
361 The arguments to pass to the initialisation on a failed
362 :meth:`load`.
363 path : str, optional
364 The path of the pickle, by default the parameter is not specified.
365 dependencies : dict, optional, must be specified as a keyword argument
366 A dictionary of additional dependencies to check.
367 **kwargs
368 The keyword arguments to pass to the initialisation on a failed
369 :meth:`load`.
370
371 Returns
372 -------
373 (SaveableObject, bool)
374 The loaded or initialised instance followed by ``True`` if the
375 instance was loaded and ``False`` if the instance was initialised.
376 """
377 bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs)
378 try:
379 path = bound_args.arguments["path"]
380 except KeyError:
381 path = None
382 path = cls._updatepathroot(path)
383 duplicates = []
384 for key in dependencies.keys():
385 if key in bound_args.arguments.keys():
386 duplicates.append(key)
387 if len(duplicates) != 0:
388 raise TypeError(f"The dependencies {duplicates} are also arguments. They must have different names.")
389 arguments = {**bound_args.arguments, **dependencies}
390 try:
391 with open(path, "rb") as file:
392 instance = cls._load(file)
393 params = pkl.load(file)
394 for key, value in arguments.items():
395 comparison = params[key] != value
396 if isinstance(comparison, bool):
397 if comparison:
398 raise ValueError
399 else:
400 if not np.array_equal(params[key], value):
401 raise ValueError
402 return instance, True
403 except (FileNotFoundError, EOFError, ValueError, TypeError, KeyError):
404 instance = cls(*args, **kwargs)
405 if path is not None:
406 SaveableObject._save(arguments, path, write_mode="ab")
407 return instance, False