Source code for saveable_objects._saveable_object

  1import os
  2import inspect
  3import numpy as np
  4import pickle as pkl
  5import cloudpickle as cpkl
  6from typing import Optional, Literal, IO, Tuple
  7
  8from ._meta_class import SaveAfterInitMetaClass
  9
[docs] 10class SaveableObject(metaclass=SaveAfterInitMetaClass): 11 """A utility class for saving objects to pickles and checkpointing. 12 """
[docs] 13 def __init__(self, path: Optional[str] = None): 14 """Initialises an instance of the :class:`SaveableObject`. 15 If a path is specified the saveable object is automatically saved after 16 initialisation. 17 18 Parameters 19 ---------- 20 path : str, optional 21 The file :attr:`path` to save the object to. If ``None`` then the 22 object is not saved. By default ``None``. 23 24 Notes 25 ----- 26 If no file extension is provided for `path` then the class name and the 27 ``.pkl`` extension are appended to the file name. 28 """ 29 self.path = path
30 @property 31 def path(self) -> Optional[str]: 32 """The current file path of the object. 33 34 Notes 35 ----- 36 On setting the value, if no file extension is provided then the class 37 name and the ``.pkl`` extension are appended to the file name. 38 """ 39 return self._path 40 @path.setter 41 def path(self, value: Optional[str]): 42 """Set the current file path of the object. 43 44 Parameters 45 ---------- 46 value : str, optional 47 The new file path of the object. By default ``None``. 48 49 Notes 50 ----- 51 If no file extension is provided then the class name and the ``.pkl`` 52 extension are appended to the file name. 53 """ 54 self._path = self._updatepathroot(value) 55 @classmethod 56 def _get_name(cls, path: Optional[str]) -> Optional[str]: 57 """Returns the file name of the specified path (without the file 58 extension). 59 60 Parameters 61 ---------- 62 path : str, optional 63 The path to obtain the file name for. 64 65 Returns 66 ------- 67 str, optional 68 The file name (without the file extension). 69 """ 70 if path is None: 71 return None 72 return os.path.split(os.path.splitext(cls._updatepathroot(path))[0])[-1] 73 @property 74 def name(self) -> Optional[str]: 75 """The file name of the object (without the file extension). Note that 76 `name` is read only. 77 """ 78 return self._get_name(self._path) 79 80 def _save(self, 81 path: str, 82 write_mode: Literal["w", "wb", "a", "ab", "x", "xb"] = "wb"): 83 """Saves the object to `path` using `write_mode`. 84 85 Parameters 86 ---------- 87 path : str 88 The path to save the object to. 89 write_mode : Literal["w", "wb", "a", "ab", "x", "xb"], optional 90 The mode with which to open the file to write to. These are the same 91 as `mode` for ``open``. By default ``"wb"``. 92 """ 93 if not os.path.exists(path): 94 dirname = os.path.dirname(path) 95 if dirname != '': 96 os.makedirs(dirname, exist_ok=True) 97 with open(path, write_mode) as file: 98 cpkl.dump(self, file, pkl.HIGHEST_PROTOCOL) 99 def _getpath(self, path: Optional[str]) -> str: 100 """Returns the specified or saved :attr:`path`. 101 102 Parameters 103 ---------- 104 path : str, optional 105 Specified path. 106 107 Returns 108 ------- 109 str 110 Returns the specified or saved :attr:`path`. 111 112 Raises 113 ------ 114 ValueError 115 No save path provided. Raised if no path is saved or specified. 116 """ 117 path = path if path is not None or not hasattr(self, 'path') else self.path 118 if path is None: 119 raise ValueError("No save path provided.") 120 path = self._updatepathroot(path) 121 return path 122 @classmethod 123 def _updatepathroot(cls, path: Optional[str]) -> Optional[str]: 124 """If no file extension is provided then the class name and the ``.pkl`` 125 extension are appended to the file name. 126 127 Parameters 128 ---------- 129 path : str, optional 130 The file path. 131 132 Returns 133 ------- 134 str, optional 135 The modified file path. 136 """ 137 if path is None: 138 return None 139 split = os.path.splitext(path) 140 if len(split[1]) == 0: 141 file_name = os.path.split(split[0])[-1] 142 prefix = "_" if len(file_name) != 0 and file_name[-1] != "_" else "" 143 path += prefix + cls.__name__ + ".pkl" 144 return path
[docs] 145 def save(self, path: Optional[str] = None): 146 """Pickles the current instance. 147 148 Parameters 149 ---------- 150 path : str, optional 151 The path to pickle the instance to. If ``None`` is specified 152 then the attribute :attr:`path` is used instead. 153 By default ``None``. 154 155 Raises 156 ------ 157 ValueError 158 Raised if no path specified either by the parameter `path` or the 159 attribute :attr:`path`. 160 161 Notes 162 ----- 163 If no file extension is provided then the class name and the ``.pkl`` 164 extension are appended to the file name. 165 """ 166 self.path = self._getpath(path) 167 self._save(self.path)
[docs] 168 def update_save(self, path: Optional[str] = None) -> bool: 169 """Pickles the current instance and retains the saved arguments if 170 they exist. 171 172 Parameters 173 ---------- 174 path : str, optional 175 The path to pickle the instance to. If ``None`` is specified 176 then the attribute :attr:`path` is used instead. By default 177 ``None``. 178 179 Returns 180 ------- 181 bool 182 ``True`` if there was an argument pickle to retain. ``False`` if 183 there was not an argument pickle to retain. 184 185 Raises 186 ------ 187 ValueError 188 Raised if no path specified. 189 190 Notes 191 ----- 192 If no file extension is provided then the class name and the ``.pkl`` 193 extension are appended to the file name. 194 """ 195 self.path = self._getpath(path) 196 file = open(self.path, "rb") 197 # Throw away the prior save: 198 try: 199 type(self)._load(file) 200 except: 201 pass 202 # Retain the parameters: 203 try: 204 params = pkl.load(file) 205 except EOFError: 206 # Close the file before writing to it 207 file.close() 208 self._save(self.path) 209 return False 210 else: 211 # Close the file before writing to it 212 file.close() 213 self._save(self.path) 214 SaveableObject._save(params, self.path, write_mode="ab") 215 return True
216 217 @classmethod 218 def _load(cls, 219 file: IO, 220 new_path: Optional[str] = None, 221 strict_typing: bool = True 222 ) -> "SaveableObject": 223 """Loads an instance from the `file`. 224 225 Parameters 226 ---------- 227 file : IO 228 The file to load the instance from. 229 new_path : str, optional 230 The path to replace the previous path with. If ``None`` the `path` 231 is not replaced. By default ``None``. 232 strict_typing : bool, optional 233 If ``True`` then the loaded instance must be an instance of `cls`. 234 By default ``True``. 235 236 Returns 237 ------- 238 `cls` 239 The loaded instance. 240 241 Raises 242 ------ 243 TypeError 244 If `strict_typing` and the loaded instance is not an instance of 245 `cls`. 246 247 Notes 248 ----- 249 ``strict_typing=True`` acts as a safety guard. Setting 250 ``strict_typing=False`` may increase the probability of unexpected or 251 uncaught errors. 252 """ 253 instance = pkl.load(file) 254 if strict_typing and not isinstance(instance, cls): 255 raise TypeError(f"The loaded instance is not an instance of {cls}.") 256 if new_path is not None: 257 instance.path = new_path 258 return instance
[docs] 259 @classmethod 260 def load(cls, 261 path: str, 262 new_path: Optional[str] = None, 263 strict_typing: bool = True 264 ) -> "SaveableObject": 265 """Loads a pickled instance. 266 267 Parameters 268 ---------- 269 path : str 270 The path of the pickle. 271 new_path : str, optional 272 The path to replace the previous path with. If ``None`` the `path` 273 is not replaced. By default ``None``. 274 strict_typing : bool, optional 275 If ``True`` then the loaded instance must be an instance of `cls`. 276 By default ``True``. 277 278 Returns 279 ------- 280 SaveableObject 281 The loaded instance. 282 283 Raises 284 ------ 285 TypeError 286 If `strict_typing` and the loaded instance is not an instance of 287 `cls`. 288 289 Notes 290 ----- 291 ``strict_typing=True`` acts as a safety guard. Setting 292 ``strict_typing=False`` may increase the probability of unexpected or 293 uncaught errors. 294 """ 295 path = cls._updatepathroot(path) 296 with open(path, "rb") as file: 297 return cls._load(file, new_path, strict_typing)
[docs] 298 @classmethod 299 def tryload(cls, 300 path: Optional[str], 301 new_path: Optional[str] = None, 302 strict_typing: bool = True 303 ) -> "SaveableObject" | Literal[False]: 304 """Attempts to :meth:`load` from the specified `path`. If the loading 305 fails then ``False`` is returned. 306 307 Parameters 308 ---------- 309 path : str, optional 310 The path of the pickle. If ``None`` then ``False`` is returned. 311 new_path : str, optional 312 The path to replace the previous path with. If ``None`` the `path` 313 is not replaced. By default ``None``. 314 strict_typing : bool, optional 315 If ``True`` then the loaded instance must be an instance of `cls`. 316 By default ``True``. 317 318 Returns 319 ------- 320 SaveableObject | Literal[False] 321 If succeeded the loaded instance, else False. 322 323 Notes 324 ----- 325 ``strict_typing=True`` acts as a safety guard. Setting 326 ``strict_typing=False`` may increase the probability of unexpected or 327 uncaught errors. 328 """ 329 try: 330 return cls.load(path, new_path, strict_typing) 331 except (FileNotFoundError, TypeError): 332 return False
[docs] 333 @classmethod 334 def loadif(cls, *args, **kwargs) -> Tuple["SaveableObject", bool]: 335 """Attempts to load from a specified `path`. If the loading fails or no 336 `path` is specified then a new instance of the object is generated with 337 the specified `*args` and `**kwargs`. 338 339 Parameters 340 ---------- 341 *args 342 The arguments to pass to the initialisation on a failed 343 :meth:`load`. 344 path : str, optional 345 The path of the pickle, by default the parameter is not specified. 346 **kwargs 347 The keyword arguments to pass to the initialisation on a failed 348 :meth:`load`. 349 350 Returns 351 ------- 352 (SaveableObject, bool) 353 The loaded or initialised instance followed by ``True`` if the 354 instance was loaded and ``False`` if the instance was initialised. 355 """ 356 bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs) 357 try: 358 path = bound_args.arguments["path"] 359 except KeyError: 360 path = None 361 instance = cls.tryload(path) 362 if instance: 363 return instance, True 364 return cls(*args, **kwargs), False
[docs] 365 @classmethod 366 def loadifparams(cls, 367 *args, 368 dependencies: dict = {}, 369 **kwargs 370 ) -> Tuple["SaveableObject", bool]: 371 """Attempts to :meth:`load` from a specified `path`. If the loading 372 fails or no `path` is specified or the parameters do not match the saved 373 parameters then a new instance of the object is generated with the 374 specified `*args` and `**kwargs`. 375 376 Parameters 377 ---------- 378 *args 379 The arguments to pass to the initialisation on a failed 380 :meth:`load`. 381 path : str, optional 382 The path of the pickle, by default the parameter is not specified. 383 dependencies : dict, optional, must be specified as a keyword argument 384 A dictionary of additional dependencies to check. 385 **kwargs 386 The keyword arguments to pass to the initialisation on a failed 387 :meth:`load`. 388 389 Returns 390 ------- 391 (SaveableObject, bool) 392 The loaded or initialised instance followed by ``True`` if the 393 instance was loaded and ``False`` if the instance was initialised. 394 """ 395 bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs) 396 try: 397 path = bound_args.arguments["path"] 398 except KeyError: 399 path = None 400 path = cls._updatepathroot(path) 401 duplicates = [] 402 for key in dependencies.keys(): 403 if key in bound_args.arguments.keys(): 404 duplicates.append(key) 405 if len(duplicates) != 0: 406 raise TypeError(f"The dependencies {duplicates} are also arguments. They must have different names.") 407 arguments = {**bound_args.arguments, **dependencies} 408 try: 409 with open(path, "rb") as file: 410 instance = cls._load(file) 411 params = pkl.load(file) 412 for key, value in arguments.items(): 413 comparison = params[key] != value 414 if isinstance(comparison, bool): 415 if comparison: 416 raise ValueError 417 else: 418 if not np.array_equal(params[key], value): 419 raise ValueError 420 return instance, True 421 except (FileNotFoundError, EOFError, ValueError, TypeError, KeyError): 422 instance = cls(*args, **kwargs) 423 if path is not None: 424 SaveableObject._save(arguments, path, write_mode="ab") 425 return instance, False