diff --git a/examples/plot_basic_usage.py b/examples/plot_basic_usage.py index dce35b0..1d725ee 100644 --- a/examples/plot_basic_usage.py +++ b/examples/plot_basic_usage.py @@ -63,8 +63,8 @@ def score_function(X): print("Pairwise comparison (one vs group):", xai.pairwise(X[2], X[5:10])) -pairlist=[(X[2], X[3]), (X[2], X[4]), (X[2], X[2]), (X[4], X[2])] -print("Pairwise comparison (group of pairs):", xai.pairwise_all(pairlist)) +pairlist = ([X[2], X[2], X[2], X[4]], [X[3], X[4], X[2], X[2]]) +print("Pairwise comparison (group of pairs):", xai.pairwise_set(*pairlist)) ###################################################################################### diff --git a/sharp/base.py b/sharp/base.py index eb84359..c1dd13c 100644 --- a/sharp/base.py +++ b/sharp/base.py @@ -259,23 +259,16 @@ def pairwise(self, sample1, sample2, **kwargs): **kwargs ) - def pairwise_all(self, pairs, **kwargs): + def pairwise_set(self, samples1, samples2, **kwargs): """ set_cols_idx should be passed in kwargs if measure is marginal pairs is a list of tuples of indexes """ - # X_ref = self._X if self._X is not None else check_inputs(X)[0] - - if "sample_size" in kwargs.keys(): - sample_size = 1 - - influences = parallel_loop( - lambda idx: self.individual( - pairs[idx][0].reshape(1, -1), X=pairs[idx][1].reshape(1, -1), verbose=False, **kwargs - ), - range(len(pairs)), + contributions = parallel_loop( + lambda samples: self.pairwise(*samples, verbose=False, **kwargs), + zip(samples1, samples2), n_jobs=self.n_jobs, progress_bar=self.verbose, ) - return np.array(influences) + return np.array(contributions)