masahi commented on a change in pull request #8781:
URL: https://github.com/apache/tvm/pull/8781#discussion_r693834511
##########
File path: tests/python/frontend/pytorch/test_rnns.py
##########
@@ -173,6 +253,115 @@ def compare(input, gold_data, rtol=1e-5, atol=1e-5):
tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol)
+def check_gru_with_type(gru_type, target=tvm.target.Target("llvm
-mcpu=core-avx2"), dev=tvm.cpu(0)):
+ device = torch.device("cpu")
+ hidden_layers_num = 1
+ model = None
+ for batch_first in (True, False):
+ for use_bias in (True, False):
+ for rnd_weights in [True]: # (True, False):
+ if gru_type == "uni":
+ model = GRU_Model(
+ device,
+ batch_first=batch_first,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ elif gru_type == "b":
+ model = GRU_Model(
+ device,
+ batch_first=batch_first,
+ bidirectional=True,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = 2
+ elif gru_type == "s":
+ model = GRU_Model(
+ device,
+ batch_first=batch_first,
+ layer_num=gru_num_layers,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = gru_num_layers
+ elif gru_type == "sb":
+ model = GRU_Model(
+ device,
+ batch_first=batch_first,
+ bidirectional=True,
+ layer_num=gru_num_layers,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = 2 * gru_num_layers
+ else:
+ print("WARNING: GRU type {} is not supported
here!".format(gru_type))
+ return
+
+ model.eval()
+
+ # Get golden output from original model
+ input_hidden_shape = (hidden_layers_num, batch_size,
gru_hidden_size)
+ dummy_input, input_shape = model.get_dummy_input()
+ golden_output_batch =
model.forward(dummy_input.to(device)).detach().cpu().numpy()
+
+ dtype = "float32"
+ h_zeros = np.zeros(input_hidden_shape, dtype=dtype)
+
+ tvm_output = None
+ for format in ["ts"]: # ["ts", "onnx"]:
Review comment:
This looks like copy paste form the LSTM test. Remove onnx stuff, or use
common function for testing both GRU or LSTM.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]