From 1481e61e99632a2c28d9fae5211c9a5a991bda13 Mon Sep 17 00:00:00 2001 From: Johnson Sun Date: Sat, 27 Jul 2024 03:05:46 +0800 Subject: [PATCH] feat(core): Add assign shifts evenly preference The `diff` and `L2` usages are directly copied from the 2023/08/20 POC. --- core/nurse_scheduling/context.py | 7 +- core/nurse_scheduling/preference_types.py | 30 ++++++- core/nurse_scheduling/scheduler.py | 87 +++++++++++--------- core/tests/test_or_tools_example_1.py | 13 +++ core/tests/testcases/or_tools_example_1.yaml | 27 ++++++ 5 files changed, 122 insertions(+), 42 deletions(-) create mode 100644 core/tests/test_or_tools_example_1.py create mode 100644 core/tests/testcases/or_tools_example_1.yaml diff --git a/core/nurse_scheduling/context.py b/core/nurse_scheduling/context.py index ff53234..83871f9 100644 --- a/core/nurse_scheduling/context.py +++ b/core/nurse_scheduling/context.py @@ -1,3 +1,6 @@ +from ortools.sat.python import cp_model + + class Context: def __init__(self) -> None: self.startdate = None @@ -9,10 +12,12 @@ def __init__(self) -> None: self.n_days = None self.n_requirements = None self.n_people = None - self.model = None + self.model: cp_model.CpModel = None + self.model_vars = None self.shifts = None self.map_dr_p = None self.map_dp_r = None self.map_d_rp = None self.map_r_dp = None self.map_p_dr = None + self.objective = None diff --git a/core/nurse_scheduling/preference_types.py b/core/nurse_scheduling/preference_types.py index 9dc4b6c..688a167 100644 --- a/core/nurse_scheduling/preference_types.py +++ b/core/nurse_scheduling/preference_types.py @@ -1,6 +1,7 @@ from . import utils +from .context import Context -def all_requirements_fulfilled(ctx, args): +def all_requirements_fulfilled(ctx: Context, args, preference_id): # Hard constraint # For all shifts, the requirements (# of people) must be fulfilled. # Note that a shift is represented as (d, r) @@ -10,7 +11,7 @@ def all_requirements_fulfilled(ctx, args): required_n_people = utils.required_n_people(ctx.requirements[r]) ctx.model.Add(actual_n_people == required_n_people) -def all_people_work_at_most_one_shift_per_day(ctx, args): +def all_people_work_at_most_one_shift_per_day(ctx: Context, args, preference_id): # Hard constraint # For all people, for all days, only work at most one shift. # Note that a shift in day `d` can be represented as `r` instead of (d, r). @@ -20,7 +21,32 @@ def all_people_work_at_most_one_shift_per_day(ctx, args): maximum_n_shifts = 1 ctx.model.Add(actual_n_shifts <= maximum_n_shifts) +def assign_shifts_evenly(ctx: Context, args, preference_id): + # Soft constraint + # For all people, spread the shifts evenly. + # Note that a shift is represented as (d, r) + # i.e., max(weight * (actual_n_shifts - target_n_shifts) ** 2), for all p, + # where actual_n_shifts = sum_{(d, r)}(shifts[(d, r, p)]) + for p in range(ctx.n_people): + actual_n_shifts = sum(ctx.shifts[(d, r, p)] for d, r in ctx.map_p_dr[p]) + target_n_shifts = round(ctx.n_days * sum(requirement.required_people for requirement in ctx.requirements) / ctx.n_people) + unique_var_prefix = f"pref_{preference_id}_p_{p}_" + + # Construct: L2 = actual_n_shifts - target_n_shifts) ** 2 + L, U = -100, 100 # TODO: Calculate the actual bounds + diff_var_name = f"{unique_var_prefix}_diff" + ctx.model_vars[diff_var_name] = diff = ctx.model.NewIntVar(L, U, diff_var_name) + ctx.model.Add(diff == (actual_n_shifts - target_n_shifts)) + L2_var_name = f"{unique_var_prefix}_L2" + ctx.model_vars[L2_var_name] = L2 = ctx.model.NewIntVar(0, max(L**2, U**2), L2_var_name) + ctx.model.AddMultiplicationEquality(L2, diff, diff) + + # Add the objective + weight = -1 + ctx.objective += weight * L2 + PREFERENCE_TYPES_TO_FUNC = { "all requirements fulfilled": all_requirements_fulfilled, "all people work at most one shift per day": all_people_work_at_most_one_shift_per_day, + "assign shifts evenly": assign_shifts_evenly, } diff --git a/core/nurse_scheduling/scheduler.py b/core/nurse_scheduling/scheduler.py index ca60685..4da3c77 100644 --- a/core/nurse_scheduling/scheduler.py +++ b/core/nurse_scheduling/scheduler.py @@ -16,64 +16,69 @@ def schedule(filepath: str, validate=True, deterministic=False): logging.info("Extracting scenario data...") if scenario.apiVersion != "alpha": raise NotImplementedError(f"Unsupported API version: {scenario.apiVersion}") - startdate = scenario.startdate - enddate = scenario.enddate - requirements = scenario.requirements - people = scenario.people - preferences = scenario.preferences + ctx = Context() + ctx.startdate = scenario.startdate + ctx.enddate = scenario.enddate + ctx.requirements = scenario.requirements + ctx.people = scenario.people + ctx.preferences = scenario.preferences del scenario - n_days = (enddate - startdate).days + 1 - n_requirements = len(requirements) - n_people = len(people) - dates = [startdate + timedelta(days=d) for d in range(n_days)] + ctx.n_days = (ctx.enddate - ctx.startdate).days + 1 + ctx.n_requirements = len(ctx.requirements) + ctx.n_people = len(ctx.people) + ctx.dates = [ctx.startdate + timedelta(days=d) for d in range(ctx.n_days)] logging.info("Initializing solver model...") - model = cp_model.CpModel() - shifts = {} + ctx.model = cp_model.CpModel() + ctx.model_vars = {} + ctx.shifts = {} """A set of indicator variables that are 1 if and only if a person (p) is assigned to a shift (d, r).""" logging.info("Creating shift variables...") # Ref: https://developers.google.com/optimization/scheduling/employee_scheduling - for d in range(n_days): - for r in range(n_requirements): + for d in range(ctx.n_days): + for r in range(ctx.n_requirements): # TODO(Optimize): Skip if no people is required in that day - for p in range(n_people): + for p in range(ctx.n_people): # TODO(Optimize): Skip if the person does not qualify for the requirement - shifts[(d, r, p)] = model.NewBoolVar(f"shift_d{d}_r{r}_p{p}") + var_name = f"shift_d{d}_r{r}_p{p}" + ctx.model_vars[var_name] = ctx.shifts[(d, r, p)] = ctx.model.NewBoolVar(var_name) logging.info("Creating maps for faster lookup...") - map_dr_p = { - (d, r): {p for p in range(n_people) if (d, r, p) in shifts} - for (d, r) in itertools.product(range(n_days), range(n_requirements)) + ctx.map_dr_p = { + (d, r): {p for p in range(ctx.n_people) if (d, r, p) in ctx.shifts} + for (d, r) in itertools.product(range(ctx.n_days), range(ctx.n_requirements)) } - map_dp_r = { - (d, p): {r for r in range(n_requirements) if (d, r, p) in shifts} - for (d, p) in itertools.product(range(n_days), range(n_people)) + ctx.map_dp_r = { + (d, p): {r for r in range(ctx.n_requirements) if (d, r, p) in ctx.shifts} + for (d, p) in itertools.product(range(ctx.n_days), range(ctx.n_people)) } - map_d_rp = { - d: {(r, p) for (r, p) in itertools.product(range(n_requirements), range(n_people)) if (d, r, p) in shifts} - for d in range(n_days) + ctx.map_d_rp = { + d: {(r, p) for (r, p) in itertools.product(range(ctx.n_requirements), range(ctx.n_people)) if (d, r, p) in ctx.shifts} + for d in range(ctx.n_days) } - map_r_dp = { - r: {(d, p) for (d, p) in itertools.product(range(n_days), range(n_people)) if (d, r, p) in shifts} - for r in range(n_requirements) + ctx.map_r_dp = { + r: {(d, p) for (d, p) in itertools.product(range(ctx.n_days), range(ctx.n_people)) if (d, r, p) in ctx.shifts} + for r in range(ctx.n_requirements) } - map_p_dr = { - p: {(d, r) for (d, r) in itertools.product(range(n_days), range(n_requirements)) if (d, r, p) in shifts} - for p in range(n_people) + ctx.map_p_dr = { + p: {(d, r) for (d, r) in itertools.product(range(ctx.n_days), range(ctx.n_requirements)) if (d, r, p) in ctx.shifts} + for p in range(ctx.n_people) } - ctx = Context() - for k in vars(ctx): - setattr(ctx, k, locals()[k]) + ctx.objective = 0 logging.info("Adding preferences (including constraints)...") # TODO: Check no duplicated preferences # TODO: Check no overlapping preferences # TODO: Check all required preferences are present - for preference in preferences: - preference_types.PREFERENCE_TYPES_TO_FUNC[preference.type](ctx, preference.args) + for i, preference in enumerate(ctx.preferences): + preference_types.PREFERENCE_TYPES_TO_FUNC[preference.type](ctx, preference.args, i) + + # Define objective (i.e., soft constraints) + print(ctx.objective) + ctx.model.Maximize(ctx.objective) logging.info("Initializing solver...") solver = cp_model.CpSolver() @@ -95,7 +100,7 @@ def on_solution_callback(self): solution_printer = PartialSolutionPrinter() logging.info("Solving and showing partial results...") - status = solver.Solve(model, solution_printer) + status = solver.Solve(ctx.model, solution_printer) logging.info(f"Status: {solver.StatusName(status)}") @@ -110,7 +115,7 @@ def on_solution_callback(self): elif status == cp_model.MODEL_INVALID: logging.info("Model invalid!") logging.info("Validation Info:") - logging.info(model.Validate()) + logging.info(ctx.model.Validate()) else: logging.info("No solution found!") @@ -119,13 +124,17 @@ def on_solution_callback(self): logging.info(f" - branches : {solver.NumBranches()}") logging.info(f" - wall time: {solver.WallTime()}s") + logging.info("Variables:") + for k, v in ctx.model_vars.items(): + logging.info(f" - {k}: {solver.Value(v)}") + logging.info(f"Done.") if not found: return None df = export.get_people_versus_date_dataframe( - dates, people, requirements, - shifts, solver, + ctx.dates, ctx.people, ctx.requirements, + ctx.shifts, solver, ) return df diff --git a/core/tests/test_or_tools_example_1.py b/core/tests/test_or_tools_example_1.py new file mode 100644 index 0000000..93f31f4 --- /dev/null +++ b/core/tests/test_or_tools_example_1.py @@ -0,0 +1,13 @@ +import nurse_scheduling + +def test_example_1(): + filepath = "tests/testcases/or_tools_example_1.yaml" + df = nurse_scheduling.schedule(filepath, validate=False, deterministic=True) + assert df.values.tolist() == [ + ['', 18, 19, 20], + ['', 'Fri', 'Sat', 'Sun'], + ['Nurse 0', '', 'N', 'D'], + ['Nurse 1', 'E', 'E', 'E'], + ['Nurse 2', 'N', '', 'N'], + ['Nurse 3', 'D', 'D', ''] + ] diff --git a/core/tests/testcases/or_tools_example_1.yaml b/core/tests/testcases/or_tools_example_1.yaml new file mode 100644 index 0000000..d98d104 --- /dev/null +++ b/core/tests/testcases/or_tools_example_1.yaml @@ -0,0 +1,27 @@ +apiVersion: alpha +description: OR-Tools Example 1. From . +startdate: 2023-08-18 +enddate: 2023-08-20 +people: + - id: 0 + description: Nurse 0 + - id: 1 + description: Nurse 1 + - id: 2 + description: Nurse 2 + - id: 3 + description: Nurse 3 +requirements: + - id: D + description: Day shift requirement + required_people: 1 + - id: E + description: Evening shift requirement + required_people: 1 + - id: N + description: Night shift requirement + required_people: 1 +preferences: + - type: all requirements fulfilled + - type: all people work at most one shift per day + - type: assign shifts evenly