API
SpatialEx training utilities.
This module contains three trainer classes:
Train_SpatialEx: Trains two SpatialEx models (one per slice) and evaluates cross-panel prediction quality via cosine similarity, SSIM, PCC, and CMD.SpatialExP: Trains SpatialEx+ with additional regression mapping heads in a cycle-style setup to translate between gene panels.SpatialExP_Big: Trains SpatialEx+ specifically for millons cells.
The trainers expect two AnnData slices whose .obsm[‘he’] stores histology-derived embeddings and whose var_names align with gene features.
Note
This module imports project-specific components from sibling modules:
model.Model, model.Model_Plus, model.Model_Big, model.Regression,
utils.create_optimizer, utils.Compute_metrics, and preprocessing functions as pp.
- class SpatialEx.SpatialEx.SpatialEx(adata1, adata2, graph1, graph2, num_layers=2, hidden_dim=512, epochs=500, seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, encoder='hgnn', lr=0.001, loss_fn='mse', num_neighbors=7, graph_kind='spatial', prune=10000, save_path=None)[source]
Bases:
objectTrainer for baseline SpatialEx on two slices.
This trainer fits two models (
module_HAfor slice 1 andmodule_HBfor slice 2) independently using hypergraph-based batches, then evaluates cross-panel predictions at the end.Attributes:
adata1 (AnnData): Slice 1.
adata2 (AnnData): Slice 2.
num_layers (int): Number of HGNN layers.
hidden_dim (int): Hidden width of the backbone.
epochs (int): Number of training epochs.
seed (int): Random seed.
device (torch.device): Device on which models are trained.
weight_decay (float): Weight decay for the optimizer.
optimizer (torch.optim.Optimizer): Optimizer instance.
batch_size (int): Batch size when building the hypergraph.
encoder (str): Encoder architecture key (e.g.,
"hgnn").lr (float): Learning rate.
loss_fn (str): Loss function key (e.g.,
"mse").num_neighbors (int): K for KNN used in hypergraph construction.
graph_kind (str): Spatial graph/hypergraph type (e.g.,
"spatial").prune (int): Pruning threshold for dataloader construction.
save (bool): Whether to save the results.
- __init__(adata1, adata2, graph1, graph2, num_layers=2, hidden_dim=512, epochs=500, seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, encoder='hgnn', lr=0.001, loss_fn='mse', num_neighbors=7, graph_kind='spatial', prune=10000, save_path=None)[source]
- auto_inference()[source]
Run cross-panel prediction for both slices using internal dataloaders.
The method uses the trained models to predict the missing panel for each slice:
Slice 1: predict panel B using
module_HBSlice 2: predict panel A using
module_HA
Returns
- tuple[numpy.ndarray, numpy.ndarray]
(panel_1b, panel_2a)where:panel_1b: panel-B prediction for slice 1panel_2a: panel-A prediction for slice 2
Notes
If
save_pathis set, predictions are saved asB1.npyandA2.npy.
- inference(he, graph, panel)[source]
Predict gene expression for a given panel on a single slice.
This is a lightweight inference helper that runs the corresponding trained model (
module_HAforpanelAormodule_HBforpanelB) on the provided histology embedding and spatial graph.Parameters
- hearray-like
Histology-derived embedding of shape
(n_cells, n_he_features).- graphscipy.sparse.spmatrix or compatible
Sparse adjacency / hypergraph matrix for the slice.
- panel{“panelA”, “panelB”}, default “panelA”
Which trained model to use.
Returns
- numpy.ndarray
Predicted expression matrix of shape
(n_cells, n_genes_in_panel).
Notes
If
save_pathis set, predictions are saved as.npyfiles.
- train()[source]
Run the training loop and evaluate cross-panel predictions.
The method trains
module_HAandmodule_HBjointly by iterating over paired mini-batches from two slices. After training, it predicts the missing panel on each slice and computes metrics at gene-level (cosine similarity, SSIM, PCC, CMD).- self:
- data_dir: Project root containing a
datasets/folder with: Human_Breast_Cancer_Rep1/cell_feature_matrix.h5Human_Breast_Cancer_Rep1/cells.csvHuman_Breast_Cancer_Rep2/cell_feature_matrix.h5Human_Breast_Cancer_Rep2/cells.csv
- data_dir: Project root containing a
- Prints:
Aggregated metrics per slice (cosine similarity, SSIM, PCC, CMD).
- Raises:
FileNotFoundError – If any expected dataset file is missing.
- Returns:
None
- class SpatialEx.SpatialEx.SpatialExP(adata1, adata2, graph1, graph2, use_agg=True, platform='Xenium', seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, encoder='hgnn', hidden_dim=512, num_layers=2, epochs=1000, lr=0.001, prune=10000, loss_fn='mse', num_neighbors=7, graph_kind='spatial', save_path=None)[source]
Bases:
object- __init__(adata1, adata2, graph1, graph2, use_agg=True, platform='Xenium', seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, encoder='hgnn', hidden_dim=512, num_layers=2, epochs=1000, lr=0.001, prune=10000, loss_fn='mse', num_neighbors=7, graph_kind='spatial', save_path=None)[source]
Initialize the SpatialEx+ trainer with cycle-style regression heads.
This trainer fits two SpatialEx+ backbones (one per slice) and two regression mappers (
rm_AB,rm_BA) to translate between gene panels. During training it optimizes reconstruction losses and cycle-consistency-style mapping losses.Parameters
- adata1, adata2AnnData
Two slices with expression matrices in
.Xand histology embeddings in.obsm['he'].- graph1, graph2scipy.sparse.spmatrix or compatible
Spatial graphs for the two slices.
- platformstr, default “Xenium”
Platform name forwarded to
Model_Plus.- seedint, default 0
Random seed.
- devicetorch.device, optional
Device on which models and tensors are placed.
- weight_decayfloat, default 0
Weight decay for the optimizer.
- optimizerstr, default “adam”
Optimizer key understood by
utils.create_optimizer().- batch_sizeint, default 4096
Kept for compatibility with other trainers (not used directly here).
- encoderstr, default “hgnn”
Encoder key (kept for logging/compatibility).
- hidden_dimint, default 512
Hidden dimension of the backbone.
- num_layersint, default 2
Number of backbone layers.
- epochsint, default 1000
Training epochs.
- lrfloat, default 0.001
Learning rate.
- loss_fnstr, default “mse”
Loss function key.
- num_neighborsint, default 7
K for KNN graph construction (kept for compatibility).
- graph_kindstr, default “spatial”
Graph kind label (kept for compatibility).
- save_pathstr or None, optional
If provided, directory to save inference outputs.
- auto_inference()[source]
Run indirect cross-panel prediction for both slices.
Returns
- tuple[numpy.ndarray, numpy.ndarray]
(panelB1_indirect, panelA2_indirect)predictions for slice 1 and slice 2 respectively.
Notes
If
save_pathis set, outputs are saved asB1.npyandA2.npy.
- inference_direct(he, graph, panel)[source]
Directly predict the specified panel with its corresponding backbone.
Parameters
- hearray-like
Histology embedding for the query slice.
- graphscipy.sparse.spmatrix or compatible
Sparse graph for the query slice.
- panel{“panelA”, “panelB”}, default “panelA”
Which panel to predict directly.
Returns
- numpy.ndarray
Direct panel prediction of shape
(n_cells, n_genes_in_panel).
Notes
If
save_pathis set, outputs are saved as<panel>_direct.npy.
- inference_indirect(he, graph, panel)[source]
Indirectly infer the missing panel using a regression mapper.
For
panelBinference, the method first predicts panel A withmodule_HA, then maps to panel B usingrm_AB. ForpanelAinference it usesmodule_HBfollowed byrm_BA.Parameters
- hearray-like
Histology embedding for the query slice.
- graphscipy.sparse.spmatrix or compatible
Sparse graph for the query slice.
- panel{“panelA”, “panelB”}, default “panelA”
Which panel to infer indirectly.
Returns
- numpy.ndarray
Indirect panel prediction of shape
(n_cells, n_genes_in_target_panel).
Notes
If
save_pathis set, outputs are saved asomics.npy.
- train()[source]
Train SpatialEx+ backbones and regression mappers.
The optimization combines:
Per-slice reconstruction losses from
Model_Plus.Cross-panel mapping losses for
A->BandB->Avia the regression heads.Cycle-style consistency losses mapping real panel expressions through the opposite backbone.
Returns
None
- class SpatialEx.SpatialEx.SpatialExP_Big(adata1, adata2, graph1, graph2, use_agg=True, num_layers=2, hidden_dim=512, epochs=200, seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, batch_num=10, encoder='hgnn', lr=0.001, loss_fn='mse', num_neighbors=7, graph_kind='spatial', save_path=None)[source]
Bases:
object- __init__(adata1, adata2, graph1, graph2, use_agg=True, num_layers=2, hidden_dim=512, epochs=200, seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, batch_num=10, encoder='hgnn', lr=0.001, loss_fn='mse', num_neighbors=7, graph_kind='spatial', save_path=None)[source]
Initialize the large-scale SpatialEx+ trainer using pseudo-spots.
This variant aggregates single-cell expression into pseudo-spots to reduce memory and enable training on very large datasets. It trains a shared big backbone (
model_big) plus two regression mappers (model_AB,model_BA) for cross-panel translation.Parameters
- adata1, adata2AnnData
Two slices with histology embeddings in
.obsm['he']and expression in.X.- graph1, graph2scipy.sparse.spmatrix or compatible
Spatial graphs for the two slices.
- num_layersint, default 2
Number of backbone layers.
- hidden_dimint, default 512
Hidden dimension of the backbone.
- epochsint, default 200
Number of training epochs.
- seedint, default 0
Random seed.
- devicetorch.device, optional
Device to run on.
- weight_decayfloat, default 0
Weight decay for the optimizer.
- optimizerstr, default “adam”
Optimizer key.
- batch_sizeint, default 4096
Kept for compatibility (batching here is driven by
batch_num).- batch_numint, default 10
Number of pseudo-spot batches per epoch.
- encoderstr, default “hgnn”
Encoder key (kept for compatibility).
- lrfloat, default 0.001
Learning rate.
- loss_fnstr, default “mse”
Loss function key.
- num_neighborsint, default 7
K for KNN (kept for compatibility).
- graph_kindstr, default “spatial”
Graph kind label (kept for compatibility).
- save_pathstr or None, optional
Directory to save inference outputs.
- auto_inference()[source]
Run indirect cross-panel prediction for both original slices.
Returns
- tuple[numpy.ndarray, numpy.ndarray]
(indirect_panel_B1, indirect_panel_A2)predictions.
Notes
If
save_pathis set, outputs are saved asB1.npyandA2.npy.
- train()[source]
Train the big model with pseudo-spot batching.
Each epoch shuffles pseudo-spot indices and iterates over
batch_numbatches. For each batch it:Computes losses for the shared backbone on both slices.
Generates exchanged predictions and trains regression mappers.
Applies reconstruction-style mapping losses using the aggregation matrices.
Returns
None