This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 5efa4b72dc [Test][TFLite] Add unit tests for `PRELU` (#19402)
5efa4b72dc is described below

commit 5efa4b72dca7c33b5e4f6658720ba2e15dae28be
Author: Felix Hirwa Nshuti <[email protected]>
AuthorDate: Tue Apr 14 07:55:38 2026 +0200

    [Test][TFLite] Add unit tests for `PRELU` (#19402)
    
    This PR adds unit test coverage for `PRELU` activation
    in the Relax TFLite frontend, as part of
    https://github.com/apache/tvm/issues/18971
    
    - Added unit test for `PRELU` and
    
    Enabled converter to handle alpha broadcasting more cleanly across
    constant and expression-backed alpha inputs.
---
 .../tvm/relax/frontend/tflite/tflite_frontend.py   | 24 ++------
 tests/python/relax/test_frontend_tflite.py         | 70 ++++++++++++++++++++--
 2 files changed, 70 insertions(+), 24 deletions(-)

diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py 
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index 9c99e98e01..584a65e1f4 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -2259,7 +2259,7 @@ class OperatorConverter:
         # Create axes list for all dimensions being sliced
         axes = list(range(input_tensor_rank))
         begin = [int(v) for v in begin]
-        end   = [int(v) for v in end]
+        end = [int(v) for v in end]
         out = relax.op.strided_slice(in_expr, axes=axes, begin=begin, end=end)
         return out
 
@@ -2840,9 +2840,7 @@ class OperatorConverter:
             new_b_shape = [1] * max(0, rank_a - rank_b) + [int(s) for s in 
shape_b]
             max_rank = max(rank_a, rank_b)
 
-            batch_shape = [
-                max(new_a_shape[i], new_b_shape[i]) for i in range(max_rank - 
2)
-            ]
+            batch_shape = [max(new_a_shape[i], new_b_shape[i]) for i in 
range(max_rank - 2)]
 
             a_broadcast = batch_shape + [int(shape_a[-2]), int(shape_a[-1])]
             b_broadcast = batch_shape + [int(shape_b[-2]), int(shape_b[-1])]
@@ -2987,21 +2985,11 @@ class OperatorConverter:
 
         input_tensor = input_tensors[0]
         alpha_tensor = input_tensors[1]
-        if self.has_expr(alpha_tensor.tensor_idx):
-            alpha_expr = self.get_expr(alpha_tensor.tensor_idx)
-        else:
-            alpha_tensor_type = alpha_tensor.tensor.Type()
-            alpha_tensor_type_str = self.get_tensor_type_str(alpha_tensor_type)
-            alpha_expr = self.exp_tab.new_const(
-                self.get_tensor_value(alpha_tensor),
-                dtype=alpha_tensor_type_str,
-                source_name=alpha_tensor.tensor.Name(),
-            )
-        in_expr = self.get_expr(input_tensor.tensor_idx)
         data_shape = to_int_list(self.get_tensor_shape(input_tensor))
-
-        alpha_expr = relax.op.broadcast_to(alpha_expr, data_shape)
-        alpha_expr = relax.op.reshape(alpha_expr, [-1])
+        alpha_expr = self.get_tensor_expr(alpha_tensor)
+        alpha_expr = self.bb.normalize(relax.op.broadcast_to(alpha_expr, 
data_shape))
+        alpha_expr = self.bb.normalize(relax.op.reshape(alpha_expr, [-1]))
+        in_expr = self.get_tensor_expr(input_tensor)
         out = relax.op.nn.prelu(_op.reshape(in_expr, [-1]), alpha_expr, axis=0)
         out = relax.op.reshape(out, data_shape)
         return out
diff --git a/tests/python/relax/test_frontend_tflite.py 
b/tests/python/relax/test_frontend_tflite.py
index 37a6b9cd93..bf6ef8e819 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -322,6 +322,7 @@ def test_tile(input_shape, multiples, dtype):
 
     verify(Tile)
 
+
 def test_concat_v2():
     class ConcatV2(tf.Module):
         @tf.function(input_signature=[tf.TensorSpec(shape=(1, 30), 
dtype=tf.float32)])
@@ -804,6 +805,7 @@ def test_transpose_conv():
 
     verify(TransposeConv)
 
+
 def test_l2_pool2d():
     class L2Pool2D(tf.Module):
         @tf.function(input_signature=[tf.TensorSpec(shape=(1, 8, 8, 2), 
dtype=tf.float32)])
@@ -815,9 +817,9 @@ def test_l2_pool2d():
     @I.ir_module
     class Expected:
         @R.function
-        def main(
-            data: R.Tensor((1, 8, 8, 2), dtype="float32")
-        ) -> R.Tensor((1, 8, 8, 2), dtype="float32"):
+        def main(data: R.Tensor((1, 8, 8, 2), dtype="float32")) -> R.Tensor(
+            (1, 8, 8, 2), dtype="float32"
+        ):
             R.func_attr({"num_input": 1})
             with R.dataflow():
                 squared = R.power(data, R.const(2.0, "float32"))
@@ -883,6 +885,7 @@ def test_reverse_v2():
 
     verify(ReverseV2, Expected)
 
+
 def _make_conv2d_module(data_shape, kernel_shape, data_format, strides, 
padding):
     class Conv2DModule(tf.Module):
         @tf.function(
@@ -1590,9 +1593,7 @@ _DETECTION_POSTPROCESS_SHAPE_CASES = [
     "build_kwargs,expected_topk_count,expected_keep_background",
     _DETECTION_POSTPROCESS_SMOKE_CASES,
 )
-def test_detection_postprocess_smoke(
-    build_kwargs, expected_topk_count, expected_keep_background
-):
+def test_detection_postprocess_smoke(build_kwargs, expected_topk_count, 
expected_keep_background):
     mod = _build_detection_postprocess_mod(**build_kwargs)
     ir = mod.script()
 
@@ -1649,6 +1650,7 @@ def 
test_detection_postprocess_shape_variations(build_kwargs):
         ),
     )
 
+
 def _make_resize_expected(
     input_shape, output_size, method, coordinate_transformation_mode, 
rounding_method
 ):
@@ -2109,5 +2111,61 @@ def test_relu_n1_to_1():
     verify(ReLU_N1_to_1, Expected)
 
 
[email protected](
+    "shared_axes",
+    [
+        pytest.param([1, 2], id="channelwise_shared_axes"),
+        pytest.param([1, 2, 3], id="scalar_shared_axes"),
+        pytest.param(None, id="elementwise_no_shared_axes"),
+    ],
+)
+def test_prelu(shared_axes):
+    inputs = tf.keras.Input(shape=(4, 4, 3), batch_size=1, dtype=tf.float32)
+    prelu_kwargs = {
+        "alpha_initializer": tf.initializers.constant(0.25),
+    }
+    if shared_axes is not None:
+        prelu_kwargs["shared_axes"] = shared_axes
+    outputs = tf.keras.layers.PReLU(**prelu_kwargs)(inputs)
+    keras_model = tf.keras.Model(inputs, outputs)
+
+    converter = tf.lite.TFLiteConverter.from_keras_model(keras_model)
+    tflite_model_buf = converter.convert()
+    if hasattr(tflite.Model, "Model"):
+        tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0)
+    else:
+        tflite_model = tflite.Model.GetRootAsModel(tflite_model_buf, 0)
+
+    mod = from_tflite(tflite_model)
+    mod["main"] = mod["main"].without_attr("params")
+
+    if shared_axes == [1, 2]:
+        alpha_const = np.full((1, 1, 3), 0.25, dtype=np.float32)
+    elif shared_axes == [1, 2, 3]:
+        alpha_const = np.full((1, 1, 1), 0.25, dtype=np.float32)
+    else:
+        alpha_const = np.full((4, 4, 3), 0.25, dtype=np.float32)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(x: R.Tensor((1, 4, 4, 3), dtype="float32")) -> R.Tensor(
+            (1, 4, 4, 3), dtype="float32"
+        ):
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.broadcast_to(
+                    R.const(alpha_const), R.shape([1, 4, 4, 3])
+                )
+                lv1: R.Tensor((48,), dtype="float32") = R.reshape(x, 
R.shape([48]))
+                lv2: R.Tensor((48,), dtype="float32") = R.reshape(lv, 
R.shape([48]))
+                lv3: R.Tensor((48,), dtype="float32") = R.nn.prelu(lv1, lv2, 
axis=0)
+                gv: R.Tensor((1, 4, 4, 3), dtype="float32") = R.reshape(lv3, 
R.shape([1, 4, 4, 3]))
+                R.output(gv)
+            return gv
+
+    tvm.ir.assert_structural_equal(mod, Expected)
+
+
 if __name__ == "__main__":
     pytest.main(["-s", __file__])

Reply via email to