π Model Summary
ScDiVa (Single-cell Deep Variational Analysis) is a 94.5M parameter foundation model pre-trained on 59 million single-cell transcriptomes. It utilizes a novel Masked Discrete Diffusion framework to model gene expression as an unordered set, effectively capturing the complex topology of gene regulatory networks.
Unlike traditional autoregressive models, ScDiVa employs a bidirectional Transformer encoder with SwiGLU activations, Rotary Positional Embeddings (RoPE), and RMSNorm, optimized for:
- Reconstruction
- Cell Type Annotation
- Multi-batch Integration
- Gene Perturbation Prediction
- Gene Regulatory Network (GRN) Inference
ποΈ Model Specifications
| Attribute | Value |
|---|---|
| Parameters | ~94.5M |
| Layers | 12 |
| Hidden Size | 512 |
| Attention Heads | 8 |
| Max Sequence Length | 1,200 genes |
| Vocabulary | 41,818 genes |
| Training Objective | Dual Denoising (Identity Classification + Value Regression) |
π Quick Start
To use ScDiVa, you need the modeling_scdiva.py file (included in this repository).
1. Installation
pip install torch numpy huggingface_hub
2. Loading the Pre-trained Model
You can load the model directly using the from_pretrained method defined in our architecture.
from modeling_scdiva import ScDiVaModel
import torch
# Load the model directly from Hugging Face
# This will automatically download model.safetensors and config
model = ScDiVaModel.from_pretrained("warming666/ScDiVa")
model.eval()
# Move to GPU if available
device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)
print(f"β
ScDiVa loaded successfully on {device}")
3. Basic Inference Example
# Create a dummy input (Batch Size: 2, Num Genes: 41818)
# In practice, replace this with your normalized gene expression matrix
input_data = torch.randn(2, 41818).to(device)
with torch.no_grad():
# Get latent embeddings (for clustering/integration)
outputs = model.encode(input_data)
embeddings = outputs['latent']
print(f"Latent Embedding Shape: {embeddings.shape}") # [2, 128]
# Get annotation logits
predictions = model.predict(input_data, task="annotation")
print(f"Annotation Logits Shape: {predictions.shape}") # [2, 100]
π Repository Structure
This repository contains the core pre-trained weights and fine-tuned checkpoints for downstream tasks.
warming666/ScDiVa
βββ config.json # Model configuration
βββ model.safetensors # π₯ Pre-trained Base Weights (94.5M)
βββ modeling_scdiva.py # Model architecture definition code
βββ downstream/ # π Fine-tuned Checkpoints
βββ Multi-batch_Integration/
β βββ immune.pt
β βββ pbmc12k.pt
β βββ ...
βββ Annotation_FT/ # Fine-tuned for specific tissues
β βββ hpancreas.pt
β βββ ms.pt
βββ Annotation_Zeroshot/ # Weights for zero-shot projection
βββ Perturbation/ # Weights for gene perturbation tasks
To load a specific downstream model (e.g., for Batch Integration on Immune dataset), you can download the specific .pt file from the downstream folder and load it using torch.load().
π Benchmarks
ScDiVa achieves state-of-the-art performance across multiple benchmarks:
- Batch Integration: Top-tier performance on PBMC12k (Avg-Bio: 0.9566) and BMMC datasets.
- Annotation: 98.6% accuracy on hPancreas fine-tuning; 91.4% average accuracy on zero-shot tasks.
- Perturbation: Pearson correlation of 0.837 on Adamson dataset.
For detailed results, please refer to our arXiv paper.
β οΈ Limitations & Bias
- Input Normalization: The model expects log-normalized gene expression data. Raw counts may lead to suboptimal performance.
- Gene Vocabulary: Inputs must be aligned to the specific 41,818 gene vocabulary used during pre-training.
- Not for Clinical Use: This model is for research purposes only and has not been validated for clinical diagnosis or treatment.
π Citation
If you use ScDiVa in your research, please cite:
@article{wang2026scdiva,
title={ScDiva: Masked Discrete Diffusion for Joint Modeling of Single-Cell Identity and Expression},
author={Wang, Mingxuan and Chen, Cheng and Jiang, Gaoyang and Ren, Zijia and Zhao, Chuangxin and Shi, Lu and Ma, Yanbiao},
journal={arXiv preprint arXiv:2602.03477},
year={2026}
}
- Downloads last month
- 41