sunggg commented on code in PR #15842:
URL: https://github.com/apache/tvm/pull/15842#discussion_r1358446279


##########
tests/python/relax/test_transform_legalize_ops.py:
##########
@@ -282,5 +284,77 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> 
R.Tensor([16, 8]):
     assert err_message.startswith("To legalize R.matmul")
 
 
+emit_legalization_through_builder = tvm.testing.parameter(
+    by_dict={
+        "return_relax_expr": False,
+        "return_relax_var": True,
+    }
+)
+
+
[email protected]
+def custom_op(emit_legalization_through_builder):
+    op_name = "custom_op.matmul_bias_add"
+
+    def infer_struct_info(call: relax.Call, context):
+        activations, weight, bias = call.args
+
+        matmul_call = relax.op.matmul(activations, weight)
+        matmul_sinfo = 
tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
+            matmul_call, context
+        )
+
+        matmul_var = relax.Var("dummy_var", matmul_sinfo)
+        add_call = matmul_var + bias
+        add_sinfo = 
tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
+
+        return add_sinfo
+
+    def legalize(bb: relax.BlockBuilder, call: relax.Call):
+        activations, weight, bias = call.args
+        legalized = relax.op.matmul(activations, weight) + bias
+        if emit_legalization_through_builder:
+            legalized = bb.emit(legalized)
+        return legalized
+
+    op_attrs = {
+        "FInferStructInfo": infer_struct_info,
+        "FLegalize": legalize,
+        "FPurity": True,
+    }
+
+    for key, value in op_attrs.items():
+        tvm.ir.register_op_attr(op_name, key, value)
+
+    op = tvm.ir.Op.get(op_name)
+    yield op
+
+    for key in op_attrs:
+        op.reset_attr(key)
+
+
+def test_recursive_legalization(custom_op):
+    """Legalization of an operator may produce new operators requiring 
legalization"""
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            A: R.Tensor([16, 32, 64], "float32"),
+            Weight: R.Tensor([64, 128], "float32"),
+            Bias: R.Tensor([16, 32, 128], "float32"),
+        ):
+            return relax.Call(custom_op, [A, Weight, Bias])
+
+    AfterFirstIter = LegalizeOps()(Before)

Review Comment:
   Does user need to perform `LegalizeOps` passes depending on their custom 
ops? For example, user needs to call twice. 



##########
tests/python/relax/test_transform_legalize_ops.py:
##########
@@ -282,5 +284,77 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> 
R.Tensor([16, 8]):
     assert err_message.startswith("To legalize R.matmul")
 
 
+emit_legalization_through_builder = tvm.testing.parameter(
+    by_dict={
+        "return_relax_expr": False,
+        "return_relax_var": True,
+    }
+)
+
+
[email protected]
+def custom_op(emit_legalization_through_builder):
+    op_name = "custom_op.matmul_bias_add"
+
+    def infer_struct_info(call: relax.Call, context):
+        activations, weight, bias = call.args
+
+        matmul_call = relax.op.matmul(activations, weight)
+        matmul_sinfo = 
tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
+            matmul_call, context
+        )
+
+        matmul_var = relax.Var("dummy_var", matmul_sinfo)
+        add_call = matmul_var + bias
+        add_sinfo = 
tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
+
+        return add_sinfo
+
+    def legalize(bb: relax.BlockBuilder, call: relax.Call):

Review Comment:
   We have a similar pass although this does not support any recursion: 
https://github.com/apache/tvm/blob/unity/python/tvm/relax/transform/transform.py#L994
   
   Is there any use-case for recursion? Or is it more like a future-proof?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to