Using Pretrained EpiAgent for Zero-Shot Cell Embeddings on Your Dataset¶

This tutorial provides an example workflow using the pretrained EpiAgent model, trained on the large-scale Human-scATAC-Corpus (~5 million cells and 35 billion tokens). It demonstrates zero-shot inference to extract cell embeddings from the demo dataset. You can follow a similar procedure to apply this method to your own dataset.

Required files:

  • pretrained_EpiAgent.pth: The pretrained EpiAgent model.

  • cCRE_document_frequency.npy: The input files for TF-IDF.

These files are available at the following link:

  • https://drive.google.com/drive/folders/1WlNykSCNtZGsUp2oG0dw3cDdVKYDR-iX?usp=sharing.

  • https://github.com/xy-chen16/EpiAgent/tree/main/data

Step 1: Data Processing (TFIDF and Tokenization)¶

To prepare the dataset for EpiAgent, we perform the following preprocessing steps:

1.TFIDF Transformation: Convert discrete count data into continuous importance scores for accessible cCREs.

2.Tokenization: Generate cell_sentences to represent each cell as a sequence of accessible cCRE indices.

import scanpy as sc
import numpy as np
from epiagent.tokenization import tokenization
from epiagent.preprocessing import global_TFIDF

# Load the dataset
input_path = './demo-cell_embedding.h5ad'
adata = sc.read_h5ad(input_path)

# Load the cCRE document frequency data
cCRE_document_frequency = np.load('./cCRE_document_frequency.npy')

# Apply TFIDF transformation
global_TFIDF(adata, cCRE_document_frequency)

# Perform tokenization
tokenization(adata)
Tokenization complete: 'cell_sentences' column added to adata.obs.

Step 2: Create Dataset and DataLoader¶

We create a PyTorch-compatible Dataset and DataLoader to handle tokenized cell_sentences from the processed AnnData object. Each cell is represented as a sequence of tokens with special tokens [CLS] and [SEP].

from epiagent.dataset import CellDataset, collate_fn
from torch.utils.data import DataLoader

# Create the dataset
cell_sentences = adata.obs['cell_sentences'].tolist()
cell_dataset = CellDataset(cell_sentences=cell_sentences)

# Create the DataLoader
batch_size = 8
dataloader = DataLoader(cell_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

Step 3: Load Pretrained EpiAgent Model¶

The pretrained EpiAgent model (pretrained_EpiAgent.pth) is loaded for zero-shot inference to compute cell embeddings.

from epiagent.model import EpiAgent
import torch

# Load the pretrained model
model_path = './pretrained_EpiAgent.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

pretrained_model = EpiAgent(vocab_size=1355449, num_layers=18, embedding_dim=512, num_attention_heads=8, max_rank_embeddings=8192, use_flash_attn=True, pos_weight_for_RLM=torch.tensor(1.), pos_weight_for_CCA=torch.tensor(1.))
pretrained_model.load_state_dict(torch.load(model_path, map_location=device))
2025-08-10 21:43:22.349387: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-08-10 21:43:22.359354: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1754833402.371958 1104784 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1754833402.375742 1104784 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1754833402.386020 1104784 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754833402.386032 1104784 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754833402.386034 1104784 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1754833402.386035 1104784 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-08-10 21:43:22.389709: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI AVX512_BF16 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
<All keys matched successfully>

Step 4: Extract Cell Embeddings¶

The infer_cell_embeddings function is used to compute cell embeddings using the pretrained EpiAgent model.

from epiagent.inference import infer_cell_embeddings

# Extract cell embeddings
cell_embeddings = infer_cell_embeddings(pretrained_model, device, dataloader)

Step 5: UMAP Visualization¶

Perform UMAP visualization to project the embeddings into a 2D space.

import scanpy as sc

# Assign embeddings to the AnnData object
adata.obsm['cell_embeddings_zero_shot'] = cell_embeddings

# UMAP visualization
sc.pp.neighbors(adata, use_rep='cell_embeddings_zero_shot')
sc.tl.umap(adata)

# Plot UMAP with original cell types
sc.pl.umap(adata, color='Cell_type (HSC)')

# # Save the processed AnnData
# output_path = './demo_output.h5ad'
# adata.write(output_path)
# print(f"Processed AnnData saved at {output_path}")
No description has been provided for this image