Repository: spark
Updated Branches:
  refs/heads/branch-2.0 e38ff70e6 -> d07bce49f


[SPARK-15721][ML] Make DefaultParamsReadable, DefaultParamsWritable public

## What changes were proposed in this pull request?

Made DefaultParamsReadable, DefaultParamsWritable public.  Also added relevant 
doc and annotations.  Added UnaryTransformerExample to demonstrate use of 
UnaryTransformer and DefaultParamsReadable,Writable.

## How was this patch tested?

Wrote example making use of the now-public APIs.  Compiled and ran locally

Author: Joseph K. Bradley <jos...@databricks.com>

Closes #13461 from jkbradley/defaultparamswritable.

(cherry picked from commit 4c74ee8d8e1c3139d3d322ae68977f2ab53295df)
Signed-off-by: Joseph K. Bradley <jos...@databricks.com>


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d07bce49
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d07bce49
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d07bce49

Branch: refs/heads/branch-2.0
Commit: d07bce49fc77aff25330920cc55b7079a3a2995c
Parents: e38ff70
Author: Joseph K. Bradley <jos...@databricks.com>
Authored: Mon Jun 6 09:49:45 2016 -0700
Committer: Joseph K. Bradley <jos...@databricks.com>
Committed: Mon Jun 6 09:49:56 2016 -0700

----------------------------------------------------------------------
 .../examples/ml/UnaryTransformerExample.scala   | 122 +++++++++++++++++++
 .../org/apache/spark/ml/util/ReadWrite.scala    |  44 ++++++-
 2 files changed, 163 insertions(+), 3 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d07bce49/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
----------------------------------------------------------------------
diff --git 
a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
 
b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
new file mode 100644
index 0000000..13c72f8
--- /dev/null
+++ 
b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// scalastyle:off println
+package org.apache.spark.examples.ml
+
+// $example on$
+import org.apache.spark.ml.UnaryTransformer
+import org.apache.spark.ml.param.DoubleParam
+import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, 
Identifiable}
+import org.apache.spark.sql.functions.col
+// $example off$
+import org.apache.spark.sql.SparkSession
+// $example on$
+import org.apache.spark.sql.types.{DataType, DataTypes}
+import org.apache.spark.util.Utils
+// $example off$
+
+/**
+ * An example demonstrating creating a custom 
[[org.apache.spark.ml.Transformer]] using
+ * the [[UnaryTransformer]] abstraction.
+ *
+ * Run with
+ * {{{
+ * bin/run-example ml.UnaryTransformerExample
+ * }}}
+ */
+object UnaryTransformerExample {
+
+  // $example on$
+  /**
+   * Simple Transformer which adds a constant value to input Doubles.
+   *
+   * [[UnaryTransformer]] can be used to create a stage usable within 
Pipelines.
+   * It defines parameters for specifying input and output columns:
+   * [[UnaryTransformer.inputCol]] and [[UnaryTransformer.outputCol]].
+   * It can optionally handle schema validation.
+   *
+   * [[DefaultParamsWritable]] provides a default implementation for 
persisting instances
+   * of this Transformer.
+   */
+  class MyTransformer(override val uid: String)
+    extends UnaryTransformer[Double, Double, MyTransformer] with 
DefaultParamsWritable {
+
+    final val shift: DoubleParam = new DoubleParam(this, "shift", "Value added 
to input")
+
+    def getShift: Double = $(shift)
+
+    def setShift(value: Double): this.type = set(shift, value)
+
+    def this() = this(Identifiable.randomUID("myT"))
+
+    override protected def createTransformFunc: Double => Double = (input: 
Double) => {
+      input + $(shift)
+    }
+
+    override protected def outputDataType: DataType = DataTypes.DoubleType
+
+    override protected def validateInputType(inputType: DataType): Unit = {
+      require(inputType == DataTypes.DoubleType, s"Bad input type: $inputType. 
Requires Double.")
+    }
+  }
+
+  /**
+   * Companion object for our simple Transformer.
+   *
+   * [[DefaultParamsReadable]] provides a default implementation for loading 
instances
+   * of this Transformer which were persisted using [[DefaultParamsWritable]].
+   */
+  object MyTransformer extends DefaultParamsReadable[MyTransformer]
+  // $example off$
+
+  def main(args: Array[String]) {
+    val spark = SparkSession
+      .builder()
+      .appName("UnaryTransformerExample")
+      .getOrCreate()
+
+    // $example on$
+    val myTransformer = new MyTransformer()
+      .setShift(0.5)
+      .setInputCol("input")
+      .setOutputCol("output")
+
+    // Create data, transform, and display it.
+    val data = spark.range(0, 5).toDF("input")
+      .select(col("input").cast("double").as("input"))
+    val result = myTransformer.transform(data)
+    result.show()
+
+    // Save and load the Transformer.
+    val tmpDir = Utils.createTempDir()
+    val dirName = tmpDir.getCanonicalPath
+    myTransformer.write.overwrite().save(dirName)
+    val sameTransformer = MyTransformer.load(dirName)
+
+    // Transform the data to show the results are identical.
+    val sameResult = sameTransformer.transform(data)
+    sameResult.show()
+
+    Utils.deleteRecursively(tmpDir)
+    // $example off$
+
+    spark.stop()
+  }
+}
+// scalastyle:on println
+

http://git-wip-us.apache.org/repos/asf/spark/blob/d07bce49/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala 
b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
index 8ed40c3..90b8d7d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala
@@ -68,6 +68,8 @@ private[util] sealed trait BaseReadWrite {
 }
 
 /**
+ * :: Experimental ::
+ *
  * Abstract class for utility classes that can save ML instances.
  */
 @Experimental
@@ -120,8 +122,11 @@ abstract class MLWriter extends BaseReadWrite with Logging 
{
 }
 
 /**
+ * :: Experimental ::
+ *
  * Trait for classes that provide [[MLWriter]].
  */
+@Experimental
 @Since("1.6.0")
 trait MLWritable {
 
@@ -139,12 +144,27 @@ trait MLWritable {
   def save(path: String): Unit = write.save(path)
 }
 
-private[ml] trait DefaultParamsWritable extends MLWritable { self: Params =>
+/**
+ * :: Experimental ::
+ *
+ * Helper trait for making simple [[Params]] types writable.  If a [[Params]] 
class stores
+ * all data as [[org.apache.spark.ml.param.Param]] values, then extending this 
trait will provide
+ * a default implementation of writing saved instances of the class.
+ * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., 
it will not handle
+ * [[org.apache.spark.sql.Dataset]].
+ *
+ * @see  [[DefaultParamsReadable]], the counterpart to this trait
+ */
+@Experimental
+@Since("2.0.0")
+trait DefaultParamsWritable extends MLWritable { self: Params =>
 
   override def write: MLWriter = new DefaultParamsWriter(this)
 }
 
 /**
+ * :: Experimental ::
+ *
  * Abstract class for utility classes that can load ML instances.
  *
  * @tparam T ML instance type
@@ -164,6 +184,8 @@ abstract class MLReader[T] extends BaseReadWrite {
 }
 
 /**
+ * :: Experimental ::
+ *
  * Trait for objects that provide [[MLReader]].
  *
  * @tparam T ML instance type
@@ -187,9 +209,25 @@ trait MLReadable[T] {
   def load(path: String): T = read.load(path)
 }
 
-private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] {
 
-  override def read: MLReader[T] = new DefaultParamsReader
+/**
+ * :: Experimental ::
+ *
+ * Helper trait for making simple [[Params]] types readable.  If a [[Params]] 
class stores
+ * all data as [[org.apache.spark.ml.param.Param]] values, then extending this 
trait will provide
+ * a default implementation of reading saved instances of the class.
+ * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., 
it will not handle
+ * [[org.apache.spark.sql.Dataset]].
+ *
+ * @tparam T ML instance type
+ *
+ * @see  [[DefaultParamsWritable]], the counterpart to this trait
+ */
+@Experimental
+@Since("2.0.0")
+trait DefaultParamsReadable[T] extends MLReadable[T] {
+
+  override def read: MLReader[T] = new DefaultParamsReader[T]
 }
 
 /**


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to