Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Proof of concept: Allow selecting of slices of assay meta in graphql #827

Draft
wants to merge 1 commit into
base: dev
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions api/graphql/loaders.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,11 @@ async def load_audit_logs_by_analysis_ids(

@connected_data_loader_with_params(LoaderKeys.ASSAYS_FOR_SAMPLES, default_factory=list)
async def load_assays_by_samples(
connection, ids, filter: AssayFilter
connection,
ids,
filter: AssayFilter,
include_meta: bool,
meta_slices: list[dict[str, Any]] | None,
) -> dict[int, list[AssayInternal]]:
"""
DataLoader: get_assays_for_sample_ids
Expand All @@ -191,7 +195,9 @@ async def load_assays_by_samples(
assaylayer = AssayLayer(connection)
# maybe this is dangerous, but I don't think it should matter
filter.sample_id = GenericFilter(in_=ids)
assays = await assaylayer.query(filter)
assays = await assaylayer.query(
filter, include_meta=include_meta, meta_slices=meta_slices
)
assay_map = group_by(assays, lambda a: a.sample_id)
return assay_map

Expand Down
27 changes: 26 additions & 1 deletion api/graphql/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"""
import datetime
from inspect import isclass
from typing import Any

import strawberry
from strawberry.extensions import QueryDepthLimiter
Expand Down Expand Up @@ -754,13 +755,28 @@ async def assays(
type: GraphQLFilter[str] | None = None,
meta: GraphQLMetaFilter | None = None,
) -> list['GraphQLAssay']:
selected_fields = info.selected_fields[0].selections
has_meta_selected = any(f.name == 'meta' for f in selected_fields)

# Find if there are any slices of meta selected, we can load these as part of the same query
meta_slices = [
{'path': f.arguments['path'], 'alias': f.alias}
for f in selected_fields
if f.name == 'metaValue'
]

loader_assays_for_sample_ids = info.context[LoaderKeys.ASSAYS_FOR_SAMPLES]
filter_ = AssayFilter(
type=type.to_internal_filter() if type else None,
meta=meta,
)
assays = await loader_assays_for_sample_ids.load(
{'id': root.internal_id, 'filter': filter_}
{
'id': root.internal_id,
'filter': filter_,
'include_meta': has_meta_selected,
'meta_slices': meta_slices,
}
)
return [GraphQLAssay.from_internal(assay) for assay in assays]

Expand Down Expand Up @@ -907,6 +923,7 @@ class GraphQLAssay:
external_ids: strawberry.scalars.JSON

sample_id: strawberry.Private[int]
meta_slices: strawberry.Private[dict[str, Any] | None]

@staticmethod
def from_internal(internal: AssayInternal) -> 'GraphQLAssay':
Expand All @@ -917,6 +934,7 @@ def from_internal(internal: AssayInternal) -> 'GraphQLAssay':
id=internal.id,
type=internal.type,
meta=internal.meta,
meta_slices=internal.meta_slices,
external_ids=internal.external_ids or {},
# internal
sample_id=internal.sample_id,
Expand All @@ -928,6 +946,13 @@ async def sample(self, info: Info, root: 'GraphQLAssay') -> GraphQLSample:
sample = await loader.load(root.sample_id)
return GraphQLSample.from_internal(sample)

@strawberry.field
async def metaValue(
self, info: Info, root: 'GraphQLAssay', path: str
) -> strawberry.scalars.JSON:
alias = info.selected_fields[0].alias
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would also need to handle un-aliased fields here

return root.meta_slices.get('_meta_' + alias) if root.meta_slices else None


@strawberry.type
class GraphQLAnalysisRunner:
Expand Down
12 changes: 10 additions & 2 deletions db/python/layers/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,18 @@ def __init__(self, connection: Connection):
self.sampt: SampleTable = SampleTable(connection)

# GET
async def query(self, filter_: AssayFilter = None, check_project_id=True):
async def query(
self,
filter_: AssayFilter = None,
check_project_id=True,
include_meta: bool = True,
meta_slices: list[dict[str, Any]] | None = None,
):
"""Query for samples"""

projects, assays = await self.seqt.query(filter_)
projects, assays = await self.seqt.query(
filter_, include_meta=include_meta, meta_slices=meta_slices
)

if not assays:
return []
Expand Down
26 changes: 22 additions & 4 deletions db/python/tables/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,10 @@ class AssayTable(DbBase):
# region GETS

async def query(
self, filter_: AssayFilter
self,
filter_: AssayFilter,
include_meta: bool = True,
meta_slices: list[dict[str, Any]] | None = None,
) -> tuple[set[ProjectId], list[AssayInternal]]:
"""Query assays"""
sql_overides = {
Expand All @@ -83,15 +86,30 @@ async def query(
raise ValueError('Must provide a project if filtering by external_id')

conditions, values = filter_.to_sql(sql_overides)
keys = ', '.join(self.COMMON_GET_KEYS)
keys = ', '.join(
[k for k in self.COMMON_GET_KEYS if include_meta or k != 'a.meta']
)
meta_slice_keys = (
','.join(
[
f'JSON_VALUE(a.meta, \'{m["path"]}\') as _meta_{m["alias"]}'
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

definitely a sql injection issue here, would need fixing if we want to do this for real

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should also use JSON_EXTRACT rather than JSON_VALUE

for m in meta_slices
]
)
if meta_slices
else None
)

if meta_slice_keys:
meta_slice_keys = ',' + meta_slice_keys

_query = f"""
SELECT {keys}
SELECT {keys} {meta_slice_keys}
FROM assay a
LEFT JOIN sample s ON s.id = a.sample_id
LEFT JOIN assay_external_id aeid ON aeid.assay_id = a.id
WHERE {conditions}
"""

assay_rows = await self.connection.fetch_all(_query, values)

# this will unique on the id, which we want due to joining on 1:many eid table
Expand Down
9 changes: 7 additions & 2 deletions models/models/assay.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ class AssayInternal(SMBase):
id: int | None
sample_id: int
meta: dict[str, Any] | None
meta_slices: dict[str, Any] | None
type: str
external_ids: dict[str, str] | None = {}

Expand All @@ -23,16 +24,18 @@ def __eq__(self, other):
return False

@staticmethod
def from_db(d: dict):
def from_db(d: dict[str, Any]):
"""Take DB mapping object, and return SampleSequencing"""
meta = d.pop('meta', None)
keys = [k for k in d]
meta_slices = {k: d.pop(k) for k in keys if k.startswith('_meta')}

if meta:
if isinstance(meta, bytes):
meta = meta.decode()
if isinstance(meta, str):
meta = json.loads(meta)
return AssayInternal(meta=meta, **d)
return AssayInternal(meta=meta, meta_slices=meta_slices, **d)

def to_external(self):
"""Convert to transport model"""
Expand Down Expand Up @@ -72,13 +75,15 @@ class Assay(SMBase):
external_ids: dict[str, str]
sample_id: str
meta: dict[str, Any]

type: str

def to_internal(self):
"""Convert to internal model"""
return AssayInternal(
id=self.id,
type=self.type,
meta_slices=None,
external_ids=self.external_ids,
sample_id=sample_id_transform_to_raw(self.sample_id),
meta=self.meta,
Expand Down
Loading