This is an automated email from the ASF dual-hosted git repository.

syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 889fc6b27d [Marvell BYOC]: global_max_pool2d and squeeze op support 
(#17481)
889fc6b27d is described below

commit 889fc6b27d200bf38a03bef532e046a7a977d136
Author: PRINCE KUMAR <[email protected]>
AuthorDate: Thu Oct 24 14:17:32 2024 +0530

    [Marvell BYOC]: global_max_pool2d and squeeze op support (#17481)
    
    Co-authored-by: princek <[email protected]>
---
 python/tvm/relay/op/contrib/mrvl.py         |  54 +++++++++++++-
 src/relay/backend/contrib/mrvl/codegen.cc   | 102 ++++++++++++++++++++++++++
 tests/python/contrib/test_mrvl/test_mrvl.py | 108 ++++++++++++++++++++++++++++
 3 files changed, 263 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relay/op/contrib/mrvl.py 
b/python/tvm/relay/op/contrib/mrvl.py
index 75041fbc8c..b13cf3d953 100644
--- a/python/tvm/relay/op/contrib/mrvl.py
+++ b/python/tvm/relay/op/contrib/mrvl.py
@@ -535,7 +535,6 @@ def mrvl_pattern_table():
 
     def globalavgpool2d_pattern():
         """Create a globalavgpool2d pattern.
-           review tvm/tests/python/relay/test_dataflow_pattern.py for examples
         Returns
         -------
         pattern : dataflow_pattern.AltPattern
@@ -544,6 +543,17 @@ def mrvl_pattern_table():
         pattern = is_op("nn.global_avg_pool2d")(wildcard())
         return pattern
 
+    def globalmaxpool2d_pattern():
+        """Create a globalmaxpool2d pattern.
+           review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+        Returns
+        -------
+        pattern : dataflow_pattern.AltPattern
+            Denotes the globalmaxpool2d pattern.
+        """
+        pattern = is_op("nn.global_max_pool2d")(wildcard())
+        return pattern
+
     def reshape_pattern():
         pattern = is_op("reshape")(wildcard())
         return pattern
@@ -552,6 +562,10 @@ def mrvl_pattern_table():
         pattern = is_op("nn.batch_flatten")(wildcard())
         return pattern
 
+    def squeeze_pattern():
+        pattern = is_op("squeeze")(wildcard())
+        return pattern
+
     def layout_transform_nchw2nhwc_pattern():
         pattern = is_op("layout_transform")(is_var(), wildcard(), 
wildcard()).has_attr(
             {"src_layout": "NCHW", "dst_layout": "NHWC"}
@@ -596,6 +610,13 @@ def mrvl_pattern_table():
             call = call.args[0]
         return globalavgpool2d_nhwc2nhwc(call)
 
+    def check_globalmaxpool2d(extract):
+        """Check globalmaxpool2d pattern is supported by Mrvl."""
+        call = extract
+        while call.op.name != "nn.global_max_pool2d":
+            call = call.args[0]
+        return globalmaxpool2d_nhwc2nhwc(call)
+
     def check_reshape(extract):
         call = extract
         while call.op.name != "reshape":
@@ -608,6 +629,12 @@ def mrvl_pattern_table():
             call = call.args[0]
         return batch_flatten_mrvl(call)
 
+    def check_squeeze(extract):
+        call = extract
+        while call.op.name != "squeeze":
+            call = call.args[0]
+        return squeeze_mrvl(call)
+
     def check_layout_transform_nchw2nhwc(extract):
         call = extract
         while call.op.name != "layout_transform":
@@ -634,6 +661,7 @@ def mrvl_pattern_table():
         ("mrvl.maxpool2d_nhwc2nhwc", maxpool2d_pattern(), check_maxpool2d),
         ("mrvl.avgpool2d_nhwc2nhwc", avgpool2d_pattern(), check_avgpool2d),
         ("mrvl.globalavgpool2d_nhwc2nhwc", globalavgpool2d_pattern(), 
check_globalavgpool2d),
+        ("mrvl.globalmaxpool2d_nhwc2nhwc", globalmaxpool2d_pattern(), 
check_globalmaxpool2d),
         ("mrvl.sum", sum_pattern(), check_sum),
         ("mrvl.concat", concat_pattern(), check_concat),
         (
@@ -643,6 +671,7 @@ def mrvl_pattern_table():
         ),
         ("mrvl.reshape", reshape_pattern(), check_reshape),
         ("mrvl.batch_flatten", batch_flatten_pattern(), check_batch_flatten),
+        ("mrvl.squeeze", squeeze_pattern(), check_squeeze),
     ]
 
 
@@ -813,6 +842,21 @@ def globalavgpool2d_nhwc2nhwc(expr):
     return True
 
 
+# register a helper function to indicate that the given operator can be 
supported by Mrvl.
[email protected]_op_attr("nn.global_max_pool2d", "target.mrvl")
+def globalmaxpool2d_nhwc2nhwc(expr):
+    """Check if the external Mrvl codegen for globalmaxpool2d_nhwc2nhwc should 
be used."""
+    attrs, args = expr.attrs, expr.args
+    if attrs.layout != "NHWC":
+        return False
+    data_type = args[0].checked_type
+    if not (len(data_type.shape) == 4 or len(data_type.shape) == 2):
+        return False
+    if (len(data_type.shape) != 4) or (data_type.dtype not in ["float32"]):
+        return False
+    return True
+
+
 @tvm.ir.register_op_attr("reshape", "target.mrvl")
 def reshape_mrvl(expr):
     """Check if the external Mrvl codegen for reshape should be used."""
@@ -846,6 +890,14 @@ def batch_flatten_mrvl(expr):
         return True
 
 
[email protected]_op_attr("squeeze", "target.mrvl")
+def squeeze_mrvl(expr):
+    """Check if the external Mrvl codegen for squeeze should be used."""
+    if expr.op.name != "squeeze":
+        return False
+    return True
+
+
 # register a helper function to indicate that the given operator can be 
supported by Mrvl.
 @tvm.ir.register_op_attr("layout_transform", "target.mrvl")
 def layout_transform_nchw2nhwc(expr):
diff --git a/src/relay/backend/contrib/mrvl/codegen.cc 
b/src/relay/backend/contrib/mrvl/codegen.cc
index 6d7e593b9b..96121e4b4b 100644
--- a/src/relay/backend/contrib/mrvl/codegen.cc
+++ b/src/relay/backend/contrib/mrvl/codegen.cc
@@ -225,6 +225,13 @@ class MrvlJSONSerializer : public 
backend::contrib::JSONSerializer {
     const CallNode* batch_flatten = nullptr;
   };
 
+  /*!
+   * \brief A series of operators that form a Squeeze node.
+   */
+  struct CompositeSqueezeNode {
+    const CallNode* squeeze = nullptr;
+  };
+
   /*!
    * \brief A series of operators that form a composite
    * fc layer. Supports both nn.fc_ni2no and qnn.fc_ni2no.
@@ -278,6 +285,8 @@ class MrvlJSONSerializer : public 
backend::contrib::JSONSerializer {
       json_kernel_node = CreateCompositeMrvlAvgpool2DLayer(cn);
     } else if (name == "mrvl.globalavgpool2d_nhwc2nhwc") {
       json_kernel_node = CreateCompositeMrvlGlobalAvgpool2DLayer(cn);
+    } else if (name == "mrvl.globalmaxpool2d_nhwc2nhwc") {
+      json_kernel_node = CreateCompositeMrvlGlobalMaxpool2DLayer(cn);
     } else if (name == "mrvl.sum") {
       json_kernel_node = CreateCompositeMrvlSumLayer(cn);
     } else if (name == "mrvl.concat") {
@@ -286,6 +295,8 @@ class MrvlJSONSerializer : public 
backend::contrib::JSONSerializer {
       json_kernel_node = CreateMrvlReshapeLayer(cn);
     } else if (name == "mrvl.batch_flatten") {
       json_kernel_node = CreateMrvlBatchFlattenLayer(cn);
+    } else if (name == "mrvl.squeeze") {
+      json_kernel_node = CreateMrvlSqueezeLayer(cn);
     } else {
       LOG(FATAL) << "Unrecognized Mrvl pattern: " << name;
     }
@@ -511,6 +522,22 @@ class MrvlJSONSerializer : public 
backend::contrib::JSONSerializer {
     return nodes;
   }
 
+  /*!
+   * \brief Extract squeeze nodes from a composite function.
+   * \param call The call node of the composite function.
+   * \return Extracted composite squeeze nodes.
+   */
+  CompositeSqueezeNode UnpackCompositeSqueeze(const CallNode* call) {
+    CompositeSqueezeNode nodes{};
+    const auto* fn = call->op.as<FunctionNode>();
+    ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode 
failed.";
+    const auto* current_call = fn->body.as<CallNode>();
+    ICHECK(backend::IsOp(current_call, "squeeze"))
+        << "Marvell-Compiler-ERROR-Internal::squeeze missing.";
+    nodes.squeeze = current_call;
+    return nodes;
+  }
+
   /*!
    * \brief Extract maxpool nodes from a composite function.
    *
@@ -533,6 +560,11 @@ class MrvlJSONSerializer : public 
backend::contrib::JSONSerializer {
           << "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing.";
       ICHECK(backend::IsOp(current_call, "nn.avg_pool2d"))
           << "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing.";
+    } else if (mrvlLayerName == "GlobalMaxpool2D") {
+      ICHECK(mrvlLayerName == "GlobalMaxpool2D")
+          << "Marvell-Compiler-ERROR-Internal::nn.global_max_pool2d Op 
missing.";
+      ICHECK(backend::IsOp(current_call, "nn.global_max_pool2d"))
+          << "Marvell-Compiler-ERROR-Internal::nn.global_max_pool2d Op 
missing.";
     } else {
       ICHECK(mrvlLayerName == "GlobalAvgpool2D")
           << "Marvell-Compiler-ERROR-Internal::nn.global_avg_pool2d Op 
missing.";
@@ -1115,6 +1147,34 @@ class MrvlJSONSerializer : public 
backend::contrib::JSONSerializer {
     return json_node;
   }
 
+  /*!
+   * \brief Create a JSON representation of a composite Squeeze.
+   *
+   * \param cn The call to be represented.
+   * \return A JSON representation of a specific operator.
+   */
+  std::shared_ptr<JSONGraphNode> CreateMrvlSqueezeLayer(const CallNode* cn) {
+    CompositeSqueezeNode nodes = UnpackCompositeSqueeze(cn);
+    std::vector<JSONGraphNodeEntry> inputs;
+    std::string name = "squeeze";
+    inputs.push_back(VisitExpr(cn->args[0])[0]);
+    std::vector<int64_t> layout_vec;
+    GetInputTensorShapeViaArgN(nodes.squeeze, &layout_vec);
+    std::string data_layout;
+    if (layout_vec.size() == 4) {
+      data_layout = "NHWC";
+    } else {
+      data_layout = "NC";
+    }
+    layout_vec.clear();
+    std::string out_layout = "NC";
+    auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 
1);
+    SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, name, data_layout,
+                            "" /* no kernel_layout */, out_layout);
+    SetMrvlQuantAttrs(json_node, nodes.instrument_1, "1");
+    return json_node;
+  }
+
   /*!
    * \brief Create a JSON representation of a composite concat.
    *
@@ -1304,6 +1364,48 @@ class MrvlJSONSerializer : public 
backend::contrib::JSONSerializer {
     return json_node;
   }
 
+  /*!
+   * \brief Create a JSON representation of a composite globalmaxpooling 
operator.
+   *
+   * A composite function is only created when using the uint8 datatype for 
these operators.
+   *
+   * \param cn The call to be represented.
+   * \return A JSON representation of a specific operator.
+   */
+  std::shared_ptr<JSONGraphNode> CreateCompositeMrvlGlobalMaxpool2DLayer(const 
CallNode* cn) {
+    std::string mrvlLayerName = "GlobalMaxpool2D";
+    std::string name = "nn.globalmaxpool2d_nhwc2nhwc";
+    CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName);
+
+    const auto* globalmaxpool_attr = nodes.pool->attrs.as<GlobalPool2DAttrs>();
+    ICHECK(globalmaxpool_attr)
+        << "Marvell-Compiler-ERROR-Internal::Downcast to GlobalPool2DAttrs 
failed.";
+    ICHECK(globalmaxpool_attr->layout == "NHWC")
+        << "Marvell-Compiler-ERROR-Internal::"
+        << "Layout must be NHWC, has the module been pre-processed correctly?";
+
+    std::string data_layout = globalmaxpool_attr->layout;
+    std::string out_layout = globalmaxpool_attr->layout;
+    std::vector<JSONGraphNodeEntry> inputs;
+    std::vector<int64_t> kernel_layout_vec;
+    std::vector<int64_t> data_layout_vec;
+    GetInputTensorShapeViaArgN(cn, &data_layout_vec);
+    ICHECK(data_layout_vec.size() == 4);
+    kernel_layout_vec.push_back(data_layout_vec[1]);
+    kernel_layout_vec.push_back(data_layout_vec[2]);
+    inputs.push_back(VisitExpr(cn->args[0])[0]);
+
+    // op_type_ is "kernel"
+    auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 
1);
+    SetCallNodeAttribute(json_node, nodes.pool);
+    JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec);
+    if (nodes.pad) SetMrvlLayerPadAttrs(json_node, nodes.pad);
+
+    SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName, 
data_layout, "HW",
+                            out_layout);
+    return json_node;
+  }
+
   /*!
    * \brief Create a JSON representation of an OpNode layer.
    *
diff --git a/tests/python/contrib/test_mrvl/test_mrvl.py 
b/tests/python/contrib/test_mrvl/test_mrvl.py
index 26956c97c5..cd3f343c2d 100644
--- a/tests/python/contrib/test_mrvl/test_mrvl.py
+++ b/tests/python/contrib/test_mrvl/test_mrvl.py
@@ -181,7 +181,115 @@ def test_dense():
     run_and_verify_func(get_graph())
 
 
+@requires_mrvl
+def test_maxpool2d():
+    """Test maxpool2d operator for "mrvl" targets"""
+
+    def get_graph():
+        x = relay.var("x", shape=(1, 3, 224, 224))
+        arr = np.random.rand(16, 3, 3, 3).astype("float32")
+        w = relay.const(arr)
+        y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], 
kernel_size=[3, 3])
+        y = relay.nn.max_pool2d(y)
+        func = relay.Function([x], y)
+        mod = tvm.IRModule()
+        mod["main"] = func
+        option_dict = {"num_tiles": 1}
+        verify_codegen(mod, params={}, tvm_ops=1, 
contains="mrvl.maxpool2d_nhwc2nhwc")
+        return func, {"x": (1, 3, 224, 224)}, [], option_dict
+
+    run_and_verify_func(get_graph())
+
+
+@requires_mrvl
+def test_avgpool2d():
+    """Test avgpool2d operator for "mrvl" targets"""
+
+    def get_graph():
+        x = relay.var("x", shape=(1, 3, 224, 224))
+        arr = np.random.rand(16, 3, 3, 3).astype("float32")
+        w = relay.const(arr)
+        y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], 
kernel_size=[3, 3])
+        y = relay.nn.avg_pool2d(y)
+        func = relay.Function([x], y)
+        mod = tvm.IRModule()
+        mod["main"] = func
+        option_dict = {"num_tiles": 1}
+        verify_codegen(mod, params={}, tvm_ops=1, 
contains="mrvl.avgpool2d_nhwc2nhwc")
+        return func, {"x": (1, 3, 224, 224)}, [], option_dict
+
+    run_and_verify_func(get_graph())
+
+
+@requires_mrvl
+def test_globalavgpool2d():
+    """Test globalavgpool2d operator for "mrvl" targets"""
+
+    def get_graph():
+        x = relay.var("x", shape=(1, 3, 224, 224))
+        arr = np.random.rand(16, 3, 3, 3).astype("float32")
+        w = relay.const(arr)
+        y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], 
kernel_size=[3, 3])
+        y = relay.nn.global_avg_pool2d(y)
+        func = relay.Function([x], y)
+        mod = tvm.IRModule()
+        mod["main"] = func
+        option_dict = {"num_tiles": 1}
+        verify_codegen(mod, params={}, tvm_ops=1, 
contains="mrvl.globalavgpool2d_nhwc2nhwc")
+        return func, {"x": (1, 3, 224, 224)}, [], option_dict
+
+    run_and_verify_func(get_graph())
+
+
+@requires_mrvl
+def test_globalmaxpool2d():
+    """Test globalmaxpool2d operator for "mrvl" targets"""
+
+    def get_graph():
+        x = relay.var("x", shape=(1, 3, 224, 224))
+        arr = np.random.rand(16, 3, 3, 3).astype("float32")
+        w = relay.const(arr)
+        y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], 
kernel_size=[3, 3])
+        y = relay.nn.global_max_pool2d(y)
+        func = relay.Function([x], y)
+        params = {}
+        params["w"] = arr
+        mod = tvm.IRModule()
+        mod["main"] = func
+        option_dict = {"num_tiles": 1}
+        verify_codegen(mod, params=params, tvm_ops=2, 
contains="mrvl.globalmaxpool2d_nhwc2nhwc")
+        return func, {"x": (1, 3, 224, 224), "w": (16, 3, 3, 3)}, ["w"], 
option_dict
+
+    run_and_verify_func(get_graph())
+
+
+@requires_mrvl
+def test_squeeze():
+    """Test squeeze operator for "mrvl" targets"""
+
+    def get_graph():
+        x = relay.var("x", shape=(1, 3, 224, 224))
+        arr = np.random.rand(16, 3, 3, 3).astype("float32")
+        w = relay.const(arr)
+        y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1], 
kernel_size=[3, 3])
+        y = relay.reshape(y, newshape=(1, 1, 16, 112, 112))
+        y = relay.squeeze(y, axis=[0, 1])
+        func = relay.Function([x], y)
+        mod = tvm.IRModule()
+        mod["main"] = func
+        option_dict = {"num_tiles": 1}
+        verify_codegen(mod, params={}, tvm_ops=3, contains="mrvl.squeeze")
+        return func, {"x": (1, 3, 224, 224)}, [], option_dict
+
+    run_and_verify_func(get_graph())
+
+
 if __name__ == "__main__":
     test_mrvl_fuse()
     test_conv2d()
     test_dense()
+    test_maxpool2d()
+    test_avgpool2d()
+    test_globalavgpool2d()
+    test_globalmaxpool2d()
+    test_squeeze()

Reply via email to