This is an automated email from the ASF dual-hosted git repository.
masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new 082874c [Torch][Quantized] Fix converting serialized quantized models
(#5839)
082874c is described below
commit 082874c51f728d8ff12a9cd2eed4d2734e71eb8f
Author: masahi <[email protected]>
AuthorDate: Fri Jun 19 01:24:03 2020 +0900
[Torch][Quantized] Fix converting serialized quantized models (#5839)
* [Torch] Fix converting serialized quantized models
* clean up dtype check
* comment clean up
---
python/tvm/relay/frontend/pytorch.py | 42 +++++++++++++++++------------
tests/python/frontend/pytorch/qnn_test.py | 45 ++++++++++++++++++++++++++++---
2 files changed, 67 insertions(+), 20 deletions(-)
diff --git a/python/tvm/relay/frontend/pytorch.py
b/python/tvm/relay/frontend/pytorch.py
index d2451cd..d3b6510 100644
--- a/python/tvm/relay/frontend/pytorch.py
+++ b/python/tvm/relay/frontend/pytorch.py
@@ -115,6 +115,14 @@ def _should_construct_dynamic_list(list_construct_node):
return False
+def _is_quantized_tensor(data, prelude):
+ # If a quantized Torch module is saved and loaded back, dtype will be
dropped
+ # Since dtypes from Torch tensors are not reliable in such cases, we use
+ # Relay's type inference result to decide if an input tensor is quantized
+ ty = _infer_type_with_prelude(data, prelude)
+ return ty.dtype == "uint8"
+
+
# operator implementation
def _elemwise(name):
def _impl(inputs, input_types):
@@ -530,10 +538,10 @@ def _linspace():
return _impl
-def _relu():
+def _relu(prelude):
def _impl(inputs, input_types):
data = inputs[0]
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
assert len(inputs) == 3, "Input quant param not found in op inputs"
input_zero_point = _expr.const(inputs[2], dtype="int32")
return qnn_torch.quantized_relu(data, input_zero_point)
@@ -595,7 +603,7 @@ def _log_sigmoid():
return _op.log(_op.tensor.sigmoid(data))
return _impl
-def _adaptive_avg_pool_2d():
+def _adaptive_avg_pool_2d(prelude):
def _impl(inputs, input_types):
data = inputs[0]
output_size = _infer_shape(inputs[1])
@@ -603,7 +611,7 @@ def _adaptive_avg_pool_2d():
def func(x):
return _op.nn.adaptive_avg_pool2d(x, output_size=output_size)
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
return qnn_torch.apply_with_upcast(data, func)
return func(data)
@@ -1108,7 +1116,7 @@ def _softplus():
return _op.log(_op.exp(inputs[0] * beta) + _expr.const(1.)) / beta
return _impl
-def _avg_pool2d():
+def _avg_pool2d(prelude):
def _impl(inputs, input_types):
data = inputs[0]
@@ -1130,7 +1138,7 @@ def _avg_pool2d():
ceil_mode=ceil_mode,
count_include_pad=count_include_pad)
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
return qnn_torch.apply_with_upcast(data, func)
return func(data)
@@ -1254,7 +1262,7 @@ def _variance():
return _impl
-def _mean():
+def _mean(prelude):
def _impl(inputs, input_types):
data = inputs[0]
@@ -1274,7 +1282,7 @@ def _mean():
def func(x):
return _op.mean(x, axis, keepdims, exclude)
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
assert len(inputs) == 6, "Input quant param not found in op inputs"
input_scale = _expr.const(inputs[4])
input_zero_point = _expr.const(inputs[5])
@@ -1492,7 +1500,7 @@ def _to():
return _impl
-def _upsample(method):
+def _upsample(method, prelude):
def _impl(inputs, input_types):
if isinstance(inputs[1], _expr.Var):
out_size = _infer_shape(inputs[1])
@@ -1516,7 +1524,7 @@ def _upsample(method):
def func(x):
return _op.image.resize(x, out_size, "NCHW", method, coord_trans)
- if input_types[0] == "quint8":
+ if _is_quantized_tensor(data, prelude):
import torch
from packaging import version
@@ -1835,8 +1843,8 @@ def _get_convert_map(prelude):
"aten::take" : _take(),
"aten::where" : _where(),
"aten::topk" : _topk(),
- "aten::relu" : _relu(),
- "aten::relu_" : _relu(),
+ "aten::relu" : _relu(prelude),
+ "aten::relu_" : _relu(prelude),
"aten::prelu" : _prelu(),
"aten::leaky_relu" : _leaky_relu(),
"aten::elu" : _elu(),
@@ -1845,7 +1853,7 @@ def _get_convert_map(prelude):
"aten::gelu" : _gelu(),
"aten::selu" : _selu(),
"aten::log_sigmoid" : _log_sigmoid(),
- "aten::adaptive_avg_pool2d" : _adaptive_avg_pool_2d(),
+ "aten::adaptive_avg_pool2d" :
_adaptive_avg_pool_2d(prelude),
"aten::adaptive_max_pool2d" : _adaptive_max_pool_2d(),
"aten::max_pool2d" : _maxpool_2d(),
"aten::max_pool2d_with_indices" : _maxpool_2d_with_indices(),
@@ -1874,13 +1882,13 @@ def _get_convert_map(prelude):
"aten::log_softmax" : _log_softmax(),
"aten::sigmoid" : _sigmoid(),
"aten::softplus" : _softplus(),
- "aten::avg_pool2d" : _avg_pool2d(),
+ "aten::avg_pool2d" : _avg_pool2d(prelude),
"aten::avg_pool3d" : _avg_pool3d(),
"aten::dropout" : _dropout(),
"aten::dropout_" : _dropout(),
"aten::feature_dropout" : _dropout(),
"aten::alpha_dropout" : _dropout(),
- "aten::mean" : _mean(),
+ "aten::mean" : _mean(prelude),
"aten::chunk" : _chunk(prelude),
"aten::matmul" : _matmul(prelude),
"aten::expand" : _expand(),
@@ -1932,8 +1940,8 @@ def _get_convert_map(prelude):
"aten::isnan" : _unary("isnan"),
"aten::clamp" : _clamp(),
"aten::detach" : _identity(),
- "aten::upsample_bilinear2d" : _upsample("bilinear"),
- "aten::upsample_nearest2d" :
_upsample("nearest_neighbor"),
+ "aten::upsample_bilinear2d" : _upsample("bilinear",
prelude),
+ "aten::upsample_nearest2d" :
_upsample("nearest_neighbor", prelude),
"aten::upsample_trilinear3d" : _upsample3d("trilinear"),
"aten::upsample_nearest3d" :
_upsample3d("nearest_neighbor"),
"aten::expand_as" : _expand_as(),
diff --git a/tests/python/frontend/pytorch/qnn_test.py
b/tests/python/frontend/pytorch/qnn_test.py
index 551cdc4..8c6c248 100644
--- a/tests/python/frontend/pytorch/qnn_test.py
+++ b/tests/python/frontend/pytorch/qnn_test.py
@@ -63,7 +63,7 @@ def get_qconfig(per_channel):
weight=default_weight_observer)
-def quantize_model(model, inp, per_channel=False, dummy=True):
+def quantize_model(model, inp, per_channel=False):
model.fuse_model()
model.qconfig = get_qconfig(per_channel)
torch.quantization.prepare(model, inplace=True)
@@ -243,6 +243,18 @@ class AvgPool2d(nn.Module):
pass
+class AdaptiveAvgPool2d(nn.Module):
+ def __init__(self):
+ super().__init__()
+ self.pool = QuantWrapper(nn.AdaptiveAvgPool2d((1, 1)))
+
+ def forward(self, x):
+ return self.pool(x)
+
+ def fuse_model(self):
+ pass
+
+
def test_quantized_modules():
imagenet_ishape = (1, 3, 224, 224)
@@ -280,7 +292,7 @@ def test_quantized_modules():
raw_module.eval()
inp = torch.rand(ishape)
- quantize_model(raw_module, inp, per_channel=per_channel, dummy=True)
+ quantize_model(raw_module, inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_module, inp).eval()
with torch.no_grad():
@@ -376,7 +388,7 @@ def test_quantized_imagenet():
inp = get_imagenet_input()
pt_inp = torch.from_numpy(inp)
- quantize_model(raw_model, pt_inp, per_channel=per_channel, dummy=False)
+ quantize_model(raw_model, pt_inp, per_channel=per_channel)
script_module = torch.jit.trace(raw_model, pt_inp).eval()
with torch.no_grad():
@@ -465,3 +477,30 @@ def test_quantized_imagenet():
mean abs_diff: 0.054197952
558 in 1000 raw outputs identical.
"""
+
+
+def test_serialized_modules():
+ ishape = (1, 16, 64, 64)
+ raw_module = AdaptiveAvgPool2d().eval()
+ inp = torch.rand(ishape)
+
+ quantize_model(raw_module, inp)
+ script_module = torch.jit.trace(raw_module, inp).eval()
+
+ fname = "tmp.pt"
+ torch.jit.save(script_module, fname)
+ loaded = torch.jit.load(fname)
+ os.remove(fname)
+
+ with torch.no_grad():
+ pt_result = loaded(inp.clone()).numpy()
+
+ input_name = "input"
+ runtime = get_tvm_runtime(loaded, input_name, ishape)
+ runtime.set_input(input_name, inp.numpy().copy())
+ runtime.run()
+ tvm_result = runtime.get_output(0).asnumpy()
+
+ num_identical = np.sum(tvm_result == pt_result)
+ match_ratio = num_identical / float(np.prod(tvm_result.shape))
+ assert match_ratio > 0.2