Skip to content

Commit

Permalink
Merge pull request #24 from lucidsoftware/query-timeouts
Browse files Browse the repository at this point in the history
add withTimeout method to sql"" and SQL("") queries.
  • Loading branch information
pauldraper committed Jun 22, 2015
2 parents c3ffdba + f4b2d13 commit 1c2f8ab
Show file tree
Hide file tree
Showing 4 changed files with 77 additions and 5 deletions.
29 changes: 24 additions & 5 deletions src/main/scala/com/lucidchart/open/relate/SqlQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,33 @@ private[relate] case class ExpandableQuery(
listParams: mutable.Map[String, ListParam] = mutable.Map[String, ListParam]()
) extends ParameterizableSql with Expandable {

val timeout: Option[Int] = None

val params = Nil
protected[relate] def queryParams = QueryParams(
query,
params,
listParams
)

/**
protected def setTimeout(stmt: PreparedStatement): Unit = {
for {
seconds <- timeout
stmt <- Option(stmt)
} yield (stmt.setQueryTimeout(seconds))
}

def withTimeout(seconds: Int): ExpandableQuery = new ExpandableQuery(query, listParams) {
override val timeout: Option[Int] = Some(seconds)

override def applyParams(stmt: PreparedStatement) {
setTimeout(stmt)

super.applyParams(stmt)
}
}

/**
* The copy method used by the Sql Trait
* Returns a SqlQuery object so that expansion can only occur before the 'on' method
*/
Expand Down Expand Up @@ -258,16 +277,16 @@ trait Sql {
protected val parsedQuery: String
protected def applyParams(stmt: PreparedStatement)

private class BaseStatement(val connection: Connection) {
protected[relate] class BaseStatement(val connection: Connection) {
protected val parsedQuery = self.parsedQuery
protected def applyParams(stmt: PreparedStatement) = self.applyParams(stmt)
}

private def normalStatement(implicit connection: Connection) = new BaseStatement(connection) with NormalStatementPreparer
protected def normalStatement(implicit connection: Connection) = new BaseStatement(connection) with NormalStatementPreparer

private def insertionStatement(implicit connection: Connection) = new BaseStatement(connection) with InsertionStatementPreparer
protected def insertionStatement(implicit connection: Connection) = new BaseStatement(connection) with InsertionStatementPreparer

private def streamedStatement(fetchSize: Int)(implicit connection: Connection) = {
protected def streamedStatement(fetchSize: Int)(implicit connection: Connection) = {
val fetchSize_ = fetchSize
new BaseStatement(connection) with StreamedStatementPreparer {
protected val fetchSize = fetchSize_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ private[relate] sealed trait StatementPreparer {
private[relate] trait BaseStatementPreparer extends StatementPreparer {
protected def applyParams(stmt: PreparedStatement)
protected def parsedQuery: String
protected def timeout: Option[Int] = None

protected def setTimeout(stmt: PreparedStatement): Unit = for {
seconds <- timeout
stmt <- Option(stmt)
} yield (stmt.setQueryTimeout(seconds))
}

private[relate] trait NormalStatementPreparer extends BaseStatementPreparer {
Expand All @@ -71,6 +77,7 @@ private[relate] trait NormalStatementPreparer extends BaseStatementPreparer {
*/
protected override def prepare(): PreparedStatement = {
val stmt = connection.prepareStatement(parsedQuery)
setTimeout(stmt)
applyParams(stmt)
stmt
}
Expand All @@ -91,6 +98,7 @@ private[relate] trait InsertionStatementPreparer extends BaseStatementPreparer {
*/
protected override def prepare(): PreparedStatement = {
val stmt = connection.prepareStatement(parsedQuery, Statement.RETURN_GENERATED_KEYS)
setTimeout(stmt)
applyParams(stmt)
stmt
}
Expand Down Expand Up @@ -118,6 +126,7 @@ private[relate] trait StreamedStatementPreparer extends BaseStatementPreparer {
ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY
)
setTimeout(stmt)
val driver = connection.getMetaData().getDriverName()
if (driver.toLowerCase.contains("mysql")) {
stmt.setFetchSize(Int.MinValue)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.lucidchart.open.relate.interp

import com.lucidchart.open.relate._
import com.lucidchart.open.relate.Sql
import java.sql.Connection
import java.sql.PreparedStatement

class InterpolatedQuery(protected val parsedQuery: String, protected val params: Seq[Parameter]) extends Sql with MultipleParameter {
Expand All @@ -11,6 +13,24 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params:

def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder ++= parsedQuery

def withTimeout(seconds: Int): InterpolatedQuery = new InterpolatedQuery(parsedQuery, params) {
override protected def normalStatement(implicit conn: Connection) = new BaseStatement(conn) with NormalStatementPreparer {
override def timeout = Some(seconds)
}

override protected def insertionStatement(implicit conn: Connection) = new BaseStatement(conn) with InsertionStatementPreparer {
override def timeout = Some(seconds)
}

override protected def streamedStatement(fetchSize: Int)(implicit conn: Connection) = {
val fetchSize_ = fetchSize
new BaseStatement(conn) with StreamedStatementPreparer {
protected val fetchSize = fetchSize_
override def timeout = Some(seconds)
}
}
}

}

object InterpolatedQuery {
Expand Down
24 changes: 24 additions & 0 deletions src/test/scala/SqlQuerySpec.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package com.lucidchart.open.relate

import java.sql.Connection
import java.sql.PreparedStatement
import org.specs2.mutable._
import org.specs2.mock.Mockito

class SqlQuerySpec extends Specification with Mockito {


"ExpandableQuery.withTimeout" should {
class TestEq extends ExpandableQuery("")

"set the timeout" in {
val eq = new TestEq().withTimeout(10)
eq.timeout must beSome(10)
}

"not set the timtout" in {
val eq = new TestEq()
eq.timeout must beNone
}
}
}

0 comments on commit 1c2f8ab

Please sign in to comment.