diff --git a/CHANGELOG.md b/CHANGELOG.md index 2993422e..1ba35f27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,7 @@ ### Changed - Updated several libraries including schug (now v.7) - Update project's Python version to 3.9 +- Modified the structure of the database table `genes`, converting the `ensembl_id` string field to `ensembl_ids`: an array of strings. This change addresses recent changes in the MySQL: https://bugs.mysql.com/bug.php?id=114838 ### Fixed - The MariaDB healthcheck step in docker-compose-mysql.yml, preventing the demo app to start diff --git a/src/chanjo2/crud/intervals.py b/src/chanjo2/crud/intervals.py index a7b11f18..b51b5e7c 100644 --- a/src/chanjo2/crud/intervals.py +++ b/src/chanjo2/crud/intervals.py @@ -1,7 +1,7 @@ import logging from typing import List, Optional, Union -from sqlalchemy import delete, or_ +from sqlalchemy import delete, or_, text from sqlalchemy.orm import Session, query from sqlalchemy.sql.expression import Delete @@ -71,9 +71,19 @@ def get_genes( ) -> List[SQLGene]: """Return genes according to specified fields.""" genes: query.Query = db.query(SQLGene) - if ensembl_ids: - genes: query.Query = genes.filter(SQLGene.ensembl_id.in_(ensembl_ids)) + ensembl_ids_placeholder = ", ".join(f"'{e}'" for e in ensembl_ids) + genes = genes.filter( + text( + f""" + EXISTS ( + SELECT 1 + FROM json_each(genes.ensembl_ids) + WHERE value IN ({ensembl_ids_placeholder}) + ) + """ + ) + ) elif hgnc_ids: genes: query.Query = genes.filter(SQLGene.hgnc_id.in_(hgnc_ids)) elif hgnc_symbols: @@ -151,7 +161,9 @@ def set_sql_intervals( ensembl_ids=None, hgnc_ids=None, hgnc_symbols=None, - ensembl_gene_ids=[gene.ensembl_id for gene in genes], + ensembl_gene_ids=[ + ensembl_id for gene in genes for ensembl_id in gene.ensembl_ids + ], limit=None, transcript_tags=transcript_tags, ) @@ -171,27 +183,39 @@ def get_gene_intervals( ) -> List[Union[SQLTranscript, SQLExon]]: """Retrieve transcripts or exons from a list of genes.""" - intervals: query.Query = db.query(interval_type).join(SQLGene) + intervals = db.query(interval_type).filter(interval_type.build == build) + + def get_ensembl_gene_ids_from_gene_filter( + filter_value: List[Union[str, int]], filter_column: str + ) -> List[str]: + """Helper function to get ensembl_gene_ids from either hgnc_ids or hgnc_symbols.""" + genes = ( + db.query(SQLGene.ensembl_ids).filter(filter_column.in_(filter_value)).all() + ) + return [ensembl_id for gene in genes for ensembl_id in gene.ensembl_ids] + if ensembl_ids: intervals: query.Query = intervals.filter( interval_type.ensembl_id.in_(ensembl_ids) ) - elif ensembl_gene_ids: - intervals: query.Query = intervals.filter( - interval_type.ensembl_gene_id.in_(ensembl_gene_ids) - ) elif hgnc_ids: - intervals: query.Query = intervals.filter(SQLGene.hgnc_id.in_(hgnc_ids)) + ensembl_gene_ids = get_ensembl_gene_ids_from_gene_filter( + hgnc_ids, SQLGene.hgnc_id + ) elif hgnc_symbols: - intervals: query.Query = intervals.filter(SQLGene.hgnc_symbol.in_(hgnc_symbols)) + ensembl_gene_ids = get_ensembl_gene_ids_from_gene_filter( + hgnc_symbols, SQLGene.hgnc_symbol + ) + if ensembl_gene_ids: + intervals = intervals.filter( + interval_type.ensembl_gene_id.in_(ensembl_gene_ids) + ) if interval_type == SQLTranscript and transcript_tags: intervals = _filter_transcripts_by_tag( transcripts=intervals, transcript_tags=transcript_tags ) - intervals: query.Query = intervals.filter(interval_type.build == build) - if limit: return intervals.limit(limit).all() diff --git a/src/chanjo2/endpoints/coverage.py b/src/chanjo2/endpoints/coverage.py index 35ceca92..5d3f4936 100644 --- a/src/chanjo2/endpoints/coverage.py +++ b/src/chanjo2/endpoints/coverage.py @@ -20,7 +20,7 @@ get_d4tools_chromosome_mean_coverage, get_d4tools_intervals_mean_coverage, ) -from chanjo2.meta.handle_d4 import get_samples_sex_metrics +from chanjo2.meta.handle_d4 import get_samples_sex_metrics, set_interval_ids_coords from chanjo2.meta.handle_report_contents import INTERVAL_TYPE_SQL_TYPE, get_mean from chanjo2.models import SQLGene from chanjo2.models.pydantic_models import ( @@ -167,10 +167,10 @@ def d4_genes_condensed_summary( detail=WRONG_COVERAGE_FILE_MSG, ) - interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = [ - (interval.ensembl_id, (interval.chromosome, interval.start, interval.stop)) - for interval in sql_intervals - ] + interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = ( + set_interval_ids_coords(sql_intervals=sql_intervals) + ) + # Sort intervals by chrom, start & stop interval_ids_coords = sort_interval_ids_coords(interval_ids_coords) diff --git a/src/chanjo2/meta/handle_d4.py b/src/chanjo2/meta/handle_d4.py index 4a052f9b..84c43c28 100644 --- a/src/chanjo2/meta/handle_d4.py +++ b/src/chanjo2/meta/handle_d4.py @@ -19,6 +19,26 @@ LOG = logging.getLogger(__name__) +def set_interval_ids_coords( + sql_intervals: List[Union[SQLGene, SQLTranscript, SQLExon]] +) -> List[Tuple[str, Tuple[str, int, int]]]: + """Returns tuples with an ensembl_id and coordinates from a list of SQL intervals.""" + + if not sql_intervals: + return [] + if isinstance(sql_intervals[0], SQLGene): + return [ + (ensembl_id, (interval.chromosome, interval.start, interval.stop)) + for interval in sql_intervals + for ensembl_id in interval.ensembl_ids + ] + else: + return [ + (interval.ensembl_id, (interval.chromosome, interval.start, interval.stop)) + for interval in sql_intervals + ] + + def get_report_sample_interval_coverage( d4_file_path: str, sample_name: str, @@ -31,11 +51,9 @@ def get_report_sample_interval_coverage( """Compute stats to populate a coverage report for one sample.""" # Compute intervals coverage completeness - interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = [ - (interval.ensembl_id, (interval.chromosome, interval.start, interval.stop)) - for interval in sql_intervals - ] - + interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = ( + set_interval_ids_coords(sql_intervals=sql_intervals) + ) interval_ids_coords = sort_interval_ids_coords(interval_ids_coords) # Compute intervals coverage @@ -57,51 +75,61 @@ def get_report_sample_interval_coverage( nr_intervals_covered_under_custom_threshold: int = 0 genes_covered_under_custom_threshold = set() - for interval_nr, interval in enumerate(sql_intervals): - - if interval.ensembl_id in interval_ids: - continue - for threshold in completeness_thresholds: - interval_coverage_at_threshold: float = intervals_coverage_completeness[ - interval.ensembl_id - ][threshold] - thresholds_dict[threshold].append(interval_coverage_at_threshold) - - # Collect intervals which are not completely covered at the custom threshold - if threshold == default_threshold and interval_coverage_at_threshold < 1: - nr_intervals_covered_under_custom_threshold += 1 - interval_ensembl_gene: str = ( - interval.ensembl_id - if interval.ensembl_id.startswith("ENSG") - else interval.ensembl_gene_id - ) - interval_hgnc_id: int = gene_ids_mapping[interval_ensembl_gene][ - "hgnc_id" - ] - interval_hgnc_symbol: str = gene_ids_mapping[interval_ensembl_gene][ - "hgnc_symbol" - ] - genes_covered_under_custom_threshold.add(interval_hgnc_symbol) - incomplete_coverages_rows.append( - ( - interval_hgnc_symbol, - interval_hgnc_id, - interval.ensembl_id, + for interval in sql_intervals: + + if hasattr(interval, "ensembl_ids"): + ensembl_ids = interval.ensembl_ids + else: + ensembl_ids = [interval.ensembl_id] + + for ensembl_id in ensembl_ids: + + if ensembl_id in interval_ids: + continue + for threshold in completeness_thresholds: + interval_coverage_at_threshold: float = intervals_coverage_completeness[ + ensembl_id + ][threshold] + thresholds_dict[threshold].append(interval_coverage_at_threshold) + + # Collect intervals which are not completely covered at the custom threshold + if ( + threshold == default_threshold + and interval_coverage_at_threshold < 1 + ): + nr_intervals_covered_under_custom_threshold += 1 + interval_ensembl_gene: str = ( + ensembl_id + if ensembl_id.startswith("ENSG") + else interval.ensembl_gene_id + ) + interval_hgnc_id: int = gene_ids_mapping[interval_ensembl_gene][ + "hgnc_id" + ] + interval_hgnc_symbol: str = gene_ids_mapping[interval_ensembl_gene][ + "hgnc_symbol" + ] + genes_covered_under_custom_threshold.add(interval_hgnc_symbol) + incomplete_coverages_rows.append( ( - { - "mane_select": interval.refseq_mane_select, - "mane_plus_clinical": interval.refseq_mane_plus_clinical, - "mrna": interval.refseq_mrna, - } - if isinstance(interval, SQLTranscript) - else {} - ), - sample_name, - round(interval_coverage_at_threshold * 100, 2), + interval_hgnc_symbol, + interval_hgnc_id, + ensembl_id, + ( + { + "mane_select": interval.refseq_mane_select, + "mane_plus_clinical": interval.refseq_mane_plus_clinical, + "mrna": interval.refseq_mrna, + } + if isinstance(interval, SQLTranscript) + else {} + ), + sample_name, + round(interval_coverage_at_threshold * 100, 2), + ) ) - ) - interval_ids.add(interval.ensembl_id) + interval_ids.add(ensembl_id) for threshold in completeness_thresholds: if thresholds_dict[threshold]: @@ -166,10 +194,9 @@ def get_gene_overview_stats( completeness_thresholds: List[int], ) -> Dict[str, list]: """Returns stats to be included in the gene overview page.""" - interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = [ - (interval.ensembl_id, (interval.chromosome, interval.start, interval.stop)) - for interval in sql_intervals - ] + interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = ( + set_interval_ids_coords(sql_intervals=sql_intervals) + ) interval_ids_coords = tuple( sort_interval_ids_coords(set(interval_ids_coords)) ) # removes duplicates and orders intervals by chromosome, start and stop diff --git a/src/chanjo2/meta/handle_load_intervals.py b/src/chanjo2/meta/handle_load_intervals.py index ab2a377a..4b400108 100644 --- a/src/chanjo2/meta/handle_load_intervals.py +++ b/src/chanjo2/meta/handle_load_intervals.py @@ -23,7 +23,6 @@ from chanjo2.models.pydantic_models import ( Builds, ExonBase, - GeneBase, IntervalType, TranscriptBase, ) @@ -45,8 +44,11 @@ def read_resource_lines(build: Builds, interval_type: IntervalType) -> Iterator[ def _replace_empty_cols(line: str, nr_expected_columns: int) -> List[Union[str, None]]: - """Split gene line into columns, replacing empty columns with None values.""" - cols = [None if col == "" else col.replace("HGNC:", "") for col in line.split("\t")] + """Split line into columns, replacing empty columns with None values.""" + cols = [ + None if cell == "" else cell.replace("HGNC:", "") for cell in line.split("\t") + ] + # Make sure that expected nr of cols are returned if last cols are blank cols += [None] * (nr_expected_columns - len(cols)) return cols @@ -57,6 +59,24 @@ async def update_genes( ) -> Optional[int]: """Loads genes into the database.""" + def update_or_insert_gene(session, sql_gene): + # Try to find the gene in the database + + existing_gene = ( + session.query(SQLGene) + .filter_by( + chromosome=sql_gene.chromosome, start=sql_gene.start, stop=sql_gene.stop + ) + .first() + ) + + if existing_gene: + # Gene exists, append the new ensembl_id to the existing ensembl_ids + existing_gene.ensembl_ids.append(sql_gene.ensembl_ids[0]) + else: + # Gene does not exist, add a new record + session.add(sql_gene) + LOG.info(f"Loading gene intervals. Genome build --> {build}") if lines is None: lines: Iterator[str] = read_resource_lines( @@ -82,16 +102,17 @@ async def update_genes( items: List = _replace_empty_cols(line=line, nr_expected_columns=len(header)) try: - gene: GeneBase = GeneBase( + sql_gene = SQLGene( build=build, chromosome=items[0], start=int(items[1]), stop=int(items[2]), - ensembl_id=items[3], + ensembl_ids=[items[3]], hgnc_symbol=items[4], hgnc_id=items[5], ) - genes_bulk.append(gene) + + update_or_insert_gene(session, sql_gene) # Update or insert the gene if len(genes_bulk) > MAX_NR_OF_RECORDS: bulk_insert_genes(db=session, genes=genes_bulk) diff --git a/src/chanjo2/meta/handle_report_contents.py b/src/chanjo2/meta/handle_report_contents.py index c709dfa1..fa4e9cdd 100644 --- a/src/chanjo2/meta/handle_report_contents.py +++ b/src/chanjo2/meta/handle_report_contents.py @@ -123,8 +123,9 @@ def get_report_data( ] gene_ids_mapping: Dict[str, dict] = { - gene.ensembl_id: {"hgnc_id": gene.hgnc_id, "hgnc_symbol": gene.hgnc_symbol} + ensembl_id: {"hgnc_id": gene.hgnc_id, "hgnc_symbol": gene.hgnc_symbol} for gene in genes + for ensembl_id in gene.ensembl_ids } sql_intervals: list = [] diff --git a/src/chanjo2/models/pydantic_models.py b/src/chanjo2/models/pydantic_models.py index 6ae60d6e..a940ff88 100644 --- a/src/chanjo2/models/pydantic_models.py +++ b/src/chanjo2/models/pydantic_models.py @@ -63,10 +63,13 @@ class Interval(IntervalBase): class GeneBase(IntervalBase): build: Builds - ensembl_id: str + ensembl_ids: List[str] hgnc_id: Optional[int] hgnc_symbol: Optional[str] + class Config: + orm_mode = True + class GeneQuery(BaseModel): build: Builds diff --git a/src/chanjo2/models/sql_models.py b/src/chanjo2/models/sql_models.py index 2f9a3778..3048c98e 100644 --- a/src/chanjo2/models/sql_models.py +++ b/src/chanjo2/models/sql_models.py @@ -1,7 +1,6 @@ from dataclasses import dataclass -from sqlalchemy import Column, Enum, ForeignKey, Index, Integer, String -from sqlalchemy.orm import relationship +from sqlalchemy import JSON, Column, Enum, ForeignKey, Index, Integer, String from chanjo2.dbutil import Base from chanjo2.models.pydantic_models import Builds @@ -28,14 +27,13 @@ class Gene(Base): chromosome = Column(String(6), nullable=False) start = Column(Integer, nullable=False) stop = Column(Integer, nullable=False) - ensembl_id = Column(String(24), nullable=False, index=True) + ensembl_ids = Column(JSON, nullable=False) hgnc_id = Column(Integer, nullable=True, index=True) hgnc_symbol = Column(String(64), nullable=True) build = Column( Enum(Builds, values_callable=lambda x: Builds.get_enum_values()), index=True ) __table_args__ = ( - Index("gene_idx_ensembl_id_build", "ensembl_id", "build"), Index("gene_idx_hgnc_id_build", "hgnc_id", "build"), Index("gene_idx_hgnc_symbol_build", "hgnc_symbol", "build"), ) @@ -56,18 +54,11 @@ class Transcript(Base): refseq_ncrna = Column(String(24), nullable=True) refseq_mane_select = Column(String(24), nullable=True, index=True) refseq_mane_plus_clinical = Column(String(24), nullable=True, index=True) - ensembl_gene_id = Column( - String(24), ForeignKey("genes.ensembl_id"), nullable=False, index=True - ) + ensembl_gene_id = Column(String(24), nullable=False, index=True) build = Column( Enum(Builds, values_callable=lambda x: Builds.get_enum_values()), index=True ) - genes = relationship( - "Gene", - primaryjoin="Transcript.ensembl_gene_id==Gene.ensembl_id", - ) - __table_args__ = ( Index("ensembl_gene_build_id", "ensembl_gene_id", "build", "ensembl_id"), ) @@ -85,18 +76,11 @@ class Exon(Base): rank_in_transcript = Column(Integer, nullable=False) ensembl_id = Column(String(24), nullable=False) ensembl_transcript_id = Column(String(24), nullable=False, index=True) - ensembl_gene_id = Column( - String(24), ForeignKey("genes.ensembl_id"), nullable=False, index=True - ) + ensembl_gene_id = Column(String(24), nullable=False, index=True) build = Column( Enum(Builds, values_callable=lambda x: Builds.get_enum_values()), index=True ) - genes = relationship( - "Gene", - primaryjoin="Exon.ensembl_gene_id==Gene.ensembl_id", - ) - __table_args__ = ( Index( "exon_idx_ensembl_gene_build_transcript", diff --git a/tests/conftest.py b/tests/conftest.py index e26c9732..9daab8bf 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -147,25 +147,25 @@ def demo_sql_genes() -> List[SQLGene]: """Return the 4 demo genes present in the demo gene panel as SQLGenes.""" gene_dicts = [ { - "ensembl_id": "ENSG00000228716", + "ensembl_ids": ["ENSG00000228716"], "chromosome": "5", "start": 79922047, "stop": 79950802, }, { - "ensembl_id": "ENSG00000110195", + "ensembl_ids": ["ENSG00000110195"], "chromosome": "11", "start": 71900602, "stop": 71907345, }, { - "ensembl_id": "ENSG00000177000", + "ensembl_ids": ["ENSG00000177000"], "chromosome": "1", "start": 11845780, "stop": 11866977, }, { - "ensembl_id": "ENSG00000076351", + "ensembl_ids": ["ENSG00000076351"], "chromosome": "17", "start": 26721661, "stop": 26734215,