lanking520 closed pull request #11844: [MXNET-689] add DataDesc type for the 
Scala Package
URL: https://github.com/apache/incubator-mxnet/pull/11844
 
 
   

This is a PR merged from a forked repository.
As GitHub hides the original diff on merge, it is displayed below for
the sake of provenance:

As this is a foreign pull request (from a fork), the diff is supplied
below (as it won't show otherwise due to GitHub magic):

diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
index 4458a7c7aeb..b015bd2169b 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/DType.scala
@@ -35,4 +35,13 @@ object DType extends Enumeration {
       case DType.Unknown => 0
     }
   }
+  private[mxnet] def getType(dtypeStr: String): DType = {
+    dtypeStr match {
+      case "UInt8" => DType.UInt8
+      case "Int32" => DType.Int32
+      case "Float16" => DType.Float16
+      case "Float32" => DType.Float32
+      case "Float64" => DType.Float64
+    }
+  }
 }
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
index 47fd4eee939..928606db8b7 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
@@ -19,13 +19,13 @@ package org.apache.mxnet
 
 import org.apache.mxnet.Base._
 import org.apache.mxnet.DType.DType
+import org.apache.mxnet.Layout.Layout
 import org.apache.mxnet.io.{MXDataIter, MXDataPack}
 import org.slf4j.LoggerFactory
 
 import scala.annotation.varargs
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ListBuffer
-
 /**
  * IO iterators for loading training & validation data
  */
@@ -106,7 +106,15 @@ object IO {
     checkCall(_LIB.mxDataIterCreateIter(handle, keys, vals, out))
     val dataName = params.getOrElse("data_name", "data")
     val labelName = params.getOrElse("label_name", "label")
-    new MXDataIter(out.value, dataName, labelName)
+    val dataLayout = params.getOrElse("dataLayout", "NCHW")
+    val labelLayout = params.getOrElse("labelLayout", "N")
+    val dataDType = params.getOrElse("dataDType", "Float32")
+    val labelDType = params.getOrElse("labelDType", "Int32")
+    new MXDataIter(out.value, dataName, labelName,
+      dataLayout = Layout.getLayout(dataLayout),
+      labelLayout = Layout.getLayout(labelLayout),
+      dataDType = DType.getType(dataDType),
+      labelDType = DType.getType(labelDType))
   }
 
   // Convert data into canonical form.
@@ -140,7 +148,11 @@ class DataBatch(val data: IndexedSeq[NDArray],
                 // use ListMap to indicate the order of data/label loading
                 // (must match the order of input data/label)
                 private val providedData: ListMap[String, Shape] = null,
-                private val providedLabel: ListMap[String, Shape] = null) {
+                private val providedLabel: ListMap[String, Shape] = null,
+                val dataDType: DType = Base.MX_REAL_TYPE,
+                val labelDType: DType = DType.Int32,
+                val dataLayout: Layout = Layout.NCHW,
+                val labelLayout: Layout = Layout.N) {
   /**
    * Dispose its data and labels
    * The object shall never be used after it is disposed.
@@ -170,6 +182,10 @@ object DataBatch {
     private var label: IndexedSeq[NDArray] = null
     private var index: IndexedSeq[Long] = null
     private var pad: Int = 0
+    private var dataLayout: Layout = Layout.NCHW
+    private var labelLayout: Layout = Layout.N
+    private var dataDType: DType = Base.MX_REAL_TYPE
+    private var labelDType: DType = DType.Int32
     private var bucketKey: AnyRef = null
     private var datatShapes: ListMap[String, Shape] = null
     private var labelShapes: ListMap[String, Shape] = null
@@ -216,6 +232,30 @@ object DataBatch {
       this
     }
 
+    /**
+      * Set the dtype.
+      * @param dataDType The dtype of the data, default is Float32
+      * @param labelDType The dtype of the label, default is Int32
+      * @return this
+      */
+    def setDType(dataDType: DType, labelDType: DType): Builder = {
+      this.dataDType = dataDType
+      this.labelDType = labelDType
+      this
+    }
+
+    /**
+      * Set the layout.
+      * @param dataLayout The layout of the data, default is NCHW
+      * @param labelLayout The layout of the label, default is N
+      * @return this
+      */
+    def setLayout(dataLayout: Layout, labelLayout: Layout): Builder = {
+      this.dataLayout = dataLayout
+      this.labelLayout = labelLayout
+      this
+    }
+
     /**
      * Set the bucket key, used for bucketing module.
      * @param bucketKey the bucket key related to this batch.
@@ -258,7 +298,8 @@ object DataBatch {
 
     def build(): DataBatch = {
       require(data != null, "data is required.")
-      new DataBatch(data, label, index, pad, bucketKey, datatShapes, 
labelShapes)
+      new DataBatch(data, label, index, pad, bucketKey, datatShapes, 
labelShapes,
+        dataDType, labelDType, dataLayout, labelLayout)
     }
   }
 }
@@ -280,7 +321,9 @@ abstract class DataIter extends Iterator[DataBatch] {
    */
   @throws(classOf[NoSuchElementException])
   def next(): DataBatch = {
-    new DataBatch(getData(), getLabel(), getIndex(), getPad())
+    new DataBatch(getData(), getLabel(), getIndex(), getPad(),
+      dataDType = getDType()._1, labelDType = getDType()._2,
+      dataLayout = getLayout()._1, labelLayout = getLayout()._2)
   }
 
   /**
@@ -302,6 +345,18 @@ abstract class DataIter extends Iterator[DataBatch] {
    */
   def getPad(): Int
 
+  /**
+    * Get the DType
+    * @return data and label DType of the DataIter
+    */
+  def getDType(): (DType, DType)
+
+  /**
+    * Get the layout
+    * @return data and label layout of the DataIter
+    */
+  def getLayout(): (Layout, Layout)
+
   /**
    * Get the index of current batch
    * @return the index of current batch
@@ -309,11 +364,19 @@ abstract class DataIter extends Iterator[DataBatch] {
   def getIndex(): IndexedSeq[Long]
 
   // The name and shape of data provided by this iterator
+  @deprecated
   def provideData: ListMap[String, Shape]
 
   // The name and shape of label provided by this iterator
+  @deprecated
   def provideLabel: ListMap[String, Shape]
 
+  // Provide type:DataDesc of the data
+  def provideDataDesc: IndexedSeq[DataDesc]
+
+  // Provide type:DataDesc of the label
+  def provideLabelDesc: IndexedSeq[DataDesc]
+
   // For bucketing io only
   // The bucket key for the default symbol.
   def defaultBucketKey: AnyRef = null
@@ -332,10 +395,11 @@ abstract class DataPack() extends Iterable[DataBatch] {
 
 // Named data desc description contains name, shape, type and other extended 
attributes.
 case class DataDesc(name: String, shape: Shape,
-                    dtype: DType = Base.MX_REAL_TYPE, layout: String = "NCHW") 
{
-  require(shape.length == layout.length, ("number of dimensions in shape :%d 
with" +
+                    dtype: DType = Base.MX_REAL_TYPE, layout: Layout = 
Layout.NCHW) {
+  val layoutStr = layout.toString
+  require(shape.length == layoutStr.length, ("number of dimensions in shape 
:%d with" +
     " shape: %s should match the length of the layout: %d with layout: %s").
-    format(shape.length, shape.toString, layout.length, layout))
+    format(shape.length, shape.toString, layoutStr.length, layoutStr))
 
   override def toString(): String = {
     s"DataDesc[$name,$shape,$dtype,$layout]"
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
new file mode 100644
index 00000000000..510cf88b59f
--- /dev/null
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/Layout.scala
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mxnet
+
+/**
+  * Layout type that represent what inside of a shape
+  * N Batch size
+  * C number of channels
+  * H height (image)
+  * W width (image)
+  * T temporal axis representing time (NLP)
+  */
+
+object Layout extends Enumeration {
+  type Layout = Value
+  val NCHW = Value("NCHW")
+  val TNC = Value("TNC")
+  val CHW = Value("CHW")
+  val NT = Value("NT")
+  val N = Value("N")
+
+  private[mxnet] def getLayout(layoutStr: String): Layout = {
+    layoutStr match {
+      case "NCHW" => NCHW
+      case "TNC" => TNC
+      case "CHW" => CHW
+      case "NT" => NT
+      case "N" => N
+      case _ => throw new RuntimeException(
+        s"Unknown $layoutStr defined!, please check Layout.scala")
+    }
+  }
+}
diff --git a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
index ee3e950512e..578f00a76f9 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/RecordIO.scala
@@ -28,9 +28,6 @@ import java.io.ByteArrayInputStream
 
 /**
  * Scala interface for read/write RecordIO data format
- *
- * @author Depeng Liang
- *
  * @param uri, path to recordIO file.
  * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
  */
@@ -144,7 +141,7 @@ object MXRecordIO {
  *
  * @author Depeng Liang
  *
- * @param idx_path, path to index file
+ * @param idxPath, path to index file
  * @param uri, path to recordIO file.
  * @param flag, RecordIO.IORead for reading or RecordIO.Write for writing.
  * @param keyType, data type for keys.
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
index 2a0c333ebf1..0bc61241aa6 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/MXDataIter.scala
@@ -18,8 +18,10 @@
 package org.apache.mxnet.io
 
 import org.apache.mxnet.Base._
-import org.apache.mxnet.{DataBatch, DataIter, DataPack, NDArray, Shape, 
WarnIfNotDisposed}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet._
 import org.apache.mxnet.IO._
+import org.apache.mxnet.Layout.Layout
 import org.slf4j.LoggerFactory
 
 import scala.collection.immutable.ListMap
@@ -31,7 +33,11 @@ import scala.collection.mutable.ListBuffer
  */
 private[mxnet] class MXDataIter(private[mxnet] val handle: DataIterHandle,
                                 dataName: String = "data",
-                                labelName: String = "label")
+                                labelName: String = "label",
+                                dataLayout: Layout = Layout.NCHW,
+                                labelLayout: Layout = Layout.N,
+                                dataDType: DType = DType.Float32,
+                                labelDType: DType = DType.Int32)
   extends DataIter with WarnIfNotDisposed {
 
   private val logger = LoggerFactory.getLogger(classOf[MXDataIter])
@@ -41,21 +47,33 @@ private[mxnet] class MXDataIter(private[mxnet] val handle: 
DataIterHandle,
   // fix me if any better way found)
   private var currentBatch: DataBatch = null
 
-  private val (_provideData: ListMap[String, Shape],
-               _provideLabel: ListMap[String, Shape],
-               _batchSize: Int) =
+  private val (_provideDataDesc: IndexedSeq[DataDesc],
+               _provideLabelDesc: IndexedSeq[DataDesc],
+                _provideData: ListMap[String, Shape],
+                _provideLabel: ListMap[String, Shape],
+               _batchSize: Int) = {
     if (hasNext) {
       iterNext()
       val data = currentBatch.data(0)
       val label = currentBatch.label(0)
+      val dataDType = currentBatch.dataDType
+      val labelDType = currentBatch.labelDType
+      val dataLayout = currentBatch.dataLayout
+      val labelLayout = currentBatch.labelLayout
       // properties
-      val res = (ListMap(dataName -> data.shape), ListMap(labelName -> 
label.shape), data.shape(0))
+      val res = (IndexedSeq(new DataDesc(dataName, data.shape, dataDType, 
dataLayout)),
+        IndexedSeq(new DataDesc(labelName, label.shape, labelDType, 
labelLayout)),
+        ListMap(dataName -> data.shape),
+        ListMap(labelName -> label.shape),
+        data.shape(0))
       currentBatch.dispose()
       reset()
       res
     } else {
-      (null, null, 0)
+      (null, null, null, null, 0)
     }
+  }
+
 
   private var disposed = false
   protected def isDisposed = disposed
@@ -101,10 +119,13 @@ private[mxnet] class MXDataIter(private[mxnet] val 
handle: DataIterHandle,
   private def iterNext(): Boolean = {
     val next = new RefInt
     checkCall(_LIB.mxDataIterNext(handle, next))
-    currentBatch = null
     if (next.value > 0) {
       currentBatch = new DataBatch(data = getData(), label = getLabel(),
-        index = getIndex(), pad = getPad())
+        index = getIndex(), pad = getPad(),
+        dataDType = getDType()._1, labelDType = getDType()._2,
+        dataLayout = getLayout()._1, labelLayout = getLayout()._2)
+    } else {
+      currentBatch = null
     }
     next.value > 0
   }
@@ -151,12 +172,30 @@ private[mxnet] class MXDataIter(private[mxnet] val 
handle: DataIterHandle,
     out.value
   }
 
+  /**
+    * Get the DType
+    * @return DType
+    */
+  def getDType(): (DType, DType) = (dataDType, labelDType)
+
+  /**
+    * Get the layout
+    * @return layout
+    */
+  def getLayout(): (Layout, Layout) = (dataLayout, labelLayout)
+
   // The name and shape of data provided by this iterator
   override def provideData: ListMap[String, Shape] = _provideData
 
   // The name and shape of label provided by this iterator
   override def provideLabel: ListMap[String, Shape] = _provideLabel
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def hasNext: Boolean = {
     if (currentBatch != null) {
       true
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
index 10461315c19..e0ac7e6d526 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/NDArrayIter.scala
@@ -20,6 +20,8 @@ package org.apache.mxnet.io
 import java.util.NoSuchElementException
 
 import org.apache.mxnet.Base._
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.Layout.Layout
 import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 
@@ -42,29 +44,36 @@ import scala.collection.immutable.ListMap
 class NDArrayIter(data: IndexedSeq[(String, NDArray)],
                   label: IndexedSeq[(String, NDArray)],
                   private val dataBatchSize: Int, shuffle: Boolean,
-                  lastBatchHandle: String) extends DataIter {
-
+                  lastBatchHandle: String,
+                  dataDType: DType, labelDType: DType,
+                  dataLayout: Layout, labelLayout: Layout) extends DataIter {
+// scalastyle:off
   /**
-   * @param data Specify the data. Data names will be data_0, data_1, ..., etc.
-   * @param label Same as data, but is not fed to the model during testing.
-   *              Label names will be label_0, label_1, ..., etc.
-   * @param dataBatchSize Batch Size
-   * @param shuffle Whether to shuffle the data
-   * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle the 
last batch
-   *
-   * This iterator will pad, discard or roll over the last batch if
-   * the size of data does not match batch_size. Roll over is intended
-   * for training and can cause problems if used for prediction.
-   */
+    * @param data Specify the data. Data names will be data_0, data_1, ..., 
etc.
+    * @param label Same as data, but is not fed to the model during testing.
+    *              Label names will be label_0, label_1, ..., etc.
+    * @param dataBatchSize Batch Size
+    * @param shuffle Whether to shuffle the data
+    * @param lastBatchHandle "pad", "discard" or "roll_over". How to handle 
the last batch
+    *
+    * This iterator will pad, discard or roll over the last batch if
+    * the size of data does not match batch_size. Roll over is intended
+    * for training and can cause problems if used for prediction.
+    */
   def this(data: IndexedSeq[NDArray], label: IndexedSeq[NDArray] = 
IndexedSeq.empty,
            dataBatchSize: Int = 1, shuffle: Boolean = false,
            lastBatchHandle: String = "pad",
-           dataName: String = "data", labelName: String = "label") {
+           dataName: String = "data", labelName: String = "label",
+           dataDType: DType = Base.MX_REAL_TYPE,
+           labelDType: DType = DType.Int32,
+           dataLayout: Layout = Layout.NCHW,
+           labelLayout : Layout = Layout.N) {
     this(IO.initData(data, allowEmpty = false, dataName),
       IO.initData(label, allowEmpty = true, labelName),
-      dataBatchSize, shuffle, lastBatchHandle)
+      dataBatchSize, shuffle, lastBatchHandle, dataDType, labelDType,
+      dataLayout, labelLayout)
   }
-
+// scalastyle:on
   private val logger = LoggerFactory.getLogger(classOf[NDArrayIter])
 
   val (initData: IndexedSeq[(String, NDArray)], initLabel: IndexedSeq[(String, 
NDArray)]) = {
@@ -107,6 +116,14 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
     (pData, pLabel)
   }
 
+  private val (_provideDataDesc: IndexedSeq[DataDesc],
+  _provideLabelDesc: IndexedSeq[DataDesc]) = {
+    val pData = initData.map(ele => new DataDesc(ele._1, getShape(ele)._2, 
dataDType, dataLayout))
+    val pLabel = initLabel.map(ele =>
+      new DataDesc(ele._1, getShape(ele)._2, labelDType, labelLayout))
+    (pData, pLabel)
+  }
+
   /**
    * get shape via dataBatchSize
    * @param dataItem
@@ -148,7 +165,9 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
   override def next(): DataBatch = {
     if (hasNext) {
       cursor += dataBatchSize
-      new DataBatch(getData(), getLabel(), getIndex(), getPad())
+      new DataBatch(getData(), getLabel(), getIndex(), getPad(),
+        dataDType = getDType()._1, labelDType = getDType()._2,
+        dataLayout = getLayout()._1, labelLayout = getLayout()._2)
     } else {
       throw new NoSuchElementException
     }
@@ -223,12 +242,34 @@ class NDArrayIter(data: IndexedSeq[(String, NDArray)],
     }
   }
 
+  /**
+    * Get the DType
+    * @return DType
+    */
+  def getDType(): (DType, DType) = {
+    (dataDType, labelDType)
+  }
+
+  /**
+    * Get the layout
+    * @return layout
+    */
+  def getLayout(): (Layout, Layout) = {
+    (dataLayout, labelLayout)
+  }
+
   // The name and shape of data provided by this iterator
   override def provideData: ListMap[String, Shape] = _provideData
 
   // The name and shape of label provided by this iterator
   override def provideLabel: ListMap[String, Shape] = _provideLabel
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def batchSize: Int = dataBatchSize
 }
 
@@ -242,6 +283,10 @@ object NDArrayIter {
     private var label: IndexedSeq[(String, NDArray)] = IndexedSeq.empty
     private var dataBatchSize: Int = 1
     private var lastBatchHandle: String = "pad"
+    private var dataLayout: Layout = Layout.NCHW
+    private var labelLayout: Layout = Layout.N
+    private var dataDType: DType = Base.MX_REAL_TYPE
+    private var labelDType: DType = DType.Int32
 
     /**
      * Add one data input with its name.
@@ -285,12 +330,37 @@ object NDArrayIter {
       this
     }
 
+    /**
+      * Set the dtype.
+      * @param dataDType The dtype of the data, default is Float32
+      * @param labelDType The dtype of the label, default is Int32
+      * @return this
+      */
+    def setDType(dataDType: DType, labelDType: DType): Builder = {
+      this.dataDType = dataDType
+      this.labelDType = labelDType
+      this
+    }
+
+    /**
+      * Set the layout.
+      * @param dataLayout The layout of the data, default is NCHW
+      * @param labelLayout The layout of the label, default is N
+      * @return this
+      */
+    def setLayout(dataLayout: Layout, labelLayout: Layout): Builder = {
+      this.dataLayout = dataLayout
+      this.labelLayout = labelLayout
+      this
+    }
+
     /**
      * Build the NDArrayIter object.
      * @return the built object.
      */
     def build(): NDArrayIter = {
-      new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle)
+      new NDArrayIter(data, label, dataBatchSize, false, lastBatchHandle,
+        dataDType, labelDType, dataLayout, labelLayout)
     }
   }
 }
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
index c0c0d1793b5..bcfb1d04327 100644
--- 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
+++ 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/PrefetchingIter.scala
@@ -17,10 +17,13 @@
 
 package org.apache.mxnet.io
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 import java.util.concurrent.Semaphore
 
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.Layout.Layout
+
 import scala.collection.immutable.ListMap
 
 /**
@@ -68,6 +71,42 @@ class PrefetchingIter(
     }
   }
 
+  private val _provideDataDesc: IndexedSeq[DataDesc] = {
+    if (dataNames == null) {
+      iters.map(_.provideDataDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, 
elem) =>
+        acc ++ elem
+      }
+    } else {
+      iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2))
+        .map(m =>
+          m._1.map(t =>
+            new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout)
+          )
+        )
+        .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
+          acc ++ elem
+        }
+    }
+  }
+
+  private val _provideLabelDesc: IndexedSeq[DataDesc] = {
+    if (dataNames == null) {
+      iters.map(_.provideLabelDesc).foldLeft(IndexedSeq[DataDesc]()) { (acc, 
elem) =>
+        acc ++ elem
+      }
+    } else {
+      iters.zipWithIndex.map(tu => (tu._1.provideDataDesc, tu._2))
+        .map(m =>
+          m._1.map(t =>
+            new DataDesc(dataNames(m._2)(t.name), t.shape, t.dtype, t.layout)
+          )
+        )
+        .foldLeft(IndexedSeq[DataDesc]()) { (acc, elem) =>
+          acc ++ elem
+        }
+    }
+  }
+
   private val _batchSize: Int = this._provideData.toList(0)._2(0)
   private val dataReady: IndexedSeq[Semaphore] =
                                         (0 until iters.length).map(i => new 
Semaphore(0))
@@ -132,19 +171,41 @@ class PrefetchingIter(
    */
   override def getIndex(): IndexedSeq[Long] = currentBatch.index
 
-  // The name and shape of label provided by this iterator
-  override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
   /**
-   * get the number of padding examples
-   * in current batch
-   * @return number of padding examples in current batch
-   */
+    * get the number of padding examples
+    * in current batch
+    * @return number of padding examples in current batch
+    */
   override def getPad(): Int = this.currentBatch.pad
 
+  /**
+    * Get the DType
+    * @return DType
+    */
+  def getDType(): (DType, DType) = {
+    (currentBatch.dataDType, currentBatch.labelDType)
+  }
+
+  /**
+    * Get the layout
+    * @return layout
+    */
+  def getLayout(): (Layout, Layout) = {
+    (currentBatch.dataLayout, currentBatch.labelLayout)
+  }
+
+  // The name and shape of label provided by this iterator
+  override def provideLabel: ListMap[String, Shape] = this._provideLabel
+
   // The name and shape of data provided by this iterator
   override def provideData: ListMap[String, Shape] = this._provideData
 
+  // Provide type:DataDesc of the data
+  override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+  // Provide type:DataDesc of the label
+  override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
   override def hasNext: Boolean = {
     for (e <- dataReady) e.acquire()
     if (nextBatch(0) == null) {
@@ -161,9 +222,13 @@ class PrefetchingIter(
       val datas = for (batch <- nextBatch) yield batch.data
       val labels = for (batch <- nextBatch) yield batch.label
       currentBatch = new DataBatch(datas.toIndexedSeq.flatten,
-                                      labels.toIndexedSeq.flatten,
-                                      nextBatch(0).index,
-                                      nextBatch(0).pad)
+        labels.toIndexedSeq.flatten,
+        nextBatch(0).index,
+        nextBatch(0).pad,
+        dataLayout = nextBatch(0).dataLayout,
+        labelLayout = nextBatch(0).labelLayout,
+        dataDType = nextBatch(0).dataDType,
+        labelDType = nextBatch(0).labelDType)
       for (e <- dataTaken) e.release()
       true
     }
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala 
b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
index 75d88d1ae72..5de42290154 100644
--- a/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
+++ b/scala-package/core/src/main/scala/org/apache/mxnet/io/ResizeIter.scala
@@ -19,7 +19,9 @@ package org.apache.mxnet.io
 
 import java.util.NoSuchElementException
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.Layout.Layout
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
 
 import scala.collection.immutable.ListMap
@@ -128,6 +130,22 @@ class ResizeIter(
     currentBatch.pad
   }
 
+  /**
+    * Get the DType
+    * @return DType
+    */
+  def getDType(): (DType, DType) = {
+    (currentBatch.dataDType, currentBatch.labelDType)
+  }
+
+  /**
+    * Get the layout
+    * @return layout
+    */
+  def getLayout(): (Layout, Layout) = {
+    (currentBatch.dataLayout, currentBatch.labelLayout)
+  }
+
   override def batchSize: Int = {
     dataIter.batchSize
   }
@@ -141,4 +159,14 @@ class ResizeIter(
   override def provideLabel: ListMap[String, Shape] = {
     dataIter.provideLabel
   }
+
+  // The name and shape of data provided by this iterator
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    dataIter.provideDataDesc
+  }
+
+  // The name and shape of label provided by this iterator
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    dataIter.provideLabelDesc
+  }
 }
diff --git 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
index 1494dc84035..bb497e47d13 100644
--- 
a/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
+++ 
b/scala-package/core/src/main/scala/org/apache/mxnet/module/DataParallelExecutorGroup.scala
@@ -330,7 +330,7 @@ class DataParallelExecutorGroup private[module](
    */
   private def decideSlices(dataShapes: Seq[DataDesc]): Seq[Int] = {
     require(dataShapes.size > 0)
-    val majorAxis = dataShapes.map(data => 
DataDesc.getBatchAxis(Option(data.layout)))
+    val majorAxis = dataShapes.map(data => 
DataDesc.getBatchAxis(Option(data.layoutStr)))
 
     for ((dataDesc, axis) <- dataShapes.zip(majorAxis)) {
       if (axis != -1) {
diff --git a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
index 1b922b3c05b..919c94b4b81 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/IOSuite.scala
@@ -38,7 +38,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
       "shuffle" -> "1",
       "flat" -> "1",
       "silent" -> "0",
-      "seed" -> "10"
+      "seed" -> "10",
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"
     )
 
     val mnistPack = IO.MNISTPack(params)
@@ -99,7 +101,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
       "data_shape" -> "(3,28,28)",
       "batch_size" -> "100",
       "preprocess_threads" -> "4",
-      "prefetch_buffer" -> "1"
+      "prefetch_buffer" -> "1",
+      "dataLayout" -> "NCHW",
+      "labelLayout" -> "N"
     )
     val imgRecIter = IO.ImageRecordIter(params)
     val nBatch = 500
@@ -145,7 +149,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
       "shuffle" -> "1",
       "flat" -> "1",
       "silent" -> "0",
-      "seed" -> "10"
+      "seed" -> "10",
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"
     )
 
     val mnistIter = IO.MNISTIter(params)
@@ -182,7 +188,9 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
       "shuffle" -> "1",
       "flat" -> "1",
       "silent" -> "0",
-      "seed" -> "10"
+      "seed" -> "10",
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"
     )
 
     val mnistPack1 = IO.MNISTPack(params)
@@ -243,7 +251,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     val batchLabel = NDArray.ones(Shape(Array(128, 1)))
 
     // test pad
-    val dataIter0 = new NDArrayIter(data, label, 128, false, "pad")
+    val dataIter0 = new NDArrayIter(data, label, 128, false, "pad",
+      dataLayout = Layout.TNC, labelLayout = Layout.NT)
     var batchCount = 0
     val nBatch0 = 8
     while(dataIter0.hasNext) {
@@ -262,6 +271,7 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
       .addData("data0", data(0)).addData("data1", data(1))
       .addLabel("label", label(0))
       .setBatchSize(128)
+      .setLayout(Layout.TNC, Layout.NT)
       .setLastBatchHandle("discard").build()
     val nBatch1 = 7
     batchCount = 0
@@ -277,7 +287,8 @@ class IOSuite extends FunSuite with BeforeAndAfterAll {
     assert(batchCount === nBatch1)
 
     // test empty label (for prediction)
-    val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, 
lastBatchHandle = "discard")
+    val dataIter2 = new NDArrayIter(data = data, dataBatchSize = 128, 
lastBatchHandle = "discard",
+      dataLayout = Layout.TNC)
     batchCount = 0
     while(dataIter2.hasNext) {
       val tBatch = dataIter2.next()
diff --git 
a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala 
b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
index 8234568d7d9..10c547c5a9b 100644
--- a/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
+++ b/scala-package/core/src/test/scala/org/apache/mxnet/ModuleSuite.scala
@@ -33,7 +33,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
     val mod = new Module(sym, IndexedSeq("data"), null,
       contexts = Array(Context.cpu(0), Context.cpu(1)))
-    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, dType, "TNC")))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, dType, 
Layout.TNC)))
     mod.initParams()
     mod.forward(new DataBatch(
       data = IndexedSeq(NDArray.ones(dShape, dtype = dType)),
@@ -57,9 +57,9 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
       .setContext(Context.cpu(0), Context.cpu(1))
       .build()
     mod.bind(dataShapes = IndexedSeq(
-      DataDesc("b", Shape(5, 5), layout = "NT"),
-      DataDesc("c", Shape(5, 5), layout = "NT"),
-      DataDesc("a", Shape(5, 5), layout = "NT")),
+      DataDesc("b", Shape(5, 5), layout = Layout.NT),
+      DataDesc("c", Shape(5, 5), layout = Layout.NT),
+      DataDesc("a", Shape(5, 5), layout = Layout.NT)),
       inputsNeedGrad = true
     )
     mod.initParams()
@@ -87,7 +87,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     val dShape = Shape(3, 8, 7)
     val mod = new Module(sym, IndexedSeq("data"), null,
       contexts = Array(Context.cpu(0), Context.cpu(1)))
-    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "TNC")))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = 
Layout.TNC)))
     mod.initParams()
     mod.forward(new DataBatch(
       data = IndexedSeq(NDArray.ones(dShape)),
@@ -110,14 +110,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll 
{
 
     // single device
     var mod = new Module(sym, IndexedSeq("data"), null)
-    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = 
"NT")))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = 
Layout.NT)))
     mod.initParams()
     mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 
0.9f))
     mod.update()
     mod.saveCheckpoint("test", 0, saveOptStates = true)
 
     var mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
-    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = 
"NT")))
+    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = 
Layout.NT)))
     mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 
0.9f))
     assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
     mapEqu(mod.getParams._1, mod2.getParams._1)
@@ -125,14 +125,14 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll 
{
     // multi device
     mod = new Module(sym, IndexedSeq("data"), null,
       contexts = Array(Context.cpu(0), Context.cpu(1)))
-    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = 
"NT" )))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = 
Layout.NT)))
     mod.initParams()
     mod.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 
0.9f))
     mod.update()
     mod.saveCheckpoint("test", 0, saveOptStates = true)
 
     mod2 = Module.loadCheckpoint("test", 0, loadOptimizerStates = true)
-    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = 
"NT")))
+    mod2.bind(dataShapes = IndexedSeq(DataDesc("data", Shape(10, 10), layout = 
Layout.NT)))
     mod2.initOptimizer(optimizer = new SGD(learningRate = 0.1f, momentum = 
0.9f))
     assert(mod.getSymbol.toJson == mod2.getSymbol.toJson)
     mapEqu(mod.getParams._1, mod2.getParams._1)
@@ -145,7 +145,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     var dShape = Shape(7, 20)
     val mod = new Module(sym, IndexedSeq("data"), null,
       contexts = Array(Context.cpu(0), Context.cpu(1)))
-    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = "NT")))
+    mod.bind(dataShapes = IndexedSeq(DataDesc("data", dShape, layout = 
Layout.NT)))
     mod.initParams()
     mod.initOptimizer(optimizer = new SGD(learningRate = 1f))
 
@@ -159,7 +159,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
     // reshape module
     dShape = Shape(14, 20)
-    mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
+    mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = Layout.NT)))
     mod.forward(new DataBatch(
       data = IndexedSeq(NDArray.ones(dShape)),
       label = null, index = null, pad = 0))
@@ -170,7 +170,7 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
     // return to original binded shape
     dShape = Shape(7, 20)
-    mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = "NT")))
+    mod.reshape(IndexedSeq(DataDesc("data", dShape, layout = Layout.NT)))
     mod.forward(new DataBatch(
       data = IndexedSeq(NDArray.ones(dShape)),
       label = null, index = null, pad = 0))
@@ -184,7 +184,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
     val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
     val trainData = new NDArrayIter(
-      IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
+      IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label",
+      dataLayout = Layout.NCHW, labelLayout = Layout.NCHW)
 
     // symbols
     var x = Symbol.Variable("data")
@@ -234,7 +235,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
     val data = NDArray.array(Array(0.05f, 0.1f), Shape(1, 1, 1, 2))
     val label = NDArray.array(Array(0.01f, 0.99f), Shape(1, 1, 1, 2))
     val trainData = new NDArrayIter(
-      IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label")
+      IndexedSeq(data), IndexedSeq(label), labelName = "softmax_label",
+      dataLayout = Layout.NCHW, labelLayout = Layout.NCHW)
 
     // symbols
     var x = Symbol.Variable("data")
@@ -309,8 +311,8 @@ class ModuleSuite extends FunSuite with BeforeAndAfterAll {
 
     val mod = new Module(sym, IndexedSeq("data1", "data2"))
     mod.bind(dataShapes = IndexedSeq(
-      DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = "NCHW")),
-      labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout 
= "N")))
+      DataDesc("data1", dShape1), DataDesc("data2", dShape2, layout = 
Layout.NCHW)),
+      labelShapes = Option(IndexedSeq(DataDesc("softmax_label", lShape, layout 
= Layout.N)))
     )
     mod.initParams()
     mod.initOptimizer(optimizer = new SGD(learningRate = 0.01f))
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala
index d61269c131f..230c56e3867 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/customop/Data.scala
@@ -20,6 +20,7 @@ package org.apache.mxnetexamples.customop
 import org.apache.mxnet.{DataIter, IO, Shape}
 
 object Data {
+
   // return train and val iterators for mnist
   def mnistIterator(dataPath: String, batchSize: Int, inputShape: Shape): 
(DataIter, DataIter) = {
     val flat = if (inputShape.length == 3) "False" else "True"
@@ -29,7 +30,9 @@ object Data {
       "input_shape" -> inputShape.toString(),
       "batch_size" -> s"$batchSize",
       "shuffle" -> "True",
-      "flat" -> flat
+      "flat" -> flat,
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"
     )
     val trainDataIter = IO.MNISTIter(trainParams)
     val testParams = Map(
@@ -37,7 +40,9 @@ object Data {
       "label" -> s"$dataPath/t10k-labels-idx1-ubyte",
       "input_shape" -> inputShape.toString(),
       "batch_size" -> s"$batchSize",
-      "flat" -> flat
+      "flat" -> flat,
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"
     )
     val testDataIter = IO.MNISTIter(testParams)
     (trainDataIter, testDataIter)
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
index 6186989b74f..f2efcfbd799 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/gan/GanMnist.scala
@@ -130,7 +130,9 @@ object GanMnist {
       "label" -> s"$dataPath/train-labels-idx1-ubyte",
       "input_shape" -> s"(1, 28, 28)",
       "batch_size" -> s"$batchSize",
-      "shuffle" -> "True"
+      "shuffle" -> "True",
+      "dataLayout" -> "NCHW",
+      "labelLayout" -> "N"
     )
 
     val mnistIter = IO.MNISTIter(params)
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
index b0ecc7d29cc..d8b79dc2b73 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/imclassification/TrainMnist.scala
@@ -77,7 +77,9 @@ object TrainMnist {
       "shuffle" -> "True",
       "flat" -> flat,
       "num_parts" -> kv.numWorkers.toString,
-      "part_index" -> kv.`rank`.toString))
+      "part_index" -> kv.`rank`.toString,
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"))
 
     val eval = IO.MNISTIter(Map(
       "image" -> (dataDir + "t10k-images-idx3-ubyte"),
@@ -87,7 +89,9 @@ object TrainMnist {
       "batch_size" -> batchSize.toString,
       "flat" -> flat,
       "num_parts" -> kv.numWorkers.toString,
-      "part_index" -> kv.`rank`.toString))
+      "part_index" -> kv.`rank`.toString,
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"))
 
     (train, eval)
   }
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
index e886b908ba2..e234e090267 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/imageclassifier/ImageClassifierExample.scala
@@ -17,10 +17,9 @@
 
 package org.apache.mxnetexamples.infer.imageclassifier
 
-import org.apache.mxnet.Shape
+import org.apache.mxnet._
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
-import org.apache.mxnet.{DType, DataDesc, Context}
 import org.apache.mxnet.infer.ImageClassifier
 
 import scala.collection.JavaConverters._
@@ -46,7 +45,7 @@ object ImageClassifierExample {
     val dType = DType.Float32
     val inputShape = Shape(1, 3, 224, 224)
 
-    val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, 
"NCHW"))
+    val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, 
Layout.NCHW))
 
     // Create object of ImageClassifier class
     val imgClassifier: ImageClassifier = new
@@ -67,7 +66,7 @@ object ImageClassifierExample {
     val dType = DType.Float32
     val inputShape = Shape(1, 3, 224, 224)
 
-    val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, 
"NCHW"))
+    val inputDescriptor = IndexedSeq(DataDesc("data", inputShape, dType, 
Layout.NCHW))
 
     // Create object of ImageClassifier class
     val imgClassifier: ImageClassifier = new
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
index c9707cb3ff6..2b8e49b8042 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/infer/objectdetector/SSDClassifierExample.scala
@@ -19,7 +19,7 @@ package org.apache.mxnetexamples.infer.objectdetector
 
 import java.io.File
 
-import org.apache.mxnet.{Context, DType, DataDesc, Shape}
+import org.apache.mxnet._
 import org.apache.mxnet.infer._
 import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
@@ -58,7 +58,7 @@ object SSDClassifierExample {
     val inputShape = Shape(1, 3, 512, 512)
     // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
     val outputShape = Shape(1, 6132, 6)
-    val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, 
"NCHW"))
+    val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, 
Layout.NCHW))
     val img = ImageClassifier.loadImageFromFile(inputImagePath)
     val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, 
context)
     val output = objDetector.imageObjectDetect(img, Some(3))
@@ -73,7 +73,7 @@ object SSDClassifierExample {
     val inputShape = Shape(1, 3, 512, 512)
     // ssd detections, numpy.array([[id, score, x1, y1, x2, y2]...])
     val outputShape = Shape(1, 6132, 6)
-    val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, 
"NCHW"))
+    val inputDescriptors = IndexedSeq(DataDesc("data", inputShape, dType, 
Layout.NCHW))
     val objDetector = new ObjectDetector(modelPathPrefix, inputDescriptors, 
context)
     // Loading batch of images from the directory path
     val batchFiles = generateBatches(inputImageDir, 20)
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
index bb17046b8b2..2b0a20b40e7 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/Data.scala
@@ -21,9 +21,6 @@ import org.apache.mxnet.Shape
 import org.apache.mxnet.IO
 import org.apache.mxnet.DataIter
 
-/**
- * @author Depeng Liang
- */
 object Data {
 
   // return train and val iterators for mnist
@@ -35,7 +32,9 @@ object Data {
       "input_shape" -> inputShape.toString(),
       "batch_size" -> s"$batchSize",
       "shuffle" -> "True",
-      "flat" -> flat
+      "flat" -> flat,
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"
     )
     val trainDataIter = IO.MNISTIter(trainParams)
     val testParams = Map(
@@ -43,7 +42,9 @@ object Data {
       "label" -> s"$dataPath/t10k-labels-idx1-ubyte",
       "input_shape" -> inputShape.toString(),
       "batch_size" -> s"$batchSize",
-      "flat" -> flat
+      "flat" -> flat,
+      "dataLayout" -> "NT",
+      "labelLayout" -> "N"
     )
     val testDataIter = IO.MNISTIter(testParams)
     (trainDataIter, testDataIter)
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
index 1270af3c45b..323dd5a95d4 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/multitask/ExampleMultiTask.scala
@@ -24,26 +24,18 @@ import org.kohsuke.args4j.{CmdLineParser, Option}
 import org.slf4j.LoggerFactory
 
 import scala.collection.JavaConverters._
-
 import org.apache.commons.io.FileUtils
-
-import org.apache.mxnet.Symbol
-import org.apache.mxnet.DataIter
-import org.apache.mxnet.DataBatch
-import org.apache.mxnet.NDArray
-import org.apache.mxnet.Shape
-import org.apache.mxnet.EvalMetric
-import org.apache.mxnet.Context
-import org.apache.mxnet.Xavier
+import org.apache.mxnet.{Context, DataBatch, DataDesc, DataIter, EvalMetric, 
NDArray, Shape, Symbol, Xavier}
+import org.apache.mxnet.DType.DType
 import org.apache.mxnet.optimizer.RMSProp
 import org.apache.mxnet.Executor
+import org.apache.mxnet.Layout.Layout
 
 import scala.collection.immutable.ListMap
 import scala.sys.process.Process
 
 /**
  * Example of multi-task
- * @author Depeng Liang
  */
 object ExampleMultiTask {
   private val logger = LoggerFactory.getLogger(classOf[ExampleMultiTask])
@@ -72,9 +64,10 @@ object ExampleMultiTask {
         val batch = this.dataIter.next()
         val label = batch.label(0)
         new DataBatch(batch.data,
-                                     IndexedSeq(label, label),
-                                     batch.index,
-                                     batch.pad)
+          IndexedSeq(label, label),
+          batch.index,
+          batch.pad, dataDType = batch.dataDType, labelDType = 
batch.labelDType,
+          dataLayout = batch.dataLayout, labelLayout = batch.labelLayout)
       } else {
         throw new NoSuchElementException
       }
@@ -116,6 +109,16 @@ object ExampleMultiTask {
               "softmax2_label" -> provideLabel(0)._2)
     }
 
+    // The name and shape of label provided by this iterator
+    override def provideLabelDesc: IndexedSeq[DataDesc] = {
+      val head = this.dataIter.provideLabelDesc(0)
+      // Different labels should be used here for actual application
+      IndexedSeq(
+        new DataDesc("softmax1_label", head.shape, head.dtype, head.layout),
+        new DataDesc("softmax2_label", head.shape, head.dtype, head.layout)
+      )
+    }
+
     /**
      * get the number of padding examples
      * in current batch
@@ -123,9 +126,15 @@ object ExampleMultiTask {
      */
     override def getPad(): Int = this.dataIter.getPad()
 
+    override def getDType(): (DType, DType) = this.dataIter.getDType()
+
+    override def getLayout(): (Layout, Layout) = this.dataIter.getLayout()
+
     // The name and shape of data provided by this iterator
     override def provideData: ListMap[String, Shape] = 
this.dataIter.provideData
 
+    override def provideDataDesc: IndexedSeq[DataDesc] = 
this.dataIter.provideDataDesc
+
     override def hasNext: Boolean = this.dataIter.hasNext
   }
 
diff --git 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
index f0eae6890c5..2bf9654802e 100644
--- 
a/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
+++ 
b/scala-package/examples/src/main/scala/org/apache/mxnetexamples/rnn/BucketIo.scala
@@ -18,8 +18,11 @@
 
 package org.apache.mxnetexamples.rnn
 
-import org.apache.mxnet.{DataBatch, DataIter, NDArray, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.Layout.Layout
+import org.apache.mxnet._
 import org.slf4j.LoggerFactory
+
 import scala.collection.immutable.ListMap
 import scala.collection.mutable.ArrayBuffer
 import scala.io.Source
@@ -95,7 +98,11 @@ object BucketIo {
       path: String, vocab: Map[String, Int], var buckets: IndexedSeq[Int],
       _batchSize: Int, private val initStates: IndexedSeq[(String, (Int, 
Int))],
       seperateChar: String = " <eos> ", text2Id: Text2Id = defaultText2Id,
-      readContent: ReadContent = defaultReadContent) extends DataIter {
+      readContent: ReadContent = defaultReadContent,
+      dataLayout: Layout = Layout.NT,
+      labelLayout: Layout = Layout.N,
+      dataDType : DType = DType.Float32,
+      labelDType: DType = DType.Int32) extends DataIter {
 
     private val logger = LoggerFactory.getLogger(classOf[BucketSentenceIter])
 
@@ -165,8 +172,18 @@ object BucketIo {
     private val _provideData = { val tmp = ListMap("data" -> Shape(_batchSize, 
_defaultBucketKey))
       tmp ++ initStates.map(x => x._1 -> Shape(x._2._1, x._2._2))
     }
+
     private val _provideLabel = ListMap("softmax_label" -> Shape(_batchSize, 
_defaultBucketKey))
 
+    private val _provideDataDesc = {
+      val tmp = IndexedSeq(new DataDesc("data",
+        Shape(_batchSize, _defaultBucketKey), dataDType, dataLayout))
+      tmp ++ initStates.map(x => new DataDesc(x._1, Shape(x._2._1, x._2._2), 
dataDType, dataLayout))
+    }
+
+    private val _provideLabelDesc = IndexedSeq(new DataDesc("softmax_label",
+      Shape(_batchSize, _defaultBucketKey), labelDType, labelLayout))
+
     private var iBucket = 0
 
     override def next(): DataBatch = {
@@ -197,7 +214,9 @@ object BucketIo {
                     getIndex(),
                     getPad(),
                     this.buckets(bucketIdx).asInstanceOf[AnyRef],
-                    batchProvideData, batchProvideLabel)
+                    batchProvideData, batchProvideLabel,
+                    getDType()._1, getDType()._2,
+                    getLayout()._1, getLayout()._2)
     }
 
     /**
@@ -228,19 +247,29 @@ object BucketIo {
      */
     override def getIndex(): IndexedSeq[Long] = IndexedSeq[Long]()
 
-    // The name and shape of label provided by this iterator
-    override def provideLabel: ListMap[String, Shape] = this._provideLabel
-
     /**
-     * get the number of padding examples
-     * in current batch
-     * @return number of padding examples in current batch
-     */
+      * get the number of padding examples
+      * in current batch
+      * @return number of padding examples in current batch
+      */
     override def getPad(): Int = 0
 
+    override def getDType(): (DType, DType) = (dataDType, labelDType)
+
+    override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout)
+
+    // The name and shape of label provided by this iterator
+    override def provideLabel: ListMap[String, Shape] = this._provideLabel
+
     // The name and shape of data provided by this iterator
     override def provideData: ListMap[String, Shape] = this._provideData
 
+    // Provide type:DataDesc of the data
+    override def provideDataDesc: IndexedSeq[DataDesc] = _provideDataDesc
+
+    // Provide type:DataDesc of the label
+    override def provideLabelDesc: IndexedSeq[DataDesc] = _provideLabelDesc
+
     override def hasNext: Boolean = {
       iBucket < bucketPlan.length
     }
diff --git 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
index 8d31d1f6b3d..db5923efe5c 100644
--- 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
+++ 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/ImageClassifier.scala
@@ -49,10 +49,11 @@ class ImageClassifier(modelPathPrefix: String,
                       extends Classifier(modelPathPrefix,
                       inputDescriptors, contexts, epoch) {
 
-  protected[infer] val inputLayout = inputDescriptors.head.layout
+  protected[infer] val inputLayout = inputDescriptors.head.layout.toString
 
   require(inputDescriptors.nonEmpty, "Please provide input descriptor")
-  require(inputDescriptors.head.layout == "NCHW", "Provided layout doesn't 
match NCHW format")
+  require(inputDescriptors.head.layout.toString == "NCHW",
+    "Provided layout doesn't match NCHW format")
 
   protected[infer] val inputShape = inputDescriptors.head.shape
 
diff --git 
a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala 
b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
index 2a4f0305637..75b55209b1d 100644
--- a/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
+++ b/scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
@@ -18,7 +18,7 @@
 package org.apache.mxnet.infer
 
 import org.apache.mxnet.io.NDArrayIter
-import org.apache.mxnet.{Context, DataDesc, NDArray, Shape}
+import org.apache.mxnet._
 import org.apache.mxnet.module.Module
 
 import scala.collection.mutable.ListBuffer
@@ -76,15 +76,15 @@ class Predictor(modelPathPrefix: String,
 
   private val logger = LoggerFactory.getLogger(classOf[Predictor])
 
-  require(inputDescriptors.head.layout.size != 0, "layout size should not be 
zero")
+  require(inputDescriptors.head.layout.toString.size != 0, "layout size should 
not be zero")
 
-  protected[infer] var batchIndex = inputDescriptors(0).layout.indexOf('N')
+  protected[infer] var batchIndex = 
inputDescriptors(0).layout.toString.indexOf('N')
   protected[infer] var batchSize = if (batchIndex != -1) 
inputDescriptors(0).shape(batchIndex)
     else 1
 
   protected[infer] var iDescriptors = inputDescriptors
 
-  inputDescriptors.foreach((f: DataDesc) => require(f.layout.indexOf('N') == 
batchIndex,
+  inputDescriptors.foreach((f: DataDesc) => 
require(f.layout.toString.indexOf('N') == batchIndex,
     "batch size should be in the same index for all inputs"))
 
   if (batchIndex != -1) {
@@ -94,7 +94,7 @@ class Predictor(modelPathPrefix: String,
     // Note: this is assuming that the input needs a batch
     logger.warn("InputDescriptor does not have batchSize, using 1 as the 
default batchSize")
     iDescriptors = inputDescriptors.map((f: DataDesc) => new DataDesc(f.name,
-      Shape(1 +: f.shape.toVector), f.dtype, 'N' +: f.layout))
+      Shape(1 +: f.shape.toVector), f.dtype, Layout.getLayout('N' +: 
f.layout.toString)))
     batchIndex = 1
   }
 
diff --git 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
index 948764ee804..1eb3bd91e37 100644
--- 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
+++ 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/ImageClassifierSuite.scala
@@ -17,8 +17,7 @@
 
 package org.apache.mxnet.infer
 
-import org.apache.mxnet.{DType, DataDesc, Shape, NDArray, Context}
-
+import org.apache.mxnet._
 import org.mockito.Matchers._
 import org.mockito.Mockito
 import org.scalatest.BeforeAndAfterAll
@@ -60,7 +59,7 @@ class ImageClassifierSuite extends ClassifierSuite with 
BeforeAndAfterAll {
   test("ImageClassifierSuite-testConvertBufferedImageToNDArray") {
     val dType = DType.Float32
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 2, 2),
-      dType, "NCHW"))
+      dType, Layout.NCHW))
 
     val image1 = new BufferedImage(100, 200, BufferedImage.TYPE_BYTE_GRAY)
     val image2 = ImageClassifier.reshapeImage(image1, 2, 2)
@@ -73,7 +72,7 @@ class ImageClassifierSuite extends ClassifierSuite with 
BeforeAndAfterAll {
   test("ImageClassifierSuite-testWithInputImage") {
     val dType = DType.Float32
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512),
-      dType, "NCHW"))
+      dType, Layout.NCHW))
 
     val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB)
 
@@ -111,7 +110,7 @@ class ImageClassifierSuite extends ClassifierSuite with 
BeforeAndAfterAll {
   test("ImageClassifierSuite-testWithInputBatchImage") {
     val dType = DType.Float32
     val inputDescriptor = IndexedSeq[DataDesc](new DataDesc(modelPath, 
Shape(1, 3, 512, 512),
-      dType, "NCHW"))
+      dType, Layout.NCHW))
 
     val inputImage = new BufferedImage(224, 224, BufferedImage.TYPE_INT_RGB)
     val imageBatch = IndexedSeq[BufferedImage](inputImage, inputImage)
diff --git 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
index 53fd7f31068..cdd5146d0a0 100644
--- 
a/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
+++ 
b/scala-package/infer/src/test/scala/org/apache/mxnet/infer/PredictorSuite.scala
@@ -19,7 +19,7 @@ package org.apache.mxnet.infer
 
 import org.apache.mxnet.io.NDArrayIter
 import org.apache.mxnet.module.{BaseModule, Module}
-import org.apache.mxnet.{DataDesc, NDArray, Shape}
+import org.apache.mxnet.{DataDesc, Layout, NDArray, Shape}
 import org.mockito.Matchers._
 import org.mockito.Mockito
 import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -45,7 +45,7 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
     val mockPredictor = new MyPredictor("xyz", inputDescriptor)
 
     assert(mockPredictor.getBatchSize == 1)
-    assert(mockPredictor.getBatchIndex == 
inputDescriptor(0).layout.indexOf('N'))
+    assert(mockPredictor.getBatchIndex == 
inputDescriptor(0).layout.toString.indexOf('N'))
 
     val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 
3, 2, 2)),
       new DataDesc("data", Shape(2, 3, 2, 2)))
@@ -55,7 +55,7 @@ class PredictorSuite extends FunSuite with BeforeAndAfterAll {
     }
 
     // batchsize is defaulted to 1
-    val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), 
layout = "CHW"))
+    val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), 
layout = Layout.CHW))
     val p2 = new MyPredictor("xyz", inputDescriptor)
     assert(p2.getBatchSize == 1, "should use a default batch size of 1")
 
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
index adc723ecdac..b84c9180733 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LabeledPointIter.scala
@@ -17,7 +17,9 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{DataBatch, NDArray, Shape, DataIter}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.Layout.Layout
+import org.apache.mxnet._
 import org.apache.spark.mllib.regression.LabeledPoint
 
 import scala.collection.immutable.ListMap
@@ -32,7 +34,11 @@ class LabeledPointIter private[mxnet](
   private val dimension: Shape,
   private val _batchSize: Int,
   private val dataName: String = "data",
-  private val labelName: String = "label") extends DataIter {
+  private val labelName: String = "label",
+  private val dataDType: DType = DType.Float32,
+  private val labelDType: DType = DType.Int32,
+  private val dataLayout: Layout = Layout.NCHW,
+  private val labelLayout: Layout = Layout.N) extends DataIter {
 
   private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch]
   private var index: Int = -1
@@ -72,7 +78,8 @@ class LabeledPointIter private[mxnet](
       }
       val pad = batchSize - instNum
       val dataBatch = new LongLivingDataBatch(
-        IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad)
+        IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad,
+        dataLayout, labelLayout, dataDType, labelDType)
       cache += dataBatch
       dataBatch
     }
@@ -124,6 +131,14 @@ class LabeledPointIter private[mxnet](
     ListMap(dataName -> dataShape)
   }
 
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    IndexedSeq(new DataDesc(dataName, dataShape, dataDType, dataLayout))
+  }
+
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    IndexedSeq(new DataDesc(labelName, Shape(_batchSize), labelDType, 
labelLayout))
+  }
+
   /**
    * Get the number of padding examples
    * in current batch
@@ -131,6 +146,10 @@ class LabeledPointIter private[mxnet](
    */
   override def getPad(): Int = 0
 
+  override def getDType(): (DType, DType) = (dataDType, labelDType)
+
+  override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout)
+
   override def batchSize: Int = _batchSize
 
   override def hasNext: Boolean = {
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
index 339f7e2d76c..0d5068544ad 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/LongLivingDataBatch.scala
@@ -17,7 +17,9 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{NDArray, DataBatch}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.Layout.Layout
+import org.apache.mxnet.{DataBatch, NDArray}
 
 /**
  * Dispose only when 'disposeForce' called
@@ -27,7 +29,13 @@ class LongLivingDataBatch(
   override val data: IndexedSeq[NDArray],
   override val label: IndexedSeq[NDArray],
   override val index: IndexedSeq[Long],
-  override val pad: Int) extends DataBatch(data, label, index, pad) {
+  override val pad: Int,
+  override val dataLayout: Layout,
+  override val labelLayout: Layout,
+  override val dataDType: DType,
+  override val labelDType: DType) extends DataBatch(data, label, index, pad,
+  dataLayout = dataLayout, labelLayout = labelLayout,
+  dataDType = dataDType, labelDType = labelDType) {
   override def dispose(): Unit = {}
   def disposeForce(): Unit = super.dispose()
 }
diff --git 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
index 21329291cfb..96eacdab9ab 100644
--- 
a/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
+++ 
b/scala-package/spark/src/main/scala/org/apache/mxnet/spark/io/PointIter.scala
@@ -17,7 +17,9 @@
 
 package org.apache.mxnet.spark.io
 
-import org.apache.mxnet.{NDArray, DataBatch, DataIter, Shape}
+import org.apache.mxnet.DType.DType
+import org.apache.mxnet.Layout.Layout
+import org.apache.mxnet._
 import org.apache.spark.mllib.linalg.Vector
 
 import scala.collection.immutable.ListMap
@@ -32,7 +34,11 @@ class PointIter private[mxnet](
   private val dimension: Shape,
   private val _batchSize: Int,
   private val dataName: String = "data",
-  private val labelName: String = "label") extends DataIter {
+  private val labelName: String = "label",
+  private val dataDType: DType = DType.Float32,
+  private val labelDType: DType = DType.Int32,
+  private val dataLayout: Layout = Layout.NCHW,
+  private val labelLayout: Layout = Layout.N) extends DataIter {
 
   private val cache: ArrayBuffer[DataBatch] = ArrayBuffer.empty[DataBatch]
   private var index: Int = -1
@@ -71,7 +77,8 @@ class PointIter private[mxnet](
       }
       val pad = batchSize - instNum
       val dataBatch = new LongLivingDataBatch(
-        IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad)
+        IndexedSeq(dataBuilder), IndexedSeq(labelBuilder), null, pad,
+        dataLayout, labelLayout, dataDType, labelDType)
       cache += dataBatch
       dataBatch
     }
@@ -123,6 +130,14 @@ class PointIter private[mxnet](
     ListMap(dataName -> dataShape)
   }
 
+  override def provideDataDesc: IndexedSeq[DataDesc] = {
+    IndexedSeq(new DataDesc(dataName, dataShape, dataDType, dataLayout))
+  }
+
+  override def provideLabelDesc: IndexedSeq[DataDesc] = {
+    IndexedSeq(new DataDesc(labelName, Shape(_batchSize), labelDType, 
labelLayout))
+  }
+
   /**
    * Get the number of padding examples
    * in current batch
@@ -130,6 +145,10 @@ class PointIter private[mxnet](
    */
   override def getPad(): Int = 0
 
+  override def getDType(): (DType, DType) = (dataDType, labelDType)
+
+  override def getLayout(): (Layout, Layout) = (dataLayout, labelLayout)
+
   override def batchSize: Int = _batchSize
 
   override def hasNext: Boolean = {


 

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to