diff --git a/starsim/disease.py b/starsim/disease.py index 58f69cc7..adc626dd 100644 --- a/starsim/disease.py +++ b/starsim/disease.py @@ -9,6 +9,7 @@ import pandas as pd ss_int_ = ss.dtypes.int +ss_float_ = ss.dtypes.float __all__ = ['Disease', 'Infection', 'InfectionLog'] @@ -299,15 +300,20 @@ def infect(self): rel_trans = self.rel_trans.asnew(self.infectious * self.rel_trans) rel_sus = self.rel_sus.asnew(self.susceptible * self.rel_sus) - for i, (nkey,net) in enumerate(self.sim.networks.items()): + for i, (nkey,route) in enumerate(self.sim.networks.items()): nk = ss.standardize_netkey(nkey) - if isinstance(net, ss.Network) and len(net): # Skip networks with no edges - edges = net.edges + if isinstance(route, (ss.MixingPool, ss.MixingPools)): + target_uids = route.compute_transmission(rel_sus, rel_trans, betamap[nk]) + new_cases.append(target_uids) + sources.append(np.full(len(target_uids), dtype=ss_float_, fill_value=np.nan)) + networks.append(np.full(len(target_uids), dtype=ss_int_, fill_value=i)) + elif isinstance(route, ss.Network) and len(route): # Skip networks with no edges + edges = route.edges p1p2b0 = [edges.p1, edges.p2, betamap[nk][0]] # Person 1, person 2, beta 0 p2p1b1 = [edges.p2, edges.p1, betamap[nk][1]] # Person 2, person 1, beta 1 for src, trg, beta in [p1p2b0, p2p1b1]: if beta: # Skip networks with no transmission - beta_per_dt = net.net_beta(disease_beta=beta) # Compute beta for this network and timestep + beta_per_dt = route.net_beta(disease_beta=beta) # Compute beta for this network and timestep randvals = self.trans_rng.rvs(src, trg) # Generate a new random number based on the two other random numbers args = (src, trg, rel_trans, rel_sus, beta_per_dt, randvals) # Set up the arguments to calculate transmission target_uids, source_uids = self.compute_transmission(*args) # Actually calculate it diff --git a/starsim/networks.py b/starsim/networks.py index d55ebd97..d1713f51 100644 --- a/starsim/networks.py +++ b/starsim/networks.py @@ -1133,11 +1133,11 @@ def init_post(self): mp.init_post() return - def step(self): - """ Step each mixing pool """ + def compute_transmission(self, *args, **kwargs): + new_cases = [] for mp in self.pools: - mp.step() - return + new_cases.extend(mp.compute_transmission(*args, **kwargs)) + return new_cases def remove_uids(self, uids): """ Remove UIDs from each mixing pool """ @@ -1145,6 +1145,9 @@ def remove_uids(self, uids): mp.remove_uids(uids) return + def step(self): + return + class MixingPool(Route): """ @@ -1249,26 +1252,40 @@ def remove_uids(self, uids): self.pars[key] = inds.remove(uids) return - def step(self): - self.src_uids = self.get_uids(self.pars.src) - self.dst_uids = self.get_uids(self.pars.dst) - beta = self.pars.beta + def compute_transmission(self, rel_sus, rel_trans, disease_beta): + """ + Calculate transmission + + This is called from Infection.infect() together with network transmission. + :param rel_sus: Relative susceptibility + :param rel_trans: Relative infectiousness + :param disease_beta: The beta value for the disease. This is typically calculated as a + pair of values as networks are bidirectional, however, only the first value + is used because mixing pools are unidirectional. + :return: UIDs of agents who acquired the disease at this step + """ + + if disease_beta[0] == 0: + return [] + + # Determine the mixing pool beta value + beta = self.pars.beta if beta == 0: - return 0 + return [] + # Get source and target UIDs + self.src_uids = self.get_uids(self.pars.src) + self.dst_uids = self.get_uids(self.pars.dst) if len(self.src_uids) == 0 or len(self.dst_uids) == 0: - return 0 - - n_new_cases = 0 - for disease in self.diseases: - trans = np.mean(disease.infectious[self.src_uids] * disease.rel_trans[self.src_uids]) - acq = self.eff_contacts[self.dst_uids] * disease.susceptible[self.dst_uids] * disease.rel_sus[self.dst_uids] - p = beta*trans*acq + return [] - self.p_acquire.set(p=p) - new_cases = self.p_acquire.filter(self.dst_uids) - n_new_cases += len(new_cases) - disease.set_prognoses(new_cases) + # Calculate transmission + trans = np.mean(rel_trans[self.src_uids]) + acq = self.eff_contacts[self.dst_uids] * rel_sus[self.dst_uids] + p = beta*disease_beta[0]*trans*acq + self.p_acquire.set(p=p) + return self.p_acquire.filter(self.dst_uids) - return n_new_cases + def step(self): + return diff --git a/starsim/products.py b/starsim/products.py index 145a5434..f312608c 100644 --- a/starsim/products.py +++ b/starsim/products.py @@ -138,8 +138,8 @@ def administer(self, uids, return_format='dict'): class Vx(Product): """ Vaccine product """ - def __init__(self, diseases=None, pars=None, *args, **kwargs): - super().__init__(pars, *args, **kwargs) + def __init__(self, diseases=None, *args, **kwargs): + super().__init__(*args, **kwargs) self.diseases = sc.tolist(diseases) return diff --git a/tests/test_mixingpools.py b/tests/test_mixingpools.py index 43792cb9..0509bd7c 100644 --- a/tests/test_mixingpools.py +++ b/tests/test_mixingpools.py @@ -58,7 +58,7 @@ def test_single_ncd(): mp_pars = { 'src': ss.AgeGroup(0, 15), 'dst': ss.AgeGroup(15, None), - 'beta': ss.beta(0.15), + 'beta': ss.beta(1), 'contacts': ss.poisson(lam=5), 'diseases': 'ncd' } @@ -78,7 +78,7 @@ def test_single_missing_disease(): mp_pars = { 'src': ss.AgeGroup(0, 15), 'dst': ss.AgeGroup(15, None), - 'beta': ss.beta(0.15), + 'beta': ss.beta(1), 'contacts': ss.poisson(lam=5), 'diseases': 'hiv' } @@ -100,7 +100,7 @@ def test_single_age(do_plot=do_plot): mp_pars = { 'src': ss.AgeGroup(0, 15), 'dst': ss.AgeGroup(15, None), - 'beta': ss.beta(0.15), + 'beta': ss.beta(1), 'contacts': ss.poisson(lam=5), } mp = ss.MixingPool(mp_pars) @@ -123,7 +123,7 @@ def test_single_sex(do_plot=do_plot): mp_pars = { 'src': lambda sim: sim.people.female, # female to male (only) transmission 'dst': lambda sim: sim.people.male, - 'beta': ss.beta(0.2), + 'beta': ss.beta(1), 'contacts': ss.poisson(lam=4), } mp = ss.MixingPool(mp_pars) @@ -166,7 +166,7 @@ def test_multi(do_plot=do_plot): mps_pars = dict( contacts = np.array([[1.4, 0.5], [1.2, 0.7]]), - beta = ss.beta(0.2), + beta = 1, src = groups, dst = groups, )