Skip to content

Commit

Permalink
Merge pull request #47 from omicsedge/fix/determine-last-chunk
Browse files Browse the repository at this point in the history
fix: determine last chunk
  • Loading branch information
sandra-selfdecode authored Oct 2, 2024
2 parents cb6a966 + 16c1f62 commit 014232b
Showing 1 changed file with 16 additions and 13 deletions.
29 changes: 16 additions & 13 deletions modules/sparse_ref_panel.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,14 +107,14 @@ def __getitem__(self, key: Tuple[Union[int, list, slice]]) -> sparse.csc_matrix:
]
)
row_stop = (
min([key[0].stop, self.n_variants - 1])
min([key[0].stop, self.n_variants])
if key[0].stop is not None
else self.n_variants - 1
else self.n_variants
)
chunks = list(
range(
(key[0].start or 0) // self.chunk_size,
row_stop // self.chunk_size + 1,
(row_stop - 1) // self.chunk_size + 1,
chunk_step,
)
)
Expand All @@ -123,16 +123,19 @@ def __getitem__(self, key: Tuple[Union[int, list, slice]]) -> sparse.csc_matrix:

if len(chunks) == 1:
return self._load_haplotypes(chunks[0])[
key[0].start
% self.chunk_size : row_stop
% self.chunk_size : key[0].step,
key[0].start % self.chunk_size : row_stop % self.chunk_size
or self.chunk_size : key[0].step,
key[1],
]

slices = (
[slice(key[0].start % self.chunk_size, None, key[0].step)]
+ [slice(None, None, key[0].step)] * (len(chunks) - 2)
+ [slice(None, row_stop % self.chunk_size, key[0].step)]
+ [
slice(
None, row_stop % self.chunk_size or self.chunk_size, key[0].step
)
]
)
return sparse.vstack(
[
Expand Down Expand Up @@ -238,7 +241,7 @@ def _load_sample_ids(self) -> List[str]:
print(f"An error occurred: {str(e)}")
return []

def determine_start_position(self, vcf_path) -> int:
def _determine_start_position(self, vcf_path) -> int:
cmd = f'bcftools query -f "%POS\n" {vcf_path} | head -1'
result = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True
Expand All @@ -247,12 +250,12 @@ def determine_start_position(self, vcf_path) -> int:
raise ValueError(f"Error executing bcftools: {result.stderr}")
return int(result.stdout.strip())

def determine_chunk_ranges(self, vcf_path, chr_length, num_variants):
def _determine_chunk_ranges(self, vcf_path, chr_length, num_variants):
chr_length = int(chr_length)
num_variants = int(num_variants)
bp_per_variant = chr_length / num_variants
bp_per_chunk = bp_per_variant * self.chunk_size
current_start = self.determine_start_position(vcf_path)
current_start = self._determine_start_position(vcf_path)
ranges = []
while current_start < chr_length:
end = min(current_start + bp_per_chunk, chr_length)
Expand All @@ -263,7 +266,7 @@ def determine_chunk_ranges(self, vcf_path, chr_length, num_variants):
ranges[-1] = (ranges[-1][0], 999_999_999)
return ranges

def get_vcf_stats(self, vcf_path):
def _get_vcf_stats(self, vcf_path):
cmd = ["bcftools", "index", "--stats", vcf_path]
result = subprocess.run(
cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True
Expand All @@ -279,8 +282,8 @@ def get_vcf_stats(self, vcf_path):
return chrom, chr_length, num_variants

def _ingest_variants(self, vcf_path: str, threads: int = os.cpu_count()):
chrom, chr_length, num_variants = self.get_vcf_stats(vcf_path)
chunk_ranges = self.determine_chunk_ranges(vcf_path, chr_length, num_variants)
chrom, chr_length, num_variants = self._get_vcf_stats(vcf_path)
chunk_ranges = self._determine_chunk_ranges(vcf_path, chr_length, num_variants)

def process_chunk(args):
start, end, chrom, vcf_path = args
Expand Down

0 comments on commit 014232b

Please sign in to comment.