Source code for saveable_objects._saveable_object

  1import os
  2import inspect
  3import numpy as np
  4import pickle as pkl
  5import cloudpickle as cpkl
  6from typing import Optional, Literal, IO, Tuple
  7
  8from ._meta_class import SaveAfterInitMetaClass
  9
[docs] 10class SaveableObject(metaclass=SaveAfterInitMetaClass): 11 """A utility class for saving objects to pickles and checkpointing. 12 """
[docs] 13 def __init__(self, path: Optional[str] = None): 14 """Initialises an instance of the :class:`SaveableObject`. 15 If a path is specified the saveable object is automatically saved after 16 initialisation. 17 18 Parameters 19 ---------- 20 path : str, optional 21 The file :attr:`path` to save the object to. If ``None`` then the 22 object is not saved. By default ``None``. 23 24 Notes 25 ----- 26 If no file extension is provided for `path` then the class name and the 27 ``.pkl`` extension are appended to the file name. 28 """ 29 self.path = path
30 @property 31 def path(self) -> Optional[str]: 32 """The current file path of the object. 33 34 Notes 35 ----- 36 On setting the value, if no file extension is provided then the class 37 name and the ``.pkl`` extension are appended to the file name. 38 """ 39 return self._path 40 @path.setter 41 def path(self, value: Optional[str]): 42 """Set the current file path of the object. 43 44 Parameters 45 ---------- 46 value : str, optional 47 The new file path of the object. By default ``None``. 48 49 Notes 50 ----- 51 If no file extension is provided then the class name and the ``.pkl`` 52 extension are appended to the file name. 53 """ 54 self._path = self._updatepathroot(value) 55 @classmethod 56 def _get_name(cls, path: Optional[str]) -> Optional[str]: 57 """Returns the file name of the specified path (without the file 58 extension). 59 60 Parameters 61 ---------- 62 path : str, optional 63 The path to obtain the file name for. 64 65 Returns 66 ------- 67 str, optional 68 The file name (without the file extension). 69 """ 70 if path is None: 71 return None 72 return os.path.split(os.path.splitext(cls._updatepathroot(path))[0])[-1] 73 @property 74 def name(self) -> Optional[str]: 75 """The file name of the object (without the file extension). Note that 76 `name` is read only. 77 """ 78 return self._get_name(self._path) 79 80 def _save(self, path: str, write_mode: Literal["w", "wb", "a", "ab", "x", "xb"] = "wb"): 81 """Saves the object to `path` using `write_mode`. 82 83 Parameters 84 ---------- 85 path : str 86 The path to save the object to. 87 write_mode : Literal["w", "wb", "a", "ab", "x", "xb"], optional 88 The mode with which to open the file to write to. These are the same 89 as `mode` for ``open``. By default ``"wb"``. 90 """ 91 if not os.path.exists(path): 92 dirname = os.path.dirname(path) 93 if dirname != '': 94 os.makedirs(dirname, exist_ok=True) 95 with open(path, write_mode) as file: 96 cpkl.dump(self, file, pkl.HIGHEST_PROTOCOL) 97 def _getpath(self, path: Optional[str]) -> str: 98 """Returns the specified or saved :attr:`path`. 99 100 Parameters 101 ---------- 102 path : str, optional 103 Specified path. 104 105 Returns 106 ------- 107 str 108 Returns the specified or saved :attr:`path`. 109 110 Raises 111 ------ 112 ValueError 113 No save path provided. Raised if no path is saved or specified. 114 """ 115 path = path if path is not None or not hasattr(self, 'path') else self.path 116 if path is None: 117 raise ValueError("No save path provided.") 118 path = self._updatepathroot(path) 119 return path 120 @classmethod 121 def _updatepathroot(cls, path: Optional[str]) -> Optional[str]: 122 """If no file extension is provided then the class name and the ``.pkl`` 123 extension are appended to the file name. 124 125 Parameters 126 ---------- 127 path : str, optional 128 The file path. 129 130 Returns 131 ------- 132 str, optional 133 The modified file path. 134 """ 135 if path is None: 136 return None 137 split = os.path.splitext(path) 138 if len(split[1]) == 0: 139 file_name = os.path.split(split[0])[-1] 140 prefix = "_" if len(file_name) != 0 and file_name[-1] != "_" else "" 141 path += prefix + cls.__name__ + ".pkl" 142 return path
[docs] 143 def save(self, path: Optional[str] = None): 144 """Pickles the current instance. 145 146 Parameters 147 ---------- 148 path : str, optional 149 The path to pickle the instance to. If ``None`` is specified 150 then the attribute :attr:`path` is used instead. 151 By default ``None``. 152 153 Raises 154 ------ 155 ValueError 156 Raised if no path specified either by the parameter `path` or the 157 attribute :attr:`path`. 158 159 Notes 160 ----- 161 If no file extension is provided then the class name and the ``.pkl`` 162 extension are appended to the file name. 163 """ 164 self.path = self._getpath(path) 165 self._save(self.path)
[docs] 166 def update_save(self, path: Optional[str] = None) -> bool: 167 """Pickles the current instance and retains the saved arguments if 168 they exist. 169 170 Parameters 171 ---------- 172 path : str, optional 173 The path to pickle the instance to. If ``None`` is specified 174 then the attribute :attr:`path` is used instead. By default 175 ``None``. 176 177 Returns 178 ------- 179 bool 180 ``True`` if there was an argument pickle to retain. ``False`` if 181 there was not an argument pickle to retain. 182 183 Raises 184 ------ 185 ValueError 186 Raised if no path specified. 187 188 Notes 189 ----- 190 If no file extension is provided then the class name and the ``.pkl`` 191 extension are appended to the file name. 192 """ 193 self.path = self._getpath(path) 194 file = open(self.path, "rb") 195 # Throw away the prior save: 196 try: 197 type(self)._load(file) 198 except: 199 pass 200 # Retain the parameters: 201 try: 202 params = pkl.load(file) 203 except EOFError: 204 # Close the file before writing to it 205 file.close() 206 self._save(self.path) 207 return False 208 else: 209 # Close the file before writing to it 210 file.close() 211 self._save(self.path) 212 SaveableObject._save(params, self.path, write_mode="ab") 213 return True
214 215 @classmethod 216 def _load(cls, file: IO, new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject": 217 """Loads an instance from the `file`. 218 219 Parameters 220 ---------- 221 file : IO 222 The file to load the instance from. 223 new_path : str, optional 224 The path to replace the previous path with. If ``None`` the `path` 225 is not replaced. By default ``None``. 226 strict_typing : bool, optional 227 If ``True`` then the loaded instance must be an instance of `cls`. 228 By default ``True``. 229 230 Returns 231 ------- 232 `cls` 233 The loaded instance. 234 235 Raises 236 ------ 237 TypeError 238 If `strict_typing` and the loaded instance is not an instance of 239 `cls`. 240 241 Notes 242 ----- 243 ``strict_typing=True`` acts as a safety guard. Setting 244 ``strict_typing=False`` may increase the probability of unexpected or 245 uncaught errors. 246 """ 247 instance = pkl.load(file) 248 if strict_typing and not isinstance(instance, cls): 249 raise TypeError(f"The loaded instance is not an instance of {cls}.") 250 if new_path is not None: 251 instance.path = new_path 252 return instance
[docs] 253 @classmethod 254 def load(cls, path: str, new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject": 255 """Loads a pickled instance. 256 257 Parameters 258 ---------- 259 path : str 260 The path of the pickle. 261 new_path : str, optional 262 The path to replace the previous path with. If ``None`` the `path` 263 is not replaced. By default ``None``. 264 strict_typing : bool, optional 265 If ``True`` then the loaded instance must be an instance of `cls`. 266 By default ``True``. 267 268 Returns 269 ------- 270 SaveableObject 271 The loaded instance. 272 273 Raises 274 ------ 275 TypeError 276 If `strict_typing` and the loaded instance is not an instance of 277 `cls`. 278 279 Notes 280 ----- 281 ``strict_typing=True`` acts as a safety guard. Setting 282 ``strict_typing=False`` may increase the probability of unexpected or 283 uncaught errors. 284 """ 285 path = cls._updatepathroot(path) 286 with open(path, "rb") as file: 287 return cls._load(file, new_path, strict_typing)
[docs] 288 @classmethod 289 def tryload(cls, path: Optional[str], new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject" | Literal[False]: 290 """Attempts to :meth:`load` from the specified `path`. If the loading 291 fails then ``False`` is returned. 292 293 Parameters 294 ---------- 295 path : str, optional 296 The path of the pickle. If ``None`` then ``False`` is returned. 297 new_path : str, optional 298 The path to replace the previous path with. If ``None`` the `path` 299 is not replaced. By default ``None``. 300 strict_typing : bool, optional 301 If ``True`` then the loaded instance must be an instance of `cls`. 302 By default ``True``. 303 304 Returns 305 ------- 306 SaveableObject | Literal[False] 307 If succeeded the loaded instance, else False. 308 309 Notes 310 ----- 311 ``strict_typing=True`` acts as a safety guard. Setting 312 ``strict_typing=False`` may increase the probability of unexpected or 313 uncaught errors. 314 """ 315 try: 316 return cls.load(path, new_path, strict_typing) 317 except (FileNotFoundError, TypeError): 318 return False
[docs] 319 @classmethod 320 def loadif(cls, *args, **kwargs) -> Tuple["SaveableObject", bool]: 321 """Attempts to load from a specified `path`. If the loading fails or no 322 `path` is specified then a new instance of the object is generated with 323 the specified `*args` and `**kwargs`. 324 325 Parameters 326 ---------- 327 *args 328 The arguments to pass to the initialisation on a failed 329 :meth:`load`. 330 path : str, optional 331 The path of the pickle, by default the parameter is not specified. 332 **kwargs 333 The keyword arguments to pass to the initialisation on a failed 334 :meth:`load`. 335 336 Returns 337 ------- 338 (SaveableObject, bool) 339 The loaded or initialised instance followed by ``True`` if the 340 instance was loaded and ``False`` if the instance was initialised. 341 """ 342 bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs) 343 try: 344 path = bound_args.arguments["path"] 345 except KeyError: 346 path = None 347 instance = cls.tryload(path) 348 if instance: 349 return instance, True 350 return cls(*args, **kwargs), False
[docs] 351 @classmethod 352 def loadifparams(cls, *args, dependencies: dict = {}, **kwargs) -> Tuple["SaveableObject", bool]: 353 """Attempts to :meth:`load` from a specified `path`. If the loading 354 fails or no `path` is specified or the parameters do not match the saved 355 parameters then a new instance of the object is generated with the 356 specified `*args` and `**kwargs`. 357 358 Parameters 359 ---------- 360 *args 361 The arguments to pass to the initialisation on a failed 362 :meth:`load`. 363 path : str, optional 364 The path of the pickle, by default the parameter is not specified. 365 dependencies : dict, optional, must be specified as a keyword argument 366 A dictionary of additional dependencies to check. 367 **kwargs 368 The keyword arguments to pass to the initialisation on a failed 369 :meth:`load`. 370 371 Returns 372 ------- 373 (SaveableObject, bool) 374 The loaded or initialised instance followed by ``True`` if the 375 instance was loaded and ``False`` if the instance was initialised. 376 """ 377 bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs) 378 try: 379 path = bound_args.arguments["path"] 380 except KeyError: 381 path = None 382 path = cls._updatepathroot(path) 383 duplicates = [] 384 for key in dependencies.keys(): 385 if key in bound_args.arguments.keys(): 386 duplicates.append(key) 387 if len(duplicates) != 0: 388 raise TypeError(f"The dependencies {duplicates} are also arguments. They must have different names.") 389 arguments = {**bound_args.arguments, **dependencies} 390 try: 391 with open(path, "rb") as file: 392 instance = cls._load(file) 393 params = pkl.load(file) 394 for key, value in arguments.items(): 395 comparison = params[key] != value 396 if isinstance(comparison, bool): 397 if comparison: 398 raise ValueError 399 else: 400 if not np.array_equal(params[key], value): 401 raise ValueError 402 return instance, True 403 except (FileNotFoundError, EOFError, ValueError, TypeError, KeyError): 404 instance = cls(*args, **kwargs) 405 if path is not None: 406 SaveableObject._save(arguments, path, write_mode="ab") 407 return instance, False