diff --git a/docs/includes/generated_docs/cli.md b/docs/includes/generated_docs/cli.md index bc3599934..9856affcc 100644 --- a/docs/includes/generated_docs/cli.md +++ b/docs/includes/generated_docs/cli.md @@ -467,7 +467,7 @@ double-dash ` -- `. create-dummy-tables ``` -ehrql create-dummy-tables DEFINITION_FILE DUMMY_TABLES_PATH [--help] +ehrql create-dummy-tables DEFINITION_FILE [DUMMY_TABLES_PATH] [--help] [ -- ... PARAMETERS ...] ``` Generate dummy tables and write them out as files – one per table, CSV by diff --git a/ehrql/__main__.py b/ehrql/__main__.py index b55701469..e0812c506 100644 --- a/ehrql/__main__.py +++ b/ehrql/__main__.py @@ -293,6 +293,7 @@ def add_create_dummy_tables(subparsers, environ, user_args): add_dataset_definition_file_argument(parser, environ) parser.add_argument( "dummy_tables_path", + nargs="?", help=strip_indent( f""" Path to directory where files (one per table) will be written. @@ -662,13 +663,31 @@ def existing_python_file(value): def valid_output_path(value): + # This can be either a single file or a directory, but either way it needs to + # specify a valid output format path = Path(value) - extension = get_file_extension(path) - if extension not in FILE_FORMATS: + directory_ext = split_directory_and_extension(path)[1] + file_ext = get_file_extension(path) + if not directory_ext and not file_ext: raise ArgumentTypeError( - f"'{extension}' is not a supported format, must be one of: " - f"{backtick_join(FILE_FORMATS)}" + f"No file format supplied\n" + f"To write a single file use a file extension: {backtick_join(FILE_FORMATS)}" + f"To write multiple files use a directory extension: " + f"{backtick_join(format_directory_extension(e) for e in FILE_FORMATS)}\n" ) + elif directory_ext: + if directory_ext not in FILE_FORMATS: + raise ArgumentTypeError( + f"'{format_directory_extension(directory_ext)}' is not a supported format, " + f"must be one of: " + f"{backtick_join(format_directory_extension(e) for e in FILE_FORMATS)}" + ) + else: + if file_ext not in FILE_FORMATS: + raise ArgumentTypeError( + f"'{file_ext}' is not a supported format, must be one of: " + f"{backtick_join(FILE_FORMATS)}" + ) return path @@ -701,7 +720,7 @@ def query_engine_from_id(str_id): f"(or a full dotted path to a query engine class)" ) query_engine = import_string(str_id) - assert_duck_type(query_engine, "query engine", "get_results") + assert_duck_type(query_engine, "query engine", "get_results_tables") return query_engine diff --git a/ehrql/docs/cli.py b/ehrql/docs/cli.py index 52e4e8bdd..c0e985fd2 100644 --- a/ehrql/docs/cli.py +++ b/ehrql/docs/cli.py @@ -56,6 +56,13 @@ def get_argument(action): "usage_long": ", ".join(action.option_strings), "description": action.help, } + elif isinstance(action, argparse._StoreAction) and action.nargs == "?": + return { + "id": action.dest, + "usage_short": f"[{action.dest.upper()}]", + "usage_long": action.dest.upper(), + "description": action.help, + } elif isinstance(action, argparse._StoreAction) and not action.required: return { "id": action.option_strings[-1].lstrip("-"), diff --git a/ehrql/dummy_data/generator.py b/ehrql/dummy_data/generator.py index d87f82541..03beb711a 100644 --- a/ehrql/dummy_data/generator.py +++ b/ehrql/dummy_data/generator.py @@ -131,10 +131,16 @@ def get_patient_id_stream(self): if i not in inline_patient_ids: yield i - def get_results(self): + def get_results_tables(self): database = InMemoryDatabase(self.get_data()) engine = InMemoryQueryEngine(database) - return engine.get_results(self.dataset) + return engine.get_results_tables(self.dataset) + + def get_results(self): + tables = self.get_results_tables() + yield from next(tables) + for remaining in tables: + assert False, "Expected only one results table" class DummyPatientGenerator: diff --git a/ehrql/dummy_data_nextgen/generator.py b/ehrql/dummy_data_nextgen/generator.py index fe96e1fbd..043185e1c 100644 --- a/ehrql/dummy_data_nextgen/generator.py +++ b/ehrql/dummy_data_nextgen/generator.py @@ -245,10 +245,16 @@ def get_patient_id_stream(self): if i not in inline_patient_ids: yield i - def get_results(self): + def get_results_tables(self): database = InMemoryDatabase(self.get_data()) engine = InMemoryQueryEngine(database) - return engine.get_results(self.dataset) + return engine.get_results_tables(self.dataset) + + def get_results(self): + tables = self.get_results_tables() + yield from next(tables) + for remaining in tables: + assert False, "Expected only one results table" class DummyPatientGenerator: diff --git a/ehrql/file_formats/console.py b/ehrql/file_formats/console.py new file mode 100644 index 000000000..6ab4f7513 --- /dev/null +++ b/ehrql/file_formats/console.py @@ -0,0 +1,28 @@ +""" +Handles writing rows/tables to the console for local development and debugging. + +At present, this just uses the CSV writer but there's scope for using something a bit +prettier and more readable here in future. +""" + +import sys + +from ehrql.file_formats.csv import write_rows_csv_lines + + +def write_rows_console(rows, column_specs): + write_rows_csv_lines(sys.stdout, rows, column_specs) + + +def write_tables_console(tables, table_specs): + write_table_names = len(table_specs) > 1 + first_table = True + for rows, (table_name, column_specs) in zip(tables, table_specs.items()): + if first_table: + first_table = False + else: + # Add whitespace between tables + sys.stdout.write("\n\n") + if write_table_names: + sys.stdout.write(f"{table_name}\n") + write_rows_console(rows, column_specs) diff --git a/ehrql/file_formats/csv.py b/ehrql/file_formats/csv.py index f081e0215..1c691b061 100644 --- a/ehrql/file_formats/csv.py +++ b/ehrql/file_formats/csv.py @@ -1,8 +1,6 @@ import csv import datetime import gzip -import sys -from contextlib import nullcontext from ehrql.file_formats.base import ( BaseRowsReader, @@ -12,13 +10,8 @@ def write_rows_csv(filename, rows, column_specs): - if filename is None: - context = nullcontext(sys.stdout) - else: - # Set `newline` as per Python docs: - # https://docs.python.org/3/library/csv.html#id3 - context = filename.open(mode="w", newline="") - with context as f: + # Set `newline` as per Python docs: https://docs.python.org/3/library/csv.html#id3 + with filename.open(mode="w", newline="") as f: write_rows_csv_lines(f, rows, column_specs) diff --git a/ehrql/file_formats/main.py b/ehrql/file_formats/main.py index 570696c24..acd68dabc 100644 --- a/ehrql/file_formats/main.py +++ b/ehrql/file_formats/main.py @@ -6,6 +6,7 @@ write_rows_arrow, ) from ehrql.file_formats.base import FileValidationError +from ehrql.file_formats.console import write_rows_console, write_tables_console from ehrql.file_formats.csv import ( CSVGZRowsReader, CSVRowsReader, @@ -23,6 +24,9 @@ def write_rows(filename, rows, column_specs): + if filename is None: + return write_rows_console(rows, column_specs) + extension = get_file_extension(filename) writer = FILE_FORMATS[extension][0] # `rows` is often a generator which won't actually execute until we start consuming @@ -31,9 +35,7 @@ def write_rows(filename, rows, column_specs): # whole thing into memory. So we wrap it in a function which draws the first item # upfront, but doesn't consume the rest of the iterator. rows = eager_iterator(rows) - # We use None for stdout - if filename is not None: - filename.parent.mkdir(parents=True, exist_ok=True) + filename.parent.mkdir(parents=True, exist_ok=True) writer(filename, rows, column_specs) @@ -48,6 +50,33 @@ def read_rows(filename, column_specs, allow_missing_columns=False): def read_tables(filename, table_specs, allow_missing_columns=False): + if not filename.exists(): + raise FileValidationError(f"Missing file or directory: {filename}") + + # If we've got a single-table input file and only a single table to read then that's + # fine, but it needs slightly special handling + if not input_filename_supports_multiple_tables(filename): + if len(table_specs) == 1: + column_specs = list(table_specs.values())[0] + rows = read_rows( + filename, + column_specs, + allow_missing_columns=allow_missing_columns, + ) + yield from [rows] + return + else: + files = list(table_specs.keys()) + suffix = filename.suffix + raise FileValidationError( + f"Attempting to read {len(table_specs)} tables, but input only " + f"provides a single table\n" + f" Try moving -> {filename}\n" + f" to -> {filename.parent / filename.stem}/{files[0]}{suffix}\n" + f" adding -> {', '.join(f + suffix for f in files[1:])}\n" + f" and using path -> {filename.parent / filename.stem}/" + ) + extension = get_extension_from_directory(filename) # Using ExitStack here allows us to open and validate all files before emiting any # rows while still correctly closing all open files if we raise an error part way @@ -66,6 +95,25 @@ def read_tables(filename, table_specs, allow_missing_columns=False): def write_tables(filename, tables, table_specs): + if filename is None: + return write_tables_console(tables, table_specs) + + # If we've got a single-table output file and only a single table to write then + # that's fine, but it needs slightly special handling + if not output_filename_supports_multiple_tables(filename): + if len(table_specs) == 1: + column_specs = list(table_specs.values())[0] + rows = next(iter(tables)) + return write_rows(filename, rows, column_specs) + else: + raise FileValidationError( + f"Attempting to write {len(table_specs)} tables, but output only " + f"supports a single table\n" + f" Instead of -> {filename}\n" + f" try -> " + f"{filename.parent / filename.stem}/:{filename.suffix.lstrip('.')}" + ) + filename, extension = split_directory_and_extension(filename) for rows, (table_name, column_specs) in zip(tables, table_specs.items()): table_filename = get_table_filename(filename, table_name, extension) @@ -73,10 +121,7 @@ def write_tables(filename, tables, table_specs): def get_file_extension(filename): - if filename is None: - # If we have no filename we're writing to stdout, so default to CSV - return ".csv" - elif filename.suffix == ".gz": + if filename.suffix == ".gz": return "".join(filename.suffixes[-2:]) else: return filename.suffix @@ -121,6 +166,20 @@ def split_directory_and_extension(filename): return filename.with_name(name), f".{extension}" +def input_filename_supports_multiple_tables(filename): + # At present, supplying a directory is the only way to provide multiple input + # tables, but it's not inconceivable that in future we might support single-file + # multiple-table formats e.g SQLite or DuckDB files. If we do then updating this + # function and its sibling below should be all that's required. + return filename.is_dir() + + +def output_filename_supports_multiple_tables(filename): + # Again, at present only directories support multiple output tables but see above + extension = split_directory_and_extension(filename)[1] + return extension != "" + + def get_table_filename(base_filename, table_name, extension): # Use URL quoting as an easy way of escaping any potentially problematic characters # in filenames diff --git a/ehrql/main.py b/ehrql/main.py index 122ef3d83..9a5cc9a84 100644 --- a/ehrql/main.py +++ b/ehrql/main.py @@ -14,6 +14,7 @@ ) from ehrql.file_formats import ( read_rows, + read_tables, split_directory_and_extension, write_rows, write_tables, @@ -35,8 +36,8 @@ from ehrql.query_engines.local_file import LocalFileQueryEngine from ehrql.query_engines.sqlite import SQLiteQueryEngine from ehrql.query_model.column_specs import ( - get_column_specs, get_column_specs_from_schema, + get_table_specs, ) from ehrql.query_model.graphs import graph_to_svg from ehrql.serializer import serialize @@ -71,11 +72,11 @@ def generate_dataset( log.info(f"Testing dataset definition with tests in {str(definition_file)}") assure(test_data_file, environ=environ, user_args=user_args) - column_specs = get_column_specs(dataset) + table_specs = get_table_specs(dataset) if dsn: log.info("Generating dataset") - results = generate_dataset_with_dsn( + results_tables = generate_dataset_with_dsn( dataset=dataset, dsn=dsn, backend_class=backend_class, @@ -84,15 +85,15 @@ def generate_dataset( ) else: log.info("Generating dummy dataset") - results = generate_dataset_with_dummy_data( + results_tables = generate_dataset_with_dummy_data( dataset=dataset, dummy_data_config=dummy_data_config, - column_specs=column_specs, + table_specs=table_specs, dummy_data_file=dummy_data_file, dummy_tables_path=dummy_tables_path, ) - write_rows(output_file, results, column_specs) + write_tables(output_file, results_tables, table_specs) def generate_dataset_with_dsn( @@ -105,23 +106,22 @@ def generate_dataset_with_dsn( environ, default_query_engine_class=LocalFileQueryEngine, ) - return query_engine.get_results(dataset) + return query_engine.get_results_tables(dataset) def generate_dataset_with_dummy_data( - *, dataset, dummy_data_config, column_specs, dummy_data_file, dummy_tables_path + *, dataset, dummy_data_config, table_specs, dummy_data_file, dummy_tables_path ): if dummy_data_file: log.info(f"Reading dummy data from {dummy_data_file}") - reader = read_rows(dummy_data_file, column_specs) - return iter(reader) + return read_tables(dummy_data_file, table_specs) elif dummy_tables_path: log.info(f"Reading table data from {dummy_tables_path}") query_engine = LocalFileQueryEngine(dummy_tables_path) - return query_engine.get_results(dataset) + return query_engine.get_results_tables(dataset) else: generator = get_dummy_data_generator(dataset, dummy_data_config) - return generator.get_results() + return generator.get_results_tables() def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ): @@ -132,8 +132,10 @@ def create_dummy_tables(definition_file, dummy_tables_path, user_args, environ): generator = get_dummy_data_generator(dataset, dummy_data_config) table_data = generator.get_data() - directory, extension = split_directory_and_extension(dummy_tables_path) - log.info(f"Writing tables as '{extension}' files to '{directory}'") + if dummy_tables_path is not None: + directory, extension = split_directory_and_extension(dummy_tables_path) + log.info(f"Writing tables as '{extension}' files to '{directory}'") + table_specs = { table.name: get_column_specs_from_schema(table.schema) for table in table_data.keys() @@ -175,8 +177,8 @@ def dump_dataset_sql( def get_sql_strings(query_engine, dataset): - results_query = query_engine.get_query(dataset) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + results_queries = query_engine.get_queries(dataset) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) dialect = query_engine.sqlalchemy_dialect() sql_strings = [] @@ -184,8 +186,11 @@ def get_sql_strings(query_engine, dataset): sql = clause_as_str(query, dialect) sql_strings.append(f"-- Setup query {i:03} / {len(setup_queries):03}\n{sql}") - sql = clause_as_str(results_query, dialect) - sql_strings.append(f"-- Results query\n{sql}") + for i, query in enumerate(results_queries, start=1): + sql = clause_as_str(query, dialect) + sql_strings.append( + f"-- Results query {i:03} / {len(results_queries):03}\n{sql}" + ) for i, query in enumerate(cleanup_queries, start=1): sql = clause_as_str(query, dialect) diff --git a/ehrql/query_engines/base.py b/ehrql/query_engines/base.py index beeb8cba4..895779dd8 100644 --- a/ehrql/query_engines/base.py +++ b/ehrql/query_engines/base.py @@ -2,6 +2,10 @@ from typing import Any from ehrql.query_model import nodes as qm +from ehrql.utils.itertools_utils import iter_groups + + +class Marker: ... class BaseQueryEngine: @@ -12,6 +16,9 @@ class BaseQueryEngine: flavour of tables and query language (SQL, pandas dataframes etc). """ + # Sentinel value used to mark the start of a new results table in a stream of results + RESULTS_START = Marker() + def __init__(self, dsn: str, backend: Any = None, config: dict | None = None): """ `dsn` is Data Source Name — a string (usually a URL) which provides connection @@ -25,12 +32,42 @@ def __init__(self, dsn: str, backend: Any = None, config: dict | None = None): self.backend = backend self.config = config or {} - def get_results(self, dataset: qm.Dataset) -> Iterator[Sequence]: + def get_results_tables(self, dataset: qm.Dataset) -> Iterator[Iterator[Sequence]]: + """ + Given a query model `Dataset` return an iterator of "results tables", where each + table is an iterator of rows (usually tuples, but any sequence type will do) + + This is the primary interface to query engines and the one required method. + + Typically however, query engine subclasses will implement `get_results_stream` + instead which yields a flat sequence of rows, with tables separated by the + `RESULTS_START` marker value. This is converted into the appropriate structure + by `iter_groups` which also enforces that the caller interacts with it safely. """ - Given a query model `Dataset` return the results as an iterator of "rows" (which - are usually tuples, but any sequence type will do) + return iter_groups(self.get_results_stream(dataset), self.RESULTS_START) + + def get_results_stream(self, dataset: qm.Dataset) -> Iterator[Sequence | Marker]: + """ + Given a query model `Dataset` return an iterator of rows over all the results + tables, with each table's results separated by the `RESULTS_START` marker value Override this method to do the things necessary to generate query code and execute it against a particular backend. + + Emitting results in a flat sequence like this with separators between the tables + ends up making the query code _much_ easier to reason about because everything + happens in a clear linear sequence rather than inside nested generators. This + makes things like transaction management and error handling much more + straightforward. """ raise NotImplementedError() + + def get_results(self, dataset: qm.Dataset) -> Iterator[Sequence]: + """ + Temporary method to continue to support code which assumes only a single results + table + """ + tables = self.get_results_tables(dataset) + yield from next(tables) + for remaining in tables: + assert False, "Expected only one results table" diff --git a/ehrql/query_engines/base_sql.py b/ehrql/query_engines/base_sql.py index 8d567ec2c..ed2939fa3 100644 --- a/ehrql/query_engines/base_sql.py +++ b/ehrql/query_engines/base_sql.py @@ -84,9 +84,9 @@ def get_next_id(self): self.counter += 1 return self.counter - def get_query(self, dataset): + def get_queries(self, dataset): """ - Return the SQL query to fetch the results for `dataset` + Return the SQL queries to fetch the results for `dataset` Note that this query might make use of intermediate tables. The SQL queries needed to create these tables and clean them up can be retrieved by calling @@ -127,7 +127,9 @@ def get_query(self, dataset): self.get_sql.cache_clear() self.get_table.cache_clear() - return query + # At the moment we only support a single results table and so we'll only ever + # have a single query + return [query] def select_patient_id_for_population(self, population_expression): """ @@ -835,26 +837,29 @@ def get_select_query_for_node_domain(self, node): query = query.where(sqlalchemy.and_(*where_clauses)) return query - def get_results(self, dataset): - results_query = self.get_query(dataset) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + def get_results_stream(self, dataset): + results_queries = self.get_queries(dataset) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) + with self.engine.connect() as connection: for i, setup_query in enumerate(setup_queries, start=1): log.info(f"Running setup query {i:03} / {len(setup_queries):03}") connection.execute(setup_query) - log.info("Fetching results") - cursor_result = connection.execute(results_query) - try: - yield from cursor_result - except Exception: # pragma: no cover - # If we hit an error part way through fetching results then we should - # close the cursor to make it clear we're not going to be fetching any - # more (only really relevant for the in-memory SQLite tests, but good - # hygiene in any case) - cursor_result.close() - # Make sure the cleanup happens before raising the error - raise + for i, results_query in enumerate(results_queries, start=1): + log.info(f"Fetching results {i:03} / {len(setup_queries):03}") + cursor_result = connection.execute(results_query) + yield self.RESULTS_START + try: + yield from cursor_result + except Exception: # pragma: no cover + # If we hit an error part way through fetching results then we should + # close the cursor to make it clear we're not going to be fetching any + # more (only really relevant for the in-memory SQLite tests, but good + # hygiene in any case) + cursor_result.close() + # Make sure the cleanup happens before raising the error + raise for i, cleanup_query in enumerate(cleanup_queries, start=1): log.info(f"Running cleanup query {i:03} / {len(cleanup_queries):03}") diff --git a/ehrql/query_engines/in_memory.py b/ehrql/query_engines/in_memory.py index 649b2ec31..f857379da 100644 --- a/ehrql/query_engines/in_memory.py +++ b/ehrql/query_engines/in_memory.py @@ -28,9 +28,10 @@ class InMemoryQueryEngine(BaseQueryEngine): tests, and a to provide a reference implementation for other engines. """ - def get_results(self, dataset): + def get_results_stream(self, dataset): table = self.get_results_as_patient_table(dataset) Row = namedtuple("Row", table.name_to_col.keys()) + yield self.RESULTS_START for record in table.to_records(): yield Row(**record) diff --git a/ehrql/query_engines/local_file.py b/ehrql/query_engines/local_file.py index c2019f979..c65ebc6fc 100644 --- a/ehrql/query_engines/local_file.py +++ b/ehrql/query_engines/local_file.py @@ -14,14 +14,14 @@ class LocalFileQueryEngine(InMemoryQueryEngine): database = None - def get_results(self, dataset): + def get_results_stream(self, dataset): # Given the dataset supplied determine the tables used and load the associated # data into the database self.populate_database( get_table_nodes(dataset), ) # Run the query as normal - return super().get_results(dataset) + return super().get_results_stream(dataset) def populate_database(self, table_nodes, allow_missing_columns=True): table_specs = { diff --git a/ehrql/query_engines/mssql.py b/ehrql/query_engines/mssql.py index 1b9c803b7..4adb37d6c 100644 --- a/ehrql/query_engines/mssql.py +++ b/ehrql/query_engines/mssql.py @@ -152,25 +152,31 @@ def create_inline_table(self, columns, rows): ] return table - def get_query(self, dataset): - results_query = super().get_query(dataset) - # Write results to a temporary table and select them from there. This allows us + def get_queries(self, dataset): + results_queries = super().get_queries(dataset) + # Write results to temporary tables and select them from there. This allows us # to use more efficient/robust mechanisms to retrieve the results. - results_table = temporary_table_from_query( - "#results", results_query, index_col="patient_id" - ) - return sqlalchemy.select(results_table) + select_queries = [] + for n, results_query in enumerate(results_queries, start=1): + results_table = temporary_table_from_query( + f"#results_{n}", results_query, index_col="patient_id" + ) + select_queries.append(sqlalchemy.select(results_table)) + return select_queries - def get_results(self, dataset): - results_query = self.get_query(dataset) + def get_results_stream(self, dataset): + results_queries = self.get_queries(dataset) - # We're expecting a query in a very specific form which is "select everything - # from one table"; so we assert that it has this form and retrieve a reference - # to the table - results_table = results_query.get_final_froms()[0] - assert str(results_query) == str(sqlalchemy.select(results_table)) + # We're expecting queries in a very specific form which is "select everything + # from one table"; so we assert that they have this form and retrieve references + # to the tables + results_tables = [] + for results_query in results_queries: + results_table = results_query.get_final_froms()[0] + assert str(results_query) == str(sqlalchemy.select(results_table)) + results_tables.append(results_table) - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) with self.engine.connect() as connection: # All our queries are either (a) read-only queries against static data, or @@ -193,16 +199,26 @@ def get_results(self, dataset): log=log.info, ) - yield from fetch_table_in_batches( - execute_with_retry, - results_table, - key_column=results_table.c.patient_id, - key_is_unique=True, - # This value was copied from the previous cohortextractor. I suspect it - # has no real scientific basis. - batch_size=32000, - log=log.info, - ) + for i, results_table in enumerate(results_tables): + yield self.RESULTS_START + yield from fetch_table_in_batches( + execute_with_retry, + results_table, + key_column=results_table.c.patient_id, + # TODO: We need to find a better way to identify which tables have a + # unique `patient_id` column because it lets the batch fetcher use a + # more efficient algorithm. At present, we know that the first + # results table does but this isn't a very comfortable approach. The + # other option is to just always use the non-unique algorithm on the + # basis that the lost efficiency probably isn't noticeable. But + # until we're supporting event-level data for real I'm reluctant to + # make things worse for the currently supported case. + key_is_unique=(i == 0), + # This value was copied from the previous cohortextractor. I suspect + # it has no real scientific basis. + batch_size=32000, + log=log.info, + ) for i, cleanup_query in enumerate(cleanup_queries, start=1): query_id = f"cleanup query {i:03} / {len(cleanup_queries):03}" diff --git a/ehrql/query_model/column_specs.py b/ehrql/query_model/column_specs.py index 52ead2529..648e7fe3e 100644 --- a/ehrql/query_model/column_specs.py +++ b/ehrql/query_model/column_specs.py @@ -6,7 +6,6 @@ AggregateByPatient, Case, Constraint, - Dataset, SelectColumn, SelectPatientTable, Value, @@ -27,16 +26,25 @@ class ColumnSpec: max_value: T | None = None -def get_column_specs(dataset): +def get_table_specs(dataset): """ - Given a dataset return a dict of ColumnSpec objects, given the types (and other - associated metadata) of all the columns in the output + Return the specifications for all the results tables this Dataset will produce + """ + # At present, Datasets only ever produce a single results table (which we call + # `dataset`) but this gives us the API we need for future expansion + return {"dataset": get_column_specs_from_variables(dataset.variables)} + + +def get_column_specs_from_variables(variables): + """ + Given a dict of dataset variables return a dict of ColumnSpec objects, given + the types (and other associated metadata) of all the columns in the output """ # TODO: It may not be universally true that IDs are ints. Internally the EMIS IDs # are SHA512 hashes stored as hex strings which we convert to ints. But we can't # guarantee always to be able to pull this trick. column_specs = {"patient_id": ColumnSpec(int, nullable=False)} - for name, series in dataset.variables.items(): + for name, series in variables.items(): column_specs[name] = get_column_spec_from_series(series) return column_specs @@ -46,14 +54,11 @@ def get_column_specs_from_schema(schema): # reusing all the logic above: we create a table node and then create some variables # by selecting each column in the schema from it. table = SelectPatientTable(name="table", schema=schema) - dataset = Dataset( - population=Value(False), - variables={ - column_name: SelectColumn(source=table, name=column_name) - for column_name in schema.column_names - }, - ) - return get_column_specs(dataset) + variables = { + column_name: SelectColumn(source=table, name=column_name) + for column_name in schema.column_names + } + return get_column_specs_from_variables(variables) def get_column_spec_from_series(series): diff --git a/ehrql/utils/itertools_utils.py b/ehrql/utils/itertools_utils.py index 4b5faffa2..56b62f19d 100644 --- a/ehrql/utils/itertools_utils.py +++ b/ehrql/utils/itertools_utils.py @@ -27,3 +27,77 @@ def iter_flatten(iterable, iter_classes=(list, tuple, GeneratorType)): yield from iter_flatten(item, iter_classes) else: yield item + + +def iter_groups(iterable, separator): + """ + Split a flat iterator of items into a nested iterator of groups of items. Groups are + delineated by the presence of a sentinel `separator` value which marks the start of + each group. + + For example, the iterator: + + - SEPARATOR + - 1 + - 2 + - SEPARATOR + - 3 + - 4 + + Will be transformed into the nested iterator: + + - + - 1 + - 2 + - + - 3 + - 4 + + This is useful for situations where a nested iterator is the natural API for + representing the data but the flat iterator is much easier to generate correctly. + """ + iterator = iter(iterable) + try: + first_item = next(iterator) + except StopIteration: + return + assert first_item is separator, ( + f"Invalid iterator: does not start with `separator` value {separator!r}" + ) + while True: + group_iter = GroupIterator(iterator, separator) + yield group_iter + # Prevent the caller from trying to consume the next group before they've + # finished consuming the current one (as would happen if you naively called + # `list()` on the result of `iter_groups()`) + assert group_iter._group_complete, ( + "Cannot consume next group until current group has been exhausted" + ) + if group_iter._exhausted: + break + + +class GroupIterator: + def __init__(self, iterator, separator): + self._iterator = iterator + self._separator = separator + self._group_complete = False + self._exhausted = False + + def __iter__(self): + return self + + def __next__(self): + if self._group_complete: + raise StopIteration() + try: + value = next(self._iterator) + except StopIteration: + self._group_complete = True + self._exhausted = True + raise + if value is self._separator: + self._group_complete = True + raise StopIteration() + else: + return value diff --git a/ehrql/utils/sqlalchemy_query_utils.py b/ehrql/utils/sqlalchemy_query_utils.py index 29ef72704..4ab99a294 100644 --- a/ehrql/utils/sqlalchemy_query_utils.py +++ b/ehrql/utils/sqlalchemy_query_utils.py @@ -72,9 +72,10 @@ def from_query(cls, name, query, metadata=None, **kwargs): return cls(name, metadata, *columns, **kwargs) -def get_setup_and_cleanup_queries(query): +def get_setup_and_cleanup_queries(queries): """ - Given a SQLAlchemy query find all GeneratedTables embeded in it and return a pair: + Given a list of SQLAlchemy queries find all GeneratedTables embedded in them and + return a pair: setup_queries, cleanup_queries @@ -90,7 +91,7 @@ def get_setup_and_cleanup_queries(query): # give it a sequence of pairs of tables (A, B) indicating that A depends on B and it # returns a suitable ordering over the tables. sorter = graphlib.TopologicalSorter() - for parent_table, table in get_generated_table_dependencies(query): + for parent_table, table in get_generated_table_dependencies(queries): # A parent_table of None indicates a root table (i.e. one with no dependants) so # we record its existence without specifying any dependencies if parent_table is None: @@ -118,10 +119,10 @@ def get_setup_and_cleanup_queries(query): return setup_queries, cleanup_queries -def get_generated_table_dependencies(query, parent_table=None, seen_tables=None): +def get_generated_table_dependencies(queries, parent_table=None, seen_tables=None): """ - Recursively find all GeneratedTable objects referenced by `query` and yield as pairs - of dependencies: + Recursively find all GeneratedTable objects referenced by any query in `queries` and + yield as pairs of dependencies: table_X, table_Y_which_is_referenced_by_X @@ -130,14 +131,15 @@ def get_generated_table_dependencies(query, parent_table=None, seen_tables=None) if seen_tables is None: seen_tables = set() - for table in get_generated_tables(query): - yield parent_table, table - # Don't recurse into the same table twice - if table not in seen_tables: - seen_tables.add(table) - for child_query in [*table.setup_queries, *table.cleanup_queries]: + for query in queries: + for table in get_generated_tables(query): + yield parent_table, table + # Don't recurse into the same table twice + if table not in seen_tables: + seen_tables.add(table) + child_queries = [*table.setup_queries, *table.cleanup_queries] yield from get_generated_table_dependencies( - child_query, parent_table=table, seen_tables=seen_tables + child_queries, parent_table=table, seen_tables=seen_tables ) diff --git a/tests/functional/test_create_dummy_tables.py b/tests/functional/test_create_dummy_tables.py index 5e0e5649d..2bda96a0f 100644 --- a/tests/functional/test_create_dummy_tables.py +++ b/tests/functional/test_create_dummy_tables.py @@ -51,3 +51,12 @@ def test_create_dummy_tables( lines = (dummy_tables_path / "patients.csv").read_text().splitlines() assert lines[0] == expected_columns assert len(lines) == 11 # 1 header, 10 rows + + +def test_create_dummy_tables_console_output(call_cli, tmp_path): + dataset_definition_path = tmp_path / "dataset_definition.py" + dataset_definition_path.write_text(trivial_dataset_definition) + captured = call_cli("create-dummy-tables", dataset_definition_path) + + assert "patient_id" in captured.out + assert "date_of_birth" in captured.out diff --git a/tests/integration/backends/test_emis.py b/tests/integration/backends/test_emis.py index f55d2b1fe..ba3944121 100644 --- a/tests/integration/backends/test_emis.py +++ b/tests/integration/backends/test_emis.py @@ -583,15 +583,16 @@ class t(PatientFrame): variables = dataset._compile() - results_query = query_engine.get_query(variables) + results_queries = query_engine.get_queries(variables) + assert len(results_queries) == 1 inline_tables = [ ch - for ch in results_query.get_children() + for ch in results_queries[0].get_children() if isinstance(ch, GeneratedTable) and "inline_data" in ch.name ] assert len(set(inline_tables)) == 1 inline_table = inline_tables[0] - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) with query_engine.engine.connect() as connection: for setup_query in setup_queries: @@ -638,15 +639,16 @@ def test_temp_table_includes_organisation_hash(trino_database): ) variables = dataset._compile() - results_query = query_engine.get_query(variables) + results_queries = query_engine.get_queries(variables) + assert len(results_queries) == 1 temp_tables = [ ch - for ch in results_query.get_children() + for ch in results_queries[0].get_children() if isinstance(ch, GeneratedTable) and "tmp" in ch.name ] assert len(set(temp_tables)) == 1 temp_table = temp_tables[0] - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_query) + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(results_queries) with query_engine.engine.connect() as connection: for setup_query in setup_queries: diff --git a/tests/integration/file_formats/test_main.py b/tests/integration/file_formats/test_main.py index 5c0770c93..880390702 100644 --- a/tests/integration/file_formats/test_main.py +++ b/tests/integration/file_formats/test_main.py @@ -1,4 +1,6 @@ +import contextlib import datetime +import textwrap import pytest @@ -230,3 +232,109 @@ def test_read_and_write_tables_roundtrip(tmp_path, extension): results = read_tables(tmp_path / "output", table_specs) assert [list(rows) for rows in results] == tables + + +def test_read_tables_allows_single_table_format_if_only_one_table(tmp_path): + filename = tmp_path / "file.csv" + filename.write_text("i,s\n1,a\n2,b\n3,c\n") + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + } + results = read_tables(filename, table_specs) + assert [list(rows) for rows in results] == [ + [(1, "a"), (2, "b"), (3, "c")], + ] + + +def test_write_tables_allows_single_table_format_if_only_one_table(tmp_path): + filename = tmp_path / "file.csv" + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + } + table_data = [ + [(1, "a"), (2, "b"), (3, "c")], + ] + write_tables(filename, table_data, table_specs) + assert filename.read_text() == "i,s\n1,a\n2,b\n3,c\n" + + +def test_read_tables_rejects_single_table_format_if_multiple_tables(tmp_path): + filename = tmp_path / "input.csv" + filename.touch() + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + "table_2": {"j": ColumnSpec(int), "k": ColumnSpec(float)}, + "table_3": {"l": ColumnSpec(int), "m": ColumnSpec(float)}, + } + expected_error = textwrap.dedent( + """\ + Attempting to read 3 tables, but input only provides a single table + Try moving -> input.csv + to -> input/table_1.csv + adding -> table_2.csv, table_3.csv + and using path -> input/ + """ + ) + with contextlib.chdir(tmp_path): + # Use relative paths to get predictable error message + relpath = filename.relative_to(tmp_path) + with pytest.raises(FileValidationError, match=expected_error.rstrip()): + list(read_tables(relpath, table_specs)) + + +def test_write_tables_rejects_single_table_format_if_multiple_tables(tmp_path): + filename = tmp_path / "output.csv" + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + "table_2": {"j": ColumnSpec(int), "k": ColumnSpec(float)}, + "table_3": {"l": ColumnSpec(int), "m": ColumnSpec(float)}, + } + table_data = [[], []] + expected_error = textwrap.dedent( + """\ + Attempting to write 3 tables, but output only supports a single table + Instead of -> output.csv + try -> output/:csv + """ + ) + with contextlib.chdir(tmp_path): + # Use relative paths to get predictable error message + relpath = filename.relative_to(tmp_path) + with pytest.raises(FileValidationError, match=expected_error.rstrip()): + write_tables(relpath, table_data, table_specs) + + +def test_read_tables_with_missing_file_raises_appropriate_error(tmp_path): + missing_file = tmp_path / "aint-no-such-file" + table_specs = { + "table_1": {"i": ColumnSpec(int), "s": ColumnSpec(str)}, + "table_2": {"j": ColumnSpec(int), "k": ColumnSpec(float)}, + "table_3": {"l": ColumnSpec(int), "m": ColumnSpec(float)}, + } + with pytest.raises(FileValidationError, match="Missing file or directory"): + next(read_tables(missing_file, table_specs)) + + +def test_write_rows_without_filename_writes_to_console(capsys): + write_rows(None, TEST_FILE_DATA, TEST_FILE_SPECS) + output = capsys.readouterr().out + # The exact content here is tested elsewhere, we just want to make sure things are + # wired up correctly + assert "patient_id" in output + + +def test_write_tables_without_filename_writes_to_console(capsys): + table_specs = { + "table_1": TEST_FILE_SPECS, + "table_2": TEST_FILE_SPECS, + } + table_data = [ + TEST_FILE_DATA, + TEST_FILE_DATA, + ] + write_tables(None, table_data, table_specs) + output = capsys.readouterr().out + # The exact content here is tested elsewhere, we just want to make sure things are + # wired up correctly + assert "patient_id" in output + assert "table_2" in output diff --git a/tests/unit/dummy_data_nextgen/test_specific_datasets.py b/tests/unit/dummy_data_nextgen/test_specific_datasets.py index 7993f9324..fcf14de7c 100644 --- a/tests/unit/dummy_data_nextgen/test_specific_datasets.py +++ b/tests/unit/dummy_data_nextgen/test_specific_datasets.py @@ -325,7 +325,7 @@ def test_will_raise_if_all_data_is_impossible(patched_time, query): generator.timeout = 1 patched_time.time.side_effect = [0.0, 20.0] with pytest.raises(CannotGenerate): - generator.get_results() + next(generator.get_results()) def test_generates_events_starting_from_birthdate(): diff --git a/tests/unit/file_formats/test_console.py b/tests/unit/file_formats/test_console.py new file mode 100644 index 000000000..c76415598 --- /dev/null +++ b/tests/unit/file_formats/test_console.py @@ -0,0 +1,85 @@ +import datetime +import textwrap + +from ehrql.file_formats.console import write_rows_console, write_tables_console +from ehrql.query_model.column_specs import ColumnSpec +from ehrql.sqlalchemy_types import TYPE_MAP + + +def test_write_rows_console(capsys): + column_specs = { + "patient_id": ColumnSpec(int), + "b": ColumnSpec(bool), + "i": ColumnSpec(int), + "f": ColumnSpec(float), + "s": ColumnSpec(str), + "c": ColumnSpec(str, categories=("A", "B")), + "d": ColumnSpec(datetime.date), + } + + rows = [ + (123, True, 1, 2.3, "a", "A", datetime.date(2020, 1, 1)), + (456, False, -5, -0.4, "b", "B", datetime.date(2022, 12, 31)), + (789, None, None, None, None, None, None), + ] + + # Check the example uses at least one of every supported type + assert {spec.type for spec in column_specs.values()} == set(TYPE_MAP) + + write_rows_console(rows, column_specs) + output = capsys.readouterr().out + + # The CSV module does its own newline handling, hence the carriage returns below + assert output == textwrap.dedent( + """\ + patient_id,b,i,f,s,c,d\r + 123,T,1,2.3,a,A,2020-01-01\r + 456,F,-5,-0.4,b,B,2022-12-31\r + 789,,,,,,\r + """ + ) + + +def test_write_tables_console(capsys): + table_specs = { + "table_1": { + "patient_id": ColumnSpec(int), + "b": ColumnSpec(bool), + "i": ColumnSpec(int), + }, + "table_2": { + "patient_id": ColumnSpec(int), + "s": ColumnSpec(str), + "d": ColumnSpec(datetime.date), + }, + } + + tables = [ + [ + (123, True, 1), + (456, False, 2), + ], + [ + (789, "a", datetime.date(2025, 1, 1)), + (987, "B", datetime.date(2025, 2, 3)), + ], + ] + + write_tables_console(tables, table_specs) + output = capsys.readouterr().out + + # The CSV module does its own newline handling, hence the carriage returns below + assert output == textwrap.dedent( + """\ + table_1 + patient_id,b,i\r + 123,T,1\r + 456,F,2\r + + + table_2 + patient_id,s,d\r + 789,a,2025-01-01\r + 987,B,2025-02-03\r + """ + ) diff --git a/tests/unit/file_formats/test_main.py b/tests/unit/file_formats/test_main.py index 8ac33c074..aeb20cb23 100644 --- a/tests/unit/file_formats/test_main.py +++ b/tests/unit/file_formats/test_main.py @@ -17,7 +17,6 @@ @pytest.mark.parametrize( "filename,extension", [ - (None, ".csv"), (Path("a/b.c/file.txt"), ".txt"), (Path("a/b.c/file.txt.foo"), ".foo"), (Path("a/b.c/file.txt.gz"), ".txt.gz"), diff --git a/tests/unit/query_model/test_column_specs.py b/tests/unit/query_model/test_column_specs.py index 3ab52506f..14f59edc9 100644 --- a/tests/unit/query_model/test_column_specs.py +++ b/tests/unit/query_model/test_column_specs.py @@ -4,8 +4,8 @@ from ehrql.query_model.column_specs import ( ColumnSpec, get_categories, - get_column_specs, get_range, + get_table_specs, ) from ehrql.query_model.nodes import ( AggregateByPatient, @@ -46,14 +46,16 @@ def test_get_column_specs(): category=SelectColumn(patients, "category"), ), ) - column_specs = get_column_specs(dataset) - assert column_specs == { - "patient_id": ColumnSpec(type=int, nullable=False, categories=None), - "dob": ColumnSpec(type=datetime.date, nullable=True, categories=None), - "code": ColumnSpec(type=str, nullable=True, categories=None), - "category": ColumnSpec( - type=str, nullable=True, categories=("123000", "456000") - ), + table_specs = get_table_specs(dataset) + assert table_specs == { + "dataset": { + "patient_id": ColumnSpec(type=int, nullable=False, categories=None), + "dob": ColumnSpec(type=datetime.date, nullable=True, categories=None), + "code": ColumnSpec(type=str, nullable=True, categories=None), + "category": ColumnSpec( + type=str, nullable=True, categories=("123000", "456000") + ), + } } diff --git a/tests/unit/test___main__.py b/tests/unit/test___main__.py index 141979299..5ade7a7ea 100644 --- a/tests/unit/test___main__.py +++ b/tests/unit/test___main__.py @@ -1,3 +1,5 @@ +from pathlib import Path + import pytest from ehrql.__main__ import ( @@ -10,6 +12,7 @@ import_string, main, query_engine_from_id, + valid_output_path, ) from ehrql.backends.base import SQLBackend from ehrql.query_engines.base import BaseQueryEngine @@ -248,7 +251,7 @@ def test_import_string_no_such_attribute(): class DummyQueryEngine: - def get_results(self): + def get_results_tables(self): raise NotImplementedError() @@ -323,3 +326,30 @@ def test_all_backends_have_an_alias(): def test_all_backend_aliases_match_display_names(): for alias in BACKEND_ALIASES.keys(): assert backend_from_id(alias).display_name.lower() == alias + + +@pytest.mark.parametrize( + "path", + [ + "some/path/file.csv", + "some/path/dir:csv", + "some/path/dir/:csv", + "some/path/dir.foo:csv", + ], +) +def test_valid_output_path(path): + assert valid_output_path(path) == Path(path) + + +@pytest.mark.parametrize( + "path, message", + [ + ("no/extension", "No file format supplied"), + ("some/path.badfile", "'.badfile' is not a supported format"), + ("some/path:baddir", "':baddir' is not a supported format"), + ("some/path/:baddir", "':baddir' is not a supported format"), + ], +) +def test_valid_output_path_errors(path, message): + with pytest.raises(ArgumentTypeError, match=message): + valid_output_path(path) diff --git a/tests/unit/utils/test_itertools_utils.py b/tests/unit/utils/test_itertools_utils.py index 04756d64e..2d805d16b 100644 --- a/tests/unit/utils/test_itertools_utils.py +++ b/tests/unit/utils/test_itertools_utils.py @@ -1,6 +1,8 @@ +import hypothesis as hyp +import hypothesis.strategies as st import pytest -from ehrql.utils.itertools_utils import eager_iterator, iter_flatten +from ehrql.utils.itertools_utils import eager_iterator, iter_flatten, iter_groups def test_eager_iterator(): @@ -49,3 +51,71 @@ def test_iter_flatten(): ] flattened = list(iter_flatten(nested)) assert flattened == [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, "foo"] + + +SEPARATOR = object() + + +@pytest.mark.parametrize( + "stream,expected", + [ + ( + [], + [], + ), + ( + [SEPARATOR], + [[]], + ), + ( + [SEPARATOR, 1, 2, SEPARATOR, SEPARATOR, 3, 4], + [[1, 2], [], [3, 4]], + ), + ], +) +def test_iter_groups(stream, expected): + results = [list(group) for group in iter_groups(stream, SEPARATOR)] + assert results == expected + + +@hyp.given( + nested=st.lists( + st.lists(st.integers(), max_size=5), + max_size=5, + ) +) +def test_iter_groups_roundtrips(nested): + flattened = [] + for inner in nested: + flattened.append(SEPARATOR) + for item in inner: + flattened.append(item) + + results = [list(group) for group in iter_groups(flattened, SEPARATOR)] + assert results == nested + + +def test_iter_groups_rejects_invalid_stream(): + stream_no_separator = [1, 2] + with pytest.raises( + AssertionError, + match="Invalid iterator: does not start with `separator` value", + ): + list(iter_groups(stream_no_separator, SEPARATOR)) + + +def test_iter_groups_rejects_out_of_order_reads(): + stream = [SEPARATOR, 1, 2, SEPARATOR, 3, 4] + with pytest.raises( + AssertionError, + match="Cannot consume next group until current group has been exhausted", + ): + list(iter_groups(stream, SEPARATOR)) + + +def test_iter_groups_allows_overreading_groups(): + stream = [SEPARATOR, 1, 2, SEPARATOR, 3, 4] + # We call `list` on each group twice: this should make no difference because on the + # second call the group should be exhausted and so result in an empty list + results = [list(group) + list(group) for group in iter_groups(stream, SEPARATOR)] + assert results == [[1, 2], [3, 4]] diff --git a/tests/unit/utils/test_sqlalchemy_query_utils.py b/tests/unit/utils/test_sqlalchemy_query_utils.py index 5c53551c8..f6dfb606b 100644 --- a/tests/unit/utils/test_sqlalchemy_query_utils.py +++ b/tests/unit/utils/test_sqlalchemy_query_utils.py @@ -72,7 +72,7 @@ def test_get_setup_and_cleanup_queries_basic(): query = sqlalchemy.select(temp_table.c.foo) # Check that we get the right queries in the right order - assert _queries_as_strs(query) == [ + assert _queries_as_strs([query]) == [ "CREATE TABLE temp_table (\n\tfoo NULL\n)", "INSERT INTO temp_table (foo) VALUES (:foo)", "SELECT temp_table.foo \nFROM temp_table", @@ -100,7 +100,7 @@ def test_get_setup_and_cleanup_queries_nested(): query = sqlalchemy.select(temp_table2.c.baz) # Check that we create and drop the temporary tables in the right order - assert _queries_as_strs(query) == [ + assert _queries_as_strs([query]) == [ "CREATE TABLE temp_table1 (\n\tfoo NULL\n)", "INSERT INTO temp_table1 (foo) VALUES (:foo)", "CREATE TABLE temp_table2 (\n\tbaz NULL\n)", @@ -111,6 +111,41 @@ def test_get_setup_and_cleanup_queries_nested(): ] +def test_get_setup_and_cleanup_queries_multiple(): + # Make a temporary table + temp_table1 = _make_temp_table("temp_table1", "foo") + temp_table1.setup_queries.append( + temp_table1.insert().values(foo="bar"), + ) + + # Make a second temporary table ... + temp_table2 = _make_temp_table("temp_table2", "baz") + temp_table2.setup_queries.append( + # ... populated by a SELECT query against the first table + temp_table2.insert().from_select( + [temp_table2.c.baz], sqlalchemy.select(temp_table1.c.foo) + ), + ) + + # Select something from the second table + query_1 = sqlalchemy.select(temp_table2.c.baz) + + # Select something from the first table + query_2 = sqlalchemy.select(temp_table1.c.foo) + + # Check that we create and drop the temporary tables in the right order + assert _queries_as_strs([query_1, query_2]) == [ + "CREATE TABLE temp_table1 (\n\tfoo NULL\n)", + "INSERT INTO temp_table1 (foo) VALUES (:foo)", + "CREATE TABLE temp_table2 (\n\tbaz NULL\n)", + "INSERT INTO temp_table2 (baz) SELECT temp_table1.foo \nFROM temp_table1", + "SELECT temp_table2.baz \nFROM temp_table2", + "SELECT temp_table1.foo \nFROM temp_table1", + "DROP TABLE temp_table2", + "DROP TABLE temp_table1", + ] + + def _make_temp_table(name, *columns): table = GeneratedTable( name, sqlalchemy.MetaData(), *[sqlalchemy.Column(c) for c in columns] @@ -122,11 +157,11 @@ def _make_temp_table(name, *columns): return table -def _queries_as_strs(query): - setup_queries, cleanup_queries = get_setup_and_cleanup_queries(query) +def _queries_as_strs(queries): + setup_queries, cleanup_queries = get_setup_and_cleanup_queries(queries) return ( [str(q).strip() for q in setup_queries] - + [str(query).strip()] + + [str(q).strip() for q in queries] + [str(q).strip() for q in cleanup_queries] ) @@ -193,7 +228,7 @@ def test_get_setup_and_cleanup_queries_with_insert_many(): sqlalchemy.Column("i", sqlalchemy.Integer()), ) statement = InsertMany(table, rows=[]) - setup_cleanup = get_setup_and_cleanup_queries(statement) + setup_cleanup = get_setup_and_cleanup_queries([statement]) assert setup_cleanup == ([], [])