Aharrypotter commented on code in PR #19587:
URL: https://github.com/apache/tvm/pull/19587#discussion_r3264880012


##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -1483,6 +1493,419 @@ def _get_stablehlo_options(self, op, options_cls):
         result.Init(op_options.Bytes, op_options.Pos)
         return result
 
+    def _get_static_tensor_shape(self, tensor, op_name):
+        """Return a statically-known TFLite tensor shape as Python ints."""
+        try:
+            return [int(dim) for dim in self.get_tensor_shape(tensor)]
+        except (TypeError, ValueError) as err:
+            raise tvm.error.OpNotImplemented(
+                f"{op_name} requires statically-known tensor shapes"
+            ) from err
+
+    def _get_stablehlo_i64_vector(self, vector, default):
+        """Convert an optional StableHLO int64 vector field to a Python int 
list."""
+        if vector is None or isinstance(vector, int):
+            return list(default)
+        return [int(v) for v in vector]
+
+    def _ensure_stablehlo_float_dtype(self, expr, op_name):
+        """Return expr dtype if the StableHLO subset supports it."""
+        dtype = expr.struct_info.dtype
+        if not dtype.startswith("float"):
+            raise tvm.error.OpNotImplemented(f"{op_name} with dtype {dtype} is 
not supported")
+        return dtype
+
+    def _convert_stablehlo_cbrt(self, op):
+        """Convert STABLEHLO_CBRT to a sign-preserving Relax expression."""
+        input_tensors = self.get_input_tensors(op)
+        assert len(input_tensors) == 1, "input tensors length should be 1"
+        assert len(self.get_output_tensors(op)) == 1
+
+        data = self.get_tensor_expr(input_tensors[0])
+        dtype = self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CBRT")
+        zero = relax.const(0, dtype)
+        exponent = relax.const(1.0 / 3.0, dtype)
+
+        is_negative = self.bb.normalize(relax.op.less(data, zero))
+        negative_base = self.bb.normalize(relax.op.negative(data))
+        negative_root = self.bb.normalize(relax.op.power(negative_base, 
exponent))
+        negative_result = self.bb.normalize(relax.op.negative(negative_root))
+        positive_result = self.bb.normalize(relax.op.power(data, exponent))
+        return self.bb.normalize(relax.op.where(is_negative, negative_result, 
positive_result))

Review Comment:
   Thanks for the suggestion. The `sign(x) * power(abs(x), 1 / 3)` formulation 
is more compact, but I prefer to keep the current expansion in this PR.
   
   The main reason is semantic conservatism: StableHLO `cbrt` should behave 
like a real cube root, including edge cases such as signed zero and NaN 
propagation. Using `relax.op.sign` would make this converter depend on Relax 
`sign` behavior for those edge cases, and I do not want to introduce that 
semantic dependency in the same PR that adds the StableHLO op coverage.
   
   The current implementation keeps the negative-input handling explicit and 
local to this converter. We can revisit the shorter `sign(abs(x))` formulation 
in a follow-up once signed-zero behavior is validated against StableHLO/XLA 
`cbrt` semantics.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to