This is an automated email from the ASF dual-hosted git repository.

nswamy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git


The following commit(s) were added to refs/heads/master by this push:
     new abbe283  [MXNET-689] add DataDesc type for the Scala Package (#11844)
abbe283 is described below

commit abbe283ef8b1d78b002cb492651f002ae27ba544
Author: Lanking <[email protected]>
AuthorDate: Fri Aug 17 10:07:30 2018 -0700

    [MXNET-689] add DataDesc type for the Scala Package (#11844)
    
    * add dataDesc
    
    * Add amend
    
    * add changes with dataLayout and labelLayout
    
    * add depreciate and example changes
    
    * Gan and Customop fixes
    
    * change the DType
    
    * add one more class to convert Strings to DTypes
    
    * convert layout to global
    
    * scala style fix
    
    * Revert to 8c7d1f8
    
    * fix coding style issue
    
    * print full stacktraces
    
    * apply changes to new constructor
    
    * add databatch bcc
    
    * introduce undefined field
    
    * Fix crashes when change provideData to provideDataDesc
    
    It looks like if we want to force conversion from Float32 to Int32 will 
cause a crash on JVM. Need to be addressed.
    
    * change spacing and revert test
    
    * apply DataDesc on DataBatch
    
    * unit test for NDArrayIter and MXDataiter
    
    * apply changes on CR
    
    * change NDArrayIter and revert the rest
    
    * revert change on examples
    
    * apply final changes
    
    * remove the provideLabelShape
    
    * add TODO about the findings
---
 .../src/main/scala/org/apache/mxnet/DType.scala    |  11 ++
 .../core/src/main/scala/org/apache/mxnet/IO.scala  | 121 +++++++++++++++------
 .../src/main/scala/org/apache/mxnet/Layout.scala}  |  28 ++---
 .../src/main/scala/org/apache/mxnet/RecordIO.scala |   5 +-
 .../scala/org/apache/mxnet/io/MXDataIter.scala     |  35 ++++--
 .../scala/org/apache/mxnet/io/NDArrayIter.scala    |  91 +++++++++++-----
 .../org/apache/mxnet/io/PrefetchingIter.scala      |  69 ++++++++++--
 .../scala/org/apache/mxnet/io/ResizeIter.scala     |  15 ++-
 .../src/test/scala/org/apache/mxnet/IOSuite.scala  |  18 ++-
 .../test/scala/org/apache/mxnet/ModuleSuite.scala  |   6 +-
 .../org/apache/mxnetexamples/multitask/Data.scala  |   3 -
 .../mxnetexamples/multitask/ExampleMultiTask.scala |  31 ++++--
 .../org/apache/mxnetexamples/rnn/BucketIo.scala    |  54 ++++++---
 .../apache/mxnet/infer/ObjectDetectorSuite.scala   |   8 +-
 .../org/apache/mxnet/infer/PredictorSuite.scala    |  16 ++-
 scala-package/pom.xml                              |   1 +
 .../apache/mxnet/spark/io/LabeledPointIter.scala   |  16 ++-
 .../mxnet/spark/io/LongLivingDataBatch.scala       |   6 +-
 .../org/apache/mxnet/spark/io/PointIter.scala      |  16 ++-
 19 files changed, 405 insertions(+), 145 deletions(-)

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
index 4458a7c..f3a8e8e 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
@@ -35,4 +35,15 @@ object DType extends Enumeration {
       case DType.Unknown => 0
     }
   }
+  private[mxnet] def getType(dtypeStr: String): DType = {
+    dtypeStr match {
+      case "UInt8" => DType.UInt8
+      case "Int32" => DType.Int32
+      case "Float16" => DType.Float16
+      case "Float32" => DType.Float32
+      case "Float64" => DType.Float64
+      case _ => throw new IllegalArgumentException(
+        s"DType: $dtypeStr not found! please set it in DType.scala")
+    }
+  }
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index 47fd4ee..a1095cf 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory
 import scala.annotation.varargs
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ListBuffer
-
 /**
  * IO iterators for loading training & validation data
  */
@@ -110,18 +109,22 @@ object IO {
   }
 
   // Convert data into canonical form.
-  private[mxnet] def initData(data: IndexedSeq[NDArray],
-                              allowEmpty: Boolean,
-                              defaultName: String): IndexedSeq[(String, 
NDArray)] = {
+  private[mxnet] def initDataDesc(data: IndexedSeq[NDArray],
+                                  allowEmpty: Boolean,
+                                  defaultName: String,
+                                  defaultDType: DType,
+                                  defaultLayout: String): 
IndexedSeq[(DataDesc, NDArray)] = {
     require(data != null)
     require(data != IndexedSeq.empty || allowEmpty)
     if (data == IndexedSeq.empty) {
       IndexedSeq()
     } else if (data.length == 1) {
-      IndexedSeq((defaultName, data(0)))
+      IndexedSeq((new DataDesc(defaultName, data(0).shape,
+        defaultDType, defaultLayout), data(0)))
     } else {
       data.zipWithIndex.map(item => {
-        (defaultName + "_" + item._2, item._1)
+        (new DataDesc(defaultName + "_" + item._2, item._1.shape,
+          defaultDType, defaultLayout), item._1)
       }).toIndexedSeq
     }
   }
@@ -136,11 +139,28 @@ class DataBatch(val data: IndexedSeq[NDArray],
                 val pad: Int,
                 // the key for the bucket that should be used for this batch,
                 // for bucketing io only
-                val bucketKey: AnyRef = null,
-                // use ListMap to indicate the order of data/label loading
+                val bucketKey: AnyRef,
+                // use DataDesc to indicate the order of data/label loading
                 // (must match the order of input data/label)
-                private val providedData: ListMap[String, Shape] = null,
-                private val providedLabel: ListMap[String, Shape] = null) {
+                private val providedDataDesc: IndexedSeq[DataDesc],
+                private val providedLabelDesc: IndexedSeq[DataDesc]) {
+  // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
+  // However, since the data and label can be accessed publicly (no getter and 
setter)
+  // the change on this will break BC
+  def this(data: IndexedSeq[NDArray],
+            label: IndexedSeq[NDArray],
+            index: IndexedSeq[Long],
+            pad: Int,
+            // the key for the bucket that should be used for this batch,
+            // for bucketing io only
+            bucketKey: AnyRef = null,
+            // use ListMap to indicate the order of data/label loading
+            // (must match the order of input data/label)
+            providedData: ListMap[String, Shape] = null,
+            providedLabel: ListMap[String, Shape] = null) {
+    this(data, label, index, pad, bucketKey,
+      DataDesc.ListMap2Descs(providedData), 
DataDesc.ListMap2Descs(providedLabel))
+  }
   /**
    * Dispose its data and labels
    * The object shall never be used after it is disposed.
@@ -155,10 +175,29 @@ class DataBatch(val data: IndexedSeq[NDArray],
   }
 
   // The name and shape of data
-  def provideData: ListMap[String, Shape] = providedData
+  def provideData: ListMap[String, Shape] = {
+    var temp = ListMap[String, Shape]()
+    if (providedDataDesc == null) null
+    else {
+      providedDataDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
+      temp
+    }
+  }
 
   // The name and shape of label
-  def provideLabel: ListMap[String, Shape] = providedLabel
+  def provideLabel: ListMap[String, Shape] = {
+    var temp = ListMap[String, Shape]()
+    if (providedLabelDesc == null) null
+    else {
+      providedLabelDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
+      temp
+    }
+  }
+
+  def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc
+
+  def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc
+
 }
 
 object DataBatch {
@@ -171,8 +210,8 @@ object DataBatch {
     private var index: IndexedSeq[Long] = null
     private var pad: Int = 0
     private var bucketKey: AnyRef = null
-    private var datatShapes: ListMap[String, Shape] = null
-    private var labelShapes: ListMap[String, Shape] = null
+    private var dataDesc: IndexedSeq[DataDesc] = null
+    private var labelDesc: IndexedSeq[DataDesc] = null
 
     /**
      * Set the input data.
@@ -228,37 +267,27 @@ object DataBatch {
 
     /**
      * Provide the shape of a data.
-     * @param name data name.
-     * @param shape data shape.
+     * @param dataDesc DataDescriptor
      * @return this.
      */
-    def provideDataShape(name: String, shape: Shape): Builder = {
-      if (datatShapes == null) {
-        datatShapes = ListMap((name, shape))
-      } else {
-        datatShapes = datatShapes.updated(name, shape)
-      }
+    def provideDataDesc(dataDesc: IndexedSeq[DataDesc]): Builder = {
+      this.dataDesc = dataDesc
       this
     }
 
     /**
      * Provide the shape of a label.
-     * @param name label name.
-     * @param shape label shape.
+     * @param labelDesc LabelDescriptor
      * @return this.
      */
-    def provideLabelShape(name: String, shape: Shape): Builder = {
-      if (labelShapes == null) {
-        labelShapes = ListMap((name, shape))
-      } else {
-        labelShapes = labelShapes.updated(name, shape)
-      }
+    def provideLabelDesc(labelDesc: IndexedSeq[DataDesc]): Builder = {
+      this.labelDesc = labelDesc
       this
     }
 
     def build(): DataBatch = {
       require(data != null, "data is required.")
-      new DataBatch(data, label, index, pad, bucketKey, datatShapes, 
labelShapes)
+      new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc)
     }
   }
 }
@@ -280,7 +309,8 @@ abstract class DataIter extends Iterator[DataBatch] {
    */
   @throws(classOf[NoSuchElementException])
   def next(): DataBatch = {
-    new DataBatch(getData(), getLabel(), getIndex(), getPad())
+    new DataBatch(getData(), getLabel(), getIndex(), getPad(),
+      null, null, null)
   }
 
   /**
@@ -309,11 +339,19 @@ abstract class DataIter extends Iterator[DataBatch] {
   def getIndex(): IndexedSeq[Long]
 
   // The name and shape of data provided by this iterator
+  @deprecated
   def provideData: ListMap[String, Shape]
 
   // The name and shape of label provided by this iterator
+  @deprecated
   def provideLabel: ListMap[String, Shape]
 
+  // Provide type:DataDesc of the data
+  def provideDataDesc: IndexedSeq[DataDesc]
+
+  // Provide type:DataDesc of the label
+  def provideLabelDesc: IndexedSeq[DataDesc]
+
   // For bucketing io only
   // The bucket key for the default symbol.
   def defaultBucketKey: AnyRef = null
@@ -332,8 +370,9 @@ abstract class DataPack() extends Iterable[DataBatch] {
 
 // Named data desc description contains name, shape, type and other extended 
attributes.
 case class DataDesc(name: String, shape: Shape,
-                    dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") 
{
-  require(shape.length == layout.length, ("number of dimensions in shape :%d 
with" +
+                    dtype: DType = DType.Float32, layout: String = 
Layout.UNDEFINED) {
+  require(layout == Layout.UNDEFINED || shape.length == layout.length,
+    ("number of dimensions in shape :%d with" +
     " shape: %s should match the length of the layout: %d with layout: %s").
     format(shape.length, shape.toString, layout.length, layout))
 
@@ -343,6 +382,8 @@ case class DataDesc(name: String, shape: Shape,
 }
 
 object DataDesc {
+
+  private val logger = LoggerFactory.getLogger(classOf[DataDesc])
   /**
    * Get the dimension that corresponds to the batch size.
    * @param layout layout string. For example, "NCHW".
@@ -352,9 +393,19 @@ object DataDesc {
    *         for each data-parallelism device.
    */
   def getBatchAxis(layout: Option[String]): Int = {
-    layout.map(_.indexOf('N')).getOrElse(0)
+    if (layout.isEmpty|| layout.get == Layout.UNDEFINED) {
+      logger.warn("Found Undefined Layout, will use default index 0 for batch 
axis")
+      0
+    } else {
+      if (layout.get.contains('N')) {
+        layout.get.indexOf("N")
+      } else {
+        throw new IllegalArgumentException("no Batch Axis('N') found in 
Layout!")
+      }
+    }
   }
 
+  @deprecated
   implicit def ListMap2Descs(shapes: ListMap[String, Shape]): 
IndexedSeq[DataDesc] = {
     if (shapes != null) {
       shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
 b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
similarity index 64%
copy from 
scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
copy to scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
index 339f7e2..cb75dbc 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
@@ -15,19 +15,21 @@
  * limitations under the License.
  */
 
-package org.apache.mxnet.spark.io
-
-import org.apache.mxnet.{NDArray, DataBatch}
+package org.apache.mxnet
 
 /**
- * Dispose only when 'disposeForce' called
- * @author Yizhi Liu
- */
-class LongLivingDataBatch(
-  override val data: IndexedSeq[NDArray],
-  override val label: IndexedSeq[NDArray],
-  override val index: IndexedSeq[Long],
-  override val pad: Int) extends DataBatch(data, label, index, pad) {
-  override def dispose(): Unit = {}
-  def disposeForce(): Unit = super.dispose()
+  * Layout definition of DataDesc
+  * N Batch size
+  * C channels
+  * H Height
+  * W Weight
+  * T sequence length
+  * __undefined__ default value of Layout
+  */
+object Layout {
+  val UNDEFINED = "__undefined__"
+  val NCHW = "NCHW"
+  val NTC = "NTC"
+  val NT = "NT"
+  val N = "N"
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
index ee3e950..578f00a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
@@ -28,9 +28,6 @@ import java.io.ByteArrayInputStream
 
 /**
  * Scala interface for read/write RecordIO data format
- *
- * @author Depeng Liang
- *
  * @param uri, path to recordIO file.
  * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
  */
@@ -144,7 +141,7 @@ object MXRecordIO {
  *
  * @author Depeng Liang
  *
- * @param idx_path, path to index file
+ * @param idxPath, path to index file
  * @param uri, path to recordIO file.
  * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
  * @param keyType, data type for keys.
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index 2a0c333..f7f858d 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -18,7 +18,8 @@
 package org.apache.mxnet.io
 
 import org.apache.mxnet.Base._
-import org.apache.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape, 
WarnIfNotDisposed}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.apache.mxnet.IO._
 import org.slf4j.LoggerFactory
 
@@ -41,21 +42,31 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: 
DataIterHandle,
   // fix me if any better way found)
   private var currentBatch: DataBatch = null
 
-  private val (_provideData: ListMap[String, Shape],
+  private val (_provideDataDesc: IndexedSeq[DataDesc],
+               _provideLabelDesc: IndexedSeq[DataDesc],
+               _provideData: ListMap[String, Shape],
                _provideLabel: ListMap[String, Shape],
-               _batchSize: Int) =
+               _batchSize: Int) = {
     if (hasNext) {
       iterNext()
       val data = currentBatch.data(0)
       val label = currentBatch.label(0)
       // properties
-      val res = (ListMap(dataName -> data.shape), ListMap(labelName -> 
label.shape), data.shape(0))
+      val res = (
+        // TODO: need to allow user to specify DType and Layout
+        IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, 
Layout.UNDEFINED)),
+        IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, 
Layout.UNDEFINED)),
+        ListMap(dataName -> data.shape),
+        ListMap(labelName -> label.shape),
+        data.shape(0))
       currentBatch.dispose()
       reset()
       res
     } else {
-      (null, null, 0)
+      (null, null, null, null, 0)
     }
+  }
+
 
   private var disposed = false
   protected def isDisposed = disposed
@@ -101,10 +112,12 @@ private[mxnet] class MXDataIter(private[mxnet] val 
handle: DataIterHandle,
   private def iterNext(): Boolean = {
     val next = new RefInt
     checkCall(_LIB.mxDataIterNext(handle, next))
-    currentBatch = null
     if (next.value > 0) {
       currentBatch = new DataBatch(data = getData(), label = getLabel(),
-        index = getIndex(), pad = getPad())
+        index = getIndex(), pad = getPad(),
+        null, null, null)
+    } else {
+      currentBatch = null
     }
     next.value > 0
   }
@@ -152,11 +165,19 @@ private[mxnet] class MXDataIter(private[mxnet] val 
handle: DataIterHandle,
   }
 
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = _provideData
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = _provideLabel
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def hasNext: Boolean = {
     if (currentBatch != null) {
       true
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 1046131..e6be0ad 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -20,6 +20,7 @@ package org.apache.mxnet.io
 import java.util.NoSuchElementException
 
 import org.apache.mxnet.Base._
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 
@@ -39,35 +40,35 @@ import scala.collection.immutable.ListMap
  * the size of data does not match batch_size. Roll over is intended
  * for training and can cause problems if used for prediction.
  */
-class NDArrayIter(data: IndexedSeq[(String, NDArray)],
-                  label: IndexedSeq[(String, NDArray)],
+class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
+                  label: IndexedSeq[(DataDesc, NDArray)],
                   private val dataBatchSize: Int, shuffle: Boolean,
                   lastBatchHandle: String) extends DataIter {
 
   /**
-   * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
-   * @param label Same as data, but is not fed to the model during testing.
-   *              Label names will be label_0, label_1, ..., etc.
-   * @param dataBatchSize Batch Size
-   * @param shuffle Whether to shuffle the data
-   * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the 
last batch
-   *
-   * This iterator will pad, discard or roll over the last batch if
-   * the size of data does not match batch_size. Roll over is intended
-   * for training and can cause problems if used for prediction.
-   */
+  * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
+  * @param label Same as data, but is not fed to the model during testing.
+  *              Label names will be label_0, label_1, ..., etc.
+  * @param dataBatchSize Batch Size
+  * @param shuffle Whether to shuffle the data
+  * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the 
last batch
+  *
+  * This iterator will pad, discard or roll over the last batch if
+  * the size of data does not match batch_size. Roll over is intended
+  * for training and can cause problems if used for prediction.
+  */
   def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = 
IndexedSeq.empty,
            dataBatchSize: Int = 1, shuffle: Boolean = false,
            lastBatchHandle: String = "pad",
            dataName: String = "data", labelName: String = "label") {
-    this(IO.initData(data, allowEmpty = false, dataName),
-      IO.initData(label, allowEmpty = true, labelName),
+    this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, 
Layout.UNDEFINED),
+      IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, 
Layout.UNDEFINED),
       dataBatchSize, shuffle, lastBatchHandle)
   }
 
   private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
 
-  val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, 
NDArray)]) = {
+  val (initData: IndexedSeq[(DataDesc, NDArray)], initLabel: 
IndexedSeq[(DataDesc, NDArray)]) = {
     // data should not be null and size > 0
     require(data != null && data.size > 0,
       "data should not be null and data.size should not be zero")
@@ -101,20 +102,30 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
   private var cursor = -dataBatchSize
 
   private val (_provideData: ListMap[String, Shape],
-               _provideLabel: ListMap[String, Shape]) = {
+               _provideLabel: ListMap[String, Shape],
+               _provideDataDesc: IndexedSeq[DataDesc],
+               _provideLabelDesc: IndexedSeq[DataDesc]) = {
     val pData = ListMap.empty[String, Shape] ++ initData.map(getShape)
     val pLabel = ListMap.empty[String, Shape] ++ initLabel.map(getShape)
-    (pData, pLabel)
+    val pDData = IndexedSeq.empty[DataDesc] ++ initData.map(ele => {
+      val temp = getShape(ele)
+      new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout)
+    })
+    val pDLabel = IndexedSeq.empty[DataDesc] ++ initLabel.map(ele => {
+      val temp = getShape(ele)
+      new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout)
+    })
+    (pData, pLabel, pDData, pDLabel)
   }
 
   /**
    * get shape via dataBatchSize
    * @param dataItem
    */
-  private def getShape(dataItem: (String, NDArray)): (String, Shape) = {
+  private def getShape(dataItem: (DataDesc, NDArray)): (String, Shape) = {
     val len = dataItem._2.shape.size
     val newShape = dataItem._2.shape.slice(1, len)
-    (dataItem._1, Shape(Array[Int](dataBatchSize)) ++ newShape)
+    (dataItem._1.name, Shape(Array[Int](dataBatchSize)) ++ newShape)
   }
 
 
@@ -148,7 +159,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
   override def next(): DataBatch = {
     if (hasNext) {
       cursor += dataBatchSize
-      new DataBatch(getData(), getLabel(), getIndex(), getPad())
+      new DataBatch(getData(), getLabel(), getIndex(), getPad(),
+        null, null, null)
     } else {
       throw new NoSuchElementException
     }
@@ -172,7 +184,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
     }
   }
 
-  private def _getData(data: IndexedSeq[(String, NDArray)]): 
IndexedSeq[NDArray] = {
+  private def _getData(data: IndexedSeq[(DataDesc, NDArray)]): 
IndexedSeq[NDArray] = {
     require(cursor < numData, "DataIter needs reset.")
     if (data == null) {
       null
@@ -223,12 +235,21 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
     }
   }
 
+
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = _provideData
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = _provideLabel
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def batchSize: Int = dataBatchSize
 }
 
@@ -238,8 +259,8 @@ object NDArrayIter {
    * Builder class for NDArrayIter.
    */
   class Builder() {
-    private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
-    private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+    private var data: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty
+    private var label: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty
     private var dataBatchSize: Int = 1
     private var lastBatchHandle: String = "pad"
 
@@ -250,7 +271,8 @@ object NDArrayIter {
      * @return The builder object itself.
      */
     def addData(name: String, data: NDArray): Builder = {
-      this.data = this.data ++ IndexedSeq((name, data))
+      this.data = this.data ++ IndexedSeq((new DataDesc(name,
+        data.shape, DType.Float32, Layout.UNDEFINED), data))
       this
     }
 
@@ -261,7 +283,24 @@ object NDArrayIter {
      * @return The builder object itself.
      */
     def addLabel(name: String, label: NDArray): Builder = {
-      this.label = this.label ++ IndexedSeq((name, label))
+      this.label = this.label ++ IndexedSeq((new DataDesc(name,
+        label.shape, DType.Float32, Layout.UNDEFINED), label))
+      this
+    }
+
+    /**
+      * Add one data input with its DataDesc
+     */
+    def addDataWithDesc(dataDesc: DataDesc, data: NDArray): Builder = {
+      this.data = this.data ++ IndexedSeq((dataDesc, data))
+      this
+    }
+
+    /**
+      * Add one label input with its DataDesc
+      */
+    def addLabelWithDesc(labelDesc: DataDesc, label: NDArray): Builder = {
+      this.data = this.data ++ IndexedSeq((labelDesc, label))
       this
     }
 
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
index c0c0d17..e59e370 100644
--- 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
+++ 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
@@ -17,10 +17,12 @@
 
 package org.apache.mxnet.io
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 import java.util.concurrent.Semaphore
 
+import org.apache.mxnet.DType.DType
+
 import scala.collection.immutable.ListMap
 
 /**
@@ -68,6 +70,42 @@ class PrefetchingIter(
     }
   }
 
+  private val _provideDataDesc: IndexedSeq[DataDesc] = {
+    if (dataNames == null) {
+      iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, 
elem) =>
+        acc ++ elem
+      }
+    } else {
+      iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2))
+        .map(m =>
+          m._1.map(t =>
+            new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout)
+          )
+        )
+        .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
+          acc ++ elem
+        }
+    }
+  }
+
+  private val _provideLabelDesc: IndexedSeq[DataDesc] = {
+    if (labelNames == null) {
+      iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, 
elem) =>
+        acc ++ elem
+      }
+    } else {
+      iters.zipWithIndex.map(tu => (tu._1.provideLabelDesc, tu._2))
+        .map(m =>
+          m._1.map(t =>
+            new DataDesc(labelNames(m._2)(t.name), t.shape, t.dtype, t.layout)
+          )
+        )
+        .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
+          acc ++ elem
+        }
+    }
+  }
+
   private val _batchSize: Int = this._provideData.toList(0)._2(0)
   private val dataReady: IndexedSeq[Semaphore] =
                                         (0 until iters.length).map(i => new 
Semaphore(0))
@@ -132,19 +170,27 @@ class PrefetchingIter(
    */
   override def getIndex(): IndexedSeq[Long] = currentBatch.index
 
-  // The name and shape of label provided by this iterator
-  override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
   /**
-   * get the number of padding examples
-   * in current batch
-   * @return number of padding examples in current batch
-   */
+    * get the number of padding examples
+    * in current batch
+    * @return number of padding examples in current batch
+    */
   override def getPad(): Int = this.currentBatch.pad
 
+  // The name and shape of label provided by this iterator
+  @deprecated
+  override def provideLabel: ListMap[String, Shape] = this._provideLabel
+
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = this._provideData
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def hasNext: Boolean = {
     for (e <- dataReady) e.acquire()
     if (nextBatch(0) == null) {
@@ -161,9 +207,10 @@ class PrefetchingIter(
       val datas = for (batch <- nextBatch) yield batch.data
       val labels = for (batch <- nextBatch) yield batch.label
       currentBatch = new DataBatch(datas.toIndexedSeq.flatten,
-                                      labels.toIndexedSeq.flatten,
-                                      nextBatch(0).index,
-                                      nextBatch(0).pad)
+        labels.toIndexedSeq.flatten,
+        nextBatch(0).index,
+        nextBatch(0).pad,
+        null, null, null)
       for (e <- dataTaken) e.release()
       true
     }
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
index 75d88d1..e840af9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
@@ -19,7 +19,8 @@ package org.apache.mxnet.io
 
 import java.util.NoSuchElementException
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 
 import scala.collection.immutable.ListMap
@@ -133,12 +134,24 @@ class ResizeIter(
   }
 
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = {
     dataIter.provideData
   }
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = {
     dataIter.provideLabel
   }
+
+  // The name and shape of data provided by this iterator
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    dataIter.provideDataDesc
+  }
+
+  // The name and shape of label provided by this iterator
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    dataIter.provideLabelDesc
+  }
 }
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 1b922b3..2ec6f66 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -243,7 +243,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     val batchLabel = NDArray.ones(Shape(Array(128, 1)))
 
     // test pad
-    val dataIter0 = new NDArrayIter(data, label, 128, false, "pad")
+    val dataIter0 = new NDArrayIter(data, label, 128, false, "pad",
+      dataName = "data", labelName = "label")
     var batchCount = 0
     val nBatch0 = 8
     while(dataIter0.hasNext) {
@@ -277,7 +278,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     assert(batchCount === nBatch1)
 
     // test empty label (for prediction)
-    val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, 
lastBatchHandle = "discard")
+    val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, shuffle 
= false,
+      lastBatchHandle = "discard")
     batchCount = 0
     while(dataIter2.hasNext) {
       val tBatch = dataIter2.next()
@@ -289,5 +291,17 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
 
     assert(batchCount === nBatch1)
     assert(dataIter2.initLabel == IndexedSeq.empty)
+
+    // test implementation with DataDesc
+    val dataIter3 = new NDArrayIter(
+      IO.initDataDesc(data, false, "data", DType.Float32, Layout.NTC),
+      IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+      128, false, "pad")
+    val dataDesc = dataIter3.provideDataDesc
+    val labelDesc = dataIter3.provideLabelDesc
+    assert(dataDesc(0).dtype == DType.Float32)
+    assert(dataDesc(0).layout == Layout.NTC)
+    assert(labelDesc(0).dtype == DType.Int32)
+    assert(labelDesc(0).layout == Layout.NT)
   }
 }
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 8234568..88e314e 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -24,7 +24,7 @@ import org.apache.mxnet.io._
 
 class ModuleSuite extends FunSuite with BeforeAndAfterAll {
   test ("model dtype") {
-    val dType = DType.Float16
+    val dType = DType.Float32
     val dShape = Shape(3, 8, 7)
 
     var sym = Symbol.Variable("data")
@@ -196,8 +196,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
     // create module
     val mod = new Module(x, contexts = Array(Context.cpu()))
-    mod.bind(dataShapes = trainData.provideData,
-      Option(trainData.provideLabel))
+    mod.bind(dataShapes = trainData.provideDataDesc,
+      Option(trainData.provideLabelDesc))
     val argParamsCorrect = Map(
       "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 
2)),
       "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
index bb17046..068aa63 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
@@ -21,9 +21,6 @@ import org.apache.mxnet.Shape
 import org.apache.mxnet.IO
 import org.apache.mxnet.DataIter
 
-/**
- * @author Depeng Liang
- */
 object Data {
 
   // return train and val iterators for mnist
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
index 9df2bcc..825e465 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
@@ -25,14 +25,9 @@ import org.slf4j.LoggerFactory
 
 import scala.collection.JavaConverters._
 import org.apache.commons.io.FileUtils
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.DataIter
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Shape
-import org.apache.mxnet.EvalMetric
-import org.apache.mxnet.Context
-import org.apache.mxnet.Xavier
+
+import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, 
NDArray, Shape, Symbol, Xavier}
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet.optimizer.RMSProp
 import org.apache.mxnet.Executor
 import org.apache.mxnetexamples.Util
@@ -70,9 +65,9 @@ object ExampleMultiTask {
         val batch = this.dataIter.next()
         val label = batch.label(0)
         new DataBatch(batch.data,
-                                     IndexedSeq(label, label),
-                                     batch.index,
-                                     batch.pad)
+          IndexedSeq(label, label),
+          batch.index,
+          batch.pad, null, null, null)
       } else {
         throw new NoSuchElementException
       }
@@ -107,6 +102,7 @@ object ExampleMultiTask {
     override def getIndex(): IndexedSeq[Long] = this.dataIter.getIndex()
 
     // The name and shape of label provided by this iterator
+    @deprecated
     override def provideLabel: ListMap[String, Shape] = {
       val provideLabel = this.dataIter.provideLabel.toArray
       // Different labels should be used here for actual application
@@ -114,6 +110,16 @@ object ExampleMultiTask {
               "softmax2_label" -> provideLabel(0)._2)
     }
 
+    // The name and shape of label provided by this iterator
+    override def provideLabelDesc: IndexedSeq[DataDesc] = {
+      val head = this.dataIter.provideLabelDesc(0)
+      // Different labels should be used here for actual application
+      IndexedSeq(
+        new DataDesc("softmax1_label", head.shape, head.dtype, head.layout),
+        new DataDesc("softmax2_label", head.shape, head.dtype, head.layout)
+      )
+    }
+
     /**
      * get the number of padding examples
      * in current batch
@@ -122,8 +128,11 @@ object ExampleMultiTask {
     override def getPad(): Int = this.dataIter.getPad()
 
     // The name and shape of data provided by this iterator
+    @deprecated
     override def provideData: ListMap[String, Shape] = 
this.dataIter.provideData
 
+    override def provideDataDesc: IndexedSeq[DataDesc] = 
this.dataIter.provideDataDesc
+
     override def hasNext: Boolean = this.dataIter.hasNext
   }
 
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index f0eae68..d4b1707 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -18,17 +18,16 @@
 
 package org.apache.mxnetexamples.rnn
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
+
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 import scala.util.Random
 import scala.collection.mutable
 
-/**
- * @author Depeng Liang
- */
 object BucketIo {
 
   type Text2Id = (String, Map[String, Int]) => Array[Int]
@@ -92,11 +91,14 @@ object BucketIo {
   }
 
   class BucketSentenceIter(
-      path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int],
-      _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, 
Int))],
-      seperateChar: String = " <eos> ", text2Id: Text2Id = defaultText2Id,
+      path: String,
+      vocab: Map[String, Int],
+      var buckets: IndexedSeq[Int],
+      _batchSize: Int,
+      private val initStates: IndexedSeq[(String, (Int, Int))],
+      seperateChar: String = " <eos> ",
+      text2Id: Text2Id = defaultText2Id,
       readContent: ReadContent = defaultReadContent) extends DataIter {
-
     private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter])
 
     private val content = readContent(path)
@@ -165,8 +167,22 @@ object BucketIo {
     private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, 
_defaultBucketKey))
       tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
     }
+
     private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize, 
_defaultBucketKey))
 
+    private val _provideDataDesc = {
+      // TODO: need to allow user to specify DType and Layout
+      val tmp = IndexedSeq(new DataDesc("data",
+        Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED))
+      tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2),
+        DType.Float32, Layout.UNDEFINED))
+    }
+
+    private val _provideLabelDesc = IndexedSeq(
+      // TODO: need to allow user to specify DType and Layout
+      new DataDesc("softmax_label",
+      Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED))
+
     private var iBucket = 0
 
     override def next(): DataBatch = {
@@ -228,19 +244,27 @@ object BucketIo {
      */
     override def getIndex(): IndexedSeq[Long] = IndexedSeq[Long]()
 
-    // The name and shape of label provided by this iterator
-    override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
     /**
-     * get the number of padding examples
-     * in current batch
-     * @return number of padding examples in current batch
-     */
+      * get the number of padding examples
+      * in current batch
+      * @return number of padding examples in current batch
+      */
     override def getPad(): Int = 0
 
+    // The name and shape of label provided by this iterator
+    @deprecated
+    override def provideLabel: ListMap[String, Shape] = this._provideLabel
+
     // The name and shape of data provided by this iterator
+    @deprecated
     override def provideData: ListMap[String, Shape] = this._provideData
 
+    // Provide type:DataDesc of the data
+    override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+    // Provide type:DataDesc of the label
+    override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
     override def hasNext: Boolean = {
       iBucket < bucketPlan.length
     }
diff --git 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
index 8160f0f..39139f8 100644
--- 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
+++ 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
@@ -19,6 +19,8 @@ package org.apache.mxnet.infer
 
 // scalastyle:off
 import java.awt.image.BufferedImage
+
+import org.apache.mxnet.{DType, Layout}
 // scalastyle:on
 import org.apache.mxnet.Context
 import org.apache.mxnet.DataDesc
@@ -69,7 +71,8 @@ class ObjectDetectorSuite extends ClassifierSuite with 
BeforeAndAfterAll {
   }
 
   test("objectDetectWithInputImage") {
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512),
+      DType.Float32, Layout.NCHW))
     val inputImage = new BufferedImage(512, 512, BufferedImage.TYPE_INT_RGB)
     val testObjectDetector: ObjectDetector =
       new MyObjectDetector(modelPath, inputDescriptor)
@@ -109,7 +112,8 @@ class ObjectDetectorSuite extends ClassifierSuite with 
BeforeAndAfterAll {
   }
 
   test("objectDetectWithBatchImages") {
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512),
+      DType.Float32, Layout.NCHW))
     val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB)
     val imageBatch = IndexedSeq[BufferedImage](inputImage, inputImage)
 
diff --git 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
index 53fd7f3..509ffb3 100644
--- 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
+++ 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.infer
 
 import org.apache.mxnet.io.NDArrayIter
 import org.apache.mxnet.module.{BaseModule, Module}
-import org.apache.mxnet.{DataDesc, NDArray, Shape}
+import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape}
 import org.mockito.Matchers._
 import org.mockito.Mockito
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -40,15 +40,17 @@ class PredictorSuite extends FunSuite with 
BeforeAndAfterAll {
   }
 
   test("PredictorSuite-testPredictorConstruction") {
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2),
+      layout = Layout.NCHW))
 
     val mockPredictor = new MyPredictor("xyz", inputDescriptor)
 
     assert(mockPredictor.getBatchSize == 1)
     assert(mockPredictor.getBatchIndex == 
inputDescriptor(0).layout.indexOf('N'))
 
-    val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2)),
-      new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2),
+      layout = Layout.NCHW),
+      new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW))
 
     assertThrows[IllegalArgumentException] {
       val mockPredictor = new MyPredictor("xyz", inputDescriptor2)
@@ -63,7 +65,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
 
   test("PredictorSuite-testWithFlatArrays") {
 
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 
3, 2, 2)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 
3, 2, 2),
+      layout = Layout.NCHW))
     val inputData = Array.fill[Float](12)(1)
 
     // this will disposed at the end of the predict call on Predictor.
@@ -89,7 +92,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
   }
 
   test("PredictorSuite-testWithNDArray") {
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 
3, 2, 2)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 
3, 2, 2),
+      layout = Layout.NCHW))
     val inputData = NDArray.ones(Shape(1, 3, 2, 2))
 
     // this will disposed at the end of the predict call on Predictor.
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 3511f4a..c221b47 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -231,6 +231,7 @@
           <skipTests>${skipTests}</skipTests>
           
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
           <junitxml>.</junitxml>
+          <stdout>F</stdout>
           <filereports>WDF TestSuite.txt</filereports>
         </configuration>
         <executions>
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
index adc723e..bf1b26e 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
@@ -17,7 +17,8 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{DataBatch, NDArray, Shape, DataIter}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.apache.spark.mllib.regression.LabeledPoint
 
 import scala.collection.immutable.ListMap
@@ -25,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer
 
 /**
  * A helper converter for LabeledPoint
- * @author Yizhi Liu
  */
 class LabeledPointIter private[mxnet](
   private val points: Iterator[LabeledPoint],
@@ -115,15 +115,27 @@ class LabeledPointIter private[mxnet](
   }
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = {
     ListMap(labelName -> Shape(_batchSize))
   }
 
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = {
     ListMap(dataName -> dataShape)
   }
 
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    // TODO: need to allow user to specify DType and Layout
+    IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32, 
Layout.UNDEFINED))
+  }
+
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    // TODO: need to allow user to specify DType and Layout
+    IndexedSeq(new DataDesc(labelName, Shape(_batchSize), DType.Float32, 
Layout.UNDEFINED))
+  }
+
   /**
    * Get the number of padding examples
    * in current batch
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
index 339f7e2..e3272a4 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
@@ -17,7 +17,8 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{NDArray, DataBatch}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.{DataBatch, NDArray}
 
 /**
  * Dispose only when 'disposeForce' called
@@ -27,7 +28,8 @@ class LongLivingDataBatch(
   override val data: IndexedSeq[NDArray],
   override val label: IndexedSeq[NDArray],
   override val index: IndexedSeq[Long],
-  override val pad: Int) extends DataBatch(data, label, index, pad) {
+  override val pad: Int) extends DataBatch(data, label, index, pad,
+  null, null, null) {
   override def dispose(): Unit = {}
   def disposeForce(): Unit = super.dispose()
 }
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
index 2132929..a955ee7 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
@@ -17,7 +17,8 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{NDArray, DataBatch, DataIter, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.apache.spark.mllib.linalg.Vector
 
 import scala.collection.immutable.ListMap
@@ -25,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer
 
 /**
  * A temporary helper implementation for predicting Vectors
- * @author Yizhi Liu
  */
 class PointIter private[mxnet](
   private val points: Iterator[Vector],
@@ -114,15 +114,27 @@ class PointIter private[mxnet](
   }
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = {
     ListMap(labelName -> Shape(_batchSize))
   }
 
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = {
     ListMap(dataName -> dataShape)
   }
 
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    // TODO: Make DType, Layout configurable
+    IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32, 
Layout.UNDEFINED))
+  }
+
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    IndexedSeq(new DataDesc(labelName, Shape(_batchSize),
+      DType.Float32, Layout.UNDEFINED))
+  }
+
   /**
    * Get the number of padding examples
    * in current batch

Reply via email to