Source code for sceleto.markers.graph._onestep

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, Union

import pandas as pd

from sceleto.markers._base import MarkersBase


[docs] @dataclass class MarkerGraphRun(MarkersBase): """Container for one-step marker-graph pipeline results. Notes ----- - Keeps intermediate artifacts for debugging and inspection. """ # Core artifacts ctx: Any edge_gene_df: pd.DataFrame edge_fc: pd.DataFrame edge_delta: pd.DataFrame labels: Any note_df: pd.DataFrame # Graph + viz G: Any pos: Any gene_edge_fc: Dict[str, Dict[Tuple[object, object], float]] gene_to_edges: Dict[str, List[str]] viz: Any # Specific marker outputs specific_ranking_df: pd.DataFrame _marker_log: Dict[str, List[str]] # Batch key (None if not provided) batch_key: Optional[str] = None # Which edge metric drove the jump filter ("fc" or "delta") edge_metric: Literal["fc", "delta"] = "fc" # Threshold sweep (None if the active threshold was not "auto") sweep_df: Optional[pd.DataFrame] = None suggested_thres_fc: Optional[float] = None
[docs] def plot_fc_threshold(self, **kwargs): """Plot threshold sweep results. Only available if the active threshold was "auto".""" if self.sweep_df is None: raise ValueError("No sweep data. Re-run with thres_fc='auto' (or thres_delta='auto').") from ._threshold import plot_fc_threshold kwargs.setdefault("edge_metric", self.edge_metric) return plot_fc_threshold(self.sweep_df, suggested=self.suggested_thres_fc, **kwargs)
[docs] def plot_gene_edges_fc(self, gene: str, **kwargs): return self.viz.plot_gene_edges_fc(gene, **kwargs)
[docs] def plot_gene_levels_with_edges(self, gene: str, level: Optional[int] = None, **kwargs): return self.viz.plot_gene_levels_with_edges(gene, level=level, **kwargs)
[docs] def plot_highlight_edges(self, edges, **kwargs): return self.viz.plot_highlight_edges(edges, **kwargs)
def __post_init__(self): # Skip MarkersBase.__init__; adata/groupby are proxied from ctx pass @property def adata(self): return self.ctx.adata @property def groupby(self): return self.ctx.groupby @property def markers(self) -> Dict[str, List[str]]: """Per-group marker gene lists, ranked by specificity score.""" return self._marker_log
[docs] def batch_mean_detail( self, adata, gene: str, group: str, ): """Return per-batch mean expression for a specific (gene, group). Parameters ---------- adata : AnnData The same AnnData used in :func:`run_marker_graph`. gene : str Marker gene name. group : str Cluster where the gene is highly expressed. Returns ------- DataFrame with columns: ``edge_start``, ``edge_end``, ``batch``, ``mean_start``, ``mean_end``, ``n_cells_start``, ``n_cells_end``. """ if self.batch_key is None: raise ValueError( "No batch data. Re-run run_marker_graph() with batch_key='...'." ) from ._batch import get_batch_mean_detail return get_batch_mean_detail( adata, self.ctx, self.edge_gene_df, self.batch_key, gene, group, )
def run_marker_graph( adata: Any, *, groupby: str, edge_metric: Literal["fc", "delta"] = "fc", thres_fc: Union[float, str] = "auto", thres_delta: Union[float, str] = 0.5, # Specific ranking params specific_A: float = 1.0, specific_B: float = 0.5, specific_only_high_markers: bool = True, specific_score_col: str = "specific_weight", specific_score_fn: Optional[Callable[[pd.DataFrame], object]] = None, # Context defaults use_raw: bool = True, k: Union[int, Literal["all"]] = 5, exclude: Optional[List[str]] = None, min_cells_per_group: int = 0, min_expr_cells_per_gene: int = 0, # FC/delta defaults eps: float = 1e-3, min_mean_any: float = 0.05, min_mean_high: float = 0.5, min_frac_high: float = 0.2, max_mean_low: float = 0.2, min_nexpr_any: int = 0, # Labeling defaults fc_cutoff: Optional[float] = None, delta_cutoff: Optional[float] = None, label_k: float = 2.0, sigma_method: str = "sd", min_gap: float = 0.2, min_margin: float = 0.0, level: int = 3, # Graph/Viz defaults bidirectional: bool = True, node_size_scale: float = 10.0, # Batch t-test (activated automatically when batch_key is provided) batch_key: Optional[str] = None, batch_min_cells: int = 5, batch_ttest_alpha: float = 0.05, batch_ttest_min_batches: int = 3, ) -> MarkerGraphRun: """One-step wrapper: context -> edge metrics -> labels -> viz -> specific marker discovery""" from ._context import build_context from ._metrics import compute_fc_delta, edge_gene_df_to_matrices, build_gene_edge_fc_from_edge_gene_df from ._labels import label_levels, labels_to_note_df from ._viz import GraphVizContext, build_graph_and_pos_from_ctx from ._local import ( build_local_marker_inputs, weight_local_prioritized, ) import scanpy as sc import matplotlib.pyplot as plt # --- Validate k and prepare graph inputs --- if k == "all": # Complete-graph mode: no PAGA. Requires UMAP for cluster centroid positions. if "X_umap" not in getattr(adata, "obsm", {}): raise ValueError( "k='all' requires adata.obsm['X_umap'] for cluster positions. " "Run sc.tl.umap(adata) first." ) elif isinstance(k, int): # PAGA-trim mode: neighbors graph must already exist (no auto-compute). if "neighbors" not in getattr(adata, "uns", {}): raise ValueError( f"k={k} (PAGA trim) requires precomputed neighbors. " "Run sc.pp.neighbors(adata) first, or set k='all' to skip PAGA." ) # Ensure PAGA exists and matches the current groupby paga = getattr(adata, "uns", {}).get("paga", None) need_paga = paga is None or "connectivities" not in paga if not need_paga: n_groups = adata.obs[groupby].nunique() if paga["connectivities"].shape[0] != n_groups: need_paga = True if need_paga: sc.tl.paga(adata, groups=groupby) # Ensure PAGA positions exist; populate if missing or stale paga = adata.uns.get("paga", {}) if "pos" not in paga or need_paga: try: sc.pl.paga_compare(adata, show=False) except Exception: sc.pl.paga(adata, show=False) plt.close("all") else: raise ValueError(f"k must be int or 'all', got {k!r}") # --- Validate edge_metric --- if edge_metric not in ("fc", "delta"): raise ValueError(f"edge_metric must be 'fc' or 'delta', got {edge_metric!r}") # --- Pick the active threshold based on edge_metric --- active_thres: Union[float, str] = thres_fc if edge_metric == "fc" else thres_delta sweep_df = None suggested_thres_fc = None if isinstance(active_thres, str) and active_thres == "auto": from ._threshold import sweep_fc_threshold, suggest_fc_threshold sweep_df = sweep_fc_threshold( adata, groupby, edge_metric=edge_metric, use_raw=use_raw, # Context kwargs k=k, exclude=exclude, min_cells_per_group=min_cells_per_group, min_expr_cells_per_gene=min_expr_cells_per_gene, # Expression filter kwargs — must match the main pipeline so the # suggested threshold is computed on the same edge population. eps=eps, min_mean_any=min_mean_any, min_mean_high=min_mean_high, min_frac_high=min_frac_high, max_mean_low=max_mean_low, min_nexpr_any=min_nexpr_any, ) suggested_thres_fc = suggest_fc_threshold(sweep_df) active_thres = suggested_thres_fc print(f" Auto thres_{edge_metric}: {active_thres:.2f}") # Write the resolved threshold back to the metric-specific variable. # The inactive threshold is not used for filtering but must be a finite # float so it can be passed through to compute_fc_delta. if edge_metric == "fc": thres_fc = float(active_thres) if isinstance(thres_delta, str): thres_delta = 0.5 else: thres_delta = float(active_thres) if isinstance(thres_fc, str): thres_fc = 3.0 # Sync cutoffs for label_levels with the active threshold (unless caller overrode). if fc_cutoff is None: fc_cutoff = float(thres_fc) if edge_metric == "fc" else 3.0 if delta_cutoff is None: delta_cutoff = float(thres_delta) if edge_metric == "delta" else 0.5 ctx = build_context( adata, groupby=groupby, use_raw=use_raw, exclude=exclude, min_cells_per_group=min_cells_per_group, min_expr_cells_per_gene=min_expr_cells_per_gene, k=k, ) edge_gene_df = compute_fc_delta( ctx, edge_metric=edge_metric, thres_fc=float(thres_fc), thres_delta=float(thres_delta), eps=eps, min_mean_any=min_mean_any, min_mean_high=min_mean_high, min_frac_high=min_frac_high, max_mean_low=max_mean_low, min_nexpr_any=min_nexpr_any, ) # --- Batch t-test filter (activated when batch_key is provided) --- if batch_key is not None: from ._batch import filter_edge_gene_df_by_ttest edge_gene_df = filter_edge_gene_df_by_ttest( adata, ctx, edge_gene_df, batch_key, use_raw=use_raw, min_cells=batch_min_cells, min_batches=batch_ttest_min_batches, alpha=batch_ttest_alpha, eps=eps, ) edge_fc, edge_delta = edge_gene_df_to_matrices(edge_gene_df) labels = label_levels( ctx, edge_gene_df, edge_metric=edge_metric, fc_cutoff=float(fc_cutoff), delta_cutoff=float(delta_cutoff), k=label_k, sigma_method=sigma_method, # type: ignore[arg-type] min_gap=min_gap, min_margin=min_margin, ) note_df = labels_to_note_df(ctx, labels, level=level) # type: ignore[arg-type] G, pos = build_graph_and_pos_from_ctx(ctx, bidirectional=bidirectional) gene_edge_fc = build_gene_edge_fc_from_edge_gene_df(edge_fc, G=G) if edge_metric == "fc": sub = edge_gene_df[edge_gene_df["fc"] >= float(thres_fc)] else: sub = edge_gene_df[edge_gene_df["delta"] >= float(thres_delta)] gene_to_edges: Dict[str, List[str]] = {} if len(sub) > 0: for g, sdf in sub.groupby("gene"): gene_to_edges[str(g)] = ( sdf["start"].astype(str) + "->" + sdf["end"].astype(str) ).tolist() viz = GraphVizContext( G=G, ctx=ctx, note_df=note_df, labels=labels, gene_edge_fc=gene_edge_fc, gene_to_edges=gene_to_edges, node_size_scale=node_size_scale, ) specific_ranking_df: pd.DataFrame _marker_log: Dict[str, List[str]] specific_inputs_df = build_local_marker_inputs( ctx=ctx, labels=labels, note_df=note_df, edge_fc=edge_fc, edge_delta=edge_delta, only_high_markers=specific_only_high_markers, ) specific_ranking_df = specific_inputs_df.copy() if specific_score_fn is not None: specific_ranking_df[specific_score_col] = specific_score_fn(specific_ranking_df) else: specific_ranking_df[specific_score_col] = weight_local_prioritized( specific_ranking_df, A=specific_A, B=specific_B ) specific_ranking_df = specific_ranking_df.sort_values( ["group", specific_score_col], ascending=[True, False] ) _marker_log = {str(g): [] for g in ctx.groups} for g, sdf in specific_ranking_df.groupby("group", sort=False): _marker_log[str(g)] = sdf["gene"].astype(str).tolist() return MarkerGraphRun( ctx=ctx, edge_gene_df=edge_gene_df, edge_fc=edge_fc, edge_delta=edge_delta, labels=labels, note_df=note_df, G=G, pos=pos, gene_edge_fc=gene_edge_fc, gene_to_edges=gene_to_edges, viz=viz, specific_ranking_df=specific_ranking_df, _marker_log=_marker_log, batch_key=batch_key, edge_metric=edge_metric, sweep_df=sweep_df, suggested_thres_fc=suggested_thres_fc, )