This is an automated email from the ASF dual-hosted git repository.
lanking pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 09c71bf Add Sparse NDArray support for Scala (#15378)
09c71bf is described below
commit 09c71bf3144b09a28b9d09d33703a3dcbf4ca9a5
Author: Lanking <[email protected]>
AuthorDate: Mon Jul 8 11:48:30 2019 -0700
Add Sparse NDArray support for Scala (#15378)
* add Sparse Support
* add imperative invoke sparse support
* add retain method and comments
* add getData method
* add Sparse NDIter test
* remove debug line
---
.../src/main/scala/org/apache/mxnet/DType.scala | 17 +-
.../src/main/scala/org/apache/mxnet/Executor.scala | 9 +-
.../src/main/scala/org/apache/mxnet/LibInfo.scala | 27 ++-
.../src/main/scala/org/apache/mxnet/NDArray.scala | 65 ++++++-
.../main/scala/org/apache/mxnet/SparseFormat.scala | 25 +++
.../scala/org/apache/mxnet/SparseNDArray.scala | 196 +++++++++++++++++++++
.../test/scala/org/apache/mxnet/NDArraySuite.scala | 16 ++
.../org/apache/mxnet/SparseNDArraySuite.scala | 93 ++++++++++
.../main/native/org_apache_mxnet_native_c_api.cc | 75 +++++++-
.../main/native/org_apache_mxnet_native_c_api.h | 48 ++++-
10 files changed, 543 insertions(+), 28 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
index f3a8e8e..1d5cc28 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
@@ -24,26 +24,17 @@ object DType extends Enumeration {
val Float16 = Value(2, "float16")
val UInt8 = Value(3, "uint8")
val Int32 = Value(4, "int32")
+ val Int8 = Value(5, "int8")
+ val Int64 = Value(6, "int64")
val Unknown = Value(-1, "unknown")
private[mxnet] def numOfBytes(dtype: DType): Int = {
dtype match {
- case DType.UInt8 => 1
+ case DType.UInt8 | DType.Int8 => 1
case DType.Int32 => 4
case DType.Float16 => 2
case DType.Float32 => 4
- case DType.Float64 => 8
+ case DType.Float64 | DType.Int64 => 8
case DType.Unknown => 0
}
}
- private[mxnet] def getType(dtypeStr: String): DType = {
- dtypeStr match {
- case "UInt8" => DType.UInt8
- case "Int32" => DType.Int32
- case "Float16" => DType.Float16
- case "Float32" => DType.Float32
- case "Float64" => DType.Float64
- case _ => throw new IllegalArgumentException(
- s"DType: $dtypeStr not found! please set it in DType.scala")
- }
- }
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
index b0fae0f..6365f9c 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Executor.scala
@@ -159,7 +159,14 @@ class Executor private[mxnet](private[mxnet] val handle:
ExecutorHandle,
private def getOutputs: Array[NDArray] = {
val ndHandles = ArrayBuffer[NDArrayHandle]()
checkCall(_LIB.mxExecutorOutputs(handle, ndHandles))
- ndHandles.toArray.map(new NDArray(_, addToCollector = false))
+ ndHandles.toArray.map(ele => {
+ val nd = new NDArray(ele, addToCollector = false)
+ if (nd.isSparse) {
+ nd.asInstanceOf[SparseNDArray]
+ }
+ nd
+ }
+ )
}
/**
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
index 640ecf5..0ee6476 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/LibInfo.scala
@@ -31,13 +31,14 @@ private[mxnet] class LibInfo {
@native def mxListAllOpNames(names: ListBuffer[String]): Int
@native def nnGetOpHandle(opName: String, opHandle: RefLong): Int
// NDArray
- @native def mxImperativeInvoke(creator: FunctionHandle,
+ @native def mxImperativeInvokeEx(creator: FunctionHandle,
inputs: Array[NDArrayHandle],
outputsGiven: Array[NDArrayHandle],
outputs: ArrayBuffer[NDArrayHandle],
numParams: Int,
paramKeys: Array[String],
- paramVals: Array[String]): Int
+ paramVals: Array[String],
+ outStype: ArrayBuffer[Int]): Int
@native def mxNDArrayFree(handle: NDArrayHandle): Int
@native def mxNDArrayCreateNone(out: NDArrayHandleRef): Int
@native def mxNDArrayCreateEx(shape: Array[Int],
@@ -47,6 +48,20 @@ private[mxnet] class LibInfo {
delayAlloc: Int,
dtype: Int,
out: NDArrayHandleRef): Int
+ // scalastyle:off parameterNum
+ @native def mxNDArrayCreateSparseEx(storageType: Int,
+ shape: Array[Int],
+ ndim: Int,
+ devType: Int,
+ devId: Int,
+ delayAlloc: Int,
+ dtype: Int,
+ numAux: Int,
+ auxTypes: Array[Int],
+ auxNdims: Array[Int],
+ auxShapes: Array[Int],
+ out: NDArrayHandleRef): Int
+ // scalastyle:on parameterNum
@native def mxNDArrayWaitAll(): Int
@native def mxNDArrayWaitToRead(handle: NDArrayHandle): Int
@native def mxListFunctions(functions: ListBuffer[FunctionHandle]): Int
@@ -76,6 +91,9 @@ private[mxnet] class LibInfo {
@native def mxNDArrayGetShape(handle: NDArrayHandle,
ndim: MXUintRef,
data: ArrayBuffer[Int]): Int
+ @native def mxNDArraySyncCopyFromNDArray(handleDst: NDArrayHandle,
+ handleSrc: NDArrayHandle,
+ locator: Int): Int
@native def mxNDArraySyncCopyToCPU(handle: NDArrayHandle,
data: Array[Byte],
size: Int): Int
@@ -105,10 +123,15 @@ private[mxnet] class LibInfo {
@native def mxNDArraySave(fname: String,
handles: Array[NDArrayHandle],
keys: Array[String]): Int
+ @native def mxNDArrayGetDataNDArray(handle: NDArrayHandle, out:
NDArrayHandleRef): Int
+ @native def mxNDArrayGetAuxNDArray(handle: NDArrayHandle,
+ location: Int,
+ out: NDArrayHandleRef): Int
@native def mxNDArrayGetContext(handle: NDArrayHandle, devTypeId: RefInt,
devId: RefInt): Int
@native def mxNDArraySaveRawBytes(handle: NDArrayHandle, buf:
ArrayBuffer[Byte]): Int
@native def mxNDArrayLoadFromRawBytes(bytes: Array[Byte], handle:
NDArrayHandleRef): Int
@native def mxNDArrayGetDType(handle: NDArrayHandle, dtype: RefInt): Int
+ @native def mxNDArrayGetStorageType(handle: NDArrayHandle, stype: RefInt):
Int
// KVStore Server
@native def mxInitPSEnv(keys: Array[String], values: Array[String]): Int
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
index 4088801..1b7b31b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/NDArray.scala
@@ -21,7 +21,8 @@ import java.nio.{ByteBuffer, ByteOrder}
import org.apache.mxnet.Base._
import org.apache.mxnet.DType.DType
-import org.apache.mxnet.MX_PRIMITIVES.{MX_PRIMITIVE_TYPE}
+import org.apache.mxnet.MX_PRIMITIVES.MX_PRIMITIVE_TYPE
+import org.apache.mxnet.SparseFormat.SparseFormat
import org.slf4j.LoggerFactory
import scala.collection.mutable
@@ -113,10 +114,22 @@ object NDArray extends NDArrayBase {
}
val outputs = ArrayBuffer.empty[NDArrayHandle]
- checkCall(_LIB.mxImperativeInvoke(function.handle,
ndArgs.map(_.handle).toArray, outputVars,
- outputs, updatedKwargs.size, updatedKwargs.keys.toArray,
updatedKwargs.values.toArray))
+ val outStypes = ArrayBuffer.empty[Int]
+ checkCall(_LIB.mxImperativeInvokeEx(function.handle,
+ ndArgs.map(_.handle).toArray,
+ outputVars,
+ outputs,
+ updatedKwargs.size,
+ updatedKwargs.keys.toArray,
+ updatedKwargs.values.toArray,
+ outStypes))
new NDArrayFuncReturn(Option(oriOutputs).getOrElse {
- val outputArrs = outputs.map(new NDArray(_)).toArray
+ val outputArrs = (outputs zip outStypes).map(
+ ele => ele._2 match {
+ case 0 => new NDArray(ele._1)
+ case _ => new SparseNDArray(ele._1)
+ }
+ ).toArray
addDependency(ndArgs.toArray, outputArrs)
outputArrs
})
@@ -943,6 +956,12 @@ class NDArray private[mxnet](private[mxnet] val handle:
NDArrayHandle,
DType(mxDtype.value)
}
+ val sparseFormat: SparseFormat = {
+ val mxSF = new RefInt
+ checkCall(_LIB.mxNDArrayGetStorageType(handle, mxSF))
+ SparseFormat(mxSF.value)
+ }
+
/**
* Return a copied numpy array of current array with specified type.
* @param dtype Desired type of result array.
@@ -1309,6 +1328,30 @@ class NDArray private[mxnet](private[mxnet] val handle:
NDArrayHandle,
if (this.context == context) this else this.copyTo(context)
}
+ /**
+ * check if NDArray is SparseNDArray
+ * @return Boolean
+ */
+ def isSparse: Boolean = {
+ this.sparseFormat.id != 0
+ }
+
+ /**
+ * Convert a NDArray to SparseNDArray
+ *
+ * @param sfOption the target sparse type
+ * @return SparseNDArray
+ */
+ def toSparse(sfOption : Option[SparseFormat] = None): SparseNDArray = {
+ val sf = sfOption.getOrElse(SparseFormat.ROW_SPARSE)
+ if (sf.id == 0) throw new IllegalArgumentException("Require Sparse")
+ if (isSparse && sfOption.isEmpty) {
+ this.asInstanceOf[SparseNDArray]
+ } else {
+ NDArray.api.cast_storage(this,
sf.toString).head.asInstanceOf[SparseNDArray]
+ }
+ }
+
override def equals(o: Any): Boolean = o match {
case that: NDArray =>
that != null && that.shape == this.shape &&
that.toArray.sameElements(this.toArray)
@@ -1479,6 +1522,7 @@ private[mxnet] class NDArrayInternal (private val
internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat.toDouble)
case DType.Float64 => units.map(wrapBytes(_).getDouble)
case DType.Int32 => units.map(wrapBytes(_).getInt.toDouble)
+ case DType.Int64 => units.map(wrapBytes(_).getLong.toDouble)
case DType.UInt8 => internal.map(_.toDouble)
}
}
@@ -1488,6 +1532,7 @@ private[mxnet] class NDArrayInternal (private val
internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toFloat)
case DType.Int32 => units.map(wrapBytes(_).getInt.toFloat)
+ case DType.Int64 => units.map(wrapBytes(_).getLong.toFloat)
case DType.UInt8 => internal.map(_.toFloat)
}
}
@@ -1497,15 +1542,27 @@ private[mxnet] class NDArrayInternal (private val
internal: Array[Byte], private
case DType.Float32 => units.map(wrapBytes(_).getFloat.toInt)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toInt)
case DType.Int32 => units.map(wrapBytes(_).getInt)
+ case DType.Int64 => units.map(wrapBytes(_).getLong.toInt)
case DType.UInt8 => internal.map(_.toInt)
}
}
+ def toLongArray: Array[Long] = {
+ require(dtype != DType.Float16, "Currently cannot convert float16 to
native numerical types")
+ dtype match {
+ case DType.Float32 => units.map(wrapBytes(_).getFloat.toLong)
+ case DType.Float64 => units.map(wrapBytes(_).getDouble.toLong)
+ case DType.Int32 => units.map(wrapBytes(_).getInt.toLong)
+ case DType.Int64 => units.map(wrapBytes(_).getLong)
+ case DType.UInt8 => internal.map(_.toLong)
+ }
+ }
def toByteArray: Array[Byte] = {
require(dtype != DType.Float16, "Currently cannot convert float16 to
native numerical types")
dtype match {
case DType.Float16 | DType.Float32 =>
units.map(wrapBytes(_).getFloat.toByte)
case DType.Float64 => units.map(wrapBytes(_).getDouble.toByte)
case DType.Int32 => units.map(wrapBytes(_).getInt.toByte)
+ case DType.Int64 => units.map(wrapBytes(_).getLong.toByte)
case DType.UInt8 => internal.clone()
}
}
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala
new file mode 100644
index 0000000..acb0c0f
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseFormat.scala
@@ -0,0 +1,25 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+object SparseFormat extends Enumeration {
+ type SparseFormat = Value
+ val DEFAULT = Value(0, "default")
+ val ROW_SPARSE = Value(1, "row_sparse")
+ val CSR = Value(2, "csr")
+}
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala
new file mode 100644
index 0000000..f3fe638
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/SparseNDArray.scala
@@ -0,0 +1,196 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.Base.{NDArrayHandle, NDArrayHandleRef, checkCall, _LIB}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.SparseFormat.SparseFormat
+
+object SparseNDArray {
+ /**
+ * Create a Compressed Sparse Row Storage (CSR) Format Matrix
+ * @param data the data to feed
+ * @param indices The indices array stores the column index for each
non-zero element in data
+ * @param indptr The indptr array is what will help identify the rows where
the data appears
+ * @param shape the shape of CSR NDArray to be created
+ * @param ctx the context of this NDArray
+ * @return SparseNDArray
+ */
+ def csrMatrix(data: Array[Float], indices: Array[Float],
+ indptr: Array[Float], shape: Shape, ctx: Context):
SparseNDArray = {
+ val fmt = SparseFormat.CSR
+ val dataND = NDArray.array(data, Shape(data.length), ctx)
+ val indicesND = NDArray.array(indices, Shape(indices.length),
ctx).asType(DType.Int64)
+ val indptrND = NDArray.array(indptr, Shape(indptr.length),
ctx).asType(DType.Int64)
+ val dTypes = Array(indptrND.dtype, indicesND.dtype)
+ val shapes = Array(indptrND.shape, indicesND.shape)
+ val handle =
+ newAllocHandle(fmt, shape, ctx, false, DType.Float32, dTypes, shapes)
+ checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, dataND.handle, -1))
+ checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indptrND.handle, 0))
+ checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indicesND.handle, 1))
+ new SparseNDArray(handle)
+ }
+
+ /**
+ * RowSparseNDArray stores the matrix in row sparse format,
+ * which is designed for arrays of which most row slices are all zeros
+ * @param data Any Array(Array(... Array(Float)))
+ * @param indices the indices to store the data
+ * @param shape shape of the NDArray
+ * @param ctx Context
+ * @return SparseNDArray
+ */
+ def rowSparseArray(data: Array[_], indices: Array[Float],
+ shape: Shape, ctx: Context): SparseNDArray = {
+ val dataND = NDArray.toNDArray(data)
+ val indicesND = NDArray.array(indices, Shape(indices.length),
ctx).asType(DType.Int64)
+ rowSparseArray(dataND, indicesND, shape, ctx)
+ }
+
+ /**
+ * RowSparseNDArray stores the matrix in row sparse format,
+ * which is designed for arrays of which most row slices are all zeros
+ * @param data NDArray input
+ * @param indices in NDArray. Only DType.Int64 supported
+ * @param shape shape of the NDArray
+ * @param ctx Context
+ * @return
+ */
+ def rowSparseArray(data: NDArray, indices: NDArray,
+ shape: Shape, ctx: Context): SparseNDArray = {
+ val fmt = SparseFormat.ROW_SPARSE
+ val handle = newAllocHandle(fmt, shape, ctx, false,
+ DType.Float32, Array(indices.dtype), Array(indices.shape))
+ checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, data.handle, -1))
+ checkCall(_LIB.mxNDArraySyncCopyFromNDArray(handle, indices.handle, 0))
+ new SparseNDArray(handle)
+ }
+
+ def retain(sparseNDArray: SparseNDArray, indices: Array[Float]):
SparseNDArray = {
+ if (sparseNDArray.sparseFormat == SparseFormat.CSR) {
+ throw new IllegalArgumentException("CSR not supported")
+ }
+ NDArray.genericNDArrayFunctionInvoke("_sparse_retain",
+ Seq(sparseNDArray, NDArray.toNDArray(indices))).head.toSparse()
+ }
+
+ private def newAllocHandle(stype : SparseFormat,
+ shape: Shape,
+ ctx: Context,
+ delayAlloc: Boolean,
+ dtype: DType = DType.Float32,
+ auxDTypes: Array[DType],
+ auxShapes: Array[Shape]) : NDArrayHandle = {
+ val hdl = new NDArrayHandleRef
+ checkCall(_LIB.mxNDArrayCreateSparseEx(
+ stype.id,
+ shape.toArray,
+ shape.length,
+ ctx.deviceTypeid,
+ ctx.deviceId,
+ if (delayAlloc) 1 else 0,
+ dtype.id,
+ auxDTypes.length,
+ auxDTypes.map(_.id),
+ auxShapes.map(_.length),
+ auxShapes.map(_.get(0)),
+ hdl)
+ )
+ hdl.value
+ }
+}
+
+/**
+ * Sparse NDArray is the child class of NDArray designed to hold the Sparse
format
+ *
+ * <p> Currently, Rowsparse and CSR typed NDArray is supported. Most of the
Operators
+ * will convert Sparse NDArray to dense. Basic operators like
<code>add</code> will
+ * have optimization for sparse operattions</p>
+ * @param handle The pointer that SparseNDArray holds
+ * @param writable whether the NDArray is writable
+ */
+class SparseNDArray private[mxnet] (override private[mxnet] val handle:
NDArrayHandle,
+ override val writable: Boolean = true)
+ extends NDArray(handle, writable) {
+
+ private lazy val dense: NDArray = toDense
+
+ override def toString: String = {
+ dense.toString
+ }
+
+ /**
+ * Convert a SparseNDArray to dense NDArray
+ * @return NDArray
+ */
+ def toDense: NDArray = {
+ NDArray.api.cast_storage(this, SparseFormat.DEFAULT.toString).head
+ }
+
+ override def toArray: Array[Float] = {
+ dense.toArray
+ }
+
+ override def at(idx: Int): NDArray = {
+ dense.at(idx)
+ }
+
+ override def slice(start: Int, end: Int): NDArray = {
+ NDArray.api.slice(this, Shape(start), Shape(end))
+ }
+
+ /**
+ * Get the Data portion from a Row Sparse NDArray
+ * @return NDArray
+ */
+ def getData: NDArray = {
+ require(this.sparseFormat == SparseFormat.ROW_SPARSE, "Not Supported for
CSR")
+ val handle = new NDArrayHandleRef
+ _LIB.mxNDArrayGetDataNDArray(this.handle, handle)
+ new NDArray(handle.value, false)
+ }
+
+ /**
+ * Get the indptr Array from a CSR NDArray
+ * @return NDArray
+ */
+ def getIndptr: NDArray = {
+ require(this.sparseFormat == SparseFormat.CSR, "Not Supported for row
sparse")
+ getAuxNDArray(0)
+ }
+
+ /**
+ * Get the indice Array
+ * @return NDArray
+ */
+ def getIndices: NDArray = {
+ if (this.sparseFormat == SparseFormat.ROW_SPARSE) {
+ getAuxNDArray(0)
+ } else {
+ getAuxNDArray(1)
+ }
+ }
+
+ private def getAuxNDArray(idx: Int): NDArray = {
+ val handle = new NDArrayHandleRef
+ checkCall(_LIB.mxNDArrayGetAuxNDArray(this.handle, idx, handle))
+ new NDArray(handle.value, false)
+ }
+
+}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
index c2ef641..82b9edc 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/NDArraySuite.scala
@@ -45,6 +45,22 @@ class NDArraySuite extends FunSuite with BeforeAndAfterAll
with Matchers {
assert(ndones.toScalar === 1f)
}
+ test("to sparse") {
+ val arr = Array(
+ Array(1f, 0f, 0f),
+ Array(0f, 3f, 0f),
+ Array(0f, 0f, 1f)
+ )
+ val nd = NDArray.toNDArray(arr)
+ assert(!nd.isSparse)
+ // row sparse
+ var ndSparse = nd.toSparse()
+ assert(ndSparse.getIndices.toArray sameElements Array(0f, 1f, 2f))
+ // csr
+ ndSparse = nd.toSparse(Some(SparseFormat.CSR))
+ assert(ndSparse.getIndptr.toArray sameElements Array(0f, 1f, 2f, 3f))
+ }
+
test("to float 64 scalar") {
val ndzeros = NDArray.zeros(Shape(1), dtype = DType.Float64)
assert(ndzeros.toFloat64Scalar === 0d)
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala
new file mode 100644
index 0000000..f9968ef
--- /dev/null
+++
b/scala-package/core/src/test/scala/org/apache/mxnet/SparseNDArraySuite.scala
@@ -0,0 +1,93 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+import org.apache.mxnet.io.NDArrayIter
+import org.scalatest.FunSuite
+import org.slf4j.LoggerFactory
+
+class SparseNDArraySuite extends FunSuite {
+
+ private val logger = LoggerFactory.getLogger(classOf[SparseNDArraySuite])
+
+ test("create CSR NDArray") {
+ val data = Array(7f, 8f, 9f)
+ val indices = Array(0f, 2f, 1f)
+ val indptr = Array(0f, 2f, 2f, 3f)
+ val shape = Shape(3, 4)
+ val sparseND = SparseNDArray.csrMatrix(data, indices, indptr, shape,
Context.cpu())
+ assert(sparseND.shape == Shape(3, 4))
+ assert(sparseND.toArray
+ sameElements Array(7.0f, 0.0f, 8.0f, 0.0f,
+ 0.0f, 0.0f, 0.0f, 0.0f,
+ 0.0f, 9.0f, 0.0f, 0.0f))
+ assert(sparseND.sparseFormat == SparseFormat.CSR)
+ assert(sparseND.getIndptr.toArray sameElements indptr)
+ assert(sparseND.getIndices.toArray sameElements indices)
+ }
+
+ test("create Row Sparse NDArray") {
+ val data = Array(
+ Array(1f, 2f),
+ Array(3f, 4f)
+ )
+ val indices = Array(1f, 4f)
+ val shape = Shape(6, 2)
+ val sparseND = SparseNDArray.rowSparseArray(data, indices, shape,
Context.cpu())
+ assert(sparseND.sparseFormat == SparseFormat.ROW_SPARSE)
+ assert(sparseND.shape == Shape(6, 2))
+ assert(sparseND.at(1).toArray sameElements Array(1f, 2f))
+ assert(sparseND.getIndices.toArray sameElements indices)
+ }
+
+ test("Test retain") {
+ val arr = Array(
+ Array(1f, 2f),
+ Array(3f, 4f),
+ Array(5f, 6f)
+ )
+ val indices = Array(0f, 1f, 3f)
+ val rspIn = SparseNDArray.rowSparseArray(arr, indices, Shape(4, 2),
Context.cpu())
+ val toRetain = Array(0f, 3f)
+ val rspOut = SparseNDArray.retain(rspIn, toRetain)
+ assert(rspOut.getData.toArray sameElements Array(1f, 2f, 5f, 6f))
+ assert(rspOut.getIndices.toArray sameElements Array(0f, 3f))
+ }
+
+ test("Test add") {
+ val nd = NDArray.array(Array(1f, 2f, 3f),
Shape(3)).toSparse(Some(SparseFormat.ROW_SPARSE))
+ val nd2 = nd + nd
+ assert(nd2.isInstanceOf[SparseNDArray])
+ assert(nd2.toArray sameElements Array(2f, 4f, 6f))
+ }
+
+ test("Test DataIter") {
+ val nd = NDArray.array(Array(1f, 2f, 3f), Shape(1,
3)).toSparse(Some(SparseFormat.CSR))
+ val arr = IndexedSeq(nd, nd, nd, nd)
+ val iter = new NDArrayIter(arr)
+ while (iter.hasNext) {
+ val tempArr = iter.next().data
+ tempArr.foreach(ele => {
+ assert(ele.sparseFormat == SparseFormat.CSR)
+ assert(ele.shape == Shape(1, 3))
+ })
+ }
+ }
+
+
+}
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
index 9b19fd3..387a0b1 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.cc
@@ -93,6 +93,31 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx
return ret;
}
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx
+ (JNIEnv *env, jobject obj, jint storageType, jintArray shape, jint ndim,
jint devType,
+ jint devId, jint delayAlloc, jint dtype, jint numAux, jintArray auxTypes,
+ jintArray auxNdims, jintArray auxShapes, jobject ndArrayHandle) {
+ jint *shapeArr = env->GetIntArrayElements(shape, NULL);
+ jint *auxTypesArr = env->GetIntArrayElements(auxTypes, NULL);
+ jint *auxNdimsArr = env->GetIntArrayElements(auxNdims, NULL);
+ jint *auxShapesArr = env->GetIntArrayElements(auxShapes, NULL);
+ NDArrayHandle out;
+ int ret = MXNDArrayCreateSparseEx(storageType,
+ reinterpret_cast<const mx_uint *>(shapeArr),
+ static_cast<mx_uint>(ndim),
+ devType, devId, delayAlloc, dtype,
+ static_cast<mx_uint>(numAux),
+ reinterpret_cast<int *>(auxTypesArr),
+ reinterpret_cast<mx_uint *>(auxNdimsArr),
+ reinterpret_cast<const mx_uint *>(auxShapesArr), &out);
+ env->ReleaseIntArrayElements(shape, shapeArr, 0);
+ env->ReleaseIntArrayElements(auxTypes, auxTypesArr, 0);
+ env->ReleaseIntArrayElements(auxNdims, auxNdimsArr, 0);
+ env->ReleaseIntArrayElements(auxShapes, auxShapesArr, 0);
+ SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayWaitAll(JNIEnv
*env, jobject obj) {
return MXNDArrayWaitAll();
}
@@ -179,10 +204,10 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxFuncGetInfo
return ret;
}
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx
(JNIEnv *env, jobject obj, jlong funcPtr, jlongArray inputs,
jlongArray outputsGiven, jobject outputs, jint numParams,
- jobjectArray paramKeys, jobjectArray paramVals) {
+ jobjectArray paramKeys, jobjectArray paramVals, jobject outStypes) {
const char **cParamKeys = NULL;
const char **cParamVals = NULL;
@@ -204,6 +229,7 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
int numOutputs = 0;
jlong *cOutputsGiven = NULL;
NDArrayHandle *cOutputs = NULL;
+ const int *cOutStypes;
if (outputsGiven) {
cOutputsGiven = env->GetLongArrayElements(outputsGiven, NULL);
cOutputs = reinterpret_cast<NDArrayHandle *>(cOutputsGiven);
@@ -211,14 +237,15 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
}
jlong *cInputs = env->GetLongArrayElements(inputs, NULL);
jsize numInputs = env->GetArrayLength(inputs);
- int ret = MXImperativeInvoke(reinterpret_cast<AtomicSymbolCreator>(funcPtr),
+ int ret =
MXImperativeInvokeEx(reinterpret_cast<AtomicSymbolCreator>(funcPtr),
static_cast<int>(numInputs),
reinterpret_cast<NDArrayHandle *>(cInputs),
&numOutputs,
&cOutputs,
static_cast<int>(numParams),
cParamKeys,
- cParamVals);
+ cParamVals,
+ &cOutStypes);
env->ReleaseLongArrayElements(inputs, cInputs, 0);
if (cOutputsGiven) {
env->ReleaseLongArrayElements(outputsGiven, cOutputsGiven, 0);
@@ -240,7 +267,9 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
if (cOutputs) {
jclass longCls = env->FindClass("java/lang/Long");
+ jclass intCls = env->FindClass("java/lang/Integer");
jmethodID longConst = env->GetMethodID(longCls, "<init>", "(J)V");
+ jmethodID intConst = env->GetMethodID(intCls, "<init>", "(I)V");
// scala.collection.mutable.ListBuffer append method
jclass listClass = env->FindClass("scala/collection/mutable/ArrayBuffer");
jmethodID listAppend = env->GetMethodID(listClass, "$plus$eq",
@@ -249,6 +278,9 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
env->CallObjectMethod(outputs, listAppend,
env->NewObject(longCls, longConst,
reinterpret_cast<uint64_t>(cOutputs[i])));
+ env->CallObjectMethod(outStypes, listAppend,
+ env->NewObject(intCls, intConst,
+ cOutStypes[i]));
}
}
@@ -379,6 +411,14 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
return ret;
}
+JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray
+ (JNIEnv *env, jobject obj, jlong dstPtr, jlong srcPtr, jint locator) {
+ int ret =
MXNDArraySyncCopyFromNDArray(reinterpret_cast<NDArrayHandle>(dstPtr),
+ reinterpret_cast<NDArrayHandle>(srcPtr),
+ locator);
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyToCPU
(JNIEnv *env, jobject obj, jlong ndArrayPtr, jbyteArray data, jint size) {
jbyte *pdata = env->GetByteArrayElements(data, NULL);
@@ -434,6 +474,25 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxFloat64NDArraySyncCopyFro
return ret;
}
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray
+ (JNIEnv *env, jobject obj, jlong arrayPtr, jobject ndArrayHandle) {
+ NDArrayHandle out;
+ int ret = MXNDArrayGetDataNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
+ &out);
+ SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
+ return ret;
+}
+
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray
+ (JNIEnv *env, jobject obj, jlong arrayPtr, jint location, jobject
ndArrayHandle) {
+ NDArrayHandle out;
+ int ret = MXNDArrayGetAuxNDArray(reinterpret_cast<NDArrayHandle>(arrayPtr),
+ static_cast<mx_uint>(location),
+ &out);
+ SetLongField(env, ndArrayHandle, reinterpret_cast<jlong>(out));
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetContext
(JNIEnv *env, jobject obj, jlong arrayPtr, jobject devTypeId, jobject devId)
{
int outDevType;
@@ -540,6 +599,14 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType
return ret;
}
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType
+ (JNIEnv * env, jobject obj, jlong jhandle, jobject jstype) {
+ int stype;
+ int ret = MXNDArrayGetStorageType(reinterpret_cast<NDArrayHandle>(jhandle),
&stype);
+ SetIntField(env, jstype, stype);
+ return ret;
+}
+
JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxInitPSEnv
(JNIEnv *env, jobject obj, jobjectArray jkeys, jobjectArray jvals) {
// keys and values
diff --git
a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
index fac32bb..c8ee0ce 100644
--- a/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
+++ b/scala-package/native/src/main/native/org_apache_mxnet_native_c_api.h
@@ -41,11 +41,11 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_nnGetOpHandle
/*
* Class: org_apache_mxnet_LibInfo
- * Method: mxImperativeInvoke
- * Signature:
(J[J[JLscala/collection/mutable/ArrayBuffer;I[Ljava/lang/String;[Ljava/lang/String;)I
+ * Method: mxImperativeInvokeEx
+ * Signature:
(J[J[JLscala/collection/mutable/ArrayBuffer;I[Ljava/lang/String;[Ljava/lang/String;Lscala/collection/mutable/ArrayBuffer;)I
*/
-JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvoke
- (JNIEnv *, jobject, jlong, jlongArray, jlongArray, jobject, jint,
jobjectArray, jobjectArray);
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxImperativeInvokeEx
+ (JNIEnv *, jobject, jlong, jlongArray, jlongArray, jobject, jint,
jobjectArray, jobjectArray, jobject);
/*
* Class: org_apache_mxnet_LibInfo
@@ -73,6 +73,14 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayCreateEx
/*
* Class: org_apache_mxnet_LibInfo
+ * Method: mxNDArrayCreateSparseEx
+ * Signature: (I[IIIIIII[I[I[ILorg/apache/mxnet/Base/RefLong;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayCreateSparseEx
+ (JNIEnv *, jobject, jint, jintArray, jint, jint, jint, jint, jint, jint,
jintArray, jintArray, jintArray, jobject);
+
+/*
+ * Class: org_apache_mxnet_LibInfo
* Method: mxNDArrayWaitAll
* Signature: ()I
*/
@@ -137,6 +145,14 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayGetShape
/*
* Class: org_apache_mxnet_LibInfo
+ * Method: mxNDArraySyncCopyFromNDArray
+ * Signature: (JJI)I
+ */
+JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArraySyncCopyFromNDArray
+ (JNIEnv *, jobject, jlong, jlong, jint);
+
+/*
+ * Class: org_apache_mxnet_LibInfo
* Method: mxNDArraySyncCopyToCPU
* Signature: (J[BI)I
*/
@@ -201,6 +217,22 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArraySave
/*
* Class: org_apache_mxnet_LibInfo
+ * Method: mxNDArrayGetDataNDArray
+ * Signature: (JLorg/apache/mxnet/Base/RefLong;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetDataNDArray
+ (JNIEnv *, jobject, jlong, jobject);
+
+/*
+ * Class: org_apache_mxnet_LibInfo
+ * Method: mxNDArrayGetAuxNDArray
+ * Signature: (JILorg/apache/mxnet/Base/RefLong;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetAuxNDArray
+ (JNIEnv *, jobject, jlong, jint, jobject);
+
+/*
+ * Class: org_apache_mxnet_LibInfo
* Method: mxNDArrayGetContext
* Signature: (JLorg/apache/mxnet/Base/RefInt;Lorg/apache/mxnet/Base/RefInt;)I
*/
@@ -233,6 +265,14 @@ JNIEXPORT jint JNICALL
Java_org_apache_mxnet_LibInfo_mxNDArrayGetDType
/*
* Class: org_apache_mxnet_LibInfo
+ * Method: mxNDArrayGetStorageType
+ * Signature: (JLorg/apache/mxnet/Base/RefInt;)I
+ */
+JNIEXPORT jint JNICALL Java_org_apache_mxnet_LibInfo_mxNDArrayGetStorageType
+ (JNIEnv *, jobject, jlong, jobject);
+
+/*
+ * Class: org_apache_mxnet_LibInfo
* Method: mxInitPSEnv
* Signature: ([Ljava/lang/String;[Ljava/lang/String;)I
*/