Mnn_correct needs entire var_subset expressed in all groups?

I have the following which crashes at the call to mnn_correct on the last line:

script

import numpy as np
import pandas as pd
import scanpy as sc
import gzip
import matplotlib.pyplot as plt
import os
import scanorama
import mnnpy

sc.settings.verbosity = 3
sc.logging.print_header()
sc.settings.set_figure_params(dpi=80, facecolor='white')

# # set directory
datadir = "Pellin2019"
savedir = "Pellin2019/saves"
figdir = "Pellin2019/figs"

os.makedirs(figdir, exist_ok=True)
os.makedirs(savedir, exist_ok=True)
os.chdir(datadir)

# # data files
fn = "Pellin2019/GSM3305359_HSC.raw_counts.tsv.gz"
hsc = pd.read_csv(fn, compression='gzip', header=None, sep='\t', quotechar='"')
hsc = hsc.transpose()
hsc.columns = hsc.iloc[0,:]
hsc.drop(0,inplace=True)
hsc.index = hsc['Barcode'].values
hsc.drop(columns=['Barcode','Library'],inplace=True)
adata_hsc = sc.AnnData(hsc)
adata_hsc.obs = adata_hsc.obs.assign(batch=0)
adata_hsc.obs = adata_hsc.obs.assign(celltype='hsc')
adata_hsc.obs.index.name = "barcode"
adata_hsc.var.index.name = "gene"

fn = "Pellin2019/GSM3305360_MPP.raw_counts.tsv.gz"
mpp = pd.read_csv(fn, compression='gzip', header=None, sep='\t', quotechar='"')
mpp = mpp.transpose()
mpp.columns = mpp.iloc[0,:]
mpp.drop(0,inplace=True)
mpp.index = mpp['Barcode'].values
mpp.drop(columns=['Barcode','Library'],inplace=True)
adata_mpp = sc.AnnData(mpp)
adata_mpp.obs = adata_mpp.obs.assign(batch=1)
adata_mpp.obs = adata_mpp.obs.assign(celltype='mpp')
adata_mpp.obs.index.name = "barcode"
adata_mpp.var.index.name = "gene"

fn = "Pellin2019/GSM3305361_MLP.raw_counts.tsv.gz"
mlp = pd.read_csv(fn, compression='gzip', header=None, sep='\t', quotechar='"')
mlp = mlp.transpose()
mlp.columns = mlp.iloc[0,:]
mlp.drop(0,inplace=True)
mlp.index = mlp['Barcode'].values
mlp.drop(columns=['Barcode','Library'],inplace=True)
adata_mlp = sc.AnnData(mlp)
adata_mlp.obs = adata_mlp.obs.assign(batch=2)
adata_mlp.obs = adata_mlp.obs.assign(celltype='mlp')
adata_mlp.obs.index.name = "barcode"
adata_mlp.var.index.name = "gene"

fn = "Pellin2019/GSM3305362_PreBNK.raw_counts.tsv.gz"
preBnk = pd.read_csv(fn, compression='gzip', header=None, sep='\t', quotechar='"')
preBnk = preBnk.transpose()
preBnk.columns = preBnk.iloc[0,:]
preBnk.drop(0,inplace=True)
preBnk.index = preBnk['Barcode'].values
preBnk.drop(columns=['Barcode','Library'],inplace=True)
adata_preBnk = sc.AnnData(preBnk)
adata_preBnk.obs = adata_preBnk.obs.assign(batch=3)
adata_preBnk.obs = adata_preBnk.obs.assign(celltype='preBnk')
adata_preBnk.obs.index.name = "barcode"
adata_preBnk.var.index.name = "gene"

fn = "Pellin2019/GSM3305363_MEP.raw_counts.tsv.gz"
mep = pd.read_csv(fn, compression='gzip', header=None, sep='\t', quotechar='"')
mep = mep.transpose()
mep.columns = mep.iloc[0,:]
mep.drop(0,inplace=True)
mep.index = mep['Barcode'].values
mep.drop(columns=['Barcode','Library'],inplace=True)
adata_mep = sc.AnnData(mep)
adata_mep.obs = adata_mep.obs.assign(batch=4)
adata_mep.obs = adata_mep.obs.assign(celltype='mep')
adata_mep.obs.index.name = "barcode"
adata_mep.var.index.name = "gene"

fn = "Pellin2019/GSM3305364_CMP.raw_counts.tsv.gz"
cmp = pd.read_csv(fn, compression='gzip', header=None, sep='\t', quotechar='"')
cmp = cmp.transpose()
cmp.columns = cmp.iloc[0,:]
cmp.drop(0,inplace=True)
cmp.index = cmp['Barcode'].values
cmp.drop(columns=['Barcode','Library'],inplace=True)
adata_cmp = sc.AnnData(cmp)
adata_cmp.obs = adata_cmp.obs.assign(batch=5)
adata_cmp.obs = adata_cmp.obs.assign(celltype='cmp')
adata_cmp.obs.index.name = "barcode"
adata_cmp.var.index.name = "gene"

fn = "Pellin2019/GSM3305364_CMP.raw_counts.tsv.gz"
gmp = pd.read_csv(fn, compression='gzip', header=None, sep='\t', quotechar='"')
gmp = gmp.transpose()
gmp.columns = gmp.iloc[0,:]
gmp.drop(0,inplace=True)
gmp.index = gmp['Barcode'].values
gmp.drop(columns=['Barcode','Library'],inplace=True)
adata_gmp = sc.AnnData(gmp)
adata_gmp.obs = adata_gmp.obs.assign(batch=6)
adata_gmp.obs = adata_gmp.obs.assign(celltype='gmp')
adata_gmp.obs.index.name = "barcode"
adata_gmp.var.index.name = "gene"

# # Integrate the data
# concatenate to get common genes
adata = adata_hsc.concatenate(adata_mpp,adata_mlp,adata_preBnk,adata_mep,adata_cmp,adata_gmp)


# qc 
# mitochondrial genes
adata.var['mt'] = adata.var_names.str.startswith('MT-') 
# ribosomal genes
adata.var['ribo'] = adata.var_names.str.startswith(("RPS","RPL"))
# hemoglobin genes.
adata.var['hb'] = adata.var_names.str.contains(("^HB[^(P)]"))

sc.pp.calculate_qc_metrics(adata, qc_vars=['mt','ribo','hb'], percent_top=None, log1p=False, inplace=True)

# mitochondrial genes
adata.var['mt'] = adata.var_names.str.startswith('MT-') 
# ribosomal genes
adata.var['ribo'] = adata.var_names.str.startswith(("RPS","RPL"))
# hemoglobin genes.
adata.var['hb'] = adata.var_names.str.contains(("^HB[^(P)]"))

sc.pl.violin(adata, ['n_genes_by_counts', 'total_counts', 'pct_counts_mt','pct_counts_ribo', 'pct_counts_hb'],
             jitter=0.4, groupby = 'celltype', rotation= 45)


# normalize to depth 10 000
sc.pp.normalize_per_cell(adata, counts_per_cell_after=1e4)

# logaritmize
sc.pp.log1p(adata)


# store normalized counts in the raw slot, 
# we will subset adata.X for variable genes, but want to keep all genes matrix as well.
adata.raw = adata

# compute variable genes
sc.pp.highly_variable_genes(adata, min_mean=0.0125, max_mean=3, min_disp=0.5)
print("Highly variable genes: %d"%sum(adata.var.highly_variable))

# create a list of hvg
var_genes = adata.var.highly_variable
var_genes.index[var_genes]

#plot variable genes
sc.pl.highly_variable_genes(adata)

# regress out unwanted variables
# sc.pp.regress_out(adata, ['total_counts', 'pct_counts_mt'])

# scale data, clip values exceeding standard deviation 10.
sc.pp.scale(adata, max_value=10)

sc.tl.pca(adata, svd_solver='arpack')

sc.pl.pca(adata, color='celltype', components = ['1,2','3,4','5,6','7,8'], ncols=2)

sc.pp.neighbors(adata, n_pcs = 30, n_neighbors = 20)
sc.tl.umap(adata)
sc.pl.umap(adata, color='celltype')

# split per batch into new objects.
celltypes = adata.obs['celltype'].cat.categories.tolist()
alldata = {}
for celltype in celltypes:
    alldata[celltype] = adata[adata.obs['celltype'] == celltype,]

cdata = sc.external.pp.mnn_correct(alldata['cmp'],alldata['gmp'],alldata['hsc'],
                                   alldata['mep'],alldata['mlp'],alldata['mpp'],
                                   alldata['preBnk'], 
                                   svd_dim = 50, batch_key = 'celltype', save_raw = True, var_subset = var_genes)

Here is the error


cdata = sc.external.pp.mnn_correct(alldata['cmp'],alldata['gmp'],alldata['hsc'],...
Performing cosine normalization...
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
~/acquisition_pellin2019_basic.py in 
----> 183 cdata = sc.external.pp.mnn_correct(alldata['cmp'],alldata['gmp'],alldata['hsc'],
      184                                    alldata['mep'],alldata['mlp'],alldata['mpp'],
      185                                    alldata['preBnk'],
      186                                    svd_dim = 50, batch_key = 'celltype', save_raw = True, var_subset = var_genes)

~/anaconda3/envs/p385/lib/python3.8/site-packages/scanpy/external/pp/_mnn_correct.py in mnn_correct(var_index, var_subset, batch_key, index_unique, batch_categories, k, sigma, cos_norm_in, cos_norm_out, svd_dim, var_adj, compute_angle, mnn_order, svd_mode, do_concatenate, save_raw, n_jobs, *datas, **kwargs)
    133 
    134     n_jobs = settings.n_jobs if n_jobs is None else n_jobs
--> 135     datas, mnn_list, angle_list = mnn_correct(
    136         *datas,
    137         var_index=var_index,

~/anaconda3/envs/p385/lib/python3.8/site-packages/mnnpy/mnn.py in mnn_correct(var_index, var_subset, batch_key, index_unique, batch_categories, k, sigma, cos_norm_in, cos_norm_out, svd_dim, var_adj, compute_angle, mnn_order, svd_mode, do_concatenate, save_raw, n_jobs, *datas, **kwargs)
    120         if var_subset is not None and set(adata_vars) == set(var_subset):
    121             var_subset = None
--> 122         corrected = mnn_correct(*(adata.X for adata in datas), var_index=adata_vars,
    123                                 var_subset=var_subset, k=k, sigma=sigma, cos_norm_in=cos_norm_in,
    124                                 cos_norm_out=cos_norm_out, svd_dim=svd_dim, var_adj=var_adj,

~/anaconda3/envs/p385/lib/python3.8/site-packages/mnnpy/mnn.py in mnn_correct(var_index, var_subset, batch_key, index_unique, batch_categories, k, sigma, cos_norm_in, cos_norm_out, svd_dim, var_adj, compute_angle, mnn_order, svd_mode, do_concatenate, save_raw, n_jobs, *datas, **kwargs)
    153     # ------------------------------------------------------------
    154     print('Performing cosine normalization...')
--> 155     in_batches, out_batches, var_subset, same_set = transform_input_data(datas, cos_norm_in,
    156                                                                          cos_norm_out, var_index,
    157                                                                          var_subset, n_jobs)

~/anaconda3/envs/p385/lib/python3.8/site-packages/mnnpy/utils.py in transform_input_data(datas, cos_norm_in, cos_norm_out, var_index, var_subset, n_jobs)
     37     if var_subset is not None:
     38         if set(var_subset) - set(var_index) != set():
---> 39             raise ValueError('Some items in var_subset are not in var_index.')
     40         do_subset = True
     41         if set(var_index) == set(var_subset):

ValueError: Some items in var_subset are not in var_index.

Is there a way to tell mnn_correct to just ignore a gene if its not in this list?

Alternatively, is there a way to simply build a list of HVG, but only include genes that mnn_correct can use?

You could find the intersection of the genes for each dataset? I’m not sure how you’re getting highly variable genes, but you could do something like:

from functools import reduce

common_var_names = reduce(
    lambda x, y: x.intersection(y), 
    (adata.var_names[adata.var["highly_variable"]] for adata in to_combine)
)
1 Like