Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modify genes table converting the ensembl_id string field to ensembl_ids: an array of strings #394

Merged
merged 19 commits into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
### Changed
- Updated several libraries including schug (now v.7)
- Update project's Python version to 3.9
- Modified the structure of the database table `genes`, converting the `ensembl_id` string field to `ensembl_ids`: an array of strings. This change addresses recent changes in the MySQL: https://bugs.mysql.com/bug.php?id=114838
### Fixed
- The MariaDB healthcheck step in docker-compose-mysql.yml, preventing the demo app to start

Expand Down
50 changes: 37 additions & 13 deletions src/chanjo2/crud/intervals.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
from typing import List, Optional, Union

from sqlalchemy import delete, or_
from sqlalchemy import delete, or_, text
from sqlalchemy.orm import Session, query
from sqlalchemy.sql.expression import Delete

Expand Down Expand Up @@ -71,9 +71,19 @@ def get_genes(
) -> List[SQLGene]:
"""Return genes according to specified fields."""
genes: query.Query = db.query(SQLGene)

if ensembl_ids:
genes: query.Query = genes.filter(SQLGene.ensembl_id.in_(ensembl_ids))
ensembl_ids_placeholder = ", ".join(f"'{e}'" for e in ensembl_ids)
genes = genes.filter(
text(
f"""
EXISTS (
SELECT 1
FROM json_each(genes.ensembl_ids)
WHERE value IN ({ensembl_ids_placeholder})
)
"""
)
)
elif hgnc_ids:
genes: query.Query = genes.filter(SQLGene.hgnc_id.in_(hgnc_ids))
elif hgnc_symbols:
Expand Down Expand Up @@ -151,7 +161,9 @@ def set_sql_intervals(
ensembl_ids=None,
hgnc_ids=None,
hgnc_symbols=None,
ensembl_gene_ids=[gene.ensembl_id for gene in genes],
ensembl_gene_ids=[
ensembl_id for gene in genes for ensembl_id in gene.ensembl_ids
],
limit=None,
transcript_tags=transcript_tags,
)
Expand All @@ -171,27 +183,39 @@ def get_gene_intervals(
) -> List[Union[SQLTranscript, SQLExon]]:
"""Retrieve transcripts or exons from a list of genes."""

intervals: query.Query = db.query(interval_type).join(SQLGene)
intervals = db.query(interval_type).filter(interval_type.build == build)

def get_ensembl_gene_ids_from_gene_filter(
filter_value: List[Union[str, int]], filter_column: str
) -> List[str]:
"""Helper function to get ensembl_gene_ids from either hgnc_ids or hgnc_symbols."""
genes = (
db.query(SQLGene.ensembl_ids).filter(filter_column.in_(filter_value)).all()
)
return [ensembl_id for gene in genes for ensembl_id in gene.ensembl_ids]

if ensembl_ids:
intervals: query.Query = intervals.filter(
interval_type.ensembl_id.in_(ensembl_ids)
)
elif ensembl_gene_ids:
intervals: query.Query = intervals.filter(
interval_type.ensembl_gene_id.in_(ensembl_gene_ids)
)
elif hgnc_ids:
intervals: query.Query = intervals.filter(SQLGene.hgnc_id.in_(hgnc_ids))
ensembl_gene_ids = get_ensembl_gene_ids_from_gene_filter(
hgnc_ids, SQLGene.hgnc_id
)
elif hgnc_symbols:
intervals: query.Query = intervals.filter(SQLGene.hgnc_symbol.in_(hgnc_symbols))
ensembl_gene_ids = get_ensembl_gene_ids_from_gene_filter(
hgnc_symbols, SQLGene.hgnc_symbol
)
if ensembl_gene_ids:
intervals = intervals.filter(
interval_type.ensembl_gene_id.in_(ensembl_gene_ids)
)

if interval_type == SQLTranscript and transcript_tags:
intervals = _filter_transcripts_by_tag(
transcripts=intervals, transcript_tags=transcript_tags
)

intervals: query.Query = intervals.filter(interval_type.build == build)

if limit:
return intervals.limit(limit).all()

Expand Down
10 changes: 5 additions & 5 deletions src/chanjo2/endpoints/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
get_d4tools_chromosome_mean_coverage,
get_d4tools_intervals_mean_coverage,
)
from chanjo2.meta.handle_d4 import get_samples_sex_metrics
from chanjo2.meta.handle_d4 import get_samples_sex_metrics, set_interval_ids_coords
from chanjo2.meta.handle_report_contents import INTERVAL_TYPE_SQL_TYPE, get_mean
from chanjo2.models import SQLGene
from chanjo2.models.pydantic_models import (
Expand Down Expand Up @@ -167,10 +167,10 @@ def d4_genes_condensed_summary(
detail=WRONG_COVERAGE_FILE_MSG,
)

interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = [
(interval.ensembl_id, (interval.chromosome, interval.start, interval.stop))
for interval in sql_intervals
]
interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = (
set_interval_ids_coords(sql_intervals=sql_intervals)
)

# Sort intervals by chrom, start & stop
interval_ids_coords = sort_interval_ids_coords(interval_ids_coords)

Expand Down
129 changes: 78 additions & 51 deletions src/chanjo2/meta/handle_d4.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,26 @@
LOG = logging.getLogger(__name__)


def set_interval_ids_coords(
sql_intervals: List[Union[SQLGene, SQLTranscript, SQLExon]]
) -> List[Tuple[str, Tuple[str, int, int]]]:
"""Returns tuples with an ensembl_id and coordinates from a list of SQL intervals."""

if not sql_intervals:
return []
if isinstance(sql_intervals[0], SQLGene):
return [
(ensembl_id, (interval.chromosome, interval.start, interval.stop))
for interval in sql_intervals
for ensembl_id in interval.ensembl_ids
]
else:
return [
(interval.ensembl_id, (interval.chromosome, interval.start, interval.stop))
for interval in sql_intervals
]


def get_report_sample_interval_coverage(
d4_file_path: str,
sample_name: str,
Expand All @@ -31,11 +51,9 @@ def get_report_sample_interval_coverage(
"""Compute stats to populate a coverage report for one sample."""

# Compute intervals coverage completeness
interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = [
(interval.ensembl_id, (interval.chromosome, interval.start, interval.stop))
for interval in sql_intervals
]

interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = (
set_interval_ids_coords(sql_intervals=sql_intervals)
)
interval_ids_coords = sort_interval_ids_coords(interval_ids_coords)

# Compute intervals coverage
Expand All @@ -57,51 +75,61 @@ def get_report_sample_interval_coverage(
nr_intervals_covered_under_custom_threshold: int = 0
genes_covered_under_custom_threshold = set()

for interval_nr, interval in enumerate(sql_intervals):

if interval.ensembl_id in interval_ids:
continue
for threshold in completeness_thresholds:
interval_coverage_at_threshold: float = intervals_coverage_completeness[
interval.ensembl_id
][threshold]
thresholds_dict[threshold].append(interval_coverage_at_threshold)

# Collect intervals which are not completely covered at the custom threshold
if threshold == default_threshold and interval_coverage_at_threshold < 1:
nr_intervals_covered_under_custom_threshold += 1
interval_ensembl_gene: str = (
interval.ensembl_id
if interval.ensembl_id.startswith("ENSG")
else interval.ensembl_gene_id
)
interval_hgnc_id: int = gene_ids_mapping[interval_ensembl_gene][
"hgnc_id"
]
interval_hgnc_symbol: str = gene_ids_mapping[interval_ensembl_gene][
"hgnc_symbol"
]
genes_covered_under_custom_threshold.add(interval_hgnc_symbol)
incomplete_coverages_rows.append(
(
interval_hgnc_symbol,
interval_hgnc_id,
interval.ensembl_id,
for interval in sql_intervals:

if hasattr(interval, "ensembl_ids"):
ensembl_ids = interval.ensembl_ids
else:
ensembl_ids = [interval.ensembl_id]

for ensembl_id in ensembl_ids:

if ensembl_id in interval_ids:
continue
for threshold in completeness_thresholds:
interval_coverage_at_threshold: float = intervals_coverage_completeness[
ensembl_id
][threshold]
thresholds_dict[threshold].append(interval_coverage_at_threshold)

# Collect intervals which are not completely covered at the custom threshold
if (
threshold == default_threshold
and interval_coverage_at_threshold < 1
):
nr_intervals_covered_under_custom_threshold += 1
interval_ensembl_gene: str = (
ensembl_id
if ensembl_id.startswith("ENSG")
else interval.ensembl_gene_id
)
interval_hgnc_id: int = gene_ids_mapping[interval_ensembl_gene][
"hgnc_id"
]
interval_hgnc_symbol: str = gene_ids_mapping[interval_ensembl_gene][
"hgnc_symbol"
]
genes_covered_under_custom_threshold.add(interval_hgnc_symbol)
incomplete_coverages_rows.append(
(
{
"mane_select": interval.refseq_mane_select,
"mane_plus_clinical": interval.refseq_mane_plus_clinical,
"mrna": interval.refseq_mrna,
}
if isinstance(interval, SQLTranscript)
else {}
),
sample_name,
round(interval_coverage_at_threshold * 100, 2),
interval_hgnc_symbol,
interval_hgnc_id,
ensembl_id,
(
{
"mane_select": interval.refseq_mane_select,
"mane_plus_clinical": interval.refseq_mane_plus_clinical,
"mrna": interval.refseq_mrna,
}
if isinstance(interval, SQLTranscript)
else {}
),
sample_name,
round(interval_coverage_at_threshold * 100, 2),
)
)
)

interval_ids.add(interval.ensembl_id)
interval_ids.add(ensembl_id)

for threshold in completeness_thresholds:
if thresholds_dict[threshold]:
Expand Down Expand Up @@ -166,10 +194,9 @@ def get_gene_overview_stats(
completeness_thresholds: List[int],
) -> Dict[str, list]:
"""Returns stats to be included in the gene overview page."""
interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = [
(interval.ensembl_id, (interval.chromosome, interval.start, interval.stop))
for interval in sql_intervals
]
interval_ids_coords: List[Tuple[str, Tuple[str, int, int]]] = (
set_interval_ids_coords(sql_intervals=sql_intervals)
)
interval_ids_coords = tuple(
sort_interval_ids_coords(set(interval_ids_coords))
) # removes duplicates and orders intervals by chromosome, start and stop
Expand Down
33 changes: 27 additions & 6 deletions src/chanjo2/meta/handle_load_intervals.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
from chanjo2.models.pydantic_models import (
Builds,
ExonBase,
GeneBase,
IntervalType,
TranscriptBase,
)
Expand All @@ -45,8 +44,11 @@ def read_resource_lines(build: Builds, interval_type: IntervalType) -> Iterator[


def _replace_empty_cols(line: str, nr_expected_columns: int) -> List[Union[str, None]]:
"""Split gene line into columns, replacing empty columns with None values."""
cols = [None if col == "" else col.replace("HGNC:", "") for col in line.split("\t")]
"""Split line into columns, replacing empty columns with None values."""
cols = [
None if cell == "" else cell.replace("HGNC:", "") for cell in line.split("\t")
]

# Make sure that expected nr of cols are returned if last cols are blank
cols += [None] * (nr_expected_columns - len(cols))
return cols
Expand All @@ -57,6 +59,24 @@ async def update_genes(
) -> Optional[int]:
"""Loads genes into the database."""

def update_or_insert_gene(session, sql_gene):
# Try to find the gene in the database

existing_gene = (
session.query(SQLGene)
.filter_by(
chromosome=sql_gene.chromosome, start=sql_gene.start, stop=sql_gene.stop
)
.first()
)

if existing_gene:
# Gene exists, append the new ensembl_id to the existing ensembl_ids
existing_gene.ensembl_ids.append(sql_gene.ensembl_ids[0])
else:
# Gene does not exist, add a new record
session.add(sql_gene)

LOG.info(f"Loading gene intervals. Genome build --> {build}")
if lines is None:
lines: Iterator[str] = read_resource_lines(
Expand All @@ -82,16 +102,17 @@ async def update_genes(
items: List = _replace_empty_cols(line=line, nr_expected_columns=len(header))

try:
gene: GeneBase = GeneBase(
sql_gene = SQLGene(
build=build,
chromosome=items[0],
start=int(items[1]),
stop=int(items[2]),
ensembl_id=items[3],
ensembl_ids=[items[3]],
hgnc_symbol=items[4],
hgnc_id=items[5],
)
genes_bulk.append(gene)

update_or_insert_gene(session, sql_gene) # Update or insert the gene

if len(genes_bulk) > MAX_NR_OF_RECORDS:
bulk_insert_genes(db=session, genes=genes_bulk)
Expand Down
3 changes: 2 additions & 1 deletion src/chanjo2/meta/handle_report_contents.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,8 +123,9 @@ def get_report_data(
]

gene_ids_mapping: Dict[str, dict] = {
gene.ensembl_id: {"hgnc_id": gene.hgnc_id, "hgnc_symbol": gene.hgnc_symbol}
ensembl_id: {"hgnc_id": gene.hgnc_id, "hgnc_symbol": gene.hgnc_symbol}
for gene in genes
for ensembl_id in gene.ensembl_ids
}

sql_intervals: list = []
Expand Down
5 changes: 4 additions & 1 deletion src/chanjo2/models/pydantic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ class Interval(IntervalBase):

class GeneBase(IntervalBase):
build: Builds
ensembl_id: str
ensembl_ids: List[str]
hgnc_id: Optional[int]
hgnc_symbol: Optional[str]

class Config:
orm_mode = True


class GeneQuery(BaseModel):
build: Builds
Expand Down
Loading
Loading