from __future__ import annotations
import warnings
from collections.abc import Callable
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
import numpy as np
import xarray as xr
if TYPE_CHECKING:
import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
[docs]
@dataclass
class AnimationUnit:
update: Callable[[int], object]
n_frames: int
[docs]
def get_frame_title(
data: xr.DataArray,
frame: int,
display_sdf_name: bool = False,
title_custom: str | None = None,
t: str = "time",
) -> str:
"""Generate the title for a frame
Parameters
----------
data
DataArray containing the target data
frame
Frame number
display_sdf_name
Display the sdf file name in the animation title
title_custom
Custom title to add to the plot
t
Time coordinate
"""
# Adds custom text to the start of the title, if specified
title_custom = "" if title_custom is None else f"{title_custom}, "
# Adds the time axis and associated units to the title
t_axis_value = data[t][frame].values
t_axis_units = data[t].attrs.get("units", False)
t_axis_units_formatted = f" [{t_axis_units}]" if t_axis_units else ""
title_t_axis = f"{data[t].long_name} = {t_axis_value:.2e}{t_axis_units_formatted}"
# Adds sdf name to the title, if specifed
title_sdf = f", {frame:04d}.sdf" if display_sdf_name else ""
return f"{title_custom}{title_t_axis}{title_sdf}"
[docs]
def calculate_window_boundaries(
data: xr.DataArray,
xlim: tuple[float, float] | None = None,
x_axis_name: str = "X_Grid_mid",
t: str = "time",
) -> np.ndarray:
"""Calculate the boundaries a moving window frame. If the user specifies xlim, this will
be used as the initial boundaries and the window will move along acordingly.
Parameters
----------
data
DataArray containing the target data
xlim
x limits
x_axis_name
Name of coordinate to assign to the x-axis
t
Time coordinate
"""
x_grid = data[x_axis_name].values
x_half_cell = (x_grid[1] - x_grid[0]) / 2
n_frames = data[t].size
# Find the window boundaries by finding the first and last non-NaN values in the 0th lineout
# along the x-axis.
window_boundaries = np.zeros((n_frames, 2))
for i in range(n_frames):
# Check if data is 1D
if data.ndim == 2:
target_lineout = data[i].values
# Check if data is 2D
if data.ndim == 3:
target_lineout = data[i, :, 0].values
x_grid_non_nan = x_grid[~np.isnan(target_lineout)]
window_boundaries[i, 0] = x_grid_non_nan[0] - x_half_cell
window_boundaries[i, 1] = x_grid_non_nan[-1] + x_half_cell
# User's choice for initial window edge supercedes the one calculated
if xlim is not None:
window_boundaries = window_boundaries + xlim - window_boundaries[0]
return window_boundaries
[docs]
def compute_global_limits(
data: xr.DataArray,
min_percentile: float = 0,
max_percentile: float = 100,
) -> tuple[float, float]:
"""Remove all NaN values from the target data to calculate the global minimum and maximum of the data.
User defined percentiles can remove extreme outliers.
Parameters
----------
data
DataArray containing the target data
min_percentile
Minimum percentile of the data
max_percentile
Maximum percentile of the data
"""
# Removes NaN values, needed for moving windows
values_no_nan = data.values[~np.isnan(data.values)]
# Finds the global minimum and maximum of the plot, based on the percentile of the data
global_min = np.percentile(values_no_nan, min_percentile)
global_max = np.percentile(values_no_nan, max_percentile)
return global_min, global_max
def _set_axes_labels(ax: plt.Axes, axis_kwargs: dict) -> None:
"""Set the labels for the x and y axes"""
if "xlabel" in axis_kwargs:
ax.set_xlabel(axis_kwargs["xlabel"])
if "ylabel" in axis_kwargs:
ax.set_ylabel(axis_kwargs["ylabel"])
def _setup_2d_plot(
data: xr.DataArray,
ax: plt.Axes,
coord_names: list[str],
kwargs: dict,
axis_kwargs: dict,
min_percentile: float,
max_percentile: float,
t: str,
) -> tuple[float, float]:
"""Setup 2D plot initialization."""
kwargs.setdefault("x", coord_names[0])
data.isel({t: 0}).plot(ax=ax, **kwargs)
global_min, global_max = compute_global_limits(data, min_percentile, max_percentile)
_set_axes_labels(ax, axis_kwargs)
if "ylim" not in kwargs:
ax.set_ylim(global_min, global_max)
return global_min, global_max
def _setup_3d_plot(
data: xr.DataArray,
ax: plt.Axes,
coord_names: list[str],
kwargs: dict,
kwargs_original: dict,
axis_kwargs: dict,
min_percentile: float,
max_percentile: float,
t: str,
) -> None:
"""Setup 3D plot initialization."""
import matplotlib.pyplot as plt # noqa: PLC0415
if "norm" not in kwargs:
global_min, global_max = compute_global_limits(
data, min_percentile, max_percentile
)
kwargs["norm"] = plt.Normalize(vmin=global_min, vmax=global_max)
kwargs["add_colorbar"] = False
kwargs.setdefault("x", coord_names[0])
kwargs.setdefault("y", coord_names[1])
argmin_time = np.unravel_index(np.argmin(data.values), data.shape)[0]
plot = data.isel({t: argmin_time}).plot(ax=ax, **kwargs)
kwargs["cmap"] = plot.cmap
_set_axes_labels(ax, axis_kwargs)
if kwargs_original.get("add_colorbar", True):
long_name = data.attrs.get("long_name")
units = data.attrs.get("units")
fig = plot.get_figure()
fig.colorbar(plot, ax=ax, label=f"{long_name} [{units}]")
def _generate_animation(
data: xr.DataArray,
clear_axes: bool = False,
min_percentile: float = 0,
max_percentile: float = 100,
title: str | None = None,
display_sdf_name: bool = False,
move_window: bool = False,
t: str | None = None,
ax: plt.Axes | None = None,
kwargs: dict | None = None,
) -> AnimationUnit:
"""
Internal function for generating the plotting logic required for animations.
Parameters
---------
data
DataArray containing the target data
clear_axes
Decide whether to run ``ax.clear()`` in every update
min_percentile
Minimum percentile of the data
max_percentile
Maximum percentile of the data
title
Custom title to add to the plot
display_sdf_name
Display the sdf file name in the animation title
move_window
Update the ``xlim`` to be only values that are not NaNs at each time interval
t
Coordinate for t axis (the coordinate which will be animated over).
If ``None``, use ``data.dims[0]``
ax
Matplotlib axes on which to plot
kwargs
Keyword arguments to be passed to matplotlib
"""
if kwargs is None:
kwargs = {}
kwargs_original = kwargs.copy()
axis_kwargs = {}
for key in ("xlabel", "ylabel"):
if key in kwargs:
axis_kwargs[key] = kwargs.pop(key)
# Sets the animation coordinate (t) for iteration. If time is in the coords
# then it will set time to be t. If it is not it will fallback to the last
# coordinate passed in. By default coordinates are passed in from xarray in
# the form x, y, z so in order to preserve the x and y being on their
# respective axes we animate over the final coordinate that is passed in
# which in this example is z
coord_names = list(data.dims)
if t is None:
t = "time" if "time" in coord_names else coord_names[-1]
coord_names.remove(t)
N_frames = data[t].size
global_min = global_max = None
if data.ndim == 2:
global_min, global_max = _setup_2d_plot(
data=data,
ax=ax,
coord_names=coord_names,
kwargs=kwargs,
axis_kwargs=axis_kwargs,
min_percentile=min_percentile,
max_percentile=max_percentile,
t=t,
)
elif data.ndim == 3:
_setup_3d_plot(
data=data,
ax=ax,
coord_names=coord_names,
kwargs=kwargs,
kwargs_original=kwargs_original,
axis_kwargs=axis_kwargs,
min_percentile=min_percentile,
max_percentile=max_percentile,
t=t,
)
ax.set_title(get_frame_title(data, 0, display_sdf_name, title, t))
window_boundaries = None
if move_window:
window_boundaries = calculate_window_boundaries(
data, kwargs.get("xlim"), kwargs["x"]
)
def update(frame):
if clear_axes:
ax.clear()
# Set the xlim for each frame in the case of a moving window
if move_window:
kwargs["xlim"] = window_boundaries[frame]
plot = data.isel({t: frame}).plot(ax=ax, **kwargs)
ax.set_title(get_frame_title(data, frame, display_sdf_name, title, t))
_set_axes_labels(ax, axis_kwargs)
if data.ndim == 2 and "ylim" not in kwargs and global_min is not None:
ax.set_ylim(global_min, global_max)
return plot
return AnimationUnit(
update=update,
n_frames=N_frames,
)
[docs]
def animate(
data: xr.DataArray,
fps: float = 10,
min_percentile: float = 0,
max_percentile: float = 100,
title: str | None = None,
display_sdf_name: bool = False,
move_window: bool = False,
t: str | None = None,
ax: plt.Axes | None = None,
**kwargs,
) -> FuncAnimation:
"""
Generate an animation using an `xarray.DataArray`. The intended use
of this function is via `sdf_xarray.dataarray_accessor.EpochAccessor.animate`.
Parameters
---------
data
DataArray containing the target data
fps
Frames per second for the animation
min_percentile
Minimum percentile of the data
max_percentile
Maximum percentile of the data
title
Custom title to add to the plot
display_sdf_name
Display the sdf file name in the animation title
move_window
Update the ``xlim`` to be only values that are not NaNs at each time interval
t
Coordinate for t axis (the coordinate which will be animated over).
If ``None``, use ``data.dims[0]``
ax
Matplotlib axes on which to plot
kwargs
Keyword arguments to be passed to matplotlib
Examples
--------
>>> anim = ds["Derived_Number_Density_Electron"].epoch.animate()
>>> anim.save("animation.gif")
"""
import matplotlib.pyplot as plt # noqa: PLC0415
from matplotlib.animation import FuncAnimation # noqa: PLC0415
# Create plot if no ax is provided
if ax is None:
_, ax = plt.subplots()
animation = _generate_animation(
data,
clear_axes=True,
min_percentile=min_percentile,
max_percentile=max_percentile,
title=title,
display_sdf_name=display_sdf_name,
move_window=move_window,
t=t,
ax=ax,
kwargs=kwargs,
)
return FuncAnimation(
ax.get_figure(),
animation.update,
frames=range(animation.n_frames),
interval=1000 / fps,
repeat=True,
)
[docs]
def animate_multiple(
*datasets: xr.DataArray,
datasets_kwargs: list[dict[str, Any]] | None = None,
fps: float = 10,
min_percentile: float = 0,
max_percentile: float = 100,
title: str | None = None,
display_sdf_name: bool = False,
move_window: bool = False,
t: str | None = None,
ax: plt.Axes | None = None,
**common_kwargs,
) -> FuncAnimation:
"""
Generate an animation using multiple `xarray.DataArray`. The intended use
of this function is via `sdf_xarray.dataset_accessor.EpochAccessor.animate_multiple`.
Parameters
---------
datasets
`xarray.DataArray` objects containing the data to be animated
datasets_kwargs
A list of dictionaries, following the same order as ``datasets``, containing
per-dataset matplotlib keyword arguments. The list does not need to be the same
length as ``datasets``; missing entries are initialised as empty dictionaries
fps
Frames per second for the animation
min_percentile
Minimum percentile of the data
max_percentile
Maximum percentile of the data
title
Custom title to add to the plot
display_sdf_name
Display the sdf file name in the animation title
move_window
Update the ``xlim`` to be only values that are not NaNs at each time interval
t
Coordinate for t axis (the coordinate which will be animated over). If ``None``,
use ``data.dims[0]``
ax
Matplotlib axes on which to plot
common_kwargs
Matplotlib keyword arguments applied to all datasets. These are overridden by
per-dataset entries in ``datasets_kwargs``
Examples
--------
>>> anim = animate_multiple(
ds["Derived_Number_Density_Electron"],
ds["Derived_Number_Density_Ion"],
datasets_kwargs=[{"label": "Electron"}, {"label": "Ion"}],
ylim=(0e27,4e27),
display_sdf_name=True,
ylabel="Derived Number Density [1/m$^3$]"
)
>>> anim.save("animation.gif")
"""
import matplotlib.pyplot as plt # noqa: PLC0415
from matplotlib.animation import FuncAnimation # noqa: PLC0415
if not datasets:
raise ValueError("At least one dataset must be provided")
# Create plot if no ax is provided
if ax is None:
_, ax = plt.subplots()
n_datasets = len(datasets)
if datasets_kwargs is None:
# Initialise an empty series of dicts the same size as the number of datasets
datasets_kwargs = [{} for _ in range(n_datasets)]
else:
# The user might only want to use kwargs on some of the datasets so we make sure
# to initialise additional empty dicts and append them to the list
datasets_kwargs.extend({} for _ in range(n_datasets - len(datasets_kwargs)))
animations: list[AnimationUnit] = []
for da, kw in zip(datasets, datasets_kwargs):
animations.append(
_generate_animation(
da,
ax=ax,
min_percentile=min_percentile,
max_percentile=max_percentile,
title=title,
display_sdf_name=display_sdf_name,
move_window=move_window,
t=t,
# Per-dataset kwargs override common matplotlib kwargs
kwargs={**common_kwargs, **kw},
)
)
lengths = [anim.n_frames for anim in animations]
n_frames = min(lengths)
if len(set(lengths)) > 1:
warnings.warn(
"Datasets have different frame counts; truncating to the shortest",
stacklevel=2,
)
# Render the legend if a label exists for any 2D dataset
show_legend = any(
"label" in kw and da.ndim == 2 for da, kw in zip(datasets, datasets_kwargs)
)
def update(frame):
ax.clear()
for anim in animations:
anim.update(frame)
if show_legend:
ax.legend(loc="upper right")
return FuncAnimation(
ax.get_figure(),
update,
frames=range(n_frames),
interval=1000 / fps,
repeat=True,
)
[docs]
def show(anim):
"""Shows the FuncAnimation in a Jupyter notebook.
Parameters
----------
anim
`matplotlib.animation.FuncAnimation`
"""
from IPython.display import HTML # noqa: PLC0415
return HTML(anim.to_jshtml())