lanking520 commented on a change in pull request #13330: [MXNET-1222] Scala 
Inference enable different shapes input
URL: https://github.com/apache/incubator-mxnet/pull/13330#discussion_r237649512
 
 

 ##########
 File path: 
scala-package/infer/src/main/scala/org/apache/mxnet/infer/Predictor.scala
 ##########
 @@ -172,18 +189,20 @@ class Predictor(modelPathPrefix: String,
     for((i, d) <- inputBatch.zip(iDescriptors)) {
        require(inputBatch(0).shape(batchIndex) == i.shape(batchIndex),
          "All inputs should be of same batch size")
-      require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
-        s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
-          s"shape: ${d.shape} except batchSize")
+      if (!shapeCheckDisabled) {
+        require(i.shape.drop(batchIndex + 1) == d.shape.drop(batchIndex + 1),
+          s"Input Data Shape: ${i.shape} should match the inputDescriptor " +
+            s"shape: ${d.shape} except batchSize")
+      }
     }
 
     val inputBatchSize = inputBatch(0).shape(batchIndex)
 
     // rebind with the new batchSize
     if (batchSize != inputBatchSize) {
 
 Review comment:
   @nswamy After some tests, I found that the batch size cannot be changed:
   ```
   - Test Predictor With Different Batch size *** FAILED ***
     java.lang.AssertionError: Shape of unspecified array arg:softmax_label 
changed.This can cause the new executor to not share parameters with the old 
one. Please check for error in network.If this is intended, set partialShaping 
= true to suppress this warning.
     at org.apache.mxnet.Executor$$anonfun$reshape$2.apply(Executor.scala:120)
     at org.apache.mxnet.Executor$$anonfun$reshape$2.apply(Executor.scala:98)
     at 
scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
     at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:48)
     at org.apache.mxnet.Executor.reshape(Executor.scala:98)
     at 
org.apache.mxnet.module.DataParallelExecutorGroup$$anonfun$bindExec$1.apply$mcVI$sp(DataParallelExecutorGroup.scala:376)
     at scala.collection.immutable.Range.foreach$mVc$sp(Range.scala:160)
     at 
org.apache.mxnet.module.DataParallelExecutorGroup.bindExec(DataParallelExecutorGroup.scala:371)
     at 
org.apache.mxnet.module.DataParallelExecutorGroup.reshape(DataParallelExecutorGroup.scala:440)
     at org.apache.mxnet.module.Module.reshape(Module.scala:348)
     at org.apache.mxnet.module.Module.forward(Module.scala:453)
     at org.apache.mxnet.module.BaseModule.predict(BaseModule.scala:238)
     at 
org.apache.mxnet.module.BaseModule.predictEveryBatch(BaseModule.scala:228)
     at org.apache.mxnet.module.BaseModule.predict(BaseModule.scala:259)
     at org.apache.mxnet.infer.Predictor$$anonfun$11.apply(Predictor.scala:213)
     at org.apache.mxnet.infer.Predictor$$anonfun$11.apply(Predictor.scala:213)
     at 
org.apache.mxnet.infer.MXNetThreadPoolHandler$$anon$4.call(MXNetHandler.scala:73)
     at java.util.concurrent.FutureTask.run(FutureTask.java:266)
     at 
java.util.concurrent.ThreadPoolExecutor.runWorker(ThreadPoolExecutor.java:1149)
     at 
java.util.concurrent.ThreadPoolExecutor$Worker.run(ThreadPoolExecutor.java:624)
     at java.lang.Thread.run(Thread.java:748)
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to