This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a393b47563 [Relax] Add FInferMixedPrecision and FRelaxInferLayout for
conv transpose ops (#18629)
a393b47563 is described below
commit a393b4756368d26db927085cb1de028b567e78c0
Author: Guan-Ming (Wesley) Chiu <[email protected]>
AuthorDate: Sat Jan 3 13:26:20 2026 +0800
[Relax] Add FInferMixedPrecision and FRelaxInferLayout for conv transpose
ops (#18629)
## Why
The `conv1d_transpose` and `conv2d_transpose` operators were missing
FInferMixedPrecision and FRelaxInferLayout attribute implementations,
which are needed for:
- Mixed precision training/inference support (e.g., float16 inputs with
float32 outputs)
- Layout transformation optimizations during compilation
- Consistency with conv1d and conv2d operators which already have these
attributes
## How
- Implemented InferLayoutConv1dTranspose and
InferMixedPrecisionConv1dTranspose
- Implemented InferLayoutConv2dTranspose and
InferMixedPrecisionConv2dTranspose
---
src/relax/op/nn/convolution.cc | 137 ++++++++++++++++++++++++++-
tests/python/relax/test_op_nn_convolution.py | 38 ++++++++
2 files changed, 171 insertions(+), 4 deletions(-)
diff --git a/src/relax/op/nn/convolution.cc b/src/relax/op/nn/convolution.cc
index 49e92719ba..ca09c0f1cb 100644
--- a/src/relax/op/nn/convolution.cc
+++ b/src/relax/op/nn/convolution.cc
@@ -707,14 +707,62 @@ StructInfo InferStructInfoConv1dTranspose(const Call&
call, const BlockBuilder&
return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
}
-// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for
conv1d_transpose
-// and unit test for mixed_precision
+InferLayoutOutput InferLayoutConv1dTranspose(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ const auto* attrs = call->attrs.as<Conv1DTransposeAttrs>();
+ LayoutDecision data_layout, weight_layout, output_layout;
+ ObjectPtr<Conv1DTransposeAttrs> new_attrs =
ffi::make_object<Conv1DTransposeAttrs>(*attrs);
+
+ auto it = desired_layouts.find("relax.nn.conv1d_transpose");
+ if (it != desired_layouts.end()) {
+ Layout desired_data_layout = (*it).second[0];
+ Layout desired_weight_layout = (*it).second[1];
+ Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2]
: (*it).second[0];
+ ICHECK_EQ(desired_data_layout.ndim(), desired_data_layout.ndim_primal())
<< "Axis swap only";
+ ICHECK_EQ(desired_weight_layout.ndim(),
desired_weight_layout.ndim_primal())
+ << "Axis swap only";
+ ICHECK_EQ(desired_output_layout.ndim(),
desired_output_layout.ndim_primal())
+ << "Axis swap only";
+ data_layout = TransposeLike(InitialLayout(3), attrs->data_layout,
desired_data_layout);
+ weight_layout = TransposeLike(InitialLayout(3), attrs->kernel_layout,
desired_weight_layout);
+ output_layout = TransposeLike(InitialLayout(3), attrs->out_layout,
desired_output_layout);
+ new_attrs->data_layout = (*it).second[0];
+ new_attrs->kernel_layout = (*it).second[1];
+ new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] :
(*it).second[0];
+ } else {
+ data_layout = GetLayoutDecision(var_layout_map, call->args[0]);
+ weight_layout = GetLayoutDecision(var_layout_map, call->args[1]);
+ output_layout = data_layout;
+ new_attrs->data_layout =
+ TransposeLike(attrs->data_layout, InitialLayout(3),
data_layout->layout).name();
+ new_attrs->kernel_layout =
+ TransposeLike(attrs->kernel_layout, InitialLayout(3),
weight_layout->layout).name();
+ new_attrs->out_layout =
+ TransposeLike(attrs->out_layout, InitialLayout(3),
output_layout->layout).name();
+ }
+ return InferLayoutOutput({data_layout, weight_layout}, {output_layout},
Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv1dTranspose(const Call& call, const DataType&
out_dtype) {
+ const auto* conv1d_transpose_attrs = call->attrs.as<Conv1DTransposeAttrs>();
+ return Downcast<Call>(
+ conv1d_transpose(call->args[0], call->args[1],
conv1d_transpose_attrs->strides,
+ conv1d_transpose_attrs->padding,
conv1d_transpose_attrs->output_padding,
+ conv1d_transpose_attrs->dilation,
conv1d_transpose_attrs->groups,
+ conv1d_transpose_attrs->data_layout,
conv1d_transpose_attrs->kernel_layout,
+ conv1d_transpose_attrs->out_layout, out_dtype));
+}
+
TVM_REGISTER_OP("relax.nn.conv1d_transpose")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_attrs_type<Conv1DTransposeAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoConv1dTranspose)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutConv1dTranspose)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kAlways)
+ .set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionConv1dTranspose)
.set_attr<Bool>("FPurity", Bool(true));
/* relax.nn.conv2d_transpose */
@@ -857,14 +905,95 @@ StructInfo InferStructInfoConv2dTranspose(const Call&
call, const BlockBuilder&
return TensorStructInfo(ShapeExpr(out_shape), out_dtype, vdevice);
}
-// TODO(relax-team): implement FInferMixedPrecision and FRelaxInferLayout for
conv2d_transpose
-// and unit test for mixed_precision
+InferLayoutOutput InferLayoutConv2dTranspose(
+ const Call& call, const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
desired_layouts,
+ const VarLayoutMap& var_layout_map) {
+ const auto* attrs = call->attrs.as<Conv2DTransposeAttrs>();
+ LayoutDecision data_layout = GetLayoutDecision(var_layout_map,
call->args[0]);
+ LayoutDecision weight_layout = GetLayoutDecision(var_layout_map,
call->args[1]);
+ LayoutDecision output_layout;
+ ObjectPtr<Conv2DTransposeAttrs> new_attrs =
ffi::make_object<Conv2DTransposeAttrs>(*attrs);
+
+ auto it = desired_layouts.find("relax.nn.conv2d_transpose");
+ if (it != desired_layouts.end()) {
+ Layout desired_data_layout = (*it).second[0];
+ Layout desired_weight_layout = (*it).second[1];
+ Layout desired_output_layout = (*it).second.size() == 3 ? (*it).second[2]
: (*it).second[0];
+
+ Layout input_layout = Layout(attrs->data_layout);
+ Layout kernel_layout = Layout(attrs->kernel_layout);
+ Layout out_layout = Layout(attrs->out_layout);
+
+ if (desired_data_layout.ndim_primal() == input_layout.ndim() &&
+ desired_weight_layout.ndim_primal() == kernel_layout.ndim() &&
+ desired_output_layout.ndim_primal() == out_layout.ndim()) {
+ data_layout = TransposeLike(InitialLayout(4), attrs->data_layout,
desired_data_layout);
+ weight_layout = TransposeLike(InitialLayout(4), attrs->kernel_layout,
desired_weight_layout);
+ output_layout = TransposeLike(InitialLayout(4), attrs->out_layout,
desired_output_layout);
+ new_attrs->data_layout = (*it).second[0];
+ new_attrs->kernel_layout = (*it).second[1];
+ new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] :
(*it).second[0];
+ return InferLayoutOutput({data_layout, weight_layout}, {output_layout},
Attrs(new_attrs));
+ } else {
+ auto data_si = GetStructInfo(call->args[0]);
+ auto kernel_si = GetStructInfo(call->args[1]);
+ TensorStructInfo data_sinfo = data_si.as<TensorStructInfo>().value();
+ TensorStructInfo kernel_sinfo = kernel_si.as<TensorStructInfo>().value();
+ ffi::Optional<ShapeExpr> data_shape =
+ ffi::GetRef<ShapeExpr>(data_sinfo->shape.as<ShapeExprNode>());
+ ffi::Optional<ShapeExpr> kernel_shape =
+ ffi::GetRef<ShapeExpr>(kernel_sinfo->shape.as<ShapeExprNode>());
+
+ bool can_data_proved =
+ CanProveLayoutTransform(input_layout, desired_data_layout,
data_shape.value()->values);
+ bool can_kernel_proved = CanProveLayoutTransform(kernel_layout,
desired_weight_layout,
+
kernel_shape.value()->values);
+
+ if (can_data_proved && can_kernel_proved) {
+ data_layout = TransposeSubLayoutLike(InitialLayout(4), input_layout,
desired_data_layout);
+ weight_layout =
+ TransposeSubLayoutLike(InitialLayout(4), kernel_layout,
desired_weight_layout);
+ output_layout = TransposeSubLayoutLike(InitialLayout(4), out_layout,
desired_output_layout);
+ new_attrs->data_layout = (*it).second[0];
+ new_attrs->kernel_layout = (*it).second[1];
+ new_attrs->out_layout = (*it).second.size() == 3 ? (*it).second[2] :
(*it).second[0];
+ return InferLayoutOutput({data_layout, weight_layout},
{output_layout}, Attrs(new_attrs));
+ } else {
+ data_layout = LayoutDecision(InitialLayout(4));
+ weight_layout = LayoutDecision(InitialLayout(4));
+ }
+ }
+ }
+
+ output_layout = data_layout;
+ new_attrs->data_layout =
+ TransposeLike(attrs->data_layout, InitialLayout(4),
data_layout->layout).name();
+ new_attrs->kernel_layout =
+ TransposeLike(attrs->kernel_layout, InitialLayout(4),
weight_layout->layout).name();
+ new_attrs->out_layout =
+ TransposeLike(attrs->out_layout, InitialLayout(4),
output_layout->layout).name();
+ return InferLayoutOutput({data_layout, weight_layout}, {output_layout},
Attrs(new_attrs));
+}
+
+Call InferMixedPrecisionConv2dTranspose(const Call& call, const DataType&
out_dtype) {
+ const auto* conv2d_transpose_attrs = call->attrs.as<Conv2DTransposeAttrs>();
+ return Downcast<Call>(
+ conv2d_transpose(call->args[0], call->args[1],
conv2d_transpose_attrs->strides,
+ conv2d_transpose_attrs->padding,
conv2d_transpose_attrs->output_padding,
+ conv2d_transpose_attrs->dilation,
conv2d_transpose_attrs->groups,
+ conv2d_transpose_attrs->data_layout,
conv2d_transpose_attrs->kernel_layout,
+ conv2d_transpose_attrs->out_layout, out_dtype));
+}
+
TVM_REGISTER_OP("relax.nn.conv2d_transpose")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("weight", "Tensor", "The weight tensor.")
.set_attrs_type<Conv2DTransposeAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo",
InferStructInfoConv2dTranspose)
+ .set_attr<FRelaxInferLayout>("FRelaxInferLayout",
InferLayoutConv2dTranspose)
+ .set_attr<TMixedPrecisionPolicy>("TMixedPrecisionPolicy",
MixedPrecisionPolicyKind::kAlways)
+ .set_attr<FInferMixedPrecision>("FInferMixedPrecision",
InferMixedPrecisionConv2dTranspose)
.set_attr<Bool>("FPurity", Bool(true));
} // namespace relax
diff --git a/tests/python/relax/test_op_nn_convolution.py
b/tests/python/relax/test_op_nn_convolution.py
index 588dc9b1b1..9b913138df 100644
--- a/tests/python/relax/test_op_nn_convolution.py
+++ b/tests/python/relax/test_op_nn_convolution.py
@@ -782,6 +782,25 @@ def
test_conv1d_transpose_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.conv1d_transpose(x1, w0))
+def test_conv1d_transpose_infer_struct_info_mixed_precision():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28), "float16"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28), "int8"))
+ w1 = relax.Var("w", R.Tensor((3, 4, 3), "int8"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x0, w0, out_dtype="float32"),
+ relax.TensorStructInfo((2, 4, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv1d_transpose(x1, w1, out_dtype="int32"),
+ relax.TensorStructInfo((2, 4, 30), "int32"),
+ )
+
+
def test_conv2d_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")
@@ -1571,6 +1590,25 @@ def
test_conv2d_transpose_infer_struct_info_wrong_input_type():
bb.normalize(relax.op.nn.conv2d_transpose(x1, w0))
+def test_conv2d_transpose_infer_struct_info_mixed_precision():
+ bb = relax.BlockBuilder()
+ x0 = relax.Var("x", R.Tensor((2, 3, 28, 28), "float16"))
+ w0 = relax.Var("w", R.Tensor((3, 4, 3, 3), "float16"))
+ x1 = relax.Var("x", R.Tensor((2, 3, 28, 28), "int8"))
+ w1 = relax.Var("w", R.Tensor((3, 4, 3, 3), "int8"))
+
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x0, w0, out_dtype="float32"),
+ relax.TensorStructInfo((2, 4, 30, 30), "float32"),
+ )
+ _check_inference(
+ bb,
+ relax.op.nn.conv2d_transpose(x1, w1, out_dtype="int32"),
+ relax.TensorStructInfo((2, 4, 30, 30), "int32"),
+ )
+
+
def test_conv3d_infer_struct_info():
bb = relax.BlockBuilder()
vdev0 = VDevice("llvm")