Skip to content

Commit

Permalink
zscore use matrix operations
Browse files Browse the repository at this point in the history
  • Loading branch information
PauBadiaM authored Jul 23, 2024
1 parent 2cb8697 commit b1fb97c
Showing 1 changed file with 11 additions and 26 deletions.
37 changes: 11 additions & 26 deletions decoupler/method_zscore.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,32 +15,17 @@

from tqdm import tqdm

def zscore(m, net, flavor='RoKAI', verbose=False):

# Get dims
n_samples = m.shape[0]
n_features, n_fsets = net.shape

es = np.zeros((n_samples, n_fsets))
pv = np.zeros((n_samples, n_fsets))

# Compute each element of Matrix3
for i in range(n_samples):
for f in range(n_fsets):
m_f = m[:, i]
if isspmatrix_csr(m_f):
m_f = m_f.toarray()
m_f = m_f.reshape(1, -1)
net_i = net[:, f]

mean_product = np.sum(m_f * net_i) / np.sum(abs(net_i))
mean_m_f = 0 if flavor == "RoKAI" else np.mean(m_f)
std_m_f = np.std(m_f)
count_non_zeros = np.count_nonzero(net_i)

es[i, f] = (mean_product - mean_m_f) * np.sqrt(count_non_zeros) / std_m_f
pv[i, f] = norm.cdf(-abs(es[i, f]))

def zscore(m, net, flavor='RoKAI', verbose=False):
stds = np.std(m, axis=1, ddof=1)
if flavor != 'RoKAI':
mean_all = np.mean(m, axis=1)
else:
mean_all = np.zeros(stds.shape)
n = np.sqrt(np.count_nonzero(net, axis=0))
mean = m.dot(net) / np.sum(np.abs(net), axis=0)
es = ((mean - mean_all.reshape(-1, 1)) * n) / stds.reshape(-1, 1)
pv = norm.cdf(-np.abs(es))
return es, pv


Expand Down Expand Up @@ -105,4 +90,4 @@ def run_zscore(mat, net, source='source', target='target', weight='weight', batc
pvals = pd.DataFrame(pvals, index=r, columns=sources)
pvals.name = 'zscore_pvals'

return return_data(mat=mat, results=(estimate, pvals))
return return_data(mat=mat, results=(estimate, pvals))

0 comments on commit b1fb97c

Please sign in to comment.