# TODO: Retrieval from S3; another retriever class, or parameters on the default?
# IDEA: Implement MultiDatastreamRetriever & variable finders
import logging
import xarray as xr
from pydantic import BaseModel, Extra
from typing import Any, Dict, List, Pattern, cast
from ..config.dataset import DatasetConfig
from .base import Retriever, DataReader, DataConverter
__all__ = ["DefaultRetriever"]
logger = logging.getLogger(__name__)
class RetrievedVariable(BaseModel, extra=Extra.forbid):
name: str
data_converters: List[DataConverter] = []
class InputKeyRetrieverConfig:
def __init__(self, input_key: str, retriever: "DefaultRetriever") -> None:
self.coords: Dict[str, RetrievedVariable] = {}
self.data_vars: Dict[str, RetrievedVariable] = {}
def update_mapping(
to_update: Dict[str, RetrievedVariable],
variable_dict: Dict[str, Dict[Pattern[str], RetrievedVariable]],
):
for name, retriever_dict in variable_dict.items():
for pattern, variable_retriever in retriever_dict.items():
if pattern.match(input_key):
to_update[name] = variable_retriever
break
update_mapping(self.coords, retriever.coords) # type: ignore
update_mapping(self.data_vars, retriever.data_vars) # type: ignore
[docs]class DefaultRetriever(Retriever):
"""------------------------------------------------------------------------------------
Default API for retrieving data from one or more input sources.
Reads data from one or more inputs, renames coordinates and data variables according
to retrieval and dataset configurations, and applies registered DataConverters to
retrieved data.
Args:
readers (Dict[Pattern[str], DataReader]): A mapping of patterns to DataReaders
that the retriever uses to determine which DataReader to use for reading any
given input key.
coords (Dict[str, Dict[Pattern[str], VariableRetriever]]): A dictionary mapping
output coordinate variable names to rules for how they should be retrieved.
data_vars (Dict[str, Dict[Pattern[str], VariableRetriever]]): A dictionary
mapping output data variable names to rules for how they should be
retrieved.
------------------------------------------------------------------------------------"""
[docs] class Parameters(BaseModel, extra=Extra.forbid):
[docs] merge_kwargs: Dict[str, Any] = {}
"""Keyword arguments passed to xr.merge(). This is only relevant if multiple
input keys are provided simultaneously, or if any registered DataReader objects
could return a dataset mapping instead of a single dataset."""
# IDEA: option to disable retrieval of input attrs
# retain_global_attrs: bool = True
# retain_variable_attrs: bool = True
[docs] parameters: Parameters = Parameters()
[docs] readers: Dict[Pattern, DataReader] # type: ignore
"""A dictionary of DataReaders that should be used to read data provided an input
key."""
[docs] coords: Dict[str, Dict[Pattern, RetrievedVariable]] # type: ignore
"""A dictionary mapping output coordinate names to the retrieval rules and
preprocessing actions (e.g., DataConverters) that should be applied to each retrieved
coordinate variable."""
[docs] data_vars: Dict[str, Dict[Pattern, RetrievedVariable]] # type: ignore
"""A dictionary mapping output data variable names to the retrieval rules and
preprocessing actions (e.g., DataConverters) that should be applied to each
retrieved data variable."""
[docs] def retrieve(
self, input_keys: List[str], dataset_config: DatasetConfig, **kwargs: Any
) -> xr.Dataset:
raw_mapping = self._get_raw_mapping(input_keys)
dataset_mapping: Dict[str, xr.Dataset] = {}
for key, dataset in raw_mapping.items():
input_config = InputKeyRetrieverConfig(key, self)
dataset = self._rename_variables(dataset, input_config)
dataset = self._reindex_dataset_coords(
dataset, dataset_config, input_config
)
dataset = self._run_data_converters(dataset, dataset_config, input_config)
dataset_mapping[key] = dataset
output_dataset = self._merge_raw_mapping(dataset_mapping)
return output_dataset
def _get_raw_mapping(self, input_keys: List[str]) -> Dict[str, xr.Dataset]:
dataset_mapping: Dict[str, xr.Dataset] = {}
input_reader_mapping = self._match_inputs(input_keys)
for input_key, reader in input_reader_mapping.items(): # IDEA: async
logger.debug("Using %s to read input_key '%s'", reader, input_key)
data = reader.read(input_key)
if isinstance(data, xr.Dataset):
data = {input_key: data}
dataset_mapping.update(data)
return dataset_mapping
def _match_inputs(self, input_keys: List[str]) -> Dict[str, DataReader]:
input_reader_mapping: Dict[str, DataReader] = {}
for input_key in input_keys:
for regex, reader in self.readers.items(): # type: ignore
regex = cast(Pattern[str], regex)
if regex.match(input_key):
input_reader_mapping[input_key] = reader
break
return input_reader_mapping
def _rename_variables(
self,
dataset: xr.Dataset,
input_config: InputKeyRetrieverConfig,
) -> xr.Dataset:
"""-----------------------------------------------------------------------------
Renames variables in the retrieved dataset according to retrieval configurations.
Args:
raw_dataset (xr.Dataset): The raw dataset.
Returns:
xr.Dataset: The simplified raw dataset.
-----------------------------------------------------------------------------"""
to_rename: Dict[str, str] = {} # {raw_name: output_name}
coords_to_rename = {
c.name: output_name for output_name, c in input_config.coords.items()
}
vars_to_rename = {
v.name: output_name for output_name, v in input_config.data_vars.items()
}
to_rename.update(coords_to_rename)
to_rename.update(vars_to_rename)
for raw_name, output_name in coords_to_rename.items():
if raw_name not in dataset:
to_rename.pop(raw_name)
logger.warning(
"Coordinate variable '%s' could not be retrieved from input. Please"
" ensure the retrieval configuration file for the '%s' coord has"
" the 'name' property set to the exact name of the variable in the"
" dataset returned by the input DataReader.",
raw_name,
output_name,
)
for raw_name, output_name in vars_to_rename.items():
if raw_name not in dataset:
to_rename.pop(raw_name)
logger.warning(
"Data variable '%s' could not be retrieved from input. Please"
" ensure the retrieval configuration file for the '%s' data"
" variable has the 'name' property set to the exact name of the"
" variable in the dataset returned by the input DataReader.",
raw_name,
output_name,
)
return dataset.rename(to_rename)
def _run_data_converters(
self,
dataset: xr.Dataset,
dataset_config: DatasetConfig,
input_config: InputKeyRetrieverConfig,
) -> xr.Dataset:
"""------------------------------------------------------------------------------------
Runs the declared DataConverters on the dataset's coords and data_vars.
Returns the dataset after all converters have been run.
Args:
dataset (xr.Dataset): The dataset to convert.
dataset_config (DatasetConfig): The DatasetConfig
Returns:
xr.Dataset: The converted dataset.
------------------------------------------------------------------------------------"""
for coord_name, coord_config in input_config.coords.items():
for converter in coord_config.data_converters:
dataset = converter.convert(dataset, dataset_config, coord_name)
for var_name, var_config in input_config.data_vars.items():
for converter in var_config.data_converters:
dataset = converter.convert(dataset, dataset_config, var_name)
return dataset
def _reindex_dataset_coords(
self,
dataset: xr.Dataset,
dataset_config: DatasetConfig,
input_config: InputKeyRetrieverConfig,
) -> xr.Dataset:
"""-----------------------------------------------------------------------------
Swaps dimensions and coordinates to match the structure of the DatasetConfig.
Ensures that the retriever coordinates are set as coordinates in the dataset,
promoting them to coordinates from data_vars as needed, and reindexes data_vars
so they are dimensioned by the appropriate coordinates.
This is useful in situations where the DataReader does not know which variables
to set as coordinates in its returned xr.Dataset, so it instead creates some
arbitrary index coordinate to dimension the data variables. This is very common
when reading from non-heirarchal formats such as csv.
Args:
dataset (xr.Dataset): The dataset to reindex.
dataset_config (DatasetConfig): The DatasetConfig.
Returns:
xr.Dataset: The reindexed dataset.
-----------------------------------------------------------------------------"""
for coord_name in input_config.coords:
expected_dim = dataset_config[coord_name].dims[0]
actual_dims = dataset[coord_name].dims
if (ndims := len(actual_dims)) != 1:
raise ValueError(
f"Retrieved coordinate '{coord_name}' must have exactly one"
f" dimension in the retrieved dataset, found {ndims} (dims="
f"{actual_dims}). If '{coord_name}' is not actually a coordinate"
" variable, please move it to the data_vars section in the"
" retriever config file."
)
dim = actual_dims[0]
if dim != expected_dim:
dataset = dataset.swap_dims({dim: expected_dim}) # type: ignore
return dataset
def _merge_raw_mapping(self, raw_mapping: Dict[str, xr.Dataset]) -> xr.Dataset:
return xr.merge(list(raw_mapping.values()), **self.parameters.merge_kwargs) # type: ignore