From bc73aa4a4c27a78fb4b71a35da516f13d1c4b12f Mon Sep 17 00:00:00 2001 From: Gregg Hernandez Date: Tue, 16 Jun 2015 16:00:27 -0600 Subject: [PATCH] add withTimeout method to sql"" and SQL("") queries. --- .../com/lucidchart/open/relate/SqlQuery.scala | 20 ++++++++++++++----- .../open/relate/StatementPreparer.scala | 9 +++++++++ .../relate/interp/InterpolatedQuery.scala | 20 +++++++++++++++++++ 3 files changed, 44 insertions(+), 5 deletions(-) diff --git a/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala b/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala index 2342094..f5c9e77 100644 --- a/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala +++ b/src/main/scala/com/lucidchart/open/relate/SqlQuery.scala @@ -17,7 +17,17 @@ private[relate] case class ExpandableQuery( listParams ) - /** + def withTimeout(seconds: Int): ExpandableQuery = new ExpandableQuery(query, listParams) { + override def applyParams(stmt: PreparedStatement) { + if (stmt != null) { + stmt.setQueryTimeout(seconds) + } + + super.applyParams(stmt) + } + } + + /** * The copy method used by the Sql Trait * Returns a SqlQuery object so that expansion can only occur before the 'on' method */ @@ -258,16 +268,16 @@ trait Sql { protected val parsedQuery: String protected def applyParams(stmt: PreparedStatement) - private class BaseStatement(val connection: Connection) { + protected[relate] class BaseStatement(val connection: Connection) { protected val parsedQuery = self.parsedQuery protected def applyParams(stmt: PreparedStatement) = self.applyParams(stmt) } - private def normalStatement(implicit connection: Connection) = new BaseStatement(connection) with NormalStatementPreparer + protected def normalStatement(implicit connection: Connection) = new BaseStatement(connection) with NormalStatementPreparer - private def insertionStatement(implicit connection: Connection) = new BaseStatement(connection) with InsertionStatementPreparer + protected def insertionStatement(implicit connection: Connection) = new BaseStatement(connection) with InsertionStatementPreparer - private def streamedStatement(fetchSize: Int)(implicit connection: Connection) = { + protected def streamedStatement(fetchSize: Int)(implicit connection: Connection) = { val fetchSize_ = fetchSize new BaseStatement(connection) with StreamedStatementPreparer { protected val fetchSize = fetchSize_ diff --git a/src/main/scala/com/lucidchart/open/relate/StatementPreparer.scala b/src/main/scala/com/lucidchart/open/relate/StatementPreparer.scala index bf92f77..35768ef 100644 --- a/src/main/scala/com/lucidchart/open/relate/StatementPreparer.scala +++ b/src/main/scala/com/lucidchart/open/relate/StatementPreparer.scala @@ -61,6 +61,12 @@ private[relate] sealed trait StatementPreparer { private[relate] trait BaseStatementPreparer extends StatementPreparer { protected def applyParams(stmt: PreparedStatement) protected def parsedQuery: String + protected def timeout: Option[Int] = None + + protected def setTimeout(stmt: PreparedStatement): Unit = for { + seconds <- timeout + stmt <- Option(stmt) + } yield (stmt.setQueryTimeout(seconds)) } private[relate] trait NormalStatementPreparer extends BaseStatementPreparer { @@ -71,6 +77,7 @@ private[relate] trait NormalStatementPreparer extends BaseStatementPreparer { */ protected override def prepare(): PreparedStatement = { val stmt = connection.prepareStatement(parsedQuery) + setTimeout(stmt) applyParams(stmt) stmt } @@ -91,6 +98,7 @@ private[relate] trait InsertionStatementPreparer extends BaseStatementPreparer { */ protected override def prepare(): PreparedStatement = { val stmt = connection.prepareStatement(parsedQuery, Statement.RETURN_GENERATED_KEYS) + setTimeout(stmt) applyParams(stmt) stmt } @@ -118,6 +126,7 @@ private[relate] trait StreamedStatementPreparer extends BaseStatementPreparer { ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY ) + setTimeout(stmt) val driver = connection.getMetaData().getDriverName() if (driver.toLowerCase.contains("mysql")) { stmt.setFetchSize(Int.MinValue) diff --git a/src/main/scala/com/lucidchart/open/relate/interp/InterpolatedQuery.scala b/src/main/scala/com/lucidchart/open/relate/interp/InterpolatedQuery.scala index 477d952..8e262db 100644 --- a/src/main/scala/com/lucidchart/open/relate/interp/InterpolatedQuery.scala +++ b/src/main/scala/com/lucidchart/open/relate/interp/InterpolatedQuery.scala @@ -1,6 +1,8 @@ package com.lucidchart.open.relate.interp +import com.lucidchart.open.relate._ import com.lucidchart.open.relate.Sql +import java.sql.Connection import java.sql.PreparedStatement class InterpolatedQuery(protected val parsedQuery: String, protected val params: Seq[Parameter]) extends Sql with MultipleParameter { @@ -11,6 +13,24 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params: def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder ++= parsedQuery + def withTimeout(seconds: Int): InterpolatedQuery = new InterpolatedQuery(parsedQuery, params) { + override protected def normalStatement(implicit conn: Connection) = new BaseStatement(conn) with NormalStatementPreparer { + override def timeout = Some(seconds) + } + + override protected def insertionStatement(implicit conn: Connection) = new BaseStatement(conn) with InsertionStatementPreparer { + override def timeout = Some(seconds) + } + + override protected def streamedStatement(fetchSize: Int)(implicit conn: Connection) = { + val fetchSize_ = fetchSize + new BaseStatement(conn) with StreamedStatementPreparer { + protected val fetchSize = fetchSize_ + override def timeout = Some(seconds) + } + } + } + } object InterpolatedQuery {