from __future__ import annotations
from typing import Iterable, Literal, Optional, Sequence, Mapping, Tuple
import re
import ast
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.transforms as mtransforms
import anndata as ad
from mpl_toolkits.axes_grid1 import make_axes_locatable # kept if used elsewhere
from mpl_toolkits.axes_grid1.inset_locator import inset_axes
from matplotlib.colors import ListedColormap
from matplotlib.patches import Patch
from scipy.cluster.hierarchy import linkage, leaves_list
from scipy.spatial.distance import pdist
from zmap.reference import load_consensus_markers
# ---------------------------------------------------------------------
# Level configuration (ontology levels and their parents)
# ---------------------------------------------------------------------
_ZMAP_LEVEL_CONFIG = {
"ZMAP_GermLayer": {
"parent_col": "ZMAP_GermLayer",
"consensus_level": "GermLayer",
},
"ZMAP_Tissue": {
"parent_col": "ZMAP_GermLayer",
"consensus_level": "Tissue",
},
"ZMAP_CellType": {
"parent_col": "ZMAP_Tissue",
"consensus_level": "CellType",
},
"ZMAP_CellTypeFine": {
"parent_col": "ZMAP_Tissue",
"consensus_level": "CellTypeFine",
},
"ZMAP_Cluster": {
"parent_col": "ZMAP_CellTypeFine",
"consensus_level": "Cluster",
},
}
_ZMAP_LEVEL_LOOKUP = {
"ZMAP_Tissue": "Tissue",
"ZMAP_CellType": "CellType",
"ZMAP_CellTypeFine": "CellTypeFine",
"ZMAP_Cluster": "Cluster",
}
# ---------------------------------------------------------------------
# Helper Functions
# ---------------------------------------------------------------------
def _add_top_sibling_axes(ax, *, height_in: float, pad_in: float = 0.0):
"""
Create a NEW axes directly above `ax` with fixed height in inches.
Width matches `ax`. Returns the new axes.
"""
fig = ax.figure
fig_w, fig_h = fig.get_size_inches()
base = ax.get_position() # in figure fraction
h_frac = float(height_in) / fig_h
pad_frac = float(pad_in) / fig_h
left, right = base.x0, base.x1
bottom = base.y1 + pad_frac
height = h_frac
return fig.add_axes([left, bottom, right - left, height])
def _pt(val: float) -> float:
"""Convert points to inches for inset_axes / fig.add_axes math."""
return float(val) / 72.0
def _inset_fixed(ax, *, width, height, loc, bbox_to_anchor=(0, 0, 1, 1), borderpad=0.0):
"""
Create an inset axes with fixed physical size anchored to 'ax'.
width/height can be floats (inches) or strings like '100%'.
"""
return inset_axes(
ax,
width=width,
height=height,
loc=loc,
bbox_to_anchor=bbox_to_anchor,
bbox_transform=ax.transAxes,
borderpad=borderpad,
)
def _add_left_time_strip_inset(
fig,
ax_main,
*,
adata,
groupby: str,
row_order: list[str],
time_key: str = "time_id",
vmin: float = 0.0,
vmax: float = 120.0,
cmap: str = "grey",
width_pt: float = 12.0, # strip width (points)
gap_pt: float = 4.0, # gap between y-axis and strip (points)
show_yticklabels: bool = False,
label_colors: dict[str, str] | None = None,
draw_border: bool = False,
ylabel: str = "Time (hpf)",
ylabel_fontsize: float = 7.0,
ylabel_color: str = "0.4",
):
"""
Add a vertical time strip (median time per group) to the LEFT of ax_main.
Both the strip width and the gap between the main y-axis and the strip are
specified in physical units (points), so the geometry is stable across
different figure sizes.
"""
if groupby not in adata.obs.columns:
raise KeyError(f"groupby '{groupby}' not found in adata.obs.")
if time_key not in adata.obs.columns:
raise KeyError(f"time_key '{time_key}' not found in adata.obs.")
df = adata.obs[[groupby, time_key]].dropna(subset=[groupby, time_key])
if df.empty:
return None
med = df.groupby(groupby)[time_key].median()
missing = [g for g in row_order if g not in med.index]
if missing:
raise ValueError(
f"The following groups in row_order have no valid '{time_key}' values: {missing}"
)
med = med.reindex(row_order)
arr = med.to_numpy(dtype=float)[:, None]
# --- Compute physical sizes (inches) and convert to axes-fraction offset ---
strip_width_in = _pt(width_pt) # strip width in inches
gap_in = _pt(gap_pt) # gap between y-axis and strip in inches
fig_w, fig_h = fig.get_size_inches()
pos = ax_main.get_position() # [x0, y0, width, height] in figure fraction
ax_width_in = fig_w * pos.width if fig_w > 0 else 0.0
if ax_width_in > 0:
# axes x=0 is at the main y-axis; negative fraction shifts left
left_edge_frac = - (gap_in + strip_width_in) / ax_width_in
else:
left_edge_frac = 0.0
# Fixed-size inset: strip width in inches, height spans the dotplot axis box
ax_strip = _inset_fixed(
ax_main,
width=strip_width_in, # inches
height="100%",
loc="lower left",
bbox_to_anchor=(left_edge_frac, 0.0, 1, 1),
borderpad=0.0,
)
im = ax_strip.imshow(
arr,
aspect="auto",
origin="lower",
cmap=cmap,
vmin=vmin,
vmax=vmax,
interpolation="nearest",
)
nrows = len(row_order)
ax_strip.set_ylim(-0.5, nrows - 0.5)
ax_strip.grid(False)
ax_strip.set_axisbelow(False)
ax_strip.set_xticks([])
if show_yticklabels:
ax_strip.set_yticks(np.arange(len(row_order)))
ax_strip.set_yticklabels(row_order, fontsize=8)
if label_colors is not None:
for lab in ax_strip.get_yticklabels():
t = lab.get_text()
if t in label_colors:
lab.set_color(label_colors[t])
else:
ax_strip.set_yticks([])
ax_strip.set_yticklabels([])
if draw_border:
for s in ax_strip.spines.values():
s.set_visible(True)
s.set_linewidth(1.0)
s.set_color("black")
else:
for s in ax_strip.spines.values():
s.set_visible(False)
if ylabel:
ax_strip.text(
0.5,
-0.005,
ylabel,
ha="center",
va="top",
rotation=90,
fontsize=ylabel_fontsize,
color=ylabel_color,
transform=ax_strip.transAxes,
)
return ax_strip, im
def find_level_for_node(
adata: ad.AnnData,
node: str,
candidate_cols: Optional[Sequence[str]] = None,
) -> str:
"""
Return the name of the obs column where `node` appears as a value.
If found in multiple columns, choose the *highest* level according to the
ordering in _ZMAP_LEVEL_CONFIG.
"""
if candidate_cols is None:
candidate_cols = list(_ZMAP_LEVEL_CONFIG.keys()) # highest → lowest
hits: list[str] = []
node_str = str(node)
for col in candidate_cols:
if col in adata.obs.columns:
vals = adata.obs[col].astype(str)
if node_str in set(vals):
hits.append(col)
if not hits:
raise ValueError(
f"{node_str!r} not found in any candidate level columns: {candidate_cols!r}."
)
if len(hits) > 1:
for col in candidate_cols:
if col in hits:
return col
return hits[0]
def get_parent_and_siblings(
adata: ad.AnnData,
node: str,
level_col: str,
parent_col: Optional[str],
) -> tuple[Optional[str], list[str]]:
"""
Return (parent_label, siblings) for a given node in `level_col`.
- If level_col == "ZMAP_Tissue":
siblings are *all* tissues in ZMAP_Tissue.
parent_label is inferred from parent_col if possible, else None.
- Else:
siblings are all values in `level_col` that share the same parent
in `parent_col`. The returned `siblings` list includes the focal node.
"""
if level_col not in adata.obs.columns:
raise KeyError(f"level_col={level_col!r} not found in adata.obs.")
node_str = str(node)
# Special case: tissue-level → siblings = all tissues
if level_col == "ZMAP_Tissue":
vals = (
adata.obs[level_col]
.dropna()
.astype(str)
.unique()
.tolist()
)
if node_str not in vals:
raise ValueError(
f"{node_str!r} not found in level_col={level_col!r} when treating as top-level."
)
siblings = sorted(vals)
parent_label: Optional[str] = None
if parent_col is not None and parent_col in adata.obs.columns:
df = adata.obs[[level_col, parent_col]].dropna()
df[level_col] = df[level_col].astype(str)
df[parent_col] = df[parent_col].astype(str)
parents = df.loc[df[level_col] == node_str, parent_col].unique()
if len(parents) == 1:
parent_label = parents[0]
return parent_label, siblings
# Normal case: use parent_col
if parent_col is None:
raise ValueError(
f"parent_col is None for level_col={level_col!r}; cannot infer siblings."
)
if parent_col not in adata.obs.columns:
raise KeyError(f"parent_col={parent_col!r} not found in adata.obs.")
df = adata.obs[[level_col, parent_col]].dropna()
df[level_col] = df[level_col].astype(str)
df[parent_col] = df[parent_col].astype(str)
parents = df.loc[df[level_col] == node_str, parent_col].unique()
if len(parents) == 0:
raise ValueError(
f"No rows with {level_col} == {node_str!r} when looking for parent in {parent_col!r}."
)
if len(parents) > 1:
raise ValueError(
f"{node_str!r} has multiple parents in {parent_col!r}: {parents!r}."
)
parent_label = parents[0]
siblings = (
df.loc[df[parent_col] == parent_label, level_col]
.astype(str)
.unique()
.tolist()
)
siblings = sorted(siblings)
return parent_label, siblings
def get_node_markers(
node: str,
level_col: str,
*,
marker_types: Iterable[Literal["overall", "exclusivity", "contrast", "consensus"]] = ("overall",),
n_per_type: int = 10,
min_support_ratio: float | None = None,
min_log2fc: float | None = None,
min_enrich: float | None = None,
omit_unannotated: bool = False,
) -> pd.DataFrame:
"""
Load consensus markers for a single node across one or more marker_types.
Returns a concatenated DataFrame with at least columns: ['celltype', 'gene', 'marker_type']
"""
if level_col not in _ZMAP_LEVEL_CONFIG:
raise KeyError(
f"level_col={level_col!r} not recognized in _ZMAP_LEVEL_CONFIG."
)
consensus_level = _ZMAP_LEVEL_CONFIG[level_col]["consensus_level"]
all_frames: list[pd.DataFrame] = []
for mtype in marker_types:
df = load_consensus_markers(
level=consensus_level,
groups=[str(node)],
marker_type=mtype,
n_per_group=n_per_type,
min_support_ratio=min_support_ratio,
min_log2fc=min_log2fc,
min_enrich=min_enrich,
omit_unannotated=omit_unannotated,
format="table",
)
if df is None or df.empty:
continue
df = df.copy()
df["marker_type"] = mtype
df["celltype"] = df["celltype"].astype(str)
df["gene"] = df["gene"].astype(str)
all_frames.append(df)
if not all_frames:
mt = list(marker_types)
raise ValueError(
f"No markers returned for node={node!r}, level={consensus_level!r}, "
f"marker_types={mt}."
)
markers = pd.concat(all_frames, ignore_index=True)
return markers
def make_sibling_design_df(
node: str,
siblings: Sequence[str],
marker_df: pd.DataFrame,
*,
genes: Sequence[str] | None = None,
support_col: str = "support_ratio",
) -> pd.DataFrame:
"""
Construct a design_df with columns ['celltype', 'gene', support_col]
for use with plot_dotplot_basegrid, where:
- rows are all (sibling, gene) pairs
- if `genes` is provided, it is treated as the canonical column order
and ONLY genes in this list (that are present in marker_df for `node`)
are included.
- if `genes` is None, genes come from node_markers['gene'].unique()
in that order.
- support_col is populated for the focal node (if present in marker_df),
and NaN for other siblings.
"""
marker_df = marker_df.copy()
marker_df["gene"] = marker_df["gene"].astype(str)
marker_df["celltype"] = marker_df["celltype"].astype(str)
node_str = str(node)
node_markers = marker_df.loc[marker_df["celltype"] == node_str]
if node_markers.empty:
raise ValueError(f"No rows in marker_df for focal node={node_str!r}.")
if genes is None:
# Fallback: use the intrinsic order of genes for this node
genes_list = node_markers["gene"].unique().tolist()
else:
# Enforce provided canonical order; keep only genes present for this node
genes_set = set(node_markers["gene"].unique().tolist())
genes_list = [str(g) for g in genes if str(g) in genes_set]
if not genes_list:
raise ValueError(
f"None of the requested genes are present for node={node_str!r} "
f"in marker_df."
)
# Cross product: siblings × genes (in canonical genes_list order)
design_df = pd.DataFrame(
[(ct, g) for ct in siblings for g in genes_list],
columns=["celltype", "gene"],
)
# Propagate support values for the focal node only
if support_col in node_markers.columns:
sup_map = (
node_markers
.drop_duplicates(["celltype", "gene"])
.set_index("gene")[support_col]
)
design_df[support_col] = np.where(
design_df["celltype"] == node_str,
design_df["gene"].map(sup_map),
np.nan,
)
else:
design_df[support_col] = np.nan
return design_df
def _get_expression_slice(
adata: ad.AnnData,
genes: Sequence[str],
*,
layer: str | None = None,
use_raw: bool | None = None,
) -> "tuple[np.ndarray, list[str]]":
"""
Densify a gene-subset of the expression matrix exactly once.
Returns (X_dense, valid_genes) where X_dense has shape (n_cells, len(valid_genes)).
Use the returned arrays as _x_dense / _x_genes in _compute_group_gene_means to
avoid repeated toarray() calls when the same gene set is queried across levels.
"""
genes = [str(g) for g in genes]
if use_raw and getattr(adata, "raw", None) is not None:
X = adata.raw.X
var_names = pd.Index(adata.raw.var_names).astype(str)
elif layer is not None:
X = adata.layers[layer]
var_names = pd.Index(adata.var_names).astype(str)
else:
X = adata.X
var_names = pd.Index(adata.var_names).astype(str)
gene_idx_arr = var_names.get_indexer(genes)
valid_gene = gene_idx_arr >= 0
valid_genes = [g for g, v in zip(genes, valid_gene) if v]
gene_cols = gene_idx_arr[valid_gene]
if not valid_genes:
return np.empty((X.shape[0], 0), dtype=np.float32), []
X_sub = X[:, gene_cols]
X_dense = X_sub.toarray() if hasattr(X_sub, "toarray") else np.asarray(X_sub)
return X_dense.astype(np.float32), valid_genes # float32 saves memory; promoted to float64 on use
def _compute_group_gene_means(
adata: ad.AnnData,
groups: Sequence[str],
genes: Sequence[str],
*,
level_col: str,
layer: str | None = None,
use_raw: bool | None = None,
# Optional: pass a pre-densified (n_cells x n_genes) matrix and matching gene list
# to skip the sparse slice + toarray() entirely.
_x_dense: "np.ndarray | None" = None,
_x_genes: "list[str] | None" = None,
) -> "tuple[np.ndarray, list[str], list[str]]":
"""
Compute a (n_groups x n_genes) mean-expression matrix.
Returns (mat, valid_groups, valid_genes).
Groups absent from adata get an all-NaN row; genes absent are dropped.
Pass _x_dense + _x_genes to reuse an already-densified gene slice and
avoid repeated X[:, gene_cols].toarray() calls across multiple groupby levels.
"""
groups = [str(s) for s in groups]
genes = [str(g) for g in genes]
if _x_dense is not None and _x_genes is not None:
# Reuse pre-densified slice; filter columns to requested genes
x_gene_idx = pd.Index(_x_genes)
col_pos = x_gene_idx.get_indexer(genes)
valid_gene = col_pos >= 0
valid_genes = [g for g, v in zip(genes, valid_gene) if v]
X_dense = _x_dense[:, col_pos[valid_gene]]
else:
# Choose expression source
if use_raw and getattr(adata, "raw", None) is not None:
X = adata.raw.X
var_names = pd.Index(adata.raw.var_names).astype(str)
elif layer is not None:
X = adata.layers[layer]
var_names = pd.Index(adata.var_names).astype(str)
else:
X = adata.X
var_names = pd.Index(adata.var_names).astype(str)
gene_idx_arr = var_names.get_indexer(genes)
valid_gene = gene_idx_arr >= 0
valid_genes = [g for g, v in zip(genes, valid_gene) if v]
gene_cols = gene_idx_arr[valid_gene]
if not valid_genes:
return np.empty((len(groups), 0), dtype=np.float64), groups, []
X_sub = X[:, gene_cols]
X_dense = X_sub.toarray() if hasattr(X_sub, "toarray") else np.asarray(X_sub)
if not valid_genes:
return np.empty((len(groups), 0), dtype=np.float64), groups, []
# Map each cell to its group integer index (-1 = not in requested groups)
labels = adata.obs[level_col].astype(str).values
ct_enum = {ct: i for i, ct in enumerate(groups)}
label_idx = np.fromiter(
(ct_enum.get(l, -1) for l in labels), dtype=np.intp, count=len(labels)
)
valid_mask = label_idx >= 0
valid_labels = label_idx[valid_mask]
X_valid = X_dense[valid_mask].astype(np.float64)
n_groups = len(groups)
# Pandas groupby-mean is BLAS-backed and far faster than np.add.at
raw_means = (
pd.DataFrame(X_valid, index=valid_labels)
.groupby(level=0)
.mean()
.reindex(np.arange(n_groups)) # missing groups become NaN rows
.to_numpy(dtype=np.float64)
)
return raw_means, groups, valid_genes
def reorder_groups_by_mean_expression(
adata: ad.AnnData,
groups: Sequence[str],
*,
level_col: str,
genes: Sequence[str],
layer: str | None = None,
use_raw: bool | None = None,
_x_dense=None,
_x_genes=None,
) -> list[str]:
"""
Reorder groups so that rows are sorted by the *average expression*
across all provided genes, in descending order.
Pass _x_dense/_x_genes to reuse a pre-densified gene slice.
"""
groups = [str(s) for s in groups]
if len(groups) <= 1 or len(genes) == 0:
return groups
if level_col not in adata.obs.columns:
raise KeyError(f"level_col={level_col!r} not found in adata.obs.")
mat, valid_groups, valid_genes = _compute_group_gene_means(
adata, groups, genes, level_col=level_col, layer=layer, use_raw=use_raw,
_x_dense=_x_dense, _x_genes=_x_genes,
)
if mat.size == 0:
return groups
row_means_arr = np.nanmean(mat, axis=1)
row_means = {g: float(v) for g, v in zip(valid_groups, row_means_arr) if np.isfinite(v)}
if not row_means:
return groups
return sorted(groups, key=lambda s: row_means.get(s, -np.inf), reverse=True)
def reorder_genes_by_mean_expression(
adata: ad.AnnData,
genes: Sequence[str],
*,
groups: Sequence[str],
level_col: str,
layer: str | None = None,
use_raw: bool | None = None,
_x_dense=None,
_x_genes=None,
) -> list[str]:
"""
Reorder genes so that columns are sorted by the average expression
across all cells in the specified groups (descending).
Pass _x_dense/_x_genes to reuse a pre-densified gene slice.
"""
genes = [str(g) for g in genes]
if len(genes) <= 1:
return genes
if level_col not in adata.obs.columns:
raise KeyError(f"level_col={level_col!r} not found in adata.obs.")
mat, valid_groups, valid_genes = _compute_group_gene_means(
adata, groups, genes, level_col=level_col, layer=layer, use_raw=use_raw,
_x_dense=_x_dense, _x_genes=_x_genes,
)
if mat.size == 0:
return genes
col_means = np.nanmean(mat, axis=0)
order = np.argsort(-col_means)
return [valid_genes[i] for i in order]
# ---------------------------------------------------------------------
# Primary DotPlot Helper Function
# ---------------------------------------------------------------------
[docs]
def plot_dotplot_basegrid(
adata,
design_df: pd.DataFrame, # must include 'celltype' and 'gene'; optional 'support'
*,
groupby: str = "ZMAP_CellType",
layer: str | None = None,
use_raw: bool | None = None,
detect_threshold: float = 0.0, # > threshold => “expressing”
# ring settings
support_col: str = "support_ratio",
add_support_ring: bool = True,
ring_min_lw: float = 0.02,
ring_max_lw: float = 1.5,
ring_color: str = "darkorange",
ring_alpha: float = 0.8,
# color/size encodings
cmap: str = "Blues",
vmin: float | None = None,
vmax: float | None = None,
standard_scale: str | None = None, # None | "var" | "obs"
size_mode: str = "sqrt", # "sqrt" or "linear"
s_min: float = 1,
s_max: float = 80,
# row label suffix
rowlabel_append_child_counts: bool = True,
rowlabel_child_col: str = "ZMAP_Cluster",
rowlabel_fmt: str = "{name} ({n})",
# figure
figsize=(12, 12),
xlabel_rotation: int = 90,
title: str | None = None,
add_colorbar: bool = True,
cbar_title: str = "log(tpm)\ncounts",
cbar_title_fontsize: float = 8.0,
cbar_ticklabel_fontsize: float = 8.0,
xticklabel_fontsize: float = 10.0,
yticklabel_fontsize: float = 10.0,
# ===================== DOT SIZE LEGEND =====================
show_size_legend: bool = True,
size_legend_fracs: tuple = (1.0, 0.75, 0.5, 0.25, 0.1),
size_legend_title: str = "Fraction\nExpressing",
size_legend_loc: str = "upper left",
size_legend_bbox_to_anchor: tuple | None = (0.0, 0.7),
size_legend_label_fmt: str = "{:.0%}",
size_legend_facecolor: str = "black",
size_legend_edgecolor: str = "black",
size_legend_edge_lw: float = 0.1,
size_legend_framealpha: float = 0.9,
size_legend_title_fontsize: float = 8.0,
size_legend_label_fontsize: float = 8.0,
# ===================== SUPPORT RING LEGEND =====================
show_ring_legend: bool = True,
ring_legend_fracs: tuple = (1, 0.5, 0.1),
ring_legend_title: str = "Consensus\nsupport",
ring_legend_loc: str = "lower left",
ring_legend_bbox_to_anchor: tuple | None = (0.0, 0.25),
ring_legend_label_fmt: str = "{:.0%}",
ring_legend_title_fontsize: float = 8.0,
ring_legend_label_fontsize: float = 8.0,
# ===================== CONSENSUS PANEL (pick one) =====================
consensus_panel: str | None = None, # None | "line" | "support_grid" | "stacked_bar" | "strip"
support_stack_field: str = "supporting_study_names",
# ---- Line panel options ("line") ----
support_line_agg: str = "median", # "median" | "mean"
support_line_color: str = "black",
support_line_lw: float = 1,
support_line_marker: str = '',
support_line_ms: float = 0,
support_line_alpha: float = 0.9,
support_line_ylim: tuple | None = (0.0, 1.2),
# ---- Support GRID options ("support_grid") ----
support_grid_show_legend: bool = False,
support_grid_legend_ncols: int = 2,
# ---- Stacked BAR options ("stacked_bar") ----
stacked_bar_show_legend: bool = True,
stacked_bar_legend_ncols: int = 2,
# ---- Strip options ("strip") ----
support_strip_agg: str = "median",
support_strip_cmap: str = "grey",
support_strip_vmin: float | None = 0.0,
support_strip_vmax: float | None = 1.0,
support_strip_show_colorbar: bool = False,
# ===== Palette sampling from a continuous cmap for study colors =====
support_palette_from_cmap: str | mpl.colors.Colormap | None = None,
support_palette_range: tuple[float, float] = (0.1, 0.9),
# ===== Group color & duplicate gene columns =====
group_color_dict: dict[str, str] | None = None,
duplicate_gene_columns: bool = True,
# ===== Vertical group separators =====
draw_group_separators: bool = True,
separator_color: str = "0.88",
separator_lw: float = 0.7,
separator_ls: str = "-",
include_outer_separators: bool = False,
# ===== Left time strip (inset) ==================================
left_time_strip: bool = False,
time_key: str = "time_id",
time_cmap: str = "jet",
time_vmin: float = 0.0,
time_vmax: float = 120.0,
yticklabel_gap_pt: float = 18.0, # gap (points) between ticklabels and dotplot
time_strip_show_yticklabels: bool = False,
time_strip_label_colors: dict[str, str] | None = None,
# ===================== PHYSICAL SIZE CONTROLS =====================
# consensus/top panel fixed height (in inches)
consensus_height_in: float = 0.35,
# support grid/bar specific heights (in inches); if None, use consensus_height_in
support_grid_height_in: float | None = None,
stacked_bar_height_in: float | None = None,
strip_height_pt: float = 12.0, # for 'strip' variant
# left time strip fixed width (points)
time_strip_width_pt: float = 12.0,
time_strip_gap_pt: float = 4.0,
# colorbar physical sizing
cbar_width_pt: float = 12.0,
cbar_height_frac: float = 0.90, # relative to dotplot height
cbar_bbox_to_anchor: tuple | None = (1.05, 0.85),
# reserve right gutter for legends/colorbar
right_gutter_frac: float = 0.82,
# ===== NEW: external axes control =====
ax: plt.Axes | None = None,
ax_leg: plt.Axes | None = None,
adjust_right_gutter: bool = True,
):
"""
Low-level rendering engine for ZMAP dotplots.
Takes a design DataFrame specifying which ``(cell type, gene)`` pairs to
plot, computes mean expression and fraction-expressing statistics from
``adata``, and renders the result as a scatter-based dotplot with optional
support rings, consensus panels, time strips, colorbars, and size legends.
This function is the shared backend for :func:`group_siblings_vs_markers`
and :func:`group_descendants_vs_markers`. Most users should call one of
those higher-level wrappers instead of invoking this directly.
Parameters
----------
adata : anndata.AnnData
Reference dataset containing expression data and ``obs`` annotations.
design_df : pd.DataFrame
Long-form table with at least columns ``["celltype", "gene"]``.
Each row defines one dot in the grid. An optional ``support_ratio``
column adds consensus-support rings.
groupby : str, default ``"ZMAP_CellType"``
``obs`` column whose categories form the rows of the dotplot.
layer : str or None, default ``None``
Layer to use for expression values. Falls back to ``adata.X``.
use_raw : bool or None, default ``None``
If ``True``, use ``adata.raw`` for expression. Overrides ``layer``.
detect_threshold : float, default ``0.0``
Minimum value for a cell to count as "expressing".
cmap : str, default ``"Blues"``
Matplotlib colormap for dot fill color (mean expression).
vmin, vmax : float or None
Colormap normalization limits. Inferred from data when ``None``.
standard_scale : ``"var"``, ``"obs"``, or None, default ``None``
Scale mean expression per gene (``"var"``) or per group (``"obs"``).
s_min, s_max : float, default ``1`` and ``80``
Minimum and maximum dot sizes (points squared).
add_support_ring : bool, default ``True``
Draw orange rings whose thickness encodes ``support_ratio``.
consensus_panel : str or None, default ``None``
Type of consensus summary panel above the dotplot. One of
``None``, ``"line"``, ``"support_grid"``, ``"stacked_bar"``, ``"strip"``.
left_time_strip : bool, default ``False``
Draw a vertical developmental-time strip to the left of the dotplot.
ax : matplotlib.axes.Axes or None, default ``None``
Pre-existing axes to draw into. A new figure is created when ``None``.
Returns
-------
tuple of (matplotlib.axes.Axes, pd.DataFrame)
``(ax, grid)`` where ``ax`` is the dotplot axes and ``grid`` is the
long-form DataFrame used for rendering, with computed columns
``mean_expr``, ``frac_pos``, and positional indices.
"""
# --------------------- validation & ordering ---------------------
if not {"celltype", "gene"}.issubset(design_df.columns):
raise ValueError("design_df must contain columns: 'celltype' and 'gene'.")
if standard_scale not in (None, "var", "obs"):
raise ValueError("standard_scale must be None, 'var', or 'obs'.")
if consensus_panel not in (None, "line", "support_grid", "stacked_bar", "strip"):
raise ValueError(
"consensus_panel must be one of: None, 'line', 'support_grid', 'stacked_bar', 'strip'."
)
if not (0.0 <= support_palette_range[0] < support_palette_range[1] <= 1.0):
raise ValueError("support_palette_range must be within [0,1] and low < high.")
if groupby not in adata.obs.columns:
raise KeyError(f"groupby '{groupby}' not found in adata.obs.")
if left_time_strip and time_key not in adata.obs.columns:
raise KeyError(f"time_key '{time_key}' not found in adata.obs.")
if consensus_panel in ("support_grid", "stacked_bar"):
if support_stack_field not in design_df.columns:
raise ValueError(
f"consensus_panel='{consensus_panel}' requires column "
f"'{support_stack_field}' in design_df."
)
design = design_df.copy()
design["celltype"] = design["celltype"].astype(str)
design["gene"] = design["gene"].astype(str)
def _ordered_unique(seq):
seen = set()
out = []
for x in seq:
if x not in seen:
seen.add(x)
out.append(x)
return out
# row order
row_order = _ordered_unique(design["celltype"].tolist())
row_index = {r: i for i, r in enumerate(row_order)}
# column slots
if duplicate_gene_columns:
cols_df = (
design.drop_duplicates(subset=["celltype", "gene"], keep="first")
.loc[:, ["celltype", "gene"]]
.copy()
)
cols_df.rename(columns={"celltype": "col_owner"}, inplace=True)
cols_df["col_id"] = np.arange(len(cols_df))
col_labels = cols_df["gene"].tolist()
else:
col_order = _ordered_unique(design["gene"].tolist())
cols_df = pd.DataFrame({"col_owner": [None] * len(col_order), "gene": col_order})
cols_df["col_id"] = np.arange(len(col_order))
col_labels = col_order
col_count = len(cols_df)
col_index = {i: i for i in range(col_count)}
# --------------------- choose expression matrix ---------------------
if use_raw and getattr(adata, "raw", None) is not None:
X = adata.raw.X
var_names = pd.Index(adata.raw.var_names).astype(str)
elif layer is not None:
X = adata.layers[layer]
var_names = pd.Index(adata.var_names).astype(str)
else:
X = adata.X
var_names = pd.Index(adata.var_names).astype(str)
# --------------------- build grid ---------------------
grid = (
pd.MultiIndex.from_product(
[row_order, cols_df["col_id"].tolist()],
names=["celltype", "col_id"],
)
.to_frame(index=False)
.merge(cols_df[["col_id", "gene", "col_owner"]], on="col_id", how="left")
)
if support_col in design.columns:
sup_map = design.groupby(["celltype", "gene"], as_index=False)[support_col].first()
grid = grid.merge(sup_map, on=["celltype", "gene"], how="left")
else:
grid[support_col] = np.nan
grid["mean_expr"] = np.nan
grid["frac_pos"] = np.nan
# --------------------- compute stats (vectorized) ---------------------
present_genes = pd.Index(cols_df["gene"].unique()).intersection(var_names)
present_list = present_genes.tolist()
if present_list:
jmap = pd.Series(range(len(var_names)), index=var_names)
gene_cols = jmap.loc[present_list].to_numpy()
# Densify only the needed gene columns once
X_sub = X[:, gene_cols]
X_dense = X_sub.toarray() if hasattr(X_sub, "toarray") else np.asarray(X_sub)
labels = adata.obs[groupby].astype(str).values
# Build integer label map for fast groupby via np.add.at
ct_list = row_order # preserve row order
ct_enum = {ct: i for i, ct in enumerate(ct_list)}
label_idx = np.array([ct_enum.get(l, -1) for l in labels], dtype=np.intp)
n_groups = len(ct_list)
n_genes = len(present_list)
sum_expr = np.zeros((n_groups, n_genes), dtype=np.float64)
sum_pos = np.zeros((n_groups, n_genes), dtype=np.float64)
counts = np.zeros(n_groups, dtype=np.int64)
valid_mask = label_idx >= 0
valid_idx = label_idx[valid_mask]
X_valid = X_dense[valid_mask]
np.add.at(sum_expr, valid_idx, X_valid)
np.add.at(sum_pos, valid_idx, (X_valid > detect_threshold).astype(np.float64))
np.add.at(counts, valid_idx, 1)
# Avoid divide-by-zero for groups with no cells
safe_counts = counts[:, None].clip(min=1)
mean_mat = sum_expr / safe_counts # shape: (n_groups, n_genes)
frac_mat = sum_pos / safe_counts
# Zero out groups that had no cells
no_cells = (counts == 0)
mean_mat[no_cells] = np.nan
frac_mat[no_cells] = np.nan
# Build lookup DataFrames and merge into grid in one pass
gene_idx_ser = pd.Index(present_list)
rows_idx_ser = pd.Index(ct_list)
mean_df = pd.DataFrame(mean_mat, index=rows_idx_ser, columns=gene_idx_ser)
frac_df = pd.DataFrame(frac_mat, index=rows_idx_ser, columns=gene_idx_ser)
# Stack to long form and merge
mean_long = (
mean_df.stack()
.rename("mean_expr_new")
.reset_index()
.rename(columns={"level_0": "celltype", "level_1": "gene"})
)
frac_long = (
frac_df.stack()
.rename("frac_pos_new")
.reset_index()
.rename(columns={"level_0": "celltype", "level_1": "gene"})
)
grid = grid.merge(mean_long, on=["celltype", "gene"], how="left")
grid = grid.merge(frac_long, on=["celltype", "gene"], how="left")
grid["mean_expr"] = grid["mean_expr_new"].combine_first(grid["mean_expr"])
grid["frac_pos"] = grid["frac_pos_new"].combine_first(grid["frac_pos"])
grid.drop(columns=["mean_expr_new", "frac_pos_new"], inplace=True)
grid["_ri"] = grid["celltype"].map(row_index)
grid["_cj"] = grid["col_id"].map(col_index)
# --------------------- optional standard-scale ---------------------
if standard_scale is not None:
mat_df = grid.pivot(
index="celltype",
columns="col_id",
values="mean_expr",
).reindex(index=row_order, columns=cols_df["col_id"].tolist())
mat = mat_df.to_numpy(dtype=float)
def _scale_0_1(a, axis):
amin = np.nanmin(a, axis=axis, keepdims=True)
amax = np.nanmax(a, axis=axis, keepdims=True)
rng = amax - amin
with np.errstate(invalid="ignore", divide="ignore"):
scaled = (a - amin) / rng
scaled = np.where(np.isfinite(scaled), scaled, 0.0)
scaled = np.where(rng == 0, 0.0, scaled)
return scaled
mat_scaled = _scale_0_1(mat, axis=0 if standard_scale == "var" else 1)
scaled_df = pd.DataFrame(
mat_scaled, index=row_order, columns=cols_df["col_id"].tolist()
)
scaled_long = (
scaled_df.stack()
.rename("mean_expr_scaled")
.reset_index()
.rename(columns={"level_0": "celltype", "level_1": "col_id"})
)
grid = grid.merge(scaled_long, on=["celltype", "col_id"], how="left")
grid["mean_expr"] = grid["mean_expr_scaled"]
grid.drop(columns=["mean_expr_scaled"], inplace=True)
# --------------------- color/size scaling ---------------------
vals = grid["mean_expr"].astype(float).to_numpy()
if vmin is None:
vmin = np.nanpercentile(vals, 1) if np.isfinite(vals).any() else 0.0
if vmax is None:
vmax = np.nanpercentile(vals, 99) if np.isfinite(vals).any() else 1.0
norm = mpl.colors.Normalize(vmin=vmin, vmax=vmax)
cmap_obj = plt.get_cmap(cmap)
sz_raw = grid["frac_pos"].astype(float).clip(lower=0.0)
sz_scale = np.sqrt(sz_raw) if size_mode == "sqrt" else sz_raw
sz = (
s_min
+ (s_max - s_min) * (sz_scale / np.nanmax(sz_scale))
if np.nanmax(sz_scale) > 0
else np.full_like(sz_scale, s_min)
)
# --------------------- main dot grid ---------------------
if ax is None:
fig, ax = plt.subplots(figsize=figsize)
else:
fig = ax.figure
if figsize is not None:
fig.set_size_inches(figsize, forward=True)
# Group separators behind dots
col_count = len(cols_df)
if draw_group_separators and col_count > 1:
owners = cols_df["col_owner"].tolist()
boundaries = [j for j in range(1, col_count) if owners[j] != owners[j - 1]]
if include_outer_separators:
boundaries = [0] + boundaries + [col_count]
for j in boundaries:
ax.axvline(
x=j - 0.5,
color=separator_color,
linewidth=separator_lw,
linestyle=separator_ls,
zorder=0,
)
# Dots
ax.scatter(
grid["_cj"],
grid["_ri"],
s=sz,
c=cmap_obj(norm(grid["mean_expr"].astype(float))),
edgecolor="black",
linewidth=0.1,
)
# -----------------------------------------------------------------
# Reserve / use right gutter (for legends/colorbar) and create ax_leg
# -----------------------------------------------------------------
fig = ax.get_figure()
if ax_leg is None:
if adjust_right_gutter:
fig.subplots_adjust(right=right_gutter_frac)
base = ax.get_position() # after any subplots_adjust
gutter_left = base.x1 + 0.01
gutter_right = 1.0 - 0.02
gutter_width = max(0.01, gutter_right - gutter_left)
ax_leg = fig.add_axes([gutter_left, base.y0, gutter_width, base.height])
ax_leg.axis("off")
# else: use provided ax_leg as-is
# Support ring overlay
if add_support_ring and support_col in grid.columns:
sr = grid[support_col].astype(float).clip(0.0, 1.0).fillna(0.0)
lw = ring_min_lw + (ring_max_lw - ring_min_lw) * sr
mask = lw.to_numpy() > 0
if np.any(mask):
ax.scatter(
grid.loc[mask, "_cj"],
grid.loc[mask, "_ri"],
s=sz[mask],
facecolors="none",
edgecolors=ring_color,
linewidths=lw[mask],
alpha=ring_alpha,
)
# Cosmetics
xpad = 0.3
ax.set_xlim(-0.5 - xpad, col_count - 0.5 + xpad)
ax.set_ylim(len(row_order) - 0.5, -0.5)
ax.set_xticks(range(col_count))
xt = ax.set_xticklabels(col_labels, rotation=xlabel_rotation, ha="center", fontsize=xticklabel_fontsize)
ax.set_yticks(range(len(row_order)))
yt = ax.set_yticklabels(row_order, fontsize=yticklabel_fontsize)
# Append child counts / leaf marker
if rowlabel_append_child_counts:
if groupby not in adata.obs.columns:
raise KeyError(f"groupby '{groupby}' not found in adata.obs.")
if rowlabel_child_col not in adata.obs.columns:
raise KeyError(f"rowlabel_child_col '{rowlabel_child_col}' not found in adata.obs.")
if rowlabel_child_col == groupby:
counts = pd.Series(1, index=pd.Index(row_order, name=groupby))
else:
tmp = adata.obs[[groupby, rowlabel_child_col]].dropna()
tmp = tmp.drop_duplicates([groupby, rowlabel_child_col])
counts = tmp.groupby(groupby, sort=False, observed=True)[rowlabel_child_col].nunique()
display_labels = []
for r in row_order:
n = int(counts.get(r, 0))
if n <= 1:
lbl = f"{r} •"
else:
lbl = rowlabel_fmt.format(name=r, n=n)
display_labels.append(lbl)
ax.set_yticklabels(display_labels, fontsize=yticklabel_fontsize)
if left_time_strip:
# Convert ticklabel gap from points → inches
lbl_gap_in = _pt(yticklabel_gap_pt)
# Figure + axis width in inches
fig_w, fig_h = fig.get_size_inches()
pos = ax.get_position()
ax_width_in = fig_w * pos.width if fig_w > 0 else 0.0
# Convert physical gap (in) → axes fraction
if ax_width_in > 0:
x_label_frac = - lbl_gap_in / ax_width_in
else:
x_label_frac = 0.0
# x in axes fraction, y in data → stable rows + constant physical gap
trans = mtransforms.blended_transform_factory(ax.transAxes, ax.transData)
for label in ax.get_yticklabels():
label.set_ha("right")
label.set_transform(trans)
label.set_x(x_label_frac)
# Clean up lines/grids
ax.grid(False)
ax.xaxis.grid(False, which="both")
ax.yaxis.grid(False, which="both")
ax.tick_params(axis="both", which="both", length=0)
ax.minorticks_off()
# Title & spines
if title:
ax.set_title(title, fontsize=10)
for spine in ax.spines.values():
spine.set_visible(True)
spine.set_linewidth(1)
spine.set_edgecolor("black")
# ----------------- CONSENSUS PANELS (fixed-height sibling axes) -----------------
def _build_study_palette(studies_order, base="tab10", from_cmap=None, t_range=(0.1, 0.9)):
S_ = len(studies_order)
if S_ == 0:
return []
if from_cmap is not None:
cmap_obj_ = plt.get_cmap(from_cmap) if isinstance(from_cmap, str) else from_cmap
t0, t1 = float(t_range[0]), float(t_range[1])
ts = np.linspace(t0, t1, S_, endpoint=True)
return [cmap_obj_(t) for t in ts]
if S_ <= 10:
base_cmap = plt.get_cmap(base)
return [base_cmap(i) for i in range(10)][:S_]
base_cmap = plt.get_cmap("tab20")
colors_ = [base_cmap(i) for i in range(20)]
return [colors_[i % 20] for i in range(S_)]
def _canon_study_name(s: str) -> str:
s = re.sub(r"^[\s'\"\[\]]+|[\s'\"\[\]]+$", "", str(s))
s = re.sub(r"\s+", " ", s).strip()
return s
def _parse_listlike_strict(x):
if x is None:
return []
if isinstance(x, (list, tuple, set)):
return [
_canon_study_name(v)
for v in x
if v is not None and str(v).strip().lower() not in ("nan", "none")
]
if hasattr(x, "tolist"):
return _parse_listlike_strict(x.tolist())
s = str(x).strip()
if not s or s.lower() in ("nan", "none"):
return []
try:
lit = ast.literal_eval(s)
if isinstance(lit, (list, tuple, set)):
out = [
_canon_study_name(v)
for v in lit
if v is not None and str(v).strip().lower() not in ("nan", "none")
]
return list(dict.fromkeys(out))
except Exception:
pass
toks = re.split(r"[;,\|]+", s)
out = [
_canon_study_name(t)
for t in toks
if t and _canon_study_name(t) and _canon_study_name(t).lower() not in ("nan", "none")
]
return list(dict.fromkeys(out))
has_any_support = (support_col in grid.columns and grid[support_col].notna().any())
if consensus_panel in ("line", "strip") and not has_any_support:
raise ValueError(
f"consensus_panel='{consensus_panel}' requires non-null values in support_col='{support_col}'."
)
if consensus_panel == "line" and has_any_support:
if support_line_agg == "median":
gene_support = grid.groupby("gene", sort=False)[support_col].median()
elif support_line_agg == "mean":
gene_support = grid.groupby("gene", sort=False)[support_col].mean()
else:
raise ValueError("support_line_agg must be 'median' or 'mean'.")
yvals = np.array([float(gene_support.get(g, np.nan)) for g in cols_df["gene"]], dtype=float)
yvals = np.nan_to_num(yvals, nan=0.0)
line_ax = _add_top_sibling_axes(ax, height_in=consensus_height_in, pad_in=0.0)
x = np.arange(col_count)
line_ax.plot(
x, yvals,
color=support_line_color,
linewidth=support_line_lw,
marker=support_line_marker,
markersize=support_line_ms,
alpha=support_line_alpha,
)
line_ax.set_xlim(-0.5, col_count - 0.5)
if support_line_ylim is not None:
line_ax.set_ylim(*support_line_ylim)
line_ax.yaxis.set_major_locator(mpl.ticker.MultipleLocator(0.2))
line_ax.tick_params(axis="y", labelleft=False, length=0)
line_ax.set_xticks([])
line_ax.grid(False)
line_ax.xaxis.grid(False, which="both")
line_ax.yaxis.grid(False, which="both")
line_ax.spines["top"].set_visible(False)
line_ax.spines["right"].set_visible(False)
line_ax.spines["left"].set_linewidth(1)
line_ax.spines["bottom"].set_linewidth(1)
elif consensus_panel in ("support_grid", "stacked_bar"):
sub = design.loc[:, ["celltype", "gene", support_stack_field]].copy()
sub["__studies__"] = sub[support_stack_field].apply(_parse_listlike_strict)
studies_order, seen = [], set()
for g in cols_df["gene"]:
for lst in sub.loc[sub["gene"] == g, "__studies__"]:
for st in lst:
if st not in seen:
seen.add(st)
studies_order.append(st)
S = len(studies_order)
if S > 0:
base_choice = "tab10"
colors = _build_study_palette(
studies_order,
base=base_choice if consensus_panel == "support_grid" else base_choice,
from_cmap=support_palette_from_cmap,
t_range=support_palette_range,
)
# Fixed-height panel above dotplot
panel_height_in = (
support_grid_height_in
if consensus_panel == "support_grid" and support_grid_height_in is not None
else stacked_bar_height_in
if consensus_panel == "stacked_bar" and stacked_bar_height_in is not None
else consensus_height_in
)
top_ax = _add_top_sibling_axes(ax, height_in=panel_height_in, pad_in=0.0)
if consensus_panel == "support_grid":
M = np.zeros((S, col_count), dtype=int)
st_to_i = {s: i for i, s in enumerate(studies_order)}
by_gene = sub.groupby("gene")["__studies__"].apply(list)
for j, g in enumerate(cols_df["gene"]):
present = set()
for lst in by_gene.get(g, []):
present.update(lst)
for st in present:
M[st_to_i[st], j] = st_to_i[st] + 1
listed = ListedColormap([(1, 1, 1, 1)] + colors)
top_ax.imshow(
M,
aspect="auto",
interpolation="nearest",
cmap=listed,
vmin=0,
vmax=max(1, S),
extent=[-0.5, col_count - 0.5, -0.5, S - 0.5],
)
top_ax.set_xlim(-0.5, col_count - 0.5)
top_ax.set_ylim(-0.5, max(S - 0.5, 0.5))
top_ax.set_xticks([])
top_ax.set_yticks([])
top_ax.grid(False)
for s in top_ax.spines.values():
s.set_visible(True)
s.set_linewidth(1)
if support_grid_show_legend:
handles = [
Patch(facecolor=colors[i], edgecolor="none", label=studies_order[i])
for i in range(S)
]
top_ax.legend(
handles=handles,
loc="lower left",
bbox_to_anchor=(0, 1.02),
ncol=support_grid_legend_ncols,
frameon=False,
fontsize=8,
)
else: # "stacked_bar"
st_to_color = {s: colors[i % len(colors)] for i, s in enumerate(studies_order)}
by_gene = sub.groupby("gene")["__studies__"].apply(list)
max_h = 0
for j, g in enumerate(cols_df["gene"]):
present = set()
for lst in by_gene.get(g, []):
present.update(lst)
stack_list = [s for s in studies_order if s in present]
for k, st in enumerate(stack_list):
rect = mpl.patches.Rectangle(
(j - 0.5, k - 0.5), 1.0, 1.0, facecolor=st_to_color[st], edgecolor="none",
)
top_ax.add_patch(rect)
max_h = max(max_h, len(stack_list))
top_ax.set_xlim(-0.5, col_count - 0.5)
top_ax.set_ylim(-0.5, max(max_h - 0.5, 0.5))
top_ax.set_xticks([])
top_ax.set_yticks([])
top_ax.grid(False)
for s in top_ax.spines.values():
s.set_visible(True)
s.set_linewidth(1)
if stacked_bar_show_legend:
handles = [
Patch(facecolor=colors[i], edgecolor="none", label=studies_order[i])
for i in range(S)
]
top_ax.legend(
handles=handles,
loc="lower left",
bbox_to_anchor=(0, 1.02),
ncol=stacked_bar_legend_ncols,
frameon=False,
fontsize=8,
)
elif consensus_panel == "strip" and has_any_support:
if support_strip_agg == "median":
gene_support = grid.groupby("gene", sort=False)[support_col].median()
elif support_strip_agg == "mean":
gene_support = grid.groupby("gene", sort=False)[support_col].mean()
else:
raise ValueError("support_strip_agg must be 'median' or 'mean'.")
vals_strip = np.array(
[float(gene_support.get(g, np.nan)) for g in cols_df["gene"]],
dtype=float,
)[None, :]
vals_strip = np.nan_to_num(vals_strip, nan=0.0)
svmin = 0.0 if support_strip_vmin is None else support_strip_vmin
svmax = 1.0 if support_strip_vmax is None else support_strip_vmax
top_ax = _add_top_sibling_axes(ax, height_in=_pt(strip_height_pt), pad_in=0.0)
im_strip = top_ax.imshow(
vals_strip,
aspect="auto",
interpolation="nearest",
cmap=support_strip_cmap,
vmin=svmin,
vmax=svmax,
extent=[-0.5, col_count - 0.5, -0.5, 0.5],
)
top_ax.set_xlim(ax.get_xlim())
top_ax.set_xticks([])
top_ax.set_yticks([])
top_ax.grid(False)
for s in top_ax.spines.values():
s.set_visible(True)
s.set_linewidth(1)
# Optional colorbar for strip
if support_strip_show_colorbar:
cax_strip = _inset_fixed(
top_ax, width=_pt(cbar_width_pt), height="80%", loc="center left",
bbox_to_anchor=(1.02, 0.5, 0, 0), borderpad=0.0
)
cbar_strip = plt.colorbar(
mpl.cm.ScalarMappable(
norm=mpl.colors.Normalize(vmin=svmin, vmax=svmax),
cmap=support_strip_cmap,
),
cax=cax_strip,
)
cbar_strip.ax.tick_params(labelsize=7)
# --------------------- Color tick labels by group -----------------------
# Important: ytick texts may include suffixes ("name (n)" or "name •").
# Color using the *canonical* names from `row_order` by index, not the label text.
if group_color_dict is not None:
# Y (cells): map by index -> canonical name
yticks = ax.get_yticklabels()
for i, label in enumerate(yticks):
if 0 <= i < len(row_order):
ct_base = row_order[i]
if ct_base in group_color_dict:
label.set_color(group_color_dict[ct_base])
# X (genes): unchanged — still keyed by column owner
xticks = ax.get_xticklabels()
for j, lbl in enumerate(xticks):
owner = cols_df.iloc[j]["col_owner"]
if owner is not None and owner in group_color_dict:
lbl.set_color(group_color_dict[owner])
# --------------------- Colorbar in a fixed-width inset ------------------
if add_colorbar:
height_str = cbar_height_frac
cax = _inset_fixed(
ax,
width=_pt(cbar_width_pt), # inches
height=height_str, # percent of dotplot height
loc="center left",
bbox_to_anchor=cbar_bbox_to_anchor,
borderpad=0.0,
)
cbar = plt.colorbar(
mpl.cm.ScalarMappable(norm=norm, cmap=cmap_obj),
cax=cax,
)
# Add left-aligned title at the top-left corner of the colorbar axis
cbar.ax.set_title("")
cbar.ax.text(
0.0, 1.1,
cbar_title,
ha="left", va="bottom",
rotation=0,
fontsize=cbar_title_fontsize,
transform=cbar.ax.transAxes,
)
cbar.ax.tick_params(labelsize=cbar_ticklabel_fontsize)
# --------------------- Size Legend (inside right gutter) ------------------
if show_size_legend and len(size_legend_fracs) > 0:
fracs = np.clip(np.array(size_legend_fracs, dtype=float), 0.0, 1.0)
fr_scaled = np.sqrt(fracs) if size_mode == "sqrt" else fracs
sizes = (
s_min + (s_max - s_min) * (fr_scaled / np.max(fr_scaled))
if np.max(fr_scaled) > 0
else np.full_like(fr_scaled, s_min)
)
handles = [
ax.scatter(
[], [], s=s,
facecolor=size_legend_facecolor,
edgecolor=size_legend_edgecolor,
linewidth=size_legend_edge_lw,
)
for s in sizes
]
labels = [size_legend_label_fmt.format(f) for f in fracs]
legend_kwargs = dict(
loc=size_legend_loc,
frameon=False,
scatterpoints=1,
handletextpad=0.9,
columnspacing=1.0,
labelspacing=0.6,
borderpad=0.0,
framealpha=size_legend_framealpha,
title=size_legend_title,
)
if size_legend_bbox_to_anchor is not None:
legend_kwargs["bbox_to_anchor"] = size_legend_bbox_to_anchor
size_legend = ax_leg.legend(
handles,
labels,
**legend_kwargs,
)
for text in size_legend.get_texts():
text.set_fontsize(size_legend_label_fontsize)
size_legend.get_title().set_fontsize(size_legend_title_fontsize)
ax_leg.add_artist(size_legend)
# --------------------- Ring Legend (inside right gutter) ------------------
if show_ring_legend and add_support_ring and len(ring_legend_fracs) > 0:
rfracs = np.clip(np.array(ring_legend_fracs, dtype=float), 0.0, 1.0)
lw_ring = ring_min_lw + (ring_max_lw - ring_min_lw) * rfracs
base_size = (s_min + s_max) / 2.0
ring_handles = [
ax.scatter([], [], s=base_size, facecolors="none",
edgecolors=ring_color, linewidths=lw_, alpha=ring_alpha)
for lw_ in lw_ring
]
ring_labels = [ring_legend_label_fmt.format(f) for f in rfracs]
legend_kwargs = dict(
loc=ring_legend_loc,
frameon=False,
scatterpoints=1,
handletextpad=0.9,
columnspacing=1.0,
labelspacing=0.6,
borderpad=0.0,
framealpha=size_legend_framealpha,
title=ring_legend_title,
)
if ring_legend_bbox_to_anchor is not None:
legend_kwargs["bbox_to_anchor"] = ring_legend_bbox_to_anchor
ring_legend = ax_leg.legend(
ring_handles,
ring_labels,
**legend_kwargs,
)
for text in ring_legend.get_texts():
text.set_fontsize(ring_legend_label_fontsize)
ring_legend.get_title().set_fontsize(ring_legend_title_fontsize)
ax_leg.add_artist(ring_legend)
# --------------------- OPTIONAL left inset time strip -------------------
if left_time_strip:
_add_left_time_strip_inset(
fig,
ax,
adata=adata,
groupby=groupby,
row_order=row_order,
time_key=time_key,
vmin=time_vmin,
vmax=time_vmax,
cmap=time_cmap,
width_pt=time_strip_width_pt,
gap_pt=time_strip_gap_pt,
show_yticklabels=time_strip_show_yticklabels,
label_colors=time_strip_label_colors,
draw_border=False,
)
return ax, grid
# ---------------------------------------------------------------------
# Dotplot Wrapper: Siblings+Tissues vs Consensus Markers
# ---------------------------------------------------------------------
[docs]
def group_siblings_vs_markers(
adata: ad.AnnData,
node: str,
*,
# ontology level
level_col: str | None = None,
parent_col: str | None = None,
# marker selection
marker_type: Literal["overall", "exclusivity", "contrast", "consensus"] = "overall",
n_markers: int = 20,
min_support_ratio: float | None = None,
min_log2fc: float | None = None,
min_enrich: float | None = None,
omit_unannotated: bool = True,
# expression source
layer: str | None = "tpm_log",
use_raw: bool | None = None,
# style
highlight_color: str = "black",
standard_scale: Literal["var", "obs"] | None = None,
cmap: str = "Blues",
enforce_global_colorscale: bool = False,
# layout / sizing
width_per_gene: float = 0.04,
height_per_group: float = 0.2,
min_figsize: tuple[float, float] = (0.5, 0.5),
max_figsize: tuple[float, float] = (25.0, 25.0),
hspace: float = 0.04,
# time strip & consensus
left_time_strip: bool = False,
consensus_panel: str | None = None,
show_size_legend: bool = True,
xlabel_rotation: int = 90,
duplicate_gene_columns: bool = False,
# legends
show_ring_legend: bool = True,
add_colorbar: bool = True,
# titles
sibling_title: str | None = None,
tissue_title: str | None = None,
**dotplot_kwargs,
):
"""
Two-block dotplot comparing a focal cell type to its siblings and tissues.
Given a focal node (e.g. ``"hepatocyte"``), this function:
- Fetches the consensus marker genes for that node.
- Plots a **sibling block** (top): all cell types sharing the same parent
in the ZMAP hierarchy, colored by mean expression of the focal node's markers.
- Plots a **tissue block** (bottom): all ``ZMAP_Tissue`` groups, showing how
broadly those same marker genes are expressed across tissues.
Both blocks share the same gene columns, color scale, and right-hand legend.
Row order in both blocks is by descending mean expression across the marker genes.
Rings in the tissue block encode ``support_ratio`` from the Tissue-level
consensus marker table, indicating cross-study reproducibility.
Parameters
----------
adata : anndata.AnnData
Reference dataset. Must contain the relevant ``obs`` columns for the
ZMAP hierarchy (e.g. ``ZMAP_CellType``, ``ZMAP_Tissue``).
node : str
The focal cell type or cluster whose markers and siblings to plot
(e.g. ``"hepatocyte"``, ``"Neurons"``).
level_col : str or None, default ``None``
``obs`` column defining the annotation level of ``node``
(e.g. ``"ZMAP_CellType"``). Auto-detected from the data when ``None``.
parent_col : str or None, default ``None``
``obs`` column defining the parent level used to identify siblings
(e.g. ``"ZMAP_Tissue"``). Inferred from ``level_col`` when ``None``.
marker_type : str, default ``"overall"``
Scoring criterion for selecting markers. One of
``"overall"``, ``"exclusivity"``, ``"contrast"``, ``"consensus"``.
n_markers : int, default ``20``
Number of top marker genes to display as columns.
min_support_ratio, min_log2fc, min_enrich : float or None, default ``None``
Optional filters applied to the marker table before selecting genes.
See ``load_consensus_markers`` for definitions.
omit_unannotated : bool, default ``True``
Remove genes with unannotated/placeholder names from the marker set.
layer : str or None, default ``"tpm_log"``
Layer in ``adata.layers`` to use for expression values. Falls back to
``adata.X`` when ``None``.
use_raw : bool or None, default ``None``
If ``True``, use ``adata.raw`` for expression. Overrides ``layer``.
highlight_color : str, default ``"black"``
Color used to highlight the focal node's row in the sibling block.
standard_scale : ``"var"``, ``"obs"``, or None, default ``None``
Scale expression values per gene (``"var"``) or per cell (``"obs"``)
before computing dot statistics. ``None`` applies no scaling.
cmap : str, default ``"Blues"``
Matplotlib colormap for dot fill color.
enforce_global_colorscale : bool, default ``False``
If ``True``, use a single shared color scale across both blocks.
If ``False``, each block is scaled independently.
width_per_gene : float, default ``0.04``
Figure width contribution per gene column, in inches.
height_per_group : float, default ``0.2``
Figure height contribution per row (group), in inches.
min_figsize, max_figsize : tuple of float
Hard lower and upper bounds on the auto-computed figure size.
hspace : float, default ``0.04``
Vertical gap between the sibling and tissue blocks.
left_time_strip : bool, default ``False``
Draw a developmental time distribution strip on the left margin.
consensus_panel : str or None, default ``None``
If provided, add an additional panel showing support ratios from the
specified consensus level (e.g. ``"Tissue"``).
show_size_legend : bool, default ``True``
Show the dot-size legend (fraction of cells expressing each gene).
xlabel_rotation : int, default ``90``
Rotation angle for gene name tick labels on the x-axis.
duplicate_gene_columns : bool, default ``False``
Repeat the gene column labels on both the top and bottom axes.
show_ring_legend : bool, default ``True``
Show the ring (support_ratio) legend in the tissue block.
add_colorbar : bool, default ``True``
Add a colorbar for mean expression.
sibling_title, tissue_title : str or None, default ``None``
Custom titles for the sibling and tissue blocks respectively.
Auto-generated when ``None``.
**dotplot_kwargs
Additional keyword arguments forwarded to the underlying dotplot
rendering functions.
Returns
-------
matplotlib.figure.Figure
The rendered figure.
Examples
--------
>>> zmap.dotplot.group_siblings_vs_markers(adata_ref, "hepatocyte")
>>> zmap.dotplot.group_siblings_vs_markers(adata_ref, "Neurons", n_markers=15, cmap="Reds")
"""
node_str = str(node)
# ---- 1) Determine level and parent ----
if level_col is None:
level_col = find_level_for_node(adata, node_str)
if level_col not in _ZMAP_LEVEL_CONFIG:
raise KeyError(f"level_col={level_col!r} not recognized in _ZMAP_LEVEL_CONFIG.")
if parent_col is None:
parent_col = _ZMAP_LEVEL_CONFIG[level_col]["parent_col"]
# will we have a sibling block?
has_sibling_block = (level_col != "ZMAP_Tissue")
parent_label: Optional[str] = None
# ---- 2) Focal markers and gene set ----
marker_df = get_node_markers(
node=node_str,
level_col=level_col,
marker_types=(marker_type,),
n_per_type=n_markers,
min_support_ratio=min_support_ratio,
min_log2fc=min_log2fc,
min_enrich=min_enrich,
omit_unannotated=omit_unannotated,
)
node_markers = marker_df.loc[marker_df["celltype"] == node_str]
if node_markers.empty:
raise ValueError(f"No rows in marker_df for focal node={node_str!r}.")
genes = node_markers["gene"].astype(str).unique().tolist()
# ---- 2b) Densify X[:, gene_cols] ONCE for all pre-render steps ----
_x_dense, _x_genes = _get_expression_slice(
adata, genes, layer=layer, use_raw=use_raw
)
# ---- 3) Siblings (if not Tissue-level) ----
if has_sibling_block:
parent_label, siblings = get_parent_and_siblings(
adata,
node=node_str,
level_col=level_col,
parent_col=parent_col,
)
siblings = reorder_groups_by_mean_expression(
adata,
groups=siblings,
level_col=level_col,
genes=genes,
layer=layer,
use_raw=use_raw,
_x_dense=_x_dense,
_x_genes=_x_genes,
)
else:
siblings = []
# ---- 4) Tissues block (list of tissues, not yet design_df) ----
if "ZMAP_Tissue" not in adata.obs.columns:
raise KeyError("ZMAP_Tissue column not found in adata.obs; required for tissue block.")
tissues_all = (
adata.obs["ZMAP_Tissue"]
.dropna()
.astype(str)
.unique()
.tolist()
)
tissues_all = sorted(tissues_all)
tissues = reorder_groups_by_mean_expression(
adata,
groups=tissues_all,
level_col="ZMAP_Tissue",
genes=genes,
layer=layer,
use_raw=use_raw,
_x_dense=_x_dense,
_x_genes=_x_genes,
)
# ---- 4b) Reorder genes by global mean (column order) ----
if has_sibling_block:
genes = reorder_genes_by_mean_expression(
adata,
genes,
groups=siblings,
level_col=level_col,
layer=layer,
use_raw=use_raw,
_x_dense=_x_dense,
_x_genes=_x_genes,
)
else:
genes = reorder_genes_by_mean_expression(
adata,
genes,
groups=tissues,
level_col="ZMAP_Tissue",
layer=layer,
use_raw=use_raw,
_x_dense=_x_dense,
_x_genes=_x_genes,
)
# ---- 5) Now build design_dfs using the reordered genes ----
if has_sibling_block:
design_sib = make_sibling_design_df(
node=node_str,
siblings=siblings,
marker_df=marker_df,
genes=genes,
support_col="support_ratio",
)
design_sib = design_sib[design_sib["gene"].isin(genes)]
# All labels black — bold applied after render
group_color_sib = {ct: "black" for ct in siblings}
if sibling_title is None:
sibling_title = f"{node_str}"
else:
design_sib = None
group_color_sib = {}
# tissues design_df with same gene order
design_tiss = pd.DataFrame(
[(t, g) for t in tissues for g in genes],
columns=["celltype", "gene"],
)
# ---- 5b) Attach Tissue-level support_ratio for (Tissue, gene) pairs ----
try:
tissue_support_df = load_consensus_markers(
level="Tissue",
groups=tissues,
marker_type=marker_type,
n_per_group=None,
min_support_ratio=min_support_ratio,
min_log2fc=min_log2fc,
min_enrich=min_enrich,
omit_unannotated=omit_unannotated,
format="table",
)
except TypeError:
tissue_support_df = load_consensus_markers(
level="Tissue",
groups=tissues,
marker_type=marker_type,
n_per_group=1000,
min_support_ratio=min_support_ratio,
min_log2fc=min_log2fc,
min_enrich=min_enrich,
omit_unannotated=omit_unannotated,
format="table",
)
if tissue_support_df is not None and not tissue_support_df.empty:
ts = tissue_support_df.copy()
ts["celltype"] = ts["celltype"].astype(str)
ts["gene"] = ts["gene"].astype(str)
ts = ts[ts["gene"].isin(genes)]
if "support_ratio" in ts.columns:
ts_small = (
ts
.drop_duplicates(["celltype", "gene"])
.loc[:, ["celltype", "gene", "support_ratio"]]
)
design_tiss = design_tiss.merge(
ts_small,
on=["celltype", "gene"],
how="left",
)
else:
design_tiss["support_ratio"] = np.nan
else:
design_tiss["support_ratio"] = np.nan
# All labels black — bold applied after render
group_color_tiss = {t: "black" for t in tissues}
if tissue_title is None:
tissue_title = f"Tissue context for node = {node_str}"
# ---- 5c) Sanity check: ensure gene order matches between sibling & tissue blocks ----
if has_sibling_block and design_sib is not None:
def _ordered_genes(df: pd.DataFrame) -> list[str]:
return list(dict.fromkeys(df["gene"].astype(str).tolist()))
sib_genes_order = _ordered_genes(design_sib)
tiss_genes_order = _ordered_genes(design_tiss)
if sib_genes_order != tiss_genes_order:
raise RuntimeError(
"Gene order mismatch between sibling and tissue design_dfs.\n"
f"sib_genes: {sib_genes_order}\n"
f"tiss_genes: {tiss_genes_order}"
)
# ---- 6) Figure sizing ----
n_cols = len(genes)
n_rows_sib = len(siblings) if has_sibling_block else 0
n_rows_tiss = len(tissues)
total_rows = n_rows_sib + n_rows_tiss
width = width_per_gene * max(n_cols, 1) + 3.0
height = height_per_group * max(total_rows, 1) + 2.5
width = float(np.clip(width, min_figsize[0], max_figsize[0]))
height = float(np.clip(height, min_figsize[1], max_figsize[1]))
fig = plt.figure(figsize=(width, height))
if has_sibling_block:
gs = fig.add_gridspec(
2, 1,
height_ratios=[max(n_rows_sib, 1), max(n_rows_tiss, 1)],
hspace=hspace,
)
ax_sib = fig.add_subplot(gs[0, 0])
ax_tiss = fig.add_subplot(gs[1, 0], sharex=ax_sib)
else:
gs = fig.add_gridspec(1, 1)
ax_sib = None
ax_tiss = fig.add_subplot(gs[0, 0])
# ---- 7) Shared right-hand gutter for legends/colorbar ----
base = ax_tiss.get_position()
gutter_left = base.x1 + 0.01
gutter_right = 1.0 - 0.02
gutter_width = max(0.01, gutter_right - gutter_left)
ax_leg = fig.add_axes([gutter_left, base.y0, gutter_width, base.height])
ax_leg.axis("off")
# ---- 8) Optional: global colorscale (vmin/vmax) across both panels ----
vmin_global = vmax_global = None
if enforce_global_colorscale:
panels: list[tuple[list[str], str]] = []
if design_sib is not None:
panels.append((siblings, level_col))
panels.append((tissues, "ZMAP_Tissue"))
all_vals: list[np.ndarray] = []
for panel_groups, gcol in panels:
mat, _, _ = _compute_group_gene_means(
adata,
groups=panel_groups,
genes=genes,
level_col=gcol,
layer=layer,
use_raw=use_raw,
_x_dense=_x_dense,
_x_genes=_x_genes,
)
if standard_scale is not None and mat.size > 0:
amin = np.nanmin(mat, axis=0 if standard_scale == "var" else 1, keepdims=True)
amax = np.nanmax(mat, axis=0 if standard_scale == "var" else 1, keepdims=True)
rng = amax - amin
with np.errstate(invalid="ignore", divide="ignore"):
mat = (mat - amin) / rng
mat = np.where(np.isfinite(mat), mat, 0.0)
mat = np.where(rng == 0, 0.0, mat)
all_vals.append(mat.ravel())
if all_vals:
all_vals_concat = np.concatenate(all_vals)
finite = np.isfinite(all_vals_concat)
if finite.any():
vmin_global = float(np.nanpercentile(all_vals_concat[finite], 1))
vmax_global = float(np.nanpercentile(all_vals_concat[finite], 99))
# ---- 9) Draw sibling block (top) ----
grid_sib = None
if has_sibling_block:
ax_sib, grid_sib = plot_dotplot_basegrid(
adata,
design_df=design_sib,
groupby=level_col,
layer=layer,
use_raw=use_raw,
standard_scale=standard_scale,
cmap=cmap,
vmin=vmin_global,
vmax=vmax_global,
group_color_dict=group_color_sib,
consensus_panel=consensus_panel,
left_time_strip=left_time_strip,
show_size_legend=False,
show_ring_legend=False,
add_colorbar=False,
xlabel_rotation=xlabel_rotation,
duplicate_gene_columns=duplicate_gene_columns,
title=sibling_title,
figsize=None,
ax=ax_sib,
ax_leg=ax_leg,
adjust_right_gutter=False,
rowlabel_append_child_counts=False,
**dotplot_kwargs,
)
ax_sib.tick_params(axis="x", labelbottom=False)
# Bold the focal node in the sibling panel
for lab in ax_sib.get_yticklabels():
base_name = lab.get_text().split(" (")[0].split(" •")[0]
if base_name == node_str:
lab.set_fontweight("bold")
break
# ---- 10) Draw tissue block (bottom) ----
ax_tiss, grid_tiss = plot_dotplot_basegrid(
adata,
design_df=design_tiss,
groupby="ZMAP_Tissue",
layer=layer,
use_raw=use_raw,
standard_scale=standard_scale,
cmap=cmap,
vmin=vmin_global,
vmax=vmax_global,
group_color_dict=group_color_tiss,
consensus_panel=consensus_panel,
left_time_strip=left_time_strip,
show_size_legend=show_size_legend,
show_ring_legend=show_ring_legend,
add_colorbar=add_colorbar,
xlabel_rotation=xlabel_rotation,
duplicate_gene_columns=duplicate_gene_columns,
title='',
figsize=None,
ax=ax_tiss,
ax_leg=ax_leg,
adjust_right_gutter=False,
rowlabel_append_child_counts=False,
**dotplot_kwargs,
)
# ---- 11) Bold the parent tissue and/or focal node in the tissue panel ----
for lab in ax_tiss.get_yticklabels():
base_name = lab.get_text().split(" (")[0].split(" •")[0]
if base_name == parent_label or base_name == node_str:
lab.set_fontweight("bold")
# ---- 12) Duplicate gene labels at top ----
xticks = ax_tiss.get_xticks()
xlabels = [t.get_text() for t in ax_tiss.get_xticklabels()]
anchor_ax = ax_sib if ax_sib is not None else ax_tiss
top_ax = anchor_ax.secondary_xaxis("top")
top_ax.set_xticks(xticks)
top_ax.set_xticklabels(
xlabels,
rotation=90,
ha="center",
fontsize=10,
)
top_ax.tick_params(axis="x", which="both", length=0)
top_ax.set_xlabel("")
# ---- 13) Centered title for ZMAP_Tissue nodes (single-panel version) ----
if level_col == "ZMAP_Tissue":
anchor_ax = ax_sib if ax_sib is not None else ax_tiss
if anchor_ax is not None:
fig = anchor_ax.figure
fig.canvas.draw()
renderer = fig.canvas.get_renderer()
bboxes = [lbl.get_window_extent(renderer=renderer) for lbl in anchor_ax.get_xticklabels()]
if len(bboxes) > 0:
max_label_height_px = max(bb.height for bb in bboxes)
dpi = fig.dpi
max_label_height_in = max_label_height_px / dpi
else:
max_label_height_in = 0.0
extra_gap_in = 0.05
required_pad = max_label_height_in + extra_gap_in
title_ax = _add_top_sibling_axes(
anchor_ax,
height_in=0.28,
pad_in=required_pad,
)
title_ax.set_axis_off()
title_ax.text(
0.5, 0.55,
f"{level_col} = {node}",
fontsize=10,
ha="center",
va="center",
)
return fig, ax_sib, ax_tiss, (grid_sib, grid_tiss), marker_df
# ---------------------------------------------------------------------
# Dotplot Wrapper: Descendants vs Consensus Markers
# ---------------------------------------------------------------------
[docs]
def group_descendants_vs_markers(
adata: ad.AnnData,
parent: Optional[str] = None, # NEW: the selected parent group (e.g., a Tissue or CellType)
*,
# ---- Hierarchy columns (parent/child) ----
parent_col: str = None, # NEW: higher-level obs column (e.g., ZMAP_Tissue or ZMAP_CellType)
child_col: str = None, # NEW: lower-level obs column used for rows (e.g., ZMAP_Cluster or leiden_100)
# ---- Back-compat aliases (deprecated) ----
celltype: Optional[str] = None, # was the selected value at the parent level
celltype_col: Optional[str] = None,
cluster_col: Optional[str] = None,
# ---- expression source ----
layer: Optional[str] = "tpm_log",
use_raw: Optional[bool] = None,
# ---- marker loading ----
marker_type: Literal["specificity", "contrast", "consensus", "overall"] = "overall",
n_markers_per_group: Optional[int] = 5,
min_support_ratio: Optional[float] = None,
min_log2fc: Optional[float] = None,
min_enrich: Optional[float] = None,
omit_unannotated: bool = True,
# ---- group filtering ----
min_cells_per_group: Optional[int] = None,
# ---- ordering / dendrogram ----
use_dendrogram_rows: bool = True,
dendrogram_metric: str = "correlation",
dendrogram_method: str = "average",
# ---- dotplot scaling & style ----
standard_scale: Optional[str] = "var",
detect_threshold: float = 0.0,
cmap: str = "Blues",
consensus_panel: Optional[str] = None,
stacked_bar_show_legend: bool = True,
group_color_dict: Optional[Mapping[str, str]] = None,
duplicate_gene_columns: bool = False,
# ---- figure size control ----
figsize: Optional[Tuple[float, float]] = None,
min_figsize: Tuple[float, float] = (10.0, 6.0),
max_figsize: Tuple[float, float] = (80.0, 40.0),
width_per_gene: float = 0.15,
height_per_group: float = 0.15,
left_margin: float = 1.5,
right_margin: float = 2.5,
top_margin: float = 1.0,
bottom_margin: float = 1.0,
# ---- misc dotplot passthrough ----
s_min: float = 1.0,
s_max: float = 80.0,
xlabel_rotation: int = 90,
title: Optional[str] = None,
show_size_legend: bool = True,
**dotplot_kwargs,
):
"""
Dotplot showing the child groups within a parent group, colored by their marker genes.
Given a parent label (e.g. ``"forebrain"`` at the Tissue level), plots all
child groups at the next level down (e.g. Clusters or Leiden clusters) as
rows, with each child group's consensus marker genes as columns. Gene columns
are ordered by child group, and rows are optionally ordered by a dendrogram
computed over mean expression profiles.
Typical usage:
- ``parent_col="ZMAP_Tissue"``, ``child_col="ZMAP_Cluster"``, ``parent="forebrain"``
- ``parent_col="ZMAP_CellType"``, ``child_col="leiden_100"``, ``parent="Neurons"``
Parameters
----------
adata : anndata.AnnData
Reference dataset. Must contain both ``parent_col`` and ``child_col``
in ``adata.obs``.
parent : str or None, default ``None``
The specific parent label to drill into (e.g. ``"forebrain"``).
Required — raises ``ValueError`` when ``None``.
parent_col : str or None, default ``None``
``obs`` column at the parent level (e.g. ``"ZMAP_Tissue"``).
Auto-detected via the ZMAP hierarchy when ``None``.
child_col : str or None, default ``None``
``obs`` column at the child level to use as rows
(e.g. ``"ZMAP_Cluster"``, ``"leiden_100"``). Required.
layer : str or None, default ``"tpm_log"``
Layer in ``adata.layers`` to use for expression values. Falls back to
``adata.X`` when ``None``.
use_raw : bool or None, default ``None``
If ``True``, use ``adata.raw`` for expression. Overrides ``layer``.
marker_type : str, default ``"overall"``
Scoring criterion for selecting markers per child group. One of
``"specificity"``, ``"contrast"``, ``"consensus"``, ``"overall"``.
n_markers_per_group : int or None, default ``5``
Number of top marker genes to display per child group as columns.
min_support_ratio, min_log2fc, min_enrich : float or None, default ``None``
Optional filters applied to the marker table. See ``load_consensus_markers``.
omit_unannotated : bool, default ``True``
Remove genes with unannotated/placeholder names from the marker set.
min_cells_per_group : int or None, default ``None``
Exclude child groups with fewer than this many cells.
use_dendrogram_rows : bool, default ``True``
Reorder rows by a hierarchical clustering dendrogram computed over
mean expression profiles.
dendrogram_metric : str, default ``"correlation"``
Distance metric for the row dendrogram.
dendrogram_method : str, default ``"average"``
Linkage method for the row dendrogram.
standard_scale : str or None, default ``"var"``
Scale expression per gene (``"var"``) or per cell (``"obs"``). ``None``
applies no scaling.
detect_threshold : float, default ``0.0``
Minimum expression value for a cell to count as "expressing" a gene
when computing the fraction-expressing dot size.
cmap : str, default ``"Blues"``
Matplotlib colormap for dot fill color.
consensus_panel : str or None, default ``None``
Add a panel showing support ratios from the specified consensus level.
group_color_dict : dict or None, default ``None``
``{group_label: color}`` mapping for child group row labels.
duplicate_gene_columns : bool, default ``False``
Repeat gene column labels on both top and bottom of the figure.
figsize : tuple of float or None, default ``None``
Explicit figure size ``(width, height)`` in inches. Auto-computed when ``None``.
min_figsize, max_figsize : tuple of float
Hard bounds on the auto-computed figure size.
width_per_gene, height_per_group : float
Per-gene and per-row size contributions when auto-computing figure size.
s_min, s_max : float, default ``1.0`` and ``80.0``
Minimum and maximum dot sizes (in points²) for the fraction-expressing scale.
xlabel_rotation : int, default ``90``
Rotation angle for gene name tick labels.
title : str or None, default ``None``
Figure title. Auto-generated when ``None``.
show_size_legend : bool, default ``True``
Show the dot-size legend.
**dotplot_kwargs
Additional keyword arguments forwarded to the underlying dotplot renderer.
Returns
-------
matplotlib.figure.Figure
The rendered figure.
Notes
-----
The ``celltype``, ``celltype_col``, and ``cluster_col`` parameter names are
deprecated aliases for ``parent``, ``parent_col``, and ``child_col``
respectively, and will be removed in a future release.
Examples
--------
>>> zmap.dotplot.group_descendants_vs_markers(
... adata_ref,
... parent="forebrain",
... parent_col="ZMAP_Tissue",
... child_col="ZMAP_Cluster",
... )
"""
# -------------------- 0) Back-compat parameter resolution --------------------
if celltype is not None and parent is None:
warnings.warn("`celltype` is deprecated; use `parent`.", DeprecationWarning, stacklevel=2)
parent = celltype
if celltype_col is not None:
warnings.warn("`celltype_col` is deprecated; use `parent_col`.", DeprecationWarning, stacklevel=2)
parent_col = celltype_col
if cluster_col is not None:
warnings.warn("`cluster_col` is deprecated; use `child_col`.", DeprecationWarning, stacklevel=2)
child_col = cluster_col
# -------------------- 1) derive group_order from obs only (no copy yet) ------
if parent is None:
raise ValueError("`parent` must be provided (e.g., a specific Tissue or CellType label).")
if parent_col is None:
parent_col = find_level_for_node(adata, str(parent))
if parent_col not in adata.obs.columns:
raise KeyError(f"{parent_col!r} not found in adata.obs columns.")
if child_col not in adata.obs.columns:
raise KeyError(f"{child_col!r} not found in adata.obs columns.")
parent_mask = adata.obs[parent_col].astype(str) == str(parent)
if parent_mask.sum() == 0:
raise ValueError(f"No cells found with {parent_col} == {parent!r}.")
# Work directly on adata.obs — no copy needed to derive group_order
children_series = adata.obs.loc[parent_mask, child_col].astype(str)
group_counts = children_series.value_counts(sort=False)
if min_cells_per_group is not None:
keep_groups = group_counts[group_counts >= int(min_cells_per_group)].index.astype(str).tolist()
else:
keep_groups = group_counts.index.astype(str).tolist()
if len(keep_groups) == 0:
raise ValueError(
f"After applying min_cells_per_group={min_cells_per_group}, "
f"no {child_col} groups remain within parent {parent!r}."
)
# Preserve observed order (matching original pd.unique behaviour)
kept_mask = children_series.isin(keep_groups)
group_order = list(pd.unique(children_series[kept_mask]))
# -------------------- 2) load consensus markers at child level ---------------
consensus_level = _ZMAP_LEVEL_LOOKUP.get(child_col)
if consensus_level is None:
raise KeyError(
f"child_col={child_col!r} has no mapping in _ZMAP_LEVEL_LOOKUP; "
f"add it to map to a canonical consensus level."
)
marker_df = load_consensus_markers(
level=consensus_level,
groups=group_order, # restrict at load time
marker_type=marker_type,
n_per_group=n_markers_per_group,
min_support_ratio=min_support_ratio,
min_log2fc=min_log2fc,
min_enrich=min_enrich,
omit_unannotated=omit_unannotated,
format="table",
)
if marker_df.empty:
raise ValueError(
f"No markers returned for level={consensus_level!r}, "
f"marker_type={marker_type!r}, groups={group_order}."
)
# Ensure consistent string types
marker_df["celltype"] = marker_df["celltype"].astype(str)
marker_df["gene"] = marker_df["gene"].astype(str)
# Keep only markers for groups actually present
marker_df = marker_df[marker_df["celltype"].isin(group_order)]
if marker_df.empty:
raise ValueError(
"Marker table is empty after restricting to present groups. "
"Check that your consensus tables and obs labels match."
)
# Intersect marker genes with adata.var_names *before* copying
var_names_full = adata.var_names.astype(str)
present_genes = pd.Index(marker_df["gene"].unique()).intersection(var_names_full)
marker_df = marker_df[marker_df["gene"].isin(present_genes)]
if marker_df.empty:
raise ValueError("No marker genes overlap with adata.var_names for this subset.")
# Ordered categorical for consistent row sort
marker_df["celltype"] = pd.Categorical(marker_df["celltype"], categories=group_order, ordered=True)
marker_df = marker_df.sort_values(["celltype"]).reset_index(drop=True)
# -------------------- 3) single combined copy: cells AND genes ---------------
# Now that we know both the cell subset and the gene subset, we do one copy
# instead of two, carrying only the vars that will actually be used.
cell_mask = parent_mask & adata.obs[child_col].astype(str).isin(keep_groups)
adata_tmp = adata[cell_mask, present_genes].copy()
# -------------------- 4) optional row dendrogram on mean expression ----------
# adata_tmp.var is now exactly present_genes, so gene_idx will always be valid;
# the valid_mask guard is kept as a safety net but should be a no-op.
var_names = adata_tmp.var_names.astype(str)
if use_dendrogram_rows:
genes_unique = pd.Index(marker_df["gene"].unique())
gene_idx = var_names.get_indexer(genes_unique)
valid_mask = gene_idx >= 0
genes_unique = genes_unique[valid_mask]
gene_idx = gene_idx[valid_mask]
if genes_unique.empty:
raise ValueError("No valid genes remain for dendrogram computation.")
# choose expression matrix
if use_raw and (adata_tmp.raw is not None):
X = adata_tmp.raw.X[:, gene_idx]
elif layer is not None:
X = adata_tmp.layers[layer][:, gene_idx]
else:
X = adata_tmp.X[:, gene_idx]
X_dense = X.toarray() if hasattr(X, "toarray") else np.asarray(X)
# group means
child_labels = adata_tmp.obs[child_col].astype(str).values
group_to_idx = {g: np.where(child_labels == g)[0] for g in group_order}
mat = np.zeros((len(group_order), len(genes_unique)), dtype=float)
for i, g in enumerate(group_order):
idx = group_to_idx[g]
if idx.size > 0:
sub = X_dense[idx, :]
mat[i, :] = np.asarray(sub.mean(axis=0)).ravel()
# If constant rows, keep original order
if np.allclose(mat, mat[0]):
new_group_order = group_order
else:
d = pdist(mat, metric=dendrogram_metric)
Z = linkage(d, method=dendrogram_method)
leaf_idx = leaves_list(Z)
new_group_order = [group_order[i] for i in leaf_idx]
group_order = new_group_order
marker_df["celltype"] = marker_df["celltype"].cat.set_categories(group_order, ordered=True)
marker_df = marker_df.sort_values(["celltype"]).reset_index(drop=True)
# -------------------- 5) auto figure size if needed -------------------------
if figsize is None:
n_groups = len(group_order)
# Use the *final* number of columns implied by marker_df
if duplicate_gene_columns:
# columns can be repeated per (celltype, gene)
n_cols = marker_df.drop_duplicates(["celltype", "gene"]).shape[0]
else:
# each gene appears only once as a column
n_cols = marker_df["gene"].nunique()
width = left_margin + width_per_gene * max(n_cols, 1) + right_margin
height = top_margin + height_per_group * max(n_groups, 1) + bottom_margin
width = float(np.clip(width, min_figsize[0], max_figsize[0]))
height = float(np.clip(height, min_figsize[1], max_figsize[1]))
figsize = (width, height)
# -------------------- 6) call the core dotplot function ---------------------
ax, grid = plot_dotplot_basegrid(
adata_tmp,
design_df=marker_df.rename(columns={"celltype": "celltype", "gene": "gene"}),
groupby=child_col, # rows = child level
layer=layer,
use_raw=use_raw,
detect_threshold=detect_threshold,
cmap=cmap,
standard_scale=standard_scale,
s_min=s_min,
s_max=s_max,
figsize=figsize,
xlabel_rotation=xlabel_rotation,
title=title,
consensus_panel=consensus_panel,
stacked_bar_show_legend=stacked_bar_show_legend,
group_color_dict=None if group_color_dict is None else dict(group_color_dict),
duplicate_gene_columns=duplicate_gene_columns,
show_size_legend=show_size_legend,
cbar_bbox_to_anchor=(1.04, 0.82),
size_legend_bbox_to_anchor = (0.0, 0.6),
ring_legend_bbox_to_anchor = (0.0, 0.02),
**dotplot_kwargs,
)
return ax, grid, marker_df