yzhliu closed pull request #11751: [MXNET-600][Scala] NDArray auto-collector
URL: https://github.com/apache/incubator-mxnet/pull/11751
This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:
As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index 2f79b58a52c..181b2328ddc 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -167,7 +167,7 @@ class Executor private[mxnet](private[mxnet] val handle:
ExecutorHandle,
private def getOutputs: Array[NDArray] = {
val ndHandles = ArrayBuffer[NDArrayHandle]()
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
- ndHandles.toArray.map(new NDArray(_))
+ ndHandles.toArray.map(new NDArray(_, addToCollector = false))
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Monitor.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Monitor.scala
index 8e53d652fde..c8a251d03a6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Monitor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Monitor.scala
@@ -51,7 +51,7 @@ class Monitor(
override def invoke(name: String, arr: NDArrayHandle): Unit = {
// wrapper for executor callback
if (activated) {
- val array = new NDArray(arr, writable = false)
+ val array = new NDArray(arr, writable = false, addToCollector = false)
val elem = (step, name, statFunc(array))
queue += elem
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index c2de6ea43f2..58ab5cadd9d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -554,11 +554,16 @@ object NDArray extends NDArrayBase {
* </b>
*/
class NDArray private[mxnet](private[mxnet] val handle: NDArrayHandle,
- val writable: Boolean = true) extends
WarnIfNotDisposed {
+ val writable: Boolean = true,
+ addToCollector: Boolean = true) extends
WarnIfNotDisposed {
+ if (addToCollector) {
+ NDArrayCollector.collect(this)
+ }
+
// record arrays who construct this array instance
// we use weak reference to prevent gc blocking
private[mxnet] val dependencies = mutable.HashMap.empty[Long,
WeakReference[NDArray]]
- private var disposed = false
+ @volatile private var disposed = false
def isDisposed: Boolean = disposed
def serialize(): Array[Byte] = {
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
new file mode 100644
index 00000000000..ea21cff9ebc
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArrayCollector.scala
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.Base.CPtrAddress
+import org.slf4j.LoggerFactory
+
+import scala.annotation.varargs
+import scala.collection.mutable
+
+/**
+ * A collector to store NDArrays.
+ * It provides a scope, NDArrays allocated in the scope can either <br />
+ * - be disposed automatically when the code block finishes, or <br />
+ * - simply be collected for future usage.
+ * <br />
+ * If the return type of scope is <em>NDArray</em> or
<em>NDArrayFuncReturn</em>,
+ * the collector is smart enough NOT to collect or dispose the returned
NDArray. <br />
+ * However in other cases, it is users' responsibility NOT to leak allocated
NDArrays outside,
+ * (e.g., store to a global variable and use later, pass to another thread,
etc.) <br />
+ * Usage Example:
+ * <pre>
+ * val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
+ * val res = NDArrayCollector.auto().withScope {
+ * (NDArray.relu(a) + a).toArray
+ * }
+ * </pre>
+ * In the case above, the intermediate NDArrays
+ * (created by <em>NDArray.relu</em> and <em>+</em>) will be disposed
automatically. <br />
+ * User can also decide to dispose the collected NDArrays later: <br />
+ * <pre>
+ * val collector = NDArrayCollector.manual()
+ * val res = collector.withScope {
+ * (NDArray.relu(a) + a).toArray
+ * }
+ * collector.foreach(_.dispose())
+ * </pre>
+ * For Java users: <br />
+ * <pre>
+ * NDArray a = NDArray.array(new float[]{-1f, 0f, 1f, 2f, 3f, 4f},
+ * Shape.create(2, 3), Context.cpu(0));
+ * float[] sliced = NDArrayCollector.auto().withScope(
+ * new scala.runtime.AbstractFunction0<float[]>() {
+ * @Override
+ * public float[] apply() {
+ * a.slice(0, 1).toArray();
+ * }
+ * });
+ * </pre>
+ */
+object NDArrayCollector {
+ private val logger = LoggerFactory.getLogger(classOf[NDArrayCollector])
+
+ private val currCollector = new ThreadLocal[NDArrayCollector] {
+ override def initialValue = new NDArrayCollector(false, false)
+ }
+
+ /**
+ * Create a collector which will dispose the collected NDArrays
automatically.
+ * @return an auto-disposable collector.
+ */
+ def auto(): NDArrayCollector = new NDArrayCollector(true)
+
+ /**
+ * Create a collector allows users to later dispose the collected NDArray
manually.
+ * @return a manually-disposable collector.
+ */
+ def manual(): NDArrayCollector = new NDArrayCollector(false)
+
+ /**
+ * Collect the NDArrays into the collector of the current thread.
+ * @param ndArray NDArrays need to be collected.
+ */
+ @varargs def collect(ndArray: NDArray*): Unit = {
+ currCollector.get().add(ndArray: _*)
+ }
+}
+
+class NDArrayCollector private(private val autoDispose: Boolean = true,
+ private val doCollect: Boolean = true) {
+ // native ptr (handle) of the NDArray -> NDArray
+ // in some rare situation, multiple NDArrays have same native ptr,
+ // the Map here is to prevent from disposing more than once.
+ private val arrays = mutable.HashMap.empty[CPtrAddress, NDArray]
+
+ private def add(nd: NDArray*): Unit = {
+ if (doCollect) nd.foreach(arr => arrays.put(arr.handle, arr))
+ }
+
+ /**
+ * Clear the collector.
+ */
+ def clear(): Unit = {
+ arrays.clear()
+ }
+
+ /**
+ * Iterate over the collected NDArrays and apply the user-defined function
to each NDArray.
+ * @param f the function that is applied for its side-effect to every
NDArray.
+ * The result of function <em>f</em> is discarded.
+ */
+ def foreach(f: NDArray => Unit): Unit = {
+ arrays.values.foreach(f(_))
+ }
+
+ /**
+ * @return how many unique NDArrays are collected.
+ */
+ def size: Int = arrays.size
+
+ /**
+ * Create a code scope, NDArrays allocated within this scope will be
collected.
+ * The collected NDArrays will be either <br />
+ * - disposed automatically when the code block finishes (when using
<em>auto</em>) or <br />
+ * - stored for later access (when using <em>manual</em>) <br />
+ * If the return type of scope is <em>NDArray</em> or
<em>NDArrayFuncReturn</em>,
+ * it is smart enough NOT to collect or dispose the returned NDArray. <br />
+ * However in other cases, it is users' responsibility NOT to leak allocated
NDArrays outside.
+ * @param codeBlock code block to be executed within the scope.
+ * @tparam T return type of the function <em>codeBlock</em>.
+ * @return The result of function <em>codeBlock</em>.
+ */
+ def withScope[T](codeBlock: => T): T = {
+ val old = NDArrayCollector.currCollector.get()
+ NDArrayCollector.currCollector.set(this)
+ try {
+ val ret = codeBlock
+ ret match {
+ case ndRet: NDArray =>
+ arrays.remove(ndRet.handle)
+ case ndarrays: NDArrayFuncReturn =>
+ ndarrays.arr.foreach(nd => arrays.remove(nd.handle))
+ case _ => // do nothing
+ }
+ ret
+ } finally {
+ if (autoDispose) {
+ foreach(_.dispose())
+ clear()
+ }
+ NDArrayCollector.currCollector.set(old)
+ }
+ }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
index 6630d5ff53d..f2abe5e4515 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Operator.scala
@@ -72,9 +72,9 @@ abstract class CustomOp {
val tensors = (0 until 5).toArray.map( x => ArrayBuffer[NDArray]() )
for (i <- 0 until numNdarray) {
if (tags(i) == 1 || tags(i) == 4) {
- tensors(tags(i)) += new NDArray(ndarraies(i), writable = true)
+ tensors(tags(i)) += new NDArray(ndarraies(i), writable = true,
addToCollector = false)
} else {
- tensors(tags(i)) += new NDArray(ndarraies(i), writable = false)
+ tensors(tags(i)) += new NDArray(ndarraies(i), writable = false,
addToCollector = false)
}
}
val reqEnum = Array("null", "write", "inplace", "add")
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 70c64877887..10461315c19 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -161,14 +161,15 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
*/
private def _padData(ndArray: NDArray): NDArray = {
val padNum = cursor + dataBatchSize - numData
- val newArray = NDArray.zeros(ndArray.slice(0, dataBatchSize).shape)
- val batch = ndArray.slice(cursor, numData)
- val padding = ndArray.slice(0, padNum)
- newArray.slice(0, dataBatchSize - padNum).set(batch).dispose()
- newArray.slice(dataBatchSize - padNum,
dataBatchSize).set(padding).dispose()
- batch.dispose()
- padding.dispose()
- newArray
+ val shape = Shape(dataBatchSize) ++ ndArray.shape.slice(1,
ndArray.shape.size)
+ val newArray = NDArray.zeros(shape)
+ NDArrayCollector.auto().withScope {
+ val batch = ndArray.slice(cursor, numData)
+ val padding = ndArray.slice(0, padNum)
+ newArray.slice(0, dataBatchSize - padNum).set(batch)
+ newArray.slice(dataBatchSize - padNum, dataBatchSize).set(padding)
+ newArray
+ }
}
private def _getData(data: IndexedSeq[(String, NDArray)]):
IndexedSeq[NDArray] = {
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArrayCollectorSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArrayCollectorSuite.scala
new file mode 100644
index 00000000000..f361ee1e4ea
--- /dev/null
+++
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArrayCollectorSuite.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite, Matchers}
+
+class NDArrayCollectorSuite extends FunSuite with BeforeAndAfterAll with
Matchers {
+
+ test("auto dispose") {
+ val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
+ var b, c: NDArray = null
+
+ val res = NDArrayCollector.auto().withScope {
+ b = NDArray.relu(a) // [0, 0, 1, 2, 3, 4]
+ c = a + b // [-1, 0, 2, 4, 6, 8]
+ c.slice(0, 1)
+ }
+
+ assert(b.isDisposed)
+ assert(c.isDisposed)
+ assert(!res.isDisposed) // smart enough not to dispose the returned NDArray
+
+ assert(res.toArray === Array(-1f, 0f, 2f))
+
+ res.dispose()
+ }
+
+ test("manually dispose") {
+ val a = NDArray.array(Array(-1f, 0f, 1f, 2f, 3f, 4f), shape = Shape(2, 3))
+ var b, c: NDArray = null
+
+ val collector = NDArrayCollector.manual()
+ val res = collector.withScope {
+ b = NDArray.relu(a) // [0, 0, 1, 2, 3, 4]
+ c = a + b // [-1, 0, 2, 4, 6, 8]
+ c.slice(0, 1)
+ }
+
+ assert(res.toArray === Array(-1f, 0f, 2f))
+
+ assert(collector.size === 2) // smart enough not to collect the returned
NDArray
+ assert(!b.isDisposed)
+ assert(!c.isDisposed)
+ assert(!res.isDisposed)
+
+ collector.foreach(_.dispose())
+ assert(b.isDisposed)
+ assert(c.isDisposed)
+ assert(!res.isDisposed)
+
+ collector.clear()
+ assert(collector.size === 0)
+
+ res.dispose()
+ }
+}
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index caf7e56cd3b..d944a8d049c 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -612,7 +612,7 @@ extern "C" void KVStoreUpdaterCallbackFunc
// find java NDArray constructor
jclass ndObjClass = env->FindClass("org/apache/mxnet/NDArray");
- jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "<init>", "(JZ)V");
+ jmethodID ndObjConstructor = env->GetMethodID(ndObjClass, "<init>",
"(JZZ)V");
jobject ndRecv = env->NewObject(ndObjClass, ndObjConstructor,
reinterpret_cast<jlong>(recv), true);
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services