## Description
(Brief description of the problem in no more than 2 sentences.)

Fitting using Scala and Module API throws an IAE.
Apparently the label shape of (50) doesn't correspond to expected the NCHW 
format.

## Environment info (Required)

macOS 10.13.6
IntelliJ 2018.2.2
Scala 2.11.12
Java 1.8.0_121
MXNet 1.2.1

## Error Message:

```scala
Exception in thread "main" java.lang.IllegalArgumentException: requirement 
failed: number of dimensions in shape :1 with shape: (50) should match the 
length of the layout: 4 with layout: NCHW
        at scala.Predef$.require(Predef.scala:224)
        at org.apache.mxnet.DataDesc.<init>(IO.scala:233)
        at 
org.apache.mxnet.DataDesc$$anonfun$ListMap2Descs$1.apply(IO.scala:256)
        at 
org.apache.mxnet.DataDesc$$anonfun$ListMap2Descs$1.apply(IO.scala:256)
        at 
scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
        at 
scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:234)
        at scala.collection.Iterator$class.foreach(Iterator.scala:891)
        at scala.collection.AbstractIterator.foreach(Iterator.scala:1334)
        at scala.collection.IterableLike$class.foreach(IterableLike.scala:72)
        at scala.collection.AbstractIterable.foreach(Iterable.scala:54)
        at scala.collection.TraversableLike$class.map(TraversableLike.scala:234)
        at scala.collection.AbstractTraversable.map(Traversable.scala:104)
        at org.apache.mxnet.DataDesc$.ListMap2Descs(IO.scala:256)
        at org.apache.mxnet.module.BaseModule.fit(BaseModule.scala:399)
```


## Minimum reproducible example

```scala
    val trainDataIter = IO.ImageRecordIter(Map(
      "data_name" -> dataName,
      "path_imgrec" -> this.getClass.getResource("/data/mydata.rec").getFile,
      "data_shape" -> "(3,128,128)",
      "batch_size" -> "50"
    ))

val mod = new Module(mlp)
    mod.fit(
      trainDataIter,
      Some(testDataIter),
      numEpoch = 10,
      fitParams =
        new FitParams()
          .setOptimizer(new SGD(0.1f, 0.9f, 0.0001f))
    )
```

Tried debugging this, but pretty difficult to find out what's going on with a 
stringly typed API.

`println(trainDataIter.provideData)` -> Map(data -> (50,3,128,128))
`println(trainDataIter.provideLabel)` -> Map(label -> (50))


[ Full content available at: 
https://github.com/apache/incubator-mxnet/issues/12409 ]
This message was relayed via gitbox.apache.org for [email protected]

Reply via email to