mshr-h commented on code in PR #17864:
URL: https://github.com/apache/tvm/pull/17864#discussion_r2052162278


##########
tests/python/relax/test_frontend_from_exported_program.py:
##########
@@ -4377,5 +4377,26 @@ def main(
     verify_model(Narrow(), example_args, {}, Expected)
 
 
+def test_eye():
+    class Eye(Module):
+        def forward(self, input):
+            return torch.eye(3, 5, dtype=torch.float32)

Review Comment:
   Need a testcase where only `n` is given.



##########
python/tvm/relax/frontend/torch/base_fx_graph_translator.py:
##########
@@ -1416,6 +1416,19 @@ def _empty_like(self, node: fx.Node) -> relax.Var:
         x = self.env[node.args[0]]
         return self.block_builder.emit(relax.op.zeros_like(x))
 
+    def _eye(self, node: fx.Node) -> relax.Var:
+        args = self.retrieve_args(node)
+        if len(args) == 1:
+            n = args[0]
+            m = n
+        elif len(args) == 2:
+            n = args[0]
+            m = args[1]
+        else:

Review Comment:
   I don't think we need to take care of this condition because the pytorch 
exporter does that.
   So we can assume that the graph in the exported program doesn't contain 
wrong nodes.
   Probably this is enough?
   ```python
   n = args[0]
   m = args[1] if len(args) > 1 else n
   ```
   
   cc @Hzfengsy 
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to