yzhliu commented on a change in pull request #11045: [MXNET-471] Add Builder 
class for Scala Module and DataBatch to simplify construction
URL: https://github.com/apache/incubator-mxnet/pull/11045#discussion_r191059071
 
 

 ##########
 File path: scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
 ##########
 @@ -160,6 +161,107 @@ class DataBatch(val data: IndexedSeq[NDArray],
   def provideLabel: ListMap[String, Shape] = providedLabel
 }
 
+object DataBatch {
+  /**
+   * Builder class for DataBatch.
+   */
+  class Builder() {
+    private var data: IndexedSeq[NDArray] = null
+    private var label: IndexedSeq[NDArray] = null
+    private var index: IndexedSeq[Long] = null
+    private var pad: Int = 0
+    private var bucketKey: AnyRef = null
+    private var datatShapes: ListMap[String, Shape] = null
+    private var labelShapes: ListMap[String, Shape] = null
+
+    /**
+     * Set the input data.
+     * @param data a list of data.
+     * @return this.
+     */
+    @varargs def setData(data: NDArray*): Builder = {
+      this.data = data.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the labels in the same order of data.
+     * @param label a list of labels.
+     * @return this.
+     */
+    @varargs def setLabel(label: NDArray*): Builder = {
+      this.label = label.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the example indices in this batch.
+     * @param index indices in the same order of data.
+     * @return this.
+     */
+    @varargs def setIndex(index: Long*): Builder = {
+      this.index = index.toIndexedSeq
+      this
+    }
+
+    /**
+     * Set the pad.
+     * @param pad The number of examples padded at the end of a batch. It is 
used when the
+     *            total number of examples read is not divisible by the 
`batch_size`.
+     *            These extra padded examples are ignored in prediction.
+     * @return this
+     */
+    def setPad(pad: Int): Builder = {
+      this.pad = pad
+      this
+    }
+
+    /**
+     * Set the bucket key, used for bucketing module.
+     * @param bucketKey the bucket key related to this batch.
+     * @return this.
+     */
+    def setBucketKey(bucketKey: AnyRef): Builder = {
+      this.bucketKey = bucketKey
+      this
+    }
+
+    /**
+     * Provide the shape of a data.
+     * @param name data name.
+     * @param shape data shape.
+     * @return this.
+     */
+    def provideDataShape(name: String, shape: Shape): Builder = {
+      if (datatShapes == null) {
+        datatShapes = ListMap((name, shape))
+      } else {
+        datatShapes = datatShapes.updated(name, shape)
+      }
+      this
+    }
+
+    /**
+     * Provide the shape of a label.
+     * @param name label name.
+     * @param shape label shape.
+     * @return this.
+     */
+    def provideLabelShape(name: String, shape: Shape): Builder = {
+      if (labelShapes == null) {
+        labelShapes = ListMap((name, shape))
+      } else {
+        labelShapes = labelShapes.updated(name, shape)
+      }
+      this
+    }
+
+    def build(): DataBatch = {
+      new DataBatch(data, label, index, pad, bucketKey, datatShapes, 
labelShapes)
 
 Review comment:
   I think we can let users have `data=null` if it is what they want. `data = 
null` does not cause any problem with `DataBatch` itself. We can check in 
`forward` instead.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to