ccjoechou commented on a change in pull request #9328:
URL: https://github.com/apache/tvm/pull/9328#discussion_r732356249



##########
File path: tests/python/relay/test_pass_convert_op_layout.py
##########
@@ -2039,5 +2039,337 @@ def expected():
         assert tvm.ir.structural_equal(a, b), "Actual = \n" + str(a)
 
 
+def test_conv_max_pool_uses_specified_convert_layout():
+    relay.op.get("nn.max_pool2d").reset_attr("FTVMConvertOpLayout")
+
+    @tvm.ir.register_op_attr("nn.max_pool2d", "FTVMConvertOpLayout")
+    def convert_maxpool2d(attrs, inputs, tinfos, desired_layouts):
+        # stick by convertng layout and out_layout to use NHWC and NHWC,
+        #   respectively, as specified in the transforms.ConvertLayout() 
function's arguments later
+        new_attrs = dict(attrs)
+        new_attrs["layout"] = str(desired_layouts[0])
+        new_attrs["out_layout"] = str(desired_layouts[0])
+        return relay.nn.max_pool2d(*inputs, **new_attrs)

Review comment:
       Moved 2 convert functions (like above for nn.max_pool2d and for 
nn.global_max_pool2d) from test_pass_convert_op_layout.py file to _nn.py as 
suggested by using decorator @reg.register_convert_op_layout() calls. Plus, 
registered two more for nn.avg_pool2d and for nn_global_avg_pool2d.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to