Mapping Cells to Healthy Tissue Reference Using Human-scATAC-Corpus¶

This tutorial demonstrates how to map cells from your dataset onto the healthy in vivo tissue cell atlas provided by the Human-scATAC-Corpus. You can choose to perform the mapping either at the individual cell level for fine-grained alignment or at the cell type level for broader categorization. This flexibility allows you to tailor the analysis according to your specific research needs and resolution requirements.

Required files:

  • pretrained_EpiAgent.pth: The pretrained EpiAgent model.

  • cCRE_document_frequency.npy: The input files for TF-IDF.

  • mean_embeddings.pkl: The reference embeddings of healthy in vivo tissue cells

These files are available at the following link:

  • https://drive.google.com/drive/folders/1WlNykSCNtZGsUp2oG0dw3cDdVKYDR-iX?usp=sharing.

  • https://github.com/xy-chen16/EpiAgent/tree/main/data

Step 1: Data Processing (TFIDF and Tokenization)¶

To prepare the dataset for EpiAgent, we perform the following preprocessing steps:

1.TFIDF Transformation: Convert discrete count data into continuous importance scores for accessible cCREs.

2.Tokenization: Generate cell_sentences to represent each cell as a sequence of accessible cCRE indices.

import scanpy as sc
import numpy as np
from epiagent.tokenization import tokenization
from epiagent.preprocessing import global_TFIDF

# Load the dataset
input_path = './demo_mapping.h5ad'
adata = sc.read_h5ad(input_path)

# Load the cCRE document frequency data
cCRE_document_frequency = np.load('./cCRE_document_frequency.npy')

# Apply TFIDF transformation
global_TFIDF(adata, cCRE_document_frequency)

# Perform tokenization
tokenization(adata)
Tokenization complete: 'cell_sentences' column added to adata.obs.

Step 2: Create Dataset and DataLoader¶

We create a PyTorch-compatible Dataset and DataLoader to handle tokenized cell_sentences from the processed AnnData object. Each cell is represented as a sequence of tokens with special tokens [CLS] and [SEP].

from epiagent.dataset import CellDataset, collate_fn
from torch.utils.data import DataLoader

# Create the dataset
cell_sentences = adata.obs['cell_sentences'].tolist()
cell_dataset = CellDataset(cell_sentences=cell_sentences)

# Create the DataLoader
batch_size = 8
dataloader = DataLoader(cell_dataset, batch_size=batch_size, shuffle=False, num_workers=4, collate_fn=collate_fn)

Step 3: Load Pretrained EpiAgent Model¶

The pretrained EpiAgent model (pretrained_EpiAgent.pth) is loaded for zero-shot inference to compute cell embeddings.

from epiagent.model import EpiAgent
import torch

# Load the pretrained model
model_path = './pretrained_EpiAgent.pth'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

pretrained_model = EpiAgent(vocab_size=1355449, num_layers=18, embedding_dim=512, num_attention_heads=8, max_rank_embeddings=8192, use_flash_attn=True, pos_weight_for_RLM=torch.tensor(1.), pos_weight_for_CCA=torch.tensor(1.))
pretrained_model.load_state_dict(torch.load(model_path, map_location=device))
2025-09-18 21:48:33.192432: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2025-09-18 21:48:33.206807: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:467] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1758203313.224459  530252 cuda_dnn.cc:8579] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1758203313.229672  530252 cuda_blas.cc:1407] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
W0000 00:00:1758203313.243274  530252 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758203313.243299  530252 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758203313.243301  530252 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
W0000 00:00:1758203313.243302  530252 computation_placer.cc:177] computation placer already registered. Please check linkage and avoid linking the same target more than once.
2025-09-18 21:48:33.248049: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Out[3]:
<All keys matched successfully>

Step 4: Extract Cell Embeddings¶

The infer_cell_embeddings function is used to compute cell embeddings using the pretrained EpiAgent model.

from epiagent.inference import infer_cell_embeddings

# Extract cell embeddings
cell_embeddings = infer_cell_embeddings(pretrained_model, device, dataloader)

Step 5: Load the reference embeddings of healthy in vivo tissue cells¶

import pickle

with open('mean_embeddings.pkl', 'rb') as f:
    mean_embeddings = pickle.load(f)

Step 6: Mapping Cells to Reference Embeddings¶

This step covers two scenarios for mapping your query data to reference embeddings from the Human-scATAC-Corpus:

  • Cell Type Level Mapping: Map average embeddings of each cell type in your data to the reference.
  • Individual Cell Level Mapping: Map each individual cell embedding to its nearest neighbor in the reference.

6.1 Cell Type Level Mapping¶

This approach aggregates query embeddings by cell type label and finds the top-10 nearest neighbors for each cell type.

import os
import numpy as np
import pandas as pd
import faiss
import matplotlib.pyplot as plt
import zipfile
from openpyxl import Workbook
from openpyxl.utils.dataframe import dataframe_to_rows

# Get the label key from arguments
label_key = 'Cancer'  # e.g. 'Cancer'

# Check if the label_key column exists in adata.obs
if label_key not in adata.obs.columns:
    raise KeyError(f"[Error] The key '{label_key}' is not found in adata.obs. Available keys: {list(adata.obs.columns)}")

# Extract cell types
celltypes = adata.obs[label_key].astype(str)

# Compute average embedding for each cell type in the query data
query_embeddings = {}
for ct in np.unique(celltypes):
    mask = celltypes == ct
    if mask.sum() == 0:
        print(f"[Warning] No cells found for {ct}")
        continue
    avg_emb = cell_embeddings[mask.values].mean(axis=0).astype('float32')
    query_embeddings[ct] = avg_emb

# Prepare reference embeddings and keys
all_keys = list(mean_embeddings.keys())  # tuples like (dataset, celltype)
all_vectors = np.stack([mean_embeddings[k] for k in all_keys]).astype('float32')

# Build FAISS index
index = faiss.IndexFlatL2(all_vectors.shape[1])
index.add(all_vectors)

# Create output directory and Excel workbook
output_dir = "demo_mapping_output_cell_type"
os.makedirs(output_dir, exist_ok=True)
wb = Workbook()
wb.remove(wb.active)  # Remove default sheet
figure_paths = []

# Search and record top-10 nearest neighbors for each cell type embedding
for query_ct, query_vec in query_embeddings.items():
    D, I = index.search(query_vec.reshape(1, -1), 10)
    records = []
    for dist, idx in zip(D[0], I[0]):
        dataset, ref_ct = all_keys[idx]
        records.append((dataset, ref_ct, dist))
    df = pd.DataFrame(records, columns=["Dataset", "CellType", "Distance"])

    # Save results in Excel sheet
    sheet = wb.create_sheet(title=query_ct[:31])  # Excel sheet name max length = 31
    for row in dataframe_to_rows(df, index=False, header=True):
        sheet.append(row)

    # Plot bar chart of distances
    fig, ax = plt.subplots(figsize=(8, 5))
    labels = [f"{d} | {c}" for d, c in zip(df["Dataset"], df["CellType"])]
    ax.barh(labels, df["Distance"])
    ax.invert_yaxis()
    ax.set_xlabel("Distance")
    ax.set_title(f"Top 10 Nearest Neighbors for {query_ct}")
    fig.tight_layout()

    fig_path = os.path.join(output_dir, f"{query_ct}_mapping_umap32.pdf")
    fig.savefig(fig_path)
    plt.show(fig)
    plt.close(fig)
    figure_paths.append(fig_path)

# Save Excel workbook
excel_path = os.path.join(output_dir, "mapping_results.xlsx")
wb.save(excel_path)

# Zip Excel and figures
zip_path = os.path.join(output_dir, "celltype_mapping_results.zip")
with zipfile.ZipFile(zip_path, 'w') as zipf:
    zipf.write(excel_path, arcname="mapping_results.xlsx")
    for fp in figure_paths:
        zipf.write(fp, arcname=os.path.basename(fp))

print(f"Cell type level mapping saved to: {zip_path}")
No description has been provided for this image
Cell type level mapping saved to: demo_mapping_output_cell_type/celltype_mapping_results.zip

6.2 Individual Cell Level Mapping¶

This approach maps each individual query cell embedding to its nearest neighbor in the reference.

import numpy as np
import pandas as pd
import faiss

# Prepare reference embeddings and keys
all_keys = list(mean_embeddings.keys())  # tuples like (dataset, celltype)
all_vectors = np.stack([mean_embeddings[k] for k in all_keys]).astype('float32')

# Build FAISS index
index = faiss.IndexFlatL2(all_vectors.shape[1])
index.add(all_vectors)

# For each query cell embedding, find top-1 nearest neighbor
nearest_datasets = []
nearest_celltypes = []

for cell_emb in cell_embeddings:
    cell_emb = cell_emb.reshape(1, -1).astype('float32')
    D, I = index.search(cell_emb, 1)
    dataset, celltype = all_keys[I[0][0]]
    nearest_datasets.append(dataset)
    nearest_celltypes.append(celltype)

# Add results to adata.obs and save as CSV
adata.obs["Mapping_Dataset"] = nearest_datasets
adata.obs["Mapping_Dataset_CellType"] = nearest_celltypes

output_path = "demo_mapping_output_cell.csv"
adata.obs.to_csv(output_path)

print(f"Individual cell level mapping saved to: {output_path}")
Individual cell level mapping saved to: demo_mapping_output_cell.csv