import numpy as np
from collections import Counter
from collections import defaultdict
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import scanpy as sc
# Logistic model methods
[docs]
def generate_training_X(adata,ct_key,select_num=200,exclude=[]):
ad_model = adata
ct_list = [x for x in set(ad_model.obs[ct_key])]
select = np.hstack([np.random.choice(np.where(ad_model.obs[ct_key]==ct)[0],select_num) for ct in ct_list if ct not in exclude])
ad_model = ad_model[select]
return ad_model
[docs]
def logistic_model(data, cell_types, sparsity=0.2, fraction=0.2, penalty='l2', max_iter=100):
X = data
X_train, X_test, y_train, y_test = \
train_test_split(X, cell_types, test_size=fraction)
lr = LogisticRegression(penalty=penalty, C=sparsity, max_iter=max_iter)
lr.fit(X_train, y_train)
y_prob = lr.predict_proba(X_test)
return y_prob, y_test, lr
from sklearn import metrics
from sceleto.data import vega_20, vega_20_scanpy, zeileis_26, godsnot_64
[docs]
def plot_roc(y_prob, y_test, lr):
aucs =[]
if len(lr.classes_)<21: colors = vega_20
elif len(lr.classes_)<27: colors = zeileis_26
else: colors = godsnot_64
for i, cell_type in enumerate(lr.classes_):
fpr, tpr, _ = metrics.roc_curve(y_test == cell_type, y_prob[:, i])
auc = metrics.auc(fpr, tpr)
aucs.append(auc)
plt.plot(fpr, tpr, c=colors[i], lw=2, label = cell_type)
plt.plot([0, 1], [0, 1], color='k', ls=':')
plt.legend(loc=(1,0))
min_auc = np.min(aucs)
plt.title("Min AUC: %.3f"%(min_auc))
return(min_auc)
def transfer_annotation_jp(input_adata,y_id,output_adata,y_out,select_num=200,log=None,exclude=[],raw=True, max_iter=100):
ad_model = generate_training_X(input_adata,y_id,select_num=select_num,exclude=exclude)
#sc.pl.umap(ad_model,color=y_id)
"""
* 07/04/22
Uses annotation on an AnnData to predict the annotation on another unlabeled AnnData.
This function uses logistic regressing and the gene expressions for the predictions.
input_adata:AnnData, REQUIRED | AnnData object with labels that will be used for training.
y_id:string, REQUIRED | Name of the label on the annotated AnnData.
output_adata:AnnData, REQUIRED | AnnData object for which the annotations will be predicted.
y_out:string, REQUIRED | Name of the label of the predicted annotation.
"""
if raw==False:
X_model = ad_model.X
else:
check_raw_exists(ad_model)
X_model = ad_model.raw.X
y_model = ad_model.obs[y_id]
if raw==False:
X_predict = output_adata.X
else:
check_raw_exists(output_adata)
X_predict = output_adata.raw.X
print("creating lr model...")
y_prob, y_test, lr = logistic_model(X_model, y_model, sparsity=0.2, fraction=0.2, penalty='l2', max_iter=max_iter)
plot_roc(y_prob, y_test, lr)
print("making lr prediction...")
Lout = {}
if log:
Lout['log'] = log
Lout['classes'] = lr.classes_
Lout['predict'] = lr.predict(X_predict)
Lout['predict_proba'] = lr.predict_proba(X_predict)
print("updating lr to adata...")
output_adata.obs[y_out] = Lout['predict']
return lr
[docs]
def update_label(from_adata,from_label,to_adata,old_label,new_label,
exclude=None,include=None,replace=False,unknown=None,keep_replaced=True):
if old_label not in to_adata.obs.columns:
to_adata.obs[old_label] = 'unknown'
if exclude:
ON = {O:N for O,N in zip(from_adata.obs_names,from_adata.obs[from_label]) if N not in exclude}
elif include:
ON = {O:N for O,N in zip(from_adata.obs_names,from_adata.obs[from_label]) if N in include}
else:
ON = {O:N for O,N in zip(from_adata.obs_names,from_adata.obs[from_label])}
if unknown:
new_anno = [ON[O] if ((O in ON) & (N == unknown)) else N for O,N in zip(to_adata.obs_names,to_adata.obs[old_label])]
else:
new_anno = [ON[O] if (O in ON) else N for O,N in zip(to_adata.obs_names,to_adata.obs[old_label])]
if new_label in to_adata.obs.columns:
if replace==False:
if keep_replaced==True:
raise SystemError
else:
to_adata.obs[new_label] = new_anno
elif replace==True:
i = 1
while True:
new_key = new_label+'.replaced.'+str(i)
if new_key in to_adata.obs.columns:
i+=1
continue
else:
to_adata.obs[new_key] = list(to_adata.obs[new_label])
del to_adata.obs[new_label]
to_adata.obs[new_label] = new_anno
break
else:
to_adata.obs[new_label] = new_anno
[docs]
def get_common_var_raw(a,b):
common = sorted(list(set(a.raw.var_names).intersection(set(b.raw.var_names))))
list_a_names = list(a.raw.var_names)
list_b_names = list(b.raw.var_names)
a_index = np.array([list_a_names.index(x) for x in common])
b_index = np.array([list_b_names.index(x) for x in common])
print('calculating a...')
a_new_X = a.raw.X[:,a_index]
print('calculating b...')
b_new_X = b.raw.X[:,b_index]
a_new = sc.AnnData(a_new_X,obs = a.obs)
a_new.obsm = a.obsm
a_new.var_names = common
b_new = sc.AnnData(b_new_X,obs = b.obs)
b_new.obsm = b.obsm
b_new.var_names = common
return a_new,b_new
def check_raw_exists(adata):
try:
adata.raw.X
print('raw exist for adata')
except:
print("raw not found from adata, adding raw...".format(adata))
adata.raw = adata
[docs]
def predict_high(lr, adata, out_name, cl_to_focus = None, p = 0.9):
out_name = '{}_{}'.format(out_name,str(p))
pb = lr.predict_proba(adata.raw.X)
adata.obs[out_name] = 'None'
if cl_to_focus:
cl_index = lr.classes_==cl_to_focus
adata.obs[out_name][(pb[:,cl_index]>p).T[0]] = cl_to_focus
else:
for i,cl in enumerate(lr.classes_):
adata.obs[out_name][pb[:,i]>p] = cl
def fill_columns(a_new,b_new):
a_col_list = list(a_new.obs.columns)
b_col_list = list(b_new.obs.columns)
for obs_item in a_col_list:
if obs_item not in b_col_list:
b_new.obs[obs_item] = 'nan'
for obs_item in b_col_list:
if obs_item not in a_col_list:
a_new.obs[obs_item] = 'nan'
def remove_minor_anno(adata,obskey,num_cut=20):
exclude = [x[0] for x in Counter(adata.obs[obskey]).items() if x[1] < num_cut]
cp = np.array(adata.obs[obskey])
for item in exclude:
cp[cp==item] = 'none'
adata.obs[obskey+'_major'] = cp