gemini-code-assist[bot] commented on code in PR #19580:
URL: https://github.com/apache/tvm/pull/19580#discussion_r3254240964


##########
python/tvm/relax/backend/xnnpack.py:
##########
@@ -0,0 +1,2194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for the XNNPACK Relax backend."""
+
+from collections.abc import Callable
+
+import numpy as np
+import tvm
+from tvm.ir import IRModule
+from tvm import relax
+from tvm.relax.dpl.pattern import is_const, is_op, wildcard
+from tvm.relax.transform import FuseOpsByPattern, FusionPattern, 
PatternCheckContext
+
+from .pattern_registry import get_patterns_with_prefix, register_patterns
+from .utils import has_leaking_intermediate_variables
+
+_SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force")
+_SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported")
+_SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve")
+_SUPPORTED_QUANTIZATIONS = ("none", "dynamic_range")
+_SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only")
+_XNN_EXTRA_BYTES = 16
+_DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4}
+_QPARAM_SCALE_RTOL = 1e-5
+_QPARAM_SCALE_ATOL = 1e-8
+
+
+def _get_static_shape(expr: relax.Expr) -> list[int] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+
+    shape = []
+    for dim in sinfo.shape.values:
+        if not isinstance(dim, (tvm.tirx.expr.IntImm, int)):

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The submodule `tvm.tirx` does not exist in the standard TVM library. This 
appears to be a typo for `tvm.tir`.
   
   ```suggestion
           if not isinstance(dim, (tvm.tir.IntImm, int)):
   ```



##########
python/tvm/relax/backend/xnnpack.py:
##########
@@ -0,0 +1,2194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for the XNNPACK Relax backend."""
+
+from collections.abc import Callable
+
+import numpy as np
+import tvm
+from tvm.ir import IRModule
+from tvm import relax
+from tvm.relax.dpl.pattern import is_const, is_op, wildcard
+from tvm.relax.transform import FuseOpsByPattern, FusionPattern, 
PatternCheckContext
+
+from .pattern_registry import get_patterns_with_prefix, register_patterns
+from .utils import has_leaking_intermediate_variables
+
+_SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force")
+_SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported")
+_SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve")
+_SUPPORTED_QUANTIZATIONS = ("none", "dynamic_range")
+_SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only")
+_XNN_EXTRA_BYTES = 16
+_DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4}
+_QPARAM_SCALE_RTOL = 1e-5
+_QPARAM_SCALE_ATOL = 1e-8
+
+
+def _get_static_shape(expr: relax.Expr) -> list[int] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+
+    shape = []
+    for dim in sinfo.shape.values:
+        if not isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+            return None
+        dim = int(dim)
+        if dim <= 0:
+            return None
+        shape.append(dim)
+    return shape
+
+
+def _shape_dims(expr: relax.Expr) -> list[object] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+    return list(sinfo.shape.values)
+
+
+def _symbol_name(dim) -> str | None:
+    if isinstance(dim, (tvm.tirx.expr.IntImm, int)):

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The submodule `tvm.tirx` does not exist in the standard TVM library. This 
appears to be a typo for `tvm.tir`.
   
   ```suggestion
       if isinstance(dim, (tvm.tir.IntImm, int)):
   ```



##########
python/tvm/relax/backend/xnnpack.py:
##########
@@ -0,0 +1,2194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for the XNNPACK Relax backend."""
+
+from collections.abc import Callable
+
+import numpy as np
+import tvm
+from tvm.ir import IRModule
+from tvm import relax
+from tvm.relax.dpl.pattern import is_const, is_op, wildcard
+from tvm.relax.transform import FuseOpsByPattern, FusionPattern, 
PatternCheckContext
+
+from .pattern_registry import get_patterns_with_prefix, register_patterns
+from .utils import has_leaking_intermediate_variables
+
+_SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force")
+_SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported")
+_SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve")
+_SUPPORTED_QUANTIZATIONS = ("none", "dynamic_range")
+_SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only")
+_XNN_EXTRA_BYTES = 16
+_DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4}
+_QPARAM_SCALE_RTOL = 1e-5
+_QPARAM_SCALE_ATOL = 1e-8
+
+
+def _get_static_shape(expr: relax.Expr) -> list[int] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+
+    shape = []
+    for dim in sinfo.shape.values:
+        if not isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+            return None
+        dim = int(dim)
+        if dim <= 0:
+            return None
+        shape.append(dim)
+    return shape
+
+
+def _shape_dims(expr: relax.Expr) -> list[object] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+    return list(sinfo.shape.values)
+
+
+def _symbol_name(dim) -> str | None:
+    if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+        return None
+    if hasattr(dim, "name"):
+        return str(dim.name)
+    if hasattr(dim, "name_hint"):
+        return str(dim.name_hint)
+    text = str(dim)
+    return text if text else None
+
+
+def _get_batch_only_shape(expr: relax.Expr) -> tuple[str, list[int | None]] | 
None:
+    dims = _shape_dims(expr)
+    if dims is None or len(dims) == 0:
+        return None
+    result: list[int | None] = []
+    symbol: str | None = None
+    for index, dim in enumerate(dims):
+        if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+            value = int(dim)
+            if value <= 0:
+                return None
+            result.append(value)
+            continue
+        name = _symbol_name(dim)
+        if index != 0 or name is None:
+            return None
+        symbol = name
+        result.append(None)
+    if symbol is None:
+        return None
+    return symbol, result
+
+
+def _same_batch_only_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool:
+    lhs_info = _get_batch_only_shape(lhs)
+    rhs_info = _get_batch_only_shape(rhs)
+    return lhs_info is not None and lhs_info == rhs_info
+
+
+def _batch_bounds_from_attrs(func: relax.Function) -> dict[str, tuple[int, 
int]]:
+    result: dict[str, tuple[int, int]] = {}
+    if not func.attrs:
+        return result
+    upper = func.attrs.get("tir_var_upper_bound")
+    lower = func.attrs.get("tir_var_lower_bound")
+    if upper is None:
+        return result
+    for key, value in upper.items():
+        upper_value = _as_bound_int(value)
+        lower_value = 1
+        if lower is not None and key in lower:
+            lower_value = _as_bound_int(lower[key])
+        result[str(key)] = (lower_value, upper_value)
+    return result
+
+
+def _as_bound_int(value) -> int:
+    if hasattr(value, "value"):
+        return int(value.value)
+    return int(value)
+
+
+def _normalize_dynamic_batch_bounds(
+    mod: IRModule, dynamic_batch_bounds
+) -> dict[str, tuple[int, int]]:
+    result: dict[str, tuple[int, int]] = {}
+    for func in mod.functions.values():
+        if isinstance(func, relax.Function):
+            result.update(_batch_bounds_from_attrs(func))
+    if dynamic_batch_bounds:
+        for key, value in dynamic_batch_bounds.items():
+            if isinstance(value, tuple):
+                lower, upper = value
+            elif isinstance(value, list):
+                if len(value) != 2:
+                    raise ValueError("XNNPACK dynamic_batch_bounds list values 
must have 2 items.")
+                lower, upper = value
+            else:
+                lower, upper = 1, value
+            result[str(key)] = (int(lower), int(upper))
+    for symbol, (lower, upper) in result.items():
+        if lower <= 0 or upper < lower:
+            raise ValueError(
+                f"Invalid XNNPACK dynamic batch bounds for {symbol!r}: "
+                f"expected 0 < lower <= upper, got ({lower}, {upper})."
+            )
+    return result
+
+
+def _is_float32_tensor(expr: relax.Expr) -> bool:
+    sinfo = expr.struct_info
+    return isinstance(sinfo, relax.TensorStructInfo) and sinfo.dtype == 
"float32"
+
+
+def _is_static_float32(expr: relax.Expr) -> bool:
+    return _is_float32_tensor(expr) and _get_static_shape(expr) is not None
+
+
+def _tensor_dtype(expr: relax.Expr) -> str | None:
+    sinfo = expr.struct_info
+    if isinstance(sinfo, relax.TensorStructInfo):
+        return str(sinfo.dtype)
+    return None
+
+
+def _num_elements(expr: relax.Expr) -> int | None:
+    shape = _get_static_shape(expr)
+    if shape is None:
+        return None
+    result = 1
+    for dim in shape:
+        result *= dim
+    return result
+
+
+def _tensor_nbytes(expr: relax.Expr) -> int:
+    num_elements = _num_elements(expr)
+    dtype = _tensor_dtype(expr)
+    if num_elements is None or dtype not in _DTYPE_BYTES:
+        return 0
+    return num_elements * _DTYPE_BYTES[dtype]
+
+
+def _const_numpy(expr: relax.Expr) -> np.ndarray | None:
+    if not isinstance(expr, relax.Constant):
+        return None
+    return expr.data.numpy()
+
+
+def _const_scalar_float(expr: relax.Expr) -> float | None:
+    arr = _const_numpy(expr)
+    if arr is None or arr.size != 1:
+        return None
+    value = float(arr.reshape(-1)[0])
+    if not np.isfinite(value):
+        return None
+    return value
+
+
+def _const_int_array(expr: relax.Expr) -> np.ndarray | None:
+    arr = _const_numpy(expr)
+    if arr is None:
+        return None
+    if not np.issubdtype(arr.dtype, np.integer):
+        return None
+    return arr.astype("int64")
+
+
+def _const_float_array(expr: relax.Expr) -> np.ndarray | None:
+    arr = _const_numpy(expr)
+    if arr is None:
+        return None
+    if not np.issubdtype(arr.dtype, np.floating):
+        return None
+    arr = arr.astype("float64")
+    if not np.all(np.isfinite(arr)):
+        return None
+    return arr
+
+
+def _const_scalar_int(expr: relax.Expr) -> int | None:
+    arr = _const_int_array(expr)
+    if arr is None or arr.size != 1:
+        return None
+    return int(arr.reshape(-1)[0])
+
+
+def _same_static_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool:
+    lhs_shape = _get_static_shape(lhs)
+    rhs_shape = _get_static_shape(rhs)
+    return lhs_shape is not None and lhs_shape == rhs_shape
+
+
+def _is_external_input(expr: relax.Expr) -> bool:
+    return not isinstance(expr, relax.Constant)
+
+
+def _as_float_prim_value(expr: relax.Expr) -> float | None:
+    if not isinstance(expr, relax.PrimValue):
+        return None
+    value = expr.value
+    if isinstance(value, tvm.tirx.expr.FloatImm):

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The submodule `tvm.tirx` does not exist in the standard TVM library. This 
appears to be a typo for `tvm.tir`.
   
   ```suggestion
       if isinstance(value, tvm.tir.FloatImm):
   ```



##########
src/relax/backend/contrib/xnnpack/codegen.cc:
##########
@@ -0,0 +1,997 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/contrib/xnnpack/codegen.cc
+ * \brief Minimal XNNPACK Relax external codegen.
+ */
+
+#include <tvm/ffi/cast.h>
+#include <tvm/ffi/extra/module.h>
+#include <tvm/ffi/reflection/registry.h>
+#include <tvm/ir/module.h>
+#include <tvm/relax/attrs/qdq.h>
+#include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+
+#include <limits>
+#include <optional>
+#include <sstream>
+#include <string>
+#include <unordered_set>
+#include <vector>
+
+#include "../codegen_json/codegen_json.h"
+#include "../utils.h"
+
+namespace tvm {
+namespace relax {
+namespace contrib {
+
+using JSONGraphNode = tvm::runtime::json::JSONGraphNode;
+using JSONGraphObjectPtr = backend::contrib::JSONGraphObjectPtr;
+using JSONSerializer = backend::contrib::JSONSerializer;
+using backend::contrib::NodeEntries;
+
+struct XNNPACKRuntimeOptions {
+  bool use_weights_cache{false};
+  bool use_workspace{false};
+  bool profile{false};
+  bool dont_spin_workers{false};
+  bool transient_indirection_buffer{false};
+  int64_t num_threads{1};
+  std::string precision{"fp32"};
+  std::string dynamic_shape_policy{"none"};
+  std::string dynamic_batch_symbol{""};
+  int64_t dynamic_batch_lower{1};
+  int64_t dynamic_batch_upper{-1};
+
+  std::string Serialize() const {
+    std::ostringstream os;
+    os << "use_weights_cache=" << (use_weights_cache ? 1 : 0) << ";";
+    os << "use_workspace=" << (use_workspace ? 1 : 0) << ";";
+    os << "profile=" << (profile ? 1 : 0) << ";";
+    os << "dont_spin_workers=" << (dont_spin_workers ? 1 : 0) << ";";
+    os << "transient_indirection_buffer=" << (transient_indirection_buffer ? 1 
: 0) << ";";
+    os << "num_threads=" << num_threads << ";";
+    os << "precision=" << precision << ";";
+    os << "dynamic_shape_policy=" << dynamic_shape_policy << ";";
+    os << "dynamic_batch_symbol=" << dynamic_batch_symbol << ";";
+    os << "dynamic_batch_lower=" << dynamic_batch_lower << ";";
+    os << "dynamic_batch_upper=" << dynamic_batch_upper << ";";
+    return os.str();
+  }
+};
+
+bool GetBoolOption(const ffi::Map<ffi::String, ffi::Any>& options, const 
std::string& key,
+                   bool default_value) {
+  auto it = options.find(key);
+  if (it == options.end()) return default_value;
+  const ffi::Any& value = (*it).second;
+  if (auto opt_bool = value.try_cast<bool>()) return opt_bool.value();
+  if (auto opt_int = value.try_cast<int64_t>()) return opt_int.value() != 0;
+  TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key << "' must 
be a boolean value.";
+}
+
+int64_t GetIntOption(const ffi::Map<ffi::String, ffi::Any>& options, const 
std::string& key,
+                     int64_t default_value) {
+  auto it = options.find(key);
+  if (it == options.end()) return default_value;
+  const ffi::Any& value = (*it).second;
+  if (auto opt_int = value.try_cast<int64_t>()) return opt_int.value();
+  TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key
+                            << "' must be an integer value.";
+}
+
+ffi::Optional<ffi::String> GetStringOption(const ffi::Map<ffi::String, 
ffi::Any>& options,
+                                           const std::string& key) {
+  auto it = options.find(key);
+  if (it == options.end()) return std::nullopt;
+  const ffi::Any& value = (*it).second;
+  if (auto opt_string = value.try_cast<ffi::String>()) return 
opt_string.value();
+  TVM_FFI_THROW(ValueError) << "XNNPACK RunCodegen option '" << key << "' must 
be a string value.";
+}
+
+void ValidatePrecision(const std::string& precision) {
+  static const std::unordered_set<std::string> supported = {"fp32", 
"fp16_hint", "fp16_force"};
+  TVM_FFI_ICHECK(supported.count(precision)) << "Unsupported XNNPACK 
precision: " << precision;
+}
+
+int64_t GetIntAttr(const Function& func, const std::string& key, int64_t 
default_value) {
+  auto value = func->GetAttr<IntImm>(key);
+  return value ? value.value()->value : default_value;
+}
+
+XNNPACKRuntimeOptions ParseRuntimeOptions(const ffi::Map<ffi::String, 
ffi::Any>& options,
+                                          const ffi::Optional<ffi::String>& 
annotated_precision) {
+  static const std::unordered_set<std::string> supported = {
+      "use_weights_cache",
+      "use_workspace",
+      "profile",
+      "dont_spin_workers",
+      "transient_indirection_buffer",
+      "num_threads",
+      "precision",
+  };
+  for (const auto& kv : options) {
+    const std::string key = kv.first;
+    TVM_FFI_ICHECK(supported.count(key)) << "Unsupported XNNPACK RunCodegen 
option: " << key;
+  }
+
+  XNNPACKRuntimeOptions parsed;
+  parsed.use_weights_cache = GetBoolOption(options, "use_weights_cache", 
false);
+  parsed.use_workspace = GetBoolOption(options, "use_workspace", false);
+  parsed.profile = GetBoolOption(options, "profile", false);
+  parsed.dont_spin_workers = GetBoolOption(options, "dont_spin_workers", 
false);
+  parsed.transient_indirection_buffer =
+      GetBoolOption(options, "transient_indirection_buffer", false);
+  parsed.num_threads = GetIntOption(options, "num_threads", 1);
+  if (annotated_precision.has_value()) {
+    parsed.precision = annotated_precision.value();
+  }
+  if (auto option_precision = GetStringOption(options, "precision")) {
+    ValidatePrecision(option_precision.value());
+    if (annotated_precision.has_value()) {
+      TVM_FFI_ICHECK_EQ(std::string(annotated_precision.value()),
+                        std::string(option_precision.value()))
+          << "XNNPACK precision from partition_for_xnnpack and RunCodegen 
options must match.";
+    }
+    parsed.precision = option_precision.value();
+  }
+  ValidatePrecision(parsed.precision);
+  parsed.dynamic_shape_policy = "none";
+  TVM_FFI_ICHECK_GE(parsed.num_threads, 1)
+      << "XNNPACK RunCodegen option 'num_threads' must be >= 1.";
+  return parsed;
+}
+
+class XNNPACKJSONSerializer : public JSONSerializer {
+ public:
+  XNNPACKJSONSerializer(ffi::Map<Constant, ffi::String> constant_names,
+                        ffi::Map<Var, Expr> bindings)
+      : JSONSerializer(constant_names), bindings_(bindings) {}
+
+  using JSONSerializer::VisitExpr_;
+
+  NodeEntries VisitExpr_(const CallNode* call_node) final {
+    const auto* fn_var = call_node->op.as<VarNode>();
+    TVM_FFI_ICHECK(fn_var) << "XNNPACK codegen expects calls to composite 
functions.";
+
+    const auto fn = Downcast<Function>(bindings_[ffi::GetRef<Var>(fn_var)]);
+    TVM_FFI_ICHECK(fn.defined()) << "Expects the callee to be a function.";
+
+    auto composite_opt = fn->GetAttr<ffi::String>(attr::kComposite);
+    TVM_FFI_ICHECK(composite_opt.has_value()) << "Only composite functions are 
supported.";
+
+    std::string composite_name = composite_opt.value();
+    TVM_FFI_ICHECK(IsSupportedComposite(composite_name))
+        << "Unsupported XNNPACK composite pattern: " << composite_name;
+
+    if (IsDynamicRangeComposite(composite_name)) {
+      return VisitDynamicRangeComposite(call_node, fn, composite_name);
+    }
+    if (IsQuantizedComposite(composite_name)) {
+      return VisitQuantizedComposite(call_node, fn, composite_name);
+    }
+    if (composite_name == "xnnpack.fully_connected_bias_gelu" ||
+        composite_name == "xnnpack.fully_connected_bias_approx_gelu") {
+      return VisitFullyConnectedGeluComposite(call_node, fn, composite_name);
+    }
+
+    NodeEntries inputs;
+    for (const auto& arg : call_node->args) {
+      auto res = VisitExpr(arg);
+      inputs.insert(inputs.end(), res.begin(), res.end());
+    }
+    for (const auto& constant : CollectConstants(fn)) {
+      auto res = VisitExpr(constant);
+      inputs.insert(inputs.end(), res.begin(), res.end());
+    }
+
+    auto node = std::make_shared<JSONGraphNode>(composite_name, "kernel", 
inputs, 1);
+    SetCompositeAttrs(node, fn, composite_name, inputs.size());
+    return AddNode(node, ffi::GetRef<Expr>(call_node));
+  }
+
+ private:
+  static constexpr double kXNNPACKInfinity = 3.4028234663852886e38;

Review Comment:
   ![medium](https://www.gstatic.com/codereviewagent/medium-priority.svg)
   
   Using a hardcoded literal for the maximum float value is less robust than 
using standard library constants. Consider using 
`std::numeric_limits<float>::max()`.
   
   ```suggestion
     static constexpr double kXNNPACKInfinity = 
std::numeric_limits<float>::max();
   ```



##########
python/tvm/relax/backend/xnnpack.py:
##########
@@ -0,0 +1,2194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for the XNNPACK Relax backend."""
+
+from collections.abc import Callable
+
+import numpy as np
+import tvm
+from tvm.ir import IRModule
+from tvm import relax
+from tvm.relax.dpl.pattern import is_const, is_op, wildcard
+from tvm.relax.transform import FuseOpsByPattern, FusionPattern, 
PatternCheckContext
+
+from .pattern_registry import get_patterns_with_prefix, register_patterns
+from .utils import has_leaking_intermediate_variables
+
+_SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force")
+_SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported")
+_SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve")
+_SUPPORTED_QUANTIZATIONS = ("none", "dynamic_range")
+_SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only")
+_XNN_EXTRA_BYTES = 16
+_DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4}
+_QPARAM_SCALE_RTOL = 1e-5
+_QPARAM_SCALE_ATOL = 1e-8
+
+
+def _get_static_shape(expr: relax.Expr) -> list[int] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+
+    shape = []
+    for dim in sinfo.shape.values:
+        if not isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+            return None
+        dim = int(dim)
+        if dim <= 0:
+            return None
+        shape.append(dim)
+    return shape
+
+
+def _shape_dims(expr: relax.Expr) -> list[object] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+    return list(sinfo.shape.values)
+
+
+def _symbol_name(dim) -> str | None:
+    if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+        return None
+    if hasattr(dim, "name"):
+        return str(dim.name)
+    if hasattr(dim, "name_hint"):
+        return str(dim.name_hint)
+    text = str(dim)
+    return text if text else None
+
+
+def _get_batch_only_shape(expr: relax.Expr) -> tuple[str, list[int | None]] | 
None:
+    dims = _shape_dims(expr)
+    if dims is None or len(dims) == 0:
+        return None
+    result: list[int | None] = []
+    symbol: str | None = None
+    for index, dim in enumerate(dims):
+        if isinstance(dim, (tvm.tirx.expr.IntImm, int)):

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The submodule `tvm.tirx` does not exist in the standard TVM library. This 
appears to be a typo for `tvm.tir`.
   
   ```suggestion
           if isinstance(dim, (tvm.tir.IntImm, int)):
   ```



##########
python/tvm/relax/backend/xnnpack.py:
##########
@@ -0,0 +1,2194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for the XNNPACK Relax backend."""
+
+from collections.abc import Callable
+
+import numpy as np
+import tvm
+from tvm.ir import IRModule
+from tvm import relax
+from tvm.relax.dpl.pattern import is_const, is_op, wildcard
+from tvm.relax.transform import FuseOpsByPattern, FusionPattern, 
PatternCheckContext
+
+from .pattern_registry import get_patterns_with_prefix, register_patterns
+from .utils import has_leaking_intermediate_variables
+
+_SUPPORTED_PRECISIONS = ("fp32", "fp16_hint", "fp16_force")
+_SUPPORTED_PARTITION_POLICIES = ("greedy", "cost", "debug_all_supported")
+_SUPPORTED_LAYOUT_POLICIES = ("auto", "NHWC", "preserve")
+_SUPPORTED_QUANTIZATIONS = ("none", "dynamic_range")
+_SUPPORTED_DYNAMIC_SHAPE_POLICIES = ("none", "batch_only")
+_XNN_EXTRA_BYTES = 16
+_DTYPE_BYTES = {"float16": 2, "float32": 4, "int8": 1, "uint8": 1, "int32": 4}
+_QPARAM_SCALE_RTOL = 1e-5
+_QPARAM_SCALE_ATOL = 1e-8
+
+
+def _get_static_shape(expr: relax.Expr) -> list[int] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+
+    shape = []
+    for dim in sinfo.shape.values:
+        if not isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+            return None
+        dim = int(dim)
+        if dim <= 0:
+            return None
+        shape.append(dim)
+    return shape
+
+
+def _shape_dims(expr: relax.Expr) -> list[object] | None:
+    sinfo = expr.struct_info
+    if not isinstance(sinfo, relax.TensorStructInfo):
+        return None
+    if sinfo.shape is None or not hasattr(sinfo.shape, "values"):
+        return None
+    return list(sinfo.shape.values)
+
+
+def _symbol_name(dim) -> str | None:
+    if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+        return None
+    if hasattr(dim, "name"):
+        return str(dim.name)
+    if hasattr(dim, "name_hint"):
+        return str(dim.name_hint)
+    text = str(dim)
+    return text if text else None
+
+
+def _get_batch_only_shape(expr: relax.Expr) -> tuple[str, list[int | None]] | 
None:
+    dims = _shape_dims(expr)
+    if dims is None or len(dims) == 0:
+        return None
+    result: list[int | None] = []
+    symbol: str | None = None
+    for index, dim in enumerate(dims):
+        if isinstance(dim, (tvm.tirx.expr.IntImm, int)):
+            value = int(dim)
+            if value <= 0:
+                return None
+            result.append(value)
+            continue
+        name = _symbol_name(dim)
+        if index != 0 or name is None:
+            return None
+        symbol = name
+        result.append(None)
+    if symbol is None:
+        return None
+    return symbol, result
+
+
+def _same_batch_only_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool:
+    lhs_info = _get_batch_only_shape(lhs)
+    rhs_info = _get_batch_only_shape(rhs)
+    return lhs_info is not None and lhs_info == rhs_info
+
+
+def _batch_bounds_from_attrs(func: relax.Function) -> dict[str, tuple[int, 
int]]:
+    result: dict[str, tuple[int, int]] = {}
+    if not func.attrs:
+        return result
+    upper = func.attrs.get("tir_var_upper_bound")
+    lower = func.attrs.get("tir_var_lower_bound")
+    if upper is None:
+        return result
+    for key, value in upper.items():
+        upper_value = _as_bound_int(value)
+        lower_value = 1
+        if lower is not None and key in lower:
+            lower_value = _as_bound_int(lower[key])
+        result[str(key)] = (lower_value, upper_value)
+    return result
+
+
+def _as_bound_int(value) -> int:
+    if hasattr(value, "value"):
+        return int(value.value)
+    return int(value)
+
+
+def _normalize_dynamic_batch_bounds(
+    mod: IRModule, dynamic_batch_bounds
+) -> dict[str, tuple[int, int]]:
+    result: dict[str, tuple[int, int]] = {}
+    for func in mod.functions.values():
+        if isinstance(func, relax.Function):
+            result.update(_batch_bounds_from_attrs(func))
+    if dynamic_batch_bounds:
+        for key, value in dynamic_batch_bounds.items():
+            if isinstance(value, tuple):
+                lower, upper = value
+            elif isinstance(value, list):
+                if len(value) != 2:
+                    raise ValueError("XNNPACK dynamic_batch_bounds list values 
must have 2 items.")
+                lower, upper = value
+            else:
+                lower, upper = 1, value
+            result[str(key)] = (int(lower), int(upper))
+    for symbol, (lower, upper) in result.items():
+        if lower <= 0 or upper < lower:
+            raise ValueError(
+                f"Invalid XNNPACK dynamic batch bounds for {symbol!r}: "
+                f"expected 0 < lower <= upper, got ({lower}, {upper})."
+            )
+    return result
+
+
+def _is_float32_tensor(expr: relax.Expr) -> bool:
+    sinfo = expr.struct_info
+    return isinstance(sinfo, relax.TensorStructInfo) and sinfo.dtype == 
"float32"
+
+
+def _is_static_float32(expr: relax.Expr) -> bool:
+    return _is_float32_tensor(expr) and _get_static_shape(expr) is not None
+
+
+def _tensor_dtype(expr: relax.Expr) -> str | None:
+    sinfo = expr.struct_info
+    if isinstance(sinfo, relax.TensorStructInfo):
+        return str(sinfo.dtype)
+    return None
+
+
+def _num_elements(expr: relax.Expr) -> int | None:
+    shape = _get_static_shape(expr)
+    if shape is None:
+        return None
+    result = 1
+    for dim in shape:
+        result *= dim
+    return result
+
+
+def _tensor_nbytes(expr: relax.Expr) -> int:
+    num_elements = _num_elements(expr)
+    dtype = _tensor_dtype(expr)
+    if num_elements is None or dtype not in _DTYPE_BYTES:
+        return 0
+    return num_elements * _DTYPE_BYTES[dtype]
+
+
+def _const_numpy(expr: relax.Expr) -> np.ndarray | None:
+    if not isinstance(expr, relax.Constant):
+        return None
+    return expr.data.numpy()
+
+
+def _const_scalar_float(expr: relax.Expr) -> float | None:
+    arr = _const_numpy(expr)
+    if arr is None or arr.size != 1:
+        return None
+    value = float(arr.reshape(-1)[0])
+    if not np.isfinite(value):
+        return None
+    return value
+
+
+def _const_int_array(expr: relax.Expr) -> np.ndarray | None:
+    arr = _const_numpy(expr)
+    if arr is None:
+        return None
+    if not np.issubdtype(arr.dtype, np.integer):
+        return None
+    return arr.astype("int64")
+
+
+def _const_float_array(expr: relax.Expr) -> np.ndarray | None:
+    arr = _const_numpy(expr)
+    if arr is None:
+        return None
+    if not np.issubdtype(arr.dtype, np.floating):
+        return None
+    arr = arr.astype("float64")
+    if not np.all(np.isfinite(arr)):
+        return None
+    return arr
+
+
+def _const_scalar_int(expr: relax.Expr) -> int | None:
+    arr = _const_int_array(expr)
+    if arr is None or arr.size != 1:
+        return None
+    return int(arr.reshape(-1)[0])
+
+
+def _same_static_shape(lhs: relax.Expr, rhs: relax.Expr) -> bool:
+    lhs_shape = _get_static_shape(lhs)
+    rhs_shape = _get_static_shape(rhs)
+    return lhs_shape is not None and lhs_shape == rhs_shape
+
+
+def _is_external_input(expr: relax.Expr) -> bool:
+    return not isinstance(expr, relax.Constant)
+
+
+def _as_float_prim_value(expr: relax.Expr) -> float | None:
+    if not isinstance(expr, relax.PrimValue):
+        return None
+    value = expr.value
+    if isinstance(value, tvm.tirx.expr.FloatImm):
+        return float(value.value)
+    if isinstance(value, tvm.tirx.expr.IntImm):

Review Comment:
   ![high](https://www.gstatic.com/codereviewagent/high-priority.svg)
   
   The submodule `tvm.tirx` does not exist in the standard TVM library. This 
appears to be a typo for `tvm.tir`.
   
   ```suggestion
       if isinstance(value, tvm.tir.IntImm):
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to