Source code for clouddrift.adapters.gdp.gdpsource

from __future__ import annotations

import asyncio
import datetime
import logging
import os
import tempfile
import warnings
from collections import defaultdict
from concurrent.futures import Future, ProcessPoolExecutor, as_completed
from typing import Callable

import numpy as np
import pandas as pd
import xarray as xr
from tqdm.asyncio import tqdm

from clouddrift.adapters.gdp import get_gdp_metadata
from clouddrift.adapters.utils import download_with_progress
from clouddrift.ragged import subset
from clouddrift.raggedarray import RaggedArray

_DATA_URL = "https://www.aoml.noaa.gov/ftp/pub/phod/pub/pazos/data/shane/sst"
_TMP_PATH = os.path.join(tempfile.gettempdir(), "clouddrift", "gdpsource")
_FILENAME_TEMPLATE = "buoydata_{start}_{end}_{suffix}.dat.gz"
_SECONDS_IN_DAY = 86_400

_COORDS = ["id", "obs_index"]

_DATA_VARS = [
    "latitude",
    "longitude",
    "position_datetime",
    "sensor_datetime",
    "drogue",
    "sst",
    "voltage",
    "sensor4",
    "sensor5",
    "sensor6",
    "qualityIndex",
]

_METADATA_VARS = [
    "rowsize",
    "wmo_number",
    "program_number",
    "buoys_type",
    "start_date",
    "start_lat",
    "start_lon",
    "end_date",
    "end_lat",
    "end_lon",
    "drogue_off_date",
    "death_code",
]

_VARS_FILL_MAP: dict = {
    "wmo_number": -999,
    "program_number": -999,
    "buoys_type": "N/A",
    "start_date": np.datetime64("1970-01-01 00:00:00"),
    "start_lat": -999,
    "start_lon": -999,
    "end_date": np.datetime64("1970-01-01 00:00:00"),
    "end_lat": -999,
    "end_lon": -999,
    "drogue_off_date": np.datetime64("1970-01-01 00:00:00"),
    "death_code": -999,
}

_VAR_DTYPES: dict = {
    "rowsize": np.int64,
    "wmo_number": np.int64,
    "program_number": np.int64,
    "buoys_type": np.str_,
    "start_date": np.dtype("datetime64[ns]"),
    "start_lat": np.float64,
    "start_lon": np.float64,
    "end_date": np.dtype("datetime64[ns]"),
    "end_lat": np.float64,
    "end_lon": np.float64,
    "drogue_off_date": np.dtype("datetime64[ns]"),
    "death_code": np.int64,
}


_INPUT_COLS = [
    "id",
    "posObsMonth",
    "posObsDay",
    "posObsYear",
    "latitude",
    "longitude",
    "senObsMonth",
    "senObsDay",
    "senObsYear",
    "drogue",
    "sst",
    "voltage",
    "sensor4",
    "sensor5",
    "sensor6",
    "qualityIndex",
]


_INPUT_COLS_DTYPES = {
    "id": np.int64,
    "posObsMonth": np.int8,
    "posObsDay": np.float64,
    "posObsYear": np.int16,
    "latitude": np.float32,
    "longitude": np.float32,
    "qualityIndex": np.float32,
    "senObsMonth": np.int8,
    "senObsDay": np.float64,
    "senObsYear": np.int16,
    "drogue": np.float32,
    "sst": np.float32,
    "voltage": np.float32,
    "sensor4": np.float32,
    "sensor5": np.float32,
    "sensor6": np.float32,
}


VARS_ATTRS: dict = {
    "id": {"long_name": "Global Drifter Program Buoy ID", "units": "-"},
    "rowsize": {
        "long_name": "Number of observations per trajectory",
        "sample_dimension": "obs",
        "units": "-",
    },
    "wmo_number": {
        "long_name": "World Meteorological Organization buoy identification number",
        "units": "-",
    },
    "program_number": {
        "long_name": "Current Program",
        "units": "-",
    },
    "buoys_type": {
        "long_name": "Buoy type (see https://www.aoml.noaa.gov/phod/dac/dirall.html)",
        "units": "-",
    },
    "start_date": {
        "long_name": "First good date and time derived by DAC quality control",
    },
    "start_lon": {
        "long_name": "First good longitude derived by DAC quality control",
        "units": "degrees_east",
    },
    "start_lat": {
        "long_name": "First good latitude derived by DAC quality control",
        "units": "degrees_north",
    },
    "end_date": {
        "long_name": "Last good date and time derived by DAC quality control",
    },
    "end_lon": {
        "long_name": "Last good longitude derived by DAC quality control",
        "units": "degrees_east",
    },
    "end_lat": {
        "long_name": "Last good latitude derived by DAC quality control",
        "units": "degrees_north",
    },
    "drogue_off_date": {
        "long_name": "Date and time of drogue loss",
    },
    "death_code": {
        "long_name": "Type of death",
        "units": "-",
        "comments": "0 (buoy still alive), 1 (buoy ran aground), 2 (picked up by vessel), 3 (stop transmitting), 4 (sporadic transmissions), 5 (bad batteries), 6 (inactive status)",
    },
    "position_datetime": {
        "comments": "Position datetime derived from the year, month, day and time (represented as a ratio of a day) columns found in the source dataset that represent when the position of the drifter was measured. This value is only different from the sensor_datetime when the position of the drifter was determined onboard the Argos satellites using the doppler shift.",
    },
    "sensor_datetime": {
        "comments": "Sensor datetime derived from the year, month, day and time (represented as a ratio of a day) columns found in the source dataset that represent when the sensor (like temp) data is recorded",
    },
    "longitude": {"long_name": "Longitude", "units": "degrees_east"},
    "latitude": {"long_name": "Latitude", "units": "degrees_north"},
    "drogue": {
        "long_name": "Drogue",
        "units": "-",
        "comments": "Values returned by drogue sensor",
    },
    "sst": {
        "long_name": "sea water temperature",
        "units": "degree_Celsius",
        "comments": "sea water temperature from drifting buoy measurements",
    },
    "voltage": {
        "long_name": "Voltage",
        "units": "V",
    },
    "sensor4": {
        "long_name": "Sensor 4",
        "units": "-",
        "comments": "placeholders for additional sensors such as barometers, salinity etc.",
    },
    "sensor5": {
        "long_name": "Sensor 5",
        "units": "-",
        "comments": "placeholders for additional sensors such as barometers, salinity etc.",
    },
    "sensor6": {
        "long_name": "Sensor 6",
        "units": "-",
        "comments": "placeholders for additional sensors such as barometers, salinity etc.",
    },
    "qualityIndex": {
        "long_name": "Quality Index",
        "units": "-",
    },
}

ATTRS = {
    "title": "Global Drifter Program source drifter dataset",
    "Conventions": "CF-1.6",
    "date_created": datetime.datetime.now().isoformat(),
    "publisher_name": "GDP Drifter DAC",
    "publisher_email": "aoml.dftr@noaa.gov",
    "publisher_url": "https://www.aoml.noaa.gov/phod/gdp",
    "license": "freely available",
    "processing_level": "source files",
    "metadata_link": "https://www.aoml.noaa.gov/phod/dac/dirall.html",
    "contributor_name": "NOAA Global Drifter Program",
    "contributor_role": "Data Acquisition Center",
    "institution": "NOAA Atlantic Oceanographic and Meteorological Laboratory",
    "summary": "Global Drifter Program source (raw) data",
}

_logger = logging.getLogger(__name__)


def _get_download_list(tmp_path: str) -> list[tuple[str, str]]:
    suffix = "rawfiles"
    batches = [(1, 5000), (5001, 10_000), (10_001, 15_000), (15_001, "current")]

    requests = list()

    for start, end in batches:
        filename = _FILENAME_TEMPLATE.format(start=start, end=end, suffix=suffix)
        requests.append((f"{_DATA_URL}/{filename}", os.path.join(tmp_path, filename)))
    return requests


def _rowsize(id_, **kwargs) -> int:
    df: pd.DataFrame | None = kwargs.get("data_df")

    if df is None:
        raise KeyError(
            "Missing `data_df`, please pass them into the `from_files` method"
        )

    traj_data_df = df[df["id"] == id_]
    return len(traj_data_df)


def _preprocess(id_, **kwargs) -> xr.Dataset:
    md_df: pd.DataFrame | None = kwargs.get("md_df")
    data_df: pd.DataFrame | None = kwargs.get("data_df")
    use_fill_values: bool = kwargs.get("use_fill_values", False)

    if md_df is None or data_df is None:
        raise KeyError(
            "Missing `md_df` or `data_df`, pass them into the `from_files` method"
        )

    traj_md_df = md_df[md_df["ID"] == id_]
    traj_data_df = data_df[data_df["id"] == id_]
    rowsize = len(traj_data_df)

    md_variables = {
        "rowsize": np.array([rowsize], dtype=np.int64),
    }

    if use_fill_values and len(traj_md_df) == 0:
        for md_var in _METADATA_VARS:
            if md_var == "rowsize":
                continue

            fill_val = _VARS_FILL_MAP.get(md_var)
            if fill_val is None:
                raise ValueError(
                    f"Fill value missing for metadata variable: `{md_var}`"
                )

            md_variables.update({md_var: np.array([fill_val])})
    else:
        md_variables.update(
            {
                "wmo_number": traj_md_df[["WMO_number"]].values.flatten(),
                "program_number": traj_md_df[["program_number"]].values.flatten(),
                "buoys_type": traj_md_df[["buoys_type"]].values.flatten(),
                "start_date": traj_md_df[["Start_date"]].values.flatten(),
                "start_lat": traj_md_df[["Start_lat"]].values.flatten(),
                "start_lon": traj_md_df[["Start_lon"]].values.flatten(),
                "end_date": traj_md_df[["End_date"]].values.flatten(),
                "end_lat": traj_md_df[["End_lat"]].values.flatten(),
                "end_lon": traj_md_df[["End_lon"]].values.flatten(),
                "drogue_off_date": traj_md_df[["Drogue_off_date"]].values.flatten(),
                "death_code": traj_md_df[["death_code"]].values.flatten(),
            }
        )

    for md_var in _METADATA_VARS:
        var = md_variables.get(md_var)

        if var is None:
            raise ValueError(f"Metadata variable `{md_var}` cannot be found")

        md_variables.update({md_var: var.astype(_VAR_DTYPES.get(md_var))})

    variables = {k: (["traj"], md_variables[k]) for k in md_variables}
    data_vars = {
        var: (["obs"], traj_data_df[[var]].values.flatten()) for var in _DATA_VARS
    }

    coords = {
        "id": (["traj"], np.array([id_]).astype(np.int64)),
        "obs_index": (
            ["obs"],
            traj_data_df[["obs_index"]].values.flatten().astype(np.int32),
        ),
    }

    variables.update(data_vars)
    dataset = xr.Dataset(variables, coords=coords)
    return dataset


def _apply_remove(df: pd.DataFrame, filters: list[Callable]) -> pd.DataFrame:
    temp_df = df
    for filter_ in filters:
        mask = filter_(temp_df)
        temp_df = temp_df[~mask]
    return temp_df


def _apply_transform(
    df: pd.DataFrame,
    transforms: dict[str, tuple[list[str], Callable]],
) -> pd.DataFrame:
    tmp_df = df
    for output_col in transforms.keys():
        input_cols, func = transforms[output_col]
        args = list()
        for col in input_cols:
            arg = df[[col]].values.flatten()
            args.append(arg)
        tmp_df = tmp_df.assign(**{output_col: func(*args)})
        tmp_df = tmp_df.drop(input_cols, axis=1)
    return tmp_df


def _parse_datetime_with_day_ratio(
    month_series: np.ndarray, day_series: np.ndarray, year_series: np.ndarray
) -> np.ndarray:
    values = list()
    for month, day_with_ratio, year in zip(month_series, day_series, year_series):
        day = day_with_ratio // 1
        dayratio = day_with_ratio - day
        seconds = dayratio * _SECONDS_IN_DAY
        dt_ns = (
            datetime.datetime(year=int(year), month=int(month), day=int(1))
            + datetime.timedelta(days=int(day), seconds=seconds)
        ).timestamp() * 10**9
        values.append(int(dt_ns))
    return np.array(values).astype("datetime64[ns]")


def _process_chunk(
    df_chunk: pd.DataFrame,
    start_idx: int,
    end_idx: int,
    gdp_metadata_df: pd.DataFrame,
    use_fill_values: bool,
) -> dict[int, xr.Dataset]:
    """Process each dataframe chunk. Return a dictionary mapping each drifter to a unique xarray Dataset."""

    # Transform the initial dataframe filtering out rows with really anomolous values
    # examples include: years in the future, years way in the past before GDP program, etc...
    preremove_df_chunk = df_chunk.assign(obs_index=range(start_idx, end_idx))
    df_chunk = _apply_remove(
        preremove_df_chunk,
        filters=[
            # Filter out year values that are in the future or predating the GDP program
            lambda df: (df["posObsYear"] > datetime.datetime.now().year)
            | (df["posObsYear"] < 0),
            lambda df: (df["senObsYear"] > datetime.datetime.now().year)
            | (df["senObsYear"] < 0),
            # Filter out month values that contain non-numeric characters
            lambda df: df["senObsMonth"].astype(np.str_).str.contains(r"[\D]"),
            lambda df: df["posObsMonth"].astype(np.str_).str.contains(r"[\D]"),
            # Filter out drogue values that cannot be interpret as floating point values.
            # (e.g. - have more than one decimal point)
            lambda df: df["drogue"].astype(np.str_).str.match(r"(\d+[\.]+){2,}"),
        ],
    )

    drifter_ds_map = dict[int, xr.Dataset]()

    preremove_len = len(preremove_df_chunk)
    postremove_len = len(df_chunk)

    if preremove_len != postremove_len:
        warnings.warn(
            f"Filters removed {preremove_len - postremove_len} rows from chunk"
        )

    if postremove_len == 0:
        return drifter_ds_map

    df_chunk = df_chunk.astype(_INPUT_COLS_DTYPES)
    df_chunk = _apply_transform(
        df_chunk,
        {
            "position_datetime": (
                ["posObsMonth", "posObsDay", "posObsYear"],
                _parse_datetime_with_day_ratio,
            ),
            "sensor_datetime": (
                ["senObsMonth", "senObsDay", "senObsYear"],
                _parse_datetime_with_day_ratio,
            ),
        },
    )

    # Find and process drifters found and documented in the drifter metadata.
    ids_in_data = np.unique(df_chunk[["id"]].values)
    ids_with_md = np.intersect1d(ids_in_data, gdp_metadata_df[["ID"]].values)

    if len(ids_with_md) < len(ids_in_data):
        warnings.warn(
            "Chunk has drifter ids not found in the metadata table. "
            + "Using fill values"
            if use_fill_values
            else "Ignoring data observations"
            + f" for missing metadata ids: {np.setdiff1d(ids_in_data, ids_with_md)}."
        )

    ra = RaggedArray.from_files(
        indices=ids_in_data if use_fill_values else ids_with_md,
        preprocess_func=_preprocess,
        rowsize_func=_rowsize,
        name_coords=_COORDS,
        name_meta=_METADATA_VARS,
        name_data=_DATA_VARS,
        name_dims={"traj": "rows", "obs": "obs"},
        md_df=gdp_metadata_df,
        data_df=df_chunk,
        use_fill_values=use_fill_values,
        tqdm=dict(disable=True),
    )
    ds = ra.to_xarray()

    for id_ in ids_with_md:
        id_f_ds = subset(ds, dict(id=id_), row_dim_name="traj")
        drifter_ds_map[id_] = id_f_ds
    return drifter_ds_map


def _combine_chunked_drifter_datasets(datasets: list[xr.Dataset]) -> xr.Dataset:
    """Combines several drifter observations found in separate chunks, ordering them
    by the observations row index.
    """
    traj_dataset = xr.concat(
        datasets, dim="obs", coords="minimal", data_vars=_DATA_VARS, compat="override"
    )

    new_rowsize = sum([ds.rowsize.values[0] for ds in datasets])
    traj_dataset["rowsize"] = xr.DataArray(
        np.array([new_rowsize], dtype=np.int64), coords=traj_dataset["rowsize"].coords
    )

    sort_coord = traj_dataset.coords["obs_index"]
    vals: np.ndarray = sort_coord.data
    sort_coord_dim = sort_coord.dims[-1]
    sort_key = vals.argsort()

    for coord_name in _COORDS:
        coord = traj_dataset.coords[coord_name]
        dim = coord.dims[-1]

        if dim == sort_coord_dim:
            sorted_coord = coord.isel({dim: sort_key})
            traj_dataset.coords[coord_name] = sorted_coord

    for varname in _DATA_VARS:
        var = traj_dataset[varname]
        dim = var.dims[-1]
        sorted_var = var.isel({dim: sort_key})
        traj_dataset[varname] = sorted_var

    return traj_dataset


async def _parallel_get(
    sources: list[str],
    gdp_metadata_df: pd.DataFrame,
    chunk_size: int,
    tmp_path: str,
    use_fill_values: bool,
    max_chunks: int | None,
) -> list[xr.Dataset]:
    """Parallel process dataset in chunks leveraging multiprocessing."""
    max_workers = (os.cpu_count() or 0) // 2
    with ProcessPoolExecutor(max_workers=max_workers) as ppe:
        drifter_chunked_datasets: dict[int, list[xr.Dataset]] = defaultdict(list)
        start_idx = 0
        for fp in tqdm(
            sources,
            desc="Loading files",
            unit="file",
            ncols=80,
            total=len(sources),
            position=0,
        ):
            file_chunks = pd.read_csv(
                fp,
                sep=r"\s+",
                header=None,
                names=_INPUT_COLS,
                engine="c",
                compression="gzip",
                chunksize=chunk_size,
            )

            joblist = list[Future]()
            jobmap = dict[Future, pd.DataFrame]()
            for idx, chunk in enumerate(file_chunks):
                if max_chunks is not None and idx >= max_chunks:
                    break
                ajob = ppe.submit(
                    _process_chunk,
                    chunk,
                    start_idx,
                    start_idx + len(chunk),
                    gdp_metadata_df,
                    use_fill_values,
                )
                start_idx += len(chunk)
                jobmap[ajob] = chunk
                joblist.append(ajob)

            bar = tqdm(
                desc="Processing file chunks",
                unit="chunk",
                ncols=80,
                total=len(joblist),
                position=1,
            )

            for ajob in as_completed(jobmap.keys()):
                if (exc := ajob.exception()) is not None:
                    chunk = jobmap[ajob]
                    _logger.warn(f"bad chunk detected, exception: {ajob.exception()}")
                    raise exc

                job_drifter_ds_map: dict[int, xr.Dataset] = ajob.result()
                for id_ in job_drifter_ds_map.keys():
                    drifter_ds = job_drifter_ds_map[id_]
                    drifter_chunked_datasets[id_].append(drifter_ds)
                bar.update()

        combine_jobmap = dict[Future, int]()
        for id_ in drifter_chunked_datasets.keys():
            datasets = drifter_chunked_datasets[id_]

            combine_job = ppe.submit(_combine_chunked_drifter_datasets, datasets)
            combine_jobmap[combine_job] = id_

        bar.close()
        bar = tqdm(
            desc="merging drifter chunks",
            unit="drifter",
            ncols=80,
            total=len(drifter_chunked_datasets.keys()),
            position=2,
        )

        os.makedirs(os.path.join(tmp_path, "drifters"), exist_ok=True)

        drifter_datasets = list[xr.Dataset]()
        for combine_job in as_completed(combine_jobmap.keys()):
            dataset: xr.Dataset = combine_job.result()
            drifter_datasets.append(dataset)
            bar.update()
        bar.close()
        return drifter_datasets


[docs] def to_raggedarray( tmp_path: str = _TMP_PATH, skip_download: bool = False, max: int | None = None, chunk_size: int = 100_000, use_fill_values: bool = True, max_chunks: int | None = None, ) -> xr.Dataset: """Get the GDP source dataset.""" os.makedirs(tmp_path, exist_ok=True) requests = _get_download_list(tmp_path) # Filter down for testing purposes. if max: requests = [requests[max]] # Download necessary data and metadata files. if not skip_download: download_with_progress(requests) gdp_metadata_df = get_gdp_metadata(tmp_path) # Run async process to parallelize data processing. drifter_datasets = asyncio.run( _parallel_get( [dst for (_, dst) in requests], gdp_metadata_df, chunk_size, tmp_path, use_fill_values, max_chunks, ) ) # Sort the drifters by their start date. deploy_date_id_map = { ds["id"].data[0]: ds["start_date"].data[0] for ds in drifter_datasets } deploy_date_sort_key = np.argsort(list(deploy_date_id_map.values())) sorted_drifter_datasets = [drifter_datasets[idx] for idx in deploy_date_sort_key] # Concatenate drifter data and metadata variables separately. obs_ds = xr.concat( [ds.drop_dims("traj") for ds in sorted_drifter_datasets], dim="obs", data_vars=_DATA_VARS, ) traj_ds = xr.concat( [ds.drop_dims("obs") for ds in sorted_drifter_datasets], dim="traj", data_vars=_METADATA_VARS, ) # Merge the separate datasets. agg_ds = xr.merge([obs_ds, traj_ds]) # Add variable metadata. for var_name in _DATA_VARS + _METADATA_VARS: if var_name in VARS_ATTRS.keys(): agg_ds[var_name].attrs = VARS_ATTRS[var_name] agg_ds.attrs = ATTRS return agg_ds