From 2984c33b6d0933c8d72f79940dd3caf3f7e3dde1 Mon Sep 17 00:00:00 2001 From: Romesh Abeysuriya Date: Fri, 24 Jan 2025 11:55:42 +1100 Subject: [PATCH 1/5] Initial refactor --- starsim/disease.py | 10 +++++--- starsim/networks.py | 60 ++++++++++++++++++++++++++++----------------- starsim/products.py | 4 +-- 3 files changed, 46 insertions(+), 28 deletions(-) diff --git a/starsim/disease.py b/starsim/disease.py index e340b34f..6b7be671 100644 --- a/starsim/disease.py +++ b/starsim/disease.py @@ -299,15 +299,17 @@ def infect(self): rel_trans = self.rel_trans.asnew(self.infectious * self.rel_trans) rel_sus = self.rel_sus.asnew(self.susceptible * self.rel_sus) - for i, (nkey,net) in enumerate(self.sim.networks.items()): + for i, (nkey,route) in enumerate(self.sim.networks.items()): nk = ss.standardize_netkey(nkey) - if isinstance(net, ss.Network) and len(net): # Skip networks with no edges - edges = net.edges + if isinstance(route, (ss.MixingPool, ss.MixingPools)): + new_cases.append(route.get_transmission(self, betamap[nk])) + elif isinstance(route, ss.Network) and len(route): # Skip networks with no edges + edges = route.edges p1p2b0 = [edges.p1, edges.p2, betamap[nk][0]] # Person 1, person 2, beta 0 p2p1b1 = [edges.p2, edges.p1, betamap[nk][1]] # Person 2, person 1, beta 1 for src, trg, beta in [p1p2b0, p2p1b1]: if beta: # Skip networks with no transmission - beta_per_dt = net.net_beta(disease_beta=beta) # Compute beta for this network and timestep + beta_per_dt = route.net_beta(disease_beta=beta) # Compute beta for this network and timestep randvals = self.trans_rng.rvs(src, trg) # Generate a new random number based on the two other random numbers args = (src, trg, rel_trans, rel_sus, beta_per_dt, randvals) # Set up the arguments to calculate transmission target_uids, source_uids = self.compute_transmission(*args) # Actually calculate it diff --git a/starsim/networks.py b/starsim/networks.py index e957c6c9..ce206b50 100644 --- a/starsim/networks.py +++ b/starsim/networks.py @@ -1133,11 +1133,11 @@ def init_post(self): mp.init_post() return - def step(self): - """ Step each mixing pool """ + def get_transmission(self, *args, **kwargs): + new_cases = [] for mp in self.pools: - mp.step() - return + new_cases.extend(mp.get_transmission(*args, **kwargs)) + return new_cases def remove_uids(self, uids): """ Remove UIDs from each mixing pool """ @@ -1145,6 +1145,9 @@ def remove_uids(self, uids): mp.remove_uids(uids) return + def step(self): + return + class MixingPool(Route): """ @@ -1249,28 +1252,41 @@ def remove_uids(self, uids): self.pars[key] = inds.remove(uids) return - def step(self): - self.src_uids = self.get_uids(self.pars.src) - self.dst_uids = self.get_uids(self.pars.dst) + def get_transmission(self, disease, disease_beta): + """ + Calculate transmission + + This is called from Infection.infect() together with network transmission + + :param disease: ss.Infection instance + :param disease_beta: The beta value for the disease. This is typically calculated as a + pair of values as networks are bidirectional, however, only the first value + is used because mixing pools are unidirectional. + :return: UIDs of agents who acquired the disease at this step + """ + + if disease_beta[0] == 0: + return [] + + # Determine the mixing pool beta value beta = self.pars.beta if isinstance(beta, ss.beta): - beta = beta.values # Don't use as a time probability - + beta = beta.values # Don't use as a time probability if beta == 0: - return 0 + return [] + # Get source and target UIDs + self.src_uids = self.get_uids(self.pars.src) + self.dst_uids = self.get_uids(self.pars.dst) if len(self.src_uids) == 0 or len(self.dst_uids) == 0: - return 0 + return [] - n_new_cases = 0 - for disease in self.diseases: - trans = np.mean(disease.infectious[self.src_uids] * disease.rel_trans[self.src_uids]) - acq = self.eff_contacts[self.dst_uids] * disease.susceptible[self.dst_uids] * disease.rel_sus[self.dst_uids] - p = beta*trans*acq + # Calculate transmission + trans = np.mean(disease.infectious[self.src_uids] * disease.rel_trans[self.src_uids]) + acq = self.eff_contacts[self.dst_uids] * disease.susceptible[self.dst_uids] * disease.rel_sus[self.dst_uids] + p = beta*disease_beta[0]*trans*acq + self.p_acquire.set(p=p) + return self.p_acquire.filter(self.dst_uids) - self.p_acquire.set(p=p) - new_cases = self.p_acquire.filter(self.dst_uids) - n_new_cases += len(new_cases) - disease.set_prognoses(new_cases) - - return n_new_cases + def step(self): + return \ No newline at end of file diff --git a/starsim/products.py b/starsim/products.py index 145a5434..f312608c 100644 --- a/starsim/products.py +++ b/starsim/products.py @@ -138,8 +138,8 @@ def administer(self, uids, return_format='dict'): class Vx(Product): """ Vaccine product """ - def __init__(self, diseases=None, pars=None, *args, **kwargs): - super().__init__(pars, *args, **kwargs) + def __init__(self, diseases=None, *args, **kwargs): + super().__init__(*args, **kwargs) self.diseases = sc.tolist(diseases) return From 6f0f619026b2a882c983092ae7f8d65203e607ee Mon Sep 17 00:00:00 2001 From: Romesh Abeysuriya Date: Fri, 24 Jan 2025 12:06:06 +1100 Subject: [PATCH 2/5] Running but not quite right --- starsim/disease.py | 6 +++++- tests/test_mixingpools.py | 6 +++--- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/starsim/disease.py b/starsim/disease.py index 6b7be671..e274c490 100644 --- a/starsim/disease.py +++ b/starsim/disease.py @@ -9,6 +9,7 @@ import pandas as pd ss_int_ = ss.dtypes.int +ss_float_ = ss.dtypes.float __all__ = ['Disease', 'Infection', 'InfectionLog'] @@ -302,7 +303,10 @@ def infect(self): for i, (nkey,route) in enumerate(self.sim.networks.items()): nk = ss.standardize_netkey(nkey) if isinstance(route, (ss.MixingPool, ss.MixingPools)): - new_cases.append(route.get_transmission(self, betamap[nk])) + target_uids = route.get_transmission(self, betamap[nk]) + new_cases.append(target_uids) + sources.append(np.full(len(target_uids), dtype=ss_float_, fill_value=np.nan)) + networks.append(np.full(len(target_uids), dtype=ss_int_, fill_value=i)) elif isinstance(route, ss.Network) and len(route): # Skip networks with no edges edges = route.edges p1p2b0 = [edges.p1, edges.p2, betamap[nk][0]] # Person 1, person 2, beta 0 diff --git a/tests/test_mixingpools.py b/tests/test_mixingpools.py index 43792cb9..de2aadda 100644 --- a/tests/test_mixingpools.py +++ b/tests/test_mixingpools.py @@ -58,7 +58,7 @@ def test_single_ncd(): mp_pars = { 'src': ss.AgeGroup(0, 15), 'dst': ss.AgeGroup(15, None), - 'beta': ss.beta(0.15), + 'beta': ss.beta(1), 'contacts': ss.poisson(lam=5), 'diseases': 'ncd' } @@ -78,7 +78,7 @@ def test_single_missing_disease(): mp_pars = { 'src': ss.AgeGroup(0, 15), 'dst': ss.AgeGroup(15, None), - 'beta': ss.beta(0.15), + 'beta': ss.beta(1), 'contacts': ss.poisson(lam=5), 'diseases': 'hiv' } @@ -100,7 +100,7 @@ def test_single_age(do_plot=do_plot): mp_pars = { 'src': ss.AgeGroup(0, 15), 'dst': ss.AgeGroup(15, None), - 'beta': ss.beta(0.15), + 'beta': ss.beta(1), 'contacts': ss.poisson(lam=5), } mp = ss.MixingPool(mp_pars) From 4f32cf95d67f7a81030060a659ba64228cd59d00 Mon Sep 17 00:00:00 2001 From: Romesh Abeysuriya Date: Fri, 24 Jan 2025 13:15:13 +1100 Subject: [PATCH 3/5] Explicitly pass in rel_sus/rel_trans, rename method --- starsim/disease.py | 2 +- starsim/networks.py | 15 ++++++++------- tests/test_mixingpools.py | 18 +++++++++--------- 3 files changed, 18 insertions(+), 17 deletions(-) diff --git a/starsim/disease.py b/starsim/disease.py index e274c490..e2e74756 100644 --- a/starsim/disease.py +++ b/starsim/disease.py @@ -303,7 +303,7 @@ def infect(self): for i, (nkey,route) in enumerate(self.sim.networks.items()): nk = ss.standardize_netkey(nkey) if isinstance(route, (ss.MixingPool, ss.MixingPools)): - target_uids = route.get_transmission(self, betamap[nk]) + target_uids = route.compute_transmission(self, rel_sus, rel_trans, betamap[nk]) new_cases.append(target_uids) sources.append(np.full(len(target_uids), dtype=ss_float_, fill_value=np.nan)) networks.append(np.full(len(target_uids), dtype=ss_int_, fill_value=i)) diff --git a/starsim/networks.py b/starsim/networks.py index ce206b50..ca0ec41f 100644 --- a/starsim/networks.py +++ b/starsim/networks.py @@ -1133,10 +1133,10 @@ def init_post(self): mp.init_post() return - def get_transmission(self, *args, **kwargs): + def compute_transmission(self, *args, **kwargs): new_cases = [] for mp in self.pools: - new_cases.extend(mp.get_transmission(*args, **kwargs)) + new_cases.extend(mp.compute_transmission(*args, **kwargs)) return new_cases def remove_uids(self, uids): @@ -1252,13 +1252,14 @@ def remove_uids(self, uids): self.pars[key] = inds.remove(uids) return - def get_transmission(self, disease, disease_beta): + def compute_transmission(self, rel_sus, rel_trans, disease_beta): """ Calculate transmission - This is called from Infection.infect() together with network transmission + This is called from Infection.infect() together with network transmission. - :param disease: ss.Infection instance + :param rel_sus: Relative susceptibility + :param rel_trans: Relative infectiousness :param disease_beta: The beta value for the disease. This is typically calculated as a pair of values as networks are bidirectional, however, only the first value is used because mixing pools are unidirectional. @@ -1282,8 +1283,8 @@ def get_transmission(self, disease, disease_beta): return [] # Calculate transmission - trans = np.mean(disease.infectious[self.src_uids] * disease.rel_trans[self.src_uids]) - acq = self.eff_contacts[self.dst_uids] * disease.susceptible[self.dst_uids] * disease.rel_sus[self.dst_uids] + trans = np.mean(rel_trans[self.src_uids]) + acq = self.eff_contacts[self.dst_uids] * rel_sus[self.dst_uids] p = beta*disease_beta[0]*trans*acq self.p_acquire.set(p=p) return self.p_acquire.filter(self.dst_uids) diff --git a/tests/test_mixingpools.py b/tests/test_mixingpools.py index de2aadda..c5065f38 100644 --- a/tests/test_mixingpools.py +++ b/tests/test_mixingpools.py @@ -123,7 +123,7 @@ def test_single_sex(do_plot=do_plot): mp_pars = { 'src': lambda sim: sim.people.female, # female to male (only) transmission 'dst': lambda sim: sim.people.male, - 'beta': ss.beta(0.2), + 'beta': ss.beta(1), 'contacts': ss.poisson(lam=4), } mp = ss.MixingPool(mp_pars) @@ -166,7 +166,7 @@ def test_multi(do_plot=do_plot): mps_pars = dict( contacts = np.array([[1.4, 0.5], [1.2, 0.7]]), - beta = ss.beta(0.2), + beta = 1, src = groups, dst = groups, ) @@ -187,14 +187,14 @@ def test_multi(do_plot=do_plot): sc.options(interactive=do_plot) T = sc.timer() - sim0 = test_single_defaults(do_plot) - sim1 = test_single_uids(do_plot) - sim2 = test_single_ncd() - sim3 = test_single_missing_disease() - sim4 = test_single_age(do_plot) - sim5 = test_single_sex(do_plot) + # sim0 = test_single_defaults(do_plot) + # sim1 = test_single_uids(do_plot) + # sim2 = test_single_ncd() + # sim3 = test_single_missing_disease() + # sim4 = test_single_age(do_plot) + # sim5 = test_single_sex(do_plot) - sim6 = test_multi_defaults(do_plot) + # sim6 = test_multi_defaults(do_plot) sim7 = test_multi(do_plot) T.toc() From 09bacab9da5ea7b131e04e9950a2613898adcd4d Mon Sep 17 00:00:00 2001 From: Romesh Abeysuriya Date: Fri, 24 Jan 2025 13:16:22 +1100 Subject: [PATCH 4/5] Fix argument --- starsim/disease.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/starsim/disease.py b/starsim/disease.py index e2e74756..e2a46c71 100644 --- a/starsim/disease.py +++ b/starsim/disease.py @@ -303,7 +303,7 @@ def infect(self): for i, (nkey,route) in enumerate(self.sim.networks.items()): nk = ss.standardize_netkey(nkey) if isinstance(route, (ss.MixingPool, ss.MixingPools)): - target_uids = route.compute_transmission(self, rel_sus, rel_trans, betamap[nk]) + target_uids = route.compute_transmission(rel_sus, rel_trans, betamap[nk]) new_cases.append(target_uids) sources.append(np.full(len(target_uids), dtype=ss_float_, fill_value=np.nan)) networks.append(np.full(len(target_uids), dtype=ss_int_, fill_value=i)) From b1eb49c32a041a07b5730d3f69dcfb4a666c434f Mon Sep 17 00:00:00 2001 From: Romesh Abeysuriya Date: Wed, 29 Jan 2025 10:45:35 +1000 Subject: [PATCH 5/5] Restore tests --- tests/test_mixingpools.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/test_mixingpools.py b/tests/test_mixingpools.py index c5065f38..0509bd7c 100644 --- a/tests/test_mixingpools.py +++ b/tests/test_mixingpools.py @@ -187,14 +187,14 @@ def test_multi(do_plot=do_plot): sc.options(interactive=do_plot) T = sc.timer() - # sim0 = test_single_defaults(do_plot) - # sim1 = test_single_uids(do_plot) - # sim2 = test_single_ncd() - # sim3 = test_single_missing_disease() - # sim4 = test_single_age(do_plot) - # sim5 = test_single_sex(do_plot) - - # sim6 = test_multi_defaults(do_plot) + sim0 = test_single_defaults(do_plot) + sim1 = test_single_uids(do_plot) + sim2 = test_single_ncd() + sim3 = test_single_missing_disease() + sim4 = test_single_age(do_plot) + sim5 = test_single_sex(do_plot) + + sim6 = test_multi_defaults(do_plot) sim7 = test_multi(do_plot) T.toc()