Source code for sdf_xarray.dataset_accessor
from __future__ import annotations
from types import MethodType
from typing import TYPE_CHECKING
import xarray as xr
from .plotting import animate_multiple, show
if TYPE_CHECKING:
from matplotlib.animation import FuncAnimation
[docs]
@xr.register_dataset_accessor("epoch")
class EpochAccessor:
[docs]
def __init__(self, xarray_obj: xr.Dataset):
# The xarray object is the Dataset, which we store as self._ds
self._ds = xarray_obj
[docs]
def rescale_coords(
self,
multiplier: float,
unit_label: str,
coord_names: str | list[str],
) -> xr.Dataset:
"""
Rescales specified X and Y coordinates in the Dataset by a given multiplier
and updates the unit label attribute.
Parameters
----------
multiplier : float
The factor by which to multiply the coordinate values (e.g., 1e6 for meters to microns).
unit_label : str
The new unit label for the coordinates (e.g., "µm").
coord_names : str or list of str
The name(s) of the coordinate variable(s) to rescale.
If a string, only that coordinate is rescaled.
If a list, all listed coordinates are rescaled.
Returns
-------
xr.Dataset
A new Dataset with the updated and rescaled coordinates.
Examples
--------
# Convert X, Y, and Z from meters to microns
>>> ds_in_microns = ds.epoch.rescale_coords(1e6, "µm", coord_names=["X_Grid", "Y_Grid", "Z_Grid"])
# Convert only X to millimeters
>>> ds_in_mm = ds.epoch.rescale_coords(1000, "mm", coord_names="X_Grid")
"""
ds = self._ds
new_coords = {}
if isinstance(coord_names, str):
# Convert single string to a list
coords_to_process = [coord_names]
elif isinstance(coord_names, list):
# Use the provided list
coords_to_process = coord_names
else:
coords_to_process = list(coord_names)
for coord_name in coords_to_process:
if coord_name not in ds.coords:
raise ValueError(
f"Coordinate '{coord_name}' not found in the Dataset. Cannot rescale."
)
coord_original = ds[coord_name]
coord_rescaled = coord_original * multiplier
coord_rescaled.attrs = coord_original.attrs.copy()
coord_rescaled.attrs["units"] = unit_label
new_coords[coord_name] = coord_rescaled
return ds.assign_coords(new_coords)
[docs]
def animate_multiple(
self,
*variables: str | xr.DataArray,
datasets_kwargs: list[dict] | None = None,
**kwargs,
) -> FuncAnimation:
"""
Animate multiple Dataset variables on the same axes.
Parameters
----------
variables
The variables to animate.
datasets_kwargs
Per-dataset keyword arguments passed to plotting.
kwargs
Common keyword arguments forwarded to animation.
Examples
--------
>>> anim = ds.epoch.animate_multiple(
ds["Derived_Number_Density_Electron"],
ds["Derived_Number_Density_Ion"],
datasets_kwargs=[{"label": "Electron"}, {"label": "Ion"}],
ylabel="Derived Number Density [1/m$^3$]"
)
>>> anim.save("animation.gif")
>>> # Or in a jupyter notebook:
>>> anim.show()
"""
dataarrays = [
self._obj[var] if isinstance(var, str) else var for var in variables
]
anim = animate_multiple(
*dataarrays,
datasets_kwargs=datasets_kwargs,
**kwargs,
)
anim.show = MethodType(show, anim)
return anim