import numpy as np
import xarray as xr
from pydantic import BaseModel, Extra, validator
from typing import Any, Dict, List, Optional, Union
from numpy.typing import NDArray
from .base import QualityChecker
__all__ = [
"CheckMissing",
"CheckMonotonic",
"CheckValidMin",
"CheckValidMax",
"CheckFailMin",
"CheckFailMax",
"CheckWarnMin",
"CheckWarnMax",
"CheckValidRangeMin",
"CheckValidRangeMax",
"CheckFailRangeMin",
"CheckFailRangeMax",
"CheckWarnRangeMin",
"CheckWarnRangeMax",
"CheckValidDelta",
"CheckFailDelta",
"CheckWarnDelta",
]
[docs]class CheckMissing(QualityChecker):
"""---------------------------------------------------------------------------------
Checks if any data are missing. A variable's data are considered missing if they are
set to the variable's _FillValue (if it has a _FillValue) or NaN (NaT for datetime-
like variables).
---------------------------------------------------------------------------------"""
[docs] def run(self, dataset: xr.Dataset, variable_name: str) -> NDArray[np.bool8]:
results: NDArray[np.bool8] = dataset[variable_name].isnull().data
if "_FillValue" in dataset[variable_name].attrs:
fill_value = dataset[variable_name].attrs["_FillValue"]
results |= dataset[variable_name].data == fill_value
elif np.issubdtype(dataset[variable_name].data.dtype, str): # type: ignore
fill_value = ""
results |= dataset[variable_name].data == fill_value
return results
[docs]class CheckMonotonic(QualityChecker):
"""---------------------------------------------------------------------------------
Checks if any values are not ordered strictly monotonically (i.e. values must all be
increasing or all decreasing). The check marks all values as failed if any data values
are not ordered monotonically.
---------------------------------------------------------------------------------"""
[docs] class Parameters(BaseModel, extra=Extra.forbid):
[docs] require_decreasing: bool = False
[docs] require_increasing: bool = False
[docs] dim: Optional[str] = None
@validator("require_increasing")
@classmethod
[docs] def check_monotonic_not_increasing_and_decreasing(
cls, inc: bool, values: Dict[str, Any]
) -> bool:
if inc and values["require_decreasing"]:
raise ValueError(
"CheckMonotonic -> Parameters: cannot set both 'require_increasing'"
" and 'require_decreasing'. Please set one or both to False."
)
return inc
[docs] parameters: Parameters = Parameters()
[docs] def run(self, dataset: xr.Dataset, variable_name: str) -> NDArray[np.bool8]:
variable = dataset[variable_name]
axis = self.get_axis(variable)
diff: NDArray[Any] = np.diff(variable.data, axis=axis) # type: ignore
zero: Any = 0
if np.issubdtype(variable.data.dtype, (np.datetime64, np.timedelta64)): # type: ignore
zero = np.timedelta64(0)
increasing: bool = np.all(diff > zero) # type: ignore
decreasing: bool = np.all(diff < zero) # type: ignore
if self.parameters.require_increasing:
is_monotonic = increasing
elif self.parameters.require_decreasing:
is_monotonic = decreasing
else:
is_monotonic = increasing | decreasing
return np.full(variable.shape, not is_monotonic, dtype=np.bool8) # type: ignore
[docs] def get_axis(self, variable: xr.DataArray) -> int:
if not (dim := self.parameters.dim):
dim = variable.dims[0]
return variable.get_axis_num(dim) # type: ignore
class _ThresholdChecker(QualityChecker):
"""---------------------------------------------------------------------------------
Base class for threshold-based classes where the threshold value is stored in a
variable attribute.
Args:
attribute_name (str): The name of the attribute containing the maximum
threshold. If the attribute ends in '_range' then it is assumed to be a list,
and the first value from the list will be used as the minimum threshold.
allow_equal (bool): True if values equal to the threshold should pass the check,
False otherwise.
---------------------------------------------------------------------------------"""
allow_equal: bool = True
"""True if values equal to the threshold should pass, False otherwise."""
attribute_name: str
"""The attribute on the data variable that should be used to get the threshold."""
def _get_threshold(
self, dataset: xr.Dataset, variable_name: str, min_: bool
) -> Optional[float]:
threshold: Optional[Union[float, List[float]]] = dataset[
variable_name
].attrs.get(self.attribute_name, None)
if threshold is not None:
if isinstance(threshold, list):
index = 0 if min_ else -1
threshold = threshold[index]
return threshold
class _CheckMin(_ThresholdChecker):
"""---------------------------------------------------------------------------------
Checks that no values for the specified variable are less than a specified minimum
threshold. The value of the threshold is specified by an attribute on each data
variable, and the attribute to search for is specified as a property of this base
class.
If the specified attribute does not exist on the variable being checked then no
failures will be reported.
Args:
attribute_name (str): The name of the attribute containing the minimum
threshold. If the attribute ends in '_range' then it is assumed to be a list,
and the first value from the list will be used as the minimum threshold.
allow_equal (bool): True if values equal to the threshold should pass the check,
False otherwise.
---------------------------------------------------------------------------------"""
def run(self, dataset: xr.Dataset, variable_name: str) -> NDArray[np.bool8]:
var_data = dataset[variable_name]
failures: NDArray[np.bool8] = np.zeros_like(var_data, dtype=np.bool8) # type: ignore
min_value = self._get_threshold(dataset, variable_name, min_=True)
if min_value is None:
return failures
if self.allow_equal:
failures = np.less(var_data.data, min_value)
else:
failures = np.less_equal(var_data.data, min_value)
return failures
class _CheckMax(_ThresholdChecker):
"""---------------------------------------------------------------------------------
Checks that no values for the specified variable are greater than a specified
threshold. The value of the threshold is specified by an attribute on each data
variable, and the attribute to search for is specified as a property of this base
class.
If the specified attribute does not exist on the variable being checked then no
failures will be reported.
Args:
attribute_name (str): The name of the attribute containing the maximum
threshold. If the attribute ends in '_range' then it is assumed to be a list,
and the first value from the list will be used as the minimum threshold.
allow_equal (bool): True if values equal to the threshold should pass the check,
False otherwise.
---------------------------------------------------------------------------------"""
def run(self, dataset: xr.Dataset, variable_name: str) -> NDArray[np.bool8]:
var_data = dataset[variable_name]
failures: NDArray[np.bool8] = np.zeros_like(var_data, dtype=np.bool8) # type: ignore
max_value = self._get_threshold(dataset, variable_name, min_=False)
if max_value is None:
return failures
if self.allow_equal:
failures = np.greater(var_data.data, max_value)
else:
failures = np.greater_equal(var_data.data, max_value)
return failures
[docs]class CheckValidMin(_CheckMin):
[docs] attribute_name: str = "valid_min"
[docs]class CheckValidMax(_CheckMax):
[docs] attribute_name: str = "valid_max"
[docs]class CheckFailMin(_CheckMin):
[docs] attribute_name: str = "fail_min"
[docs]class CheckFailMax(_CheckMax):
[docs] attribute_name: str = "fail_max"
[docs]class CheckWarnMin(_CheckMin):
[docs] attribute_name: str = "warn_min"
[docs]class CheckWarnMax(_CheckMax):
[docs] attribute_name: str = "warn_max"
[docs]class CheckValidRangeMin(_CheckMin):
[docs] attribute_name: str = "valid_range"
[docs]class CheckValidRangeMax(_CheckMax):
[docs] attribute_name: str = "valid_range"
[docs]class CheckFailRangeMin(_CheckMin):
[docs] attribute_name: str = "fail_range"
[docs]class CheckFailRangeMax(_CheckMax):
[docs] attribute_name: str = "fail_range"
[docs]class CheckWarnRangeMin(_CheckMin):
[docs] attribute_name: str = "warn_range"
[docs]class CheckWarnRangeMax(_CheckMax):
[docs] attribute_name: str = "warn_range"
class _CheckDelta(_ThresholdChecker):
"""------------------------------------------------------------------------------------
Checks the difference between consecutive values and reports a failure if the
difference is less than the threshold specified by the value in the attribute
provided to this check.
Args:
attribute_name (str): The name of the attribute containing the threshold to use.
------------------------------------------------------------------------------------"""
class Parameters(BaseModel, extra=Extra.forbid):
dim: str = "time"
"""The dimension on which to perform the diff."""
parameters: Parameters = Parameters()
def run(self, dataset: xr.Dataset, variable_name: str) -> NDArray[np.bool8]:
threshold = self._get_threshold(dataset, variable_name, True)
data: NDArray[Any] = dataset[variable_name].data
axis = dataset[variable_name].get_axis_num(self.parameters.dim)
diff: NDArray[Any] = np.absolute(np.diff(data, axis=axis, prepend=data[0])) # type: ignore
failures = diff > threshold if self.allow_equal else diff >= threshold
return failures
[docs]class CheckValidDelta(_CheckDelta):
[docs] attribute_name: str = "valid_delta"
[docs]class CheckFailDelta(_CheckDelta):
[docs] attribute_name: str = "fail_delta"
[docs]class CheckWarnDelta(_CheckDelta):
[docs] attribute_name: str = "warn_delta"
# check_outlier(std_dev)
# check_time_gap --> parameters: min_time_gap (str), max_time_gap (str)