Source code for tsdat.config.utils

import os
import yaml
import warnings
from jsonpointer import set_pointer  # type: ignore
from dunamai import Style, Version
from pathlib import Path
from pydantic import BaseModel, Extra, Field, StrictStr, validator, FilePath
from pydantic.utils import import_string
from pydantic.generics import GenericModel
from typing import (
    Any,
    Optional,
    cast,
    Dict,
    Generic,
    List,
    Protocol,
    Sequence,
    Set,
    TypeVar,
)


__all__ = [
    "ParameterizedConfigClass",
    "recursive_instantiate",
    "read_yaml",
    "get_code_version",
    "YamlModel",
]


[docs]class YamlModel(BaseModel): @classmethod
[docs] def from_yaml(cls, filepath: Path, overrides: Optional[Dict[str, Any]] = None): # TODO: Add docstring since this is a public-facing method config = read_yaml(filepath) if overrides: for pointer, new_value in overrides.items(): set_pointer(config, pointer, new_value) return cls(**config)
@classmethod
[docs] def generate_schema(cls, output_file: Path): # TODO: Add docstring since this is a semi-public-facing method output_file.write_text(cls.schema_json(indent=4))
Config = TypeVar("Config", bound=BaseModel) class Overrideable(YamlModel, GenericModel, Generic[Config], extra=Extra.forbid): path: FilePath overrides: Dict[str, Any] = {} # def get_defaults_dict(self) -> Dict[Any, Any]: # txt = self.path.read_text() # return list(yaml.safe_load_all(txt))[0] # def merge_overrides(self) -> Dict[Any, Any]: # defaults = self.get_defaults_dict() # for pointer, new_value in self.overrides.items(): # set_pointer(defaults, pointer, new_value) # return defaults def matches_overrideable_schema(model_dict: Dict[str, Any]): return "path" in model_dict
[docs]class ParameterizedConfigClass(BaseModel, extra=Extra.forbid): # Unfortunately, the classname has to be a string type unless PyObject becomes JSON # serializable: https://github.com/samuelcolvin/pydantic/discussions/3842
[docs] classname: StrictStr = Field( description="The import path to the Python class that should be used, e.g., if" " your import statement looks like `from foo.bar import Baz`, then your" " classname would be `foo.bar.Baz`.", )
[docs] parameters: Dict[str, Any] = Field( {}, description="Optional dictionary that will be passed to the Python class" " specified by 'classname' when it is instantiated. If the object is a tsdat" " class, then the parameters will typically be made accessible under the" " `params` property on an instance of the class. See the documentation for" " individual classes for more information.", )
@validator("classname") @classmethod
[docs] def classname_looks_like_a_module(cls, v: StrictStr) -> StrictStr: if "." not in v or not v.replace(".", "").replace("_", "").isalnum(): raise ValueError(f"Classname '{v}' is not a valid classname.") return v
[docs] def instantiate(self) -> Any: """------------------------------------------------------------------------------------ Instantiates and returns the class specified by the 'classname' parameter. Returns: Any: An instance of the specified class. ------------------------------------------------------------------------------------""" params = {field: getattr(self, field) for field in self.__fields_set__} _cls = import_string(params.pop("classname")) return _cls(**params)
[docs]def recursive_instantiate(model: Any) -> Any: """--------------------------------------------------------------------------------- Instantiates all ParametrizedClass components and subcomponents of a given model. Recursively calls model.instantiate() on all ParameterizedConfigClass instances under the the model, resulting in a new model which follows the same general structure as the given model, but possibly containing totally different properties and methods. Note that this method does a depth-first traversal of the model tree to to instantiate leaf nodes first. Traversing breadth-first would result in new pydantic models attempting to call the __init__ method of child models, which is not valid because the child models are ParameterizedConfigClass instances. Traversing depth-first allows us to first transform child models into the appropriate type using the classname of the ParameterizedConfigClass. This method is primarily used to instantiate a Pipeline subclass and all of its properties from a yaml pipeline config file, but it can be applied to any other pydantic model. Args: model (Any): The object to recursively instantiate. Returns: Any: The recursively-instantiated object. ---------------------------------------------------------------------------------""" # Case: ParameterizedConfigClass. Want to instantiate any sub-models then return the class # with all submodels recursively instantiated, then statically instantiate the model. # Note: the model is instantiated last so that sub-models are only processed once. if isinstance(model, ParameterizedConfigClass): fields = model.__fields_set__ - {"classname"} # No point checking classname for field in fields: setattr(model, field, recursive_instantiate(getattr(model, field))) return model.instantiate() # Case: BaseModel. Want to instantiate any sub-models then return the model itself. elif isinstance(model, BaseModel): fields = model.__fields_set__ assert "classname" not in fields for field in fields: setattr(model, field, recursive_instantiate(getattr(model, field))) return model # Case: List. Want to iterate through and recursively instantiate all sub-models in # the list, then return everything as a list. elif isinstance(model, List): return [recursive_instantiate(m) for m in cast(List[Any], model)] # Case Dict. Want to iterate through and recursively instantiate all sub-models in # the Dict's values, then return everything as a Dict, unless the dict is meant to # be turned into a parameterized class, in which case we instantiate it as the # intended object elif isinstance(model, Dict): model = { k: recursive_instantiate(v) for k, v in cast(Dict[str, Any], model).items() } if "classname" in model: classname: str = model.pop("classname") # type: ignore _cls = import_string(classname) return _cls(**model) return model # Base case: Anything else; just return the value return model
class _NamedClass(Protocol): name: str def find_duplicates(entries: Sequence[_NamedClass]) -> List[str]: duplicates: List[str] = [] seen: Set[str] = set() for entry in entries: if entry.name in seen: duplicates.append(entry.name) else: seen.add(entry.name) return duplicates
[docs]def read_yaml(filepath: Path) -> Dict[Any, Any]: return list(yaml.safe_load_all(filepath.read_text()))[0]
[docs]def get_code_version() -> str: version = "N/A" try: version = os.environ["CODE_VERSION"] except KeyError: try: version = Version.from_git().serialize(dirty=True, style=Style.SemVer) except RuntimeError: warnings.warn( "Could not get code_version from either the 'CODE_VERSION' environment" " variable nor from git history. The 'code_version' global attribute" " will be set to 'N/A'.", RuntimeWarning, ) return version