"""Module for binning Lagrangian data."""
import datetime
import warnings
from functools import partial, wraps
from typing import Callable
import numpy as np
import pandas as pd
import xarray as xr
DEFAULT_BINS_NUMBER = 10
DEFAULT_COORD_NAME = "coord"
DEFAULT_DATA_NAME = "data"
[docs]
def binned_statistics(
coords: np.ndarray | list[np.ndarray],
data: np.ndarray | list[np.ndarray] | None = None,
bins: int | list = DEFAULT_BINS_NUMBER,
bins_range: list | None = None,
dim_names: list[str] | None = None,
output_names: list[str] | None = None,
statistics: str | list | Callable[[np.ndarray], float] = "count",
) -> xr.Dataset:
"""
Perform N-dimensional binning and compute statistics of values in each bin. The result is returned as an Xarray Dataset.
Parameters
----------
coords : array-like or list of array-like
Array(s) of Lagrangian data coordinates to be binned. For 1D, provide a single array.
For N-dimensions, provide a list of N arrays, each giving coordinates along one dimension.
data : array-like or list of array-like
Data values associated with the Lagrangian coordinates in coords.
Can be a single array or a list of arrays for multiple variables.
Complex values are supported for the supported statistics except for 'min', 'max', and 'median'.
bins : int or lists, optional
Number of bins or bin edges per dimension. It can be:
- An int: same number of bins for all dimensions,
- A list of ints or arrays: one per dimension, specifying either bin count or bin edges,
- None: defaults to 10 bins per dimension.
bins_range : list of tuples, optional
Outer bin limits for each dimension.
statistics : str or list of str, Callable[[np.ndarray], float] or list[Callable[[np.ndarray], float]]
Statistics to compute for each bin. It can be:
- a string, supported values: 'count', 'sum', 'mean', 'median', 'std', 'min', 'max', (default: "count"),
- a custom function as a callable for univariate statistics that take a 1D array of values and return a single value.
The callable is applied to each variable of data.
- a tuple of (output_name, callable) for multivariate statistics. 'output_name' is used to identify the resulting variables.
In this case, the callable will receive the list of arrays provided in `data`. For example, to calculate kinetic energy from data with velocity components `u` and `v`,
you can pass `data = [u, v]` and `statistics=("ke", lambda data: np.sqrt(np.mean(data[0] ** 2 + data[1] ** 2)))`.
- a list containing any combination of the above, e.g., ['mean', np.nanmax, ('ke', lambda data: np.sqrt(np.mean(data[0] ** 2 + data[1] ** 2)))].
dim_names : list of str, optional
Names for the dimensions of the output xr.Dataset.
If None, default names are "coord_0", "coord_1", etc.
output_names : list of str, optional
Names for output variables in the xr.Dataset.
If None, default names are "data_0_{statistic}", "data_1_{statistic}", etc.
Returns
-------
xr.Dataset
Xarray dataset with binned means and count for each variable.
"""
# convert coords, data parameters to numpy arrays and validate dimensions
# D, N = number of dimensions and number of data points
if not isinstance(coords, np.ndarray) or coords.ndim == 1:
coords = np.atleast_2d(coords)
D, N = coords.shape
# validate coordinates are finite
for c in coords:
var = c.copy()
if var.dtype == "O":
var = var.astype(type(var[0]))
if _is_datetime_array(var):
if pd.isna(var).any():
raise ValueError("Datetime coordinates must be finite values.")
else:
if pd.isna(var).any() or np.isinf(var).any():
raise ValueError("Coordinates must be finite values.")
# V, VN = number of variables and number of data points per variable
if data is None:
data = np.empty((1, 0))
V, VN = 1, N # no data provided
elif not isinstance(data, np.ndarray) or data.ndim == 1:
data = np.atleast_2d(data)
V, VN = data.shape
else:
V, VN = data.shape
# convert datetime coordinates to numeric values
coords_datetime_index = np.where([_is_datetime_array(c) for c in coords])[0]
for i in coords_datetime_index:
coords[i] = _datetime64_to_float(coords[i])
coords = coords.astype(np.float64)
# set default bins and bins range
if isinstance(bins, (list, tuple)):
if len(bins) != len(coords):
raise ValueError("`bins` must match the dimensions of the coordinates")
bins = [b if b is not None else DEFAULT_BINS_NUMBER for b in bins]
elif isinstance(bins, int):
bins = [bins if bins is not None else DEFAULT_BINS_NUMBER] * len(coords)
if bins_range is None:
bins_range = [(np.nanmin(c), np.nanmax(c)) for c in coords]
else:
if isinstance(bins_range, tuple):
bins_range = [bins_range] * len(coords)
bins_range = [
r if r is not None else (np.nanmin(c), np.nanmax(c))
for r, c in zip(bins_range, coords)
]
# validate statistics parameter
ordered_statistics = ["count", "sum", "mean", "median", "std", "min", "max"]
if isinstance(statistics, (str, tuple)) or callable(statistics):
statistics = [statistics]
elif not isinstance(statistics, list):
raise ValueError(
"`statistics` must be a string, list of strings, Callable, or a list of Callables. "
f"Supported values: {', '.join(ordered_statistics)}."
)
if invalid := [
stat
for stat in statistics
if (stat not in ordered_statistics)
and not callable(stat)
and not isinstance(stat, tuple)
]:
raise ValueError(
f"Unsupported statistic(s): {', '.join(map(str, invalid))}. "
f"Supported: {ordered_statistics} or a Callable."
)
# validate multivariable statistics
for statistic in statistics:
if isinstance(statistic, tuple):
output_name, statistic = statistic
if not isinstance(output_name, str):
raise ValueError(
f"Invalid output name '{output_name}', must be a string."
)
if not callable(statistic):
raise ValueError(
"Multivariable `statistics` function is not Callable, must provide as a tuple(output_name, Callable)."
)
# validate and sort statistics for efficiency
statistics_str = [s for s in statistics if isinstance(s, str)]
statistics_func = [s for s in statistics if not isinstance(s, str)]
statistics = (
sorted(
set(statistics_str),
key=lambda x: ordered_statistics.index(x),
)
+ statistics_func
)
if statistics and not data.size:
warnings.warn(
f"no `data` provided, `statistics` ({statistics}) will be computed on the coordinates."
)
# set default dimension names
if dim_names is None:
dim_names = [f"{DEFAULT_COORD_NAME}_{i}" for i in range(len(coords))]
else:
dim_names = [
name if name is not None else f"{DEFAULT_COORD_NAME}_{i}"
for i, name in enumerate(dim_names)
]
# set default variable names
if output_names is None:
output_names = [
f"{DEFAULT_DATA_NAME}_{i}" if data[0].size else DEFAULT_DATA_NAME
for i in range(len(data))
]
else:
output_names = [
name if name is not None else f"{DEFAULT_DATA_NAME}_{i}"
for i, name in enumerate(output_names)
]
# ensure inputs are consistent
if D != len(dim_names):
raise ValueError("`coords` and `dim_names` must have the same length")
if V != len(output_names):
raise ValueError("`data` and `output_names` must have the same length")
if N != VN:
raise ValueError("`coords` and `data` must have the same number of data points")
# edges and bin centers
if isinstance(bins, int) or isinstance(bins[0], int):
edges = [np.linspace(r[0], r[1], b + 1) for r, b in zip(bins_range, bins)]
else:
edges = [np.asarray(b) for b in bins]
edges_sz = [len(e) - 1 for e in edges]
n_bins = int(np.prod(edges_sz))
bin_centers = [0.5 * (e[:-1] + e[1:]) for e in edges]
# convert bin centers back to datetime64 for output dataset
for i in coords_datetime_index:
bin_centers[i] = _float_to_datetime64(bin_centers[i])
# digitize coordinates into bin indices
# modify edges to ensure the last edge is inclusive
# by adding a small tolerance to the last edge (1s for date coordinates)
edges_with_tol = [e.copy() for e in edges]
for i, e in enumerate(edges_with_tol):
e[-1] += np.finfo(float).eps if i not in coords_datetime_index else 1
indices = [np.digitize(c, edges_with_tol[j]) - 1 for j, c in enumerate(coords)]
valid = np.all(
[(j >= 0) & (j < edges_sz[i]) for i, j in enumerate(indices)], axis=0
)
indices = [i[valid] for i in indices]
# create an iterable of statistics to compute
statistics_iter = []
for statistic in statistics:
if isinstance(statistic, str) or callable(statistic):
for var, name in zip(data, output_names):
statistics_iter.append((var, name, statistic))
elif isinstance(statistic, tuple):
output_name, statistic = statistic
statistics_iter.append((data, output_name, statistic))
ds = xr.Dataset()
for var, name, statistic in statistics_iter:
# count the number of points in each bin
var_finite, indices_finite = _filter_valid_and_finite(var, indices, valid)
flat_idx = np.ravel_multi_index(indices_finite, edges_sz)
# convert object arrays to a common type
if var_finite.dtype == "O":
var_finite = var_finite.astype(type(var_finite[0]))
# loop through statistics for the variable
bin_count, bin_mean, bin_sum = None, None, None
if statistic == "count":
binned_stats = _binned_count(flat_idx, n_bins)
bin_count = binned_stats.copy()
elif statistic == "sum":
if _is_datetime_array(var_finite):
raise ValueError("Datetime data is not supported for 'sum' statistic.")
binned_stats = _binned_sum(flat_idx, n_bins, values=var_finite)
bin_sum = binned_stats.copy()
elif statistic == "mean":
binned_stats = _binned_mean(
flat_idx,
n_bins,
values=var_finite,
bin_counts=bin_count,
bin_sum=bin_sum,
)
bin_mean = binned_stats.copy()
elif statistic == "std":
binned_stats = _binned_std(
flat_idx,
n_bins,
values=var_finite,
bin_counts=bin_count,
bin_mean=bin_mean,
)
elif statistic == "min":
binned_stats = _binned_min(
flat_idx,
n_bins,
values=var_finite,
)
elif statistic == "max":
binned_stats = _binned_max(
flat_idx,
n_bins,
values=var_finite,
)
elif statistic == "median":
if np.iscomplexobj(var_finite):
raise ValueError(
"Complex values are not supported for 'median' statistic."
)
binned_stats = _binned_apply_func(
flat_idx,
n_bins,
values=var_finite,
func=np.median,
)
else:
binned_stats = _binned_apply_func(
flat_idx,
n_bins,
values=var_finite,
func=statistic,
)
# add the binned statistics variable to the Dataset
variable_name = (
name
if var_finite.ndim == 2
else _get_variable_name(name, statistic, ds.data_vars)
if callable(statistic)
else f"{name}_{statistic}"
)
ds[variable_name] = xr.DataArray(
binned_stats.reshape(edges_sz),
dims=dim_names,
coords=dict(zip(dim_names, bin_centers)),
)
return ds
def _get_variable_name(
output_name: str,
func: Callable,
ds_vars: xr.core.dataset.DataVariables | dict[str, xr.DataArray],
) -> str:
"""
Get the name of the function or a default name if it is a lambda function.
Parameters
----------
func : Callable
Function to get the name of.
output_name : str
Name of the output variable to which the function is applied.
ds_vars : dict[str, xr.DataArray]
Dictionary of existing variables in the dataset to avoid name collisions.
Returns
-------
str
Name of the function or a custom function name for lambda function.
"""
default_name = "stat"
if isinstance(func, partial):
function_name = getattr(func.func, "__name__", default_name)
else:
function_name = getattr(func, "__name__", default_name)
if function_name == "<lambda>":
function_name = default_name
# avoid name collisions with existing variables
# by adding a suffix if the name already exists
base_name = f"{output_name}_{function_name}"
name = base_name
i = 1
while name in ds_vars:
name = f"{base_name}_{i}"
i += 1
return name
def _filter_valid_and_finite(
var: np.ndarray, indices: list, valid: np.ndarray
) -> tuple[np.ndarray, list[np.ndarray]]:
"""
Filter valid and finite values from the variable and indices.
Args:
var : np.ndarray
Variable data to filter.
indices : list
List of index arrays to filter.
valid : np.ndarray
Boolean array indicating valid entries.
V : int
Size of the 'data' parameter to determine if the variable is multivariate.
Returns:
tuple[np.ndarray, list[np.ndarray]]: Filtered variable and indices.
"""
if var.ndim == 2:
var_valid = [v[valid] for v in var]
mask = np.logical_or.reduce([~pd.isna(v) for v in var_valid])
var_finite = np.array([v[mask] for v in var_valid])
indices_finite = [i[mask] for i in indices]
elif var.size:
var = var[valid]
mask = ~pd.isna(var)
var_finite = var[mask]
indices_finite = [i[mask] for i in indices]
else:
var_finite = var.copy()
indices_finite = indices.copy()
return var_finite, indices_finite
def _is_datetime_subelement(arr: np.ndarray) -> bool:
"""
Get the type of the first non-null element in an array.
Parameters
----------
arr : np.ndarray
Numpy array to check.
Returns
-------
bool
True if the first non-null element is a datetime type, False otherwise.
"""
for item in arr.flat:
if item is not None:
return isinstance(item, (datetime.date, np.datetime64))
return False
def _is_datetime_array(arr: np.ndarray) -> bool:
"""
Verify if an array contains datetime values.
Parameters
----------
arr : np.ndarray
Numpy array to check.
Returns
-------
bool
True if the array contains datetime64 or timedelta64 values, False otherwise.
"""
if arr.dtype.kind == "M": # numpy datetime64
return True
# if array is object, check first element
if arr.dtype == object and arr.size > 0:
return _is_datetime_subelement(arr)
return False
def _datetime64_to_float(time_dt: np.ndarray) -> np.ndarray:
"""
Convert np.datetime64 or array of datetime64 to float time since epoch.
Parameters:
----------
time_dt : np.datetime64 or array-like
Datetime64 values to convert.
Returns:
-------
float or np.ndarray of floats
Seconds since UNIX epoch (1970-01-01T00:00:00).
"""
reference_date = np.datetime64("1970-01-01T00:00:00")
return np.array(
(pd.to_datetime(time_dt) - pd.to_datetime(reference_date))
/ pd.to_timedelta(1, "s")
)
def _float_to_datetime64(time_float, count=None):
"""
Convert float seconds since UNIX epoch to np.datetime64.
Parameters:
----------
time_float : float or array-like
Seconds since epoch (1970-01-01T00:00:00).
Returns:
-------
np.datetime64 or np.ndarray of np.datetime64
Converted datetime64 values.
"""
reference_date = np.datetime64("1970-01-01T00:00:00")
date = reference_date + time_float.astype("timedelta64[s]")
return date
[docs]
def handle_datetime_conversion(func: Callable) -> Callable:
"""
A decorator to handle datetime64/timedelta64 conversion for
statistics functions. For datetime `values`, it converts the time to float
seconds since epoch before calling the function, and converts the result back
to datetime64 after the function call.
Assumes that the function accepts `values` as keyword arguments.
"""
@wraps(func)
def wrapper(*args, **kwargs) -> np.ndarray:
values = kwargs.get("values")
datetime_conversion = False
if values is not None:
if datetime_conversion := _is_datetime_array(values):
kwargs["values"] = _datetime64_to_float(values)
# call back the original function
result = func(*args, **kwargs)
# Convert the result to datetime if necessary
if datetime_conversion:
if func.__name__ == "_binned_std":
return result.astype("timedelta64[s]")
return _float_to_datetime64(result)
return result
return wrapper
def _binned_count(flat_idx: np.ndarray, n_bins: int) -> np.ndarray:
"""
Compute the count of values in each bin.
Parameters
----------
flat_idx : array-like
1D array of bin indices, same shape as values.
n_bins: int
number of bins
Returns
-------
result : array-like
1D array of length n_bins with the count per bin
"""
return np.bincount(flat_idx, minlength=n_bins)
@handle_datetime_conversion
def _binned_sum(flat_idx: np.ndarray, n_bins: int, values: np.ndarray) -> np.ndarray:
"""
Compute the sum of values per bin.
Parameters
----------
flat_idx : array-like
1D array of bin indices, same shape as values.
n_bins: int
number of bins
values : array-like
1D array of data values
Returns
-------
result : array-like
1D array of length n_bins with the sum per bin
"""
if np.iscomplexobj(values):
real = np.bincount(flat_idx, weights=values.real, minlength=n_bins)
imag = np.bincount(flat_idx, weights=values.imag, minlength=n_bins)
return real + 1j * imag
else:
return np.bincount(flat_idx, weights=values, minlength=n_bins)
@handle_datetime_conversion
def _binned_mean(
flat_idx: np.ndarray,
n_bins: int,
values: np.ndarray,
bin_counts: np.ndarray | None = None,
bin_sum: np.ndarray | None = None,
) -> np.ndarray:
"""
Compute the mean of values per bin.
Parameters
----------
flat_idx : array-like
1D array of bin indices, same shape as values.
n_bins: int
number of bins
values : array-like
1D array of data values
bin_counts : array-like, optional
Precomputed counts per bin. If None, it will be computed using `_binned_count`.
bin_sum : array-like, optional
Precomputed sum per bin. If None, it will be computed using `_binned_sum`.
Returns
-------
result : array-like
1D array of length n_bins with the mean per bin
"""
if bin_counts is None:
bin_counts = _binned_count(flat_idx, n_bins)
if bin_sum is None:
bin_sum = _binned_sum(flat_idx, n_bins, values)
return np.divide(
bin_sum,
bin_counts,
out=np.full_like(bin_sum, np.nan, dtype=bin_sum.dtype),
where=bin_counts > 0,
)
@handle_datetime_conversion
def _binned_std(
flat_idx: np.ndarray,
n_bins: int,
values: np.ndarray,
bin_counts: np.ndarray | None = None,
bin_mean: np.ndarray | None = None,
) -> np.ndarray:
"""
Compute the standard deviation of values per bin.
Parameters
----------
flat_idx : array-like
1D array of bin indices, same shape as values.
n_bins: int
number of bins
values : array-like
1D array of data values
bin_counts : array-like, optional
Precomputed counts per bin. If None, it will be computed using `_binned_count`.
bin_mean : array-like, optional
Precomputed mean per bin. If None, it will be computed using `_binned_mean`.
Returns
-------
result : array-like
1D array of length n_bins with the standard deviation per bin
"""
if bin_counts is None:
bin_counts = _binned_count(flat_idx, n_bins)
if bin_mean is None:
bin_mean = _binned_mean(flat_idx, n_bins, values, bin_counts)
if np.iscomplexobj(values):
# Use modulus for variance
abs_values = np.abs(values)
bin_sumsq = np.bincount(flat_idx, weights=abs_values**2, minlength=n_bins)
bin_mean_sq = np.divide(
bin_sumsq,
bin_counts,
out=np.full(n_bins, np.nan, dtype=bin_sumsq.dtype),
where=bin_counts > 0,
)
abs_bin_mean = np.abs(bin_mean)
variance = np.maximum(bin_mean_sq - abs_bin_mean**2, 0)
else:
bin_sumsq = np.bincount(flat_idx, weights=values**2, minlength=n_bins)
bin_mean_sq = np.divide(
bin_sumsq,
bin_counts,
out=np.full(n_bins, np.nan, dtype=bin_sumsq.dtype),
where=bin_counts > 0,
)
variance = np.maximum(bin_mean_sq - bin_mean**2, 0)
return np.sqrt(variance)
@handle_datetime_conversion
def _binned_min(flat_idx: np.ndarray, n_bins: int, values: np.ndarray) -> np.ndarray:
"""
Compute the minimum of values per bin.
Parameters
----------
flat_idx : array-like
1D array of bin indices, same shape as values.
n_bins: int
number of bins
values : array-like
1D array of data values
Returns
-------
result : array-like
1D array of length n_bins with the minimum per bin
"""
if np.iscomplexobj(values):
raise ValueError("Complex values are not supported for 'min' statistic.")
output = np.full(n_bins, np.inf)
np.minimum.at(output, flat_idx, values)
output[output == np.inf] = np.nan
return output
@handle_datetime_conversion
def _binned_max(flat_idx: np.ndarray, n_bins: int, values: np.ndarray) -> np.ndarray:
"""
Compute the maximum of values per bin.
Parameters
----------
flat_idx : array-like
1D array of bin indices, same shape as values.
n_bins: int
number of bins
values : array-like
1D array of data values
Returns
-------
result : array-like
1D array of length n_bins with the maximum per bin
"""
if np.iscomplexobj(values):
raise ValueError("Complex values are not supported for 'max' statistic.")
output = np.full(n_bins, -np.inf)
np.maximum.at(output, flat_idx, values)
output[output == -np.inf] = np.nan
return output
@handle_datetime_conversion
def _binned_apply_func(
flat_idx: np.ndarray,
n_bins: int,
values: np.ndarray,
func: Callable[[np.ndarray | list[np.ndarray]], float] = np.mean,
) -> np.ndarray:
"""
Generic wrapper to apply any functions (e.g., percentile) to binned data.
Parameters
----------
flat_idx : array-like
1D array of bin indices.
n_bins : int
Number of bins.
values : array-like or list of array-like
1D array (univariate) or list of 1D arrays (multivariate) of data values.
func : Callable[[list[np.ndarray]], float]
Function to apply to each bin. If multivariate, will receive a list of arrays.
Returns
-------
result : np.ndarray
1D array of length n_bins with results from func per bin.
"""
sort_indices = np.argsort(flat_idx)
sorted_flat_idx = flat_idx[sort_indices]
# single or all variables can be passed as input values
if is_multivariate := values.ndim == 2:
sorted_values = [v[sort_indices] for v in values]
else:
sorted_values = [values[sort_indices]]
unique_bins, bin_starts = np.unique(sorted_flat_idx, return_index=True)
bin_ends = np.append(bin_starts[1:], len(sorted_flat_idx))
result = np.full(n_bins, np.nan)
for i, bin_idx in enumerate(unique_bins):
if is_multivariate:
bin_values = [v[bin_starts[i] : bin_ends[i]] for v in sorted_values]
else:
bin_values = sorted_values[0][bin_starts[i] : bin_ends[i]]
result[bin_idx] = func(bin_values)
return result