Source code for clouddrift.raggedarray

"""
This module defines the RaggedArray class, which is the intermediate data
structure used by CloudDrift to process custom Lagrangian datasets to Xarray
Datasets and Awkward Arrays.
"""

from __future__ import annotations

import warnings
from collections.abc import Callable
from typing import Any, Literal

import awkward as ak  # type: ignore
import numpy as np
import xarray as xr
from tqdm import tqdm

from clouddrift.ragged import rowsize_to_index

DimNames = Literal["rows", "obs"]
_DISABLE_SHOW_PROGRESS = False  # purely to de-noise our test suite output, should never be used/configured outside of that.


[docs] class RaggedArray:
[docs] def __init__( self, coords: dict, metadata: dict, data: dict, attrs_global: dict = {}, attrs_variables: dict = {}, name_dims: dict[str, DimNames] = {}, coord_dims: dict[str, str] = {}, ): self.coords = coords self.coord_dims = coord_dims self.metadata = metadata self.data = data self.attrs_global = attrs_global self.attrs_variables = attrs_variables self.name_dims = name_dims self._coord_dims = coord_dims self.validate_attributes()
[docs] @classmethod def from_awkward( cls, array: ak.Array, name_coords: list, name_dims: dict[str, DimNames], coord_dims: dict[str, str], ): """Load a RaggedArray instance from an Awkward Array. Parameters ---------- array : ak.Array Awkward Array instance to load the data from name_coords : list, optional Names of the coordinate variables in the ragged arrays name_dims: dict Map a dimension to an alias. coord_dims: dict Map a coordinate to a dimension alias. Returns ------- RaggedArray A RaggedArray instance """ coords: dict[str, Any] = {} metadata = {} data = {} attrs_variables = {} attrs_global = array.layout.parameters["attrs"] for var in name_coords: alias = coord_dims[var] if name_dims[alias] == "obs": coords[var] = ak.flatten(array.obs[var]).to_numpy() else: coords[var] = array.obs[var].to_numpy() attrs_variables[var] = array.obs[var].layout.parameters["attrs"] for var in [v for v in array.fields if v != "obs"]: metadata[var] = array[var].to_numpy() attrs_variables[var] = array[var].layout.parameters["attrs"] for var in [v for v in array.obs.fields if v not in coords.keys()]: data[var] = ak.flatten(array.obs[var]).to_numpy() attrs_variables[var] = array.obs[var].layout.parameters["attrs"] return RaggedArray( coords, metadata, data, attrs_global, attrs_variables, name_dims, coord_dims )
[docs] @classmethod def from_files( cls, indices: list[int], preprocess_func: Callable[[int], xr.Dataset], name_coords: list, name_meta: list = list(), name_data: list = list(), name_dims: dict[str, DimNames] = {}, rowsize_func: Callable[[int], int] | None = None, attrs_global: dict | None = None, attrs_variables: dict | None = None, **kwargs, ): """Generate a ragged array archive from a list of files Parameters ---------- indices : list Identification numbers list to iterate preprocess_func : Callable[[int], xr.Dataset] Returns a processed xarray Dataset from an identification number name_meta : list, optional Name of metadata variables to include in the archive (Defaults to []) name_data : list, optional Name of the data variables to include in the archive (Defaults to []) name_dims: dict Map an alias to a dimension. rowsize_func : Optional[Callable[[int], int]], optional Returns the number of observations from an identification number (to speed up processing) (Defaults to None) Returns ------- RaggedArray A RaggedArray instance """ # if no method is supplied, get the dimension from the preprocessing function rowsize_func = ( rowsize_func if rowsize_func else lambda i, **kwargs: preprocess_func(i, **kwargs).sizes["obs"] ) rowsize = cls.number_of_observations(rowsize_func, indices, **kwargs) coords, metadata, data, coord_dims = cls.allocate( preprocess_func, indices, rowsize, name_coords, name_meta, name_data, name_dims, **kwargs, ) extracted_attrs_global, extracted_attrs_variables = cls.attributes( preprocess_func(indices[0], **kwargs), name_coords, name_meta, name_data, ) if attrs_global is None: attrs_global = extracted_attrs_global if attrs_variables is None: attrs_variables = extracted_attrs_variables return RaggedArray( coords, metadata, data, attrs_global, attrs_variables, name_dims, coord_dims )
[docs] @classmethod def from_netcdf(cls, filename: str, rows_dim_name="rows", obs_dim_name="obs"): """Read a ragged arrays archive from a NetCDF file. This is a thin wrapper around ``from_xarray()``. Parameters ---------- filename : str File name of the NetCDF archive to read. Returns ------- RaggedArray A ragged array instance """ return cls.from_xarray(xr.open_dataset(filename), rows_dim_name, obs_dim_name)
[docs] @classmethod def from_parquet( cls, filename: str, name_coords: list, name_dims: dict[str, DimNames], coord_dims: dict[str, str], ): """Read a ragged array from a parquet file. Parameters ---------- filename : str File name of the parquet archive to read. name_coords : list, optional Names of the coordinate variables in the ragged arrays name_dims: dict Map a alias to a dimension. coord_dims: dict Map a coordinate to a dimension alias. Returns ------- RaggedArray A ragged array instance """ return RaggedArray.from_awkward( ak.from_parquet(filename), name_coords, name_dims, coord_dims )
[docs] @classmethod def from_xarray( cls, ds: xr.Dataset, rows_dim_name: str = "rows", obs_dim_name: str = "obs" ): """Populate a RaggedArray instance from an xarray Dataset instance. Parameters ---------- ds : xr.Dataset Xarray Dataset from which to load the RaggedArray rows_dim_name : str, optional Name of the row dimension in the xarray Dataset obs_dim_name : str, optional Name of the observations dimension in the xarray Dataset Returns ------- RaggedArray A RaggedArray instance """ coords = {} metadata = {} data = {} coord_dims = {} name_dims: dict[str, DimNames] = {rows_dim_name: "rows", obs_dim_name: "obs"} attrs_global = {} attrs_variables = {} attrs_global = ds.attrs for var in ds.coords.keys(): var = str(var) dim = ds[var].dims[-1] coord_dims[var] = str(dim) coords[var] = ds[var].data attrs_variables[var] = ds[var].attrs for var in ds.data_vars.keys(): if len(ds[var]) == ds.sizes.get(rows_dim_name): metadata[var] = ds[var].data elif len(ds[var]) == ds.sizes.get(obs_dim_name): data[var] = ds[var].data else: warnings.warn( f""" Variable '{var}' has unknown dimension size of {len(ds[var])}, which is not rows={ds.sizes.get(rows_dim_name)} or obs={ds.sizes.get(obs_dim_name)}; skipping. """ ) attrs_variables[str(var)] = ds[var].attrs return RaggedArray( coords, metadata, data, attrs_global, attrs_variables, name_dims, coord_dims )
[docs] @staticmethod def number_of_observations( rowsize_func: Callable[[int], int], indices: list, **kwargs ) -> np.ndarray: """Iterate through the files and evaluate the number of observations. Parameters ---------- rowsize_func : Callable[[int], int]] Function that returns the number observations of a row from its identification number indices : list Identification numbers list to iterate Returns ------- np.ndarray Number of observations """ rowsize = np.zeros(len(indices), dtype="int") for i, index in tqdm( enumerate(indices), total=len(indices), desc="Retrieving the number of obs", ncols=80, disable=_DISABLE_SHOW_PROGRESS, ): rowsize[i] = rowsize_func(index, **kwargs) return rowsize
[docs] @staticmethod def attributes( ds: xr.Dataset, name_coords: list, name_meta: list, name_data: list, ) -> tuple[dict, dict]: """Return global attributes and the attributes of all variables (name_coords, name_meta, and name_data) from an Xarray Dataset. Parameters ---------- ds : xr.Dataset _description_ name_coords : list, optional Name of metadata variables to include in the archive (default is []) name_meta : list, optional Name of metadata variables to include in the archive (default is []) name_data : list, optional Name of the data variables to include in the archive (default is []) Returns ------- Tuple[dict, dict] The global and variables attributes """ attrs_global = ds.attrs # coordinates, metadata, and data attrs_variables = {} for var in name_meta + name_data + name_coords: if var in ds.keys(): attrs_variables[var] = ds[var].attrs else: warnings.warn(f"Variable {var} requested but not found; skipping.") return attrs_global, attrs_variables
[docs] @staticmethod def allocate( preprocess_func: Callable[[int], xr.Dataset], indices: list, rowsize: list | np.ndarray | xr.DataArray, name_coords: list, name_meta: list, name_data: list, name_dims: dict[str, DimNames], **kwargs, ) -> tuple[dict, dict, dict, dict]: """ Iterate through the files and fill for the ragged array associated with coordinates, and selected metadata and data variables. Parameters ---------- preprocess_func : Callable[[int], xr.Dataset] Returns a processed xarray Dataset from an identification number. indices : list List of indices separating row in the ragged arrays. rowsize : list List of the number of observations per row. name_coords : list Name of the coordinate variables to include in the archive. name_meta : list, optional Name of metadata variables to include in the archive (Defaults to []). name_data : list, optional Name of the data variables to include in the archive (Defaults to []). name_dims: dict[str, DimNames] Dimension alias mapped to the name used by clouddrift. Returns ------- Tuple[dict, dict, dict, dict] Dictionaries containing numerical data and attributes of coordinates, metadata and data variables. """ # open one file to get dtype of variables ds = preprocess_func(indices[0], **kwargs) nb_rows = len(rowsize) nb_obs = np.sum(rowsize).astype("int") index_traj = rowsize_to_index(rowsize) dim_sizes = {} for alias in name_dims.keys(): if name_dims[alias] == "rows": dim_sizes[alias] = nb_rows else: dim_sizes[alias] = nb_obs # allocate memory coords = {} coord_dims: dict[str, str] = {} for var in name_coords: dim = ds[var].dims[-1] dim_size = dim_sizes[dim] coords[var] = np.zeros(dim_size, dtype=ds[var].dtype) coord_dims[var] = dim metadata = {} for var in name_meta: try: metadata[var] = np.zeros(nb_rows, dtype=ds[var].dtype) except KeyError: warnings.warn(f"Variable {var} requested but not found; skipping.") data = {} for var in name_data: if var in ds.keys(): data[var] = np.zeros(nb_obs, dtype=ds[var].dtype) else: warnings.warn(f"Variable {var} requested but not found; skipping.") ds.close() # loop and fill the ragged array for i, index in tqdm( enumerate(indices), total=len(indices), desc="Filling the Ragged Array", ncols=80, disable=_DISABLE_SHOW_PROGRESS, ): with preprocess_func(index, **kwargs) as ds: size = rowsize[i] oid = index_traj[i] for var in name_coords: dim = ds[var].dims[-1] if name_dims[dim] == "obs": coords[var][oid : oid + size] = ds[var].data else: coords[var][i] = ds[var].data[0] for var in name_meta: try: metadata[var][i] = ds[var][0].data except KeyError: warnings.warn( f"Variable {var} requested but not found; skipping." ) for var in name_data: if var in ds.keys(): data[var][oid : oid + size] = ds[var].data else: warnings.warn( f"Variable {var} requested but not found; skipping." ) return coords, metadata, data, coord_dims
[docs] def validate_attributes(self): """Validate that each variable has an assigned attribute tag.""" for key in ( list(self.coords.keys()) + list(self.metadata.keys()) + list(self.data.keys()) ): if key not in self.attrs_variables: self.attrs_variables[key] = {}
[docs] def to_xarray(self): """Convert ragged array object to a xarray Dataset. Parameters ---------- cast_to_float32 : bool, optional Cast all float64 variables to float32 (default is True). This option aims at minimizing the size of the xarray dataset. Returns ------- xr.Dataset Xarray Dataset containing the ragged arrays and their attributes """ dim_name_map = {v: k for k, v in self.name_dims.items()} xr_coords = {} for var in self.coords.keys(): xr_coords[var] = ( [self._coord_dims[var]], self.coords[var], self.attrs_variables[var], ) xr_data = {} for var in self.metadata.keys(): xr_data[var] = ( [dim_name_map["rows"]], self.metadata[var], self.attrs_variables[var], ) for var in self.data.keys(): xr_data[var] = ( [dim_name_map["obs"]], self.data[var], self.attrs_variables[var], ) return xr.Dataset(coords=xr_coords, data_vars=xr_data, attrs=self.attrs_global)
[docs] def to_awkward(self): """Convert ragged array object to an Awkward Array. Returns ------- ak.Array Awkward Array containing the ragged array and its attributes """ index_traj = rowsize_to_index(self.metadata["rowsize"]) offset = ak.index.Index64(index_traj) data = [] for var in self.coords.keys(): dim = self._coord_dims[var] if self.name_dims[dim] == "obs": data.append( ak.contents.ListOffsetArray( offset, ak.contents.NumpyArray(self.coords[var]), parameters={"attrs": self.attrs_variables[var]}, ) ) else: data.append( ak.with_parameter( self.coords[var], "attrs", self.attrs_variables[var], highlevel=False, ) ) for var in self.data.keys(): data.append( ak.contents.ListOffsetArray( offset, ak.contents.NumpyArray(self.data[var]), parameters={"attrs": self.attrs_variables[var]}, ) ) data_names = list(self.coords.keys()) + list(self.data.keys()) metadata = [] for var in self.metadata.keys(): metadata.append( ak.with_parameter( self.metadata[var], "attrs", self.attrs_variables[var], highlevel=False, ) ) metadata_names = list(self.metadata.keys()) # include the data inside the metadata list as a nested array metadata_names.append("obs") metadata.append(ak.Array(ak.contents.RecordArray(data, data_names)).layout) return ak.Array( ak.contents.RecordArray( metadata, metadata_names, parameters={"attrs": self.attrs_global} ) )
[docs] def to_netcdf(self, filename: str): """Export ragged array object to a NetCDF file. Parameters ---------- filename : str Name of the NetCDF file to create. """ self.to_xarray().to_netcdf(filename)
[docs] def to_parquet(self, filename: str): """Export ragged array object to a parquet file. Parameters ---------- filename : str Name of the parquet file to create. """ ak.to_parquet(self.to_awkward(), filename)