Skip to content

Commit

Permalink
Merge pull request #2321 from opensafely-core/evansd/dataset-qm
Browse files Browse the repository at this point in the history
Add a proper query model type for datasets
  • Loading branch information
evansd authored Dec 17, 2024
2 parents 055d510 + 62c76d8 commit 452b758
Show file tree
Hide file tree
Showing 37 changed files with 371 additions and 265 deletions.
7 changes: 3 additions & 4 deletions ehrql/assurance.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
UNEXPECTED_OUTPUT_VALUE = "unexpected-output-value"


def validate(variable_definitions, test_data):
def validate(dataset, test_data):
"""Validates that the given test data
(1) meet the constraints in the tables and
(2) produce the given expected output.
Expand All @@ -23,7 +23,7 @@ def validate(variable_definitions, test_data):
"""

# Create objects to insert into database
table_nodes = get_table_nodes(*variable_definitions.values())
table_nodes = get_table_nodes(dataset)

constraint_validation_errors = {}
input_data = {table: [] for table in table_nodes}
Expand All @@ -50,8 +50,7 @@ def validate(variable_definitions, test_data):
# Query the database
engine = InMemoryQueryEngine(database)
query_results = {
row.patient_id: row._asdict()
for row in engine.get_results(variable_definitions)
row.patient_id: row._asdict() for row in engine.get_results(dataset)
}

# Validate results of query
Expand Down
7 changes: 4 additions & 3 deletions ehrql/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import sqlalchemy

from ehrql.query_language import get_tables_from_namespace
from ehrql.query_model import nodes as qm


class ValidationError(Exception): ...
Expand All @@ -22,11 +23,11 @@ def modify_dsn(self, dsn: str | None) -> str | None:
"""
return dsn

def modify_query_variables(self, variables: dict) -> dict:
def modify_dataset(self, dataset: qm.Dataset) -> qm.Dataset:
"""
This hook gives backends the option to modify queries before they are run
This hook gives backends the option to modify the dataset before running it
"""
return variables
return dataset

def modify_inline_table_args(self, columns, rows):
"""
Expand Down
14 changes: 8 additions & 6 deletions ehrql/backends/tpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,11 +94,11 @@ def modify_dsn(self, dsn):
new_parts = parts._replace(query=new_query)
return parse.urlunparse(new_parts)

def modify_query_variables(self, variables):
def modify_dataset(self, dataset):
# If this query has been explictly flagged as including T1OO patients then
# return it unmodified
if self.include_t1oo:
return variables
return dataset

# Otherwise we add an extra condition to the population definition which is that
# the patient does not appear in the T1OO table.
Expand All @@ -111,9 +111,8 @@ def modify_query_variables(self, variables):
# From ehrQL's point of view, the construction of the T1OO table is opaque. For
# discussion of the approach currently used to populate this see:
# https://docs.google.com/document/d/1nBAwDucDCeoNeC5IF58lHk6LT-RJg6YZRp5RRkI7HI8/
variables = dict(variables)
variables["population"] = qm.Function.And(
variables["population"],
new_population = qm.Function.And(
dataset.population,
qm.Function.Not(
qm.AggregateByPatient.Exists(
# We don't currently expose this table in the user-facing schema. If
Expand All @@ -126,7 +125,10 @@ def modify_query_variables(self, variables):
)
),
)
return variables
return qm.Dataset(
population=new_population,
variables=dataset.variables,
)

def get_exit_status_for_exception(self, exception):
# Checking for "DatabaseError" in the MRO means we can identify database errors without
Expand Down
17 changes: 9 additions & 8 deletions ehrql/dummy_data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from ehrql.query_engines.in_memory import InMemoryQueryEngine
from ehrql.query_engines.in_memory_database import InMemoryDatabase
from ehrql.query_model.introspection import all_inline_patient_ids
from ehrql.query_model.nodes import Dataset
from ehrql.tables import Constraint
from ehrql.utils.regex_utils import create_regex_generator

Expand All @@ -26,14 +27,14 @@
class DummyDataGenerator:
def __init__(
self,
variable_definitions,
dataset,
population_size=10,
batch_size=5000,
random_seed="BwRV3spP",
timeout=60,
today=None,
):
self.variable_definitions = variable_definitions
self.dataset = dataset
self.population_size = population_size
self.batch_size = batch_size
self.random_seed = random_seed
Expand All @@ -43,7 +44,7 @@ def __init__(
# suitable time range by inspecting the query, this will have to do.
self.today = today if today is not None else date.today()
self.patient_generator = DummyPatientGenerator(
self.variable_definitions, self.random_seed, self.today
self.dataset, self.random_seed, self.today
)
log.info("Using legacy dummy data generation")

Expand All @@ -55,7 +56,7 @@ def get_data(self):

# Create a version of the query with just the population definition, and an
# in-memory engine to run it against
population_query = {"population": self.variable_definitions["population"]}
population_query = Dataset(population=self.dataset.population, variables={})
database = InMemoryDatabase()
engine = InMemoryQueryEngine(database)

Expand Down Expand Up @@ -124,7 +125,7 @@ def get_patient_id_batches(self):
def get_patient_id_stream(self):
# Where a query involves inline tables we want to extract all the patient IDs
# and include them in the IDs for which we're going to generate dummy data
inline_patient_ids = all_inline_patient_ids(*self.variable_definitions.values())
inline_patient_ids = all_inline_patient_ids(self.dataset)
yield from sorted(inline_patient_ids)
for i in range(1, 2**63): # pragma: no branch
if i not in inline_patient_ids:
Expand All @@ -133,15 +134,15 @@ def get_patient_id_stream(self):
def get_results(self):
database = InMemoryDatabase(self.get_data())
engine = InMemoryQueryEngine(database)
return engine.get_results(self.variable_definitions)
return engine.get_results(self.dataset)


class DummyPatientGenerator:
def __init__(self, variable_definitions, random_seed, today):
def __init__(self, dataset, random_seed, today):
self.rnd = random.Random()
self.random_seed = random_seed
self.today = today
self.query_info = QueryInfo.from_variable_definitions(variable_definitions)
self.query_info = QueryInfo.from_dataset(dataset)

def get_patient_data_for_population_condition(self, patient_id):
# Generate data for just those tables needed for determining whether the patient
Expand Down
21 changes: 10 additions & 11 deletions ehrql/dummy_data/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
)
from ehrql.query_engines.in_memory import InMemoryQueryEngine
from ehrql.query_engines.in_memory_database import InMemoryDatabase
from ehrql.query_model.nodes import Function
from ehrql.query_model.nodes import Dataset, Function


class DummyMeasuresDataGenerator:
def __init__(self, measures, dummy_data_config, **kwargs):
self.measures = measures
combined = CombinedMeasureComponents.from_measures(measures)
self.generator = DummyDataGenerator(
get_dataset_variables(combined),
get_dataset(combined),
population_size=get_population_size(dummy_data_config, combined),
timeout=dummy_data_config.timeout,
**kwargs,
Expand Down Expand Up @@ -55,27 +55,26 @@ def from_measures(cls, measures):
)


def get_dataset_variables(combined):
def get_dataset(combined):
"""
Return a dict of dataset definition variables suitable for passing to the dummy data
generator which should produce dummy data of the right shape to use for calculating
measures
Return a query model dataset suitable for passing to the dummy data generator which
should produce dummy data of the right shape to use for calculating measures
"""
variable_placeholders = {
dataset_placeholders = Dataset(
# Use the union of all denominators as the population
"population": reduce(Function.Or, combined.denominators),
**{
population=reduce(Function.Or, combined.denominators),
variables={
f"column_{i}": column
for i, column in enumerate([*combined.numerators, *combined.groups])
},
}
)

# Use the maximum range over all intervals as a date range
min_interval_start = min(interval[0] for interval in combined.intervals)
max_interval_end = max(interval[1] for interval in combined.intervals)

return substitute_interval_parameters(
variable_placeholders, (min_interval_start, max_interval_end)
dataset_placeholders, (min_interval_start, max_interval_end)
)


Expand Down
8 changes: 5 additions & 3 deletions ehrql/dummy_data/query_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from ehrql.query_model.introspection import all_unique_nodes, get_table_nodes
from ehrql.query_model.nodes import (
Column,
Dataset,
Function,
InlinePatientTable,
SelectColumn,
Expand Down Expand Up @@ -97,8 +98,9 @@ class QueryInfo:
other_table_names: list[str]

@classmethod
def from_variable_definitions(cls, variable_definitions):
all_nodes = all_unique_nodes(*variable_definitions.values())
def from_dataset(cls, dataset):
assert isinstance(dataset, Dataset)
all_nodes = all_unique_nodes(dataset)
by_type = get_nodes_by_type(all_nodes)

tables = {
Expand Down Expand Up @@ -155,7 +157,7 @@ def from_variable_definitions(cls, variable_definitions):
# Record which tables are used in determining population membership and which
# are not
population_table_names = {
node.name for node in get_table_nodes(variable_definitions["population"])
node.name for node in get_table_nodes(dataset.population)
}

other_table_names = tables.keys() - population_table_names
Expand Down
31 changes: 17 additions & 14 deletions ehrql/dummy_data_nextgen/generator.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import dataclasses
import functools
import itertools
import logging
Expand All @@ -16,7 +17,7 @@
from ehrql.query_engines.in_memory_database import InMemoryDatabase
from ehrql.query_language import DummyDataConfig
from ehrql.query_model.introspection import all_inline_patient_ids
from ehrql.query_model.nodes import Function
from ehrql.query_model.nodes import Dataset, Function
from ehrql.tables import Constraint
from ehrql.utils.regex_utils import create_regex_generator

Expand Down Expand Up @@ -73,14 +74,13 @@ def get_possible_values(self, column_info):
class DummyDataGenerator:
@classmethod
def from_dataset(cls, dataset, **kwargs):
variable_definitions = dataset._compile()
return cls(
variable_definitions, configuration=dataset.dummy_data_config, **kwargs
dataset._compile(), configuration=dataset.dummy_data_config, **kwargs
)

def __init__(
self,
variable_definitions,
dataset,
configuration=None,
batch_size=5000,
random_seed="BwRV3spP",
Expand All @@ -95,12 +95,15 @@ def __init__(
)
assert not configuration.legacy
self.configuration = configuration
self.variable_definitions = variable_definitions
if self.configuration.additional_population_constraint is not None:
variable_definitions["population"] = Function.And(
lhs=variable_definitions["population"],
rhs=self.configuration.additional_population_constraint,
dataset = dataclasses.replace(
dataset,
population=Function.And(
lhs=dataset.population,
rhs=self.configuration.additional_population_constraint,
),
)
self.dataset = dataset
self.population_size = configuration.population_size
self.batch_size = batch_size
self.random_seed = random_seed
Expand All @@ -110,7 +113,7 @@ def __init__(
# suitable time range by inspecting the query, this will have to do.
self.today = today if today is not None else date.today()
self.patient_generator = DummyPatientGenerator(
self.variable_definitions,
self.dataset,
self.random_seed,
self.today,
self.population_size,
Expand All @@ -125,7 +128,7 @@ def get_data(self):

# Create a version of the query with just the population definition, and an
# in-memory engine to run it against
population_query = {"population": self.variable_definitions["population"]}
population_query = Dataset(population=self.dataset.population, variables={})
database = InMemoryDatabase()
engine = InMemoryQueryEngine(database)

Expand Down Expand Up @@ -236,7 +239,7 @@ def get_patient_id_batches(self):
def get_patient_id_stream(self):
# Where a query involves inline tables we want to extract all the patient IDs
# and include them in the IDs for which we're going to generate dummy data
inline_patient_ids = all_inline_patient_ids(*self.variable_definitions.values())
inline_patient_ids = all_inline_patient_ids(self.dataset)
yield from sorted(inline_patient_ids)
for i in range(1, 2**63): # pragma: no branch
if i not in inline_patient_ids:
Expand All @@ -245,15 +248,15 @@ def get_patient_id_stream(self):
def get_results(self):
database = InMemoryDatabase(self.get_data())
engine = InMemoryQueryEngine(database)
return engine.get_results(self.variable_definitions)
return engine.get_results(self.dataset)


class DummyPatientGenerator:
def __init__(self, variable_definitions, random_seed, today, population_size):
def __init__(self, dataset, random_seed, today, population_size):
self.__rnd = None
self.random_seed = random_seed
self.today = today
self.query_info = QueryInfo.from_variable_definitions(variable_definitions)
self.query_info = QueryInfo.from_dataset(dataset)
self.population_size = population_size

self.__active_population_subsets = []
Expand Down
21 changes: 10 additions & 11 deletions ehrql/dummy_data_nextgen/measures.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,15 @@
)
from ehrql.query_engines.in_memory import InMemoryQueryEngine
from ehrql.query_engines.in_memory_database import InMemoryDatabase
from ehrql.query_model.nodes import Function
from ehrql.query_model.nodes import Dataset, Function


class DummyMeasuresDataGenerator:
def __init__(self, measures, dummy_data_config, **kwargs):
self.measures = measures
combined = CombinedMeasureComponents.from_measures(measures)
self.generator = DummyDataGenerator(
get_dataset_variables(combined),
get_dataset(combined),
configuration=replace(
dummy_data_config,
population_size=get_population_size(dummy_data_config, combined),
Expand Down Expand Up @@ -59,27 +59,26 @@ def from_measures(cls, measures):
)


def get_dataset_variables(combined):
def get_dataset(combined):
"""
Return a dict of dataset definition variables suitable for passing to the dummy data
generator which should produce dummy data of the right shape to use for calculating
measures
Return a query model dataset suitable for passing to the dummy data generator which
should produce dummy data of the right shape to use for calculating measures
"""
variable_placeholders = {
dataset_placeholders = Dataset(
# Use the union of all denominators as the population
"population": reduce(Function.Or, combined.denominators),
**{
population=reduce(Function.Or, combined.denominators),
variables={
f"column_{i}": column
for i, column in enumerate([*combined.numerators, *combined.groups])
},
}
)

# Use the maximum range over all intervals as a date range
min_interval_start = min(interval[0] for interval in combined.intervals)
max_interval_end = max(interval[1] for interval in combined.intervals)

return substitute_interval_parameters(
variable_placeholders, (min_interval_start, max_interval_end)
dataset_placeholders, (min_interval_start, max_interval_end)
)


Expand Down
Loading

0 comments on commit 452b758

Please sign in to comment.