Source code for uclchem.plot.panels

"""Atomic plot units — each function draws a single panel onto a provided axis."""

from __future__ import annotations

import warnings
from typing import TYPE_CHECKING, Any

import pandas as pd

if TYPE_CHECKING:
    import matplotlib.pyplot as plt

from uclchem.style import format_chemical_formula, format_reaction_label

from ._helpers import _color_for

#: CV threshold above which a rate constant is considered time-varying.
_K_VARY_THRESHOLD = 0.01


[docs] def plot_species( ax: plt.Axes, df: pd.DataFrame, species: list[str], legend: bool = True, plot_kwargs: dict[str, Any] | None = None, ) -> plt.Axes: """Plot the abundance of a list of species through time directly onto an axis. Parameters ---------- ax : plt.Axes An axis object to plot on df : pd.DataFrame A dataframe created by ``uclchem.analysis.read_output_file``, ``uclchem.model.load_model`` or ``uclchem.model.Model.get_dataframes``. species : list[str] A list of species names to be plotted. If species name starts with "$" instead of "#" or "@", plots the sum of surface and bulk abundances legend : bool Whether to add a legend to the plot. Default = True. plot_kwargs : dict[str, Any] | None keyword arguments passed to ``ax.plot``. Default = None. Returns ------- ax : plt.Axes Modified input axis is returned Raises ------ KeyError if no ``"Time"`` column is present in ``df``. """ if plot_kwargs is None: plot_kwargs = {} for species_name in species: linestyle = "solid" if species_name[0] == "$": abundances = df[species_name.replace("$", "#")] linestyle = "dashed" if species_name.replace("$", "@") in df.columns: abundances += df[species_name.replace("$", "@")] else: abundances = df[species_name] plot_kwargs["linestyle"] = linestyle plot_kwargs["label"] = species_name # Support legacy code that use either "age" or "Time" as the time variable if "age" in df.columns: timecolumn = "age" elif "Time" in df.columns: timecolumn = "Time" else: msg = "No time variable in dataframe" raise KeyError(msg) ax.plot( df[timecolumn], abundances, lw=2, **plot_kwargs, ) ax.set(yscale="log") if legend: ax.legend() return ax
[docs] def draw_panel_abundances( ax: plt.Axes, time: pd.Series, species: str, chem: pd.DataFrame, companion: list[str] | None = None, *, reactant_species: set[str] | None = None, color_registry: dict[str, str] | None = None, ) -> plt.Axes: """Draw species abundances onto *ax* (Panel A of a deepdive figure). The target *species* is drawn in black at full weight. Each entry in *companion* is drawn in a tab20 color; species that appear in *reactant_species* get a thicker, more opaque line to signal their direct chemical involvement. Parameters ---------- ax : plt.Axes Axes to draw on. time : pd.Series Time series (years) for the x-axis. species : str UCLCHEM name of the primary species. chem : pd.DataFrame Chemistry (abundance) DataFrame, already filtered to the desired time range. companion : list[str] | None Additional species to overlay. Pass ``None`` (default) to show only *species*. reactant_species : set[str] | None Species that appear as reactants in the top reactions; these are rendered with higher visual weight. Default: empty set. color_registry : dict[str, str] | None Shared color map (species name → hex string). Pass the same dict to multiple panel calls to keep colors consistent. A fresh registry is created when ``None`` is passed. Default: ``None``. Returns ------- ax : plt.Axes The modified axes. """ if companion is None: companion = [] if reactant_species is None: reactant_species = set() if color_registry is None: color_registry = {} sp_color = {sp: _color_for(sp, color_registry) for sp in companion} sp_ls = {sp: ["-", "--"][i // 6] for i, sp in enumerate(companion)} if species in chem.columns: ax.plot( time, chem[species], label=format_chemical_formula(species), linewidth=2.0, color="black", alpha=0.9, zorder=10, ) for sp in companion: in_rxns = sp in reactant_species ax.plot( time, chem[sp], label=format_chemical_formula(sp), linewidth=1.5 if in_rxns else 0.8, color=sp_color[sp], alpha=0.85 if in_rxns else 0.5, linestyle=sp_ls[sp], ) ax.set_ylabel("Abundance (w.r.t. H)") ax.text( 0.02, 0.98, "A", transform=ax.transAxes, fontsize=10, verticalalignment="top", weight="bold", ) ax.set_xscale("log") ax.set_yscale("log") ax.legend(fontsize=7, loc="lower left", ncol=3, framealpha=0.9) ax.tick_params(labelbottom=False) return ax
def _draw_reaction_time_series( ax: plt.Axes, time: pd.Series, prod_df: pd.DataFrame, dest_df: pd.DataFrame, top_prod: list[str], top_dest: list[str], color_registry: dict[str, str], ylabel: str, panel_label: str, ) -> plt.Axes: """Shared time-series drawing logic for rates and rate-constants panels. Draws total formation/destruction envelopes (black), an "Other" aggregate for reactions not in *top_prod* / *top_dest* (gray), and individual lines for each listed reaction. Parameters ---------- ax : plt.Axes Axes to draw on. time : pd.Series Time series (years) for the x-axis. prod_df : pd.DataFrame Per-reaction production data (rates or rate constants). dest_df : pd.DataFrame Per-reaction destruction data. top_prod : list[str] Reaction column names to draw individually for production. top_dest : list[str] Reaction column names to draw individually for destruction. color_registry : dict[str, str] Shared color map (reaction string → hex string). ylabel : str Label for the y-axis. panel_label : str Single-character panel identifier drawn in the top-left corner. Returns ------- ax : plt.Axes The modified axes. """ other_prod = [c for c in prod_df.columns if c not in top_prod] other_dest = [c for c in dest_df.columns if c not in top_dest] ax.plot( time, prod_df.sum(axis=1), lw=1.5, color="black", alpha=0.45, linestyle="-", zorder=10, label="Total formation", ) ax.plot( time, dest_df.sum(axis=1), lw=1.5, color="black", alpha=0.45, linestyle="--", zorder=10, label="Total destruction", ) if other_prod: ax.plot( time, prod_df[other_prod].sum(axis=1), lw=1.2, color="gray", alpha=0.6, linestyle=":", zorder=9, label="Other formation", ) if other_dest: ax.plot( time, dest_df[other_dest].sum(axis=1), lw=1.2, color="gray", alpha=0.6, linestyle="-.", zorder=9, label="Other destruction", ) for rxn in top_prod: ax.plot( time, prod_df[rxn], lw=1.2, color=_color_for(rxn, color_registry), alpha=0.85, linestyle="-", label=format_reaction_label(rxn), ) for rxn in top_dest: ax.plot( time, dest_df[rxn], lw=1.2, color=_color_for(rxn, color_registry), alpha=0.85, linestyle="--", label=format_reaction_label(rxn), ) leg = ax.legend( loc="lower center", mode="expand", ncol=2, fontsize=7, framealpha=0.9, handlelength=3.5, ) for handle in leg.legend_handles: handle.set_linewidth(1.75) ax.set_xlabel("Time / years") ax.set_ylabel(ylabel) ax.text( 0.02, 0.98, panel_label, transform=ax.transAxes, fontsize=10, verticalalignment="top", weight="bold", ) ax.set_xscale("log") ax.set_yscale("log") return ax
[docs] def draw_panel_rates( ax: plt.Axes, time: pd.Series, prod_rates: pd.DataFrame, dest_rates: pd.DataFrame, top_prod: list[str] | None = None, top_dest: list[str] | None = None, *, top_k: int | None = 5, color_registry: dict[str, str] | None = None, ) -> plt.Axes: """Draw production and destruction rates onto *ax* (Panel B of a deepdive figure). Total formation and destruction envelopes are always drawn in black. Reactions listed in *top_prod* / *top_dest* are drawn individually in tab20 colors; remaining reactions are summed into a gray "Other" line. Parameters ---------- ax : plt.Axes Axes to draw on. time : pd.Series Time series (years) for the x-axis. prod_rates : pd.DataFrame Per-reaction production rates (abundance wrt H s⁻¹), already filtered to the desired time range. dest_rates : pd.DataFrame Per-reaction destruction rates (absolute values), already filtered. top_prod : list[str] | None Reaction column names to draw individually for production. When ``None``, the top *top_k* reactions by mean rate are selected. Default: ``None``. top_dest : list[str] | None Reaction column names to draw individually for destruction. When ``None``, the top *top_k* reactions by mean rate are selected. Default: ``None``. top_k : int | None Number of top reactions to show individually when *top_prod* / *top_dest* are ``None``. Pass ``None`` to show all reactions. Default: 5. color_registry : dict[str, str] | None Shared color map (reaction string → hex string). Pass the same dict to multiple panel calls to keep colors consistent. A fresh registry is created when ``None`` is passed. Default: ``None``. Returns ------- ax : plt.Axes The modified axes. """ if color_registry is None: color_registry = {} if top_prod is None: cols = prod_rates.columns top_prod = ( list(prod_rates[cols].mean().nlargest(top_k).index) if top_k is not None else list(cols) ) if top_dest is None: cols = dest_rates.columns top_dest = ( list(dest_rates[cols].mean().nlargest(top_k).index) if top_k is not None else list(cols) ) return _draw_reaction_time_series( ax, time, prod_rates, dest_rates, top_prod, top_dest, color_registry, ylabel=r"Reaction rate (abundance wrt H s$^{-1}$)", panel_label="B", )
[docs] def draw_panel_rate_constants( ax: plt.Axes, time: pd.Series, prod_k: pd.DataFrame, dest_k: pd.DataFrame, top_prod: list[str] | None = None, top_dest: list[str] | None = None, *, top_k: int | None = 5, bar: bool = False, color_registry: dict[str, str] | None = None, ) -> plt.Axes: """Draw rate constants onto *ax* (Panel C of a deepdive figure). By default draws rate constants as time series so trends are visible. Pass ``bar=True`` for a mean bar chart; a warning is emitted for any reaction whose rate constant varies significantly over time (CV > 1 %). Parameters ---------- ax : plt.Axes Axes to draw on. time : pd.Series Time series (years) for the x-axis. prod_k : pd.DataFrame Rate-constant DataFrame for production reactions, already filtered to the desired time range. dest_k : pd.DataFrame Rate-constant DataFrame for destruction reactions. top_prod : list[str] | None Reaction column names to include for production. When ``None``, the top *top_k* reactions by mean rate constant are selected. Default: ``None``. top_dest : list[str] | None Reaction column names to include for destruction. When ``None``, the top *top_k* reactions by mean rate constant are selected. Default: ``None``. top_k : int | None Number of top reactions to show when *top_prod* / *top_dest* are ``None``. Pass ``None`` to show all. Default: 5. bar : bool If ``True``, draw a mean bar chart instead of time series. Default: ``False``. color_registry : dict[str, str] | None Shared color map (reaction string → hex string). A fresh registry is created when ``None`` is passed. Default: ``None``. Returns ------- ax : plt.Axes The modified axes. """ if color_registry is None: color_registry = {} if top_prod is None: cols = prod_k.columns top_prod = ( list(prod_k[cols].mean().nlargest(top_k).index) if top_k is not None else list(cols) ) if top_dest is None: cols = dest_k.columns top_dest = ( list(dest_k[cols].mean().nlargest(top_k).index) if top_k is not None else list(cols) ) if not bar: return _draw_reaction_time_series( ax, time, prod_k, dest_k, top_prod, top_dest, color_registry, ylabel=r"Rate constant $k$ (s$^{-1}$)", panel_label="C", ) # Bar mode: warn for time-varying rate constants, then plot means. import numpy as np # noqa: PLC0415 top_prod_k = [r for r in top_prod if r in prod_k.columns] top_dest_k = [r for r in top_dest if r in dest_k.columns] varying = [ r for r in top_prod_k if prod_k[r].mean() > 0 and prod_k[r].std() / prod_k[r].mean() > _K_VARY_THRESHOLD ] + [ r for r in top_dest_k if dest_k[r].mean() > 0 and dest_k[r].std() / dest_k[r].mean() > _K_VARY_THRESHOLD ] if varying: warnings.warn( f"{len(varying)} rate constant(s) vary over time (CV > {_K_VARY_THRESHOLD:.0%}); " "bar shows time-mean. Use bar=False for a time-resolved view.\n" f"Varying reactions: {', '.join(varying)}", stacklevel=2, ) prod_k_mean = prod_k[top_prod_k].mean() if top_prod_k else pd.Series(dtype=float) dest_k_mean = dest_k[top_dest_k].mean() if top_dest_k else pd.Series(dtype=float) n_prod_k = len(top_prod_k) n_dest_k = len(top_dest_k) x_prod = np.arange(n_prod_k) x_dest = np.arange(n_prod_k, n_prod_k + n_dest_k) prod_colors = [_color_for(rxn, color_registry) for rxn in top_prod_k] dest_colors = [_color_for(rxn, color_registry) for rxn in top_dest_k] ax.bar(x_prod, prod_k_mean.values, width=0.75, color=prod_colors, alpha=0.85) ax.bar(x_dest, dest_k_mean.values, width=0.75, color=dest_colors, alpha=0.85) ax.set_xticks([n_prod_k / 2.0 - 0.5, n_prod_k + n_dest_k / 2.0 - 0.5]) ax.set_xticklabels(["Formation", "Destruction"]) ax.set_ylabel(r"Mean $k$ (s$^{-1}$)") ax.text( 0.02, 0.98, "C", transform=ax.transAxes, fontsize=10, verticalalignment="top", weight="bold", ) ax.set_yscale("log") return ax