diff --git a/modules/sparse_ref_panel.py b/modules/sparse_ref_panel.py index 0ce7ced..6faecb2 100644 --- a/modules/sparse_ref_panel.py +++ b/modules/sparse_ref_panel.py @@ -107,14 +107,14 @@ def __getitem__(self, key: Tuple[Union[int, list, slice]]) -> sparse.csc_matrix: ] ) row_stop = ( - min([key[0].stop, self.n_variants - 1]) + min([key[0].stop, self.n_variants]) if key[0].stop is not None - else self.n_variants - 1 + else self.n_variants ) chunks = list( range( (key[0].start or 0) // self.chunk_size, - row_stop // self.chunk_size + 1, + (row_stop - 1) // self.chunk_size + 1, chunk_step, ) ) @@ -123,16 +123,19 @@ def __getitem__(self, key: Tuple[Union[int, list, slice]]) -> sparse.csc_matrix: if len(chunks) == 1: return self._load_haplotypes(chunks[0])[ - key[0].start - % self.chunk_size : row_stop - % self.chunk_size : key[0].step, + key[0].start % self.chunk_size : row_stop % self.chunk_size + or self.chunk_size : key[0].step, key[1], ] slices = ( [slice(key[0].start % self.chunk_size, None, key[0].step)] + [slice(None, None, key[0].step)] * (len(chunks) - 2) - + [slice(None, row_stop % self.chunk_size, key[0].step)] + + [ + slice( + None, row_stop % self.chunk_size or self.chunk_size, key[0].step + ) + ] ) return sparse.vstack( [ @@ -238,7 +241,7 @@ def _load_sample_ids(self) -> List[str]: print(f"An error occurred: {str(e)}") return [] - def determine_start_position(self, vcf_path) -> int: + def _determine_start_position(self, vcf_path) -> int: cmd = f'bcftools query -f "%POS\n" {vcf_path} | head -1' result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, shell=True @@ -247,12 +250,12 @@ def determine_start_position(self, vcf_path) -> int: raise ValueError(f"Error executing bcftools: {result.stderr}") return int(result.stdout.strip()) - def determine_chunk_ranges(self, vcf_path, chr_length, num_variants): + def _determine_chunk_ranges(self, vcf_path, chr_length, num_variants): chr_length = int(chr_length) num_variants = int(num_variants) bp_per_variant = chr_length / num_variants bp_per_chunk = bp_per_variant * self.chunk_size - current_start = self.determine_start_position(vcf_path) + current_start = self._determine_start_position(vcf_path) ranges = [] while current_start < chr_length: end = min(current_start + bp_per_chunk, chr_length) @@ -263,7 +266,7 @@ def determine_chunk_ranges(self, vcf_path, chr_length, num_variants): ranges[-1] = (ranges[-1][0], 999_999_999) return ranges - def get_vcf_stats(self, vcf_path): + def _get_vcf_stats(self, vcf_path): cmd = ["bcftools", "index", "--stats", vcf_path] result = subprocess.run( cmd, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True @@ -279,8 +282,8 @@ def get_vcf_stats(self, vcf_path): return chrom, chr_length, num_variants def _ingest_variants(self, vcf_path: str, threads: int = os.cpu_count()): - chrom, chr_length, num_variants = self.get_vcf_stats(vcf_path) - chunk_ranges = self.determine_chunk_ranges(vcf_path, chr_length, num_variants) + chrom, chr_length, num_variants = self._get_vcf_stats(vcf_path) + chunk_ranges = self._determine_chunk_ranges(vcf_path, chr_length, num_variants) def process_chunk(args): start, end, chrom, vcf_path = args