From db622175bcf810257379b932965bf93b0506ba9a Mon Sep 17 00:00:00 2001 From: Jouni Siren Date: Sat, 4 May 2024 21:45:13 -0700 Subject: [PATCH 1/8] Simplify WFAExtender --- src/gbwt_extender.cpp | 577 +++++++++++++++++------------------------- src/gbwt_extender.hpp | 3 - 2 files changed, 226 insertions(+), 354 deletions(-) diff --git a/src/gbwt_extender.cpp b/src/gbwt_extender.cpp index bec78296a8d..a3694fc8cdd 100644 --- a/src/gbwt_extender.cpp +++ b/src/gbwt_extender.cpp @@ -1,4 +1,5 @@ #include "gbwt_extender.hpp" +#include "hash_map.hpp" #include #include @@ -1385,109 +1386,79 @@ struct WFAPoint { //------------------------------------------------------------------------------ -/// Represents a node in the tree of haplotypes we are traversing and doing WFA -/// against. -/// -/// Will have WFANode::find_pos() called against it as part of a loop for each -/// diagonal, and if it doesn't answer its parent will be queried, recursively -/// back to the root. If we allow the number of WFANode objects along a -/// non-branching path to be linear in the sequence length, then we will make -/// O(sequence length) calls for each diagonal, and we end up getting O(n^2) -/// (or worse?) lookups. -/// -/// So, it is essential that we allow one WFANode to stand for a whole -/// non-branching run of haplotypes, up to about the total sequence length we -/// will be working on. This limits the number of recursive queries of parents -/// so it grows only with the number of haplotypes we are aligning against, -/// which is bounded, and not directly with the sequence length. +/* + WFANode corresponds to a path in a graph and a set of haplotypes that follow + the path. We store the path itself, the GBWT search state at the end of the + path, and the concatenation of node sequences. We additionally store links + to the parent and child nodes in WFATree, using offsets in the vector of + nodes. + + The path continues until the haplotypes branch, it exceeds 1024 bp, or it + covers the target position. +*/ struct WFANode { - /// This tracks the GBWT search states for all graph nodes we visit that - /// have been coalesced into this WFANode - std::vector states; - /// And this tracks the GBWT packed nodes (ID and orientation) that are - /// visited, and maps to start offset, for O(1) query. It cannot have - /// duplicates. - std::unordered_map starts_by_node; - /// And this tracks the start offsets of each in our sequence space. - /// TODO: Replace with something O(1) - std::map states_by_start; - /// Total length - size_t stored_length; + // The path in the graph. + gbwt::vector_type path; + + // GBWT search state at the end of the path. + gbwt::SearchState state; + + // Concatenation of node sequences. + std::string node_sequence; // Offsets in the vector of nodes. - uint32_t parent; - std::vector children; + std::uint32_t parent; + std::vector children; + + // Offset for the target position in the sequence. + std::uint32_t target_offset; // All haplotypes end here. bool dead_end; + constexpr static size_t TARGET_LENGTH = 1024; + constexpr static size_t MATCHES = 0; constexpr static size_t INSERTIONS = 1; // characters in the sequence but not in the graph constexpr static size_t DELETIONS = 2; // characters in the graph but not in the sequence // Points on the wavefronts are indexed by score, diagonal. - std::array, 3> wavefronts; - - WFANode(const vector& states, uint32_t parent, const gbwtgraph::GBWTGraph& graph) : - states(states), - starts_by_node(), - states_by_start(), - stored_length(0), - parent(parent), children(), - dead_end(false), - wavefronts() { - if (states.empty()) { - throw std::runtime_error("Cannot make a WFANode for nothing"); - } - - // Fill in the visited nodes set and the index from start position to node. - this->starts_by_node.reserve(states.size()); - for (size_t i = 0; i < this->states.size(); i++) { - // Remember that this node starts here - this->starts_by_node.emplace(this->states[i].node, stored_length); - // Remember that here starts this node - states_by_start[stored_length] = i; - -#ifdef debug_wfa - std::cerr << "State #" << i << " is GBWT encoded node " << this->states[i].node << " and starts at offset " << stored_length << std::endl; -#endif + std::array, 3> wavefronts; + + WFANode(const gbwtgraph::CachedGBWTGraph& graph, const gbwt::SearchState& state, pos_t target, std::uint32_t parent) : + parent(parent), target_offset(std::numeric_limits::max()), dead_end(false) + { + if(this->append_node(graph, state, target)) { + return; + } - // And up the start position - stored_length += graph.get_length(gbwtgraph::GBWTGraph::node_to_handle(this->states[i].node)); + while (this->node_sequence.length() < TARGET_LENGTH) { + size_t successors = 0; + gbwt::SearchState next_state; + graph.follow_paths(this->state, [&](const gbwt::SearchState& next) { + successors++; + next_state = next; + return true; + }); + if (successors == 0) { + this->dead_end = true; + break; + } else if (successors > 1) { + break; + } + if (this->append_node(graph, next_state, target)) { + break; + } } } + size_t length() const { return this->node_sequence.length(); } bool is_leaf() const { return (this->children.empty() || this->dead_end); } bool expanded() const { return (!this->children.empty() || this->dead_end); } - bool same_node(pos_t pos) const { - // See if we have seen anything on this node - gbwt::node_type lookup = gbwt::Node::encode(id(pos), is_rev(pos)); - bool is_here = starts_by_node.count(lookup); - - return is_here; - } - - /// Map from graph position to offset along the WFANode. - /// TODO: Having this requires that the WFANode never visits the same - /// oriented geaph node twice. Can we get away with not having this - /// somehow? - size_t node_offset_of(pos_t pos) const { - gbwt::node_type lookup = gbwt::Node::encode(id(pos), is_rev(pos)); - // Find where the referenced graph node starts in us - size_t start = starts_by_node.at(lookup); - // And then apply the offset - size_t result = start + offset(pos); - - return result; - } - - size_t length() const { - return stored_length; - } - // WFANode::find_pos - // Returns the position for the given score and diagonal with the given path, or an empty position if it does not exist. + // Returns the position for the given score and diagonal with the given path, + // or an empty position if it does not exist. MatchPos find_pos(size_t type, int32_t score, int32_t diagonal, const MatchPos::PathList& path) const { WFAPoint::key_type key { score, diagonal }; auto& points = this->wavefronts[type]; @@ -1516,259 +1487,157 @@ struct WFANode { auto iter = points.find(key); if (iter == points.end()) { // This is a new score and diagonal - points.emplace_hint(iter, std::move(key), std::move(value)); + points.emplace_hint(iter, key, value); } else { // This score and diagonal already exists, so overwrite the value - iter->second = std::move(value); + iter->second = value; } } - // Returns a position at the first non-match after the given position. - void match_forward(const std::string& sequence, const gbwtgraph::GBWTGraph& graph, MatchPos& pos) const { - - // Get first graph node starting after our offset. - std::map::const_iterator here = this->states_by_start.upper_bound(pos.node_offset); - if (here == this->states_by_start.begin()) { - // We are somehow starting before the first item (which should start at 0). This should never happen. - throw std::runtime_error("Offset on WFANode starts before its first graph node, which ought to be at 0"); - } - // Get last graph node starting at or before our offset. - --here; - - // We have the index of the state starting at or after the match pos. So it's the one the position is on. - while (here != this->states_by_start.end()) { - // Until we hit the end of the WFANode - - // Grab the handle for the graph node we are at - handle_t handle = gbwtgraph::GBWTGraph::node_to_handle(this->states[here->second].node); - - // And get a view of its sequence - gbwtgraph::view_type node_seq = graph.get_sequence_view(handle); - size_t graph_node_offset = pos.node_offset - here->first; - - while (pos.seq_offset < sequence.length() && graph_node_offset < node_seq.second && sequence[pos.seq_offset] == node_seq.first[graph_node_offset]) { - // Until we hit the end of the sequence, or the graph node, or a mismatch, advance - pos.seq_offset++; - pos.node_offset++; - graph_node_offset = pos.node_offset - here->first; - } - if (graph_node_offset >= node_seq.second) { - // We hit the end of a graph node. - // Advance to the next graph node. - ++here; - } else { - // We hit the end of the sequence, or a mismatch. - break; - } + // Advances the position to the first non-match at or after the current position. + void match_forward(const gbwtgraph::CachedGBWTGraph& graph, const std::string& sequence, MatchPos& pos) const { + while ( + pos.seq_offset < sequence.length() && + pos.node_offset < this->node_sequence.length() && + sequence[pos.seq_offset] == this->node_sequence[pos.node_offset] + ) { + pos.seq_offset++; + pos.node_offset++; } - } - // Returns a position at the start of the run of matches before the given position. - void match_backward(const std::string& sequence, const gbwtgraph::GBWTGraph& graph, MatchPos& pos) const { - - // Get first graph node starting after our offset. - std::map::const_iterator here = this->states_by_start.upper_bound(pos.node_offset); - if (here == this->states_by_start.begin()) { - // We are somehow starting before the first item (which should start at 0). This should never happen. - throw std::runtime_error("Offset on WFANode starts before its first graph node, which ought to be at 0"); - } - // Get last graph node starting at or before our offset. - --here; - - // We have the index of the state starting at or after the match pos. So it's the one the position is on. - while (pos.seq_offset > 0 && pos.node_offset > 0) { - // Until we hit the start of the WFANode - - // Grab the handle for the graph node we are at - handle_t handle = gbwtgraph::GBWTGraph::node_to_handle(this->states[here->second].node); - // And get a view of its sequence - gbwtgraph::view_type node_seq = graph.get_sequence_view(handle); - size_t graph_node_offset = pos.node_offset - here->first; - - while (pos.seq_offset > 0 && graph_node_offset > 0 && sequence[pos.seq_offset - 1] == node_seq.first[graph_node_offset - 1]) { - // Until we hit the start of the sequence, or the graph node, or a mismatch, go left - pos.seq_offset--; - pos.node_offset--; - graph_node_offset = pos.node_offset - here->first; - } - if (graph_node_offset == 0 && here->first != 0) { - // We hit the start of a graph node, but we could go left still. - // Go left to the next graph node. - --here; - } else { - // We hit the end of the sequence, or the end of the node, or a mismatch. - break; - } +private: + // Proceed to the given search state and append the path. Returns true if we reached the target. + bool append_node(const gbwtgraph::CachedGBWTGraph& graph, gbwt::SearchState next, pos_t target) { + this->state = next; + this->path.push_back(this->state.node); + gbwtgraph::view_type view = graph.get_sequence_view(gbwtgraph::GBWTGraph::node_to_handle(this->state.node)); + this->node_sequence.append(view.first, view.second); + if (gbwt::Node::encode(id(target), is_rev(target)) == this->state.node) { + this->target_offset = this->node_sequence.length() - (view.second - offset(target)); + return true; } + return false; } - }; //------------------------------------------------------------------------------ +/* + WFATree represents a trie of haplotypes starting from a given position in + the graph. The tree is expanded lazily as needed, and WFA alignment is done + over all haplotypes in the tree. Each node is a WFANode that represents a + non-branching set of haplotypes over a path in the graph. +*/ class WFATree { public: - const gbwtgraph::GBWTGraph& graph; + const gbwtgraph::CachedGBWTGraph graph; const std::string& sequence; - /// Each WFANode represents a run of graph nodes, as traversed by a set of haplotypes. + // Start and end positions in the graph (exclusive). + pos_t from, to; + + // Node identifiers are offsets in this vector. Node 0 is the root. std::vector nodes; - // Best alignment found so far. If we reached the destination in the graph, - // the score includes the implicit insertion at the end but the point itself - // does not. + // Best alignment found so far. If we reached the target position in the + // graph, the score includes the implicit insertion at the end but the + // position itself does not. WFAPoint candidate_point; - uint32_t candidate_node; + std::uint32_t candidate_node; // WFA score (penalty) parameters derived from the actual scoring parameters. - int32_t mismatch, gap_open, gap_extend; + std::int32_t mismatch, gap_open, gap_extend; // Stop if no alignment has been found with this score or less. - int32_t score_bound; + std::int32_t score_bound; struct ScoreProperties { - int32_t min_diagonal; - int32_t max_diagonal; + std::int32_t min_diagonal; + std::int32_t max_diagonal; bool reachable_with_gap; }; - // A set of possible scores and diagonals reached with them. - std::map possible_scores; + // A set of possible scores and diagonals reached with them. Because we + // derive the three WFA scoring parameters from four Aligner parameters + // (that include a match bonus), many small scores are impossible and can + // be skipped. + std::map possible_scores; // The overall closed range of diagonals reached. - std::pair max_diagonals; - - // TODO: Remove when unnecessary. - bool debug; - - WFATree(const gbwtgraph::GBWTGraph& graph, const std::string& sequence, const gbwt::SearchState& root, uint32_t node_offset, const Aligner& aligner, const WFAExtender::ErrorModel& error_model) : - graph(graph), sequence(sequence), + std::pair max_diagonals; + + WFATree( + const gbwtgraph::GBWTGraph& graph, const std::string& sequence, + pos_t from, pos_t to, + const Aligner& aligner, const WFAExtender::ErrorModel& error_model + ) : + graph(graph), sequence(sequence), from(from), to(to), nodes(), - candidate_point({ std::numeric_limits::max(), 0, 0, 0 }), candidate_node(0), + candidate_point({ std::numeric_limits::max(), 0, 0, 0 }), candidate_node(0), mismatch(2 * (aligner.match + aligner.mismatch)), gap_open(2 * (aligner.gap_open - aligner.gap_extension)), gap_extend(2 * aligner.gap_extension + aligner.match), score_bound(0), - possible_scores(), max_diagonals(0, 0), - debug(false) + possible_scores(), max_diagonals(0, 0) { - this->nodes.emplace_back(this->coalesce(root), 0, this->graph); - // No need to convert the node offset because it is from the root state's node start - this->nodes.front().update(WFANode::MATCHES, 0, 0, 0, node_offset); - - // Determine a reasonable upper bound for the number of edits. - int32_t max_mismatches = error_model.mismatches.evaluate(sequence.length()); - int32_t max_gaps = error_model.gaps.evaluate(sequence.length()); - int32_t max_gap_length = error_model.gap_length.evaluate(sequence.length()); + // Create the root node based on the starting position. Because the start + // is outside the alignment, we may already have exhausted the node. + handle_t handle = this->graph.get_handle(id(this->from), is_rev(this->from)); + gbwt::SearchState state = this->graph.get_state(handle); + WFANode root(this->graph, state, this->to, 0); + root.update(WFANode::MATCHES, 0, 0, 0, offset(this->from) + 1); + this->nodes.push_back(root); + + // Determine score bound based on the error model and sequence length. + std::int32_t max_mismatches = error_model.mismatches.evaluate(sequence.length()); + std::int32_t max_gaps = error_model.gaps.evaluate(sequence.length()); + std::int32_t max_gap_length = error_model.gap_length.evaluate(sequence.length()); this->score_bound = max_mismatches * this->mismatch + max_gaps * this->gap_open + max_gap_length * this->gap_extend; - possible_scores[0] = { 0, 0, false }; - } - - /// Get all the GBWT search states for a run of the same set of haplotypes - /// through nodes in the graph, without any haplotypes in the set branching - /// off, and without any visits to the same oriented graph node twice. - /// TODO: We can only visit each graph node once, or we can't map graph - /// pos_t values back to offsets along the WFANode. Do we need to be able - /// to do that, or can we try not doing that? - /// TODO: Save a scan by unifying with WFANode constructor? - vector coalesce(const gbwt::SearchState& start, size_t base_limit = 1024) { - vector coalesced {start}; - - std::unordered_set visited {start.node}; - gbwt::SearchState here = start; - gbwt::CachedGBWT cache = graph.get_cache(); // TODO: Take in cache? Is this even useful here? - // How many bases have we grabbed? - size_t coalesced_bases = 0; - // How many places did we have to pick from? - size_t options = 1; - while(options == 1) { - // Until we find multiple next places we could go - - // See how far we have come - handle_t node_handle = gbwtgraph::GBWTGraph::node_to_handle(here.node); - size_t node_length = graph.get_length(node_handle); - coalesced_bases += node_length; - if (coalesced_bases >= base_limit) { - // We don't want to look any more bases out; we might be - // wasting our time lloking further than the remaining read. - break; - } - - // If we want to keep going, see where we could go - options = 0; - gbwt::SearchState next; - graph.follow_paths(cache, here, [&](const gbwt::SearchState& reachable) { - options++; - if (options > 1) { - // We found bore than one place to go, so stop coalescing. - return false; - } - next = reachable; - return true; - }); - if (options == 1) { - // We found exactly one place to go. - - if (visited.count(next.node)) { - // We can't go there, we would cycle within a WFANode and - // break mapping from graph position to WFANode offset - break; - } - visited.insert(next.node); - - // Some haplotypes may have dropped out, but it is OK to keep - // coalescing because others did not. - // Go there. - here = next; - coalesced.push_back(here); - } - } - - return coalesced; + this->possible_scores[0] = { 0, 0, false }; } - uint32_t size() const { return this->nodes.size(); } - static bool is_root(uint32_t node) { return (node == 0); } - uint32_t parent(uint32_t node) const { return this->nodes[node].parent; } + std::uint32_t size() const { return this->nodes.size(); } + static bool is_root(std::uint32_t node) { return (node == 0); } + uint32_t parent(std::uint32_t node) const { return this->nodes[node].parent; } // Assumes length > 0. - int32_t gap_extend_penalty(uint32_t length) const { - return static_cast(length) * this->gap_extend; + std::int32_t gap_extend_penalty(std::uint32_t length) const { + return std::int32_t(length) * this->gap_extend; } // Assumes length > 0. - int32_t gap_penalty(uint32_t length) const { + std::int32_t gap_penalty(std::uint32_t length) const { return this->gap_open + this->gap_extend_penalty(length); } // wf_extend() in the paper. // If we reach the end of a node, we continue to the start of the next node even // if we do not use any characters in it. - void extend(int32_t score, pos_t to) { - for (int32_t diagonal = this->max_diagonals.first; diagonal <= this->max_diagonals.second; diagonal++) { - - std::vector leaves = this->get_leaves(); - this->extend_over(score, diagonal, to, leaves); + void extend(std::int32_t score) { + for (std::int32_t diagonal = this->max_diagonals.first; diagonal <= this->max_diagonals.second; diagonal++) { + std::vector leaves = this->get_leaves(); + this->extend_over(score, diagonal, leaves); } } - // Returns the next possible score after the given score. Also updates the set + // Returns the next possible score after the given score, assuming that the + // given score has already been determined possible. Also updates the set // of possible scores with those reachable from the given score but does not // set the diagonal ranges for them. - int32_t next_score(int32_t match_score) { - - int32_t mismatch_score = match_score + this->mismatch; + std::int32_t next_score(std::int32_t match_score) { + // This score + a mismatch is a possible score. + std::int32_t mismatch_score = match_score + this->mismatch; if (this->possible_scores.find(mismatch_score) == this->possible_scores.end()) { - this->possible_scores[mismatch_score] = { 0, 0, false }; } - // We assume that match_score is a valid score. + // This score + gap extend is a possible score reachable by a gap, if + // this score was also reachable by a gap. auto match_iter = this->possible_scores.find(match_score); if (match_iter->second.reachable_with_gap) { - int32_t extend_score = match_score + this->gap_extend; + std::int32_t extend_score = match_score + this->gap_extend; auto extend_iter = this->possible_scores.find(extend_score); if (extend_iter != this->possible_scores.end()) { extend_iter->second.reachable_with_gap = true; @@ -1777,13 +1646,12 @@ class WFATree { } } - int32_t open_score = match_score + this->gap_open + this->gap_extend; + // This score + gap open + gap extend is a possible score reachable by a gap. + std::int32_t open_score = match_score + this->gap_open + this->gap_extend; auto open_iter = this->possible_scores.find(open_score); if (open_iter != this->possible_scores.end()) { - open_iter->second.reachable_with_gap = true; } else { - this->possible_scores[open_score] = { 0, 0, true }; } @@ -1795,12 +1663,12 @@ class WFATree { // wf_next() in the paper. // If we reach the end of a node, we continue to the start of the next node even // if we do not use any characters in it. - void next(int32_t score, pos_t to) { - std::pair diagonal_range = this->get_diagonals(score); - for (int32_t diagonal = diagonal_range.first; diagonal <= diagonal_range.second; diagonal++) { - std::vector leaves = this->get_leaves(); - // Note that we may do the same update from multiple leaves. - for (uint32_t leaf : leaves) { + void next(std::int32_t score) { + std::pair diagonal_range = this->get_diagonals(score); + for (std::int32_t diagonal = diagonal_range.first; diagonal <= diagonal_range.second; diagonal++) { + std::vector leaves = this->get_leaves(); + // NOTE: We may do the same updates from multiple leaves. + for (std::uint32_t leaf : leaves) { MatchPos ins = this->ins_predecessor(leaf, score, diagonal).first; if (!ins.empty()) { ins.seq_offset++; @@ -1838,9 +1706,10 @@ class WFATree { // If the edit is an insertion, we charge the gap open cost again, but // we already got the same insertion without the extra cost from the // match preceding the insertion. - if (this->nodes[subst.node()].same_node(to) && subst.node_offset == this->nodes[subst.node()].node_offset_of(to)) { - uint32_t gap_length = this->sequence.length() - subst.seq_offset; - int32_t gap_score = 0; + WFANode& node = this->nodes[subst.node()]; + if (subst.node_offset == node.target_offset) { + std::uint32_t gap_length = this->sequence.length() - subst.seq_offset; + std::int32_t gap_score = 0; if (gap_length > 0) { gap_score = this->gap_penalty(gap_length); } @@ -1849,7 +1718,7 @@ class WFATree { this->candidate_node = subst.node(); } } - this->nodes[subst.node()].update(WFANode::MATCHES, score, diagonal, subst); + node.update(WFANode::MATCHES, score, diagonal, subst); } } } @@ -1917,16 +1786,16 @@ class WFATree { // Replaces the candidate with the partial alignment with the highest alignment // score according to the aligner. void trim(const Aligner& aligner) { - this->candidate_point = { 0, 0, 0, 0}; + this->candidate_point = { 0, 0, 0, 0 }; this->candidate_node = 0; - int32_t best_score = 0; - for (uint32_t node = 0; node < this->size(); node++) { + std::int32_t best_score = 0; + for (std::uint32_t node = 0; node < this->size(); node++) { for (auto& point_entry : this->nodes[node].wavefronts[WFANode::MATCHES]) { // Scan all stored points on the node. // TODO: Does iteration order matter? // Convert map entries to points. WFAPoint point = WFAPoint::from_map_entry(point_entry); - int32_t alignment_score = point.alignment_score(aligner); + std::int32_t alignment_score = point.alignment_score(aligner); if (alignment_score > best_score) { // This is a new winner this->candidate_point = point; @@ -1940,59 +1809,62 @@ class WFATree { private: // wf_extend() on a specific diagonal for the set of (local) haplotypes corresponding to - // the given list of leaves in the tree of GBWT search states. - void extend_over(int32_t score, int32_t diagonal, pos_t to, const std::vector& leaves) { - for (uint32_t leaf : leaves) { - + // the given list of leaves in the WFATree. + void extend_over(std::int32_t score, std::int32_t diagonal, const std::vector& leaves) { + for (std::uint32_t leaf : leaves) { MatchPos pos = this->find_pos(WFANode::MATCHES, leaf, score, diagonal, false, false); if (pos.empty()) { continue; // An impossible score / diagonal combination. } while (true) { - // We want to determine if we could reach our fixed destination point, if it exists - bool may_reach_to; - // And if so, where it would be along this WFANode. - uint32_t to_offset; - if (this->nodes[pos.node()].same_node(to)) { - // Work out where we would have to go - to_offset = this->nodes[pos.node()].node_offset_of(to); - // And if we can get there - may_reach_to = this->nodes[pos.node()].same_node(to) && (pos.node_offset <= to_offset); - } else { - // We can't get there, it's not on this WFANode. - may_reach_to = false; + WFANode& node = this->nodes[pos.node()]; + bool may_reach_target = false; + if (node.target_offset >= pos.node_offset && node.target_offset < node.length()) { + // This node covers the target, and we have not matched the offset yet. + may_reach_target = true; } - - this->nodes[pos.node()].match_forward(this->sequence, this->graph, pos); + node.match_forward(this->graph, this->sequence, pos); // We got a match that reached the end or went past it. // Alternatively there is no end position and we have aligned the entire sequence. // This gives us a candidate where the rest of the sequence is an insertion. - if ((may_reach_to && pos.node_offset >= to_offset) || (no_pos(to) && pos.seq_offset >= this->sequence.length())) { - uint32_t overshoot = (no_pos(to) ? 0 : pos.node_offset - to_offset); - uint32_t gap_length = (this->sequence.length() - pos.seq_offset) + overshoot; - int32_t gap_score = 0; + if ( + (may_reach_target && pos.node_offset >= node.target_offset) || + (no_pos(this->to) && pos.seq_offset >= this->sequence.length()) + ) { + // If we managed to match the target position, it is part of the overshoot. + std::uint32_t overshoot = (no_pos(to) ? 0 : pos.node_offset - node.target_offset); + std::uint32_t gap_length = (this->sequence.length() - pos.seq_offset) + overshoot; + std::int32_t gap_score = 0; if (gap_length > 0) { gap_score = this->gap_penalty(gap_length); } if (score + gap_score < this->candidate_point.score) { - this->candidate_point = { score + gap_score, diagonal, pos.seq_offset - overshoot, to_offset }; + this->candidate_point = { + score + gap_score, diagonal, + pos.seq_offset - overshoot, node.target_offset + }; this->candidate_node = pos.node(); } } - this->nodes[pos.node()].update(WFANode::MATCHES, score, diagonal, pos); - if (pos.node_offset < this->nodes[pos.node()].length()) { + + // We may have matched some additional bases in the node. + node.update(WFANode::MATCHES, score, diagonal, pos); + if (pos.node_offset < node.length()) { break; } - this->expand_if_necessary(pos); + + // We reached the end of the node. If the position does not specify the + // path we should follow, we continue recursively in all children. + this->expand_if_necessary(pos); // NOTE: Possibly invalidates `node`. if (pos.at_last_node()) { - // We have exhausted the path leading to the current leaf. Make a copy of the children - // of the leaf (the actual list may be invalidated by further expansions) and continue - // aligning over them. - std::vector new_leaves = this->nodes[leaf].children; - this->extend_over(score, diagonal, to, new_leaves); + assert(pos.node() == leaf); + // Create a copy of the child list, because further pushes to the + // node vector may invalidate the reference. + std::vector new_leaves = this->nodes[pos.node()].children; + this->extend_over(score, diagonal, new_leaves); break; } pos.pop(); @@ -2001,9 +1873,9 @@ class WFATree { } } - std::vector get_leaves() const { - std::vector leaves; - for (uint32_t node = 0; node < this->size(); node++) { + std::vector get_leaves() const { + std::vector leaves; + for (std::uint32_t node = 0; node < this->size(); node++) { if (this->nodes[node].is_leaf()) { leaves.push_back(node); } @@ -2011,7 +1883,10 @@ class WFATree { return leaves; } - std::pair update_range(std::pair range, int32_t score) const { + std::pair update_diagonal_range( + std::pair range, + std::int32_t score) const + { if (score >= 0) { auto iter = this->possible_scores.find(score); if (iter != this->possible_scores.end()) { @@ -2022,15 +1897,15 @@ class WFATree { return range; } - // Determines the diagonal range for the given score and store it in possible_scores. + // Determines the diagonal range for the given score and stores it in possible_scores. // Assumes that the score is valid. Updates max_diagonals. // Returns an empty range if the score is impossible. std::pair get_diagonals(int32_t score) { // Determine the diagonal range for the given score. std::pair range(1, -1); - range = this->update_range(range, score - this->mismatch); // Mismatch. - range = this->update_range(range, score - this->gap_open - this->gap_extend); // New gap. - range = this->update_range(range, score - this->gap_extend); // Extend an existing gap. + range = this->update_diagonal_range(range, score - this->mismatch); // Mismatch. + range = this->update_diagonal_range(range, score - this->gap_open - this->gap_extend); // New gap. + range = this->update_diagonal_range(range, score - this->gap_extend); // Extend an existing gap. if (range.first > range.second) { return range; } @@ -2048,18 +1923,20 @@ class WFATree { // If we have reached the end of the current node, expand its children if necessary. // Call this whenever the alignment advances in the node. void expand_if_necessary(const MatchPos& pos) { - if (this->nodes[pos.node()].expanded() || pos.node_offset < this->nodes[pos.node()].length()) { + // NOTE: Pushes to the node vector may invalidate references to the node. + std::uint32_t node = pos.node(); + if (this->nodes[node].expanded() || pos.node_offset < this->nodes[node].length()) { return; } bool found = false; - this->graph.follow_paths(this->nodes[pos.node()].states.back(), [&](const gbwt::SearchState& child) -> bool { - this->nodes[pos.node()].children.push_back(this->size()); - this->nodes.emplace_back(this->coalesce(child), pos.node(), this->graph); + this->graph.follow_paths(this->nodes[node].state, [&](const gbwt::SearchState& child) -> bool { + this->nodes[node].children.push_back(this->size()); + this->nodes.emplace_back(this->graph, child, this->to, node); found = true; return true; }); if (!found) { - this->nodes[pos.node()].dead_end = true; + this->nodes[node].dead_end = true; } } @@ -2078,7 +1955,6 @@ class WFATree { path.push(node); // Find a position at this node. // The MatchPos will need to know the whole path, so we can return it. - // TODO: Actually manage the moves ourselves to make this faster! MatchPos pos = this->nodes[node].find_pos(type, score, diagonal, path); if (!pos.empty()) { if (extendable_seq && pos.seq_offset >= this->sequence.length()) { @@ -2089,7 +1965,7 @@ class WFATree { } return pos; } - if (is_root(node)) { + if (this->is_root(node)) { return MatchPos(); } node = this->parent(node); @@ -2104,6 +1980,8 @@ class WFATree { //------------------------------------------------------------------------------ +//#define debug_connect + WFAAlignment WFAExtender::connect(std::string sequence, pos_t from, pos_t to) const { if (this->graph == nullptr || this->aligner == nullptr) { #ifdef debug_connect @@ -2111,24 +1989,22 @@ WFAAlignment WFAExtender::connect(std::string sequence, pos_t from, pos_t to) co #endif return WFAAlignment(); } - gbwt::SearchState root_state = this->graph->get_state(this->graph->get_handle(id(from), is_rev(from))); - if (root_state.empty()) { + if (!this->graph->has_node(id(from))) { #ifdef debug_connect - std::cerr << "No root state! Returning empty alignment!" << std::endl; + std::cerr << "No start node! Returning empty alignment!" << std::endl; #endif return WFAAlignment(); } - this->mask(sequence); - WFATree tree(*(this->graph), sequence, root_state, offset(from) + 1, *(this->aligner), *(this->error_model)); - tree.debug = this->debug; + this->mask(sequence); + WFATree tree(*(this->graph), sequence, from, to, *(this->aligner), *(this->error_model)); - int32_t score = 0; + std::int32_t score = 0; while (true) { #ifdef debug_connect std::cerr << "Extend for score " << score << std::endl; #endif - tree.extend(score, to); + tree.extend(score); if (tree.candidate_point.score <= score) { break; @@ -2146,14 +2022,14 @@ WFAAlignment WFAExtender::connect(std::string sequence, pos_t from, pos_t to) co #ifdef debug_connect std::cerr << "Next for score " << score << std::endl; #endif - tree.next(score, to); + tree.next(score); } // If we do not have a full-length alignment within the score bound, // we find the best partial alignment if there was no destination or // return an empty alignment otherwise. bool full_length = true; - uint32_t unaligned_tail = sequence.length() - tree.candidate_point.seq_offset; + std::uint32_t unaligned_tail = sequence.length() - tree.candidate_point.seq_offset; if (tree.candidate_point.score > tree.score_bound) { #ifdef debug_connect std::cerr << "No alignment could be found under score bound of " << tree.score_bound << "; best found was " << tree.candidate_point.score << std::endl; @@ -2170,17 +2046,16 @@ WFAAlignment WFAExtender::connect(std::string sequence, pos_t from, pos_t to) co // Start building an alignment. Store the path first. // No need to convert the node offset because it is from the root state's node start. WFAAlignment result { - {}, {}, static_cast(offset(from) + 1), 0, + {}, {}, std::uint32_t(offset(from) + 1), 0, tree.candidate_point.seq_offset + unaligned_tail, tree.candidate_point.alignment_score(*(this->aligner), unaligned_tail), true }; - uint32_t node = tree.candidate_node; + std::uint32_t node = tree.candidate_node; while (true) { // Go back up the tree and compose the path in reverse order - for (auto it = tree.nodes[node].states.rbegin(); it != tree.nodes[node].states.rend(); ++it) { - // Visit all the states in each WFANode and put their graph nodes on the path in reverse order. - result.path.push_back(gbwtgraph::GBWTGraph::node_to_handle(it->node)); + for (auto iter = tree.nodes[node].path.rbegin(); iter != tree.nodes[node].path.rend(); ++iter) { + result.path.push_back(gbwtgraph::GBWTGraph::node_to_handle(*iter)); } if (tree.is_root(node)) { // Stop when we reach the root @@ -2195,7 +2070,7 @@ WFAAlignment WFAExtender::connect(std::string sequence, pos_t from, pos_t to) co WFAPoint point = tree.candidate_point; node = tree.candidate_node; if (unaligned_tail > 0) { - uint32_t final_insertion = sequence.length() - tree.candidate_point.seq_offset; + std::uint32_t final_insertion = sequence.length() - tree.candidate_point.seq_offset; result.append(WFAAlignment::insertion, final_insertion); point.score -= tree.gap_penalty(unaligned_tail); } diff --git a/src/gbwt_extender.hpp b/src/gbwt_extender.hpp index e7286a284dc..5c8cf419b3e 100644 --- a/src/gbwt_extender.hpp +++ b/src/gbwt_extender.hpp @@ -434,9 +434,6 @@ class WFAExtender { ReadMasker mask; const Aligner* aligner; const ErrorModel* error_model; - - /// TODO: Remove when unnecessary. - bool debug = false; }; //------------------------------------------------------------------------------ From eb26e32a4f31396a46f59e25d66d3b188df49c86 Mon Sep 17 00:00:00 2001 From: Jouni Siren Date: Thu, 9 May 2024 20:36:56 -0700 Subject: [PATCH 2/8] Ignore points that have fallen too far behind in WFAExtender --- src/gbwt_extender.cpp | 113 +++++++++++++++++++++++---------- src/gbwt_extender.hpp | 16 ++++- src/unittest/gbwt_extender.cpp | 25 +++++--- 3 files changed, 112 insertions(+), 42 deletions(-) diff --git a/src/gbwt_extender.cpp b/src/gbwt_extender.cpp index a3694fc8cdd..9a46b21e25a 100644 --- a/src/gbwt_extender.cpp +++ b/src/gbwt_extender.cpp @@ -1209,12 +1209,10 @@ namespace vg { //------------------------------------------------------------------------------ const WFAExtender::ErrorModel WFAExtender::default_error_model { - // Mismatches (per base, plus min, cap at max) - {0.03, 1, 6}, - // Gaps - {0.05, 1, 10}, - // Gap length - {0.1, 1, 20} + WFAExtender::ErrorModel::default_mismatches(), + WFAExtender::ErrorModel::default_gaps(), + WFAExtender::ErrorModel::default_gap_length(), + WFAExtender::ErrorModel::default_distance() }; WFAExtender::WFAExtender() : @@ -1258,6 +1256,8 @@ struct MatchPos { const static size_t NUM_INLINE = 4; size_t item_count = 0; uint32_t inline_items[NUM_INLINE]; + // TODO: There are probably some obscure situations where we could have a bug + // because this is a shared pointer. Document them. std::shared_ptr> additional_items; void push(uint32_t value) { @@ -1316,11 +1316,17 @@ struct MatchPos { uint32_t node() const { return this->path.top(); } void pop() { this->path.pop(); } + // Returns the distance from the start of the alignment as seq_offset + target_offset, + // assuming that we are on the given diagonal. + std::int32_t distance(std::int32_t diagonal) { + return 2 * std::int32_t(this->seq_offset) - diagonal; + } + // Positions are ordered by sequence offsets. Empty positions are smaller than // non-empty ones. bool operator<(const MatchPos& another) { if (this->empty()) { - return true; + return !another.empty(); } if (another.empty()) { return false; @@ -1495,7 +1501,7 @@ struct WFANode { } // Advances the position to the first non-match at or after the current position. - void match_forward(const gbwtgraph::CachedGBWTGraph& graph, const std::string& sequence, MatchPos& pos) const { + void match_forward(const std::string& sequence, MatchPos& pos) const { while ( pos.seq_offset < sequence.length() && pos.node_offset < this->node_sequence.length() && @@ -1552,6 +1558,12 @@ class WFATree { // Stop if no alignment has been found with this score or less. std::int32_t score_bound; + // Furthest distance (seq_offset + target_offset) reached so far. + std::int32_t max_distance; + + // Minimum distance we are still considering. + std::int32_t min_distance; + struct ScoreProperties { std::int32_t min_diagonal; std::int32_t max_diagonal; @@ -1564,9 +1576,6 @@ class WFATree { // be skipped. std::map possible_scores; - // The overall closed range of diagonals reached. - std::pair max_diagonals; - WFATree( const gbwtgraph::GBWTGraph& graph, const std::string& sequence, pos_t from, pos_t to, @@ -1578,8 +1587,8 @@ class WFATree { mismatch(2 * (aligner.match + aligner.mismatch)), gap_open(2 * (aligner.gap_open - aligner.gap_extension)), gap_extend(2 * aligner.gap_extension + aligner.match), - score_bound(0), - possible_scores(), max_diagonals(0, 0) + score_bound(0), max_distance(0), min_distance(0), + possible_scores() { // Create the root node based on the starting position. Because the start // is outside the alignment, we may already have exhausted the node. @@ -1616,7 +1625,12 @@ class WFATree { // If we reach the end of a node, we continue to the start of the next node even // if we do not use any characters in it. void extend(std::int32_t score) { - for (std::int32_t diagonal = this->max_diagonals.first; diagonal <= this->max_diagonals.second; diagonal++) { + auto iter = this->possible_scores.find(score); + if (iter == this->possible_scores.end()) { + // TODO: This should not happen. Should there be a warning/error? + return; + } + for (std::int32_t diagonal = iter->second.min_diagonal; diagonal <= iter->second.max_diagonal; diagonal++) { std::vector leaves = this->get_leaves(); this->extend_over(score, diagonal, leaves); } @@ -1660,11 +1674,16 @@ class WFATree { return match_iter->first; } - // wf_next() in the paper. + // wf_next() in the paper, with a variant of wf_reduce(). // If we reach the end of a node, we continue to the start of the next node even // if we do not use any characters in it. void next(std::int32_t score) { + // diagonal_range is the potential range of diagonals for this score, based on the + // diagonals of previous scores. std::pair diagonal_range = this->get_diagonals(score); + // actual_range is the range of diagonals with points that reach far enough + // with this score. + std::pair actual_range = empty_diagonal_range(); for (std::int32_t diagonal = diagonal_range.first; diagonal <= diagonal_range.second; diagonal++) { std::vector leaves = this->get_leaves(); // NOTE: We may do the same updates from multiple leaves. @@ -1672,13 +1691,19 @@ class WFATree { MatchPos ins = this->ins_predecessor(leaf, score, diagonal).first; if (!ins.empty()) { ins.seq_offset++; - this->nodes[ins.node()].update(WFANode::INSERTIONS, score, diagonal, ins); + if (ins.distance(diagonal) >= this->min_distance) { + this->nodes[ins.node()].update(WFANode::INSERTIONS, score, diagonal, ins); + adjust_diagonal_range(actual_range, diagonal); + } } MatchPos del = this->del_predecessor(leaf, score, diagonal).first; if (!del.empty()) { this->successor_offset(del); - this->nodes[del.node()].update(WFANode::DELETIONS, score, diagonal, del); + if (del.distance(diagonal) >= this->min_distance) { + this->nodes[del.node()].update(WFANode::DELETIONS, score, diagonal, del); + adjust_diagonal_range(actual_range, diagonal); + } this->expand_if_necessary(del); } @@ -1690,14 +1715,11 @@ class WFATree { } // Determine the edit that reaches furthest on the diagonal. - bool is_insertion = false; if (subst < ins) { subst = std::move(ins); - is_insertion = true; } if (subst < del) { subst = std::move(del); - is_insertion = false; } if (!subst.empty()) { @@ -1718,10 +1740,20 @@ class WFATree { this->candidate_node = subst.node(); } } - node.update(WFANode::MATCHES, score, diagonal, subst); + if (subst.distance(diagonal) >= this->min_distance) { + node.update(WFANode::MATCHES, score, diagonal, subst); + adjust_diagonal_range(actual_range, diagonal); + } } } } + + // Set the diagonal range for the score. + auto iter = this->possible_scores.find(score); + if (iter != this->possible_scores.end()) { + iter->second.min_diagonal = actual_range.first; + iter->second.max_diagonal = actual_range.second; + } } // Returns the predecessor position for the furthest reaching insertion for @@ -1814,7 +1846,6 @@ class WFATree { for (std::uint32_t leaf : leaves) { MatchPos pos = this->find_pos(WFANode::MATCHES, leaf, score, diagonal, false, false); if (pos.empty()) { - continue; // An impossible score / diagonal combination. } while (true) { @@ -1824,7 +1855,7 @@ class WFATree { // This node covers the target, and we have not matched the offset yet. may_reach_target = true; } - node.match_forward(this->graph, this->sequence, pos); + node.match_forward(this->sequence, pos); // We got a match that reached the end or went past it. // Alternatively there is no end position and we have aligned the entire sequence. @@ -1851,6 +1882,7 @@ class WFATree { } // We may have matched some additional bases in the node. + this->max_distance = std::max(this->max_distance, pos.distance(diagonal)); node.update(WFANode::MATCHES, score, diagonal, pos); if (pos.node_offset < node.length()) { break; @@ -1873,6 +1905,7 @@ class WFATree { } } + // TODO: Can we trim subtrees with no points far enough? std::vector get_leaves() const { std::vector leaves; for (std::uint32_t node = 0; node < this->size(); node++) { @@ -1883,6 +1916,14 @@ class WFATree { return leaves; } + // Adjusts the diagonal range to include the given diagonal. + static void adjust_diagonal_range(std::pair& range, std::int32_t diagonal) { + range.first = std::min(range.first, diagonal); + range.second = std::max(range.second, diagonal); + } + + // Updates the given diagonal range to include all diagonals covered by the + // range for the given score. std::pair update_diagonal_range( std::pair range, std::int32_t score) const @@ -1897,26 +1938,23 @@ class WFATree { return range; } - // Determines the diagonal range for the given score and stores it in possible_scores. - // Assumes that the score is valid. Updates max_diagonals. + static std::pair empty_diagonal_range() { + return std::make_pair(std::numeric_limits::max(), std::numeric_limits::min()); + } + + // Determines the diagonal range for the given score. + // Assumes that the score is valid. // Returns an empty range if the score is impossible. std::pair get_diagonals(int32_t score) { // Determine the diagonal range for the given score. - std::pair range(1, -1); + std::pair range = empty_diagonal_range(); range = this->update_diagonal_range(range, score - this->mismatch); // Mismatch. range = this->update_diagonal_range(range, score - this->gap_open - this->gap_extend); // New gap. range = this->update_diagonal_range(range, score - this->gap_extend); // Extend an existing gap. if (range.first > range.second) { return range; } - range.first--; range.second++; - this->max_diagonals.first = std::min(this->max_diagonals.first, range.first); - this->max_diagonals.second = std::max(this->max_diagonals.second, range.second); - auto iter = this->possible_scores.find(score); - iter->second.min_diagonal = range.first; - iter->second.max_diagonal = range.second; - return range; } @@ -2006,6 +2044,15 @@ WFAAlignment WFAExtender::connect(std::string sequence, pos_t from, pos_t to) co #endif tree.extend(score); + // Update the minimum distance for points we will still consider. + std::int32_t distance_band = this->error_model->distance.evaluate(sequence.length()); + if (distance_band < tree.max_distance) { + tree.min_distance = tree.max_distance - distance_band; + } +#ifdef debug_connect + std::cerr << "Max distance is " << tree.max_distance << ", min distance is " << tree.min_distance << std::endl; +#endif + if (tree.candidate_point.score <= score) { break; } diff --git a/src/gbwt_extender.hpp b/src/gbwt_extender.hpp index 5c8cf419b3e..362b8571a5e 100644 --- a/src/gbwt_extender.hpp +++ b/src/gbwt_extender.hpp @@ -366,13 +366,27 @@ class WFAExtender { return std::min(max, (int32_t)(per_base * length) + min); } }; - + /// Limits for mismatches Event mismatches; /// Limits for total gaps (*not* gap opens; a gap open uses 1 gap and 1 gap length) Event gaps; /// Limits for total gap length (gap extends plus gap opens) Event gap_length; + /// Limits for alignments that have fallen too far behind. + Event distance; + + /// Default error model for mismatches. + constexpr static Event default_mismatches() { return { 0.03, 1, 6 }; } + + /// Default error model for gaps. + constexpr static Event default_gaps() { return { 0.05, 1, 10 }; } + + /// Default error model for gap length. + constexpr static Event default_gap_length() { return { 0.1, 1, 20 }; } + + /// Default error model for distance. + constexpr static Event default_distance() { return { 0.1, 10, 200 }; } }; /// If not specified, we use this default error model. diff --git a/src/unittest/gbwt_extender.cpp b/src/unittest/gbwt_extender.cpp index d51b356a664..6d7151c7177 100644 --- a/src/unittest/gbwt_extender.cpp +++ b/src/unittest/gbwt_extender.cpp @@ -2465,6 +2465,7 @@ TEST_CASE("Connect with a non-diverging multi-node cycle", "[wfa_extender]") { //------------------------------------------------------------------------------ +// TODO: Can we do meaningful tests with the distance parameters? TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { // Create the structures for graph 1: CGC, 2: GATTACA, 3: GATTA, 4: TAT gbwt::GBWT index = wfa_linear_gbwt(); @@ -2479,7 +2480,8 @@ TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { WFAExtender::ErrorModel errors { {0, 1, 1}, {0, 0, 0}, - {0, 0, 0} + {0, 0, 0}, + WFAExtender::ErrorModel::default_distance(), }; WFAExtender extender(graph, aligner, errors); @@ -2496,7 +2498,8 @@ TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { WFAExtender::ErrorModel errors { {0, 1, 1}, {0, 0, 0}, - {0, 0, 0} + {0, 0, 0}, + WFAExtender::ErrorModel::default_distance(), }; WFAExtender extender(graph, aligner, errors); @@ -2512,7 +2515,8 @@ TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { WFAExtender::ErrorModel errors { {0, 0, 0}, {0, 1, 1}, - {0, 0, 0} + {0, 0, 0}, + WFAExtender::ErrorModel::default_distance(), }; WFAExtender extender(graph, aligner, errors); @@ -2528,7 +2532,8 @@ TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { WFAExtender::ErrorModel errors { {0, 0, 0}, {0, 1, 1}, - {0, 1, 1} + {0, 1, 1}, + WFAExtender::ErrorModel::default_distance(), }; WFAExtender extender(graph, aligner, errors); @@ -2545,7 +2550,8 @@ TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { WFAExtender::ErrorModel errors { {0, 0, 0}, {0, 1, 1}, - {0, 1, 1} + {0, 1, 1}, + WFAExtender::ErrorModel::default_distance(), }; WFAExtender extender(graph, aligner, errors); @@ -2561,7 +2567,8 @@ TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { WFAExtender::ErrorModel errors { {0, 0, 0}, {0, 1, 1}, - {0, 2, 2} + {0, 2, 2}, + WFAExtender::ErrorModel::default_distance(), }; WFAExtender extender(graph, aligner, errors); @@ -2578,7 +2585,8 @@ TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { WFAExtender::ErrorModel errors { {0, 0, 0}, {0, 1, 1}, - {0, 2, 2} + {0, 2, 2}, + WFAExtender::ErrorModel::default_distance(), }; WFAExtender extender(graph, aligner, errors); @@ -2594,7 +2602,8 @@ TEST_CASE("WFA score caps constrain returned alignments", "[wfa_extender]") { WFAExtender::ErrorModel errors { {0, 0, 0}, {0, 2, 2}, - {0, 2, 2} + {0, 2, 2}, + WFAExtender::ErrorModel::default_distance(), }; WFAExtender extender(graph, aligner, errors); From 64fe857a4e56a5bb06a2d6a57de8c8dc247fb0dc Mon Sep 17 00:00:00 2001 From: Jouni Siren Date: Mon, 10 Jun 2024 23:43:00 -0700 Subject: [PATCH 3/8] Non-greedy subchain boundaries --- src/recombinator.cpp | 91 ++++++++++++++++++++++++++++++++++++++------ src/recombinator.hpp | 21 ++++++++-- 2 files changed, 97 insertions(+), 15 deletions(-) diff --git a/src/recombinator.cpp b/src/recombinator.cpp index b91454f07e4..9da1def2e95 100644 --- a/src/recombinator.cpp +++ b/src/recombinator.cpp @@ -430,6 +430,29 @@ size_t HaplotypePartitioner::get_distance(handle_t from, handle_t to) const { ); } +bool HaplotypePartitioner::contains_reversals(handle_t handle) const { + gbwt::node_type forward = gbwtgraph::GBWTGraph::handle_to_node(handle); + std::vector forward_da = this->r_index.decompressDA(forward); + std::sort(forward_da.begin(), forward_da.end()); + + gbwt::node_type reverse = gbwt::Node::reverse(forward); + std::vector reverse_da = this->r_index.decompressDA(reverse); + std::sort(reverse_da.begin(), reverse_da.end()); + + auto fw_iter = forward_da.begin(); + auto rv_iter = reverse_da.begin(); + while (fw_iter != forward_da.end() && rv_iter != reverse_da.end()) { + if (*fw_iter == *rv_iter) { + return true; + } else if (*fw_iter < *rv_iter) { + ++fw_iter; + } else { + ++rv_iter; + } + } + return false; +} + std::vector HaplotypePartitioner::get_subchains(const gbwtgraph::TopLevelChain& chain, const Parameters& parameters) const { std::vector result; @@ -447,16 +470,16 @@ HaplotypePartitioner::get_subchains(const gbwtgraph::TopLevelChain& chain, const if (was_snarl) { if (!has_start) { // If the chain starts with a snarl, we take it as a prefix. - snarls.push_back({ Haplotypes::Subchain::prefix, empty_gbwtgraph_handle(), handle }); + snarls.push_back({ Haplotypes::Subchain::prefix, empty_gbwtgraph_handle(), handle, 0, 0 }); } else { size_t distance = this->get_distance(snarl_start, handle); if (distance < std::numeric_limits::max()) { // Normal snarl with two boundary nodes. - snarls.push_back({ Haplotypes::Subchain::normal, snarl_start, handle }); + snarls.push_back({ Haplotypes::Subchain::normal, snarl_start, handle, 0, 0 }); } else { // The snarl is not connected, so we break it into two. - snarls.push_back({ Haplotypes::Subchain::suffix, snarl_start, empty_gbwtgraph_handle() }); - snarls.push_back({ Haplotypes::Subchain::prefix, empty_gbwtgraph_handle(), handle }); + snarls.push_back({ Haplotypes::Subchain::suffix, snarl_start, empty_gbwtgraph_handle(), 0, 0 }); + snarls.push_back({ Haplotypes::Subchain::prefix, empty_gbwtgraph_handle(), handle, 0, 0 }); } } } @@ -479,7 +502,7 @@ HaplotypePartitioner::get_subchains(const gbwtgraph::TopLevelChain& chain, const } if (was_snarl && has_start) { // If the chain ends with a snarl, we take it as a suffix. - snarls.push_back({ Haplotypes::Subchain::suffix, snarl_start, empty_gbwtgraph_handle() }); + snarls.push_back({ Haplotypes::Subchain::suffix, snarl_start, empty_gbwtgraph_handle(), 0, 0 }); } // Second pass: Combine snarls into subchains. @@ -492,18 +515,35 @@ HaplotypePartitioner::get_subchains(const gbwtgraph::TopLevelChain& chain, const continue; } size_t tail = head; + std::uint32_t distance = this->get_distance(snarls[head].start, snarls[tail].end); + std::uint32_t extra_snarls = 0; while (tail + 1 < snarls.size()) { if (snarls[tail + 1].type != Haplotypes::Subchain::normal) { break; } size_t candidate = this->get_distance(snarls[head].start, snarls[tail + 1].end); - if (candidate <= parameters.subchain_length) { - tail++; - } else { - break; + if (candidate > parameters.subchain_length) { + // Including the next snarl would exceed target length. But if a haplotype visits + // the tail in both orientations, it flips the orientation in a subsequent subchain, + // returns back, flips again, and eventually continues forward. In such situations, + // sampling minimal haplotypes within this subchain would lead to sequence loss, + // while sampling maximal haplotypes could make some kmers specific to the next + // subchain shared with haplotypes in this subchain. We therefore move forward until + // we can make the subchain contain the reversals. + if (this->contains_reversals(snarls[tail].end)) { + extra_snarls++; + } else { + break; + } } + tail++; + distance = candidate; } - result.push_back({ Haplotypes::Subchain::normal, snarls[head].start, snarls[tail].end }); + result.push_back({ + Haplotypes::Subchain::normal, + snarls[head].start, snarls[tail].end, + distance, extra_snarls + }); head = tail + 1; } @@ -715,15 +755,42 @@ void present_kmers(const std::vector subchains = this->get_subchains(chain, parameters); + if (this->verbosity >= Haplotypes::verbosity_debug) { + size_t long_subchains = 0, with_extra_snarls = 0, extra_snarls = 0; + for (const Subchain& subchain : subchains) { + if (subchain.length > parameters.subchain_length) { + long_subchains++; + } + if (subchain.extra_snarls > 0) { + with_extra_snarls++; + extra_snarls += subchain.extra_snarls; + } + } + #pragma omp critical + { + std::cerr << "Chain " << chain.offset << ": " << long_subchains << " long subchains (" + << with_extra_snarls << " with " << extra_snarls << " extra snarls)" << std::endl; + } + } + + // Convert the subchains to actual subchains. for (const Subchain& subchain : subchains) { std::vector>> to_process; auto sequences = this->get_sequences(subchain); if (sequences.empty()) { // There are no haplotypes crossing the subchain, so we break it into // a suffix and a prefix. - to_process.push_back({ { Haplotypes::Subchain::suffix, subchain.start, empty_gbwtgraph_handle() }, this->get_sequences(subchain.start) }); - to_process.push_back({ { Haplotypes::Subchain::prefix, empty_gbwtgraph_handle(), subchain.end }, this->get_sequences(subchain.end) }); + to_process.push_back({ + { Haplotypes::Subchain::suffix, subchain.start, empty_gbwtgraph_handle(), 0, 0 }, + this->get_sequences(subchain.start) + }); + to_process.push_back({ + { Haplotypes::Subchain::prefix, empty_gbwtgraph_handle(), subchain.end, 0, 0 }, + this->get_sequences(subchain.end) + }); } else { to_process.push_back({ subchain, std::move(sequences) }); } diff --git a/src/recombinator.hpp b/src/recombinator.hpp index 36f03ead7bd..d132446d08b 100644 --- a/src/recombinator.hpp +++ b/src/recombinator.hpp @@ -269,6 +269,14 @@ class HaplotypePartitioner { /// End node. handle_t end; + /// Shortest distance from the last base of `start` to the first base of `end`, + /// if both are present. + std::uint32_t length; + + /// Number of additional snarls included in the subchain to keep reversals + /// within the subchain. + std::uint32_t extra_snarls; + /// Returns `true` if the subchain has a start node. bool has_start() const { return (this->type == Haplotypes::Subchain::normal || this->type == Haplotypes::Subchain::suffix); } @@ -302,9 +310,13 @@ class HaplotypePartitioner { * Each top-level chain is partitioned into subchains that consist of one or * more snarls. Multiple snarls are combined into the same subchain if the * minimum distance over the subchain is at most the target length and there - * are GBWT haplotypes that cross the subchain. If there are no snarls in a - * top-level chain, it is represented as a single subchain without boundary - * nodes. + * are GBWT haplotypes that cross the subchain. We also keep extending the + * subchain if a haplotype would cross the end in both directions. By doing + * this, we can avoid sequence loss with haplotypes reversing their direction, + * while keeping kmers specific to each subchain. + * + * If there are no snarls in a top-level chain, it is represented as a single + * subchain without boundary nodes. * * Haplotypes crossing each subchain are represented using minimizers with a * single occurrence in the graph. @@ -325,6 +337,9 @@ class HaplotypePartitioner { // Return the minimum distance from the last base of `from` to the first base of `to`. size_t get_distance(handle_t from, handle_t to) const; + // Returns true if a haplotype visits the node in both orientations. + bool contains_reversals(handle_t handle) const; + // Partition the top-level chain into subchains. std::vector get_subchains(const gbwtgraph::TopLevelChain& chain, const Parameters& parameters) const; From 9e2ef48c20836580c9e9927c1395686ea16ca028 Mon Sep 17 00:00:00 2001 From: Jouni Siren Date: Mon, 1 Jul 2024 17:51:57 +0200 Subject: [PATCH 4/8] Move Haplotypes to Recombinator --- src/recombinator.cpp | 40 ++++++++++++++---------------- src/recombinator.hpp | 25 +++++++++++++------ src/subcommand/giraffe_main.cpp | 4 +-- src/subcommand/haplotypes_main.cpp | 13 +++++----- 4 files changed, 44 insertions(+), 38 deletions(-) diff --git a/src/recombinator.cpp b/src/recombinator.cpp index 9da1def2e95..9d451dba073 100644 --- a/src/recombinator.cpp +++ b/src/recombinator.cpp @@ -1082,8 +1082,8 @@ std::ostream& Recombinator::Statistics::print(std::ostream& out) const { //------------------------------------------------------------------------------ -Recombinator::Recombinator(const gbwtgraph::GBZ& gbz, Verbosity verbosity) : - gbz(gbz), verbosity(verbosity) +Recombinator::Recombinator(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, Verbosity verbosity) : + gbz(gbz), haplotypes(haplotypes), verbosity(verbosity) { } @@ -1197,13 +1197,13 @@ double get_or_estimate_coverage( return coverage; } -gbwt::GBWT Recombinator::generate_haplotypes(const Haplotypes& haplotypes, const std::string& kff_file, const Parameters& parameters) const { +gbwt::GBWT Recombinator::generate_haplotypes(const std::string& kff_file, const Parameters& parameters) const { // Sanity checks (may throw). recombinator_sanity_checks(parameters); // Get kmer counts (may throw) and determine coverage. - hash_map counts = haplotypes.kmer_counts(kff_file, this->verbosity); + hash_map counts = this->haplotypes.kmer_counts(kff_file, this->verbosity); double coverage = get_or_estimate_coverage(counts, parameters, this->verbosity); double start = gbwt::readTimer(); @@ -1212,19 +1212,19 @@ gbwt::GBWT Recombinator::generate_haplotypes(const Haplotypes& haplotypes, const } // Determine construction jobs. - std::vector> jobs(haplotypes.jobs()); - for (auto& chain : haplotypes.chains) { - if (chain.job_id < haplotypes.jobs()) { + std::vector> jobs(this->haplotypes.jobs()); + for (auto& chain : this->haplotypes.chains) { + if (chain.job_id < this->haplotypes.jobs()) { jobs[chain.job_id].push_back(chain.offset); } } // Figure out GBWT path ids for reference paths in each job. - std::vector> reference_paths(haplotypes.jobs()); + std::vector> reference_paths(this->haplotypes.jobs()); if (parameters.include_reference) { for (size_t i = 0; i < this->gbz.graph.named_paths.size(); i++) { - size_t job_id = haplotypes.jobs_for_cached_paths[i]; - if (job_id < haplotypes.jobs()) { + size_t job_id = this->haplotypes.jobs_for_cached_paths[i]; + if (job_id < this->haplotypes.jobs()) { reference_paths[job_id].push_back(this->gbz.graph.named_paths[i].id); } } @@ -1253,7 +1253,7 @@ gbwt::GBWT Recombinator::generate_haplotypes(const Haplotypes& haplotypes, const for (auto chain_id : jobs[job]) { try { Statistics chain_statistics = this->generate_haplotypes( - haplotypes.chains[chain_id], counts, builder, metadata, parameters, coverage + this->haplotypes.chains[chain_id], counts, builder, metadata, parameters, coverage ); job_statistics.combine(chain_statistics); } catch (const std::runtime_error& e) { @@ -1348,17 +1348,16 @@ std::vector> classify_kmers( } std::vector Recombinator::classify_kmers( - const Haplotypes& haplotypes, const std::string& kff_file, const Recombinator::Parameters& parameters ) const { // Get kmer counts (may throw) and determine coverage. - hash_map counts = haplotypes.kmer_counts(kff_file, this->verbosity); + hash_map counts = this->haplotypes.kmer_counts(kff_file, this->verbosity); double coverage = get_or_estimate_coverage(counts, parameters, this->verbosity); // Classify the kmers in each subchain. std::vector classifications; - classifications.reserve(haplotypes.kmers()); - for (const auto& chain : haplotypes.chains) { + classifications.reserve(this->haplotypes.kmers()); + for (const auto& chain : this->haplotypes.chains) { for (const auto& subchain : chain.subchains) { std::vector> kmer_types = vg::classify_kmers( subchain, counts, coverage, nullptr, parameters @@ -1613,15 +1612,14 @@ Recombinator::Statistics Recombinator::generate_haplotypes(const Haplotypes::Top //------------------------------------------------------------------------------ std::vector Recombinator::extract_sequences( - const Haplotypes& haplotypes, const std::string& kff_file, - size_t chain_id, size_t subchain_id, const Parameters& parameters + const std::string& kff_file, size_t chain_id, size_t subchain_id, const Parameters& parameters ) const { // Sanity checks. - if (chain_id >= haplotypes.chains.size()) { + if (chain_id >= this->haplotypes.chains.size()) { std::string msg = "Recombinator::extract_sequences(): invalid chain id " + std::to_string(chain_id); throw std::runtime_error(msg); } - if (subchain_id >= haplotypes.chains[chain_id].subchains.size()) { + if (subchain_id >= this->haplotypes.chains[chain_id].subchains.size()) { std::string msg = "Recombinator::extract_sequences(): invalid subchain id " + std::to_string(subchain_id) + " in chain " + std::to_string(chain_id); throw std::runtime_error(msg); @@ -1629,7 +1627,7 @@ std::vector Recombinator::extract_sequences( recombinator_sanity_checks(parameters); // Extract the haplotypes. - const Haplotypes::Subchain& subchain = haplotypes.chains[chain_id].subchains[subchain_id]; + const Haplotypes::Subchain& subchain = this->haplotypes.chains[chain_id].subchains[subchain_id]; std::vector result(subchain.sequences.size()); for (size_t i = 0; i < subchain.sequences.size(); i++) { size_t path_id = gbwt::Path::id(subchain.sequences[i].first); @@ -1648,7 +1646,7 @@ std::vector Recombinator::extract_sequences( } // Get kmer counts (may throw) and determine coverage. - hash_map counts = haplotypes.kmer_counts(kff_file, this->verbosity); + hash_map counts = this->haplotypes.kmer_counts(kff_file, this->verbosity); double coverage = get_or_estimate_coverage(counts, parameters, this->verbosity); // Fill in the scores. diff --git a/src/recombinator.hpp b/src/recombinator.hpp index d132446d08b..c673e65e53a 100644 --- a/src/recombinator.hpp +++ b/src/recombinator.hpp @@ -373,6 +373,16 @@ class HaplotypePartitioner { //------------------------------------------------------------------------------ +/* + TODO: + * Parameters should include the functionality for classifying kmers, scoring + haplotypes, and samplign haplotypes. + + Models: + * Diploid (current) + * Haploid +*/ + /** * A class that creates synthetic haplotypes from a `Haplotypes` representation of * local haplotypes. @@ -447,7 +457,7 @@ class Recombinator { }; /// Creates a new `Recombinator`. - Recombinator(const gbwtgraph::GBZ& gbz, Verbosity verbosity); + Recombinator(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, Verbosity verbosity); /// Parameters for `generate_haplotypes()`. struct Parameters { @@ -484,8 +494,7 @@ class Recombinator { }; /** - * Generates haplotypes based on the given `Haplotypes` representation and - * the kmer counts in the given KFF file. + * Generates haplotypes based on the kmer counts in the given KFF file. * * Runs multiple GBWT construction jobs in parallel using OpenMP threads and * generates the specified number of haplotypes in each top-level chain @@ -500,7 +509,7 @@ class Recombinator { * Throws `std::runtime_error` on error in single-threaded parts and exits * with `std::exit(EXIT_FAILURE)` in multi-threaded parts. */ - gbwt::GBWT generate_haplotypes(const Haplotypes& haplotypes, const std::string& kff_file, const Parameters& parameters) const; + gbwt::GBWT generate_haplotypes(const std::string& kff_file, const Parameters& parameters) const; /// A local haplotype sequence within a single subchain. struct LocalHaplotype { @@ -525,23 +534,23 @@ class Recombinator { * * Throws `std::runtime_error` on error. */ - std::vector classify_kmers(const Haplotypes& haplotypes, const std::string& kff_file, const Parameters& parameters) const; + std::vector classify_kmers(const std::string& kff_file, const Parameters& parameters) const; /** * Extracts the local haplotypes in the given subchain. In addition to the * haplotype sequence, this also reports the name of the corresponding path * as well as (rank, score) for the haplotype in each round of haplotype - * selection. The number of rounds is `parameters.num_haplotyeps`, but if + * selection. The number of rounds is `parameters.num_haplotypes`, but if * the haplotype is selected earlier, it will not get further scores. * * Throws `std::runtime_error` on error. */ std::vector extract_sequences( - const Haplotypes& haplotypes, const std::string& kff_file, - size_t chain_id, size_t subchain_id, const Parameters& parameters + const std::string& kff_file, size_t chain_id, size_t subchain_id, const Parameters& parameters ) const; const gbwtgraph::GBZ& gbz; + const Haplotypes& haplotypes; Verbosity verbosity; private: diff --git a/src/subcommand/giraffe_main.cpp b/src/subcommand/giraffe_main.cpp index ed7fcdbf527..3d91d0a41b4 100644 --- a/src/subcommand/giraffe_main.cpp +++ b/src/subcommand/giraffe_main.cpp @@ -1669,14 +1669,14 @@ string sample_haplotypes(const vector>& indexes, string& ba // Sample haplotypes. Haplotypes::Verbosity verbosity = (progress ? Haplotypes::verbosity_basic : Haplotypes::verbosity_silent); - Recombinator recombinator(gbz, verbosity); + Recombinator recombinator(gbz, haplotypes, verbosity); Recombinator::Parameters parameters; parameters.num_haplotypes = Recombinator::NUM_CANDIDATES; parameters.diploid_sampling = true; parameters.include_reference = true; gbwt::GBWT sampled_gbwt; try { - sampled_gbwt = recombinator.generate_haplotypes(haplotypes, kff_file, parameters); + sampled_gbwt = recombinator.generate_haplotypes(kff_file, parameters); } catch (const std::runtime_error& e) { std::cerr << "error:[vg giraffe] Haplotype sampling failed: " << e.what() << std::endl; std::exit(EXIT_FAILURE); diff --git a/src/subcommand/haplotypes_main.cpp b/src/subcommand/haplotypes_main.cpp index 4a26f02c6c2..74ad0b0ceb7 100644 --- a/src/subcommand/haplotypes_main.cpp +++ b/src/subcommand/haplotypes_main.cpp @@ -574,10 +574,10 @@ void validate_subgraph(const gbwtgraph::GBWTGraph& graph, const gbwtgraph::GBWTG void sample_haplotypes(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, const HaplotypesConfig& config) { omp_set_num_threads(threads_to_jobs(config.threads)); - Recombinator recombinator(gbz, config.verbosity); + Recombinator recombinator(gbz, haplotypes, config.verbosity); gbwt::GBWT merged; try { - merged = recombinator.generate_haplotypes(haplotypes, config.kmer_input, config.recombinator_parameters); + merged = recombinator.generate_haplotypes(config.kmer_input, config.recombinator_parameters); } catch (const std::runtime_error& e) { std::cerr << "error: [vg haplotypes] " << e.what() << std::endl; std::exit(EXIT_FAILURE); @@ -896,12 +896,11 @@ void extract_haplotypes(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, std::cerr << "Extracting haplotypes from chain " << config.chain_id << ", subchain " << config.subchain_id << std::endl; } - Recombinator recombinator(gbz, config.verbosity); + Recombinator recombinator(gbz, haplotypes, config.verbosity); std::vector result; try { result = recombinator.extract_sequences( - haplotypes, config.kmer_input, - config.chain_id, config.subchain_id, config.recombinator_parameters + config.kmer_input, config.chain_id, config.subchain_id, config.recombinator_parameters ); } catch (const std::runtime_error& e) { std::cerr << "error: [vg haplotypes] " << e.what() << std::endl; @@ -940,8 +939,8 @@ void classify_kmers(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, con if (config.verbosity >= Haplotypes::verbosity_basic) { std::cerr << "Classifying kmers" << std::endl; } - Recombinator recombinator(gbz, config.verbosity); - std::vector classifications = recombinator.classify_kmers(haplotypes, config.kmer_input, config.recombinator_parameters); + Recombinator recombinator(gbz, haplotypes, config.verbosity); + std::vector classifications = recombinator.classify_kmers(config.kmer_input, config.recombinator_parameters); if (config.verbosity >= Haplotypes::verbosity_basic) { std::cerr << "Writing " << classifications.size() << " classifications to " << config.kmer_output << std::endl; From 0bd1c84aa0383b85fe5a0de86cf84fce43f0f2c1 Mon Sep 17 00:00:00 2001 From: Jouni Siren Date: Tue, 2 Jul 2024 10:23:38 +0200 Subject: [PATCH 5/8] Haploid scoring option for haplotype sampling --- src/recombinator.cpp | 43 +++++++++++++++++++++++++----- src/recombinator.hpp | 24 +++++++++-------- src/subcommand/haplotypes_main.cpp | 12 ++++++--- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/src/recombinator.cpp b/src/recombinator.cpp index 9d451dba073..a1bad0063b7 100644 --- a/src/recombinator.cpp +++ b/src/recombinator.cpp @@ -293,6 +293,10 @@ HaplotypePartitioner::HaplotypePartitioner(const gbwtgraph::GBZ& gbz, { } +void HaplotypePartitioner::Parameters::print(std::ostream& out) const { + out << "Partitioning parameters: target length " << this->subchain_length << " bp; " << this->approximate_jobs << " jobs" << std::endl; +} + //------------------------------------------------------------------------------ Haplotypes HaplotypePartitioner::partition_haplotypes(const Parameters& parameters) const { @@ -306,6 +310,9 @@ Haplotypes HaplotypePartitioner::partition_haplotypes(const Parameters& paramete std::string msg = "HaplotypePartitioner::partition_haplotypes(): number of jobs cannot be 0"; throw std::runtime_error(msg); } + if (this->verbosity >= Haplotypes::verbosity_detailed) { + parameters.print(std::cerr); + } Haplotypes result; result.header.k = this->minimizer_index.k(); @@ -523,6 +530,7 @@ HaplotypePartitioner::get_subchains(const gbwtgraph::TopLevelChain& chain, const } size_t candidate = this->get_distance(snarls[head].start, snarls[tail + 1].end); if (candidate > parameters.subchain_length) { + // TODO: We need an option for non-greedy boundaries. // Including the next snarl would exceed target length. But if a haplotype visits // the tail in both orientations, it flips the orientation in a subsequent subchain, // returns back, flips again, and eventually continues forward. In such situations, @@ -1087,6 +1095,26 @@ Recombinator::Recombinator(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotyp { } +void Recombinator::Parameters::print(std::ostream& out) const { + out << "Sampling parameters:" << std::endl; + if (this->haploid_scoring) { + out << "- haploid scoring (absent " << this->absent_score << ", present " << this->present_discount << ")" << std::endl; + } else { + out << "- diploid scoring (absent " << this->absent_score << ", het " << this->het_adjustment << ", present " << this->present_discount << ")" << std::endl; + } + if (this->coverage > 0) { + out << "- kmer coverage " << this->coverage << std::endl; + } + if (this->diploid_sampling) { + out << "- diploid sampling (" << this->num_haplotypes << " candidates)" << std::endl; + } else { + out << "- heuristic sampling (" << this->num_haplotypes << " haplotypes)" << std::endl; + } + if (this->include_reference) { + out << "- include reference paths" << std::endl; + } +} + //------------------------------------------------------------------------------ void add_path(const gbwt::GBWT& source, gbwt::size_type path_id, gbwt::GBWTBuilder& builder, gbwtgraph::MetadataBuilder& metadata) { @@ -1162,9 +1190,9 @@ double get_or_estimate_coverage( << ", mode " << statistics.mode; } - // If mode < median, try to find a secondary peak at ~2x mode and use - // it if it is good enough. - if (statistics.mode < statistics.median) { + // In the default (non-haploid) scoring model, if mode < median, we try + // to find a secondary peak at ~2x mode and use it if it is good enough. + if (statistics.mode < statistics.median && !parameters.haploid_scoring) { size_t low = 1.7 * statistics.mode, high = 2.3 * statistics.mode; size_t peak = count_to_frequency[coverage]; size_t best = low, secondary = count_to_frequency[low]; @@ -1202,6 +1230,10 @@ gbwt::GBWT Recombinator::generate_haplotypes(const std::string& kff_file, const // Sanity checks (may throw). recombinator_sanity_checks(parameters); + if (this->verbosity >= Haplotypes::verbosity_detailed) { + parameters.print(std::cerr); + } + // Get kmer counts (may throw) and determine coverage. hash_map counts = this->haplotypes.kmer_counts(kff_file, this->verbosity); double coverage = get_or_estimate_coverage(counts, parameters, this->verbosity); @@ -1317,9 +1349,6 @@ std::vector> classify_kmers( double heterozygous_threshold = coverage / std::log(4.0); double homozygous_threshold = coverage * 2.5; - // TODO: -log prob may be the right score once we have enough haplotypes, but - // right now +1 works better, because we don't have haplotypes with the right - // combination of rare kmers. // Determine the type of each kmer in the sample and the score for the kmer. // A haplotype with the kmer gets +1.0 * score, while a haplotype without it // gets -1.0 * score. @@ -1330,7 +1359,7 @@ std::vector> classify_kmers( if (count < absent_threshold) { kmer_types.push_back({ Recombinator::absent, -1.0 * parameters.absent_score }); selected_kmers++; - } else if (count < heterozygous_threshold) { + } else if (count < heterozygous_threshold && !parameters.haploid_scoring) { kmer_types.push_back({ Recombinator::heterozygous, 0.0 }); selected_kmers++; } else if (count < homozygous_threshold) { diff --git a/src/recombinator.hpp b/src/recombinator.hpp index c673e65e53a..90f37bb252a 100644 --- a/src/recombinator.hpp +++ b/src/recombinator.hpp @@ -296,8 +296,11 @@ class HaplotypePartitioner { /// Target length for subchains (in bp). size_t subchain_length = SUBCHAIN_LENGTH; - /// Generate approximately this many jobs. + /// Generate approximately this many jobs. size_t approximate_jobs = APPROXIMATE_JOBS; + + /// Print a description of the parameters. + void print(std::ostream& out) const; }; /** @@ -373,16 +376,6 @@ class HaplotypePartitioner { //------------------------------------------------------------------------------ -/* - TODO: - * Parameters should include the functionality for classifying kmers, scoring - haplotypes, and samplign haplotypes. - - Models: - * Diploid (current) - * Haploid -*/ - /** * A class that creates synthetic haplotypes from a `Haplotypes` representation of * local haplotypes. @@ -460,6 +453,7 @@ class Recombinator { Recombinator(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, Verbosity verbosity); /// Parameters for `generate_haplotypes()`. + /// TODO: We should have a single parameter for the scoring/sampling model. struct Parameters { /// Number of haplotypes to be generated, or the number of candidates /// for diploid sampling. @@ -485,12 +479,20 @@ class Recombinator { /// the wrong variants out. double absent_score = ABSENT_SCORE; + /// Use the haploid scoring model. The most common kmer count is used as + /// the coverage estimate. Kmers that would be classified as heterozygous + /// are treated as homozygous. + bool haploid_scoring = false; + /// After selecting the initial `num_haplotypes` haplotypes, choose the /// highest-scoring pair out of them. bool diploid_sampling = false; /// Include named and reference paths. bool include_reference = false; + + /// Print a description of the parameters. + void print(std::ostream& out) const; }; /** diff --git a/src/subcommand/haplotypes_main.cpp b/src/subcommand/haplotypes_main.cpp index 74ad0b0ceb7..361a3ae0a9a 100644 --- a/src/subcommand/haplotypes_main.cpp +++ b/src/subcommand/haplotypes_main.cpp @@ -225,7 +225,8 @@ void help_haplotypes(char** argv, bool developer_options) { std::cerr << " --present-discount F discount scores for present kmers by factor F (default: " << haplotypes_default_discount() << ")" << std::endl; std::cerr << " --het-adjustment F adjust scores for heterozygous kmers by F (default: " << haplotypes_default_adjustment() << ")" << std::endl; std::cerr << " --absent-score F score absent kmers -F/+F (default: " << haplotypes_default_absent() << ")" << std::endl; - std::cerr << " --diploid-sampling choose the best pair from the greedily selected haplotypes" << std::endl; + std::cerr << " --haploid-scoring use a scoring model without heterozygous kmers" << std::endl; + std::cerr << " --diploid-sampling choose the best pair from the sampled haplotypes" << std::endl; std::cerr << " --include-reference include named and reference paths in the output" << std::endl; std::cerr << std::endl; std::cerr << "Other options:" << std::endl; @@ -255,8 +256,9 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { constexpr int OPT_PRESENT_DISCOUNT = 1302; constexpr int OPT_HET_ADJUSTMENT = 1303; constexpr int OPT_ABSENT_SCORE = 1304; - constexpr int OPT_DIPLOID_SAMPLING = 1305; - constexpr int OPT_INCLUDE_REFERENCE = 1306; + constexpr int OPT_HAPLOID_SCORING = 1305; + constexpr int OPT_DIPLOID_SAMPLING = 1306; + constexpr int OPT_INCLUDE_REFERENCE = 1307; constexpr int OPT_VALIDATE = 1400; constexpr int OPT_VCF_INPUT = 1500; constexpr int OPT_CONTIG_PREFIX = 1501; @@ -280,6 +282,7 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { { "present-discount", required_argument, 0, OPT_PRESENT_DISCOUNT }, { "het-adjustment", required_argument, 0, OPT_HET_ADJUSTMENT }, { "absent-score", required_argument, 0, OPT_ABSENT_SCORE }, + { "haploid-scoring", no_argument, 0, OPT_HAPLOID_SCORING }, { "diploid-sampling", no_argument, 0, OPT_DIPLOID_SAMPLING }, { "include-reference", no_argument, 0, OPT_INCLUDE_REFERENCE }, { "verbosity", required_argument, 0, 'v' }, @@ -377,6 +380,9 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { std::exit(EXIT_FAILURE); } break; + case OPT_HAPLOID_SCORING: + this->recombinator_parameters.haploid_scoring = true; + break; case OPT_DIPLOID_SAMPLING: this->recombinator_parameters.diploid_sampling = true; break; From 65aaf718468f2716f2e0ba5250434a8897fbc9d6 Mon Sep 17 00:00:00 2001 From: Jouni Siren Date: Tue, 2 Jul 2024 10:59:42 +0200 Subject: [PATCH 6/8] Make non-greedy subchain boundaries optional --- src/recombinator.cpp | 10 +++++++--- src/recombinator.hpp | 4 ++++ src/subcommand/haplotypes_main.cpp | 6 ++++++ 3 files changed, 17 insertions(+), 3 deletions(-) diff --git a/src/recombinator.cpp b/src/recombinator.cpp index a1bad0063b7..6a25fba32e3 100644 --- a/src/recombinator.cpp +++ b/src/recombinator.cpp @@ -294,7 +294,12 @@ HaplotypePartitioner::HaplotypePartitioner(const gbwtgraph::GBZ& gbz, } void HaplotypePartitioner::Parameters::print(std::ostream& out) const { - out << "Partitioning parameters: target length " << this->subchain_length << " bp; " << this->approximate_jobs << " jobs" << std::endl; + out << "Partitioning parameters:" << std::endl; + out << "- target length " << this->subchain_length << " bp" << std::endl; + if (this->linear_structure) { + out << "- strictly linear structure" << std::endl; + } + out << "- " << this->approximate_jobs << " jobs" << std::endl; } //------------------------------------------------------------------------------ @@ -530,7 +535,6 @@ HaplotypePartitioner::get_subchains(const gbwtgraph::TopLevelChain& chain, const } size_t candidate = this->get_distance(snarls[head].start, snarls[tail + 1].end); if (candidate > parameters.subchain_length) { - // TODO: We need an option for non-greedy boundaries. // Including the next snarl would exceed target length. But if a haplotype visits // the tail in both orientations, it flips the orientation in a subsequent subchain, // returns back, flips again, and eventually continues forward. In such situations, @@ -538,7 +542,7 @@ HaplotypePartitioner::get_subchains(const gbwtgraph::TopLevelChain& chain, const // while sampling maximal haplotypes could make some kmers specific to the next // subchain shared with haplotypes in this subchain. We therefore move forward until // we can make the subchain contain the reversals. - if (this->contains_reversals(snarls[tail].end)) { + if (parameters.linear_structure && this->contains_reversals(snarls[tail].end)) { extra_snarls++; } else { break; diff --git a/src/recombinator.hpp b/src/recombinator.hpp index 90f37bb252a..ae160907b8a 100644 --- a/src/recombinator.hpp +++ b/src/recombinator.hpp @@ -299,6 +299,10 @@ class HaplotypePartitioner { /// Generate approximately this many jobs. size_t approximate_jobs = APPROXIMATE_JOBS; + /// Avoid placing subchain boundaries in places where haplotypes would + /// cross them multiple times. + bool linear_structure = false; + /// Print a description of the parameters. void print(std::ostream& out) const; }; diff --git a/src/subcommand/haplotypes_main.cpp b/src/subcommand/haplotypes_main.cpp index 361a3ae0a9a..1b07e47a741 100644 --- a/src/subcommand/haplotypes_main.cpp +++ b/src/subcommand/haplotypes_main.cpp @@ -217,6 +217,7 @@ void help_haplotypes(char** argv, bool developer_options) { std::cerr << " --kmer-length N kmer length for building the minimizer index (default: " << haplotypes_default_k() << ")" << std::endl; std::cerr << " --window-length N window length for building the minimizer index (default: " << haplotypes_default_w() << ")" << std::endl; std::cerr << " --subchain-length N target length (in bp) for subchains (default: " << haplotypes_default_subchain_length() << ")" << std::endl; + std::cerr << " --linear-structure extend subchains to avoid haplotypes visiting them multiple times" << std::endl; std::cerr << std::endl; std::cerr << "Options for sampling haplotypes:" << std::endl; std::cerr << " --coverage N kmer coverage in the KFF file (default: estimate)" << std::endl; @@ -251,6 +252,7 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { constexpr int OPT_KMER_LENGTH = 1200; constexpr int OPT_WINDOW_LENGTH = 1201; constexpr int OPT_SUBCHAIN_LENGTH = 1202; + constexpr int OPT_LINEAR_STRUCTURE = 1203; constexpr int OPT_COVERAGE = 1300; constexpr int OPT_NUM_HAPLOTYPES = 1301; constexpr int OPT_PRESENT_DISCOUNT = 1302; @@ -277,6 +279,7 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { { "kmer-length", required_argument, 0, OPT_KMER_LENGTH }, { "window-length", required_argument, 0, OPT_WINDOW_LENGTH }, { "subchain-length", required_argument, 0, OPT_SUBCHAIN_LENGTH }, + { "linear-structure", no_argument, 0, OPT_LINEAR_STRUCTURE }, { "coverage", required_argument, 0, OPT_COVERAGE }, { "num-haplotypes", required_argument, 0, OPT_NUM_HAPLOTYPES }, { "present-discount", required_argument, 0, OPT_PRESENT_DISCOUNT }, @@ -348,6 +351,9 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { std::exit(EXIT_FAILURE); } break; + case OPT_LINEAR_STRUCTURE: + this->partitioner_parameters.linear_structure = true; + break; case OPT_COVERAGE: this->recombinator_parameters.coverage = parse(optarg); break; From a614f7139416dac62837e5053b69959033c9d257 Mon Sep 17 00:00:00 2001 From: Jouni Siren Date: Tue, 2 Jul 2024 14:52:25 +0200 Subject: [PATCH 7/8] More parameter validation in vg gbwt --- src/subcommand/gbwt_main.cpp | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/subcommand/gbwt_main.cpp b/src/subcommand/gbwt_main.cpp index 94b16945984..db1175b426e 100644 --- a/src/subcommand/gbwt_main.cpp +++ b/src/subcommand/gbwt_main.cpp @@ -1049,7 +1049,12 @@ void validate_gbwt_config(GBWTConfig& config) { } } - if (!config.graph_output.empty()) { + if (config.graph_output.empty()) { + if (config.gbz_format) { + std::cerr << "error: [vg gbwt] GBZ format requires graph output" << std::endl; + std::exit(EXIT_FAILURE); + } + } else { if (!has_graph_input || !one_input_gbwt) { std::cerr << "error: [vg gbwt] GBWTGraph construction requires an input graph and and one input GBWT" << std::endl; std::exit(EXIT_FAILURE); From a018e65527bae13a78b26a0fe201efbb4b4ef672 Mon Sep 17 00:00:00 2001 From: Jouni Siren Date: Wed, 3 Jul 2024 10:08:12 +0200 Subject: [PATCH 8/8] Add presets for haplotype sampling --- src/recombinator.cpp | 13 +++++++++++ src/recombinator.hpp | 13 ++++++++++- src/subcommand/giraffe_main.cpp | 5 +--- src/subcommand/haplotypes_main.cpp | 37 +++++++++++++++++++++++------- test/t/54_vg_haplotypes.t | 10 ++++++-- 5 files changed, 63 insertions(+), 15 deletions(-) diff --git a/src/recombinator.cpp b/src/recombinator.cpp index 6a25fba32e3..a4ababbcb5f 100644 --- a/src/recombinator.cpp +++ b/src/recombinator.cpp @@ -1099,6 +1099,19 @@ Recombinator::Recombinator(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotyp { } +//------------------------------------------------------------------------------ + +Recombinator::Parameters::Parameters(preset_t preset) { + if (preset == preset_haploid) { + this->haploid_scoring = true; + this->include_reference = true; + } else if (preset == preset_diploid) { + this->num_haplotypes = NUM_CANDIDATES; + this->diploid_sampling = true; + this->include_reference = true; + } +} + void Recombinator::Parameters::print(std::ostream& out) const { out << "Sampling parameters:" << std::endl; if (this->haploid_scoring) { diff --git a/src/recombinator.hpp b/src/recombinator.hpp index ae160907b8a..17e6f618b2c 100644 --- a/src/recombinator.hpp +++ b/src/recombinator.hpp @@ -457,7 +457,6 @@ class Recombinator { Recombinator(const gbwtgraph::GBZ& gbz, const Haplotypes& haplotypes, Verbosity verbosity); /// Parameters for `generate_haplotypes()`. - /// TODO: We should have a single parameter for the scoring/sampling model. struct Parameters { /// Number of haplotypes to be generated, or the number of candidates /// for diploid sampling. @@ -495,6 +494,18 @@ class Recombinator { /// Include named and reference paths. bool include_reference = false; + /// Preset parameters for common use cases. + enum preset_t { + /// Default parameters. + preset_default, + /// Best practices for haploid sampling. + preset_haploid, + /// Best practices for diploid sampling. + preset_diploid + }; + + explicit Parameters(preset_t preset = preset_default); + /// Print a description of the parameters. void print(std::ostream& out) const; }; diff --git a/src/subcommand/giraffe_main.cpp b/src/subcommand/giraffe_main.cpp index 3d91d0a41b4..f59fdcd3f53 100644 --- a/src/subcommand/giraffe_main.cpp +++ b/src/subcommand/giraffe_main.cpp @@ -1670,10 +1670,7 @@ string sample_haplotypes(const vector>& indexes, string& ba // Sample haplotypes. Haplotypes::Verbosity verbosity = (progress ? Haplotypes::verbosity_basic : Haplotypes::verbosity_silent); Recombinator recombinator(gbz, haplotypes, verbosity); - Recombinator::Parameters parameters; - parameters.num_haplotypes = Recombinator::NUM_CANDIDATES; - parameters.diploid_sampling = true; - parameters.include_reference = true; + Recombinator::Parameters parameters(Recombinator::Parameters::preset_diploid); gbwt::GBWT sampled_gbwt; try { sampled_gbwt = recombinator.generate_haplotypes(kff_file, parameters); diff --git a/src/subcommand/haplotypes_main.cpp b/src/subcommand/haplotypes_main.cpp index 1b07e47a741..3a3c36a3bac 100644 --- a/src/subcommand/haplotypes_main.cpp +++ b/src/subcommand/haplotypes_main.cpp @@ -220,6 +220,7 @@ void help_haplotypes(char** argv, bool developer_options) { std::cerr << " --linear-structure extend subchains to avoid haplotypes visiting them multiple times" << std::endl; std::cerr << std::endl; std::cerr << "Options for sampling haplotypes:" << std::endl; + std::cerr << " --preset X use preset X (default, haploid, diploid)" << std::endl; std::cerr << " --coverage N kmer coverage in the KFF file (default: estimate)" << std::endl; std::cerr << " --num-haplotypes N generate N haplotypes (default: " << haplotypes_default_n() << ")" << std::endl; std::cerr << " sample from N candidates (with --diploid-sampling; default: " << haplotypes_default_candidates() << ")" << std::endl; @@ -253,14 +254,15 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { constexpr int OPT_WINDOW_LENGTH = 1201; constexpr int OPT_SUBCHAIN_LENGTH = 1202; constexpr int OPT_LINEAR_STRUCTURE = 1203; - constexpr int OPT_COVERAGE = 1300; - constexpr int OPT_NUM_HAPLOTYPES = 1301; - constexpr int OPT_PRESENT_DISCOUNT = 1302; - constexpr int OPT_HET_ADJUSTMENT = 1303; - constexpr int OPT_ABSENT_SCORE = 1304; - constexpr int OPT_HAPLOID_SCORING = 1305; - constexpr int OPT_DIPLOID_SAMPLING = 1306; - constexpr int OPT_INCLUDE_REFERENCE = 1307; + constexpr int OPT_PRESET = 1300; + constexpr int OPT_COVERAGE = 1301; + constexpr int OPT_NUM_HAPLOTYPES = 1302; + constexpr int OPT_PRESENT_DISCOUNT = 1303; + constexpr int OPT_HET_ADJUSTMENT = 1304; + constexpr int OPT_ABSENT_SCORE = 1305; + constexpr int OPT_HAPLOID_SCORING = 1306; + constexpr int OPT_DIPLOID_SAMPLING = 1307; + constexpr int OPT_INCLUDE_REFERENCE = 1308; constexpr int OPT_VALIDATE = 1400; constexpr int OPT_VCF_INPUT = 1500; constexpr int OPT_CONTIG_PREFIX = 1501; @@ -280,6 +282,7 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { { "window-length", required_argument, 0, OPT_WINDOW_LENGTH }, { "subchain-length", required_argument, 0, OPT_SUBCHAIN_LENGTH }, { "linear-structure", no_argument, 0, OPT_LINEAR_STRUCTURE }, + { "preset", required_argument, 0, OPT_PRESET }, { "coverage", required_argument, 0, OPT_COVERAGE }, { "num-haplotypes", required_argument, 0, OPT_NUM_HAPLOTYPES }, { "present-discount", required_argument, 0, OPT_PRESENT_DISCOUNT }, @@ -354,6 +357,24 @@ HaplotypesConfig::HaplotypesConfig(int argc, char** argv, size_t max_threads) { case OPT_LINEAR_STRUCTURE: this->partitioner_parameters.linear_structure = true; break; + + case OPT_PRESET: + { + Recombinator::Parameters::preset_t preset; + if (std::string(optarg) == "default") { + preset = Recombinator::Parameters::preset_default; + } else if (std::string(optarg) == "haploid") { + preset = Recombinator::Parameters::preset_haploid; + } else if (std::string(optarg) == "diploid") { + preset = Recombinator::Parameters::preset_diploid; + } else { + std::cerr << "error: [vg haplotypes] unknown preset: " << optarg << std::endl; + std::exit(EXIT_FAILURE); + } + this->recombinator_parameters = Recombinator::Parameters(preset); + num_haplotypes_set = true; // The preset is assumed to include the number of haplotypes. + break; + } case OPT_COVERAGE: this->recombinator_parameters.coverage = parse(optarg); break; diff --git a/test/t/54_vg_haplotypes.t b/test/t/54_vg_haplotypes.t index c7bec93b1c9..888c5246023 100644 --- a/test/t/54_vg_haplotypes.t +++ b/test/t/54_vg_haplotypes.t @@ -5,7 +5,7 @@ BASH_TAP_ROOT=../deps/bash-tap PATH=../bin:$PATH # for vg -plan tests 19 +plan tests 21 # The test graph consists of two subgraphs of the HPRC Minigraph-Cactus v1.1 graph: # - GRCh38#chr6:31498145-31511124 (micb) @@ -45,6 +45,12 @@ is $(vg gbwt -S -Z diploid.gbz) 3 "1 generated + 2 reference samples" is $(vg gbwt -C -Z diploid.gbz) 2 "2 contigs" is $(vg gbwt -H -Z diploid.gbz) 4 "2 generated + 2 reference haplotypes" +# Diploid sampling using a preset +vg haplotypes -i full.hapl -k haplotype-sampling/HG003.kff --preset diploid -g diploid2.gbz full.gbz +is $? 0 "diploid sampling using a preset" +cmp diploid.gbz diploid2.gbz +is $? 0 "the outputs are identical" + # Giraffe integration, guessed output name vg giraffe -Z full.gbz --haplotype-name full.hapl --kff-name haplotype-sampling/HG003.kff \ -f haplotype-sampling/HG003.fq.gz > default.gam 2> /dev/null @@ -63,6 +69,6 @@ is $? 0 "the sampled graphs are identical" # Cleanup rm -r full.gbz full.ri full.dist full.hapl rm -f indirect.gbz direct.gbz no_ref.gbz -rm -f diploid.gbz +rm -f diploid.gbz diploid2.gbz rm -f full.HG003.gbz full.HG003.dist full.HG003.min default.gam rm -f sampled.003HG.gbz sampled.003HG.dist sampled.003HG.min specified.gam