This is an automated email from the ASF dual-hosted git repository.
reminisce pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/numpy by this push:
new 874695d [Numpy] Java/Scala modification (#14625)
874695d is described below
commit 874695df9c36c66ec3eecb9cc639108f3469cd26
Author: Yizhi Liu <[email protected]>
AuthorDate: Fri Apr 5 10:21:55 2019 -0700
[Numpy] Java/Scala modification (#14625)
* modify jni to support 0 dim/shape
* fix transpose axes default value
---
.../src/main/scala/org/apache/mxnet/Executor.scala | 142 +++++-------
.../src/main/scala/org/apache/mxnet/LibInfo.scala | 32 ++-
.../src/main/scala/org/apache/mxnet/NDArray.scala | 10 +-
.../main/scala/org/apache/mxnet/NumpyScope.scala | 58 +++++
.../src/main/scala/org/apache/mxnet/Symbol.scala | 38 +++-
.../scala/org/apache/mxnet/NumpyScopeSuite.scala | 34 +++
.../org/apache/mxnet/utils/CToScalaUtils.scala | 3 +-
.../main/native/org_apache_mxnet_native_c_api.cc | 239 ++++++++++++++++++---
.../main/native/org_apache_mxnet_native_c_api.h | 32 +++
src/operator/tensor/matrix_op-inl.h | 6 +-
10 files changed, 452 insertions(+), 142 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index aec4402..f51424b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -61,10 +61,6 @@ class Executor private[mxnet](private[mxnet] val handle:
ExecutorHandle,
protected var monitorCallback: MXMonitorCallback = null
private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])
- private[mxnet] var ownsArgArrays = false
- private[mxnet] var ownsGradArrays = false
- private[mxnet] var ownsAuxArrays = false
-
override def nativeAddress: CPtrAddress = handle
override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
// cannot determine the off-heap size of this object
@@ -75,17 +71,12 @@ class Executor private[mxnet](private[mxnet] val handle:
ExecutorHandle,
if (!super.isDisposed) {
super.dispose()
outputs.foreach(o => o.dispose())
- // Symbol.bind clones symbol when creating the executor so we need to
dispose of the clone
- symbol.dispose()
- if (ownsArgArrays && argArrays != null) {argArrays.foreach(a =>
a.dispose())}
- if (ownsGradArrays && gradArrays != null) {gradArrays.foreach(
+ if (argArrays != null) {argArrays.foreach(a => a.dispose())}
+ if (gradArrays != null) {gradArrays.foreach(
// Symbol will sometimes fill this with nulls so we've got to check
the elements too
a => if (a != null) {a.dispose()})
}
- if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a =>
a.dispose())}
- if (_argDict != null) {_argDict.foreach(a => a._2.dispose())}
- if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())}
- if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())}
+ if (auxArrays != null) {auxArrays.foreach(a => a.dispose())}
}
}
@@ -104,95 +95,58 @@ class Executor private[mxnet](private[mxnet] val handle:
ExecutorHandle,
*/
def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
kwargs: Map[String, Shape]): Executor = {
- var setArgOwner = false
- var setAuxOwner = false
- var setGradOwner = false
- val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
- // TODO: more precise error message should be provided by backend
- require(argShapes != null, "Shape inference failed." +
- s"Known shapes are $kwargs for symbol arguments
${symbol.listArguments()} " +
- s"and aux states ${symbol.listAuxiliaryStates()}")
-
- var newArgDict = Map[String, NDArray]()
- var newGradDict = Map[String, NDArray]()
- this.symbol.listArguments().zipWithIndex.foreach { case (name, i) =>
- val newShape = argShapes(i)
- val arr = this.argArrays(i)
- val dArr = if (this.gradArrays == null) null else this.gradArrays(i)
- if (partialShaping || kwargs.contains(name) ||
newShape.equals(arr.shape)) {
- if (newShape.product > arr.shape.product) {
- require(allowUpSizing, s"New shape of arg:$name larger than
original. " +
- "First making a big executor and then down sizing it "
+
- "is more efficient than the reverse." +
- "If you really want to up size, set allowUpSizing =
true " +
- "to enable allocation of new arrays.")
- newArgDict = newArgDict + (name -> NDArray.empty(newShape,
arr.context, arr.dtype))
- setArgOwner = true
- if (dArr != null) {
- newGradDict = newGradDict + (name -> NDArray.empty(newShape,
dArr.context, dArr.dtype))
- setGradOwner = true
- }
- } else {
- newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
- if (dArr != null) {
- newGradDict = newGradDict + (name ->
dArr.reshape(newShape.toArray))
- }
- }
- } else {
- throw new AssertionError(s"Shape of unspecified array arg:$name
changed." +
- "This can cause the new executor to not share parameters "
+
- "with the old one. Please check for error in network." +
- "If this is intended, set partialShaping = true to
suppress this warning.")
- }
- }
+ val providedArgShapeNames = kwargs.keys
+ val providedArgShapeData = kwargs.values.flatMap(_.toVector)
+ val providedArgShapeIdx = kwargs.values.scanLeft(0)((sum, shape) => sum +
shape.size)
- var newAuxDict = Map[String, NDArray]()
- val zip3 = (this.symbol.listAuxiliaryStates(), auxShapes,
this.auxArrays).zipped
- zip3.foreach { case (name, newShape, arr) =>
- if (partialShaping || newShape.equals(arr.shape)) {
- if (newShape.product > arr.shape.product) {
- require(allowUpSizing, s"New shape of aux:$name larger than
original. " +
- "First making a big executor and then down sizing it "
+
- "is more efficient than the reverse." +
- "If you really want to up size, set allowUpSizing =
true " +
- "to enable allocation of new arrays.")
- newAuxDict = newAuxDict + (name -> NDArray.empty(newShape,
arr.context))
- setAuxOwner = true
- } else {
- newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
- }
- } else {
- throw new AssertionError(s"Shape of unspecified array aux:$name
changed." +
- "This can cause the new executor to not share parameters " +
- "with the old one. Please check for error in network." +
- "If this is intended, set partialShaping = true to suppress
this warning.")
- }
+ val ctxMapKeys = if (_group2ctx != null) _group2ctx.keys.toArray else
Array.empty[String]
+ val ctxMapDevTypes = if (_group2ctx != null) {
+ _group2ctx.values.map(_.deviceTypeid).toArray
+ } else {
+ Array.empty[Int]
}
- val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) {
- this.symbol.bind(this._ctx,
- newArgDict,
- newGradDict,
- this._gradsReq.asInstanceOf[Seq[String]],
- newAuxDict,
- this._group2ctx,
- this)
+ val ctxMapDevIds = if (_group2ctx != null) {
+ _group2ctx.values.map(_.deviceId).toArray
} else {
- this.symbol.bind(this._ctx,
- newArgDict,
- newGradDict,
- this._gradsReq.asInstanceOf[Map[String, String]],
- newAuxDict,
- this._group2ctx,
- this)
+ Array.empty[Int]
}
- // This method has created new NDArrays that will need to be managed by
the new Executor
- if (setArgOwner) reshapedExecutor.ownsArgArrays = true
- if (setGradOwner) reshapedExecutor.ownsGradArrays = true
- if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true
+ val inArgs = ArrayBuffer.empty[NDArrayHandle]
+ val argGrads = ArrayBuffer.empty[NDArrayHandle]
+ val auxStates = ArrayBuffer.empty[NDArrayHandle]
+ val outHandle = new ExecutorHandleRef()
+
+ checkCall(_LIB.mxExecutorReshape(
+ if (partialShaping) 1 else 0,
+ if (allowUpSizing) 1 else 0,
+ _ctx.deviceTypeid,
+ _ctx.deviceId,
+ ctxMapKeys.toArray,
+ ctxMapDevTypes.toArray,
+ ctxMapDevIds.toArray,
+ providedArgShapeNames.toArray,
+ providedArgShapeData.toArray,
+ providedArgShapeIdx.toArray,
+ inArgs,
+ argGrads,
+ auxStates,
+ this.handle,
+ outHandle))
+
+ val argArrays = inArgs.map(new NDArray(_)).toArray
+ val gradArrays = argGrads.map(handle =>
+ if (handle == 0) null else new NDArray(handle)).toArray
+ val auxArrays = auxStates.map(new NDArray(_)).toArray
- reshapedExecutor
+ val executor = new Executor(outHandle.value, this.symbol)
+ executor._ctx = this._ctx
+ executor._gradsReq = this._gradsReq
+ executor._group2ctx = this._group2ctx
+ executor.argArrays = argArrays
+ executor.gradArrays = gradArrays
+ executor.auxArrays = auxArrays
+ executor
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 40fc095..aba6185 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -188,6 +188,23 @@ private[mxnet] class LibInfo {
grads: Array[NDArrayHandle]): Int
@native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int
@native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback:
MXMonitorCallback): Int
+ // scalastyle:off parameterNum
+ @native def mxExecutorReshape(partialShaping: Int,
+ allowUpSizing: Int,
+ devType: Int,
+ devId: Int,
+ mapKeys: Array[String],
+ mapDevTypes: Array[Int],
+ mapDevIds: Array[Int],
+ providedArgShapeNames: Array[String],
+ providedArgShapeData: Array[Int],
+ providedArgShapeIdx: Array[Int],
+ inArgs: ArrayBuffer[NDArrayHandle],
+ argGrads: ArrayBuffer[NDArrayHandle],
+ auxStates: ArrayBuffer[NDArrayHandle],
+ sharedExec: ExecutorHandle,
+ out: ExecutorHandleRef): Int
+ // scalastyle:on parameterNum
// Symbols
@native def mxSymbolListAtomicSymbolCreators(symbolList:
ListBuffer[SymbolHandle]): Int
@@ -240,11 +257,20 @@ private[mxnet] class LibInfo {
numArgs: MXUint,
keys: Array[String],
argIndPtr: Array[MXUint],
- argShapeData: Array[MXUint],
+ argShapeData: Array[Int],
inShapeData: ListBuffer[Array[Int]],
outShapeData: ListBuffer[Array[Int]],
auxShapeData: ListBuffer[Array[Int]],
complete: RefInt): Int
+ @native def mxSymbolInferShapePartial(handle: SymbolHandle,
+ numArgs: MXUint,
+ keys: Array[String],
+ argIndPtr: Array[MXUint],
+ argShapeData: Array[Int],
+ inShapeData: ListBuffer[Array[Int]],
+ outShapeData: ListBuffer[Array[Int]],
+ auxShapeData: ListBuffer[Array[Int]],
+ complete: RefInt): Int
@native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out:
SymbolHandleRef): Int
@native def mxSymbolSaveToJSON(handle: SymbolHandle, out: RefString): Int
@native def mxSymbolCreateFromJSON(json: String, handle: SymbolHandleRef):
Int
@@ -322,4 +348,8 @@ private[mxnet] class LibInfo {
@native def mxSetProfilerConfig(keys: Array[String], vals: Array[String]):
Int
@native def mxSetProfilerState(state: Int): Int
@native def mxDumpProfile(finished: Int): Int
+
+ // Numpy
+ @native def mxIsNumpyCompatible(compatible: RefInt): Int
+ @native def mxSetIsNumpyCompatible(isNpComp: Int, prev: RefInt): Int
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index ab42265..849f456 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -1274,11 +1274,15 @@ class NDArray private[mxnet](private[mxnet] val handle:
NDArrayHandle,
* @return an array representing shape of current ndarray
*/
def shape: Shape = {
- val ndim = new MXUintRef
+ val ndim = new RefInt
val data = ArrayBuffer[Int]()
checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data))
- require(ndim.value == data.length, s"ndim=$ndim, while
len(data)=${data.length}")
- Shape(data)
+ if (ndim.value == -1) {
+ null
+ } else {
+ require(ndim.value == data.length, s"ndim=$ndim, while
len(data)=${data.length}")
+ Shape(data)
+ }
}
// Get size of current NDArray.
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
new file mode 100644
index 0000000..ec366ea
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.Base._
+
+object NumpyScope {
+ def setNumpyCompatible(isNpComp: Boolean): Boolean = {
+ val prev = new RefInt()
+ checkCall(_LIB.mxSetIsNumpyCompatible(if (isNpComp) 1 else 0, prev))
+ if (prev.value != 0) true else false
+ }
+
+ def isNumpyCompatible: Boolean = {
+ val curr = new RefInt
+ checkCall(_LIB.mxIsNumpyCompatible(curr))
+ if (curr.value != 0) true else false
+ }
+
+ def enableNumpyCompatible: NumpyScope = {
+ new NumpyScope(true)
+ }
+
+
+ def disableNumpyCompatible: NumpyScope = {
+ new NumpyScope(false)
+ }
+}
+
+class NumpyScope(var isCompatible: Boolean) {
+ private var prev: Boolean = false
+
+ def withScope[T](body: => T): T = {
+ prev = NumpyScope.setNumpyCompatible(isCompatible)
+ try {
+ body
+ } finally {
+ if (prev != isCompatible) {
+ NumpyScope.setNumpyCompatible(prev)
+ }
+ }
+ }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 821e04f..808a23a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -260,17 +260,45 @@ class Symbol private(private[mxnet] val handle:
SymbolHandle) extends NativeReso
def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int])
: (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = {
+ val res = inferShapeImpl(partial = false, keys, indPtr, values)
+ if (res._2 == null) {
+ val (argShapes, _, _) = inferShapeImpl(partial = true, keys, indPtr,
values)
+ val argNames = listArguments()
+ val unknown = (argNames zip argShapes).map { case (name, shape) =>
+ val shapeIsNone = if (NumpyScope.isNumpyCompatible) {
+ shape == null || shape.toVector.contains(-1)
+ } else {
+ shape == null || shape.toVector.contains(0)
+ }
+ if (shapeIsNone) s"$name: $shape" else ""
+ }
+ logger.warn("Cannot decide shape for the following arguments. " +
+ "Consider providing them as input: \n\t{}",
+ unknown.filter(_ != "").mkString("\n\t"))
+ }
+ res
+ }
+
+ private def inferShapeImpl(partial: Boolean,
+ keys: Array[String],
+ indPtr: Array[Int],
+ values: Array[Int])
+ : (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = {
val argShapeData = ListBuffer.empty[Array[Int]]
val outShapeData = ListBuffer.empty[Array[Int]]
val auxShapeData = ListBuffer.empty[Array[Int]]
val complete = new RefInt
-
- checkCall(_LIB.mxSymbolInferShape(handle, indPtr.length - 1, keys, indPtr,
values,
- argShapeData, outShapeData, auxShapeData, complete))
+ if (partial) {
+ checkCall(_LIB.mxSymbolInferShapePartial(handle, indPtr.length - 1,
keys, indPtr, values,
+ argShapeData, outShapeData, auxShapeData, complete))
+ } else {
+ checkCall(_LIB.mxSymbolInferShape(handle, indPtr.length - 1, keys,
indPtr, values,
+ argShapeData, outShapeData, auxShapeData, complete))
+ }
if (complete.value != 0) {
(argShapeData.map(s => Shape(s)).toIndexedSeq,
- outShapeData.map(s => Shape(s)).toIndexedSeq,
- auxShapeData.map(s => Shape(s)).toIndexedSeq)
+ outShapeData.map(s => Shape(s)).toIndexedSeq,
+ auxShapeData.map(s => Shape(s)).toIndexedSeq)
} else {
(null, null, null)
}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
new file mode 100644
index 0000000..bf6627a
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class NumpyScopeSuite extends FunSuite with BeforeAndAfterAll {
+ test("compatible") {
+ NumpyScope.enableNumpyCompatible.withScope {
+ assert(NumpyScope.isNumpyCompatible === true)
+ }
+ }
+
+ test("incompatible") {
+ NumpyScope.disableNumpyCompatible.withScope {
+ assert(NumpyScope.isNumpyCompatible === false)
+ }
+ }
+}
diff --git
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index 57c4cfb..12d797f 100644
---
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -47,7 +47,8 @@ private[mxnet] object CToScalaUtils {
case "double" | "doubleorNone" => types("double")
case "string" => "String"
case "boolean" | "booleanorNone" => types("bool")
- case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" =>
"Any"
+ case "tupleof<int>" | "tupleof<float>" | "tupleof<double>" |
"tupleof<intorNone>" |
+ "tupleof<>" | "ptr" | "" => "Any"
case default => throw new IllegalArgumentException(
s"Invalid type for args: $default\nString argType: $argType\nargName:
$argName")
}
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index 33e4cca..678dfc1 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -354,8 +354,8 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
(JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject ndimRef, jobject
dataBuf) {
- mx_uint ndim;
- const mx_uint *pdata;
+ int ndim;
+ const int *pdata;
int ret = MXNDArrayGetShape(reinterpret_cast<NDArrayHandle>(ndArrayPtr),
&ndim, &pdata);
// fill dataBuf
@@ -365,7 +365,7 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
jmethodID arrayAppend = env->GetMethodID(arrayClass,
"$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
- for (size_t i = 0; i < ndim; ++i) {
+ for (int i = 0; i < ndim; ++i) {
jobject data = env->NewObject(integerClass, newInteger, pdata[i]);
env->CallObjectMethod(dataBuf, arrayAppend, data);
env->DeleteLocalRef(data);
@@ -892,6 +892,118 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxExecutorBackward
return ret;
}
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape
+ (JNIEnv * env, jobject obj,
+ jint partialReshaping, jint allowUpSizing, jint devType, jint devId,
+ jobjectArray jmapKeys, jintArray jmapDevTypes, jintArray jmapDevIds,
+ jobjectArray jprovidedArgShapeNames, jintArray jprovidedArgShapeData,
jintArray jprovidedArgShapeIdx,
+ jobject jrefInArgs, jobject jrefArgGrads, jobject jrefAuxStates,
+ jlong jsharedExec, jobject jrefOut) {
+ CHECK(jmapKeys != NULL);
+ CHECK(jprovidedArgShapeNames != NULL);
+
+ int numMapKeys = env->GetArrayLength(jmapKeys);
+ jint *mapDevTypes = env->GetIntArrayElements(jmapDevTypes, NULL);
+ jint *mapDevIds = env->GetIntArrayElements(jmapDevIds, NULL);
+ const char **mapKeys = NULL;
+ if (numMapKeys > 0) {
+ mapKeys = new const char*[numMapKeys];
+ for (int i = 0; i < numMapKeys; ++i) {
+ jstring jkey =
reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
+ mapKeys[i] = env->GetStringUTFChars(jkey, 0);
+ env->DeleteLocalRef(jkey);
+ }
+ }
+
+ int numProvidedArgShapes = env->GetArrayLength(jprovidedArgShapeNames);
+ jint *providedArgShapeData = env->GetIntArrayElements(jprovidedArgShapeData,
NULL);
+ jint *providedArgShapeIdx = env->GetIntArrayElements(jprovidedArgShapeIdx,
NULL);
+ const char **providedArgShapeNames = NULL;
+ if (numProvidedArgShapes > 0) {
+ providedArgShapeNames = new const char*[numProvidedArgShapes];
+ for (int i = 0; i < numProvidedArgShapes; ++i) {
+ jstring jkey =
reinterpret_cast<jstring>(env->GetObjectArrayElement(jprovidedArgShapeNames,
i));
+ providedArgShapeNames[i] = env->GetStringUTFChars(jkey, 0);
+ env->DeleteLocalRef(jkey);
+ }
+ }
+
+ mx_uint numInArgs = 0;
+ NDArrayHandle *inArgs;
+ NDArrayHandle *argGrads;
+
+ mx_uint numAuxStates = 0;
+ NDArrayHandle *auxStates;
+
+ ExecutorHandle out;
+
+ int ret = MXExecutorReshape(partialReshaping,
+ allowUpSizing,
+ devType,
+ devId,
+ static_cast<mx_uint>(numMapKeys),
+ mapKeys,
+ static_cast<const int*>(mapDevTypes),
+ static_cast<const int*>(mapDevIds),
+ static_cast<const mx_uint>(numProvidedArgShapes),
+ providedArgShapeNames,
+ static_cast<const int*>(providedArgShapeData),
+ reinterpret_cast<const
mx_uint*>(providedArgShapeIdx),
+ &numInArgs,
+ &inArgs,
+ &argGrads,
+ &numAuxStates,
+ &auxStates,
+ reinterpret_cast<ExecutorHandle>(jsharedExec),
+ &out);
+
+ jclass longCls = env->FindClass("java/lang/Long");
+ jmethodID newLong = env->GetMethodID(longCls, "<init>", "(J)V");
+
+ jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
+ jmethodID arrayAppend = env->GetMethodID(arrayClass,
+ "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
+
+ for (size_t i = 0; i < numInArgs; ++i) {
+ jobject inArg = env->NewObject(longCls, newLong, inArgs[i]);
+ env->CallObjectMethod(jrefInArgs, arrayAppend, inArg);
+ env->DeleteLocalRef(inArg);
+
+ jobject argGrad = env->NewObject(longCls, newLong, argGrads[i]);
+ env->CallObjectMethod(jrefArgGrads, arrayAppend, argGrad);
+ env->DeleteLocalRef(argGrad);
+ }
+
+ for (size_t i = 0; i < numAuxStates; ++i) {
+ jobject auxState = env->NewObject(longCls, newLong, auxStates[i]);
+ env->CallObjectMethod(jrefAuxStates, arrayAppend, auxState);
+ env->DeleteLocalRef(auxState);
+ }
+
+ SetLongField(env, jrefOut, reinterpret_cast<jlong>(out));
+
+ // release allocated memory
+ for (int i = 0; i < numMapKeys; i++) {
+ jstring jkey =
reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
+ env->ReleaseStringUTFChars(jkey, mapKeys[i]);
+ env->DeleteLocalRef(jkey);
+ }
+ if (mapKeys != NULL) {
+ delete[] mapKeys;
+ }
+
+ for (int i = 0; i < numProvidedArgShapes; i++) {
+ jstring jkey =
reinterpret_cast<jstring>(env->GetObjectArrayElement(jprovidedArgShapeNames,
i));
+ env->ReleaseStringUTFChars(jkey, providedArgShapeNames[i]);
+ env->DeleteLocalRef(jkey);
+ }
+ if (providedArgShapeNames != NULL) {
+ delete[] providedArgShapeNames;
+ }
+
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorPrint
(JNIEnv * env, jobject obj, jlong ptr, jobject debugStr) {
const char *retDebugStr;
@@ -1530,23 +1642,26 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile
int FillSymbolInferShape
(JNIEnv *env, jmethodID listAppend, jobject joutData,
- mx_uint shapeSize, const mx_uint *shapeNdim, const mx_uint **shapeData) {
- for (size_t i = 0; i < shapeSize; ++i) {
- jintArray jshape = env->NewIntArray(shapeNdim[i]);
- if (jshape == NULL) {
- // TODO(Yizhi): out of memory error thrown, return a specific error code
?
- return -1;
+ int shapeSize, const int *shapeNdim, const int **shapeData) {
+ for (int i = 0; i < shapeSize; ++i) {
+ jintArray jshape = NULL;
+ if (shapeNdim[i] >= 0) {
+ jshape = env->NewIntArray(shapeNdim[i]);
+ if (jshape == NULL) {
+ // TODO(Yizhi): out of memory error thrown, return a specific error
code ?
+ return -1;
+ }
+ env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast<const
jint *>(shapeData[i]));
}
- env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast<const
jint *>(shapeData[i]));
env->CallObjectMethod(joutData, listAppend, jshape);
env->DeleteLocalRef(jshape);
}
return 0;
}
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
- (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray
jkeys,
- jintArray jargIndPtr, jintArray jargShapeData,
- jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData,
jobject jcomplete) {
+
+int SymbolInferShapeHelper(JNIEnv *env, jobject obj, jlong symbolPtr, jint
jnumArgs, jobjectArray jkeys,
+ jintArray jargIndPtr, jintArray jargShapeData,
jobject jinShapeData,
+ jobject joutShapeData, jobject jauxShapeData,
jobject jcomplete, bool partial) {
const char **keys = NULL;
if (jkeys != NULL) {
keys = new const char *[jnumArgs];
@@ -1559,36 +1674,55 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
}
mx_uint inShapeSize;
- const mx_uint *inShapeNdim;
- const mx_uint **inShapeData;
+ const int *inShapeNdim;
+ const int **inShapeData;
mx_uint outShapeSize;
- const mx_uint *outShapeNdim;
- const mx_uint **outShapeData;
+ const int *outShapeNdim;
+ const int **outShapeData;
mx_uint auxShapeSize;
- const mx_uint *auxShapeNdim;
- const mx_uint **auxShapeData;
+ const int *auxShapeNdim;
+ const int **auxShapeData;
int complete;
jint *argIndPtr = env->GetIntArrayElements(jargIndPtr, NULL);
jint *argShapeData = env->GetIntArrayElements(jargShapeData, NULL);
- int ret = MXSymbolInferShape(reinterpret_cast<SymbolHandle>(symbolPtr),
- static_cast<mx_uint>(jnumArgs),
- keys,
- reinterpret_cast<const mx_uint *>(argIndPtr),
- reinterpret_cast<const mx_uint *>(argShapeData),
- &inShapeSize,
- &inShapeNdim,
- &inShapeData,
- &outShapeSize,
- &outShapeNdim,
- &outShapeData,
- &auxShapeSize,
- &auxShapeNdim,
- &auxShapeData,
- &complete);
+ int ret;
+ if (!partial) {
+ ret = MXSymbolInferShape(reinterpret_cast<SymbolHandle>(symbolPtr),
+ static_cast<mx_uint>(jnumArgs),
+ keys,
+ reinterpret_cast<mx_uint *>(argIndPtr),
+ reinterpret_cast<const int *>(argShapeData),
+ &inShapeSize,
+ &inShapeNdim,
+ &inShapeData,
+ &outShapeSize,
+ &outShapeNdim,
+ &outShapeData,
+ &auxShapeSize,
+ &auxShapeNdim,
+ &auxShapeData,
+ &complete);
+ } else {
+ ret = MXSymbolInferShapePartial(reinterpret_cast<SymbolHandle>(symbolPtr),
+ static_cast<mx_uint>(jnumArgs),
+ keys,
+ reinterpret_cast<mx_uint *>(argIndPtr),
+ reinterpret_cast<const int
*>(argShapeData),
+ &inShapeSize,
+ &inShapeNdim,
+ &inShapeData,
+ &outShapeSize,
+ &outShapeNdim,
+ &outShapeData,
+ &auxShapeSize,
+ &auxShapeNdim,
+ &auxShapeData,
+ &complete);
+ }
env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0);
env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0);
@@ -1629,6 +1763,24 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
return ret;
}
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray
jkeys,
+ jintArray jargIndPtr, jintArray jargShapeData,
+ jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData,
jobject jcomplete) {
+
+ return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys,
jargIndPtr, jargShapeData,
+ jinShapeData, joutShapeData, jauxShapeData,
jcomplete, false);
+}
+
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial
+ (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray
jkeys,
+ jintArray jargIndPtr, jintArray jargShapeData,
+ jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData,
jobject jcomplete) {
+
+ return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys,
jargIndPtr, jargShapeData,
+ jinShapeData, joutShapeData, jauxShapeData,
jcomplete, true);
+}
+
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindX
(JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint
deviceID, jint numCtx,
jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray
jctxMapDevIDs, jint numArgs,
@@ -2551,3 +2703,20 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxDumpProfile
(JNIEnv *env, jobject obj, jint finished) {
return MXDumpProfile(finished);
}
+
+// Numpy
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible
+ (JNIEnv *env, jobject obj, jobject compatibleRef) {
+ bool isCompatible;
+ int ret = MXIsNumpyCompatible(&isCompatible);
+ SetIntField(env, compatibleRef, static_cast<int>(isCompatible));
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible
+ (JNIEnv *env, jobject obj, jint isNpComp, jobject prevRef) {
+ int prev;
+ int ret = MXSetIsNumpyCompatible(isNpComp, &prev);
+ SetIntField(env, prevRef, prev);
+ return ret;
+}
\ No newline at end of file
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index b8a9b3b..467272c 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -513,6 +513,14 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallbac
/*
* Class: org_apache_mxnet_LibInfo
+ * Method: mxExecutorReshape
+ * Signature:
(IIII[Ljava/lang/String;[I[I[Ljava/lang/String;[I[ILscala/collection/mutable/ArrayBuffer;Lscala/collection/mutable/ArrayBuffer;Lscala/collection/mutable/ArrayBuffer;JLorg/apache/mxnet/Base/RefLong;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape
+ (JNIEnv *, jobject, jint, jint, jint, jint, jobjectArray, jintArray,
jintArray, jobjectArray, jintArray, jintArray, jobject, jobject, jobject,
jlong, jobject);
+
+/*
+ * Class: org_apache_mxnet_LibInfo
* Method: mxSymbolListAtomicSymbolCreators
* Signature: (Lscala/collection/mutable/ListBuffer;)I
*/
@@ -657,6 +665,14 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
/*
* Class: org_apache_mxnet_LibInfo
+ * Method: mxSymbolInferShapePartial
+ * Signature:
(JI[Ljava/lang/String;[I[ILscala/collection/mutable/ListBuffer;Lscala/collection/mutable/ListBuffer;Lscala/collection/mutable/ListBuffer;Lorg/apache/mxnet/Base/RefInt;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial
+ (JNIEnv *, jobject, jlong, jint, jobjectArray, jintArray, jintArray,
jobject, jobject, jobject, jobject);
+
+/*
+ * Class: org_apache_mxnet_LibInfo
* Method: mxSymbolGetOutput
* Signature: (JILorg/apache/mxnet/Base/RefLong;)I
*/
@@ -855,6 +871,22 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxSetProfilerState
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile
(JNIEnv *, jobject, jint);
+/*
+ * Class: org_apache_mxnet_LibInfo
+ * Method: mxIsNumpyCompatible
+ * Signature: (Lorg/apache/mxnet/Base/RefInt;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible
+ (JNIEnv *, jobject, jobject);
+
+/*
+ * Class: org_apache_mxnet_LibInfo
+ * Method: mxSetIsNumpyCompatible
+ * Signature: (ILorg/apache/mxnet/Base/RefInt;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible
+ (JNIEnv *, jobject, jint, jobject);
+
#ifdef __cplusplus
}
#endif
diff --git a/src/operator/tensor/matrix_op-inl.h
b/src/operator/tensor/matrix_op-inl.h
index 252e0c5..efc8289 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -240,7 +240,7 @@ inline bool FlattenShape(const nnvm::NodeAttrs& attrs,
struct TransposeParam : public dmlc::Parameter<TransposeParam> {
mxnet::TShape axes;
DMLC_DECLARE_PARAMETER(TransposeParam) {
- DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape())
+ DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape(0))
.describe("Target axis order. By default the axes will be inverted.");
}
};
@@ -314,7 +314,7 @@ void Transpose(const nnvm::NodeAttrs& attrs,
const std::vector<TBlob>& outputs) {
const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
- if (!mxnet::ndim_is_known(param.axes)) {
+ if (param.axes.ndim() == 0) {
mxnet::TShape axes(inputs[0].ndim(), -1);
for (int i = 0; i < axes.ndim(); ++i) {
axes[i] = axes.ndim() - 1 - i;
@@ -334,7 +334,7 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
mxnet::TShape& shp = (*in_attrs)[0];
CHECK_LE(shp.ndim(), 6U) << "Transpose support at most 6 dimensions";
mxnet::TShape ret(shp.ndim(), -1);
- if (!mxnet::ndim_is_known(param.axes)) {
+ if (param.axes.ndim() == 0) {
for (int i = 0; i < shp.ndim(); ++i) {
ret[i] = shp[shp.ndim()-1-i];
}