Source code for clouddrift.plotting

"""
This module provides a function to easily and efficiently plot the rows of a ragged array.
"""

import numpy as np
import pandas as pd
import xarray as xr

from clouddrift.ragged import rowsize_to_index, segment


[docs] def plot_ragged( ax, longitude: list | np.ndarray | pd.Series | xr.DataArray, latitude: list | np.ndarray | pd.Series | xr.DataArray, rowsize: list | np.ndarray | pd.Series | xr.DataArray, *args, colors: list | np.ndarray | pd.Series | xr.DataArray | None = None, tolerance: float | int = 180, **kwargs, ): """Plot individually the rows of a ragged array dataset on a Matplotlib Axes or a Cartopy GeoAxes object ``ax``. This function wraps Matplotlib's ``plot`` function (``plt.plot``) and ``LineCollection`` (``matplotlib.collections``) to efficiently plot the rows of a ragged array dataset. Parameters ---------- ax: matplotlib.axes.Axes or cartopy.mpl.geoaxes.GeoAxes Axis to plot on. longitude : array-like Longitude sequence. Unidimensional array input. latitude : array-like Latitude sequence. Unidimensional array input. rowsize : list List of integers specifying the number of data points in each row. *args : tuple Additional arguments to pass to ``ax.plot``. colors : array-like Values to map on the current colormap. If ``colors`` is the shape of ``rowsize``, the data points of each row are uniformly colored according to the color value for the row. If ``colors`` is the shape of ``longitude`` and ``latitude``, the data points are colored according to the color value for each data point. tolerance : float Longitude tolerance gap between data points (in degrees) for segmenting rows. For periodic domains, the tolerance parameter should be set to the maximum allowed gap between data points. Defaults to 180. **kwargs : dict Additional keyword arguments to pass to ``ax.plot``. Returns ------- list of matplotlib.lines.Line2D or matplotlib.collections.LineCollection The plotted lines or line collection. Can be used to set a colorbar after plotting or extract information from the lines. Examples -------- Load 100 trajectories from the gdp1h dataset for the examples. >>> from clouddrift import datasets >>> from clouddrift.ragged import subset >>> from clouddrift.plotting import plot_ragged >>> import numpy as np >>> import matplotlib.pyplot as plt >>> from mpl_toolkits.axes_grid1 import make_axes_locatable >>> ds = datasets.gdp1h() >>> ds = subset(ds, {"id": ds.id[:100].values}).load() Plot the trajectories, assigning a different color to each trajectory: >>> fig = plt.figure() >>> ax = fig.add_subplot(1, 1, 1) >>> l = plot_ragged( >>> ax, >>> ds.lon, >>> ds.lat, >>> ds.rowsize, >>> colors=np.arange(len(ds.rowsize)) >>> ) >>> divider = make_axes_locatable(ax) >>> cax = divider.append_axes('right', size='3%', pad=0.05) >>> fig.colorbar(l, cax=cax) To plot the same trajectories, but assigning a different color to each data point based on a transformation of the time variable mapped onto the ``inferno`` colormap: >>> fig = plt.figure() >>> ax = fig.add_subplot(1, 1, 1) >>> time = [v.astype(np.int64) / 86400 / 1e9 for v in ds.time.values] >>> l = plot_ragged( >>> ax, >>> ds.lon, >>> ds.lat, >>> ds.rowsize, >>> colors=np.floor(time), >>> cmap="inferno" >>> ) >>> divider = make_axes_locatable(ax) >>> cax = divider.append_axes('right', size="3%", pad=0.05) >>> fig.colorbar(l, cax=cax) Finally, to plot the same trajectories, but using a cartopy projection: >>> import cartopy.crs as ccrs >>> fig = plt.figure() >>> ax = fig.add_subplot(1, 1, 1, projection=ccrs.Mollweide()) >>> l = plot_ragged( >>> ax, >>> ds.lon, >>> ds.lat, >>> ds.rowsize, >>> colors=np.arange(len(ds.rowsize)), >>> transform=ccrs.PlateCarree(), >>> cmap="Blues", >>> ) >>> ax.set_extent([-180, 180, -90, 90]) >>> ax.coastlines() >>> ax.gridlines(draw_labels=True) >>> divider = make_axes_locatable(ax) >>> cax = divider.append_axes('right', size="3%", pad=0.25, axes_class=plt.Axes) >>> fig.colorbar(l, cax=cax) Raises ------ ValueError If longitude and latitude arrays do not have the same shape. If colors do not have the same shape as longitude and latitude arrays or rowsize. If ax is not a matplotlib Axes or GeoAxes object. If ax is a GeoAxes object and the transform keyword argument is not provided. ImportError If matplotlib is not installed. If the axis is a GeoAxes object and cartopy is not installed. """ # optional dependency try: import matplotlib.colors as mcolors import matplotlib.pyplot as plt from matplotlib import cm from matplotlib.collections import LineCollection except ImportError: raise ImportError("missing optional dependency 'matplotlib'") if hasattr(ax, "coastlines"): # check if GeoAxes without cartopy try: from cartopy.mpl.geoaxes import GeoAxes if isinstance(ax, GeoAxes) and not kwargs.get("transform"): raise ValueError( "For GeoAxes, the transform keyword argument must be provided." ) except ImportError: raise ImportError("missing optional dependency 'cartopy'") elif not isinstance(ax, plt.Axes): raise ValueError("ax must be either: plt.Axes or GeoAxes.") if np.sum(rowsize) != len(longitude): raise ValueError("The sum of rowsize must equal the length of lon and lat.") if len(longitude) != len(latitude): raise ValueError("lon and lat must have the same length.") if colors is None: colors = np.arange(len(rowsize)) elif colors is not None and (len(colors) not in [len(longitude), len(rowsize)]): raise ValueError("shape colors must match the shape of lon/lat or rowsize.") # define a colormap if isinstance(cmap := kwargs.pop("cmap", cm.viridis), str): cmap = plt.get_cmap(cmap) # define a normalization obtain uniform colors # for the sequence of lines or LineCollection norm = kwargs.pop( "norm", mcolors.Normalize(vmin=np.nanmin(colors), vmax=np.nanmax(colors)) ) # create Mappable for colorbar cb = plt.cm.ScalarMappable(norm=norm, cmap=cmap) mpl_plot = True if colors is None or len(colors) == len(rowsize) else False traj_idx = rowsize_to_index(rowsize) for i in range(len(rowsize)): lon_i, lat_i = ( longitude[traj_idx[i] : traj_idx[i + 1]], latitude[traj_idx[i] : traj_idx[i + 1]], ) start = 0 for length in segment(lon_i, tolerance, rowsize=segment(lon_i, -tolerance)): end = start + length if mpl_plot: ax.plot( lon_i[start:end], lat_i[start:end], c=cmap(norm(colors[i])) if colors is not None else None, *args, **kwargs, ) else: colors_i = colors[traj_idx[i] : traj_idx[i + 1]] segments = np.column_stack( [ lon_i[start : end - 1], lat_i[start : end - 1], lon_i[start + 1 : end], lat_i[start + 1 : end], ] ).reshape(-1, 2, 2) lc = LineCollection(segments, cmap=cmap, norm=norm, *args, **kwargs) lc.set_array( # color of a segment is the average of its two data points np.convolve(colors_i[start:end], [0.5, 0.5], mode="valid") ) ax.add_collection(lc) start = end # set axis limits ax.set_xlim([np.min(longitude), np.max(longitude)]) ax.set_ylim([np.min(latitude), np.max(latitude)]) return cb