gemini-code-assist[bot] commented on code in PR #18857:
URL: https://github.com/apache/tvm/pull/18857#discussion_r2868359822


##########
python/tvm/script/parser/tir/parser.py:
##########
@@ -537,9 +537,23 @@ def visit_assert(self: Parser, node: doc.Assert) -> None:
 
     node : doc.Assert
         The doc AST assert node.
+
+    The assert message can be either:
+    - A plain string: ``assert cond, "message"``
+    - A tuple of (kind, [parts...]): ``assert cond, ("ValueError", ["part0", 
"part1"])``
     """
     cond = self.eval_expr(node.test)
     msg = self.eval_expr(node.msg)
+    if isinstance(msg, tuple) and len(msg) == 2:
+        kind_str, parts = msg
+        if isinstance(kind_str, tvm.tir.StringImm):
+            kind_str = kind_str.value
+        if isinstance(parts, list | tuple):
+            parts_str = [p.value if isinstance(p, tvm.tir.StringImm) else 
str(p) for p in parts]
+            frame = T.Assert(cond, parts_str, kind=str(kind_str))
+            frame.add_callback(partial(frame.__exit__, None, None, None))
+            frame.__enter__()
+            return
     frame = T.Assert(cond, msg)
     frame.add_callback(partial(frame.__exit__, None, None, None))
     frame.__enter__()

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The logic for parsing `assert` statements with structured messages has a 
subtle bug. If the message is a tuple where the second element is not a list or 
tuple (e.g., `("ValueError", "a single string")`), the `kind` is not correctly 
propagated. The code falls through to `T.Assert(cond, msg)`, which uses the 
default `kind="RuntimeError"`, ignoring the "ValueError" from the message tuple.
   
   I suggest refactoring this logic to be clearer and to correctly handle all 
cases. The refactored code below correctly extracts the `kind` and `message` 
parts and passes them to `T.Assert`.
   
   ```python
       cond = self.eval_expr(node.test)
       msg = self.eval_expr(node.msg)
   
       kind = "RuntimeError"
       message = msg
   
       if isinstance(msg, tuple) and len(msg) == 2:
           kind_str, parts = msg
           if isinstance(kind_str, tvm.tir.StringImm):
               kind_str = kind_str.value
           kind = str(kind_str)
           message = parts
   
       if isinstance(message, list | tuple):
           message = [p.value if isinstance(p, tvm.tir.StringImm) else str(p) 
for p in message]
   
       frame = T.Assert(cond, message, kind=kind)
       frame.add_callback(partial(frame.__exit__, None, None, None))
       frame.__enter__()
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to