This is an automated email from the ASF dual-hosted git repository.

reminisce pushed a commit to branch numpy
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/numpy by this push:
     new 874695d  [Numpy] Java/Scala modification (#14625)
874695d is described below

commit 874695df9c36c66ec3eecb9cc639108f3469cd26
Author: Yizhi Liu <[email protected]>
AuthorDate: Fri Apr 5 10:21:55 2019 -0700

    [Numpy] Java/Scala modification (#14625)
    
    * modify jni to support 0 dim/shape
    
    * fix transpose axes default value
---
 .../src/main/scala/org/apache/mxnet/Executor.scala | 142 +++++-------
 .../src/main/scala/org/apache/mxnet/LibInfo.scala  |  32 ++-
 .../src/main/scala/org/apache/mxnet/NDArray.scala  |  10 +-
 .../main/scala/org/apache/mxnet/NumpyScope.scala   |  58 +++++
 .../src/main/scala/org/apache/mxnet/Symbol.scala   |  38 +++-
 .../scala/org/apache/mxnet/NumpyScopeSuite.scala   |  34 +++
 .../org/apache/mxnet/utils/CToScalaUtils.scala     |   3 +-
 .../main/native/org_apache_mxnet_native_c_api.cc   | 239 ++++++++++++++++++---
 .../main/native/org_apache_mxnet_native_c_api.h    |  32 +++
 src/operator/tensor/matrix_op-inl.h                |   6 +-
 10 files changed, 452 insertions(+), 142 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index aec4402..f51424b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -61,10 +61,6 @@ class Executor private[mxnet](private[mxnet] val handle: 
ExecutorHandle,
   protected var monitorCallback: MXMonitorCallback = null
   private val logger: Logger = LoggerFactory.getLogger(classOf[Executor])
 
-  private[mxnet] var ownsArgArrays = false
-  private[mxnet] var ownsGradArrays = false
-  private[mxnet] var ownsAuxArrays = false
-
   override def nativeAddress: CPtrAddress = handle
   override def nativeDeAllocator: (CPtrAddress => Int) = _LIB.mxExecutorFree
   // cannot determine the off-heap size of this object
@@ -75,17 +71,12 @@ class Executor private[mxnet](private[mxnet] val handle: 
ExecutorHandle,
     if (!super.isDisposed) {
       super.dispose()
       outputs.foreach(o => o.dispose())
-      // Symbol.bind clones symbol when creating the executor so we need to 
dispose of the clone
-      symbol.dispose()
-      if (ownsArgArrays && argArrays != null) {argArrays.foreach(a => 
a.dispose())}
-      if (ownsGradArrays && gradArrays != null) {gradArrays.foreach(
+      if (argArrays != null) {argArrays.foreach(a => a.dispose())}
+      if (gradArrays != null) {gradArrays.foreach(
         // Symbol will sometimes fill this with nulls so we've got to check 
the elements too
         a => if (a != null) {a.dispose()})
       }
-      if (ownsAuxArrays && auxArrays != null) {auxArrays.foreach(a => 
a.dispose())}
-      if (_argDict != null) {_argDict.foreach(a => a._2.dispose())}
-      if (_gradDict != null) {_gradDict.foreach(a => a._2.dispose())}
-      if (_auxDict != null) {_auxDict.foreach(a => a._2.dispose())}
+      if (auxArrays != null) {auxArrays.foreach(a => a.dispose())}
     }
   }
 
@@ -104,95 +95,58 @@ class Executor private[mxnet](private[mxnet] val handle: 
ExecutorHandle,
    */
   def reshape(partialShaping: Boolean = false, allowUpSizing: Boolean = false,
     kwargs: Map[String, Shape]): Executor = {
-    var setArgOwner = false
-    var setAuxOwner = false
-    var setGradOwner = false
-     val (argShapes, _, auxShapes) = this.symbol.inferShape(kwargs)
-    // TODO: more precise error message should be provided by backend
-    require(argShapes != null, "Shape inference failed." +
-      s"Known shapes are $kwargs for symbol arguments 
${symbol.listArguments()} " +
-      s"and aux states ${symbol.listAuxiliaryStates()}")
-
-    var newArgDict = Map[String, NDArray]()
-    var newGradDict = Map[String, NDArray]()
 
-    this.symbol.listArguments().zipWithIndex.foreach { case (name, i) =>
-      val newShape = argShapes(i)
-      val arr = this.argArrays(i)
-      val dArr = if (this.gradArrays == null) null else this.gradArrays(i)
-      if (partialShaping || kwargs.contains(name) || 
newShape.equals(arr.shape)) {
-        if (newShape.product > arr.shape.product) {
-          require(allowUpSizing, s"New shape of arg:$name larger than 
original. " +
-                        "First making a big executor and then down sizing it " 
+
-                        "is more efficient than the reverse." +
-                        "If you really want to up size, set allowUpSizing = 
true " +
-                        "to enable allocation of new arrays.")
-          newArgDict = newArgDict + (name -> NDArray.empty(newShape, 
arr.context, arr.dtype))
-          setArgOwner = true
-          if (dArr != null) {
-            newGradDict = newGradDict + (name -> NDArray.empty(newShape, 
dArr.context, dArr.dtype))
-            setGradOwner = true
-          }
-        } else {
-          newArgDict = newArgDict + (name -> arr.reshape(newShape.toArray))
-          if (dArr != null) {
-            newGradDict = newGradDict + (name -> 
dArr.reshape(newShape.toArray))
-          }
-        }
-      } else {
-        throw new  AssertionError(s"Shape of unspecified array arg:$name 
changed." +
-                    "This can cause the new executor to not share parameters " 
+
-                    "with the old one. Please check for error in network." +
-                    "If this is intended, set partialShaping = true to 
suppress this warning.")
-      }
-    }
+    val providedArgShapeNames = kwargs.keys
+    val providedArgShapeData = kwargs.values.flatMap(_.toVector)
+    val providedArgShapeIdx = kwargs.values.scanLeft(0)((sum, shape) => sum + 
shape.size)
 
-    var newAuxDict = Map[String, NDArray]()
-    val zip3 = (this.symbol.listAuxiliaryStates(), auxShapes, 
this.auxArrays).zipped
-    zip3.foreach { case (name, newShape, arr) =>
-      if (partialShaping || newShape.equals(arr.shape)) {
-        if (newShape.product > arr.shape.product) {
-          require(allowUpSizing, s"New shape of aux:$name larger than 
original. " +
-                        "First making a big executor and then down sizing it " 
+
-                        "is more efficient than the reverse." +
-                        "If you really want to up size, set allowUpSizing = 
true " +
-                        "to enable allocation of new arrays.")
-          newAuxDict = newAuxDict + (name -> NDArray.empty(newShape, 
arr.context))
-          setAuxOwner = true
-        } else {
-          newAuxDict = newAuxDict + (name -> arr.reshape(newShape.toArray))
-        }
-      } else {
-        throw new  AssertionError(s"Shape of unspecified array aux:$name 
changed." +
-                  "This can cause the new executor to not share parameters " +
-                  "with the old one. Please check for error in network." +
-                  "If this is intended, set partialShaping = true to suppress 
this warning.")
-      }
+    val ctxMapKeys = if (_group2ctx != null) _group2ctx.keys.toArray else 
Array.empty[String]
+    val ctxMapDevTypes = if (_group2ctx != null) {
+      _group2ctx.values.map(_.deviceTypeid).toArray
+    } else {
+      Array.empty[Int]
     }
-    val reshapedExecutor = if (this._gradsReq.isInstanceOf[Seq[_]]) {
-      this.symbol.bind(this._ctx,
-                          newArgDict,
-                          newGradDict,
-                          this._gradsReq.asInstanceOf[Seq[String]],
-                          newAuxDict,
-                          this._group2ctx,
-                          this)
+    val ctxMapDevIds = if (_group2ctx != null) {
+      _group2ctx.values.map(_.deviceId).toArray
     } else {
-      this.symbol.bind(this._ctx,
-                          newArgDict,
-                          newGradDict,
-                          this._gradsReq.asInstanceOf[Map[String, String]],
-                          newAuxDict,
-                          this._group2ctx,
-                          this)
+      Array.empty[Int]
     }
 
-    // This method has created new NDArrays that will need to be managed by 
the new Executor
-    if (setArgOwner) reshapedExecutor.ownsArgArrays = true
-    if (setGradOwner) reshapedExecutor.ownsGradArrays = true
-    if (setAuxOwner) reshapedExecutor.ownsAuxArrays = true
+    val inArgs = ArrayBuffer.empty[NDArrayHandle]
+    val argGrads = ArrayBuffer.empty[NDArrayHandle]
+    val auxStates = ArrayBuffer.empty[NDArrayHandle]
+    val outHandle = new ExecutorHandleRef()
+
+    checkCall(_LIB.mxExecutorReshape(
+              if (partialShaping) 1 else 0,
+              if (allowUpSizing) 1 else 0,
+              _ctx.deviceTypeid,
+              _ctx.deviceId,
+              ctxMapKeys.toArray,
+              ctxMapDevTypes.toArray,
+              ctxMapDevIds.toArray,
+              providedArgShapeNames.toArray,
+              providedArgShapeData.toArray,
+              providedArgShapeIdx.toArray,
+              inArgs,
+              argGrads,
+              auxStates,
+              this.handle,
+              outHandle))
+
+    val argArrays = inArgs.map(new NDArray(_)).toArray
+    val gradArrays = argGrads.map(handle =>
+      if (handle == 0) null else new NDArray(handle)).toArray
+    val auxArrays = auxStates.map(new NDArray(_)).toArray
 
-    reshapedExecutor
+    val executor = new Executor(outHandle.value, this.symbol)
+    executor._ctx = this._ctx
+    executor._gradsReq = this._gradsReq
+    executor._group2ctx = this._group2ctx
+    executor.argArrays = argArrays
+    executor.gradArrays = gradArrays
+    executor.auxArrays = auxArrays
+    executor
   }
 
   /**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 40fc095..aba6185 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -188,6 +188,23 @@ private[mxnet] class LibInfo {
                                  grads: Array[NDArrayHandle]): Int
   @native def mxExecutorPrint(handle: ExecutorHandle, debugStr: RefString): Int
   @native def mxExecutorSetMonitorCallback(handle: ExecutorHandle, callback: 
MXMonitorCallback): Int
+  // scalastyle:off parameterNum
+  @native def mxExecutorReshape(partialShaping: Int,
+                                allowUpSizing: Int,
+                                devType: Int,
+                                devId: Int,
+                                mapKeys: Array[String],
+                                mapDevTypes: Array[Int],
+                                mapDevIds: Array[Int],
+                                providedArgShapeNames: Array[String],
+                                providedArgShapeData: Array[Int],
+                                providedArgShapeIdx: Array[Int],
+                                inArgs: ArrayBuffer[NDArrayHandle],
+                                argGrads: ArrayBuffer[NDArrayHandle],
+                                auxStates: ArrayBuffer[NDArrayHandle],
+                                sharedExec: ExecutorHandle,
+                                out: ExecutorHandleRef): Int
+  // scalastyle:on parameterNum
 
   // Symbols
   @native def mxSymbolListAtomicSymbolCreators(symbolList: 
ListBuffer[SymbolHandle]): Int
@@ -240,11 +257,20 @@ private[mxnet] class LibInfo {
                                  numArgs: MXUint,
                                  keys: Array[String],
                                  argIndPtr: Array[MXUint],
-                                 argShapeData: Array[MXUint],
+                                 argShapeData: Array[Int],
                                  inShapeData: ListBuffer[Array[Int]],
                                  outShapeData: ListBuffer[Array[Int]],
                                  auxShapeData: ListBuffer[Array[Int]],
                                  complete: RefInt): Int
+  @native def mxSymbolInferShapePartial(handle: SymbolHandle,
+                                        numArgs: MXUint,
+                                        keys: Array[String],
+                                        argIndPtr: Array[MXUint],
+                                        argShapeData: Array[Int],
+                                        inShapeData: ListBuffer[Array[Int]],
+                                        outShapeData: ListBuffer[Array[Int]],
+                                        auxShapeData: ListBuffer[Array[Int]],
+                                        complete: RefInt): Int
   @native def mxSymbolGetOutput(handle: SymbolHandle, index: Int, out: 
SymbolHandleRef): Int
   @native def mxSymbolSaveToJSON(handle: SymbolHandle, out: RefString): Int
   @native def mxSymbolCreateFromJSON(json: String, handle: SymbolHandleRef): 
Int
@@ -322,4 +348,8 @@ private[mxnet] class LibInfo {
   @native def mxSetProfilerConfig(keys: Array[String], vals: Array[String]): 
Int
   @native def mxSetProfilerState(state: Int): Int
   @native def mxDumpProfile(finished: Int): Int
+
+  // Numpy
+  @native def mxIsNumpyCompatible(compatible: RefInt): Int
+  @native def mxSetIsNumpyCompatible(isNpComp: Int, prev: RefInt): Int
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index ab42265..849f456 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -1274,11 +1274,15 @@ class NDArray private[mxnet](private[mxnet] val handle: 
NDArrayHandle,
    * @return an array representing shape of current ndarray
    */
   def shape: Shape = {
-    val ndim = new MXUintRef
+    val ndim = new RefInt
     val data = ArrayBuffer[Int]()
     checkCall(_LIB.mxNDArrayGetShape(handle, ndim, data))
-    require(ndim.value == data.length, s"ndim=$ndim, while 
len(data)=${data.length}")
-    Shape(data)
+    if (ndim.value == -1) {
+      null
+    } else {
+      require(ndim.value == data.length, s"ndim=$ndim, while 
len(data)=${data.length}")
+      Shape(data)
+    }
   }
 
   // Get size of current NDArray.
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
new file mode 100644
index 0000000..ec366ea
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NumpyScope.scala
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.Base._
+
+object NumpyScope {
+  def setNumpyCompatible(isNpComp: Boolean): Boolean = {
+    val prev = new RefInt()
+    checkCall(_LIB.mxSetIsNumpyCompatible(if (isNpComp) 1 else 0, prev))
+    if (prev.value != 0) true else false
+  }
+
+  def isNumpyCompatible: Boolean = {
+    val curr = new RefInt
+    checkCall(_LIB.mxIsNumpyCompatible(curr))
+    if (curr.value != 0) true else false
+  }
+
+  def enableNumpyCompatible: NumpyScope = {
+    new NumpyScope(true)
+  }
+
+
+  def disableNumpyCompatible: NumpyScope = {
+    new NumpyScope(false)
+  }
+}
+
+class NumpyScope(var isCompatible: Boolean) {
+  private var prev: Boolean = false
+
+  def withScope[T](body: => T): T = {
+    prev = NumpyScope.setNumpyCompatible(isCompatible)
+    try {
+      body
+    } finally {
+      if (prev != isCompatible) {
+        NumpyScope.setNumpyCompatible(prev)
+      }
+    }
+  }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
index 821e04f..808a23a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Symbol.scala
@@ -260,17 +260,45 @@ class Symbol private(private[mxnet] val handle: 
SymbolHandle) extends NativeReso
 
   def inferShape(keys: Array[String], indPtr: Array[Int], values: Array[Int])
     : (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = {
+    val res = inferShapeImpl(partial = false, keys, indPtr, values)
+    if (res._2 == null) {
+      val (argShapes, _, _) = inferShapeImpl(partial = true, keys, indPtr, 
values)
+      val argNames = listArguments()
+      val unknown = (argNames zip argShapes).map { case (name, shape) =>
+        val shapeIsNone = if (NumpyScope.isNumpyCompatible) {
+          shape == null || shape.toVector.contains(-1)
+        } else {
+          shape == null || shape.toVector.contains(0)
+        }
+        if (shapeIsNone) s"$name: $shape" else ""
+      }
+      logger.warn("Cannot decide shape for the following arguments. " +
+        "Consider providing them as input: \n\t{}",
+        unknown.filter(_ != "").mkString("\n\t"))
+    }
+    res
+  }
+
+  private def inferShapeImpl(partial: Boolean,
+                             keys: Array[String],
+                             indPtr: Array[Int],
+                             values: Array[Int])
+    : (IndexedSeq[Shape], IndexedSeq[Shape], IndexedSeq[Shape]) = {
     val argShapeData = ListBuffer.empty[Array[Int]]
     val outShapeData = ListBuffer.empty[Array[Int]]
     val auxShapeData = ListBuffer.empty[Array[Int]]
     val complete = new RefInt
-
-    checkCall(_LIB.mxSymbolInferShape(handle, indPtr.length - 1, keys, indPtr, 
values,
-      argShapeData, outShapeData, auxShapeData, complete))
+    if (partial) {
+      checkCall(_LIB.mxSymbolInferShapePartial(handle, indPtr.length - 1, 
keys, indPtr, values,
+        argShapeData, outShapeData, auxShapeData, complete))
+    } else {
+      checkCall(_LIB.mxSymbolInferShape(handle, indPtr.length - 1, keys, 
indPtr, values,
+        argShapeData, outShapeData, auxShapeData, complete))
+    }
     if (complete.value != 0) {
       (argShapeData.map(s => Shape(s)).toIndexedSeq,
-       outShapeData.map(s => Shape(s)).toIndexedSeq,
-       auxShapeData.map(s => Shape(s)).toIndexedSeq)
+        outShapeData.map(s => Shape(s)).toIndexedSeq,
+        auxShapeData.map(s => Shape(s)).toIndexedSeq)
     } else {
       (null, null, null)
     }
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
new file mode 100644
index 0000000..bf6627a
--- /dev/null
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NumpyScopeSuite.scala
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.scalatest.{BeforeAndAfterAll, FunSuite}
+
+class NumpyScopeSuite extends FunSuite with BeforeAndAfterAll {
+  test("compatible") {
+    NumpyScope.enableNumpyCompatible.withScope {
+      assert(NumpyScope.isNumpyCompatible === true)
+    }
+  }
+
+  test("incompatible") {
+    NumpyScope.disableNumpyCompatible.withScope {
+      assert(NumpyScope.isNumpyCompatible === false)
+    }
+  }
+}
diff --git 
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
 
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
index 57c4cfb..12d797f 100644
--- 
a/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
+++ 
b/scala-package/macros/src/main/scala/org/apache/mxnet/utils/CToScalaUtils.scala
@@ -47,7 +47,8 @@ private[mxnet] object CToScalaUtils {
       case "double" | "doubleorNone" => types("double")
       case "string" => "String"
       case "boolean" | "booleanorNone" => types("bool")
-      case "tupleof<float>" | "tupleof<double>" | "tupleof<>" | "ptr" | "" => 
"Any"
+      case "tupleof<int>" | "tupleof<float>" | "tupleof<double>" | 
"tupleof<intorNone>" |
+           "tupleof<>" | "ptr" | "" => "Any"
       case default => throw new IllegalArgumentException(
         s"Invalid type for args: $default\nString argType: $argType\nargName: 
$argName")
     }
diff --git 
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc 
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index 33e4cca..678dfc1 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -354,8 +354,8 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayLoadFromRawBytes
 
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
   (JNIEnv *env, jobject obj, jlong ndArrayPtr, jobject ndimRef, jobject 
dataBuf) {
-  mx_uint ndim;
-  const mx_uint *pdata;
+  int ndim;
+  const int *pdata;
   int ret = MXNDArrayGetShape(reinterpret_cast<NDArrayHandle>(ndArrayPtr), 
&ndim, &pdata);
 
   // fill dataBuf
@@ -365,7 +365,7 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
   jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
   jmethodID arrayAppend = env->GetMethodID(arrayClass,
     "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
-  for (size_t i = 0; i < ndim; ++i) {
+  for (int i = 0; i < ndim; ++i) {
     jobject data = env->NewObject(integerClass, newInteger, pdata[i]);
     env->CallObjectMethod(dataBuf, arrayAppend, data);
     env->DeleteLocalRef(data);
@@ -892,6 +892,118 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxExecutorBackward
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape
+  (JNIEnv * env, jobject obj,
+    jint partialReshaping, jint allowUpSizing, jint devType, jint devId,
+    jobjectArray jmapKeys, jintArray jmapDevTypes, jintArray jmapDevIds,
+    jobjectArray jprovidedArgShapeNames, jintArray jprovidedArgShapeData, 
jintArray jprovidedArgShapeIdx,
+    jobject jrefInArgs, jobject jrefArgGrads, jobject jrefAuxStates,
+    jlong jsharedExec, jobject jrefOut) {
+  CHECK(jmapKeys != NULL);
+  CHECK(jprovidedArgShapeNames != NULL);
+
+  int numMapKeys = env->GetArrayLength(jmapKeys);
+  jint *mapDevTypes = env->GetIntArrayElements(jmapDevTypes, NULL);
+  jint *mapDevIds = env->GetIntArrayElements(jmapDevIds, NULL);
+  const char **mapKeys = NULL;
+  if (numMapKeys > 0) {
+    mapKeys = new const char*[numMapKeys];
+    for (int i = 0; i < numMapKeys; ++i) {
+      jstring jkey = 
reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
+      mapKeys[i] = env->GetStringUTFChars(jkey, 0);
+      env->DeleteLocalRef(jkey);
+    }
+  }
+
+  int numProvidedArgShapes = env->GetArrayLength(jprovidedArgShapeNames);
+  jint *providedArgShapeData = env->GetIntArrayElements(jprovidedArgShapeData, 
NULL);
+  jint *providedArgShapeIdx = env->GetIntArrayElements(jprovidedArgShapeIdx, 
NULL);
+  const char **providedArgShapeNames = NULL;
+  if (numProvidedArgShapes > 0) {
+    providedArgShapeNames = new const char*[numProvidedArgShapes];
+    for (int i = 0; i < numProvidedArgShapes; ++i) {
+      jstring jkey = 
reinterpret_cast<jstring>(env->GetObjectArrayElement(jprovidedArgShapeNames, 
i));
+      providedArgShapeNames[i] = env->GetStringUTFChars(jkey, 0);
+      env->DeleteLocalRef(jkey);
+    }
+  }
+
+  mx_uint numInArgs = 0;
+  NDArrayHandle *inArgs;
+  NDArrayHandle *argGrads;
+
+  mx_uint numAuxStates = 0;
+  NDArrayHandle *auxStates;
+
+  ExecutorHandle out;
+
+  int ret = MXExecutorReshape(partialReshaping,
+                              allowUpSizing,
+                              devType,
+                              devId,
+                              static_cast<mx_uint>(numMapKeys),
+                              mapKeys,
+                              static_cast<const int*>(mapDevTypes),
+                              static_cast<const int*>(mapDevIds),
+                              static_cast<const mx_uint>(numProvidedArgShapes),
+                              providedArgShapeNames,
+                              static_cast<const int*>(providedArgShapeData),
+                              reinterpret_cast<const 
mx_uint*>(providedArgShapeIdx),
+                              &numInArgs,
+                              &inArgs,
+                              &argGrads,
+                              &numAuxStates,
+                              &auxStates,
+                              reinterpret_cast<ExecutorHandle>(jsharedExec),
+                              &out);
+
+  jclass longCls = env->FindClass("java/lang/Long");
+  jmethodID newLong = env->GetMethodID(longCls, "<init>", "(J)V");
+
+  jclass arrayClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
+  jmethodID arrayAppend = env->GetMethodID(arrayClass,
+    "$plus$eq", "(Ljava/lang/Object;)Lscala/collection/mutable/ArrayBuffer;");
+
+  for (size_t i = 0; i < numInArgs; ++i) {
+    jobject inArg = env->NewObject(longCls, newLong, inArgs[i]);
+    env->CallObjectMethod(jrefInArgs, arrayAppend, inArg);
+    env->DeleteLocalRef(inArg);
+
+    jobject argGrad = env->NewObject(longCls, newLong, argGrads[i]);
+    env->CallObjectMethod(jrefArgGrads, arrayAppend, argGrad);
+    env->DeleteLocalRef(argGrad);
+  }
+
+  for (size_t i = 0; i < numAuxStates; ++i) {
+    jobject auxState = env->NewObject(longCls, newLong, auxStates[i]);
+    env->CallObjectMethod(jrefAuxStates, arrayAppend, auxState);
+    env->DeleteLocalRef(auxState);
+  }
+
+  SetLongField(env, jrefOut, reinterpret_cast<jlong>(out));
+
+  // release allocated memory
+  for (int i = 0; i < numMapKeys; i++) {
+    jstring jkey = 
reinterpret_cast<jstring>(env->GetObjectArrayElement(jmapKeys, i));
+    env->ReleaseStringUTFChars(jkey, mapKeys[i]);
+    env->DeleteLocalRef(jkey);
+  }
+  if (mapKeys != NULL) {
+    delete[] mapKeys;
+  }
+
+  for (int i = 0; i < numProvidedArgShapes; i++) {
+    jstring jkey = 
reinterpret_cast<jstring>(env->GetObjectArrayElement(jprovidedArgShapeNames, 
i));
+    env->ReleaseStringUTFChars(jkey, providedArgShapeNames[i]);
+    env->DeleteLocalRef(jkey);
+  }
+  if (providedArgShapeNames != NULL) {
+    delete[] providedArgShapeNames;
+  }
+
+  return ret;
+}
+
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorPrint
   (JNIEnv * env, jobject obj, jlong ptr, jobject debugStr) {
   const char *retDebugStr;
@@ -1530,23 +1642,26 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxSymbolCreateFromFile
 
 int FillSymbolInferShape
   (JNIEnv *env, jmethodID listAppend, jobject joutData,
-    mx_uint shapeSize, const mx_uint *shapeNdim, const mx_uint **shapeData) {
-  for (size_t i = 0; i < shapeSize; ++i) {
-    jintArray jshape = env->NewIntArray(shapeNdim[i]);
-    if (jshape == NULL) {
-      // TODO(Yizhi): out of memory error thrown, return a specific error code 
?
-      return -1;
+    int shapeSize, const int *shapeNdim, const int **shapeData) {
+  for (int i = 0; i < shapeSize; ++i) {
+    jintArray jshape = NULL;
+    if (shapeNdim[i] >= 0) {
+      jshape = env->NewIntArray(shapeNdim[i]);
+      if (jshape == NULL) {
+        // TODO(Yizhi): out of memory error thrown, return a specific error 
code ?
+        return -1;
+      }
+      env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast<const 
jint *>(shapeData[i]));
     }
-    env->SetIntArrayRegion(jshape, 0, shapeNdim[i], reinterpret_cast<const 
jint *>(shapeData[i]));
     env->CallObjectMethod(joutData, listAppend, jshape);
     env->DeleteLocalRef(jshape);
   }
   return 0;
 }
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
-  (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray 
jkeys,
-    jintArray jargIndPtr, jintArray jargShapeData,
-    jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, 
jobject jcomplete) {
+
+int SymbolInferShapeHelper(JNIEnv *env, jobject obj, jlong symbolPtr, jint 
jnumArgs, jobjectArray jkeys,
+                            jintArray jargIndPtr, jintArray jargShapeData, 
jobject jinShapeData,
+                            jobject joutShapeData, jobject jauxShapeData, 
jobject jcomplete, bool partial) {
   const char **keys = NULL;
   if (jkeys != NULL) {
     keys = new const char *[jnumArgs];
@@ -1559,36 +1674,55 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
   }
 
   mx_uint inShapeSize;
-  const mx_uint *inShapeNdim;
-  const mx_uint **inShapeData;
+  const int *inShapeNdim;
+  const int **inShapeData;
 
   mx_uint outShapeSize;
-  const mx_uint *outShapeNdim;
-  const mx_uint **outShapeData;
+  const int *outShapeNdim;
+  const int **outShapeData;
 
   mx_uint auxShapeSize;
-  const mx_uint *auxShapeNdim;
-  const mx_uint **auxShapeData;
+  const int *auxShapeNdim;
+  const int **auxShapeData;
 
   int complete;
 
   jint *argIndPtr = env->GetIntArrayElements(jargIndPtr, NULL);
   jint *argShapeData = env->GetIntArrayElements(jargShapeData, NULL);
-  int ret = MXSymbolInferShape(reinterpret_cast<SymbolHandle>(symbolPtr),
-                               static_cast<mx_uint>(jnumArgs),
-                               keys,
-                               reinterpret_cast<const mx_uint *>(argIndPtr),
-                               reinterpret_cast<const mx_uint *>(argShapeData),
-                               &inShapeSize,
-                               &inShapeNdim,
-                               &inShapeData,
-                               &outShapeSize,
-                               &outShapeNdim,
-                               &outShapeData,
-                               &auxShapeSize,
-                               &auxShapeNdim,
-                               &auxShapeData,
-                               &complete);
+  int ret;
+  if (!partial) {
+    ret = MXSymbolInferShape(reinterpret_cast<SymbolHandle>(symbolPtr),
+                              static_cast<mx_uint>(jnumArgs),
+                              keys,
+                              reinterpret_cast<mx_uint *>(argIndPtr),
+                              reinterpret_cast<const int *>(argShapeData),
+                              &inShapeSize,
+                              &inShapeNdim,
+                              &inShapeData,
+                              &outShapeSize,
+                              &outShapeNdim,
+                              &outShapeData,
+                              &auxShapeSize,
+                              &auxShapeNdim,
+                              &auxShapeData,
+                              &complete);
+  } else {
+    ret = MXSymbolInferShapePartial(reinterpret_cast<SymbolHandle>(symbolPtr),
+                                    static_cast<mx_uint>(jnumArgs),
+                                    keys,
+                                    reinterpret_cast<mx_uint *>(argIndPtr),
+                                    reinterpret_cast<const int 
*>(argShapeData),
+                                    &inShapeSize,
+                                    &inShapeNdim,
+                                    &inShapeData,
+                                    &outShapeSize,
+                                    &outShapeNdim,
+                                    &outShapeData,
+                                    &auxShapeSize,
+                                    &auxShapeNdim,
+                                    &auxShapeData,
+                                    &complete);
+  }
   env->ReleaseIntArrayElements(jargShapeData, argShapeData, 0);
   env->ReleaseIntArrayElements(jargIndPtr, argIndPtr, 0);
 
@@ -1629,6 +1763,24 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
   return ret;
 }
 
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
+  (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray 
jkeys,
+    jintArray jargIndPtr, jintArray jargShapeData,
+    jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, 
jobject jcomplete) {
+
+  return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, 
jargIndPtr, jargShapeData,
+                                jinShapeData, joutShapeData, jauxShapeData, 
jcomplete, false);
+}
+
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial
+  (JNIEnv *env, jobject obj, jlong symbolPtr, jint jnumArgs, jobjectArray 
jkeys,
+    jintArray jargIndPtr, jintArray jargShapeData,
+    jobject jinShapeData, jobject joutShapeData, jobject jauxShapeData, 
jobject jcomplete) {
+
+  return SymbolInferShapeHelper(env, obj, symbolPtr, jnumArgs, jkeys, 
jargIndPtr, jargShapeData,
+                                jinShapeData, joutShapeData, jauxShapeData, 
jcomplete, true);
+}
+
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorBindX
   (JNIEnv *env, jobject obj, jlong symbolPtr, jint deviceTypeId, jint 
deviceID, jint numCtx,
     jobjectArray jctxMapKeys, jintArray jctxMapDevTypes, jintArray 
jctxMapDevIDs, jint numArgs,
@@ -2551,3 +2703,20 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxDumpProfile
   (JNIEnv *env, jobject obj, jint finished) {
   return MXDumpProfile(finished);
 }
+
+// Numpy
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible
+  (JNIEnv *env, jobject obj, jobject compatibleRef) {
+  bool isCompatible;
+  int ret = MXIsNumpyCompatible(&isCompatible);
+  SetIntField(env, compatibleRef, static_cast<int>(isCompatible));
+  return ret;
+}
+
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible
+  (JNIEnv *env, jobject obj, jint isNpComp, jobject prevRef) {
+  int prev;
+  int ret = MXSetIsNumpyCompatible(isNpComp, &prev);
+  SetIntField(env, prevRef, prev);
+  return ret;
+}
\ No newline at end of file
diff --git 
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h 
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index b8a9b3b..467272c 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -513,6 +513,14 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxExecutorSetMonitorCallbac
 
 /*
  * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxExecutorReshape
+ * Signature: 
(IIII[Ljava/lang/String;[I[I[Ljava/lang/String;[I[ILscala/collection/mutable/ArrayBuffer;Lscala/collection/mutable/ArrayBuffer;Lscala/collection/mutable/ArrayBuffer;JLorg/apache/mxnet/Base/RefLong;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxExecutorReshape
+  (JNIEnv *, jobject, jint, jint, jint, jint, jobjectArray, jintArray, 
jintArray, jobjectArray, jintArray, jintArray, jobject, jobject, jobject, 
jlong, jobject);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
  * Method:    mxSymbolListAtomicSymbolCreators
  * Signature: (Lscala/collection/mutable/ListBuffer;)I
  */
@@ -657,6 +665,14 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxSymbolInferShape
 
 /*
  * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxSymbolInferShapePartial
+ * Signature: 
(JI[Ljava/lang/String;[I[ILscala/collection/mutable/ListBuffer;Lscala/collection/mutable/ListBuffer;Lscala/collection/mutable/ListBuffer;Lorg/apache/mxnet/Base/RefInt;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSymbolInferShapePartial
+  (JNIEnv *, jobject, jlong, jint, jobjectArray, jintArray, jintArray, 
jobject, jobject, jobject, jobject);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
  * Method:    mxSymbolGetOutput
  * Signature: (JILorg/apache/mxnet/Base/RefLong;)I
  */
@@ -855,6 +871,22 @@ JNIEXPORT jint JNICALL 
Java_org_apache_mxnet_LibInfo_mxSetProfilerState
 JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxDumpProfile
   (JNIEnv *, jobject, jint);
 
+/*
+ * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxIsNumpyCompatible
+ * Signature: (Lorg/apache/mxnet/Base/RefInt;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxIsNumpyCompatible
+  (JNIEnv *, jobject, jobject);
+
+/*
+ * Class:     org_apache_mxnet_LibInfo
+ * Method:    mxSetIsNumpyCompatible
+ * Signature: (ILorg/apache/mxnet/Base/RefInt;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxSetIsNumpyCompatible
+  (JNIEnv *, jobject, jint, jobject);
+
 #ifdef __cplusplus
 }
 #endif
diff --git a/src/operator/tensor/matrix_op-inl.h 
b/src/operator/tensor/matrix_op-inl.h
index 252e0c5..efc8289 100644
--- a/src/operator/tensor/matrix_op-inl.h
+++ b/src/operator/tensor/matrix_op-inl.h
@@ -240,7 +240,7 @@ inline bool FlattenShape(const nnvm::NodeAttrs& attrs,
 struct TransposeParam : public dmlc::Parameter<TransposeParam> {
   mxnet::TShape axes;
   DMLC_DECLARE_PARAMETER(TransposeParam) {
-    DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape())
+    DMLC_DECLARE_FIELD(axes).set_default(mxnet::TShape(0))
     .describe("Target axis order. By default the axes will be inverted.");
   }
 };
@@ -314,7 +314,7 @@ void Transpose(const nnvm::NodeAttrs& attrs,
                const std::vector<TBlob>& outputs) {
   const TransposeParam& param = nnvm::get<TransposeParam>(attrs.parsed);
   CHECK_EQ(req[0], kWriteTo) << "Transpose does not support inplace";
-  if (!mxnet::ndim_is_known(param.axes)) {
+  if (param.axes.ndim() == 0) {
     mxnet::TShape axes(inputs[0].ndim(), -1);
     for (int i = 0; i < axes.ndim(); ++i) {
       axes[i] = axes.ndim() - 1 - i;
@@ -334,7 +334,7 @@ inline bool TransposeShape(const nnvm::NodeAttrs& attrs,
   mxnet::TShape& shp = (*in_attrs)[0];
   CHECK_LE(shp.ndim(), 6U) << "Transpose support at most 6 dimensions";
   mxnet::TShape ret(shp.ndim(), -1);
-  if (!mxnet::ndim_is_known(param.axes)) {
+  if (param.axes.ndim() == 0) {
     for (int i = 0; i < shp.ndim(); ++i) {
       ret[i] = shp[shp.ndim()-1-i];
     }

Reply via email to