Source code for tsdat.testing

"""-------------------------------------------------------------------------------------
Utility functions used in tsdat tests.

-------------------------------------------------------------------------------------"""

import xarray as xr
from typing import Any, Hashable, List

__all__ = [
    # "compare",
    "assert_close",
]


def get_pydantic_error_message(error: Any) -> str:
    return error.getrepr().reprcrash.message


def get_pydantic_warning_message(warning: Any) -> str:
    warnings: List[str] = [_warning.message.args[0] for _warning in warning.list]
    return "\n".join(warnings)


# def compare(*model_dicts: Any):
#     """---------------------------------------------------------------------------------
#     Method used to compare dictionaries side-by-side in the terminal. Primarily useful
#     for debugging.

#     ---------------------------------------------------------------------------------"""
#     # IDEA: highlight differences

#     from rich.console import Console
#     from rich.columns import Columns
#     from rich.pretty import Pretty
#     from rich.panel import Panel

#     console = Console()
#     renderables: List[Panel] = [Panel(Pretty(model_dict)) for model_dict in model_dicts]
#     console.print(Columns(renderables, equal=True, expand=True))


[docs]def assert_close( a: xr.Dataset, b: xr.Dataset, check_attrs: bool = True, check_fill_value: bool = True, **kwargs: Any, ) -> None: """--------------------------------------------------------------------------------- Thin wrapper around xarray.assert_allclose. Also checks dataset and variable attrs. Removes global attributes that are allowed to be different, which are currently just the 'history' attribute and the 'code_version' attribute. Also handles some obscure edge cases for variable attributes. Args: a (xr.Dataset): The first dataset to compare. b (xr.Dataset): The second dataset to compare. check_attrs (bool): Check global and variable attributes in addition to the data. Defaults to True. check_fill_value (bool): Check the _FillValue attribute. This is a special case because xarray moves the _FillValue from a variable's attributes to its encoding upon saving the dataset. Defaults to True. ---------------------------------------------------------------------------------""" a, b = a.copy(), b.copy() # type: ignore xr.testing.assert_allclose(a, b, **kwargs) # type: ignore if check_attrs: _check_global_attrs(a, b) _check_variable_attrs(a, b, check_fill_value)
def _check_global_attrs(a: xr.Dataset, b: xr.Dataset) -> None: _drop_incomparable_global_attrs(a) _drop_incomparable_global_attrs(b) assert a.attrs == b.attrs, f"global attributes do not match:\n{a.attrs}\n{b.attrs}" def _check_variable_attrs(a: xr.Dataset, b: xr.Dataset, check_fill_value: bool) -> None: _drop_incomparable_variable_attrs(a) _drop_incomparable_variable_attrs(b) for var_name in a.variables: _check_var_attrs(a, b, var_name, check_fill_value) def _check_var_attrs( a: xr.Dataset, b: xr.Dataset, var_name: Hashable, check_fill_value: bool ) -> None: if check_fill_value: _check_fillvalue(a[var_name], b[var_name]) a_attrs, b_attrs = a[var_name].attrs, b[var_name].attrs a_attrs.pop("_FillValue", None), b_attrs.pop("_FillValue", None) assert ( a_attrs == b_attrs ), f"'{var_name}' attributes do not match:\n{a_attrs},\n{b_attrs}" def _drop_incomparable_global_attrs(ds: xr.Dataset): ds.attrs.pop("history", None) ds.attrs.pop("code_version", None) def _drop_incomparable_variable_attrs(ds: xr.Dataset): if "time" in ds.variables: ds["time"].attrs.pop("units", None) def _check_fillvalue(a: xr.DataArray, b: xr.DataArray) -> None: def _is_nan(num: Any) -> bool: return num != num nan = float("nan") a_fill = a.attrs.get("_FillValue", nan) or a.encoding.get("_FillValue", nan) b_fill = b.attrs.get("_FillValue", nan) or b.encoding.get("_FillValue", nan) assert (a_fill == b_fill) or ( _is_nan(a_fill) and _is_nan(b_fill) ), f"'{a.name}' _FillValue attrs/encoding does not match:\n{a_fill},\n{b_fill}"