Repository: spark Updated Branches: refs/heads/branch-2.0 e38ff70e6 -> d07bce49f
[SPARK-15721][ML] Make DefaultParamsReadable, DefaultParamsWritable public ## What changes were proposed in this pull request? Made DefaultParamsReadable, DefaultParamsWritable public. Also added relevant doc and annotations. Added UnaryTransformerExample to demonstrate use of UnaryTransformer and DefaultParamsReadable,Writable. ## How was this patch tested? Wrote example making use of the now-public APIs. Compiled and ran locally Author: Joseph K. Bradley <jos...@databricks.com> Closes #13461 from jkbradley/defaultparamswritable. (cherry picked from commit 4c74ee8d8e1c3139d3d322ae68977f2ab53295df) Signed-off-by: Joseph K. Bradley <jos...@databricks.com> Project: http://git-wip-us.apache.org/repos/asf/spark/repo Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d07bce49 Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d07bce49 Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d07bce49 Branch: refs/heads/branch-2.0 Commit: d07bce49fc77aff25330920cc55b7079a3a2995c Parents: e38ff70 Author: Joseph K. Bradley <jos...@databricks.com> Authored: Mon Jun 6 09:49:45 2016 -0700 Committer: Joseph K. Bradley <jos...@databricks.com> Committed: Mon Jun 6 09:49:56 2016 -0700 ---------------------------------------------------------------------- .../examples/ml/UnaryTransformerExample.scala | 122 +++++++++++++++++++ .../org/apache/spark/ml/util/ReadWrite.scala | 44 ++++++- 2 files changed, 163 insertions(+), 3 deletions(-) ---------------------------------------------------------------------- http://git-wip-us.apache.org/repos/asf/spark/blob/d07bce49/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala ---------------------------------------------------------------------- diff --git a/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala new file mode 100644 index 0000000..13c72f8 --- /dev/null +++ b/examples/src/main/scala/org/apache/spark/examples/ml/UnaryTransformerExample.scala @@ -0,0 +1,122 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// scalastyle:off println +package org.apache.spark.examples.ml + +// $example on$ +import org.apache.spark.ml.UnaryTransformer +import org.apache.spark.ml.param.DoubleParam +import org.apache.spark.ml.util.{DefaultParamsReadable, DefaultParamsWritable, Identifiable} +import org.apache.spark.sql.functions.col +// $example off$ +import org.apache.spark.sql.SparkSession +// $example on$ +import org.apache.spark.sql.types.{DataType, DataTypes} +import org.apache.spark.util.Utils +// $example off$ + +/** + * An example demonstrating creating a custom [[org.apache.spark.ml.Transformer]] using + * the [[UnaryTransformer]] abstraction. + * + * Run with + * {{{ + * bin/run-example ml.UnaryTransformerExample + * }}} + */ +object UnaryTransformerExample { + + // $example on$ + /** + * Simple Transformer which adds a constant value to input Doubles. + * + * [[UnaryTransformer]] can be used to create a stage usable within Pipelines. + * It defines parameters for specifying input and output columns: + * [[UnaryTransformer.inputCol]] and [[UnaryTransformer.outputCol]]. + * It can optionally handle schema validation. + * + * [[DefaultParamsWritable]] provides a default implementation for persisting instances + * of this Transformer. + */ + class MyTransformer(override val uid: String) + extends UnaryTransformer[Double, Double, MyTransformer] with DefaultParamsWritable { + + final val shift: DoubleParam = new DoubleParam(this, "shift", "Value added to input") + + def getShift: Double = $(shift) + + def setShift(value: Double): this.type = set(shift, value) + + def this() = this(Identifiable.randomUID("myT")) + + override protected def createTransformFunc: Double => Double = (input: Double) => { + input + $(shift) + } + + override protected def outputDataType: DataType = DataTypes.DoubleType + + override protected def validateInputType(inputType: DataType): Unit = { + require(inputType == DataTypes.DoubleType, s"Bad input type: $inputType. Requires Double.") + } + } + + /** + * Companion object for our simple Transformer. + * + * [[DefaultParamsReadable]] provides a default implementation for loading instances + * of this Transformer which were persisted using [[DefaultParamsWritable]]. + */ + object MyTransformer extends DefaultParamsReadable[MyTransformer] + // $example off$ + + def main(args: Array[String]) { + val spark = SparkSession + .builder() + .appName("UnaryTransformerExample") + .getOrCreate() + + // $example on$ + val myTransformer = new MyTransformer() + .setShift(0.5) + .setInputCol("input") + .setOutputCol("output") + + // Create data, transform, and display it. + val data = spark.range(0, 5).toDF("input") + .select(col("input").cast("double").as("input")) + val result = myTransformer.transform(data) + result.show() + + // Save and load the Transformer. + val tmpDir = Utils.createTempDir() + val dirName = tmpDir.getCanonicalPath + myTransformer.write.overwrite().save(dirName) + val sameTransformer = MyTransformer.load(dirName) + + // Transform the data to show the results are identical. + val sameResult = sameTransformer.transform(data) + sameResult.show() + + Utils.deleteRecursively(tmpDir) + // $example off$ + + spark.stop() + } +} +// scalastyle:on println + http://git-wip-us.apache.org/repos/asf/spark/blob/d07bce49/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala ---------------------------------------------------------------------- diff --git a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala index 8ed40c3..90b8d7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/util/ReadWrite.scala @@ -68,6 +68,8 @@ private[util] sealed trait BaseReadWrite { } /** + * :: Experimental :: + * * Abstract class for utility classes that can save ML instances. */ @Experimental @@ -120,8 +122,11 @@ abstract class MLWriter extends BaseReadWrite with Logging { } /** + * :: Experimental :: + * * Trait for classes that provide [[MLWriter]]. */ +@Experimental @Since("1.6.0") trait MLWritable { @@ -139,12 +144,27 @@ trait MLWritable { def save(path: String): Unit = write.save(path) } -private[ml] trait DefaultParamsWritable extends MLWritable { self: Params => +/** + * :: Experimental :: + * + * Helper trait for making simple [[Params]] types writable. If a [[Params]] class stores + * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide + * a default implementation of writing saved instances of the class. + * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle + * [[org.apache.spark.sql.Dataset]]. + * + * @see [[DefaultParamsReadable]], the counterpart to this trait + */ +@Experimental +@Since("2.0.0") +trait DefaultParamsWritable extends MLWritable { self: Params => override def write: MLWriter = new DefaultParamsWriter(this) } /** + * :: Experimental :: + * * Abstract class for utility classes that can load ML instances. * * @tparam T ML instance type @@ -164,6 +184,8 @@ abstract class MLReader[T] extends BaseReadWrite { } /** + * :: Experimental :: + * * Trait for objects that provide [[MLReader]]. * * @tparam T ML instance type @@ -187,9 +209,25 @@ trait MLReadable[T] { def load(path: String): T = read.load(path) } -private[ml] trait DefaultParamsReadable[T] extends MLReadable[T] { - override def read: MLReader[T] = new DefaultParamsReader +/** + * :: Experimental :: + * + * Helper trait for making simple [[Params]] types readable. If a [[Params]] class stores + * all data as [[org.apache.spark.ml.param.Param]] values, then extending this trait will provide + * a default implementation of reading saved instances of the class. + * This only handles simple [[org.apache.spark.ml.param.Param]] types; e.g., it will not handle + * [[org.apache.spark.sql.Dataset]]. + * + * @tparam T ML instance type + * + * @see [[DefaultParamsWritable]], the counterpart to this trait + */ +@Experimental +@Since("2.0.0") +trait DefaultParamsReadable[T] extends MLReadable[T] { + + override def read: MLReader[T] = new DefaultParamsReader[T] } /** --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org