This is an automated email from the ASF dual-hosted git repository.
syfeng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 889fc6b27d [Marvell BYOC]: global_max_pool2d and squeeze op support
(#17481)
889fc6b27d is described below
commit 889fc6b27d200bf38a03bef532e046a7a977d136
Author: PRINCE KUMAR <[email protected]>
AuthorDate: Thu Oct 24 14:17:32 2024 +0530
[Marvell BYOC]: global_max_pool2d and squeeze op support (#17481)
Co-authored-by: princek <[email protected]>
---
python/tvm/relay/op/contrib/mrvl.py | 54 +++++++++++++-
src/relay/backend/contrib/mrvl/codegen.cc | 102 ++++++++++++++++++++++++++
tests/python/contrib/test_mrvl/test_mrvl.py | 108 ++++++++++++++++++++++++++++
3 files changed, 263 insertions(+), 1 deletion(-)
diff --git a/python/tvm/relay/op/contrib/mrvl.py
b/python/tvm/relay/op/contrib/mrvl.py
index 75041fbc8c..b13cf3d953 100644
--- a/python/tvm/relay/op/contrib/mrvl.py
+++ b/python/tvm/relay/op/contrib/mrvl.py
@@ -535,7 +535,6 @@ def mrvl_pattern_table():
def globalavgpool2d_pattern():
"""Create a globalavgpool2d pattern.
- review tvm/tests/python/relay/test_dataflow_pattern.py for examples
Returns
-------
pattern : dataflow_pattern.AltPattern
@@ -544,6 +543,17 @@ def mrvl_pattern_table():
pattern = is_op("nn.global_avg_pool2d")(wildcard())
return pattern
+ def globalmaxpool2d_pattern():
+ """Create a globalmaxpool2d pattern.
+ review tvm/tests/python/relay/test_dataflow_pattern.py for examples
+ Returns
+ -------
+ pattern : dataflow_pattern.AltPattern
+ Denotes the globalmaxpool2d pattern.
+ """
+ pattern = is_op("nn.global_max_pool2d")(wildcard())
+ return pattern
+
def reshape_pattern():
pattern = is_op("reshape")(wildcard())
return pattern
@@ -552,6 +562,10 @@ def mrvl_pattern_table():
pattern = is_op("nn.batch_flatten")(wildcard())
return pattern
+ def squeeze_pattern():
+ pattern = is_op("squeeze")(wildcard())
+ return pattern
+
def layout_transform_nchw2nhwc_pattern():
pattern = is_op("layout_transform")(is_var(), wildcard(),
wildcard()).has_attr(
{"src_layout": "NCHW", "dst_layout": "NHWC"}
@@ -596,6 +610,13 @@ def mrvl_pattern_table():
call = call.args[0]
return globalavgpool2d_nhwc2nhwc(call)
+ def check_globalmaxpool2d(extract):
+ """Check globalmaxpool2d pattern is supported by Mrvl."""
+ call = extract
+ while call.op.name != "nn.global_max_pool2d":
+ call = call.args[0]
+ return globalmaxpool2d_nhwc2nhwc(call)
+
def check_reshape(extract):
call = extract
while call.op.name != "reshape":
@@ -608,6 +629,12 @@ def mrvl_pattern_table():
call = call.args[0]
return batch_flatten_mrvl(call)
+ def check_squeeze(extract):
+ call = extract
+ while call.op.name != "squeeze":
+ call = call.args[0]
+ return squeeze_mrvl(call)
+
def check_layout_transform_nchw2nhwc(extract):
call = extract
while call.op.name != "layout_transform":
@@ -634,6 +661,7 @@ def mrvl_pattern_table():
("mrvl.maxpool2d_nhwc2nhwc", maxpool2d_pattern(), check_maxpool2d),
("mrvl.avgpool2d_nhwc2nhwc", avgpool2d_pattern(), check_avgpool2d),
("mrvl.globalavgpool2d_nhwc2nhwc", globalavgpool2d_pattern(),
check_globalavgpool2d),
+ ("mrvl.globalmaxpool2d_nhwc2nhwc", globalmaxpool2d_pattern(),
check_globalmaxpool2d),
("mrvl.sum", sum_pattern(), check_sum),
("mrvl.concat", concat_pattern(), check_concat),
(
@@ -643,6 +671,7 @@ def mrvl_pattern_table():
),
("mrvl.reshape", reshape_pattern(), check_reshape),
("mrvl.batch_flatten", batch_flatten_pattern(), check_batch_flatten),
+ ("mrvl.squeeze", squeeze_pattern(), check_squeeze),
]
@@ -813,6 +842,21 @@ def globalavgpool2d_nhwc2nhwc(expr):
return True
+# register a helper function to indicate that the given operator can be
supported by Mrvl.
[email protected]_op_attr("nn.global_max_pool2d", "target.mrvl")
+def globalmaxpool2d_nhwc2nhwc(expr):
+ """Check if the external Mrvl codegen for globalmaxpool2d_nhwc2nhwc should
be used."""
+ attrs, args = expr.attrs, expr.args
+ if attrs.layout != "NHWC":
+ return False
+ data_type = args[0].checked_type
+ if not (len(data_type.shape) == 4 or len(data_type.shape) == 2):
+ return False
+ if (len(data_type.shape) != 4) or (data_type.dtype not in ["float32"]):
+ return False
+ return True
+
+
@tvm.ir.register_op_attr("reshape", "target.mrvl")
def reshape_mrvl(expr):
"""Check if the external Mrvl codegen for reshape should be used."""
@@ -846,6 +890,14 @@ def batch_flatten_mrvl(expr):
return True
[email protected]_op_attr("squeeze", "target.mrvl")
+def squeeze_mrvl(expr):
+ """Check if the external Mrvl codegen for squeeze should be used."""
+ if expr.op.name != "squeeze":
+ return False
+ return True
+
+
# register a helper function to indicate that the given operator can be
supported by Mrvl.
@tvm.ir.register_op_attr("layout_transform", "target.mrvl")
def layout_transform_nchw2nhwc(expr):
diff --git a/src/relay/backend/contrib/mrvl/codegen.cc
b/src/relay/backend/contrib/mrvl/codegen.cc
index 6d7e593b9b..96121e4b4b 100644
--- a/src/relay/backend/contrib/mrvl/codegen.cc
+++ b/src/relay/backend/contrib/mrvl/codegen.cc
@@ -225,6 +225,13 @@ class MrvlJSONSerializer : public
backend::contrib::JSONSerializer {
const CallNode* batch_flatten = nullptr;
};
+ /*!
+ * \brief A series of operators that form a Squeeze node.
+ */
+ struct CompositeSqueezeNode {
+ const CallNode* squeeze = nullptr;
+ };
+
/*!
* \brief A series of operators that form a composite
* fc layer. Supports both nn.fc_ni2no and qnn.fc_ni2no.
@@ -278,6 +285,8 @@ class MrvlJSONSerializer : public
backend::contrib::JSONSerializer {
json_kernel_node = CreateCompositeMrvlAvgpool2DLayer(cn);
} else if (name == "mrvl.globalavgpool2d_nhwc2nhwc") {
json_kernel_node = CreateCompositeMrvlGlobalAvgpool2DLayer(cn);
+ } else if (name == "mrvl.globalmaxpool2d_nhwc2nhwc") {
+ json_kernel_node = CreateCompositeMrvlGlobalMaxpool2DLayer(cn);
} else if (name == "mrvl.sum") {
json_kernel_node = CreateCompositeMrvlSumLayer(cn);
} else if (name == "mrvl.concat") {
@@ -286,6 +295,8 @@ class MrvlJSONSerializer : public
backend::contrib::JSONSerializer {
json_kernel_node = CreateMrvlReshapeLayer(cn);
} else if (name == "mrvl.batch_flatten") {
json_kernel_node = CreateMrvlBatchFlattenLayer(cn);
+ } else if (name == "mrvl.squeeze") {
+ json_kernel_node = CreateMrvlSqueezeLayer(cn);
} else {
LOG(FATAL) << "Unrecognized Mrvl pattern: " << name;
}
@@ -511,6 +522,22 @@ class MrvlJSONSerializer : public
backend::contrib::JSONSerializer {
return nodes;
}
+ /*!
+ * \brief Extract squeeze nodes from a composite function.
+ * \param call The call node of the composite function.
+ * \return Extracted composite squeeze nodes.
+ */
+ CompositeSqueezeNode UnpackCompositeSqueeze(const CallNode* call) {
+ CompositeSqueezeNode nodes{};
+ const auto* fn = call->op.as<FunctionNode>();
+ ICHECK(fn) << "Marvell-Compiler-ERROR-Internal::Downcast to FunctionNode
failed.";
+ const auto* current_call = fn->body.as<CallNode>();
+ ICHECK(backend::IsOp(current_call, "squeeze"))
+ << "Marvell-Compiler-ERROR-Internal::squeeze missing.";
+ nodes.squeeze = current_call;
+ return nodes;
+ }
+
/*!
* \brief Extract maxpool nodes from a composite function.
*
@@ -533,6 +560,11 @@ class MrvlJSONSerializer : public
backend::contrib::JSONSerializer {
<< "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing.";
ICHECK(backend::IsOp(current_call, "nn.avg_pool2d"))
<< "Marvell-Compiler-ERROR-Internal::nn.avg_pool2d Op missing.";
+ } else if (mrvlLayerName == "GlobalMaxpool2D") {
+ ICHECK(mrvlLayerName == "GlobalMaxpool2D")
+ << "Marvell-Compiler-ERROR-Internal::nn.global_max_pool2d Op
missing.";
+ ICHECK(backend::IsOp(current_call, "nn.global_max_pool2d"))
+ << "Marvell-Compiler-ERROR-Internal::nn.global_max_pool2d Op
missing.";
} else {
ICHECK(mrvlLayerName == "GlobalAvgpool2D")
<< "Marvell-Compiler-ERROR-Internal::nn.global_avg_pool2d Op
missing.";
@@ -1115,6 +1147,34 @@ class MrvlJSONSerializer : public
backend::contrib::JSONSerializer {
return json_node;
}
+ /*!
+ * \brief Create a JSON representation of a composite Squeeze.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateMrvlSqueezeLayer(const CallNode* cn) {
+ CompositeSqueezeNode nodes = UnpackCompositeSqueeze(cn);
+ std::vector<JSONGraphNodeEntry> inputs;
+ std::string name = "squeeze";
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+ std::vector<int64_t> layout_vec;
+ GetInputTensorShapeViaArgN(nodes.squeeze, &layout_vec);
+ std::string data_layout;
+ if (layout_vec.size() == 4) {
+ data_layout = "NHWC";
+ } else {
+ data_layout = "NC";
+ }
+ layout_vec.clear();
+ std::string out_layout = "NC";
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs,
1);
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, name, data_layout,
+ "" /* no kernel_layout */, out_layout);
+ SetMrvlQuantAttrs(json_node, nodes.instrument_1, "1");
+ return json_node;
+ }
+
/*!
* \brief Create a JSON representation of a composite concat.
*
@@ -1304,6 +1364,48 @@ class MrvlJSONSerializer : public
backend::contrib::JSONSerializer {
return json_node;
}
+ /*!
+ * \brief Create a JSON representation of a composite globalmaxpooling
operator.
+ *
+ * A composite function is only created when using the uint8 datatype for
these operators.
+ *
+ * \param cn The call to be represented.
+ * \return A JSON representation of a specific operator.
+ */
+ std::shared_ptr<JSONGraphNode> CreateCompositeMrvlGlobalMaxpool2DLayer(const
CallNode* cn) {
+ std::string mrvlLayerName = "GlobalMaxpool2D";
+ std::string name = "nn.globalmaxpool2d_nhwc2nhwc";
+ CompositePoolNode nodes = UnpackCompositePool(cn, mrvlLayerName);
+
+ const auto* globalmaxpool_attr = nodes.pool->attrs.as<GlobalPool2DAttrs>();
+ ICHECK(globalmaxpool_attr)
+ << "Marvell-Compiler-ERROR-Internal::Downcast to GlobalPool2DAttrs
failed.";
+ ICHECK(globalmaxpool_attr->layout == "NHWC")
+ << "Marvell-Compiler-ERROR-Internal::"
+ << "Layout must be NHWC, has the module been pre-processed correctly?";
+
+ std::string data_layout = globalmaxpool_attr->layout;
+ std::string out_layout = globalmaxpool_attr->layout;
+ std::vector<JSONGraphNodeEntry> inputs;
+ std::vector<int64_t> kernel_layout_vec;
+ std::vector<int64_t> data_layout_vec;
+ GetInputTensorShapeViaArgN(cn, &data_layout_vec);
+ ICHECK(data_layout_vec.size() == 4);
+ kernel_layout_vec.push_back(data_layout_vec[1]);
+ kernel_layout_vec.push_back(data_layout_vec[2]);
+ inputs.push_back(VisitExpr(cn->args[0])[0]);
+
+ // op_type_ is "kernel"
+ auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs,
1);
+ SetCallNodeAttribute(json_node, nodes.pool);
+ JsonNodeSetVecAttr(json_node, "kernel_layout_shape", kernel_layout_vec);
+ if (nodes.pad) SetMrvlLayerPadAttrs(json_node, nodes.pad);
+
+ SetMrvlLayerCommonAttrs(json_node, cn, layer_name_, mrvlLayerName,
data_layout, "HW",
+ out_layout);
+ return json_node;
+ }
+
/*!
* \brief Create a JSON representation of an OpNode layer.
*
diff --git a/tests/python/contrib/test_mrvl/test_mrvl.py
b/tests/python/contrib/test_mrvl/test_mrvl.py
index 26956c97c5..cd3f343c2d 100644
--- a/tests/python/contrib/test_mrvl/test_mrvl.py
+++ b/tests/python/contrib/test_mrvl/test_mrvl.py
@@ -181,7 +181,115 @@ def test_dense():
run_and_verify_func(get_graph())
+@requires_mrvl
+def test_maxpool2d():
+ """Test maxpool2d operator for "mrvl" targets"""
+
+ def get_graph():
+ x = relay.var("x", shape=(1, 3, 224, 224))
+ arr = np.random.rand(16, 3, 3, 3).astype("float32")
+ w = relay.const(arr)
+ y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1],
kernel_size=[3, 3])
+ y = relay.nn.max_pool2d(y)
+ func = relay.Function([x], y)
+ mod = tvm.IRModule()
+ mod["main"] = func
+ option_dict = {"num_tiles": 1}
+ verify_codegen(mod, params={}, tvm_ops=1,
contains="mrvl.maxpool2d_nhwc2nhwc")
+ return func, {"x": (1, 3, 224, 224)}, [], option_dict
+
+ run_and_verify_func(get_graph())
+
+
+@requires_mrvl
+def test_avgpool2d():
+ """Test avgpool2d operator for "mrvl" targets"""
+
+ def get_graph():
+ x = relay.var("x", shape=(1, 3, 224, 224))
+ arr = np.random.rand(16, 3, 3, 3).astype("float32")
+ w = relay.const(arr)
+ y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1],
kernel_size=[3, 3])
+ y = relay.nn.avg_pool2d(y)
+ func = relay.Function([x], y)
+ mod = tvm.IRModule()
+ mod["main"] = func
+ option_dict = {"num_tiles": 1}
+ verify_codegen(mod, params={}, tvm_ops=1,
contains="mrvl.avgpool2d_nhwc2nhwc")
+ return func, {"x": (1, 3, 224, 224)}, [], option_dict
+
+ run_and_verify_func(get_graph())
+
+
+@requires_mrvl
+def test_globalavgpool2d():
+ """Test globalavgpool2d operator for "mrvl" targets"""
+
+ def get_graph():
+ x = relay.var("x", shape=(1, 3, 224, 224))
+ arr = np.random.rand(16, 3, 3, 3).astype("float32")
+ w = relay.const(arr)
+ y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1],
kernel_size=[3, 3])
+ y = relay.nn.global_avg_pool2d(y)
+ func = relay.Function([x], y)
+ mod = tvm.IRModule()
+ mod["main"] = func
+ option_dict = {"num_tiles": 1}
+ verify_codegen(mod, params={}, tvm_ops=1,
contains="mrvl.globalavgpool2d_nhwc2nhwc")
+ return func, {"x": (1, 3, 224, 224)}, [], option_dict
+
+ run_and_verify_func(get_graph())
+
+
+@requires_mrvl
+def test_globalmaxpool2d():
+ """Test globalmaxpool2d operator for "mrvl" targets"""
+
+ def get_graph():
+ x = relay.var("x", shape=(1, 3, 224, 224))
+ arr = np.random.rand(16, 3, 3, 3).astype("float32")
+ w = relay.const(arr)
+ y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1],
kernel_size=[3, 3])
+ y = relay.nn.global_max_pool2d(y)
+ func = relay.Function([x], y)
+ params = {}
+ params["w"] = arr
+ mod = tvm.IRModule()
+ mod["main"] = func
+ option_dict = {"num_tiles": 1}
+ verify_codegen(mod, params=params, tvm_ops=2,
contains="mrvl.globalmaxpool2d_nhwc2nhwc")
+ return func, {"x": (1, 3, 224, 224), "w": (16, 3, 3, 3)}, ["w"],
option_dict
+
+ run_and_verify_func(get_graph())
+
+
+@requires_mrvl
+def test_squeeze():
+ """Test squeeze operator for "mrvl" targets"""
+
+ def get_graph():
+ x = relay.var("x", shape=(1, 3, 224, 224))
+ arr = np.random.rand(16, 3, 3, 3).astype("float32")
+ w = relay.const(arr)
+ y = relay.nn.conv2d(x, w, strides=[2, 2], padding=[1, 1, 1, 1],
kernel_size=[3, 3])
+ y = relay.reshape(y, newshape=(1, 1, 16, 112, 112))
+ y = relay.squeeze(y, axis=[0, 1])
+ func = relay.Function([x], y)
+ mod = tvm.IRModule()
+ mod["main"] = func
+ option_dict = {"num_tiles": 1}
+ verify_codegen(mod, params={}, tvm_ops=3, contains="mrvl.squeeze")
+ return func, {"x": (1, 3, 224, 224)}, [], option_dict
+
+ run_and_verify_func(get_graph())
+
+
if __name__ == "__main__":
test_mrvl_fuse()
test_conv2d()
test_dense()
+ test_maxpool2d()
+ test_avgpool2d()
+ test_globalavgpool2d()
+ test_globalmaxpool2d()
+ test_squeeze()