This is an automated email from the ASF dual-hosted git repository. liuyizhi pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push: new 5153370 [MXNET-471] Add Builder class for Scala Module and DataBatch to simplify construction (#11045) 5153370 is described below commit 5153370e3a3d922bbccdf3d28b5c6f31995722fe Author: Yizhi Liu <liuyi...@apache.org> AuthorDate: Sun May 27 14:07:31 2018 -0700 [MXNET-471] Add Builder class for Scala Module and DataBatch to simplify construction (#11045) * Add Builder class for Module and DataBatch to simplify construction. Add annotation to enable varargs in Java * change provideData & provideLabel to more proper names. add test cases * lint code * add comments for type-safe * fix test for DataBatch * remove varargs * check data != null in DataBatch.Builder --- .../core/src/main/scala/org/apache/mxnet/IO.scala | 105 ++++++++++++++++++++- .../src/main/scala/org/apache/mxnet/Shape.scala | 4 + .../src/main/scala/org/apache/mxnet/Symbol.scala | 1 - .../scala/org/apache/mxnet/module/BaseModule.scala | 30 ++++++ .../scala/org/apache/mxnet/module/Module.scala | 73 +++++++++++++- .../test/scala/org/apache/mxnet/ModuleSuite.scala | 28 +++--- .../main/scala/org/apache/mxnet/NDArrayMacro.scala | 3 + 7 files changed, 230 insertions(+), 14 deletions(-) diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala index 7a9c1a7..d9c767c 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala @@ -19,9 +19,10 @@ package org.apache.mxnet import org.apache.mxnet.Base._ import org.apache.mxnet.DType.DType -import org.apache.mxnet.io.{MXDataPack, MXDataIter} +import org.apache.mxnet.io.{MXDataIter, MXDataPack} import org.slf4j.LoggerFactory +import scala.annotation.varargs import scala.collection.immutable.ListMap import scala.collection.mutable.ListBuffer @@ -160,6 +161,108 @@ class DataBatch(val data: IndexedSeq[NDArray], def provideLabel: ListMap[String, Shape] = providedLabel } +object DataBatch { + /** + * Builder class for DataBatch. + */ + class Builder() { + private var data: IndexedSeq[NDArray] = null + private var label: IndexedSeq[NDArray] = null + private var index: IndexedSeq[Long] = null + private var pad: Int = 0 + private var bucketKey: AnyRef = null + private var datatShapes: ListMap[String, Shape] = null + private var labelShapes: ListMap[String, Shape] = null + + /** + * Set the input data. + * @param data a list of data. + * @return this. + */ + @varargs def setData(data: NDArray*): Builder = { + this.data = data.toIndexedSeq + this + } + + /** + * Set the labels in the same order of data. + * @param label a list of labels. + * @return this. + */ + @varargs def setLabel(label: NDArray*): Builder = { + this.label = label.toIndexedSeq + this + } + + /** + * Set the example indices in this batch. + * @param index indices in the same order of data. + * @return this. + */ + @varargs def setIndex(index: Long*): Builder = { + this.index = index.toIndexedSeq + this + } + + /** + * Set the pad. + * @param pad The number of examples padded at the end of a batch. It is used when the + * total number of examples read is not divisible by the `batch_size`. + * These extra padded examples are ignored in prediction. + * @return this + */ + def setPad(pad: Int): Builder = { + this.pad = pad + this + } + + /** + * Set the bucket key, used for bucketing module. + * @param bucketKey the bucket key related to this batch. + * @return this. + */ + def setBucketKey(bucketKey: AnyRef): Builder = { + this.bucketKey = bucketKey + this + } + + /** + * Provide the shape of a data. + * @param name data name. + * @param shape data shape. + * @return this. + */ + def provideDataShape(name: String, shape: Shape): Builder = { + if (datatShapes == null) { + datatShapes = ListMap((name, shape)) + } else { + datatShapes = datatShapes.updated(name, shape) + } + this + } + + /** + * Provide the shape of a label. + * @param name label name. + * @param shape label shape. + * @return this. + */ + def provideLabelShape(name: String, shape: Shape): Builder = { + if (labelShapes == null) { + labelShapes = ListMap((name, shape)) + } else { + labelShapes = labelShapes.updated(name, shape) + } + this + } + + def build(): DataBatch = { + require(data != null, "data is required.") + new DataBatch(data, label, index, pad, bucketKey, datatShapes, labelShapes) + } + } +} + /** * DataIter object in mxnet. */ diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala index e632ade..6891762 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Shape.scala @@ -17,6 +17,8 @@ package org.apache.mxnet +import scala.annotation.varargs + /** * Shape of [[NDArray]] or other data */ @@ -28,6 +30,7 @@ class Shape(dims: Traversable[Int]) extends Serializable { } def apply(dim: Int): Int = shape(dim) + def get(dim: Int): Int = apply(dim) def size: Int = shape.size def length: Int = shape.length def drop(dim: Int): Shape = new Shape(shape.drop(dim)) @@ -56,4 +59,5 @@ class Shape(dims: Traversable[Int]) extends Serializable { object Shape { def apply(dims: Int *): Shape = new Shape(dims: _*) def apply(dims: Traversable[Int]): Shape = new Shape(dims) + @varargs def create(dims: Int*): Shape = new Shape(dims) } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala index 60efd2b..a17fe57 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala @@ -101,7 +101,6 @@ class Symbol private(private[mxnet] val handle: SymbolHandle) extends WarnIfNotD var index: Int = -1 for ((output, i) <- listOutputs().view.zipWithIndex) { if (output == name) { - require(index == -1, s"There are multiple outputs with name $name") index = i } } diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala index 108cff4..60b80f2 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/BaseModule.scala @@ -23,6 +23,8 @@ import org.apache.mxnet.optimizer.SGD import org.apache.mxnet._ import org.slf4j.LoggerFactory import org.slf4j.Logger + +import scala.annotation.varargs import scala.collection.mutable.ArrayBuffer object BaseModule { @@ -469,6 +471,15 @@ abstract class BaseModule { def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): Unit /** + * Forward computation. + * @param dataBatch a batch of data. + * @param isTrain Whether it is for training or not. + */ + def forward(dataBatch: DataBatch, isTrain: Boolean): Unit = { + forward(dataBatch, Option(isTrain)) + } + + /** * Backward computation. * @param outGrads Gradient on the outputs to be propagated back. * This parameter is only needed when bind is called @@ -549,6 +560,25 @@ abstract class BaseModule { forceRebind: Boolean = false, sharedModule: Option[BaseModule] = None, gradReq: String = "write"): Unit + + /** + * Bind the symbols to construct executors. + * This is necessary before one can perform computation with the module. + * @param forTraining Default is `True`. Whether the executors should be bind for training. + * @param inputsNeedGrad Default is `False`. + * Whether the gradients to the input data need to be computed. + * Typically this is not needed. + * But this might be needed when implementing composition of modules. + * @param forceRebind Default is `False`. This function does nothing + * if the executors are already binded. But with this `True`, + * the executors will be forced to rebind. + * @param dataShape Typically is `DataIter.provideData`. + */ + @varargs def bind(forTraining: Boolean, inputsNeedGrad: Boolean, + forceRebind: Boolean, dataShape: DataDesc*): Unit = { + bind(dataShape.toVector, None, forTraining, inputsNeedGrad, forceRebind, None) + } + // Install and initialize optimizers. def initOptimizer(kvstore: String = "local", optimizer: Optimizer = new SGD(), resetOptimizer: Boolean = true, forceInit: Boolean = false): Unit diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala index ac3d645..d55a426 100644 --- a/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala +++ b/scala-package/core/src/main/scala/org/apache/mxnet/module/Module.scala @@ -17,13 +17,16 @@ package org.apache.mxnet.module -import java.io.{FileInputStream, BufferedInputStream, BufferedOutputStream, FileOutputStream} +import java.io.{BufferedInputStream, BufferedOutputStream, FileInputStream, FileOutputStream} + import org.apache.mxnet.DType.DType import org.apache.mxnet._ import org.apache.mxnet.module.DataParallelExecutorGroup.Builder import org.apache.mxnet.optimizer.SGD import org.slf4j.LoggerFactory +import scala.annotation.varargs + /** * Module is a basic module that wrap a `Symbol`. It is functionally the same * as the `FeedForward` model, except under the module API. @@ -642,4 +645,72 @@ object Module { } mod } + + /** + * Builder class for Module. + * @param modelDef model definition in Symbol. + */ + class Builder(private val modelDef: Symbol) { + private var dataNames: IndexedSeq[String] = IndexedSeq("data") + private var labelNames: IndexedSeq[String] = IndexedSeq("softmax_label") + private var contexts: Array[Context] = Array(Context.cpu()) + private var workLoadList: IndexedSeq[Float] = _ + private var fixedParamNames: Set[String] = _ + + /** + * Set the context for execution. + * @param ctx a list of contexts. + * @return this. + */ + @varargs def setContext(ctx: Context*): Builder = { + contexts = ctx.toArray + this + } + + /** + * Set the input data names. + * @param name a list of data names. Cannot be null. + * @return this. + */ + @varargs def setDataNames(name: String*): Builder = { + dataNames = name.toVector + this + } + + /** + * Set the label names. + * @param name a list of label names. + * Set to null if no label is required. + * @return this. + */ + @varargs def setLabelNames(name: String*): Builder = { + labelNames = if (name == null) IndexedSeq.empty[String] else name.toVector + this + } + + /** + * Set the workloads. + * @param workloads a list of workloads + * @return this. + */ + @varargs def setWorkLoadList(workloads: Float*): Builder = { + workLoadList = workloads.toVector + this + } + + /** + * Specify the parameters need to be fixed. + * @param name a list of parameter names. + * @return this. + */ + @varargs def setFixedParamNames(name: String*): Builder = { + fixedParamNames = name.toSet + this + } + + def build(): Module = { + new Module(modelDef, dataNames, labelNames, contexts, + Option(workLoadList), Option(fixedParamNames)) + } + } } diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala index a9cac13..22b9c3b 100644 --- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala +++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala @@ -18,7 +18,6 @@ package org.apache.mxnet import org.scalatest.{BeforeAndAfterAll, FunSuite} -import org.apache.mxnet.CheckUtils._ import org.apache.mxnet.module._ import org.apache.mxnet.optimizer._ import org.apache.mxnet.io._ @@ -52,8 +51,11 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { import SymbolConversions._ c = a + 2 * b + 3 * c - val mod = new Module(c, IndexedSeq("b", "c", "a"), null, - contexts = Array(Context.cpu(0), Context.cpu(1))) + val mod = new Module.Builder(c) + .setDataNames("b", "c", "a") + .setLabelNames(null) + .setContext(Context.cpu(0), Context.cpu(1)) + .build() mod.bind(dataShapes = IndexedSeq( DataDesc("b", Shape(5, 5), layout = "NT"), DataDesc("c", Shape(5, 5), layout = "NT"), @@ -342,11 +344,13 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { dShape1 = Shape(20, 3, 120, 120) dShape2 = Shape(20, 3, 32, 64) lShape = Shape(20) - dataBatch = new DataBatch( - data = IndexedSeq( + dataBatch = new DataBatch.Builder() + .setData( NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(), - NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()), - label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0) + NDArray.random_uniform(Map("low" -> 5, "high" -> 15, "shape" -> dShape2.toString()))()) + .setLabel(NDArray.ones(lShape)) + .setPad(0) + .build() mod.forward(dataBatch) assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass)) mod.backward() @@ -355,11 +359,13 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll { dShape1 = Shape(5, 3, 28, 40) dShape2 = Shape(5, 3, 24, 16) lShape = Shape(5) - dataBatch = new DataBatch( - data = IndexedSeq( + dataBatch = new DataBatch.Builder() + .setData( NDArray.random_uniform(Map("low" -> 0, "high" -> 9, "shape" -> dShape1.toString()))(), - NDArray.random_uniform(Map("low" -> 15, "high" -> 25, "shape" -> dShape2.toString()))()), - label = IndexedSeq(NDArray.ones(lShape)), index = null, pad = 0) + NDArray.random_uniform(Map("low" -> 15, "high" -> 25, "shape" -> dShape2.toString()))()) + .setLabel(NDArray.ones(lShape)) + .setPad(0) + .build() mod.forward(dataBatch) assert(mod.getOutputsMerged()(0).shape == Shape(lShape(0), numClass)) mod.backward() diff --git a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala index 56cc325..c1c3a42 100644 --- a/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala +++ b/scala-package/macros/src/main/scala/org/apache/mxnet/NDArrayMacro.scala @@ -66,6 +66,9 @@ private[mxnet] object NDArrayMacro { if (!NDArrayfunction.name.startsWith("_") || NDArrayfunction.name.startsWith("_contrib_")) { Seq( // scalastyle:off + // (yizhi) We are investigating a way to make these functions type-safe + // and waiting to see the new approach is stable enough. + // Thus these functions may be deprecated in the future. // e.g def transpose(kwargs: Map[String, Any] = null)(args: Any*) q"def $termName(kwargs: Map[String, Any] = null)(args: Any*) = {genericNDArrayFunctionInvoke($funcName, args, kwargs)}".asInstanceOf[DefDef], // e.g def transpose(args: Any*) -- To stop receiving notification emails like this one, please contact liuyi...@apache.org.