diff --git a/wot/tmap/transport_map_model.py b/wot/tmap/transport_map_model.py index 3d08ef1..3e401c3 100644 --- a/wot/tmap/transport_map_model.py +++ b/wot/tmap/transport_map_model.py @@ -37,7 +37,7 @@ def __init__(self, tmaps, meta, timepoints=None, day_pairs=None, cache=False): day_pairs = [(timepoints[i], timepoints[i + 1]) for i in range(len(timepoints) - 1)] self.day_pairs = day_pairs - def fates(self, populations): + def fates(self, populations, include_forward_fates=False): """ Computes fates for each population @@ -47,6 +47,8 @@ def fates(self, populations): The TransportMapModel used to find fates populations : list of wot.Population The target populations such as ones from self.population_from_cell_sets. The populations must be from the same time. + include_forward_fates : bool + Whether to also compute fates forward in time (after the time specified in the populations) Returns ------- fates : anndata.AnnData @@ -54,6 +56,7 @@ def fates(self, populations): """ start_day = wot.tmap.unique_timepoint(*populations) # check for unique timepoint populations = Population.copy(*populations, normalize=False, add_missing=True) + initial_populations = populations pop_names = [pop.name for pop in populations] results = [] @@ -62,10 +65,17 @@ def fates(self, populations): populations = self.pull_back(*populations, as_list=True, normalize=False) results.insert(0, np.array([pop.p for pop in populations]).T) + if include_forward_fates: + populations = initial_populations + while self.can_push_forward(*populations): + populations = self.push_forward(*populations, as_list=True, normalize=False) + results.append(np.array([pop.p for pop in populations]).T) + X = np.concatenate(results) X /= X.sum(axis=1, keepdims=1) obs = self.meta.copy() - obs = obs[obs['day'] <= start_day] + if not include_forward_fates: + obs = obs[obs['day'] <= start_day] return anndata.AnnData(X=X, obs=obs, var=pd.DataFrame(index=pop_names)) def transition_table(self, start_populations, end_populations): diff --git a/wot/tmap/util.py b/wot/tmap/util.py index 0b3f87a..47ac0f8 100644 --- a/wot/tmap/util.py +++ b/wot/tmap/util.py @@ -7,6 +7,8 @@ def generate_comparisons(comparison_names, compare, days, reference_day='start'): + if compare == "all_times": + return itertools.product([(name, name) for name in comparison_names], itertools.combinations(days, 2)) if compare != 'within': # within, match, all, or trajectory name if compare == 'all': comparisons = itertools.combinations(comparison_names, 2)