diff --git a/ehrql/dummy_data_nextgen/generator.py b/ehrql/dummy_data_nextgen/generator.py index f9880b0d0..3b16feb5b 100644 --- a/ehrql/dummy_data_nextgen/generator.py +++ b/ehrql/dummy_data_nextgen/generator.py @@ -4,10 +4,12 @@ import random import string import time +from bisect import bisect_left from contextlib import contextmanager from datetime import date, timedelta -from ehrql.dummy_data_nextgen.query_info import QueryInfo +from ehrql.dummy_data_nextgen.query_info import QueryInfo, filter_values +from ehrql.exceptions import CannotGenerate from ehrql.query_engines.in_memory import InMemoryQueryEngine from ehrql.query_engines.in_memory_database import InMemoryDatabase from ehrql.query_model.introspection import all_inline_patient_ids @@ -44,7 +46,10 @@ def __init__( # suitable time range by inspecting the query, this will have to do. self.today = today if today is not None else date.today() self.patient_generator = DummyPatientGenerator( - self.variable_definitions, self.random_seed, self.today + self.variable_definitions, + self.random_seed, + self.today, + self.population_size, ) log.info("Using next generation dummy data generation") @@ -138,11 +143,15 @@ def get_results(self): class DummyPatientGenerator: - def __init__(self, variable_definitions, random_seed, today): + def __init__(self, variable_definitions, random_seed, today, population_size): self.__rnd = None self.random_seed = random_seed self.today = today self.query_info = QueryInfo.from_variable_definitions(variable_definitions) + self.population_size = population_size + + self.__column_values = {} + self.__reset_event_range() @property def rnd(self): @@ -202,21 +211,23 @@ def get_patient_column(self, column_name): except KeyError: pass + def __reset_event_range(self): + self.events_start = date(1900, 1, 1) + self.events_end = self.today + def generate_patient_facts(self, patient_id): # Seed the random generator using the patient_id so we always generate the same # data for the same patient with self.seed(patient_id): - # TODO: We could obviously generate more realistic age distributions than this - + self.__reset_event_range() + iters = 0 while True: + iters += 1 + assert iters <= 1000 # Retry until we have a date of birth and date of death that are # within reasonable ranges dob_column = self.get_patient_column("date_of_birth") - if dob_column is not None and dob_column.get_constraint( - Constraint.GeneralRange - ): - self.events_start = self.today - timedelta(days=120 * 365) - self.events_end = self.today + if dob_column is not None: date_of_birth = self.get_random_value(dob_column) else: date_of_birth = self.today - timedelta( @@ -224,23 +235,33 @@ def generate_patient_facts(self, patient_id): ) dod_column = self.get_patient_column("date_of_death") - if dod_column is not None and dod_column.get_constraint( - Constraint.GeneralRange - ): + if dod_column is not None: date_of_death = self.get_random_value(dod_column) else: age_days = self.rnd.randrange(105 * 365) date_of_death = date_of_birth + timedelta(days=age_days) - if date_of_death >= date_of_birth and ( - date_of_death - date_of_birth < timedelta(105 * 365) + if ( + date_of_birth is None + or date_of_death is None + or ( + date_of_death >= date_of_birth + and (date_of_death - date_of_birth < timedelta(105 * 365)) + ) ): break self.date_of_birth = date_of_birth - self.date_of_death = date_of_death if date_of_death < self.today else None self.events_start = self.date_of_birth - self.events_end = min(self.today, date_of_death) + + if date_of_death is None: + self.date_of_death = None + self.events_end = self.today + else: + self.date_of_death = ( + date_of_death if date_of_death < self.today else None + ) + self.events_end = min(self.today, date_of_death) def rows_for_patients(self, table_info): row = { @@ -279,94 +300,164 @@ def populate_row(self, table_info, row): if name not in row: row[name] = self.get_random_value(column_info) - def get_random_value(self, column_info): - # TODO: This never returns None although for realism it sometimes should - if cat_constraint := column_info.get_constraint(Constraint.Categorical): - # TODO: It's obviously not true in general that categories are equiprobable - return self.rnd.choice(cat_constraint.values) - elif range_constraint := column_info.get_constraint(Constraint.ClosedRange): - return self.rnd.randrange( - range_constraint.minimum, - range_constraint.maximum + 1, - range_constraint.step, + def __check_values(self, column_info, result): + if not result: + raise CannotGenerate( + f"Unable to find any values for {column_info.name} that satisfy the population definition." + + ( + "" + if self.__is_exhaustive(column_info) + else " If you believe this should be possible, please report this as a bug." + ) ) - elif (column_info.type is date) and ( - date_range_constraint := column_info.get_constraint(Constraint.GeneralRange) - ): - if date_range_constraint.maximum is not None: - maximum = date_range_constraint.maximum - else: - maximum = self.today - - if not date_range_constraint.includes_maximum: - maximum -= timedelta(days=1) - if date_range_constraint.minimum is not None: - minimum = date_range_constraint.minimum - if not date_range_constraint.includes_minimum: - minimum += timedelta(days=1) - # TODO: Currently this code only runs when the column is date_of_birth - # so condition is always hit. Remove this pragma when that stops being - # the case. - if ( - column_info.get_constraint(Constraint.FirstOfMonth) - and minimum.day != 1 + + for v in result: + assert v is None or isinstance(v, column_info.type) + return result + + def __is_exhaustive(self, column_info): + if column_info.get_constraint( + Constraint.Categorical + ) or column_info.get_constraint(Constraint.ClosedRange): + return True + return column_info.type not in (int, float, str) + + def get_possible_values(self, column_info): + try: + return self.__column_values[column_info] + except KeyError: + pass + + with self.seed(f"columns:{column_info.name}"): + exhaustive = True + + # Arbitrary small number of retries for when we don't manage + # to generate enough of some unbounded range the first time. + for _ in range(3): + if cat_constraint := column_info.get_constraint(Constraint.Categorical): + base_values = list(cat_constraint.values) + elif range_constraint := column_info.get_constraint( + Constraint.ClosedRange ): - if minimum.month == 12: - minimum = minimum.replace(year=minimum.year + 1, month=1, day=1) + base_values = range( + range_constraint.minimum, + range_constraint.maximum + 1, + range_constraint.step, + ) + elif column_info.type is date: + earliest_possible = date(1900, 1, 1) + base_values = [ + earliest_possible + timedelta(days=i) + for i in range((self.today - earliest_possible).days + 1) + ] + elif column_info.type is bool: + base_values = [False, True] + elif column_info.type is int: + base_values = list( + range( + -max(100, self.population_size * 2), + max(100, self.population_size * 2), + ) + ) + exhaustive = False + elif column_info.type is float: + base_values = [ + 0.01 * i for i in range(max(101, self.population_size * 2 + 1)) + ] + exhaustive = False + elif column_info.type is str: + exhaustive = False + if column_info._values_used: + # If we know some good strings already there's no point in generating + # additional strings that almost certainly won't work.' + base_values = [] + elif regex_constraint := column_info.get_constraint( + Constraint.Regex + ): + generator = get_regex_generator(regex_constraint.regex) + base_values = [ + generator(self.rnd) + for _ in range(self.population_size * 10) + ] else: - minimum = minimum.replace(month=minimum.month + 1, day=1) - else: - minimum = (maximum - timedelta(days=100 * 365)).replace(day=1) - - assert minimum <= maximum - - days = (maximum - minimum).days - result = minimum + timedelta(days=random.randint(0, days)) - # TODO: Currently this code only runs when the column is date_of_birth - # so condition is always hit. Remove this pragma when that stops being - # the case. - if column_info.get_constraint(Constraint.FirstOfMonth): # pragma: no branch - assert minimum.day == 1 - result = result.replace(day=1) - assert minimum <= result <= maximum - return result - elif column_info.values_used: - if self.rnd.randint(0, len(column_info.values_used)) != 0: - return self.rnd.choice(column_info.values_used) - elif column_info.type is bool: - return self.rnd.choice((True, False)) - elif column_info.type is int: - # TODO: This distributon is obviously ridiculous but will do for now - return self.rnd.randrange(100) - elif column_info.type is float: - # TODO: As is this - return self.rnd.random() * 100 - elif column_info.type is str: - # If the column must match a regex then generate matching strings - if regex_constraint := column_info.get_constraint(Constraint.Regex): - generator = get_regex_generator(regex_constraint.regex) - return generator(self.rnd) - # A random ASCII string is unlikely to be very useful here, but it at least - # makes it a bit clearer what the issue is (that we don't know enough about - # the column to generate anything more helpful) rather than the blank string - # we always used to return - return "".join( - self.rnd.choice(CHARS) for _ in range(self.rnd.randrange(16)) + # A random ASCII string is unlikely to be very useful here, but it at least + # makes it a bit clearer what the issue is (that we don't know enough about + # the column to generate anything more helpful) rather than the blank string + # we always used to return + base_values = [ + "".join( + self.rnd.choice(CHARS) + for _ in range(self.rnd.randrange(16)) + ) + for _ in range(self.population_size * 10) + ] + else: + assert False + + base_values = list(base_values) + base_values.extend(column_info._values_used) + base_values.append(None) + if column_info.name == "date_of_death": + base_values = [ + v for v in base_values if v is None or v < self.today + ] + + base_values = [ + v + for v in base_values + if all(c.validate(v) for c in column_info.constraints) + ] + + if column_info.query is None: + values = base_values + else: + values = filter_values(column_info.query, base_values) + + if exhaustive or values: + break + + values.sort(key=lambda x: (x is not None, x)) + + values = self.__check_values(column_info, values) + assert values[0] is None or None not in values + + self.__column_values[column_info] = values + return values + + def get_random_value(self, column_info): + values = self.get_possible_values(column_info) + assert values + if column_info.type is date: + result = self.rnd.choice(values) + if result is None: + return result + if self.events_start <= result <= self.events_end: + return result + + lo = bisect_left( + values, self.events_start, lo=1 if values[0] is None else 0 ) - elif column_info.type is date: - # Use an exponential distribution to preferentially generate recent events - # (mean of one year ago). This works OK for the our immediate purposes but - # we'll no doubt have to iterate on this. - days_ago = int(self.rnd.expovariate(1 / 365)) - event_date = self.events_end - timedelta(days=days_ago) - # Clip to the available time range - event_date = max(event_date, self.events_start) - # Apply any FirstOfMonth constraints - if column_info.get_constraint(Constraint.FirstOfMonth): - event_date = event_date.replace(day=1) - return event_date + hi = bisect_left(values, self.events_end, lo=lo) + if hi < len(values) and values[hi] == self.events_end: + hi += 1 + if lo >= len(values) or hi == 0 or lo == hi: + # TODO: This is something of a bad hack. + # We've found ourselves in a situation where we've generated + # a patient that can't actually have a valid value for this, + # but we are required to have one. The solution here is to just + # return some random nonsense and let the population definition + # exclude this patient. + # + # We pick values[0] in particular because that's where None will + # be, so it's the only possible valid value, but if this column + # is not nullable then it'll just be an arbitrary date that can't + # work. + return values[0] + + i = self.rnd.randrange(lo, hi) + return values[i] else: - assert False, f"Unhandled type: {column_info.type}" + return self.rnd.choice(values) def get_empty_data(self): return { diff --git a/ehrql/dummy_data_nextgen/measures.py b/ehrql/dummy_data_nextgen/measures.py index 6b31eff43..bd9d5ada8 100644 --- a/ehrql/dummy_data_nextgen/measures.py +++ b/ehrql/dummy_data_nextgen/measures.py @@ -28,7 +28,8 @@ def get_data(self): return self.generator.get_data() def get_results(self): - database = InMemoryDatabase(self.get_data()) + data = self.get_data() + database = InMemoryDatabase(data) engine = InMemoryQueryEngine(database) return get_measure_results(engine, self.measures) diff --git a/ehrql/dummy_data_nextgen/query_info.py b/ehrql/dummy_data_nextgen/query_info.py index 159b9022c..74c7bde03 100644 --- a/ehrql/dummy_data_nextgen/query_info.py +++ b/ehrql/dummy_data_nextgen/query_info.py @@ -1,24 +1,30 @@ import dataclasses from collections import defaultdict -from datetime import date, timedelta +from collections.abc import Mapping from functools import cached_property, lru_cache +from ehrql.query_engines.in_memory import InMemoryQueryEngine +from ehrql.query_engines.in_memory_database import InMemoryDatabase, Rows from ehrql.query_model.introspection import all_unique_nodes, get_table_nodes from ehrql.query_model.nodes import ( + AggregateByPatient, + Case, Column, Function, InlinePatientTable, + Node, SelectColumn, SelectPatientTable, SelectTable, TableSchema, Value, + get_input_nodes, get_root_frame, ) -from ehrql.query_model.table_schema import Constraint +from ehrql.query_model.query_graph_rewriter import QueryGraphRewriter -@dataclasses.dataclass +@dataclasses.dataclass(unsafe_hash=True) class ColumnInfo: """ Captures information about a column as used in a particular dataset definition @@ -27,19 +33,19 @@ class ColumnInfo: name: str type: type # NOQA: A003 constraints: tuple = () - _values_used: set = dataclasses.field(default_factory=set) + query: Node | None = None + _values_used: set = dataclasses.field(default_factory=set, hash=False) @classmethod - def from_column(cls, name, column, extra_constraints=()): + def from_column(cls, name, column, query): type_ = column.type_ if hasattr(type_, "_primitive_type"): type_ = type_._primitive_type() return cls( name, type_, - constraints=normalize_constraints( - tuple(column.constraints) + tuple(extra_constraints) - ), + query=query, + constraints=tuple(column.constraints), ) def __post_init__(self): @@ -109,10 +115,6 @@ def from_variable_definitions(cls, variable_definitions): all_nodes = all_unique_nodes(*variable_definitions.values()) by_type = get_nodes_by_type(all_nodes) - extra_constraints = query_to_column_constraints( - variable_definitions["population"] - ) - tables = { # Create a TableInfo object … table.name: TableInfo.from_table(table) @@ -139,10 +141,22 @@ def from_variable_definitions(cls, variable_definitions): column_info = table_info.columns.get(name) if column_info is None: # … insert a ColumnInfo object into the appropriate table + base_column = SelectColumn(source=table, name=column.name) + + specialized_query = specialize( + variable_definitions["population"], base_column + ) + + # TODO: We should actually check whether it's a False value here and raise an + # error if it is. + if specialized_query is not None and is_value(specialized_query): + specialized_query = None + if specialized_query is not None: + assert tuple(columns_for_query(specialized_query)) == (base_column,) column_info = ColumnInfo.from_column( name, table.schema.get_column(name), - extra_constraints=extra_constraints.get(column, ()), + query=specialized_query, ) table_info.columns[name] = column_info # Record the ColumnInfo object associated with each SelectColumn node @@ -193,161 +207,207 @@ def sort_by_name(iterable): @lru_cache -def query_to_column_constraints(query): - """Converts a query (typically a population definition) into - constraints that would have to be applied to a record in order - to satisfy it.""" +def is_value(query): + if query is None: + return True + elif isinstance(query, Value): + return True + else: + children = get_input_nodes(query) + return children and all(is_value(child) for child in children) + + +@lru_cache +def columns_for_query(query: Node): + """Returns all columns referenced in a given query.""" + return frozenset( + { + subnode + for subnode in all_unique_nodes(query) + if isinstance(subnode, SelectColumn) + and isinstance( + subnode.source, SelectTable | SelectPatientTable | InlinePatientTable + ) + } + ) + + +def specialize(query, column) -> Node | None: + """Takes query and specialises it to one that only references column. + Satisfying the resulting query is necessary but not in general sufficient + to satisfy the source query. + """ + assert len(columns_for_query(column)) == 1 + if is_value(query): + return query match query: case Function.And(lhs=lhs, rhs=rhs): - left = query_to_column_constraints(lhs) - right = query_to_column_constraints(rhs) - keys = set(left) | set(right) - return {k: left.get(k, []) + right.get(k, []) for k in keys} + lhs = specialize(lhs, column) + rhs = specialize(rhs, column) + if lhs is None: + return rhs + if rhs is None: + return lhs + result = Function.And(lhs, rhs) + assert len(columns_for_query(result)) == 1 + return result case Function.Or(lhs=lhs, rhs=rhs): - left = query_to_column_constraints(lhs) - right = query_to_column_constraints(rhs) - result = {} - for k, v in left.items(): - try: - result[k] = list(set(v) & set(right[k])) - except KeyError: - pass - for k, v in list(result.items()): - if not v: - del result[k] + lhs = specialize(lhs, column) + rhs = specialize(rhs, column) + if lhs is None or rhs is None: + return None + result = Function.Or(lhs=lhs, rhs=rhs) + assert len(columns_for_query(result)) == 1 return result - case Function.EQ( - lhs=SelectColumn() as lhs, - rhs=Value(value=value), - ): - return {lhs: [Constraint.Categorical(values=(value,))]} - case Function.EQ( - lhs=Function.YearFromDate(source=SelectColumn() as column), - rhs=Value(value=year), - ): - return { - column: [ - Constraint.GeneralRange( - minimum=date(year, 1, 1), - maximum=date(year, 12, 31), - ) - ] - } - case Function.In( - lhs=SelectColumn() as lhs, - rhs=Value(value=values), - ): - return {lhs: [Constraint.Categorical(values=values)]} - case Function.GE( - lhs=Function.DateDifferenceInYears( - lhs=Value(value=reference_date), rhs=column - ), - rhs=Value(value=difference), - ): - return { - column: [ - Constraint.GeneralRange( - maximum=reference_date - timedelta(days=365 * difference) - ) - ] - } - case Function.LE( - lhs=Function.DateDifferenceInYears( - lhs=Value(value=reference_date), rhs=column - ), - rhs=Value(value=difference), + + # TODO: This could really use a nicer way of handling it. + # All of them create some sort of follow on obligations though + # that can only be handled by creating additional records, + # so for this first pass generation of data we do need to + # exclude them. + case ( + AggregateByPatient.Count() + | AggregateByPatient.CountDistinct() + | AggregateByPatient.Exists() ): - return { - column: [ - Constraint.GeneralRange( - minimum=reference_date.replace( - year=reference_date.year - difference - ) - ) - ] + return None + case ( + ( + Function.EQ(rhs=Case()) + | Function.NE(rhs=Case()) + | Function.LT(rhs=Case()) + | Function.GT(rhs=Case()) + | Function.LE(rhs=Case()) + | Function.GE(rhs=Case()) + ) as comp + ) if column not in columns_for_query(comp.rhs): + case_statement = comp.rhs + if case_statement.default is None: + rewritten = None + else: + rewritten = comp.__class__(lhs=comp.lhs, rhs=case_statement.default) + for v in case_statement.cases.values(): + if v is None: + continue + part = comp.__class__(lhs=comp.lhs, rhs=v) + if rewritten is None: + rewritten = part + else: + rewritten = Function.Or(rewritten, part) + return specialize(rewritten, column) + case ( + ( + Function.EQ(lhs=Case()) + | Function.NE(lhs=Case()) + | Function.LT(lhs=Case()) + | Function.GT(lhs=Case()) + | Function.LE(lhs=Case()) + | Function.GE(lhs=Case()) + ) as comp + ) if column not in columns_for_query(comp.lhs): + opposites: dict[type, type] = { + Function.LT: Function.GT, + Function.LE: Function.GE, } - case Function.LT( - lhs=Function.DateAddYears( - lhs=SelectColumn() as column, - rhs=Value(value=difference), - ), - rhs=Value(value=reference_date), + opposites.update([(v, k) for k, v in opposites.items()]) + assert len(opposites) == 4 + opposite_type = opposites.get(type(comp), type(comp)) + return specialize(opposite_type(lhs=comp.rhs, rhs=comp.lhs), column) + case ( + ( + Function.EQ() + | Function.NE() + | Function.LT() + | Function.GT() + | Function.LE() + | Function.GE() + ) as comp ): - return { - column: [ - Constraint.GeneralRange( - maximum=reference_date.replace( - year=reference_date.year - difference - ), - includes_maximum=False, - ) - ] - } - case Function.GT(lhs=SelectColumn() as column, rhs=Value(value=min_value)): - return { - column: [ - Constraint.GeneralRange(minimum=min_value, includes_minimum=False) - ] - } - case Function.GE(lhs=SelectColumn() as column, rhs=Value(value=min_value)): - return { - column: [ - Constraint.GeneralRange(minimum=min_value, includes_minimum=True) - ] - } - case Function.LT(lhs=SelectColumn() as column, rhs=Value(value=max_value)): - return { - column: [ - Constraint.GeneralRange(maximum=max_value, includes_maximum=False) - ] - } - case Function.LE(lhs=SelectColumn() as column, rhs=Value(value=max_value)): - return { - column: [ - Constraint.GeneralRange(maximum=max_value, includes_maximum=True) - ] - } - case Function.IsNull(source=SelectColumn() as column): - return {column: [Constraint.NotNull()]} - - return {} - - -def normalize_constraints(constraints): - group_by_type = defaultdict(list) - for constraint in constraints: - group_by_type[type(constraint)].append(constraint) - if len(group_by_type[Constraint.Categorical]) > 1: - constraint, *rest = group_by_type[Constraint.Categorical] - for more in rest: - constraint = Constraint.Categorical( - values=set(constraint.values) & set(more.values) - ) - group_by_type[Constraint.Categorical] = [constraint] - if len(ranges := group_by_type[Constraint.GeneralRange]) > 1: - minimum = None - maximum = None - for r in ranges: - if minimum is None: - minimum = r.minimum - elif r.minimum is not None: - minimum = max(minimum, r.minimum) - if maximum is None: - maximum = r.maximum - elif r.maximum is not None: - maximum = min(maximum, r.maximum) - - includes_minimum = minimum is None or all(r.validate(minimum) for r in ranges) - includes_maximum = maximum is None or all(r.validate(maximum) for r in ranges) - group_by_type[Constraint.GeneralRange] = [ - Constraint.GeneralRange( - minimum=minimum, - maximum=maximum, - includes_maximum=includes_maximum, - includes_minimum=includes_minimum, - ) - ] - - return tuple( - [constraint for group in group_by_type.values() for constraint in group] + lhs = specialize(comp.lhs, column) + rhs = specialize(comp.rhs, column) + if lhs is None or rhs is None: + return None + return type(comp)(lhs=lhs, rhs=rhs) + case SelectColumn() as q: + if column == q: + assert len(columns_for_query(q)) == 1 + return q + else: + return None + case _: + fields = query.__dataclass_fields__ + specialized = {} + for k in fields: + v = getattr(query, k) + if isinstance(v, Node): + v = specialize(v, column) + if v is None: + return None + elif isinstance(v, Mapping): + items = list(v.items()) + new_items = {} + for x, y in items: + x = specialize(x, column) + if x is None: + return None + y = specialize(y, column) + if y is None: + return None + new_items[x] = y + v = type(v)(new_items) + else: + try: + values = list(v) + except TypeError: + pass + else: + new_values = [] + for elt in values: + elt = specialize(elt, column) + if elt is None: + return None + new_values.append(elt) + v = type(v)(new_values) + specialized[k] = v + return type(query)(**specialized) + + +def filter_values(query, values): + """Returns the subset of `values` that can appear in a result for `query`. + + `query` may only refer to a single column (which `values` will be interpreted + as belonging to). + """ + (column,) = columns_for_query(query) + source = get_root_frame(column) + column_name = column.name + simplified_schema = TableSchema( + dummy_column_for_sorting=Column(int), + **{column_name: source.schema.get_column(column_name)}, + ) + fake_table = type(source)(name=source.name, schema=simplified_schema) + replacement_column = SelectColumn(source=fake_table, name=column_name) + database = InMemoryDatabase( + {fake_table: [(i, i, v) for i, v in enumerate(values, 1)]} ) + engine = InMemoryQueryEngine(database) + + rewriter = QueryGraphRewriter() + rewriter.replace(column, replacement_column) + + rows = list(engine.get_results({"population": rewriter.rewrite(query)})) + + # If we're picking from an event frame we may get a Rows object rather than + # a value back. We only care about the distinct values that can be returned + # here, so we just care about what values are in the rows. + result = [] + for row in rows: + value = database.tables[fake_table.name][column_name][row.patient_id] + if isinstance(value, Rows): + result.extend(value.values()) + else: + result.append(value) + for v in result: + assert not isinstance(v, Rows) + + return result diff --git a/ehrql/exceptions.py b/ehrql/exceptions.py new file mode 100644 index 000000000..b142f9cd9 --- /dev/null +++ b/ehrql/exceptions.py @@ -0,0 +1,17 @@ +class EHRQLException(Exception): + """Base exception for EHRQL errors of all sorts. + + This is not yet reliably used everywhere it should be. + """ + + +class DummyDataException(EHRQLException): + """Base class for dummy data errors.""" + + +class CannotGenerate(DummyDataException): + """Raised when a population definition cannot be satisfied. + + This may be because it is logically impossible, or it may be + logically possible but we were unable to do so. + """ diff --git a/tests/generative/test_query_model.py b/tests/generative/test_query_model.py index 81346c406..85ffd9249 100644 --- a/tests/generative/test_query_model.py +++ b/tests/generative/test_query_model.py @@ -9,7 +9,8 @@ import pytest from hypothesis.vendor.pretty import _singleton_pprinters, pretty -from ehrql.dummy_data import DummyDataGenerator +from ehrql import dummy_data, dummy_data_nextgen +from ehrql.exceptions import CannotGenerate from ehrql.query_model.introspection import all_unique_nodes from ehrql.query_model.nodes import ( AggregateByPatient, @@ -83,7 +84,14 @@ settings = dict( max_examples=(int(os.environ.get("GENTEST_EXAMPLES", 10))), deadline=None, - derandomize=not os.environ.get("GENTEST_RANDOMIZE"), + # On CI we want to run derandomized unless we're explicitly in the + # long-running bug finding tests. This prevents flaky CI. + # In development we never want to run with derandomize because it's + # less effective at finding bugs and, more importantly, has very slow + # replay due to turning off the test database + derandomize=bool(os.environ.get("GENTEST_RANDOMIZE")) + if os.environ.get("CI") + else False, # The explain phase is comparatively expensive here given how # costly data generation is for our tests here, so we turn it # off by default. @@ -92,6 +100,7 @@ if os.environ.get("GENTEST_EXPLAIN") != "true" else hyp.Phase ), + report_multiple_bugs=False, ) @@ -114,6 +123,7 @@ def query_engines(request): class EnabledTests(Enum): serializer = auto() dummy_data = auto() + dummy_data_nextgen = auto() main_query = auto() all_population = auto() pretty_printing = auto() @@ -149,6 +159,8 @@ def test_query_model( run_serializer_test(population, variable) if EnabledTests.dummy_data in test_types: run_dummy_data_test(population, variable) + if EnabledTests.dummy_data_nextgen in test_types: + run_dummy_data_test(population, variable, next_gen=True) if EnabledTests.main_query in test_types: run_test(query_engines, data, population, variable, recorder) if EnabledTests.pretty_printing in test_types: @@ -252,18 +264,21 @@ def run_with(engine, instances, variables): engine.teardown() -def run_dummy_data_test(population, variable): +def run_dummy_data_test(population, variable, next_gen=False): try: - run_dummy_data_test_without_error_handling(population, variable) + run_dummy_data_test_without_error_handling(population, variable, next_gen) except Exception as e: # pragma: no cover if not get_ignored_error_type(e): raise -def run_dummy_data_test_without_error_handling(population, variable): +def run_dummy_data_test_without_error_handling(population, variable, next_gen=False): # We can't do much more here than check that the generator runs without error, but # that's enough to catch quite a few issues - dummy_data_generator = DummyDataGenerator( + + dummy = dummy_data_nextgen if next_gen else dummy_data + + dummy_data_generator = dummy.DummyDataGenerator( {"population": population, "v": variable}, population_size=1, # We need a batch size bigger than one otherwise by chance (or, more strictly, @@ -273,10 +288,17 @@ def run_dummy_data_test_without_error_handling(population, variable): batch_size=5, timeout=-1, ) - assert isinstance(dummy_data_generator.get_data(), dict) + try: + assert isinstance(dummy_data_generator.get_data(), dict) + # TODO: This isn't reliably getting hit. Figure out how to make it be so. + # This error is logically possible here but the actual code paths are tested + # elsewhere so it's not that important for the generative tests to be able to + # hit it. + except CannotGenerate: # pragma: no cover + pass # Using a simplified population definition which should always have matching patients # we can confirm that we generate at least some data - dummy_data_generator = DummyDataGenerator( + dummy_data_generator = dummy.DummyDataGenerator( {"population": all_patients_query, "v": variable}, population_size=1, batch_size=1, @@ -338,8 +360,8 @@ def test_variable_strategy_is_comprehensive(): # The specific seed used has no particular significance. This test is just # a bit fragile. If it fails and you think this isn't a real failure, feel # free to tweak the seed a bit and see if that fixes it. - @hyp.settings(max_examples=500, database=None, deadline=None) - @hyp.seed(2789686902) + @hyp.settings(max_examples=600, database=None, deadline=None) + @hyp.seed(3457902459072) @hyp.given(variable=variable_strategy) def record_operations_seen(variable): operations_seen.update(type(node) for node in all_unique_nodes(variable)) diff --git a/tests/generative/variable_strategies.py b/tests/generative/variable_strategies.py index b54fa2efc..7f585770b 100644 --- a/tests/generative/variable_strategies.py +++ b/tests/generative/variable_strategies.py @@ -532,14 +532,14 @@ def filtered_table(draw): @st.composite def sorted_frame(draw): # Decide how many Sorts and Filters (if any) we're going to apply - num_sorts = draw(st.integers(min_value=1, max_value=3)) - num_filters = draw(st.integers(min_value=0, max_value=6)) - # Mix up the order of operations - operations = [filter_] * num_filters + [sort] * num_sorts - shuffled_operations = draw(st.permutations(operations)) + operations = draw( + st.lists(st.sampled_from([sort, filter_]), min_size=1, max_size=9).filter( + lambda ls: (1 <= ls.count(sort) <= 3) and (ls.count(filter_) <= 6) + ) + ) # Pick a table and apply the operations source = draw(select_table()) - for operation in shuffled_operations: + for operation in operations: source = draw(operation(source)) return source @@ -559,15 +559,23 @@ def select_patient_table(): @st.composite def inline_patient_table(draw): - patient_ids = draw(st.lists(st.integers(1, 10), unique=True)) - rows = tuple( - ( - patient_id, - *[draw(value_strategies[type_]) for name, type_ in schema.column_types], - ) - for patient_id in patient_ids + return InlinePatientTable( + rows=tuple( + draw( + st.lists( + st.tuples( + st.integers(1, 10), + *[ + value_strategies[type_] + for name, type_ in schema.column_types + ], + ), + unique_by=lambda r: r[0], + ), + ) + ), + schema=schema, ) - return InlinePatientTable(rows=rows, schema=schema) @st.composite def filter_(draw, source): diff --git a/tests/unit/dummy_data_nextgen/test_edge_cases_for_coverage.py b/tests/unit/dummy_data_nextgen/test_edge_cases_for_coverage.py new file mode 100644 index 000000000..c416c46f3 --- /dev/null +++ b/tests/unit/dummy_data_nextgen/test_edge_cases_for_coverage.py @@ -0,0 +1,88 @@ +from datetime import date + +from ehrql.dummy_data_nextgen.query_info import QueryInfo, is_value, specialize +from ehrql.query_language import ( + Series, +) +from ehrql.query_model.nodes import ( + Case, + Function, + SelectColumn, + SelectPatientTable, + TableSchema, + Value, +) + + +def test_check_is_value(): + assert is_value(Function.GT(lhs=Value(value=0.0), rhs=Value(value=0.0))) + + +schema = TableSchema( + i1=Series(int), + b0=Series(bool), + b1=Series(bool), + d1=Series(date), +) +p0 = SelectPatientTable(name="p0", schema=schema) + + +def test_case_is_not_value_with_non_value_outcome(): + assert not is_value( + Case( + cases={ + Value(True): SelectColumn(p0, "i1"), + }, + default=None, + ) + ) + + +def test_case_is_value_with_value_outcome(): + assert is_value( + Case( + cases={ + Value(True): Value(0.0), + }, + default=None, + ) + ) + + +def test_minimum_of_values_is_value(): + assert is_value(Function.MinimumOf((Value(0), Value(1)))) + + +def test_some_nonsense(): + QueryInfo.from_variable_definitions( + { + "population": Function.LT( + lhs=Case( + cases={SelectColumn(source=p0, name="b1"): None}, + default=Value(value=date(2010, 1, 1)), + ), + rhs=Value(value=date(2010, 1, 1)), + ), + "v": SelectColumn( + source=p0, + name="i1", + ), + } + ) + + +def test_rewrites_rhs_case_to_or(): + table = SelectPatientTable(name="p0", schema=schema) + + specialized = specialize( + Function.LT( + lhs=Value(value=date(2010, 1, 1)), + rhs=Case( + cases={SelectColumn(table, name="b1"): Value(value=date(2010, 1, 1))}, + default=Value(value=date(2010, 1, 1)), + ), + ), + SelectColumn(table, name="i1"), + ) + + assert isinstance(specialized, Function.Or) diff --git a/tests/unit/dummy_data_nextgen/test_generator.py b/tests/unit/dummy_data_nextgen/test_generator.py index d1fd25307..0ff7fe401 100644 --- a/tests/unit/dummy_data_nextgen/test_generator.py +++ b/tests/unit/dummy_data_nextgen/test_generator.py @@ -29,7 +29,11 @@ class patients(PatientFrame): date_of_birth = Series(datetime.date, constraints=[Constraint.FirstOfMonth()]) date_of_death = Series(datetime.date) sex = Series( - str, constraints=[Constraint.Categorical(["male", "female", "intersex"])] + str, + constraints=[ + Constraint.Categorical(["male", "female", "intersex"]), + Constraint.NotNull(), + ], ) @@ -70,26 +74,31 @@ def test_dummy_data_generator(): dataset.date = last_event.date # Generate some results + target_size = 7 + variable_definitions = compile(dataset) - generator = DummyDataGenerator(variable_definitions) - generator.population_size = 7 + generator = DummyDataGenerator(variable_definitions, population_size=target_size) generator.batch_size = 4 results = list(generator.get_results()) # Check they look right - assert len(results) == 7 + + assert any(r.code is not None for r in results) + assert any(r.date is not None for r in results) for r in results: assert isinstance(r.date_of_birth, datetime.date) assert r.date_of_birth.day == 1 assert r.date_of_death is None or r.date_of_death > r.date_of_birth - assert r.sex in {"male", "female", "intersex"} + assert r.sex in {"male", "female", "intersex", None} # To get full coverage here we need to generate enough data so that we get at # least one patient with a matching event and one without if r.code is not None or r.date is not None: assert r.code in {"abc", "def"} assert isinstance(r.date, datetime.date) - assert r.imd in {0, 1000, 2000, 3000, 4000, 5000} + assert r.imd in {0, 1000, 2000, 3000, 4000, 5000, None} + + assert len(results) == target_size @mock.patch("ehrql.dummy_data_nextgen.generator.time") @@ -118,7 +127,10 @@ def test_dummy_data_generator_timeout_with_some_results(patched_time): def test_dummy_data_generator_timeout_with_no_results(patched_time): # Define a dataset with a condition no patient can match dataset = Dataset() - dataset.define_population(patients.sex != patients.sex) + dataset.define_population( + (patients.date_of_birth == patients.date_of_death) + & (patients.date_of_death.day == 2) + ) variable_definitions = compile(dataset) generator = DummyDataGenerator(variable_definitions) @@ -234,9 +246,10 @@ def test_get_random_value_on_first_of_month_with_last_month_minimum( constraints=( Constraint.FirstOfMonth(), Constraint.GeneralRange( - minimum=datetime.datetime(2020, 12, 5), - maximum=datetime.datetime(2021, 1, 30), + minimum=datetime.date(2020, 12, 5), + maximum=datetime.date(2021, 1, 30), ), + Constraint.NotNull(), ), ) with dummy_patient_generator.seed(""): @@ -245,7 +258,7 @@ def test_get_random_value_on_first_of_month_with_last_month_minimum( ] # All generated dates should be forced to 2021-01-01 assert len(set(values)) == 1 - assert all(value == datetime.datetime(2021, 1, 1) for value in values) + assert all(value == datetime.date(2021, 1, 1) for value in values) def test_get_random_str(dummy_patient_generator): @@ -306,7 +319,7 @@ def test_get_random_int_with_range(dummy_patient_generator): values = [ dummy_patient_generator.get_random_value(column_info) for _ in range(10) ] - assert all(value in [0, 2, 4, 6, 8, 10] for value in values), values + assert all(value in [0, 2, 4, 6, 8, 10, None] for value in values), values def test_cannot_generate_data_outside_of_a_seed_block(dummy_patient_generator): @@ -328,6 +341,7 @@ def dummy_patient_generator(): variable_definitions, random_seed="abc", today=datetime.date(2024, 1, 1), + population_size=1000, ) generator.generate_patient_facts(patient_id=1) # Ensure that this patient has a long enough history that we get a sensible diff --git a/tests/unit/dummy_data_nextgen/test_measures.py b/tests/unit/dummy_data_nextgen/test_measures.py index 81a9c2e18..a8248362f 100644 --- a/tests/unit/dummy_data_nextgen/test_measures.py +++ b/tests/unit/dummy_data_nextgen/test_measures.py @@ -10,22 +10,28 @@ class patients(PatientFrame): sex = Series( str, - constraints=[Constraint.Categorical(["male", "female"])], + constraints=[Constraint.Categorical(["male", "female"]), Constraint.NotNull()], ) region = Series( str, constraints=[ - Constraint.Categorical(["London", "The North", "The Countryside"]) + Constraint.Categorical( + ["London", "The North", "The Countryside"], + ), + Constraint.NotNull(), ], ) @table class events(EventFrame): - date = Series(date) + date = Series(date, constraints=[Constraint.NotNull()]) code = Series( str, - constraints=[Constraint.Categorical(["abc", "def", "foo"])], + constraints=[ + Constraint.Categorical(["abc", "def", "foo"]), + Constraint.NotNull(), + ], ) @@ -37,6 +43,7 @@ def test_dummy_measures_data_generator(): intervals = years(2).starting_on("2020-01-01") measures = Measures() + measures.dummy_data_config.population_size = 200 measures.define_measure( "foo_events_by_sex", @@ -56,6 +63,7 @@ def test_dummy_measures_data_generator(): generator = DummyMeasuresDataGenerator( measures, measures.dummy_data_config, today=date(2024, 1, 1) ) + results = list(generator.get_results()) # Check we generated the right number of rows: 2 rows for each breakdown by sex, 3 diff --git a/tests/unit/dummy_data_nextgen/test_query_info.py b/tests/unit/dummy_data_nextgen/test_query_info.py index 982041eab..f7e0c6d99 100644 --- a/tests/unit/dummy_data_nextgen/test_query_info.py +++ b/tests/unit/dummy_data_nextgen/test_query_info.py @@ -103,11 +103,7 @@ class test_table(PatientFrame): query_info = QueryInfo.from_variable_definitions(variable_definitions) column_info = query_info.tables["test_table"].columns["value"] - assert column_info == ColumnInfo( - name="value", - type=str, - _values_used={"a", "b", "c", "d"}, - ) + assert column_info._values_used == {"a", "b", "c", "d"} def test_query_info_ignores_inline_patient_tables(): diff --git a/tests/unit/dummy_data_nextgen/test_query_to_constraints.py b/tests/unit/dummy_data_nextgen/test_query_to_constraints.py deleted file mode 100644 index e69e2a84c..000000000 --- a/tests/unit/dummy_data_nextgen/test_query_to_constraints.py +++ /dev/null @@ -1,89 +0,0 @@ -from datetime import date - -from ehrql import create_dataset, years -from ehrql.dummy_data_nextgen.query_info import ( - normalize_constraints, - query_to_column_constraints, -) -from ehrql.query_language import compile -from ehrql.tables import Constraint -from ehrql.tables.core import patients - - -def test_or_query_includes_constraints_on_each_side(): - dataset = create_dataset() - - dataset.define_population( - ((patients.date_of_birth.year == 1970) & (patients.sex == "male")) - | ((patients.date_of_birth.year == 1963) & (patients.sex == "male")) - ) - - variable_definitions = compile(dataset) - constraints = query_to_column_constraints(variable_definitions["population"]) - - assert len(constraints) == 1 - - (column,) = constraints.keys() - assert column.name == "sex" - (column_constraints,) = constraints.values() - assert column_constraints == [Constraint.Categorical(values=("male",))] - - -def test_combine_date_range_constraints(): - dataset = create_dataset() - index_date = "2023-10-01" - - dataset = create_dataset() - - was_adult = (patients.age_on(index_date) >= 18) & ( - patients.age_on(index_date) <= 100 - ) - - was_born_in_particular_range = (patients.date_of_birth < date(2000, 1, 1)) & ( - patients.date_of_birth > date(1970, 1, 1) - ) - - dataset.define_population(was_adult & was_born_in_particular_range) - - variable_definitions = compile(dataset) - constraints = query_to_column_constraints(variable_definitions["population"]) - - assert len(constraints) == 1 - - (column,) = constraints.keys() - assert column.name == "date_of_birth" - (column_constraints,) = constraints.values() - assert normalize_constraints(column_constraints) == ( - Constraint.GeneralRange( - minimum=date(1970, 1, 1), - maximum=date(2000, 1, 1), - includes_minimum=False, - includes_maximum=False, - ), - ) - - -def test_or_query_does_not_includes_constraints_on_only_one_size(): - dataset = create_dataset() - - dataset.define_population( - (patients.date_of_birth.year == 1970) | (patients.sex == "male") - ) - - variable_definitions = compile(dataset) - constraints = query_to_column_constraints(variable_definitions["population"]) - - assert len(constraints) == 0 - - -def test_gt_query_with_date_addition(): - dataset = create_dataset() - - index_date = date(2022, 3, 1) - died_more_than_10_years_ago = (patients.date_of_death + years(10)) < index_date - dataset.define_population(died_more_than_10_years_ago) - - variable_definitions = compile(dataset) - constraints = query_to_column_constraints(variable_definitions["population"]) - - assert len(constraints) == 1 diff --git a/tests/unit/dummy_data_nextgen/test_specific_datasets.py b/tests/unit/dummy_data_nextgen/test_specific_datasets.py index aee1835bd..d5c5d7b01 100644 --- a/tests/unit/dummy_data_nextgen/test_specific_datasets.py +++ b/tests/unit/dummy_data_nextgen/test_specific_datasets.py @@ -5,99 +5,66 @@ import pytest from hypothesis import example, given, settings from hypothesis import strategies as st +from hypothesis.vendor.pretty import pretty -from ehrql import create_dataset, years +from ehrql import case, create_dataset, maximum_of, when, years from ehrql.dummy_data_nextgen.generator import DummyDataGenerator -from ehrql.query_language import compile -from ehrql.tables.core import patients - - -@pytest.mark.parametrize("sex", ["male", "female", "intersex"]) -@mock.patch("ehrql.dummy_data_nextgen.generator.time") -def test_can_generate_single_sex_data_in_one_shot(patched_time, sex): - dataset = create_dataset() - - dataset.define_population(patients.sex == sex) - - target_size = 1000 - - variable_definitions = compile(dataset) - generator = DummyDataGenerator(variable_definitions, population_size=target_size) - generator.batch_size = target_size - generator.timeout = 10 - - # Configure `time.time()` so we timeout after one loop pass, as we - # should be able to generate these correctly in the first pass. - patched_time.time.side_effect = [0.0, 20.0] - data = generator.get_data() - - # Expecting a single table - assert len(data) == 1 - data_for_table = list(data.values())[0] - # Within that table expecting we generated a full population - assert len(data_for_table) == target_size - - -@mock.patch("ehrql.dummy_data_nextgen.generator.time") -def test_can_generate_patients_from_a_specific_year(patched_time): - dataset = create_dataset() - - dataset.define_population(patients.date_of_birth.year == 1950) - - target_size = 1000 - - variable_definitions = compile(dataset) - generator = DummyDataGenerator(variable_definitions, population_size=target_size) - generator.batch_size = target_size - generator.timeout = 10 - - # Configure `time.time()` so we timeout after one loop pass, as we - # should be able to generate these correctly in the first pass. - patched_time.time.side_effect = [0.0, 20.0] - data = generator.get_data() +from ehrql.exceptions import CannotGenerate +from ehrql.query_language import ( + EventFrame, + PatientFrame, + Series, + compile, + table, + table_from_rows, +) +from ehrql.tables.core import clinical_events, medications, patients - # Expecting a single table - assert len(data) == 1 - data_for_table = list(data.values())[0] - # Within that table expecting we generated a full population - assert len(data_for_table) == target_size +index_date = date(2022, 3, 1) -@mock.patch("ehrql.dummy_data_nextgen.generator.time") -def test_can_combine_constraints_on_generated_data(patched_time): - dataset = create_dataset() +is_female_or_male = patients.sex.is_in(["female", "male"]) - dataset.define_population( - (patients.date_of_birth.year == 1970) & (patients.sex == "intersex") - ) +was_adult = (patients.age_on(index_date) >= 18) & (patients.age_on(index_date) <= 110) - target_size = 1000 +was_alive = ( + patients.date_of_death.is_after(index_date) | patients.date_of_death.is_null() +) - variable_definitions = compile(dataset) - generator = DummyDataGenerator(variable_definitions, population_size=target_size) - generator.batch_size = target_size - generator.timeout = 10 +died_more_than_10_years_ago = (patients.date_of_death + years(10)) < index_date - # Configure `time.time()` so we timeout after one loop pass, as we - # should be able to generate these correctly in the first pass. - patched_time.time.side_effect = [0.0, 20.0] - data = generator.get_data() - # Expecting a single table - assert len(data) == 1 - data_for_table = list(data.values())[0] - # Within that table expecting we generated a full population - assert len(data_for_table) == target_size +@table +class extra_patients(PatientFrame): + some_integer = Series(int) @mock.patch("ehrql.dummy_data_nextgen.generator.time") -def test_will_satisfy_constraints_on_both_sides_of_an_or(patched_time): +@pytest.mark.parametrize( + "query", + [ + patients.sex == "male", + patients.date_of_birth.year == 1950, + (patients.date_of_birth.year == 1970) & (patients.sex == "intersex"), + ((patients.date_of_birth.year == 1970) & (patients.sex == "male")) + | ((patients.date_of_birth.year == 1963) & (patients.sex == "male")), + is_female_or_male & was_adult & was_alive, + died_more_than_10_years_ago, + patients.date_of_death.is_null(), + ~patients.date_of_death.is_null(), + case( + when(patients.sex == "male").then(1), + when(patients.sex == "female").then(2), + ) + >= 2, + maximum_of(patients.date_of_birth, patients.date_of_birth) <= index_date, + ], + ids=pretty, +) +def test_queries_with_exact_one_shot_generation(patched_time, query): dataset = create_dataset() - dataset.define_population( - ((patients.date_of_birth.year == 1970) & (patients.sex == "male")) - | ((patients.date_of_birth.year == 1963) & (patients.sex == "male")) - ) + dataset.define_population(patients.exists_for_patient() & query) target_size = 1000 @@ -109,53 +76,9 @@ def test_will_satisfy_constraints_on_both_sides_of_an_or(patched_time): # Configure `time.time()` so we timeout after one loop pass, as we # should be able to generate these correctly in the first pass. patched_time.time.side_effect = [0.0, 20.0] - data = generator.get_data() - - # Expecting a single table - assert len(data) == 1 - data_for_table = list(data.values())[0] - # Within that table expecting we generated a full population - assert len(data_for_table) > 0 - - -@mock.patch("ehrql.dummy_data_nextgen.generator.time") -def test_basic_patient_constraints_age_and_sex(patched_time): - index_date = "2023-10-01" - - dataset = create_dataset() + patient_ids = {row.patient_id for row in generator.get_results()} - is_female_or_male = patients.sex.is_in(["female", "male"]) - - was_adult = (patients.age_on(index_date) >= 18) & ( - patients.age_on(index_date) <= 110 - ) - - was_alive = ( - patients.date_of_death.is_after(index_date) | patients.date_of_death.is_null() - ) - - dataset.define_population(is_female_or_male & was_adult & was_alive) - - target_size = 1000 - - dataset.age = patients.age_on(index_date) - dataset.sex = patients.sex - - variable_definitions = compile(dataset) - generator = DummyDataGenerator(variable_definitions, population_size=target_size) - generator.batch_size = target_size - generator.timeout = 10 - - # Configure `time.time()` so we timeout after three loop passes. - # We cannot currently generate the exact data reliably because the - # constraint on date of death doesn't straightforwardly map to a - # single logical constraint. - patched_time.time.side_effect = [0.0, 1.0, 2.0, 20.0] - data = generator.get_data() - - (data_for_table,) = (v for k, v in data.items() if k.name == "patients") - # Within that table expecting we generated a full population - assert len(data_for_table) == target_size + assert len(patient_ids) == target_size @st.composite @@ -232,33 +155,163 @@ def test_combined_age_range_in_one_shot(patched_time, query, target_size): assert len(data_for_table) == target_size -@pytest.mark.xfail(reason="FIXME: This test is very slightly flaky at the moment.") +@table_from_rows([(i, i, False) for i in range(1, 1000)]) +class p(PatientFrame): + i = Series(int) + b = Series(bool) + + +@table +class p0(PatientFrame): + i1 = Series(int) + b1 = Series(bool) + d1 = Series(date) + + +@table +class e0(EventFrame): + b0 = Series(bool) + i1 = Series(int) + + @mock.patch("ehrql.dummy_data_nextgen.generator.time") -def test_date_arithmetic_comparison(patched_time): +@pytest.mark.parametrize( + "query", + [ + date(2010, 1, 1) < medications.date.minimum_for_patient(), + medications.sort_by(medications.date).first_for_patient().date + < date(2020, 1, 1), + clinical_events.where( + clinical_events.snomedct_code == "123456789" + ).exists_for_patient(), + maximum_of(patients.date_of_birth, patients.date_of_death) + <= index_date - years(10), + case( + when(patients.sex == "male").then(1), + when(patients.date_of_birth == index_date).then(2), + otherwise=3, + ) + >= 2, + case( + when(patients.date_of_birth <= index_date).then(patients.date_of_death), + otherwise=patients.date_of_birth, + ) + <= index_date, + p0.i1 + < case( + when(e0.sort_by(e0.i1).first_for_patient().b0).then( + e0.sort_by(e0.i1).first_for_patient().i1 + ), + otherwise=None, + ), + patients.date_of_birth < case(when(p0.b1).then(None), otherwise=p0.d1), + case(when(p0.b1).then(None), otherwise=p0.d1) + < patients.date_of_birth + years(1), + case(when(p0.b1).then(date(2010, 1, 2)), otherwise=date(2010, 1, 1)) + > date(2010, 1, 1), + ], + ids=pretty, +) +def test_queries_not_yet_well_handled(patched_time, query): + """Tests queries that we need to work, but do not currently + expect to be handled particularly well. + """ dataset = create_dataset() - index_date = date(2022, 3, 1) - died_more_than_10_years_ago = (patients.date_of_death + years(10)) < index_date - dataset.define_population(died_more_than_10_years_ago) - dataset.date_of_birth = patients.date_of_birth - dataset.date_of_death = patients.date_of_death + dataset.define_population(patients.exists_for_patient() & query) target_size = 1000 + variable_definitions = compile(dataset) + + generator = DummyDataGenerator(variable_definitions, population_size=target_size) + generator.batch_size = target_size + generator.timeout = 10 + + patched_time.time.side_effect = [0.0, 20.0] + + patient_ids = {row.patient_id for row in generator.get_results()} + + # Should be able to generate at least one patient satisfying this + assert len(patient_ids) > 0 + + # If one of these queries manages to generate fully in a single pass then + # it deserves a more specific test. This is effectively a sort of xfail + # assertion. + assert len(patient_ids) < target_size + + +@mock.patch("ehrql.dummy_data_nextgen.generator.time") +def test_inline_table_query(patched_time): + """Tests queries that we need to work, but do not currently + expect to be handled particularly well. + """ + dataset = create_dataset() + dataset.define_population(p.i < 1000) + dataset.i = p.i + + target_size = 1000 variable_definitions = compile(dataset) + generator = DummyDataGenerator(variable_definitions, population_size=target_size) generator.batch_size = target_size generator.timeout = 10 - # Configure `time.time()` so we timeout after one loop pass, as we - # should be able to generate these correctly in the first pass. patched_time.time.side_effect = [0.0, 20.0] - data = generator.get_data() - # Expecting a single table - assert len(data) == 1 - data_for_table = list(data.values())[0] - # Confirm that all patients have date of birth before date of death - assert all(row[1] <= row[2] for row in data_for_table) - # Within that table expecting we generated a full population - assert len(data_for_table) == target_size + patient_ids = {row.patient_id for row in generator.get_results()} + + # Should be able to generate at least one patient satisfying this + assert len(patient_ids) > 0 + + # If one of these queries manages to generate fully in a single pass then + # it deserves a more specific test. This is effectively a sort of xfail + # assertion. + assert len(patient_ids) < target_size + + +@pytest.mark.parametrize( + "query", + [ + patients.sex == "book", + extra_patients.some_integer + 1 < extra_patients.some_integer, + ], + ids=pretty, +) +@mock.patch("ehrql.dummy_data_nextgen.generator.time") +def test_will_raise_if_all_data_is_impossible(patched_time, query): + dataset = create_dataset() + + dataset.define_population(query) + target_size = 1000 + variable_definitions = compile(dataset) + + generator = DummyDataGenerator(variable_definitions, population_size=target_size) + generator.timeout = 1 + patched_time.time.side_effect = [0.0, 20.0] + with pytest.raises(CannotGenerate): + generator.get_results() + + +def test_generates_events_starting_from_birthdate(): + dataset = create_dataset() + + age = patients.age_on("2020-03-31") + + dataset.age = age + dataset.sex = patients.sex + + events = clinical_events.sort_by(clinical_events.date).first_for_patient() + dataset.dob = patients.date_of_birth + dataset.event_date = events.date + dataset.after_dob = events.date >= patients.date_of_birth + dataset.define_population((age > 18) & (age < 80) & ~dataset.event_date.is_null()) + + target_size = 1000 + dataset.configure_dummy_data(population_size=target_size) + variable_definitions = compile(dataset) + + generator = DummyDataGenerator(variable_definitions, population_size=target_size) + + for row in generator.get_results(): + assert row.after_dob