"""Module for binning Lagrangian data."""
import datetime
import warnings
from functools import partial, wraps
from typing import Callable
import numpy as np
import pandas as pd
import xarray as xr
DEFAULT_BINS_NUMBER = 10
DEFAULT_COORD_NAME = "coord"
DEFAULT_DATA_NAME = "data"
[docs]
def binned_statistics(
    coords: np.ndarray | list[np.ndarray],
    data: np.ndarray | list[np.ndarray] | None = None,
    bins: int | list = DEFAULT_BINS_NUMBER,
    bins_range: list | None = None,
    dim_names: list[str] | None = None,
    output_names: list[str] | None = None,
    statistics: str | list | Callable[[np.ndarray], float] = "count",
) -> xr.Dataset:
    """
    Perform N-dimensional binning and compute statistics of values in each bin. The result is returned as an Xarray Dataset.
    Parameters
    ----------
    coords : array-like or list of array-like
        Array(s) of Lagrangian data coordinates to be binned. For 1D, provide a single array.
        For N-dimensions, provide a list of N arrays, each giving coordinates along one dimension.
    data : array-like or list of array-like
        Data values associated with the Lagrangian coordinates in coords.
        Can be a single array or a list of arrays for multiple variables.
        Complex values are supported for the supported statistics except for 'min', 'max', and 'median'.
    bins : int or lists, optional
        Number of bins or bin edges per dimension. It can be:
        - An int: same number of bins for all dimensions,
        - A list of ints or arrays: one per dimension, specifying either bin count or bin edges,
        - None: defaults to 10 bins per dimension.
    bins_range : list of tuples, optional
        Outer bin limits for each dimension.
    statistics : str or list of str, Callable[[np.ndarray], float] or list[Callable[[np.ndarray], float]]
        Statistics to compute for each bin. It can be:
        - a string, supported values: 'count', 'sum', 'mean', 'median', 'std', 'min', 'max', (default: "count"),
        - a custom function as a callable for univariate statistics that take a 1D array of values and return a single value.
          The callable is applied to each variable of data.
        - a tuple of (output_name, callable) for multivariate statistics. 'output_name' is used to identify the resulting variables.
          In this case, the callable will receive the list of arrays provided in `data`. For example, to calculate kinetic energy from data with velocity components `u` and `v`,
          you can pass `data = [u, v]` and  `statistics=("ke", lambda data: np.sqrt(np.mean(data[0] ** 2 + data[1] ** 2)))`.
        - a list containing any combination of the above, e.g., ['mean', np.nanmax, ('ke', lambda data: np.sqrt(np.mean(data[0] ** 2 + data[1] ** 2)))].
    dim_names : list of str, optional
        Names for the dimensions of the output xr.Dataset.
        If None, default names are "coord_0", "coord_1", etc.
    output_names : list of str, optional
        Names for output variables in the xr.Dataset.
        If None, default names are "data_0_{statistic}", "data_1_{statistic}", etc.
    Returns
    -------
    xr.Dataset
        Xarray dataset with binned means and count for each variable.
    """
    # convert coords, data parameters to numpy arrays and validate dimensions
    # D, N = number of dimensions and number of data points
    if not isinstance(coords, np.ndarray) or coords.ndim == 1:
        coords = np.atleast_2d(coords)
    D, N = coords.shape
    # validate coordinates are finite
    for c in coords:
        var = c.copy()
        if var.dtype == "O":
            var = var.astype(type(var[0]))
        if _is_datetime_array(var):
            if pd.isna(var).any():
                raise ValueError("Datetime coordinates must be finite values.")
        else:
            if pd.isna(var).any() or np.isinf(var).any():
                raise ValueError("Coordinates must be finite values.")
    # V, VN = number of variables and number of data points per variable
    if data is None:
        data = np.empty((1, 0))
        V, VN = 1, N  # no data provided
    elif not isinstance(data, np.ndarray) or data.ndim == 1:
        data = np.atleast_2d(data)
        V, VN = data.shape
    else:
        V, VN = data.shape
    # convert datetime coordinates to numeric values
    coords_datetime_index = np.where([_is_datetime_array(c) for c in coords])[0]
    for i in coords_datetime_index:
        coords[i] = _datetime64_to_float(coords[i])
    coords = coords.astype(np.float64)
    # set default bins and bins range
    if isinstance(bins, (list, tuple)):
        if len(bins) != len(coords):
            raise ValueError("`bins` must match the dimensions of the coordinates")
        bins = [b if b is not None else DEFAULT_BINS_NUMBER for b in bins]
    elif isinstance(bins, int):
        bins = [bins if bins is not None else DEFAULT_BINS_NUMBER] * len(coords)
    if bins_range is None:
        bins_range = [(np.nanmin(c), np.nanmax(c)) for c in coords]
    else:
        if isinstance(bins_range, tuple):
            bins_range = [bins_range] * len(coords)
        bins_range = [
            r if r is not None else (np.nanmin(c), np.nanmax(c))
            for r, c in zip(bins_range, coords)
        ]
    # validate statistics parameter
    ordered_statistics = ["count", "sum", "mean", "median", "std", "min", "max"]
    if isinstance(statistics, (str, tuple)) or callable(statistics):
        statistics = [statistics]
    elif not isinstance(statistics, list):
        raise ValueError(
            "`statistics` must be a string, list of strings, Callable, or a list of Callables. "
            f"Supported values: {', '.join(ordered_statistics)}."
        )
    if invalid := [
        stat
        for stat in statistics
        if (stat not in ordered_statistics)
        and not callable(stat)
        and not isinstance(stat, tuple)
    ]:
        raise ValueError(
            f"Unsupported statistic(s): {', '.join(map(str, invalid))}. "
            f"Supported: {ordered_statistics} or a Callable."
        )
    # validate multivariable statistics
    for statistic in statistics:
        if isinstance(statistic, tuple):
            output_name, statistic = statistic
            if not isinstance(output_name, str):
                raise ValueError(
                    f"Invalid output name '{output_name}', must be a string."
                )
            if not callable(statistic):
                raise ValueError(
                    "Multivariable `statistics` function is not Callable, must provide as a tuple(output_name, Callable)."
                )
    # validate and sort statistics for efficiency
    statistics_str = [s for s in statistics if isinstance(s, str)]
    statistics_func = [s for s in statistics if not isinstance(s, str)]
    statistics = (
        sorted(
            set(statistics_str),
            key=lambda x: ordered_statistics.index(x),
        )
        + statistics_func
    )
    if statistics and not data.size:
        warnings.warn(
            f"no `data` provided, `statistics` ({statistics}) will be computed on the coordinates."
        )
    # set default dimension names
    if dim_names is None:
        dim_names = [f"{DEFAULT_COORD_NAME}_{i}" for i in range(len(coords))]
    else:
        dim_names = [
            name if name is not None else f"{DEFAULT_COORD_NAME}_{i}"
            for i, name in enumerate(dim_names)
        ]
    # set default variable names
    if output_names is None:
        output_names = [
            f"{DEFAULT_DATA_NAME}_{i}" if data[0].size else DEFAULT_DATA_NAME
            for i in range(len(data))
        ]
    else:
        output_names = [
            name if name is not None else f"{DEFAULT_DATA_NAME}_{i}"
            for i, name in enumerate(output_names)
        ]
    # ensure inputs are consistent
    if D != len(dim_names):
        raise ValueError("`coords` and `dim_names` must have the same length")
    if V != len(output_names):
        raise ValueError("`data` and `output_names` must have the same length")
    if N != VN:
        raise ValueError("`coords` and `data` must have the same number of data points")
    # edges and bin centers
    if isinstance(bins, int) or isinstance(bins[0], int):
        edges = [np.linspace(r[0], r[1], b + 1) for r, b in zip(bins_range, bins)]
    else:
        edges = [np.asarray(b) for b in bins]
    edges_sz = [len(e) - 1 for e in edges]
    n_bins = int(np.prod(edges_sz))
    bin_centers = [0.5 * (e[:-1] + e[1:]) for e in edges]
    # convert bin centers back to datetime64 for output dataset
    for i in coords_datetime_index:
        bin_centers[i] = _float_to_datetime64(bin_centers[i])
    # digitize coordinates into bin indices
    # modify edges to ensure the last edge is inclusive
    # by adding a small tolerance to the last edge (1s for date coordinates)
    edges_with_tol = [e.copy() for e in edges]
    for i, e in enumerate(edges_with_tol):
        e[-1] += np.finfo(float).eps if i not in coords_datetime_index else 1
    indices = [np.digitize(c, edges_with_tol[j]) - 1 for j, c in enumerate(coords)]
    valid = np.all(
        [(j >= 0) & (j < edges_sz[i]) for i, j in enumerate(indices)], axis=0
    )
    indices = [i[valid] for i in indices]
    # create an iterable of statistics to compute
    statistics_iter = []
    for statistic in statistics:
        if isinstance(statistic, str) or callable(statistic):
            for var, name in zip(data, output_names):
                statistics_iter.append((var, name, statistic))
        elif isinstance(statistic, tuple):
            output_name, statistic = statistic
            statistics_iter.append((data, output_name, statistic))
    ds = xr.Dataset()
    for var, name, statistic in statistics_iter:
        # count the number of points in each bin
        var_finite, indices_finite = _filter_valid_and_finite(var, indices, valid)
        flat_idx = np.ravel_multi_index(indices_finite, edges_sz)
        # convert object arrays to a common type
        if var_finite.dtype == "O":
            var_finite = var_finite.astype(type(var_finite[0]))
        # loop through statistics for the variable
        bin_count, bin_mean, bin_sum = None, None, None
        if statistic == "count":
            binned_stats = _binned_count(flat_idx, n_bins)
            bin_count = binned_stats.copy()
        elif statistic == "sum":
            if _is_datetime_array(var_finite):
                raise ValueError("Datetime data is not supported for 'sum' statistic.")
            binned_stats = _binned_sum(flat_idx, n_bins, values=var_finite)
            bin_sum = binned_stats.copy()
        elif statistic == "mean":
            binned_stats = _binned_mean(
                flat_idx,
                n_bins,
                values=var_finite,
                bin_counts=bin_count,
                bin_sum=bin_sum,
            )
            bin_mean = binned_stats.copy()
        elif statistic == "std":
            binned_stats = _binned_std(
                flat_idx,
                n_bins,
                values=var_finite,
                bin_counts=bin_count,
                bin_mean=bin_mean,
            )
        elif statistic == "min":
            binned_stats = _binned_min(
                flat_idx,
                n_bins,
                values=var_finite,
            )
        elif statistic == "max":
            binned_stats = _binned_max(
                flat_idx,
                n_bins,
                values=var_finite,
            )
        elif statistic == "median":
            if np.iscomplexobj(var_finite):
                raise ValueError(
                    "Complex values are not supported for 'median' statistic."
                )
            binned_stats = _binned_apply_func(
                flat_idx,
                n_bins,
                values=var_finite,
                func=np.median,
            )
        else:
            binned_stats = _binned_apply_func(
                flat_idx,
                n_bins,
                values=var_finite,
                func=statistic,
            )
        # add the binned statistics variable to the Dataset
        variable_name = (
            name
            if var_finite.ndim == 2
            else _get_variable_name(name, statistic, ds.data_vars)
            if callable(statistic)
            else f"{name}_{statistic}"
        )
        ds[variable_name] = xr.DataArray(
            binned_stats.reshape(edges_sz),
            dims=dim_names,
            coords=dict(zip(dim_names, bin_centers)),
        )
    return ds 
def _get_variable_name(
    output_name: str,
    func: Callable,
    ds_vars: xr.core.dataset.DataVariables | dict[str, xr.DataArray],
) -> str:
    """
    Get the name of the function or a default name if it is a lambda function.
    Parameters
    ----------
    func : Callable
        Function to get the name of.
    output_name : str
        Name of the output variable to which the function is applied.
    ds_vars : dict[str, xr.DataArray]
        Dictionary of existing variables in the dataset to avoid name collisions.
    Returns
    -------
    str
        Name of the function or a custom function name for lambda function.
    """
    default_name = "stat"
    if isinstance(func, partial):
        function_name = getattr(func.func, "__name__", default_name)
    else:
        function_name = getattr(func, "__name__", default_name)
        if function_name == "<lambda>":
            function_name = default_name
    # avoid name collisions with existing variables
    # by adding a suffix if the name already exists
    base_name = f"{output_name}_{function_name}"
    name = base_name
    i = 1
    while name in ds_vars:
        name = f"{base_name}_{i}"
        i += 1
    return name
def _filter_valid_and_finite(
    var: np.ndarray, indices: list, valid: np.ndarray
) -> tuple[np.ndarray, list[np.ndarray]]:
    """
    Filter valid and finite values from the variable and indices.
    Args:
        var : np.ndarray
            Variable data to filter.
        indices : list
            List of index arrays to filter.
        valid : np.ndarray
            Boolean array indicating valid entries.
        V : int
            Size of the 'data' parameter to determine if the variable is multivariate.
    Returns:
        tuple[np.ndarray, list[np.ndarray]]: Filtered variable and indices.
    """
    if var.ndim == 2:
        var_valid = [v[valid] for v in var]
        mask = np.logical_or.reduce([~pd.isna(v) for v in var_valid])
        var_finite = np.array([v[mask] for v in var_valid])
        indices_finite = [i[mask] for i in indices]
    elif var.size:
        var = var[valid]
        mask = ~pd.isna(var)
        var_finite = var[mask]
        indices_finite = [i[mask] for i in indices]
    else:
        var_finite = var.copy()
        indices_finite = indices.copy()
    return var_finite, indices_finite
def _is_datetime_subelement(arr: np.ndarray) -> bool:
    """
    Get the type of the first non-null element in an array.
    Parameters
    ----------
    arr : np.ndarray
        Numpy array to check.
    Returns
    -------
    bool
        True if the first non-null element is a datetime type, False otherwise.
    """
    for item in arr.flat:
        if item is not None:
            return isinstance(item, (datetime.date, np.datetime64))
    return False
def _is_datetime_array(arr: np.ndarray) -> bool:
    """
    Verify if an array contains datetime values.
    Parameters
    ----------
    arr : np.ndarray
        Numpy array to check.
    Returns
    -------
    bool
        True if the array contains datetime64 or timedelta64 values, False otherwise.
    """
    if arr.dtype.kind == "M":  # numpy datetime64
        return True
    # if array is object, check first element
    if arr.dtype == object and arr.size > 0:
        return _is_datetime_subelement(arr)
    return False
def _datetime64_to_float(time_dt: np.ndarray) -> np.ndarray:
    """
    Convert np.datetime64 or array of datetime64 to float time since epoch.
    Parameters:
    ----------
    time_dt : np.datetime64 or array-like
        Datetime64 values to convert.
    Returns:
    -------
    float or np.ndarray of floats
        Seconds since UNIX epoch (1970-01-01T00:00:00).
    """
    reference_date = np.datetime64("1970-01-01T00:00:00")
    return np.array(
        (pd.to_datetime(time_dt) - pd.to_datetime(reference_date))
        / pd.to_timedelta(1, "s")
    )
def _float_to_datetime64(time_float, count=None):
    """
    Convert float seconds since UNIX epoch to np.datetime64.
    Parameters:
    ----------
    time_float : float or array-like
        Seconds since epoch (1970-01-01T00:00:00).
    Returns:
    -------
    np.datetime64 or np.ndarray of np.datetime64
        Converted datetime64 values.
    """
    reference_date = np.datetime64("1970-01-01T00:00:00")
    date = reference_date + time_float.astype("timedelta64[s]")
    return date
[docs]
def handle_datetime_conversion(func: Callable) -> Callable:
    """
    A decorator to handle datetime64/timedelta64 conversion for
    statistics functions. For datetime `values`, it converts the time to float
    seconds since epoch before calling the function, and converts the result back
    to datetime64 after the function call.
    Assumes that the function accepts `values` as keyword arguments.
    """
    @wraps(func)
    def wrapper(*args, **kwargs) -> np.ndarray:
        values = kwargs.get("values")
        datetime_conversion = False
        if values is not None:
            if datetime_conversion := _is_datetime_array(values):
                kwargs["values"] = _datetime64_to_float(values)
        # call back the original function
        result = func(*args, **kwargs)
        # Convert the result to datetime if necessary
        if datetime_conversion:
            if func.__name__ == "_binned_std":
                return result.astype("timedelta64[s]")
            return _float_to_datetime64(result)
        return result
    return wrapper 
def _binned_count(flat_idx: np.ndarray, n_bins: int) -> np.ndarray:
    """
    Compute the count of values in each bin.
    Parameters
    ----------
    flat_idx : array-like
        1D array of bin indices, same shape as values.
    n_bins: int
        number of bins
    Returns
    -------
    result : array-like
        1D array of length n_bins with the count per bin
    """
    return np.bincount(flat_idx, minlength=n_bins)
@handle_datetime_conversion
def _binned_sum(flat_idx: np.ndarray, n_bins: int, values: np.ndarray) -> np.ndarray:
    """
    Compute the sum of values per bin.
    Parameters
    ----------
    flat_idx : array-like
        1D array of bin indices, same shape as values.
    n_bins: int
        number of bins
    values : array-like
        1D array of data values
    Returns
    -------
    result : array-like
        1D array of length n_bins with the sum per bin
    """
    if np.iscomplexobj(values):
        real = np.bincount(flat_idx, weights=values.real, minlength=n_bins)
        imag = np.bincount(flat_idx, weights=values.imag, minlength=n_bins)
        return real + 1j * imag
    else:
        return np.bincount(flat_idx, weights=values, minlength=n_bins)
@handle_datetime_conversion
def _binned_mean(
    flat_idx: np.ndarray,
    n_bins: int,
    values: np.ndarray,
    bin_counts: np.ndarray | None = None,
    bin_sum: np.ndarray | None = None,
) -> np.ndarray:
    """
    Compute the mean of values per bin.
    Parameters
    ----------
    flat_idx : array-like
        1D array of bin indices, same shape as values.
    n_bins: int
        number of bins
    values : array-like
        1D array of data values
    bin_counts : array-like, optional
        Precomputed counts per bin. If None, it will be computed using `_binned_count`.
    bin_sum : array-like, optional
        Precomputed sum per bin. If None, it will be computed using `_binned_sum`.
    Returns
    -------
    result : array-like
        1D array of length n_bins with the mean per bin
    """
    if bin_counts is None:
        bin_counts = _binned_count(flat_idx, n_bins)
    if bin_sum is None:
        bin_sum = _binned_sum(flat_idx, n_bins, values)
    return np.divide(
        bin_sum,
        bin_counts,
        out=np.full_like(bin_sum, np.nan, dtype=bin_sum.dtype),
        where=bin_counts > 0,
    )
@handle_datetime_conversion
def _binned_std(
    flat_idx: np.ndarray,
    n_bins: int,
    values: np.ndarray,
    bin_counts: np.ndarray | None = None,
    bin_mean: np.ndarray | None = None,
) -> np.ndarray:
    """
    Compute the standard deviation of values per bin.
    Parameters
    ----------
    flat_idx : array-like
        1D array of bin indices, same shape as values.
    n_bins: int
        number of bins
    values : array-like
        1D array of data values
    bin_counts : array-like, optional
        Precomputed counts per bin. If None, it will be computed using `_binned_count`.
    bin_mean : array-like, optional
        Precomputed mean per bin. If None, it will be computed using `_binned_mean`.
    Returns
    -------
    result : array-like
        1D array of length n_bins with the standard deviation per bin
    """
    if bin_counts is None:
        bin_counts = _binned_count(flat_idx, n_bins)
    if bin_mean is None:
        bin_mean = _binned_mean(flat_idx, n_bins, values, bin_counts)
    if np.iscomplexobj(values):
        # Use modulus for variance
        abs_values = np.abs(values)
        bin_sumsq = np.bincount(flat_idx, weights=abs_values**2, minlength=n_bins)
        bin_mean_sq = np.divide(
            bin_sumsq,
            bin_counts,
            out=np.full(n_bins, np.nan, dtype=bin_sumsq.dtype),
            where=bin_counts > 0,
        )
        abs_bin_mean = np.abs(bin_mean)
        variance = np.maximum(bin_mean_sq - abs_bin_mean**2, 0)
    else:
        bin_sumsq = np.bincount(flat_idx, weights=values**2, minlength=n_bins)
        bin_mean_sq = np.divide(
            bin_sumsq,
            bin_counts,
            out=np.full(n_bins, np.nan, dtype=bin_sumsq.dtype),
            where=bin_counts > 0,
        )
        variance = np.maximum(bin_mean_sq - bin_mean**2, 0)
    return np.sqrt(variance)
@handle_datetime_conversion
def _binned_min(flat_idx: np.ndarray, n_bins: int, values: np.ndarray) -> np.ndarray:
    """
    Compute the minimum of values per bin.
    Parameters
    ----------
    flat_idx : array-like
        1D array of bin indices, same shape as values.
    n_bins: int
        number of bins
    values : array-like
        1D array of data values
    Returns
    -------
    result : array-like
        1D array of length n_bins with the minimum per bin
    """
    if np.iscomplexobj(values):
        raise ValueError("Complex values are not supported for 'min' statistic.")
    output = np.full(n_bins, np.inf)
    np.minimum.at(output, flat_idx, values)
    output[output == np.inf] = np.nan
    return output
@handle_datetime_conversion
def _binned_max(flat_idx: np.ndarray, n_bins: int, values: np.ndarray) -> np.ndarray:
    """
    Compute the maximum of values per bin.
    Parameters
    ----------
    flat_idx : array-like
        1D array of bin indices, same shape as values.
    n_bins: int
        number of bins
    values : array-like
        1D array of data values
    Returns
    -------
    result : array-like
        1D array of length n_bins with the maximum per bin
    """
    if np.iscomplexobj(values):
        raise ValueError("Complex values are not supported for 'max' statistic.")
    output = np.full(n_bins, -np.inf)
    np.maximum.at(output, flat_idx, values)
    output[output == -np.inf] = np.nan
    return output
@handle_datetime_conversion
def _binned_apply_func(
    flat_idx: np.ndarray,
    n_bins: int,
    values: np.ndarray,
    func: Callable[[np.ndarray | list[np.ndarray]], float] = np.mean,
) -> np.ndarray:
    """
    Generic wrapper to apply any functions (e.g., percentile) to binned data.
    Parameters
    ----------
    flat_idx : array-like
        1D array of bin indices.
    n_bins : int
        Number of bins.
    values : array-like or list of array-like
        1D array (univariate) or list of 1D arrays (multivariate) of data values.
    func : Callable[[list[np.ndarray]], float]
        Function to apply to each bin. If multivariate, will receive a list of arrays.
    Returns
    -------
    result : np.ndarray
        1D array of length n_bins with results from func per bin.
    """
    sort_indices = np.argsort(flat_idx)
    sorted_flat_idx = flat_idx[sort_indices]
    # single or all variables can be passed as input values
    if is_multivariate := values.ndim == 2:
        sorted_values = [v[sort_indices] for v in values]
    else:
        sorted_values = [values[sort_indices]]
    unique_bins, bin_starts = np.unique(sorted_flat_idx, return_index=True)
    bin_ends = np.append(bin_starts[1:], len(sorted_flat_idx))
    result = np.full(n_bins, np.nan)
    for i, bin_idx in enumerate(unique_bins):
        if is_multivariate:
            bin_values = [v[bin_starts[i] : bin_ends[i]] for v in sorted_values]
        else:
            bin_values = sorted_values[0][bin_starts[i] : bin_ends[i]]
        result[bin_idx] = func(bin_values)
    return result