lanking520 commented on a change in pull request #13330: [MXNET-1222] Scala
Inference enable different shapes input
URL: https://github.com/apache/incubator-mxnet/pull/13330#discussion_r237649512
##########
File path:
scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
##########
@@ -172,18 +189,20 @@ class Predictor(modelPathPrefix: String,
for((i, d) <- inputBatch.zip(iDescriptors)) {
require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex),
"All inputs should be of same batch size")
- require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
- s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
- s"shape: ${d.shape} except batchSize")
+ if (!shapeCheckDisabled) {
+ require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
+ s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
+ s"shape: ${d.shape} except batchSize")
+ }
}
val inputBatchSize = inputBatch(0).shape(batchIndex)
// rebind with the new batchSize
if (batchSize != inputBatchSize) {
Review comment:
@nswamy After some tests, I found that the batch size cannot be changed:
```
- Test Predictor With Different Batch size *** FAILED ***
java.lang.AssertionError: Shape of unspecified array arg:softmax_label
changed.This can cause the new executor to not share parameters with the old
one. Please check for error in network.If this is intended, set partialShaping
= true to suppress this warning.
at org.apache.mxnet.Executor$$anonfun$reshape$2.apply(Executor.scala:120)
at org.apache.mxnet.Executor$$anonfun$reshape$2.apply(Executor.scala:98)
at
scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
at org.apache.mxnet.Executor.reshape(Executor.scala:98)
at
org.apache.mxnet.module.DataParallelExecutorGroup$$anonfun$bindExec$1.apply$mcVI$sp(DataParallelExecutorGroup.scala:376)
at scala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)
at
org.apache.mxnet.module.DataParallelExecutorGroup.bindExec(DataParallelExecutorGroup.scala:371)
at
org.apache.mxnet.module.DataParallelExecutorGroup.reshape(DataParallelExecutorGroup.scala:440)
at org.apache.mxnet.module.Module.reshape(Module.scala:348)
at org.apache.mxnet.module.Module.forward(Module.scala:453)
at org.apache.mxnet.module.BaseModule.predict(BaseModule.scala:238)
at
org.apache.mxnet.module.BaseModule.predictEveryBatch(BaseModule.scala:228)
at org.apache.mxnet.module.BaseModule.predict(BaseModule.scala:259)
at org.apache.mxnet.infer.Predictor$$anonfun$11.apply(Predictor.scala:213)
at org.apache.mxnet.infer.Predictor$$anonfun$11.apply(Predictor.scala:213)
at
org.apache.mxnet.infer.MXNetThreadPoolHandler$$anon$4.call(MXNetHandler.scala:73)
at java.util.concurrent.FutureTask.run(FutureTask.java:266)
at
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
at
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
at java.lang.Thread.run(Thread.java:748)
```
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]
With regards,
Apache Git Services