This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4ab312bf12 [Relax][Frontend][TFLite] Add Conv3D support (#19523)
4ab312bf12 is described below
commit 4ab312bf12cbe1f1da87c7a9687287c938503aab
Author: Wei-Cheng Hsu <[email protected]>
AuthorDate: Sat May 9 12:17:31 2026 +0800
[Relax][Frontend][TFLite] Add Conv3D support (#19523)
Description
This PR adds support for the CONV_3D operator in the TFLite frontend for
Relax.
Key Changes
- Operator Mapping: Added CONV_3D to the OperatorConverter mapping in
tflite_frontend.py.
- Implementation:
- Implemented convert_conv3d to handle 3D convolution attributes such as
StrideD/H/W, DilationD/H/W, and Padding.
- Correctly handled the TFLite 3D kernel layout, which is expected to be
DHWIO (Depth, Height, Width, Input Channels, Output Channels).
- Integrated support for fused activation functions (ReLU, ReLU6, etc.)
directly following the convolution.
- Unit Tests:
- Added comprehensive tests in
tests/python/relax/test_frontend_tflite.py covering:
- VALID and SAME padding modes.
- Various stride and dilation configurations.
- Verification against expected Relax IR structure.
Testing:
- `python3 -m pytest tests/python/relax/test_frontend_tflite.py -k
"test_conv3d"`
Notes for Reviewers
The implementation follows the existing pattern used for CONV_2D but
extends it to the 5D case (NDHWC layout). I've ensured that the kernel
layout mapping aligns with TVM's R.nn.conv3d requirements.
Related to: https://github.com/apache/tvm/issues/19519
---
.../tvm/relax/frontend/tflite/tflite_frontend.py | 137 +++++++++++++++++++++
tests/python/relax/test_frontend_tflite.py | 83 +++++++++++++
2 files changed, 220 insertions(+)
diff --git a/python/tvm/relax/frontend/tflite/tflite_frontend.py
b/python/tvm/relax/frontend/tflite/tflite_frontend.py
index f5b88b0c6a..d70f5d837e 100644
--- a/python/tvm/relax/frontend/tflite/tflite_frontend.py
+++ b/python/tvm/relax/frontend/tflite/tflite_frontend.py
@@ -132,6 +132,7 @@ class OperatorConverter:
"CEIL": functools.partial(self._convert_unary_elemwise,
relax_op=_op.ceil),
"CONCATENATION": self.convert_concatenation,
"CONV_2D": functools.partial(self.convert_conv,
conv_type="conv2d"),
+ "CONV_3D": self.convert_conv3d,
"COS": functools.partial(self._convert_unary_elemwise,
relax_op=_op.cos),
"CUMSUM": self.convert_cumsum,
"DENSIFY": self.convert_densify,
@@ -2449,6 +2450,142 @@ class OperatorConverter:
out = self.convert_fused_activation_function(out,
fused_activation_fn)
return out
+ def convert_conv3d(self, op):
+ """3D convolution implementation."""
+
+ from tflite.BuiltinOptions import BuiltinOptions
+ from tflite.Conv3DOptions import Conv3DOptions
+ from tflite.Padding import Padding
+ from tflite.TensorType import TensorType
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) >= 2, "input tensors length should be >= 2"
+
+ input_tensor = input_tensors[0]
+ input_tensor_idx = input_tensor.tensor_idx
+ weight_tensor = input_tensors[1]
+
+ output_tensors = self.get_output_tensors(op)
+ assert len(output_tensors) == 1, "output tensors length should be 1"
+ output_tensor = output_tensors[0]
+
+ assert op.BuiltinOptionsType() == BuiltinOptions.Conv3DOptions
+ op_options = op.BuiltinOptions()
+ conv3d_options = Conv3DOptions()
+ conv3d_options.Init(op_options.Bytes, op_options.Pos)
+
+ stride_d = conv3d_options.StrideD()
+ stride_h = conv3d_options.StrideH()
+ stride_w = conv3d_options.StrideW()
+ dilation_d = conv3d_options.DilationDFactor()
+ dilation_h = conv3d_options.DilationHFactor()
+ dilation_w = conv3d_options.DilationWFactor()
+ padding = conv3d_options.Padding()
+ fused_activation_fn = conv3d_options.FusedActivationFunction()
+
+ _, input_d, input_h, input_w, input_c =
to_int_list(self.get_tensor_shape(input_tensor))
+ # TFLite Conv3D kernel layout is already DHWIO:
+ # KD KH KW IC OC
+ kernel_d, kernel_h, kernel_w, in_channels, output_channels =
to_int_list(
+ self.get_tensor_shape(weight_tensor)
+ )
+
+ dilated_kernel_d = dilation_d * (kernel_d - 1) + 1
+ dilated_kernel_h = dilation_h * (kernel_h - 1) + 1
+ dilated_kernel_w = dilation_w * (kernel_w - 1) + 1
+
+ params = {
+ "strides": [stride_d, stride_h, stride_w],
+ "dilation": [dilation_d, dilation_h, dilation_w],
+ "padding": [0, 0, 0, 0, 0, 0],
+ "data_layout": "NDHWC",
+ }
+
+ params["kernel_layout"] = "DHWIO"
+ if input_c != in_channels:
+ assert input_c % in_channels == 0, (
+ "Input channels is not divisible by kernel in_channels."
+ )
+ params["groups"] = int(input_c / in_channels)
+
+ # weight tensor type should be INT8/UINT8 (quantization) or FLOAT32
+ weight_tensor_type = weight_tensor.tensor.Type()
+ assert weight_tensor_type in (
+ TensorType.INT8,
+ TensorType.UINT8,
+ TensorType.FLOAT32,
+ )
+ weight_tensor_type_str = self.get_tensor_type_str(weight_tensor_type)
+
+ in_expr = self.get_expr(input_tensor_idx)
+
+ # TFLite Conv3D kernel is already in DHWIO layout, no transpose needed.
+ if self.has_expr(weight_tensor.tensor_idx):
+ weight_expr = self.get_expr(weight_tensor.tensor_idx)
+ else:
+ if self.is_prefetched(weight_tensor.tensor_idx):
+ weight_value =
self.get_prefetched_node(weight_tensor.tensor_idx)
+ else:
+ weight_value = self.get_tensor_value(weight_tensor)
+
+ weight_expr = self.exp_tab.new_const(
+ weight_value, dtype=weight_tensor_type_str,
+ source_name=weight_tensor.tensor.Name()
+ )
+
+ if padding == Padding.VALID:
+ pass
+ elif padding == Padding.SAME:
+ pad_front, pad_back = get_pad_value(input_d, dilated_kernel_d,
stride_d)
+ pad_top, pad_bottom = get_pad_value(input_h, dilated_kernel_h,
stride_h)
+ pad_left, pad_right = get_pad_value(input_w, dilated_kernel_w,
stride_w)
+
+ do_pad = not (
+ pad_front == 0 and pad_back == 0
+ and pad_top == 0 and pad_bottom == 0
+ and pad_left == 0 and pad_right == 0
+ )
+ if do_pad:
+ params["padding"] = [pad_front, pad_top, pad_left, pad_back,
pad_bottom, pad_right]
+ else:
+ raise tvm.error.OpAttributeUnImplemented(
+ f"Padding format {padding} is not supported for operator
Conv3D."
+ )
+
+ if input_tensor.qnn_params:
+ raise tvm.error.OpNotImplemented(
+ "Quantized Conv3D is not yet supported in the Relax frontend."
+ )
+
+ out = relax.op.nn.conv3d(in_expr, weight_expr, **params)
+
+ # if we have bias
+ if len(input_tensors) == 3:
+ bias_tensor = input_tensors[2]
+ if bias_tensor.tensor_idx != -1:
+ bias_tensor_type = bias_tensor.tensor.Type()
+ # bias tensor type should be INT32 (int8 qnn) or INT64 (int16
qnn) or FLOAT32
+ assert bias_tensor_type in (TensorType.INT32,
TensorType.INT64, TensorType.FLOAT32)
+ bias_tensor_type_str =
self.get_tensor_type_str(bias_tensor_type)
+ if self.has_expr(bias_tensor.tensor_idx):
+ bias_expr = self.get_expr(bias_tensor.tensor_idx)
+ else:
+ bias_expr = self.exp_tab.new_const(
+ self.get_tensor_value(bias_tensor),
+ dtype=bias_tensor_type_str,
+ source_name=bias_tensor.tensor.Name(),
+ )
+ out = relax.op.add(out, bias_expr)
+
+ # Handle fused activation.
+ if output_tensor.qnn_params:
+ raise tvm.error.OpNotImplemented(
+ "Quantized Conv3D is not yet supported in the Relax frontend."
+ )
+
+ out = self.convert_fused_activation_function(out, fused_activation_fn)
+ return out
+
def convert_split(self, op):
"""split implementation."""
diff --git a/tests/python/relax/test_frontend_tflite.py
b/tests/python/relax/test_frontend_tflite.py
index e4c237887e..d0401e4649 100644
--- a/tests/python/relax/test_frontend_tflite.py
+++ b/tests/python/relax/test_frontend_tflite.py
@@ -1611,6 +1611,89 @@ def test_conv2d_valid():
verify(Conv2DModule, Expected)
+def _make_conv3d_module(data_shape, kernel_shape, strides, padding):
+ class Conv3DModule(tf.Module):
+ @tf.function(
+ input_signature=[
+ tf.TensorSpec(shape=data_shape, dtype=tf.float32),
+ tf.TensorSpec(shape=kernel_shape, dtype=tf.float32),
+ ]
+ )
+ def func(self, data, kernel):
+ return tf.nn.conv3d(
+ input=data,
+ filters=kernel,
+ strides=strides,
+ padding=padding,
+ )
+
+ return Conv3DModule
+
+
+def test_conv3d_valid():
+ Conv3DModule = _make_conv3d_module(
+ (1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "VALID"
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"),
+ kernel: R.Tensor((3, 3, 3, 3, 16), dtype="float32"),
+ ) -> R.Tensor((1, 6, 6, 6, 16), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((1, 6, 6, 6, 16), dtype="float32") = R.nn.conv3d(
+ data,
+ kernel,
+ strides=[1, 1, 1],
+ padding=[0, 0, 0, 0, 0, 0],
+ dilation=[1, 1, 1],
+ groups=1,
+ data_layout="NDHWC",
+ kernel_layout="DHWIO",
+ out_layout="NDHWC",
+ out_dtype="void",
+ )
+ R.output(gv)
+ return gv
+
+ verify(Conv3DModule, Expected)
+
+
+def test_conv3d_same():
+ Conv3DModule = _make_conv3d_module(
+ (1, 8, 8, 8, 3), (3, 3, 3, 3, 16), (1, 1, 1, 1, 1), "SAME"
+ )
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ data: R.Tensor((1, 8, 8, 8, 3), dtype="float32"),
+ kernel: R.Tensor((3, 3, 3, 3, 16), dtype="float32"),
+ ) -> R.Tensor((1, 8, 8, 8, 16), dtype="float32"):
+ R.func_attr({"num_input": 2})
+ with R.dataflow():
+ gv: R.Tensor((1, 8, 8, 8, 16), dtype="float32") = R.nn.conv3d(
+ data,
+ kernel,
+ strides=[1, 1, 1],
+ padding=[1, 1, 1, 1, 1, 1],
+ dilation=[1, 1, 1],
+ groups=1,
+ data_layout="NDHWC",
+ kernel_layout="DHWIO",
+ out_layout="NDHWC",
+ out_dtype="void",
+ )
+ R.output(gv)
+ return gv
+
+ verify(Conv3DModule, Expected)
+
+
def _make_pool2d_module(pool, data_shape, ksize, data_format, strides,
padding):
class Pool2DModule(tf.Module):
@tf.function(