This is an automated email from the ASF dual-hosted git repository.
takezoe pushed a commit to branch develop
in repository https://gitbox.apache.org/repos/asf/predictionio.git
The following commit(s) were added to refs/heads/develop by this push:
new 9bf29cd [PIO-192] Enhance PySpark support (#494)
9bf29cd is described below
commit 9bf29cdeaf872e7b01198c4cc09bb2d3a06d1f1f
Author: takako shimamoto <[email protected]>
AuthorDate: Mon Dec 10 21:58:17 2018 +0900
[PIO-192] Enhance PySpark support (#494)
---
bin/pio-shell | 1 -
build.sbt | 14 +--
.../workflow/EngineServerPluginContext.scala | 19 ++--
.../predictionio/workflow/JsonExtractor.scala | 9 +-
.../predictionio/e2/engine/PythonEngine.scala | 96 ++++++++++++++++++
python/pypio/__init__.py | 9 ++
python/pypio/data/eventstore.py | 4 +-
python/pypio/pypio.py | 111 +++++++++++++++++++++
python/pypio/shell.py | 23 -----
python/pypio/utils.py | 40 ++++++--
.../apache/predictionio/tools/RunWorkflow.scala | 32 +++---
.../predictionio/tools/commands/Engine.scala | 27 ++---
12 files changed, 303 insertions(+), 82 deletions(-)
diff --git a/bin/pio-shell b/bin/pio-shell
index cd119cd..e23041a 100755
--- a/bin/pio-shell
+++ b/bin/pio-shell
@@ -65,7 +65,6 @@ then
# Get paths of assembly jars to pass to pyspark
. ${PIO_HOME}/bin/compute-classpath.sh
shift
- export PYTHONSTARTUP=${PIO_HOME}/python/pypio/shell.py
export PYTHONPATH=${PIO_HOME}/python
${SPARK_HOME}/bin/pyspark --jars ${ASSEMBLY_JARS} $@
else
diff --git a/build.sbt b/build.sbt
index 9efed21..65c4ca3 100644
--- a/build.sbt
+++ b/build.sbt
@@ -151,19 +151,19 @@ val core = (project in file("core")).
enablePlugins(SbtTwirl).
disablePlugins(sbtassembly.AssemblyPlugin)
-val tools = (project in file("tools")).
+val e2 = (project in file("e2")).
dependsOn(core).
- dependsOn(data).
settings(commonSettings: _*).
- settings(commonTestSettings: _*).
- settings(skip in publish := true).
enablePlugins(GenJavadocPlugin).
- enablePlugins(SbtTwirl)
+ disablePlugins(sbtassembly.AssemblyPlugin)
-val e2 = (project in file("e2")).
+val tools = (project in file("tools")).
+ dependsOn(e2).
settings(commonSettings: _*).
+ settings(commonTestSettings: _*).
+ settings(skip in publish := true).
enablePlugins(GenJavadocPlugin).
- disablePlugins(sbtassembly.AssemblyPlugin)
+ enablePlugins(SbtTwirl)
val dataEs = if (majorVersion(es) == 1) dataElasticsearch1 else
dataElasticsearch
diff --git
a/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala
b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala
index cfc83eb..011cd95 100644
---
a/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala
+++
b/core/src/main/scala/org/apache/predictionio/workflow/EngineServerPluginContext.scala
@@ -55,9 +55,10 @@ object EngineServerPluginContext extends Logging {
EngineServerPlugin.outputSniffer -> mutable.Map())
val pluginParams = mutable.Map[String, JValue]()
val serviceLoader = ServiceLoader.load(classOf[EngineServerPlugin])
- val variantJson = parse(stringFromFile(engineVariant))
- (variantJson \ "plugins").extractOpt[JObject].foreach { pluginDefs =>
- pluginDefs.obj.foreach { pluginParams += _ }
+ stringFromFile(engineVariant).foreach { variantJson =>
+ (parse(variantJson) \ "plugins").extractOpt[JObject].foreach {
pluginDefs =>
+ pluginDefs.obj.foreach { pluginParams += _ }
+ }
}
serviceLoader foreach { service =>
pluginParams.get(service.pluginName) map { params =>
@@ -77,11 +78,15 @@ object EngineServerPluginContext extends Logging {
log)
}
- private def stringFromFile(filePath: String): String = {
+ private def stringFromFile(filePath: String): Option[String] = {
try {
- val uri = new URI(filePath)
- val fs = FileSystem.get(uri, new Configuration())
- new String(ByteStreams.toByteArray(fs.open(new Path(uri))).map(_.toChar))
+ val fs = FileSystem.get(new Configuration())
+ val path = new Path(new URI(filePath))
+ if (fs.exists(path)) {
+ Some(new String(ByteStreams.toByteArray(fs.open(path)).map(_.toChar)))
+ } else {
+ None
+ }
} catch {
case e: java.io.IOException =>
error(s"Error reading from file: ${e.getMessage}. Aborting.")
diff --git
a/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala
b/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala
index cb71f14..3aafe67 100644
--- a/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala
+++ b/core/src/main/scala/org/apache/predictionio/workflow/JsonExtractor.scala
@@ -32,7 +32,6 @@ import org.json4s.native.JsonMethods.compact
import org.json4s.native.JsonMethods.pretty
import org.json4s.native.JsonMethods.parse
import org.json4s.native.JsonMethods.render
-import org.json4s.reflect.TypeInfo
object JsonExtractor {
@@ -144,7 +143,13 @@ object JsonExtractor {
formats: Formats,
clazz: Class[T]): T = {
- Extraction.extract(parse(json), TypeInfo(clazz,
None))(formats).asInstanceOf[T]
+ implicit val f = formats
+ implicit val m = if (clazz == classOf[Map[_, _]]) {
+ Manifest.classType(clazz, manifest[String], manifest[Any])
+ } else {
+ Manifest.classType(clazz)
+ }
+ Extraction.extract(parse(json))
}
private def extractWithGson[T](
diff --git
a/e2/src/main/scala/org/apache/predictionio/e2/engine/PythonEngine.scala
b/e2/src/main/scala/org/apache/predictionio/e2/engine/PythonEngine.scala
new file mode 100644
index 0000000..a2c7282
--- /dev/null
+++ b/e2/src/main/scala/org/apache/predictionio/e2/engine/PythonEngine.scala
@@ -0,0 +1,96 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.predictionio.e2.engine
+
+import java.util.Arrays
+
+import org.apache.predictionio.controller._
+import org.apache.predictionio.workflow.KryoInstantiator
+import org.apache.spark.SparkContext
+import org.apache.spark.ml.PipelineModel
+import org.apache.spark.sql.catalyst.expressions.Literal
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.{Row, SparkSession}
+
+
+object PythonEngine extends EngineFactory {
+
+ private[engine] type Query = Map[String, Any]
+
+ def apply(): Engine[EmptyTrainingData, EmptyEvaluationInfo,
EmptyPreparedData,
+ Query, Row, EmptyActualResult] = {
+ new Engine(
+ classOf[PythonDataSource],
+ classOf[PythonPreparator],
+ Map("default" -> classOf[PythonAlgorithm]),
+ classOf[PythonServing])
+ }
+
+ def models(model: PipelineModel): Array[Byte] = {
+ val kryo = KryoInstantiator.newKryoInjection
+ kryo(Seq(model))
+ }
+
+}
+
+import PythonEngine.Query
+
+class PythonDataSource extends
+ PDataSource[EmptyTrainingData, EmptyEvaluationInfo, Query,
EmptyActualResult] {
+ def readTraining(sc: SparkContext): EmptyTrainingData = new
SerializableClass()
+}
+
+class PythonPreparator extends PPreparator[EmptyTrainingData,
EmptyPreparedData] {
+ def prepare(sc: SparkContext, trainingData: EmptyTrainingData):
EmptyPreparedData =
+ new SerializableClass()
+}
+
+object PythonServing {
+ private[engine] val columns = "PythonPredictColumns"
+
+ case class Params(columns: Seq[String]) extends
org.apache.predictionio.controller.Params
+}
+
+class PythonServing(params: PythonServing.Params) extends LFirstServing[Query,
Row] {
+ override def supplement(q: Query): Query = {
+ q + (PythonServing.columns -> params.columns)
+ }
+}
+
+class PythonAlgorithm extends
+ P2LAlgorithm[EmptyPreparedData, PipelineModel, Query, Row] {
+
+ def train(sc: SparkContext, data: EmptyPreparedData): PipelineModel = ???
+
+ def predict(model: PipelineModel, query: Query): Row = {
+ val selectCols = query(PythonServing.columns).asInstanceOf[Seq[String]]
+ val (colNames, data) = (query - PythonServing.columns).toList.unzip
+
+ val rows = Arrays.asList(Row.fromSeq(data))
+ val schema = StructType(colNames.zipWithIndex.map { case (col, i) =>
+ StructField(col, Literal(data(i)).dataType)
+ })
+
+ val spark = SparkSession.builder.getOrCreate()
+ val df = spark.createDataFrame(rows, schema)
+ model.transform(df)
+ .select(selectCols.head, selectCols.tail: _*)
+ .first()
+ }
+
+}
diff --git a/python/pypio/__init__.py b/python/pypio/__init__.py
index 04d8ac3..e0ca788 100644
--- a/python/pypio/__init__.py
+++ b/python/pypio/__init__.py
@@ -18,3 +18,12 @@
"""
PyPIO is the Python API for PredictionIO.
"""
+
+from __future__ import absolute_import
+
+from pypio.pypio import *
+
+
+__all__ = [
+ 'pypio'
+]
diff --git a/python/pypio/data/eventstore.py b/python/pypio/data/eventstore.py
index 4eb73df..58f09d1 100644
--- a/python/pypio/data/eventstore.py
+++ b/python/pypio/data/eventstore.py
@@ -17,8 +17,8 @@
from __future__ import absolute_import
-from pypio.utils import new_string_array
from pyspark.sql.dataframe import DataFrame
+from pyspark.sql import utils
__all__ = ["PEventStore"]
@@ -43,7 +43,7 @@ class PEventStore(object):
pes =
self._sc._jvm.org.apache.predictionio.data.store.python.PPythonEventStore
jdf = pes.aggregateProperties(app_name, entity_type, channel_name,
start_time, until_time,
- new_string_array(required,
self._sc._gateway),
+ utils.toJArray(self._sc._gateway,
self._sc._gateway.jvm.String, required),
self._jss)
return DataFrame(jdf, self.sql_ctx)
diff --git a/python/pypio/pypio.py b/python/pypio/pypio.py
new file mode 100644
index 0000000..3a32cd8
--- /dev/null
+++ b/python/pypio/pypio.py
@@ -0,0 +1,111 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from __future__ import absolute_import
+
+import atexit
+import json
+import os
+import sys
+
+from pypio.data import PEventStore
+from pypio.utils import dict_to_scalamap, list_to_dict
+from pypio.workflow import CleanupFunctions
+from pyspark.sql import SparkSession
+
+
+def init():
+ global spark
+ spark = SparkSession.builder.getOrCreate()
+ global sc
+ sc = spark.sparkContext
+ global sqlContext
+ sqlContext = spark._wrapped
+ global p_event_store
+ p_event_store = PEventStore(spark._jsparkSession, sqlContext)
+
+ cleanup_functions = CleanupFunctions(sqlContext)
+ atexit.register(lambda: cleanup_functions.run())
+ atexit.register(lambda: sc.stop())
+ print("Initialized pypio")
+
+
+def find_events(app_name):
+ """
+ Returns a dataset of the specified app.
+
+ :param app_name: app name
+ :return: :py:class:`pyspark.sql.DataFrame`
+ """
+ return p_event_store.find(app_name)
+
+
+def save_model(model, predict_columns):
+ """
+ Save a PipelineModel object to storage.
+
+ :param model: :py:class:`pyspark.ml.pipeline.PipelineModel`
+ :param predict_columns: prediction columns
+ :return: identifier for the trained model to use for predict
+ """
+ if not predict_columns:
+ raise ValueError("predict_columns should have more than one value")
+ if os.environ.get('PYSPARK_PYTHON') is None:
+ # spark-submit
+ d = list_to_dict(sys.argv[1:])
+ pio_env = list_to_dict([v for e in d['--env'].split(',') for v in
e.split('=')])
+ else:
+ # pyspark
+ pio_env = {k: v for k, v in os.environ.items() if k.startswith('PIO_')}
+
+ meta_storage =
sc._jvm.org.apache.predictionio.data.storage.Storage.getMetaDataEngineInstances()
+
+ meta = sc._jvm.org.apache.predictionio.data.storage.EngineInstance.apply(
+ "",
+ "INIT", # status
+ sc._jvm.org.joda.time.DateTime.now(), # startTime
+ sc._jvm.org.joda.time.DateTime.now(), # endTime
+ "org.apache.predictionio.e2.engine.PythonEngine", # engineId
+ "1", # engineVersion
+ "default", # engineVariant
+ "org.apache.predictionio.e2.engine.PythonEngine", # engineFactory
+ "", # batch
+ dict_to_scalamap(sc._jvm, pio_env), # env
+ sc._jvm.scala.Predef.Map().empty(), # sparkConf
+ "{\"\":{}}", # dataSourceParams
+ "{\"\":{}}", # preparatorParams
+ "[{\"default\":{}}]", # algorithmsParams
+ json.dumps({"":{"columns":[v for v in predict_columns]}}) #
servingParams
+ )
+ id = meta_storage.insert(meta)
+
+ engine = sc._jvm.org.apache.predictionio.e2.engine.PythonEngine
+ data = sc._jvm.org.apache.predictionio.data.storage.Model(id,
engine.models(model._to_java()))
+ model_storage =
sc._jvm.org.apache.predictionio.data.storage.Storage.getModelDataModels()
+ model_storage.insert(data)
+
+ meta_storage.update(
+ sc._jvm.org.apache.predictionio.data.storage.EngineInstance.apply(
+ id, "COMPLETED", meta.startTime(),
sc._jvm.org.joda.time.DateTime.now(),
+ meta.engineId(), meta.engineVersion(), meta.engineVariant(),
+ meta.engineFactory(), meta.batch(), meta.env(), meta.sparkConf(),
+ meta.dataSourceParams(), meta.preparatorParams(),
meta.algorithmsParams(), meta.servingParams()
+ )
+ )
+
+ return id
+
diff --git a/python/pypio/shell.py b/python/pypio/shell.py
deleted file mode 100644
index b0295d3..0000000
--- a/python/pypio/shell.py
+++ /dev/null
@@ -1,23 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements. See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-from pypio.data import PEventStore
-from pypio.workflow import CleanupFunctions
-
-p_event_store = PEventStore(spark._jsparkSession, sqlContext)
-cleanup_functions = CleanupFunctions(sqlContext)
-
diff --git a/python/pypio/utils.py b/python/pypio/utils.py
index 76900c3..4155efb 100644
--- a/python/pypio/utils.py
+++ b/python/pypio/utils.py
@@ -16,12 +16,38 @@
#
-def new_string_array(list_data, gateway):
- if list_data is None:
+def dict_to_scalamap(jvm, d):
+ """
+ Convert python dictionary to scala type map
+
+ :param jvm: sc._jvm
+ :param d: python type dictionary
+ """
+ if d is None:
+ return None
+ sm = jvm.scala.Predef.Map().empty()
+ for k, v in d.items():
+ sm = sm.updated(k, v)
+ return sm
+
+def list_to_dict(l):
+ """
+ Convert python list to python dictionary
+
+ :param l: python type list
+
+ >>> list = ["key1", 1, "key2", 2, "key3", 3]
+ >>> list_to_dict(list) == {'key1': 1, 'key2': 2, 'key3': 3}
+ True
+ """
+ if l is None:
return None
- string_class = gateway.jvm.String
- args = gateway.new_array(string_class, len(list_data))
- for i in range(len(list_data)):
- args[i] = list_data[i]
- return args
+ return dict(zip(l[0::2], l[1::2]))
+
+if __name__ == "__main__":
+ import doctest
+ import sys
+ (failure_count, test_count) = doctest.testmod()
+ if failure_count:
+ sys.exit(-1)
\ No newline at end of file
diff --git
a/tools/src/main/scala/org/apache/predictionio/tools/RunWorkflow.scala
b/tools/src/main/scala/org/apache/predictionio/tools/RunWorkflow.scala
index 236d3ba..50b9337 100644
--- a/tools/src/main/scala/org/apache/predictionio/tools/RunWorkflow.scala
+++ b/tools/src/main/scala/org/apache/predictionio/tools/RunWorkflow.scala
@@ -51,19 +51,15 @@ object RunWorkflow extends Logging {
verbose: Boolean = false): Expected[(Process, () => Unit)] = {
val jarFiles = jarFilesForScala(engineDirPath).map(_.toURI)
- val variantJson = wa.variantJson.getOrElse(new File(engineDirPath,
"engine.json"))
- val ei = Console.getEngineInfo(
- variantJson,
- engineDirPath)
- val args = Seq(
- "--engine-id",
- ei.engineId,
- "--engine-version",
- ei.engineVersion,
- "--engine-variant",
- variantJson.toURI.toString,
- "--verbosity",
- wa.verbosity.toString) ++
+ val args =
+ (if (wa.mainPyFile.isEmpty) {
+ val variantJson = wa.variantJson.getOrElse(new File(engineDirPath,
"engine.json"))
+ val ei = Console.getEngineInfo(variantJson, engineDirPath)
+ Seq(
+ "--engine-id", ei.engineId,
+ "--engine-version", ei.engineVersion,
+ "--engine-variant", variantJson.toURI.toString)
+ } else Nil) ++
wa.engineFactory.map(
x => Seq("--engine-factory", x)).getOrElse(Nil) ++
wa.engineParamsKey.map(
@@ -72,19 +68,15 @@ object RunWorkflow extends Logging {
(if (verbose) Seq("--verbose") else Nil) ++
(if (wa.skipSanityCheck) Seq("--skip-sanity-check") else Nil) ++
(if (wa.stopAfterRead) Seq("--stop-after-read") else Nil) ++
- (if (wa.stopAfterPrepare) {
- Seq("--stop-after-prepare")
- } else {
- Nil
- }) ++
+ (if (wa.stopAfterPrepare) Seq("--stop-after-prepare") else Nil) ++
wa.evaluation.map(x => Seq("--evaluation-class", x)).
getOrElse(Nil) ++
// If engineParamsGenerator is specified, it overrides the evaluation.
wa.engineParamsGenerator.orElse(wa.evaluation)
.map(x => Seq("--engine-params-generator-class", x))
.getOrElse(Nil) ++
- (if (wa.batch != "") Seq("--batch", wa.batch) else Nil) ++
- Seq("--json-extractor", wa.jsonExtractor.toString)
+ Seq("--json-extractor", wa.jsonExtractor.toString,
+ "--verbosity", wa.verbosity.toString)
val resourceName = wa.mainPyFile match {
case Some(x) => x
diff --git
a/tools/src/main/scala/org/apache/predictionio/tools/commands/Engine.scala
b/tools/src/main/scala/org/apache/predictionio/tools/commands/Engine.scala
index 8380695..20d0ed9 100644
--- a/tools/src/main/scala/org/apache/predictionio/tools/commands/Engine.scala
+++ b/tools/src/main/scala/org/apache/predictionio/tools/commands/Engine.scala
@@ -218,22 +218,23 @@ object Engine extends EitherLogging {
if (verifyResult.isLeft) {
return Left(verifyResult.left.get)
}
- val ei = Console.getEngineInfo(
- serverArgs.variantJson.getOrElse(new File(engineDirPath, "engine.json")),
- engineDirPath)
val engineInstances = storage.Storage.getMetaDataEngineInstances
- val engineInstance = engineInstanceId map { eid =>
- engineInstances.get(eid)
+ engineInstanceId map { eid =>
+ engineInstances.get(eid).map { r =>
+ RunServer.runServer(
+ r.id, serverArgs, sparkArgs, pioHome, engineDirPath, verbose)
+ } getOrElse {
+ logAndFail(s"Invalid engine instance ID ${eid}. Aborting.")
+ }
} getOrElse {
+ val ei = Console.getEngineInfo(
+ serverArgs.variantJson.getOrElse(new File(engineDirPath,
"engine.json")),
+ engineDirPath)
+
engineInstances.getLatestCompleted(
- ei.engineId, ei.engineVersion, ei.variantId)
- }
- engineInstance map { r =>
- RunServer.runServer(
- r.id, serverArgs, sparkArgs, pioHome, engineDirPath, verbose)
- } getOrElse {
- engineInstanceId map { eid =>
- logAndFail(s"Invalid engine instance ID ${eid}. Aborting.")
+ ei.engineId, ei.engineVersion, ei.variantId).map { r =>
+ RunServer.runServer(
+ r.id, serverArgs, sparkArgs, pioHome, engineDirPath, verbose)
} getOrElse {
logAndFail(s"No valid engine instance found for engine ${ei.engineId}
" +
s"${ei.engineVersion}.\nTry running 'train' before 'deploy'.
Aborting.")