comaniac commented on a change in pull request #4741: [External codegen] Add 
test cases for fused ops with manual annotation
URL: https://github.com/apache/incubator-tvm/pull/4741#discussion_r368361331
 
 

 ##########
 File path: tests/python/relay/test_pass_partition_graph.py
 ##########
 @@ -425,10 +470,120 @@ def test_extern_dnnl_mobilenet():
                  (1, 1000), ref_res.asnumpy(), tol=1e-5, params=params)
 
 
+def test_partition_conv_bias_relu():
+    if not tvm.get_global_func("relay.ext.dnnl", True):
+        print("skip because DNNL codegen is not available")
+        return
+
+    def get_layers(prefix, data, in_channel, out_channel,
+                   include_bn=True, include_sigmoid=False):
+        weight = relay.var(prefix + "weight")
+        bn_gamma = relay.var(prefix + "bn_gamma")
+        bn_beta = relay.var(prefix + "bn_beta")
+        bn_mmean = relay.var(prefix + "bn_mean")
+        bn_mvar = relay.var(prefix + "bn_var")
+
+        layer = relay.nn.conv2d(data=data, weight=weight, kernel_size=(3, 3),
+                                channels=out_channel, padding=(1, 1))
+        if include_bn:
+            bn_output = relay.nn.batch_norm(layer, bn_gamma, bn_beta,
+                                            bn_mmean, bn_mvar)
+            layer = bn_output[0]
+        if include_sigmoid:
+            # dummy layer to prevent pattern detection
+            layer = relay.sigmoid(layer)
+        layer = relay.nn.relu(layer)
+        return layer
+
+    def get_net(include_bn=True, include_sigmoid=False):
+        data = relay.var("data", relay.TensorType((1, 3, 224, 224), "float32"))
+        layer1 = get_layers("layer1_", data, 3, 16, include_bn, 
include_sigmoid)
+        layer2 = get_layers("layer2_", layer1, 16, 16, include_bn, 
include_sigmoid)
+        last = layer2
+        return relay.Function(relay.analysis.free_vars(last), last)
+
+    def pre_optimize(mod, params):
+        remove_bn_pass = transform.Sequential([
+            relay.transform.InferType(),
+            relay.transform.SimplifyInference(),
+            relay.transform.FoldConstant(),
+            relay.transform.FoldScaleAxis(),
+        ])
+
+        if params != {}:
+            # This is required for constant folding
+            mod["main"] = bind_params_by_name(mod["main"], params)
+
+        with relay.build_config(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            mod = remove_bn_pass(mod)
+
+        return mod
+
+    def get_partitoned_mod(mod):
+        mod["main"] = ConvBiasAddReLUAnnotator("dnnl").visit(mod["main"])
+        mod = transform.PartitionGraph()(mod)
+        return mod
+
+    def get_partitions(mod):
+        partitions = []
+
+        def visit_func(expr):
+            if isinstance(expr, _expr.Function) and expr != mod["main"]:
+                partitions.append(expr)
+        analysis.post_order_visit(mod["main"], visit_func)
+        return partitions
+
+    def test_detect_pattern(include_bn, include_sigmoid, 
num_expected_partition):
+        net = get_net(include_bn, include_sigmoid)
+        mod, params = tvm.relay.testing.create_workload(net)
+        mod = pre_optimize(mod, params)
+        mod = get_partitoned_mod(mod)
+        assert(len(get_partitions(mod)) == num_expected_partition)
+
+    def test_partition():
+        # conv + bn + relu -> detection succeed
+        test_detect_pattern(True, False, 2)
+        # conv + relu -> fail
+        test_detect_pattern(False, False, 0)
+        # conv + bn + sigmoid + relu -> fail
+        test_detect_pattern(True, True, 0)
+
+    def test_partition_mobilenet():
+        mod, params = relay.testing.mobilenet.get_workload()
+        mod = pre_optimize(mod, params)
+        mod = get_partitoned_mod(mod)
+        assert(len(get_partitions(mod)) == 27)
+
+    def test_exec(mod, params, ref_mod, ref_params, out_shape):
+        ishape = (1, 3, 224, 224)
+        i_data = np.random.randn(*ishape).astype(np.float32)
+        ref_ex = relay.create_executor("graph", mod=ref_mod, ctx=tvm.cpu(0))
+        ref_res = ref_ex.evaluate()(i_data, **ref_params)
 
 Review comment:
   It's better to add `compile_engine.get().clear()`. Otherwise it's possible 
that the module doesn't really run with DNNL in `check_result` because ops in 
`mod` and `ref_mod` are the same.

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
[email protected]


With regards,
Apache Git Services

Reply via email to