import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
# ---------------------------------------------------------------------
# Group orderings & colors
# ---------------------------------------------------------------------
DEFAULT_CLUSTER_ORDER_ZMAP_CellType = [
'blastomeres',
'ectoderm_progenitor', 'fin_epidermis', 'basal_epidermis', 'ionocytes', 'lens', 'lateral_line', 'olfactory', 'otic', 'periderm', 'superficial_epidermis',
'endoderm_progenitor', 'pharyngeal_pouch', 'forerunner_cells', 'intestinal_epithelium', 'hepatocyte', 'endocrine_pancreas', 'exocrine_pancreas',
'neural_crest_progenitor','pharyngeal_arch_crest', 'iridophore', 'melanocyte', 'xanthophore', 'schwann',
'neural_progenitor', 'telencephalon', 'diencephalon', 'mesencephalon', 'rhombocephalon', 'spinal_cord_progenitor', 'differentiating_neurons',
'optic_cup', 'retinal_pigment_epithelium', 'retina', 'photoreceptors',
'mesoderm_progenitor', 'tailbud', 'presomitic_meso', 'adaxial', 'myotome', 'slow_muscle', 'fast_muscle', 'somites', 'dermatome', 'dermis', 'fibroblasts', 'sclerotome', 'axial_skeleton', 'tendon',
'lateral_plate_progenitor', 'pharyngeal_arch_meso', 'head_mesenchyme', 'meninges', 'periocular_meso', 'fin_bud_mesenchyme', 'fin_bud_chondrocytes', 'facial_skeleton', 'cephalic_muscle', 'visceral_smooth_muscle', 'vascular_smooth_muscle', 'cardiac_mesenchyme', 'cardiac_muscle', 'pronephros',
'vasculature', 'vein', 'endocardial', 'aorta', 'lymphatic',
'hematopoietic_progenitor', 'erythroid', 'myeloid_progenitor', 'macrophage', 'neutrophil',
'notochord', 'prechordal_plate', 'hatching_gland',
'pgc_progenitor']
DEFAULT_ROW_COLOR_GROUPS_ZMAP_CellType = {
"black": ["blastomeres"],
"green": ['ectoderm_progenitor', 'fin_epidermis', 'basal_epidermis', 'ionocytes', 'lens', 'lateral_line', 'olfactory', 'otic', 'periderm', 'superficial_epidermis'],
"darkorange": ['endoderm_progenitor', 'pharyngeal_pouch', 'forerunner_cells', 'intestinal_epithelium', 'hepatocyte', 'endocrine_pancreas', 'exocrine_pancreas'],
"steelblue": ['neural_crest_progenitor','pharyngeal_arch_crest', 'iridophore', 'melanocyte', 'xanthophore', 'schwann'],
"blue": ['neural_progenitor', 'telencephalon', 'diencephalon', 'mesencephalon', 'rhombocephalon', 'spinal_cord_progenitor', 'differentiating_neurons'],
"darkblue": ['optic_cup', 'retinal_pigment_epithelium', 'retina', 'photoreceptors'],
"red": ['mesoderm_progenitor', 'tailbud', 'presomitic_meso', 'adaxial', 'myotome', 'slow_muscle', 'fast_muscle', 'somites', 'dermatome', 'dermis', 'fibroblasts', 'sclerotome', 'axial_skeleton', 'tendon'],
"crimson": ['lateral_plate_progenitor', 'pharyngeal_arch_meso', 'head_mesenchyme', 'meninges', 'periocular_meso', 'fin_bud_mesenchyme', 'fin_bud_chondrocytes', 'facial_skeleton', 'cephalic_muscle', 'visceral_smooth_muscle', 'vascular_smooth_muscle', 'cardiac_mesenchyme', 'cardiac_muscle', 'pronephros'],
"orangered": ['vasculature', 'vein', 'endocardial', 'aorta', 'lymphatic'],
"firebrick": ['hematopoietic_progenitor', 'erythroid', 'myeloid_progenitor', 'macrophage', 'neutrophil'],
"r": ['notochord', 'prechordal_plate', 'hatching_gland', ],
"purple": ["pgc_progenitor"],
}
DEFAULT_ROW_DIVIDERS_ZMAP_CellType = [1, 11, 18, 24, 31, 35, 49, 63, 68, 73, 76]
DEFAULT_STUDY_ORDER = ["Kamimoto2023", "Farrell2018", "Kukreja2024", "Wagner2018",
"Farnsworth2020", "Lange2023", "Sur2023", "Spanjaard2018"]
# ---------------------------------------------------------------------
# Shared helpers
# ---------------------------------------------------------------------
def _select_matrix_and_var_names(adata, layer, use_raw):
"""Return (X, var_names) based on raw/layer/X selection."""
if use_raw and getattr(adata, "raw", None) is not None:
X = adata.raw.X
var_names = pd.Index(adata.raw.var_names).astype(str)
elif layer is not None:
X = adata.layers[layer]
var_names = pd.Index(adata.var_names).astype(str)
else:
X = adata.X
var_names = pd.Index(adata.var_names).astype(str)
return X, var_names
def _extract_gene_vector(X, var_names, gene: str) -> np.ndarray:
"""Return 1D expression vector for a given gene name from X."""
gene = str(gene)
if gene not in var_names:
raise ValueError("Gene %r not found in var_names." % gene)
j = int(np.where(var_names == gene)[0][0])
# Handle dense vs sparse
if hasattr(X, "getcol"): # scipy.sparse
col = X.getcol(j).toarray().ravel()
elif hasattr(X, "tocsc") or hasattr(X, "tocsr"):
col = np.asarray(X[:, j].toarray()).ravel()
else:
col = np.asarray(X[:, j]).ravel()
return col
def _extract_color_vector(
adata,
color: str,
layer: str | None,
use_raw: bool | None,
) -> tuple[np.ndarray, str]:
"""
Resolve `color` (typically a gene) to a per-cell vector.
Resolution order:
1) Gene in X/raw/layer (via `_select_matrix_and_var_names`).
2) Numeric column in `adata.obs`.
Parameters
----------
adata
AnnData object.
color
Key describing what to plot. Usually a gene name, but can also
be a numeric `adata.obs` column.
layer, use_raw
Same semantics as in the dotplot functions.
Returns
-------
vec : np.ndarray
1D array of length n_obs with values used for color & size.
source : {"gene", "obs"}
Where the values came from.
"""
from pandas.api.types import is_numeric_dtype
color = str(color)
# Try gene in X/raw/layer
X, var_names = _select_matrix_and_var_names(adata, layer, use_raw)
if color in var_names:
vec = _extract_gene_vector(X, var_names, color)
return vec, "gene"
# Try numeric obs column
if color in adata.obs.columns:
s = adata.obs[color]
if not is_numeric_dtype(s):
raise ValueError(
f"Color key '{color}' found in adata.obs, but dtype is not numeric."
)
vec = pd.to_numeric(s, errors="coerce").to_numpy()
return vec, "obs"
raise ValueError(
f"Color key '{color}' not found as gene (var_names/raw/layer) or numeric adata.obs column."
)
def _scale01(vec: np.ndarray) -> np.ndarray:
"""
In-place scale of finite entries in `vec` to [0, 1].
Non-finite values are left untouched. If the dynamic range is zero,
all finite entries are set to 0.0.
"""
mask = np.isfinite(vec)
if not np.any(mask):
return vec
v = vec[mask]
vmin_ = np.nanmin(v)
vmax_ = np.nanmax(v)
if (not np.isfinite(vmin_)) or (not np.isfinite(vmax_)) or vmax_ == vmin_:
vec[mask] = 0.0
else:
vec[mask] = (vec[mask] - vmin_) / (vmax_ - vmin_)
return vec
def _compute_sizes_from_fraction(
frac: np.ndarray,
s_min: float,
s_max: float,
size_zero_for_missing: bool,
) -> np.ndarray:
"""Map fraction-expressing values to pixel sizes."""
if size_zero_for_missing:
frac_plot = np.nan_to_num(frac, nan=0.0)
else:
frac_plot = frac
f = np.clip(frac_plot, 0.0, 1.0)
sizes = s_min + f * (s_max - s_min)
sizes[~np.isfinite(sizes)] = s_min
return sizes
def _compute_color_norm(vals_color, vmin, vmax):
"""
Compute a matplotlib.Normalize for given color values and vmin/vmax behavior.
If vmin/vmax are None, percentiles (1, 99) are used on finite values.
"""
finite_vals = vals_color[np.isfinite(vals_color)]
if vmin is not None:
_vmin = vmin
else:
_vmin = np.nanpercentile(finite_vals, 1) if finite_vals.size else 0.0
if vmax is not None:
_vmax = vmax
else:
_vmax = np.nanpercentile(finite_vals, 99) if finite_vals.size else 1.0
norm = mpl.colors.Normalize(vmin=_vmin, vmax=_vmax)
return norm, _vmin, _vmax
def _style_axes_spines(ax, lw: float = 0.5):
"""Turn off ticks grid, keep spines thin but visible."""
ax.tick_params(length=0)
ax.grid(False)
for s in ax.spines.values():
s.set_visible(True)
s.set_linewidth(lw)
def _bold_first_rows(ax, row_dividers):
"""Bold the first ytick in each row group defined by row_dividers."""
if not row_dividers:
return
group_starts = [0] + list(row_dividers)
yticklabels = ax.get_yticklabels()
for idx in group_starts:
if 0 <= idx < len(yticklabels):
yticklabels[idx].set_fontweight("bold")
def _color_row_labels(ax, row_color_groups):
"""
Color ytick labels according to row_color_groups mapping:
{color: [cluster1, cluster2, ...]}.
"""
if not row_color_groups:
return
name_to_color = {
name: col for col, names in row_color_groups.items() for name in names
}
for label in ax.get_yticklabels():
txt = label.get_text()
if txt in name_to_color:
label.set_color(name_to_color[txt])
def _draw_tick_grids(
ax,
n_x: int,
n_y: int,
tick_grid_x: bool,
tick_grid_y: bool,
tick_grid_lw: float,
tick_grid_alpha: float,
):
"""Draw faint vertical/horizontal grid lines at integer positions."""
if tick_grid_x:
for xloc in range(n_x):
ax.axvline(
xloc,
color="0.85",
lw=tick_grid_lw,
alpha=tick_grid_alpha,
zorder=0,
)
if tick_grid_y:
for yloc in range(n_y):
ax.axhline(
yloc,
color="0.85",
lw=tick_grid_lw,
alpha=tick_grid_alpha,
zorder=0,
)
def _draw_row_dividers(
ax,
C: int,
row_dividers,
row_divider_color: str,
row_divider_lw: float,
row_divider_alpha: float,
):
"""Draw horizontal lines between rows at indices listed in row_dividers."""
if not row_dividers:
return
for idx in row_dividers:
try:
idx = int(idx)
except Exception:
continue
if 0 <= idx <= C - 1:
y = idx - 0.5
ax.axhline(
y,
color=row_divider_color,
lw=row_divider_lw,
alpha=row_divider_alpha,
zorder=1,
)
def _add_vertical_colorbar(
fig,
norm,
cmap_obj,
*,
left: float,
bottom: float,
width: float,
height: float,
title: str,
tick_fontsize: int = 8,
title_fontsize: int = 8,
):
"""
Add a vertically oriented colorbar to the figure at figure-fraction coords.
Returns the colorbar axis and Colorbar object.
"""
cax = fig.add_axes([left, bottom, width, height])
sm = mpl.cm.ScalarMappable(norm=norm, cmap=cmap_obj)
cbar = fig.colorbar(sm, cax=cax, orientation="vertical")
cbar.ax.set_title(title, fontsize=title_fontsize, pad=6)
cbar.ax.tick_params(labelsize=tick_fontsize)
return cax, cbar
def _add_fraction_size_legend(
fig,
*,
s_min: float,
s_max: float,
size_legend_values,
size_legend_title: str,
size_legend_facecolor: str,
left: float,
bottom: float,
width: float,
height: float,
background: str,
):
"""
Add a legend showing mapping from fraction-expressing to dot size.
size_legend_values: iterable of fractions in [0, 1].
"""
base_vals = np.asarray(
size_legend_values if size_legend_values is not None
else (0.1, 0.25, 0.5, 0.75, 1.0),
dtype=float,
)
base_vals = np.clip(base_vals, 0.0, 1.0)
legend_sizes = s_min + base_vals * (s_max - s_min)
order = np.argsort(base_vals)[::-1]
y_positions = (np.arange(len(order)) * 0.5)[::-1]
leg_ax = fig.add_axes([left, bottom, width, height], facecolor=background)
leg_ax.set_in_layout(False)
for i, idx in enumerate(order):
s_ = legend_sizes[idx]
lab_val = base_vals[idx]
leg_ax.scatter(
[0.2],
[y_positions[i]],
s=s_,
facecolor=size_legend_facecolor,
edgecolor="none",
linewidth=0,
zorder=3,
clip_on=False,
)
label = f"{lab_val * 100:.0f}% "
leg_ax.text(
0.45,
y_positions[i],
label,
va="center",
ha="left",
fontsize=8,
clip_on=False,
)
leg_ax.set_xlim(0.0, 1.0)
if len(order):
leg_ax.set_ylim(-0.5, y_positions.max() + 0.7)
else:
leg_ax.set_ylim(0, 1)
leg_ax.axis("off")
if size_legend_title is None:
size_legend_title = "Fraction\nExpressing"
leg_ax.text(
0.02,
(y_positions.max() + 0.6) if len(order) else 0.6,
size_legend_title,
fontsize=8,
va="bottom",
ha="left",
rotation=0,
clip_on=False,
)
def _add_row_color_legend(
fig,
cluster_order,
row_color_groups,
*,
left: float,
bottom: float,
width: float,
height: float,
background: str,
):
"""
Add a legend summarizing row color groups.
Shows one short colored line per group, labelled by color and count.
"""
if not row_color_groups:
return
counts_by_color = {
col: sum(1 for n in names if n in cluster_order)
for col, names in row_color_groups.items()
}
sw_ax = fig.add_axes([left, bottom, width, height], facecolor=background)
y0 = 0.9
dy = 0.18
for i, (col, count) in enumerate(
sorted(counts_by_color.items(), key=lambda kv: (-kv[1], str(kv[0])))
):
y = y0 - i * dy
sw_ax.plot(
[0.05, 0.20],
[y, y],
color=col,
lw=6,
solid_capstyle="butt",
)
sw_ax.text(
0.25,
y,
f"{col} ({count})",
va="center",
ha="left",
fontsize=8,
clip_on=False,
)
sw_ax.set_xlim(0, 1)
sw_ax.set_ylim(0, 1)
sw_ax.axis("off")
# ---------------------------------------------------------------------
# 1) Dotplot: color feature vs time
# ---------------------------------------------------------------------
[docs]
def gene_groups_vs_time(
adata,
color: str,
*,
groupby: str = "ZMAP_CellType",
time_col: str = "time_block_id",
layer: str | None = "tpm_log",
use_raw: bool | None = None,
detect_threshold: float = 0.0,
show: bool = True,
# ===== color (mean expression) =====
cmap: str = "viridis",
vmin: float | None = 0,
vmax: float | None = None,
standard_scale: str | None = None, # None | "time" | "cluster"
# ===== DOT SIZE (fraction expressing) =====
s_min: float = 4.0,
s_max: float = 60.0,
size_zero_for_missing: bool = True,
# ===== LOW-SUPPORT / MISSING =====
abs_min_cells: int = 10,
rel_min_frac: float = 0.01,
rel_abs_cap: int = 300,
low_support_grey: str = "0.7",
low_support_alpha: float = 0.5,
# ===== layout/labels =====
base_col_width: float = 0.14,
base_row_height: float = 0.14,
gutter_width: float = 1.5, # fixed right gutter (plot area is stable)
gutter_pad: float = 0.03, # extra right-side pad (in figure coords)
figsize: tuple | None = None,
xlabel_rotation: int = 90,
title: str | None = None,
add_colorbar: bool = True,
cbar_title: str = "log(tpm)\ncounts",
# ===== size legend =====
show_size_legend: bool = True,
size_legend_values: tuple | None = None,
size_legend_title: str | None = None,
size_legend_facecolor: str = "black",
# ===== ordering & filtering =====
cluster_order: list[str] | None = DEFAULT_CLUSTER_ORDER_ZMAP_CellType,
time_order: list[str] | None = None,
omit_groups: list[str] | None = ['nan', 'unknown'],
# ===== cosmetics =====
edgecolor: str = "black",
edge_lw: float = 0.1,
background: str = "white",
# ===== faint tick/grid lines =====
tick_grid_x: bool = False,
tick_grid_y: bool = False,
tick_grid_alpha: float = 0.15,
tick_grid_lw: float = 0.6,
# ===== Row color group mapping & legend =====
row_color_groups: dict[str, list[str]] | None = DEFAULT_ROW_COLOR_GROUPS_ZMAP_CellType,
show_row_color_legend: bool = False,
# ===== Horizontal split lines between rows =====
row_dividers: list[int] | None = DEFAULT_ROW_DIVIDERS_ZMAP_CellType,
row_divider_color: str = "black",
row_divider_lw: float = 0.5,
row_divider_alpha: float = 1,
# ===== performance/output controls =====
return_long: bool = False,
):
"""
Dotplot of a single gene across cell types and developmental timepoints.
Renders a ``(cell type × timepoint)`` grid where dot color encodes mean
expression and dot size encodes the fraction of cells expressing the gene
above ``detect_threshold``. Bins with insufficient cells are shown as
grey dots at minimum size.
``color`` is resolved first as a gene name in ``adata.var_names``
(or ``adata.raw`` / ``adata.layers[layer]``); if no match is found it
falls back to a numeric column in ``adata.obs``.
Parameters
----------
adata : anndata.AnnData
Reference dataset containing expression data and ``obs`` columns
for ``groupby`` and ``time_col``.
color : str
Gene name or numeric ``obs`` column to visualize.
groupby : str, default ``"ZMAP_CellType"``
``obs`` column whose categories form the rows.
time_col : str, default ``"time_block_id"``
``obs`` column containing developmental time group labels (columns).
layer : str or None, default ``"tpm_log"``
Layer to use for expression values.
use_raw : bool or None, default ``None``
If ``True``, use ``adata.raw``.
detect_threshold : float, default ``0.0``
Minimum value for a cell to count as "expressing".
show : bool, default ``True``
Call ``plt.show()`` after rendering.
cmap : str, default ``"viridis"``
Colormap for mean expression.
vmin, vmax : float or None
Color scale limits.
standard_scale : str or None, default ``None``
Scale values per timepoint (``"time"``) or per cluster (``"cluster"``).
s_min, s_max : float
Dot size range in points squared.
cluster_order : list of str or None
Explicit row ordering. Defaults to the canonical ZMAP CellType order.
time_order : list of str or None
Explicit column ordering. Inferred from data when ``None``.
omit_groups : list of str or None
Cell-type labels to exclude.
row_color_groups : dict or None
``{color: [group, ...]}`` mapping to color row labels by lineage.
row_dividers : list of int or None
Row indices at which to draw horizontal divider lines.
return_long : bool, default ``False``
Also return the underlying long-form DataFrame.
Returns
-------
tuple of (matplotlib.axes.Axes, pd.DataFrame or None)
``(ax, long_df)`` where ``long_df`` is the aggregated data table
when ``return_long=True``, otherwise ``None``.
Examples
--------
>>> ax, _ = zmap.dotplot.dotplot_gene.gene_groups_vs_time(adata_ref, "sox2")
"""
color = str(color)
# ------------------------ color vector ------------------------
x, color_source = _extract_color_vector(adata, color, layer, use_raw)
# ------------------------ metadata as arrays (no full copy) --------------
obs = adata.obs
if groupby not in obs or time_col not in obs:
raise ValueError(f"Missing columns in adata.obs: need '{groupby}' and '{time_col}'.")
clusters_ser = obs[groupby].astype(str)
times_ser = obs[time_col].astype(str)
# Optional omit
if omit_groups:
omit_set = set(map(str, omit_groups))
keep_mask = ~clusters_ser.isin(omit_set)
clusters_ser = clusters_ser[keep_mask]
times_ser = times_ser[keep_mask]
x = x[keep_mask.values]
else:
omit_set = set()
# Determine orders (respect categorical if present)
if cluster_order is None:
if isinstance(obs[groupby].dtype, pd.CategoricalDtype):
cluster_order = [
c for c in obs[groupby].cat.categories.astype(str) if c not in omit_set
]
else:
cluster_order = list(pd.unique(clusters_ser))
else:
present = set(pd.unique(clusters_ser))
cluster_order = [c for c in cluster_order if c in present]
if time_order is None:
if isinstance(obs[time_col].dtype, pd.CategoricalDtype):
time_order = list(obs[time_col].cat.categories.astype(str))
else:
time_order = list(pd.unique(times_ser))
C = len(cluster_order)
T = len(time_order)
# Map to categorical codes aligned to the chosen orders
clust_cat = pd.Categorical(clusters_ser, categories=cluster_order, ordered=True)
time_cat = pd.Categorical(times_ser, categories=time_order, ordered=True)
c_codes = clust_cat.codes.astype(np.int64)
t_codes = time_cat.codes.astype(np.int64)
valid = (c_codes >= 0) & (t_codes >= 0)
c_codes = c_codes[valid]
t_codes = t_codes[valid]
x = x[valid]
# Linearize (cluster,time) → idx
lin = (c_codes * T + t_codes).astype(np.int64)
# ------------------------ fast aggregation via bincount ------------------
n = np.bincount(lin, minlength=C*T).astype(np.int32)
is_expr = (x > detect_threshold)
k = np.bincount(lin, weights=is_expr.astype(np.int8), minlength=C*T).astype(np.int32)
sum_expr = np.bincount(lin, weights=x, minlength=C*T)
with np.errstate(divide="ignore", invalid="ignore"):
mean_expr = np.where(n > 0, sum_expr / n, np.nan)
# N_clust per cluster, N_time per time
N_clust = np.bincount(c_codes, minlength=C).astype(np.int32)
N_time = np.bincount(t_codes, minlength=T).astype(np.int32)
# ------------------------ low-support eligibility ------------------------
req = np.maximum(
abs_min_cells,
np.minimum(np.ceil(N_clust * float(rel_min_frac)).astype(int), int(rel_abs_cap)),
).astype(np.int32)
req_grid = np.repeat(req[:, None], T, axis=1).ravel()
eligible = (n >= req_grid)
has_cells = (n > 0)
# ------------------------ standard scaling (optional) --------------------
vals_color = mean_expr.copy()
if standard_scale == "time":
vals_color = vals_color.reshape(C, T)
for ci in range(C):
vals_color[ci, :] = _scale01(vals_color[ci, :])
vals_color = vals_color.ravel()
elif standard_scale == "cluster":
vals_color = vals_color.reshape(C, T)
for ti in range(T):
vals_color[:, ti] = _scale01(vals_color[:, ti])
vals_color = vals_color.ravel()
elif standard_scale is not None:
raise ValueError("standard_scale must be None, 'time', or 'cluster'.")
# ------------------------ color mapping ------------------------
cmap_obj = plt.get_cmap(cmap)
norm, _vmin, _vmax = _compute_color_norm(vals_color, vmin, vmax)
colors = np.zeros((C*T, 4), dtype=float)
valid_color = has_cells
colors[valid_color] = cmap_obj(
norm(np.nan_to_num(vals_color[valid_color], nan=_vmin))
)
# ------------------------ dot sizes (fraction expressing) ----------------
with np.errstate(divide="ignore", invalid="ignore"):
frac = np.where(n > 0, k / np.maximum(n, 1), np.nan)
sizes = _compute_sizes_from_fraction(
frac, s_min=s_min, s_max=s_max, size_zero_for_missing=size_zero_for_missing
)
# Treat low-support and missing identically
low_or_missing = (~eligible) | (~has_cells)
if np.any(low_or_missing):
colors[low_or_missing] = mpl.colors.to_rgba(low_support_grey, alpha=low_support_alpha)
sizes[low_or_missing] = s_min
# ------------------------ positions ---------------------------
xs = np.tile(np.arange(T, dtype=float), C)
ys = np.repeat(np.arange(C, dtype=float), T)
# ------------------------ stable-layout figure ---------------------------
plot_w = max(1.5, base_col_width * max(1, T))
plot_h = max(1.5, base_row_height * max(1, C))
fig_w = plot_w + gutter_width + 1.0 # extra buffer for legends
fig_h = plot_h
fig = plt.figure(
figsize=(figsize if figsize is not None else (fig_w, fig_h)),
constrained_layout=False,
facecolor=background,
)
gs = fig.add_gridspec(
nrows=1,
ncols=2,
width_ratios=[plot_w, gutter_width], # dot grid area fixed
wspace=0.02,
)
ax = fig.add_subplot(gs[0, 0], facecolor=background)
# right gutter host (absorbs legends/labels, never alters grid)
gutter_ax = fig.add_subplot(gs[0, 1], facecolor=background)
gutter_ax.set_axis_off()
gutter_ax.set_in_layout(False)
# plot dots
ax.scatter(xs, ys, s=sizes, c=colors, edgecolor=edgecolor, linewidth=edge_lw)
# axes limits & ticks
ax.set_xlim(-0.5, T - 0.5)
ax.set_ylim(C - 0.5, -0.5)
ax.set_xticks(range(T))
ax.set_xticklabels(time_order, rotation=xlabel_rotation, ha="center", fontsize=9)
ax.set_yticks(range(C))
ax.set_yticklabels(cluster_order, fontsize=9)
# label styling
_bold_first_rows(ax, row_dividers)
_color_row_labels(ax, row_color_groups)
# optional faint grid
_draw_tick_grids(
ax,
n_x=T,
n_y=C,
tick_grid_x=tick_grid_x,
tick_grid_y=tick_grid_y,
tick_grid_lw=tick_grid_lw,
tick_grid_alpha=tick_grid_alpha,
)
# horizontal row dividers
_draw_row_dividers(
ax,
C=C,
row_dividers=row_dividers,
row_divider_color=row_divider_color,
row_divider_lw=row_divider_lw,
row_divider_alpha=row_divider_alpha,
)
_style_axes_spines(ax)
ax.set_title(color if title is None else title, fontsize=12)
# ------------------------ colorbar (figure-anchored) ---------------------
if add_colorbar:
left = 1.0 - gutter_pad - 0.12
bottom = 0.7
width = 0.05
height = 0.15
_add_vertical_colorbar(
fig,
norm,
cmap_obj,
left=left,
bottom=bottom,
width=width,
height=height,
title=cbar_title,
)
# ------------------------ size legend (figure-anchored) ------------------
if show_size_legend:
_add_fraction_size_legend(
fig,
s_min=s_min,
s_max=s_max,
size_legend_values=size_legend_values,
size_legend_title=size_legend_title,
size_legend_facecolor=size_legend_facecolor,
left=1.0 - gutter_pad - 0.18,
bottom=0.55,
width=0.22,
height=0.10,
background=background,
)
# ------------------------ row color legend (optional) --------------------
if show_row_color_legend and row_color_groups:
_add_row_color_legend(
fig,
cluster_order,
row_color_groups,
left=1.0 - gutter_pad - 0.26,
bottom=0.08,
width=0.22,
height=0.18,
background=background,
)
# ------------------------ long table output ------------------------------
if return_long:
clust_rep = np.repeat(np.array(cluster_order), T)
time_rep = np.tile(np.array(time_order), C)
long_out = pd.DataFrame({
"cluster": clust_rep,
"time": time_rep,
"n": n.astype(np.int32),
"k": k.astype(np.int32),
"N_clust": np.repeat(N_clust, T).astype(np.int32),
"N_time": np.tile(N_time, C).astype(np.int32),
"mean_expr": mean_expr.astype(float),
"mean_expr_plot": vals_color.astype(float),
"frac_expressing": np.where(n > 0, k / np.maximum(n, 1), np.nan),
"size_signal": np.clip(
np.where(n > 0, k / np.maximum(n, 1), np.nan), 0.0, 1.0
),
"size_pixels": sizes.astype(float),
"is_low_support": (~eligible).astype(bool),
})
else:
long_out = None
if not show:
plt.close(ax.figure)
return ax, long_out
# ---------------------------------------------------------------------
# 2) Dotplot: color feature vs studies
# ---------------------------------------------------------------------
[docs]
def gene_groups_vs_studies(
adata,
color: str,
*,
groupby: str = "ZMAP_CellType", # cluster column in .obs
study_col: str = "study_id", # study column in .obs (x-axis)
layer: str | None = "tpm_log",
use_raw: bool | None = None,
detect_threshold: float = 0.0, # > threshold => “expressing”
show: bool = True,
# ===== color (mean expression) =====
cmap: str = "viridis",
vmin: float | None = 0,
vmax: float | None = None,
standard_scale: str | None = None, # None | "study" | "cluster"
# ===== DOT SIZE (fraction expressing) =====
s_min: float = 4.0,
s_max: float = 60.0,
size_zero_for_missing: bool = True,
# ===== LOW-SUPPORT / MISSING =====
abs_min_cells: int = 10,
rel_min_frac: float = 0.01,
rel_abs_cap: int = 300,
low_support_grey: str = "0.7",
low_support_alpha: float = 0.5,
# ===== layout/labels =====
base_col_width: float = 0.14, # inches per study column (plot area)
base_row_height: float = 0.14, # inches per cluster row (plot area)
gutter_width: float = 1.5, # inches, fixed-width right gutter
gutter_pad: float = 0.03, # extra right-side pad (in figure coords)
figsize: tuple | None = None, # if None, computed from grid + gutter
xlabel_rotation: int = 90,
title: str | None = None, # defaults to color name only
add_colorbar: bool = True,
cbar_title: str = "log(tpm)\ncounts",
# ===== size legend =====
show_size_legend: bool = True,
size_legend_values: tuple | None = None, # fractions to show (0–1)
size_legend_title: str | None = None,
size_legend_facecolor: str = "black",
# ===== ordering & filtering =====
cluster_order: list[str] | None = DEFAULT_CLUSTER_ORDER_ZMAP_CellType,
study_order: list[str] | None = DEFAULT_STUDY_ORDER,
omit_groups: list[str] | None = ['nan', 'unknown'], # omits clusters
# ===== cosmetics =====
edgecolor: str = "black",
edge_lw: float = 0.1,
background: str = "white",
# ===== faint tick/grid lines =====
tick_grid_x: bool = False,
tick_grid_y: bool = False,
tick_grid_alpha: float = 0.15,
tick_grid_lw: float = 0.6,
# ===== Row color group mapping & legend =====
row_color_groups: dict[str, list[str]] | None = DEFAULT_ROW_COLOR_GROUPS_ZMAP_CellType,
show_row_color_legend: bool = False,
# ===== Horizontal split lines between rows =====
row_dividers: list[int] | None = DEFAULT_ROW_DIVIDERS_ZMAP_CellType,
row_divider_color: str = "black",
row_divider_lw: float = 0.5,
row_divider_alpha: float = 1,
# ===== performance/output controls =====
return_long: bool = False,
):
"""
Dotplot of a single gene across cell types and studies.
Renders a ``(cell type × study)`` grid where dot color encodes mean
expression and dot size encodes the fraction of cells expressing the gene
above ``detect_threshold``. Useful for assessing cross-study
reproducibility of marker gene expression patterns.
``color`` is resolved first as a gene name in ``adata.var_names``; if no
match is found it falls back to a numeric column in ``adata.obs``.
Parameters
----------
adata : anndata.AnnData
Reference dataset containing expression data and ``obs`` columns
for ``groupby`` and ``study_col``.
color : str
Gene name or numeric ``obs`` column to visualize.
groupby : str, default ``"ZMAP_CellType"``
``obs`` column whose categories form the rows.
study_col : str, default ``"study_id"``
``obs`` column containing study identifiers (columns).
layer : str or None, default ``"tpm_log"``
Layer to use for expression values.
use_raw : bool or None, default ``None``
If ``True``, use ``adata.raw``.
detect_threshold : float, default ``0.0``
Minimum value for a cell to count as "expressing".
show : bool, default ``True``
Call ``plt.show()`` after rendering.
cmap : str, default ``"viridis"``
Colormap for mean expression.
vmin, vmax : float or None
Color scale limits.
standard_scale : str or None, default ``None``
Scale values per study (``"study"``) or per cluster (``"cluster"``).
s_min, s_max : float
Dot size range in points squared.
cluster_order : list of str or None
Explicit row ordering. Defaults to the canonical ZMAP CellType order.
study_order : list of str or None
Explicit column ordering. Defaults to the canonical ZMAP study order.
omit_groups : list of str or None
Cell-type labels to exclude.
row_color_groups : dict or None
``{color: [group, ...]}`` mapping to color row labels by lineage.
row_dividers : list of int or None
Row indices at which to draw horizontal divider lines.
return_long : bool, default ``False``
Also return the underlying long-form DataFrame.
Returns
-------
tuple of (matplotlib.axes.Axes, pd.DataFrame or None)
``(ax, long_df)`` where ``long_df`` is the aggregated data table
when ``return_long=True``, otherwise ``None``.
Examples
--------
>>> ax, _ = zmap.dotplot.dotplot_gene.gene_groups_vs_studies(adata_ref, "myod1")
"""
color = str(color)
# ------------------------ color vector ------------------------
x, color_source = _extract_color_vector(adata, color, layer, use_raw)
# ------------------------ metadata as arrays (no full copy) --------------
obs = adata.obs
if groupby not in obs or study_col not in obs:
raise ValueError(f"Missing columns in adata.obs: need '{groupby}' and '{study_col}'.")
clusters_ser = obs[groupby].astype(str)
studies_ser = obs[study_col].astype(str)
# Optional omit
if omit_groups:
omit_set = set(map(str, omit_groups))
keep_mask = ~clusters_ser.isin(omit_set)
clusters_ser = clusters_ser[keep_mask]
studies_ser = studies_ser[keep_mask]
x = x[keep_mask.values]
else:
omit_set = set()
# Determine orders (respect categorical if present)
if cluster_order is None:
if isinstance(obs[groupby].dtype, pd.CategoricalDtype):
cluster_order = [
c for c in obs[groupby].cat.categories.astype(str) if c not in omit_set
]
else:
cluster_order = list(pd.unique(clusters_ser))
else:
present = set(pd.unique(clusters_ser))
cluster_order = [c for c in cluster_order if c in present]
if study_order is None:
if isinstance(obs[study_col].dtype, pd.CategoricalDtype):
study_order = list(obs[study_col].cat.categories.astype(str))
else:
study_order = list(pd.unique(studies_ser))
C = len(cluster_order)
S = len(study_order)
# Map to categorical codes aligned to the chosen orders
clust_cat = pd.Categorical(clusters_ser, categories=cluster_order, ordered=True)
study_cat = pd.Categorical(studies_ser, categories=study_order, ordered=True)
c_codes = clust_cat.codes.astype(np.int64)
s_codes = study_cat.codes.astype(np.int64)
valid = (c_codes >= 0) & (s_codes >= 0)
c_codes = c_codes[valid]
s_codes = s_codes[valid]
x = x[valid]
# Linearize (cluster,study) → idx
lin = (c_codes * S + s_codes).astype(np.int64)
# ------------------------ fast aggregation via bincount ------------------
n = np.bincount(lin, minlength=C*S).astype(np.int32)
is_expr = (x > detect_threshold)
k = np.bincount(lin, weights=is_expr.astype(np.int8), minlength=C*S).astype(np.int32)
sum_expr = np.bincount(lin, weights=x, minlength=C*S)
with np.errstate(divide="ignore", invalid="ignore"):
mean_expr = np.where(n > 0, sum_expr / n, np.nan)
# N_clust per cluster, N_study per study
N_clust = np.bincount(c_codes, minlength=C).astype(np.int32)
N_study = np.bincount(s_codes, minlength=S).astype(np.int32)
# ------------------------ low-support eligibility ------------------------
req = np.maximum(
abs_min_cells,
np.minimum(np.ceil(N_clust * float(rel_min_frac)).astype(int), int(rel_abs_cap)),
).astype(np.int32)
req_grid = np.repeat(req[:, None], S, axis=1).ravel()
eligible = (n >= req_grid)
has_cells = (n > 0)
# ------------------------ standard scaling (optional) --------------------
vals_color = mean_expr.copy()
if standard_scale == "study":
vals_color = vals_color.reshape(C, S)
for ci in range(C):
vals_color[ci, :] = _scale01(vals_color[ci, :])
vals_color = vals_color.ravel()
elif standard_scale == "cluster":
vals_color = vals_color.reshape(C, S)
for si in range(S):
vals_color[:, si] = _scale01(vals_color[:, si])
vals_color = vals_color.ravel()
elif standard_scale is not None:
raise ValueError("standard_scale must be None, 'study', or 'cluster'.")
# ------------------------ color mapping ------------------------
cmap_obj = plt.get_cmap(cmap)
norm, _vmin, _vmax = _compute_color_norm(vals_color, vmin, vmax)
colors = np.zeros((C*S, 4), dtype=float)
valid_color = has_cells
colors[valid_color] = cmap_obj(
norm(np.nan_to_num(vals_color[valid_color], nan=_vmin))
)
# ------------------------ dot sizes (fraction expressing) ----------------
with np.errstate(divide="ignore", invalid="ignore"):
frac = np.where(n > 0, k / np.maximum(n, 1), np.nan)
sizes = _compute_sizes_from_fraction(
frac, s_min=s_min, s_max=s_max, size_zero_for_missing=size_zero_for_missing
)
# Treat low-support and missing identically
low_or_missing = (~eligible) | (~has_cells)
if np.any(low_or_missing):
colors[low_or_missing] = mpl.colors.to_rgba(low_support_grey, alpha=low_support_alpha)
sizes[low_or_missing] = s_min
# ------------------------ positions ---------------------------
xs = np.tile(np.arange(S, dtype=float), C)
ys = np.repeat(np.arange(C, dtype=float), S)
# ------------------------ STABLE LAYOUT FIGURE ---------------------------
plot_w = max(1.5, base_col_width * max(1, S))
plot_h = max(1.5, base_row_height * max(1, C))
fig_w = plot_w + gutter_width + 1.0
fig_h = plot_h
fig = plt.figure(
figsize=(figsize if figsize is not None else (fig_w, fig_h)),
constrained_layout=False,
facecolor=background,
)
gs = fig.add_gridspec(
nrows=1,
ncols=2,
width_ratios=[plot_w, gutter_width],
wspace=0.02,
)
ax = fig.add_subplot(gs[0, 0], facecolor=background)
# right gutter host (empty)
gutter_ax = fig.add_subplot(gs[0, 1], facecolor=background)
gutter_ax.set_axis_off()
gutter_ax.set_in_layout(False)
# plot dots
ax.scatter(xs, ys, s=sizes, c=colors, edgecolor=edgecolor, linewidth=edge_lw)
# Exact limits so ticks align with dot centers
ax.set_xlim(-0.5, S - 0.5)
ax.set_ylim(C - 0.5, -0.5)
# ticks & labels
ax.set_xticks(range(S))
ax.set_xticklabels(study_order, rotation=xlabel_rotation, ha="center", fontsize=9)
ax.set_yticks(range(C))
ax.set_yticklabels(cluster_order, fontsize=9)
# label styling
_bold_first_rows(ax, row_dividers)
_color_row_labels(ax, row_color_groups)
# Optional faint tick/grid
_draw_tick_grids(
ax,
n_x=S,
n_y=C,
tick_grid_x=tick_grid_x,
tick_grid_y=tick_grid_y,
tick_grid_lw=tick_grid_lw,
tick_grid_alpha=tick_grid_alpha,
)
# Horizontal row dividers
_draw_row_dividers(
ax,
C=C,
row_dividers=row_dividers,
row_divider_color=row_divider_color,
row_divider_lw=row_divider_lw,
row_divider_alpha=row_divider_alpha,
)
_style_axes_spines(ax)
ax.set_title(color if title is None else title, fontsize=12)
# ------------------------ colorbar (figure-anchored) ---------------------
if add_colorbar:
left = 1.0 - gutter_pad - 0.12
bottom = 0.70
width = 0.05
height = 0.15
_add_vertical_colorbar(
fig,
norm,
cmap_obj,
left=left,
bottom=bottom,
width=width,
height=height,
title=cbar_title,
)
# ------------------------ size legend (figure-anchored) ------------------
if show_size_legend:
_add_fraction_size_legend(
fig,
s_min=s_min,
s_max=s_max,
size_legend_values=size_legend_values,
size_legend_title=size_legend_title,
size_legend_facecolor=size_legend_facecolor,
left=1.0 - gutter_pad - 0.18,
bottom=0.55,
width=0.22,
height=0.10,
background=background,
)
# ------------------------ row color legend (optional) --------------------
if show_row_color_legend and row_color_groups:
_add_row_color_legend(
fig,
cluster_order,
row_color_groups,
left=1.0 - gutter_pad - 0.26,
bottom=0.08,
width=0.22,
height=0.18,
background=background,
)
# ------------------------ long table output ------------------------------
if return_long:
clust_rep = np.repeat(np.array(cluster_order), S)
study_rep = np.tile(np.array(study_order), C)
long_out = pd.DataFrame({
"cluster": clust_rep,
"study": study_rep,
"n": n.astype(np.int32),
"k": k.astype(np.int32),
"N_clust": np.repeat(N_clust, S).astype(np.int32),
"N_study": np.tile(N_study, C).astype(np.int32),
"mean_expr": mean_expr.astype(float),
"mean_expr_plot": vals_color.astype(float),
"frac_expressing": np.where(n > 0, k / np.maximum(n, 1), np.nan),
"size_signal": np.clip(
np.where(n > 0, k / np.maximum(n, 1), np.nan), 0.0, 1.0
),
"size_pixels": sizes.astype(float),
"is_low_support": (~eligible).astype(bool),
})
else:
long_out = None
if not show:
plt.close(ax.figure)
return ax, long_out
# ---------------------------------------------------------------------
# 3) Combined dotplot: time and studies in one figure
# ---------------------------------------------------------------------
[docs]
def gene_groups_vs_time_and_studies(
adata,
color: str,
*,
# shared metadata
groupby: str = "ZMAP_CellType",
time_col: str = "time_group_id",
study_col: str = "study_id",
layer: str | None = "tpm_log",
use_raw: bool | None = None,
detect_threshold: float = 0.0,
show: bool = True,
# ===== COLOR (mean expression) shared across both panels =====
cmap: str = "viridis",
vmin: float | None = 0,
vmax: float | None = None,
cbar_title: str = "log(tpm)\ncounts",
add_colorbar: bool = True,
# ===== DOT SIZE (fraction expressing) shared mapping =====
s_min: float = 4.0,
s_max: float = 60.0,
size_zero_for_missing: bool = True,
# ===== LOW-SUPPORT / MISSING =====
abs_min_cells: int = 10,
rel_min_frac: float = 0.01,
rel_abs_cap: int = 300,
low_support_grey: str = "0.7",
low_support_alpha: float = 0.5,
# ===== layout/labels (SQUARE cells) =====
cell_size: float = 0.14, # inches per data cell in BOTH x and y (for both panels)
middle_gap: float = 0.2, # inches between panels (blank spacer)
gutter_width: float = 0.8, # inches for right gutter (legends live here)
gutter_pad: float = 0.03, # right-side pad in figure coords (for cbar/legends)
figsize: tuple | None = None,
xlabel_rotation: int = 90,
title_left: str | None = "Timepoints",
title_right: str | None = "Studies",
# ===== size legend (shared) =====
show_size_legend: bool = True,
size_legend_values: tuple | None = None,
size_legend_title: str | None = None,
size_legend_facecolor: str = "black",
# ===== ordering & filtering =====
cluster_order: list[str] | None = DEFAULT_CLUSTER_ORDER_ZMAP_CellType,
time_order: list[str] | None = None,
study_order: list[str] | None = DEFAULT_STUDY_ORDER,
omit_groups: list[str] | None = ['nan', 'unknown'],
# ===== cosmetics =====
edgecolor: str = "black",
edge_lw: float = 0.1,
background: str = "white",
# ===== faint tick/grid lines =====
tick_grid_x_left: bool = False,
tick_grid_x_right: bool = False,
tick_grid_y: bool = False,
tick_grid_alpha: float = 0.15,
tick_grid_lw: float = 0.6,
# ===== Row color group mapping & legend =====
row_color_groups: dict[str, list[str]] | None = DEFAULT_ROW_COLOR_GROUPS_ZMAP_CellType,
show_row_color_legend: bool = False,
# ===== Horizontal split lines between rows =====
row_dividers: list[int] | None = DEFAULT_ROW_DIVIDERS_ZMAP_CellType,
row_divider_color: str = "black",
row_divider_lw: float = 0.5,
row_divider_alpha: float = 1.0,
# ===== output controls =====
return_long: bool = False,
):
"""
Two-panel dotplot showing a single gene (or obs feature) across cell types,
split by developmental timepoint and by study.
The left panel shows ``(cell type × timepoint)`` and the right panel shows
``(cell type × study)``, both using the same color scale (mean expression)
and size scale (fraction of cells expressing). This makes it easy to
simultaneously assess temporal dynamics and cross-study reproducibility
for any gene of interest.
``color`` is resolved first as a gene name in ``adata.var_names``; if no
match is found it falls back to a numeric column in ``adata.obs``.
Parameters
----------
adata : anndata.AnnData
Reference dataset. Must contain ``groupby``, ``time_col``, and
``study_col`` in ``adata.obs``.
color : str
Gene name or numeric ``obs`` column to visualize. Gene lookup is
case-sensitive and must match ``adata.var_names`` exactly.
groupby : str, default ``"ZMAP_CellType"``
``obs`` column whose categories form the rows of both panels.
time_col : str, default ``"time_group_id"``
``obs`` column containing developmental time group labels
(e.g. ``"6hpf"``, ``"24hpf"``). Forms the columns of the left panel.
study_col : str, default ``"study_id"``
``obs`` column containing dataset/study identifiers. Forms the columns
of the right panel.
layer : str or None, default ``"tpm_log"``
Layer in ``adata.layers`` to use for expression values. Falls back to
``adata.X`` when ``None``.
use_raw : bool or None, default ``None``
If ``True``, read expression from ``adata.raw``. Overrides ``layer``.
detect_threshold : float, default ``0.0``
Minimum value for a cell to count as "expressing" when computing
the fraction-expressing dot size.
show : bool, default ``True``
Call ``plt.show()`` after rendering.
cmap : str, default ``"viridis"``
Matplotlib colormap for dot fill color (mean expression).
vmin : float or None, default ``0``
Minimum value for the color scale.
vmax : float or None, default ``None``
Maximum value for the color scale. Inferred from data when ``None``.
cbar_title : str, default ``"log(tpm)\\ncounts"``
Label for the colorbar.
add_colorbar : bool, default ``True``
Draw a colorbar for the mean expression scale.
s_min, s_max : float, default ``4.0`` and ``60.0``
Minimum and maximum dot sizes (in points²) for the fraction-expressing scale.
abs_min_cells : int, default ``10``
Minimum absolute number of cells in a bin to render a dot. Bins with
fewer cells are shown as low-support (grey).
rel_min_frac : float, default ``0.01``
Minimum fraction of the median bin size required to render a dot.
low_support_grey : str, default ``"0.7"``
Greyscale color for low-support / missing bins.
cell_size : float, default ``0.14``
Size of each data cell in both x and y dimensions, in inches.
Determines the aspect ratio and overall figure size when ``figsize=None``.
middle_gap : float, default ``0.2``
Gap between the left (time) and right (study) panels, in inches.
cluster_order : list of str or None
Explicit ordering of rows (cell types). Defaults to the canonical
ZMAP CellType order defined in the package.
time_order : list of str or None, default ``None``
Explicit ordering of timepoint columns. Inferred from data when ``None``.
study_order : list of str or None
Explicit ordering of study columns. Defaults to the canonical ZMAP
study order defined in the package.
omit_groups : list of str or None, default ``['nan', 'unknown']``
Cell-type labels to exclude from the rows.
row_color_groups : dict or None
``{color: [group1, group2, ...]}`` mapping used to color row labels
by lineage. Defaults to the canonical ZMAP lineage color scheme.
row_dividers : list of int or None
Row indices at which to draw horizontal divider lines between lineage
blocks. Defaults to the canonical ZMAP lineage boundaries.
return_long : bool, default ``False``
If ``True``, also return the long-form ``pd.DataFrame`` used to
generate the plot.
Returns
-------
matplotlib.figure.Figure or tuple
The rendered figure, or ``(figure, long_df)`` when ``return_long=True``.
Examples
--------
>>> zmap.dotplot.gene_view(adata_ref, "sox2")
>>> zmap.dotplot.gene_view(adata_ref, "myod1", cmap="Reds", vmax=5)
"""
color = str(color)
# ------------------------ color vector ------------------------
x_full, color_source = _extract_color_vector(adata, color, layer, use_raw)
obs = adata.obs
if groupby not in obs or time_col not in obs or study_col not in obs:
raise ValueError(
f"Missing columns in adata.obs: need '{groupby}', '{time_col}', and '{study_col}'."
)
clusters_ser = obs[groupby].astype(str)
times_ser = obs[time_col].astype(str)
studies_ser = obs[study_col].astype(str)
# Optional omit
if omit_groups:
omit_set = set(map(str, omit_groups))
keep_mask = ~clusters_ser.isin(omit_set)
clusters_ser = clusters_ser[keep_mask]
times_ser = times_ser[keep_mask]
studies_ser = studies_ser[keep_mask]
x_full = x_full[keep_mask.values]
else:
omit_set = set()
# Row order (clusters) once
if cluster_order is None:
if isinstance(obs[groupby].dtype, pd.CategoricalDtype):
cluster_order = [
c for c in obs[groupby].cat.categories.astype(str) if c not in omit_set
]
else:
cluster_order = list(pd.unique(clusters_ser))
else:
present = set(pd.unique(clusters_ser))
cluster_order = [c for c in cluster_order if c in present]
C = len(cluster_order)
# Column orders
if time_order is None:
if isinstance(obs[time_col].dtype, pd.CategoricalDtype):
time_order = list(obs[time_col].cat.categories.astype(str))
else:
time_order = list(pd.unique(times_ser))
if study_order is None:
if isinstance(obs[study_col].dtype, pd.CategoricalDtype):
study_order = list(obs[study_col].cat.categories.astype(str))
else:
study_order = list(pd.unique(studies_ser))
T, S = len(time_order), len(study_order)
# Codes
clust_cat_t = pd.Categorical(clusters_ser, categories=cluster_order, ordered=True)
time_cat = pd.Categorical(times_ser, categories=time_order, ordered=True)
clust_cat_s = pd.Categorical(clusters_ser, categories=cluster_order, ordered=True)
study_cat = pd.Categorical(studies_ser, categories=study_order, ordered=True)
c_codes_t = clust_cat_t.codes.astype(np.int64)
t_codes = time_cat.codes.astype(np.int64)
c_codes_s = clust_cat_s.codes.astype(np.int64)
s_codes = study_cat.codes.astype(np.int64)
valid_t = (c_codes_t >= 0) & (t_codes >= 0)
valid_s = (c_codes_s >= 0) & (s_codes >= 0)
# -------- aggregation helper (local on purpose) --------
def _aggregate(c_codes, x_codes, Xvec, C, Xn):
lin = (c_codes * Xn + x_codes).astype(np.int64)
n = np.bincount(lin, minlength=C*Xn).astype(np.int32)
is_expr = (Xvec > detect_threshold)
k = np.bincount(lin, weights=is_expr.astype(np.int8), minlength=C*Xn).astype(np.int32)
sum_expr = np.bincount(lin, weights=Xvec, minlength=C*Xn)
with np.errstate(divide="ignore", invalid="ignore"):
mean_expr = np.where(n > 0, sum_expr / n, np.nan)
N_clust = np.bincount(c_codes, minlength=C).astype(np.int32)
N_x = np.bincount(x_codes, minlength=Xn).astype(np.int32)
req = np.maximum(
abs_min_cells,
np.minimum(np.ceil(N_clust * float(rel_min_frac)).astype(int), int(rel_abs_cap)),
).astype(np.int32)
req_grid = np.repeat(req[:, None], Xn, axis=1).ravel()
eligible = (n >= req_grid)
has_cells = (n > 0)
with np.errstate(divide="ignore", invalid="ignore"):
frac = np.where(n > 0, k / np.maximum(n, 1), np.nan)
return n, k, mean_expr, frac, eligible, has_cells, N_clust, N_x
n_t, k_t, mean_t, frac_t, elig_t, has_t, Nclust_t, Ntime = _aggregate(
c_codes_t[valid_t], t_codes[valid_t], x_full[valid_t], C, T,
)
n_s, k_s, mean_s, frac_s, elig_s, has_s, Nclust_s, Nstudy = _aggregate(
c_codes_s[valid_s], s_codes[valid_s], x_full[valid_s], C, S,
)
# ----- absolute color values (no standardization) -----
vals_t = mean_t.copy()
vals_s = mean_s.copy()
# Shared color scale across both panels
cmap_obj = plt.get_cmap(cmap)
finite_all = np.concatenate([
vals_t[np.isfinite(vals_t)],
vals_s[np.isfinite(vals_s)],
])
if vmin is not None:
_vmin = vmin
else:
_vmin = np.nanpercentile(finite_all, 1) if finite_all.size else 0.0
if vmax is not None:
_vmax = vmax
else:
_vmax = np.nanpercentile(finite_all, 99) if finite_all.size else 1.0
norm = mpl.colors.Normalize(vmin=_vmin, vmax=_vmax)
colors_t = np.zeros((C*T, 4), dtype=float)
colors_s = np.zeros((C*S, 4), dtype=float)
if np.any(has_t):
colors_t[has_t] = cmap_obj(
norm(np.nan_to_num(vals_t[has_t], nan=_vmin))
)
if np.any(has_s):
colors_s[has_s] = cmap_obj(
norm(np.nan_to_num(vals_s[has_s], nan=_vmin))
)
# Sizes (fraction expressing) – shared s_min, s_max, size_zero_for_missing
sizes_t = _compute_sizes_from_fraction(
frac_t, s_min=s_min, s_max=s_max, size_zero_for_missing=size_zero_for_missing
)
sizes_s = _compute_sizes_from_fraction(
frac_s, s_min=s_min, s_max=s_max, size_zero_for_missing=size_zero_for_missing
)
# Low-support / missing
low_or_missing_t = (~elig_t) | (~has_t)
low_or_missing_s = (~elig_s) | (~has_s)
if np.any(low_or_missing_t):
colors_t[low_or_missing_t] = mpl.colors.to_rgba(
low_support_grey, alpha=low_support_alpha
)
sizes_t[low_or_missing_t] = s_min
if np.any(low_or_missing_s):
colors_s[low_or_missing_s] = mpl.colors.to_rgba(
low_support_grey, alpha=low_support_alpha
)
sizes_s[low_or_missing_s] = s_min
# Positions
xs_t = np.tile(np.arange(T, dtype=float), C)
ys_t = np.repeat(np.arange(C, dtype=float), T)
xs_s = np.tile(np.arange(S, dtype=float), C)
ys_s = np.repeat(np.arange(C, dtype=float), S)
# ===================== LAYOUT IN CELL UNITS =====================
gap_cells = max(0.0, middle_gap / max(cell_size, 1e-9))
gutter_cells = max(0.0, gutter_width / max(cell_size, 1e-9))
fig_w_cells = T + gap_cells + S + gutter_cells
fig_h_cells = C
if figsize is None:
fig_w = fig_w_cells * cell_size
fig_h = fig_h_cells * cell_size
else:
fig_w, fig_h = figsize
fig = plt.figure(
figsize=(fig_w, fig_h),
constrained_layout=False,
facecolor=background,
)
gs = fig.add_gridspec(
nrows=1,
ncols=4,
width_ratios=[T, gap_cells, S, gutter_cells],
wspace=0.0,
)
ax_left = fig.add_subplot(gs[0, 0], facecolor=background)
spacer = fig.add_subplot(gs[0, 1], facecolor=background)
spacer.set_axis_off()
ax_right = fig.add_subplot(gs[0, 2], facecolor=background)
gutter_ax = fig.add_subplot(gs[0, 3], facecolor=background)
gutter_ax.set_axis_off()
gutter_ax.set_in_layout(False)
# ===============================================================
# Draw dots
ax_left.scatter(
xs_t, ys_t, s=sizes_t, c=colors_t, marker="o",
edgecolor=edgecolor, linewidth=edge_lw,
)
ax_right.scatter(
xs_s, ys_s, s=sizes_s, c=colors_s, marker="o",
edgecolor=edgecolor, linewidth=edge_lw,
)
# Limits & ticks
ax_left.set_xlim(-0.5, T - 0.5)
ax_left.set_ylim(C - 0.5, -0.5)
ax_right.set_xlim(-0.5, S - 0.5)
ax_right.set_ylim(C - 0.5, -0.5)
ax_left.set_xticks(range(T))
ax_left.set_xticklabels(time_order, rotation=xlabel_rotation, ha="center", fontsize=9)
ax_right.set_xticks(range(S))
ax_right.set_xticklabels(study_order, rotation=xlabel_rotation, ha="center", fontsize=9)
ax_left.set_yticks(range(C))
ax_left.set_yticklabels(cluster_order, fontsize=9)
ax_right.set_yticks(range(C))
ax_right.set_yticklabels([""] * C)
# Label styling on left
_bold_first_rows(ax_left, row_dividers)
_color_row_labels(ax_left, row_color_groups)
# Optional faint grid
_draw_tick_grids(
ax_left,
n_x=T,
n_y=C,
tick_grid_x=tick_grid_x_left,
tick_grid_y=tick_grid_y,
tick_grid_lw=tick_grid_lw,
tick_grid_alpha=tick_grid_alpha,
)
_draw_tick_grids(
ax_right,
n_x=S,
n_y=C,
tick_grid_x=tick_grid_x_right,
tick_grid_y=tick_grid_y,
tick_grid_lw=tick_grid_lw,
tick_grid_alpha=tick_grid_alpha,
)
# Row dividers
_draw_row_dividers(
ax_left,
C=C,
row_dividers=row_dividers,
row_divider_color=row_divider_color,
row_divider_lw=row_divider_lw,
row_divider_alpha=row_divider_alpha,
)
_draw_row_dividers(
ax_right,
C=C,
row_dividers=row_dividers,
row_divider_color=row_divider_color,
row_divider_lw=row_divider_lw,
row_divider_alpha=row_divider_alpha,
)
# cosmetics
_style_axes_spines(ax_left)
_style_axes_spines(ax_right)
# Titles over each panel (keep small)
if title_left:
ax_left.set_title(f"{title_left}", fontsize=8)
if title_right:
ax_right.set_title(f"{title_right}", fontsize=8)
# ------------------------ SHARED colorbar in gutter + rotated label -----
if add_colorbar:
left_cb = 1.0 - gutter_pad - 0.12
bottom_cb = 0.70
width_cb = 0.02
height_cb = 0.15
_add_vertical_colorbar(
fig,
norm,
cmap_obj,
left=left_cb,
bottom=bottom_cb,
width=width_cb,
height=height_cb,
title=cbar_title,
)
# Rotated label: centered vertically, to the LEFT of the colorbar
gene_pad = 0.025 # horizontal spacing from cbar
fig.text(
left_cb - gene_pad,
bottom_cb + height_cb / 2.0,
color,
rotation=90,
rotation_mode="anchor",
va="center",
ha="center",
fontsize=10,
)
# ------------------------ SHARED size legend in gutter -------------------
if show_size_legend:
_add_fraction_size_legend(
fig,
s_min=s_min,
s_max=s_max,
size_legend_values=size_legend_values,
size_legend_title=size_legend_title,
size_legend_facecolor=size_legend_facecolor,
left=1.0 - gutter_pad - 0.14,
bottom=0.55,
width=0.12,
height=0.10,
background=background,
)
# ------------------------ optional row-color legend (gutter) -------------
if show_row_color_legend and row_color_groups:
_add_row_color_legend(
fig,
cluster_order,
row_color_groups,
left=1.0 - gutter_pad - 0.26,
bottom=0.08,
width=0.22,
height=0.18,
background=background,
)
# ------------------------ optional returns -------------------------------
long_time = None
long_study = None
if return_long:
clust_rep_t = np.repeat(np.array(cluster_order), T)
time_rep = np.tile(np.array(time_order), C)
long_time = pd.DataFrame({
"panel": "time",
"cluster": clust_rep_t,
"time": time_rep,
"n": n_t.astype(np.int32),
"k": k_t.astype(np.int32),
"mean_expr": mean_t.astype(float),
"mean_expr_plot": vals_t.astype(float),
"frac_expressing": np.where(n_t > 0, k_t / np.maximum(n_t, 1), np.nan),
"size_pixels": sizes_t.astype(float),
"is_low_support": (~elig_t).astype(bool),
})
clust_rep_s = np.repeat(np.array(cluster_order), S)
study_rep = np.tile(np.array(study_order), C)
long_study = pd.DataFrame({
"panel": "study",
"cluster": clust_rep_s,
"study": study_rep,
"n": n_s.astype(np.int32),
"k": k_s.astype(np.int32),
"mean_expr": mean_s.astype(float),
"mean_expr_plot": vals_s.astype(float),
"frac_expressing": np.where(n_s > 0, k_s / np.maximum(n_s, 1), np.nan),
"size_pixels": sizes_s.astype(float),
"is_low_support": (~elig_s).astype(bool),
})
if not show:
plt.close(fig)
return fig, (ax_left, ax_right), (long_time, long_study)