API

SpatialEx training utilities.

This module contains three trainer classes:

  • Train_SpatialEx: Trains two SpatialEx models (one per slice) and evaluates cross-panel prediction quality via cosine similarity, SSIM, PCC, and CMD.

  • SpatialExP: Trains SpatialEx+ with additional regression mapping heads in a cycle-style setup to translate between gene panels.

  • SpatialExP_Big: Trains SpatialEx+ specifically for millons cells.

The trainers expect two AnnData slices whose .obsm[‘he’] stores histology-derived embeddings and whose var_names align with gene features.

Note

This module imports project-specific components from sibling modules: model.Model, model.Model_Plus, model.Model_Big, model.Regression, utils.create_optimizer, utils.Compute_metrics, and preprocessing functions as pp.

class SpatialEx.SpatialEx.SpatialEx(adata1, adata2, graph1, graph2, num_layers=2, hidden_dim=512, epochs=500, seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, encoder='hgnn', lr=0.001, loss_fn='mse', num_neighbors=7, graph_kind='spatial', prune=10000, save_path=None)[source]

Bases: object

Trainer for baseline SpatialEx on two slices.

This trainer fits two models (module_HA for slice 1 and module_HB for slice 2) independently using hypergraph-based batches, then evaluates cross-panel predictions at the end.

Attributes:

adata1 (AnnData): Slice 1.

adata2 (AnnData): Slice 2.

num_layers (int): Number of HGNN layers.

hidden_dim (int): Hidden width of the backbone.

epochs (int): Number of training epochs.

seed (int): Random seed.

device (torch.device): Device on which models are trained.

weight_decay (float): Weight decay for the optimizer.

optimizer (torch.optim.Optimizer): Optimizer instance.

batch_size (int): Batch size when building the hypergraph.

encoder (str): Encoder architecture key (e.g., "hgnn").

lr (float): Learning rate.

loss_fn (str): Loss function key (e.g., "mse").

num_neighbors (int): K for KNN used in hypergraph construction.

graph_kind (str): Spatial graph/hypergraph type (e.g., "spatial").

prune (int): Pruning threshold for dataloader construction.

save (bool): Whether to save the results.

__init__(adata1, adata2, graph1, graph2, num_layers=2, hidden_dim=512, epochs=500, seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, encoder='hgnn', lr=0.001, loss_fn='mse', num_neighbors=7, graph_kind='spatial', prune=10000, save_path=None)[source]
auto_inference()[source]

Run cross-panel prediction for both slices using internal dataloaders.

The method uses the trained models to predict the missing panel for each slice:

  • Slice 1: predict panel B using module_HB

  • Slice 2: predict panel A using module_HA

Returns

tuple[numpy.ndarray, numpy.ndarray]

(panel_1b, panel_2a) where:

  • panel_1b: panel-B prediction for slice 1

  • panel_2a: panel-A prediction for slice 2

Notes

If save_path is set, predictions are saved as B1.npy and A2.npy.

inference(he, graph, panel)[source]

Predict gene expression for a given panel on a single slice.

This is a lightweight inference helper that runs the corresponding trained model (module_HA for panelA or module_HB for panelB) on the provided histology embedding and spatial graph.

Parameters

hearray-like

Histology-derived embedding of shape (n_cells, n_he_features).

graphscipy.sparse.spmatrix or compatible

Sparse adjacency / hypergraph matrix for the slice.

panel{“panelA”, “panelB”}, default “panelA”

Which trained model to use.

Returns

numpy.ndarray

Predicted expression matrix of shape (n_cells, n_genes_in_panel).

Notes

If save_path is set, predictions are saved as .npy files.

train()[source]

Run the training loop and evaluate cross-panel predictions.

The method trains module_HA and module_HB jointly by iterating over paired mini-batches from two slices. After training, it predicts the missing panel on each slice and computes metrics at gene-level (cosine similarity, SSIM, PCC, CMD).

self:
data_dir: Project root containing a datasets/ folder with:
  • Human_Breast_Cancer_Rep1/cell_feature_matrix.h5

  • Human_Breast_Cancer_Rep1/cells.csv

  • Human_Breast_Cancer_Rep2/cell_feature_matrix.h5

  • Human_Breast_Cancer_Rep2/cells.csv

Prints:

Aggregated metrics per slice (cosine similarity, SSIM, PCC, CMD).

Raises:

FileNotFoundError – If any expected dataset file is missing.

Returns:

None

class SpatialEx.SpatialEx.SpatialExP(adata1, adata2, graph1, graph2, use_agg=True, platform='Xenium', seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, encoder='hgnn', hidden_dim=512, num_layers=2, epochs=1000, lr=0.001, prune=10000, loss_fn='mse', num_neighbors=7, graph_kind='spatial', save_path=None)[source]

Bases: object

__init__(adata1, adata2, graph1, graph2, use_agg=True, platform='Xenium', seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, encoder='hgnn', hidden_dim=512, num_layers=2, epochs=1000, lr=0.001, prune=10000, loss_fn='mse', num_neighbors=7, graph_kind='spatial', save_path=None)[source]

Initialize the SpatialEx+ trainer with cycle-style regression heads.

This trainer fits two SpatialEx+ backbones (one per slice) and two regression mappers (rm_AB, rm_BA) to translate between gene panels. During training it optimizes reconstruction losses and cycle-consistency-style mapping losses.

Parameters

adata1, adata2AnnData

Two slices with expression matrices in .X and histology embeddings in .obsm['he'].

graph1, graph2scipy.sparse.spmatrix or compatible

Spatial graphs for the two slices.

platformstr, default “Xenium”

Platform name forwarded to Model_Plus.

seedint, default 0

Random seed.

devicetorch.device, optional

Device on which models and tensors are placed.

weight_decayfloat, default 0

Weight decay for the optimizer.

optimizerstr, default “adam”

Optimizer key understood by utils.create_optimizer().

batch_sizeint, default 4096

Kept for compatibility with other trainers (not used directly here).

encoderstr, default “hgnn”

Encoder key (kept for logging/compatibility).

hidden_dimint, default 512

Hidden dimension of the backbone.

num_layersint, default 2

Number of backbone layers.

epochsint, default 1000

Training epochs.

lrfloat, default 0.001

Learning rate.

loss_fnstr, default “mse”

Loss function key.

num_neighborsint, default 7

K for KNN graph construction (kept for compatibility).

graph_kindstr, default “spatial”

Graph kind label (kept for compatibility).

save_pathstr or None, optional

If provided, directory to save inference outputs.

auto_inference()[source]

Run indirect cross-panel prediction for both slices.

Returns

tuple[numpy.ndarray, numpy.ndarray]

(panelB1_indirect, panelA2_indirect) predictions for slice 1 and slice 2 respectively.

Notes

If save_path is set, outputs are saved as B1.npy and A2.npy.

inference_direct(he, graph, panel)[source]

Directly predict the specified panel with its corresponding backbone.

Parameters

hearray-like

Histology embedding for the query slice.

graphscipy.sparse.spmatrix or compatible

Sparse graph for the query slice.

panel{“panelA”, “panelB”}, default “panelA”

Which panel to predict directly.

Returns

numpy.ndarray

Direct panel prediction of shape (n_cells, n_genes_in_panel).

Notes

If save_path is set, outputs are saved as <panel>_direct.npy.

inference_indirect(he, graph, panel)[source]

Indirectly infer the missing panel using a regression mapper.

For panelB inference, the method first predicts panel A with module_HA, then maps to panel B using rm_AB. For panelA inference it uses module_HB followed by rm_BA.

Parameters

hearray-like

Histology embedding for the query slice.

graphscipy.sparse.spmatrix or compatible

Sparse graph for the query slice.

panel{“panelA”, “panelB”}, default “panelA”

Which panel to infer indirectly.

Returns

numpy.ndarray

Indirect panel prediction of shape (n_cells, n_genes_in_target_panel).

Notes

If save_path is set, outputs are saved as omics.npy.

train()[source]

Train SpatialEx+ backbones and regression mappers.

The optimization combines:

  1. Per-slice reconstruction losses from Model_Plus.

  2. Cross-panel mapping losses for A->B and B->A via the regression heads.

  3. Cycle-style consistency losses mapping real panel expressions through the opposite backbone.

Returns

None

class SpatialEx.SpatialEx.SpatialExP_Big(adata1, adata2, graph1, graph2, use_agg=True, num_layers=2, hidden_dim=512, epochs=200, seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, batch_num=10, encoder='hgnn', lr=0.001, loss_fn='mse', num_neighbors=7, graph_kind='spatial', save_path=None)[source]

Bases: object

__init__(adata1, adata2, graph1, graph2, use_agg=True, num_layers=2, hidden_dim=512, epochs=200, seed=0, device=device(type='cpu'), weight_decay=0, optimizer='adam', batch_size=4096, batch_num=10, encoder='hgnn', lr=0.001, loss_fn='mse', num_neighbors=7, graph_kind='spatial', save_path=None)[source]

Initialize the large-scale SpatialEx+ trainer using pseudo-spots.

This variant aggregates single-cell expression into pseudo-spots to reduce memory and enable training on very large datasets. It trains a shared big backbone (model_big) plus two regression mappers (model_AB, model_BA) for cross-panel translation.

Parameters

adata1, adata2AnnData

Two slices with histology embeddings in .obsm['he'] and expression in .X.

graph1, graph2scipy.sparse.spmatrix or compatible

Spatial graphs for the two slices.

num_layersint, default 2

Number of backbone layers.

hidden_dimint, default 512

Hidden dimension of the backbone.

epochsint, default 200

Number of training epochs.

seedint, default 0

Random seed.

devicetorch.device, optional

Device to run on.

weight_decayfloat, default 0

Weight decay for the optimizer.

optimizerstr, default “adam”

Optimizer key.

batch_sizeint, default 4096

Kept for compatibility (batching here is driven by batch_num).

batch_numint, default 10

Number of pseudo-spot batches per epoch.

encoderstr, default “hgnn”

Encoder key (kept for compatibility).

lrfloat, default 0.001

Learning rate.

loss_fnstr, default “mse”

Loss function key.

num_neighborsint, default 7

K for KNN (kept for compatibility).

graph_kindstr, default “spatial”

Graph kind label (kept for compatibility).

save_pathstr or None, optional

Directory to save inference outputs.

auto_inference()[source]

Run indirect cross-panel prediction for both original slices.

Returns

tuple[numpy.ndarray, numpy.ndarray]

(indirect_panel_B1, indirect_panel_A2) predictions.

Notes

If save_path is set, outputs are saved as B1.npy and A2.npy.

train()[source]

Train the big model with pseudo-spot batching.

Each epoch shuffles pseudo-spot indices and iterates over batch_num batches. For each batch it:

  1. Computes losses for the shared backbone on both slices.

  2. Generates exchanged predictions and trains regression mappers.

  3. Applies reconstruction-style mapping losses using the aggregation matrices.

Returns

None