haoyang9804 opened a new pull request, #15723:
URL: https://github.com/apache/tvm/pull/15723

   Fiz this [issue](https://github.com/apache/tvm/issues/14868)
   
   I found that `_convert_simple_rnn` has some logical errors. I'm not very 
sure if I fix it correctly. All in all, after this fix, running the following 
bug-triggered script 
   ```Python
   import tvm
   import tvm.relay as relay
   from tensorflow import keras
   from tensorflow.keras import layers, models
   
   input_shape = (2, 2, 2)
   x = layers.Input(shape=input_shape[1:], dtype='float32')
   
   layer = keras.layers.SimpleRNN(units=2)
   layer.set_weights(layer.get_weights())
   
   y = layer(x)
   model = models.Model(x, y)
   model.summary()
   mod, params = relay.frontend.from_keras(model, {'input_1': input_shape})
   mod = relay.transform.InferType()(mod)
   
   print(mod)
   with tvm.transform.PassContext(opt_level=3):
       model = relay.build_module.create_executor("vm", mod, tvm.cpu(0), 
'llvm', params).evaluate()
   ```
   The compilation result is 
   ```Python
   Model: "model"
   _________________________________________________________________
    Layer (type)                Output Shape              Param #   
   =================================================================
    input_1 (InputLayer)        [(None, 2, 2)]            0         
                                                                    
    simple_rnn (SimpleRNN)      (None, 2)                 10        
                                                                    
   =================================================================
   Total params: 10 (40.00 Byte)
   Trainable params: 10 (40.00 Byte)
   Non-trainable params: 0 (0.00 Byte)
   _________________________________________________________________
   def @main(%input_1: Tensor[(2, 2, 2), float32] /* ty=Tensor[(2, 2, 2), 
float32] */, %v_param_2: Tensor[(2, 4), float32] /* ty=Tensor[(2, 4), float32] 
*/, %v_param_4: Tensor[(2), float32] /* ty=Tensor[(2), float32] */, %v_param_1: 
Tensor[(1, 2), float32] /* ty=Tensor[(1, 2), float32] */, %v_param_3: 
Tensor[(2, 2), float32] /* ty=Tensor[(2, 2), float32] */) -> Tensor[(2, 2), 
float32] {
     %0 = nn.batch_flatten(%input_1) /* ty=Tensor[(2, 4), float32] */;
     %1 = nn.dense(%0, %v_param_2, units=2) /* ty=Tensor[(2, 2), float32] */;
     %2 = nn.bias_add(%1, %v_param_4) /* ty=Tensor[(2, 2), float32] */;
     %3 = split(%2, indices_or_sections=[1], axis=1) /* ty=(Tensor[(2, 1), 
float32], Tensor[(2, 1), float32]) */;
     %4 = nn.batch_flatten(%v_param_1) /* ty=Tensor[(1, 2), float32] */;
     %5 = %3.0 /* ty=Tensor[(2, 1), float32] */;
     %6 = nn.dense(%4, %v_param_3, units=2) /* ty=Tensor[(1, 2), float32] */;
     %7 = add(%5, %6) /* ty=Tensor[(2, 2), float32] */;
     %8 = %3.0 /* ty=Tensor[(2, 1), float32] */;
     %9 = nn.dense(%7, %v_param_3, units=2) /* ty=Tensor[(2, 2), float32] */;
     add(%8, %9) /* ty=Tensor[(2, 2), float32] */
   }
   
   One or more operators have not been tuned. Please tune your model for better 
performance. Use DEBUG logging level to see more details.
   ```
   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to