Skip to content

Commit

Permalink
Merge pull request #4329 from vgteam/haplotype-sampling-improvements
Browse files Browse the repository at this point in the history
Haplotype sampling improvements
  • Loading branch information
jltsiren authored Jul 4, 2024
2 parents a049c6b + 333f269 commit 7009793
Show file tree
Hide file tree
Showing 9 changed files with 599 additions and 465 deletions.
676 changes: 299 additions & 377 deletions src/gbwt_extender.cpp

Large diffs are not rendered by default.

19 changes: 15 additions & 4 deletions src/gbwt_extender.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -366,13 +366,27 @@ class WFAExtender {
return std::min(max, (int32_t)(per_base * length) + min);
}
};

/// Limits for mismatches
Event mismatches;
/// Limits for total gaps (*not* gap opens; a gap open uses 1 gap and 1 gap length)
Event gaps;
/// Limits for total gap length (gap extends plus gap opens)
Event gap_length;
/// Limits for alignments that have fallen too far behind.
Event distance;

/// Default error model for mismatches.
constexpr static Event default_mismatches() { return { 0.03, 1, 6 }; }

/// Default error model for gaps.
constexpr static Event default_gaps() { return { 0.05, 1, 10 }; }

/// Default error model for gap length.
constexpr static Event default_gap_length() { return { 0.1, 1, 20 }; }

/// Default error model for distance.
constexpr static Event default_distance() { return { 0.1, 10, 200 }; }
};

/// If not specified, we use this default error model.
Expand Down Expand Up @@ -438,9 +452,6 @@ class WFAExtender {
ReadMasker mask;
const Aligner* aligner;
const ErrorModel* error_model;

/// TODO: Remove when unnecessary.
bool debug = false;
};

//------------------------------------------------------------------------------
Expand Down
191 changes: 151 additions & 40 deletions src/recombinator.cpp

Large diffs are not rendered by default.

65 changes: 53 additions & 12 deletions src/recombinator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,14 @@ class HaplotypePartitioner {
/// End node.
handle_t end;

/// Shortest distance from the last base of `start` to the first base of `end`,
/// if both are present.
std::uint32_t length;

/// Number of additional snarls included in the subchain to keep reversals
/// within the subchain.
std::uint32_t extra_snarls;

/// Returns `true` if the subchain has a start node.
bool has_start() const { return (this->type == Haplotypes::Subchain::normal || this->type == Haplotypes::Subchain::suffix); }

Expand All @@ -288,8 +296,15 @@ class HaplotypePartitioner {
/// Target length for subchains (in bp).
size_t subchain_length = SUBCHAIN_LENGTH;

/// Generate approximately this many jobs.
/// Generate approximately this many jobs.
size_t approximate_jobs = APPROXIMATE_JOBS;

/// Avoid placing subchain boundaries in places where haplotypes would
/// cross them multiple times.
bool linear_structure = false;

/// Print a description of the parameters.
void print(std::ostream& out) const;
};

/**
Expand All @@ -302,9 +317,13 @@ class HaplotypePartitioner {
* Each top-level chain is partitioned into subchains that consist of one or
* more snarls. Multiple snarls are combined into the same subchain if the
* minimum distance over the subchain is at most the target length and there
* are GBWT haplotypes that cross the subchain. If there are no snarls in a
* top-level chain, it is represented as a single subchain without boundary
* nodes.
* are GBWT haplotypes that cross the subchain. We also keep extending the
* subchain if a haplotype would cross the end in both directions. By doing
* this, we can avoid sequence loss with haplotypes reversing their direction,
* while keeping kmers specific to each subchain.
*
* If there are no snarls in a top-level chain, it is represented as a single
* subchain without boundary nodes.
*
* Haplotypes crossing each subchain are represented using minimizers with a
* single occurrence in the graph.
Expand All @@ -325,6 +344,9 @@ class HaplotypePartitioner {
// Return the minimum distance from the last base of `from` to the first base of `to`.
size_t get_distance(handle_t from, handle_t to) const;

// Returns true if a haplotype visits the node in both orientations.
bool contains_reversals(handle_t handle) const;

// Partition the top-level chain into subchains.
std::vector<Subchain> get_subchains(const gbwtgraph::TopLevelChain& chain, const Parameters& parameters) const;

Expand Down Expand Up @@ -432,7 +454,7 @@ class Recombinator {
};

/// Creates a new `Recombinator`.
Recombinator(const gbwtgraph::GBZ& gbz, Verbosity verbosity);
Recombinator(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, Verbosity verbosity);

/// Parameters for `generate_haplotypes()`.
struct Parameters {
Expand Down Expand Up @@ -460,17 +482,36 @@ class Recombinator {
/// the wrong variants out.
double absent_score = ABSENT_SCORE;

/// Use the haploid scoring model. The most common kmer count is used as
/// the coverage estimate. Kmers that would be classified as heterozygous
/// are treated as homozygous.
bool haploid_scoring = false;

/// After selecting the initial `num_haplotypes` haplotypes, choose the
/// highest-scoring pair out of them.
bool diploid_sampling = false;

/// Include named and reference paths.
bool include_reference = false;

/// Preset parameters for common use cases.
enum preset_t {
/// Default parameters.
preset_default,
/// Best practices for haploid sampling.
preset_haploid,
/// Best practices for diploid sampling.
preset_diploid
};

explicit Parameters(preset_t preset = preset_default);

/// Print a description of the parameters.
void print(std::ostream& out) const;
};

/**
* Generates haplotypes based on the given `Haplotypes` representation and
* the kmer counts in the given KFF file.
* Generates haplotypes based on the kmer counts in the given KFF file.
*
* Runs multiple GBWT construction jobs in parallel using OpenMP threads and
* generates the specified number of haplotypes in each top-level chain
Expand All @@ -485,7 +526,7 @@ class Recombinator {
* Throws `std::runtime_error` on error in single-threaded parts and exits
* with `std::exit(EXIT_FAILURE)` in multi-threaded parts.
*/
gbwt::GBWT generate_haplotypes(const Haplotypes& haplotypes, const std::string& kff_file, const Parameters& parameters) const;
gbwt::GBWT generate_haplotypes(const std::string& kff_file, const Parameters& parameters) const;

/// A local haplotype sequence within a single subchain.
struct LocalHaplotype {
Expand All @@ -510,23 +551,23 @@ class Recombinator {
*
* Throws `std::runtime_error` on error.
*/
std::vector<char> classify_kmers(const Haplotypes& haplotypes, const std::string& kff_file, const Parameters& parameters) const;
std::vector<char> classify_kmers(const std::string& kff_file, const Parameters& parameters) const;

/**
* Extracts the local haplotypes in the given subchain. In addition to the
* haplotype sequence, this also reports the name of the corresponding path
* as well as (rank, score) for the haplotype in each round of haplotype
* selection. The number of rounds is `parameters.num_haplotyeps`, but if
* selection. The number of rounds is `parameters.num_haplotypes`, but if
* the haplotype is selected earlier, it will not get further scores.
*
* Throws `std::runtime_error` on error.
*/
std::vector<LocalHaplotype> extract_sequences(
const Haplotypes& haplotypes, const std::string& kff_file,
size_t chain_id, size_t subchain_id, const Parameters& parameters
const std::string& kff_file, size_t chain_id, size_t subchain_id, const Parameters& parameters
) const;

const gbwtgraph::GBZ& gbz;
const Haplotypes& haplotypes;
Verbosity verbosity;

private:
Expand Down
7 changes: 6 additions & 1 deletion src/subcommand/gbwt_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1049,7 +1049,12 @@ void validate_gbwt_config(GBWTConfig& config) {
}
}

if (!config.graph_output.empty()) {
if (config.graph_output.empty()) {
if (config.gbz_format) {
std::cerr << "error: [vg gbwt] GBZ format requires graph output" << std::endl;
std::exit(EXIT_FAILURE);
}
} else {
if (!has_graph_input || !one_input_gbwt) {
std::cerr << "error: [vg gbwt] GBWTGraph construction requires an input graph and and one input GBWT" << std::endl;
std::exit(EXIT_FAILURE);
Expand Down
9 changes: 3 additions & 6 deletions src/subcommand/giraffe_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1669,14 +1669,11 @@ string sample_haplotypes(const vector<pair<string, string>>& indexes, string& ba

// Sample haplotypes.
Haplotypes::Verbosity verbosity = (progress ? Haplotypes::verbosity_basic : Haplotypes::verbosity_silent);
Recombinator recombinator(gbz, verbosity);
Recombinator::Parameters parameters;
parameters.num_haplotypes = Recombinator::NUM_CANDIDATES;
parameters.diploid_sampling = true;
parameters.include_reference = true;
Recombinator recombinator(gbz, haplotypes, verbosity);
Recombinator::Parameters parameters(Recombinator::Parameters::preset_diploid);
gbwt::GBWT sampled_gbwt;
try {
sampled_gbwt = recombinator.generate_haplotypes(haplotypes, kff_file, parameters);
sampled_gbwt = recombinator.generate_haplotypes(kff_file, parameters);
} catch (const std::runtime_error& e) {
std::cerr << "error:[vg giraffe] Haplotype sampling failed: " << e.what() << std::endl;
std::exit(EXIT_FAILURE);
Expand Down
62 changes: 47 additions & 15 deletions src/subcommand/haplotypes_main.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,15 +217,18 @@ void help_haplotypes(char** argv, bool developer_options) {
std::cerr << " --kmer-length N kmer length for building the minimizer index (default: " << haplotypes_default_k() << ")" << std::endl;
std::cerr << " --window-length N window length for building the minimizer index (default: " << haplotypes_default_w() << ")" << std::endl;
std::cerr << " --subchain-length N target length (in bp) for subchains (default: " << haplotypes_default_subchain_length() << ")" << std::endl;
std::cerr << " --linear-structure extend subchains to avoid haplotypes visiting them multiple times" << std::endl;
std::cerr << std::endl;
std::cerr << "Options for sampling haplotypes:" << std::endl;
std::cerr << " --preset X use preset X (default, haploid, diploid)" << std::endl;
std::cerr << " --coverage N kmer coverage in the KFF file (default: estimate)" << std::endl;
std::cerr << " --num-haplotypes N generate N haplotypes (default: " << haplotypes_default_n() << ")" << std::endl;
std::cerr << " sample from N candidates (with --diploid-sampling; default: " << haplotypes_default_candidates() << ")" << std::endl;
std::cerr << " --present-discount F discount scores for present kmers by factor F (default: " << haplotypes_default_discount() << ")" << std::endl;
std::cerr << " --het-adjustment F adjust scores for heterozygous kmers by F (default: " << haplotypes_default_adjustment() << ")" << std::endl;
std::cerr << " --absent-score F score absent kmers -F/+F (default: " << haplotypes_default_absent() << ")" << std::endl;
std::cerr << " --diploid-sampling choose the best pair from the greedily selected haplotypes" << std::endl;
std::cerr << " --haploid-scoring use a scoring model without heterozygous kmers" << std::endl;
std::cerr << " --diploid-sampling choose the best pair from the sampled haplotypes" << std::endl;
std::cerr << " --include-reference include named and reference paths in the output" << std::endl;
std::cerr << std::endl;
std::cerr << "Other options:" << std::endl;
Expand All @@ -250,13 +253,16 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) {
constexpr int OPT_KMER_LENGTH = 1200;
constexpr int OPT_WINDOW_LENGTH = 1201;
constexpr int OPT_SUBCHAIN_LENGTH = 1202;
constexpr int OPT_COVERAGE = 1300;
constexpr int OPT_NUM_HAPLOTYPES = 1301;
constexpr int OPT_PRESENT_DISCOUNT = 1302;
constexpr int OPT_HET_ADJUSTMENT = 1303;
constexpr int OPT_ABSENT_SCORE = 1304;
constexpr int OPT_DIPLOID_SAMPLING = 1305;
constexpr int OPT_INCLUDE_REFERENCE = 1306;
constexpr int OPT_LINEAR_STRUCTURE = 1203;
constexpr int OPT_PRESET = 1300;
constexpr int OPT_COVERAGE = 1301;
constexpr int OPT_NUM_HAPLOTYPES = 1302;
constexpr int OPT_PRESENT_DISCOUNT = 1303;
constexpr int OPT_HET_ADJUSTMENT = 1304;
constexpr int OPT_ABSENT_SCORE = 1305;
constexpr int OPT_HAPLOID_SCORING = 1306;
constexpr int OPT_DIPLOID_SAMPLING = 1307;
constexpr int OPT_INCLUDE_REFERENCE = 1308;
constexpr int OPT_VALIDATE = 1400;
constexpr int OPT_VCF_INPUT = 1500;
constexpr int OPT_CONTIG_PREFIX = 1501;
Expand All @@ -275,11 +281,14 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) {
{ "kmer-length", required_argument, 0, OPT_KMER_LENGTH },
{ "window-length", required_argument, 0, OPT_WINDOW_LENGTH },
{ "subchain-length", required_argument, 0, OPT_SUBCHAIN_LENGTH },
{ "linear-structure", no_argument, 0, OPT_LINEAR_STRUCTURE },
{ "preset", required_argument, 0, OPT_PRESET },
{ "coverage", required_argument, 0, OPT_COVERAGE },
{ "num-haplotypes", required_argument, 0, OPT_NUM_HAPLOTYPES },
{ "present-discount", required_argument, 0, OPT_PRESENT_DISCOUNT },
{ "het-adjustment", required_argument, 0, OPT_HET_ADJUSTMENT },
{ "absent-score", required_argument, 0, OPT_ABSENT_SCORE },
{ "haploid-scoring", no_argument, 0, OPT_HAPLOID_SCORING },
{ "diploid-sampling", no_argument, 0, OPT_DIPLOID_SAMPLING },
{ "include-reference", no_argument, 0, OPT_INCLUDE_REFERENCE },
{ "verbosity", required_argument, 0, 'v' },
Expand Down Expand Up @@ -345,6 +354,27 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) {
std::exit(EXIT_FAILURE);
}
break;
case OPT_LINEAR_STRUCTURE:
this->partitioner_parameters.linear_structure = true;
break;

case OPT_PRESET:
{
Recombinator::Parameters::preset_t preset;
if (std::string(optarg) == "default") {
preset = Recombinator::Parameters::preset_default;
} else if (std::string(optarg) == "haploid") {
preset = Recombinator::Parameters::preset_haploid;
} else if (std::string(optarg) == "diploid") {
preset = Recombinator::Parameters::preset_diploid;
} else {
std::cerr << "error: [vg haplotypes] unknown preset: " << optarg << std::endl;
std::exit(EXIT_FAILURE);
}
this->recombinator_parameters = Recombinator::Parameters(preset);
num_haplotypes_set = true; // The preset is assumed to include the number of haplotypes.
break;
}
case OPT_COVERAGE:
this->recombinator_parameters.coverage = parse<size_t>(optarg);
break;
Expand Down Expand Up @@ -377,6 +407,9 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) {
std::exit(EXIT_FAILURE);
}
break;
case OPT_HAPLOID_SCORING:
this->recombinator_parameters.haploid_scoring = true;
break;
case OPT_DIPLOID_SAMPLING:
this->recombinator_parameters.diploid_sampling = true;
break;
Expand Down Expand Up @@ -574,10 +607,10 @@ void validate_subgraph(const gbwtgraph::GBWTGraph& graph, const gbwtgraph::GBWTG

void sample_haplotypes(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, const HaplotypesConfig& config) {
omp_set_num_threads(threads_to_jobs(config.threads));
Recombinator recombinator(gbz, config.verbosity);
Recombinator recombinator(gbz, haplotypes, config.verbosity);
gbwt::GBWT merged;
try {
merged = recombinator.generate_haplotypes(haplotypes, config.kmer_input, config.recombinator_parameters);
merged = recombinator.generate_haplotypes(config.kmer_input, config.recombinator_parameters);
} catch (const std::runtime_error& e) {
std::cerr << "error: [vg haplotypes] " << e.what() << std::endl;
std::exit(EXIT_FAILURE);
Expand Down Expand Up @@ -896,12 +929,11 @@ void extract_haplotypes(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes,
std::cerr << "Extracting haplotypes from chain " << config.chain_id << ", subchain " << config.subchain_id << std::endl;
}

Recombinator recombinator(gbz, config.verbosity);
Recombinator recombinator(gbz, haplotypes, config.verbosity);
std::vector<Recombinator::LocalHaplotype> result;
try {
result = recombinator.extract_sequences(
haplotypes, config.kmer_input,
config.chain_id, config.subchain_id, config.recombinator_parameters
config.kmer_input, config.chain_id, config.subchain_id, config.recombinator_parameters
);
} catch (const std::runtime_error& e) {
std::cerr << "error: [vg haplotypes] " << e.what() << std::endl;
Expand Down Expand Up @@ -940,8 +972,8 @@ void classify_kmers(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, con
if (config.verbosity >= Haplotypes::verbosity_basic) {
std::cerr << "Classifying kmers" << std::endl;
}
Recombinator recombinator(gbz, config.verbosity);
std::vector<char> classifications = recombinator.classify_kmers(haplotypes, config.kmer_input, config.recombinator_parameters);
Recombinator recombinator(gbz, haplotypes, config.verbosity);
std::vector<char> classifications = recombinator.classify_kmers(config.kmer_input, config.recombinator_parameters);

if (config.verbosity >= Haplotypes::verbosity_basic) {
std::cerr << "Writing " << classifications.size() << " classifications to " << config.kmer_output << std::endl;
Expand Down
Loading

1 comment on commit 7009793

@adamnovak
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vg CI tests complete for merge to master. View the full report here.

16 tests passed, 0 tests failed and 0 tests skipped in 17433 seconds

Please sign in to comment.