This is an automated email from the ASF dual-hosted git repository.

takezoe pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/predictionio.git


The following commit(s) were added to refs/heads/develop by this push:
     new 9bf29cd  [PIO-192] Enhance PySpark support (#494)
9bf29cd is described below

commit 9bf29cdeaf872e7b01198c4cc09bb2d3a06d1f1f
Author: takako shimamoto <[email protected]>
AuthorDate: Mon Dec 10 21:58:17 2018 +0900

    [PIO-192] Enhance PySpark support (#494)
---
 bin/pio-shell                                      |   1 -
 build.sbt                                          |  14 +--
 .../workflow/EngineServerPluginContext.scala       |  19 ++--
 .../predictionio/workflow/JsonExtractor.scala      |   9 +-
 .../predictionio/e2/engine/PythonEngine.scala      |  96 ++++++++++++++++++
 python/pypio/__init__.py                           |   9 ++
 python/pypio/data/eventstore.py                    |   4 +-
 python/pypio/pypio.py                              | 111 +++++++++++++++++++++
 python/pypio/shell.py                              |  23 -----
 python/pypio/utils.py                              |  40 ++++++--
 .../apache/predictionio/tools/RunWorkflow.scala    |  32 +++---
 .../predictionio/tools/commands/Engine.scala       |  27 ++---
 12 files changed, 303 insertions(+), 82 deletions(-)

diff --git a/bin/pio-shell b/bin/pio-shell
index cd119cd..e23041a 100755
--- a/bin/pio-shell
+++ b/bin/pio-shell
@@ -65,7 +65,6 @@ then
   # Get paths of assembly jars to pass to pyspark
   . ${PIO_HOME}/bin/compute-classpath.sh
   shift
-  export PYTHONSTARTUP=${PIO_HOME}/python/pypio/shell.py
   export PYTHONPATH=${PIO_HOME}/python
   ${SPARK_HOME}/bin/pyspark --jars ${ASSEMBLY_JARS} $@
 else
diff --git a/build.sbt b/build.sbt
index 9efed21..65c4ca3 100644
--- a/build.sbt
+++ b/build.sbt
@@ -151,19 +151,19 @@ val core = (project in file("core")).
   enablePlugins(SbtTwirl).
   disablePlugins(sbtassembly.AssemblyPlugin)
 
-val tools = (project in file("tools")).
+val e2 = (project in file("e2")).
   dependsOn(core).
-  dependsOn(data).
   settings(commonSettings: _*).
-  settings(commonTestSettings: _*).
-  settings(skip in publish := true).
   enablePlugins(GenJavadocPlugin).
-  enablePlugins(SbtTwirl)
+  disablePlugins(sbtassembly.AssemblyPlugin)
 
-val e2 = (project in file("e2")).
+val tools = (project in file("tools")).
+  dependsOn(e2).
   settings(commonSettings: _*).
+  settings(commonTestSettings: _*).
+  settings(skip in publish := true).
   enablePlugins(GenJavadocPlugin).
-  disablePlugins(sbtassembly.AssemblyPlugin)
+  enablePlugins(SbtTwirl)
 
 val dataEs = if (majorVersion(es) == 1) dataElasticsearch1 else 
dataElasticsearch
 
diff --git 
a/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala
 
b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala
index cfc83eb..011cd95 100644
--- 
a/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala
+++ 
b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala
@@ -55,9 +55,10 @@ object EngineServerPluginContext extends Logging {
       EngineServerPlugin.outputSniffer -> mutable.Map())
     val pluginParams = mutable.Map[String, JValue]()
     val serviceLoader = ServiceLoader.load(classOf[EngineServerPlugin])
-    val variantJson = parse(stringFromFile(engineVariant))
-    (variantJson \ "plugins").extractOpt[JObject].foreach { pluginDefs =>
-      pluginDefs.obj.foreach { pluginParams += _ }
+    stringFromFile(engineVariant).foreach { variantJson =>
+      (parse(variantJson) \ "plugins").extractOpt[JObject].foreach { 
pluginDefs =>
+        pluginDefs.obj.foreach { pluginParams += _ }
+      }
     }
     serviceLoader foreach { service =>
       pluginParams.get(service.pluginName) map { params =>
@@ -77,11 +78,15 @@ object EngineServerPluginContext extends Logging {
       log)
   }
 
-  private def stringFromFile(filePath: String): String = {
+  private def stringFromFile(filePath: String): Option[String] = {
     try {
-      val uri = new URI(filePath)
-      val fs = FileSystem.get(uri, new Configuration())
-      new String(ByteStreams.toByteArray(fs.open(new Path(uri))).map(_.toChar))
+      val fs = FileSystem.get(new Configuration())
+      val path = new Path(new URI(filePath))
+      if (fs.exists(path)) {
+        Some(new String(ByteStreams.toByteArray(fs.open(path)).map(_.toChar)))
+      } else {
+        None
+      }
     } catch {
       case e: java.io.IOException =>
         error(s"Error reading from file: ${e.getMessage}. Aborting.")
diff --git 
a/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala 
b/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala
index cb71f14..3aafe67 100644
--- a/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala
+++ b/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala
@@ -32,7 +32,6 @@ import org.json4s.native.JsonMethods.compact
 import org.json4s.native.JsonMethods.pretty
 import org.json4s.native.JsonMethods.parse
 import org.json4s.native.JsonMethods.render
-import org.json4s.reflect.TypeInfo
 
 object JsonExtractor {
 
@@ -144,7 +143,13 @@ object JsonExtractor {
     formats: Formats,
     clazz: Class[T]): T = {
 
-    Extraction.extract(parse(json), TypeInfo(clazz, 
None))(formats).asInstanceOf[T]
+    implicit val f = formats
+    implicit val m = if (clazz == classOf[Map[_, _]]) {
+      Manifest.classType(clazz, manifest[String], manifest[Any])
+    } else {
+      Manifest.classType(clazz)
+    }
+    Extraction.extract(parse(json))
   }
 
   private def extractWithGson[T](
diff --git 
a/e2/src/main/scala/org/apache/predictionio/e2/engine/PythonEngine.scala 
b/e2/src/main/scala/org/apache/predictionio/e2/engine/PythonEngine.scala
new file mode 100644
index 0000000..a2c7282
--- /dev/null
+++ b/e2/src/main/scala/org/apache/predictionio/e2/engine/PythonEngine.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.predictionio.e2.engine
+
+import java.util.Arrays
+
+import org.apache.predictionio.controller._
+import org.apache.predictionio.workflow.KryoInstantiator
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.PipelineModel
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.{Row, SparkSession}
+
+
+object PythonEngine extends EngineFactory {
+
+  private[engine] type Query = Map[String, Any]
+
+  def apply(): Engine[EmptyTrainingData, EmptyEvaluationInfo, 
EmptyPreparedData,
+    Query, Row, EmptyActualResult] = {
+    new Engine(
+      classOf[PythonDataSource],
+      classOf[PythonPreparator],
+      Map("default" -> classOf[PythonAlgorithm]),
+      classOf[PythonServing])
+  }
+
+  def models(model: PipelineModel): Array[Byte] = {
+    val kryo = KryoInstantiator.newKryoInjection
+    kryo(Seq(model))
+  }
+
+}
+
+import PythonEngine.Query
+
+class PythonDataSource extends
+  PDataSource[EmptyTrainingData, EmptyEvaluationInfo, Query, 
EmptyActualResult] {
+  def readTraining(sc: SparkContext): EmptyTrainingData = new 
SerializableClass()
+}
+
+class PythonPreparator extends PPreparator[EmptyTrainingData, 
EmptyPreparedData] {
+  def prepare(sc: SparkContext, trainingData: EmptyTrainingData): 
EmptyPreparedData =
+    new SerializableClass()
+}
+
+object PythonServing {
+  private[engine] val columns = "PythonPredictColumns"
+
+  case class Params(columns: Seq[String]) extends 
org.apache.predictionio.controller.Params
+}
+
+class PythonServing(params: PythonServing.Params) extends LFirstServing[Query, 
Row] {
+  override def supplement(q: Query): Query = {
+    q + (PythonServing.columns -> params.columns)
+  }
+}
+
+class PythonAlgorithm extends
+  P2LAlgorithm[EmptyPreparedData, PipelineModel, Query, Row] {
+
+  def train(sc: SparkContext, data: EmptyPreparedData): PipelineModel = ???
+
+  def predict(model: PipelineModel, query: Query): Row = {
+    val selectCols = query(PythonServing.columns).asInstanceOf[Seq[String]]
+    val (colNames, data) = (query - PythonServing.columns).toList.unzip
+
+    val rows = Arrays.asList(Row.fromSeq(data))
+    val schema = StructType(colNames.zipWithIndex.map { case (col, i) =>
+      StructField(col, Literal(data(i)).dataType)
+    })
+
+    val spark = SparkSession.builder.getOrCreate()
+    val df = spark.createDataFrame(rows, schema)
+    model.transform(df)
+      .select(selectCols.head, selectCols.tail: _*)
+      .first()
+  }
+
+}
diff --git a/python/pypio/__init__.py b/python/pypio/__init__.py
index 04d8ac3..e0ca788 100644
--- a/python/pypio/__init__.py
+++ b/python/pypio/__init__.py
@@ -18,3 +18,12 @@
 """
 PyPIO is the Python API for PredictionIO.
 """
+
+from __future__ import absolute_import
+
+from pypio.pypio import *
+
+
+__all__ = [
+    'pypio'
+]
diff --git a/python/pypio/data/eventstore.py b/python/pypio/data/eventstore.py
index 4eb73df..58f09d1 100644
--- a/python/pypio/data/eventstore.py
+++ b/python/pypio/data/eventstore.py
@@ -17,8 +17,8 @@
 
 from __future__ import absolute_import
 
-from pypio.utils import new_string_array
 from pyspark.sql.dataframe import DataFrame
+from pyspark.sql import utils
 
 __all__ = ["PEventStore"]
 
@@ -43,7 +43,7 @@ class PEventStore(object):
         pes = 
self._sc._jvm.org.apache.predictionio.data.store.python.PPythonEventStore
         jdf = pes.aggregateProperties(app_name, entity_type, channel_name,
                                       start_time, until_time,
-                                      new_string_array(required, 
self._sc._gateway),
+                                      utils.toJArray(self._sc._gateway, 
self._sc._gateway.jvm.String, required),
                                       self._jss)
         return DataFrame(jdf, self.sql_ctx)
 
diff --git a/python/pypio/pypio.py b/python/pypio/pypio.py
new file mode 100644
index 0000000..3a32cd8
--- /dev/null
+++ b/python/pypio/pypio.py
@@ -0,0 +1,111 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import absolute_import
+
+import atexit
+import json
+import os
+import sys
+
+from pypio.data import PEventStore
+from pypio.utils import dict_to_scalamap, list_to_dict
+from pypio.workflow import CleanupFunctions
+from pyspark.sql import SparkSession
+
+
+def init():
+    global spark
+    spark = SparkSession.builder.getOrCreate()
+    global sc
+    sc = spark.sparkContext
+    global sqlContext
+    sqlContext = spark._wrapped
+    global p_event_store
+    p_event_store = PEventStore(spark._jsparkSession, sqlContext)
+
+    cleanup_functions = CleanupFunctions(sqlContext)
+    atexit.register(lambda: cleanup_functions.run())
+    atexit.register(lambda: sc.stop())
+    print("Initialized pypio")
+
+
+def find_events(app_name):
+    """
+    Returns a dataset of the specified app.
+
+    :param app_name: app name
+    :return: :py:class:`pyspark.sql.DataFrame`
+    """
+    return p_event_store.find(app_name)
+
+
+def save_model(model, predict_columns):
+    """
+    Save a PipelineModel object to storage.
+
+    :param model: :py:class:`pyspark.ml.pipeline.PipelineModel`
+    :param predict_columns: prediction columns
+    :return: identifier for the trained model to use for predict
+    """
+    if not predict_columns:
+        raise ValueError("predict_columns should have more than one value")
+    if os.environ.get('PYSPARK_PYTHON') is None:
+        # spark-submit
+        d = list_to_dict(sys.argv[1:])
+        pio_env = list_to_dict([v for e in d['--env'].split(',') for v in 
e.split('=')])
+    else:
+        # pyspark
+        pio_env = {k: v for k, v in os.environ.items() if k.startswith('PIO_')}
+
+    meta_storage = 
sc._jvm.org.apache.predictionio.data.storage.Storage.getMetaDataEngineInstances()
+
+    meta = sc._jvm.org.apache.predictionio.data.storage.EngineInstance.apply(
+        "",
+        "INIT", # status
+        sc._jvm.org.joda.time.DateTime.now(), # startTime
+        sc._jvm.org.joda.time.DateTime.now(), # endTime
+        "org.apache.predictionio.e2.engine.PythonEngine", # engineId
+        "1", # engineVersion
+        "default", # engineVariant
+        "org.apache.predictionio.e2.engine.PythonEngine", # engineFactory
+        "", # batch
+        dict_to_scalamap(sc._jvm, pio_env), # env
+        sc._jvm.scala.Predef.Map().empty(), # sparkConf
+        "{\"\":{}}", # dataSourceParams
+        "{\"\":{}}", # preparatorParams
+        "[{\"default\":{}}]", # algorithmsParams
+        json.dumps({"":{"columns":[v for v in predict_columns]}}) # 
servingParams
+    )
+    id = meta_storage.insert(meta)
+
+    engine = sc._jvm.org.apache.predictionio.e2.engine.PythonEngine
+    data = sc._jvm.org.apache.predictionio.data.storage.Model(id, 
engine.models(model._to_java()))
+    model_storage = 
sc._jvm.org.apache.predictionio.data.storage.Storage.getModelDataModels()
+    model_storage.insert(data)
+
+    meta_storage.update(
+        sc._jvm.org.apache.predictionio.data.storage.EngineInstance.apply(
+            id, "COMPLETED", meta.startTime(), 
sc._jvm.org.joda.time.DateTime.now(),
+            meta.engineId(), meta.engineVersion(), meta.engineVariant(),
+            meta.engineFactory(), meta.batch(), meta.env(), meta.sparkConf(),
+            meta.dataSourceParams(), meta.preparatorParams(), 
meta.algorithmsParams(), meta.servingParams()
+        )
+    )
+
+    return id
+
diff --git a/python/pypio/shell.py b/python/pypio/shell.py
deleted file mode 100644
index b0295d3..0000000
--- a/python/pypio/shell.py
+++ /dev/null
@@ -1,23 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from pypio.data import PEventStore
-from pypio.workflow import CleanupFunctions
-
-p_event_store = PEventStore(spark._jsparkSession, sqlContext)
-cleanup_functions = CleanupFunctions(sqlContext)
-
diff --git a/python/pypio/utils.py b/python/pypio/utils.py
index 76900c3..4155efb 100644
--- a/python/pypio/utils.py
+++ b/python/pypio/utils.py
@@ -16,12 +16,38 @@
 #
 
 
-def new_string_array(list_data, gateway):
-    if list_data is None:
+def dict_to_scalamap(jvm, d):
+    """
+    Convert python dictionary to scala type map
+
+    :param jvm: sc._jvm
+    :param d: python type dictionary
+    """
+    if d is None:
+        return None
+    sm = jvm.scala.Predef.Map().empty()
+    for k, v in d.items():
+        sm = sm.updated(k, v)
+    return sm
+
+def list_to_dict(l):
+    """
+    Convert python list to python dictionary
+
+    :param l: python type list
+
+    >>> list = ["key1", 1, "key2", 2, "key3", 3]
+    >>> list_to_dict(list) == {'key1': 1, 'key2': 2, 'key3': 3}
+    True
+    """
+    if l is None:
         return None
-    string_class = gateway.jvm.String
-    args = gateway.new_array(string_class, len(list_data))
-    for i in range(len(list_data)):
-        args[i] = list_data[i]
-    return args
+    return dict(zip(l[0::2], l[1::2]))
+
 
+if __name__ == "__main__":
+    import doctest
+    import sys
+    (failure_count, test_count) = doctest.testmod()
+    if failure_count:
+        sys.exit(-1)
\ No newline at end of file
diff --git 
a/tools/src/main/scala/org/apache/predictionio/tools/RunWorkflow.scala 
b/tools/src/main/scala/org/apache/predictionio/tools/RunWorkflow.scala
index 236d3ba..50b9337 100644
--- a/tools/src/main/scala/org/apache/predictionio/tools/RunWorkflow.scala
+++ b/tools/src/main/scala/org/apache/predictionio/tools/RunWorkflow.scala
@@ -51,19 +51,15 @@ object RunWorkflow extends Logging {
     verbose: Boolean = false): Expected[(Process, () => Unit)] = {
 
     val jarFiles = jarFilesForScala(engineDirPath).map(_.toURI)
-    val variantJson = wa.variantJson.getOrElse(new File(engineDirPath, 
"engine.json"))
-    val ei = Console.getEngineInfo(
-      variantJson,
-      engineDirPath)
-    val args = Seq(
-      "--engine-id",
-      ei.engineId,
-      "--engine-version",
-      ei.engineVersion,
-      "--engine-variant",
-      variantJson.toURI.toString,
-      "--verbosity",
-      wa.verbosity.toString) ++
+    val args =
+      (if (wa.mainPyFile.isEmpty) {
+        val variantJson = wa.variantJson.getOrElse(new File(engineDirPath, 
"engine.json"))
+        val ei = Console.getEngineInfo(variantJson, engineDirPath)
+        Seq(
+          "--engine-id", ei.engineId,
+          "--engine-version", ei.engineVersion,
+          "--engine-variant", variantJson.toURI.toString)
+      } else Nil) ++
       wa.engineFactory.map(
         x => Seq("--engine-factory", x)).getOrElse(Nil) ++
       wa.engineParamsKey.map(
@@ -72,19 +68,15 @@ object RunWorkflow extends Logging {
       (if (verbose) Seq("--verbose") else Nil) ++
       (if (wa.skipSanityCheck) Seq("--skip-sanity-check") else Nil) ++
       (if (wa.stopAfterRead) Seq("--stop-after-read") else Nil) ++
-      (if (wa.stopAfterPrepare) {
-        Seq("--stop-after-prepare")
-      } else {
-        Nil
-      }) ++
+      (if (wa.stopAfterPrepare) Seq("--stop-after-prepare") else Nil) ++
       wa.evaluation.map(x => Seq("--evaluation-class", x)).
         getOrElse(Nil) ++
       // If engineParamsGenerator is specified, it overrides the evaluation.
       wa.engineParamsGenerator.orElse(wa.evaluation)
         .map(x => Seq("--engine-params-generator-class", x))
         .getOrElse(Nil) ++
-      (if (wa.batch != "") Seq("--batch", wa.batch) else Nil) ++
-      Seq("--json-extractor", wa.jsonExtractor.toString)
+      Seq("--json-extractor", wa.jsonExtractor.toString,
+          "--verbosity", wa.verbosity.toString)
 
     val resourceName = wa.mainPyFile match {
       case Some(x) => x
diff --git 
a/tools/src/main/scala/org/apache/predictionio/tools/commands/Engine.scala 
b/tools/src/main/scala/org/apache/predictionio/tools/commands/Engine.scala
index 8380695..20d0ed9 100644
--- a/tools/src/main/scala/org/apache/predictionio/tools/commands/Engine.scala
+++ b/tools/src/main/scala/org/apache/predictionio/tools/commands/Engine.scala
@@ -218,22 +218,23 @@ object Engine extends EitherLogging {
     if (verifyResult.isLeft) {
       return Left(verifyResult.left.get)
     }
-    val ei = Console.getEngineInfo(
-      serverArgs.variantJson.getOrElse(new File(engineDirPath, "engine.json")),
-      engineDirPath)
     val engineInstances = storage.Storage.getMetaDataEngineInstances
-    val engineInstance = engineInstanceId map { eid =>
-      engineInstances.get(eid)
+    engineInstanceId map { eid =>
+      engineInstances.get(eid).map { r =>
+        RunServer.runServer(
+          r.id, serverArgs, sparkArgs, pioHome, engineDirPath, verbose)
+      } getOrElse {
+        logAndFail(s"Invalid engine instance ID ${eid}. Aborting.")
+      }
     } getOrElse {
+      val ei = Console.getEngineInfo(
+        serverArgs.variantJson.getOrElse(new File(engineDirPath, 
"engine.json")),
+        engineDirPath)
+
       engineInstances.getLatestCompleted(
-        ei.engineId, ei.engineVersion, ei.variantId)
-    }
-    engineInstance map { r =>
-      RunServer.runServer(
-        r.id, serverArgs, sparkArgs, pioHome, engineDirPath, verbose)
-    } getOrElse {
-      engineInstanceId map { eid =>
-        logAndFail(s"Invalid engine instance ID ${eid}. Aborting.")
+        ei.engineId, ei.engineVersion, ei.variantId).map { r =>
+        RunServer.runServer(
+          r.id, serverArgs, sparkArgs, pioHome, engineDirPath, verbose)
       } getOrElse {
         logAndFail(s"No valid engine instance found for engine ${ei.engineId} 
" +
           s"${ei.engineVersion}.\nTry running 'train' before 'deploy'. 
Aborting.")

Reply via email to