Lunderberg commented on code in PR #15842:
URL: https://github.com/apache/tvm/pull/15842#discussion_r1362092455


##########
tests/python/relax/test_transform_legalize_ops.py:
##########
@@ -282,5 +284,77 @@ def main(A: R.Tensor([16, 32]), B: R.Tensor([32, 8])) -> 
R.Tensor([16, 8]):
     assert err_message.startswith("To legalize R.matmul")
 
 
+emit_legalization_through_builder = tvm.testing.parameter(
+    by_dict={
+        "return_relax_expr": False,
+        "return_relax_var": True,
+    }
+)
+
+
[email protected]
+def custom_op(emit_legalization_through_builder):
+    op_name = "custom_op.matmul_bias_add"
+
+    def infer_struct_info(call: relax.Call, context):
+        activations, weight, bias = call.args
+
+        matmul_call = relax.op.matmul(activations, weight)
+        matmul_sinfo = 
tvm.ir.Op.get("relax.matmul").get_attr("FInferStructInfo")(
+            matmul_call, context
+        )
+
+        matmul_var = relax.Var("dummy_var", matmul_sinfo)
+        add_call = matmul_var + bias
+        add_sinfo = 
tvm.ir.Op.get("relax.add").get_attr("FInferStructInfo")(add_call, context)
+
+        return add_sinfo
+
+    def legalize(bb: relax.BlockBuilder, call: relax.Call):

Review Comment:
   Taking a look, the `DecomposeOpsFor*` passes are currently doing two 
distinct roles.  The first role is to lower the `relax.nn.batch_norm`, 
`relax.nn.layer_norm`, and `relax.tensor_to_shape` operators into lower-level 
relax implementations.  The second role is to mutate the `relax.nn.batch_norm` 
operator into a training-specific version.
   
   I think the first role of lowering relax operators into less complex Relax 
operators will be supported by the partial lowering intended for `LegalizeOps`. 
 The second role is independent to the legalization, and would be best kept as 
a standalone pass.  The second role would become much simpler, as the 
`relax.nn.batch_norm(data, gamma, beta, prev_mean, prev_var)` could be updated 
to `relax.nn.batch_norm(data, gamma, beta, weighted_avg(mean(data), prev_mean), 
weighted_avg(var(data), prev_var))`, rather than needing a full definition of 
`relax.nn.batch_norm`.
   
   Though, those are probably changes that would be best for a follow-up PR.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to