This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new 5153370  [MXNET-471] Add Builder class for Scala Module and DataBatch 
to simplify construction (#11045)
5153370 is described below

commit 5153370e3a3d922bbccdf3d28b5c6f31995722fe
Author: Yizhi Liu <liuyi...@apache.org>
AuthorDate: Sun May 27 14:07:31 2018 -0700

    [MXNET-471] Add Builder class for Scala Module and DataBatch to simplify 
construction (#11045)
    
    * Add Builder class for Module and DataBatch to simplify construction. Add 
annotation to enable varargs in Java
    
    * change provideData & provideLabel to more proper names. add test cases
    
    * lint code
    
    * add comments for type-safe
    
    * fix test for DataBatch
    
    * remove varargs
    
    * check data != null in DataBatch.Builder
---
 .../core/src/main/scala/org/apache/mxnet/IO.scala  | 105 ++++++++++++++++++++-
 .../src/main/scala/org/apache/mxnet/Shape.scala    |   4 +
 .../src/main/scala/org/apache/mxnet/Symbol.scala   |   1 -
 .../scala/org/apache/mxnet/module/BaseModule.scala |  30 ++++++
 .../scala/org/apache/mxnet/module/Module.scala     |  73 +++++++++++++-
 .../test/scala/org/apache/mxnet/ModuleSuite.scala  |  28 +++---
 .../main/scala/org/apache/mxnet/NDArrayMacro.scala |   3 +
 7 files changed, 230 insertions(+), 14 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index 7a9c1a7..d9c767c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -19,9 +19,10 @@ package org.apache.mxnet
 
 import org.apache.mxnet.Base._
 import org.apache.mxnet.DType.DType
-import org.apache.mxnet.io.{MXDataPack, MXDataIter}
+import org.apache.mxnet.io.{MXDataIter, MXDataPack}
 import org.slf4j.LoggerFactory
 
+import scala.annotation.varargs
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ListBuffer
 
@@ -160,6 +161,108 @@ class DataBatch(val data: IndexedSeq[NDArray],
   def provideLabel: ListMap[String, Shape] = providedLabel
 }
 
+object DataBatch {
+  /**
+   * Builder class for DataBatch.
+   */
+  class Builder() {
+    private var data: IndexedSeq[NDArray] = null
+    private var label: IndexedSeq[NDArray] = null
+    private var index: IndexedSeq[Long] = null
+    private var pad: Int = 0
+    private var bucketKey: AnyRef = null
+    private var datatShapes: ListMap[String, Shape] = null
+    private var labelShapes: ListMap[String, Shape] = null
+
+    /**
+     * Set the input data.
+     * @param data a list of data.
+     * @return this.
+     */
+    @varargs def setData(data: NDArray*): Builder = {
+      this.data = data.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the labels in the same order of data.
+     * @param label a list of labels.
+     * @return this.
+     */
+    @varargs def setLabel(label: NDArray*): Builder = {
+      this.label = label.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the example indices in this batch.
+     * @param index indices in the same order of data.
+     * @return this.
+     */
+    @varargs def setIndex(index: Long*): Builder = {
+      this.index = index.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the pad.
+     * @param pad The number of examples padded at the end of a batch. It is 
used when the
+     *            total number of examples read is not divisible by the 
`batch_size`.
+     *            These extra padded examples are ignored in prediction.
+     * @return this
+     */
+    def setPad(pad: Int): Builder = {
+      this.pad = pad
+      this
+    }
+
+    /**
+     * Set the bucket key, used for bucketing module.
+     * @param bucketKey the bucket key related to this batch.
+     * @return this.
+     */
+    def setBucketKey(bucketKey: AnyRef): Builder = {
+      this.bucketKey = bucketKey
+      this
+    }
+
+    /**
+     * Provide the shape of a data.
+     * @param name data name.
+     * @param shape data shape.
+     * @return this.
+     */
+    def provideDataShape(name: String, shape: Shape): Builder = {
+      if (datatShapes == null) {
+        datatShapes = ListMap((name, shape))
+      } else {
+        datatShapes = datatShapes.updated(name, shape)
+      }
+      this
+    }
+
+    /**
+     * Provide the shape of a label.
+     * @param name label name.
+     * @param shape label shape.
+     * @return this.
+     */
+    def provideLabelShape(name: String, shape: Shape): Builder = {
+      if (labelShapes == null) {
+        labelShapes = ListMap((name, shape))
+      } else {
+        labelShapes = labelShapes.updated(name, shape)
+      }
+      this
+    }
+
+    def build(): DataBatch = {
+      require(data != null, "data is required.")
+      new DataBatch(data, label, index, pad, bucketKey, datatShapes, 
labelShapes)
+    }
+  }
+}
+
 /**
  * DataIter object in mxnet.
  */
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
index e632ade..6891762 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala
@@ -17,6 +17,8 @@
 
 package org.apache.mxnet
 
+import scala.annotation.varargs
+
 /**
  * Shape of [[NDArray]] or other data
  */
@@ -28,6 +30,7 @@ class Shape(dims: Traversable[Int]) extends Serializable {
   }
 
   def apply(dim: Int): Int = shape(dim)
+  def get(dim: Int): Int = apply(dim)
   def size: Int = shape.size
   def length: Int = shape.length
   def drop(dim: Int): Shape = new Shape(shape.drop(dim))
@@ -56,4 +59,5 @@ class Shape(dims: Traversable[Int]) extends Serializable {
 object Shape {
   def apply(dims: Int *): Shape = new Shape(dims: _*)
   def apply(dims: Traversable[Int]): Shape = new Shape(dims)
+  @varargs def create(dims: Int*): Shape = new Shape(dims)
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 60efd2b..a17fe57 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -101,7 +101,6 @@ class Symbol private(private[mxnet] val handle: 
SymbolHandle) extends WarnIfNotD
     var index: Int = -1
     for ((output, i) <- listOutputs().view.zipWithIndex) {
       if (output == name) {
-        require(index == -1, s"There are multiple outputs with name $name")
         index = i
       }
     }
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
index 108cff4..60b80f2 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala
@@ -23,6 +23,8 @@ import org.apache.mxnet.optimizer.SGD
 import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 import org.slf4j.Logger
+
+import scala.annotation.varargs
 import scala.collection.mutable.ArrayBuffer
 
 object BaseModule {
@@ -469,6 +471,15 @@ abstract class BaseModule {
   def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit
 
   /**
+   * Forward computation.
+   * @param dataBatch a batch of data.
+   * @param isTrain Whether it is for training or not.
+   */
+  def forward(dataBatch: DataBatch, isTrain: Boolean): Unit = {
+    forward(dataBatch, Option(isTrain))
+  }
+
+  /**
    * Backward computation.
    * @param outGrads Gradient on the outputs to be propagated back.
    *                 This parameter is only needed when bind is called
@@ -549,6 +560,25 @@ abstract class BaseModule {
            forceRebind: Boolean = false, sharedModule: Option[BaseModule] = 
None,
            gradReq: String = "write"): Unit
 
+
+ /**
+  * Bind the symbols to construct executors.
+  * This is necessary before one can perform computation with the module.
+  * @param forTraining Default is `True`. Whether the executors should be bind 
for training.
+  * @param inputsNeedGrad  Default is `False`.
+  *                        Whether the gradients to the input data need to be 
computed.
+  *                        Typically this is not needed.
+  *                        But this might be needed when implementing 
composition of modules.
+  * @param forceRebind Default is `False`. This function does nothing
+  *                    if the executors are already binded. But with this 
`True`,
+  *                    the executors will be forced to rebind.
+  * @param dataShape Typically is `DataIter.provideData`.
+  */
+  @varargs def bind(forTraining: Boolean, inputsNeedGrad: Boolean,
+                    forceRebind: Boolean, dataShape: DataDesc*): Unit = {
+    bind(dataShape.toVector, None, forTraining, inputsNeedGrad, forceRebind, 
None)
+  }
+
   // Install and initialize optimizers.
   def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new 
SGD(),
                     resetOptimizer: Boolean = true, forceInit: Boolean = 
false): Unit
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
index ac3d645..d55a426 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala
@@ -17,13 +17,16 @@
 
 package org.apache.mxnet.module
 
-import java.io.{FileInputStream, BufferedInputStream, BufferedOutputStream, 
FileOutputStream}
+import java.io.{BufferedInputStream, BufferedOutputStream, FileInputStream, 
FileOutputStream}
+
 import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 import org.apache.mxnet.module.DataParallelExecutorGroup.Builder
 import org.apache.mxnet.optimizer.SGD
 import org.slf4j.LoggerFactory
 
+import scala.annotation.varargs
+
 /**
  * Module is a basic module that wrap a `Symbol`. It is functionally the same
  * as the `FeedForward` model, except under the module API.
@@ -642,4 +645,72 @@ object Module {
     }
     mod
   }
+
+  /**
+   * Builder class for Module.
+   * @param modelDef model definition in Symbol.
+   */
+  class Builder(private val modelDef: Symbol) {
+    private var dataNames: IndexedSeq[String] = IndexedSeq("data")
+    private var labelNames: IndexedSeq[String] = IndexedSeq("softmax_label")
+    private var contexts: Array[Context] = Array(Context.cpu())
+    private var workLoadList: IndexedSeq[Float] = _
+    private var fixedParamNames: Set[String] = _
+
+    /**
+     * Set the context for execution.
+     * @param ctx a list of contexts.
+     * @return this.
+     */
+    @varargs def setContext(ctx: Context*): Builder = {
+      contexts = ctx.toArray
+      this
+    }
+
+    /**
+     * Set the input data names.
+     * @param name a list of data names. Cannot be null.
+     * @return this.
+     */
+    @varargs def setDataNames(name: String*): Builder = {
+      dataNames = name.toVector
+      this
+    }
+
+    /**
+     * Set the label names.
+     * @param name a list of label names.
+     *             Set to null if no label is required.
+     * @return this.
+     */
+    @varargs def setLabelNames(name: String*): Builder = {
+      labelNames = if (name == null) IndexedSeq.empty[String] else 
name.toVector
+      this
+    }
+
+    /**
+     * Set the workloads.
+     * @param workloads a list of workloads
+     * @return this.
+     */
+    @varargs def setWorkLoadList(workloads: Float*): Builder = {
+      workLoadList = workloads.toVector
+      this
+    }
+
+    /**
+     * Specify the parameters need to be fixed.
+     * @param name a list of parameter names.
+     * @return this.
+     */
+    @varargs def setFixedParamNames(name: String*): Builder = {
+      fixedParamNames = name.toSet
+      this
+    }
+
+    def build(): Module = {
+      new Module(modelDef, dataNames, labelNames, contexts,
+        Option(workLoadList), Option(fixedParamNames))
+    }
+  }
 }
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index a9cac13..22b9c3b 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -18,7 +18,6 @@
 package org.apache.mxnet
 
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
-import org.apache.mxnet.CheckUtils._
 import org.apache.mxnet.module._
 import org.apache.mxnet.optimizer._
 import org.apache.mxnet.io._
@@ -52,8 +51,11 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     import SymbolConversions._
     c = a + 2 * b + 3 * c
 
-    val mod = new Module(c, IndexedSeq("b", "c", "a"), null,
-      contexts = Array(Context.cpu(0), Context.cpu(1)))
+    val mod = new Module.Builder(c)
+      .setDataNames("b", "c", "a")
+      .setLabelNames(null)
+      .setContext(Context.cpu(0), Context.cpu(1))
+      .build()
     mod.bind(dataShapes = IndexedSeq(
       DataDesc("b", Shape(5, 5), layout = "NT"),
       DataDesc("c", Shape(5, 5), layout = "NT"),
@@ -342,11 +344,13 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll 
{
     dShape1 = Shape(20, 3, 120, 120)
     dShape2 = Shape(20, 3, 32, 64)
     lShape = Shape(20)
-    dataBatch = new DataBatch(
-      data = IndexedSeq(
+    dataBatch = new DataBatch.Builder()
+      .setData(
         NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> 
dShape1.toString()))(),
-        NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> 
dShape2.toString()))()),
-      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+        NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> 
dShape2.toString()))())
+      .setLabel(NDArray.ones(lShape))
+      .setPad(0)
+      .build()
     mod.forward(dataBatch)
     assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
     mod.backward()
@@ -355,11 +359,13 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll 
{
     dShape1 = Shape(5, 3, 28, 40)
     dShape2 = Shape(5, 3, 24, 16)
     lShape = Shape(5)
-    dataBatch = new DataBatch(
-      data = IndexedSeq(
+    dataBatch = new DataBatch.Builder()
+      .setData(
         NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> 
dShape1.toString()))(),
-        NDArray.random_uniform(Map("low" -> 15, "high" -> 25, "shape" -> 
dShape2.toString()))()),
-      label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0)
+        NDArray.random_uniform(Map("low" -> 15, "high" -> 25, "shape" -> 
dShape2.toString()))())
+      .setLabel(NDArray.ones(lShape))
+      .setPad(0)
+      .build()
     mod.forward(dataBatch)
     assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass))
     mod.backward()
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala 
b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
index 56cc325..c1c3a42 100644
--- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
+++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala
@@ -66,6 +66,9 @@ private[mxnet] object NDArrayMacro {
         if (!NDArrayfunction.name.startsWith("_") || 
NDArrayfunction.name.startsWith("_contrib_")) {
           Seq(
             // scalastyle:off
+            // (yizhi) We are investigating a way to make these functions 
type-safe
+            // and waiting to see the new approach is stable enough.
+            // Thus these functions may be deprecated in the future.
             // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*)
             q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = 
{genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef],
             // e.g def transpose(args: Any*)

-- 
To stop receiving notification emails like this one, please contact
liuyi...@apache.org.

Reply via email to