diff --git a/src/main/scala/com/fulcrumgenomics/alignment/Aligner.scala b/src/main/scala/com/fulcrumgenomics/alignment/Aligner.scala index ad3a3886b..2b3d06de1 100644 --- a/src/main/scala/com/fulcrumgenomics/alignment/Aligner.scala +++ b/src/main/scala/com/fulcrumgenomics/alignment/Aligner.scala @@ -50,7 +50,6 @@ object Mode extends FgBioEnum[Mode] { override def values: immutable.IndexedSeq[Mode] = findValues } - object Aligner { /** Creates a NW aligner with fixed match and mismatch scores. */ def apply(matchScore: Int, mismatchScore: Int, gapOpen: Int, gapExtend: Int, mode: Mode = Global): Aligner = { @@ -87,19 +86,86 @@ object Aligner { object AlignmentMatrix { def apply(direction: Direction, queryLength: Int, targetLength: Int): AlignmentMatrix = { - AlignmentMatrix( + SimpleAlignmentMatrix( direction = direction, scoring = Matrix(queryLength+1, targetLength+1), trace = Matrix(queryLength+1, targetLength+1)) } } + abstract class AlignmentMatrix { + def direction: Direction + def scoring: Matrix[Int] + def trace: Matrix[Direction] + def queryLength: Int + def targetLength: Int + } + /** A single alignment matrix for a given `Direction` storing both the scoring and traceback matrices produce by the aligner. */ - case class AlignmentMatrix(direction: Direction, scoring: Matrix[Int], trace: Matrix[Direction]) { + private case class SimpleAlignmentMatrix(direction: Direction, scoring: Matrix[Int], trace: Matrix[Direction]) extends AlignmentMatrix { val queryLength: Int = scoring.x - 1 val targetLength: Int = scoring.y - 1 } + object CachedAligner { + + case class CachedAlignmentMatrix + (direction: Direction, + scoring: Matrix[Int], + trace: Matrix[Direction], + queryLength: Int, + targetLength: Int) extends AlignmentMatrix { + require(queryLength <= scoring.x) + require(queryLength <= trace.x) + require(targetLength <= scoring.y) + require(targetLength <= trace.y) + } + + def apply(matchScore: Int, mismatchScore: Int, gapOpen: Int, gapExtend: Int, mode: Mode = Global): CachedAligner = { + val aligner = Aligner(matchScore, mismatchScore, gapOpen, gapExtend, mode) + new CachedAligner(scorer=aligner.scorer, mode=aligner.mode) + } + } + + class CachedAligner(scorer: AlignmentScorer, + useEqualsAndX: Boolean = true, + mode: Mode = Global, + initQueryLength: Int = 1024, + initTargetLength: Int = 1024) extends Aligner(scorer=scorer, useEqualsAndX=useEqualsAndX, mode=mode) { + import CachedAligner.CachedAlignmentMatrix + + private var matrices: Array[CachedAlignmentMatrix] = AllDirections.sorted.map { dir => + CachedAlignmentMatrix( + direction = dir, + scoring = Matrix[Int](initQueryLength, initTargetLength), + trace = Matrix[Int](initQueryLength, initTargetLength), + queryLength = initQueryLength, + targetLength = initTargetLength, + ) + }.toArray + + override protected def allocMatrices(query: Array[Byte], target: Array[Byte]): Array[AlignmentMatrix] = { + if (query.length > this.matrices(0).scoring.x || target.length > this.matrices(0).scoring.y) { + val xLength = math.max(query.length, this.matrices(0).scoring.x) + val yLength = math.max(target.length, this.matrices(0).scoring.y) + this.matrices = AllDirections.sorted.map { dir => + CachedAlignmentMatrix( + direction = dir, + scoring = Matrix[Int](xLength, yLength), + trace = Matrix[Int](xLength, yLength), + queryLength = query.length, + targetLength = target.length, + ) + }.toArray + } else if (this.matrices(0).queryLength != query.length || this.matrices(0).targetLength != target.length) { + this.matrices = this.matrices.map { matrix => + matrix.copy(queryLength=query.length, targetLength=target.length) + } + } + this.matrices.map(_.asInstanceOf[AlignmentMatrix]) + } + } + /** Represents a cell within the set of matrices used for alignment. */ private case class MatrixLocation(queryIndex: Int, targetIndex: Int, direction: Direction) @@ -218,6 +284,10 @@ class Aligner(val scorer: AlignmentScorer, locations.map(l => generateAlignment(query, target, matrices, l)) } + protected def allocMatrices(query: Array[Byte], target: Array[Byte]): Array[AlignmentMatrix] = { + AllDirections.sorted.map(dir => AlignmentMatrix(direction=dir, queryLength=query.length, targetLength=target.length)).toArray + } + /** * Constructs both the scoring and traceback matrices. * @@ -234,8 +304,9 @@ class Aligner(val scorer: AlignmentScorer, * @param target the target sequence * @return an array of alignment matrices, where the indices to the array are the Directions */ + protected def buildMatrices(query: Array[Byte], target: Array[Byte]): Array[AlignmentMatrix] = { - val matrices = AllDirections.sorted.map(dir => AlignmentMatrix(direction=dir, queryLength=query.length, targetLength=target.length)).toArray + val matrices = allocMatrices(query=query, target=target) // While we have `matrices` above, it's useful to unpack all the matrices for direct access // in the core loop; when we know the exact matrix we need at compile time, it's faster @@ -258,7 +329,7 @@ class Aligner(val scorer: AlignmentScorer, } /** Fills in the leftmost column of the matrices. */ - private final def fillLeftmostColumn(query: Array[Byte], + protected final def fillLeftmostColumn(query: Array[Byte], target: Array[Byte], leftScoreMatrix: Matrix[Int], leftTraceMatrix: Matrix[Direction], @@ -289,7 +360,7 @@ class Aligner(val scorer: AlignmentScorer, } /** Fills in the top row of the matrices. */ - private final def fillTopRow(query: Array[Byte], + protected final def fillTopRow(query: Array[Byte], target: Array[Byte], leftScoreMatrix: Matrix[Int], leftTraceMatrix: Matrix[Direction], @@ -321,7 +392,7 @@ class Aligner(val scorer: AlignmentScorer, } /** Fills the interior of the matrix. */ - private final def fillInterior(query: Array[Byte], + protected final def fillInterior(query: Array[Byte], target: Array[Byte], leftScoreMatrix: Matrix[Int], leftTraceMatrix: Matrix[Direction], diff --git a/src/test/scala/com/fulcrumgenomics/alignment/AlignerTest.scala b/src/test/scala/com/fulcrumgenomics/alignment/AlignerTest.scala index 732a155db..0e5fa7562 100644 --- a/src/test/scala/com/fulcrumgenomics/alignment/AlignerTest.scala +++ b/src/test/scala/com/fulcrumgenomics/alignment/AlignerTest.scala @@ -25,6 +25,7 @@ package com.fulcrumgenomics.alignment import com.fulcrumgenomics.alignment.Aligner.AlignmentScorer +import com.fulcrumgenomics.alignment.Aligner.CachedAligner import com.fulcrumgenomics.alignment.Mode.{Global, Glocal, Local} import com.fulcrumgenomics.commons.util.NumericCounter import com.fulcrumgenomics.testing.UnitSpec @@ -550,4 +551,31 @@ class AlignerTest extends UnitSpec { System.out.println(s"Run median=${counter.median()}, mean=${counter.mean()}") } + + "CachedAligner.align(Global)" should "align two identical sequences with all matches" in { + val aligner = CachedAligner(1, -1, -3, -1) + + val result0 = aligner.align(s("ACGTAACC"), s("ACGTAACC")) + assertValidGlobalAlignment(result0) + result0.cigar.toString() shouldBe "8=" + result0.score shouldBe 8 + + // same sequence + val result1 = aligner.align(s("ACGTAACC"), s("ACGTAACC")) + assertValidGlobalAlignment(result1) + result1.cigar.toString() shouldBe "8=" + result1.score shouldBe 8 + + // longer + val result2 = aligner.align(s("ACGTAACCACGTAACC"), s("ACGTAACCACGTAACC")) + assertValidGlobalAlignment(result2) + result2.cigar.toString() shouldBe "16=" + result2.score shouldBe 16 + + // shorter + val result3 = aligner.align(s("ACGT"), s("ACGT")) + assertValidGlobalAlignment(result3) + result3.cigar.toString() shouldBe "4=" + result3.score shouldBe 4 + } }