This is an automated email from the ASF dual-hosted git repository.
nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new abbe283 [MXNET-689] add DataDesc type for the Scala Package (#11844)
abbe283 is described below
commit abbe283ef8b1d78b002cb492651f002ae27ba544
Author: Lanking <[email protected]>
AuthorDate: Fri Aug 17 10:07:30 2018 -0700
[MXNET-689] add DataDesc type for the Scala Package (#11844)
* add dataDesc
* Add amend
* add changes with dataLayout and labelLayout
* add depreciate and example changes
* Gan and Customop fixes
* change the DType
* add one more class to convert Strings to DTypes
* convert layout to global
* scala style fix
* Revert to 8c7d1f8
* fix coding style issue
* print full stacktraces
* apply changes to new constructor
* add databatch bcc
* introduce undefined field
* Fix crashes when change provideData to provideDataDesc
It looks like if we want to force conversion from Float32 to Int32 will
cause a crash on JVM. Need to be addressed.
* change spacing and revert test
* apply DataDesc on DataBatch
* unit test for NDArrayIter and MXDataiter
* apply changes on CR
* change NDArrayIter and revert the rest
* revert change on examples
* apply final changes
* remove the provideLabelShape
* add TODO about the findings
---
.../src/main/scala/org/apache/mxnet/DType.scala | 11 ++
.../core/src/main/scala/org/apache/mxnet/IO.scala | 121 +++++++++++++++------
.../src/main/scala/org/apache/mxnet/Layout.scala} | 28 ++---
.../src/main/scala/org/apache/mxnet/RecordIO.scala | 5 +-
.../scala/org/apache/mxnet/io/MXDataIter.scala | 35 ++++--
.../scala/org/apache/mxnet/io/NDArrayIter.scala | 91 +++++++++++-----
.../org/apache/mxnet/io/PrefetchingIter.scala | 69 ++++++++++--
.../scala/org/apache/mxnet/io/ResizeIter.scala | 15 ++-
.../src/test/scala/org/apache/mxnet/IOSuite.scala | 18 ++-
.../test/scala/org/apache/mxnet/ModuleSuite.scala | 6 +-
.../org/apache/mxnetexamples/multitask/Data.scala | 3 -
.../mxnetexamples/multitask/ExampleMultiTask.scala | 31 ++++--
.../org/apache/mxnetexamples/rnn/BucketIo.scala | 54 ++++++---
.../apache/mxnet/infer/ObjectDetectorSuite.scala | 8 +-
.../org/apache/mxnet/infer/PredictorSuite.scala | 16 ++-
scala-package/pom.xml | 1 +
.../apache/mxnet/spark/io/LabeledPointIter.scala | 16 ++-
.../mxnet/spark/io/LongLivingDataBatch.scala | 6 +-
.../org/apache/mxnet/spark/io/PointIter.scala | 16 ++-
19 files changed, 405 insertions(+), 145 deletions(-)
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
index 4458a7c..f3a8e8e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
@@ -35,4 +35,15 @@ object DType extends Enumeration {
case DType.Unknown => 0
}
}
+ private[mxnet] def getType(dtypeStr: String): DType = {
+ dtypeStr match {
+ case "UInt8" => DType.UInt8
+ case "Int32" => DType.Int32
+ case "Float16" => DType.Float16
+ case "Float32" => DType.Float32
+ case "Float64" => DType.Float64
+ case _ => throw new IllegalArgumentException(
+ s"DType: $dtypeStr not found! please set it in DType.scala")
+ }
+ }
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index 47fd4ee..a1095cf 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory
import scala.annotation.varargs
import scala.collection.immutable.ListMap
import scala.collection.mutable.ListBuffer
-
/**
* IO iterators for loading training & validation data
*/
@@ -110,18 +109,22 @@ object IO {
}
// Convert data into canonical form.
- private[mxnet] def initData(data: IndexedSeq[NDArray],
- allowEmpty: Boolean,
- defaultName: String): IndexedSeq[(String,
NDArray)] = {
+ private[mxnet] def initDataDesc(data: IndexedSeq[NDArray],
+ allowEmpty: Boolean,
+ defaultName: String,
+ defaultDType: DType,
+ defaultLayout: String):
IndexedSeq[(DataDesc, NDArray)] = {
require(data != null)
require(data != IndexedSeq.empty || allowEmpty)
if (data == IndexedSeq.empty) {
IndexedSeq()
} else if (data.length == 1) {
- IndexedSeq((defaultName, data(0)))
+ IndexedSeq((new DataDesc(defaultName, data(0).shape,
+ defaultDType, defaultLayout), data(0)))
} else {
data.zipWithIndex.map(item => {
- (defaultName + "_" + item._2, item._1)
+ (new DataDesc(defaultName + "_" + item._2, item._1.shape,
+ defaultDType, defaultLayout), item._1)
}).toIndexedSeq
}
}
@@ -136,11 +139,28 @@ class DataBatch(val data: IndexedSeq[NDArray],
val pad: Int,
// the key for the bucket that should be used for this batch,
// for bucketing io only
- val bucketKey: AnyRef = null,
- // use ListMap to indicate the order of data/label loading
+ val bucketKey: AnyRef,
+ // use DataDesc to indicate the order of data/label loading
// (must match the order of input data/label)
- private val providedData: ListMap[String, Shape] = null,
- private val providedLabel: ListMap[String, Shape] = null) {
+ private val providedDataDesc: IndexedSeq[DataDesc],
+ private val providedLabelDesc: IndexedSeq[DataDesc]) {
+ // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
+ // However, since the data and label can be accessed publicly (no getter and
setter)
+ // the change on this will break BC
+ def this(data: IndexedSeq[NDArray],
+ label: IndexedSeq[NDArray],
+ index: IndexedSeq[Long],
+ pad: Int,
+ // the key for the bucket that should be used for this batch,
+ // for bucketing io only
+ bucketKey: AnyRef = null,
+ // use ListMap to indicate the order of data/label loading
+ // (must match the order of input data/label)
+ providedData: ListMap[String, Shape] = null,
+ providedLabel: ListMap[String, Shape] = null) {
+ this(data, label, index, pad, bucketKey,
+ DataDesc.ListMap2Descs(providedData),
DataDesc.ListMap2Descs(providedLabel))
+ }
/**
* Dispose its data and labels
* The object shall never be used after it is disposed.
@@ -155,10 +175,29 @@ class DataBatch(val data: IndexedSeq[NDArray],
}
// The name and shape of data
- def provideData: ListMap[String, Shape] = providedData
+ def provideData: ListMap[String, Shape] = {
+ var temp = ListMap[String, Shape]()
+ if (providedDataDesc == null) null
+ else {
+ providedDataDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
+ temp
+ }
+ }
// The name and shape of label
- def provideLabel: ListMap[String, Shape] = providedLabel
+ def provideLabel: ListMap[String, Shape] = {
+ var temp = ListMap[String, Shape]()
+ if (providedLabelDesc == null) null
+ else {
+ providedLabelDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
+ temp
+ }
+ }
+
+ def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc
+
+ def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc
+
}
object DataBatch {
@@ -171,8 +210,8 @@ object DataBatch {
private var index: IndexedSeq[Long] = null
private var pad: Int = 0
private var bucketKey: AnyRef = null
- private var datatShapes: ListMap[String, Shape] = null
- private var labelShapes: ListMap[String, Shape] = null
+ private var dataDesc: IndexedSeq[DataDesc] = null
+ private var labelDesc: IndexedSeq[DataDesc] = null
/**
* Set the input data.
@@ -228,37 +267,27 @@ object DataBatch {
/**
* Provide the shape of a data.
- * @param name data name.
- * @param shape data shape.
+ * @param dataDesc DataDescriptor
* @return this.
*/
- def provideDataShape(name: String, shape: Shape): Builder = {
- if (datatShapes == null) {
- datatShapes = ListMap((name, shape))
- } else {
- datatShapes = datatShapes.updated(name, shape)
- }
+ def provideDataDesc(dataDesc: IndexedSeq[DataDesc]): Builder = {
+ this.dataDesc = dataDesc
this
}
/**
* Provide the shape of a label.
- * @param name label name.
- * @param shape label shape.
+ * @param labelDesc LabelDescriptor
* @return this.
*/
- def provideLabelShape(name: String, shape: Shape): Builder = {
- if (labelShapes == null) {
- labelShapes = ListMap((name, shape))
- } else {
- labelShapes = labelShapes.updated(name, shape)
- }
+ def provideLabelDesc(labelDesc: IndexedSeq[DataDesc]): Builder = {
+ this.labelDesc = labelDesc
this
}
def build(): DataBatch = {
require(data != null, "data is required.")
- new DataBatch(data, label, index, pad, bucketKey, datatShapes,
labelShapes)
+ new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc)
}
}
}
@@ -280,7 +309,8 @@ abstract class DataIter extends Iterator[DataBatch] {
*/
@throws(classOf[NoSuchElementException])
def next(): DataBatch = {
- new DataBatch(getData(), getLabel(), getIndex(), getPad())
+ new DataBatch(getData(), getLabel(), getIndex(), getPad(),
+ null, null, null)
}
/**
@@ -309,11 +339,19 @@ abstract class DataIter extends Iterator[DataBatch] {
def getIndex(): IndexedSeq[Long]
// The name and shape of data provided by this iterator
+ @deprecated
def provideData: ListMap[String, Shape]
// The name and shape of label provided by this iterator
+ @deprecated
def provideLabel: ListMap[String, Shape]
+ // Provide type:DataDesc of the data
+ def provideDataDesc: IndexedSeq[DataDesc]
+
+ // Provide type:DataDesc of the label
+ def provideLabelDesc: IndexedSeq[DataDesc]
+
// For bucketing io only
// The bucket key for the default symbol.
def defaultBucketKey: AnyRef = null
@@ -332,8 +370,9 @@ abstract class DataPack() extends Iterable[DataBatch] {
// Named data desc description contains name, shape, type and other extended
attributes.
case class DataDesc(name: String, shape: Shape,
- dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW")
{
- require(shape.length == layout.length, ("number of dimensions in shape :%d
with" +
+ dtype: DType = DType.Float32, layout: String =
Layout.UNDEFINED) {
+ require(layout == Layout.UNDEFINED || shape.length == layout.length,
+ ("number of dimensions in shape :%d with" +
" shape: %s should match the length of the layout: %d with layout: %s").
format(shape.length, shape.toString, layout.length, layout))
@@ -343,6 +382,8 @@ case class DataDesc(name: String, shape: Shape,
}
object DataDesc {
+
+ private val logger = LoggerFactory.getLogger(classOf[DataDesc])
/**
* Get the dimension that corresponds to the batch size.
* @param layout layout string. For example, "NCHW".
@@ -352,9 +393,19 @@ object DataDesc {
* for each data-parallelism device.
*/
def getBatchAxis(layout: Option[String]): Int = {
- layout.map(_.indexOf('N')).getOrElse(0)
+ if (layout.isEmpty|| layout.get == Layout.UNDEFINED) {
+ logger.warn("Found Undefined Layout, will use default index 0 for batch
axis")
+ 0
+ } else {
+ if (layout.get.contains('N')) {
+ layout.get.indexOf("N")
+ } else {
+ throw new IllegalArgumentException("no Batch Axis('N') found in
Layout!")
+ }
+ }
}
+ @deprecated
implicit def ListMap2Descs(shapes: ListMap[String, Shape]):
IndexedSeq[DataDesc] = {
if (shapes != null) {
shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
similarity index 64%
copy from
scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
copy to scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
index 339f7e2..cb75dbc 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
@@ -15,19 +15,21 @@
* limitations under the License.
*/
-package org.apache.mxnet.spark.io
-
-import org.apache.mxnet.{NDArray, DataBatch}
+package org.apache.mxnet
/**
- * Dispose only when 'disposeForce' called
- * @author Yizhi Liu
- */
-class LongLivingDataBatch(
- override val data: IndexedSeq[NDArray],
- override val label: IndexedSeq[NDArray],
- override val index: IndexedSeq[Long],
- override val pad: Int) extends DataBatch(data, label, index, pad) {
- override def dispose(): Unit = {}
- def disposeForce(): Unit = super.dispose()
+ * Layout definition of DataDesc
+ * N Batch size
+ * C channels
+ * H Height
+ * W Weight
+ * T sequence length
+ * __undefined__ default value of Layout
+ */
+object Layout {
+ val UNDEFINED = "__undefined__"
+ val NCHW = "NCHW"
+ val NTC = "NTC"
+ val NT = "NT"
+ val N = "N"
}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
index ee3e950..578f00a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
@@ -28,9 +28,6 @@ import java.io.ByteArrayInputStream
/**
* Scala interface for read/write RecordIO data format
- *
- * @author Depeng Liang
- *
* @param uri, path to recordIO file.
* @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
*/
@@ -144,7 +141,7 @@ object MXRecordIO {
*
* @author Depeng Liang
*
- * @param idx_path, path to index file
+ * @param idxPath, path to index file
* @param uri, path to recordIO file.
* @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
* @param keyType, data type for keys.
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index 2a0c333..f7f858d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -18,7 +18,8 @@
package org.apache.mxnet.io
import org.apache.mxnet.Base._
-import org.apache.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape,
WarnIfNotDisposed}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
import org.apache.mxnet.IO._
import org.slf4j.LoggerFactory
@@ -41,21 +42,31 @@ private[mxnet] class MXDataIter(private[mxnet] val handle:
DataIterHandle,
// fix me if any better way found)
private var currentBatch: DataBatch = null
- private val (_provideData: ListMap[String, Shape],
+ private val (_provideDataDesc: IndexedSeq[DataDesc],
+ _provideLabelDesc: IndexedSeq[DataDesc],
+ _provideData: ListMap[String, Shape],
_provideLabel: ListMap[String, Shape],
- _batchSize: Int) =
+ _batchSize: Int) = {
if (hasNext) {
iterNext()
val data = currentBatch.data(0)
val label = currentBatch.label(0)
// properties
- val res = (ListMap(dataName -> data.shape), ListMap(labelName ->
label.shape), data.shape(0))
+ val res = (
+ // TODO: need to allow user to specify DType and Layout
+ IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32,
Layout.UNDEFINED)),
+ IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32,
Layout.UNDEFINED)),
+ ListMap(dataName -> data.shape),
+ ListMap(labelName -> label.shape),
+ data.shape(0))
currentBatch.dispose()
reset()
res
} else {
- (null, null, 0)
+ (null, null, null, null, 0)
}
+ }
+
private var disposed = false
protected def isDisposed = disposed
@@ -101,10 +112,12 @@ private[mxnet] class MXDataIter(private[mxnet] val
handle: DataIterHandle,
private def iterNext(): Boolean = {
val next = new RefInt
checkCall(_LIB.mxDataIterNext(handle, next))
- currentBatch = null
if (next.value > 0) {
currentBatch = new DataBatch(data = getData(), label = getLabel(),
- index = getIndex(), pad = getPad())
+ index = getIndex(), pad = getPad(),
+ null, null, null)
+ } else {
+ currentBatch = null
}
next.value > 0
}
@@ -152,11 +165,19 @@ private[mxnet] class MXDataIter(private[mxnet] val
handle: DataIterHandle,
}
// The name and shape of data provided by this iterator
+ @deprecated
override def provideData: ListMap[String, Shape] = _provideData
// The name and shape of label provided by this iterator
+ @deprecated
override def provideLabel: ListMap[String, Shape] = _provideLabel
+ // Provide type:DataDesc of the data
+ override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+ // Provide type:DataDesc of the label
+ override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
override def hasNext: Boolean = {
if (currentBatch != null) {
true
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 1046131..e6be0ad 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -20,6 +20,7 @@ package org.apache.mxnet.io
import java.util.NoSuchElementException
import org.apache.mxnet.Base._
+import org.apache.mxnet.DType.DType
import org.apache.mxnet._
import org.slf4j.LoggerFactory
@@ -39,35 +40,35 @@ import scala.collection.immutable.ListMap
* the size of data does not match batch_size. Roll over is intended
* for training and can cause problems if used for prediction.
*/
-class NDArrayIter(data: IndexedSeq[(String, NDArray)],
- label: IndexedSeq[(String, NDArray)],
+class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
+ label: IndexedSeq[(DataDesc, NDArray)],
private val dataBatchSize: Int, shuffle: Boolean,
lastBatchHandle: String) extends DataIter {
/**
- * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
- * @param label Same as data, but is not fed to the model during testing.
- * Label names will be label_0, label_1, ..., etc.
- * @param dataBatchSize Batch Size
- * @param shuffle Whether to shuffle the data
- * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the
last batch
- *
- * This iterator will pad, discard or roll over the last batch if
- * the size of data does not match batch_size. Roll over is intended
- * for training and can cause problems if used for prediction.
- */
+ * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
+ * @param label Same as data, but is not fed to the model during testing.
+ * Label names will be label_0, label_1, ..., etc.
+ * @param dataBatchSize Batch Size
+ * @param shuffle Whether to shuffle the data
+ * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the
last batch
+ *
+ * This iterator will pad, discard or roll over the last batch if
+ * the size of data does not match batch_size. Roll over is intended
+ * for training and can cause problems if used for prediction.
+ */
def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] =
IndexedSeq.empty,
dataBatchSize: Int = 1, shuffle: Boolean = false,
lastBatchHandle: String = "pad",
dataName: String = "data", labelName: String = "label") {
- this(IO.initData(data, allowEmpty = false, dataName),
- IO.initData(label, allowEmpty = true, labelName),
+ this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE,
Layout.UNDEFINED),
+ IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE,
Layout.UNDEFINED),
dataBatchSize, shuffle, lastBatchHandle)
}
private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
- val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String,
NDArray)]) = {
+ val (initData: IndexedSeq[(DataDesc, NDArray)], initLabel:
IndexedSeq[(DataDesc, NDArray)]) = {
// data should not be null and size > 0
require(data != null && data.size > 0,
"data should not be null and data.size should not be zero")
@@ -101,20 +102,30 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
private var cursor = -dataBatchSize
private val (_provideData: ListMap[String, Shape],
- _provideLabel: ListMap[String, Shape]) = {
+ _provideLabel: ListMap[String, Shape],
+ _provideDataDesc: IndexedSeq[DataDesc],
+ _provideLabelDesc: IndexedSeq[DataDesc]) = {
val pData = ListMap.empty[String, Shape] ++ initData.map(getShape)
val pLabel = ListMap.empty[String, Shape] ++ initLabel.map(getShape)
- (pData, pLabel)
+ val pDData = IndexedSeq.empty[DataDesc] ++ initData.map(ele => {
+ val temp = getShape(ele)
+ new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout)
+ })
+ val pDLabel = IndexedSeq.empty[DataDesc] ++ initLabel.map(ele => {
+ val temp = getShape(ele)
+ new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout)
+ })
+ (pData, pLabel, pDData, pDLabel)
}
/**
* get shape via dataBatchSize
* @param dataItem
*/
- private def getShape(dataItem: (String, NDArray)): (String, Shape) = {
+ private def getShape(dataItem: (DataDesc, NDArray)): (String, Shape) = {
val len = dataItem._2.shape.size
val newShape = dataItem._2.shape.slice(1, len)
- (dataItem._1, Shape(Array[Int](dataBatchSize)) ++ newShape)
+ (dataItem._1.name, Shape(Array[Int](dataBatchSize)) ++ newShape)
}
@@ -148,7 +159,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
override def next(): DataBatch = {
if (hasNext) {
cursor += dataBatchSize
- new DataBatch(getData(), getLabel(), getIndex(), getPad())
+ new DataBatch(getData(), getLabel(), getIndex(), getPad(),
+ null, null, null)
} else {
throw new NoSuchElementException
}
@@ -172,7 +184,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
}
}
- private def _getData(data: IndexedSeq[(String, NDArray)]):
IndexedSeq[NDArray] = {
+ private def _getData(data: IndexedSeq[(DataDesc, NDArray)]):
IndexedSeq[NDArray] = {
require(cursor < numData, "DataIter needs reset.")
if (data == null) {
null
@@ -223,12 +235,21 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
}
}
+
// The name and shape of data provided by this iterator
+ @deprecated
override def provideData: ListMap[String, Shape] = _provideData
// The name and shape of label provided by this iterator
+ @deprecated
override def provideLabel: ListMap[String, Shape] = _provideLabel
+ // Provide type:DataDesc of the data
+ override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+ // Provide type:DataDesc of the label
+ override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
override def batchSize: Int = dataBatchSize
}
@@ -238,8 +259,8 @@ object NDArrayIter {
* Builder class for NDArrayIter.
*/
class Builder() {
- private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
- private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+ private var data: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty
+ private var label: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty
private var dataBatchSize: Int = 1
private var lastBatchHandle: String = "pad"
@@ -250,7 +271,8 @@ object NDArrayIter {
* @return The builder object itself.
*/
def addData(name: String, data: NDArray): Builder = {
- this.data = this.data ++ IndexedSeq((name, data))
+ this.data = this.data ++ IndexedSeq((new DataDesc(name,
+ data.shape, DType.Float32, Layout.UNDEFINED), data))
this
}
@@ -261,7 +283,24 @@ object NDArrayIter {
* @return The builder object itself.
*/
def addLabel(name: String, label: NDArray): Builder = {
- this.label = this.label ++ IndexedSeq((name, label))
+ this.label = this.label ++ IndexedSeq((new DataDesc(name,
+ label.shape, DType.Float32, Layout.UNDEFINED), label))
+ this
+ }
+
+ /**
+ * Add one data input with its DataDesc
+ */
+ def addDataWithDesc(dataDesc: DataDesc, data: NDArray): Builder = {
+ this.data = this.data ++ IndexedSeq((dataDesc, data))
+ this
+ }
+
+ /**
+ * Add one label input with its DataDesc
+ */
+ def addLabelWithDesc(labelDesc: DataDesc, label: NDArray): Builder = {
+ this.data = this.data ++ IndexedSeq((labelDesc, label))
this
}
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
index c0c0d17..e59e370 100644
---
a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
+++
b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
@@ -17,10 +17,12 @@
package org.apache.mxnet.io
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet._
import org.slf4j.LoggerFactory
import java.util.concurrent.Semaphore
+import org.apache.mxnet.DType.DType
+
import scala.collection.immutable.ListMap
/**
@@ -68,6 +70,42 @@ class PrefetchingIter(
}
}
+ private val _provideDataDesc: IndexedSeq[DataDesc] = {
+ if (dataNames == null) {
+ iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc,
elem) =>
+ acc ++ elem
+ }
+ } else {
+ iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2))
+ .map(m =>
+ m._1.map(t =>
+ new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout)
+ )
+ )
+ .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
+ acc ++ elem
+ }
+ }
+ }
+
+ private val _provideLabelDesc: IndexedSeq[DataDesc] = {
+ if (labelNames == null) {
+ iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc,
elem) =>
+ acc ++ elem
+ }
+ } else {
+ iters.zipWithIndex.map(tu => (tu._1.provideLabelDesc, tu._2))
+ .map(m =>
+ m._1.map(t =>
+ new DataDesc(labelNames(m._2)(t.name), t.shape, t.dtype, t.layout)
+ )
+ )
+ .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
+ acc ++ elem
+ }
+ }
+ }
+
private val _batchSize: Int = this._provideData.toList(0)._2(0)
private val dataReady: IndexedSeq[Semaphore] =
(0 until iters.length).map(i => new
Semaphore(0))
@@ -132,19 +170,27 @@ class PrefetchingIter(
*/
override def getIndex(): IndexedSeq[Long] = currentBatch.index
- // The name and shape of label provided by this iterator
- override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
/**
- * get the number of padding examples
- * in current batch
- * @return number of padding examples in current batch
- */
+ * get the number of padding examples
+ * in current batch
+ * @return number of padding examples in current batch
+ */
override def getPad(): Int = this.currentBatch.pad
+ // The name and shape of label provided by this iterator
+ @deprecated
+ override def provideLabel: ListMap[String, Shape] = this._provideLabel
+
// The name and shape of data provided by this iterator
+ @deprecated
override def provideData: ListMap[String, Shape] = this._provideData
+ // Provide type:DataDesc of the data
+ override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+ // Provide type:DataDesc of the label
+ override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
override def hasNext: Boolean = {
for (e <- dataReady) e.acquire()
if (nextBatch(0) == null) {
@@ -161,9 +207,10 @@ class PrefetchingIter(
val datas = for (batch <- nextBatch) yield batch.data
val labels = for (batch <- nextBatch) yield batch.label
currentBatch = new DataBatch(datas.toIndexedSeq.flatten,
- labels.toIndexedSeq.flatten,
- nextBatch(0).index,
- nextBatch(0).pad)
+ labels.toIndexedSeq.flatten,
+ nextBatch(0).index,
+ nextBatch(0).pad,
+ null, null, null)
for (e <- dataTaken) e.release()
true
}
diff --git
a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
index 75d88d1..e840af9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
@@ -19,7 +19,8 @@ package org.apache.mxnet.io
import java.util.NoSuchElementException
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
import org.slf4j.LoggerFactory
import scala.collection.immutable.ListMap
@@ -133,12 +134,24 @@ class ResizeIter(
}
// The name and shape of data provided by this iterator
+ @deprecated
override def provideData: ListMap[String, Shape] = {
dataIter.provideData
}
// The name and shape of label provided by this iterator
+ @deprecated
override def provideLabel: ListMap[String, Shape] = {
dataIter.provideLabel
}
+
+ // The name and shape of data provided by this iterator
+ override def provideDataDesc: IndexedSeq[DataDesc] = {
+ dataIter.provideDataDesc
+ }
+
+ // The name and shape of label provided by this iterator
+ override def provideLabelDesc: IndexedSeq[DataDesc] = {
+ dataIter.provideLabelDesc
+ }
}
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 1b922b3..2ec6f66 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -243,7 +243,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
val batchLabel = NDArray.ones(Shape(Array(128, 1)))
// test pad
- val dataIter0 = new NDArrayIter(data, label, 128, false, "pad")
+ val dataIter0 = new NDArrayIter(data, label, 128, false, "pad",
+ dataName = "data", labelName = "label")
var batchCount = 0
val nBatch0 = 8
while(dataIter0.hasNext) {
@@ -277,7 +278,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(batchCount === nBatch1)
// test empty label (for prediction)
- val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128,
lastBatchHandle = "discard")
+ val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, shuffle
= false,
+ lastBatchHandle = "discard")
batchCount = 0
while(dataIter2.hasNext) {
val tBatch = dataIter2.next()
@@ -289,5 +291,17 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
assert(batchCount === nBatch1)
assert(dataIter2.initLabel == IndexedSeq.empty)
+
+ // test implementation with DataDesc
+ val dataIter3 = new NDArrayIter(
+ IO.initDataDesc(data, false, "data", DType.Float32, Layout.NTC),
+ IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+ 128, false, "pad")
+ val dataDesc = dataIter3.provideDataDesc
+ val labelDesc = dataIter3.provideLabelDesc
+ assert(dataDesc(0).dtype == DType.Float32)
+ assert(dataDesc(0).layout == Layout.NTC)
+ assert(labelDesc(0).dtype == DType.Int32)
+ assert(labelDesc(0).layout == Layout.NT)
}
}
diff --git
a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 8234568..88e314e 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -24,7 +24,7 @@ import org.apache.mxnet.io._
class ModuleSuite extends FunSuite with BeforeAndAfterAll {
test ("model dtype") {
- val dType = DType.Float16
+ val dType = DType.Float32
val dShape = Shape(3, 8, 7)
var sym = Symbol.Variable("data")
@@ -196,8 +196,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
// create module
val mod = new Module(x, contexts = Array(Context.cpu()))
- mod.bind(dataShapes = trainData.provideData,
- Option(trainData.provideLabel))
+ mod.bind(dataShapes = trainData.provideDataDesc,
+ Option(trainData.provideLabelDesc))
val argParamsCorrect = Map(
"fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2,
2)),
"fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
index bb17046..068aa63 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
@@ -21,9 +21,6 @@ import org.apache.mxnet.Shape
import org.apache.mxnet.IO
import org.apache.mxnet.DataIter
-/**
- * @author Depeng Liang
- */
object Data {
// return train and val iterators for mnist
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
index 9df2bcc..825e465 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
@@ -25,14 +25,9 @@ import org.slf4j.LoggerFactory
import scala.collection.JavaConverters._
import org.apache.commons.io.FileUtils
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.DataIter
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Shape
-import org.apache.mxnet.EvalMetric
-import org.apache.mxnet.Context
-import org.apache.mxnet.Xavier
+
+import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric,
NDArray, Shape, Symbol, Xavier}
+import org.apache.mxnet.DType.DType
import org.apache.mxnet.optimizer.RMSProp
import org.apache.mxnet.Executor
import org.apache.mxnetexamples.Util
@@ -70,9 +65,9 @@ object ExampleMultiTask {
val batch = this.dataIter.next()
val label = batch.label(0)
new DataBatch(batch.data,
- IndexedSeq(label, label),
- batch.index,
- batch.pad)
+ IndexedSeq(label, label),
+ batch.index,
+ batch.pad, null, null, null)
} else {
throw new NoSuchElementException
}
@@ -107,6 +102,7 @@ object ExampleMultiTask {
override def getIndex(): IndexedSeq[Long] = this.dataIter.getIndex()
// The name and shape of label provided by this iterator
+ @deprecated
override def provideLabel: ListMap[String, Shape] = {
val provideLabel = this.dataIter.provideLabel.toArray
// Different labels should be used here for actual application
@@ -114,6 +110,16 @@ object ExampleMultiTask {
"softmax2_label" -> provideLabel(0)._2)
}
+ // The name and shape of label provided by this iterator
+ override def provideLabelDesc: IndexedSeq[DataDesc] = {
+ val head = this.dataIter.provideLabelDesc(0)
+ // Different labels should be used here for actual application
+ IndexedSeq(
+ new DataDesc("softmax1_label", head.shape, head.dtype, head.layout),
+ new DataDesc("softmax2_label", head.shape, head.dtype, head.layout)
+ )
+ }
+
/**
* get the number of padding examples
* in current batch
@@ -122,8 +128,11 @@ object ExampleMultiTask {
override def getPad(): Int = this.dataIter.getPad()
// The name and shape of data provided by this iterator
+ @deprecated
override def provideData: ListMap[String, Shape] =
this.dataIter.provideData
+ override def provideDataDesc: IndexedSeq[DataDesc] =
this.dataIter.provideDataDesc
+
override def hasNext: Boolean = this.dataIter.hasNext
}
diff --git
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index f0eae68..d4b1707 100644
---
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -18,17 +18,16 @@
package org.apache.mxnetexamples.rnn
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
import org.slf4j.LoggerFactory
+
import scala.collection.immutable.ListMap
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.util.Random
import scala.collection.mutable
-/**
- * @author Depeng Liang
- */
object BucketIo {
type Text2Id = (String, Map[String, Int]) => Array[Int]
@@ -92,11 +91,14 @@ object BucketIo {
}
class BucketSentenceIter(
- path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int],
- _batchSize: Int, private val initStates: IndexedSeq[(String, (Int,
Int))],
- seperateChar: String = " <eos> ", text2Id: Text2Id = defaultText2Id,
+ path: String,
+ vocab: Map[String, Int],
+ var buckets: IndexedSeq[Int],
+ _batchSize: Int,
+ private val initStates: IndexedSeq[(String, (Int, Int))],
+ seperateChar: String = " <eos> ",
+ text2Id: Text2Id = defaultText2Id,
readContent: ReadContent = defaultReadContent) extends DataIter {
-
private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter])
private val content = readContent(path)
@@ -165,8 +167,22 @@ object BucketIo {
private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize,
_defaultBucketKey))
tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
}
+
private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize,
_defaultBucketKey))
+ private val _provideDataDesc = {
+ // TODO: need to allow user to specify DType and Layout
+ val tmp = IndexedSeq(new DataDesc("data",
+ Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED))
+ tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2),
+ DType.Float32, Layout.UNDEFINED))
+ }
+
+ private val _provideLabelDesc = IndexedSeq(
+ // TODO: need to allow user to specify DType and Layout
+ new DataDesc("softmax_label",
+ Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED))
+
private var iBucket = 0
override def next(): DataBatch = {
@@ -228,19 +244,27 @@ object BucketIo {
*/
override def getIndex(): IndexedSeq[Long] = IndexedSeq[Long]()
- // The name and shape of label provided by this iterator
- override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
/**
- * get the number of padding examples
- * in current batch
- * @return number of padding examples in current batch
- */
+ * get the number of padding examples
+ * in current batch
+ * @return number of padding examples in current batch
+ */
override def getPad(): Int = 0
+ // The name and shape of label provided by this iterator
+ @deprecated
+ override def provideLabel: ListMap[String, Shape] = this._provideLabel
+
// The name and shape of data provided by this iterator
+ @deprecated
override def provideData: ListMap[String, Shape] = this._provideData
+ // Provide type:DataDesc of the data
+ override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+ // Provide type:DataDesc of the label
+ override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
override def hasNext: Boolean = {
iBucket < bucketPlan.length
}
diff --git
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
index 8160f0f..39139f8 100644
---
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
+++
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
@@ -19,6 +19,8 @@ package org.apache.mxnet.infer
// scalastyle:off
import java.awt.image.BufferedImage
+
+import org.apache.mxnet.{DType, Layout}
// scalastyle:on
import org.apache.mxnet.Context
import org.apache.mxnet.DataDesc
@@ -69,7 +71,8 @@ class ObjectDetectorSuite extends ClassifierSuite with
BeforeAndAfterAll {
}
test("objectDetectWithInputImage") {
- val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath,
Shape(1, 3, 512, 512)))
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath,
Shape(1, 3, 512, 512),
+ DType.Float32, Layout.NCHW))
val inputImage = new BufferedImage(512, 512, BufferedImage.TYPE_INT_RGB)
val testObjectDetector: ObjectDetector =
new MyObjectDetector(modelPath, inputDescriptor)
@@ -109,7 +112,8 @@ class ObjectDetectorSuite extends ClassifierSuite with
BeforeAndAfterAll {
}
test("objectDetectWithBatchImages") {
- val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath,
Shape(1, 3, 512, 512)))
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath,
Shape(1, 3, 512, 512),
+ DType.Float32, Layout.NCHW))
val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB)
val imageBatch = IndexedSeq[BufferedImage](inputImage, inputImage)
diff --git
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
index 53fd7f3..509ffb3 100644
---
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
+++
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.infer
import org.apache.mxnet.io.NDArrayIter
import org.apache.mxnet.module.{BaseModule, Module}
-import org.apache.mxnet.{DataDesc, NDArray, Shape}
+import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape}
import org.mockito.Matchers._
import org.mockito.Mockito
import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -40,15 +40,17 @@ class PredictorSuite extends FunSuite with
BeforeAndAfterAll {
}
test("PredictorSuite-testPredictorConstruction") {
- val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1,
3, 2, 2)))
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1,
3, 2, 2),
+ layout = Layout.NCHW))
val mockPredictor = new MyPredictor("xyz", inputDescriptor)
assert(mockPredictor.getBatchSize == 1)
assert(mockPredictor.getBatchIndex ==
inputDescriptor(0).layout.indexOf('N'))
- val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1,
3, 2, 2)),
- new DataDesc("data", Shape(2, 3, 2, 2)))
+ val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1,
3, 2, 2),
+ layout = Layout.NCHW),
+ new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW))
assertThrows[IllegalArgumentException] {
val mockPredictor = new MyPredictor("xyz", inputDescriptor2)
@@ -63,7 +65,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
test("PredictorSuite-testWithFlatArrays") {
- val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2,
3, 2, 2)))
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2,
3, 2, 2),
+ layout = Layout.NCHW))
val inputData = Array.fill[Float](12)(1)
// this will disposed at the end of the predict call on Predictor.
@@ -89,7 +92,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
}
test("PredictorSuite-testWithNDArray") {
- val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2,
3, 2, 2)))
+ val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2,
3, 2, 2),
+ layout = Layout.NCHW))
val inputData = NDArray.ones(Shape(1, 3, 2, 2))
// this will disposed at the end of the predict call on Predictor.
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 3511f4a..c221b47 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -231,6 +231,7 @@
<skipTests>${skipTests}</skipTests>
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
<junitxml>.</junitxml>
+ <stdout>F</stdout>
<filereports>WDF TestSuite.txt</filereports>
</configuration>
<executions>
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
index adc723e..bf1b26e 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
@@ -17,7 +17,8 @@
package org.apache.mxnet.spark.io
-import org.apache.mxnet.{DataBatch, NDArray, Shape, DataIter}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
import org.apache.spark.mllib.regression.LabeledPoint
import scala.collection.immutable.ListMap
@@ -25,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer
/**
* A helper converter for LabeledPoint
- * @author Yizhi Liu
*/
class LabeledPointIter private[mxnet](
private val points: Iterator[LabeledPoint],
@@ -115,15 +115,27 @@ class LabeledPointIter private[mxnet](
}
// The name and shape of label provided by this iterator
+ @deprecated
override def provideLabel: ListMap[String, Shape] = {
ListMap(labelName -> Shape(_batchSize))
}
// The name and shape of data provided by this iterator
+ @deprecated
override def provideData: ListMap[String, Shape] = {
ListMap(dataName -> dataShape)
}
+ override def provideDataDesc: IndexedSeq[DataDesc] = {
+ // TODO: need to allow user to specify DType and Layout
+ IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32,
Layout.UNDEFINED))
+ }
+
+ override def provideLabelDesc: IndexedSeq[DataDesc] = {
+ // TODO: need to allow user to specify DType and Layout
+ IndexedSeq(new DataDesc(labelName, Shape(_batchSize), DType.Float32,
Layout.UNDEFINED))
+ }
+
/**
* Get the number of padding examples
* in current batch
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
index 339f7e2..e3272a4 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
@@ -17,7 +17,8 @@
package org.apache.mxnet.spark.io
-import org.apache.mxnet.{NDArray, DataBatch}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.{DataBatch, NDArray}
/**
* Dispose only when 'disposeForce' called
@@ -27,7 +28,8 @@ class LongLivingDataBatch(
override val data: IndexedSeq[NDArray],
override val label: IndexedSeq[NDArray],
override val index: IndexedSeq[Long],
- override val pad: Int) extends DataBatch(data, label, index, pad) {
+ override val pad: Int) extends DataBatch(data, label, index, pad,
+ null, null, null) {
override def dispose(): Unit = {}
def disposeForce(): Unit = super.dispose()
}
diff --git
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
index 2132929..a955ee7 100644
---
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
+++
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
@@ -17,7 +17,8 @@
package org.apache.mxnet.spark.io
-import org.apache.mxnet.{NDArray, DataBatch, DataIter, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
import org.apache.spark.mllib.linalg.Vector
import scala.collection.immutable.ListMap
@@ -25,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer
/**
* A temporary helper implementation for predicting Vectors
- * @author Yizhi Liu
*/
class PointIter private[mxnet](
private val points: Iterator[Vector],
@@ -114,15 +114,27 @@ class PointIter private[mxnet](
}
// The name and shape of label provided by this iterator
+ @deprecated
override def provideLabel: ListMap[String, Shape] = {
ListMap(labelName -> Shape(_batchSize))
}
// The name and shape of data provided by this iterator
+ @deprecated
override def provideData: ListMap[String, Shape] = {
ListMap(dataName -> dataShape)
}
+ override def provideDataDesc: IndexedSeq[DataDesc] = {
+ // TODO: Make DType, Layout configurable
+ IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32,
Layout.UNDEFINED))
+ }
+
+ override def provideLabelDesc: IndexedSeq[DataDesc] = {
+ IndexedSeq(new DataDesc(labelName, Shape(_batchSize),
+ DType.Float32, Layout.UNDEFINED))
+ }
+
/**
* Get the number of padding examples
* in current batch