spcoral.model.regist_model#
- class spcoral.model.regist_model(adata_omics1, adata_omics2, graph_method, k_spatial_omics1=None, radius_spatial_omics1=None, k_spatial_omics2=None, radius_spatial_omics2=None, use_obsm='spatial', n_layer=[1, 2, 3, 4, 5], alpha=0.1, device=device(type='cuda', index=0), random_seed=2024, strict_repro=False, learning_rate=0.001, weight_decay=0.0001, epochs=100, gradient_clipping=5.0, hidden_dim_shared=32, out_dim_shared=8, hidden_dim_pcc=32, out_dim_pcc=8, GAN_batch_d_per_iter=5)#
Bases:
objectA PyTorch-based model for cross-modality spatial omics integration via graph attention networks and adversarial alignment.
- Parameters:
adata_omics1 (anndata.AnnData) – First omics AnnData object. Must contain
obsm['feat'](modality-specific features) and spatial coordinates (viause_obsm).adata_omics2 (anndata.AnnData) – Second omics AnnData object with the same requirements.
graph_method (str) – Method used to construct spatial graphs (passed to
adata_to_dgl; e.g., ‘knn’ or ‘radius’).k_spatial_omics1 (int, optional) – Number of nearest neighbors for omics1 graph when
graph_method='knn'.radius_spatial_omics1 (float, optional) – Radius for omics1 graph when
graph_method='radius'.k_spatial_omics2 (int, optional) – Number of nearest neighbors for omics2 graph.
radius_spatial_omics2 (float, optional) – Radius for omics2 graph.
use_obsm (str, optional (default: 'spatial')) – Key in
.obsmcontaining spatial coordinates.n_layer (list of int, optional (default: [1, 2, 3, 4, 5])) – Orders of neighborhood aggregation (hop distances) for morphological feature computation. The actual layers used are
[0] + n_layer.alpha (float, optional (default: 0.1)) – Weight balancing reconstruction loss and adversarial (GAN) loss.
device (torch.device, optional (default: torch.device('cuda:0'))) – Device for training.
random_seed (int, optional (default: 2024)) – Random seed for reproducibility.
strict_repro (bool, optional (default: False)) – If True, enforces stricter reproducibility (e.g., deterministic CUDA operations).
learning_rate (float, optional (default: 0.001)) – Learning rate for the main model optimizer.
weight_decay (float, optional (default: 0.0001)) – Weight decay for both optimizers.
epochs (int, optional (default: 100)) – Number of training epochs.
gradient_clipping (float, optional (default: 5.0)) – Maximum gradient norm for clipping.
hidden_dim_shared (int, optional (default: 32)) – Hidden dimension in the shared feature branch.
out_dim_shared (int, optional (default: 8)) – Output dimension of the shared embedding.
hidden_dim_pcc (int, optional (default: 32)) – Hidden dimension in modality-specific (PCC) branches.
out_dim_pcc (int, optional (default: 8)) – Output dimension of modality-specific embeddings.
GAN_batch_d_per_iter (int, optional (default: 5)) – Number of discriminator updates per generator update in the Wasserstein GAN.
- adata_omics1, adata_omics2
Processed AnnData objects with added embeddings after training.
- Type:
- model#
The main COM_NET encoder-decoder model (instantiated during
train).- Type:
- __init__(adata_omics1, adata_omics2, graph_method, k_spatial_omics1=None, radius_spatial_omics1=None, k_spatial_omics2=None, radius_spatial_omics2=None, use_obsm='spatial', n_layer=[1, 2, 3, 4, 5], alpha=0.1, device=device(type='cuda', index=0), random_seed=2024, strict_repro=False, learning_rate=0.001, weight_decay=0.0001, epochs=100, gradient_clipping=5.0, hidden_dim_shared=32, out_dim_shared=8, hidden_dim_pcc=32, out_dim_pcc=8, GAN_batch_d_per_iter=5)#
- Parameters:
adata_omics1 (AnnData) –
adata_omics2 (AnnData) –
graph_method (str) –
k_spatial_omics1 (int | None) –
radius_spatial_omics1 (float | None) –
k_spatial_omics2 (int | None) –
radius_spatial_omics2 (float | None) –
use_obsm (str) –
alpha (float) –
device (device) –
random_seed (int) –
strict_repro (bool) –
learning_rate (float) –
weight_decay (float) –
epochs (int) –
gradient_clipping (float) –
hidden_dim_shared (int) –
out_dim_shared (int) –
hidden_dim_pcc (int) –
out_dim_pcc (int) –
GAN_batch_d_per_iter (int) –
Methods
- train()#
Train the integration model.
- Returns:
adata_omics1 : anndata.AnnData Updated first omics object with
obsm['embedding']andobsm['share_feature'].adata_omics2 : anndata.AnnData Updated second omics object with the same added keys.
loss_list : list of [reconstruction_loss, gan_loss] per epoch Training loss history.
- Return type: