sunggg commented on code in PR #15842:
URL: https://github.com/apache/tvm/pull/15842#discussion_r1364127634
##########
tests/python/relax/test_transform_legalize_ops.py:
##########
@@ -282,5 +284,77 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) ->
R.Tensor([16, 8]):
assert err_message.startswith("To legalize R.matmul")
+emit_legalization_through_builder = tvm.testing.parameter(
+ by_dict={
+ "return_relax_expr": False,
+ "return_relax_var": True,
+ }
+)
+
+
[email protected]
+def custom_op(emit_legalization_through_builder):
+ op_name = "custom_op.matmul_bias_add"
+
+ def infer_struct_info(call: relax.Call, context):
+ activations, weight, bias = call.args
+
+ matmul_call = relax.op.matmul(activations, weight)
+ matmul_sinfo =
tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
+ matmul_call, context
+ )
+
+ matmul_var = relax.Var("dummy_var", matmul_sinfo)
+ add_call = matmul_var + bias
+ add_sinfo =
tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
+
+ return add_sinfo
+
+ def legalize(bb: relax.BlockBuilder, call: relax.Call):
Review Comment:
Interesting! I did not know there is a training-specific version of batch
norm. SGTM. Let's discuss about it in the follow-up PR.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]