Skip to content

Commit

Permalink
Merge pull request #51 from tomerk/better-broadcasts
Browse files Browse the repository at this point in the history
Added a broadcast via spark & switched to a single long-lived spark context
  • Loading branch information
dcrankshaw committed Apr 21, 2015
2 parents ef4e5fb + 12c3ab8 commit 3eff393
Show file tree
Hide file tree
Showing 11 changed files with 99 additions and 103 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import io.dropwizard.setup.Bootstrap
import io.dropwizard.setup.Environment
import com.fasterxml.jackson.annotation.JsonProperty
import javax.validation.constraints.NotNull
import org.apache.spark.{SparkContext, SparkConf}

import scala.collection.mutable
import com.fasterxml.jackson.module.scala.DefaultScalaModule

Expand Down Expand Up @@ -54,9 +56,17 @@ class VeloxApplication extends Application[VeloxConfiguration] with Logging {

// this assumes that etcd is running on each velox server
val etcdClient = new EtcdClient(conf.hostname, 4001, conf.hostname, new DispatchUtil)
logInfo("Starting spark context")
val sparkConf = new SparkConf()
.setMaster(conf.sparkMaster)
.setAppName("VeloxOnSpark!")
.setJars(SparkContext.jarOfObject(this).toSeq)

val sparkContext = new SparkContext(sparkConf)
val broadcastProvider = new SparkVersionedBroadcastProvider(sparkContext, conf.sparkDataLocation)

conf.modelFactories.foreach { case (name, modelFactory) => {
val (model, partition, partitionMap) = modelFactory.build(env, name, conf.hostname, etcdClient)
val (model, partition, partitionMap) = modelFactory.build(env, name, conf.hostname, broadcastProvider)

val predictServlet = new PointPredictionServlet(model, env.metrics().timer(name + "/predict/"))
val topKServlet = new TopKPredictionServlet(model, env.metrics().timer(name + "/predict_top_k/"))
Expand All @@ -67,12 +77,12 @@ class VeloxApplication extends Application[VeloxConfiguration] with Logging {
val writeHdfsServlet = new WriteToHDFSServlet(
model,
env.metrics().timer(name + "/observe/"),
conf.sparkMaster,
sparkContext,
conf.sparkDataLocation,
partition)
val retrainServlet = new RetrainServlet(
model,
conf.sparkMaster,
sparkContext,
conf.sparkDataLocation,
env.metrics().timer(name + "/retrain/"),
etcdClient,
Expand All @@ -81,7 +91,7 @@ class VeloxApplication extends Application[VeloxConfiguration] with Logging {
val loadNewModelServlet = new LoadNewModelServlet(
model,
env.metrics().timer(name + "/loadmodel/"),
conf.sparkMaster,
sparkContext,
conf.sparkDataLocation)
env.getApplicationContext.addServlet(new ServletHolder(predictServlet), "/predict/" + name)
env.getApplicationContext.addServlet(new ServletHolder(topKServlet), "/predict_top_k/" + name)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
package edu.berkeley.veloxms.models

import edu.berkeley.veloxms._
import edu.berkeley.veloxms.util.EtcdClient
import edu.berkeley.veloxms.storage.BroadcastProvider
import org.apache.spark.rdd._
import org.apache.spark.mllib.recommendation.{ALS,Rating}

class MatrixFactorizationModel(
val name: String,
val etcdClient: EtcdClient,
val broadcastProvider: BroadcastProvider,
val numFeatures: Int,
val averageUser: WeightVector,
val cacheResults: Boolean,
Expand Down
31 changes: 10 additions & 21 deletions veloxms-core/src/main/scala/edu/berkeley/veloxms/models/Model.scala
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,9 @@ import scala.util.Sorting
abstract class Model[T:ClassTag, U] extends Logging {

val name: String
val etcdClient: EtcdClient
val broadcastProvider: BroadcastProvider

private var version: Version = new Date(0)
private var version: Version = new Date(0).getTime
def currentVersion: Version = version
def useVersion(version: Version): Unit = {
// TODO: Implement cache invalidation!
Expand Down Expand Up @@ -76,10 +76,9 @@ abstract class Model[T:ClassTag, U] extends Logging {
**/
val averageUser: WeightVector

// FIXME: Add some sort of Broadcast provider instead of hardcoding the EtcdBroadcast
val broadcasts = new ConcurrentLinkedQueue[VersionedBroadcast[_]]()
protected def broadcast[V](id: String): VersionedBroadcast[V] = {
val b = new VersionedEtcdBroadcast[V](s"$name/$id", etcdClient)
protected def broadcast[V: ClassTag](id: String): VersionedBroadcast[V] = {
val b = broadcastProvider.get[V](s"$name/$id")
broadcasts.add(b)
b
}
Expand Down Expand Up @@ -129,21 +128,10 @@ abstract class Model[T:ClassTag, U] extends Logging {

// TODO: probably want to elect a leader to initiate the Spark retraining
// once we are running a Spark cluster
def retrainInSpark(sparkMaster: String, trainingDataDir: String, newModelsDir: String, nextVersion: Version) {
// This is installation specific
val sparkHome = "/root/spark-1.3.0-bin-hadoop1"
logWarning("Starting spark context")
val conf = new SparkConf()
.setMaster(sparkMaster)
.setAppName("VeloxOnSpark!")
.setJars(SparkContext.jarOfObject(this).toSeq)
.setSparkHome(sparkHome)

val sc = new SparkContext(conf)

def retrainInSpark(sparkContext: SparkContext, trainingDataDir: String, newModelsDir: String, nextVersion: Version) {
// TODO: Have to make sure this trainingData contains observations from ALL nodes!!
// TODO: This could be made better
val trainingData: RDD[(UserID, T, Double)] = sc.objectFile(s"$trainingDataDir/*/*")
val trainingData: RDD[(UserID, T, Double)] = sparkContext.objectFile(s"$trainingDataDir/*/*")

val itemFeatures = retrainFeatureModelsInSpark(trainingData, nextVersion)
val userWeights = retrainUserWeightsInSpark(itemFeatures, trainingData).map({
Expand All @@ -153,7 +141,6 @@ abstract class Model[T:ClassTag, U] extends Logging {

userWeights.saveAsTextFile(newModelsDir + "/users")

sc.stop()
logInfo("Finished retraining new model")
}

Expand Down Expand Up @@ -253,11 +240,13 @@ abstract class Model[T:ClassTag, U] extends Logging {
val partialScoresSum = precomputed.map(_._2).getOrElse(DenseVector.zeros[Double](k))

val allScores: Seq[(T, Double)] = if (newData) {
val scores = observations.putIfAbsent(uid, mutable.Map())
observations.putIfAbsent(uid, mutable.Map())
val scores = observations.get(uid)
scores.put(context, score)
scores.toSeq
} else {
observations.putIfAbsent(uid, mutable.Map()).toSeq
observations.putIfAbsent(uid, mutable.Map())
observations.get(uid).toSeq
}

val newScores: Seq[(T, Double)] = if (precomputed == None) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import javax.validation.constraints.NotNull

import com.fasterxml.jackson.annotation.JsonProperty
import edu.berkeley.veloxms._
import edu.berkeley.veloxms.storage.BroadcastProvider
import edu.berkeley.veloxms.util.{EtcdClient, Logging}
import io.dropwizard.setup.Environment
import org.hibernate.validator.constraints.NotEmpty
Expand All @@ -21,7 +22,7 @@ class ModelFactory extends Logging {
val cachePredictions: Boolean = false
val cacheFeatures: Boolean = false

def build(env: Environment, modelName: String, hostname: String, etcdClient: EtcdClient): (Model[_, _], Int, Map[String, Int]) = {
def build(env: Environment, modelName: String, hostname: String, broadcastProvider: BroadcastProvider): (Model[_, _], Int, Map[String, Int]) = {
modelType match {
case "MatrixFactorizationModel" => {
require(partitionFile != "")
Expand All @@ -31,7 +32,7 @@ class ModelFactory extends Logging {
val averageUser = Array.fill[Double](modelSize)(1.0)
val model = new MatrixFactorizationModel(
modelName,
etcdClient,
broadcastProvider,
modelSize,
averageUser,
cachePartialSums,
Expand All @@ -54,7 +55,7 @@ class ModelFactory extends Logging {
val averageUser = Array.fill[Double](modelSize)(1.0)
val model = new NewsgroupsModel(
modelName,
etcdClient,
broadcastProvider,
modelLoc,
modelSize,
averageUser,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import edu.berkeley.veloxms._
import edu.berkeley.veloxms.pipelines.transformers.{SimpleNGramTokenizer, MulticlassClassifierEvaluator}
import edu.berkeley.veloxms.pipelines._
import edu.berkeley.veloxms.pipelines.estimators.{MulticlassNaiveBayesEstimator, MostFrequentSparseFeatureSelector}
import edu.berkeley.veloxms.storage.BroadcastProvider
import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
import org.apache.spark.{SparkContext, SparkConf}
Expand All @@ -15,7 +16,7 @@ import edu.berkeley.veloxms.util.{EtcdClient, Logging}

class NewsgroupsModel(
val name: String,
val etcdClient: EtcdClient,
val broadcastProvider: BroadcastProvider,
val modelLoc: String,
val numFeatures: Int,
val averageUser: WeightVector,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ package object veloxms {
type FeatureVector = Array[Double]
type WeightVector = Array[Double]
type UserID = Long
type Version = Date
type Version = Long

val jsonMapper = new ObjectMapper().registerModule(new DefaultScalaModule)
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package edu.berkeley.veloxms.resources
import java.util.Date
import javax.servlet.http.{HttpServletResponse, HttpServletRequest, HttpServlet}

import org.apache.spark.SparkContext

import scala.concurrent.Await
import scala.concurrent.duration.Duration
import com.codahale.metrics.Timer
Expand Down Expand Up @@ -30,7 +32,7 @@ import org.apache.hadoop.fs._

class RetrainServlet(
model: Model[_, _],
sparkMaster: String,
sparkContext: SparkContext,
sparkDataLocation: String,
timer: Timer,
etcdClient: EtcdClient,
Expand All @@ -48,9 +50,9 @@ class RetrainServlet(
val lockAcquired = etcdClient.acquireRetrainLock(modelName)

if (lockAcquired) {
val nextVersion = new Date()
val obsDataLocation = HDFSLocation(s"$modelName/observations/${nextVersion.getTime}")
val newModelLocation = LoadModelParameters(s"$modelName/retrained_model/${nextVersion.getTime}", nextVersion)
val nextVersion = new Date().getTime
val obsDataLocation = HDFSLocation(s"$modelName/observations/$nextVersion")
val newModelLocation = LoadModelParameters(s"$modelName/retrained_model/$nextVersion", nextVersion)

val hosts = hostPartitionMap.map({
case(h, _) => host(h, veloxPort).setContentType("application/json", "UTF-8")
Expand All @@ -68,7 +70,7 @@ class RetrainServlet(
val writeResponses = Await.result(writeResponseFutures, Duration.Inf)
logInfo(s"Write to hdfs responses: ${writeResponses.mkString("\n")}")
model.retrainInSpark(
sparkMaster,
sparkContext,
s"$sparkDataLocation/${obsDataLocation.loc}",
s"$sparkDataLocation/${newModelLocation.userWeightsLoc}",
nextVersion)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import org.apache.spark.{SparkContext, SparkConf}
case class HDFSLocation(loc: String)
case class LoadModelParameters(userWeightsLoc: String, version: Version)

class WriteToHDFSServlet(model: Model[_, _], timer: Timer, sparkMaster: String, sparkDataLocation: String, partition: Int) extends HttpServlet
class WriteToHDFSServlet(model: Model[_, _], timer: Timer, sparkContext: SparkContext, sparkDataLocation: String, partition: Int) extends HttpServlet
with Logging {

override def doPost(req: HttpServletRequest, resp: HttpServletResponse) {
Expand All @@ -25,22 +25,10 @@ class WriteToHDFSServlet(model: Model[_, _], timer: Timer, sparkMaster: String,
val obsLocation = jsonMapper.readValue(req.getInputStream, classOf[HDFSLocation])
val uri = s"$sparkDataLocation/${obsLocation.loc}/part_$partition"

val sparkHome = "/root/spark-1.3.0-bin-hadoop1"
logWarning("Starting spark context")
val sparkConf = new SparkConf()
.setMaster(sparkMaster)
.setAppName("VeloxOnSpark!")
.setJars(SparkContext.jarOfObject(this).toSeq)
.setSparkHome(sparkHome)
// .set("spark.akka.logAkkaConfig", "true")
val sc = new SparkContext(sparkConf)

val observations = model.getObservationsAsRDD(sc)
val observations = model.getObservationsAsRDD(sparkContext)
observations.saveAsObjectFile(uri)

sc.stop()

resp.setContentType("application/json");
resp.setContentType("application/json")
jsonMapper.writeValue(resp.getOutputStream, "success")
} finally {
timeContext.stop()
Expand All @@ -50,7 +38,7 @@ class WriteToHDFSServlet(model: Model[_, _], timer: Timer, sparkMaster: String,
}


class LoadNewModelServlet(model: Model[_, _], timer: Timer, sparkMaster: String, sparkDataLocation: String)
class LoadNewModelServlet(model: Model[_, _], timer: Timer, sparkContext: SparkContext, sparkDataLocation: String)
extends HttpServlet with Logging {

override def doPost(req: HttpServletRequest, resp: HttpServletResponse) {
Expand All @@ -59,18 +47,8 @@ class LoadNewModelServlet(model: Model[_, _], timer: Timer, sparkMaster: String,
try {
val uri = s"$sparkDataLocation/${modelLocation.userWeightsLoc}"

val sparkHome = "/root/spark-1.3.0-bin-hadoop1"
logWarning("Starting spark context")
val sparkConf = new SparkConf()
.setMaster(sparkMaster)
.setAppName("VeloxOnSpark!")
.setJars(SparkContext.jarOfObject(this).toSeq)
.setSparkHome(sparkHome)

val sc = new SparkContext(sparkConf)

// TODO only add users in this partition: if (userId % partNum == 0)
val users = sc.textFile(s"$uri/users/*").map(line => {
val users = sparkContext.textFile(s"$uri/users/*").map(line => {
val userSplits = line.split(", ")
val userId = userSplits(0).toLong
val userFeatures: Array[Double] = userSplits.drop(1).map(_.toDouble)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
package edu.berkeley.veloxms.storage

import org.apache.spark.SparkContext

import scala.reflect.ClassTag

/**
* This trait provides Versioned Broadcasts.
* TODO: Add HTTP broadcast
* TODO: Add Torrent broadcast
*/
trait BroadcastProvider {
def get[T: ClassTag](id: String): VersionedBroadcast[T]
}

class SparkVersionedBroadcastProvider(sparkContext: SparkContext, path: String) extends BroadcastProvider {
override def get[T: ClassTag](id: String): VersionedBroadcast[T] = new SparkVersionedBroadcast(sparkContext, s"$path/$id")
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
package edu.berkeley.veloxms.storage

import edu.berkeley.veloxms._
import edu.berkeley.veloxms.util.{EtcdClient, KryoThreadLocal}
import org.apache.spark.{SparkContext, SparkConf}
import sun.misc.{BASE64Decoder, BASE64Encoder}

import scala.collection.mutable
import scala.reflect.ClassTag

/**
* A versioned broadcast that works via reading & writing to a global filesystem via the spark cluster
* @tparam T Has ClassTag because sc.objectFile (to load the broadcast) requires a classtag
*/
class SparkVersionedBroadcast[T: ClassTag](sc: SparkContext, path: String) extends VersionedBroadcast[T] {
private val cachedValues: mutable.Map[Version, T] = mutable.Map()

override def put(value: T, version: Version): Unit = this.synchronized {
sc.parallelize(Seq(value)).saveAsObjectFile(s"$path/$version")
}

override def get(version: Version): Option[T] = this.synchronized {
val out = cachedValues.get(version).orElse(fetch(version))
out.foreach(x => cachedValues.put(version, x))
out
}

override def cache(version: Version): Unit = this.synchronized {
fetch(version).foreach(x => cachedValues.put(version, x))
}

private def fetch(version: Version): Option[T] = {
val location = s"$path/$version"
Some(sc.objectFile(location).first())
}
}
Loading

0 comments on commit 3eff393

Please sign in to comment.