This is an automated email from the ASF dual-hosted git repository.
anijain2305 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0778afd Use channels from attrs if possible (#7011)
0778afd is described below
commit 0778afd6d0fb0283fba5d4839f27e2ac548a3284
Author: Trevor Morris <[email protected]>
AuthorDate: Tue Dec 1 22:04:43 2020 -0800
Use channels from attrs if possible (#7011)
---
src/runtime/contrib/tensorrt/tensorrt_ops.cc | 4 ++++
tests/python/contrib/test_tensorrt.py | 5 +++++
2 files changed, 9 insertions(+)
diff --git a/src/runtime/contrib/tensorrt/tensorrt_ops.cc
b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
index 057743c..c3ff1c4 100644
--- a/src/runtime/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/contrib/tensorrt/tensorrt_ops.cc
@@ -243,6 +243,10 @@ class Conv2DOpConverter : public TensorRTOpConverter {
auto str_padding =
params->node.GetAttr<std::vector<std::string>>("padding");
int groups =
std::stoi(params->node.GetAttr<std::vector<std::string>>("groups")[0]);
int channels = weight_shape[0];
+ if (params->node.HasAttr("channels") &&
+
!params->node.GetAttr<std::vector<std::string>>("channels")[0].empty()) {
+ channels =
std::stoi(params->node.GetAttr<std::vector<std::string>>("channels")[0]);
+ }
// TRT conv2d op doesn't support asymmetric padding before 5.1, so we
// workaround by adding a padding layer before the pooling op.
nvinfer1::DimsHW prepadding, postpadding;
diff --git a/tests/python/contrib/test_tensorrt.py
b/tests/python/contrib/test_tensorrt.py
index 10c311a..de98222 100644
--- a/tests/python/contrib/test_tensorrt.py
+++ b/tests/python/contrib/test_tensorrt.py
@@ -352,6 +352,7 @@ def test_conv2d():
padding=(0, 0),
strides=(1, 1),
dilation=(1, 1),
+ channels=None,
):
x = relay.var("x", shape=(x_shape), dtype="float32")
kernel = relay.var("kernel", shape=(k_shape), dtype="float32")
@@ -363,6 +364,7 @@ def test_conv2d():
padding=padding,
strides=strides,
dilation=dilation,
+ channels=channels,
)
f = relay.Function([x, kernel], out)
return f, {"x": x_shape, "kernel": k_shape}, ["kernel"]
@@ -380,6 +382,9 @@ def test_conv2d():
dilation=dilation,
)
)
+ run_and_verify_func(
+ get_graph((1, 3, 16, 16), (3, 8, 7, 7), 3, [2, 2, 3, 3], [2, 2], [1,
1], 24)
+ )
def test_conv2d_nhwc():