Merge branch 'ziemin/disable_upgrade' into develop
Project: http://git-wip-us.apache.org/repos/asf/incubator-predictionio/repo Commit: http://git-wip-us.apache.org/repos/asf/incubator-predictionio/commit/c508c791 Tree: http://git-wip-us.apache.org/repos/asf/incubator-predictionio/tree/c508c791 Diff: http://git-wip-us.apache.org/repos/asf/incubator-predictionio/diff/c508c791 Branch: refs/heads/develop Commit: c508c791d85061bc053fbffc704ac953be814234 Parents: 02a5655 3b87275 Author: Donald Szeto <[email protected]> Authored: Mon Jul 18 13:28:06 2016 -0700 Committer: Donald Szeto <[email protected]> Committed: Mon Jul 18 13:28:06 2016 -0700 ---------------------------------------------------------------------- .../predictionio/workflow/CoreWorkflow.scala | 3 -- .../predictionio/workflow/CreateServer.scala | 18 ----------- .../predictionio/workflow/WorkflowUtils.scala | 32 -------------------- .../predictionio/tools/console/Console.scala | 9 +----- 4 files changed, 1 insertion(+), 61 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/c508c791/core/src/main/scala/org/apache/predictionio/workflow/CoreWorkflow.scala ---------------------------------------------------------------------- diff --cc core/src/main/scala/org/apache/predictionio/workflow/CoreWorkflow.scala index 6a27e87,0000000..513b22b mode 100644,000000..100644 --- a/core/src/main/scala/org/apache/predictionio/workflow/CoreWorkflow.scala +++ b/core/src/main/scala/org/apache/predictionio/workflow/CoreWorkflow.scala @@@ -1,163 -1,0 +1,160 @@@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import org.apache.predictionio.controller.EngineParams +import org.apache.predictionio.controller.Evaluation +import org.apache.predictionio.core.BaseEngine +import org.apache.predictionio.core.BaseEvaluator +import org.apache.predictionio.core.BaseEvaluatorResult +import org.apache.predictionio.data.storage.EngineInstance +import org.apache.predictionio.data.storage.EvaluationInstance +import org.apache.predictionio.data.storage.Model +import org.apache.predictionio.data.storage.Storage + +import com.github.nscala_time.time.Imports.DateTime +import grizzled.slf4j.Logger + +import scala.language.existentials + +/** CoreWorkflow handles PredictionIO metadata and environment variables of + * training and evaluation. + */ +object CoreWorkflow { + @transient lazy val logger = Logger[this.type] + @transient lazy val engineInstances = Storage.getMetaDataEngineInstances + @transient lazy val evaluationInstances = + Storage.getMetaDataEvaluationInstances() + + def runTrain[EI, Q, P, A]( + engine: BaseEngine[EI, Q, P, A], + engineParams: EngineParams, + engineInstance: EngineInstance, + env: Map[String, String] = WorkflowUtils.pioEnvVars, + params: WorkflowParams = WorkflowParams()) { + logger.debug("Starting SparkContext") + val mode = "training" - WorkflowUtils.checkUpgrade(mode, engineInstance.engineFactory) + + val batch = if (params.batch.nonEmpty) { + s"{engineInstance.engineFactory} (${params.batch}})" + } else { + engineInstance.engineFactory + } + val sc = WorkflowContext( + batch, + env, + params.sparkEnv, + mode.capitalize) + + try { + + val models: Seq[Any] = engine.train( + sc = sc, + engineParams = engineParams, + engineInstanceId = engineInstance.id, + params = params + ) + + val instanceId = Storage.getMetaDataEngineInstances + + val kryo = KryoInstantiator.newKryoInjection + + logger.info("Inserting persistent model") + Storage.getModelDataModels.insert(Model( + id = engineInstance.id, + models = kryo(models))) + + logger.info("Updating engine instance") + val engineInstances = Storage.getMetaDataEngineInstances + engineInstances.update(engineInstance.copy( + status = "COMPLETED", + endTime = DateTime.now + )) + + logger.info("Training completed successfully.") + } catch { + case e @( + _: StopAfterReadInterruption | + _: StopAfterPrepareInterruption) => { + logger.info(s"Training interrupted by $e.") + } + } finally { + logger.debug("Stopping SparkContext") + sc.stop() + } + } + + def runEvaluation[EI, Q, P, A, R <: BaseEvaluatorResult]( + evaluation: Evaluation, + engine: BaseEngine[EI, Q, P, A], + engineParamsList: Seq[EngineParams], + evaluationInstance: EvaluationInstance, + evaluator: BaseEvaluator[EI, Q, P, A, R], + env: Map[String, String] = WorkflowUtils.pioEnvVars, + params: WorkflowParams = WorkflowParams()) { + logger.info("runEvaluation started") + logger.debug("Start SparkContext") + + val mode = "evaluation" + - WorkflowUtils.checkUpgrade(mode, engine.getClass.getName) - + val batch = if (params.batch.nonEmpty) { + s"{evaluation.getClass.getName} (${params.batch}})" + } else { + evaluation.getClass.getName + } + val sc = WorkflowContext( + batch, + env, + params.sparkEnv, + mode.capitalize) + val evaluationInstanceId = evaluationInstances.insert(evaluationInstance) + + logger.info(s"Starting evaluation instance ID: $evaluationInstanceId") + + val evaluatorResult: BaseEvaluatorResult = EvaluationWorkflow.runEvaluation( + sc, + evaluation, + engine, + engineParamsList, + evaluator, + params) + + if (evaluatorResult.noSave) { + logger.info(s"This evaluation result is not inserted into database: $evaluatorResult") + } else { + val evaluatedEvaluationInstance = evaluationInstance.copy( + status = "EVALCOMPLETED", + id = evaluationInstanceId, + endTime = DateTime.now, + evaluatorResults = evaluatorResult.toOneLiner, + evaluatorResultsHTML = evaluatorResult.toHTML, + evaluatorResultsJSON = evaluatorResult.toJSON + ) + + logger.info(s"Updating evaluation instance with result: $evaluatorResult") + + evaluationInstances.update(evaluatedEvaluationInstance) + } + + logger.debug("Stop SparkContext") + + sc.stop() + + logger.info("runEvaluation completed") + } +} + + http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/c508c791/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala ---------------------------------------------------------------------- diff --cc core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala index d4f6323,0000000..9e12b35 mode 100644,000000..100644 --- a/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala +++ b/core/src/main/scala/org/apache/predictionio/workflow/CreateServer.scala @@@ -1,737 -1,0 +1,719 @@@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import java.io.PrintWriter +import java.io.Serializable +import java.io.StringWriter +import java.util.concurrent.TimeUnit + +import akka.actor._ +import akka.event.Logging +import akka.io.IO +import akka.pattern.ask +import akka.util.Timeout +import com.github.nscala_time.time.Imports.DateTime +import com.twitter.bijection.Injection +import com.twitter.chill.KryoBase +import com.twitter.chill.KryoInjection +import com.twitter.chill.ScalaKryoInstantiator +import com.typesafe.config.ConfigFactory +import de.javakaffee.kryoserializers.SynchronizedCollectionsSerializer +import grizzled.slf4j.Logging +import org.apache.predictionio.authentication.KeyAuthentication +import org.apache.predictionio.configuration.SSLConfiguration +import org.apache.predictionio.controller.Engine +import org.apache.predictionio.controller.Params +import org.apache.predictionio.controller.Utils +import org.apache.predictionio.controller.WithPrId +import org.apache.predictionio.core.BaseAlgorithm +import org.apache.predictionio.core.BaseServing +import org.apache.predictionio.core.Doer +import org.apache.predictionio.data.storage.EngineInstance +import org.apache.predictionio.data.storage.EngineManifest +import org.apache.predictionio.data.storage.Storage +import org.apache.predictionio.workflow.JsonExtractorOption.JsonExtractorOption +import org.json4s._ +import org.json4s.native.JsonMethods._ +import org.json4s.native.Serialization.write +import spray.can.Http +import spray.can.server.ServerSettings +import spray.http.MediaTypes._ +import spray.http._ +import spray.httpx.Json4sSupport +import spray.routing._ +import spray.routing.authentication.{UserPass, BasicAuth} + +import scala.concurrent.ExecutionContext.Implicits.global +import scala.concurrent.Future +import scala.concurrent.duration._ +import scala.concurrent.future +import scala.language.existentials +import scala.util.Failure +import scala.util.Random +import scala.util.Success +import scalaj.http.HttpOptions + +class KryoInstantiator(classLoader: ClassLoader) extends ScalaKryoInstantiator { + override def newKryo(): KryoBase = { + val kryo = super.newKryo() + kryo.setClassLoader(classLoader) + SynchronizedCollectionsSerializer.registerSerializers(kryo) + kryo + } +} + +object KryoInstantiator extends Serializable { + def newKryoInjection : Injection[Any, Array[Byte]] = { + val kryoInstantiator = new KryoInstantiator(getClass.getClassLoader) + KryoInjection.instance(kryoInstantiator) + } +} + +case class ServerConfig( + batch: String = "", + engineInstanceId: String = "", + engineId: Option[String] = None, + engineVersion: Option[String] = None, + engineVariant: String = "", + env: Option[String] = None, + ip: String = "0.0.0.0", + port: Int = 8000, + feedback: Boolean = false, + eventServerIp: String = "0.0.0.0", + eventServerPort: Int = 7070, + accessKey: Option[String] = None, + logUrl: Option[String] = None, + logPrefix: Option[String] = None, + logFile: Option[String] = None, + verbose: Boolean = false, + debug: Boolean = false, + jsonExtractor: JsonExtractorOption = JsonExtractorOption.Both) + +case class StartServer() +case class BindServer() +case class StopServer() +case class ReloadServer() - case class UpgradeCheck() + + +object CreateServer extends Logging { + val actorSystem = ActorSystem("pio-server") + val engineInstances = Storage.getMetaDataEngineInstances + val engineManifests = Storage.getMetaDataEngineManifests + val modeldata = Storage.getModelDataModels + + def main(args: Array[String]): Unit = { + val parser = new scopt.OptionParser[ServerConfig]("CreateServer") { + opt[String]("batch") action { (x, c) => + c.copy(batch = x) + } text("Batch label of the deployment.") + opt[String]("engineId") action { (x, c) => + c.copy(engineId = Some(x)) + } text("Engine ID.") + opt[String]("engineVersion") action { (x, c) => + c.copy(engineVersion = Some(x)) + } text("Engine version.") + opt[String]("engine-variant") required() action { (x, c) => + c.copy(engineVariant = x) + } text("Engine variant JSON.") + opt[String]("ip") action { (x, c) => + c.copy(ip = x) + } + opt[String]("env") action { (x, c) => + c.copy(env = Some(x)) + } text("Comma-separated list of environmental variables (in 'FOO=BAR' " + + "format) to pass to the Spark execution environment.") + opt[Int]("port") action { (x, c) => + c.copy(port = x) + } text("Port to bind to (default: 8000).") + opt[String]("engineInstanceId") required() action { (x, c) => + c.copy(engineInstanceId = x) + } text("Engine instance ID.") + opt[Unit]("feedback") action { (_, c) => + c.copy(feedback = true) + } text("Enable feedback loop to event server.") + opt[String]("event-server-ip") action { (x, c) => + c.copy(eventServerIp = x) + } + opt[Int]("event-server-port") action { (x, c) => + c.copy(eventServerPort = x) + } text("Event server port. Default: 7070") + opt[String]("accesskey") action { (x, c) => + c.copy(accessKey = Some(x)) + } text("Event server access key.") + opt[String]("log-url") action { (x, c) => + c.copy(logUrl = Some(x)) + } + opt[String]("log-prefix") action { (x, c) => + c.copy(logPrefix = Some(x)) + } + opt[String]("log-file") action { (x, c) => + c.copy(logFile = Some(x)) + } + opt[Unit]("verbose") action { (x, c) => + c.copy(verbose = true) + } text("Enable verbose output.") + opt[Unit]("debug") action { (x, c) => + c.copy(debug = true) + } text("Enable debug output.") + opt[String]("json-extractor") action { (x, c) => + c.copy(jsonExtractor = JsonExtractorOption.withName(x)) + } + } + + parser.parse(args, ServerConfig()) map { sc => + WorkflowUtils.modifyLogging(sc.verbose) + engineInstances.get(sc.engineInstanceId) map { engineInstance => + val engineId = sc.engineId.getOrElse(engineInstance.engineId) + val engineVersion = sc.engineVersion.getOrElse( + engineInstance.engineVersion) + engineManifests.get(engineId, engineVersion) map { manifest => + val engineFactoryName = engineInstance.engineFactory - val upgrade = actorSystem.actorOf(Props( - classOf[UpgradeActor], - engineFactoryName)) - actorSystem.scheduler.schedule( - 0.seconds, - 1.days, - upgrade, - UpgradeCheck()) + val master = actorSystem.actorOf(Props( + classOf[MasterActor], + sc, + engineInstance, + engineFactoryName, + manifest), + "master") + implicit val timeout = Timeout(5.seconds) + master ? StartServer() + actorSystem.awaitTermination + } getOrElse { + error(s"Invalid engine ID or version. Aborting server.") + } + } getOrElse { + error(s"Invalid engine instance ID. Aborting server.") + } + } + } + + def createServerActorWithEngine[TD, EIN, PD, Q, P, A]( + sc: ServerConfig, + engineInstance: EngineInstance, + engine: Engine[TD, EIN, PD, Q, P, A], + engineLanguage: EngineLanguage.Value, + manifest: EngineManifest): ActorRef = { + + val engineParams = engine.engineInstanceToEngineParams(engineInstance, sc.jsonExtractor) + + val kryo = KryoInstantiator.newKryoInjection + + val modelsFromEngineInstance = + kryo.invert(modeldata.get(engineInstance.id).get.models).get. + asInstanceOf[Seq[Any]] + + val batch = if (engineInstance.batch.nonEmpty) { + s"${engineInstance.engineFactory} (${engineInstance.batch})" + } else { + engineInstance.engineFactory + } + + val sparkContext = WorkflowContext( + batch = batch, + executorEnv = engineInstance.env, + mode = "Serving", + sparkEnv = engineInstance.sparkConf) + + val models = engine.prepareDeploy( + sparkContext, + engineParams, + engineInstance.id, + modelsFromEngineInstance, + params = WorkflowParams() + ) + + val algorithms = engineParams.algorithmParamsList.map { case (n, p) => + Doer(engine.algorithmClassMap(n), p) + } + + val servingParamsWithName = engineParams.servingParams + + val serving = Doer(engine.servingClassMap(servingParamsWithName._1), + servingParamsWithName._2) + + actorSystem.actorOf( + Props( + classOf[ServerActor[Q, P]], + sc, + engineInstance, + engine, + engineLanguage, + manifest, + engineParams.dataSourceParams._2, + engineParams.preparatorParams._2, + algorithms, + engineParams.algorithmParamsList.map(_._2), + models, + serving, + engineParams.servingParams._2)) + } +} + - class UpgradeActor(engineClass: String) extends Actor { - val log = Logging(context.system, this) - implicit val system = context.system - def receive: Actor.Receive = { - case x: UpgradeCheck => - WorkflowUtils.checkUpgrade("deployment", engineClass) - } - } - +class MasterActor ( + sc: ServerConfig, + engineInstance: EngineInstance, + engineFactoryName: String, + manifest: EngineManifest) extends Actor with SSLConfiguration with KeyAuthentication { + val log = Logging(context.system, this) + implicit val system = context.system + var sprayHttpListener: Option[ActorRef] = None + var currentServerActor: Option[ActorRef] = None + var retry = 3 + + def undeploy(ip: String, port: Int): Unit = { + val serverUrl = s"https://${ip}:${port}" + log.info( + s"Undeploying any existing engine instance at $serverUrl") + try { + val code = scalaj.http.Http(s"$serverUrl/stop") + .option(HttpOptions.allowUnsafeSSL) + .param(ServerKey.param, ServerKey.get) + .method("POST").asString.code + code match { + case 200 => Unit + case 404 => log.error( + s"Another process is using $serverUrl. Unable to undeploy.") + case _ => log.error( + s"Another process is using $serverUrl, or an existing " + + s"engine server is not responding properly (HTTP $code). " + + "Unable to undeploy.") + } + } catch { + case e: java.net.ConnectException => + log.warning(s"Nothing at $serverUrl") + case _: Throwable => + log.error("Another process might be occupying " + + s"$ip:$port. Unable to undeploy.") + } + } + + def receive: Actor.Receive = { + case x: StartServer => + val actor = createServerActor( + sc, + engineInstance, + engineFactoryName, + manifest) + currentServerActor = Some(actor) + undeploy(sc.ip, sc.port) + self ! BindServer() + case x: BindServer => + currentServerActor map { actor => + val settings = ServerSettings(system) + IO(Http) ! Http.Bind( + actor, + interface = sc.ip, + port = sc.port, + settings = Some(settings.copy(sslEncryption = true))) + } getOrElse { + log.error("Cannot bind a non-existing server backend.") + } + case x: StopServer => + log.info(s"Stop server command received.") + sprayHttpListener.map { l => + log.info("Server is shutting down.") + l ! Http.Unbind(5.seconds) + system.shutdown + } getOrElse { + log.warning("No active server is running.") + } + case x: ReloadServer => + log.info("Reload server command received.") + val latestEngineInstance = + CreateServer.engineInstances.getLatestCompleted( + manifest.id, + manifest.version, + engineInstance.engineVariant) + latestEngineInstance map { lr => + val actor = createServerActor(sc, lr, engineFactoryName, manifest) + sprayHttpListener.map { l => + l ! Http.Unbind(5.seconds) + val settings = ServerSettings(system) + IO(Http) ! Http.Bind( + actor, + interface = sc.ip, + port = sc.port, + settings = Some(settings.copy(sslEncryption = true))) + currentServerActor.get ! Kill + currentServerActor = Some(actor) + } getOrElse { + log.warning("No active server is running. Abort reloading.") + } + } getOrElse { + log.warning( + s"No latest completed engine instance for ${manifest.id} " + + s"${manifest.version}. Abort reloading.") + } + case x: Http.Bound => + val serverUrl = s"https://${sc.ip}:${sc.port}" + log.info(s"Engine is deployed and running. Engine API is live at ${serverUrl}.") + sprayHttpListener = Some(sender) + case x: Http.CommandFailed => + if (retry > 0) { + retry -= 1 + log.error(s"Bind failed. Retrying... ($retry more trial(s))") + context.system.scheduler.scheduleOnce(1.seconds) { + self ! BindServer() + } + } else { + log.error("Bind failed. Shutting down.") + system.shutdown + } + } + + def createServerActor( + sc: ServerConfig, + engineInstance: EngineInstance, + engineFactoryName: String, + manifest: EngineManifest): ActorRef = { + val (engineLanguage, engineFactory) = + WorkflowUtils.getEngine(engineFactoryName, getClass.getClassLoader) + val engine = engineFactory() + + // EngineFactory return a base engine, which may not be deployable. + if (!engine.isInstanceOf[Engine[_,_,_,_,_,_]]) { + throw new NoSuchMethodException(s"Engine $engine is not deployable") + } + + val deployableEngine = engine.asInstanceOf[Engine[_,_,_,_,_,_]] + + CreateServer.createServerActorWithEngine( + sc, + engineInstance, + // engine, + deployableEngine, + engineLanguage, + manifest) + } +} + +class ServerActor[Q, P]( + val args: ServerConfig, + val engineInstance: EngineInstance, + val engine: Engine[_, _, _, Q, P, _], + val engineLanguage: EngineLanguage.Value, + val manifest: EngineManifest, + val dataSourceParams: Params, + val preparatorParams: Params, + val algorithms: Seq[BaseAlgorithm[_, _, Q, P]], + val algorithmsParams: Seq[Params], + val models: Seq[Any], + val serving: BaseServing[Q, P], + val servingParams: Params) extends Actor with HttpService with KeyAuthentication { + val serverStartTime = DateTime.now + val log = Logging(context.system, this) + + var requestCount: Int = 0 + var avgServingSec: Double = 0.0 + var lastServingSec: Double = 0.0 + + /** The following is required by HttpService */ + def actorRefFactory: ActorContext = context + + implicit val timeout = Timeout(5, TimeUnit.SECONDS) + val pluginsActorRef = + context.actorOf(Props(classOf[PluginsActor], args.engineVariant), "PluginsActor") + val pluginContext = EngineServerPluginContext(log, args.engineVariant) + + def receive: Actor.Receive = runRoute(myRoute) + + val feedbackEnabled = if (args.feedback) { + if (args.accessKey.isEmpty) { + log.error("Feedback loop cannot be enabled because accessKey is empty.") + false + } else { + true + } + } else false + + def remoteLog(logUrl: String, logPrefix: String, message: String): Unit = { + implicit val formats = Utils.json4sDefaultFormats + try { + scalaj.http.Http(logUrl).postData( + logPrefix + write(Map( + "engineInstance" -> engineInstance, + "message" -> message))).asString + } catch { + case e: Throwable => + log.error(s"Unable to send remote log: ${e.getMessage}") + } + } + + def getStackTraceString(e: Throwable): String = { + val writer = new StringWriter() + val printWriter = new PrintWriter(writer) + e.printStackTrace(printWriter) + writer.toString + } + + val myRoute = + path("") { + get { + respondWithMediaType(`text/html`) { + detach() { + complete { + html.index( + args, + manifest, + engineInstance, + algorithms.map(_.toString), + algorithmsParams.map(_.toString), + models.map(_.toString), + dataSourceParams.toString, + preparatorParams.toString, + servingParams.toString, + serverStartTime, + feedbackEnabled, + args.eventServerIp, + args.eventServerPort, + requestCount, + avgServingSec, + lastServingSec + ).toString + } + } + } + } + } ~ + path("queries.json") { + post { + detach() { + entity(as[String]) { queryString => + try { + val servingStartTime = DateTime.now + val jsonExtractorOption = args.jsonExtractor + val queryTime = DateTime.now + // Extract Query from Json + val query = JsonExtractor.extract( + jsonExtractorOption, + queryString, + algorithms.head.queryClass, + algorithms.head.querySerializer, + algorithms.head.gsonTypeAdapterFactories + ) + val queryJValue = JsonExtractor.toJValue( + jsonExtractorOption, + query, + algorithms.head.querySerializer, + algorithms.head.gsonTypeAdapterFactories) + // Deploy logic. First call Serving.supplement, then Algo.predict, + // finally Serving.serve. + val supplementedQuery = serving.supplementBase(query) + // TODO: Parallelize the following. + val predictions = algorithms.zipWithIndex.map { case (a, ai) => + a.predictBase(models(ai), supplementedQuery) + } + // Notice that it is by design to call Serving.serve with the + // *original* query. + val prediction = serving.serveBase(query, predictions) + val predictionJValue = JsonExtractor.toJValue( + jsonExtractorOption, + prediction, + algorithms.head.querySerializer, + algorithms.head.gsonTypeAdapterFactories) + /** Handle feedback to Event Server + * Send the following back to the Event Server + * - appId + * - engineInstanceId + * - query + * - prediction + * - prId + */ + val result = if (feedbackEnabled) { + implicit val formats = + algorithms.headOption map { alg => + alg.querySerializer + } getOrElse { + Utils.json4sDefaultFormats + } + // val genPrId = Random.alphanumeric.take(64).mkString + def genPrId: String = Random.alphanumeric.take(64).mkString + val newPrId = prediction match { + case id: WithPrId => + val org = id.prId + if (org.isEmpty) genPrId else org + case _ => genPrId + } + + // also save Query's prId as prId of this pio_pr predict events + val queryPrId = + query match { + case id: WithPrId => + Map("prId" -> id.prId) + case _ => + Map() + } + val data = Map( + // "appId" -> dataSourceParams.asInstanceOf[ParamsWithAppId].appId, + "event" -> "predict", + "eventTime" -> queryTime.toString(), + "entityType" -> "pio_pr", // prediction result + "entityId" -> newPrId, + "properties" -> Map( + "engineInstanceId" -> engineInstance.id, + "query" -> query, + "prediction" -> prediction)) ++ queryPrId + // At this point args.accessKey should be Some(String). + val accessKey = args.accessKey.getOrElse("") + val f: Future[Int] = future { + scalaj.http.Http( + s"http://${args.eventServerIp}:${args.eventServerPort}/" + + s"events.json?accessKey=$accessKey").postData( + write(data)).header( + "content-type", "application/json").asString.code + } + f onComplete { + case Success(code) => { + if (code != 201) { + log.error(s"Feedback event failed. Status code: $code." + + s"Data: ${write(data)}.") + } + } + case Failure(t) => { + log.error(s"Feedback event failed: ${t.getMessage}") } + } + // overwrite prId in predictedResult + // - if it is WithPrId, + // then overwrite with new prId + // - if it is not WithPrId, no prId injection + if (prediction.isInstanceOf[WithPrId]) { + predictionJValue merge parse(s"""{"prId" : "$newPrId"}""") + } else { + predictionJValue + } + } else predictionJValue + + val pluginResult = + pluginContext.outputBlockers.values.foldLeft(result) { case (r, p) => + p.process(engineInstance, queryJValue, r, pluginContext) + } + + // Bookkeeping + val servingEndTime = DateTime.now + lastServingSec = + (servingEndTime.getMillis - servingStartTime.getMillis) / 1000.0 + avgServingSec = + ((avgServingSec * requestCount) + lastServingSec) / + (requestCount + 1) + requestCount += 1 + + respondWithMediaType(`application/json`) { + complete(compact(render(pluginResult))) + } + } catch { + case e: MappingException => + log.error( + s"Query '$queryString' is invalid. Reason: ${e.getMessage}") + args.logUrl map { url => + remoteLog( + url, + args.logPrefix.getOrElse(""), + s"Query:\n$queryString\n\nStack Trace:\n" + + s"${getStackTraceString(e)}\n\n") + } + complete(StatusCodes.BadRequest, e.getMessage) + case e: Throwable => + val msg = s"Query:\n$queryString\n\nStack Trace:\n" + + s"${getStackTraceString(e)}\n\n" + log.error(msg) + args.logUrl map { url => + remoteLog( + url, + args.logPrefix.getOrElse(""), + msg) + } + complete(StatusCodes.InternalServerError, msg) + } + } + } + } + } ~ + path("reload") { + authenticate(withAccessKeyFromFile) { request => + post { + complete { + context.actorSelection("/user/master") ! ReloadServer() + "Reloading..." + } + } + } + } ~ + path("stop") { + authenticate(withAccessKeyFromFile) { request => + post { + complete { + context.system.scheduler.scheduleOnce(1.seconds) { + context.actorSelection("/user/master") ! StopServer() + } + "Shutting down..." + } + } + } + } ~ + pathPrefix("assets") { + getFromResourceDirectory("assets") + } ~ + path("plugins.json") { + import EngineServerJson4sSupport._ + get { + respondWithMediaType(MediaTypes.`application/json`) { + complete { + Map("plugins" -> Map( + "outputblockers" -> pluginContext.outputBlockers.map { case (n, p) => + n -> Map( + "name" -> p.pluginName, + "description" -> p.pluginDescription, + "class" -> p.getClass.getName, + "params" -> pluginContext.pluginParams(p.pluginName)) + }, + "outputsniffers" -> pluginContext.outputSniffers.map { case (n, p) => + n -> Map( + "name" -> p.pluginName, + "description" -> p.pluginDescription, + "class" -> p.getClass.getName, + "params" -> pluginContext.pluginParams(p.pluginName)) + } + )) + } + } + } + } ~ + path("plugins" / Segments) { segments => + import EngineServerJson4sSupport._ + get { + respondWithMediaType(MediaTypes.`application/json`) { + complete { + val pluginArgs = segments.drop(2) + val pluginType = segments(0) + val pluginName = segments(1) + pluginType match { + case EngineServerPlugin.outputSniffer => + pluginsActorRef ? PluginsActor.HandleREST( + pluginName = pluginName, + pluginArgs = pluginArgs) map { + _.asInstanceOf[String] + } + } + } + } + } + } +} + +object EngineServerJson4sSupport extends Json4sSupport { + implicit def json4sFormats: Formats = DefaultFormats +} http://git-wip-us.apache.org/repos/asf/incubator-predictionio/blob/c508c791/core/src/main/scala/org/apache/predictionio/workflow/WorkflowUtils.scala ---------------------------------------------------------------------- diff --cc core/src/main/scala/org/apache/predictionio/workflow/WorkflowUtils.scala index cd80fd9,0000000..0df8db5 mode 100644,000000..100644 --- a/core/src/main/scala/org/apache/predictionio/workflow/WorkflowUtils.scala +++ b/core/src/main/scala/org/apache/predictionio/workflow/WorkflowUtils.scala @@@ -1,419 -1,0 +1,387 @@@ +/** Copyright 2015 TappingStone, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.predictionio.workflow + +import java.io.File +import java.io.FileNotFoundException + +import org.apache.predictionio.controller.EmptyParams +import org.apache.predictionio.controller.EngineFactory +import org.apache.predictionio.controller.EngineParamsGenerator +import org.apache.predictionio.controller.Evaluation +import org.apache.predictionio.controller.Params +import org.apache.predictionio.controller.PersistentModelLoader +import org.apache.predictionio.controller.Utils +import org.apache.predictionio.core.BuildInfo + +import com.google.gson.Gson +import com.google.gson.JsonSyntaxException +import grizzled.slf4j.Logging +import org.apache.predictionio.workflow.JsonExtractorOption.JsonExtractorOption +import org.apache.log4j.Level +import org.apache.log4j.LogManager +import org.apache.spark.SparkContext +import org.apache.spark.api.java.JavaRDDLike +import org.apache.spark.rdd.RDD +import org.json4s.JsonAST.JValue +import org.json4s.MappingException +import org.json4s._ +import org.json4s.native.JsonMethods._ + - import scala.io.Source +import scala.language.existentials +import scala.reflect.runtime.universe + +/** Collection of reusable workflow related utilities. */ +object WorkflowUtils extends Logging { + @transient private lazy val gson = new Gson + + /** Obtains an Engine object in Scala, or instantiate an Engine in Java. + * + * @param engine Engine factory name. + * @param cl A Java ClassLoader to look for engine-related classes. + * + * @throws ClassNotFoundException + * Thrown when engine factory class does not exist. + * @throws NoSuchMethodException + * Thrown when engine factory's apply() method is not implemented. + */ + def getEngine(engine: String, cl: ClassLoader): (EngineLanguage.Value, EngineFactory) = { + val runtimeMirror = universe.runtimeMirror(cl) + val engineModule = runtimeMirror.staticModule(engine) + val engineObject = runtimeMirror.reflectModule(engineModule) + try { + ( + EngineLanguage.Scala, + engineObject.instance.asInstanceOf[EngineFactory] + ) + } catch { + case e @ (_: NoSuchFieldException | _: ClassNotFoundException) => try { + ( + EngineLanguage.Java, + Class.forName(engine).newInstance.asInstanceOf[EngineFactory] + ) + } + } + } + + def getEngineParamsGenerator(epg: String, cl: ClassLoader): + (EngineLanguage.Value, EngineParamsGenerator) = { + val runtimeMirror = universe.runtimeMirror(cl) + val epgModule = runtimeMirror.staticModule(epg) + val epgObject = runtimeMirror.reflectModule(epgModule) + try { + ( + EngineLanguage.Scala, + epgObject.instance.asInstanceOf[EngineParamsGenerator] + ) + } catch { + case e @ (_: NoSuchFieldException | _: ClassNotFoundException) => try { + ( + EngineLanguage.Java, + Class.forName(epg).newInstance.asInstanceOf[EngineParamsGenerator] + ) + } + } + } + + def getEvaluation(evaluation: String, cl: ClassLoader): (EngineLanguage.Value, Evaluation) = { + val runtimeMirror = universe.runtimeMirror(cl) + val evaluationModule = runtimeMirror.staticModule(evaluation) + val evaluationObject = runtimeMirror.reflectModule(evaluationModule) + try { + ( + EngineLanguage.Scala, + evaluationObject.instance.asInstanceOf[Evaluation] + ) + } catch { + case e @ (_: NoSuchFieldException | _: ClassNotFoundException) => try { + ( + EngineLanguage.Java, + Class.forName(evaluation).newInstance.asInstanceOf[Evaluation] + ) + } + } + } + + /** Converts a JSON document to an instance of Params. + * + * @param language Engine's programming language. + * @param json JSON document. + * @param clazz Class of the component that is going to receive the resulting + * Params instance as a constructor argument. + * @param jsonExtractor JSON extractor option. + * @param formats JSON4S serializers for deserialization. + * + * @throws MappingException Thrown when JSON4S fails to perform conversion. + * @throws JsonSyntaxException Thrown when GSON fails to perform conversion. + */ + def extractParams( + language: EngineLanguage.Value = EngineLanguage.Scala, + json: String, + clazz: Class[_], + jsonExtractor: JsonExtractorOption, + formats: Formats = Utils.json4sDefaultFormats): Params = { + implicit val f = formats + val pClass = clazz.getConstructors.head.getParameterTypes + if (pClass.size == 0) { + if (json != "") { + warn(s"Non-empty parameters supplied to ${clazz.getName}, but its " + + "constructor does not accept any arguments. Stubbing with empty " + + "parameters.") + } + EmptyParams() + } else { + val apClass = pClass.head + try { + JsonExtractor.extract(jsonExtractor, json, apClass, f).asInstanceOf[Params] + } catch { + case e@(_: MappingException | _: JsonSyntaxException) => + error( + s"Unable to extract parameters for ${apClass.getName} from " + + s"JSON string: $json. Aborting workflow.", + e) + throw e + } + } + } + + def getParamsFromJsonByFieldAndClass( + variantJson: JValue, + field: String, + classMap: Map[String, Class[_]], + engineLanguage: EngineLanguage.Value, + jsonExtractor: JsonExtractorOption): (String, Params) = { + variantJson findField { + case JField(f, _) => f == field + case _ => false + } map { jv => + implicit lazy val formats = Utils.json4sDefaultFormats + new NameParamsSerializer + val np: NameParams = try { + jv._2.extract[NameParams] + } catch { + case e: Exception => + error(s"Unable to extract $field name and params $jv") + throw e + } + val extractedParams = np.params.map { p => + try { + if (!classMap.contains(np.name)) { + error(s"Unable to find $field class with name '${np.name}'" + + " defined in Engine.") + sys.exit(1) + } + WorkflowUtils.extractParams( + engineLanguage, + compact(render(p)), + classMap(np.name), + jsonExtractor, + formats) + } catch { + case e: Exception => + error(s"Unable to extract $field params $p") + throw e + } + }.getOrElse(EmptyParams()) + + (np.name, extractedParams) + } getOrElse("", EmptyParams()) + } + + /** Grab environmental variables that starts with 'PIO_'. */ + def pioEnvVars: Map[String, String] = + sys.env.filter(kv => kv._1.startsWith("PIO_")) + + /** Converts Java (non-Scala) objects to a JSON4S JValue. + * + * @param params The Java object to be converted. + */ + def javaObjectToJValue(params: AnyRef): JValue = parse(gson.toJson(params)) + - private[predictionio] def checkUpgrade( - component: String = "core", - engine: String = ""): Unit = { - val runner = new Thread(new UpgradeCheckRunner(component, engine)) - runner.start() - } - + // Extract debug string by recursively traversing the data. + def debugString[D](data: D): String = { + val s: String = data match { + case rdd: RDD[_] => { + debugString(rdd.collect()) + } + case javaRdd: JavaRDDLike[_, _] => { + debugString(javaRdd.collect()) + } + case array: Array[_] => { + "[" + array.map(debugString).mkString(",") + "]" + } + case d: AnyRef => { + d.toString + } + case null => "null" + } + s + } + + /** Detect third party software configuration files to be submitted as + * extras to Apache Spark. This makes sure all executors receive the same + * configuration. + */ + def thirdPartyConfFiles: Seq[String] = { + val thirdPartyFiles = Map( + "PIO_CONF_DIR" -> "log4j.properties", + "ES_CONF_DIR" -> "elasticsearch.yml", + "HADOOP_CONF_DIR" -> "core-site.xml", + "HBASE_CONF_DIR" -> "hbase-site.xml") + + thirdPartyFiles.keys.toSeq.map { k: String => + sys.env.get(k) map { x => + val p = Seq(x, thirdPartyFiles(k)).mkString(File.separator) + if (new File(p).exists) Seq(p) else Seq[String]() + } getOrElse Seq[String]() + }.flatten + } + + def thirdPartyClasspaths: Seq[String] = { + val thirdPartyPaths = Seq( + "PIO_CONF_DIR", + "ES_CONF_DIR", + "POSTGRES_JDBC_DRIVER", + "MYSQL_JDBC_DRIVER", + "HADOOP_CONF_DIR", + "HBASE_CONF_DIR") + thirdPartyPaths.map(p => + sys.env.get(p).map(Seq(_)).getOrElse(Seq[String]()) + ).flatten + } + + def modifyLogging(verbose: Boolean): Unit = { + val rootLoggerLevel = if (verbose) Level.TRACE else Level.INFO + val chattyLoggerLevel = if (verbose) Level.INFO else Level.WARN + + LogManager.getRootLogger.setLevel(rootLoggerLevel) + + LogManager.getLogger("org.elasticsearch").setLevel(chattyLoggerLevel) + LogManager.getLogger("org.apache.hadoop").setLevel(chattyLoggerLevel) + LogManager.getLogger("org.apache.spark").setLevel(chattyLoggerLevel) + LogManager.getLogger("org.eclipse.jetty").setLevel(chattyLoggerLevel) + LogManager.getLogger("akka").setLevel(chattyLoggerLevel) + } + + def extractNameParams(jv: JValue): NameParams = { + implicit val formats = Utils.json4sDefaultFormats + val nameOpt = (jv \ "name").extract[Option[String]] + val paramsOpt = (jv \ "params").extract[Option[JValue]] + + if (nameOpt.isEmpty && paramsOpt.isEmpty) { + error("Unable to find 'name' or 'params' fields in" + + s" ${compact(render(jv))}.\n" + + "Since 0.8.4, the 'params' field is required in engine.json" + + " in order to specify parameters for DataSource, Preparator or" + + " Serving.\n" + + "Please go to https://docs.prediction.io/resources/upgrade/" + + " for detailed instruction of how to change engine.json.") + sys.exit(1) + } + + if (nameOpt.isEmpty) { + info(s"No 'name' is found. Default empty String will be used.") + } + + if (paramsOpt.isEmpty) { + info(s"No 'params' is found. Default EmptyParams will be used.") + } + + NameParams( + name = nameOpt.getOrElse(""), + params = paramsOpt + ) + } + + def extractSparkConf(root: JValue): List[(String, String)] = { + def flatten(jv: JValue): List[(List[String], String)] = { + jv match { + case JObject(fields) => + for ((namePrefix, childJV) <- fields; + (name, value) <- flatten(childJV)) + yield (namePrefix :: name) -> value + case JArray(_) => { + error("Arrays are not allowed in the sparkConf section of engine.js.") + sys.exit(1) + } + case JNothing => List() + case _ => List(List() -> jv.values.toString) + } + } + + flatten(root \ "sparkConf").map(x => + (x._1.reduce((a, b) => s"$a.$b"), x._2)) + } +} + +case class NameParams(name: String, params: Option[JValue]) + +class NameParamsSerializer extends CustomSerializer[NameParams](format => ( { + case jv: JValue => WorkflowUtils.extractNameParams(jv) +}, { + case x: NameParams => + JObject(JField("name", JString(x.name)) :: + JField("params", x.params.getOrElse(JNothing)) :: Nil) +} + )) + +/** Collection of reusable workflow related utilities that touch on Apache + * Spark. They are separated to avoid compilation problems with certain code. + */ +object SparkWorkflowUtils extends Logging { + def getPersistentModel[AP <: Params, M]( + pmm: PersistentModelManifest, + runId: String, + params: AP, + sc: Option[SparkContext], + cl: ClassLoader): M = { + val runtimeMirror = universe.runtimeMirror(cl) + val pmmModule = runtimeMirror.staticModule(pmm.className) + val pmmObject = runtimeMirror.reflectModule(pmmModule) + try { + pmmObject.instance.asInstanceOf[PersistentModelLoader[AP, M]]( + runId, + params, + sc) + } catch { + case e @ (_: NoSuchFieldException | _: ClassNotFoundException) => try { + val loadMethod = Class.forName(pmm.className).getMethod( + "load", + classOf[String], + classOf[Params], + classOf[SparkContext]) + loadMethod.invoke(null, runId, params, sc.orNull).asInstanceOf[M] + } catch { + case e: ClassNotFoundException => + error(s"Model class ${pmm.className} cannot be found.") + throw e + case e: NoSuchMethodException => + error( + "The load(String, Params, SparkContext) method cannot be found.") + throw e + } + } + } +} + - class UpgradeCheckRunner( - val component: String, - val engine: String) extends Runnable with Logging { - val version = BuildInfo.version - val versionsHost = "https://direct.prediction.io/" - - def run(): Unit = { - val url = if (engine == "") { - s"$versionsHost$version/$component.json" - } else { - s"$versionsHost$version/$component/$engine.json" - } - try { - val upgradeData = Source.fromURL(url) - } catch { - case e: FileNotFoundException => - debug(s"Update metainfo not found. $url") - case e: java.net.UnknownHostException => - debug(s"${e.getClass.getName}: {e.getMessage}") - } - // TODO: Implement upgrade logic - } - } - +class WorkflowInterruption() extends Exception + +case class StopAfterReadInterruption() extends WorkflowInterruption + +case class StopAfterPrepareInterruption() extends WorkflowInterruption + +object EngineLanguage extends Enumeration { + val Scala, Java = Value +}
