diff --git a/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala b/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala index 2342094..fa85869 100644 --- a/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala +++ b/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala @@ -10,6 +10,8 @@ private[relate] case class ExpandableQuery( listParams: mutable.Map[String, ListParam] = mutable.Map[String, ListParam]() ) extends ParameterizableSql with Expandable { + val timeout: Option[Int] = None + val params = Nil protected[relate] def queryParams = QueryParams( query, @@ -17,7 +19,24 @@ private[relate] case class ExpandableQuery( listParams ) - /** + protected def setTimeout(stmt: PreparedStatement): Unit = { + for { + seconds <- timeout + stmt <- Option(stmt) + } yield (stmt.setQueryTimeout(seconds)) + } + + def withTimeout(seconds: Int): ExpandableQuery = new ExpandableQuery(query, listParams) { + override val timeout: Option[Int] = Some(seconds) + + override def applyParams(stmt: PreparedStatement) { + setTimeout(stmt) + + super.applyParams(stmt) + } + } + + /** * The copy method used by the Sql Trait * Returns a SqlQuery object so that expansion can only occur before the 'on' method */ @@ -258,16 +277,16 @@ trait Sql { protected val parsedQuery: String protected def applyParams(stmt: PreparedStatement) - private class BaseStatement(val connection: Connection) { + protected[relate] class BaseStatement(val connection: Connection) { protected val parsedQuery = self.parsedQuery protected def applyParams(stmt: PreparedStatement) = self.applyParams(stmt) } - private def normalStatement(implicit connection: Connection) = new BaseStatement(connection) with NormalStatementPreparer + protected def normalStatement(implicit connection: Connection) = new BaseStatement(connection) with NormalStatementPreparer - private def insertionStatement(implicit connection: Connection) = new BaseStatement(connection) with InsertionStatementPreparer + protected def insertionStatement(implicit connection: Connection) = new BaseStatement(connection) with InsertionStatementPreparer - private def streamedStatement(fetchSize: Int)(implicit connection: Connection) = { + protected def streamedStatement(fetchSize: Int)(implicit connection: Connection) = { val fetchSize_ = fetchSize new BaseStatement(connection) with StreamedStatementPreparer { protected val fetchSize = fetchSize_ diff --git a/src/main/scala/com/lucidchart/open/relate/StatementPreparer.scala b/src/main/scala/com/lucidchart/open/relate/StatementPreparer.scala index bf92f77..35768ef 100644 --- a/src/main/scala/com/lucidchart/open/relate/StatementPreparer.scala +++ b/src/main/scala/com/lucidchart/open/relate/StatementPreparer.scala @@ -61,6 +61,12 @@ private[relate] sealed trait StatementPreparer { private[relate] trait BaseStatementPreparer extends StatementPreparer { protected def applyParams(stmt: PreparedStatement) protected def parsedQuery: String + protected def timeout: Option[Int] = None + + protected def setTimeout(stmt: PreparedStatement): Unit = for { + seconds <- timeout + stmt <- Option(stmt) + } yield (stmt.setQueryTimeout(seconds)) } private[relate] trait NormalStatementPreparer extends BaseStatementPreparer { @@ -71,6 +77,7 @@ private[relate] trait NormalStatementPreparer extends BaseStatementPreparer { */ protected override def prepare(): PreparedStatement = { val stmt = connection.prepareStatement(parsedQuery) + setTimeout(stmt) applyParams(stmt) stmt } @@ -91,6 +98,7 @@ private[relate] trait InsertionStatementPreparer extends BaseStatementPreparer { */ protected override def prepare(): PreparedStatement = { val stmt = connection.prepareStatement(parsedQuery, Statement.RETURN_GENERATED_KEYS) + setTimeout(stmt) applyParams(stmt) stmt } @@ -118,6 +126,7 @@ private[relate] trait StreamedStatementPreparer extends BaseStatementPreparer { ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY ) + setTimeout(stmt) val driver = connection.getMetaData().getDriverName() if (driver.toLowerCase.contains("mysql")) { stmt.setFetchSize(Int.MinValue) diff --git a/src/main/scala/com/lucidchart/open/relate/interp/InterpolatedQuery.scala b/src/main/scala/com/lucidchart/open/relate/interp/InterpolatedQuery.scala index 477d952..8e262db 100644 --- a/src/main/scala/com/lucidchart/open/relate/interp/InterpolatedQuery.scala +++ b/src/main/scala/com/lucidchart/open/relate/interp/InterpolatedQuery.scala @@ -1,6 +1,8 @@ package com.lucidchart.open.relate.interp +import com.lucidchart.open.relate._ import com.lucidchart.open.relate.Sql +import java.sql.Connection import java.sql.PreparedStatement class InterpolatedQuery(protected val parsedQuery: String, protected val params: Seq[Parameter]) extends Sql with MultipleParameter { @@ -11,6 +13,24 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params: def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder ++= parsedQuery + def withTimeout(seconds: Int): InterpolatedQuery = new InterpolatedQuery(parsedQuery, params) { + override protected def normalStatement(implicit conn: Connection) = new BaseStatement(conn) with NormalStatementPreparer { + override def timeout = Some(seconds) + } + + override protected def insertionStatement(implicit conn: Connection) = new BaseStatement(conn) with InsertionStatementPreparer { + override def timeout = Some(seconds) + } + + override protected def streamedStatement(fetchSize: Int)(implicit conn: Connection) = { + val fetchSize_ = fetchSize + new BaseStatement(conn) with StreamedStatementPreparer { + protected val fetchSize = fetchSize_ + override def timeout = Some(seconds) + } + } + } + } object InterpolatedQuery { diff --git a/src/test/scala/SqlQuerySpec.scala b/src/test/scala/SqlQuerySpec.scala new file mode 100644 index 0000000..004c907 --- /dev/null +++ b/src/test/scala/SqlQuerySpec.scala @@ -0,0 +1,24 @@ +package com.lucidchart.open.relate + +import java.sql.Connection +import java.sql.PreparedStatement +import org.specs2.mutable._ +import org.specs2.mock.Mockito + +class SqlQuerySpec extends Specification with Mockito { + + + "ExpandableQuery.withTimeout" should { + class TestEq extends ExpandableQuery("") + + "set the timeout" in { + val eq = new TestEq().withTimeout(10) + eq.timeout must beSome(10) + } + + "not set the timtout" in { + val eq = new TestEq() + eq.timeout must beNone + } + } +} \ No newline at end of file