Skip to content

Commit

Permalink
Initialize cluster provider / jedis pool only once (#157)
Browse files Browse the repository at this point in the history
Signed-off-by: khorshuheng <[email protected]>

Co-authored-by: khorshuheng <[email protected]>
  • Loading branch information
khorshuheng and khorshuheng authored Jul 7, 2022
1 parent 96b9336 commit 2be80cd
Show file tree
Hide file tree
Showing 5 changed files with 117 additions and 64 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
*/
package feast.ingestion.stores.redis

import redis.clients.jedis.{ClusterPipeline, DefaultJedisClientConfig, HostAndPort}
import redis.clients.jedis.commands.PipelineBinaryCommands
import redis.clients.jedis.{ClusterPipeline, DefaultJedisClientConfig, HostAndPort, Response}
import redis.clients.jedis.providers.ClusterConnectionProvider

import scala.collection.JavaConverters._
Expand All @@ -34,9 +35,14 @@ case class ClusterPipelineProvider(endpoint: RedisEndpoint) extends PipelineProv
val provider = new ClusterConnectionProvider(nodes, DEFAULT_CLIENT_CONFIG)

/**
* @return a cluster pipeline
* @return execute commands within a pipeline and return the result
*/
override def pipeline(): UnifiedPipeline = new ClusterPipeline(provider)
override def withPipeline[T](ops: PipelineBinaryCommands => T): T = {
val pipeline = new ClusterPipeline(provider)
val response = ops(pipeline)
pipeline.close()
response
}

/**
* Close client connection
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
*/
package feast.ingestion.stores.redis

import redis.clients.jedis.Response
import redis.clients.jedis.commands.PipelineBinaryCommands

import java.io.Closeable
Expand All @@ -25,12 +26,7 @@ import java.io.Closeable
*/
trait PipelineProvider {

type UnifiedPipeline = PipelineBinaryCommands with Closeable

/**
* @return an interface for executing pipeline commands
*/
def pipeline(): UnifiedPipeline
def withPipeline[T](ops: PipelineBinaryCommands => T): T

/**
* Close client connection
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
/*
* SPDX-License-Identifier: Apache-2.0
* Copyright 2018-2022 The Feast Authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package feast.ingestion.stores.redis

import redis.clients.jedis.Jedis

import scala.collection.mutable
import scala.util.Try

object PipelineProviderFactory {

private lazy val providers: mutable.Map[RedisEndpoint, PipelineProvider] = mutable.Map.empty

private def newJedisClient(endpoint: RedisEndpoint): Jedis = {
val jedis = new Jedis(endpoint.host, endpoint.port)
if (endpoint.password.nonEmpty) {
jedis.auth(endpoint.password)
}
jedis
}

private def checkIfInClusterMode(endpoint: RedisEndpoint): Boolean = {
val jedis = newJedisClient(endpoint)
val isCluster = Try(jedis.clusterInfo()).isSuccess
jedis.close()
isCluster
}

private def clusterPipelineProvider(endpoint: RedisEndpoint): PipelineProvider = {
ClusterPipelineProvider(endpoint)
}

private def singleNodePipelineProvider(endpoint: RedisEndpoint): PipelineProvider = {
SingleNodePipelineProvider(endpoint)
}

def newProvider(endpoint: RedisEndpoint): PipelineProvider = {
if (checkIfInClusterMode(endpoint)) {
clusterPipelineProvider(endpoint)
}
singleNodePipelineProvider(endpoint)
}

def provider(endpoint: RedisEndpoint): PipelineProvider = {
providers.getOrElseUpdate(endpoint, newProvider(endpoint))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -65,23 +65,6 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
pipelineSize = sparkConf.get("spark.redis.properties.pipelineSize").toInt
)

lazy val isClusterMode: Boolean = checkIfInClusterMode(endpoint)

def newJedisClient(endpoint: RedisEndpoint): Jedis = {
val jedis = new Jedis(endpoint.host, endpoint.port)
if (endpoint.password.nonEmpty) {
jedis.auth(endpoint.password)
}
jedis
}

def checkIfInClusterMode(endpoint: RedisEndpoint): Boolean = {
val jedis = newJedisClient(endpoint)
val isCluster = Try(jedis.clusterInfo()).isSuccess
jedis.close()
isCluster
}

override def insert(data: DataFrame, overwrite: Boolean): Unit = {
// repartition for deduplication
val dataToStore =
Expand All @@ -95,23 +78,19 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
java.security.Security.setProperty("networkaddress.cache.ttl", "3");
java.security.Security.setProperty("networkaddress.cache.negative.ttl", "0");

val pipelineProvider = if (isClusterMode) {
ClusterPipelineProvider(endpoint)
} else {
SingleNodePipelineProvider(newJedisClient(endpoint))
}
val pipelineProvider = PipelineProviderFactory.provider(endpoint)

// grouped iterator to only allocate memory for a portion of rows
partition.grouped(properties.pipelineSize).foreach { batch =>
// group by key and keep only latest row per each key
val rowsWithKey: Map[String, Row] =
compactRowsToLatestTimestamp(batch.map(row => dataKeyId(row) -> row)).toMap

val keys = rowsWithKey.keysIterator.toList
val readPipeline = pipelineProvider.pipeline()
val readResponses =
keys.map(key => persistence.get(readPipeline, key.getBytes()))
readPipeline.close()
val keys = rowsWithKey.keysIterator.toList
val readResponses = pipelineProvider.withPipeline(pipeline => {
keys.map(key => persistence.get(pipeline, key.getBytes()))
})

val storedValues = readResponses.map(_.get())
val timestamps = storedValues.map(persistence.storedTimestamp)
val timestampByKey = keys.zip(timestamps).toMap
Expand All @@ -122,31 +101,30 @@ class RedisSinkRelation(override val sqlContext: SQLContext, config: SparkRedisC
}
.toMap

val writePipeline = pipelineProvider.pipeline()
rowsWithKey.foreach { case (key, row) =>
timestampByKey(key) match {
case Some(t) if (t.after(row.getAs[java.sql.Timestamp](config.timestampColumn))) =>
()
case _ =>
if (metricSource.nonEmpty) {
val lag = System.currentTimeMillis() - row
.getAs[java.sql.Timestamp](config.timestampColumn)
.getTime

metricSource.get.METRIC_TOTAL_ROWS_INSERTED.inc()
metricSource.get.METRIC_ROWS_LAG.update(lag)
}
persistence.save(
writePipeline,
key.getBytes(),
row,
expiryTimestampByKey(key)
)
pipelineProvider.withPipeline(pipeline => {
rowsWithKey.foreach { case (key, row) =>
timestampByKey(key) match {
case Some(t) if (t.after(row.getAs[java.sql.Timestamp](config.timestampColumn))) =>
()
case _ =>
if (metricSource.nonEmpty) {
val lag = System.currentTimeMillis() - row
.getAs[java.sql.Timestamp](config.timestampColumn)
.getTime

metricSource.get.METRIC_TOTAL_ROWS_INSERTED.inc()
metricSource.get.METRIC_ROWS_LAG.update(lag)
}
persistence.save(
pipeline,
key.getBytes(),
row,
expiryTimestampByKey(key)
)
}
}
}
writePipeline.close()
})
}
pipelineProvider.close()
}
dataToStore.unpersist()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,32 @@
*/
package feast.ingestion.stores.redis

import redis.clients.jedis.Jedis
import redis.clients.jedis.commands.PipelineBinaryCommands
import redis.clients.jedis.{JedisPool, Response}

/**
* Provide pipeline for single node Redis.
*/
case class SingleNodePipelineProvider(jedis: Jedis) extends PipelineProvider {
case class SingleNodePipelineProvider(endpoint: RedisEndpoint) extends PipelineProvider {

val jedisPool = new JedisPool(endpoint.host, endpoint.port)

/**
* @return a single node redis pipeline
* @return execute command within a pipeline and return the result
*/
override def pipeline(): UnifiedPipeline = jedis.pipelined()
override def withPipeline[T](ops: PipelineBinaryCommands => T): T = {
val jedis = jedisPool.getResource
if (endpoint.password.nonEmpty) {
jedis.auth(endpoint.password)
}
val response = ops(jedis.pipelined())
jedis.close()
response
}

/**
* Close client connection
*/
override def close(): Unit = jedis.close()
override def close(): Unit = jedisPool.close()

}

0 comments on commit 2be80cd

Please sign in to comment.