Skip to content

Commit

Permalink
add withTimeout method to sql"" and SQL("") queries.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gregg Hernandez committed Jun 16, 2015
1 parent 6e255f7 commit bc73aa4
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 5 deletions.
20 changes: 15 additions & 5 deletions src/main/scala/com/lucidchart/open/relate/SqlQuery.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,17 @@ private[relate] case class ExpandableQuery(
listParams
)

/**
def withTimeout(seconds: Int): ExpandableQuery = new ExpandableQuery(query, listParams) {
override def applyParams(stmt: PreparedStatement) {
if (stmt != null) {
stmt.setQueryTimeout(seconds)
}

super.applyParams(stmt)
}
}

/**
* The copy method used by the Sql Trait
* Returns a SqlQuery object so that expansion can only occur before the 'on' method
*/
Expand Down Expand Up @@ -258,16 +268,16 @@ trait Sql {
protected val parsedQuery: String
protected def applyParams(stmt: PreparedStatement)

private class BaseStatement(val connection: Connection) {
protected[relate] class BaseStatement(val connection: Connection) {
protected val parsedQuery = self.parsedQuery
protected def applyParams(stmt: PreparedStatement) = self.applyParams(stmt)
}

private def normalStatement(implicit connection: Connection) = new BaseStatement(connection) with NormalStatementPreparer
protected def normalStatement(implicit connection: Connection) = new BaseStatement(connection) with NormalStatementPreparer

private def insertionStatement(implicit connection: Connection) = new BaseStatement(connection) with InsertionStatementPreparer
protected def insertionStatement(implicit connection: Connection) = new BaseStatement(connection) with InsertionStatementPreparer

private def streamedStatement(fetchSize: Int)(implicit connection: Connection) = {
protected def streamedStatement(fetchSize: Int)(implicit connection: Connection) = {
val fetchSize_ = fetchSize
new BaseStatement(connection) with StreamedStatementPreparer {
protected val fetchSize = fetchSize_
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,12 @@ private[relate] sealed trait StatementPreparer {
private[relate] trait BaseStatementPreparer extends StatementPreparer {
protected def applyParams(stmt: PreparedStatement)
protected def parsedQuery: String
protected def timeout: Option[Int] = None

protected def setTimeout(stmt: PreparedStatement): Unit = for {
seconds <- timeout
stmt <- Option(stmt)
} yield (stmt.setQueryTimeout(seconds))
}

private[relate] trait NormalStatementPreparer extends BaseStatementPreparer {
Expand All @@ -71,6 +77,7 @@ private[relate] trait NormalStatementPreparer extends BaseStatementPreparer {
*/
protected override def prepare(): PreparedStatement = {
val stmt = connection.prepareStatement(parsedQuery)
setTimeout(stmt)
applyParams(stmt)
stmt
}
Expand All @@ -91,6 +98,7 @@ private[relate] trait InsertionStatementPreparer extends BaseStatementPreparer {
*/
protected override def prepare(): PreparedStatement = {
val stmt = connection.prepareStatement(parsedQuery, Statement.RETURN_GENERATED_KEYS)
setTimeout(stmt)
applyParams(stmt)
stmt
}
Expand Down Expand Up @@ -118,6 +126,7 @@ private[relate] trait StreamedStatementPreparer extends BaseStatementPreparer {
ResultSet.TYPE_FORWARD_ONLY,
ResultSet.CONCUR_READ_ONLY
)
setTimeout(stmt)
val driver = connection.getMetaData().getDriverName()
if (driver.toLowerCase.contains("mysql")) {
stmt.setFetchSize(Int.MinValue)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
package com.lucidchart.open.relate.interp

import com.lucidchart.open.relate._
import com.lucidchart.open.relate.Sql
import java.sql.Connection
import java.sql.PreparedStatement

class InterpolatedQuery(protected val parsedQuery: String, protected val params: Seq[Parameter]) extends Sql with MultipleParameter {
Expand All @@ -11,6 +13,24 @@ class InterpolatedQuery(protected val parsedQuery: String, protected val params:

def appendPlaceholders(stringBuilder: StringBuilder) = stringBuilder ++= parsedQuery

def withTimeout(seconds: Int): InterpolatedQuery = new InterpolatedQuery(parsedQuery, params) {
override protected def normalStatement(implicit conn: Connection) = new BaseStatement(conn) with NormalStatementPreparer {
override def timeout = Some(seconds)
}

override protected def insertionStatement(implicit conn: Connection) = new BaseStatement(conn) with InsertionStatementPreparer {
override def timeout = Some(seconds)
}

override protected def streamedStatement(fetchSize: Int)(implicit conn: Connection) = {
val fetchSize_ = fetchSize
new BaseStatement(conn) with StreamedStatementPreparer {
protected val fetchSize = fetchSize_
override def timeout = Some(seconds)
}
}
}

}

object InterpolatedQuery {
Expand Down

0 comments on commit bc73aa4

Please sign in to comment.