This is an automated email from the ASF dual-hosted git repository.

ekalda pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new bbfe481784 [microNPU][ETHOSU] Channel pad offloaded to NPU (#14765)
bbfe481784 is described below

commit bbfe4817848f0f42057d00f4219b5fc7391d9d11
Author: sergio-grovety <[email protected]>
AuthorDate: Fri May 19 18:20:23 2023 +0300

    [microNPU][ETHOSU] Channel pad offloaded to NPU (#14765)
    
    A separate channel-dimension nn.pad relay operator is rewritten as Relay 
concatenate operation.
    
    ---------
    
    Co-authored-by: Sergey Smirnov 
<[email protected]>
    Co-authored-by: arina.naumova <[email protected]>
---
 .../tvm/relay/backend/contrib/ethosu/legalize.py   |  77 +++++++
 python/tvm/relay/backend/contrib/ethosu/util.py    |   2 +-
 python/tvm/relay/op/contrib/ethosu.py              |  90 +++++++-
 tests/python/contrib/test_ethosu/infra.py          |   6 +-
 tests/python/contrib/test_ethosu/test_codegen.py   |  23 ++
 tests/python/contrib/test_ethosu/test_legalize.py  | 241 ++++++++++++++++++++-
 6 files changed, 429 insertions(+), 10 deletions(-)

diff --git a/python/tvm/relay/backend/contrib/ethosu/legalize.py 
b/python/tvm/relay/backend/contrib/ethosu/legalize.py
index ffb09f4e2e..b4e8124d14 100644
--- a/python/tvm/relay/backend/contrib/ethosu/legalize.py
+++ b/python/tvm/relay/backend/contrib/ethosu/legalize.py
@@ -1448,6 +1448,82 @@ class PadRewriter(DFPatternCallback):
         )
 
 
+class ChannelPadRewriter(DFPatternCallback):
+    """Convert ethos-u.channel-pad composite function to the Relay concatenate 
operation"""
+
+    def __init__(self):
+        super().__init__(require_type=True)
+        self.pattern = (
+            wildcard().has_attr({"Composite": 
ethosu_patterns.ChannelPadParams.composite_name})
+        )(wildcard())
+
+    def callback(
+        self, pre: tvm.relay.Expr, post: tvm.relay.Expr, node_map: 
tvm.ir.container.Map
+    ) -> tvm.relay.Expr:
+        params = ethosu_patterns.ChannelPadParams(post.op.body)
+        params.ifm.tensor = post.args[0]
+
+        concat_args = list()
+        lut = relay.const([], dtype="int8")
+        # pad channels before
+        if params.ch_padding[0] > 0:
+            shape1 = list(params.ifm.shape)
+            shape1[3] = params.ch_padding[0].value
+            pad_channels = relay.Constant(
+                tvm.nd.array(
+                    np.full(
+                        shape=shape1,
+                        fill_value=int(params.ifm.q_params.zero_point),
+                        dtype=params.ifm.dtype,
+                    )
+                )
+            )
+            identity1 = ethosu_ops.ethosu_identity(
+                ifm=pad_channels,
+                lut=lut,
+                ifm_scale=float(params.ifm.q_params.scale_f32),
+                ifm_zero_point=int(params.ifm.q_params.zero_point),
+                ofm_scale=float(params.ofm.q_params.scale_f32),
+                ofm_zero_point=int(params.ofm.q_params.zero_point),
+            )
+            concat_args.append(identity1)
+
+        identity2 = ethosu_ops.ethosu_identity(
+            ifm=params.ifm.tensor,
+            lut=lut,
+            ifm_scale=float(params.ifm.q_params.scale_f32),
+            ifm_zero_point=int(params.ifm.q_params.zero_point),
+            ofm_scale=float(params.ofm.q_params.scale_f32),
+            ofm_zero_point=int(params.ofm.q_params.zero_point),
+        )
+        concat_args.append(identity2)
+
+        # pad channels after
+        if params.ch_padding[1] > 0:
+            shape3 = list(params.ifm.shape)
+            shape3[3] = params.ch_padding[1].value
+            pad_channels3 = relay.Constant(
+                tvm.nd.array(
+                    np.full(
+                        shape=shape3,
+                        fill_value=int(params.ifm.q_params.zero_point),
+                        dtype=params.ifm.dtype,
+                    )
+                )
+            )
+            identity3 = ethosu_ops.ethosu_identity(
+                ifm=pad_channels3,
+                lut=lut,
+                ifm_scale=float(params.ifm.q_params.scale_f32),
+                ifm_zero_point=int(params.ifm.q_params.zero_point),
+                ofm_scale=float(params.ofm.q_params.scale_f32),
+                ofm_zero_point=int(params.ofm.q_params.zero_point),
+            )
+            concat_args.append(identity3)
+
+        return relay.op.concatenate(relay.Tuple(concat_args), axis=3)
+
+
 @util.create_npu_function_pass(opt_level=1)
 class LegalizeEthosU:
     """This is the pass to call graph-rewrites to perform graph transformation
@@ -1462,6 +1538,7 @@ class LegalizeEthosU:
         rewriters = [
             PartitionedSplitRewriter(),
             SplitRewriter(),
+            ChannelPadRewriter(),
             Conv2DRewriter(),
             Conv2DTransposeRewriter(),
             DepthwiseConv2DRewriter(),
diff --git a/python/tvm/relay/backend/contrib/ethosu/util.py 
b/python/tvm/relay/backend/contrib/ethosu/util.py
index bbc43395c3..f0c753814f 100644
--- a/python/tvm/relay/backend/contrib/ethosu/util.py
+++ b/python/tvm/relay/backend/contrib/ethosu/util.py
@@ -144,7 +144,7 @@ class QDenseArgs(Enum):
     WEIGHTS_SCALE = 5
 
 
-class QPad2DArgs(Enum):
+class QPadArgs(Enum):
     """
     This is a helper enum to obtain the correct index
     of nn.pad arguments.
diff --git a/python/tvm/relay/op/contrib/ethosu.py 
b/python/tvm/relay/op/contrib/ethosu.py
index acf3fc1174..71b419507b 100644
--- a/python/tvm/relay/op/contrib/ethosu.py
+++ b/python/tvm/relay/op/contrib/ethosu.py
@@ -1940,15 +1940,15 @@ class PadParams:
     padding_bounds = [31, 31, 32, 32]
 
     def __init__(self, func_body: Call):
-        from tvm.relay.backend.contrib.ethosu.util import QPad2DArgs
+        from tvm.relay.backend.contrib.ethosu.util import QPadArgs
 
         # there is no 'layout' attribute in nn.pad
         layout = "NHWC"
         self.ifm = TensorParams(
-            tensor=func_body.args[QPad2DArgs.IFM.value],
+            tensor=func_body.args[QPadArgs.IFM.value],
             layout=layout,
             scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, 
dtype="float32"))),
-            zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value],
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
         )
 
         self.padding = self.extract_padding(func_body)
@@ -1956,7 +1956,7 @@ class PadParams:
             tensor=func_body,
             layout=layout,
             scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, 
dtype="float32"))),
-            zero_point=func_body.args[QPad2DArgs.IFM_ZERO_POINT.value],
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
         )
 
     @staticmethod
@@ -1964,8 +1964,8 @@ class PadParams:
         padding: relay.Call,
     ) -> Optional[Tuple[int, int, int, int]]:
         """
-        Here we check whether a separate padding operation can be rewritten
-        as NPU depthwise convolution. If the padding specified by the
+        Here we check whether a separate spatial-dimension padding operation 
can be
+        rewritten as NPU depthwise convolution. If the padding specified by the
         separate nn.pad operation is not supported by NPU depthwise 
convolution,
         None will be returned. This will cause the nn.pad not to be offloaded 
to NPU.
         """
@@ -2000,6 +2000,79 @@ class PadParams:
         return True
 
 
+class ChannelPadParams:
+    """
+    This class will parse a call to a ethos-u.channel-pad composite function
+    and extract the parameter information.
+    """
+
+    composite_name = "ethos-u.channel-pad"
+    # The ethos-u.channel-pad composite function will be transformed
+    # to the Relay concatenate operation.
+
+    def __init__(self, func_body: Call):
+        from tvm.relay.backend.contrib.ethosu.util import QPadArgs
+
+        # there is no 'layout' attribute in nn.pad
+        layout = "NHWC"
+        self.ifm = TensorParams(
+            tensor=func_body.args[QPadArgs.IFM.value],
+            layout=layout,
+            scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, 
dtype="float32"))),
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
+        )
+
+        self.ch_padding = self.extract_ch_padding(func_body)
+        self.ofm = TensorParams(
+            tensor=func_body,
+            layout=layout,
+            scale=tvm.relay.Constant(tvm.nd.array(np.array(1.0, 
dtype="float32"))),
+            zero_point=func_body.args[QPadArgs.IFM_ZERO_POINT.value],
+        )
+
+    @staticmethod
+    def extract_ch_padding(
+        padding: relay.Call,
+    ) -> Optional[Tuple[int, int]]:
+        """
+        Here we check whether a separate channel-dimension padding operation 
can be
+        rewritten as Relay concatenate operation. If the padding specified by 
the
+        separate nn.pad operation is not supported by NPU, None will be 
returned.
+        This will cause the nn.pad not to be offloaded to NPU.
+        """
+        pad_width = padding.attrs["pad_width"]
+        if len(pad_width) != 4:
+            return None
+        if (
+            list(pad_width[0]) != [0, 0]
+            or list(pad_width[1]) != [0, 0]
+            or list(pad_width[2]) != [0, 0]
+        ):
+            return None
+        return [
+            pad_width[3][0],
+            pad_width[3][1],
+        ]
+
+    def is_valid(self):
+        """
+        This function checks whether pad has compatible attributes
+        with the Relay concatenate operation
+        """
+        tensor_params = [self.ifm, self.ofm]
+        if not check_valid_dtypes(tensor_params, supported_dtypes=[np.uint8, 
np.int8]):
+            return False
+        if self.ifm.dtype != self.ofm.dtype:
+            return False
+        if not check_batch_size(self.ifm):
+            return False
+        if not self.ch_padding:
+            return False
+        if not check_dimensions(self.ifm) or not check_dimensions(self.ofm):
+            return False
+        return True
+
+
 def pad_pattern():
     """Create pattern for pad"""
     pattern = is_op("nn.pad")(wildcard(), is_constant())
@@ -2066,6 +2139,11 @@ def softmax_pattern() -> 
tvm.relay.dataflow_pattern.DFPattern:
 @register_pattern_table("ethos-u")
 def pattern_table() -> List[Tuple[str, tvm.relay.dataflow_pattern.DFPattern, 
Callable]]:
     return [
+        (
+            ChannelPadParams.composite_name,
+            pad_pattern(),
+            lambda pat: ChannelPadParams(pat).is_valid(),
+        ),
         (
             QnnConv2DParams.composite_name,
             qnn_conv2d_pattern(),
diff --git a/tests/python/contrib/test_ethosu/infra.py 
b/tests/python/contrib/test_ethosu/infra.py
index c621155827..71e7e029c1 100644
--- a/tests/python/contrib/test_ethosu/infra.py
+++ b/tests/python/contrib/test_ethosu/infra.py
@@ -475,7 +475,9 @@ def get_convolutional_args(call, include_buffers=False, 
remove_constants=False):
     return conv_args
 
 
-def compute_ofm_shape(ifm_shape, padding, kernel_shape, strides, dilation=[1, 
1]):
+def compute_ofm_shape(
+    ifm_shape, padding, kernel_shape, strides, dilation=[1, 1], 
channel_padding=[0, 0]
+):
     assert len(strides) == 2
     assert len(dilation) == 2
     assert len(kernel_shape) == 2
@@ -492,7 +494,7 @@ def compute_ofm_shape(ifm_shape, padding, kernel_shape, 
strides, dilation=[1, 1]
     elif padding.lower() == "same":
         h = math.ceil(ifm_shape[1] / strides[0])
         w = math.ceil(ifm_shape[2] / strides[1])
-    ofm_shape = [ifm_shape[0], h, w, ifm_shape[3]]
+    ofm_shape = [ifm_shape[0], h, w, ifm_shape[3] + channel_padding[0] + 
channel_padding[1]]
     return ofm_shape
 
 
diff --git a/tests/python/contrib/test_ethosu/test_codegen.py 
b/tests/python/contrib/test_ethosu/test_codegen.py
index e9a3e82a28..ef91b75efa 100644
--- a/tests/python/contrib/test_ethosu/test_codegen.py
+++ b/tests/python/contrib/test_ethosu/test_codegen.py
@@ -281,6 +281,29 @@ def test_tflite_separate_pad(
     infra.compare_tvm_with_tflite(pad2d, [ifm_shape], "ethos-u55-256")
 
 
[email protected]("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
[email protected]("channel_padding", [(0, 1), (1, 1), (5, 2)])
[email protected]("const_value", [0, 5, 125, -5])
+def test_tflite_separate_channel_pad(
+    ifm_shape,
+    channel_padding,
+    const_value,
+):
+    np.random.seed(0)
+
+    @tf.function
+    def concat_func(x):
+        x = tf.pad(
+            x,
+            [[0, 0], [0, 0], [0, 0], [channel_padding[0], channel_padding[1]]],
+            "CONSTANT",
+            const_value,
+        )
+        return x
+
+    infra.compare_tvm_with_tflite(concat_func, [ifm_shape], "ethos-u55-256", 
enable_cascader=False)
+
+
 @pytest.mark.parametrize(
     "accel_type",
     ACCEL_TYPES,
diff --git a/tests/python/contrib/test_ethosu/test_legalize.py 
b/tests/python/contrib/test_ethosu/test_legalize.py
index d1d0befcee..f87b2da983 100644
--- a/tests/python/contrib/test_ethosu/test_legalize.py
+++ b/tests/python/contrib/test_ethosu/test_legalize.py
@@ -34,6 +34,7 @@ from tvm.relay.op.contrib import ethosu
 from tvm.relay.backend.contrib.ethosu import util
 from tvm.relay.build_module import bind_params_by_name
 from tvm.relay.frontend.tflite import get_pad_value
+from tvm.relay.expr_functor import ExprVisitor
 
 from . import infra
 
@@ -462,6 +463,118 @@ def test_tflite_conv2d_with_separate_padding_legalize():
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
+def test_tflite_conv2d_with_separate_channel_padding_legalize():
+    dtype = "int8"
+    ifm_shape = (1, 55, 34, 3)
+    kernel_shape = (3, 2)
+    strides = (1, 1)
+    dilation = (2, 1)
+    padding_ch = (1, 1)
+
+    class ArePadOnGraph(ExprVisitor):
+        """
+        Visits the Graph recursively and checks if it contains 'nn.pad' op
+        """
+
+        def __init__(self):
+            ExprVisitor.__init__(self)
+            self.on_graph = False
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                if str(call.op.name) == "nn.pad":
+                    self.on_graph = True
+
+            return super().visit_call(call)
+
+        def are_pad_on_graph(self, subgraph) -> bool:
+            """
+            This function recursively visits the graph and checks if 'nn.pad' 
op is on graph
+            """
+            self.visit(subgraph)
+            return self.on_graph
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                tf_strides = [1, strides[0], strides[1], 1]
+                op = tf.pad(
+                    x,
+                    [[0, 0], [0, 0], [0, 0], [padding_ch[0], padding_ch[1]]],
+                    "CONSTANT",
+                )
+                # HWIO
+                weight_shape = [
+                    kernel_shape[0],
+                    kernel_shape[1],
+                    ifm_shape[3] + padding_ch[0] + padding_ch[1],
+                    3,
+                ]
+                weight = tf.constant(np.random.uniform(size=weight_shape), 
dtype=tf.float32)
+                return tf.nn.conv2d(
+                    op,
+                    weight,
+                    strides=tf_strides,
+                    padding="VALID",
+                    dilations=dilation,
+                )
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func):
+
+        assert ArePadOnGraph().are_pad_on_graph(ext_func.body) == True
+
+    conv2d_pattern_table = [
+        (
+            ethosu.ChannelPadParams.composite_name,
+            ethosu.pad_pattern(),
+            lambda pat: ethosu.ChannelPadParams(pat).is_valid(),
+        ),
+        (
+            ethosu.QnnConv2DParams.composite_name,
+            ethosu.qnn_conv2d_pattern(),
+            lambda pat: ethosu.QnnConv2DParams(pat).is_valid(),
+        ),
+    ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, conv_params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], conv_params)
+    mod = partition_ethosu_by_table(mod, conv2d_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.Conv2DRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+
+    verify(mod["tvmgen_default_ethos_u_main_0"])
+
+
 @pytest.mark.parametrize("ifm_shape", [(1, 299, 299, 3), (1, 123, 17, 7)])
 @pytest.mark.parametrize("kernel_shape", [(7, 3), (22, 5)])
 @pytest.mark.parametrize("padding", ["SAME", "VALID"])
@@ -760,7 +873,7 @@ def test_tflite_separate_padding_legalize(ifm_shape, 
padding, const_value):
             ethosu.PadParams.composite_name,
             ethosu.pad_pattern(),
             lambda pat: ethosu.PadParams(pat).is_valid(),
-        )
+        ),
     ]
 
     tflite_graph = create_tflite_graph()
@@ -781,6 +894,132 @@ def test_tflite_separate_padding_legalize(ifm_shape, 
padding, const_value):
     verify(mod["tvmgen_default_ethos_u_main_0"])
 
 
[email protected]("ifm_shape", [(1, 55, 55, 3), (1, 23, 32, 7)])
[email protected]("channel_padding", [(0, 1), (1, 1), (5, 2)])
[email protected]("const_value", [0, 5, 125, -5])
+def test_tflite_separate_channel_padding_legalize(ifm_shape, channel_padding, 
const_value):
+    dtype = "int8"
+    padding = (0, 0, 0, 0)
+
+    class AreConcatenateOnGraph(ExprVisitor):
+        """
+        Visits the Graph recursively and checks if it contains 'concatenate' op
+        """
+
+        def __init__(self):
+            ExprVisitor.__init__(self)
+            self.on_graph = False
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                if str(call.op.name) == "concatenate":
+                    self.on_graph = True
+
+            return super().visit_call(call)
+
+        def are_concatenate_on_graph(self, subgraph) -> bool:
+            """
+            This function recursively visits the graph and checks if 
'concatenate' op is on graph
+            """
+            self.visit(subgraph)
+            return self.on_graph
+
+    def create_tflite_graph():
+        class Model(tf.Module):
+            @tf.function
+            def tf_function(self, x):
+                return tf.pad(
+                    x,
+                    [
+                        [0, 0],
+                        [padding[0], padding[2]],
+                        [padding[1], padding[3]],
+                        [channel_padding[0], channel_padding[1]],
+                    ],
+                    "CONSTANT",
+                    const_value,
+                )
+
+        model = Model()
+        concrete_func = model.tf_function.get_concrete_function(
+            tf.TensorSpec(ifm_shape, dtype=tf.float32)
+        )
+        # Convert the model
+        def representative_dataset():
+            for _ in range(100):
+                data = np.random.rand(*tuple(ifm_shape))
+                yield [data.astype(np.float32)]
+
+        converter = 
tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
+        converter.optimizations = [tf.lite.Optimize.DEFAULT]
+        converter.representative_dataset = representative_dataset
+        converter.target_spec.supported_ops = 
[tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
+        converter.inference_input_type = tf.int8
+        converter.inference_output_type = tf.int8
+        tflite_model = converter.convert()
+        return tflite_model
+
+    def verify(ext_func, channel_padding):
+
+        op = ext_func.body
+
+        pad_before = 0
+        pad_after = 0
+        if channel_padding[0] == 0 and channel_padding[1] > 0:
+            pad_after = ext_func.body.args[0][1].args[0].checked_type.shape[3]
+            ifm = ext_func.body.args[0][0].args[0].checked_type
+        if channel_padding[0] > 0 and channel_padding[1] == 0:
+            pad_before = ext_func.body.args[0][0].args[0].checked_type.shape[3]
+            ifm = ext_func.body.args[0][1].args[0].checked_type
+        if channel_padding[0] > 0 and channel_padding[1] > 0:
+            pad_before = ext_func.body.args[0][0].args[0].checked_type.shape[3]
+            ifm = ext_func.body.args[0][1].args[0].checked_type
+            pad_after = ext_func.body.args[0][2].args[0].checked_type.shape[3]
+
+        # check IFM
+        assert list(ifm.shape) == list(ifm_shape)
+        assert str(ifm.dtype) == dtype
+        assert ifm.shape[3] == ifm_shape[3]
+
+        # check OFM
+        ofm = op.checked_type
+        expected_ofm_shape = list(ifm_shape)
+        expected_ofm_shape[3] = channel_padding[0] + ifm_shape[3] + 
channel_padding[1]
+        assert list(ofm.shape) == expected_ofm_shape
+        assert str(ofm.dtype) == dtype
+
+        # check padding
+        assert [pad_before, pad_after] == list(channel_padding)
+
+        # check if relay contains 'concatenate' op
+        assert AreConcatenateOnGraph().are_concatenate_on_graph(ext_func.body) 
== True
+
+    pad_pattern_table = [
+        (
+            ethosu.ChannelPadParams.composite_name,
+            ethosu.pad_pattern(),
+            lambda pat: ethosu.ChannelPadParams(pat).is_valid(),
+        ),
+    ]
+
+    tflite_graph = create_tflite_graph()
+    tflite_model = tflite.Model.Model.GetRootAsModel(tflite_graph, 0)
+
+    mod, params = relay.frontend.from_tflite(
+        tflite_model,
+        shape_dict={"input": ifm_shape},
+        dtype_dict={"input": dtype},
+    )
+
+    mod["main"] = bind_params_by_name(mod["main"], params)
+    mod = partition_ethosu_by_table(mod, pad_pattern_table)
+
+    mod["tvmgen_default_ethos_u_main_0"] = dataflow_pattern.rewrite(
+        legalize.ChannelPadRewriter(), mod["tvmgen_default_ethos_u_main_0"]
+    )
+    verify(mod["tvmgen_default_ethos_u_main_0"], channel_padding)
+
+
 @pytest.mark.parametrize("pooling_type", ["MAX", "AVG"])
 @pytest.mark.parametrize("ifm_shape", [[1, 3, 4, 3], [1, 4, 5, 2]])
 @pytest.mark.parametrize(

Reply via email to