from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Sequence, Tuple, Optional
import numpy as np
import pandas as pd
from ._gene_filter import GeneFilter
@dataclass(frozen=True)
class BatchExpression:
"""Per-batch expression arrays for one resolution level."""
mean: np.ndarray # (n_groups, n_batches, n_genes)
frac_expr: np.ndarray # (n_groups, n_batches, n_genes)
n_cells: np.ndarray # (n_groups, n_batches) — cells per group/batch
groups: List[str]
batches: List[str]
genes: np.ndarray
group_to_idx: Dict[str, int]
[docs]
@dataclass
class HierarchyRun:
# Inputs / meta
levels: List[str]
params: Dict[str, Any]
# Key artifacts
icls_full_dict: Dict[str, str]
icls_path_df: pd.DataFrame
marker_rank_df: pd.DataFrame
# Full (untruncated) ranked gene lists per leiden ID
full_gene_lists: Dict[str, List[str]]
# Expression contexts per resolution (groupby -> MarkerContext)
contexts: Dict[str, Any]
# Per-batch expression data (groupby -> BatchExpression); None if no batch_key
batch_expression: Optional[Dict[str, BatchExpression]]
# Batch key used (None if not provided)
batch_key: Optional[str]
[docs]
def interactive_viewer(
self,
adata,
mgr,
*,
save: str = "interactive_viewer.html",
n_top: Optional[int] = None,
) -> None:
"""Generate an interactive HTML viewer with edge-activation panel.
Layout: icls UMAP (left) + marker comparison heatmap (top-right) +
per-gene edge-activation graph (bottom-right). In batch mode the
heatmap shows per-batch expression strips instead of presence.
Parameters
----------
adata
AnnData with ``obs['icls']`` (set by hierarchy) and
``obsm['X_umap']``.
mgr
:class:`sceleto.markers.graph.MarkerGraphRun` driving the
bottom-right edge-activation graph. Typically::
mgr = scl.markers.marker(adata, "icls")
save
Output HTML file path.
n_top
Number of top markers per cluster. Defaults to the value
used in the hierarchy run.
"""
if n_top is None:
n_top = self.params["n_top_markers"]
if self.batch_expression is not None:
from ._viewer import build_interactive_html_batch
build_interactive_html_batch(
adata=adata,
icls_full_dict=self.icls_full_dict,
full_gene_lists=self.full_gene_lists,
batch_expression=self.batch_expression,
n_top=n_top,
save=save,
mgr=mgr,
)
else:
from ._branching_viewer import build_branching_html
build_branching_html(
adata=adata,
hr=self,
save=save,
mgr=mgr,
n_top=n_top,
)
[docs]
def compare_markers(
self,
icls: str,
*,
figsize=None,
gene_filter: Optional[GeneFilter] = None,
return_genes: bool = False,
):
"""Visualize top-N marker overlap across levels for a given icls."""
import matplotlib.pyplot as plt
import seaborn as sns
leiden_list = self.icls_path_df.set_index("icls").loc[icls, self.levels].tolist()
n = self.params["n_top_markers"]
sets = _build_gene_sets(leiden_list, self.full_gene_lists, n, gene_filter)
union = sorted(set().union(*sets))
df = pd.DataFrame(
{lid: [1 if g in s else 0 for g in union] for lid, s in zip(leiden_list, sets)},
index=union,
).sort_values(leiden_list, ascending=False).T
if return_genes:
return union
from matplotlib.patches import Patch
if figsize is None:
figsize = (len(union) * 0.4, 2)
fig, ax = plt.subplots(figsize=figsize)
cmap = plt.get_cmap("Blues")
sns.heatmap(
df, cmap=cmap, linewidths=0.5, linecolor="black",
cbar=False, xticklabels=True, ax=ax,
)
ax.set_title(f"Marker genes for path {icls}")
ax.set_xlabel("")
legend_handles = [
Patch(facecolor=cmap(1.0), edgecolor="black", label="in top-N markers"),
Patch(facecolor=cmap(0.0), edgecolor="black", label="not in top-N markers"),
]
ax.legend(
handles=legend_handles, loc="upper left",
bbox_to_anchor=(1.0, 1.0), fontsize=7, frameon=False,
)
plt.close(fig)
return fig
[docs]
def compare_markers_batch(
self,
icls: str,
*,
figsize=None,
gene_filter: Optional[GeneFilter] = None,
return_genes: bool = False,
):
"""Visualize top-N marker overlap with per-batch expression strips.
Each strip in a (level, gene) cell encodes one batch:
- grey : batch has no cells in this cluster (no data)
- white : batch has cells but mean expression is 0
- red : colored by mean / cell_max, where cell_max is the
maximum batch mean within that (cluster, gene) cell
Color scale is per-cell normalized (0–1), so each cell's brightest
batch is always 1. This makes batch consistency visible regardless
of absolute expression level.
"""
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from matplotlib.patches import Rectangle, Patch
if self.batch_expression is None:
raise ValueError(
"No batch expression data. Re-run hierarchy() with batch_key."
)
leiden_list = (
self.icls_path_df.set_index("icls").loc[icls, self.levels].tolist()
)
n = self.params["n_top_markers"]
sets = _build_gene_sets(leiden_list, self.full_gene_lists, n, gene_filter)
union = sorted(set().union(*sets))
# sort genes by presence pattern
presence_df = pd.DataFrame(
{lid: [1 if g in s else 0 for g in union]
for lid, s in zip(leiden_list, sets)},
index=union,
).sort_values(leiden_list, ascending=False)
union = presence_df.index.tolist()
if return_genes:
return union
n_rows = len(leiden_list)
n_cols = len(union)
batch_data = _collect_batch_values(leiden_list, union, self.batch_expression)
n_batches = len(next(iter(self.batch_expression.values())).batches)
if figsize is None:
figsize = (n_cols * 0.6, n_rows * 0.8 + 0.5)
fig, ax = plt.subplots(figsize=figsize)
cmap = plt.cm.Reds
grey_color = "#cccccc"
cell_w, cell_h = 1.0, 1.0
strip_w = cell_w / n_batches
for i, lid in enumerate(leiden_list):
y = n_rows - 1 - i
for j, gene in enumerate(union):
if gene in sets[i]:
means_sorted, active_sorted = batch_data[i][j]
for b in range(n_batches):
if not active_sorted[b]:
facecolor = grey_color
elif means_sorted[b] == 0:
facecolor = "white"
else:
facecolor = cmap(means_sorted[b])
ax.add_patch(Rectangle(
(j * cell_w + b * strip_w, y * cell_h),
strip_w, cell_h,
facecolor=facecolor, edgecolor="none",
))
ax.add_patch(Rectangle(
(j * cell_w, y * cell_h), cell_w, cell_h,
facecolor="none", edgecolor="black", linewidth=0.5,
))
fs = 8
ax.set_xlim(0, n_cols * cell_w)
ax.set_ylim(0, n_rows * cell_h)
ax.set_xticks([j * cell_w + cell_w / 2 for j in range(n_cols)])
ax.set_xticklabels(union, rotation=90, ha="center", fontsize=fs)
ax.set_yticks([i * cell_h + cell_h / 2 for i in range(n_rows)])
ax.set_yticklabels(leiden_list[::-1], fontsize=fs)
ax.set_title(f"Marker genes for path {icls} (per-batch)", fontsize=fs)
sm = plt.cm.ScalarMappable(cmap=cmap, norm=mcolors.Normalize(vmin=0, vmax=1))
sm.set_array([])
cbar = fig.colorbar(sm, ax=ax, fraction=0.02, pad=0.02)
cbar.set_label("Mean expression\n(per-cell max = 1)", fontsize=fs)
cbar.ax.tick_params(labelsize=fs)
legend_handles = [
Patch(facecolor=grey_color, edgecolor="black", label="not in cluster"),
]
ax.legend(
handles=legend_handles, loc="upper left",
bbox_to_anchor=(1.0, -0.05), fontsize=fs, frameon=False,
)
plt.tight_layout()
return fig
# ---------------------------------------------------------------------------
# Helper functions
# ---------------------------------------------------------------------------
def _build_gene_sets(
leiden_list: List[str],
full_gene_lists: Dict[str, List[str]],
n: int,
gene_filter: Optional[GeneFilter],
) -> List[set]:
"""Build top-N gene sets per leiden ID, optionally filtered."""
sets = []
for lid in leiden_list:
genes = full_gene_lists[lid]
if gene_filter is not None:
genes = gene_filter.filter(genes)
sets.append(set(genes[:n]))
return sets
def _collect_batch_values(
leiden_list: List[str],
union: List[str],
batch_expression: Dict[str, BatchExpression],
) -> List[List[Tuple[np.ndarray, np.ndarray]]]:
"""Collect per-batch values normalized per cell by that cell's max batch mean.
For each (cluster, gene) cell:
- raw batch means are divided by the cell's own maximum active-batch mean
- batches are sorted: active first, then by descending normalized value
- returns values in [0, 1] where 1 = the brightest batch in that cell
"""
n_rows = len(leiden_list)
n_cols = len(union)
n_batches = len(next(iter(batch_expression.values())).batches)
raw_vals = np.zeros((n_rows, n_cols, n_batches), dtype=np.float32)
active = np.zeros((n_rows, n_batches), dtype=bool)
for i, lid in enumerate(leiden_list):
groupby, group_name = lid.split("@", 1)
be = batch_expression[groupby]
g_idx = be.group_to_idx[group_name]
active[i] = be.n_cells[g_idx] > 0
gene_indices = {g: int(k) for k, g in enumerate(be.genes)}
for j, gene in enumerate(union):
if gene in gene_indices:
raw_vals[i, j] = be.mean[g_idx, :, gene_indices[gene]]
batch_data: List[List[Tuple[np.ndarray, np.ndarray]]] = []
for i in range(n_rows):
row: List[Tuple[np.ndarray, np.ndarray]] = []
act_i = active[i]
for j in range(n_cols):
means_j = raw_vals[i, j].copy()
# per-cell max normalization
active_means = means_j[act_i]
cell_max = float(active_means.max()) if act_i.any() and active_means.max() > 0 else 1.0
means_j /= cell_max
order = np.lexsort((-means_j, ~act_i))
row.append((means_j[order], act_i[order]))
batch_data.append(row)
return batch_data
def _compute_batch_expression(adata, ctx, batch_key):
"""Compute per-batch expression statistics for one resolution level."""
from scipy import sparse
from sceleto._expr import resolve_expression
X, _, _ = resolve_expression(adata)
if not sparse.issparse(X):
X = sparse.csr_matrix(X)
else:
X = X.tocsr()
groups = ctx.groups
group_to_idx = ctx.group_to_idx
genes = ctx.genes
batches = sorted(adata.obs[batch_key].astype(str).unique().tolist())
n_groups, n_batches, n_genes = len(groups), len(batches), len(genes)
mean = np.zeros((n_groups, n_batches, n_genes), dtype=np.float32)
frac_expr = np.zeros((n_groups, n_batches, n_genes), dtype=np.float32)
n_cells_arr = np.zeros((n_groups, n_batches), dtype=np.int32)
obs_groups = adata.obs[ctx.groupby].astype(str).to_numpy()
obs_batches = adata.obs[batch_key].astype(str).to_numpy()
for g_name, g_idx in group_to_idx.items():
g_mask = obs_groups == g_name
for b_idx, b_name in enumerate(batches):
mask = g_mask & (obs_batches == b_name)
n_cells = int(mask.sum())
n_cells_arr[g_idx, b_idx] = n_cells
if n_cells == 0:
continue
Xsub = X[mask]
mean[g_idx, b_idx] = np.asarray(Xsub.mean(axis=0)).ravel()
frac_expr[g_idx, b_idx] = Xsub.getnnz(axis=0) / n_cells
return BatchExpression(
mean=mean, frac_expr=frac_expr, n_cells=n_cells_arr,
groups=groups, batches=batches, genes=genes, group_to_idx=group_to_idx,
)
# ---------------------------------------------------------------------------
# Main entry point
# ---------------------------------------------------------------------------
def hierarchy(
adata: Any,
markers_list: Sequence[Any],
*,
min_cells_for_path: Optional[int] = None,
n_top_markers: int = 10,
gene_filter: Optional[GeneFilter] = None,
batch_key: Optional[str] = None,
) -> HierarchyRun:
"""Run cross-resolution hierarchy pipeline.
Combines three resolution levels of marker outputs into icls
(integration cell lineage strings) and prepares marker comparison.
Parameters
----------
min_cells_for_path
Paths with fewer cells are reassigned to their major neighbor path
via kNN connectivities. Default: ``int(adata.shape[0] * 0.005)``.
"""
from scipy import sparse
markers_list = list(markers_list)
if min_cells_for_path is None:
min_cells_for_path = max(int(adata.shape[0] * 0.005), 1)
g0, g1, g2 = [mo.ctx.groupby for mo in markers_list]
# 1) Ensure categorical dtype
for g in (g0, g1, g2):
if not pd.api.types.is_categorical_dtype(adata.obs[g]):
adata.obs[g] = adata.obs[g].astype("category")
adata.obs[g] = adata.obs[g].cat.set_categories(
adata.obs[g].cat.categories, ordered=True,
)
# 2) Build per-cell path strings
adata.obs["path"] = (
f"{g0}@" + adata.obs[g0].astype(str)
+ "|" + f"{g1}@" + adata.obs[g1].astype(str)
+ "|" + f"{g2}@" + adata.obs[g2].astype(str)
)
# 3) Make path categorical with cartesian product order
path_categories = [
f"{g0}@{x}|{g1}@{y}|{g2}@{z}"
for x in adata.obs[g0].cat.categories
for y in adata.obs[g1].cat.categories
for z in adata.obs[g2].cat.categories
]
adata.obs["path"] = pd.Categorical(
adata.obs["path"], categories=path_categories, ordered=True,
)
# 4) Identify small paths and reassign their cells to neighbor paths
small_paths = (
adata.obs["path"].value_counts()
.loc[lambda s: s < min_cells_for_path].index
)
small_mask = adata.obs["path"].isin(small_paths).to_numpy()
if small_mask.any():
conn = adata.obsp["connectivities"]
if not sparse.issparse(conn):
conn = sparse.csr_matrix(conn)
else:
conn = conn.tocsr()
# path labels for non-small cells; NA for small-path cells
path_arr = adata.obs["path"].to_numpy(dtype=object, copy=True)
path_arr[small_mask] = None
# Iterative reassignment: repeat until convergence.
# Each pass may unlock neighbors that were None in the previous pass.
remaining = list(np.where(small_mask)[0])
while remaining:
next_remaining = []
for idx in remaining:
row = conn[idx]
nbr_paths = path_arr[row.indices]
valid = nbr_paths[nbr_paths != None] # noqa: E711
if len(valid) > 0:
values, counts = np.unique(valid, return_counts=True)
path_arr[idx] = values[counts.argmax()]
else:
next_remaining.append(idx)
if len(next_remaining) == len(remaining): # no progress → stop
break
remaining = next_remaining
adata.obs["path"] = pd.Categorical(
path_arr, categories=path_categories, ordered=True,
)
present = set(adata.obs["path"].dropna().unique())
new_categories = [c for c in path_categories if c in present]
adata.obs["path"] = adata.obs["path"].cat.set_categories(
new_categories, ordered=True,
)
# Build icls mapping
icls_full_dict: Dict[str, str] = {
str(i): path for i, path in enumerate(adata.obs["path"].cat.categories)
}
path_to_key = {v: k for k, v in icls_full_dict.items()}
adata.obs["icls"] = adata.obs["path"].map(path_to_key).astype("string")
# Build path dataframe
df_icls_path = pd.DataFrame(
pd.Series(icls_full_dict), columns=["icls_full"],
)
df_icls_path[g0] = [x.split("|")[0] for x in df_icls_path["icls_full"]]
df_icls_path[g1] = [x.split("|")[1] for x in df_icls_path["icls_full"]]
df_icls_path[g2] = [x.split("|")[2] for x in df_icls_path["icls_full"]]
df_icls_path["root"] = df_icls_path[g0]
df_icls_path = df_icls_path.reset_index(names="icls")
# Build marker rank table and full gene lists
rows: List[List[Any]] = []
full_gene_lists: Dict[str, List[str]] = {}
for level_key, mo in zip([g0, g1, g2], markers_list):
for k, v in mo.markers.items():
leiden_id = f"{level_key}@{k}"
full_gene_lists[leiden_id] = list(v)
genes = gene_filter.filter(v) if gene_filter is not None else v
for i, gene in enumerate(genes[:n_top_markers]):
rows.append([level_key, leiden_id, gene, i + 1])
df_marker_rank = pd.DataFrame(
rows, columns=["resolution", "leiden", "gene", "rank"],
)
# Build contexts dict
contexts = {mo.ctx.groupby: mo.ctx for mo in markers_list}
# Compute batch expression if requested
batch_expression = None
if batch_key is not None:
batch_expression = {
mo.ctx.groupby: _compute_batch_expression(adata, mo.ctx, batch_key)
for mo in markers_list
}
return HierarchyRun(
levels=[str(g0), str(g1), str(g2)],
params={
"min_cells_for_path": int(min_cells_for_path),
"n_top_markers": int(n_top_markers),
},
icls_full_dict=icls_full_dict,
icls_path_df=df_icls_path,
marker_rank_df=df_marker_rank,
full_gene_lists=full_gene_lists,
contexts=contexts,
batch_expression=batch_expression,
batch_key=batch_key,
)