This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new d5818b4 [Frontend, pytorch] Vc/pytorch lstm (#8447)
d5818b4 is described below
commit d5818b4cd32470e2209bd3f1305d74ad9e7819d2
Author: Valery Chernov <[email protected]>
AuthorDate: Tue Jul 20 15:36:57 2021 +0300
[Frontend, pytorch] Vc/pytorch lstm (#8447)
* lstm layer conversion to relay from pytorch model (TorchScript) was
supported
* bidirectional LSTM layer was supported for pytorch API
* lstm tests were implemented. fixes in pytorch lstm
* fix pytorch bidirectional lstm. update test comment
* black format and some small fixes
* LSTM with projection was supported for pytorch frontend. test was updated
by new combination of LSTM types
* lint fixes
* add bias switcher for LSTM types test. fix LSTM implementation in pytorch
frontend for case without biases. exception in the test for conversion LSTM
with projection from pytorch to ONNX
* transfer test_lstms to pytest format
* onnx model saving was implemented through io.BytesIO. creating/removing
tmp dir was removed. remove unneccessary comments
* gpu target was added to the test
Co-authored-by: Valery Chernov <[email protected]>
---
python/tvm/relay/frontend/pytorch.py | 294 ++++++++++++++++++++++
tests/python/frontend/pytorch/test_lstms.py | 363 ++++++++++++++++++++++++++++
2 files changed, 657 insertions(+)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index 4c87467..33cb83b 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -41,6 +41,7 @@ from ..ty import Any, TensorType, TupleType
from . import qnn_torch
from .common import AttrCvt, get_relay_op
from .common import infer_value as _infer_value
+from .common import infer_shape as _infer_shape
from .common import infer_value_simulated as _infer_value_simulated
from .common import try_infer_value
from .pytorch_utils import is_version_greater_than
@@ -2329,6 +2330,298 @@ class PyTorchOpConverter:
axis = inputs[1]
return _op.transform.reverse(data, axis=axis[0])
+ def lstm_cell(self, input_seqs, hidden, weights, has_proj=False):
+ if has_proj:
+ assert len(weights) == 5
+ else:
+ assert len(weights) == 4
+ outputs_list = []
+ # Default activations types
+ f_act = _op.sigmoid
+ g_act = _op.tanh
+ h_act = _op.tanh
+
+ # Input hiddens
+ H_t = hidden[0] # (batch, hidden_size)
+ C_t = hidden[1] # (batch, hidden_size)
+ for x_t in input_seqs:
+ # x_t shape = (batch, feature size)
+ # gates shape = (batch, 4 * hidden_size)
+ gates = _op.nn.dense(x_t, weights[0]) + _op.nn.dense(H_t,
weights[1])
+ # Add biases
+ if weights[2] is not None:
+ gates += weights[2]
+ if weights[3] is not None:
+ gates += weights[3]
+ i, f, c, o = _op.split(gates, 4, axis=-1) # (batch, hidden_size)
+
+ i = f_act(i)
+ f = f_act(f)
+ c = g_act(c)
+ o = f_act(o)
+
+ C = f * C_t + i * c
+ H = o * h_act(C)
+
+ if has_proj:
+ H = _op.nn.dense(H, weights[4])
+
+ H_t = H
+ C_t = C
+ outputs_list.append(H) # [seq_num, (batch, hidden_size)]
+ hidden_outputs = (H_t, C_t)
+
+ return (outputs_list, hidden_outputs)
+
+ def bidir_lstm_cell(self, input_seq, hidden_pair, weights_pair,
has_proj=False):
+ fw_outputs = self.lstm_cell(input_seq, hidden_pair[0],
weights_pair[0], has_proj)
+
+ rev_input_seq = []
+ seq_len = len(input_seq)
+ for i in range(seq_len):
+ rev_input_seq.append(input_seq[seq_len - 1 - i]) # [seq_num,
(batch, hidden_size)]
+ rev_outputs = self.lstm_cell(rev_input_seq, hidden_pair[1],
weights_pair[1], has_proj)
+
+ final_outputs = [] # [seq_num, (batch, 2 * hidden_size)]
+ for j in range(seq_len):
+ final_outputs.append(
+ _op.concatenate([fw_outputs[0][j], rev_outputs[0][seq_len - 1
- j]], -1)
+ )
+
+ return final_outputs, (fw_outputs[1], rev_outputs[1])
+
+ def lstm_layers(
+ self, input_data, hiddens, weights, bidirectional, dtype,
dropout_p=0.0, has_proj=False
+ ):
+ hidden_layers_num = len(hiddens)
+ assert len(weights) == hidden_layers_num
+
+ # split input sequence to samples set
+ input_seqs = self.unbind((input_data, 0), dtype) # [seq_num, (batch,
feature_size)]
+ output_hiddens = []
+ for k in range(hidden_layers_num):
+ hiddens_input = hiddens[k]
+ weights_input = weights[k]
+
+ outputs = (
+ self.bidir_lstm_cell(input_seqs, hiddens_input, weights_input,
has_proj)
+ if bidirectional
+ else self.lstm_cell(input_seqs, hiddens_input, weights_input,
has_proj)
+ )
+
+ output_hiddens.append(outputs[1])
+ # input_seqs shape = [seq_num, (batch, feature_size)] or
+ # [seq_num, (batch, 2*feature_size)] for bidirectional
+ input_seqs = outputs[0]
+
+ # TODO (vvchernov): in pytorch implementation train is also checked
+ # see
https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339
+ # /aten/src/ATen/native/RNN.cpp#L1054
+ if dropout_p != 0 and k < hidden_layers_num - 1:
+ # for input in input_seqs:
+ # input = _op.dropout(input, dropout_p)
+ raise NotImplementedError("Dropout for LSTM has not been
supported yet!")
+ final_hiddens = []
+ if bidirectional:
+ for i in range(hidden_layers_num):
+ final_hiddens.append(output_hiddens[i][0])
+ final_hiddens.append(output_hiddens[i][1])
+ else:
+ final_hiddens = output_hiddens
+
+ return _op.stack(input_seqs, 0), final_hiddens
+
+ def lstm(self, inputs, input_types):
+ """
+ Description of LSTM in
pytorch:https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html
+ Native implementation for torch version less than 1.8.0 (projection is
unsupported):
+
https://github.com/pytorch/pytorch/blob/70c8daf43946b53af6493d058899ef952d27d339/aten/
\
+ src/ATen/native/RNN.cpp#L1396
+ Native implementation for torch version from 1.8.0 and higher
(projection is supported):
+
https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/RNN.cpp#L1483
+ """
+ # TODO (vvchernov): support dropout
+ assert len(inputs) == 9, "Input of size 9 is expected"
+ # Unpack inputs, note that if optional and not provided then value
will be None.
+ _X = inputs[0]
+ # _X shape (seq_num, batch, feature_size) or (batch, seq_num,
feature_size)
+
+ hidden_states = inputs[1]
+ assert len(hidden_states) == 2, "lstm expects two hidden states"
+ h_0 = hidden_states[0]
+ c_0 = hidden_states[1]
+ # H0 shape (hidden_layers_num, batch, proj_size) if projection
+ # else (hidden_layers_num, batch, hidden_size)
+ # C0 shape (hidden_layers_num, batch, hidden_size)
+
+ _weights = inputs[2]
+ # If no projection
+ # Wi layer[0] shape (4 * hidden_size, feature_size)
+ # Wh layer[0] shape (4 * hidden_size, hidden_size)
+ # Bi layer[0] shape (4 * hidden_size)
+ # Bh layer[0] shape (4 * hidden_size)
+
+ # Wi layer[>0] shape (4 * hidden_size, hidden_size * num_directions)
+ # Wh layer[>0] shape (4 * hidden_size, hidden_size)
+ # Bi layer[>0] shape (4 * hidden_size)
+ # Bh layer[>0] shape (4 * hidden_size)
+
+ # If projection
+ # Wi layer[0] shape (4 * hidden_size, feature_size)
+ # Wh layer[0] shape (4 * hidden_size, proj_size)
+ # Bi layer[0] shape (4 * hidden_size)
+ # Bh layer[0] shape (4 * hidden_size)
+ # P layer[0] shape (proj_size, hidden_size)
+
+ # Wi layer[>0] shape (4 * hidden_size, proj_size * num_directions)
+ # Wh layer[>0] shape (4 * hidden_size, proj_size)
+ # Bi layer[>0] shape (4 * hidden_size)
+ # Bh layer[>0] shape (4 * hidden_size)
+ # P layer[>0] shape (proj_size, hidden_size)
+
+ # Scalar inputs
+ has_biases = inputs[3]
+ num_layers = inputs[4]
+ dropout_p = inputs[5] # dropout probability, if 0.0 it means there is
no dropout
+ # train = inputs[6]
+ bidirectional = inputs[7]
+ batch_first = inputs[8]
+
+ num_directions = 1
+ if bidirectional:
+ num_directions = 2
+
+ rsd = len(_weights) % num_layers
+ assert rsd == 0, "The number of weights must be a multiple of the
number of layers!"
+ rsd = (len(_weights) / num_layers) % num_directions
+ assert (
+ rsd == 0
+ ), "The number of weights in layer must be a multiple of the number of
directions!"
+ has_proj = False
+ proj_size = 0
+ weights_num = int(len(_weights) / num_layers / num_directions)
+ if has_biases:
+ if weights_num == 5:
+ has_proj = True
+ proj_size = _infer_shape(_weights[4])[0]
+ else:
+ assert weights_num == 4, "The weights number in layer is
expected equal to 4"
+ else:
+ if weights_num == 3:
+ has_proj = True
+ proj_size = _infer_shape(_weights[2])[0]
+ else:
+ assert weights_num == 2, "The weights number in layer is
expected equal to 2"
+
+ weights = []
+ if has_biases:
+ if bidirectional:
+ rsd = len(_weights) % (2 * weights_num)
+ assert rsd == 0, "got an incorrect number of LSTM weights"
+ for i in range(0, len(_weights), 2 * weights_num):
+ fw_weights = []
+ rev_weights = []
+ for j in range(weights_num):
+ fw_weights.append(_weights[i + j])
+ rev_weights.append(_weights[i + j + weights_num])
+ weights.append((fw_weights, rev_weights))
+ else:
+ assert len(_weights) % weights_num == 0, "got an incorrect
number of LSTM weights"
+ for i in range(0, len(_weights), weights_num):
+ fw_weights = []
+ for j in range(weights_num):
+ fw_weights.append(_weights[i + j])
+ weights.append(fw_weights)
+ else:
+ if bidirectional:
+ rsd = len(_weights) % (2 * weights_num)
+ assert rsd == 0, "got an incorrect number of LSTM weights"
+ for i in range(0, len(_weights), 2 * weights_num):
+ fw_weights = []
+ rev_weights = []
+ k = i + weights_num
+ if has_proj:
+ fw_weights = [_weights[i], _weights[i + 1], None,
None, _weights[i + 2]]
+ rev_weights = [_weights[k], _weights[k + 1], None,
None, _weights[k + 2]]
+ else:
+ fw_weights = [_weights[i], _weights[i + 1], None, None]
+ rev_weights = [_weights[k], _weights[k + 1], None,
None]
+ weights.append((fw_weights, rev_weights))
+ else:
+ assert len(_weights) % weights_num == 0, "got an incorrect
number of LSTM weights"
+ for i in range(0, len(_weights), weights_num):
+ if has_proj:
+ fw_weights = [_weights[i], _weights[i + 1], None,
None, _weights[i + 2]]
+ else:
+ fw_weights = [_weights[i], _weights[i + 1], None, None]
+ weights.append(fw_weights)
+ assert (
+ len(weights) == num_layers
+ ), "For stacked LSTM number of weights tuples should be the same as
number of layers!"
+
+ X = _op.transpose(_X, (1, 0, 2)) if batch_first else _X
+ # TODO (vvchernov): Which data type should be used? from input or
weights?
+ # Instead of it _infer_type(X).checked_type.dtype can be used
+ X_dtype = input_types[0]
+ X_shape = _infer_shape(X) # (seq_num, batch, feature_size)
+
+ hidden_size = _infer_shape(_weights[0])[0] / 4
+ batch_size = X_shape[1]
+
+ # Initialize hidden states if not provided.
+ layers_h = []
+ layers_c = []
+ hidden_layers_num = num_directions * num_layers
+ if h_0 is None:
+ if has_proj:
+ h_0 = _op.zeros((batch_size, proj_size), X_dtype)
+ else:
+ h_0 = _op.zeros((batch_size, hidden_size), X_dtype)
+ for i in range(hidden_layers_num):
+ layers_h.append(h_0)
+ else:
+ layers_h = self.unbind((h_0, 0), X_dtype)
+ if c_0 is None:
+ c_0 = _op.zeros((batch_size, hidden_size), X_dtype)
+ for i in range(hidden_layers_num):
+ layers_c.append(c_0)
+ else:
+ layers_c = self.unbind((c_0, 0), X_dtype)
+
+ hiddens = []
+ for i in range(num_layers):
+ if bidirectional:
+ hiddens.append(
+ ((layers_h[2 * i], layers_c[2 * i]), (layers_h[2 * i + 1],
layers_c[2 * i + 1]))
+ )
+ else:
+ hiddens.append((layers_h[i], layers_c[i]))
+
+ outputs = self.lstm_layers(
+ X,
+ hiddens,
+ weights,
+ bidirectional,
+ dtype=X_dtype,
+ dropout_p=dropout_p,
+ has_proj=has_proj,
+ )
+
+ # output shape = (seq_num, batch, hidden_size) or
+ # (seq_num, batch, 2*feature_size) for bidirectional
+ output = outputs[0]
+
+ hy = []
+ cy = []
+ for hidden in outputs[1]:
+ hy.append(hidden[0])
+ cy.append(hidden[1])
+
+ if batch_first:
+ output = _op.transpose(output, (1, 0, 2))
+
+ return (output, _op.stack(hy, 0), _op.stack(cy, 0))
+
# Operator mappings
def create_convert_map(self):
self.convert_map = {
@@ -2545,6 +2838,7 @@ class PyTorchOpConverter:
"aten::nll_loss": self.nll_loss,
"aten::nll_loss2d": self.nll_loss,
"aten::flip": self.flip,
+ "aten::lstm": self.lstm,
}
def update_convert_map(self, custom_map):
diff --git a/tests/python/frontend/pytorch/test_lstms.py
b/tests/python/frontend/pytorch/test_lstms.py
new file mode 100644
index 0000000..e780ae7
--- /dev/null
+++ b/tests/python/frontend/pytorch/test_lstms.py
@@ -0,0 +1,363 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+import numpy as np
+import torch
+import onnx
+import io
+import sys
+import pytest
+
+from tvm import relay
+from tvm.contrib import graph_executor
+
+from torch import nn
+
+## Model parameters
+model_feature_size = 5
+model_hidden_size = 10
+model_num_layers = 2
+seqs_length = 15
+projection_size = 7
+batch_size = 3
+
+
+def check_torch_version_for_proj_in_lstm():
+ """
+ proj_size parameter is supported in torch.nn.LSTM layer started from 1.8.0
torch version
+ """
+ me = False
+
+ version = torch.__version__
+ major, minor, micro = version.split(".")
+
+ if int(major) > 1:
+ me = True
+ elif int(major) == 1:
+ if int(minor) >= 8:
+ me = True
+
+ return me
+
+
+class LSTM_Model(nn.Module):
+ def __init__(
+ self,
+ device,
+ batch_first=False,
+ layer_num=1,
+ bidirectional=False,
+ proj_size=0,
+ use_bias=True,
+ rnd_weights_init=False,
+ ):
+ super().__init__()
+
+ self.device = device
+ self.batch_first = batch_first
+ self.use_bias = use_bias
+
+ if check_torch_version_for_proj_in_lstm():
+ self.lstm = nn.LSTM(
+ input_size=model_feature_size,
+ hidden_size=model_hidden_size,
+ num_layers=layer_num,
+ bidirectional=bidirectional,
+ proj_size=proj_size,
+ batch_first=batch_first,
+ bias=use_bias,
+ ).to(device)
+ else:
+ if proj_size > 0:
+ print(
+ "WARNING: projection is not supported for torch version
less than 1.8.0! ",
+ "LSTM was constructed without projection!",
+ )
+ # sys.exit()
+ self.lstm = nn.LSTM(
+ input_size=model_feature_size,
+ hidden_size=model_hidden_size,
+ num_layers=layer_num,
+ bidirectional=bidirectional,
+ batch_first=batch_first,
+ bias=use_bias,
+ ).to(device)
+
+ if rnd_weights_init:
+ self.gen_rnd_weights()
+
+ def forward(self, input, hidden_init=None):
+ """
+ Computes the output tensor after input inference along LSTM layer.
+
+ :param input: batch of data as a tensor of shape (seqs_length,
batch_size, model_feature_size) or (batch_size, seqs_length,
model_feature_size) if self.batch_first = True
+ :param hidden_init: initial hidden state of the LSTM as a tensor of
shape (num_layers, batch_size, hidden_size). Will default to a tensor of zeros
if None.
+ :return: the output tensor of shape (batch_size, model_hidden_size)
+ """
+ # Pass the input through the LSTM layers and retrieve all outputs, the
final hidden state
+ # and the final cell state.
+ out, (hidden, cell) = self.lstm(input, hidden_init)
+
+ return out
+
+ def gen_rnd_weights(self):
+ """
+ Generate random weigths for the model with biases
+ Without projection:
+ For first weights group:
+ Wi (4*model_hidden_size, model_feature_size)
+ Wh (4*model_hidden_size, model_hidden_size)
+ Bi (4*model_hidden_size)
+ Bh (4*model_hidden_size)
+ For first bidirectional weights group:
+ Wi (4*model_hidden_size, model_feature_size)
+ Wh (4*model_hidden_size, model_hidden_size)
+ Bi (4*model_hidden_size)
+ Bh (4*model_hidden_size)
+ For other weights group:
+ Wi (4*model_hidden_size, model_hidden_size)
+ Wh (4*model_hidden_size, model_hidden_size)
+ Bi (4*model_hidden_size)
+ Bh (4*model_hidden_size)
+ With projection:
+ For first weights group:
+ Wi (4*model_hidden_size, model_feature_size)
+ Wh (4*model_hidden_size, proj_size)
+ Bi (4*model_hidden_size)
+ Bh (4*model_hidden_size)
+ P (proj_size, model_hidden_size)
+ For first bidirectional weights group:
+ Wi (4*model_hidden_size, model_feature_size)
+ Wh (4*model_hidden_size, proj_size)
+ Bi (4*model_hidden_size)
+ Bh (4*model_hidden_size)
+ P (proj_size, model_hidden_size)
+ For other weights group:
+ Wi (4*model_hidden_size, proj_size * num_directions)
+ Wh (4*model_hidden_size, proj_size)
+ Bi (4*model_hidden_size)
+ Bh (4*model_hidden_size)
+ P (proj_size, model_hidden_size)
+ For generation of random weigths for the model without biases Bi and
Bh are skipped
+ """
+ for weight_group in self.lstm.all_weights:
+ for weight in weight_group:
+ weight.data = torch.rand(weight.shape)
+
+ def get_dummy_input(self):
+ shape = [seqs_length, batch_size, model_feature_size]
+ if self.batch_first:
+ shape = [batch_size, seqs_length, model_feature_size]
+ res = torch.rand(shape)
+
+ return res, shape
+
+
+def compare(input, gold_data, rtol=1e-5, atol=1e-5):
+ tvm.testing.assert_allclose(input, gold_data, rtol=rtol, atol=atol)
+
+
+def check_lstm_with_type(
+ lstm_type, target=tvm.target.Target("llvm -mcpu=core-avx2"), dev=tvm.cpu(0)
+):
+ has_proj = "p" in lstm_type
+
+ device = torch.device("cpu")
+ hidden_layers_num = 1
+ model = None
+ for batch_first in (True, False):
+ for use_bias in (True, False):
+ for rnd_weights in (True, False):
+ if lstm_type == "uni":
+ model = LSTM_Model(
+ device,
+ batch_first=batch_first,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ elif lstm_type == "b":
+ model = LSTM_Model(
+ device,
+ batch_first=batch_first,
+ bidirectional=True,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = 2
+ elif lstm_type == "p":
+ model = LSTM_Model(
+ device,
+ batch_first=batch_first,
+ proj_size=projection_size,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ elif lstm_type == "s":
+ model = LSTM_Model(
+ device,
+ batch_first=batch_first,
+ layer_num=model_num_layers,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = model_num_layers
+ elif lstm_type == "sb":
+ model = LSTM_Model(
+ device,
+ batch_first=batch_first,
+ bidirectional=True,
+ layer_num=model_num_layers,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = 2 * model_num_layers
+ elif lstm_type == "sp":
+ model = LSTM_Model(
+ device,
+ batch_first=batch_first,
+ layer_num=model_num_layers,
+ proj_size=projection_size,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = model_num_layers
+ elif lstm_type == "bp":
+ model = LSTM_Model(
+ device,
+ batch_first=batch_first,
+ bidirectional=True,
+ proj_size=projection_size,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = 2
+ elif lstm_type == "sbp":
+ model = LSTM_Model(
+ device,
+ batch_first=batch_first,
+ bidirectional=True,
+ layer_num=model_num_layers,
+ proj_size=projection_size,
+ rnd_weights_init=rnd_weights,
+ use_bias=use_bias,
+ )
+ hidden_layers_num = 2 * model_num_layers
+ else:
+ print("WARNING: LSTM type {} is not supported
here!".format(lstm_type))
+ return
+
+ model.eval()
+
+ # Get golden output from original model
+ input_hidden_shape = (hidden_layers_num, batch_size,
model_hidden_size)
+ input_hidden_shape_with_proj = (hidden_layers_num, batch_size,
projection_size)
+ dummy_input, input_shape = model.get_dummy_input()
+ golden_output_batch =
model.forward(dummy_input.to(device)).detach().cpu().numpy()
+
+ dtype = "float32"
+ h_zeros = np.zeros(input_hidden_shape, dtype=dtype)
+ if has_proj:
+ h_zeros = np.zeros(input_hidden_shape_with_proj,
dtype=dtype)
+ c_zeros = np.zeros(input_hidden_shape, dtype=dtype)
+
+ tvm_output = None
+ for format in ("ts", "onnx"):
+ if format == "ts":
+ # Use torch.jit.trace to generate a
torch.jit.ScriptModule via tracing.
+ traced_script_module = torch.jit.trace(model,
dummy_input).eval()
+
+ # Import model to Relay
+ shape_list = [("input", input_shape)]
+ mod, params =
relay.frontend.from_pytorch(traced_script_module, shape_list)
+
+ # Model compilation by tvm
+ with tvm.transform.PassContext(opt_level=3):
+ lib = relay.build(mod, target=target,
params=params)
+ elif format == "onnx":
+ if has_proj:
+ print(
+ "WARNING: torch.onnx.export does not support
conversion LSTM with projection "
+ "from pytorch! TODO: waiting for the support
and correct test after that."
+ )
+ continue
+ onnx_io = io.BytesIO()
+ with torch.no_grad():
+ h0 = torch.rand(input_hidden_shape)
+ if has_proj:
+ h0 = torch.rand(input_hidden_shape_with_proj)
+ c0 = torch.rand(input_hidden_shape)
+ input_names = ["input", "h0", "c0"]
+
+ # default export (without dynamic input)
+ torch.onnx.export(
+ model, (dummy_input, (h0, c0)), onnx_io,
input_names=input_names
+ )
+ onnx_io.seek(0, 0)
+ onnx_model = onnx.load_model(onnx_io)
+
+ # Import model to Relay
+ shape_dict = {
+ "input": input_shape,
+ "h0": input_hidden_shape,
+ "c0": input_hidden_shape,
+ }
+ if has_proj:
+ shape_dict = {
+ "input": input_shape,
+ "h0": input_hidden_shape_with_proj,
+ "c0": input_hidden_shape,
+ }
+ mod, params = relay.frontend.from_onnx(onnx_model,
shape_dict)
+
+ # Model compilation by tvm
+ with tvm.transform.PassContext(opt_level=1):
+ lib = relay.build(mod, target=target,
params=params)
+
+ # Inference of the model with given input data
+ m = graph_executor.GraphModule(lib["default"](dev))
+
+ # Set inputs
+ m.set_input(
+ input=tvm.nd.array(dummy_input.numpy().astype(dtype)),
+ h0=tvm.nd.array(h_zeros),
+ c0=tvm.nd.array(c_zeros),
+ )
+ # Execute
+ m.run()
+ # Get outputs (converted to numpy array)
+ tvm_output = m.get_output(0).numpy()
+
+ compare(tvm_output, golden_output_batch)
+
+
[email protected]_gpu
+def test_lstms():
+ for target, dev in tvm.testing.enabled_targets():
+ check_lstm_with_type("uni", target, dev)
+ check_lstm_with_type("p", target, dev)
+ check_lstm_with_type("s", target, dev)
+ check_lstm_with_type("b", target, dev)
+ check_lstm_with_type("bp", target, dev)
+ check_lstm_with_type("sp", target, dev)
+ check_lstm_with_type("sb", target, dev)
+ check_lstm_with_type("sbp", target, dev)
+
+
+if __name__ == "__main__":
+ test_lstms()