Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mixing pool updates #848

Merged
merged 6 commits into from
Feb 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 10 additions & 4 deletions starsim/disease.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd

ss_int_ = ss.dtypes.int
ss_float_ = ss.dtypes.float

__all__ = ['Disease', 'Infection', 'InfectionLog']

Expand Down Expand Up @@ -299,15 +300,20 @@ def infect(self):
rel_trans = self.rel_trans.asnew(self.infectious * self.rel_trans)
rel_sus = self.rel_sus.asnew(self.susceptible * self.rel_sus)

for i, (nkey,net) in enumerate(self.sim.networks.items()):
for i, (nkey,route) in enumerate(self.sim.networks.items()):
nk = ss.standardize_netkey(nkey)
if isinstance(net, ss.Network) and len(net): # Skip networks with no edges
edges = net.edges
if isinstance(route, (ss.MixingPool, ss.MixingPools)):
target_uids = route.compute_transmission(rel_sus, rel_trans, betamap[nk])
new_cases.append(target_uids)
sources.append(np.full(len(target_uids), dtype=ss_float_, fill_value=np.nan))
networks.append(np.full(len(target_uids), dtype=ss_int_, fill_value=i))
elif isinstance(route, ss.Network) and len(route): # Skip networks with no edges
edges = route.edges
p1p2b0 = [edges.p1, edges.p2, betamap[nk][0]] # Person 1, person 2, beta 0
p2p1b1 = [edges.p2, edges.p1, betamap[nk][1]] # Person 2, person 1, beta 1
for src, trg, beta in [p1p2b0, p2p1b1]:
if beta: # Skip networks with no transmission
beta_per_dt = net.net_beta(disease_beta=beta) # Compute beta for this network and timestep
beta_per_dt = route.net_beta(disease_beta=beta) # Compute beta for this network and timestep
randvals = self.trans_rng.rvs(src, trg) # Generate a new random number based on the two other random numbers
args = (src, trg, rel_trans, rel_sus, beta_per_dt, randvals) # Set up the arguments to calculate transmission
target_uids, source_uids = self.compute_transmission(*args) # Actually calculate it
Expand Down
59 changes: 38 additions & 21 deletions starsim/networks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,18 +1133,21 @@ def init_post(self):
mp.init_post()
return

def step(self):
""" Step each mixing pool """
def compute_transmission(self, *args, **kwargs):
new_cases = []
for mp in self.pools:
mp.step()
return
new_cases.extend(mp.compute_transmission(*args, **kwargs))
return new_cases

def remove_uids(self, uids):
""" Remove UIDs from each mixing pool """
for mp in self.pools:
mp.remove_uids(uids)
return

def step(self):
return


class MixingPool(Route):
"""
Expand Down Expand Up @@ -1249,26 +1252,40 @@ def remove_uids(self, uids):
self.pars[key] = inds.remove(uids)
return

def step(self):
self.src_uids = self.get_uids(self.pars.src)
self.dst_uids = self.get_uids(self.pars.dst)
beta = self.pars.beta
def compute_transmission(self, rel_sus, rel_trans, disease_beta):
"""
Calculate transmission

This is called from Infection.infect() together with network transmission.

:param rel_sus: Relative susceptibility
:param rel_trans: Relative infectiousness
:param disease_beta: The beta value for the disease. This is typically calculated as a
pair of values as networks are bidirectional, however, only the first value
is used because mixing pools are unidirectional.
:return: UIDs of agents who acquired the disease at this step
"""

if disease_beta[0] == 0:
return []

# Determine the mixing pool beta value
beta = self.pars.beta
if beta == 0:
return 0
return []

# Get source and target UIDs
self.src_uids = self.get_uids(self.pars.src)
self.dst_uids = self.get_uids(self.pars.dst)
if len(self.src_uids) == 0 or len(self.dst_uids) == 0:
return 0

n_new_cases = 0
for disease in self.diseases:
trans = np.mean(disease.infectious[self.src_uids] * disease.rel_trans[self.src_uids])
acq = self.eff_contacts[self.dst_uids] * disease.susceptible[self.dst_uids] * disease.rel_sus[self.dst_uids]
p = beta*trans*acq
return []

self.p_acquire.set(p=p)
new_cases = self.p_acquire.filter(self.dst_uids)
n_new_cases += len(new_cases)
disease.set_prognoses(new_cases)
# Calculate transmission
trans = np.mean(rel_trans[self.src_uids])
acq = self.eff_contacts[self.dst_uids] * rel_sus[self.dst_uids]
p = beta*disease_beta[0]*trans*acq
self.p_acquire.set(p=p)
return self.p_acquire.filter(self.dst_uids)

return n_new_cases
def step(self):
return
4 changes: 2 additions & 2 deletions starsim/products.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,8 +138,8 @@ def administer(self, uids, return_format='dict'):

class Vx(Product):
""" Vaccine product """
def __init__(self, diseases=None, pars=None, *args, **kwargs):
super().__init__(pars, *args, **kwargs)
def __init__(self, diseases=None, *args, **kwargs):
super().__init__(*args, **kwargs)
self.diseases = sc.tolist(diseases)
return

Expand Down
10 changes: 5 additions & 5 deletions tests/test_mixingpools.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def test_single_ncd():
mp_pars = {
'src': ss.AgeGroup(0, 15),
'dst': ss.AgeGroup(15, None),
'beta': ss.beta(0.15),
'beta': ss.beta(1),
'contacts': ss.poisson(lam=5),
'diseases': 'ncd'
}
Expand All @@ -78,7 +78,7 @@ def test_single_missing_disease():
mp_pars = {
'src': ss.AgeGroup(0, 15),
'dst': ss.AgeGroup(15, None),
'beta': ss.beta(0.15),
'beta': ss.beta(1),
'contacts': ss.poisson(lam=5),
'diseases': 'hiv'
}
Expand All @@ -100,7 +100,7 @@ def test_single_age(do_plot=do_plot):
mp_pars = {
'src': ss.AgeGroup(0, 15),
'dst': ss.AgeGroup(15, None),
'beta': ss.beta(0.15),
'beta': ss.beta(1),
'contacts': ss.poisson(lam=5),
}
mp = ss.MixingPool(mp_pars)
Expand All @@ -123,7 +123,7 @@ def test_single_sex(do_plot=do_plot):
mp_pars = {
'src': lambda sim: sim.people.female, # female to male (only) transmission
'dst': lambda sim: sim.people.male,
'beta': ss.beta(0.2),
'beta': ss.beta(1),
'contacts': ss.poisson(lam=4),
}
mp = ss.MixingPool(mp_pars)
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_multi(do_plot=do_plot):

mps_pars = dict(
contacts = np.array([[1.4, 0.5], [1.2, 0.7]]),
beta = ss.beta(0.2),
beta = 1,
src = groups,
dst = groups,
)
Expand Down
Loading