[
https://issues.apache.org/jira/browse/SPARK-6932?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel
]
Reynold Xin updated SPARK-6932:
-------------------------------
Target Version/s: (was: 2+)
> A Prototype of Parameter Server
> -------------------------------
>
> Key: SPARK-6932
> URL: https://issues.apache.org/jira/browse/SPARK-6932
> Project: Spark
> Issue Type: New Feature
> Components: ML, MLlib, Spark Core
> Reporter: Qiping Li
>
> h2. Introduction
> As specified in
> [SPARK-4590|https://issues.apache.org/jira/browse/SPARK-4590],it would be
> very helpful to integrate parameter server into Spark for machine learning
> algorithms, especially for those with ultra high dimensions features.
> After carefully studying the design doc of [Parameter
> Servers|https://docs.google.com/document/d/1SX3nkmF41wFXAAIr9BgqvrHSS5mW362fJ7roBXJm06o/edit?usp=sharing],and
> the paper of [Factorbird|http://stanford.edu/~rezab/papers/factorbird.pdf],
> we proposed a prototype of Parameter Server on Spark(Ps-on-Spark), with
> several key design concerns:
> * *User friendly interface*
> Careful investigation is done to most existing Parameter Server
> systems(including: [petuum|http://petuum.github.io], [parameter
> server|http://parameterserver.org],
> [paracel|https://github.com/douban/paracel]) and a user friendly interface is
> design by absorbing essence from all these system.
> * *Prototype of distributed array*
> IndexRDD (see
> [SPARK-4590|https://issues.apache.org/jira/browse/SPARK-4590]) doesn't seem
> to be a good option for distributed array, because in most case, the #key
> updates/second is not be very high.
> So we implement a distributed HashMap to store the parameters, which can
> be easily extended to get better performance.
>
> * *Minimal code change*
> Quite a lot of effort in done to avoid code change of Spark core. Tasks
> which need parameter server are still created and scheduled by Spark's
> scheduler. Tasks communicate with parameter server with a client object,
> through *akka* or *netty*.
> With all these concerns we propose the following architecture:
> h2. Architecture
> !https://cloud.githubusercontent.com/assets/1285855/7158179/f2d25cc4-e3a9-11e4-835e-89681596c478.jpg!
> Data is stored in RDD and is partitioned across workers. During each
> iteration, each worker gets parameters from parameter server then computes
> new parameters based on old parameters and data in the partition. Finally
> each worker updates parameters to parameter server.Worker communicates with
> parameter server through a parameter server client,which is initialized in
> `TaskContext` of this worker.
> The current implementation is based on YARN cluster mode,
> but it should not be a problem to transplanted it to other modes.
> h3. Interface
> We refer to existing parameter server systems(petuum, parameter server,
> paracel) when design the interface of parameter server.
> *`PSClient` provides the following interface for workers to use:*
> {code}
> // get parameter indexed by key from parameter server
> def get[T](key: String): T
> // get multiple parameters from parameter server
> def multiGet[T](keys: Array[String]): Array[T]
> // add parameter indexed by `key` by `delta`,
> // if multiple `delta` to update on the same parameter,
> // use `reduceFunc` to reduce these `delta`s frist.
> def update[T](key: String, delta: T, reduceFunc: (T, T) => T): Unit
> // update multiple parameters at the same time, use the same `reduceFunc`.
> def multiUpdate(keys: Array[String], delta: Array[T], reduceFunc: (T, T) =>
> T: Unit
>
> // advance clock to indicate that current iteration is finished.
> def clock(): Unit
>
> // block until all workers have reached this line of code.
> def sync(): Unit
> {code}
> *`PSContext` provides following functions to use on driver:*
> {code}
> // load parameters from existing rdd.
> def loadPSModel[T](model: RDD[String, T])
> // fetch parameters from parameter server to construct model.
> def fetchPSModel[T](keys: Array[String]): Array[T]
> {code}
>
> *A new function has been add to `RDD` to run parameter server tasks:*
> {code}
> // run the provided `func` on each partition of this RDD.
> // This function can use data of this partition(the first argument)
> // and a parameter server client(the second argument).
> // See the following Logistic Regression for an example.
> def runWithPS[U: ClassTag](func: (Array[T], PSClient) => U): Array[U]
>
> {code}
> h2. Example
> Here is an example of using our prototype to implement logistic regression:
> {code:title=LogisticRegression.scala|borderStyle=solid}
> def train(
> sc: SparkContext,
> input: RDD[LabeledPoint],
> numIterations: Int,
> stepSize: Double,
> miniBatchFraction: Double): LogisticRegressionModel = {
>
> // initialize weights
> val numFeatures = input.map(_.features.size).first()
> val initialWeights = new Array[Double](numFeatures)
> // initialize parameter server context
> val pssc = new PSContext(sc)
> // load initialized weights into parameter server
> val initialModelRDD = sc.parallelize(Array(("w", initialWeights)), 1)
> pssc.loadPSModel(initialModelRDD)
> // run logistic regression algorithm on input data
> input.runWithPS((arr, client) => {
> val sampler = new BernoulliSampler[LabeledPoint](miniBatchFraction)
>
> // for each iteration, compute delta and update weights
> for (i <- 0 to numIterations) {
> // get weights from parameter server
> val weights = Vectors.dense(client.get[Array[Double]]("w"))
> sampler.setSeed(i + 42)
> // for each sample point, compute delta and update weights
> sampler.sample(arr.toIterator).foreach { point =>
> // compute delta
> val data = point.features
> val label = point.label
> val margin = -1.0 * dot(data, weights)
> val multiplier = (1.0 / (1.0 + math.exp(margin))) - label
> val delta = Vectors.dense(new Array[Double](numFeatures))
> axpy((-1) * stepSize / math.sqrt(i + 1) * multiplier, data, delta)
> // update weights
> client.update("w", delta.toArray, (d1, d2) => {
> d1.zip(d2).map((a, b) => a + b)
> })
> }
>
> // end of current iteration
> client.clock()
> }
> })
> // fetch weights from parameter server
> val weights =
> Vectors.dense(pssc.fetchPSModel[Array[Double]](Array("w"))(0))
> val intercept = 0.0
> // construct LogisiticRegressionModel
> new LogisticRegressionModel(weights, intercept).clearThreshold()
> }
> {code}
> The above code can be run on current PS-on-Spark implementation.
> h2. Other considerations
> The current implementation is just a prototype and we will try to improve it
> in the following directions:
> h3. Consistency protocol
> Currently we have just implemented BSP protocol. And SSP consistency will be
> added soon.
> h3. Model partition across servers
> Currently all the parameters are stored on a single server. Parameters should
> be partitioned across multiple servers when the parameter size get large.
> Parameter server client should route request to different servers
> accordingly.
> h3. Performance optimizing
> To get better performance, client can cache parameter servers and store
> updates through operation log(as petuum does). There may be some other ways
> to improve performance.
> h3. Fault Recovery
> When a parameter server crashes, it should be restarted on another node. Data
> of a parameter server should be periodically checkpointed so it can be
> transfered when a server is restarted.When a task is restarted, it should not
> rerun finished iterations.
> We would like to see parameter server integrated into Spark soon and hope
> this help other Spark users who need parameter server. As specified above,
> there is still much work to be done so any comments are welcome.
--
This message was sent by Atlassian JIRA
(v6.3.4#6332)
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]