"""
SpatialEx training utilities.
This module contains three trainer classes:
- :class:`Train_SpatialEx`: Trains two SpatialEx models (one per slice) and
evaluates cross-panel prediction quality via cosine similarity, SSIM, PCC,
and CMD.
- :class:`SpatialExP`: Trains SpatialEx+ with additional regression
mapping heads in a cycle-style setup to translate between gene panels.
- :class:`SpatialExP_Big`: Trains SpatialEx+ specifically for millons cells.
The trainers expect two AnnData slices whose `.obsm['he']` stores
histology-derived embeddings and whose `var_names` align with gene features.
Note:
This module imports project-specific components from sibling modules:
``model.Model``, ``model.Model_Plus``, ``model.Model_Big``, ``model.Regression``,
``utils.create_optimizer``, ``utils.Compute_metrics``, and preprocessing functions as ``pp``.
"""
import os
import torch
import random
import warnings
import numpy as np
import pandas as pd
from tqdm import tqdm
import scipy.sparse as sp
from . import preprocess as pp
from .utils import create_optimizer, Generate_pseudo_spot
from .model import Model, Model_Plus, Model_Big, Regression
warnings.filterwarnings("ignore")
[docs]
class SpatialEx:
"""Trainer for baseline SpatialEx on two slices.
This trainer fits two models (:attr:`module_HA` for slice 1 and
:attr:`module_HB` for slice 2) independently using hypergraph-based
batches, then evaluates cross-panel predictions at the end.
Attributes:
adata1 (AnnData): Slice 1.
adata2 (AnnData): Slice 2.
num_layers (int): Number of HGNN layers.
hidden_dim (int): Hidden width of the backbone.
epochs (int): Number of training epochs.
seed (int): Random seed.
device (torch.device): Device on which models are trained.
weight_decay (float): Weight decay for the optimizer.
optimizer (torch.optim.Optimizer): Optimizer instance.
batch_size (int): Batch size when building the hypergraph.
encoder (str): Encoder architecture key (e.g., ``"hgnn"``).
lr (float): Learning rate.
loss_fn (str): Loss function key (e.g., ``"mse"``).
num_neighbors (int): K for KNN used in hypergraph construction.
graph_kind (str): Spatial graph/hypergraph type (e.g., ``"spatial"``).
prune (int): Pruning threshold for dataloader construction.
save (bool): Whether to save the results.
"""
[docs]
def __init__(self,
adata1,
adata2,
graph1,
graph2,
num_layers=2,
hidden_dim=512,
epochs=500,
seed=0,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
weight_decay=0,
optimizer="adam",
batch_size=4096,
encoder="hgnn",
lr=0.001,
loss_fn="mse",
num_neighbors=7,
graph_kind='spatial',
prune=10000,
save_path=None
):
self.adata1 = adata1
self.adata2 = adata2
self.graph1 = graph1
self.graph2 = graph2
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.epochs = epochs
self.seed = seed
self.device = device
self.weight_decay = weight_decay
self.batch_size = batch_size
self.encoder = encoder
self.lr = lr
self.loss_fn = loss_fn
self.num_neighbors = num_neighbors
self.graph_kind = graph_kind
self.prune = prune
self.save_path = save_path
self.in_dim1 = self.adata1.obsm['he'].shape[1]
self.in_dim2 = self.adata2.obsm['he'].shape[1]
self.out_dim1 = self.adata1.n_vars
self.out_dim2 = self.adata2.n_vars
self.module_HA = Model(self.num_layers, self.in_dim1, self.hidden_dim, self.out_dim1, self.loss_fn, self.device)
self.module_HB = Model(self.num_layers, self.in_dim2, self.hidden_dim, self.out_dim2, self.loss_fn, self.device)
self.models = [self.module_HA, self.module_HB]
self.optimizer = create_optimizer(optimizer, self.models, self.lr, self.weight_decay)
# H1 = pp.Build_hypergraph_spatial_and_HE(adata1, num_neighbors, batch_size, False, 'spatial', 'crs') # 注释掉
self.slice1_dataloader = pp.Build_dataloader(adata1, graph=graph1, graph_norm='hpnn', feat_norm=False,
prune=[prune, prune], drop_last=False)
# H2 = pp.Build_hypergraph_spatial_and_HE(adata2, num_neighbors, batch_size, False, 'spatial', 'crs') # 注释掉
self.slice2_dataloader = pp.Build_dataloader(adata2, graph=graph2, graph_norm='hpnn', feat_norm=False,
prune=[prune, prune], drop_last=False)
[docs]
def train(self):
"""Run the training loop and evaluate cross-panel predictions.
The method trains :attr:`module_HA` and :attr:`module_HB` jointly by
iterating over paired mini-batches from two slices. After training, it
predicts the missing panel on each slice and computes metrics at
gene-level (cosine similarity, SSIM, PCC, CMD).
self:
data_dir: Project root containing a ``datasets/`` folder with:
- ``Human_Breast_Cancer_Rep1/cell_feature_matrix.h5``
- ``Human_Breast_Cancer_Rep1/cells.csv``
- ``Human_Breast_Cancer_Rep2/cell_feature_matrix.h5``
- ``Human_Breast_Cancer_Rep2/cells.csv``
Prints:
Aggregated metrics per slice (cosine similarity, SSIM, PCC, CMD).
Raises:
FileNotFoundError: If any expected dataset file is missing.
Returns:
None
"""
pp.set_random_seed(self.seed)
self.module_HA.train()
self.module_HB.train()
print('\n')
print('=================================== Start training =========================================')
epoch_iter = tqdm(range(self.epochs))
for epoch in epoch_iter:
batch_iter = zip(self.slice1_dataloader, self.slice2_dataloader)
for data1, data2 in batch_iter:
graph1, he1, panel_1a, selection1 = data1[0]['graph'].to(self.device), data1[0]['he'].to(self.device), \
data1[0]['exp'].to(self.device), data1[0]['selection']
graph2, he2, panel_2b, selection2 = data2[0]['graph'].to(self.device), data2[0]['he'].to(self.device), \
data2[0]['exp'].to(self.device), data2[0]['selection']
agg_mtx1, agg_exp1 = data1[0]['agg_mtx'].to(self.device), data1[0]['agg_exp'].to(self.device)
agg_mtx2, agg_exp2 = data2[0]['agg_mtx'].to(self.device), data2[0]['agg_exp'].to(self.device)
loss1, _ = self.module_HA(graph1, he1, agg_exp1, agg_mtx1, selection1)
loss2, _ = self.module_HB(graph2, he2, agg_exp2, agg_mtx2, selection2)
loss = loss1 + loss2
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
epoch_iter.set_description(f"#Epoch: {epoch}: train_loss: {loss.item():.2f}")
'''========================= 测试 ========================'''
[docs]
def inference(self, he, graph, panel):
"""Predict gene expression for a given panel on a single slice.
This is a lightweight inference helper that runs the corresponding
trained model (:attr:`module_HA` for ``panelA`` or :attr:`module_HB` for
``panelB``) on the provided histology embedding and spatial graph.
Parameters
----------
he : array-like
Histology-derived embedding of shape ``(n_cells, n_he_features)``.
graph : scipy.sparse.spmatrix or compatible
Sparse adjacency / hypergraph matrix for the slice.
panel : {"panelA", "panelB"}, default "panelA"
Which trained model to use.
Returns
-------
numpy.ndarray
Predicted expression matrix of shape ``(n_cells, n_genes_in_panel)``.
Notes
-----
If :attr:`save_path` is set, predictions are saved as ``.npy`` files.
"""
he = torch.Tensor(he).to(self.device)
graph = pp.sparse_mx_to_torch_sparse_tensor(graph).to(self.device)
if panel == 'panelA':
self.module_HA.eval()
pred = self.module_HA.predict(he, graph).detach().cpu().numpy()
panel_name = "panelA"
elif panel == 'panelB':
self.module_HB.eval()
pred = self.module_HB.predict(he, graph).detach().cpu().numpy()
panel_name = "panelB"
if self.save_path is not None:
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
np.save(self.save_path + panel_name + '.npy', pred)
print(f'The results have been sucessfully saved in {self.save_path}') # 改成保存路径
return pred
[docs]
def auto_inference(self):
"""Run cross-panel prediction for both slices using internal dataloaders.
The method uses the trained models to predict the *missing* panel for
each slice:
- Slice 1: predict panel B using :attr:`module_HB`
- Slice 2: predict panel A using :attr:`module_HA`
Returns
-------
tuple[numpy.ndarray, numpy.ndarray]
``(panel_1b, panel_2a)`` where:
- ``panel_1b``: panel-B prediction for slice 1
- ``panel_2a``: panel-A prediction for slice 2
Notes
-----
If :attr:`save_path` is set, predictions are saved as ``B1.npy`` and
``A2.npy``.
"""
self.module_HA.eval()
self.module_HB.eval()
'''PanelA1'''
panel_1b = []
obs_list = []
for data in self.slice1_dataloader:
graph, he, obs = data[0]['graph'].to(self.device), data[0]['he'].to(self.device), data[0]['obs']
panelB1 = self.module_HB.predict(he, graph).detach().cpu().numpy()
panel_1b.append(panelB1)
obs_list = obs_list + obs
panel_1b = np.vstack(panel_1b)
panel_1b = pd.DataFrame(panel_1b)
panel_1b.columns = self.adata2.var_names ##修改
panel_1b = panel_1b.values
'''Panel2B'''
panel_2a = []
obs_list = []
for data in self.slice2_dataloader:
graph, he, obs = data[0]['graph'].to(self.device), data[0]['he'].to(self.device), data[0]['obs']
panel2A = self.module_HA.predict(he, graph).detach().cpu().numpy()
panel_2a.append(panel2A)
obs_list = obs_list + obs
panel_2a = np.vstack(panel_2a)
panel_2a = pd.DataFrame(panel_2a)
panel_2a.columns = self.adata1.var_names ##修改
panel_2a = panel_2a.values
if self.save_path is not None:
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
np.save(self.save_path + 'B1.npy', panel_1b)
np.save(self.save_path + 'A2.npy', panel_2a)
print(f'The results have been sucessfully saved in {self.save_path}') # 改成保存路径
return panel_1b, panel_2a
[docs]
class SpatialExP:
[docs]
def __init__(self,
adata1,
adata2,
graph1,
graph2,
use_agg = True, ##计算损失的时候使用原始分辨率还是spot分辨率
platform = 'Xenium',
seed=0,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
weight_decay=0,
optimizer="adam",
batch_size=4096,
encoder="hgnn",
hidden_dim=512,
num_layers=2,
epochs=1000,
lr=0.001,
prune=10000,
loss_fn="mse",
num_neighbors=7,
graph_kind='spatial',
save_path=None
):
"""Initialize the SpatialEx+ trainer with cycle-style regression heads.
This trainer fits two SpatialEx+ backbones (one per slice) and two
regression mappers (:attr:`rm_AB`, :attr:`rm_BA`) to translate between
gene panels. During training it optimizes reconstruction losses and
cycle-consistency-style mapping losses.
Parameters
----------
adata1, adata2 : AnnData
Two slices with expression matrices in ``.X`` and histology
embeddings in ``.obsm['he']``.
graph1, graph2 : scipy.sparse.spmatrix or compatible
Spatial graphs for the two slices.
platform : str, default "Xenium"
Platform name forwarded to :class:`~model.Model_Plus`.
seed : int, default 0
Random seed.
device : torch.device, optional
Device on which models and tensors are placed.
weight_decay : float, default 0
Weight decay for the optimizer.
optimizer : str, default "adam"
Optimizer key understood by :func:`utils.create_optimizer`.
batch_size : int, default 4096
Kept for compatibility with other trainers (not used directly here).
encoder : str, default "hgnn"
Encoder key (kept for logging/compatibility).
hidden_dim : int, default 512
Hidden dimension of the backbone.
num_layers : int, default 2
Number of backbone layers.
epochs : int, default 1000
Training epochs.
lr : float, default 0.001
Learning rate.
loss_fn : str, default "mse"
Loss function key.
num_neighbors : int, default 7
K for KNN graph construction (kept for compatibility).
graph_kind : str, default "spatial"
Graph kind label (kept for compatibility).
save_path : str or None, optional
If provided, directory to save inference outputs.
"""
self.adata1 = adata1
self.adata2 = adata2
self.graph1 = pp.sparse_mx_to_torch_sparse_tensor(graph1).to(device)
self.graph2 = pp.sparse_mx_to_torch_sparse_tensor(graph2).to(device)
# self.graph1 = graph1
# self.graph2 = graph2
# 基础参数
self.seed = seed
self.device = device
self.weight_decay = weight_decay
self.batch_size = batch_size
self.encoder = encoder
self.hidden_dim = hidden_dim
self.num_layers = num_layers
self.epochs = epochs
self.lr = lr
self.loss_fn = loss_fn
self.save_path = save_path
self.use_agg = use_agg
self.platform = platform
# 空间参数
self.num_neighbors = num_neighbors
self.graph_kind = graph_kind
self.slice1_dataloader = pp.Build_dataloader(adata1, graph=graph1, graph_norm='hpnn', feat_norm=False,
prune=[prune, prune], drop_last=False)
self.slice2_dataloader = pp.Build_dataloader(adata2, graph=graph2, graph_norm='hpnn', feat_norm=False,
prune=[prune, prune], drop_last=False)
self.HE1, self.HE2 = torch.Tensor(adata1.obsm['he']).to(self.device), torch.Tensor(adata2.obsm['he']).to(self.device)
self.panelA1, self.panelB2 = torch.Tensor(adata1.X).to(self.device), torch.Tensor(adata2.X).to(self.device)
self.in_dim1 = adata1.obsm['he'].shape[1]
self.in_dim2 = adata2.obsm['he'].shape[1]
self.out_dim1 = adata1.n_vars
self.out_dim2 = adata2.n_vars
self.module_HA = Model_Plus(in_dim=self.in_dim1, hidden_dim=self.hidden_dim, out_dim=self.out_dim1, num_layers=self.num_layers,
platform=self.platform).to(self.device)
self.module_HB = Model_Plus(in_dim=self.in_dim2, hidden_dim=self.hidden_dim, out_dim=self.out_dim2, num_layers=self.num_layers,
platform=self.platform).to(self.device)
self.rm_AB = Regression(self.out_dim1, self.out_dim2, self.out_dim2).to(self.device)
self.rm_BA = Regression(self.out_dim2, self.out_dim1, self.out_dim1).to(self.device)
self.models = [self.module_HA, self.module_HB, self.rm_AB, self.rm_BA]
self.optimizer = create_optimizer(optimizer, self.models, self.lr, self.weight_decay)
[docs]
def train(self):
"""Train SpatialEx+ backbones and regression mappers.
The optimization combines:
1) Per-slice reconstruction losses from :class:`~model.Model_Plus`.
2) Cross-panel mapping losses for ``A->B`` and ``B->A`` via the regression
heads.
3) Cycle-style consistency losses mapping real panel expressions through
the opposite backbone.
Returns
-------
None
"""
pp.set_random_seed(self.seed)
self.module_HA.train()
self.module_HB.train()
self.rm_AB.train()
self.rm_BA.train()
print('\n')
print('=================================== Start training =========================================')
if self.platform == 'Xenium':
for epoch in tqdm(range(self.epochs)):
batch_iter = zip(self.slice1_dataloader, self.slice2_dataloader)
for data1, data2 in batch_iter:
graph1, he1, panel_1a = data1[0]['graph'].to(self.device), data1[0]['he'].to(self.device), data1[0]['exp'].to(self.device)
graph2, he2, panel_2b = data2[0]['graph'].to(self.device), data2[0]['he'].to(self.device), data2[0]['exp'].to(self.device)
agg_mtx1, agg_exp1 = data1[0]['agg_mtx'].to(self.device), data1[0]['agg_exp'].to(self.device)
agg_mtx2, agg_exp2 = data2[0]['agg_mtx'].to(self.device), data2[0]['agg_exp'].to(self.device)
loss1, _ = self.module_HA(he1, graph1, panel_1a, agg_exp1, agg_mtx1, self.use_agg)
loss2, _ = self.module_HB(he2, graph2, panel_2b, agg_exp2, agg_mtx2, self.use_agg)
panel_2a = self.module_HA.predict(he2, graph2, grad=False) ##对切片2的组学a进行预测
panel_1b = self.module_HB.predict(he1, graph1, grad=False) ##对切片1的组学b进行预测
loss3, _ = self.rm_AB(panel_1a, panel_1b, torch.spmm(agg_mtx1, panel_1b), agg_mtx1, self.use_agg) ##将切片1的组学a映射成切片1的组学b,与预测的组学b进行比较
loss4, _ = self.rm_BA(panel_2b, panel_2a, torch.spmm(agg_mtx2, panel_2a), agg_mtx2, self.use_agg) ##将切片2的组学b映射成切片2的组学a,与预测的组学a进行比较
loss5, _ = self.rm_AB(panel_2a, panel_2b, agg_exp2, agg_mtx2, self.use_agg) #对切片2的组学a进行预测,在映射回组学b,与切片2的真实标签进行比较
loss6, _ = self.rm_BA(panel_1b, panel_1a, agg_exp1, agg_mtx1, self.use_agg) #对切片1的组学b进行预测,在映射回组学a,与切片1的真实标签进行比较
loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
elif self.platform == 'Visium': ##不能用dataloader,直接用全量数据训练
for epoch in tqdm(range(self.epochs)):
loss1, _ = self.module_HA(self.HE1, self.graph1, self.panelA1, use_agg=False)
loss2, _ = self.module_HB(self.HE2, self.graph2, self.panelB2, use_agg=False)
panelA2 = self.module_HA.predict(self.HE2, self.graph2, grad=False)
panelB1 = self.module_HB.predict(self.HE1, self.graph1, grad=False)
loss3, _ = self.rm_AB(panelA2, self.panelB2, use_agg=False)
loss4, _ = self.rm_BA(panelB1, self.panelA1, use_agg=False)
loss5, _ = self.rm_AB(self.panelA1, panelB1, use_agg=False)
loss6, _ = self.rm_BA(self.panelB2, panelA2, use_agg=False)
loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
[docs]
def inference_direct(self, he, graph, panel):
"""Directly predict the specified panel with its corresponding backbone.
Parameters
----------
he : array-like
Histology embedding for the query slice.
graph : scipy.sparse.spmatrix or compatible
Sparse graph for the query slice.
panel : {"panelA", "panelB"}, default "panelA"
Which panel to predict directly.
Returns
-------
numpy.ndarray
Direct panel prediction of shape ``(n_cells, n_genes_in_panel)``.
Notes
-----
If :attr:`save_path` is set, outputs are saved as
``<panel>_direct.npy``.
"""
he = torch.Tensor(he).to(self.device)
graph = pp.sparse_mx_to_torch_sparse_tensor(graph).to(self.device)
if panel == 'panelA':
self.module_HA.eval()
omics_direct = self.module_HA.predict(he, graph, grad=False)
if panel == 'panelB':
self.module_HB.eval()
omics_direct = self.module_HB.predict(he, graph, grad=False)
if self.save_path is not None:
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
np.save(self.save_path + panel + '_direct.npy', omics_direct.detach().cpu().numpy())
print(f'The results have been sucessfully saved in {self.save_path}')
return omics_direct.detach().cpu().numpy()
[docs]
def inference_indirect(self, he, graph, panel):
"""Indirectly infer the missing panel using a regression mapper.
For ``panelB`` inference, the method first predicts panel A with
:attr:`module_HA`, then maps to panel B using :attr:`rm_AB`. For
``panelA`` inference it uses :attr:`module_HB` followed by
:attr:`rm_BA`.
Parameters
----------
he : array-like
Histology embedding for the query slice.
graph : scipy.sparse.spmatrix or compatible
Sparse graph for the query slice.
panel : {"panelA", "panelB"}, default "panelA"
Which panel to infer indirectly.
Returns
-------
numpy.ndarray
Indirect panel prediction of shape ``(n_cells, n_genes_in_target_panel)``.
Notes
-----
If :attr:`save_path` is set, outputs are saved as ``omics.npy``.
"""
he = torch.Tensor(he).to(self.device)
graph = pp.sparse_mx_to_torch_sparse_tensor(graph).to(self.device)
if panel == 'panelB':
self.module_HA.eval()
self.rm_AB.eval()
panelA1_direct = self.module_HA.predict(he, graph, grad=False)
omics_indirect = self.rm_AB.predict(panelA1_direct)
omics_indirect = omics_indirect.detach().cpu().numpy()
if panel == 'panelA':
self.module_HB.eval()
self.rm_BA.eval()
panelB2_direct = self.module_HB.predict(he, graph, grad=False)
omics_indirect = self.rm_BA.predict(panelB2_direct)
omics_indirect = omics_indirect.detach().cpu().numpy()
if self.save_path:
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
np.save(self.save_path + 'omics.npy', omics_indirect)
print(f'The results have been sucessfully saved in {self.save_path}')
return omics_indirect
'''========================= 测试 ========================'''
[docs]
def auto_inference(self):
"""Run indirect cross-panel prediction for both slices.
Returns
-------
tuple[numpy.ndarray, numpy.ndarray]
``(panelB1_indirect, panelA2_indirect)`` predictions for slice 1 and
slice 2 respectively.
Notes
-----
If :attr:`save_path` is set, outputs are saved as ``B1.npy`` and
``A2.npy``.
"""
self.module_HA.eval()
self.module_HB.eval()
self.rm_AB.eval()
self.rm_BA.eval()
'''PanelB1'''
panelA1_direct = self.module_HA.predict(self.HE1, self.graph1, grad=False)
panelB1_indirect = self.rm_AB.predict(panelA1_direct).detach().cpu().numpy()
'''PanelA2'''
panelB2_direct = self.module_HB.predict(self.HE2, self.graph2, grad=False)
panelA2_indirect = self.rm_BA.predict(panelB2_direct).detach().cpu().numpy()
if self.save_path is not None:
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
np.save(self.save_path + 'B1.npy', panelB1_indirect)
np.save(self.save_path + 'A2.npy', panelA2_indirect)
print(f'The results have been sucessfully saved in {self.save_path}')
return panelB1_indirect, panelA2_indirect
[docs]
class SpatialExP_Big:
[docs]
def __init__(self,
adata1,
adata2,
graph1,
graph2,
use_agg = True, ##计算损失的时候使用原始分辨率还是spot分辨率
num_layers=2,
hidden_dim=512,
epochs=200,
seed=0,
device=torch.device("cuda" if torch.cuda.is_available() else "cpu"),
weight_decay=0,
optimizer="adam",
batch_size=4096,
batch_num=10,
encoder="hgnn",
lr=0.001,
loss_fn="mse",
num_neighbors=7,
graph_kind='spatial',
save_path=None
):
"""Initialize the large-scale SpatialEx+ trainer using pseudo-spots.
This variant aggregates single-cell expression into pseudo-spots to
reduce memory and enable training on very large datasets. It trains a
shared big backbone (:attr:`model_big`) plus two regression mappers
(:attr:`model_AB`, :attr:`model_BA`) for cross-panel translation.
Parameters
----------
adata1, adata2 : AnnData
Two slices with histology embeddings in ``.obsm['he']`` and
expression in ``.X``.
graph1, graph2 : scipy.sparse.spmatrix or compatible
Spatial graphs for the two slices.
num_layers : int, default 2
Number of backbone layers.
hidden_dim : int, default 512
Hidden dimension of the backbone.
epochs : int, default 200
Number of training epochs.
seed : int, default 0
Random seed.
device : torch.device, optional
Device to run on.
weight_decay : float, default 0
Weight decay for the optimizer.
optimizer : str, default "adam"
Optimizer key.
batch_size : int, default 4096
Kept for compatibility (batching here is driven by ``batch_num``).
batch_num : int, default 10
Number of pseudo-spot batches per epoch.
encoder : str, default "hgnn"
Encoder key (kept for compatibility).
lr : float, default 0.001
Learning rate.
loss_fn : str, default "mse"
Loss function key.
num_neighbors : int, default 7
K for KNN (kept for compatibility).
graph_kind : str, default "spatial"
Graph kind label (kept for compatibility).
save_path : str or None, optional
Directory to save inference outputs.
"""
self.adata1 = adata1
self.adata2 = adata2
self.graph1 = graph1,
self.graph2 = graph2,
self.num_layers = num_layers
self.hidden_dim = hidden_dim
self.epochs = epochs
self.seed = seed
self.device = device
self.weight_decay = weight_decay
self.use_agg = use_agg
self.batch_size = batch_size
self.batch_num = batch_num
self.encoder = encoder
self.lr = lr
self.loss_fn = loss_fn
self.num_neighbors = num_neighbors
self.graph_kind = graph_kind
self.save_path = save_path
self.in_dim1 = self.adata1.obsm['he'].shape[1]
self.in_dim2 = self.adata2.obsm['he'].shape[1]
self.out_dim1 = self.adata1.n_vars
self.out_dim2 = self.adata2.n_vars
# H1 = pp.Build_hypergraph_spatial_and_HE(adata1, num_neighbors, batch_size, False, 'spatial', 'crs')
_, _, adata1 = Generate_pseudo_spot(adata1, all_in=True)
spot_id = adata1.obs['spot'].values
head = spot_id[~pd.isna(adata1.obs['spot'])].astype(int)
tail = np.where(~pd.isna(adata1.obs['spot']))[0]
values = np.ones_like(tail)
self.agg_mtx1 = sp.coo_matrix((values, (head, tail)), shape=(head.max() + 1, adata1.n_obs)).tocsr()
self.spot_A1 = torch.Tensor(self.agg_mtx1 @ adata1.X)
# H2 = pp.Build_hypergraph_spatial_and_HE(adata2, num_neighbors, batch_size, False, 'spatial', 'crs')
_, _, adata2 = Generate_pseudo_spot(adata2, all_in=True)
spot_id = adata2.obs['spot'].values
head = spot_id[~pd.isna(adata2.obs['spot'])].astype(int)
tail = np.where(~pd.isna(adata2.obs['spot']))[0]
values = np.ones_like(tail)
self.agg_mtx2 = sp.coo_matrix((values, (head, tail)), shape=(head.max()+1, adata2.n_obs)).tocsr()
self.spot_B2 = torch.Tensor(self.agg_mtx2 @ adata2.X)
self.HE1, self.HE2 = torch.Tensor(adata1.obsm['he']), torch.Tensor(adata2.obsm['he'])
self.panelA1, self.panelB2 = torch.Tensor(adata1.X), torch.Tensor(adata2.X)
self.model_big = Model_Big([graph1, graph2], [self.in_dim1, self.in_dim2], [self.out_dim1, self.out_dim2], num_layers=self.num_layers,
hidden_dim=self.hidden_dim, device=self.device).to(self.device)
self.model_AB = Regression(self.out_dim1, int(self.out_dim1/2), self.out_dim2).to(self.device)
self.model_BA = Regression(self.out_dim2, int(self.out_dim1/2), self.out_dim1).to(self.device)
self.models = [self.model_big, self.model_AB, self.model_BA]
self.optimizer = create_optimizer(optimizer, self.models, self.lr, self.weight_decay)
[docs]
def train(self):
"""Train the big model with pseudo-spot batching.
Each epoch shuffles pseudo-spot indices and iterates over ``batch_num``
batches. For each batch it:
1) Computes losses for the shared backbone on both slices.
2) Generates exchanged predictions and trains regression mappers.
3) Applies reconstruction-style mapping losses using the aggregation
matrices.
Returns
-------
None
"""
batch_num = self.batch_num
obs_index1 = list(range(self.agg_mtx1.shape[0]))
obs_index2 = list(range(self.agg_mtx2.shape[0]))
batch_size1 = int(self.agg_mtx1.shape[0]/batch_num)
batch_size2 = int(self.agg_mtx2.shape[0]/batch_num)
for epoch in range(self.epochs):
random.shuffle(obs_index1)
random.shuffle(obs_index2)
batch_iter = tqdm(range(batch_num), leave=False)
for batch_idx in batch_iter:
torch.cuda.empty_cache()
tgt_spot1 = obs_index1[batch_idx*batch_size1:(batch_idx+1)*batch_size1]
tgt_cell1 = self.agg_mtx1[tgt_spot1].tocoo().col
sub_agg_mtx1 = self.agg_mtx1[tgt_spot1][:,tgt_cell1]
sub_agg_mtx1 = pp.sparse_mx_to_torch_sparse_tensor(sub_agg_mtx1).to(self.device)
spot_A1_batch = self.spot_A1[tgt_spot1].to(self.device)
tgt_spot2 = obs_index2[batch_idx*batch_size2:(batch_idx+1)*batch_size2]
tgt_cell2 = self.agg_mtx2[tgt_spot2].tocoo().col
sub_agg_mtx2 = self.agg_mtx2[tgt_spot2][:,tgt_cell2]
sub_agg_mtx2 = pp.sparse_mx_to_torch_sparse_tensor(sub_agg_mtx2).to(self.device)
spot_B2_batch = self.spot_B2[tgt_spot2].to(self.device)
loss1, loss2 = self.model_big([tgt_cell1, tgt_cell2], [self.HE1, self.HE2], [spot_A1_batch, spot_B2_batch], [sub_agg_mtx1, sub_agg_mtx2])
x_prime = self.model_big.predict([tgt_cell1, tgt_cell2], [self.HE1, self.HE2], exchange=True, which='both', grad=False)
panel_A2, panel_B1 = x_prime[0], x_prime[1]
loss3, _ = self.model_AB(panel_A2, self.panelB2[tgt_cell2].to(self.device), spot_B2_batch, sub_agg_mtx2, self.use_agg)
loss4, _ = self.model_BA(panel_B1, self.panelA1[tgt_cell1].to(self.device), spot_A1_batch, sub_agg_mtx1, self.use_agg)
loss5, _ = self.model_AB(self.panelA1[tgt_cell1].to(self.device), panel_B1, torch.spmm(sub_agg_mtx1, panel_B1), sub_agg_mtx1, self.use_agg)
loss6, _ = self.model_BA(self.panelB2[tgt_cell2].to(self.device), panel_A2, torch.spmm(sub_agg_mtx2, panel_A2), sub_agg_mtx2, self.use_agg)
loss = loss1 + loss2 + loss3 + loss4 + loss5 + loss6
self.optimizer.zero_grad()
loss.backward()
self.optimizer.step()
batch_iter.set_description(f"#Epoch {epoch}, loss: {round(loss.item(), 2)}")
[docs]
def auto_inference(self):
"""Run indirect cross-panel prediction for both original slices.
Returns
-------
tuple[numpy.ndarray, numpy.ndarray]
``(indirect_panel_B1, indirect_panel_A2)`` predictions.
Notes
-----
If :attr:`save_path` is set, outputs are saved as ``B1.npy`` and
``A2.npy``.
"""
self.model_big.eval()
self.model_AB.eval()
self.model_BA.eval()
obs_index1 = list(range(self.HE1.shape[0]))
obs_index2 = list(range(self.HE2.shape[0]))
batch_size1 = int(np.ceil(self.HE1.shape[0]/self.batch_num))
batch_size2 = int(np.ceil(self.HE2.shape[0]/self.batch_num))
batch_iter = tqdm(range(self.batch_num), leave=False)
indirect_panel_B1_list = []
indirect_panel_A2_list = []
tgt_id1_list = []
tgt_id2_list = []
for batch_idx in batch_iter:
tgt_id1 = obs_index1[batch_idx*batch_size1:min((batch_idx+1)*batch_size1, self.HE1.shape[0])]
tgt_id2 = obs_index2[batch_idx*batch_size2:min((batch_idx+1)*batch_size2, self.HE2.shape[0])]
x_prime = self.model_big.predict([tgt_id1, tgt_id2], [self.HE1, self.HE2], exchange=False, which='both')
panel_A1_predict, panel_B2_predict = x_prime[0], x_prime[1]
indirect_panel_B1 = self.model_AB.predict(panel_A1_predict)
indirect_panel_A2 = self.model_BA.predict(panel_B2_predict)
tgt_id1_list = tgt_id1_list + tgt_id1
tgt_id2_list = tgt_id2_list + tgt_id2
indirect_panel_A2_list.append(indirect_panel_A2.detach().cpu().numpy())
indirect_panel_B1_list.append(indirect_panel_B1.detach().cpu().numpy())
indirect_panel_A2_list = np.vstack(indirect_panel_A2_list)
indirect_panel_B1_list = np.vstack(indirect_panel_B1_list)
if self.save_path is not None:
if not os.path.exists(self.save_path):
os.mkdir(self.save_path)
np.save(self.save_path + 'B1.npy', indirect_panel_B1_list)
np.save(self.save_path + 'A2.npy', indirect_panel_A2_list)
print(f'The results have been sucessfully saved in {self.save_path}')
return indirect_panel_B1_list, indirect_panel_A2_list