diff --git a/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala b/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala index f5c9e77..fa85869 100644 --- a/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala +++ b/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala @@ -10,6 +10,8 @@ private[relate] case class ExpandableQuery( listParams: mutable.Map[String, ListParam] = mutable.Map[String, ListParam]() ) extends ParameterizableSql with Expandable { + val timeout: Option[Int] = None + val params = Nil protected[relate] def queryParams = QueryParams( query, @@ -17,11 +19,18 @@ private[relate] case class ExpandableQuery( listParams ) + protected def setTimeout(stmt: PreparedStatement): Unit = { + for { + seconds <- timeout + stmt <- Option(stmt) + } yield (stmt.setQueryTimeout(seconds)) + } + def withTimeout(seconds: Int): ExpandableQuery = new ExpandableQuery(query, listParams) { + override val timeout: Option[Int] = Some(seconds) + override def applyParams(stmt: PreparedStatement) { - if (stmt != null) { - stmt.setQueryTimeout(seconds) - } + setTimeout(stmt) super.applyParams(stmt) } diff --git a/src/test/scala/SqlQuerySpec.scala b/src/test/scala/SqlQuerySpec.scala new file mode 100644 index 0000000..004c907 --- /dev/null +++ b/src/test/scala/SqlQuerySpec.scala @@ -0,0 +1,24 @@ +package com.lucidchart.open.relate + +import java.sql.Connection +import java.sql.PreparedStatement +import org.specs2.mutable._ +import org.specs2.mock.Mockito + +class SqlQuerySpec extends Specification with Mockito { + + + "ExpandableQuery.withTimeout" should { + class TestEq extends ExpandableQuery("") + + "set the timeout" in { + val eq = new TestEq().withTimeout(10) + eq.timeout must beSome(10) + } + + "not set the timtout" in { + val eq = new TestEq() + eq.timeout must beNone + } + } +} \ No newline at end of file