This is an automated email from the ASF dual-hosted git repository.

liuyizhi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new ce20a60  [Scala] add BucketingModule (#7716)
ce20a60 is described below

commit ce20a60b5046bbf184bee0b0845c9c630435ad42
Author: 梁德澎 <[email protected]>
AuthorDate: Sat Sep 9 14:00:19 2017 +0800

    [Scala] add BucketingModule (#7716)
    
    * [Scala] add BucketingModule
---
 .../scala/ml/dmlc/mxnet/module/BaseModule.scala    |  29 +-
 .../ml/dmlc/mxnet/module/BucketingModule.scala     | 404 +++++++++++++++++++++
 .../main/scala/ml/dmlc/mxnet/module/Module.scala   |  37 +-
 .../examples/scripts/rnn/run_lstm_bucketing.sh     |  41 +++
 .../scala/ml/dmlc/mxnetexamples/rnn/Lstm.scala     |   8 +-
 .../ml/dmlc/mxnetexamples/rnn/LstmBucketing.scala  |  52 +--
 .../scala/ml/dmlc/mxnetexamples/rnn/RnnModel.scala |   6 +-
 .../ml/dmlc/mxnetexamples/rnn/TrainCharRnn.scala   |   6 +-
 8 files changed, 548 insertions(+), 35 deletions(-)

diff --git 
a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala 
b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
index 0a73e1a..f6f2e83 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BaseModule.scala
@@ -22,9 +22,32 @@ import java.io.IOException
 import ml.dmlc.mxnet.optimizer.SGD
 import ml.dmlc.mxnet._
 import org.slf4j.LoggerFactory
-
+import org.slf4j.Logger
 import scala.collection.mutable.ArrayBuffer
 
+object BaseModule {
+  /**
+   * Check that all input names are in symbol's arguments.
+   */
+  @throws(classOf[IllegalArgumentException])
+  private[module] def _checkInputNames(symbol: Symbol, names: 
IndexedSeq[String],
+    typeName: String, throws: Boolean, logger: Logger): Unit = {
+    val args = symbol.listArguments()
+    for (name <- names) {
+      if (!args.contains(name)) {
+        val candidates = args.filter ( arg =>
+          !arg.endsWith("_weight") && !arg.endsWith("_bias")
+          && !arg.endsWith("_gamma") && !arg.endsWith("_beta"))
+        val msg = s"You created Module with Module(..., 
${typeName}_names=${names.mkString})" +
+          s" but input with name \'${name}\' is not found in 
symbol.listArguments(). " +
+          s"Did you mean one of:\n${candidates.mkString("\n\t")}"
+        if (throws) throw new IllegalArgumentException(msg)
+        else logger.warn(msg)
+      }
+    }
+  }
+}
+
 /**
  * The base class of a modules. A module represents a computation component. 
The design
  * purpose of a module is that it abstract a computation "machine", that one 
can run forward,
@@ -397,7 +420,7 @@ abstract class BaseModule {
 
       // one epoch of training is finished
       val (name, value) = fitParams.evalMetric.get
-      logger.info(s"Epoch[$epoch] Train-$name=$value")
+      logger.info(s"Epoch[$epoch] Train-${name.head}=${value.head}")
       val toc = System.currentTimeMillis
       logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
 
@@ -415,7 +438,7 @@ abstract class BaseModule {
           scoreEndCallback = fitParams.evalEndCallback,
           batchEndCallback = fitParams.evalBatchEndCallback, epoch = epoch)
         val (name, value) = res.get
-        logger.info(s"Epoch[$epoch] Validation-$name=$value")
+        logger.info(s"Epoch[$epoch] Validation-${name.head}=${value.head}")
       })
 
       // end of 1 epoch, reset the data-iter for another epoch
diff --git 
a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BucketingModule.scala 
b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BucketingModule.scala
new file mode 100644
index 0000000..d64ccc0
--- /dev/null
+++ 
b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/BucketingModule.scala
@@ -0,0 +1,404 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package ml.dmlc.mxnet.module
+
+import ml.dmlc.mxnet._
+import org.slf4j.LoggerFactory
+import org.slf4j.Logger
+import scala.collection.mutable.ArrayBuffer
+import ml.dmlc.mxnet.optimizer.SGD
+import scala.collection.immutable.ListMap
+import ml.dmlc.mxnet.module.BaseModule._
+
+/**
+ * This module helps to deal efficiently with varying-length inputs.
+ * @param symGen A function when called with a bucket key, returns a triple
+ *              ``(symbol, dataNames, labelNames)``.
+ * @param defaultBucketKey The key for the default bucket.
+ * @param contexts Default is cpu().
+ * @param workLoadList Default `None`, indicating uniform workload.
+ * @param fixedParamNames Default `None`, indicating no network parameters are 
fixed.
+ */
+class BucketingModule(symGen: AnyRef => (Symbol, IndexedSeq[String], 
IndexedSeq[String]),
+                     defaultBucketKey: AnyRef, contexts: Array[Context] = 
Context.cpu(),
+                     workLoadList: Option[IndexedSeq[Float]] = None,
+                     fixedParamNames: Option[Set[String]] = None) extends 
BaseModule {
+  private val logger = LoggerFactory.getLogger(classOf[BucketingModule])
+
+  {
+    val (sym, dNames, lNames) = symGen(defaultBucketKey)
+    val dataNameList = if (dNames == null) IndexedSeq.empty[String] else dNames
+    val labelNameList = if (lNames == null) IndexedSeq.empty[String] else 
lNames
+    val fixedParamNameList = 
fixedParamNames.getOrElse(IndexedSeq.empty[String]).toIndexedSeq
+
+    _checkInputNames(sym, dataNameList, "data", true, logger)
+    _checkInputNames(sym, labelNameList, "label", false, logger)
+    _checkInputNames(sym, fixedParamNameList, "fixed_param", true, logger)
+  }
+
+  private val workLoads = workLoadList.getOrElse(contexts.map(_ => 
1f).toIndexedSeq)
+  require(workLoads.size == contexts.length)
+
+  private val _buckets = scala.collection.mutable.Map[AnyRef, Module]()
+  private var _currModule: Module = null
+  private var _currBucketKey = defaultBucketKey
+
+  private var paramsDirty = false
+
+  // Internal function to reset binded state.
+  private def resetBind(): Unit = {
+    this.binded = false
+    this._buckets.clear()
+    this._currModule = null
+    this._currBucketKey = defaultBucketKey
+  }
+
+  // Symbol information
+  // A list of names for data required by this module.
+  override def dataNames: IndexedSeq[String] = {
+    if (this.binded) this._currModule.dataNames
+    else this.symGen(this.defaultBucketKey)._2
+  }
+
+  // A list of names for the outputs of this module.
+  override def outputNames: IndexedSeq[String] = {
+    if (this.binded) this._currModule.outputNames
+    else this.symGen(this.defaultBucketKey)._1.listOutputs()
+  }
+
+  // Input/Output information
+  // A list of (name, shape) pairs specifying the data inputs to this module.
+  override def dataShapes: IndexedSeq[DataDesc] = {
+    require(this.binded)
+    this._currModule.dataShapes
+  }
+
+  /**
+   * A list of (name, shape) pairs specifying the label inputs to this module.
+   * If this module does not accept labels -- either it is a module without 
loss
+   * function, or it is not binded for training, then this should return an 
empty
+   * list `[]`.
+   */
+  override def labelShapes: IndexedSeq[DataDesc] = {
+    require(this.binded)
+    this._currModule.labelShapes
+  }
+
+  // A list of (name, shape) pairs specifying the outputs of this module.
+  override def outputShapes: IndexedSeq[(String, Shape)] = {
+    require(this.binded)
+    this._currModule.outputShapes
+  }
+
+  /**
+   * Get current parameters.
+   * `(arg_params, aux_params)`, each a dictionary of name to parameters (in
+   * `NDArray`) mapping.
+   */
+  override def getParams: (Map[String, NDArray], Map[String, NDArray]) = {
+    require(binded && paramsInitialized)
+    this._currModule.paramsDirty = this.paramsDirty
+    val params = this._currModule.getParams
+    this.paramsDirty = false
+    params
+  }
+
+  /**
+   * Assign parameter and aux state values.
+   * @param argParams Dictionary of name to value (`NDArray`) mapping.
+   * @param auxParams Dictionary of name to value (`NDArray`) mapping.
+   * @param allowMissing
+   *         If true, params could contain missing values, and the initializer 
will be
+   *         called to fill those missing params.
+   * @param forceInit
+   *         If true, will force re-initialize even if already initialized.
+   * @param allowExtra
+   *         Whether allow extra parameters that are not needed by symbol.
+   *         If this is True, no error will be thrown when argParams or 
auxParams
+   *         contain extra parameters that is not needed by the executor.
+   */
+  override def setParams(argParams: Map[String, NDArray],
+                auxParams: Map[String, NDArray],
+                allowMissing: Boolean = false,
+                forceInit: Boolean = true,
+                allowExtra: Boolean = false): Unit = {
+    if (!allowMissing) {
+      this.initParams(null, argParams, auxParams, allowMissing, forceInit, 
allowExtra)
+    } else if (this.paramsInitialized && !forceInit) {
+      logger.warn("Parameters already initialized and forceInit=false. " +
+        "setParams call ignored.")
+    } else {
+      this._currModule.setParams(
+        argParams, auxParams, allowMissing, forceInit, allowExtra)
+
+      // because we didn't update self._arg_params, they are dirty now.
+      this.paramsDirty = true
+      this.paramsInitialized = true
+    }
+  }
+
+  /**
+   * Initialize the parameters and auxiliary states.
+   * @param initializer Called to initialize parameters if needed.
+   * @param argParams If not None, should be a dictionary of existing 
arg_params.
+   *                  Initialization will be copied from that.
+   * @param auxParams If not None, should be a dictionary of existing 
aux_params.
+   *                  Initialization will be copied from that.
+   * @param allowMissing If true, params could contain missing values,
+   *                     and the initializer will be called to fill those 
missing params.
+   * @param forceInit If true, will force re-initialize even if already 
initialized.
+   * @param allowExtra Whether allow extra parameters that are not needed by 
symbol.
+   *         If this is True, no error will be thrown when argParams or 
auxParams
+   *         contain extra parameters that is not needed by the executor.
+   */
+  override def initParams(initializer: Initializer = new Uniform(0.01f),
+                          argParams: Map[String, NDArray] = null,
+                          auxParams: Map[String, NDArray] = null,
+                          allowMissing: Boolean = false,
+                          forceInit: Boolean = false,
+                          allowExtra: Boolean = false): Unit = {
+    if (paramsInitialized && !forceInit) {
+      return
+    }
+    require(binded, "call bind before initializing the parameters")
+    this._currModule.initParams(initializer, argParams, auxParams,
+      allowMissing, forceInit, allowExtra)
+    this.paramsDirty = false
+    this.paramsInitialized = true
+  }
+
+  /**
+   * Bind the symbols to construct executors. This is necessary before one
+   * can perform computation with the module.
+   * @param dataShapes Typically is `dataIter.provideData`.
+   * @param labelShapes Typically is `dataIter.provideLabel`.
+   * @param forTraining Default is `true`. Whether the executors should be 
bind for training.
+   * @param inputsNeedGrad Default is `false`.
+   *                       Whether the gradients to the input data need to be 
computed.
+   *                       Typically this is not needed.
+   *                       But this might be needed when implementing 
composition of modules.
+   * @param forceRebind Default is `false`.
+   *                    This function does nothing if the executors are 
already binded.
+   *                    But with this `true`, the executors will be forced to 
rebind.
+   * @param sharedModule Default is `None`. This is used in bucketing.
+   *                     When not `None`, the shared module essentially 
corresponds to
+   *                     a different bucket -- a module with different symbol
+   *                     but with the same sets of parameters
+   *                     (e.g. unrolled RNNs with different lengths).
+   */
+  override def bind(dataShapes: IndexedSeq[DataDesc],
+                    labelShapes: Option[IndexedSeq[DataDesc]] = None,
+                    forTraining: Boolean = true, inputsNeedGrad: Boolean = 
false,
+                    forceRebind: Boolean = false, sharedModule: 
Option[BaseModule] = None,
+                    gradReq: String = "write"): Unit = {
+    // in case we already initialized params, keep it
+    val (argParams, auxParams) =
+      if (this.paramsInitialized) this.getParams
+      else (null, null)
+
+    // force rebinding is typically used when one want to switch from
+    // training to prediction phase.
+    if (forceRebind) this.resetBind()
+
+    if (this.binded) {
+      logger.warn("Already bound, ignoring bind()")
+      return
+    }
+
+    require(sharedModule == None,
+      "shared_module for BucketingModule is not supported")
+
+    this.forTraining = forTraining
+    this.inputsNeedGrad = inputsNeedGrad
+    this.binded = true
+
+    val (sym, dNames, lNames) = this.symGen(this.defaultBucketKey)
+    val module = new Module(sym, dNames, lNames, this.contexts,
+      this.workLoadList, this.fixedParamNames)
+    module.bind(dataShapes, labelShapes, forTraining, inputsNeedGrad,
+      forceRebind = false, sharedModule = None, gradReq)
+    this._currModule = module
+    this._currBucketKey = this.defaultBucketKey
+    this._buckets(this.defaultBucketKey) = module
+
+    // copy back saved params, if already initialized
+    if (this.paramsInitialized) {
+      this.setParams(argParams, auxParams)
+    }
+  }
+
+  /**
+   * Switches to a different bucket. This will change ``this._currModule``.
+   * @param bucketKey The key of the target bucket.
+   * @param dataShapes Typically is `dataIter.provideData`.
+   * @param labelShapes Typically is `dataIter.provideLabel`.
+   */
+  def switchBucket(bucketKey: AnyRef, dataShapes: IndexedSeq[DataDesc],
+    labelShapes: Option[IndexedSeq[DataDesc]] = None): Unit = {
+    require(this.binded, "call bind before switching bucket")
+    if (!this._buckets.contains(bucketKey)) {
+      val (sym, dNames, lNames) = this.symGen(bucketKey)
+      val module = new Module(sym, dNames, lNames, this.contexts,
+        this.workLoadList, this.fixedParamNames)
+      module.bind(dataShapes, labelShapes, this._currModule.forTraining,
+        this._currModule.inputsNeedGrad, forceRebind = false,
+        sharedModule = Option(this._buckets(this.defaultBucketKey)))
+      this._buckets(bucketKey) = module
+    }
+
+    this._currModule = this._buckets(bucketKey)
+    this._currBucketKey = bucketKey
+  }
+
+  /**
+   * Install and initialize optimizers.
+   * @param kvstore
+   * @param optimizer
+   * @param resetOptimizer Default `True`, indicating whether we should set 
`rescaleGrad`
+   *                       & `idx2name` for optimizer according to 
executorGroup
+   * @param forceInit Default `False`, indicating whether we should force 
re-initializing
+   *                  the optimizer in the case an optimizer is already 
installed.
+   */
+  override def initOptimizer(kvstore: String = "local", optimizer: Optimizer = 
new SGD(),
+                    resetOptimizer: Boolean = true, forceInit: Boolean = 
false): Unit = {
+    require(binded && paramsInitialized)
+    if (optimizerInitialized && !forceInit) {
+      logger.warn("optimizer already initialized, ignoring ...")
+    } else {
+      this._currModule.initOptimizer(kvstore, optimizer, resetOptimizer, 
forceInit)
+      for (mod <- this._buckets.values) {
+        if (mod != this._currModule) mod.borrowOptimizer(this._currModule)
+      }
+      this.optimizerInitialized = true
+    }
+  }
+
+  /**
+   * Prepares a data batch for forward.
+   * @param dataBatch input data
+   */
+  def prepare(dataBatch: DataBatch): Unit = {
+    // perform bind if haven't done so
+    require(this.binded && this.paramsInitialized)
+    val bucketKey = dataBatch.bucketKey
+    val originalBucketKey = this._currBucketKey
+    this.switchBucket(bucketKey, dataBatch.provideData, 
Option(dataBatch.provideLabel))
+    // switch back
+    this.switchBucket(originalBucketKey, null, None)
+  }
+
+  /**
+   * Forward computation.
+   * @param dataBatch input data
+   * @param isTrain Default is `None`, which means `is_train` takes the value 
of `for_training`.
+   */
+  override def forward(dataBatch: DataBatch, isTrain: Option[Boolean] = None): 
Unit = {
+    require(binded && paramsInitialized)
+    this.switchBucket(dataBatch.bucketKey, dataBatch.provideData,
+      Option(dataBatch.provideLabel))
+    this._currModule.forward(dataBatch, isTrain)
+  }
+
+  /**
+   * Backward computation.
+   * @param outGrads Gradient on the outputs to be propagated back.
+   *                 This parameter is only needed when bind is called
+   *                 on outputs that are not a loss function.
+   */
+  override def backward(outGrads: Array[NDArray] = null): Unit = {
+    require(binded && paramsInitialized)
+    this._currModule.backward(outGrads)
+  }
+
+  // Update parameters according to the installed optimizer and the gradients 
computed
+  // in the previous forward-backward cycle.
+  override def update(): Unit = {
+    require(binded && paramsInitialized && optimizerInitialized)
+    this.paramsDirty = true
+    this._currModule.update()
+  }
+
+  /**
+   * Get outputs of the previous forward computation.
+   * @return In the case when data-parallelism is used,
+   *         the outputs will be collected from multiple devices.
+   *         The results will look like `[[out1_dev1, out1_dev2], [out2_dev1, 
out2_dev2]]`,
+   *         those `NDArray` might live on different devices.
+   */
+  override def getOutputs(): IndexedSeq[IndexedSeq[NDArray]] = {
+    require(binded && paramsInitialized)
+    this._currModule.getOutputs()
+  }
+
+  /**
+   * Get outputs of the previous forward computation.
+   * @return In the case when data-parallelism is used,
+   *         the outputs will be merged from multiple devices,
+   *         as they look like from a single executor.
+   *         The results will look like `[out1, out2]`
+   */
+  override def getOutputsMerged(): IndexedSeq[NDArray] = {
+    require(binded && paramsInitialized)
+    this._currModule.getOutputsMerged()
+  }
+
+  /**
+   * Get the gradients to the inputs, computed in the previous backward 
computation.
+   * @return In the case when data-parallelism is used,
+   *         the grads will be collected from multiple devices.
+   *         The results will look like `[[grad1_dev1, grad1_dev2], 
[grad2_dev1, grad2_dev2]]`,
+   *         those `NDArray` might live on different devices.
+   */
+  override def getInputGrads(): IndexedSeq[IndexedSeq[NDArray]] = {
+    require(binded && paramsInitialized && inputsNeedGrad)
+    this._currModule.getInputGrads()
+  }
+
+  /**
+   * Get the gradients to the inputs, computed in the previous backward 
computation.
+   * @return In the case when data-parallelism is used,
+   *         the grads will be merged from multiple devices,
+   *         as they look like from a single executor.
+   *         The results will look like `[grad1, grad2]`
+   */
+  override def getInputGradsMerged(): IndexedSeq[NDArray] = {
+    require(binded && paramsInitialized && inputsNeedGrad)
+    this._currModule.getInputGradsMerged()
+  }
+
+  /**
+   * Evaluate and accumulate evaluation metric on outputs of the last forward 
computation.
+   * @param evalMetric
+   * @param labels
+   */
+  override def updateMetric(evalMetric: EvalMetric, labels: 
IndexedSeq[NDArray]): Unit = {
+    require(binded && paramsInitialized)
+    this._currModule.updateMetric(evalMetric, labels)
+  }
+
+  override def getSymbol: Symbol = {
+    require(binded)
+    this._currModule.symbol
+  }
+
+  // Install monitor on all executors
+  override def installMonitor(monitor: Monitor): Unit = {
+    require(binded)
+    for (mod <- this._buckets.values) mod.installMonitor(monitor)
+  }
+}
diff --git 
a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala 
b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
index b9cc078..445622e 100644
--- a/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
+++ b/scala-package/core/src/main/scala/ml/dmlc/mxnet/module/Module.scala
@@ -56,7 +56,7 @@ class Module(symbolVar: Symbol,
   private val auxNames = symbol.listAuxiliaryStates()
   private val outputNamesVar = symbol.listOutputs()
 
-  private var paramsDirty = false
+  private[module] var paramsDirty = false
 
   private var optimizer: Optimizer = null
   private var kvstore: Option[KVStore] = None
@@ -168,6 +168,41 @@ class Module(symbolVar: Symbol,
     }
   }
 
+  /**
+   * Assign parameter and aux state values.
+   *     argParams : dict
+   *         Dictionary of name to value (`NDArray`) mapping.
+   *     auxParams : dict
+   *         Dictionary of name to value (`NDArray`) mapping.
+   *     allowMissing : bool
+   *         If true, params could contain missing values, and the initializer 
will be
+   *         called to fill those missing params.
+   *     forceInit : bool
+   *         If true, will force re-initialize even if already initialized.
+   *     allowExtra : bool
+   *         Whether allow extra parameters that are not needed by symbol.
+   *         If this is True, no error will be thrown when argParams or 
auxParams
+   *         contain extra parameters that is not needed by the executor.
+   */
+  override def setParams(argParams: Map[String, NDArray],
+                auxParams: Map[String, NDArray],
+                allowMissing: Boolean = false,
+                forceInit: Boolean = true,
+                allowExtra: Boolean = false): Unit = {
+    if (!allowMissing) {
+      this.initParams(null, argParams, auxParams, allowMissing, forceInit, 
allowExtra)
+    } else if (this.paramsInitialized && !forceInit) {
+      logger.warn("Parameters already initialized and forceInit=false. " +
+        "setParams call ignored.")
+    } else {
+      this.execGroup.setParams(argParams, auxParams, allowExtra)
+
+      // because we didn't update self._arg_params, they are dirty now.
+      this.paramsDirty = true
+      this.paramsInitialized = true
+    }
+  }
+
   // Internal function to reset binded state.
   private def resetBind(): Unit = {
     binded = false
diff --git a/scala-package/examples/scripts/rnn/run_lstm_bucketing.sh 
b/scala-package/examples/scripts/rnn/run_lstm_bucketing.sh
new file mode 100644
index 0000000..3ad160e
--- /dev/null
+++ b/scala-package/examples/scripts/rnn/run_lstm_bucketing.sh
@@ -0,0 +1,41 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../../.."; pwd)
+OS=$(uname)
+if [ "$OS" = "Darwin" ]; then
+  
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/osx-x86_64-gpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*
+else
+  
CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-gpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*
+fi
+
+DATA_TRAIN=$1
+DATA_VAL=$2
+NUM_EPOCH=5
+GPUS="0"
+SAVE_MODEL_PATH=./model/lstm
+
+java -Xmx4G -cp $CLASS_PATH \
+  ml.dmlc.mxnetexamples.rnn.LstmBucketing \
+  --data-train $DATA_TRAIN \
+  --data-val $DATA_VAL \
+  --num-epoch $NUM_EPOCH \
+  --gpus $GPUS \
+  --save-model-path $SAVE_MODEL_PATH
diff --git 
a/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/Lstm.scala 
b/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/Lstm.scala
index fe900a3..7804fd0 100644
--- a/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/Lstm.scala
+++ b/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/Lstm.scala
@@ -77,8 +77,8 @@ object Lstm {
                                      i2hBias = 
Symbol.Variable(s"l${i}_i2h_bias"),
                                      h2hWeight = 
Symbol.Variable(s"l${i}_h2h_weight"),
                                      h2hBias = 
Symbol.Variable(s"l${i}_h2h_bias")))
-      lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c"),
-                                     h = Symbol.Variable(s"l${i}_init_h")))
+      lastStatesBuf.append(LSTMState(c = Symbol.Variable(s"l${i}_init_c_beta"),
+                                     h = 
Symbol.Variable(s"l${i}_init_h_beta")))
     }
     val paramCells = paramCellsBuf.toArray
     val lastStates = lastStatesBuf.toArray
@@ -134,8 +134,8 @@ object Lstm {
                                            i2hBias = 
Symbol.Variable(s"l${i}_i2h_bias"),
                                            h2hWeight = 
Symbol.Variable(s"l${i}_h2h_weight"),
                                            h2hBias = 
Symbol.Variable(s"l${i}_h2h_bias"))
-      lastStates = lastStates :+ LSTMState(c = 
Symbol.Variable(s"l${i}_init_c"),
-                                           h = 
Symbol.Variable(s"l${i}_init_h"))
+      lastStates = lastStates :+ LSTMState(c = 
Symbol.Variable(s"l${i}_init_c_beta"),
+                                           h = 
Symbol.Variable(s"l${i}_init_h_beta"))
     }
     assert(lastStates.length == numLstmLayer)
 
diff --git 
a/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/LstmBucketing.scala
 
b/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/LstmBucketing.scala
index a49044b..303c2fb 100644
--- 
a/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/LstmBucketing.scala
+++ 
b/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/LstmBucketing.scala
@@ -26,6 +26,8 @@ import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.{Logger, LoggerFactory}
 
 import scala.collection.JavaConverters._
+import ml.dmlc.mxnet.module.BucketingModule
+import ml.dmlc.mxnet.module.FitParams
 
 /**
  * Bucketing LSTM examples
@@ -50,6 +52,7 @@ object LstmBucketing {
   private val logger: Logger = LoggerFactory.getLogger(classOf[LstmBucketing])
 
   def perplexity(label: NDArray, pred: NDArray): Float = {
+    pred.waitToRead()
     val labelArr = label.T.toArray.map(_.toInt)
     var loss = .0
     (0 until pred.shape(0)).foreach(i =>
@@ -74,25 +77,22 @@ object LstmBucketing {
       val numEmbed = 200
       val numLstmLayer = 2
 
-      val learningRate = 0.01f
-      val momentum = 0.0f
-
       logger.info("Building vocab ...")
       val vocab = BucketIo.defaultBuildVocab(inst.dataTrain)
 
-      class BucketSymGen extends SymbolGenerator {
-        override def generate(key: AnyRef): Symbol = {
-          val seqLen = key.asInstanceOf[Int]
-          Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
-            numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
-        }
+      def BucketSymGen(key: AnyRef):
+        (Symbol, IndexedSeq[String], IndexedSeq[String]) = {
+        val seqLen = key.asInstanceOf[Int]
+        val sym = Lstm.lstmUnroll(numLstmLayer, seqLen, vocab.size,
+          numHidden = numHidden, numEmbed = numEmbed, numLabel = vocab.size)
+        (sym, IndexedSeq("data"), IndexedSeq("softmax_label"))
       }
 
       val initC = (0 until numLstmLayer).map(l =>
-        (s"l${l}_init_c", (batchSize, numHidden))
+        (s"l${l}_init_c_beta", (batchSize, numHidden))
       )
       val initH = (0 until numLstmLayer).map(l =>
-        (s"l${l}_init_h", (batchSize, numHidden))
+        (s"l${l}_init_h_beta", (batchSize, numHidden))
       )
       val initStates = initC ++ initH
 
@@ -101,18 +101,26 @@ object LstmBucketing {
       val dataVal = new BucketSentenceIter(inst.dataVal, vocab,
         buckets, batchSize, initStates)
 
+      val model = new BucketingModule(
+        symGen = BucketSymGen,
+        defaultBucketKey = dataTrain.defaultBucketKey,
+        contexts = contexts)
+
+      val fitParams = new FitParams()
+      fitParams.setEvalMetric(
+        new CustomMetric(perplexity, name = "perplexity"))
+      fitParams.setKVStore("device")
+      fitParams.setOptimizer(
+        new SGD(learningRate = 0.01f, momentum = 0f, wd = 0.00001f))
+      fitParams.setInitializer(new Xavier(factorType = "in", magnitude = 
2.34f))
+      fitParams.setBatchEndCallback(new Speedometer(batchSize, 50))
+
       logger.info("Start training ...")
-      val model = FeedForward.newBuilder(new BucketSymGen())
-        .setContext(contexts)
-        .setNumEpoch(inst.numEpoch)
-        .setOptimizer(new SGD(learningRate = learningRate, momentum = 
momentum, wd = 0.00001f))
-        .setInitializer(new Xavier(factorType = "in", magnitude = 2.34f))
-        .setTrainData(dataTrain)
-        .setEvalData(dataVal)
-        .setEvalMetric(new CustomMetric(perplexity, name = "perplexity"))
-        .setBatchEndCallback(new Speedometer(batchSize, 50))
-        .build()
-      model.save(inst.saveModelPath)
+      model.fit(
+        trainData = dataTrain,
+        evalData = Some(dataVal),
+        numEpoch = inst.numEpoch, fitParams)
+      logger.info("Finished training...")
     } catch {
       case ex: Exception =>
         logger.error(ex.getMessage, ex)
diff --git 
a/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/RnnModel.scala
 
b/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/RnnModel.scala
index b2188f8..5f919f1 100644
--- 
a/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/RnnModel.scala
+++ 
b/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/RnnModel.scala
@@ -35,9 +35,9 @@ object RnnModel {
                                                dropout)
     private val batchSize = 1
     private val initC = (for (l <- 0 until numLstmLayer)
-                          yield (s"l${l}_init_c" -> Shape(batchSize, 
numHidden))).toMap
+                          yield (s"l${l}_init_c_beta" -> Shape(batchSize, 
numHidden))).toMap
     private val initH = (for (l <- 0 until numLstmLayer)
-                          yield (s"l${l}_init_h" -> Shape(batchSize, 
numHidden))).toMap
+                          yield (s"l${l}_init_h_beta" -> Shape(batchSize, 
numHidden))).toMap
     private val dataShape = Map("data" -> Shape(batchSize))
     private val inputShape = initC ++ initH ++ dataShape
     private val executor = sym.simpleBind(ctx = ctx, shapeDict = inputShape)
@@ -49,7 +49,7 @@ object RnnModel {
     }
 
     private var stateName = (Array[String]() /: (0 until numLstmLayer)) { 
(acc, i) =>
-      acc :+ s"l${i}_init_c"  :+ s"l${i}_init_h"
+      acc :+ s"l${i}_init_c_beta"  :+ s"l${i}_init_h_beta"
     }
 
     private val statesDict = stateName.zip(this.executor.outputs.drop(1)).toMap
diff --git 
a/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/TrainCharRnn.scala
 
b/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/TrainCharRnn.scala
index 59c074e..2fe2780 100644
--- 
a/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/TrainCharRnn.scala
+++ 
b/scala-package/examples/src/main/scala/ml/dmlc/mxnetexamples/rnn/TrainCharRnn.scala
@@ -70,8 +70,10 @@ object TrainCharRnn {
       }
 
       // initalize states for LSTM
-      val initC = for (l <- 0 until numLstmLayer) yield (s"l${l}_init_c", 
(batchSize, numHidden))
-      val initH = for (l <- 0 until numLstmLayer) yield (s"l${l}_init_h", 
(batchSize, numHidden))
+      val initC = for (l <- 0 until numLstmLayer)
+        yield (s"l${l}_init_c_beta", (batchSize, numHidden))
+      val initH = for (l <- 0 until numLstmLayer)
+        yield (s"l${l}_init_h_beta", (batchSize, numHidden))
       val initStates = initC ++ initH
 
       val dataTrain = new BucketIo.BucketSentenceIter(incr.dataPath, vocab, 
buckets,

-- 
To stop receiving notification emails like this one, please contact
['"[email protected]" <[email protected]>'].

Reply via email to