From 22440ff4121b3de720c8a8db566ca757d930fb2a Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 15 Jan 2025 11:56:06 +0000 Subject: [PATCH 01/14] Support multiple queries in `get_setup_and_cleanup_queries` --- ehrql/main.py | 2 +- ehrql/query_engines/base_sql.py | 2 +- ehrql/query_engines/mssql.py | 2 +- ehrql/utils/sqlalchemy_query_utils.py | 28 ++++++----- tests/integration/backends/test_emis.py | 4 +- .../unit/utils/test_sqlalchemy_query_utils.py | 47 ++++++++++++++++--- 6 files changed, 61 insertions(+), 24 deletions(-) diff --git a/ehrql/main.py b/ehrql/main.py index 122ef3d83..1ee96c8bc 100644 --- a/ehrql/main.py +++ b/ehrql/main.py @@ -176,7 +176,7 @@ def dump_dataset_sql( def get_sql_strings(query_engine, dataset): results_query = query_engine.get_query(dataset) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) dialect = query_engine.sqlalchemy_dialect() sql_strings = [] diff --git a/ehrql/query_engines/base_sql.py b/ehrql/query_engines/base_sql.py index 8d567ec2c..c698741e6 100644 --- a/ehrql/query_engines/base_sql.py +++ b/ehrql/query_engines/base_sql.py @@ -837,7 +837,7 @@ def get_select_query_for_node_domain(self, node): def get_results(self, dataset): results_query = self.get_query(dataset) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) with self.engine.connect() as connection: for i, setup_query in enumerate(setup_queries, start=1): log.info(f"Running setup query {i:03} / {len(setup_queries):03}") diff --git a/ehrql/query_engines/mssql.py b/ehrql/query_engines/mssql.py index 1b9c803b7..dce83355e 100644 --- a/ehrql/query_engines/mssql.py +++ b/ehrql/query_engines/mssql.py @@ -170,7 +170,7 @@ def get_results(self, dataset): results_table = results_query.get_final_froms()[0] assert str(results_query) == str(sqlalchemy.select(results_table)) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) with self.engine.connect() as connection: # All our queries are either (a) read-only queries against static data, or diff --git a/ehrql/utils/sqlalchemy_query_utils.py b/ehrql/utils/sqlalchemy_query_utils.py index 29ef72704..4ab99a294 100644 --- a/ehrql/utils/sqlalchemy_query_utils.py +++ b/ehrql/utils/sqlalchemy_query_utils.py @@ -72,9 +72,10 @@ def from_query(cls, name, query, metadata=None, **kwargs): return cls(name, metadata, *columns, **kwargs) -def get_setup_and_cleanup_queries(query): +def get_setup_and_cleanup_queries(queries): """ - Given a SQLAlchemy query find all GeneratedTables embeded in it and return a pair: + Given a list of SQLAlchemy queries find all GeneratedTables embedded in them and + return a pair: setup_queries, cleanup_queries @@ -90,7 +91,7 @@ def get_setup_and_cleanup_queries(query): # give it a sequence of pairs of tables (A, B) indicating that A depends on B and it # returns a suitable ordering over the tables. sorter = graphlib.TopologicalSorter() - for parent_table, table in get_generated_table_dependencies(query): + for parent_table, table in get_generated_table_dependencies(queries): # A parent_table of None indicates a root table (i.e. one with no dependants) so # we record its existence without specifying any dependencies if parent_table is None: @@ -118,10 +119,10 @@ def get_setup_and_cleanup_queries(query): return setup_queries, cleanup_queries -def get_generated_table_dependencies(query, parent_table=None, seen_tables=None): +def get_generated_table_dependencies(queries, parent_table=None, seen_tables=None): """ - Recursively find all GeneratedTable objects referenced by `query` and yield as pairs - of dependencies: + Recursively find all GeneratedTable objects referenced by any query in `queries` and + yield as pairs of dependencies: table_X, table_Y_which_is_referenced_by_X @@ -130,14 +131,15 @@ def get_generated_table_dependencies(query, parent_table=None, seen_tables=None) if seen_tables is None: seen_tables = set() - for table in get_generated_tables(query): - yield parent_table, table - # Don't recurse into the same table twice - if table not in seen_tables: - seen_tables.add(table) - for child_query in [*table.setup_queries, *table.cleanup_queries]: + for query in queries: + for table in get_generated_tables(query): + yield parent_table, table + # Don't recurse into the same table twice + if table not in seen_tables: + seen_tables.add(table) + child_queries = [*table.setup_queries, *table.cleanup_queries] yield from get_generated_table_dependencies( - child_query, parent_table=table, seen_tables=seen_tables + child_queries, parent_table=table, seen_tables=seen_tables ) diff --git a/tests/integration/backends/test_emis.py b/tests/integration/backends/test_emis.py index f55d2b1fe..9994e5c7a 100644 --- a/tests/integration/backends/test_emis.py +++ b/tests/integration/backends/test_emis.py @@ -591,7 +591,7 @@ class t(PatientFrame): ] assert len(set(inline_tables)) == 1 inline_table = inline_tables[0] - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) with query_engine.engine.connect() as connection: for setup_query in setup_queries: @@ -646,7 +646,7 @@ def test_temp_table_includes_organisation_hash(trino_database): ] assert len(set(temp_tables)) == 1 temp_table = temp_tables[0] - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) with query_engine.engine.connect() as connection: for setup_query in setup_queries: diff --git a/tests/unit/utils/test_sqlalchemy_query_utils.py b/tests/unit/utils/test_sqlalchemy_query_utils.py index 5c53551c8..f6dfb606b 100644 --- a/tests/unit/utils/test_sqlalchemy_query_utils.py +++ b/tests/unit/utils/test_sqlalchemy_query_utils.py @@ -72,7 +72,7 @@ def test_get_setup_and_cleanup_queries_basic(): query = sqlalchemy.select(temp_table.c.foo) # Check that we get the right queries in the right order - assert _queries_as_strs(query) == [ + assert _queries_as_strs([query]) == [ "CREATE TABLE temp_table (\n\tfoo NULL\n)", "INSERT INTO temp_table (foo) VALUES (:foo)", "SELECT temp_table.foo \nFROM temp_table", @@ -100,7 +100,7 @@ def test_get_setup_and_cleanup_queries_nested(): query = sqlalchemy.select(temp_table2.c.baz) # Check that we create and drop the temporary tables in the right order - assert _queries_as_strs(query) == [ + assert _queries_as_strs([query]) == [ "CREATE TABLE temp_table1 (\n\tfoo NULL\n)", "INSERT INTO temp_table1 (foo) VALUES (:foo)", "CREATE TABLE temp_table2 (\n\tbaz NULL\n)", @@ -111,6 +111,41 @@ def test_get_setup_and_cleanup_queries_nested(): ] +def test_get_setup_and_cleanup_queries_multiple(): + # Make a temporary table + temp_table1 = _make_temp_table("temp_table1", "foo") + temp_table1.setup_queries.append( + temp_table1.insert().values(foo="bar"), + ) + + # Make a second temporary table ... + temp_table2 = _make_temp_table("temp_table2", "baz") + temp_table2.setup_queries.append( + # ... populated by a SELECT query against the first table + temp_table2.insert().from_select( + [temp_table2.c.baz], sqlalchemy.select(temp_table1.c.foo) + ), + ) + + # Select something from the second table + query_1 = sqlalchemy.select(temp_table2.c.baz) + + # Select something from the first table + query_2 = sqlalchemy.select(temp_table1.c.foo) + + # Check that we create and drop the temporary tables in the right order + assert _queries_as_strs([query_1, query_2]) == [ + "CREATE TABLE temp_table1 (\n\tfoo NULL\n)", + "INSERT INTO temp_table1 (foo) VALUES (:foo)", + "CREATE TABLE temp_table2 (\n\tbaz NULL\n)", + "INSERT INTO temp_table2 (baz) SELECT temp_table1.foo \nFROM temp_table1", + "SELECT temp_table2.baz \nFROM temp_table2", + "SELECT temp_table1.foo \nFROM temp_table1", + "DROP TABLE temp_table2", + "DROP TABLE temp_table1", + ] + + def _make_temp_table(name, *columns): table = GeneratedTable( name, sqlalchemy.MetaData(), *[sqlalchemy.Column(c) for c in columns] @@ -122,11 +157,11 @@ def _make_temp_table(name, *columns): return table -def _queries_as_strs(query): - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(query) +def _queries_as_strs(queries): + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(queries) return ( [str(q).strip() for q in setup_queries] - + [str(query).strip()] + + [str(q).strip() for q in queries] + [str(q).strip() for q in cleanup_queries] ) @@ -193,7 +228,7 @@ def test_get_setup_and_cleanup_queries_with_insert_many(): sqlalchemy.Column("i", sqlalchemy.Integer()), ) statement = InsertMany(table, rows=[]) - setup_cleanup = get_setup_and_cleanup_queries(statement) + setup_cleanup = get_setup_and_cleanup_queries([statement]) assert setup_cleanup == ([], []) From fb189da55dab73638ef41181c2ce79da390db705 Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 15 Jan 2025 11:58:12 +0000 Subject: [PATCH 02/14] Change `get_query()` to `get_queries()` --- ehrql/main.py | 11 +++++++---- ehrql/query_engines/base_sql.py | 16 ++++++++++----- ehrql/query_engines/mssql.py | 26 ++++++++++++++----------- tests/integration/backends/test_emis.py | 14 +++++++------ 4 files changed, 41 insertions(+), 26 deletions(-) diff --git a/ehrql/main.py b/ehrql/main.py index 1ee96c8bc..83d13b66c 100644 --- a/ehrql/main.py +++ b/ehrql/main.py @@ -175,8 +175,8 @@ def dump_dataset_sql( def get_sql_strings(query_engine, dataset): - results_query = query_engine.get_query(dataset) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) + results_queries = query_engine.get_queries(dataset) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) dialect = query_engine.sqlalchemy_dialect() sql_strings = [] @@ -184,8 +184,11 @@ def get_sql_strings(query_engine, dataset): sql = clause_as_str(query, dialect) sql_strings.append(f"-- Setup query {i:03} / {len(setup_queries):03}\n{sql}") - sql = clause_as_str(results_query, dialect) - sql_strings.append(f"-- Results query\n{sql}") + for i, query in enumerate(results_queries, start=1): + sql = clause_as_str(query, dialect) + sql_strings.append( + f"-- Results query {i:03} / {len(results_queries):03}\n{sql}" + ) for i, query in enumerate(cleanup_queries, start=1): sql = clause_as_str(query, dialect) diff --git a/ehrql/query_engines/base_sql.py b/ehrql/query_engines/base_sql.py index c698741e6..ad88b1404 100644 --- a/ehrql/query_engines/base_sql.py +++ b/ehrql/query_engines/base_sql.py @@ -84,9 +84,9 @@ def get_next_id(self): self.counter += 1 return self.counter - def get_query(self, dataset): + def get_queries(self, dataset): """ - Return the SQL query to fetch the results for `dataset` + Return the SQL queries to fetch the results for `dataset` Note that this query might make use of intermediate tables. The SQL queries needed to create these tables and clean them up can be retrieved by calling @@ -127,7 +127,9 @@ def get_query(self, dataset): self.get_sql.cache_clear() self.get_table.cache_clear() - return query + # At the moment we only support a single results table and so we'll only ever + # have a single query + return [query] def select_patient_id_for_population(self, population_expression): """ @@ -836,8 +838,12 @@ def get_select_query_for_node_domain(self, node): return query def get_results(self, dataset): - results_query = self.get_query(dataset) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) + results_queries = self.get_queries(dataset) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) + + assert len(results_queries) == 1 + results_query = results_queries[0] + with self.engine.connect() as connection: for i, setup_query in enumerate(setup_queries, start=1): log.info(f"Running setup query {i:03} / {len(setup_queries):03}") diff --git a/ehrql/query_engines/mssql.py b/ehrql/query_engines/mssql.py index dce83355e..93bf6aa47 100644 --- a/ehrql/query_engines/mssql.py +++ b/ehrql/query_engines/mssql.py @@ -152,25 +152,29 @@ def create_inline_table(self, columns, rows): ] return table - def get_query(self, dataset): - results_query = super().get_query(dataset) - # Write results to a temporary table and select them from there. This allows us + def get_queries(self, dataset): + results_queries = super().get_queries(dataset) + # Write results to temporary tables and select them from there. This allows us # to use more efficient/robust mechanisms to retrieve the results. - results_table = temporary_table_from_query( - "#results", results_query, index_col="patient_id" - ) - return sqlalchemy.select(results_table) + select_queries = [] + for n, results_query in enumerate(results_queries, start=1): + results_table = temporary_table_from_query( + f"#results_{n}", results_query, index_col="patient_id" + ) + select_queries.append(sqlalchemy.select(results_table)) + return select_queries def get_results(self, dataset): - results_query = self.get_query(dataset) + results_queries = self.get_queries(dataset) # We're expecting a query in a very specific form which is "select everything # from one table"; so we assert that it has this form and retrieve a reference # to the table - results_table = results_query.get_final_froms()[0] - assert str(results_query) == str(sqlalchemy.select(results_table)) + assert len(results_queries) == 1 + results_table = results_queries[0].get_final_froms()[0] + assert str(results_queries[0]) == str(sqlalchemy.select(results_table)) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) with self.engine.connect() as connection: # All our queries are either (a) read-only queries against static data, or diff --git a/tests/integration/backends/test_emis.py b/tests/integration/backends/test_emis.py index 9994e5c7a..ba3944121 100644 --- a/tests/integration/backends/test_emis.py +++ b/tests/integration/backends/test_emis.py @@ -583,15 +583,16 @@ class t(PatientFrame): variables = dataset._compile() - results_query = query_engine.get_query(variables) + results_queries = query_engine.get_queries(variables) + assert len(results_queries) == 1 inline_tables = [ ch - for ch in results_query.get_children() + for ch in results_queries[0].get_children() if isinstance(ch, GeneratedTable) and "inline_data" in ch.name ] assert len(set(inline_tables)) == 1 inline_table = inline_tables[0] - setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) with query_engine.engine.connect() as connection: for setup_query in setup_queries: @@ -638,15 +639,16 @@ def test_temp_table_includes_organisation_hash(trino_database): ) variables = dataset._compile() - results_query = query_engine.get_query(variables) + results_queries = query_engine.get_queries(variables) + assert len(results_queries) == 1 temp_tables = [ ch - for ch in results_query.get_children() + for ch in results_queries[0].get_children() if isinstance(ch, GeneratedTable) and "tmp" in ch.name ] assert len(set(temp_tables)) == 1 temp_table = temp_tables[0] - setup_queries, cleanup_queries = get_setup_and_cleanup_queries([results_query]) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) with query_engine.engine.connect() as connection: for setup_query in setup_queries: From b6ab544f9ee1ac995a4003e57c1038f877f0bdbc Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 15 Jan 2025 17:09:05 +0000 Subject: [PATCH 03/14] Add `iter_groups` function --- ehrql/utils/itertools_utils.py | 74 ++++++++++++++++++++++++ tests/unit/utils/test_itertools_utils.py | 72 ++++++++++++++++++++++- 2 files changed, 145 insertions(+), 1 deletion(-) diff --git a/ehrql/utils/itertools_utils.py b/ehrql/utils/itertools_utils.py index 4b5faffa2..56b62f19d 100644 --- a/ehrql/utils/itertools_utils.py +++ b/ehrql/utils/itertools_utils.py @@ -27,3 +27,77 @@ def iter_flatten(iterable, iter_classes=(list, tuple, GeneratorType)): yield from iter_flatten(item, iter_classes) else: yield item + + +def iter_groups(iterable, separator): + """ + Split a flat iterator of items into a nested iterator of groups of items. Groups are + delineated by the presence of a sentinel `separator` value which marks the start of + each group. + + For example, the iterator: + + - SEPARATOR + - 1 + - 2 + - SEPARATOR + - 3 + - 4 + + Will be transformed into the nested iterator: + + - + - 1 + - 2 + - + - 3 + - 4 + + This is useful for situations where a nested iterator is the natural API for + representing the data but the flat iterator is much easier to generate correctly. + """ + iterator = iter(iterable) + try: + first_item = next(iterator) + except StopIteration: + return + assert first_item is separator, ( + f"Invalid iterator: does not start with `separator` value {separator!r}" + ) + while True: + group_iter = GroupIterator(iterator, separator) + yield group_iter + # Prevent the caller from trying to consume the next group before they've + # finished consuming the current one (as would happen if you naively called + # `list()` on the result of `iter_groups()`) + assert group_iter._group_complete, ( + "Cannot consume next group until current group has been exhausted" + ) + if group_iter._exhausted: + break + + +class GroupIterator: + def __init__(self, iterator, separator): + self._iterator = iterator + self._separator = separator + self._group_complete = False + self._exhausted = False + + def __iter__(self): + return self + + def __next__(self): + if self._group_complete: + raise StopIteration() + try: + value = next(self._iterator) + except StopIteration: + self._group_complete = True + self._exhausted = True + raise + if value is self._separator: + self._group_complete = True + raise StopIteration() + else: + return value diff --git a/tests/unit/utils/test_itertools_utils.py b/tests/unit/utils/test_itertools_utils.py index 04756d64e..2d805d16b 100644 --- a/tests/unit/utils/test_itertools_utils.py +++ b/tests/unit/utils/test_itertools_utils.py @@ -1,6 +1,8 @@ +import hypothesis as hyp +import hypothesis.strategies as st import pytest -from ehrql.utils.itertools_utils import eager_iterator, iter_flatten +from ehrql.utils.itertools_utils import eager_iterator, iter_flatten, iter_groups def test_eager_iterator(): @@ -49,3 +51,71 @@ def test_iter_flatten(): ] flattened = list(iter_flatten(nested)) assert flattened == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "foo"] + + +SEPARATOR = object() + + +@pytest.mark.parametrize( + "stream,expected", + [ + ( + [], + [], + ), + ( + [SEPARATOR], + [[]], + ), + ( + [SEPARATOR, 1, 2, SEPARATOR, SEPARATOR, 3, 4], + [[1, 2], [], [3, 4]], + ), + ], +) +def test_iter_groups(stream, expected): + results = [list(group) for group in iter_groups(stream, SEPARATOR)] + assert results == expected + + +@hyp.given( + nested=st.lists( + st.lists(st.integers(), max_size=5), + max_size=5, + ) +) +def test_iter_groups_roundtrips(nested): + flattened = [] + for inner in nested: + flattened.append(SEPARATOR) + for item in inner: + flattened.append(item) + + results = [list(group) for group in iter_groups(flattened, SEPARATOR)] + assert results == nested + + +def test_iter_groups_rejects_invalid_stream(): + stream_no_separator = [1, 2] + with pytest.raises( + AssertionError, + match="Invalid iterator: does not start with `separator` value", + ): + list(iter_groups(stream_no_separator, SEPARATOR)) + + +def test_iter_groups_rejects_out_of_order_reads(): + stream = [SEPARATOR, 1, 2, SEPARATOR, 3, 4] + with pytest.raises( + AssertionError, + match="Cannot consume next group until current group has been exhausted", + ): + list(iter_groups(stream, SEPARATOR)) + + +def test_iter_groups_allows_overreading_groups(): + stream = [SEPARATOR, 1, 2, SEPARATOR, 3, 4] + # We call `list` on each group twice: this should make no difference because on the + # second call the group should be exhausted and so result in an empty list + results = [list(group) + list(group) for group in iter_groups(stream, SEPARATOR)] + assert results == [[1, 2], [3, 4]] From 8ff16b412066c2982f8932e2759c440fffdb005c Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 15 Jan 2025 12:28:43 +0000 Subject: [PATCH 04/14] Update QueryEngine API to support multiple results tables --- ehrql/__main__.py | 2 +- ehrql/dummy_data/generator.py | 10 ++++-- ehrql/dummy_data_nextgen/generator.py | 10 ++++-- ehrql/query_engines/base.py | 43 +++++++++++++++++++++++-- ehrql/query_engines/base_sql.py | 31 +++++++++--------- ehrql/query_engines/in_memory.py | 3 +- ehrql/query_engines/local_file.py | 4 +-- ehrql/query_engines/mssql.py | 46 +++++++++++++++++---------- tests/unit/test___main__.py | 2 +- 9 files changed, 106 insertions(+), 45 deletions(-) diff --git a/ehrql/__main__.py b/ehrql/__main__.py index b55701469..9d0bf849f 100644 --- a/ehrql/__main__.py +++ b/ehrql/__main__.py @@ -701,7 +701,7 @@ def query_engine_from_id(str_id): f"(or a full dotted path to a query engine class)" ) query_engine = import_string(str_id) - assert_duck_type(query_engine, "query engine", "get_results") + assert_duck_type(query_engine, "query engine", "get_results_tables") return query_engine diff --git a/ehrql/dummy_data/generator.py b/ehrql/dummy_data/generator.py index d87f82541..03beb711a 100644 --- a/ehrql/dummy_data/generator.py +++ b/ehrql/dummy_data/generator.py @@ -131,10 +131,16 @@ def get_patient_id_stream(self): if i not in inline_patient_ids: yield i - def get_results(self): + def get_results_tables(self): database = InMemoryDatabase(self.get_data()) engine = InMemoryQueryEngine(database) - return engine.get_results(self.dataset) + return engine.get_results_tables(self.dataset) + + def get_results(self): + tables = self.get_results_tables() + yield from next(tables) + for remaining in tables: + assert False, "Expected only one results table" class DummyPatientGenerator: diff --git a/ehrql/dummy_data_nextgen/generator.py b/ehrql/dummy_data_nextgen/generator.py index fe96e1fbd..043185e1c 100644 --- a/ehrql/dummy_data_nextgen/generator.py +++ b/ehrql/dummy_data_nextgen/generator.py @@ -245,10 +245,16 @@ def get_patient_id_stream(self): if i not in inline_patient_ids: yield i - def get_results(self): + def get_results_tables(self): database = InMemoryDatabase(self.get_data()) engine = InMemoryQueryEngine(database) - return engine.get_results(self.dataset) + return engine.get_results_tables(self.dataset) + + def get_results(self): + tables = self.get_results_tables() + yield from next(tables) + for remaining in tables: + assert False, "Expected only one results table" class DummyPatientGenerator: diff --git a/ehrql/query_engines/base.py b/ehrql/query_engines/base.py index beeb8cba4..895779dd8 100644 --- a/ehrql/query_engines/base.py +++ b/ehrql/query_engines/base.py @@ -2,6 +2,10 @@ from typing import Any from ehrql.query_model import nodes as qm +from ehrql.utils.itertools_utils import iter_groups + + +class Marker: ... class BaseQueryEngine: @@ -12,6 +16,9 @@ class BaseQueryEngine: flavour of tables and query language (SQL, pandas dataframes etc). """ + # Sentinel value used to mark the start of a new results table in a stream of results + RESULTS_START = Marker() + def __init__(self, dsn: str, backend: Any = None, config: dict | None = None): """ `dsn` is Data Source Name — a string (usually a URL) which provides connection @@ -25,12 +32,42 @@ def __init__(self, dsn: str, backend: Any = None, config: dict | None = None): self.backend = backend self.config = config or {} - def get_results(self, dataset: qm.Dataset) -> Iterator[Sequence]: + def get_results_tables(self, dataset: qm.Dataset) -> Iterator[Iterator[Sequence]]: + """ + Given a query model `Dataset` return an iterator of "results tables", where each + table is an iterator of rows (usually tuples, but any sequence type will do) + + This is the primary interface to query engines and the one required method. + + Typically however, query engine subclasses will implement `get_results_stream` + instead which yields a flat sequence of rows, with tables separated by the + `RESULTS_START` marker value. This is converted into the appropriate structure + by `iter_groups` which also enforces that the caller interacts with it safely. """ - Given a query model `Dataset` return the results as an iterator of "rows" (which - are usually tuples, but any sequence type will do) + return iter_groups(self.get_results_stream(dataset), self.RESULTS_START) + + def get_results_stream(self, dataset: qm.Dataset) -> Iterator[Sequence | Marker]: + """ + Given a query model `Dataset` return an iterator of rows over all the results + tables, with each table's results separated by the `RESULTS_START` marker value Override this method to do the things necessary to generate query code and execute it against a particular backend. + + Emitting results in a flat sequence like this with separators between the tables + ends up making the query code _much_ easier to reason about because everything + happens in a clear linear sequence rather than inside nested generators. This + makes things like transaction management and error handling much more + straightforward. """ raise NotImplementedError() + + def get_results(self, dataset: qm.Dataset) -> Iterator[Sequence]: + """ + Temporary method to continue to support code which assumes only a single results + table + """ + tables = self.get_results_tables(dataset) + yield from next(tables) + for remaining in tables: + assert False, "Expected only one results table" diff --git a/ehrql/query_engines/base_sql.py b/ehrql/query_engines/base_sql.py index ad88b1404..ed2939fa3 100644 --- a/ehrql/query_engines/base_sql.py +++ b/ehrql/query_engines/base_sql.py @@ -837,30 +837,29 @@ def get_select_query_for_node_domain(self, node): query = query.where(sqlalchemy.and_(*where_clauses)) return query - def get_results(self, dataset): + def get_results_stream(self, dataset): results_queries = self.get_queries(dataset) setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) - assert len(results_queries) == 1 - results_query = results_queries[0] - with self.engine.connect() as connection: for i, setup_query in enumerate(setup_queries, start=1): log.info(f"Running setup query {i:03} / {len(setup_queries):03}") connection.execute(setup_query) - log.info("Fetching results") - cursor_result = connection.execute(results_query) - try: - yield from cursor_result - except Exception: # pragma: no cover - # If we hit an error part way through fetching results then we should - # close the cursor to make it clear we're not going to be fetching any - # more (only really relevant for the in-memory SQLite tests, but good - # hygiene in any case) - cursor_result.close() - # Make sure the cleanup happens before raising the error - raise + for i, results_query in enumerate(results_queries, start=1): + log.info(f"Fetching results {i:03} / {len(setup_queries):03}") + cursor_result = connection.execute(results_query) + yield self.RESULTS_START + try: + yield from cursor_result + except Exception: # pragma: no cover + # If we hit an error part way through fetching results then we should + # close the cursor to make it clear we're not going to be fetching any + # more (only really relevant for the in-memory SQLite tests, but good + # hygiene in any case) + cursor_result.close() + # Make sure the cleanup happens before raising the error + raise for i, cleanup_query in enumerate(cleanup_queries, start=1): log.info(f"Running cleanup query {i:03} / {len(cleanup_queries):03}") diff --git a/ehrql/query_engines/in_memory.py b/ehrql/query_engines/in_memory.py index 649b2ec31..f857379da 100644 --- a/ehrql/query_engines/in_memory.py +++ b/ehrql/query_engines/in_memory.py @@ -28,9 +28,10 @@ class InMemoryQueryEngine(BaseQueryEngine): tests, and a to provide a reference implementation for other engines. """ - def get_results(self, dataset): + def get_results_stream(self, dataset): table = self.get_results_as_patient_table(dataset) Row = namedtuple("Row", table.name_to_col.keys()) + yield self.RESULTS_START for record in table.to_records(): yield Row(**record) diff --git a/ehrql/query_engines/local_file.py b/ehrql/query_engines/local_file.py index c2019f979..c65ebc6fc 100644 --- a/ehrql/query_engines/local_file.py +++ b/ehrql/query_engines/local_file.py @@ -14,14 +14,14 @@ class LocalFileQueryEngine(InMemoryQueryEngine): database = None - def get_results(self, dataset): + def get_results_stream(self, dataset): # Given the dataset supplied determine the tables used and load the associated # data into the database self.populate_database( get_table_nodes(dataset), ) # Run the query as normal - return super().get_results(dataset) + return super().get_results_stream(dataset) def populate_database(self, table_nodes, allow_missing_columns=True): table_specs = { diff --git a/ehrql/query_engines/mssql.py b/ehrql/query_engines/mssql.py index 93bf6aa47..4adb37d6c 100644 --- a/ehrql/query_engines/mssql.py +++ b/ehrql/query_engines/mssql.py @@ -164,15 +164,17 @@ def get_queries(self, dataset): select_queries.append(sqlalchemy.select(results_table)) return select_queries - def get_results(self, dataset): + def get_results_stream(self, dataset): results_queries = self.get_queries(dataset) - # We're expecting a query in a very specific form which is "select everything - # from one table"; so we assert that it has this form and retrieve a reference - # to the table - assert len(results_queries) == 1 - results_table = results_queries[0].get_final_froms()[0] - assert str(results_queries[0]) == str(sqlalchemy.select(results_table)) + # We're expecting queries in a very specific form which is "select everything + # from one table"; so we assert that they have this form and retrieve references + # to the tables + results_tables = [] + for results_query in results_queries: + results_table = results_query.get_final_froms()[0] + assert str(results_query) == str(sqlalchemy.select(results_table)) + results_tables.append(results_table) setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) @@ -197,16 +199,26 @@ def get_results(self, dataset): log=log.info, ) - yield from fetch_table_in_batches( - execute_with_retry, - results_table, - key_column=results_table.c.patient_id, - key_is_unique=True, - # This value was copied from the previous cohortextractor. I suspect it - # has no real scientific basis. - batch_size=32000, - log=log.info, - ) + for i, results_table in enumerate(results_tables): + yield self.RESULTS_START + yield from fetch_table_in_batches( + execute_with_retry, + results_table, + key_column=results_table.c.patient_id, + # TODO: We need to find a better way to identify which tables have a + # unique `patient_id` column because it lets the batch fetcher use a + # more efficient algorithm. At present, we know that the first + # results table does but this isn't a very comfortable approach. The + # other option is to just always use the non-unique algorithm on the + # basis that the lost efficiency probably isn't noticeable. But + # until we're supporting event-level data for real I'm reluctant to + # make things worse for the currently supported case. + key_is_unique=(i == 0), + # This value was copied from the previous cohortextractor. I suspect + # it has no real scientific basis. + batch_size=32000, + log=log.info, + ) for i, cleanup_query in enumerate(cleanup_queries, start=1): query_id = f"cleanup query {i:03} / {len(cleanup_queries):03}" diff --git a/tests/unit/test___main__.py b/tests/unit/test___main__.py index 141979299..0f6fe5bb7 100644 --- a/tests/unit/test___main__.py +++ b/tests/unit/test___main__.py @@ -248,7 +248,7 @@ def test_import_string_no_such_attribute(): class DummyQueryEngine: - def get_results(self): + def get_results_tables(self): raise NotImplementedError() From c14ccc1b3789980e7d3b01d5b6f5c73000114744 Mon Sep 17 00:00:00 2001 From: David Evans Date: Fri, 17 Jan 2025 17:16:58 +0000 Subject: [PATCH 05/14] Generator must now be consumed to trigger error --- tests/unit/dummy_data_nextgen/test_specific_datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/unit/dummy_data_nextgen/test_specific_datasets.py b/tests/unit/dummy_data_nextgen/test_specific_datasets.py index 7993f9324..fcf14de7c 100644 --- a/tests/unit/dummy_data_nextgen/test_specific_datasets.py +++ b/tests/unit/dummy_data_nextgen/test_specific_datasets.py @@ -325,7 +325,7 @@ def test_will_raise_if_all_data_is_impossible(patched_time, query): generator.timeout = 1 patched_time.time.side_effect = [0.0, 20.0] with pytest.raises(CannotGenerate): - generator.get_results() + next(generator.get_results()) def test_generates_events_starting_from_birthdate(): From 8c7bb21610dd7995784bb7dc192fe55a367ee339 Mon Sep 17 00:00:00 2001 From: David Evans Date: Fri, 17 Jan 2025 11:06:27 +0000 Subject: [PATCH 06/14] Change `get_column_specs()` to `get_table_specs()` In future Datasets will be able to produce more than one results table and so this is the API we need. --- ehrql/main.py | 7 +++-- ehrql/query_model/column_specs.py | 31 ++++++++++++--------- tests/unit/query_model/test_column_specs.py | 20 +++++++------ 3 files changed, 34 insertions(+), 24 deletions(-) diff --git a/ehrql/main.py b/ehrql/main.py index 83d13b66c..851a1d1b0 100644 --- a/ehrql/main.py +++ b/ehrql/main.py @@ -35,8 +35,8 @@ from ehrql.query_engines.local_file import LocalFileQueryEngine from ehrql.query_engines.sqlite import SQLiteQueryEngine from ehrql.query_model.column_specs import ( - get_column_specs, get_column_specs_from_schema, + get_table_specs, ) from ehrql.query_model.graphs import graph_to_svg from ehrql.serializer import serialize @@ -71,7 +71,10 @@ def generate_dataset( log.info(f"Testing dataset definition with tests in {str(definition_file)}") assure(test_data_file, environ=environ, user_args=user_args) - column_specs = get_column_specs(dataset) + table_specs = get_table_specs(dataset) + # For now we only handle datasets with a single output table + assert len(table_specs) == 1 + column_specs = list(table_specs.values())[0] if dsn: log.info("Generating dataset") diff --git a/ehrql/query_model/column_specs.py b/ehrql/query_model/column_specs.py index 52ead2529..648e7fe3e 100644 --- a/ehrql/query_model/column_specs.py +++ b/ehrql/query_model/column_specs.py @@ -6,7 +6,6 @@ AggregateByPatient, Case, Constraint, - Dataset, SelectColumn, SelectPatientTable, Value, @@ -27,16 +26,25 @@ class ColumnSpec: max_value: T | None = None -def get_column_specs(dataset): +def get_table_specs(dataset): """ - Given a dataset return a dict of ColumnSpec objects, given the types (and other - associated metadata) of all the columns in the output + Return the specifications for all the results tables this Dataset will produce + """ + # At present, Datasets only ever produce a single results table (which we call + # `dataset`) but this gives us the API we need for future expansion + return {"dataset": get_column_specs_from_variables(dataset.variables)} + + +def get_column_specs_from_variables(variables): + """ + Given a dict of dataset variables return a dict of ColumnSpec objects, given + the types (and other associated metadata) of all the columns in the output """ # TODO: It may not be universally true that IDs are ints. Internally the EMIS IDs # are SHA512 hashes stored as hex strings which we convert to ints. But we can't # guarantee always to be able to pull this trick. column_specs = {"patient_id": ColumnSpec(int, nullable=False)} - for name, series in dataset.variables.items(): + for name, series in variables.items(): column_specs[name] = get_column_spec_from_series(series) return column_specs @@ -46,14 +54,11 @@ def get_column_specs_from_schema(schema): # reusing all the logic above: we create a table node and then create some variables # by selecting each column in the schema from it. table = SelectPatientTable(name="table", schema=schema) - dataset = Dataset( - population=Value(False), - variables={ - column_name: SelectColumn(source=table, name=column_name) - for column_name in schema.column_names - }, - ) - return get_column_specs(dataset) + variables = { + column_name: SelectColumn(source=table, name=column_name) + for column_name in schema.column_names + } + return get_column_specs_from_variables(variables) def get_column_spec_from_series(series): diff --git a/tests/unit/query_model/test_column_specs.py b/tests/unit/query_model/test_column_specs.py index 3ab52506f..14f59edc9 100644 --- a/tests/unit/query_model/test_column_specs.py +++ b/tests/unit/query_model/test_column_specs.py @@ -4,8 +4,8 @@ from ehrql.query_model.column_specs import ( ColumnSpec, get_categories, - get_column_specs, get_range, + get_table_specs, ) from ehrql.query_model.nodes import ( AggregateByPatient, @@ -46,14 +46,16 @@ def test_get_column_specs(): category=SelectColumn(patients, "category"), ), ) - column_specs = get_column_specs(dataset) - assert column_specs == { - "patient_id": ColumnSpec(type=int, nullable=False, categories=None), - "dob": ColumnSpec(type=datetime.date, nullable=True, categories=None), - "code": ColumnSpec(type=str, nullable=True, categories=None), - "category": ColumnSpec( - type=str, nullable=True, categories=("123000", "456000") - ), + table_specs = get_table_specs(dataset) + assert table_specs == { + "dataset": { + "patient_id": ColumnSpec(type=int, nullable=False, categories=None), + "dob": ColumnSpec(type=datetime.date, nullable=True, categories=None), + "code": ColumnSpec(type=str, nullable=True, categories=None), + "category": ColumnSpec( + type=str, nullable=True, categories=("123000", "456000") + ), + } } From be92f98ad1967d8d2f3cc589bac5d786370cc6ab Mon Sep 17 00:00:00 2001 From: David Evans Date: Tue, 14 Jan 2025 08:26:27 +0000 Subject: [PATCH 07/14] Support single-table formats in `read_tables`/`write_tables` This allows you to call the `write_tables` function with a file specification that only supports a single table, as long as you've only got a single table to write. And similarly for `read_tables`. This allows us to use a common code path for datasets which will only produce a single results table and those which will produce multiple (and similarly will require multiple dummy data files). We pay special attention to making the error messages here as helpful as possible. This is on the assumption that the first time anyone tries to use event-level data they won't immediately realise that they need to specify their output and dummy data files in a different way. --- ehrql/file_formats/main.py | 56 ++++++++++++++++ tests/integration/file_formats/test_main.py | 72 +++++++++++++++++++++ 2 files changed, 128 insertions(+) diff --git a/ehrql/file_formats/main.py b/ehrql/file_formats/main.py index 570696c24..216fee361 100644 --- a/ehrql/file_formats/main.py +++ b/ehrql/file_formats/main.py @@ -48,6 +48,30 @@ def read_rows(filename, column_specs, allow_missing_columns=False): def read_tables(filename, table_specs, allow_missing_columns=False): + # If we've got a single-table input file and only a single table to read then that's + # fine, but it needs slightly special handling + if not input_filename_supports_multiple_tables(filename): + if len(table_specs) == 1: + column_specs = list(table_specs.values())[0] + rows = read_rows( + filename, + column_specs, + allow_missing_columns=allow_missing_columns, + ) + yield from [rows] + return + else: + files = list(table_specs.keys()) + suffix = filename.suffix + raise FileValidationError( + f"Attempting to read {len(table_specs)} tables, but input only " + f"provides a single table\n" + f" Try moving -> {filename}\n" + f" to -> {filename.parent / filename.stem}/{files[0]}{suffix}\n" + f" adding -> {', '.join(f + suffix for f in files[1:])}\n" + f" and using path -> {filename.parent / filename.stem}/" + ) + extension = get_extension_from_directory(filename) # Using ExitStack here allows us to open and validate all files before emiting any # rows while still correctly closing all open files if we raise an error part way @@ -66,6 +90,22 @@ def read_tables(filename, table_specs, allow_missing_columns=False): def write_tables(filename, tables, table_specs): + # If we've got a single-table output file and only a single table to write then + # that's fine, but it needs slightly special handling + if not output_filename_supports_multiple_tables(filename): + if len(table_specs) == 1: + column_specs = list(table_specs.values())[0] + rows = next(iter(tables)) + return write_rows(filename, rows, column_specs) + else: + raise FileValidationError( + f"Attempting to write {len(table_specs)} tables, but output only " + f"supports a single table\n" + f" Instead of -> {filename}\n" + f" try -> " + f"{filename.parent / filename.stem}/:{filename.suffix.lstrip('.')}" + ) + filename, extension = split_directory_and_extension(filename) for rows, (table_name, column_specs) in zip(tables, table_specs.items()): table_filename = get_table_filename(filename, table_name, extension) @@ -121,6 +161,22 @@ def split_directory_and_extension(filename): return filename.with_name(name), f".{extension}" +def input_filename_supports_multiple_tables(filename): + # At present, supplying a directory is the only way to provide multiple input + # tables, but it's not inconceivable that in future we might support single-file + # multiple-table formats e.g SQLite or DuckDB files. If we do then updating this + # function and its sibling below should be all that's required. + return filename.is_dir() + + +def output_filename_supports_multiple_tables(filename): + if filename is None: + return False + # Again, at present only directories support multiple output tables but see above + extension = split_directory_and_extension(filename)[1] + return extension != "" + + def get_table_filename(base_filename, table_name, extension): # Use URL quoting as an easy way of escaping any potentially problematic characters # in filenames diff --git a/tests/integration/file_formats/test_main.py b/tests/integration/file_formats/test_main.py index 5c0770c93..1debca5c3 100644 --- a/tests/integration/file_formats/test_main.py +++ b/tests/integration/file_formats/test_main.py @@ -1,4 +1,6 @@ +import contextlib import datetime +import textwrap import pytest @@ -230,3 +232,73 @@ def test_read_and_write_tables_roundtrip(tmp_path, extension): results = read_tables(tmp_path / "output", table_specs) assert [list(rows) for rows in results] == tables + + +def test_read_tables_allows_single_table_format_if_only_one_table(tmp_path): + filename = tmp_path / "file.csv" + filename.write_text("i,s\n1,a\n2,b\n3,c\n") + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + } + results = read_tables(filename, table_specs) + assert [list(rows) for rows in results] == [ + [(1, "a"), (2, "b"), (3, "c")], + ] + + +def test_write_tables_allows_single_table_format_if_only_one_table(tmp_path): + filename = tmp_path / "file.csv" + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + } + table_data = [ + [(1, "a"), (2, "b"), (3, "c")], + ] + write_tables(filename, table_data, table_specs) + assert filename.read_text() == "i,s\n1,a\n2,b\n3,c\n" + + +def test_read_tables_rejects_single_table_format_if_multiple_tables(tmp_path): + filename = tmp_path / "input.csv" + filename.touch() + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + "table_2": {"j": ColumnSpec(int), "k": ColumnSpec(float)}, + "table_3": {"l": ColumnSpec(int), "m": ColumnSpec(float)}, + } + expected_error = textwrap.dedent( + """\ + Attempting to read 3 tables, but input only provides a single table + Try moving -> input.csv + to -> input/table_1.csv + adding -> table_2.csv, table_3.csv + and using path -> input/ + """ + ) + with contextlib.chdir(tmp_path): + # Use relative paths to get predictable error message + relpath = filename.relative_to(tmp_path) + with pytest.raises(FileValidationError, match=expected_error.rstrip()): + list(read_tables(relpath, table_specs)) + + +def test_write_tables_rejects_single_table_format_if_multiple_tables(tmp_path): + filename = tmp_path / "output.csv" + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + "table_2": {"j": ColumnSpec(int), "k": ColumnSpec(float)}, + "table_3": {"l": ColumnSpec(int), "m": ColumnSpec(float)}, + } + table_data = [[], []] + expected_error = textwrap.dedent( + """\ + Attempting to write 3 tables, but output only supports a single table + Instead of -> output.csv + try -> output/:csv + """ + ) + with contextlib.chdir(tmp_path): + # Use relative paths to get predictable error message + relpath = filename.relative_to(tmp_path) + with pytest.raises(FileValidationError, match=expected_error.rstrip()): + write_tables(relpath, table_data, table_specs) From 4f7b5297f7acf929472703a8c1cd17557bb062a3 Mon Sep 17 00:00:00 2001 From: David Evans Date: Thu, 23 Jan 2025 13:53:38 +0000 Subject: [PATCH 08/14] Raise correct error for missing directories in `read_rows` Previously you'd get a misleading error about the input file only supporting a single table (because it isn't a directory, because it doesn't exist). --- ehrql/file_formats/main.py | 3 +++ tests/integration/file_formats/test_main.py | 11 +++++++++++ 2 files changed, 14 insertions(+) diff --git a/ehrql/file_formats/main.py b/ehrql/file_formats/main.py index 216fee361..4e299dec5 100644 --- a/ehrql/file_formats/main.py +++ b/ehrql/file_formats/main.py @@ -48,6 +48,9 @@ def read_rows(filename, column_specs, allow_missing_columns=False): def read_tables(filename, table_specs, allow_missing_columns=False): + if not filename.exists(): + raise FileValidationError(f"Missing file or directory: {filename}") + # If we've got a single-table input file and only a single table to read then that's # fine, but it needs slightly special handling if not input_filename_supports_multiple_tables(filename): diff --git a/tests/integration/file_formats/test_main.py b/tests/integration/file_formats/test_main.py index 1debca5c3..8d39ddaf0 100644 --- a/tests/integration/file_formats/test_main.py +++ b/tests/integration/file_formats/test_main.py @@ -302,3 +302,14 @@ def test_write_tables_rejects_single_table_format_if_multiple_tables(tmp_path): relpath = filename.relative_to(tmp_path) with pytest.raises(FileValidationError, match=expected_error.rstrip()): write_tables(relpath, table_data, table_specs) + + +def test_read_tables_with_missing_file_raises_appropriate_error(tmp_path): + missing_file = tmp_path / "aint-no-such-file" + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + "table_2": {"j": ColumnSpec(int), "k": ColumnSpec(float)}, + "table_3": {"l": ColumnSpec(int), "m": ColumnSpec(float)}, + } + with pytest.raises(FileValidationError, match="Missing file or directory"): + next(read_tables(missing_file, table_specs)) From e9c994403ea3889daa829d2ad37bf99f4482ee54 Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 22 Jan 2025 12:31:13 +0000 Subject: [PATCH 09/14] Add dedicated console output handler This means we now support writing multiple tables of output to the console, as we'll need to once `generate-dataset` produces more than one output table. At present we just re-use the CSV renderer for this which means that there's no user-facing change. I think in future we should consider using some kind of pretty-printer to make the output more user-friendly and readable. --- ehrql/file_formats/console.py | 28 +++++++ ehrql/file_formats/main.py | 7 ++ tests/integration/file_formats/test_main.py | 25 ++++++ tests/unit/file_formats/test_console.py | 85 +++++++++++++++++++++ 4 files changed, 145 insertions(+) create mode 100644 ehrql/file_formats/console.py create mode 100644 tests/unit/file_formats/test_console.py diff --git a/ehrql/file_formats/console.py b/ehrql/file_formats/console.py new file mode 100644 index 000000000..6ab4f7513 --- /dev/null +++ b/ehrql/file_formats/console.py @@ -0,0 +1,28 @@ +""" +Handles writing rows/tables to the console for local development and debugging. + +At present, this just uses the CSV writer but there's scope for using something a bit +prettier and more readable here in future. +""" + +import sys + +from ehrql.file_formats.csv import write_rows_csv_lines + + +def write_rows_console(rows, column_specs): + write_rows_csv_lines(sys.stdout, rows, column_specs) + + +def write_tables_console(tables, table_specs): + write_table_names = len(table_specs) > 1 + first_table = True + for rows, (table_name, column_specs) in zip(tables, table_specs.items()): + if first_table: + first_table = False + else: + # Add whitespace between tables + sys.stdout.write("\n\n") + if write_table_names: + sys.stdout.write(f"{table_name}\n") + write_rows_console(rows, column_specs) diff --git a/ehrql/file_formats/main.py b/ehrql/file_formats/main.py index 4e299dec5..31b3be7d6 100644 --- a/ehrql/file_formats/main.py +++ b/ehrql/file_formats/main.py @@ -6,6 +6,7 @@ write_rows_arrow, ) from ehrql.file_formats.base import FileValidationError +from ehrql.file_formats.console import write_rows_console, write_tables_console from ehrql.file_formats.csv import ( CSVGZRowsReader, CSVRowsReader, @@ -23,6 +24,9 @@ def write_rows(filename, rows, column_specs): + if filename is None: + return write_rows_console(rows, column_specs) + extension = get_file_extension(filename) writer = FILE_FORMATS[extension][0] # `rows` is often a generator which won't actually execute until we start consuming @@ -93,6 +97,9 @@ def read_tables(filename, table_specs, allow_missing_columns=False): def write_tables(filename, tables, table_specs): + if filename is None: + return write_tables_console(tables, table_specs) + # If we've got a single-table output file and only a single table to write then # that's fine, but it needs slightly special handling if not output_filename_supports_multiple_tables(filename): diff --git a/tests/integration/file_formats/test_main.py b/tests/integration/file_formats/test_main.py index 8d39ddaf0..880390702 100644 --- a/tests/integration/file_formats/test_main.py +++ b/tests/integration/file_formats/test_main.py @@ -313,3 +313,28 @@ def test_read_tables_with_missing_file_raises_appropriate_error(tmp_path): } with pytest.raises(FileValidationError, match="Missing file or directory"): next(read_tables(missing_file, table_specs)) + + +def test_write_rows_without_filename_writes_to_console(capsys): + write_rows(None, TEST_FILE_DATA, TEST_FILE_SPECS) + output = capsys.readouterr().out + # The exact content here is tested elsewhere, we just want to make sure things are + # wired up correctly + assert "patient_id" in output + + +def test_write_tables_without_filename_writes_to_console(capsys): + table_specs = { + "table_1": TEST_FILE_SPECS, + "table_2": TEST_FILE_SPECS, + } + table_data = [ + TEST_FILE_DATA, + TEST_FILE_DATA, + ] + write_tables(None, table_data, table_specs) + output = capsys.readouterr().out + # The exact content here is tested elsewhere, we just want to make sure things are + # wired up correctly + assert "patient_id" in output + assert "table_2" in output diff --git a/tests/unit/file_formats/test_console.py b/tests/unit/file_formats/test_console.py new file mode 100644 index 000000000..c76415598 --- /dev/null +++ b/tests/unit/file_formats/test_console.py @@ -0,0 +1,85 @@ +import datetime +import textwrap + +from ehrql.file_formats.console import write_rows_console, write_tables_console +from ehrql.query_model.column_specs import ColumnSpec +from ehrql.sqlalchemy_types import TYPE_MAP + + +def test_write_rows_console(capsys): + column_specs = { + "patient_id": ColumnSpec(int), + "b": ColumnSpec(bool), + "i": ColumnSpec(int), + "f": ColumnSpec(float), + "s": ColumnSpec(str), + "c": ColumnSpec(str, categories=("A", "B")), + "d": ColumnSpec(datetime.date), + } + + rows = [ + (123, True, 1, 2.3, "a", "A", datetime.date(2020, 1, 1)), + (456, False, -5, -0.4, "b", "B", datetime.date(2022, 12, 31)), + (789, None, None, None, None, None, None), + ] + + # Check the example uses at least one of every supported type + assert {spec.type for spec in column_specs.values()} == set(TYPE_MAP) + + write_rows_console(rows, column_specs) + output = capsys.readouterr().out + + # The CSV module does its own newline handling, hence the carriage returns below + assert output == textwrap.dedent( + """\ + patient_id,b,i,f,s,c,d\r + 123,T,1,2.3,a,A,2020-01-01\r + 456,F,-5,-0.4,b,B,2022-12-31\r + 789,,,,,,\r + """ + ) + + +def test_write_tables_console(capsys): + table_specs = { + "table_1": { + "patient_id": ColumnSpec(int), + "b": ColumnSpec(bool), + "i": ColumnSpec(int), + }, + "table_2": { + "patient_id": ColumnSpec(int), + "s": ColumnSpec(str), + "d": ColumnSpec(datetime.date), + }, + } + + tables = [ + [ + (123, True, 1), + (456, False, 2), + ], + [ + (789, "a", datetime.date(2025, 1, 1)), + (987, "B", datetime.date(2025, 2, 3)), + ], + ] + + write_tables_console(tables, table_specs) + output = capsys.readouterr().out + + # The CSV module does its own newline handling, hence the carriage returns below + assert output == textwrap.dedent( + """\ + table_1 + patient_id,b,i\r + 123,T,1\r + 456,F,2\r + + + table_2 + patient_id,s,d\r + 789,a,2025-01-01\r + 987,B,2025-02-03\r + """ + ) From 8d4137309c29a455d88747a9b728a335fd0572a5 Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 22 Jan 2025 12:38:03 +0000 Subject: [PATCH 10/14] Remove special case handling for missing filename --- ehrql/file_formats/csv.py | 11 ++--------- ehrql/file_formats/main.py | 11 ++--------- tests/unit/file_formats/test_main.py | 1 - 3 files changed, 4 insertions(+), 19 deletions(-) diff --git a/ehrql/file_formats/csv.py b/ehrql/file_formats/csv.py index f081e0215..1c691b061 100644 --- a/ehrql/file_formats/csv.py +++ b/ehrql/file_formats/csv.py @@ -1,8 +1,6 @@ import csv import datetime import gzip -import sys -from contextlib import nullcontext from ehrql.file_formats.base import ( BaseRowsReader, @@ -12,13 +10,8 @@ def write_rows_csv(filename, rows, column_specs): - if filename is None: - context = nullcontext(sys.stdout) - else: - # Set `newline` as per Python docs: - # https://docs.python.org/3/library/csv.html#id3 - context = filename.open(mode="w", newline="") - with context as f: + # Set `newline` as per Python docs: https://docs.python.org/3/library/csv.html#id3 + with filename.open(mode="w", newline="") as f: write_rows_csv_lines(f, rows, column_specs) diff --git a/ehrql/file_formats/main.py b/ehrql/file_formats/main.py index 31b3be7d6..acd68dabc 100644 --- a/ehrql/file_formats/main.py +++ b/ehrql/file_formats/main.py @@ -35,9 +35,7 @@ def write_rows(filename, rows, column_specs): # whole thing into memory. So we wrap it in a function which draws the first item # upfront, but doesn't consume the rest of the iterator. rows = eager_iterator(rows) - # We use None for stdout - if filename is not None: - filename.parent.mkdir(parents=True, exist_ok=True) + filename.parent.mkdir(parents=True, exist_ok=True) writer(filename, rows, column_specs) @@ -123,10 +121,7 @@ def write_tables(filename, tables, table_specs): def get_file_extension(filename): - if filename is None: - # If we have no filename we're writing to stdout, so default to CSV - return ".csv" - elif filename.suffix == ".gz": + if filename.suffix == ".gz": return "".join(filename.suffixes[-2:]) else: return filename.suffix @@ -180,8 +175,6 @@ def input_filename_supports_multiple_tables(filename): def output_filename_supports_multiple_tables(filename): - if filename is None: - return False # Again, at present only directories support multiple output tables but see above extension = split_directory_and_extension(filename)[1] return extension != "" diff --git a/tests/unit/file_formats/test_main.py b/tests/unit/file_formats/test_main.py index 8ac33c074..aeb20cb23 100644 --- a/tests/unit/file_formats/test_main.py +++ b/tests/unit/file_formats/test_main.py @@ -17,7 +17,6 @@ @pytest.mark.parametrize( "filename,extension", [ - (None, ".csv"), (Path("a/b.c/file.txt"), ".txt"), (Path("a/b.c/file.txt.foo"), ".foo"), (Path("a/b.c/file.txt.gz"), ".txt.gz"), From f7e646a359d65ff049d784ba2085c67b0bc2b164 Mon Sep 17 00:00:00 2001 From: David Evans Date: Fri, 17 Jan 2025 12:33:42 +0000 Subject: [PATCH 11/14] fix: Add plumbing to support multiple output files --- ehrql/__main__.py | 26 ++++++++++++++++++++++---- ehrql/main.py | 23 ++++++++++------------- tests/unit/test___main__.py | 30 ++++++++++++++++++++++++++++++ 3 files changed, 62 insertions(+), 17 deletions(-) diff --git a/ehrql/__main__.py b/ehrql/__main__.py index 9d0bf849f..87a56377a 100644 --- a/ehrql/__main__.py +++ b/ehrql/__main__.py @@ -662,13 +662,31 @@ def existing_python_file(value): def valid_output_path(value): + # This can be either a single file or a directory, but either way it needs to + # specify a valid output format path = Path(value) - extension = get_file_extension(path) - if extension not in FILE_FORMATS: + directory_ext = split_directory_and_extension(path)[1] + file_ext = get_file_extension(path) + if not directory_ext and not file_ext: raise ArgumentTypeError( - f"'{extension}' is not a supported format, must be one of: " - f"{backtick_join(FILE_FORMATS)}" + f"No file format supplied\n" + f"To write a single file use a file extension: {backtick_join(FILE_FORMATS)}" + f"To write multiple files use a directory extension: " + f"{backtick_join(format_directory_extension(e) for e in FILE_FORMATS)}\n" ) + elif directory_ext: + if directory_ext not in FILE_FORMATS: + raise ArgumentTypeError( + f"'{format_directory_extension(directory_ext)}' is not a supported format, " + f"must be one of: " + f"{backtick_join(format_directory_extension(e) for e in FILE_FORMATS)}" + ) + else: + if file_ext not in FILE_FORMATS: + raise ArgumentTypeError( + f"'{file_ext}' is not a supported format, must be one of: " + f"{backtick_join(FILE_FORMATS)}" + ) return path diff --git a/ehrql/main.py b/ehrql/main.py index 851a1d1b0..490dd47bd 100644 --- a/ehrql/main.py +++ b/ehrql/main.py @@ -14,6 +14,7 @@ ) from ehrql.file_formats import ( read_rows, + read_tables, split_directory_and_extension, write_rows, write_tables, @@ -72,13 +73,10 @@ def generate_dataset( assure(test_data_file, environ=environ, user_args=user_args) table_specs = get_table_specs(dataset) - # For now we only handle datasets with a single output table - assert len(table_specs) == 1 - column_specs = list(table_specs.values())[0] if dsn: log.info("Generating dataset") - results = generate_dataset_with_dsn( + results_tables = generate_dataset_with_dsn( dataset=dataset, dsn=dsn, backend_class=backend_class, @@ -87,15 +85,15 @@ def generate_dataset( ) else: log.info("Generating dummy dataset") - results = generate_dataset_with_dummy_data( + results_tables = generate_dataset_with_dummy_data( dataset=dataset, dummy_data_config=dummy_data_config, - column_specs=column_specs, + table_specs=table_specs, dummy_data_file=dummy_data_file, dummy_tables_path=dummy_tables_path, ) - write_rows(output_file, results, column_specs) + write_tables(output_file, results_tables, table_specs) def generate_dataset_with_dsn( @@ -108,23 +106,22 @@ def generate_dataset_with_dsn( environ, default_query_engine_class=LocalFileQueryEngine, ) - return query_engine.get_results(dataset) + return query_engine.get_results_tables(dataset) def generate_dataset_with_dummy_data( - *, dataset, dummy_data_config, column_specs, dummy_data_file, dummy_tables_path + *, dataset, dummy_data_config, table_specs, dummy_data_file, dummy_tables_path ): if dummy_data_file: log.info(f"Reading dummy data from {dummy_data_file}") - reader = read_rows(dummy_data_file, column_specs) - return iter(reader) + return read_tables(dummy_data_file, table_specs) elif dummy_tables_path: log.info(f"Reading table data from {dummy_tables_path}") query_engine = LocalFileQueryEngine(dummy_tables_path) - return query_engine.get_results(dataset) + return query_engine.get_results_tables(dataset) else: generator = get_dummy_data_generator(dataset, dummy_data_config) - return generator.get_results() + return generator.get_results_tables() def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ): diff --git a/tests/unit/test___main__.py b/tests/unit/test___main__.py index 0f6fe5bb7..5ade7a7ea 100644 --- a/tests/unit/test___main__.py +++ b/tests/unit/test___main__.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from ehrql.__main__ import ( @@ -10,6 +12,7 @@ import_string, main, query_engine_from_id, + valid_output_path, ) from ehrql.backends.base import SQLBackend from ehrql.query_engines.base import BaseQueryEngine @@ -323,3 +326,30 @@ def test_all_backends_have_an_alias(): def test_all_backend_aliases_match_display_names(): for alias in BACKEND_ALIASES.keys(): assert backend_from_id(alias).display_name.lower() == alias + + +@pytest.mark.parametrize( + "path", + [ + "some/path/file.csv", + "some/path/dir:csv", + "some/path/dir/:csv", + "some/path/dir.foo:csv", + ], +) +def test_valid_output_path(path): + assert valid_output_path(path) == Path(path) + + +@pytest.mark.parametrize( + "path, message", + [ + ("no/extension", "No file format supplied"), + ("some/path.badfile", "'.badfile' is not a supported format"), + ("some/path:baddir", "':baddir' is not a supported format"), + ("some/path/:baddir", "':baddir' is not a supported format"), + ], +) +def test_valid_output_path_errors(path, message): + with pytest.raises(ArgumentTypeError, match=message): + valid_output_path(path) From 51e16d9b984d0324e28dde7dce42f95349bd584f Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 22 Jan 2025 12:48:13 +0000 Subject: [PATCH 12/14] Allow `create-dummy-tables` to write to console This isn't a massively useful feature in itself, but I think in general it's better for commands to show you what they produce by default, rather than requiring outputs to be written to disk. --- ehrql/__main__.py | 1 + ehrql/main.py | 6 ++++-- tests/functional/test_create_dummy_tables.py | 9 +++++++++ 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/ehrql/__main__.py b/ehrql/__main__.py index 87a56377a..e0812c506 100644 --- a/ehrql/__main__.py +++ b/ehrql/__main__.py @@ -293,6 +293,7 @@ def add_create_dummy_tables(subparsers, environ, user_args): add_dataset_definition_file_argument(parser, environ) parser.add_argument( "dummy_tables_path", + nargs="?", help=strip_indent( f""" Path to directory where files (one per table) will be written. diff --git a/ehrql/main.py b/ehrql/main.py index 490dd47bd..9a5cc9a84 100644 --- a/ehrql/main.py +++ b/ehrql/main.py @@ -132,8 +132,10 @@ def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ): generator = get_dummy_data_generator(dataset, dummy_data_config) table_data = generator.get_data() - directory, extension = split_directory_and_extension(dummy_tables_path) - log.info(f"Writing tables as '{extension}' files to '{directory}'") + if dummy_tables_path is not None: + directory, extension = split_directory_and_extension(dummy_tables_path) + log.info(f"Writing tables as '{extension}' files to '{directory}'") + table_specs = { table.name: get_column_specs_from_schema(table.schema) for table in table_data.keys() diff --git a/tests/functional/test_create_dummy_tables.py b/tests/functional/test_create_dummy_tables.py index 5e0e5649d..2bda96a0f 100644 --- a/tests/functional/test_create_dummy_tables.py +++ b/tests/functional/test_create_dummy_tables.py @@ -51,3 +51,12 @@ def test_create_dummy_tables( lines = (dummy_tables_path / "patients.csv").read_text().splitlines() assert lines[0] == expected_columns assert len(lines) == 11 # 1 header, 10 rows + + +def test_create_dummy_tables_console_output(call_cli, tmp_path): + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(trivial_dataset_definition) + captured = call_cli("create-dummy-tables", dataset_definition_path) + + assert "patient_id" in captured.out + assert "date_of_birth" in captured.out From c038516116a03582a06a42450860448d53bd2607 Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 22 Jan 2025 14:47:19 +0000 Subject: [PATCH 13/14] Handle optional positional arguments in CLI docs --- ehrql/docs/cli.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/ehrql/docs/cli.py b/ehrql/docs/cli.py index 52e4e8bdd..c0e985fd2 100644 --- a/ehrql/docs/cli.py +++ b/ehrql/docs/cli.py @@ -56,6 +56,13 @@ def get_argument(action): "usage_long": ", ".join(action.option_strings), "description": action.help, } + elif isinstance(action, argparse._StoreAction) and action.nargs == "?": + return { + "id": action.dest, + "usage_short": f"[{action.dest.upper()}]", + "usage_long": action.dest.upper(), + "description": action.help, + } elif isinstance(action, argparse._StoreAction) and not action.required: return { "id": action.option_strings[-1].lstrip("-"), From d8cb142cca310bdd8b4f56d90fbd9d9f18bb9d2c Mon Sep 17 00:00:00 2001 From: David Evans Date: Wed, 22 Jan 2025 16:14:06 +0000 Subject: [PATCH 14/14] Run `just generate-docs` --- docs/includes/generated_docs/cli.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/includes/generated_docs/cli.md b/docs/includes/generated_docs/cli.md index bc3599934..9856affcc 100644 --- a/docs/includes/generated_docs/cli.md +++ b/docs/includes/generated_docs/cli.md @@ -467,7 +467,7 @@ double-dash ` -- `. create-dummy-tables ``` -ehrql create-dummy-tables DEFINITION_FILE DUMMY_TABLES_PATH [--help] +ehrql create-dummy-tables DEFINITION_FILE [DUMMY_TABLES_PATH] [--help] [ -- ... PARAMETERS ...] ``` Generate dummy tables and write them out as files – one per table, CSV by