Skip to content

Commit

Permalink
Merge pull request #2363 from opensafely-core/evansd/extend-query-eng…
Browse files Browse the repository at this point in the history
…ine-api

Update QueryEngine API to support multiple results tables
  • Loading branch information
evansd authored Jan 23, 2025
2 parents ac63c73 + d8cb142 commit 696c9a5
Show file tree
Hide file tree
Showing 27 changed files with 747 additions and 144 deletions.
2 changes: 1 addition & 1 deletion docs/includes/generated_docs/cli.md
Original file line number Diff line number Diff line change
Expand Up @@ -467,7 +467,7 @@ double-dash ` -- `.
create-dummy-tables
</h2>
```
ehrql create-dummy-tables DEFINITION_FILE DUMMY_TABLES_PATH [--help]
ehrql create-dummy-tables DEFINITION_FILE [DUMMY_TABLES_PATH] [--help]
[ -- ... PARAMETERS ...]
```
Generate dummy tables and write them out as files – one per table, CSV by
Expand Down
29 changes: 24 additions & 5 deletions ehrql/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ def add_create_dummy_tables(subparsers, environ, user_args):
add_dataset_definition_file_argument(parser, environ)
parser.add_argument(
"dummy_tables_path",
nargs="?",
help=strip_indent(
f"""
Path to directory where files (one per table) will be written.
Expand Down Expand Up @@ -662,13 +663,31 @@ def existing_python_file(value):


def valid_output_path(value):
# This can be either a single file or a directory, but either way it needs to
# specify a valid output format
path = Path(value)
extension = get_file_extension(path)
if extension not in FILE_FORMATS:
directory_ext = split_directory_and_extension(path)[1]
file_ext = get_file_extension(path)
if not directory_ext and not file_ext:
raise ArgumentTypeError(
f"'{extension}' is not a supported format, must be one of: "
f"{backtick_join(FILE_FORMATS)}"
f"No file format supplied\n"
f"To write a single file use a file extension: {backtick_join(FILE_FORMATS)}"
f"To write multiple files use a directory extension: "
f"{backtick_join(format_directory_extension(e) for e in FILE_FORMATS)}\n"
)
elif directory_ext:
if directory_ext not in FILE_FORMATS:
raise ArgumentTypeError(
f"'{format_directory_extension(directory_ext)}' is not a supported format, "
f"must be one of: "
f"{backtick_join(format_directory_extension(e) for e in FILE_FORMATS)}"
)
else:
if file_ext not in FILE_FORMATS:
raise ArgumentTypeError(
f"'{file_ext}' is not a supported format, must be one of: "
f"{backtick_join(FILE_FORMATS)}"
)
return path


Expand Down Expand Up @@ -701,7 +720,7 @@ def query_engine_from_id(str_id):
f"(or a full dotted path to a query engine class)"
)
query_engine = import_string(str_id)
assert_duck_type(query_engine, "query engine", "get_results")
assert_duck_type(query_engine, "query engine", "get_results_tables")
return query_engine


Expand Down
7 changes: 7 additions & 0 deletions ehrql/docs/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,13 @@ def get_argument(action):
"usage_long": ", ".join(action.option_strings),
"description": action.help,
}
elif isinstance(action, argparse._StoreAction) and action.nargs == "?":
return {
"id": action.dest,
"usage_short": f"[{action.dest.upper()}]",
"usage_long": action.dest.upper(),
"description": action.help,
}
elif isinstance(action, argparse._StoreAction) and not action.required:
return {
"id": action.option_strings[-1].lstrip("-"),
Expand Down
10 changes: 8 additions & 2 deletions ehrql/dummy_data/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,10 +131,16 @@ def get_patient_id_stream(self):
if i not in inline_patient_ids:
yield i

def get_results(self):
def get_results_tables(self):
database = InMemoryDatabase(self.get_data())
engine = InMemoryQueryEngine(database)
return engine.get_results(self.dataset)
return engine.get_results_tables(self.dataset)

def get_results(self):
tables = self.get_results_tables()
yield from next(tables)
for remaining in tables:
assert False, "Expected only one results table"


class DummyPatientGenerator:
Expand Down
10 changes: 8 additions & 2 deletions ehrql/dummy_data_nextgen/generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,10 +245,16 @@ def get_patient_id_stream(self):
if i not in inline_patient_ids:
yield i

def get_results(self):
def get_results_tables(self):
database = InMemoryDatabase(self.get_data())
engine = InMemoryQueryEngine(database)
return engine.get_results(self.dataset)
return engine.get_results_tables(self.dataset)

def get_results(self):
tables = self.get_results_tables()
yield from next(tables)
for remaining in tables:
assert False, "Expected only one results table"


class DummyPatientGenerator:
Expand Down
28 changes: 28 additions & 0 deletions ehrql/file_formats/console.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""
Handles writing rows/tables to the console for local development and debugging.
At present, this just uses the CSV writer but there's scope for using something a bit
prettier and more readable here in future.
"""

import sys

from ehrql.file_formats.csv import write_rows_csv_lines


def write_rows_console(rows, column_specs):
write_rows_csv_lines(sys.stdout, rows, column_specs)


def write_tables_console(tables, table_specs):
write_table_names = len(table_specs) > 1
first_table = True
for rows, (table_name, column_specs) in zip(tables, table_specs.items()):
if first_table:
first_table = False
else:
# Add whitespace between tables
sys.stdout.write("\n\n")
if write_table_names:
sys.stdout.write(f"{table_name}\n")
write_rows_console(rows, column_specs)
11 changes: 2 additions & 9 deletions ehrql/file_formats/csv.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import csv
import datetime
import gzip
import sys
from contextlib import nullcontext

from ehrql.file_formats.base import (
BaseRowsReader,
Expand All @@ -12,13 +10,8 @@


def write_rows_csv(filename, rows, column_specs):
if filename is None:
context = nullcontext(sys.stdout)
else:
# Set `newline` as per Python docs:
# https://docs.python.org/3/library/csv.html#id3
context = filename.open(mode="w", newline="")
with context as f:
# Set `newline` as per Python docs: https://docs.python.org/3/library/csv.html#id3
with filename.open(mode="w", newline="") as f:
write_rows_csv_lines(f, rows, column_specs)


Expand Down
73 changes: 66 additions & 7 deletions ehrql/file_formats/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
write_rows_arrow,
)
from ehrql.file_formats.base import FileValidationError
from ehrql.file_formats.console import write_rows_console, write_tables_console
from ehrql.file_formats.csv import (
CSVGZRowsReader,
CSVRowsReader,
Expand All @@ -23,6 +24,9 @@


def write_rows(filename, rows, column_specs):
if filename is None:
return write_rows_console(rows, column_specs)

extension = get_file_extension(filename)
writer = FILE_FORMATS[extension][0]
# `rows` is often a generator which won't actually execute until we start consuming
Expand All @@ -31,9 +35,7 @@ def write_rows(filename, rows, column_specs):
# whole thing into memory. So we wrap it in a function which draws the first item
# upfront, but doesn't consume the rest of the iterator.
rows = eager_iterator(rows)
# We use None for stdout
if filename is not None:
filename.parent.mkdir(parents=True, exist_ok=True)
filename.parent.mkdir(parents=True, exist_ok=True)
writer(filename, rows, column_specs)


Expand All @@ -48,6 +50,33 @@ def read_rows(filename, column_specs, allow_missing_columns=False):


def read_tables(filename, table_specs, allow_missing_columns=False):
if not filename.exists():
raise FileValidationError(f"Missing file or directory: {filename}")

# If we've got a single-table input file and only a single table to read then that's
# fine, but it needs slightly special handling
if not input_filename_supports_multiple_tables(filename):
if len(table_specs) == 1:
column_specs = list(table_specs.values())[0]
rows = read_rows(
filename,
column_specs,
allow_missing_columns=allow_missing_columns,
)
yield from [rows]
return
else:
files = list(table_specs.keys())
suffix = filename.suffix
raise FileValidationError(
f"Attempting to read {len(table_specs)} tables, but input only "
f"provides a single table\n"
f" Try moving -> {filename}\n"
f" to -> {filename.parent / filename.stem}/{files[0]}{suffix}\n"
f" adding -> {', '.join(f + suffix for f in files[1:])}\n"
f" and using path -> {filename.parent / filename.stem}/"
)

extension = get_extension_from_directory(filename)
# Using ExitStack here allows us to open and validate all files before emiting any
# rows while still correctly closing all open files if we raise an error part way
Expand All @@ -66,17 +95,33 @@ def read_tables(filename, table_specs, allow_missing_columns=False):


def write_tables(filename, tables, table_specs):
if filename is None:
return write_tables_console(tables, table_specs)

# If we've got a single-table output file and only a single table to write then
# that's fine, but it needs slightly special handling
if not output_filename_supports_multiple_tables(filename):
if len(table_specs) == 1:
column_specs = list(table_specs.values())[0]
rows = next(iter(tables))
return write_rows(filename, rows, column_specs)
else:
raise FileValidationError(
f"Attempting to write {len(table_specs)} tables, but output only "
f"supports a single table\n"
f" Instead of -> {filename}\n"
f" try -> "
f"{filename.parent / filename.stem}/:{filename.suffix.lstrip('.')}"
)

filename, extension = split_directory_and_extension(filename)
for rows, (table_name, column_specs) in zip(tables, table_specs.items()):
table_filename = get_table_filename(filename, table_name, extension)
write_rows(table_filename, rows, column_specs)


def get_file_extension(filename):
if filename is None:
# If we have no filename we're writing to stdout, so default to CSV
return ".csv"
elif filename.suffix == ".gz":
if filename.suffix == ".gz":
return "".join(filename.suffixes[-2:])
else:
return filename.suffix
Expand Down Expand Up @@ -121,6 +166,20 @@ def split_directory_and_extension(filename):
return filename.with_name(name), f".{extension}"


def input_filename_supports_multiple_tables(filename):
# At present, supplying a directory is the only way to provide multiple input
# tables, but it's not inconceivable that in future we might support single-file
# multiple-table formats e.g SQLite or DuckDB files. If we do then updating this
# function and its sibling below should be all that's required.
return filename.is_dir()


def output_filename_supports_multiple_tables(filename):
# Again, at present only directories support multiple output tables but see above
extension = split_directory_and_extension(filename)[1]
return extension != ""


def get_table_filename(base_filename, table_name, extension):
# Use URL quoting as an easy way of escaping any potentially problematic characters
# in filenames
Expand Down
Loading

0 comments on commit 696c9a5

Please sign in to comment.