Using Pretrained EpiAgent for Zero-Shot Cell Embeddings on Your Dataset¶
This tutorial provides an example workflow using the pretrained EpiAgent model, trained on the large-scale Human-scATAC-Corpus (~5 million cells and 35 billion tokens). It demonstrates zero-shot inference to extract cell embeddings from the demo dataset. You can follow a similar procedure to apply this method to your own dataset.
Required files:
pretrained_EpiAgent.pth: The pretrained EpiAgent model.
cCRE_document_frequency.npy: The input files for TF-IDF.
These files are available at the following link:
Step 1: Data Processing (TFIDF and Tokenization)¶
To prepare the dataset for EpiAgent, we perform the following preprocessing steps:
1.TFIDF Transformation: Convert discrete count data into continuous importance scores for accessible cCREs.
2.Tokenization: Generate cell_sentences to represent each cell as a sequence of accessible cCRE indices.
import scanpy as sc
import numpy as np
from epiagent.tokenization import tokenization
from epiagent.preprocessing import global_TFIDF
# Load the dataset
input_path = './demo-cell_embedding.h5ad'
adata = sc.read_h5ad(input_path)
# Load the cCRE document frequency data
cCRE_document_frequency = np.load('./cCRE_document_frequency.npy')
# Apply TFIDF transformation
global_TFIDF(adata, cCRE_document_frequency)
# Perform tokenization
tokenization(adata)
Tokenization complete: 'cell_sentences' column added to adata.obs.
Step 2: Create Dataset and DataLoader¶
We create a PyTorch-compatible Dataset and DataLoader to handle tokenized cell_sentences from the processed AnnData object. Each cell is represented as a sequence of tokens with special tokens [CLS] and [SEP].
from epiagent.dataset import CellDataset, collate_fn
from torch.utils.data import DataLoader
# Create the dataset
cell_sentences = adata.obs['cell_sentences'].tolist()
cell_dataset = CellDataset(cell_sentences=cell_sentences)
# Create the DataLoader
batch_size = 8
dataloader = DataLoader(cell_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)
Step 3: Load Pretrained EpiAgent Model¶
The pretrained EpiAgent model (pretrained_EpiAgent.pth) is loaded for zero-shot inference to compute cell embeddings.
from epiagent.model import EpiAgent
import torch
# Load the pretrained model
model_path = './pretrained_EpiAgent.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
pretrained_model = EpiAgent(vocab_size=1355449, num_layers=18, embedding_dim=512, num_attention_heads=8, max_rank_embeddings=8192, use_flash_attn=True, pos_weight_for_RLM=torch.tensor(1.), pos_weight_for_CCA=torch.tensor(1.))
pretrained_model.load_state_dict(torch.load(model_path, map_location=device))
2025-08-10 21:43:22.349387: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`. 2025-08-10 21:43:22.359354: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered WARNING: All log messages before absl::InitializeLog() is called are written to STDERR E0000 00:00:1754833402.371958 1104784 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered E0000 00:00:1754833402.375742 1104784 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered W0000 00:00:1754833402.386020 1104784 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1754833402.386032 1104784 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1754833402.386034 1104784 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. W0000 00:00:1754833402.386035 1104784 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once. 2025-08-10 21:43:22.389709: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
<All keys matched successfully>
Step 4: Extract Cell Embeddings¶
The infer_cell_embeddings function is used to compute cell embeddings using the pretrained EpiAgent model.
from epiagent.inference import infer_cell_embeddings
# Extract cell embeddings
cell_embeddings = infer_cell_embeddings(pretrained_model, device, dataloader)
Step 5: UMAP Visualization¶
Perform UMAP visualization to project the embeddings into a 2D space.
import scanpy as sc
# Assign embeddings to the AnnData object
adata.obsm['cell_embeddings_zero_shot'] = cell_embeddings
# UMAP visualization
sc.pp.neighbors(adata, use_rep='cell_embeddings_zero_shot')
sc.tl.umap(adata)
# Plot UMAP with original cell types
sc.pl.umap(adata, color='Cell_type (HSC)')
# # Save the processed AnnData
# output_path = './demo_output.h5ad'
# adata.write(output_path)
# print(f"Processed AnnData saved at {output_path}")