This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new b482a44 [MXNET-1379] update reshape operator (#14600)
b482a44 is described below
commit b482a44fa8cd932ed48d62bafadb11299c7cd381
Author: Lanking <[email protected]>
AuthorDate: Wed Apr 3 10:07:38 2019 -0700
[MXNET-1379] update reshape operator (#14600)
* update reshape operator
* Satisfy the Lint God =v=
* update the jni header signature
---
.../core/src/main/scala/org/apache/mxnet/LibInfo.scala | 5 +++--
.../core/src/main/scala/org/apache/mxnet/NDArray.scala | 13 ++++++++++++-
.../core/src/test/scala/org/apache/mxnet/NDArraySuite.scala | 8 ++++++--
.../native/src/main/native/org_apache_mxnet_native_c_api.cc | 13 +++++++------
.../native/src/main/native/org_apache_mxnet_native_c_api.h | 8 ++++----
5 files changed, 32 insertions(+), 15 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 20b6ed9..40fc095 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -86,9 +86,10 @@ private[mxnet] class LibInfo {
@native def mxNDArrayAt(handle: NDArrayHandle,
idx: MXUint,
out: NDArrayHandleRef): Int
- @native def mxNDArrayReshape(handle: NDArrayHandle,
+ @native def mxNDArrayReshape64(handle: NDArrayHandle,
nDim: Int,
- dims: Array[Int],
+ dims: Array[Long],
+ reverse: Boolean,
reshapeHandle: NDArrayHandleRef): Int
@native def mxNDArraySyncCopyFromCPU(handle: NDArrayHandle,
source: Array[MXFloat],
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 915e4c6..ab42265 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -950,8 +950,19 @@ class NDArray private[mxnet](private[mxnet] val handle:
NDArrayHandle,
* @return a reshaped NDArray that shares memory with current one.
*/
def reshape(dims: Array[Int]): NDArray = {
+ reshape(dims.map(_.toLong))
+ }
+
+ /**
+ * Return a reshaped NDArray that shares memory with current one.
+ * @param dims New shape.
+ * @param reverse whether to inplace reshape
+ * @return a reshaped NDArray that shares memory with current one.
+ */
+ def reshape(dims: Array[Long], reverse: Option[Boolean] = None): NDArray = {
val reshapeHandle = new NDArrayHandleRef
- checkCall(_LIB.mxNDArrayReshape(handle, dims.length, dims, reshapeHandle))
+ checkCall(_LIB.mxNDArrayReshape64(handle,
+ dims.length, dims, reverse.getOrElse(false), reshapeHandle))
new NDArray(handle = reshapeHandle.value, writable = this.writable)
}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index 206094c..c2ef641 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -878,14 +878,18 @@ class NDArraySuite extends FunSuite with
BeforeAndAfterAll with Matchers {
}
test("reshape") {
- val arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2))
+ var arr = NDArray.array(Array(1f, 2f, 3f, 4f, 5f, 6f), shape = Shape(3, 2))
- val arr1 = arr.reshape(Array(2, 3))
+ var arr1 = arr.reshape(Array(2, 3))
assert(arr1.shape === Shape(2, 3))
assert(arr1.toArray === Array(1f, 2f, 3f, 4f, 5f, 6f))
arr.set(1f)
assert(arr1.toArray === Array(1f, 1f, 1f, 1f, 1f, 1f))
+
+ arr = NDArray.ones(1, 384, 1)
+ arr1 = arr.reshape(Array(0, -3))
+ assert(arr1.shape === Shape(1, 384))
}
test("dispose deps") {
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index ea6e9c8..33e4cca 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -404,14 +404,15 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayAt
return ret;
}
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape
- (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim, jintArray dims,
jobject reshapedHandle) {
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
+ (JNIEnv *env, jobject obj, jlong ndArrayPtr, jint ndim,
+ jlongArray dims, jboolean reverse, jobject reshapedHandle) {
NDArrayHandle out;
- jint *pdims = env->GetIntArrayElements(dims, NULL);
- int ret = MXNDArrayReshape(reinterpret_cast<NDArrayHandle>(ndArrayPtr), ndim,
- reinterpret_cast<int *>(pdims), &out);
+ jlong *pdims = env->GetLongArrayElements(dims, NULL);
+ int ret = MXNDArrayReshape64(reinterpret_cast<NDArrayHandle>(ndArrayPtr),
ndim,
+ reinterpret_cast<dim_t *>(pdims), reverse,
&out);
SetLongField(env, reshapedHandle, reinterpret_cast<jlong>(out));
- env->ReleaseIntArrayElements(dims, pdims, 0);
+ env->ReleaseLongArrayElements(dims, pdims, 0);
return ret;
}
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index 7e8e03d..b8a9b3b 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -161,11 +161,11 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayAt
/*
* Class: org_apache_mxnet_LibInfo
- * Method: mxNDArrayReshape
- * Signature: (JI[ILorg/apache/mxnet/Base/RefLong;)I
+ * Method: mxNDArrayReshape64
+ * Signature: (JI[JZLorg/apache/mxnet/Base/RefLong;)I
*/
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape
- (JNIEnv *, jobject, jlong, jint, jintArray, jobject);
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayReshape64
+ (JNIEnv *, jobject, jlong, jint, jlongArray, jboolean, jobject);
/*
* Class: org_apache_mxnet_LibInfo