This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch java-api
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/java-api by this push:
     new 94f3665  NativeResource Management in Scala (#12647) (#12883)
94f3665 is described below

commit 94f36651233258217d5005fe0e2894a7e88ffa21
Author: Andrew Ayres <[email protected]>
AuthorDate: Fri Oct 19 16:21:06 2018 -0700

    NativeResource Management in Scala (#12647) (#12883)
    
    * add Generic MXNetHandle trait and MXNetHandlePhantomRef class that will 
be used by all MXNetObjects
    
    * Generic Handle with AutoCloseable
    
    * add NativeResource and NativeResourceManager with Periodic GC calling
    
    * use NativeResource trait in NDArray, Symbol and Executor
    
    * add run train mnist script
    
    * create a Generic ResourceScope that can collect all NativeResources to 
dispose at the end
    
    * modify NativeResource and ResourceScope, extend NativeResource in 
NDArray, Symbol and Executor
    
    * remove GCExecutor
    
    * deRegister PhantomReferences by when calling dispose()
    
    * add Finalizer(temporary) to NativeResource
    
    * refactor NativeResource.dispose() method
    
    * update NativeResource/add Unit Test for NativeResource
    
    * updates to NativeResource/NativeResourceRef and unit tests to 
NativeResource
    
    * remove redundant code added because of the object equality that was needed
    
    * add ResourceScope
    
    * Fix NativeResource to not remove from Scope, add Unit Tests to 
ResourceScope
    
    * cleanup log/print debug statements
    
    * use TreeSet inplace of ArrayBuffer to speedup removal of resources from 
ResourceScope
    Fix Executor dispose and make KVStore a NativeResource
    
    * fix segfault that was happening because of NDArray creation on the fly in 
Optimizer
    
    * Add comments for dispose(param:Boolean)
---
 scala-package/core/pom.xml                         |   7 +
 .../src/main/scala/org/apache/mxnet/Executor.scala |  20 ++-
 .../src/main/scala/org/apache/mxnet/KVStore.scala  |  21 +--
 .../src/main/scala/org/apache/mxnet/Model.scala    | 122 ++++++-------
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |  18 +-
 .../scala/org/apache/mxnet/NativeResource.scala    | 189 ++++++++++++++++++++
 .../main/scala/org/apache/mxnet/Optimizer.scala    |  22 ++-
 .../scala/org/apache/mxnet/ResourceScope.scala     | 196 +++++++++++++++++++++
 .../src/main/scala/org/apache/mxnet/Symbol.scala   |  25 ++-
 .../scala/org/apache/mxnet/io/MXDataIter.scala     |  19 +-
 .../scala/org/apache/mxnet/optimizer/SGD.scala     |  10 +-
 .../org/apache/mxnet/NativeResourceSuite.scala     |  69 ++++++++
 .../org/apache/mxnet/ResourceScopeSuite.scala      | 151 ++++++++++++++++
 scala-package/examples/scripts/run_train_mnist.sh  |  33 ++++
 14 files changed, 778 insertions(+), 124 deletions(-)

diff --git a/scala-package/core/pom.xml b/scala-package/core/pom.xml
index 6e2d8d6..d5396da 100644
--- a/scala-package/core/pom.xml
+++ b/scala-package/core/pom.xml
@@ -126,5 +126,12 @@
       <artifactId>commons-io</artifactId>
       <version>2.1</version>
     </dependency>
+    <!-- https://mvnrepository.com/artifact/org.mockito/mockito-all -->
+    <dependency>
+      <groupId>org.mockito</groupId>
+      <artifactId>mockito-all</artifactId>
+      <version>1.10.19</version>
+      <scope>test</scope>
+    </dependency>
   </dependencies>
 </project>
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index fc791d5..19fb6fe 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -45,7 +45,7 @@ object Executor {
  * @see Symbol.bind : to create executor
  */
 class Executor private[mxnet](private[mxnet] val handle: ExecutorHandle,
-                              private[mxnet] val symbol: Symbol) extends 
WarnIfNotDisposed {
+                              private[mxnet] val symbol: Symbol) extends 
NativeResource {
   private[mxnet] var argArrays: Array[NDArray] = null
   private[mxnet] var gradArrays: Array[NDArray] = null
   private[mxnet] var auxArrays: Array[NDArray] = null
@@ -59,14 +59,15 @@ class Executor private[mxnet](private[mxnet] val handle: 
ExecutorHandle,
   private[mxnet] var _group2ctx: Map[String, Context] = null
   private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])
 
-  private var disposed = false
-  protected def isDisposed = disposed
-
-  def dispose(): Unit = {
-    if (!disposed) {
-      outputs.foreach(_.dispose())
-      _LIB.mxExecutorFree(handle)
-      disposed = true
+  override def nativeAddress: CPtrAddress = handle
+  override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
+  // cannot determine the off-heap size of this object
+  override val bytesAllocated: Long = 0
+  override val ref: NativeResourceRef = super.register()
+  override def dispose(): Unit = {
+    if (!super.isDisposed) {
+      super.dispose()
+      outputs.foreach(o => o.dispose())
     }
   }
 
@@ -305,4 +306,5 @@ class Executor private[mxnet](private[mxnet] val handle: 
ExecutorHandle,
     checkCall(_LIB.mxExecutorPrint(handle, str))
     str.value
   }
+
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
index 8e89ce7..45189a1 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/KVStore.scala
@@ -52,22 +52,17 @@ object KVStore {
   }
 }
 
-class KVStore(private[mxnet] val handle: KVStoreHandle) extends 
WarnIfNotDisposed {
+class KVStore(private[mxnet] val handle: KVStoreHandle) extends NativeResource 
{
   private val logger: Logger = LoggerFactory.getLogger(classOf[KVStore])
   private var updaterFunc: MXKVStoreUpdater = null
-  private var disposed = false
-  protected def isDisposed = disposed
 
-  /**
-   * Release the native memory.
-   * The object shall never be used after it is disposed.
-   */
-  def dispose(): Unit = {
-    if (!disposed) {
-      _LIB.mxKVStoreFree(handle)
-      disposed = true
-    }
-  }
+  override def nativeAddress: CPtrAddress = handle
+
+  override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxKVStoreFree
+
+  override val ref: NativeResourceRef = super.register()
+
+  override val bytesAllocated: Long = 0L
 
   /**
    * Initialize a single or a sequence of key-value pairs into the store.
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
index 4bb9cdd..b835c49 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Model.scala
@@ -259,7 +259,9 @@ object Model {
                                       workLoadList: Seq[Float] = Nil,
                                       monitor: Option[Monitor] = None,
                                       symGen: SymbolGenerator = null): Unit = {
-    val executorManager = new DataParallelExecutorManager(
+    ResourceScope.using() {
+
+      val executorManager = new DataParallelExecutorManager(
         symbol = symbol,
         symGen = symGen,
         ctx = ctx,
@@ -269,17 +271,17 @@ object Model {
         auxNames = auxNames,
         workLoadList = workLoadList)
 
-    monitor.foreach(executorManager.installMonitor)
-    executorManager.setParams(argParams, auxParams)
+      monitor.foreach(executorManager.installMonitor)
+      executorManager.setParams(argParams, auxParams)
 
-    // updater for updateOnKVStore = false
-    val updaterLocal = Optimizer.getUpdater(optimizer)
+      // updater for updateOnKVStore = false
+      val updaterLocal = Optimizer.getUpdater(optimizer)
 
-    kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
-      argParams, executorManager.paramNames, updateOnKVStore))
-    if (updateOnKVStore) {
-      kvStore.foreach(_.setOptimizer(optimizer))
-    }
+      kvStore.foreach(initializeKVStore(_, executorManager.paramArrays,
+        argParams, executorManager.paramNames, updateOnKVStore))
+      if (updateOnKVStore) {
+        kvStore.foreach(_.setOptimizer(optimizer))
+      }
 
     // Now start training
     for (epoch <- beginEpoch until endEpoch) {
@@ -290,45 +292,46 @@ object Model {
       var epochDone = false
       // Iterate over training data.
       trainData.reset()
-      while (!epochDone) {
-        var doReset = true
-        while (doReset && trainData.hasNext) {
-          val dataBatch = trainData.next()
-          executorManager.loadDataBatch(dataBatch)
-          monitor.foreach(_.tic())
-          executorManager.forward(isTrain = true)
-          executorManager.backward()
-          if (updateOnKVStore) {
-            updateParamsOnKVStore(executorManager.paramArrays,
-              executorManager.gradArrays,
-              kvStore, executorManager.paramNames)
-          } else {
-            updateParams(executorManager.paramArrays,
-              executorManager.gradArrays,
-              updaterLocal, ctx.length,
-              executorManager.paramNames,
-              kvStore)
-          }
-          monitor.foreach(_.tocPrint())
-          // evaluate at end, so out_cpu_array can lazy copy
-          executorManager.updateMetric(evalMetric, dataBatch.label)
+      ResourceScope.using() {
+        while (!epochDone) {
+          var doReset = true
+          while (doReset && trainData.hasNext) {
+            val dataBatch = trainData.next()
+            executorManager.loadDataBatch(dataBatch)
+            monitor.foreach(_.tic())
+            executorManager.forward(isTrain = true)
+            executorManager.backward()
+            if (updateOnKVStore) {
+              updateParamsOnKVStore(executorManager.paramArrays,
+                executorManager.gradArrays,
+                kvStore, executorManager.paramNames)
+            } else {
+              updateParams(executorManager.paramArrays,
+                executorManager.gradArrays,
+                updaterLocal, ctx.length,
+                executorManager.paramNames,
+                kvStore)
+            }
+            monitor.foreach(_.tocPrint())
+            // evaluate at end, so out_cpu_array can lazy copy
+            executorManager.updateMetric(evalMetric, dataBatch.label)
 
-          nBatch += 1
-          batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))
+            nBatch += 1
+            batchEndCallback.foreach(_.invoke(epoch, nBatch, evalMetric))
 
-          // this epoch is done possibly earlier
-          if (epochSize != -1 && nBatch >= epochSize) {
-            doReset = false
+            // this epoch is done possibly earlier
+            if (epochSize != -1 && nBatch >= epochSize) {
+              doReset = false
+            }
+          }
+          if (doReset) {
+            trainData.reset()
           }
-        }
-        if (doReset) {
-          trainData.reset()
-        }
 
-        // this epoch is done
-        epochDone = (epochSize == -1 || nBatch >= epochSize)
+          // this epoch is done
+          epochDone = (epochSize == -1 || nBatch >= epochSize)
+        }
       }
-
       val (name, value) = evalMetric.get
       name.zip(value).foreach { case (n, v) =>
         logger.info(s"Epoch[$epoch] Train-$n=$v")
@@ -336,20 +339,22 @@ object Model {
       val toc = System.currentTimeMillis
       logger.info(s"Epoch[$epoch] Time cost=${toc - tic}")
 
-      evalData.foreach { evalDataIter =>
-        evalMetric.reset()
-        evalDataIter.reset()
-        // TODO: make DataIter implement Iterator
-        while (evalDataIter.hasNext) {
-          val evalBatch = evalDataIter.next()
-          executorManager.loadDataBatch(evalBatch)
-          executorManager.forward(isTrain = false)
-          executorManager.updateMetric(evalMetric, evalBatch.label)
-        }
+      ResourceScope.using() {
+        evalData.foreach { evalDataIter =>
+          evalMetric.reset()
+          evalDataIter.reset()
+          // TODO: make DataIter implement Iterator
+          while (evalDataIter.hasNext) {
+            val evalBatch = evalDataIter.next()
+            executorManager.loadDataBatch(evalBatch)
+            executorManager.forward(isTrain = false)
+            executorManager.updateMetric(evalMetric, evalBatch.label)
+          }
 
-        val (name, value) = evalMetric.get
-        name.zip(value).foreach { case (n, v) =>
-          logger.info(s"Epoch[$epoch] Train-$n=$v")
+          val (name, value) = evalMetric.get
+          name.zip(value).foreach { case (n, v) =>
+            logger.info(s"Epoch[$epoch] Validation-$n=$v")
+          }
         }
       }
 
@@ -359,8 +364,7 @@ object Model {
       epochEndCallback.foreach(_.invoke(epoch, symbol, argParams, auxParams))
     }
 
-    updaterLocal.dispose()
-    executorManager.dispose()
+    }
   }
   // scalastyle:on parameterNum
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 9b6a7dc..f2a7603 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -562,16 +562,20 @@ object NDArray extends NDArrayBase {
  */
 class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
                              val writable: Boolean = true,
-                             addToCollector: Boolean = true) extends 
WarnIfNotDisposed {
+                             addToCollector: Boolean = true) extends 
NativeResource {
   if (addToCollector) {
     NDArrayCollector.collect(this)
   }
 
+  override def nativeAddress: CPtrAddress = handle
+  override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxNDArrayFree
+  override val bytesAllocated: Long = DType.numOfBytes(this.dtype) * 
this.shape.product
+
+  override val ref: NativeResourceRef = super.register()
+
   // record arrays who construct this array instance
   // we use weak reference to prevent gc blocking
   private[mxnet] val dependencies = mutable.HashMap.empty[Long, 
WeakReference[NDArray]]
-  @volatile private var disposed = false
-  def isDisposed: Boolean = disposed
 
   def serialize(): Array[Byte] = {
     val buf = ArrayBuffer.empty[Byte]
@@ -584,11 +588,10 @@ class NDArray private[mxnet](private[mxnet] val handle: 
NDArrayHandle,
    * The NDArrays it depends on will NOT be disposed. <br />
    * The object shall never be used after it is disposed.
    */
-  def dispose(): Unit = {
-    if (!disposed) {
-      _LIB.mxNDArrayFree(handle)
+  override def dispose(): Unit = {
+    if (!super.isDisposed) {
+      super.dispose()
       dependencies.clear()
-      disposed = true
     }
   }
 
@@ -1034,6 +1037,7 @@ class NDArray private[mxnet](private[mxnet] val handle: 
NDArrayHandle,
     // TODO: naive implementation
     shape.hashCode + toArray.hashCode
   }
+
 }
 
 private[mxnet] object NDArrayConversions {
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
new file mode 100644
index 0000000..48d4b0c
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NativeResource.scala
@@ -0,0 +1,189 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.Base.CPtrAddress
+import java.lang.ref.{PhantomReference, ReferenceQueue, WeakReference}
+import java.util.concurrent._
+
+import org.apache.mxnet.Base.checkCall
+import java.util.concurrent.atomic.AtomicLong
+
+
+/**
+  * NativeResource trait is used to manage MXNet Objects
+  * such as NDArray, Symbol, Executor, etc.,
+  * The MXNet Object calls NativeResource.register
+  * and assign the returned NativeResourceRef to PhantomReference
+  * NativeResource also implements AutoCloseable so MXNetObjects
+  * can be used like Resources in try-with-resources paradigm
+  */
+private[mxnet] trait NativeResource
+  extends AutoCloseable with WarnIfNotDisposed {
+
+  /**
+    * native Address associated with this object
+    */
+  def nativeAddress: CPtrAddress
+
+  /**
+    * Function Pointer to the NativeDeAllocator of nativeAddress
+    */
+  def nativeDeAllocator: (CPtrAddress => Int)
+
+  /** Call NativeResource.register to get the reference
+    */
+  val ref: NativeResourceRef
+
+  /**
+    * Off-Heap Bytes Allocated for this object
+    */
+  // intentionally making it a val, so it gets evaluated when defined
+  val bytesAllocated: Long
+
+  private[mxnet] var scope: Option[ResourceScope] = None
+
+  @volatile private var disposed = false
+
+  override def isDisposed: Boolean = disposed || isDeAllocated
+
+  /**
+    * Register this object for PhantomReference tracking and in
+    * ResourceScope if used inside ResourceScope.
+    * @return NativeResourceRef that tracks reachability of this object
+    *         using PhantomReference
+    */
+  def register(): NativeResourceRef = {
+    scope = ResourceScope.getCurrentScope()
+    if (scope.isDefined) scope.get.add(this)
+
+    NativeResource.totalBytesAllocated.getAndAdd(bytesAllocated)
+    // register with PhantomRef tracking to release incase the objects go
+    // out of reference within scope but are held for long time
+    NativeResourceRef.register(this, nativeDeAllocator)
+ }
+
+  // Implements [[@link AutoCloseable.close]]
+  override def close(): Unit = {
+    dispose()
+  }
+
+  // Implements [[@link WarnIfNotDisposed.dispose]]
+  def dispose(): Unit = dispose(true)
+
+  /**
+    * This method deAllocates nativeResource and deRegisters
+    * from PhantomRef and removes from Scope if
+    * removeFromScope is set to true.
+    * @param removeFromScope remove from the currentScope if true
+    */
+  // the parameter here controls whether to remove from current scope.
+  // [[ResourceScope.close]] calls NativeResource.dispose
+  // if we remove from the ResourceScope ie., from the container in 
ResourceScope.
+  // while iterating on the container, calling iterator.next is undefined and 
not safe.
+  // Note that ResourceScope automatically disposes all the resources within.
+  private[mxnet] def dispose(removeFromScope: Boolean = true): Unit = {
+    if (!disposed) {
+      checkCall(nativeDeAllocator(this.nativeAddress))
+      NativeResourceRef.deRegister(ref) // removes from PhantomRef tracking
+      if (removeFromScope && scope.isDefined) scope.get.remove(this)
+      NativeResource.totalBytesAllocated.getAndAdd(-1*bytesAllocated)
+      disposed = true
+    }
+  }
+
+  /*
+  this is used by the WarnIfNotDisposed finalizer,
+  the object could be disposed by the GC without the need for explicit disposal
+  but the finalizer might not have run, then the WarnIfNotDisposed throws a 
warning
+   */
+  private[mxnet] def isDeAllocated(): Boolean = 
NativeResourceRef.isDeAllocated(ref)
+
+}
+
+private[mxnet] object NativeResource {
+  var totalBytesAllocated : AtomicLong = new AtomicLong(0)
+}
+
+// Do not make [[NativeResource.resource]] a member of the class,
+// this will hold reference and GC will not clear the object.
+private[mxnet] class NativeResourceRef(resource: NativeResource,
+                                       val resourceDeAllocator: CPtrAddress => 
Int)
+        extends PhantomReference[NativeResource](resource, 
NativeResourceRef.refQ) {}
+
+private[mxnet] object NativeResourceRef {
+
+  private[mxnet] val refQ: ReferenceQueue[NativeResource]
+                = new ReferenceQueue[NativeResource]
+
+  private[mxnet] val refMap = new ConcurrentHashMap[NativeResourceRef, 
CPtrAddress]()
+
+  private[mxnet] val cleaner = new ResourceCleanupThread()
+
+  cleaner.start()
+
+  def register(resource: NativeResource, nativeDeAllocator: (CPtrAddress => 
Int)):
+  NativeResourceRef = {
+    val ref = new NativeResourceRef(resource, nativeDeAllocator)
+    refMap.put(ref, resource.nativeAddress)
+    ref
+  }
+
+  // remove from PhantomRef tracking
+  def deRegister(ref: NativeResourceRef): Unit = refMap.remove(ref)
+
+  /**
+    * This method will check if the cleaner ran and deAllocated the object
+    * As a part of GC, when the object is unreachable GC inserts a phantomRef
+    * to the ReferenceQueue which the cleaner thread will deallocate, however
+    * the finalizer runs much later depending on the GC.
+    * @param resource resource to verify if it has been deAllocated
+    * @return true if already deAllocated
+    */
+  def isDeAllocated(ref: NativeResourceRef): Boolean = {
+    !refMap.containsKey(ref)
+  }
+
+  def cleanup: Unit = {
+    // remove is a blocking call
+    val ref: NativeResourceRef = refQ.remove().asInstanceOf[NativeResourceRef]
+    // phantomRef will be removed from the map when NativeResource.close is 
called.
+    val resource = refMap.get(ref)
+    if (resource != 0L)  { // since CPtrAddress is Scala a Long, it cannot be 
null
+      ref.resourceDeAllocator(resource)
+      refMap.remove(ref)
+    }
+  }
+
+  protected class ResourceCleanupThread extends Thread {
+    setPriority(Thread.MAX_PRIORITY)
+    setName("NativeResourceDeAllocatorThread")
+    setDaemon(true)
+
+    override def run(): Unit = {
+      while (true) {
+        try {
+          NativeResourceRef.cleanup
+        }
+        catch {
+          case _: InterruptedException => Thread.currentThread().interrupt()
+        }
+      }
+    }
+  }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
index 758cbc8..c3f8aae 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Optimizer.scala
@@ -19,6 +19,8 @@ package org.apache.mxnet
 
 import java.io._
 
+import org.apache.mxnet.Base.CPtrAddress
+
 import scala.collection.mutable
 import scala.util.Either
 
@@ -38,8 +40,10 @@ object Optimizer {
       }
 
       override def dispose(): Unit = {
-        states.values.foreach(optimizer.disposeState)
-        states.clear()
+        if (!super.isDisposed) {
+          states.values.foreach(optimizer.disposeState)
+          states.clear()
+        }
       }
 
       override def serializeState(): Array[Byte] = {
@@ -285,7 +289,8 @@ abstract class Optimizer extends Serializable {
   }
 }
 
-trait MXKVStoreUpdater {
+trait MXKVStoreUpdater extends
+  NativeResource {
   /**
    * user-defined updater for the kvstore
    * It's this updater's responsibility to delete recv and local
@@ -294,9 +299,14 @@ trait MXKVStoreUpdater {
    * @param local the value stored on local on this key
    */
   def update(key: Int, recv: NDArray, local: NDArray): Unit
-  def dispose(): Unit
-  // def serializeState(): Array[Byte]
-  // def deserializeState(bytes: Array[Byte]): Unit
+
+  // This is a hack to make Optimizers work with ResourceScope
+  // otherwise the user has to manage calling dispose on this object.
+  override def nativeAddress: CPtrAddress = hashCode()
+  override def nativeDeAllocator: CPtrAddress => Int = doNothingDeAllocator
+  private def doNothingDeAllocator(dummy: CPtrAddress): Int = 0
+  override val ref: NativeResourceRef = super.register()
+  override val bytesAllocated: Long = 0L
 }
 
 trait MXKVStoreCachedStates {
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
new file mode 100644
index 0000000..1c5782d
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/ResourceScope.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.util.HashSet
+
+import org.slf4j.LoggerFactory
+
+import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+import scala.util.Try
+import scala.util.control.{ControlThrowable, NonFatal}
+
+/**
+  * This class manages automatically releasing of [[NativeResource]]s
+  */
+class ResourceScope extends AutoCloseable {
+
+  // HashSet does not take a custom comparator
+  private[mxnet] val resourceQ = new 
mutable.TreeSet[NativeResource]()(nativeAddressOrdering)
+
+  private object nativeAddressOrdering extends Ordering[NativeResource] {
+    def compare(a: NativeResource, b: NativeResource): Int = {
+      a.nativeAddress compare  b.nativeAddress
+    }
+  }
+
+  ResourceScope.addToThreadLocal(this)
+
+  /**
+    * Releases all the [[NativeResource]] by calling
+    * the associated [[NativeResource.close()]] method
+    */
+  override def close(): Unit = {
+    ResourceScope.removeFromThreadLocal(this)
+    resourceQ.foreach(resource => if (resource != null) 
resource.dispose(false) )
+    resourceQ.clear()
+  }
+
+  /**
+    * Add a NativeResource to the scope
+    * @param resource
+    */
+  def add(resource: NativeResource): Unit = {
+    resourceQ.+=(resource)
+  }
+
+  /**
+    * Remove NativeResource from the Scope, this uses
+    * object equality to find the resource in the stack.
+    * @param resource
+    */
+  def remove(resource: NativeResource): Unit = {
+    resourceQ.-=(resource)
+  }
+}
+
+object ResourceScope {
+
+  private val logger = LoggerFactory.getLogger(classOf[ResourceScope])
+
+  /**
+    * Captures all Native Resources created using the ResourceScope and
+    * at the end of the body, de allocates all the Native resources by calling 
close on them.
+    * This method will not deAllocate NativeResources returned from the block.
+    * @param scope (Optional). Scope in which to capture the native resources
+    * @param body  block of code to execute in this scope
+    * @tparam A return type
+    * @return result of the operation, if the result is of type 
NativeResource, it is not
+    *         de allocated so the user can use it and then de allocate 
manually by calling
+    *         close or enclose in another resourceScope.
+    */
+  // inspired from slide 21 of 
https://www.slideshare.net/Odersky/fosdem-2009-1013261
+  // and 
https://github.com/scala/scala/blob/2.13.x/src/library/scala/util/Using.scala
+  // TODO: we should move to the Scala util's Using method when we move to 
Scala 2.13
+  def using[A](scope: ResourceScope = null)(body: => A): A = {
+
+    val curScope = if (scope != null) scope else new ResourceScope()
+
+    val prevScope: Option[ResourceScope] = ResourceScope.getPrevScope()
+
+    @inline def resourceInGeneric(g: scala.collection.Iterable[_]) = {
+      g.foreach( n =>
+        n match {
+          case nRes: NativeResource => {
+            removeAndAddToPrevScope(nRes)
+          }
+          case kv: scala.Tuple2[_, _] => {
+            if (kv._1.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+              kv._1.asInstanceOf[NativeResource])
+            if (kv._2.isInstanceOf[NativeResource]) removeAndAddToPrevScope(
+              kv._2.asInstanceOf[NativeResource])
+          }
+        }
+      )
+    }
+
+    @inline def removeAndAddToPrevScope(r: NativeResource) = {
+      curScope.remove(r)
+      if (prevScope.isDefined)  {
+        prevScope.get.add(r)
+        r.scope = prevScope
+      }
+    }
+
+    @inline def safeAddSuppressed(t: Throwable, suppressed: Throwable): Unit = 
{
+      if (!t.isInstanceOf[ControlThrowable]) t.addSuppressed(suppressed)
+    }
+
+    var retThrowable: Throwable = null
+
+    try {
+      val ret = body
+       ret match {
+          // don't de-allocate if returning any collection that contains 
NativeResource.
+        case resInGeneric: scala.collection.Iterable[_] => 
resourceInGeneric(resInGeneric)
+        case nRes: NativeResource => removeAndAddToPrevScope(nRes)
+        case ndRet: NDArrayFuncReturn => ndRet.arr.foreach( nd => 
removeAndAddToPrevScope(nd) )
+        case _ => // do nothing
+      }
+      ret
+    } catch {
+      case t: Throwable =>
+        retThrowable = t
+        null.asInstanceOf[A] // we'll throw in finally
+    } finally {
+      var toThrow: Throwable = retThrowable
+      if (retThrowable eq null) curScope.close()
+      else {
+        try {
+          curScope.close
+        } catch {
+          case closeThrowable: Throwable =>
+            if (NonFatal(retThrowable) && !NonFatal(closeThrowable)) toThrow = 
closeThrowable
+            else safeAddSuppressed(retThrowable, closeThrowable)
+        } finally {
+          throw toThrow
+        }
+      }
+    }
+  }
+
+  // thread local Scopes
+  private[mxnet] val threadLocalScopes = new 
ThreadLocal[ArrayBuffer[ResourceScope]] {
+    override def initialValue(): ArrayBuffer[ResourceScope] =
+      new ArrayBuffer[ResourceScope]()
+  }
+
+  /**
+    * Add resource to current ThreadLocal DataStructure
+    * @param r ResourceScope to add.
+    */
+  private[mxnet] def addToThreadLocal(r: ResourceScope): Unit = {
+    threadLocalScopes.get() += r
+  }
+
+  /**
+    * Remove resource from current ThreadLocal DataStructure
+    * @param r ResourceScope to remove
+    */
+  private[mxnet] def removeFromThreadLocal(r: ResourceScope): Unit = {
+    threadLocalScopes.get() -= r
+  }
+
+  /**
+    * Get the latest Scope in the stack
+    * @return
+    */
+  private[mxnet] def getCurrentScope(): Option[ResourceScope] = {
+    Try(Some(threadLocalScopes.get().last)).getOrElse(None)
+  }
+
+  /**
+    * Get the Last but one Scope from threadLocal Scopes.
+    * @return n-1th scope or None when not found
+    */
+  private[mxnet] def getPrevScope(): Option[ResourceScope] = {
+    val scopes = threadLocalScopes.get()
+    Try(Some(scopes(scopes.size - 2))).getOrElse(None)
+  }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index b1a3e39..a009e7e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -29,21 +29,15 @@ import scala.collection.mutable.{ArrayBuffer, ListBuffer}
  * WARNING: it is your responsibility to clear this object through dispose().
  * </b>
  */
-class Symbol private(private[mxnet] val handle: SymbolHandle) extends 
WarnIfNotDisposed {
+class Symbol private(private[mxnet] val handle: SymbolHandle) extends 
NativeResource {
   private val logger: Logger = LoggerFactory.getLogger(classOf[Symbol])
-  private var disposed = false
-  protected def isDisposed = disposed
 
-  /**
-   * Release the native memory.
-   * The object shall never be used after it is disposed.
-   */
-  def dispose(): Unit = {
-    if (!disposed) {
-      _LIB.mxSymbolFree(handle)
-      disposed = true
-    }
-  }
+  // unable to get the byteAllocated for Symbol
+  override val bytesAllocated: Long = 0L
+  override def nativeAddress: CPtrAddress = handle
+  override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxSymbolFree
+  override val ref: NativeResourceRef = super.register()
+
 
   def +(other: Symbol): Symbol = 
Symbol.createFromListedSymbols("_Plus")(Array(this, other))
   def +[@specialized(Int, Float, Double) V](other: V): Symbol = {
@@ -793,7 +787,7 @@ class Symbol private(private[mxnet] val handle: 
SymbolHandle) extends WarnIfNotD
     }
 
     val execHandle = new ExecutorHandleRef
-    val sharedHadle = if (sharedExec != null) sharedExec.handle else 0L
+    val sharedHandle = if (sharedExec != null) sharedExec.handle else 0L
     checkCall(_LIB.mxExecutorBindEX(handle,
                                    ctx.deviceTypeid,
                                    ctx.deviceId,
@@ -806,7 +800,7 @@ class Symbol private(private[mxnet] val handle: 
SymbolHandle) extends WarnIfNotD
                                    argsGradHandle,
                                    reqsArray,
                                    auxArgsHandle,
-                                   sharedHadle,
+                                   sharedHandle,
                                    execHandle))
     val executor = new Executor(execHandle.value, this.clone())
     executor.argArrays = argsNDArray
@@ -832,6 +826,7 @@ class Symbol private(private[mxnet] val handle: 
SymbolHandle) extends WarnIfNotD
     checkCall(_LIB.mxSymbolSaveToJSON(handle, jsonStr))
     jsonStr.value
   }
+
 }
 
 /**
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index f7f858d..9980177 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -33,7 +33,7 @@ import scala.collection.mutable.ListBuffer
 private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
                                 dataName: String = "data",
                                 labelName: String = "label")
-  extends DataIter with WarnIfNotDisposed {
+  extends DataIter with NativeResource {
 
   private val logger = LoggerFactory.getLogger(classOf[MXDataIter])
 
@@ -67,20 +67,13 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: 
DataIterHandle,
     }
   }
 
+  override def nativeAddress: CPtrAddress = handle
 
-  private var disposed = false
-  protected def isDisposed = disposed
+  override def nativeDeAllocator: CPtrAddress => MXUint = _LIB.mxDataIterFree
 
-  /**
-   * Release the native memory.
-   * The object shall never be used after it is disposed.
-   */
-  def dispose(): Unit = {
-    if (!disposed) {
-      _LIB.mxDataIterFree(handle)
-      disposed = true
-    }
-  }
+  override val ref: NativeResourceRef = super.register()
+
+  override val bytesAllocated: Long = 0L
 
   /**
    * reset the iterator
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
index e20b433..d349fea 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/optimizer/SGD.scala
@@ -17,7 +17,7 @@
 
 package org.apache.mxnet.optimizer
 
-import org.apache.mxnet.{Optimizer, LRScheduler, NDArray}
+import org.apache.mxnet._
 import org.apache.mxnet.NDArrayConversions._
 
 /**
@@ -92,7 +92,13 @@ class SGD(val learningRate: Float = 0.01f, momentum: Float = 
0.0f,
     if (momentum == 0.0f) {
       null
     } else {
-      NDArray.zeros(weight.shape, weight.context)
+      val s = NDArray.zeros(weight.shape, weight.context)
+      // this is created on the fly and shared between runs,
+      // we don't want it to be dispose from the scope
+      // and should be handled by the dispose
+      val scope = ResourceScope.getCurrentScope()
+      if (scope.isDefined) scope.get.remove(s)
+      s
     }
   }
 
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala
new file mode 100644
index 0000000..81a9f60
--- /dev/null
+++ 
b/scala-package/core/src/test/scala/org/apache/mxnet/NativeResourceSuite.scala
@@ -0,0 +1,69 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.lang.ref.ReferenceQueue
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.mxnet.Base.CPtrAddress
+import org.mockito.Matchers.any
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers, TagAnnotation}
+import org.mockito.Mockito._
+
+@TagAnnotation("resource")
+class NativeResourceSuite extends FunSuite with BeforeAndAfterAll with 
Matchers {
+
+  object TestRef  {
+    def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ}
+    def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress]
+    = {NativeResourceRef.refMap}
+    def getCleaner: Thread = { NativeResourceRef.cleaner }
+  }
+
+  class TestRef(resource: NativeResource,
+                          resourceDeAllocator: CPtrAddress => Int)
+    extends NativeResourceRef(resource, resourceDeAllocator) {
+  }
+
+  test(testName = "test native resource setup/teardown") {
+    val a = spy(NDArray.ones(Shape(2, 3)))
+    val aRef = a.ref
+    val spyRef = spy(aRef)
+
+    assert(TestRef.getRefMap.containsKey(aRef) == true)
+    a.close()
+    verify(a).dispose()
+    verify(a).nativeDeAllocator
+    // resourceDeAllocator does not get called when explicitly closing
+    verify(spyRef, times(0)).resourceDeAllocator
+
+    assert(TestRef.getRefMap.containsKey(aRef) == false)
+    assert(a.isDisposed == true, "isDisposed should be set to true after 
calling close")
+  }
+
+  test(testName = "test dispose") {
+    val a: NDArray = spy(NDArray.ones(Shape(3, 4)))
+    val aRef = a.ref
+    val spyRef = spy(aRef)
+    a.dispose()
+    verify(a).nativeDeAllocator
+    assert(TestRef.getRefMap.containsKey(aRef) == false)
+    assert(a.isDisposed == true, "isDisposed should be set to true after 
calling close")
+  }
+}
+
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
new file mode 100644
index 0000000..41dfa7d
--- /dev/null
+++ 
b/scala-package/core/src/test/scala/org/apache/mxnet/ResourceScopeSuite.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import java.lang.ref.ReferenceQueue
+import java.util.concurrent.ConcurrentHashMap
+
+import org.apache.mxnet.Base.CPtrAddress
+import org.apache.mxnet.ResourceScope.logger
+import org.mockito.Matchers.any
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+import org.mockito.Mockito._
+import scala.collection.mutable.HashMap
+
+class ResourceScopeSuite extends FunSuite with BeforeAndAfterAll with Matchers 
{
+
+  class TestNativeResource extends NativeResource {
+    /**
+      * native Address associated with this object
+      */
+    override def nativeAddress: CPtrAddress = hashCode()
+
+    /**
+      * Function Pointer to the NativeDeAllocator of nativeAddress
+      */
+    override def nativeDeAllocator: CPtrAddress => Int = 
TestNativeResource.deAllocator
+
+    /** Call NativeResource.register to get the reference
+      */
+    override val ref: NativeResourceRef = super.register()
+    /**
+      * Off-Heap Bytes Allocated for this object
+      */
+    override val bytesAllocated: Long = 0
+  }
+  object TestNativeResource {
+    def deAllocator(handle: CPtrAddress): Int = 0
+  }
+
+  object TestPhantomRef  {
+    def getRefQueue: ReferenceQueue[NativeResource] = { NativeResourceRef.refQ}
+    def getRefMap: ConcurrentHashMap[NativeResourceRef, CPtrAddress]
+    = {NativeResourceRef.refMap}
+    def getCleaner: Thread = { NativeResourceRef.cleaner }
+
+  }
+
+  class TestPhantomRef(resource: NativeResource,
+                       resourceDeAllocator: CPtrAddress => Int)
+    extends NativeResourceRef(resource, resourceDeAllocator) {
+  }
+
+  test(testName = "test NDArray Auto Release") {
+    var a: NDArray = null
+    var aRef: NativeResourceRef = null
+    var b: NDArray = null
+
+    ResourceScope.using() {
+      b = ResourceScope.using() {
+          a = NDArray.ones(Shape(3, 4))
+          aRef = a.ref
+          val x = NDArray.ones(Shape(3, 4))
+        x
+      }
+      val bRef: NativeResourceRef = b.ref
+      assert(a.isDisposed == true,
+        "objects created within scope should have isDisposed set to true")
+      assert(b.isDisposed == false,
+        "returned NativeResource should not be released")
+      assert(TestPhantomRef.getRefMap.containsKey(aRef) == false,
+        "reference of resource in Scope should be removed refMap")
+      assert(TestPhantomRef.getRefMap.containsKey(bRef) == true,
+        "reference of resource outside scope should be not removed refMap")
+    }
+    assert(b.isDisposed, "resource returned from inner scope should be 
released in outer scope")
+  }
+
+  test("test return object release from outer scope") {
+    var a: TestNativeResource = null
+    ResourceScope.using() {
+      a = ResourceScope.using() {
+        new TestNativeResource()
+      }
+      assert(a.isDisposed == false, "returned object should not be disposed 
within Using")
+    }
+    assert(a.isDisposed == true, "returned object should be disposed in the 
outer scope")
+  }
+
+  test(testName = "test NativeResources in returned Lists are not disposed") {
+    var ndListRet: IndexedSeq[TestNativeResource] = null
+    ResourceScope.using() {
+      ndListRet = ResourceScope.using() {
+        val ndList: IndexedSeq[TestNativeResource] =
+          IndexedSeq(new TestNativeResource(), new TestNativeResource())
+        ndList
+      }
+      ndListRet.foreach(nd => assert(nd.isDisposed == false,
+        "NativeResources within a returned collection should not be disposed"))
+    }
+    ndListRet.foreach(nd => assert(nd.isDisposed == true,
+    "NativeResources returned from inner scope should be disposed in outer 
scope"))
+  }
+
+  test("test native resource inside a map") {
+    var nRInKeyOfMap: HashMap[TestNativeResource, String] = null
+    var nRInValOfMap: HashMap[String, TestNativeResource] = HashMap[String, 
TestNativeResource]()
+
+    ResourceScope.using() {
+      nRInKeyOfMap = ResourceScope.using() {
+        val ret = HashMap[TestNativeResource, String]()
+        ret.put(new TestNativeResource, "hello")
+        ret
+      }
+      assert(!nRInKeyOfMap.isEmpty)
+
+      nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed == false,
+      "NativeResources returned in Traversable should not be disposed"))
+    }
+
+    nRInKeyOfMap.keysIterator.foreach(it => assert(it.isDisposed))
+
+    ResourceScope.using() {
+
+      nRInValOfMap = ResourceScope.using() {
+        val ret = HashMap[String, TestNativeResource]()
+        ret.put("world!", new TestNativeResource)
+        ret
+      }
+      assert(!nRInValOfMap.isEmpty)
+      nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed == false,
+        "NativeResources returned in Collection should not be disposed"))
+    }
+    nRInValOfMap.valuesIterator.foreach(it => assert(it.isDisposed))
+  }
+
+}
diff --git a/scala-package/examples/scripts/run_train_mnist.sh 
b/scala-package/examples/scripts/run_train_mnist.sh
new file mode 100755
index 0000000..ea53c1a
--- /dev/null
+++ b/scala-package/examples/scripts/run_train_mnist.sh
@@ -0,0 +1,33 @@
+#!/bin/bash
+
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+set -e
+
+MXNET_ROOT=$(cd "$(dirname $0)/../../.."; pwd)
+echo $MXNET_ROOT
+CLASS_PATH=$MXNET_ROOT/scala-package/assembly/linux-x86_64-cpu/target/*:$MXNET_ROOT/scala-package/examples/target/*:$MXNET_ROOT/scala-package/examples/target/classes/lib/*:$MXNET_ROOT/scala-package/infer/target/*
+
+# model dir
+DATA_PATH=$2
+
+java -XX:+PrintGC -Xms256M -Xmx512M -Dmxnet.traceLeakedObjects=false -cp 
$CLASS_PATH \
+        org.apache.mxnetexamples.imclassification.TrainMnist \
+        --data-dir /home/ubuntu/mxnet_scala/scala-package/examples/mnist/ \
+        --num-epochs 10000000 \
+        --batch-size 1024
\ No newline at end of file

Reply via email to