nswamy commented on a change in pull request #11045: [MXNET-471] Add Builder
class for Scala Module and DataBatch to simplify construction
URL: https://github.com/apache/incubator-mxnet/pull/11045#discussion_r190488267
##########
File path: scala-package/core/src/main/scala/org/apache/mxnet/IO.scala
##########
@@ -160,6 +161,100 @@ class DataBatch(val data: IndexedSeq[NDArray],
def provideLabel: ListMap[String, Shape] = providedLabel
}
+object DataBatch {
+ /**
+ * Builder class for DataBatch.
+ */
+ class Builder() {
+ private var data: IndexedSeq[NDArray] = null
+ private var label: IndexedSeq[NDArray] = null
+ private var index: IndexedSeq[Long] = null
+ private var pad: Int = 0
+ private var bucketKey: AnyRef = null
+ private var providedData: ListMap[String, Shape] = ListMap.empty
+ private var providedLabel: ListMap[String, Shape] = ListMap.empty
+
+ /**
+ * Set the input data.
+ * @param data a list of data.
+ * @return this.
+ */
+ @varargs def setData(data: NDArray*): Builder = {
+ this.data = data.toIndexedSeq
+ this
+ }
+
+ /**
+ * Set the labels in the same order of data.
+ * @param label a list of labels.
+ * @return this.
+ */
+ @varargs def setLabel(label: NDArray*): Builder = {
+ this.label = label.toIndexedSeq
+ this
+ }
+
+ /**
+ * Set the example indices in this batch.
+ * @param index indices in the same order of data.
+ * @return this.
+ */
+ @varargs def setIndex(index: Long*): Builder = {
+ this.index = index.toIndexedSeq
+ this
+ }
+
+ /**
+ * Set the pad.
+ * @param pad The number of examples padded at the end of a batch. It is
used when the
+ * total number of examples read is not divisible by the
`batch_size`.
+ * These extra padded examples are ignored in prediction.
+ * @return this
+ */
+ def setPad(pad: Int): Builder = {
+ this.pad = pad
+ this
+ }
+
+ /**
+ * Set the bucket key, used for bucketing module.
+ * @param bucketKey the bucket key related to this batch.
+ * @return this.
+ */
+ def setBucketKey(bucketKey: AnyRef): Builder = {
+ this.bucketKey = bucketKey
+ this
+ }
+
+ /**
+ * Provide the shape of a data.
+ * @param name data name.
+ * @param shape data shape.
+ * @return this.
+ */
+ def provideData(name: String, shape: Shape): Builder = {
+ providedData = providedData.updated(name, shape)
+ this
+ }
+
+ /**
+ * Provide the shape of a label.
+ * @param name label name.
+ * @param shape label shape.
+ * @return this.
+ */
+ def provideLabel(name: String, shape: Shape): Builder = {
Review comment:
same as above
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services