Source code for tsdat.testing

"""-------------------------------------------------------------------------------------
Utility functions used in tsdat tests.

-------------------------------------------------------------------------------------"""

import numpy as np
import xarray as xr
from typing import Any, Hashable, List

__all__ = [
    # "compare",
    "assert_close",
]


def get_pydantic_error_message(error: Any) -> str:
    return error.getrepr().reprcrash.message


def get_pydantic_warning_message(warning: Any) -> str:
    warnings: List[str] = [_warning.message.args[0] for _warning in warning.list]
    return "\n".join(warnings)


# def compare(*model_dicts: Any):
#     """---------------------------------------------------------------------------------
#     Method used to compare dictionaries side-by-side in the terminal. Primarily useful
#     for debugging.

#     ---------------------------------------------------------------------------------"""
#     # IDEA: highlight differences

#     from rich.console import Console
#     from rich.columns import Columns
#     from rich.pretty import Pretty
#     from rich.panel import Panel

#     console = Console()
#     renderables: List[Panel] = [Panel(Pretty(model_dict)) for model_dict in model_dicts]
#     console.print(Columns(renderables, equal=True, expand=True))


[docs]def assert_close( a: xr.Dataset, b: xr.Dataset, check_attrs: bool = True, check_fill_value: bool = True, **kwargs: Any, ) -> None: """--------------------------------------------------------------------------------- Thin wrapper around xarray.assert_allclose. Also checks dataset and variable attrs. Removes global attributes that are allowed to be different, which are currently just the 'history' attribute and the 'code_version' attribute. Also handles some obscure edge cases for variable attributes. Args: a (xr.Dataset): The first dataset to compare. b (xr.Dataset): The second dataset to compare. check_attrs (bool): Check global and variable attributes in addition to the data. Defaults to True. check_fill_value (bool): Check the _FillValue attribute. This is a special case because xarray moves the _FillValue from a variable's attributes to its encoding upon saving the dataset. Defaults to True. ---------------------------------------------------------------------------------""" a, b = a.copy(), b.copy() # type: ignore _convert_time(a, b) xr.testing.assert_allclose(a, b, **kwargs) # type: ignore if check_attrs: _check_global_attrs(a, b) _check_variable_attrs(a, b, check_fill_value)
def _convert_time(a: xr.Dataset, b: xr.Dataset) -> None: # Converts datetime64 to seconds since 1970 for v in a.variables: if np.issubdtype(a[v].dtype, np.datetime64): # type: ignore a[v] = a[v].astype("datetime64[ns]").astype("float") / 1e9 # type: ignore b[v] = b[v].astype("datetime64[ns]").astype("float") / 1e9 # type: ignore def _check_global_attrs(a: xr.Dataset, b: xr.Dataset) -> None: _drop_incomparable_global_attrs(a) _drop_incomparable_global_attrs(b) if a.attrs != b.attrs: raise AssertionError(f"global attributes do not match:\n{a.attrs}\n{b.attrs}") def _check_variable_attrs(a: xr.Dataset, b: xr.Dataset, check_fill_value: bool) -> None: _drop_incomparable_variable_attrs(a) _drop_incomparable_variable_attrs(b) for var_name in a.variables: _check_var_attrs(a, b, var_name, check_fill_value) def _check_var_attrs( a: xr.Dataset, b: xr.Dataset, var_name: Hashable, check_fill_value: bool ) -> None: if check_fill_value: _check_fillvalue(a[var_name], b[var_name]) a_attrs, b_attrs = a[var_name].attrs, b[var_name].attrs a_attrs.pop("_FillValue", None), b_attrs.pop("_FillValue", None) if a.attrs != b.attrs: raise AssertionError(f"attributes do not match:\n{a_attrs}\n{b_attrs}") def _drop_incomparable_global_attrs(ds: xr.Dataset): ds.attrs.pop("history", None) ds.attrs.pop("code_version", None) def _drop_incomparable_variable_attrs(ds: xr.Dataset): if "time" in ds.variables: ds["time"].attrs.pop("units", None) def _check_fillvalue(a: xr.DataArray, b: xr.DataArray) -> None: a_fill = a.attrs.get("_FillValue") or a.encoding.get("_FillValue") b_fill = b.attrs.get("_FillValue") or b.encoding.get("_FillValue") if not (a_fill == b_fill or (np.isnan(a_fill) and np.isnan(b_fill))): raise AssertionError( f"'{a.name}' _FillValue attrs/encoding do not match:\n" f"{a_fill},\n" f"{b_fill}" )