"""Dotplot wrapper around ``scanpy.pl.dotplot`` with per-gene max-normalized color.
Usage
-----
>>> import sceleto as scl
>>> scl.dotplot(adata, ['CD3D', 'CD8A'], 'leiden')
For marker outputs, prefer the convenience method:
>>> mk = scl.markers.simple(adata, 'leiden')
>>> mk.plot()
"""
from __future__ import annotations
import warnings
from collections.abc import Mapping, Sequence
from typing import Optional, Tuple, Union
import anndata
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import scanpy as sc
from scipy import sparse
_LAYER_NAME = "_scl_scaled"
# ── helpers ─────────────────────────────────────────────────────────
def _resolve_var_names(var_names, available: set):
"""Return ``(var_group_dict_or_None, flat_gene_list)``.
- If *var_names* is a mapping, keep the mapping structure so scanpy
renders bracket-grouped x-axis labels. Tuple entries ``(gene, score)``
are accepted.
- Groups with no valid genes are silently dropped.
- Genes absent from *available* are dropped.
- For mappings, the flat list is built round-robin (one gene per cluster
per round) so every cluster with unique genes gets represented even when
markers are shared across clusters.
"""
if isinstance(var_names, Mapping):
clean: dict = {}
for k, items in var_names.items():
names = []
for it in items:
g = it[0] if isinstance(it, tuple) else str(it)
if g in available:
names.append(g)
if names: # groups with no valid genes are silently dropped
clean[k] = names
# flat list preserves duplicates so the same gene can appear in
# multiple bracket groups (scanpy DotPlot handles this fine).
flat: list = []
for gs in clean.values():
flat.extend(gs)
return clean, flat
flat = [g for g in var_names if g in available]
return None, flat
def _check_log1p_normalized(X, label: str = "adata.X"):
"""Check if *X* looks like log1p-normalized data.
- Negative values → ``ValueError`` (definitive failure).
- Max > 30 → ``UserWarning`` only (heuristic; proceed anyway).
"""
x_sub = X[: min(500, X.shape[0])]
min_val = float(np.asarray(x_sub.min() if sparse.issparse(x_sub) else x_sub.min()))
if min_val < 0:
raise ValueError(
f"{label} has negative values — looks like scaled data."
)
max_val = float(np.asarray(x_sub.max() if sparse.issparse(x_sub) else x_sub.max()))
if max_val > 30:
warnings.warn(
f"{label} max = {max_val:.1f}; may not be log1p-normalized.",
UserWarning,
stacklevel=3,
)
def _add_scaled_layer(adata, groupby: str, layer_name: str = _LAYER_NAME):
"""Attach a per-gene max-normalized layer to *adata* (in-place).
For each gene g, ``gene_max[g] = max over groups of (group mean of X[:, g])``.
The layer stores ``X[:, g] / gene_max[g]``. Because mean is linear, the
per-group mean of the layer equals ``group_mean / gene_max`` — i.e. the
``x / max`` normalization used in ``sceleto.markers``.
"""
X = adata.X
labels = adata.obs[groupby].astype(str).to_numpy()
groups_u = np.unique(labels)
gene_max = np.zeros(adata.n_vars, dtype=np.float64)
for g in groups_u:
mask = labels == g
if not mask.any():
continue
mean_g = np.asarray(X[mask].mean(axis=0)).ravel()
gene_max = np.maximum(gene_max, mean_g)
gene_max[gene_max == 0] = 1.0
inv = 1.0 / gene_max
if sparse.issparse(X):
adata.layers[layer_name] = X @ sparse.diags(inv)
else:
adata.layers[layer_name] = np.asarray(X) * inv[np.newaxis, :]
# ── main API ────────────────────────────────────────────────────────
[docs]
def dotplot(
adata,
var_names: Union[Sequence[str], Mapping[str, Sequence]],
groupby: str,
*,
groups: Optional[Sequence[str]] = None,
swap_axes: bool = False,
use_raw: bool = True,
dendrogram: bool = False,
cmap: str = "OrRd",
figsize: Optional[Tuple[float, float]] = None,
save: Optional[str] = None,
show: bool = True,
**kwargs,
):
"""Dotplot with per-gene max-normalized color, built on ``scanpy.pl.dotplot``.
Size encodes fraction of cells expressing the gene (scanpy default).
Color encodes ``group_mean(gene) / max_group(group_mean(gene))`` per gene,
so ``vmax=1`` always corresponds to the highest-expressing group.
Follows scanpy's default axis orientation: genes on x-axis, groups on y-axis.
Pass ``swap_axes=True`` to put genes on y-axis, groups on x-axis.
Parameters
----------
adata
AnnData with log1p-normalized expression.
var_names
Gene list or ``{bracket_name: [gene, ...]}`` / ``{bracket_name: [(gene, score), ...]}``
mapping. Mappings render as bracket-grouped labels via scanpy.
groupby
Column in ``adata.obs`` to group cells by.
groups
Subset of groups to display. ``None`` shows all.
swap_axes
If ``True``, genes on y-axis, groups on x-axis (swaps scanpy default).
use_raw
If ``True`` (default), read from ``adata.raw.X``.
If ``False``, read from ``adata.X``.
Both sources are checked for log1p normalization.
cmap
Matplotlib colormap for color scale (default ``OrRd``).
figsize
Manual ``(width, height)`` in inches.
save
Path to save figure (PDF, dpi=300).
show
Whether to call ``plt.show()``.
**kwargs
Forwarded to ``scanpy.pl.dotplot``.
"""
# ── select expression source ─────────────────────────────────────
if use_raw:
if adata.raw is None:
raise ValueError("use_raw=True but adata.raw is None.")
src_var_names = list(adata.raw.var_names)
else:
src_var_names = list(adata.var_names)
available = set(src_var_names)
var_group_dict, flat_genes = _resolve_var_names(var_names, available)
if not flat_genes:
raise ValueError("sceleto.dotplot: none of the provided genes are in var_names.")
# unique genes for building the intermediate AnnData (var_names must be unique)
unique_genes = list(dict.fromkeys(flat_genes))
# ── filter cells ─────────────────────────────────────────────────
if groups is not None:
cell_mask = adata.obs[groupby].astype(str).isin([str(g) for g in groups]).values
adata_c = adata[cell_mask]
else:
adata_c = adata
# ── build working AnnData ─────────────────────────────────────────
if use_raw:
gene_idx = np.array([src_var_names.index(g) for g in unique_genes])
X_work = adata_c.raw.X[:, gene_idx]
_check_log1p_normalized(X_work, "adata.raw.X")
X_copy = X_work.copy() if sparse.issparse(X_work) else np.asarray(X_work)
ad = anndata.AnnData(
X=X_copy,
obs=adata_c.obs[[groupby]].copy(),
var=pd.DataFrame(index=pd.Index(unique_genes)),
)
else:
ad = adata_c[:, unique_genes].copy()
_check_log1p_normalized(ad.X, "adata.X")
# ── per-gene max-normalized layer ─────────────────────────────────
_add_scaled_layer(ad, groupby, layer_name=_LAYER_NAME)
# dict → bracket-grouped x-axis via scanpy; else flat list
sc_var = var_group_dict if var_group_dict is not None else flat_genes
# Block kwargs that conflict with sceleto's normalization logic
_BLOCKED = {
"layer", "standard_scale", # normalization
"vmin", "vmax", "vcenter", "norm", # color range
"var_group_positions", "var_group_labels", # bracket structure
"dot_color_df", "dot_size_df", # bypass sceleto logic entirely
}
bad = _BLOCKED & set(kwargs)
if bad:
raise ValueError(f"sceleto.dotplot: {sorted(bad)} cannot be set.")
# Split kwargs: .style() params must not go to the constructor
_STYLE_KEYS = {
"color_on", "dot_max", "dot_min", "smallest_dot", "largest_dot",
"size_exponent", "grid", "x_padding", "y_padding",
}
style_kwargs = {k: v for k, v in kwargs.items() if k in _STYLE_KEYS}
dp_kwargs = {k: v for k, v in kwargs.items() if k not in _STYLE_KEYS}
# Use DotPlot class API directly: module-level sc.pl.dotplot does not
# expose dot_edge_* in scanpy 1.12; those live on DotPlot.style().
# ── compact figsize ───────────────────────────────────────────────
# Passing figsize directly sets min_figure_height = figsize[1], causing
# legend to scale with plot size. Instead let scanpy auto-calculate by
# overriding per-cell size on the instance (read in make_figure()).
dp = sc.pl.DotPlot(
ad,
sc_var,
groupby,
use_raw=False,
layer=_LAYER_NAME,
vmin=0,
vmax=1,
figsize=figsize,
**dp_kwargs,
)
if dendrogram:
dendro_key = f"dendrogram_{groupby}"
if dendro_key not in adata.uns:
sc.tl.dendrogram(adata, groupby)
ad.uns[dendro_key] = adata.uns[dendro_key]
dp.add_dendrogram()
if figsize is None:
dp.DEFAULT_CATEGORY_HEIGHT = 0.27
dp.DEFAULT_CATEGORY_WIDTH = 0.29
style_kwargs.setdefault("x_padding", 0.6)
style_kwargs.setdefault("y_padding", 0.6)
dp.style(cmap=cmap, dot_edge_color="none", dot_edge_lw=0, **style_kwargs)
dp.legend(colorbar_title="Max-scaled\nmean")
if swap_axes:
dp.swap_axes()
dp.make_figure()
if save:
dp.fig.savefig(save, bbox_inches="tight", format="pdf", dpi=300)
if show:
plt.show()