gemini-code-assist[bot] commented on code in PR #19587:
URL: https://github.com/apache/tvm/pull/19587#discussion_r3263875310
##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -1483,6 +1493,419 @@ def _get_stablehlo_options(self, op, options_cls):
result.Init(op_options.Bytes, op_options.Pos)
return result
+ def _get_static_tensor_shape(self, tensor, op_name):
+ """Return a statically-known TFLite tensor shape as Python ints."""
+ try:
+ return [int(dim) for dim in self.get_tensor_shape(tensor)]
+ except (TypeError, ValueError) as err:
+ raise tvm.error.OpNotImplemented(
+ f"{op_name} requires statically-known tensor shapes"
+ ) from err
+
+ def _get_stablehlo_i64_vector(self, vector, default):
+ """Convert an optional StableHLO int64 vector field to a Python int
list."""
+ if vector is None or isinstance(vector, int):
+ return list(default)
+ return [int(v) for v in vector]
+
+ def _ensure_stablehlo_float_dtype(self, expr, op_name):
+ """Return expr dtype if the StableHLO subset supports it."""
+ dtype = expr.struct_info.dtype
+ if not dtype.startswith("float"):
+ raise tvm.error.OpNotImplemented(f"{op_name} with dtype {dtype} is
not supported")
+ return dtype
+
+ def _convert_stablehlo_cbrt(self, op):
+ """Convert STABLEHLO_CBRT to a sign-preserving Relax expression."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ assert len(self.get_output_tensors(op)) == 1
+
+ data = self.get_tensor_expr(input_tensors[0])
+ dtype = self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CBRT")
+ zero = relax.const(0, dtype)
+ exponent = relax.const(1.0 / 3.0, dtype)
+
+ is_negative = self.bb.normalize(relax.op.less(data, zero))
+ negative_base = self.bb.normalize(relax.op.negative(data))
+ negative_root = self.bb.normalize(relax.op.power(negative_base,
exponent))
+ negative_result = self.bb.normalize(relax.op.negative(negative_root))
+ positive_result = self.bb.normalize(relax.op.power(data, exponent))
+ return self.bb.normalize(relax.op.where(is_negative, negative_result,
positive_result))
Review Comment:

The implementation of sign-preserving cube root can be simplified and made
more concise by using `relax.op.sign` and `relax.op.abs` instead of a `where`
condition. This might also be more efficient on some backends.
```suggestion
abs_data = self.bb.normalize(relax.op.abs(data))
root = self.bb.normalize(relax.op.power(abs_data, exponent))
sign = self.bb.normalize(relax.op.sign(data))
return self.bb.normalize(relax.op.multiply(sign, root))
```
##########
python/tvm/relax/frontend/tflite/tflite_frontend.py:
##########
@@ -1483,6 +1493,419 @@ def _get_stablehlo_options(self, op, options_cls):
result.Init(op_options.Bytes, op_options.Pos)
return result
+ def _get_static_tensor_shape(self, tensor, op_name):
+ """Return a statically-known TFLite tensor shape as Python ints."""
+ try:
+ return [int(dim) for dim in self.get_tensor_shape(tensor)]
+ except (TypeError, ValueError) as err:
+ raise tvm.error.OpNotImplemented(
+ f"{op_name} requires statically-known tensor shapes"
+ ) from err
+
+ def _get_stablehlo_i64_vector(self, vector, default):
+ """Convert an optional StableHLO int64 vector field to a Python int
list."""
+ if vector is None or isinstance(vector, int):
+ return list(default)
+ return [int(v) for v in vector]
+
+ def _ensure_stablehlo_float_dtype(self, expr, op_name):
+ """Return expr dtype if the StableHLO subset supports it."""
+ dtype = expr.struct_info.dtype
+ if not dtype.startswith("float"):
+ raise tvm.error.OpNotImplemented(f"{op_name} with dtype {dtype} is
not supported")
+ return dtype
+
+ def _convert_stablehlo_cbrt(self, op):
+ """Convert STABLEHLO_CBRT to a sign-preserving Relax expression."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 1, "input tensors length should be 1"
+ assert len(self.get_output_tensors(op)) == 1
+
+ data = self.get_tensor_expr(input_tensors[0])
+ dtype = self._ensure_stablehlo_float_dtype(data, "STABLEHLO_CBRT")
+ zero = relax.const(0, dtype)
+ exponent = relax.const(1.0 / 3.0, dtype)
+
+ is_negative = self.bb.normalize(relax.op.less(data, zero))
+ negative_base = self.bb.normalize(relax.op.negative(data))
+ negative_root = self.bb.normalize(relax.op.power(negative_base,
exponent))
+ negative_result = self.bb.normalize(relax.op.negative(negative_root))
+ positive_result = self.bb.normalize(relax.op.power(data, exponent))
+ return self.bb.normalize(relax.op.where(is_negative, negative_result,
positive_result))
+
+ def _convert_stablehlo_remainder(self, op):
+ """Convert STABLEHLO_REMAINDER to truncating remainder for float
tensors."""
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ lhs = self.get_tensor_expr(input_tensors[0])
+ rhs = self.get_tensor_expr(input_tensors[1])
+ self._ensure_stablehlo_float_dtype(lhs, "STABLEHLO_REMAINDER")
+ self._ensure_stablehlo_float_dtype(rhs, "STABLEHLO_REMAINDER")
+
+ quotient = self.bb.normalize(relax.op.divide(lhs, rhs))
+ truncated = self.bb.normalize(relax.op.trunc(quotient))
+ product = self.bb.normalize(relax.op.multiply(rhs, truncated))
+ return self.bb.normalize(relax.op.subtract(lhs, product))
+
+ def _get_stablehlo_simple_body_op(self, body_subgraph_index,
parent_op_name, input_count):
+ """Return the single operator from a simple StableHLO body subgraph."""
+ if body_subgraph_index <= 0 or body_subgraph_index >=
self.model.SubgraphsLength():
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} requires a valid non-main body subgraph"
+ )
+
+ body_subgraph = self.model.Subgraphs(body_subgraph_index)
+ if (
+ body_subgraph.InputsLength() != input_count
+ or body_subgraph.OutputsLength() != 1
+ or body_subgraph.OperatorsLength() != 1
+ ):
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} only supports single-op body subgraphs"
+ )
+
+ return body_subgraph.Operators(0)
+
+ def _check_stablehlo_reduce_init(
+ self, init_tensor, reducer_name, parent_op_name="STABLEHLO_REDUCE"
+ ):
+ """Validate that the StableHLO reduce init value matches the Relax
identity."""
+ if self.has_expr(init_tensor.tensor_idx):
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} with dynamic init values is not supported"
+ )
+
+ init_value = np.asarray(self.get_tensor_value(init_tensor))
+ if init_value.shape not in [(), (1,)]:
+ raise tvm.error.OpNotImplemented(f"{parent_op_name} requires
scalar init values")
+
+ dtype = init_value.dtype
+ scalar = init_value.item()
+ if reducer_name == "STABLEHLO_ADD":
+ is_identity = bool(np.isclose(scalar, 0))
+ elif reducer_name == "STABLEHLO_MULTIPLY":
+ is_identity = bool(np.isclose(scalar, 1))
+ elif reducer_name == "STABLEHLO_MAXIMUM":
+ if np.issubdtype(dtype, np.floating):
+ is_identity = bool(np.isneginf(scalar))
+ elif np.issubdtype(dtype, np.integer):
+ is_identity = scalar == np.iinfo(dtype).min
+ else:
+ is_identity = False
+ elif reducer_name == "STABLEHLO_MINIMUM":
+ if np.issubdtype(dtype, np.floating):
+ is_identity = bool(np.isposinf(scalar))
+ elif np.issubdtype(dtype, np.integer):
+ is_identity = scalar == np.iinfo(dtype).max
+ else:
+ is_identity = False
+ else:
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} reducer {reducer_name} is not supported"
+ )
+
+ if not is_identity:
+ raise tvm.error.OpNotImplemented(
+ f"{parent_op_name} init value must match the reducer identity"
+ )
+
+ def _convert_stablehlo_reduce(self, op):
+ """Convert the single-input STABLEHLO_REDUCE subset to Relax
reductions."""
+ from tflite.StablehloReduceOptions import StablehloReduceOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloReduceOptions)
+ dimensions = self._get_stablehlo_i64_vector(opts.DimensionsAsNumpy(),
[])
+ body_op = self._get_stablehlo_simple_body_op(
+ int(opts.BodySubgraphIndex()), "STABLEHLO_REDUCE", 2
+ )
+ reducer_name = self.get_op_code_str(body_op)
+
+ reducers = {
+ "STABLEHLO_ADD": relax.op.sum,
+ "STABLEHLO_MAXIMUM": relax.op.max,
+ "STABLEHLO_MINIMUM": relax.op.min,
+ "STABLEHLO_MULTIPLY": relax.op.prod,
+ }
+ if reducer_name not in reducers:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_REDUCE reducer {reducer_name} is not supported"
+ )
+
+ self._check_stablehlo_reduce_init(input_tensors[1], reducer_name)
+ data = self.get_tensor_expr(input_tensors[0])
+ return self.bb.normalize(reducers[reducer_name](data, axis=dimensions,
keepdims=False))
+
+ def _convert_stablehlo_reduce_window(self, op):
+ """Convert the NHWC 2D max-pool STABLEHLO_REDUCE_WINDOW subset."""
+ from tflite.StablehloReduceWindowOptions import
StablehloReduceWindowOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 2, "input tensors length should be 2"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloReduceWindowOptions)
+ body_op = self._get_stablehlo_simple_body_op(
+ int(opts.BodySubgraphIndex()), "STABLEHLO_REDUCE_WINDOW", 2
+ )
+ reducer_name = self.get_op_code_str(body_op)
+ if reducer_name != "STABLEHLO_MAXIMUM":
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports MAXIMUM reducer windows"
+ )
+ self._check_stablehlo_reduce_init(
+ input_tensors[1], reducer_name, "STABLEHLO_REDUCE_WINDOW"
+ )
+
+ data_shape = self._get_static_tensor_shape(input_tensors[0],
"STABLEHLO_REDUCE_WINDOW")
+ if len(data_shape) != 4:
+ raise tvm.error.OpNotImplemented("STABLEHLO_REDUCE_WINDOW only
supports 4D input")
+
+ window_dimensions =
self._get_stablehlo_i64_vector(opts.WindowDimensionsAsNumpy(), [])
+ window_strides = self._get_stablehlo_i64_vector(
+ opts.WindowStridesAsNumpy(), [1] * len(window_dimensions)
+ )
+ base_dilations = self._get_stablehlo_i64_vector(
+ opts.BaseDilationsAsNumpy(), [1] * len(window_dimensions)
+ )
+ window_dilations = self._get_stablehlo_i64_vector(
+ opts.WindowDilationsAsNumpy(), [1] * len(window_dimensions)
+ )
+ padding = self._get_stablehlo_i64_vector(
+ opts.PaddingAsNumpy(), [0] * (2 * len(window_dimensions))
+ )
+
+ if (
+ len(window_dimensions) != 4
+ or len(window_strides) != 4
+ or len(base_dilations) != 4
+ or len(window_dilations) != 4
+ or len(padding) != 8
+ ):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports rank-4 window
attributes"
+ )
+ if window_dimensions[0] != 1 or window_dimensions[3] != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports pooling over spatial
dimensions"
+ )
+ if window_strides[0] != 1 or window_strides[3] != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports unit batch/channel
strides"
+ )
+ if base_dilations != [1, 1, 1, 1]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW with base dilation is not supported"
+ )
+ if padding[0] != 0 or padding[1] != 0 or padding[6] != 0 or padding[7]
!= 0:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_REDUCE_WINDOW only supports spatial padding"
+ )
+
+ data = self.get_tensor_expr(input_tensors[0])
+ return self.bb.normalize(
+ relax.op.nn.max_pool2d(
+ data,
+ pool_size=[window_dimensions[1], window_dimensions[2]],
+ strides=[window_strides[1], window_strides[2]],
+ padding=[padding[2], padding[4], padding[3], padding[5]],
+ dilation=[window_dilations[1], window_dilations[2]],
+ layout="NHWC",
+ out_layout="NHWC",
+ )
+ )
+
+ def _convert_stablehlo_scatter(self, op):
+ """Convert the canonical point-update STABLEHLO_SCATTER subset."""
+ from tflite.StablehloScatterOptions import StablehloScatterOptions
+
+ input_tensors = self.get_input_tensors(op)
+ assert len(input_tensors) == 3, "input tensors length should be 3"
+ assert len(self.get_output_tensors(op)) == 1
+
+ opts = self._get_stablehlo_options(op, StablehloScatterOptions)
+ operand_shape = self._get_static_tensor_shape(input_tensors[0],
"STABLEHLO_SCATTER")
+ indices_shape = self._get_static_tensor_shape(input_tensors[1],
"STABLEHLO_SCATTER")
+ updates_shape = self._get_static_tensor_shape(input_tensors[2],
"STABLEHLO_SCATTER")
+ operand_rank = len(operand_shape)
+ indices_rank = len(indices_shape)
+
+ update_window_dims =
self._get_stablehlo_i64_vector(opts.UpdateWindowDimsAsNumpy(), [])
+ inserted_window_dims = self._get_stablehlo_i64_vector(
+ opts.InsertedWindowDimsAsNumpy(), []
+ )
+ scatter_dims_to_operand_dims = self._get_stablehlo_i64_vector(
+ opts.ScatterDimsToOperandDimsAsNumpy(), []
+ )
+ index_vector_dim = int(opts.IndexVectorDim())
+
+ if indices_rank == 0 or index_vector_dim != indices_rank - 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER only supports trailing index-vector
dimensions"
+ )
+ if update_window_dims:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER only supports point updates without update
windows"
+ )
+ if inserted_window_dims != list(range(operand_rank)):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER only supports point updates for every
operand dimension"
+ )
+ if scatter_dims_to_operand_dims != list(range(operand_rank)):
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER only supports canonical scatter-to-operand
dimensions"
+ )
+ if indices_shape[-1] != operand_rank or updates_shape !=
indices_shape[:-1]:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_SCATTER requires point update shapes to match
scatter indices"
+ )
+
+ body_op = self._get_stablehlo_simple_body_op(
+ int(opts.UpdateComputationSubgraphIndex()), "STABLEHLO_SCATTER", 2
+ )
+ reducer_name = self.get_op_code_str(body_op)
+ reductions = {
+ "STABLEHLO_ADD": "add",
+ "STABLEHLO_MAXIMUM": "max",
+ "STABLEHLO_MINIMUM": "min",
+ "STABLEHLO_MULTIPLY": "mul",
+ }
+ if reducer_name not in reductions:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_SCATTER reducer {reducer_name} is not supported"
+ )
+
+ operand = self.get_tensor_expr(input_tensors[0])
+ indices = self.get_tensor_expr(input_tensors[1])
+ updates = self.get_tensor_expr(input_tensors[2])
+ return self.bb.normalize(
+ relax.op.scatter_nd(operand, indices, updates,
reductions[reducer_name])
+ )
+
+ def _convert_stablehlo_composite(self, op):
+ """Convert STABLEHLO_COMPOSITE by inlining a simple decomposition
subgraph."""
+ from tflite.StableHLOCompositeOptions import StableHLOCompositeOptions
+
+ input_tensors = self.get_input_tensors(op)
+ output_tensors = self.get_output_tensors(op)
+ if len(output_tensors) != 1:
+ raise tvm.error.OpNotImplemented(
+ "STABLEHLO_COMPOSITE only supports single-output
decompositions"
+ )
+
+ opts = self._get_stablehlo_options(op, StableHLOCompositeOptions)
+ composite_name = opts.Name()
+ composite_name = (
+ composite_name.decode("utf-8") if composite_name is not None else
"<unnamed>"
+ )
+ if opts.CompositeAttributesLength() != 0:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_COMPOSITE {composite_name} with composite
attributes is not supported"
+ )
+
+ decomposition_subgraph_index = int(opts.DecompositionSubgraphIndex())
+ if (
+ decomposition_subgraph_index <= 0
+ or decomposition_subgraph_index >= self.model.SubgraphsLength()
+ ):
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_COMPOSITE {composite_name} requires a valid
decomposition subgraph"
+ )
+ decomposition_subgraph =
self.model.Subgraphs(decomposition_subgraph_index)
+ if decomposition_subgraph.InputsLength() != len(input_tensors):
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_COMPOSITE {composite_name} decomposition input
count mismatch"
+ )
+ if decomposition_subgraph.OutputsLength() != 1:
+ raise tvm.error.OpNotImplemented(
+ f"STABLEHLO_COMPOSITE {composite_name} only supports
single-output decompositions"
+ )
+
+ decomposition_exp_tab = ExprTable()
+ decomposition_converter = OperatorConverter(
+ self.model, decomposition_subgraph, decomposition_exp_tab, self.bb
+ )
+ for decomposition_input_idx, composite_input in zip(
+ decomposition_subgraph.InputsAsNumpy(), input_tensors
+ ):
+ decomposition_input_name = get_tensor_name(
+ decomposition_subgraph, int(decomposition_input_idx)
+ )
+ decomposition_exp_tab.set_expr(
+ decomposition_input_name,
+ self.get_tensor_expr(composite_input),
+ force_override=True,
+ )
+
+ decomposition_converter.check_unsupported_ops()
+ decomposition_converter.convert_op_to_relax()
+ decomposition_output_idx = int(decomposition_subgraph.Outputs(0))
+ decomposition_output_tensor = decomposition_converter.get_tensors(
+ [decomposition_output_idx]
+ )[0]
+ for _, value in decomposition_exp_tab.params.values():
+ param_name = f"_param_{self.exp_tab.const_ctr}"
+ self.exp_tab.const_ctr += 1
+ self.exp_tab.params[param_name] = (relax.const(value), value)
Review Comment:

The loop to copy constants from the decomposition subgraph's `ExprTable` to
the parent's `ExprTable` re-creates `relax.const` objects. This is redundant
because the constants are already created and part of the graph through the
shared `BlockBuilder`. You can avoid creating duplicate constants by reusing
the existing `relax.Expr` from the `decomposition_exp_tab`.
```suggestion
for const_expr, value in decomposition_exp_tab.params.values():
param_name = f"_param_{self.exp_tab.const_ctr}"
self.exp_tab.const_ctr += 1
self.exp_tab.params[param_name] = (const_expr, value)
```
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]