Source code for saveable_objects._saveable_object

import os
import inspect
import numpy as np
import pickle as pkl
import cloudpickle as cpkl
from typing import Optional, Literal, IO, Tuple

from ._meta_class import SaveAfterInitMetaClass

[docs] class SaveableObject(metaclass=SaveAfterInitMetaClass): """A utility class for saving objects to pickles and checkpointing. """
[docs] def __init__(self, path: Optional[str] = None): """Initialises an instance of the :class:`SaveableObject`. If a path is specified the saveable object is automatically saved after initialisation. Parameters ---------- path : str, optional The file :attr:`path` to save the object to. If ``None`` then the object is not saved. By default ``None``. Notes ----- If no file extension is provided for `path` then the class name and the ``.pkl`` extension are appended to the file name. """ self.path = path
@property def path(self) -> Optional[str]: """The current file path of the object. Notes ----- On setting the value, if no file extension is provided then the class name and the ``.pkl`` extension are appended to the file name. """ return self._path @path.setter def path(self, value: Optional[str]): """Set the current file path of the object. Parameters ---------- value : str, optional The new file path of the object. By default ``None``. Notes ----- If no file extension is provided then the class name and the ``.pkl`` extension are appended to the file name. """ self._path = self._updatepathroot(value) @classmethod def _get_name(cls, path: Optional[str]) -> Optional[str]: """Returns the file name of the specified path (without the file extension). Parameters ---------- path : str, optional The path to obtain the file name for. Returns ------- str, optional The file name (without the file extension). """ if path is None: return None return os.path.split(os.path.splitext(cls._updatepathroot(path))[0])[-1] @property def name(self) -> Optional[str]: """The file name of the object (without the file extension). Note that `name` is read only. """ return self._get_name(self._path) def _save(self, path: str, write_mode: Literal["w", "wb", "a", "ab", "x", "xb"] = "wb"): """Saves the object to `path` using `write_mode`. Parameters ---------- path : str The path to save the object to. write_mode : Literal["w", "wb", "a", "ab", "x", "xb"], optional The mode with which to open the file to write to. These are the same as `mode` for ``open``. By default ``"wb"``. """ if not os.path.exists(path): if (dirname := os.path.dirname(path)) != '': os.makedirs(dirname, exist_ok=True) with open(path, write_mode) as file: cpkl.dump(self, file, pkl.HIGHEST_PROTOCOL) def _getpath(self, path: Optional[str]) -> str: """Returns the specified or saved :attr:`path`. Parameters ---------- path : str, optional Specified path. Returns ------- str Returns the specified or saved :attr:`path`. Raises ------ ValueError No save path provided. Raised if no path is saved or specified. """ path = path if path is not None or not hasattr(self, 'path') else self.path if path is None: raise ValueError("No save path provided.") path = self._updatepathroot(path) return path @classmethod def _updatepathroot(cls, path: Optional[str]) -> Optional[str]: """If no file extension is provided then the class name and the ``.pkl`` extension are appended to the file name. Parameters ---------- path : str, optional The file path. Returns ------- str, optional The modified file path. """ if path is None: return None split = os.path.splitext(path) if len(split[1]) == 0: file_name = os.path.split(split[0])[-1] prefix = "_" if len(file_name) != 0 and file_name[-1] != "_" else "" path += prefix + cls.__name__ + ".pkl" return path
[docs] def save(self, path: Optional[str] = None): """Pickles the current instance. Parameters ---------- path : str, optional The path to pickle the instance to. If ``None`` is specified then the attribute :attr:`path` is used instead. By default ``None``. Raises ------ ValueError Raised if no path specified either by the parameter `path` or the attribute :attr:`path`. Notes ----- If no file extension is provided then the class name and the ``.pkl`` extension are appended to the file name. """ self.path = self._getpath(path) self._save(self.path)
[docs] def update_save(self, path: Optional[str] = None) -> bool: """Pickles the current instance and retains the saved arguments if they exist. Parameters ---------- path : str, optional The path to pickle the instance to. If ``None`` is specified then the attribute :attr:`path` is used instead. By default ``None``. Returns ------- bool ``True`` if there was an argument pickle to retain. ``False`` if there was not an argument pickle to retain. Raises ------ ValueError Raised if no path specified. Notes ----- If no file extension is provided then the class name and the ``.pkl`` extension are appended to the file name. """ self.path = self._getpath(path) file = open(self.path, "rb") # Throw away the prior save: try: type(self)._load(file) except: pass # Retain the parameters: try: params = pkl.load(file) except EOFError: # Close the file before writing to it file.close() self._save(self.path) return False else: # Close the file before writing to it file.close() self._save(self.path) SaveableObject._save(params, self.path, write_mode="ab") return True
@classmethod def _load(cls, file: IO, new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject": """Loads an instance from the `file`. Parameters ---------- file : IO The file to load the instance from. new_path : str, optional The path to replace the previous path with. If ``None`` the `path` is not replaced. By default ``None``. strict_typing : bool, optional If ``True`` then the loaded instance must be an instance of `cls`. By default ``True``. Returns ------- `cls` The loaded instance. Raises ------ TypeError If `strict_typing` and the loaded instance is not an instance of `cls`. Notes ----- ``strict_typing=True`` acts as a safety guard. Setting ``strict_typing=False`` may increase the probability of unexpected or uncaught errors. """ instance = pkl.load(file) if strict_typing and not isinstance(instance, cls): raise TypeError(f"The loaded instance is not an instance of {cls}.") if new_path is not None: instance.path = new_path return instance
[docs] @classmethod def load(cls, path: str, new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject": """Loads a pickled instance. Parameters ---------- path : str The path of the pickle. new_path : str, optional The path to replace the previous path with. If ``None`` the `path` is not replaced. By default ``None``. strict_typing : bool, optional If ``True`` then the loaded instance must be an instance of `cls`. By default ``True``. Returns ------- SaveableObject The loaded instance. Raises ------ TypeError If `strict_typing` and the loaded instance is not an instance of `cls`. Notes ----- ``strict_typing=True`` acts as a safety guard. Setting ``strict_typing=False`` may increase the probability of unexpected or uncaught errors. """ path = cls._updatepathroot(path) with open(path, "rb") as file: return cls._load(file, new_path, strict_typing)
[docs] @classmethod def tryload(cls, path: Optional[str], new_path: Optional[str] = None, strict_typing: bool = True) -> "SaveableObject" | Literal[False]: """Attempts to :meth:`load` from the specified `path`. If the loading fails then ``False`` is returned. Parameters ---------- path : str, optional The path of the pickle. If ``None`` then ``False`` is returned. new_path : str, optional The path to replace the previous path with. If ``None`` the `path` is not replaced. By default ``None``. strict_typing : bool, optional If ``True`` then the loaded instance must be an instance of `cls`. By default ``True``. Returns ------- SaveableObject | Literal[False] If succeeded the loaded instance, else False. Notes ----- ``strict_typing=True`` acts as a safety guard. Setting ``strict_typing=False`` may increase the probability of unexpected or uncaught errors. """ try: return cls.load(path, new_path, strict_typing) except (FileNotFoundError, TypeError): return False
[docs] @classmethod def loadif(cls, *args, **kwargs) -> Tuple["SaveableObject", bool]: """Attempts to load from a specified `path`. If the loading fails or no `path` is specified then a new instance of the object is generated with the specified `*args` and `**kwargs`. Parameters ---------- *args The arguments to pass to the initialisation on a failed :meth:`load`. path : str, optional The path of the pickle, by default the parameter is not specified. **kwargs The keyword arguments to pass to the initialisation on a failed :meth:`load`. Returns ------- (SaveableObject, bool) The loaded or initialised instance followed by ``True`` if the instance was loaded and ``False`` if the instance was initialised. """ bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs) try: path = bound_args.arguments["path"] except KeyError: path = None if instance := cls.tryload(path): return instance, True return cls(*args, **kwargs), False
[docs] @classmethod def loadifparams(cls, *args, dependencies: dict = {}, **kwargs) -> Tuple["SaveableObject", bool]: """Attempts to :meth:`load` from a specified `path`. If the loading fails or no `path` is specified or the parameters do not match the saved parameters then a new instance of the object is generated with the specified `*args` and `**kwargs`. Parameters ---------- *args The arguments to pass to the initialisation on a failed :meth:`load`. path : str, optional The path of the pickle, by default the parameter is not specified. dependencies : dict, optional, must be specified as a keyword argument A dictionary of additional dependencies to check. **kwargs The keyword arguments to pass to the initialisation on a failed :meth:`load`. Returns ------- (SaveableObject, bool) The loaded or initialised instance followed by ``True`` if the instance was loaded and ``False`` if the instance was initialised. """ bound_args = inspect.signature(cls.__init__).bind(..., *args, **kwargs) try: path = bound_args.arguments["path"] except KeyError: path = None path = cls._updatepathroot(path) duplicates = [] for key in dependencies.keys(): if key in bound_args.arguments.keys(): duplicates.append(key) if len(duplicates) != 0: raise TypeError(f"The dependencies {duplicates} are also arguments. They must have different names.") arguments = {**bound_args.arguments, **dependencies} try: with open(path, "rb") as file: instance = cls._load(file) params = pkl.load(file) for key, value in arguments.items(): comparison = params[key] != value if isinstance(comparison, bool): if comparison: raise ValueError else: if not np.array_equal(params[key], value): raise ValueError return instance, True except (FileNotFoundError, EOFError, ValueError, TypeError, KeyError): instance = cls(*args, **kwargs) if path is not None: SaveableObject._save(arguments, path, write_mode="ab") return instance, False