Skip to content

Commit

Permalink
allow customized gene_id and gene_name keys in make_gene_matrix.
Browse files Browse the repository at this point in the history
  • Loading branch information
kaizhang committed Dec 28, 2023
1 parent b5376c5 commit 4ab7506
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 96 deletions.
2 changes: 1 addition & 1 deletion snapatac2-core/src/preprocessing/count_data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ pub use crate::preprocessing::qc;
pub use import::{import_fragments, import_contacts};
pub use coverage::{GenomeCount, ContactMap, FragmentType, fragments_to_insertions};
pub use genome::{
Transcript, Promoters, FeatureCounter, TranscriptCount, GeneCount,
TranscriptParserOptions, Transcript, Promoters, FeatureCounter, TranscriptCount, GeneCount,
read_transcripts_from_gff, read_transcripts_from_gtf,
ChromSizes, ChromValueIter, ChromValues, GenomeBaseIndex,
};
Expand Down
164 changes: 79 additions & 85 deletions snapatac2-core/src/preprocessing/count_data/genome.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
//! genomic feature counts in Rust.
use noodles::{core::Position, gff, gff::record::Strand, gtf};
use bed_utils::bed::tree::GenomeRegions;
use anyhow::Result;
use anyhow::{Result, bail};
use std::{collections::{BTreeMap, HashMap}, fmt::Debug, io::BufRead};
use indexmap::map::IndexMap;
use bed_utils::bed::{GenomicRange, BEDLike, tree::SparseCoverage};
Expand Down Expand Up @@ -49,90 +49,84 @@ pub struct Transcript {
pub strand: Strand,
}

impl TryFrom<gtf::Record> for Transcript {
type Error = anyhow::Error;
pub struct TranscriptParserOptions {
pub transcript_name_key: String,
pub transcript_id_key: String,
pub gene_name_key: String,
pub gene_id_key: String,
}

fn try_from(record: gtf::Record) -> Result<Self, Self::Error> {
if record.ty() != "transcript" {
return Err(anyhow::anyhow!("record is not a transcript"));
impl<'a> Default for TranscriptParserOptions {
fn default() -> Self {
Self {
transcript_name_key: "transcript_name".to_string(),
transcript_id_key: "transcript_id".to_string(),
gene_name_key: "gene_name".to_string(),
gene_id_key: "gene_id".to_string(),
}

let err_msg =
|x: &str| -> String { format!("failed to find '{}' in record: {}", x, record) };

let left = record.start();
let right = record.end();
let attributes: HashMap<&str, &str> = record
.attributes()
.iter()
.map(|x| (x.key(), x.value()))
.collect();
Ok(Transcript {
transcript_name: attributes.get("transcript_name").map(|x| x.to_string()),
transcript_id: attributes
.get("transcript_id")
.expect(&err_msg("transcript_id"))
.to_string(),
gene_name: attributes
.get("gene_name")
.expect(&err_msg("gene_name"))
.to_string(),
gene_id: attributes
.get("gene_id")
.expect(&err_msg("gene_id"))
.to_string(),
is_coding: attributes
.get("transcript_type")
.map(|x| *x == "protein_coding"),
chrom: record.reference_sequence_name().to_string(),
left,
right,
strand: match record.strand() {
None => Strand::None,
Some(gtf::record::Strand::Forward) => Strand::Forward,
Some(gtf::record::Strand::Reverse) => Strand::Reverse,
},
})
}
}

impl TryFrom<gff::Record> for Transcript {
type Error = anyhow::Error;

fn try_from(record: gff::Record) -> Result<Self, Self::Error> {
if record.ty() != "transcript" {
return Err(anyhow::anyhow!("record is not a transcript"));
}
fn from_gtf(record: gtf::Record, options: &TranscriptParserOptions) -> Result<Transcript> {
if record.ty() != "transcript" {
bail!("record is not a transcript");
}

let left = record.start();
let right = record.end();
let attributes: HashMap<&str, &str> = record
.attributes()
.iter()
.map(|x| (x.key(), x.value()))
.collect();
let get_attr = |key: &str| -> String {
attributes.get(key).expect(&format!("failed to find '{}' in record: {}", key, record)) .to_string()
};

Ok(Transcript {
transcript_name: attributes.get(options.transcript_name_key.as_str()).map(|x| x.to_string()),
transcript_id: get_attr(options.transcript_id_key.as_str()),
gene_name: get_attr(options.gene_name_key.as_str()),
gene_id: get_attr(options.gene_id_key.as_str()),
is_coding: attributes
.get("transcript_type")
.map(|x| *x == "protein_coding"),
chrom: record.reference_sequence_name().to_string(),
left,
right,
strand: match record.strand() {
None => Strand::None,
Some(gtf::record::Strand::Forward) => Strand::Forward,
Some(gtf::record::Strand::Reverse) => Strand::Reverse,
},
})
}

let err_msg =
|x: &str| -> String { format!("failed to find '{}' in record: {}", x, record) };

let left = record.start();
let right = record.end();
let attributes = record.attributes();
Ok(Transcript {
transcript_name: attributes.get("transcript_name").map(|x| x.to_string()),
transcript_id: attributes
.get("transcript_id")
.expect(&err_msg("transcript_id"))
.to_string(),
gene_name: attributes
.get("gene_name")
.expect(&err_msg("gene_name"))
.to_string(),
gene_id: attributes
.get("gene_id")
.expect(&err_msg("gene_id"))
.to_string(),
is_coding: attributes
.get("transcript_type")
.map(|x| x.as_string() == Some("protein_coding")),
chrom: record.reference_sequence_name().to_string(),
left,
right,
strand: record.strand(),
})
}
fn from_gff(record: gff::Record, options: &TranscriptParserOptions) -> Result<Transcript> {
if record.ty() != "transcript" {
bail!("record is not a transcript");
}

let left = record.start();
let right = record.end();
let attributes = record.attributes();
let get_attr = |key: &str| -> String {
attributes.get(key).expect(&format!("failed to find '{}' in record: {}", key, record)) .to_string()
};

Ok(Transcript {
transcript_name: attributes.get(options.transcript_name_key.as_str()).map(|x| x.to_string()),
transcript_id: get_attr(options.transcript_id_key.as_str()),
gene_name: get_attr(options.gene_name_key.as_str()),
gene_id: get_attr(options.gene_id_key.as_str()),
is_coding: attributes
.get("transcript_type")
.map(|x| x.as_string() == Some("protein_coding")),
chrom: record.reference_sequence_name().to_string(),
left,
right,
strand: record.strand(),
})
}

impl Transcript {
Expand All @@ -147,28 +141,28 @@ impl Transcript {
}
}

pub fn read_transcripts_from_gtf<R>(input: R) -> Result<Vec<Transcript>>
pub fn read_transcripts_from_gtf<R>(input: R, options: &TranscriptParserOptions) -> Result<Vec<Transcript>>
where
R: BufRead,
{
gtf::Reader::new(input)
.records()
.try_fold(Vec::new(), |mut acc, rec| {
if let Ok(transcript) = rec?.try_into() {
if let Ok(transcript) = from_gtf(rec?, options) {
acc.push(transcript);
}
Ok(acc)
})
}

pub fn read_transcripts_from_gff<R>(input: R) -> Result<Vec<Transcript>>
pub fn read_transcripts_from_gff<R>(input: R, options: &TranscriptParserOptions) -> Result<Vec<Transcript>>
where
R: BufRead,
{
gff::Reader::new(input)
.records()
.try_fold(Vec::new(), |mut acc, rec| {
if let Ok(transcript) = rec?.try_into() {
if let Ok(transcript) = from_gff(rec?, options) {
acc.push(transcript);
}
Ok(acc)
Expand Down Expand Up @@ -824,11 +818,11 @@ mod tests {
strand: Strand::Forward,
};
assert_eq!(
read_transcripts_from_gff(gff.as_bytes()).unwrap()[0],
read_transcripts_from_gff(gff.as_bytes(), &Default::default()).unwrap()[0],
expected
);
assert_eq!(
read_transcripts_from_gtf(gtf.as_bytes()).unwrap()[0],
read_transcripts_from_gtf(gtf.as_bytes(), &Default::default()).unwrap()[0],
expected
);
}
Expand Down
20 changes: 18 additions & 2 deletions snapatac2-python/snapatac2/preprocessing/_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -551,6 +551,10 @@ def make_gene_matrix(
chunk_size: int = 500,
use_x: bool = False,
id_type: Literal['gene', 'transcript'] = "gene",
transcript_name_key: str = "transcript_name",
transcript_id_key: str = "transcript_id",
gene_name_key: str = "gene_name",
gene_id_key: str = "gene_id",
min_frag_size: int | None = None,
max_frag_size: int | None = None,
) -> internal.AnnData:
Expand Down Expand Up @@ -582,6 +586,14 @@ def make_gene_matrix(
Otherwise the `.obsm['insertion']` is used.
id_type
"gene" or "transcript".
transcript_name_key
The key of the transcript name in the gene annotation file.
transcript_id_key
The key of the transcript id in the gene annotation file.
gene_name_key
The key of the gene name in the gene annotation file.
gene_id_key
The key of the gene id in the gene annotation file.
min_frag_size
Minimum fragment size to include.
max_frag_size
Expand Down Expand Up @@ -612,7 +624,9 @@ def make_gene_matrix(
gene_anno = gene_anno.annotation

if inplace:
internal.mk_gene_matrix(adata, gene_anno, chunk_size, use_x, id_type, min_frag_size, max_frag_size, None)
internal.mk_gene_matrix(adata, gene_anno, chunk_size, use_x, id_type,
transcript_name_key, transcript_id_key, gene_name_key, gene_id_key,
min_frag_size, max_frag_size, None)
else:
if file is None:
if adata.isbacked:
Expand All @@ -621,7 +635,9 @@ def make_gene_matrix(
out = AnnData(obs=adata.obs[:])
else:
out = internal.AnnData(filename=file, backend=backend, obs=adata.obs[:])
internal.mk_gene_matrix(adata, gene_anno, chunk_size, use_x, id_type, min_frag_size, max_frag_size, out)
internal.mk_gene_matrix(adata, gene_anno, chunk_size, use_x, id_type,
transcript_name_key, transcript_id_key, gene_name_key, gene_id_key,
min_frag_size, max_frag_size, out)
return out

def filter_cells(
Expand Down
2 changes: 1 addition & 1 deletion snapatac2-python/src/network.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ pub(crate) fn link_region_to_gene(
) -> HashMap<(String, String), Vec<(String, String, u64)>>
{
let promoters = Promoters::new(
read_transcripts(annot_fl).into_iter()
read_transcripts(annot_fl, &Default::default()).into_iter()
.filter(|x| if coding_gene_only { x.is_coding.unwrap_or(true) } else { true })
.collect(),
upstream,
Expand Down
13 changes: 12 additions & 1 deletion snapatac2-python/src/preprocessing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ use crate::utils::*;

use anndata::Backend;
use anndata_hdf5::H5;
use snapatac2_core::preprocessing::count_data::TranscriptParserOptions;
use std::io::{BufRead, BufReader};
use std::path::PathBuf;
use std::{str::FromStr, collections::BTreeMap, ops::Deref, collections::HashSet};
Expand Down Expand Up @@ -263,12 +264,22 @@ pub(crate) fn mk_gene_matrix(
chunk_size: usize,
use_x: bool,
id_type: &str,
transcript_name_key: String,
transcript_id_key: String,
gene_name_key: String,
gene_id_key: String,
min_fragment_size: Option<u64>,
max_fragment_size: Option<u64>,
out: Option<AnnDataLike>,
) -> Result<()>
{
let transcripts = read_transcripts(gff_file);
let options = TranscriptParserOptions {
transcript_name_key,
transcript_id_key,
gene_name_key,
gene_id_key,
};
let transcripts = read_transcripts(gff_file, &options);
macro_rules! run {
($data:expr) => {
if let Some(out) = out {
Expand Down
11 changes: 6 additions & 5 deletions snapatac2-python/src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ use pyo3::{
PyResult, Python,
};
use numpy::{Element, PyReadonlyArrayDyn, PyReadonlyArray, Ix1, Ix2, PyArray, IntoPyArray};
use snapatac2_core::preprocessing::count_data::TranscriptParserOptions;
use snapatac2_core::preprocessing::{Transcript, read_transcripts_from_gff, read_transcripts_from_gtf};
use snapatac2_core::utils;

Expand Down Expand Up @@ -257,18 +258,18 @@ pub(crate) fn kmeans<'py>(
Ok(model.predict(observations).targets.into_pyarray(py))
}

pub fn read_transcripts<P: AsRef<std::path::Path>>(file_path: P) -> Vec<Transcript> {
pub fn read_transcripts<P: AsRef<std::path::Path>>(file_path: P, options: &TranscriptParserOptions) -> Vec<Transcript> {
let path = if file_path.as_ref().extension().unwrap() == "gz" {
file_path.as_ref().file_stem().unwrap().as_ref()
} else {
file_path.as_ref()
};
if path.extension().unwrap() == "gff" {
read_transcripts_from_gff(BufReader::new(utils::open_file_for_read(file_path))).unwrap()
read_transcripts_from_gff(BufReader::new(utils::open_file_for_read(file_path)), options).unwrap()
} else if path.extension().unwrap() == "gtf" {
read_transcripts_from_gtf(BufReader::new(utils::open_file_for_read(file_path))).unwrap()
read_transcripts_from_gtf(BufReader::new(utils::open_file_for_read(file_path)), options).unwrap()
} else {
read_transcripts_from_gff(BufReader::new(utils::open_file_for_read(file_path.as_ref())))
.unwrap_or_else(|_| read_transcripts_from_gtf(BufReader::new(utils::open_file_for_read(file_path))).unwrap())
read_transcripts_from_gff(BufReader::new(utils::open_file_for_read(file_path.as_ref())), options)
.unwrap_or_else(|_| read_transcripts_from_gtf(BufReader::new(utils::open_file_for_read(file_path)), options).unwrap())
}
}
2 changes: 1 addition & 1 deletion snapatac2-python/tests/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,4 +66,4 @@ def test_in_memory():
chrom_sizes=snap.genome.hg38,
sorted_by_barcode=False,
)
pipeline(data)
pipeline(data)

0 comments on commit 4ab7506

Please sign in to comment.