http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/core/BasePreparator.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/core/BasePreparator.scala b/core/src/main/scala/org/apache/predictionio/core/BasePreparator.scala new file mode 100644 index 0000000..2075bbb --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/core/BasePreparator.scala @@ -0,0 +1,42 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.core + +import org.apache.predictionio.annotation.DeveloperApi +import org.apache.spark.SparkContext + +/** :: DeveloperApi :: + * Base class of all preparator controller classes + * + * Dev note: Probably will add an extra parameter for ad hoc JSON formatter + * + * @tparam TD Training data class + * @tparam PD Prepared data class + */ +@DeveloperApi +abstract class BasePreparator[TD, PD] + extends AbstractDoer { + /** :: DeveloperApi :: + * Engine developers should not use this directly. This is called by training + * workflow to prepare data before handing it over to algorithm + * + * @param sc Spark context + * @param td Training data + * @return Prepared data + */ + @DeveloperApi + def prepareBase(sc: SparkContext, td: TD): PD +}
http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/core/BaseServing.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/core/BaseServing.scala b/core/src/main/scala/org/apache/predictionio/core/BaseServing.scala new file mode 100644 index 0000000..bf1c842 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/core/BaseServing.scala @@ -0,0 +1,51 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.core + +import org.apache.predictionio.annotation.DeveloperApi +import org.apache.predictionio.annotation.Experimental + +/** :: DeveloperApi :: + * Base class of all serving controller classes + * + * @tparam Q Query class + * @tparam P Predicted result class + */ +@DeveloperApi +abstract class BaseServing[Q, P] + extends AbstractDoer { + /** :: Experimental :: + * Engine developers should not use this directly. This is called by serving + * layer to supplement process the query before sending it to algorithms. + * + * @param q Query + * @return A supplement Query + */ + @Experimental + def supplementBase(q: Q): Q + + /** :: DeveloperApi :: + * Engine developers should not use this directly. This is called by serving + * layer to combine multiple predicted results from multiple algorithms, and + * custom business logic before serving to the end user. + * + * @param q Query + * @param ps List of predicted results + * @return A single predicted result + */ + @DeveloperApi + def serveBase(q: Q, ps: Seq[P]): P +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/core/package.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/core/package.scala b/core/src/main/scala/org/apache/predictionio/core/package.scala new file mode 100644 index 0000000..0f3098c --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/core/package.scala @@ -0,0 +1,21 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio + +/** Core base classes of PredictionIO controller components. Engine developers + * should not use these directly. + */ +package object core {} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/package.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/package.scala b/core/src/main/scala/org/apache/predictionio/package.scala new file mode 100644 index 0000000..7b1989f --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/package.scala @@ -0,0 +1,19 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache + +/** PredictionIO Scala API */ +package object predictionio {} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/CoreWorkflow.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/CoreWorkflow.scala b/core/src/main/scala/org/apache/predictionio/workflow/CoreWorkflow.scala new file mode 100644 index 0000000..6a27e87 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/CoreWorkflow.scala @@ -0,0 +1,163 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import org.apache.predictionio.controller.EngineParams +import org.apache.predictionio.controller.Evaluation +import org.apache.predictionio.core.BaseEngine +import org.apache.predictionio.core.BaseEvaluator +import org.apache.predictionio.core.BaseEvaluatorResult +import org.apache.predictionio.data.storage.EngineInstance +import org.apache.predictionio.data.storage.EvaluationInstance +import org.apache.predictionio.data.storage.Model +import org.apache.predictionio.data.storage.Storage + +import com.github.nscala_time.time.Imports.DateTime +import grizzled.slf4j.Logger + +import scala.language.existentials + +/** CoreWorkflow handles PredictionIO metadata and environment variables of + * training and evaluation. + */ +object CoreWorkflow { + @transient lazy val logger = Logger[this.type] + @transient lazy val engineInstances = Storage.getMetaDataEngineInstances + @transient lazy val evaluationInstances = + Storage.getMetaDataEvaluationInstances() + + def runTrain[EI, Q, P, A]( + engine: BaseEngine[EI, Q, P, A], + engineParams: EngineParams, + engineInstance: EngineInstance, + env: Map[String, String] = WorkflowUtils.pioEnvVars, + params: WorkflowParams = WorkflowParams()) { + logger.debug("Starting SparkContext") + val mode = "training" + WorkflowUtils.checkUpgrade(mode, engineInstance.engineFactory) + + val batch = if (params.batch.nonEmpty) { + s"{engineInstance.engineFactory} (${params.batch}})" + } else { + engineInstance.engineFactory + } + val sc = WorkflowContext( + batch, + env, + params.sparkEnv, + mode.capitalize) + + try { + + val models: Seq[Any] = engine.train( + sc = sc, + engineParams = engineParams, + engineInstanceId = engineInstance.id, + params = params + ) + + val instanceId = Storage.getMetaDataEngineInstances + + val kryo = KryoInstantiator.newKryoInjection + + logger.info("Inserting persistent model") + Storage.getModelDataModels.insert(Model( + id = engineInstance.id, + models = kryo(models))) + + logger.info("Updating engine instance") + val engineInstances = Storage.getMetaDataEngineInstances + engineInstances.update(engineInstance.copy( + status = "COMPLETED", + endTime = DateTime.now + )) + + logger.info("Training completed successfully.") + } catch { + case e @( + _: StopAfterReadInterruption | + _: StopAfterPrepareInterruption) => { + logger.info(s"Training interrupted by $e.") + } + } finally { + logger.debug("Stopping SparkContext") + sc.stop() + } + } + + def runEvaluation[EI, Q, P, A, R <: BaseEvaluatorResult]( + evaluation: Evaluation, + engine: BaseEngine[EI, Q, P, A], + engineParamsList: Seq[EngineParams], + evaluationInstance: EvaluationInstance, + evaluator: BaseEvaluator[EI, Q, P, A, R], + env: Map[String, String] = WorkflowUtils.pioEnvVars, + params: WorkflowParams = WorkflowParams()) { + logger.info("runEvaluation started") + logger.debug("Start SparkContext") + + val mode = "evaluation" + + WorkflowUtils.checkUpgrade(mode, engine.getClass.getName) + + val batch = if (params.batch.nonEmpty) { + s"{evaluation.getClass.getName} (${params.batch}})" + } else { + evaluation.getClass.getName + } + val sc = WorkflowContext( + batch, + env, + params.sparkEnv, + mode.capitalize) + val evaluationInstanceId = evaluationInstances.insert(evaluationInstance) + + logger.info(s"Starting evaluation instance ID: $evaluationInstanceId") + + val evaluatorResult: BaseEvaluatorResult = EvaluationWorkflow.runEvaluation( + sc, + evaluation, + engine, + engineParamsList, + evaluator, + params) + + if (evaluatorResult.noSave) { + logger.info(s"This evaluation result is not inserted into database: $evaluatorResult") + } else { + val evaluatedEvaluationInstance = evaluationInstance.copy( + status = "EVALCOMPLETED", + id = evaluationInstanceId, + endTime = DateTime.now, + evaluatorResults = evaluatorResult.toOneLiner, + evaluatorResultsHTML = evaluatorResult.toHTML, + evaluatorResultsJSON = evaluatorResult.toJSON + ) + + logger.info(s"Updating evaluation instance with result: $evaluatorResult") + + evaluationInstances.update(evaluatedEvaluationInstance) + } + + logger.debug("Stop SparkContext") + + sc.stop() + + logger.info("runEvaluation completed") + } +} + + http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala b/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala new file mode 100644 index 0000000..d4f6323 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala @@ -0,0 +1,737 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import java.io.PrintWriter +import java.io.Serializable +import java.io.StringWriter +import java.util.concurrent.TimeUnit + +import akka.actor._ +import akka.event.Logging +import akka.io.IO +import akka.pattern.ask +import akka.util.Timeout +import com.github.nscala_time.time.Imports.DateTime +import com.twitter.bijection.Injection +import com.twitter.chill.KryoBase +import com.twitter.chill.KryoInjection +import com.twitter.chill.ScalaKryoInstantiator +import com.typesafe.config.ConfigFactory +import de.javakaffee.kryoserializers.SynchronizedCollectionsSerializer +import grizzled.slf4j.Logging +import org.apache.predictionio.authentication.KeyAuthentication +import org.apache.predictionio.configuration.SSLConfiguration +import org.apache.predictionio.controller.Engine +import org.apache.predictionio.controller.Params +import org.apache.predictionio.controller.Utils +import org.apache.predictionio.controller.WithPrId +import org.apache.predictionio.core.BaseAlgorithm +import org.apache.predictionio.core.BaseServing +import org.apache.predictionio.core.Doer +import org.apache.predictionio.data.storage.EngineInstance +import org.apache.predictionio.data.storage.EngineManifest +import org.apache.predictionio.data.storage.Storage +import org.apache.predictionio.workflow.JsonExtractorOption.JsonExtractorOption +import org.json4s._ +import org.json4s.native.JsonMethods._ +import org.json4s.native.Serialization.write +import spray.can.Http +import spray.can.server.ServerSettings +import spray.http.MediaTypes._ +import spray.http._ +import spray.httpx.Json4sSupport +import spray.routing._ +import spray.routing.authentication.{UserPass, BasicAuth} + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.concurrent.future +import scala.language.existentials +import scala.util.Failure +import scala.util.Random +import scala.util.Success +import scalaj.http.HttpOptions + +class KryoInstantiator(classLoader: ClassLoader) extends ScalaKryoInstantiator { + override def newKryo(): KryoBase = { + val kryo = super.newKryo() + kryo.setClassLoader(classLoader) + SynchronizedCollectionsSerializer.registerSerializers(kryo) + kryo + } +} + +object KryoInstantiator extends Serializable { + def newKryoInjection : Injection[Any, Array[Byte]] = { + val kryoInstantiator = new KryoInstantiator(getClass.getClassLoader) + KryoInjection.instance(kryoInstantiator) + } +} + +case class ServerConfig( + batch: String = "", + engineInstanceId: String = "", + engineId: Option[String] = None, + engineVersion: Option[String] = None, + engineVariant: String = "", + env: Option[String] = None, + ip: String = "0.0.0.0", + port: Int = 8000, + feedback: Boolean = false, + eventServerIp: String = "0.0.0.0", + eventServerPort: Int = 7070, + accessKey: Option[String] = None, + logUrl: Option[String] = None, + logPrefix: Option[String] = None, + logFile: Option[String] = None, + verbose: Boolean = false, + debug: Boolean = false, + jsonExtractor: JsonExtractorOption = JsonExtractorOption.Both) + +case class StartServer() +case class BindServer() +case class StopServer() +case class ReloadServer() +case class UpgradeCheck() + + +object CreateServer extends Logging { + val actorSystem = ActorSystem("pio-server") + val engineInstances = Storage.getMetaDataEngineInstances + val engineManifests = Storage.getMetaDataEngineManifests + val modeldata = Storage.getModelDataModels + + def main(args: Array[String]): Unit = { + val parser = new scopt.OptionParser[ServerConfig]("CreateServer") { + opt[String]("batch") action { (x, c) => + c.copy(batch = x) + } text("Batch label of the deployment.") + opt[String]("engineId") action { (x, c) => + c.copy(engineId = Some(x)) + } text("Engine ID.") + opt[String]("engineVersion") action { (x, c) => + c.copy(engineVersion = Some(x)) + } text("Engine version.") + opt[String]("engine-variant") required() action { (x, c) => + c.copy(engineVariant = x) + } text("Engine variant JSON.") + opt[String]("ip") action { (x, c) => + c.copy(ip = x) + } + opt[String]("env") action { (x, c) => + c.copy(env = Some(x)) + } text("Comma-separated list of environmental variables (in 'FOO=BAR' " + + "format) to pass to the Spark execution environment.") + opt[Int]("port") action { (x, c) => + c.copy(port = x) + } text("Port to bind to (default: 8000).") + opt[String]("engineInstanceId") required() action { (x, c) => + c.copy(engineInstanceId = x) + } text("Engine instance ID.") + opt[Unit]("feedback") action { (_, c) => + c.copy(feedback = true) + } text("Enable feedback loop to event server.") + opt[String]("event-server-ip") action { (x, c) => + c.copy(eventServerIp = x) + } + opt[Int]("event-server-port") action { (x, c) => + c.copy(eventServerPort = x) + } text("Event server port. Default: 7070") + opt[String]("accesskey") action { (x, c) => + c.copy(accessKey = Some(x)) + } text("Event server access key.") + opt[String]("log-url") action { (x, c) => + c.copy(logUrl = Some(x)) + } + opt[String]("log-prefix") action { (x, c) => + c.copy(logPrefix = Some(x)) + } + opt[String]("log-file") action { (x, c) => + c.copy(logFile = Some(x)) + } + opt[Unit]("verbose") action { (x, c) => + c.copy(verbose = true) + } text("Enable verbose output.") + opt[Unit]("debug") action { (x, c) => + c.copy(debug = true) + } text("Enable debug output.") + opt[String]("json-extractor") action { (x, c) => + c.copy(jsonExtractor = JsonExtractorOption.withName(x)) + } + } + + parser.parse(args, ServerConfig()) map { sc => + WorkflowUtils.modifyLogging(sc.verbose) + engineInstances.get(sc.engineInstanceId) map { engineInstance => + val engineId = sc.engineId.getOrElse(engineInstance.engineId) + val engineVersion = sc.engineVersion.getOrElse( + engineInstance.engineVersion) + engineManifests.get(engineId, engineVersion) map { manifest => + val engineFactoryName = engineInstance.engineFactory + val upgrade = actorSystem.actorOf(Props( + classOf[UpgradeActor], + engineFactoryName)) + actorSystem.scheduler.schedule( + 0.seconds, + 1.days, + upgrade, + UpgradeCheck()) + val master = actorSystem.actorOf(Props( + classOf[MasterActor], + sc, + engineInstance, + engineFactoryName, + manifest), + "master") + implicit val timeout = Timeout(5.seconds) + master ? StartServer() + actorSystem.awaitTermination + } getOrElse { + error(s"Invalid engine ID or version. Aborting server.") + } + } getOrElse { + error(s"Invalid engine instance ID. Aborting server.") + } + } + } + + def createServerActorWithEngine[TD, EIN, PD, Q, P, A]( + sc: ServerConfig, + engineInstance: EngineInstance, + engine: Engine[TD, EIN, PD, Q, P, A], + engineLanguage: EngineLanguage.Value, + manifest: EngineManifest): ActorRef = { + + val engineParams = engine.engineInstanceToEngineParams(engineInstance, sc.jsonExtractor) + + val kryo = KryoInstantiator.newKryoInjection + + val modelsFromEngineInstance = + kryo.invert(modeldata.get(engineInstance.id).get.models).get. + asInstanceOf[Seq[Any]] + + val batch = if (engineInstance.batch.nonEmpty) { + s"${engineInstance.engineFactory} (${engineInstance.batch})" + } else { + engineInstance.engineFactory + } + + val sparkContext = WorkflowContext( + batch = batch, + executorEnv = engineInstance.env, + mode = "Serving", + sparkEnv = engineInstance.sparkConf) + + val models = engine.prepareDeploy( + sparkContext, + engineParams, + engineInstance.id, + modelsFromEngineInstance, + params = WorkflowParams() + ) + + val algorithms = engineParams.algorithmParamsList.map { case (n, p) => + Doer(engine.algorithmClassMap(n), p) + } + + val servingParamsWithName = engineParams.servingParams + + val serving = Doer(engine.servingClassMap(servingParamsWithName._1), + servingParamsWithName._2) + + actorSystem.actorOf( + Props( + classOf[ServerActor[Q, P]], + sc, + engineInstance, + engine, + engineLanguage, + manifest, + engineParams.dataSourceParams._2, + engineParams.preparatorParams._2, + algorithms, + engineParams.algorithmParamsList.map(_._2), + models, + serving, + engineParams.servingParams._2)) + } +} + +class UpgradeActor(engineClass: String) extends Actor { + val log = Logging(context.system, this) + implicit val system = context.system + def receive: Actor.Receive = { + case x: UpgradeCheck => + WorkflowUtils.checkUpgrade("deployment", engineClass) + } +} + +class MasterActor ( + sc: ServerConfig, + engineInstance: EngineInstance, + engineFactoryName: String, + manifest: EngineManifest) extends Actor with SSLConfiguration with KeyAuthentication { + val log = Logging(context.system, this) + implicit val system = context.system + var sprayHttpListener: Option[ActorRef] = None + var currentServerActor: Option[ActorRef] = None + var retry = 3 + + def undeploy(ip: String, port: Int): Unit = { + val serverUrl = s"https://${ip}:${port}" + log.info( + s"Undeploying any existing engine instance at $serverUrl") + try { + val code = scalaj.http.Http(s"$serverUrl/stop") + .option(HttpOptions.allowUnsafeSSL) + .param(ServerKey.param, ServerKey.get) + .method("POST").asString.code + code match { + case 200 => Unit + case 404 => log.error( + s"Another process is using $serverUrl. Unable to undeploy.") + case _ => log.error( + s"Another process is using $serverUrl, or an existing " + + s"engine server is not responding properly (HTTP $code). " + + "Unable to undeploy.") + } + } catch { + case e: java.net.ConnectException => + log.warning(s"Nothing at $serverUrl") + case _: Throwable => + log.error("Another process might be occupying " + + s"$ip:$port. Unable to undeploy.") + } + } + + def receive: Actor.Receive = { + case x: StartServer => + val actor = createServerActor( + sc, + engineInstance, + engineFactoryName, + manifest) + currentServerActor = Some(actor) + undeploy(sc.ip, sc.port) + self ! BindServer() + case x: BindServer => + currentServerActor map { actor => + val settings = ServerSettings(system) + IO(Http) ! Http.Bind( + actor, + interface = sc.ip, + port = sc.port, + settings = Some(settings.copy(sslEncryption = true))) + } getOrElse { + log.error("Cannot bind a non-existing server backend.") + } + case x: StopServer => + log.info(s"Stop server command received.") + sprayHttpListener.map { l => + log.info("Server is shutting down.") + l ! Http.Unbind(5.seconds) + system.shutdown + } getOrElse { + log.warning("No active server is running.") + } + case x: ReloadServer => + log.info("Reload server command received.") + val latestEngineInstance = + CreateServer.engineInstances.getLatestCompleted( + manifest.id, + manifest.version, + engineInstance.engineVariant) + latestEngineInstance map { lr => + val actor = createServerActor(sc, lr, engineFactoryName, manifest) + sprayHttpListener.map { l => + l ! Http.Unbind(5.seconds) + val settings = ServerSettings(system) + IO(Http) ! Http.Bind( + actor, + interface = sc.ip, + port = sc.port, + settings = Some(settings.copy(sslEncryption = true))) + currentServerActor.get ! Kill + currentServerActor = Some(actor) + } getOrElse { + log.warning("No active server is running. Abort reloading.") + } + } getOrElse { + log.warning( + s"No latest completed engine instance for ${manifest.id} " + + s"${manifest.version}. Abort reloading.") + } + case x: Http.Bound => + val serverUrl = s"https://${sc.ip}:${sc.port}" + log.info(s"Engine is deployed and running. Engine API is live at ${serverUrl}.") + sprayHttpListener = Some(sender) + case x: Http.CommandFailed => + if (retry > 0) { + retry -= 1 + log.error(s"Bind failed. Retrying... ($retry more trial(s))") + context.system.scheduler.scheduleOnce(1.seconds) { + self ! BindServer() + } + } else { + log.error("Bind failed. Shutting down.") + system.shutdown + } + } + + def createServerActor( + sc: ServerConfig, + engineInstance: EngineInstance, + engineFactoryName: String, + manifest: EngineManifest): ActorRef = { + val (engineLanguage, engineFactory) = + WorkflowUtils.getEngine(engineFactoryName, getClass.getClassLoader) + val engine = engineFactory() + + // EngineFactory return a base engine, which may not be deployable. + if (!engine.isInstanceOf[Engine[_,_,_,_,_,_]]) { + throw new NoSuchMethodException(s"Engine $engine is not deployable") + } + + val deployableEngine = engine.asInstanceOf[Engine[_,_,_,_,_,_]] + + CreateServer.createServerActorWithEngine( + sc, + engineInstance, + // engine, + deployableEngine, + engineLanguage, + manifest) + } +} + +class ServerActor[Q, P]( + val args: ServerConfig, + val engineInstance: EngineInstance, + val engine: Engine[_, _, _, Q, P, _], + val engineLanguage: EngineLanguage.Value, + val manifest: EngineManifest, + val dataSourceParams: Params, + val preparatorParams: Params, + val algorithms: Seq[BaseAlgorithm[_, _, Q, P]], + val algorithmsParams: Seq[Params], + val models: Seq[Any], + val serving: BaseServing[Q, P], + val servingParams: Params) extends Actor with HttpService with KeyAuthentication { + val serverStartTime = DateTime.now + val log = Logging(context.system, this) + + var requestCount: Int = 0 + var avgServingSec: Double = 0.0 + var lastServingSec: Double = 0.0 + + /** The following is required by HttpService */ + def actorRefFactory: ActorContext = context + + implicit val timeout = Timeout(5, TimeUnit.SECONDS) + val pluginsActorRef = + context.actorOf(Props(classOf[PluginsActor], args.engineVariant), "PluginsActor") + val pluginContext = EngineServerPluginContext(log, args.engineVariant) + + def receive: Actor.Receive = runRoute(myRoute) + + val feedbackEnabled = if (args.feedback) { + if (args.accessKey.isEmpty) { + log.error("Feedback loop cannot be enabled because accessKey is empty.") + false + } else { + true + } + } else false + + def remoteLog(logUrl: String, logPrefix: String, message: String): Unit = { + implicit val formats = Utils.json4sDefaultFormats + try { + scalaj.http.Http(logUrl).postData( + logPrefix + write(Map( + "engineInstance" -> engineInstance, + "message" -> message))).asString + } catch { + case e: Throwable => + log.error(s"Unable to send remote log: ${e.getMessage}") + } + } + + def getStackTraceString(e: Throwable): String = { + val writer = new StringWriter() + val printWriter = new PrintWriter(writer) + e.printStackTrace(printWriter) + writer.toString + } + + val myRoute = + path("") { + get { + respondWithMediaType(`text/html`) { + detach() { + complete { + html.index( + args, + manifest, + engineInstance, + algorithms.map(_.toString), + algorithmsParams.map(_.toString), + models.map(_.toString), + dataSourceParams.toString, + preparatorParams.toString, + servingParams.toString, + serverStartTime, + feedbackEnabled, + args.eventServerIp, + args.eventServerPort, + requestCount, + avgServingSec, + lastServingSec + ).toString + } + } + } + } + } ~ + path("queries.json") { + post { + detach() { + entity(as[String]) { queryString => + try { + val servingStartTime = DateTime.now + val jsonExtractorOption = args.jsonExtractor + val queryTime = DateTime.now + // Extract Query from Json + val query = JsonExtractor.extract( + jsonExtractorOption, + queryString, + algorithms.head.queryClass, + algorithms.head.querySerializer, + algorithms.head.gsonTypeAdapterFactories + ) + val queryJValue = JsonExtractor.toJValue( + jsonExtractorOption, + query, + algorithms.head.querySerializer, + algorithms.head.gsonTypeAdapterFactories) + // Deploy logic. First call Serving.supplement, then Algo.predict, + // finally Serving.serve. + val supplementedQuery = serving.supplementBase(query) + // TODO: Parallelize the following. + val predictions = algorithms.zipWithIndex.map { case (a, ai) => + a.predictBase(models(ai), supplementedQuery) + } + // Notice that it is by design to call Serving.serve with the + // *original* query. + val prediction = serving.serveBase(query, predictions) + val predictionJValue = JsonExtractor.toJValue( + jsonExtractorOption, + prediction, + algorithms.head.querySerializer, + algorithms.head.gsonTypeAdapterFactories) + /** Handle feedback to Event Server + * Send the following back to the Event Server + * - appId + * - engineInstanceId + * - query + * - prediction + * - prId + */ + val result = if (feedbackEnabled) { + implicit val formats = + algorithms.headOption map { alg => + alg.querySerializer + } getOrElse { + Utils.json4sDefaultFormats + } + // val genPrId = Random.alphanumeric.take(64).mkString + def genPrId: String = Random.alphanumeric.take(64).mkString + val newPrId = prediction match { + case id: WithPrId => + val org = id.prId + if (org.isEmpty) genPrId else org + case _ => genPrId + } + + // also save Query's prId as prId of this pio_pr predict events + val queryPrId = + query match { + case id: WithPrId => + Map("prId" -> id.prId) + case _ => + Map() + } + val data = Map( + // "appId" -> dataSourceParams.asInstanceOf[ParamsWithAppId].appId, + "event" -> "predict", + "eventTime" -> queryTime.toString(), + "entityType" -> "pio_pr", // prediction result + "entityId" -> newPrId, + "properties" -> Map( + "engineInstanceId" -> engineInstance.id, + "query" -> query, + "prediction" -> prediction)) ++ queryPrId + // At this point args.accessKey should be Some(String). + val accessKey = args.accessKey.getOrElse("") + val f: Future[Int] = future { + scalaj.http.Http( + s"http://${args.eventServerIp}:${args.eventServerPort}/" + + s"events.json?accessKey=$accessKey").postData( + write(data)).header( + "content-type", "application/json").asString.code + } + f onComplete { + case Success(code) => { + if (code != 201) { + log.error(s"Feedback event failed. Status code: $code." + + s"Data: ${write(data)}.") + } + } + case Failure(t) => { + log.error(s"Feedback event failed: ${t.getMessage}") } + } + // overwrite prId in predictedResult + // - if it is WithPrId, + // then overwrite with new prId + // - if it is not WithPrId, no prId injection + if (prediction.isInstanceOf[WithPrId]) { + predictionJValue merge parse(s"""{"prId" : "$newPrId"}""") + } else { + predictionJValue + } + } else predictionJValue + + val pluginResult = + pluginContext.outputBlockers.values.foldLeft(result) { case (r, p) => + p.process(engineInstance, queryJValue, r, pluginContext) + } + + // Bookkeeping + val servingEndTime = DateTime.now + lastServingSec = + (servingEndTime.getMillis - servingStartTime.getMillis) / 1000.0 + avgServingSec = + ((avgServingSec * requestCount) + lastServingSec) / + (requestCount + 1) + requestCount += 1 + + respondWithMediaType(`application/json`) { + complete(compact(render(pluginResult))) + } + } catch { + case e: MappingException => + log.error( + s"Query '$queryString' is invalid. Reason: ${e.getMessage}") + args.logUrl map { url => + remoteLog( + url, + args.logPrefix.getOrElse(""), + s"Query:\n$queryString\n\nStack Trace:\n" + + s"${getStackTraceString(e)}\n\n") + } + complete(StatusCodes.BadRequest, e.getMessage) + case e: Throwable => + val msg = s"Query:\n$queryString\n\nStack Trace:\n" + + s"${getStackTraceString(e)}\n\n" + log.error(msg) + args.logUrl map { url => + remoteLog( + url, + args.logPrefix.getOrElse(""), + msg) + } + complete(StatusCodes.InternalServerError, msg) + } + } + } + } + } ~ + path("reload") { + authenticate(withAccessKeyFromFile) { request => + post { + complete { + context.actorSelection("/user/master") ! ReloadServer() + "Reloading..." + } + } + } + } ~ + path("stop") { + authenticate(withAccessKeyFromFile) { request => + post { + complete { + context.system.scheduler.scheduleOnce(1.seconds) { + context.actorSelection("/user/master") ! StopServer() + } + "Shutting down..." + } + } + } + } ~ + pathPrefix("assets") { + getFromResourceDirectory("assets") + } ~ + path("plugins.json") { + import EngineServerJson4sSupport._ + get { + respondWithMediaType(MediaTypes.`application/json`) { + complete { + Map("plugins" -> Map( + "outputblockers" -> pluginContext.outputBlockers.map { case (n, p) => + n -> Map( + "name" -> p.pluginName, + "description" -> p.pluginDescription, + "class" -> p.getClass.getName, + "params" -> pluginContext.pluginParams(p.pluginName)) + }, + "outputsniffers" -> pluginContext.outputSniffers.map { case (n, p) => + n -> Map( + "name" -> p.pluginName, + "description" -> p.pluginDescription, + "class" -> p.getClass.getName, + "params" -> pluginContext.pluginParams(p.pluginName)) + } + )) + } + } + } + } ~ + path("plugins" / Segments) { segments => + import EngineServerJson4sSupport._ + get { + respondWithMediaType(MediaTypes.`application/json`) { + complete { + val pluginArgs = segments.drop(2) + val pluginType = segments(0) + val pluginName = segments(1) + pluginType match { + case EngineServerPlugin.outputSniffer => + pluginsActorRef ? PluginsActor.HandleREST( + pluginName = pluginName, + pluginArgs = pluginArgs) map { + _.asInstanceOf[String] + } + } + } + } + } + } +} + +object EngineServerJson4sSupport extends Json4sSupport { + implicit def json4sFormats: Formats = DefaultFormats +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/CreateWorkflow.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/CreateWorkflow.scala b/core/src/main/scala/org/apache/predictionio/workflow/CreateWorkflow.scala new file mode 100644 index 0000000..a4f3227 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/CreateWorkflow.scala @@ -0,0 +1,274 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import java.net.URI + +import com.github.nscala_time.time.Imports._ +import com.google.common.io.ByteStreams +import grizzled.slf4j.Logging +import org.apache.predictionio.controller.Engine +import org.apache.predictionio.core.BaseEngine +import org.apache.predictionio.data.storage.EngineInstance +import org.apache.predictionio.data.storage.EvaluationInstance +import org.apache.predictionio.data.storage.Storage +import org.apache.predictionio.workflow.JsonExtractorOption.JsonExtractorOption +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.json4s.JValue +import org.json4s.JString +import org.json4s.native.JsonMethods.parse + +import scala.language.existentials + +object CreateWorkflow extends Logging { + + case class WorkflowConfig( + deployMode: String = "", + batch: String = "", + engineId: String = "", + engineVersion: String = "", + engineVariant: String = "", + engineFactory: String = "", + engineParamsKey: String = "", + evaluationClass: Option[String] = None, + engineParamsGeneratorClass: Option[String] = None, + env: Option[String] = None, + skipSanityCheck: Boolean = false, + stopAfterRead: Boolean = false, + stopAfterPrepare: Boolean = false, + verbosity: Int = 0, + verbose: Boolean = false, + debug: Boolean = false, + logFile: Option[String] = None, + jsonExtractor: JsonExtractorOption = JsonExtractorOption.Both) + + case class AlgorithmParams(name: String, params: JValue) + + private def stringFromFile(filePath: String): String = { + try { + val uri = new URI(filePath) + val fs = FileSystem.get(uri, new Configuration()) + new String(ByteStreams.toByteArray(fs.open(new Path(uri))).map(_.toChar)) + } catch { + case e: java.io.IOException => + error(s"Error reading from file: ${e.getMessage}. Aborting workflow.") + sys.exit(1) + } + } + + val parser = new scopt.OptionParser[WorkflowConfig]("CreateWorkflow") { + override def errorOnUnknownArgument: Boolean = false + opt[String]("batch") action { (x, c) => + c.copy(batch = x) + } text("Batch label of the workflow run.") + opt[String]("engine-id") required() action { (x, c) => + c.copy(engineId = x) + } text("Engine's ID.") + opt[String]("engine-version") required() action { (x, c) => + c.copy(engineVersion = x) + } text("Engine's version.") + opt[String]("engine-variant") required() action { (x, c) => + c.copy(engineVariant = x) + } text("Engine variant JSON.") + opt[String]("evaluation-class") action { (x, c) => + c.copy(evaluationClass = Some(x)) + } text("Class name of the run's evaluator.") + opt[String]("engine-params-generator-class") action { (x, c) => + c.copy(engineParamsGeneratorClass = Some(x)) + } text("Path to evaluator parameters") + opt[String]("env") action { (x, c) => + c.copy(env = Some(x)) + } text("Comma-separated list of environmental variables (in 'FOO=BAR' " + + "format) to pass to the Spark execution environment.") + opt[Unit]("verbose") action { (x, c) => + c.copy(verbose = true) + } text("Enable verbose output.") + opt[Unit]("debug") action { (x, c) => + c.copy(debug = true) + } text("Enable debug output.") + opt[Unit]("skip-sanity-check") action { (x, c) => + c.copy(skipSanityCheck = true) + } + opt[Unit]("stop-after-read") action { (x, c) => + c.copy(stopAfterRead = true) + } + opt[Unit]("stop-after-prepare") action { (x, c) => + c.copy(stopAfterPrepare = true) + } + opt[String]("deploy-mode") action { (x, c) => + c.copy(deployMode = x) + } + opt[Int]("verbosity") action { (x, c) => + c.copy(verbosity = x) + } + opt[String]("engine-factory") action { (x, c) => + c.copy(engineFactory = x) + } + opt[String]("engine-params-key") action { (x, c) => + c.copy(engineParamsKey = x) + } + opt[String]("log-file") action { (x, c) => + c.copy(logFile = Some(x)) + } + opt[String]("json-extractor") action { (x, c) => + c.copy(jsonExtractor = JsonExtractorOption.withName(x)) + } + } + + def main(args: Array[String]): Unit = { + val wfcOpt = parser.parse(args, WorkflowConfig()) + if (wfcOpt.isEmpty) { + logger.error("WorkflowConfig is empty. Quitting") + return + } + + val wfc = wfcOpt.get + + WorkflowUtils.modifyLogging(wfc.verbose) + + val evaluation = wfc.evaluationClass.map { ec => + try { + WorkflowUtils.getEvaluation(ec, getClass.getClassLoader)._2 + } catch { + case e @ (_: ClassNotFoundException | _: NoSuchMethodException) => + error(s"Unable to obtain evaluation $ec. Aborting workflow.", e) + sys.exit(1) + } + } + + val engineParamsGenerator = wfc.engineParamsGeneratorClass.map { epg => + try { + WorkflowUtils.getEngineParamsGenerator(epg, getClass.getClassLoader)._2 + } catch { + case e @ (_: ClassNotFoundException | _: NoSuchMethodException) => + error(s"Unable to obtain engine parameters generator $epg. " + + "Aborting workflow.", e) + sys.exit(1) + } + } + + val pioEnvVars = wfc.env.map(e => + e.split(',').flatMap(p => + p.split('=') match { + case Array(k, v) => List(k -> v) + case _ => Nil + } + ).toMap + ).getOrElse(Map()) + + if (evaluation.isEmpty) { + val variantJson = parse(stringFromFile(wfc.engineVariant)) + val engineFactory = if (wfc.engineFactory == "") { + variantJson \ "engineFactory" match { + case JString(s) => s + case _ => + error("Unable to read engine factory class name from " + + s"${wfc.engineVariant}. Aborting.") + sys.exit(1) + } + } else wfc.engineFactory + val variantId = variantJson \ "id" match { + case JString(s) => s + case _ => + error("Unable to read engine variant ID from " + + s"${wfc.engineVariant}. Aborting.") + sys.exit(1) + } + val (engineLanguage, engineFactoryObj) = try { + WorkflowUtils.getEngine(engineFactory, getClass.getClassLoader) + } catch { + case e @ (_: ClassNotFoundException | _: NoSuchMethodException) => + error(s"Unable to obtain engine: ${e.getMessage}. Aborting workflow.") + sys.exit(1) + } + + val engine: BaseEngine[_, _, _, _] = engineFactoryObj() + + val customSparkConf = WorkflowUtils.extractSparkConf(variantJson) + val workflowParams = WorkflowParams( + verbose = wfc.verbosity, + skipSanityCheck = wfc.skipSanityCheck, + stopAfterRead = wfc.stopAfterRead, + stopAfterPrepare = wfc.stopAfterPrepare, + sparkEnv = WorkflowParams().sparkEnv ++ customSparkConf) + + // Evaluator Not Specified. Do training. + if (!engine.isInstanceOf[Engine[_,_,_,_,_,_]]) { + throw new NoSuchMethodException(s"Engine $engine is not trainable") + } + + val trainableEngine = engine.asInstanceOf[Engine[_, _, _, _, _, _]] + + val engineParams = if (wfc.engineParamsKey == "") { + trainableEngine.jValueToEngineParams(variantJson, wfc.jsonExtractor) + } else { + engineFactoryObj.engineParams(wfc.engineParamsKey) + } + + val engineInstance = EngineInstance( + id = "", + status = "INIT", + startTime = DateTime.now, + endTime = DateTime.now, + engineId = wfc.engineId, + engineVersion = wfc.engineVersion, + engineVariant = variantId, + engineFactory = engineFactory, + batch = wfc.batch, + env = pioEnvVars, + sparkConf = workflowParams.sparkEnv, + dataSourceParams = + JsonExtractor.paramToJson(wfc.jsonExtractor, engineParams.dataSourceParams), + preparatorParams = + JsonExtractor.paramToJson(wfc.jsonExtractor, engineParams.preparatorParams), + algorithmsParams = + JsonExtractor.paramsToJson(wfc.jsonExtractor, engineParams.algorithmParamsList), + servingParams = + JsonExtractor.paramToJson(wfc.jsonExtractor, engineParams.servingParams)) + + val engineInstanceId = Storage.getMetaDataEngineInstances.insert( + engineInstance) + + CoreWorkflow.runTrain( + env = pioEnvVars, + params = workflowParams, + engine = trainableEngine, + engineParams = engineParams, + engineInstance = engineInstance.copy(id = engineInstanceId)) + } else { + val workflowParams = WorkflowParams( + verbose = wfc.verbosity, + skipSanityCheck = wfc.skipSanityCheck, + stopAfterRead = wfc.stopAfterRead, + stopAfterPrepare = wfc.stopAfterPrepare, + sparkEnv = WorkflowParams().sparkEnv) + val evaluationInstance = EvaluationInstance( + evaluationClass = wfc.evaluationClass.get, + engineParamsGeneratorClass = wfc.engineParamsGeneratorClass.get, + batch = wfc.batch, + env = pioEnvVars, + sparkConf = workflowParams.sparkEnv + ) + Workflow.runEvaluation( + evaluation = evaluation.get, + engineParamsGenerator = engineParamsGenerator.get, + evaluationInstance = evaluationInstance, + params = workflowParams) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPlugin.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPlugin.scala b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPlugin.scala new file mode 100644 index 0000000..5393e71 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPlugin.scala @@ -0,0 +1,40 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import org.apache.predictionio.data.storage.EngineInstance +import org.json4s._ + +trait EngineServerPlugin { + val pluginName: String + val pluginDescription: String + val pluginType: String + + def start(context: EngineServerPluginContext): Unit + + def process( + engineInstance: EngineInstance, + query: JValue, + prediction: JValue, + context: EngineServerPluginContext): JValue + + def handleREST(arguments: Seq[String]): String +} + +object EngineServerPlugin { + val outputBlocker = "outputblocker" + val outputSniffer = "outputsniffer" +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala new file mode 100644 index 0000000..3742e01 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala @@ -0,0 +1,88 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import java.net.URI +import java.util.ServiceLoader + +import akka.event.LoggingAdapter +import com.google.common.io.ByteStreams +import grizzled.slf4j.Logging +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.fs.FileSystem +import org.apache.hadoop.fs.Path +import org.json4s.DefaultFormats +import org.json4s.Formats +import org.json4s.JObject +import org.json4s.JValue +import org.json4s.native.JsonMethods._ + +import scala.collection.JavaConversions._ +import scala.collection.mutable + +class EngineServerPluginContext( + val plugins: mutable.Map[String, mutable.Map[String, EngineServerPlugin]], + val pluginParams: mutable.Map[String, JValue], + val log: LoggingAdapter) { + def outputBlockers: Map[String, EngineServerPlugin] = + plugins.getOrElse(EngineServerPlugin.outputBlocker, Map()).toMap + def outputSniffers: Map[String, EngineServerPlugin] = + plugins.getOrElse(EngineServerPlugin.outputSniffer, Map()).toMap +} + +object EngineServerPluginContext extends Logging { + implicit val formats: Formats = DefaultFormats + + def apply(log: LoggingAdapter, engineVariant: String): EngineServerPluginContext = { + val plugins = mutable.Map[String, mutable.Map[String, EngineServerPlugin]]( + EngineServerPlugin.outputBlocker -> mutable.Map(), + EngineServerPlugin.outputSniffer -> mutable.Map()) + val pluginParams = mutable.Map[String, JValue]() + val serviceLoader = ServiceLoader.load(classOf[EngineServerPlugin]) + val variantJson = parse(stringFromFile(engineVariant)) + (variantJson \ "plugins").extractOpt[JObject].foreach { pluginDefs => + pluginDefs.obj.foreach { pluginParams += _ } + } + serviceLoader foreach { service => + pluginParams.get(service.pluginName) map { params => + if ((params \ "enabled").extractOrElse(false)) { + info(s"Plugin ${service.pluginName} is enabled.") + plugins(service.pluginType) += service.pluginName -> service + } else { + info(s"Plugin ${service.pluginName} is disabled.") + } + } getOrElse { + info(s"Plugin ${service.pluginName} is disabled.") + } + } + new EngineServerPluginContext( + plugins, + pluginParams, + log) + } + + private def stringFromFile(filePath: String): String = { + try { + val uri = new URI(filePath) + val fs = FileSystem.get(uri, new Configuration()) + new String(ByteStreams.toByteArray(fs.open(new Path(uri))).map(_.toChar)) + } catch { + case e: java.io.IOException => + error(s"Error reading from file: ${e.getMessage}. Aborting.") + sys.exit(1) + } + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginsActor.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginsActor.scala b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginsActor.scala new file mode 100644 index 0000000..0068751 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginsActor.scala @@ -0,0 +1,46 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import akka.actor.Actor +import akka.event.Logging +import org.apache.predictionio.data.storage.EngineInstance +import org.json4s.JValue + +class PluginsActor(engineVariant: String) extends Actor { + implicit val system = context.system + val log = Logging(system, this) + + val pluginContext = EngineServerPluginContext(log, engineVariant) + + def receive: PartialFunction[Any, Unit] = { + case (ei: EngineInstance, q: JValue, p: JValue) => + pluginContext.outputSniffers.values.foreach(_.process(ei, q, p, pluginContext)) + case h: PluginsActor.HandleREST => + try { + sender() ! pluginContext.outputSniffers(h.pluginName).handleREST(h.pluginArgs) + } catch { + case e: Exception => + sender() ! s"""{"message":"${e.getMessage}"}""" + } + case _ => + log.error("Unknown message sent to the Engine Server output sniffer plugin host.") + } +} + +object PluginsActor { + case class HandleREST(pluginName: String, pluginArgs: Seq[String]) +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/EvaluationWorkflow.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/EvaluationWorkflow.scala b/core/src/main/scala/org/apache/predictionio/workflow/EvaluationWorkflow.scala new file mode 100644 index 0000000..6c7e731 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/EvaluationWorkflow.scala @@ -0,0 +1,42 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import org.apache.predictionio.controller.EngineParams +import org.apache.predictionio.controller.Evaluation +import org.apache.predictionio.core.BaseEvaluator +import org.apache.predictionio.core.BaseEvaluatorResult +import org.apache.predictionio.core.BaseEngine + +import grizzled.slf4j.Logger +import org.apache.spark.SparkContext + +import scala.language.existentials + +object EvaluationWorkflow { + @transient lazy val logger = Logger[this.type] + def runEvaluation[EI, Q, P, A, R <: BaseEvaluatorResult]( + sc: SparkContext, + evaluation: Evaluation, + engine: BaseEngine[EI, Q, P, A], + engineParamsList: Seq[EngineParams], + evaluator: BaseEvaluator[EI, Q, P, A, R], + params: WorkflowParams) + : R = { + val engineEvalDataSet = engine.batchEval(sc, engineParamsList, params) + evaluator.evaluateBase(sc, evaluation, engineEvalDataSet, params) + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/FakeWorkflow.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/FakeWorkflow.scala b/core/src/main/scala/org/apache/predictionio/workflow/FakeWorkflow.scala new file mode 100644 index 0000000..f11ea2e --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/FakeWorkflow.scala @@ -0,0 +1,106 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import org.apache.predictionio.annotation.Experimental +// FIXME(yipjustin): Remove wildcard import. +import org.apache.predictionio.core._ +import org.apache.predictionio.controller._ + +import grizzled.slf4j.Logger +import org.apache.spark.SparkContext +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd.RDD + + +@Experimental +private[prediction] class FakeEngine +extends BaseEngine[EmptyParams, EmptyParams, EmptyParams, EmptyParams] { + @transient lazy val logger = Logger[this.type] + + def train( + sc: SparkContext, + engineParams: EngineParams, + engineInstanceId: String, + params: WorkflowParams): Seq[Any] = { + throw new StopAfterReadInterruption() + } + + def eval( + sc: SparkContext, + engineParams: EngineParams, + params: WorkflowParams) + : Seq[(EmptyParams, RDD[(EmptyParams, EmptyParams, EmptyParams)])] = { + return Seq[(EmptyParams, RDD[(EmptyParams, EmptyParams, EmptyParams)])]() + } +} + +@Experimental +private[prediction] class FakeRunner(f: (SparkContext => Unit)) + extends BaseEvaluator[EmptyParams, EmptyParams, EmptyParams, EmptyParams, + FakeEvalResult] { + @transient private lazy val logger = Logger[this.type] + def evaluateBase( + sc: SparkContext, + evaluation: Evaluation, + engineEvalDataSet: + Seq[(EngineParams, Seq[(EmptyParams, RDD[(EmptyParams, EmptyParams, EmptyParams)])])], + params: WorkflowParams): FakeEvalResult = { + f(sc) + FakeEvalResult() + } +} + +@Experimental +private[prediction] case class FakeEvalResult() extends BaseEvaluatorResult { + override val noSave: Boolean = true +} + +/** FakeRun allows user to implement custom function under the exact enviroment + * as other PredictionIO workflow. + * + * Useful for developing new features. Only need to extend this trait and + * implement a function: (SparkContext => Unit). For example, the code below + * can be run with `pio eval HelloWorld`. + * + * {{{ + * object HelloWorld extends FakeRun { + * // func defines the function pio runs, must have signature (SparkContext => Unit). + * func = f + * + * def f(sc: SparkContext): Unit { + * val logger = Logger[this.type] + * logger.info("HelloWorld") + * } + * } + * }}} + * + */ +@Experimental +trait FakeRun extends Evaluation with EngineParamsGenerator { + private[this] var _runner: FakeRunner = _ + + def runner: FakeRunner = _runner + def runner_=(r: FakeRunner) { + engineEvaluator = (new FakeEngine(), r) + engineParamsList = Seq(new EngineParams()) + } + + def func: (SparkContext => Unit) = { (sc: SparkContext) => Unit } + def func_=(f: SparkContext => Unit) { + runner = new FakeRunner(f) + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala b/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala new file mode 100644 index 0000000..b9737a6 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala @@ -0,0 +1,164 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import com.google.gson.Gson +import com.google.gson.GsonBuilder +import com.google.gson.TypeAdapterFactory +import org.apache.predictionio.controller.EngineParams +import org.apache.predictionio.controller.Params +import org.apache.predictionio.controller.Utils +import org.apache.predictionio.workflow.JsonExtractorOption.JsonExtractorOption +import org.json4s.Extraction +import org.json4s.Formats +import org.json4s.JsonAST.{JArray, JValue} +import org.json4s.native.JsonMethods.compact +import org.json4s.native.JsonMethods.pretty +import org.json4s.native.JsonMethods.parse +import org.json4s.native.JsonMethods.render +import org.json4s.reflect.TypeInfo + +object JsonExtractor { + + def toJValue( + extractorOption: JsonExtractorOption, + o: Any, + json4sFormats: Formats = Utils.json4sDefaultFormats, + gsonTypeAdapterFactories: Seq[TypeAdapterFactory] = Seq.empty[TypeAdapterFactory]): JValue = { + + extractorOption match { + case JsonExtractorOption.Both => + + val json4sResult = Extraction.decompose(o)(json4sFormats) + json4sResult.children.size match { + case 0 => parse(gson(gsonTypeAdapterFactories).toJson(o)) + case _ => json4sResult + } + case JsonExtractorOption.Json4sNative => + Extraction.decompose(o)(json4sFormats) + case JsonExtractorOption.Gson => + parse(gson(gsonTypeAdapterFactories).toJson(o)) + } + } + + def extract[T]( + extractorOption: JsonExtractorOption, + json: String, + clazz: Class[T], + json4sFormats: Formats = Utils.json4sDefaultFormats, + gsonTypeAdapterFactories: Seq[TypeAdapterFactory] = Seq.empty[TypeAdapterFactory]): T = { + + extractorOption match { + case JsonExtractorOption.Both => + try { + extractWithJson4sNative(json, json4sFormats, clazz) + } catch { + case e: Exception => + extractWithGson(json, clazz, gsonTypeAdapterFactories) + } + case JsonExtractorOption.Json4sNative => + extractWithJson4sNative(json, json4sFormats, clazz) + case JsonExtractorOption.Gson => + extractWithGson(json, clazz, gsonTypeAdapterFactories) + } + } + + def paramToJson(extractorOption: JsonExtractorOption, param: (String, Params)): String = { + // to be replaced JValue needs to be done by Json4s, otherwise the tuple JValue will be wrong + val toBeReplacedJValue = + JsonExtractor.toJValue(JsonExtractorOption.Json4sNative, (param._1, null)) + val paramJValue = JsonExtractor.toJValue(extractorOption, param._2) + + compact(render(toBeReplacedJValue.replace(param._1 :: Nil, paramJValue))) + } + + def paramsToJson(extractorOption: JsonExtractorOption, params: Seq[(String, Params)]): String = { + compact(render(paramsToJValue(extractorOption, params))) + } + + def engineParamsToJson(extractorOption: JsonExtractorOption, params: EngineParams) : String = { + compact(render(engineParamsToJValue(extractorOption, params))) + } + + def engineParamstoPrettyJson( + extractorOption: JsonExtractorOption, + params: EngineParams) : String = { + + pretty(render(engineParamsToJValue(extractorOption, params))) + } + + private def engineParamsToJValue(extractorOption: JsonExtractorOption, params: EngineParams) = { + var jValue = toJValue(JsonExtractorOption.Json4sNative, params) + + val dataSourceParamsJValue = toJValue(extractorOption, params.dataSourceParams._2) + jValue = jValue.replace( + "dataSourceParams" :: params.dataSourceParams._1 :: Nil, + dataSourceParamsJValue) + + val preparatorParamsJValue = toJValue(extractorOption, params.preparatorParams._2) + jValue = jValue.replace( + "preparatorParams" :: params.preparatorParams._1 :: Nil, + preparatorParamsJValue) + + val algorithmParamsJValue = paramsToJValue(extractorOption, params.algorithmParamsList) + jValue = jValue.replace("algorithmParamsList" :: Nil, algorithmParamsJValue) + + val servingParamsJValue = toJValue(extractorOption, params.servingParams._2) + jValue = jValue.replace("servingParams" :: params.servingParams._1 :: Nil, servingParamsJValue) + + jValue + } + + private + def paramsToJValue(extractorOption: JsonExtractorOption, params: Seq[(String, Params)]) = { + val jValues = params.map { case (name, param) => + // to be replaced JValue needs to be done by Json4s, otherwise the tuple JValue will be wrong + val toBeReplacedJValue = + JsonExtractor.toJValue(JsonExtractorOption.Json4sNative, (name, null)) + val paramJValue = JsonExtractor.toJValue(extractorOption, param) + + toBeReplacedJValue.replace(name :: Nil, paramJValue) + } + + JArray(jValues.toList) + } + + private def extractWithJson4sNative[T]( + json: String, + formats: Formats, + clazz: Class[T]): T = { + + Extraction.extract(parse(json), TypeInfo(clazz, None))(formats).asInstanceOf[T] + } + + private def extractWithGson[T]( + json: String, + clazz: Class[T], + gsonTypeAdapterFactories: Seq[TypeAdapterFactory]): T = { + + gson(gsonTypeAdapterFactories).fromJson(json, clazz) + } + + private def gson(gsonTypeAdapterFactories: Seq[TypeAdapterFactory]): Gson = { + val gsonBuilder = new GsonBuilder() + gsonTypeAdapterFactories.foreach { typeAdapterFactory => + gsonBuilder.registerTypeAdapterFactory(typeAdapterFactory) + } + + gsonBuilder.create() + } + +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractorOption.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractorOption.scala b/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractorOption.scala new file mode 100644 index 0000000..a7915a6 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractorOption.scala @@ -0,0 +1,23 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +object JsonExtractorOption extends Enumeration { + type JsonExtractorOption = Value + val Json4sNative = Value + val Gson = Value + val Both = Value +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/PersistentModelManifest.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/PersistentModelManifest.scala b/core/src/main/scala/org/apache/predictionio/workflow/PersistentModelManifest.scala new file mode 100644 index 0000000..7cf7ede --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/PersistentModelManifest.scala @@ -0,0 +1,18 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +case class PersistentModelManifest(className: String) http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/Workflow.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/Workflow.scala b/core/src/main/scala/org/apache/predictionio/workflow/Workflow.scala new file mode 100644 index 0000000..d88c8d0 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/Workflow.scala @@ -0,0 +1,135 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import org.apache.predictionio.annotation.Experimental +import org.apache.predictionio.controller.EngineParams +import org.apache.predictionio.controller.EngineParamsGenerator +import org.apache.predictionio.controller.Evaluation +import org.apache.predictionio.core.BaseEngine +import org.apache.predictionio.core.BaseEvaluator +import org.apache.predictionio.core.BaseEvaluatorResult +import org.apache.predictionio.data.storage.EvaluationInstance + +/** Collection of workflow creation methods. + * @group Workflow + */ +object Workflow { + // evaluator is already instantiated. + // This is an undocumented way of using evaluator. Still experimental. + // evaluatorParams is used to write into EngineInstance, will be shown in + // dashboard. + /* + def runEval[EI, Q, P, A, ER <: AnyRef]( + engine: BaseEngine[EI, Q, P, A], + engineParams: EngineParams, + evaluator: BaseEvaluator[EI, Q, P, A, ER], + evaluatorParams: Params, + env: Map[String, String] = WorkflowUtils.pioEnvVars, + params: WorkflowParams = WorkflowParams()) { + + implicit lazy val formats = Utils.json4sDefaultFormats + + new NameParamsSerializer + + val engineInstance = EngineInstance( + id = "", + status = "INIT", + startTime = DateTime.now, + endTime = DateTime.now, + engineId = "", + engineVersion = "", + engineVariant = "", + engineFactory = "FIXME", + evaluatorClass = evaluator.getClass.getName(), + batch = params.batch, + env = env, + sparkConf = params.sparkEnv, + dataSourceParams = write(engineParams.dataSourceParams), + preparatorParams = write(engineParams.preparatorParams), + algorithmsParams = write(engineParams.algorithmParamsList), + servingParams = write(engineParams.servingParams), + evaluatorParams = write(evaluatorParams), + evaluatorResults = "", + evaluatorResultsHTML = "", + evaluatorResultsJSON = "") + + CoreWorkflow.runEval( + engine = engine, + engineParams = engineParams, + engineInstance = engineInstance, + evaluator = evaluator, + evaluatorParams = evaluatorParams, + env = env, + params = params) + } + */ + + def runEvaluation( + evaluation: Evaluation, + engineParamsGenerator: EngineParamsGenerator, + env: Map[String, String] = WorkflowUtils.pioEnvVars, + evaluationInstance: EvaluationInstance = EvaluationInstance(), + params: WorkflowParams = WorkflowParams()) { + runEvaluationTypeless( + evaluation = evaluation, + engine = evaluation.engine, + engineParamsList = engineParamsGenerator.engineParamsList, + evaluationInstance = evaluationInstance, + evaluator = evaluation.evaluator, + env = env, + params = params + ) + } + + def runEvaluationTypeless[ + EI, Q, P, A, EEI, EQ, EP, EA, ER <: BaseEvaluatorResult]( + evaluation: Evaluation, + engine: BaseEngine[EI, Q, P, A], + engineParamsList: Seq[EngineParams], + evaluationInstance: EvaluationInstance, + evaluator: BaseEvaluator[EEI, EQ, EP, EA, ER], + env: Map[String, String] = WorkflowUtils.pioEnvVars, + params: WorkflowParams = WorkflowParams()) { + runEvaluationViaCoreWorkflow( + evaluation = evaluation, + engine = engine, + engineParamsList = engineParamsList, + evaluationInstance = evaluationInstance, + evaluator = evaluator.asInstanceOf[BaseEvaluator[EI, Q, P, A, ER]], + env = env, + params = params) + } + + /** :: Experimental :: */ + @Experimental + def runEvaluationViaCoreWorkflow[EI, Q, P, A, R <: BaseEvaluatorResult]( + evaluation: Evaluation, + engine: BaseEngine[EI, Q, P, A], + engineParamsList: Seq[EngineParams], + evaluationInstance: EvaluationInstance, + evaluator: BaseEvaluator[EI, Q, P, A, R], + env: Map[String, String] = WorkflowUtils.pioEnvVars, + params: WorkflowParams = WorkflowParams()) { + CoreWorkflow.runEvaluation( + evaluation = evaluation, + engine = engine, + engineParamsList = engineParamsList, + evaluationInstance = evaluationInstance, + evaluator = evaluator, + env = env, + params = params) + } +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/WorkflowContext.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/WorkflowContext.scala b/core/src/main/scala/org/apache/predictionio/workflow/WorkflowContext.scala new file mode 100644 index 0000000..2abb79a --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/WorkflowContext.scala @@ -0,0 +1,45 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import grizzled.slf4j.Logging +import org.apache.spark.SparkContext +import org.apache.spark.SparkConf + +import scala.language.existentials + +// FIXME: move to better location. +object WorkflowContext extends Logging { + def apply( + batch: String = "", + executorEnv: Map[String, String] = Map(), + sparkEnv: Map[String, String] = Map(), + mode: String = "" + ): SparkContext = { + val conf = new SparkConf() + val prefix = if (mode == "") "PredictionIO" else s"PredictionIO ${mode}" + conf.setAppName(s"${prefix}: ${batch}") + debug(s"Executor environment received: ${executorEnv}") + executorEnv.map(kv => conf.setExecutorEnv(kv._1, kv._2)) + debug(s"SparkConf executor environment: ${conf.getExecutorEnv}") + debug(s"Application environment received: ${sparkEnv}") + conf.setAll(sparkEnv) + val sparkConfString = conf.getAll.toSeq + debug(s"SparkConf environment: $sparkConfString") + new SparkContext(conf) + } +} + http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/4f03388e/core/src/main/scala/org/apache/predictionio/workflow/WorkflowParams.scala ---------------------------------------------------------------------- diff --git a/core/src/main/scala/org/apache/predictionio/workflow/WorkflowParams.scala b/core/src/main/scala/org/apache/predictionio/workflow/WorkflowParams.scala new file mode 100644 index 0000000..8727a50 --- /dev/null +++ b/core/src/main/scala/org/apache/predictionio/workflow/WorkflowParams.scala @@ -0,0 +1,42 @@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +/** Workflow parameters. + * + * @param batch Batch label of the run. + * @param verbose Verbosity level. + * @param saveModel Controls whether trained models are persisted. + * @param sparkEnv Spark properties that will be set in SparkConf.setAll(). + * @param skipSanityCheck Skips all data sanity check. + * @param stopAfterRead Stops workflow after reading from data source. + * @param stopAfterPrepare Stops workflow after data preparation. + * @group Workflow + */ +case class WorkflowParams( + batch: String = "", + verbose: Int = 2, + saveModel: Boolean = true, + sparkEnv: Map[String, String] = + Map[String, String]("spark.executor.extraClassPath" -> "."), + skipSanityCheck: Boolean = false, + stopAfterRead: Boolean = false, + stopAfterPrepare: Boolean = false) { + // Temporary workaround for WorkflowParamsBuilder for Java. It doesn't support + // custom spark environment yet. + def this(batch: String, verbose: Int, saveModel: Boolean) + = this(batch, verbose, saveModel, Map[String, String]()) +}
