nswamy closed pull request #11844: [MXNET-689] add DataDesc type for the Scala 
Package
URL: https://github.com/apache/incubator-mxnet/pull/11844
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
index 4458a7c7aeb..f3a8e8e9a4a 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
@@ -35,4 +35,15 @@ object DType extends Enumeration {
       case DType.Unknown => 0
     }
   }
+  private[mxnet] def getType(dtypeStr: String): DType = {
+    dtypeStr match {
+      case "UInt8" => DType.UInt8
+      case "Int32" => DType.Int32
+      case "Float16" => DType.Float16
+      case "Float32" => DType.Float32
+      case "Float64" => DType.Float64
+      case _ => throw new IllegalArgumentException(
+        s"DType: $dtypeStr not found! please set it in DType.scala")
+    }
+  }
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index 47fd4eee939..a1095cf0483 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -25,7 +25,6 @@ import org.slf4j.LoggerFactory
 import scala.annotation.varargs
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ListBuffer
-
 /**
  * IO iterators for loading training & validation data
  */
@@ -110,18 +109,22 @@ object IO {
   }
 
   // Convert data into canonical form.
-  private[mxnet] def initData(data: IndexedSeq[NDArray],
-                              allowEmpty: Boolean,
-                              defaultName: String): IndexedSeq[(String, 
NDArray)] = {
+  private[mxnet] def initDataDesc(data: IndexedSeq[NDArray],
+                                  allowEmpty: Boolean,
+                                  defaultName: String,
+                                  defaultDType: DType,
+                                  defaultLayout: String): 
IndexedSeq[(DataDesc, NDArray)] = {
     require(data != null)
     require(data != IndexedSeq.empty || allowEmpty)
     if (data == IndexedSeq.empty) {
       IndexedSeq()
     } else if (data.length == 1) {
-      IndexedSeq((defaultName, data(0)))
+      IndexedSeq((new DataDesc(defaultName, data(0).shape,
+        defaultDType, defaultLayout), data(0)))
     } else {
       data.zipWithIndex.map(item => {
-        (defaultName + "_" + item._2, item._1)
+        (new DataDesc(defaultName + "_" + item._2, item._1.shape,
+          defaultDType, defaultLayout), item._1)
       }).toIndexedSeq
     }
   }
@@ -136,11 +139,28 @@ class DataBatch(val data: IndexedSeq[NDArray],
                 val pad: Int,
                 // the key for the bucket that should be used for this batch,
                 // for bucketing io only
-                val bucketKey: AnyRef = null,
-                // use ListMap to indicate the order of data/label loading
+                val bucketKey: AnyRef,
+                // use DataDesc to indicate the order of data/label loading
                 // (must match the order of input data/label)
-                private val providedData: ListMap[String, Shape] = null,
-                private val providedLabel: ListMap[String, Shape] = null) {
+                private val providedDataDesc: IndexedSeq[DataDesc],
+                private val providedLabelDesc: IndexedSeq[DataDesc]) {
+  // TODO: change the data/label type into IndexedSeq[(NDArray, DataDesc)]
+  // However, since the data and label can be accessed publicly (no getter and 
setter)
+  // the change on this will break BC
+  def this(data: IndexedSeq[NDArray],
+            label: IndexedSeq[NDArray],
+            index: IndexedSeq[Long],
+            pad: Int,
+            // the key for the bucket that should be used for this batch,
+            // for bucketing io only
+            bucketKey: AnyRef = null,
+            // use ListMap to indicate the order of data/label loading
+            // (must match the order of input data/label)
+            providedData: ListMap[String, Shape] = null,
+            providedLabel: ListMap[String, Shape] = null) {
+    this(data, label, index, pad, bucketKey,
+      DataDesc.ListMap2Descs(providedData), 
DataDesc.ListMap2Descs(providedLabel))
+  }
   /**
    * Dispose its data and labels
    * The object shall never be used after it is disposed.
@@ -155,10 +175,29 @@ class DataBatch(val data: IndexedSeq[NDArray],
   }
 
   // The name and shape of data
-  def provideData: ListMap[String, Shape] = providedData
+  def provideData: ListMap[String, Shape] = {
+    var temp = ListMap[String, Shape]()
+    if (providedDataDesc == null) null
+    else {
+      providedDataDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
+      temp
+    }
+  }
 
   // The name and shape of label
-  def provideLabel: ListMap[String, Shape] = providedLabel
+  def provideLabel: ListMap[String, Shape] = {
+    var temp = ListMap[String, Shape]()
+    if (providedLabelDesc == null) null
+    else {
+      providedLabelDesc.foreach(ele => temp = temp + (ele.name -> ele.shape))
+      temp
+    }
+  }
+
+  def provideDataDesc: IndexedSeq[DataDesc] = providedDataDesc
+
+  def provideLabelDesc: IndexedSeq[DataDesc] = providedLabelDesc
+
 }
 
 object DataBatch {
@@ -171,8 +210,8 @@ object DataBatch {
     private var index: IndexedSeq[Long] = null
     private var pad: Int = 0
     private var bucketKey: AnyRef = null
-    private var datatShapes: ListMap[String, Shape] = null
-    private var labelShapes: ListMap[String, Shape] = null
+    private var dataDesc: IndexedSeq[DataDesc] = null
+    private var labelDesc: IndexedSeq[DataDesc] = null
 
     /**
      * Set the input data.
@@ -228,37 +267,27 @@ object DataBatch {
 
     /**
      * Provide the shape of a data.
-     * @param name data name.
-     * @param shape data shape.
+     * @param dataDesc DataDescriptor
      * @return this.
      */
-    def provideDataShape(name: String, shape: Shape): Builder = {
-      if (datatShapes == null) {
-        datatShapes = ListMap((name, shape))
-      } else {
-        datatShapes = datatShapes.updated(name, shape)
-      }
+    def provideDataDesc(dataDesc: IndexedSeq[DataDesc]): Builder = {
+      this.dataDesc = dataDesc
       this
     }
 
     /**
      * Provide the shape of a label.
-     * @param name label name.
-     * @param shape label shape.
+     * @param labelDesc LabelDescriptor
      * @return this.
      */
-    def provideLabelShape(name: String, shape: Shape): Builder = {
-      if (labelShapes == null) {
-        labelShapes = ListMap((name, shape))
-      } else {
-        labelShapes = labelShapes.updated(name, shape)
-      }
+    def provideLabelDesc(labelDesc: IndexedSeq[DataDesc]): Builder = {
+      this.labelDesc = labelDesc
       this
     }
 
     def build(): DataBatch = {
       require(data != null, "data is required.")
-      new DataBatch(data, label, index, pad, bucketKey, datatShapes, 
labelShapes)
+      new DataBatch(data, label, index, pad, bucketKey, dataDesc, labelDesc)
     }
   }
 }
@@ -280,7 +309,8 @@ abstract class DataIter extends Iterator[DataBatch] {
    */
   @throws(classOf[NoSuchElementException])
   def next(): DataBatch = {
-    new DataBatch(getData(), getLabel(), getIndex(), getPad())
+    new DataBatch(getData(), getLabel(), getIndex(), getPad(),
+      null, null, null)
   }
 
   /**
@@ -309,11 +339,19 @@ abstract class DataIter extends Iterator[DataBatch] {
   def getIndex(): IndexedSeq[Long]
 
   // The name and shape of data provided by this iterator
+  @deprecated
   def provideData: ListMap[String, Shape]
 
   // The name and shape of label provided by this iterator
+  @deprecated
   def provideLabel: ListMap[String, Shape]
 
+  // Provide type:DataDesc of the data
+  def provideDataDesc: IndexedSeq[DataDesc]
+
+  // Provide type:DataDesc of the label
+  def provideLabelDesc: IndexedSeq[DataDesc]
+
   // For bucketing io only
   // The bucket key for the default symbol.
   def defaultBucketKey: AnyRef = null
@@ -332,8 +370,9 @@ abstract class DataPack() extends Iterable[DataBatch] {
 
 // Named data desc description contains name, shape, type and other extended 
attributes.
 case class DataDesc(name: String, shape: Shape,
-                    dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") 
{
-  require(shape.length == layout.length, ("number of dimensions in shape :%d 
with" +
+                    dtype: DType = DType.Float32, layout: String = 
Layout.UNDEFINED) {
+  require(layout == Layout.UNDEFINED || shape.length == layout.length,
+    ("number of dimensions in shape :%d with" +
     " shape: %s should match the length of the layout: %d with layout: %s").
     format(shape.length, shape.toString, layout.length, layout))
 
@@ -343,6 +382,8 @@ case class DataDesc(name: String, shape: Shape,
 }
 
 object DataDesc {
+
+  private val logger = LoggerFactory.getLogger(classOf[DataDesc])
   /**
    * Get the dimension that corresponds to the batch size.
    * @param layout layout string. For example, "NCHW".
@@ -352,9 +393,19 @@ object DataDesc {
    *         for each data-parallelism device.
    */
   def getBatchAxis(layout: Option[String]): Int = {
-    layout.map(_.indexOf('N')).getOrElse(0)
+    if (layout.isEmpty|| layout.get == Layout.UNDEFINED) {
+      logger.warn("Found Undefined Layout, will use default index 0 for batch 
axis")
+      0
+    } else {
+      if (layout.get.contains('N')) {
+        layout.get.indexOf("N")
+      } else {
+        throw new IllegalArgumentException("no Batch Axis('N') found in 
Layout!")
+      }
+    }
   }
 
+  @deprecated
   implicit def ListMap2Descs(shapes: ListMap[String, Shape]): 
IndexedSeq[DataDesc] = {
     if (shapes != null) {
       shapes.map { case (k, s) => new DataDesc(k, s) }.toIndexedSeq
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
new file mode 100644
index 00000000000..cb75dbc4080
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
@@ -0,0 +1,35 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+/**
+  * Layout definition of DataDesc
+  * N Batch size
+  * C channels
+  * H Height
+  * W Weight
+  * T sequence length
+  * __undefined__ default value of Layout
+  */
+object Layout {
+  val UNDEFINED = "__undefined__"
+  val NCHW = "NCHW"
+  val NTC = "NTC"
+  val NT = "NT"
+  val N = "N"
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
index ee3e950512e..578f00a76f9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
@@ -28,9 +28,6 @@ import java.io.ByteArrayInputStream
 
 /**
  * Scala interface for read/write RecordIO data format
- *
- * @author Depeng Liang
- *
  * @param uri, path to recordIO file.
  * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
  */
@@ -144,7 +141,7 @@ object MXRecordIO {
  *
  * @author Depeng Liang
  *
- * @param idx_path, path to index file
+ * @param idxPath, path to index file
  * @param uri, path to recordIO file.
  * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
  * @param keyType, data type for keys.
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index 2a0c333ebf1..f7f858deb82 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -18,7 +18,8 @@
 package org.apache.mxnet.io
 
 import org.apache.mxnet.Base._
-import org.apache.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape, 
WarnIfNotDisposed}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.apache.mxnet.IO._
 import org.slf4j.LoggerFactory
 
@@ -41,21 +42,31 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: 
DataIterHandle,
   // fix me if any better way found)
   private var currentBatch: DataBatch = null
 
-  private val (_provideData: ListMap[String, Shape],
+  private val (_provideDataDesc: IndexedSeq[DataDesc],
+               _provideLabelDesc: IndexedSeq[DataDesc],
+               _provideData: ListMap[String, Shape],
                _provideLabel: ListMap[String, Shape],
-               _batchSize: Int) =
+               _batchSize: Int) = {
     if (hasNext) {
       iterNext()
       val data = currentBatch.data(0)
       val label = currentBatch.label(0)
       // properties
-      val res = (ListMap(dataName -> data.shape), ListMap(labelName -> 
label.shape), data.shape(0))
+      val res = (
+        // TODO: need to allow user to specify DType and Layout
+        IndexedSeq(new DataDesc(dataName, data.shape, DType.Float32, 
Layout.UNDEFINED)),
+        IndexedSeq(new DataDesc(labelName, label.shape, DType.Float32, 
Layout.UNDEFINED)),
+        ListMap(dataName -> data.shape),
+        ListMap(labelName -> label.shape),
+        data.shape(0))
       currentBatch.dispose()
       reset()
       res
     } else {
-      (null, null, 0)
+      (null, null, null, null, 0)
     }
+  }
+
 
   private var disposed = false
   protected def isDisposed = disposed
@@ -101,10 +112,12 @@ private[mxnet] class MXDataIter(private[mxnet] val 
handle: DataIterHandle,
   private def iterNext(): Boolean = {
     val next = new RefInt
     checkCall(_LIB.mxDataIterNext(handle, next))
-    currentBatch = null
     if (next.value > 0) {
       currentBatch = new DataBatch(data = getData(), label = getLabel(),
-        index = getIndex(), pad = getPad())
+        index = getIndex(), pad = getPad(),
+        null, null, null)
+    } else {
+      currentBatch = null
     }
     next.value > 0
   }
@@ -152,11 +165,19 @@ private[mxnet] class MXDataIter(private[mxnet] val 
handle: DataIterHandle,
   }
 
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = _provideData
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = _provideLabel
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def hasNext: Boolean = {
     if (currentBatch != null) {
       true
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 10461315c19..e6be0ad02f8 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -20,6 +20,7 @@ package org.apache.mxnet.io
 import java.util.NoSuchElementException
 
 import org.apache.mxnet.Base._
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 
@@ -39,35 +40,35 @@ import scala.collection.immutable.ListMap
  * the size of data does not match batch_size. Roll over is intended
  * for training and can cause problems if used for prediction.
  */
-class NDArrayIter(data: IndexedSeq[(String, NDArray)],
-                  label: IndexedSeq[(String, NDArray)],
+class NDArrayIter(data: IndexedSeq[(DataDesc, NDArray)],
+                  label: IndexedSeq[(DataDesc, NDArray)],
                   private val dataBatchSize: Int, shuffle: Boolean,
                   lastBatchHandle: String) extends DataIter {
 
   /**
-   * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
-   * @param label Same as data, but is not fed to the model during testing.
-   *              Label names will be label_0, label_1, ..., etc.
-   * @param dataBatchSize Batch Size
-   * @param shuffle Whether to shuffle the data
-   * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the 
last batch
-   *
-   * This iterator will pad, discard or roll over the last batch if
-   * the size of data does not match batch_size. Roll over is intended
-   * for training and can cause problems if used for prediction.
-   */
+  * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
+  * @param label Same as data, but is not fed to the model during testing.
+  *              Label names will be label_0, label_1, ..., etc.
+  * @param dataBatchSize Batch Size
+  * @param shuffle Whether to shuffle the data
+  * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the 
last batch
+  *
+  * This iterator will pad, discard or roll over the last batch if
+  * the size of data does not match batch_size. Roll over is intended
+  * for training and can cause problems if used for prediction.
+  */
   def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = 
IndexedSeq.empty,
            dataBatchSize: Int = 1, shuffle: Boolean = false,
            lastBatchHandle: String = "pad",
            dataName: String = "data", labelName: String = "label") {
-    this(IO.initData(data, allowEmpty = false, dataName),
-      IO.initData(label, allowEmpty = true, labelName),
+    this(IO.initDataDesc(data, allowEmpty = false, dataName, MX_REAL_TYPE, 
Layout.UNDEFINED),
+      IO.initDataDesc(label, allowEmpty = true, labelName, MX_REAL_TYPE, 
Layout.UNDEFINED),
       dataBatchSize, shuffle, lastBatchHandle)
   }
 
   private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
 
-  val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, 
NDArray)]) = {
+  val (initData: IndexedSeq[(DataDesc, NDArray)], initLabel: 
IndexedSeq[(DataDesc, NDArray)]) = {
     // data should not be null and size > 0
     require(data != null && data.size > 0,
       "data should not be null and data.size should not be zero")
@@ -101,20 +102,30 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
   private var cursor = -dataBatchSize
 
   private val (_provideData: ListMap[String, Shape],
-               _provideLabel: ListMap[String, Shape]) = {
+               _provideLabel: ListMap[String, Shape],
+               _provideDataDesc: IndexedSeq[DataDesc],
+               _provideLabelDesc: IndexedSeq[DataDesc]) = {
     val pData = ListMap.empty[String, Shape] ++ initData.map(getShape)
     val pLabel = ListMap.empty[String, Shape] ++ initLabel.map(getShape)
-    (pData, pLabel)
+    val pDData = IndexedSeq.empty[DataDesc] ++ initData.map(ele => {
+      val temp = getShape(ele)
+      new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout)
+    })
+    val pDLabel = IndexedSeq.empty[DataDesc] ++ initLabel.map(ele => {
+      val temp = getShape(ele)
+      new DataDesc(temp._1, temp._2, ele._1.dtype, ele._1.layout)
+    })
+    (pData, pLabel, pDData, pDLabel)
   }
 
   /**
    * get shape via dataBatchSize
    * @param dataItem
    */
-  private def getShape(dataItem: (String, NDArray)): (String, Shape) = {
+  private def getShape(dataItem: (DataDesc, NDArray)): (String, Shape) = {
     val len = dataItem._2.shape.size
     val newShape = dataItem._2.shape.slice(1, len)
-    (dataItem._1, Shape(Array[Int](dataBatchSize)) ++ newShape)
+    (dataItem._1.name, Shape(Array[Int](dataBatchSize)) ++ newShape)
   }
 
 
@@ -148,7 +159,8 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
   override def next(): DataBatch = {
     if (hasNext) {
       cursor += dataBatchSize
-      new DataBatch(getData(), getLabel(), getIndex(), getPad())
+      new DataBatch(getData(), getLabel(), getIndex(), getPad(),
+        null, null, null)
     } else {
       throw new NoSuchElementException
     }
@@ -172,7 +184,7 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
     }
   }
 
-  private def _getData(data: IndexedSeq[(String, NDArray)]): 
IndexedSeq[NDArray] = {
+  private def _getData(data: IndexedSeq[(DataDesc, NDArray)]): 
IndexedSeq[NDArray] = {
     require(cursor < numData, "DataIter needs reset.")
     if (data == null) {
       null
@@ -223,12 +235,21 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
     }
   }
 
+
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = _provideData
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = _provideLabel
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def batchSize: Int = dataBatchSize
 }
 
@@ -238,8 +259,8 @@ object NDArrayIter {
    * Builder class for NDArrayIter.
    */
   class Builder() {
-    private var data: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
-    private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
+    private var data: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty
+    private var label: IndexedSeq[(DataDesc, NDArray)] = IndexedSeq.empty
     private var dataBatchSize: Int = 1
     private var lastBatchHandle: String = "pad"
 
@@ -250,7 +271,8 @@ object NDArrayIter {
      * @return The builder object itself.
      */
     def addData(name: String, data: NDArray): Builder = {
-      this.data = this.data ++ IndexedSeq((name, data))
+      this.data = this.data ++ IndexedSeq((new DataDesc(name,
+        data.shape, DType.Float32, Layout.UNDEFINED), data))
       this
     }
 
@@ -261,7 +283,24 @@ object NDArrayIter {
      * @return The builder object itself.
      */
     def addLabel(name: String, label: NDArray): Builder = {
-      this.label = this.label ++ IndexedSeq((name, label))
+      this.label = this.label ++ IndexedSeq((new DataDesc(name,
+        label.shape, DType.Float32, Layout.UNDEFINED), label))
+      this
+    }
+
+    /**
+      * Add one data input with its DataDesc
+     */
+    def addDataWithDesc(dataDesc: DataDesc, data: NDArray): Builder = {
+      this.data = this.data ++ IndexedSeq((dataDesc, data))
+      this
+    }
+
+    /**
+      * Add one label input with its DataDesc
+      */
+    def addLabelWithDesc(labelDesc: DataDesc, label: NDArray): Builder = {
+      this.data = this.data ++ IndexedSeq((labelDesc, label))
       this
     }
 
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
index c0c0d1793b5..e59e3706317 100644
--- 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
+++ 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
@@ -17,10 +17,12 @@
 
 package org.apache.mxnet.io
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 import java.util.concurrent.Semaphore
 
+import org.apache.mxnet.DType.DType
+
 import scala.collection.immutable.ListMap
 
 /**
@@ -68,6 +70,42 @@ class PrefetchingIter(
     }
   }
 
+  private val _provideDataDesc: IndexedSeq[DataDesc] = {
+    if (dataNames == null) {
+      iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, 
elem) =>
+        acc ++ elem
+      }
+    } else {
+      iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2))
+        .map(m =>
+          m._1.map(t =>
+            new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout)
+          )
+        )
+        .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
+          acc ++ elem
+        }
+    }
+  }
+
+  private val _provideLabelDesc: IndexedSeq[DataDesc] = {
+    if (labelNames == null) {
+      iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, 
elem) =>
+        acc ++ elem
+      }
+    } else {
+      iters.zipWithIndex.map(tu => (tu._1.provideLabelDesc, tu._2))
+        .map(m =>
+          m._1.map(t =>
+            new DataDesc(labelNames(m._2)(t.name), t.shape, t.dtype, t.layout)
+          )
+        )
+        .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
+          acc ++ elem
+        }
+    }
+  }
+
   private val _batchSize: Int = this._provideData.toList(0)._2(0)
   private val dataReady: IndexedSeq[Semaphore] =
                                         (0 until iters.length).map(i => new 
Semaphore(0))
@@ -132,19 +170,27 @@ class PrefetchingIter(
    */
   override def getIndex(): IndexedSeq[Long] = currentBatch.index
 
-  // The name and shape of label provided by this iterator
-  override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
   /**
-   * get the number of padding examples
-   * in current batch
-   * @return number of padding examples in current batch
-   */
+    * get the number of padding examples
+    * in current batch
+    * @return number of padding examples in current batch
+    */
   override def getPad(): Int = this.currentBatch.pad
 
+  // The name and shape of label provided by this iterator
+  @deprecated
+  override def provideLabel: ListMap[String, Shape] = this._provideLabel
+
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = this._provideData
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def hasNext: Boolean = {
     for (e <- dataReady) e.acquire()
     if (nextBatch(0) == null) {
@@ -161,9 +207,10 @@ class PrefetchingIter(
       val datas = for (batch <- nextBatch) yield batch.data
       val labels = for (batch <- nextBatch) yield batch.label
       currentBatch = new DataBatch(datas.toIndexedSeq.flatten,
-                                      labels.toIndexedSeq.flatten,
-                                      nextBatch(0).index,
-                                      nextBatch(0).pad)
+        labels.toIndexedSeq.flatten,
+        nextBatch(0).index,
+        nextBatch(0).pad,
+        null, null, null)
       for (e <- dataTaken) e.release()
       true
     }
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
index 75d88d1ae72..e840af9395f 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
@@ -19,7 +19,8 @@ package org.apache.mxnet.io
 
 import java.util.NoSuchElementException
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 
 import scala.collection.immutable.ListMap
@@ -133,12 +134,24 @@ class ResizeIter(
   }
 
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = {
     dataIter.provideData
   }
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = {
     dataIter.provideLabel
   }
+
+  // The name and shape of data provided by this iterator
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    dataIter.provideDataDesc
+  }
+
+  // The name and shape of label provided by this iterator
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    dataIter.provideLabelDesc
+  }
 }
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 1b922b3c05b..2ec6f668dbc 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -243,7 +243,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     val batchLabel = NDArray.ones(Shape(Array(128, 1)))
 
     // test pad
-    val dataIter0 = new NDArrayIter(data, label, 128, false, "pad")
+    val dataIter0 = new NDArrayIter(data, label, 128, false, "pad",
+      dataName = "data", labelName = "label")
     var batchCount = 0
     val nBatch0 = 8
     while(dataIter0.hasNext) {
@@ -277,7 +278,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     assert(batchCount === nBatch1)
 
     // test empty label (for prediction)
-    val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, 
lastBatchHandle = "discard")
+    val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, shuffle 
= false,
+      lastBatchHandle = "discard")
     batchCount = 0
     while(dataIter2.hasNext) {
       val tBatch = dataIter2.next()
@@ -289,5 +291,17 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
 
     assert(batchCount === nBatch1)
     assert(dataIter2.initLabel == IndexedSeq.empty)
+
+    // test implementation with DataDesc
+    val dataIter3 = new NDArrayIter(
+      IO.initDataDesc(data, false, "data", DType.Float32, Layout.NTC),
+      IO.initDataDesc(label, false, "label", DType.Int32, Layout.NT),
+      128, false, "pad")
+    val dataDesc = dataIter3.provideDataDesc
+    val labelDesc = dataIter3.provideLabelDesc
+    assert(dataDesc(0).dtype == DType.Float32)
+    assert(dataDesc(0).layout == Layout.NTC)
+    assert(labelDesc(0).dtype == DType.Int32)
+    assert(labelDesc(0).layout == Layout.NT)
   }
 }
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 8234568d7d9..88e314e2a72 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -24,7 +24,7 @@ import org.apache.mxnet.io._
 
 class ModuleSuite extends FunSuite with BeforeAndAfterAll {
   test ("model dtype") {
-    val dType = DType.Float16
+    val dType = DType.Float32
     val dShape = Shape(3, 8, 7)
 
     var sym = Symbol.Variable("data")
@@ -196,8 +196,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
     // create module
     val mod = new Module(x, contexts = Array(Context.cpu()))
-    mod.bind(dataShapes = trainData.provideData,
-      Option(trainData.provideLabel))
+    mod.bind(dataShapes = trainData.provideDataDesc,
+      Option(trainData.provideLabelDesc))
     val argParamsCorrect = Map(
       "fc_0_weight" -> NDArray.array(Array(0.15f, 0.2f, 0.25f, 0.3f), Shape(2, 
2)),
       "fc_0_bias" -> NDArray.array(Array(0.35f, 0.35f), Shape(2)),
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
index bb17046b8b2..068aa6314f8 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
@@ -21,9 +21,6 @@ import org.apache.mxnet.Shape
 import org.apache.mxnet.IO
 import org.apache.mxnet.DataIter
 
-/**
- * @author Depeng Liang
- */
 object Data {
 
   // return train and val iterators for mnist
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
index 9df2bcc0566..825e4659675 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
@@ -25,14 +25,9 @@ import org.slf4j.LoggerFactory
 
 import scala.collection.JavaConverters._
 import org.apache.commons.io.FileUtils
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.DataIter
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Shape
-import org.apache.mxnet.EvalMetric
-import org.apache.mxnet.Context
-import org.apache.mxnet.Xavier
+
+import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, 
NDArray, Shape, Symbol, Xavier}
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet.optimizer.RMSProp
 import org.apache.mxnet.Executor
 import org.apache.mxnetexamples.Util
@@ -70,9 +65,9 @@ object ExampleMultiTask {
         val batch = this.dataIter.next()
         val label = batch.label(0)
         new DataBatch(batch.data,
-                                     IndexedSeq(label, label),
-                                     batch.index,
-                                     batch.pad)
+          IndexedSeq(label, label),
+          batch.index,
+          batch.pad, null, null, null)
       } else {
         throw new NoSuchElementException
       }
@@ -107,6 +102,7 @@ object ExampleMultiTask {
     override def getIndex(): IndexedSeq[Long] = this.dataIter.getIndex()
 
     // The name and shape of label provided by this iterator
+    @deprecated
     override def provideLabel: ListMap[String, Shape] = {
       val provideLabel = this.dataIter.provideLabel.toArray
       // Different labels should be used here for actual application
@@ -114,6 +110,16 @@ object ExampleMultiTask {
               "softmax2_label" -> provideLabel(0)._2)
     }
 
+    // The name and shape of label provided by this iterator
+    override def provideLabelDesc: IndexedSeq[DataDesc] = {
+      val head = this.dataIter.provideLabelDesc(0)
+      // Different labels should be used here for actual application
+      IndexedSeq(
+        new DataDesc("softmax1_label", head.shape, head.dtype, head.layout),
+        new DataDesc("softmax2_label", head.shape, head.dtype, head.layout)
+      )
+    }
+
     /**
      * get the number of padding examples
      * in current batch
@@ -122,8 +128,11 @@ object ExampleMultiTask {
     override def getPad(): Int = this.dataIter.getPad()
 
     // The name and shape of data provided by this iterator
+    @deprecated
     override def provideData: ListMap[String, Shape] = 
this.dataIter.provideData
 
+    override def provideDataDesc: IndexedSeq[DataDesc] = 
this.dataIter.provideDataDesc
+
     override def hasNext: Boolean = this.dataIter.hasNext
   }
 
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index f0eae6890c5..d4b17074d48 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -18,17 +18,16 @@
 
 package org.apache.mxnetexamples.rnn
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
+
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
 import scala.util.Random
 import scala.collection.mutable
 
-/**
- * @author Depeng Liang
- */
 object BucketIo {
 
   type Text2Id = (String, Map[String, Int]) => Array[Int]
@@ -92,11 +91,14 @@ object BucketIo {
   }
 
   class BucketSentenceIter(
-      path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int],
-      _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, 
Int))],
-      seperateChar: String = " <eos> ", text2Id: Text2Id = defaultText2Id,
+      path: String,
+      vocab: Map[String, Int],
+      var buckets: IndexedSeq[Int],
+      _batchSize: Int,
+      private val initStates: IndexedSeq[(String, (Int, Int))],
+      seperateChar: String = " <eos> ",
+      text2Id: Text2Id = defaultText2Id,
       readContent: ReadContent = defaultReadContent) extends DataIter {
-
     private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter])
 
     private val content = readContent(path)
@@ -165,8 +167,22 @@ object BucketIo {
     private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, 
_defaultBucketKey))
       tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
     }
+
     private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize, 
_defaultBucketKey))
 
+    private val _provideDataDesc = {
+      // TODO: need to allow user to specify DType and Layout
+      val tmp = IndexedSeq(new DataDesc("data",
+        Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED))
+      tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2),
+        DType.Float32, Layout.UNDEFINED))
+    }
+
+    private val _provideLabelDesc = IndexedSeq(
+      // TODO: need to allow user to specify DType and Layout
+      new DataDesc("softmax_label",
+      Shape(_batchSize, _defaultBucketKey), DType.Float32, Layout.UNDEFINED))
+
     private var iBucket = 0
 
     override def next(): DataBatch = {
@@ -228,19 +244,27 @@ object BucketIo {
      */
     override def getIndex(): IndexedSeq[Long] = IndexedSeq[Long]()
 
-    // The name and shape of label provided by this iterator
-    override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
     /**
-     * get the number of padding examples
-     * in current batch
-     * @return number of padding examples in current batch
-     */
+      * get the number of padding examples
+      * in current batch
+      * @return number of padding examples in current batch
+      */
     override def getPad(): Int = 0
 
+    // The name and shape of label provided by this iterator
+    @deprecated
+    override def provideLabel: ListMap[String, Shape] = this._provideLabel
+
     // The name and shape of data provided by this iterator
+    @deprecated
     override def provideData: ListMap[String, Shape] = this._provideData
 
+    // Provide type:DataDesc of the data
+    override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+    // Provide type:DataDesc of the label
+    override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
     override def hasNext: Boolean = {
       iBucket < bucketPlan.length
     }
diff --git 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
index 8160f0f6eb4..39139f8d3d2 100644
--- 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
+++ 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ObjectDetectorSuite.scala
@@ -19,6 +19,8 @@ package org.apache.mxnet.infer
 
 // scalastyle:off
 import java.awt.image.BufferedImage
+
+import org.apache.mxnet.{DType, Layout}
 // scalastyle:on
 import org.apache.mxnet.Context
 import org.apache.mxnet.DataDesc
@@ -69,7 +71,8 @@ class ObjectDetectorSuite extends ClassifierSuite with 
BeforeAndAfterAll {
   }
 
   test("objectDetectWithInputImage") {
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512),
+      DType.Float32, Layout.NCHW))
     val inputImage = new BufferedImage(512, 512, BufferedImage.TYPE_INT_RGB)
     val testObjectDetector: ObjectDetector =
       new MyObjectDetector(modelPath, inputDescriptor)
@@ -109,7 +112,8 @@ class ObjectDetectorSuite extends ClassifierSuite with 
BeforeAndAfterAll {
   }
 
   test("objectDetectWithBatchImages") {
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512),
+      DType.Float32, Layout.NCHW))
     val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB)
     val imageBatch = IndexedSeq[BufferedImage](inputImage, inputImage)
 
diff --git 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
index 53fd7f31068..509ffb35db8 100644
--- 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
+++ 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.infer
 
 import org.apache.mxnet.io.NDArrayIter
 import org.apache.mxnet.module.{BaseModule, Module}
-import org.apache.mxnet.{DataDesc, NDArray, Shape}
+import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape}
 import org.mockito.Matchers._
 import org.mockito.Mockito
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -40,15 +40,17 @@ class PredictorSuite extends FunSuite with 
BeforeAndAfterAll {
   }
 
   test("PredictorSuite-testPredictorConstruction") {
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2),
+      layout = Layout.NCHW))
 
     val mockPredictor = new MyPredictor("xyz", inputDescriptor)
 
     assert(mockPredictor.getBatchSize == 1)
     assert(mockPredictor.getBatchIndex == 
inputDescriptor(0).layout.indexOf('N'))
 
-    val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2)),
-      new DataDesc("data", Shape(2, 3, 2, 2)))
+    val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2),
+      layout = Layout.NCHW),
+      new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW))
 
     assertThrows[IllegalArgumentException] {
       val mockPredictor = new MyPredictor("xyz", inputDescriptor2)
@@ -63,7 +65,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
 
   test("PredictorSuite-testWithFlatArrays") {
 
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 
3, 2, 2)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 
3, 2, 2),
+      layout = Layout.NCHW))
     val inputData = Array.fill[Float](12)(1)
 
     // this will disposed at the end of the predict call on Predictor.
@@ -89,7 +92,8 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
   }
 
   test("PredictorSuite-testWithNDArray") {
-    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 
3, 2, 2)))
+    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 
3, 2, 2),
+      layout = Layout.NCHW))
     val inputData = NDArray.ones(Shape(1, 3, 2, 2))
 
     // this will disposed at the end of the predict call on Predictor.
diff --git a/scala-package/pom.xml b/scala-package/pom.xml
index 3511f4acfff..c221b4721d8 100644
--- a/scala-package/pom.xml
+++ b/scala-package/pom.xml
@@ -231,6 +231,7 @@
           <skipTests>${skipTests}</skipTests>
           
<reportsDirectory>${project.build.directory}/surefire-reports</reportsDirectory>
           <junitxml>.</junitxml>
+          <stdout>F</stdout>
           <filereports>WDF TestSuite.txt</filereports>
         </configuration>
         <executions>
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
index adc723ecdac..bf1b26e4b48 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
@@ -17,7 +17,8 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{DataBatch, NDArray, Shape, DataIter}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.apache.spark.mllib.regression.LabeledPoint
 
 import scala.collection.immutable.ListMap
@@ -25,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer
 
 /**
  * A helper converter for LabeledPoint
- * @author Yizhi Liu
  */
 class LabeledPointIter private[mxnet](
   private val points: Iterator[LabeledPoint],
@@ -115,15 +115,27 @@ class LabeledPointIter private[mxnet](
   }
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = {
     ListMap(labelName -> Shape(_batchSize))
   }
 
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = {
     ListMap(dataName -> dataShape)
   }
 
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    // TODO: need to allow user to specify DType and Layout
+    IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32, 
Layout.UNDEFINED))
+  }
+
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    // TODO: need to allow user to specify DType and Layout
+    IndexedSeq(new DataDesc(labelName, Shape(_batchSize), DType.Float32, 
Layout.UNDEFINED))
+  }
+
   /**
    * Get the number of padding examples
    * in current batch
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
index 339f7e2d76c..e3272a4066b 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
@@ -17,7 +17,8 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{NDArray, DataBatch}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.{DataBatch, NDArray}
 
 /**
  * Dispose only when 'disposeForce' called
@@ -27,7 +28,8 @@ class LongLivingDataBatch(
   override val data: IndexedSeq[NDArray],
   override val label: IndexedSeq[NDArray],
   override val index: IndexedSeq[Long],
-  override val pad: Int) extends DataBatch(data, label, index, pad) {
+  override val pad: Int) extends DataBatch(data, label, index, pad,
+  null, null, null) {
   override def dispose(): Unit = {}
   def disposeForce(): Unit = super.dispose()
 }
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
index 21329291cfb..a955ee74e7e 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
@@ -17,7 +17,8 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{NDArray, DataBatch, DataIter, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.apache.spark.mllib.linalg.Vector
 
 import scala.collection.immutable.ListMap
@@ -25,7 +26,6 @@ import scala.collection.mutable.ArrayBuffer
 
 /**
  * A temporary helper implementation for predicting Vectors
- * @author Yizhi Liu
  */
 class PointIter private[mxnet](
   private val points: Iterator[Vector],
@@ -114,15 +114,27 @@ class PointIter private[mxnet](
   }
 
   // The name and shape of label provided by this iterator
+  @deprecated
   override def provideLabel: ListMap[String, Shape] = {
     ListMap(labelName -> Shape(_batchSize))
   }
 
   // The name and shape of data provided by this iterator
+  @deprecated
   override def provideData: ListMap[String, Shape] = {
     ListMap(dataName -> dataShape)
   }
 
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    // TODO: Make DType, Layout configurable
+    IndexedSeq(new DataDesc(dataName, dataShape, DType.Float32, 
Layout.UNDEFINED))
+  }
+
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    IndexedSeq(new DataDesc(labelName, Shape(_batchSize),
+      DType.Float32, Layout.UNDEFINED))
+  }
+
   /**
    * Get the number of padding examples
    * in current batch


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to