This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 246e290267 [Relax][TensorRT] Update BYOC operator converters from 
Relay to Relax (#19810)
246e290267 is described below

commit 246e290267482c4c3625f75e06ab59df7cc2afe7
Author: Shushi Hong <[email protected]>
AuthorDate: Wed Jun 17 08:03:20 2026 -0400

    [Relax][TensorRT] Update BYOC operator converters from Relay to Relax 
(#19810)
    
    This pr is the follow-up pr to #19789. CurrentTensorRT BYOC converters
    were ported from Relay and still read attribute names/shapes that no
    longer match the Relax ops, so most ops crashed ("Key: <name> is not
    found") or produced wrong results when offloaded.
    
    This pr changed
    - Converters (tensorrt_ops.cc): port reduce, matmul, expand_dims,
    layer_norm, clip, reshape, strided_slice, split and layout_transform to
    read Relax's attributes/arguments. Notable shape changes: clip min/max
    are PrimValue arguments (not a_min/a_max attrs), reshape's shape is a
    Shape argument, matmul has no transpose flags, split is multi-output
    with no "mode", and layout_transform is an IndexMap rather than
    src/dst_layout strings. Unsupported cases (non-static reshape,
    non-permutation layout_transform) now raise a clear error instead of
    crashing.
    - Codegen (codegen.cc): serialize an op's non-tensor arguments
    (PrimValue / ShapeExpr / tuple) as "arg_"-prefixed node attributes,
    materialize a reduce op's all-axes default, and translate a
    pure-permutation layout_transform IndexMap into a transpose order.
    - Runtime: disable the TF32 builder flag so offloaded FP32 subgraphs
    match TVM's FP32 reference, and use a process-lifetime TensorRT logger
    (a per-runtime logger was left dangling once its runtime was destroyed,
    corrupting the heap during TensorRT teardown).
    
    All tests are validated locally.
---
 src/relax/backend/contrib/tensorrt/codegen.cc      | 104 +++++++-
 .../extra/contrib/tensorrt/tensorrt_builder.cc     |   2 +
 .../extra/contrib/tensorrt/tensorrt_logger.h       |  10 +
 src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc | 163 ++++++------
 .../extra/contrib/tensorrt/tensorrt_runtime.cc     |   7 +-
 tests/python/relax/test_codegen_tensorrt.py        | 284 ++++++++++++++++++++-
 6 files changed, 483 insertions(+), 87 deletions(-)

diff --git a/src/relax/backend/contrib/tensorrt/codegen.cc 
b/src/relax/backend/contrib/tensorrt/codegen.cc
index 07ba1c81e6..78ed6fbc4e 100644
--- a/src/relax/backend/contrib/tensorrt/codegen.cc
+++ b/src/relax/backend/contrib/tensorrt/codegen.cc
@@ -24,10 +24,15 @@
 #include <tvm/ffi/cast.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/ir/module.h>
+#include <tvm/ir/op.h>
 #include <tvm/ir/transform.h>
+#include <tvm/relax/attrs/manipulate.h>
 #include <tvm/relax/attrs/nn.h>
+#include <tvm/relax/attrs/statistical.h>
+#include <tvm/relax/expr.h>
 #include <tvm/relax/type.h>
 #include <tvm/runtime/logging.h>
+#include <tvm/tirx/index_map.h>
 
 #include <memory>
 #include <string>
@@ -112,6 +117,98 @@ class CollectFromCompositeFunctionBody : public 
ExprVisitor {
     extractor.Extract(const_cast<ffi::Object*>(attr_obj));
   }
 
+  // Serialize an op's non-tensor arguments (scalars/shapes) as "arg_<name>" 
attributes; the "arg_"
+  // prefix avoids JSONGraphNode's reserved "shape"/"dtype".
+  void SetArgumentAttributes(const CallNode* call_node) {
+    const auto* op_node = call_node->op.as<OpNode>();
+    if (op_node == nullptr) return;
+    const ffi::Array<ArgumentInfo>& arg_infos = op_node->arguments;
+    for (size_t i = 0; i < call_node->args.size() && i < arg_infos.size(); 
++i) {
+      const Expr& arg = call_node->args[i];
+      const std::string key = "arg_" + std::string(arg_infos[i]->name);
+      if (const auto* prim_value = arg.as<PrimValueNode>()) {
+        if (const auto* imm = prim_value->value.as<IntImmNode>()) {
+          node_->SetAttr(key, static_cast<int64_t>(imm->value));
+        } else if (const auto* fimm = prim_value->value.as<FloatImmNode>()) {
+          node_->SetAttr(key, static_cast<double>(fimm->value));
+        }
+      } else if (const auto* shape_expr = arg.as<ShapeExprNode>()) {
+        SetIntArrayAttr(key, shape_expr->values);
+      }
+    }
+  }
+
+  // Relax reduce axis is optional; materialize the all-axes default (it 
otherwise serializes as
+  // "").
+  void MaybeFillReduceAxes(const CallNode* call_node) {
+    const auto* attrs = call_node->attrs.as<StatisticalAttrs>();
+    if (attrs == nullptr || attrs->axis.has_value()) return;
+    const auto* tensor_sinfo = 
GetStructInfo(call_node->args[0]).as<TensorStructInfoNode>();
+    if (tensor_sinfo == nullptr || !tensor_sinfo->shape.defined()) return;
+    const auto* shape = tensor_sinfo->shape.value().as<ShapeExprNode>();
+    if (shape == nullptr) return;
+    ffi::Array<int64_t> all_axes;
+    for (size_t i = 0; i < shape->values.size(); ++i) 
all_axes.push_back(static_cast<int64_t>(i));
+    node_->SetAttr("axis", std::move(all_axes));
+  }
+
+  // strided_slice's axes/begin/end/strides are tuple args the op does not 
name; serialize by
+  // position.
+  void SetStridedSliceArguments(const CallNode* call_node) {
+    const auto* op_node = call_node->op.as<OpNode>();
+    if (op_node == nullptr || op_node->name != "relax.strided_slice") return;
+    static constexpr const char* kNames[] = {"arg_axes", "arg_begin", 
"arg_end", "arg_strides"};
+    for (size_t i = 1; i < call_node->args.size() && i <= 4; ++i) {
+      const auto* tuple = call_node->args[i].as<TupleNode>();
+      if (tuple == nullptr) continue;
+      ffi::Array<PrimExpr> values;
+      for (const Expr& field : tuple->fields) {
+        if (const auto* prim_value = field.as<PrimValueNode>()) 
values.push_back(prim_value->value);
+      }
+      if (values.size() == tuple->fields.size()) SetIntArrayAttr(kNames[i - 
1], values);
+    }
+  }
+
+  // Serialize static integer PrimExprs as an int64 array attribute (skips 
non-constant entries).
+  void SetIntArrayAttr(const std::string& key, const ffi::Array<PrimExpr>& 
exprs) {
+    ffi::Array<int64_t> values;
+    for (const PrimExpr& expr : exprs) {
+      const auto* imm = expr.as<IntImmNode>();
+      if (imm == nullptr) return;
+      values.push_back(imm->value);
+    }
+    node_->SetAttr(key, std::move(values));
+  }
+
+  // layout_transform's IndexMap is not generically serializable; emit a pure 
permutation as
+  // "arg_axes". Returns true for layout_transform (so generic extraction is 
skipped for it).
+  bool TrySetLayoutTransformAttributes(const CallNode* call_node) {
+    const auto* op_node = call_node->op.as<OpNode>();
+    if (op_node == nullptr || op_node->name != "relax.layout_transform") 
return false;
+    const auto* attrs = call_node->attrs.as<LayoutTransformAttrs>();
+    if (attrs == nullptr) return true;
+    auto index_map = attrs->index_map;
+    const auto& initial = index_map->initial_indices;
+    const auto& final_indices = index_map->final_indices;
+    if (initial.size() != final_indices.size()) return true;
+    ffi::Array<int64_t> permutation;
+    for (const PrimExpr& expr : final_indices) {
+      const auto* var = expr.as<tirx::VarNode>();
+      if (var == nullptr) return true;
+      int64_t pos = -1;
+      for (size_t j = 0; j < initial.size(); ++j) {
+        if (initial[j].get() == var) {
+          pos = static_cast<int64_t>(j);
+          break;
+        }
+      }
+      if (pos < 0) return true;
+      permutation.push_back(pos);
+    }
+    node_->SetAttr("arg_axes", std::move(permutation));
+    return true;
+  }
+
   TensorRTJSONSerializer* serializer_;
   /*! \brief Accumulated translated arguments. */
   std::vector<JSONGraphNodeEntry> args_;
@@ -206,7 +303,12 @@ void CollectFromCompositeFunctionBody::VisitExpr_(const 
ConstantNode* constant_n
 }
 
 void CollectFromCompositeFunctionBody::VisitExpr_(const CallNode* call_node) {
-  SetGenericAttributes(call_node);
+  if (!TrySetLayoutTransformAttributes(call_node)) {
+    SetGenericAttributes(call_node);
+    SetArgumentAttributes(call_node);
+    SetStridedSliceArguments(call_node);
+    MaybeFillReduceAxes(call_node);
+  }
   ExprVisitor::VisitExpr_(call_node);
 }
 
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc 
b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
index 281d64cfbc..10b3bdf447 100644
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_builder.cc
@@ -157,6 +157,8 @@ TensorRTEngineAndContext TensorRTBuilder::BuildEngine() {
   config_ = builder_->createBuilderConfig();
   // TensorRT 10 replaced IBuilderConfig::setMaxWorkspaceSize with a tunable 
memory pool.
   config_->setMemoryPoolLimit(nvinfer1::MemoryPoolType::kWORKSPACE, 
max_workspace_size_);
+  // Disable TF32 (on by default on Ampere+) so FP32 layers match TVM's 
full-precision reference.
+  config_->clearFlag(nvinfer1::BuilderFlag::kTF32);
   if (use_fp16_) {
     config_->setFlag(nvinfer1::BuilderFlag::kFP16);
   }
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_logger.h 
b/src/runtime/extra/contrib/tensorrt/tensorrt_logger.h
index 5406f4c57d..0adf0d3cec 100644
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_logger.h
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_logger.h
@@ -71,6 +71,16 @@ class TensorRTLogger : public nvinfer1::ILogger {
   Severity reportable_severity{Severity::kWARNING};
 };
 
+/*!
+ * \brief Process-wide TensorRT logger. TensorRT keeps a global pointer to the 
first logger it is
+ * given, so a per-runtime logger would dangle once its runtime is destroyed; 
this one is
+ * intentionally leaked to outlive all TensorRT state.
+ */
+inline TensorRTLogger& GetTensorRTLogger() {
+  static TensorRTLogger* logger = new TensorRTLogger();
+  return *logger;
+}
+
 }  // namespace contrib
 }  // namespace runtime
 }  // namespace tvm
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc 
b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
index 00ca3cea96..e7783ae80d 100644
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_ops.cc
@@ -180,8 +180,9 @@ class ActivationOpConverter : public TensorRTOpConverter {
         params->network->addActivation(*params->inputs.at(0).tensor, 
it->second);
 #if TRT_VERSION_GE(5, 1, 5)
     if (op_name == "clip") {
-      float a_min = static_cast<float>(params->node.GetAttr<double>("a_min"));
-      float a_max = static_cast<float>(params->node.GetAttr<double>("a_max"));
+      // Relax clip min/max are PrimValue args (serialized as 
arg_min/arg_max), not Relay attrs.
+      float a_min = 
static_cast<float>(params->node.GetAttr<double>("arg_min"));
+      float a_max = 
static_cast<float>(params->node.GetAttr<double>("arg_max"));
       act_layer->setAlpha(a_min);
       act_layer->setBeta(a_max);
     } else if (op_name == "nn.leaky_relu") {
@@ -545,19 +546,28 @@ class LayerNormOpConverter : public TensorRTOpConverter {
     const bool scale = 
static_cast<int>(params->node.GetAttr<int64_t>("scale"));
     const bool center = 
static_cast<int>(params->node.GetAttr<int64_t>("center"));
     const int input_rank = input->getDimensions().nbDims;
-    const int original_axis = 
static_cast<int>(params->node.GetAttr<int64_t>("axis"));
-    const int axis = ConvertAxis(params, original_axis, input_rank);
-
+    auto input_dims = TrtDimsToVector(input->getDimensions());
+    // Relax layer_norm normalizes over an `axes` list (Relay used a single 
`axis`).
+    auto axes_attr = params->node.GetAttr<ffi::Array<int64_t>>("axes");
+    uint32_t reduce_axes = 0;
     std::vector<int> weight_shape(input_rank, 1);
-    weight_shape[axis] = gamma_input.count;
+    int64_t normalized_count = 1;
+    for (size_t i = 0; i < axes_attr.size(); ++i) {
+      const int axis = ConvertAxis(params, static_cast<int>(axes_attr[i]), 
input_rank);
+      reduce_axes |= 1 << axis;
+      weight_shape[axis] = input_dims[axis];
+      normalized_count *= input_dims[axis];
+    }
+    TVM_FFI_ICHECK_EQ(normalized_count, gamma_input.count)
+        << "TensorRT layer_norm expects gamma/beta to cover exactly the 
normalized axes";
     auto gamma =
         params->network->addConstant(VectorToTrtDims(weight_shape), 
gamma_input)->getOutput(0);
     auto beta =
         params->network->addConstant(VectorToTrtDims(weight_shape), 
beta_input)->getOutput(0);
 
     // Compute mean
-    auto mean_layer = params->network->addReduce(*input, 
nvinfer1::ReduceOperation::kAVG, 1 << axis,
-                                                 /*keepdims=*/true);
+    auto mean_layer = params->network->addReduce(*input, 
nvinfer1::ReduceOperation::kAVG,
+                                                 reduce_axes, 
/*keepdims=*/true);
     TVM_FFI_ICHECK(mean_layer != nullptr);
     auto mean = mean_layer->getOutput(0);
     // Compute variance
@@ -568,8 +578,9 @@ class LayerNormOpConverter : public TensorRTOpConverter {
         params->network->addElementWise(*diff_layer->getOutput(0), 
*diff_layer->getOutput(0),
                                         nvinfer1::ElementWiseOperation::kPROD);
     TVM_FFI_ICHECK(square_layer != nullptr);
-    auto var_layer = params->network->addReduce(
-        *square_layer->getOutput(0), nvinfer1::ReduceOperation::kAVG, 1 << 
axis, /*keepdims=*/true);
+    auto var_layer = params->network->addReduce(*square_layer->getOutput(0),
+                                                
nvinfer1::ReduceOperation::kAVG, reduce_axes,
+                                                /*keepdims=*/true);
     TVM_FFI_ICHECK(var_layer != nullptr);
     auto var = var_layer->getOutput(0);
     // sqrt(var + epsilon)
@@ -775,10 +786,15 @@ class ExpandDimsOpConverter : public TensorRTOpConverter {
   void Convert(TensorRTOpConverterParams* params) const {
     auto input_tensor = params->inputs.at(0).tensor;
     auto input_dims = TrtDimsToVector(input_tensor->getDimensions());
-    const int original_axis = 
static_cast<int>(params->node.GetAttr<int64_t>("axis"));
-    const int num_newaxis = 
static_cast<int>(params->node.GetAttr<int64_t>("num_newaxis"));
-    const int axis = ConvertAxis(params, original_axis, input_dims.size() + 1);
-    for (int i = 0; i < num_newaxis; ++i) {
+    // Relax expand_dims carries an `axis` list (not Relay's `axis` + 
`num_newaxis`).
+    auto axes = params->node.GetAttr<ffi::Array<int64_t>>("axis");
+    const int output_ndim = static_cast<int>(input_dims.size() + axes.size());
+    std::vector<int> new_axes;
+    for (size_t i = 0; i < axes.size(); ++i) {
+      new_axes.push_back(ConvertAxis(params, static_cast<int>(axes[i]), 
output_ndim));
+    }
+    std::sort(new_axes.begin(), new_axes.end());
+    for (int axis : new_axes) {
       input_dims.insert(input_dims.begin() + axis, 1);
     }
     params->outputs.push_back(Reshape(params, params->inputs.at(0).tensor, 
input_dims));
@@ -875,39 +891,20 @@ class SplitOpConverter : public TensorRTOpConverter {
     auto input_dims = TrtDimsToVector(input->getDimensions());
     const int original_axis = 
static_cast<int>(params->node.GetAttr<int64_t>("axis"));
     const int axis = ConvertAxis(params, original_axis, input_dims.size());
-    auto indices_or_sections = 
params->node.GetAttr<ffi::Array<int64_t>>("indices_or_sections");
-    auto mode = std::string(params->node.GetAttr<ffi::String>("mode"));
-
-    std::vector<int> split_starts;
-    std::vector<int> split_sizes;
-    if (mode == "sections") {
-      int sections = static_cast<int>(indices_or_sections[0]);
-      int size = input_dims[axis] / sections;
-      for (int i = 0; i < sections; i++) {
-        split_starts.push_back(i * size);
-        split_sizes.push_back(size);
-      }
-    } else {
-      int last_index = 0;
-      for (size_t i = 0; i < indices_or_sections.size(); ++i) {
-        int index = static_cast<int>(indices_or_sections[i]);
-        split_starts.push_back(last_index);
-        split_sizes.push_back(index - last_index);
-        last_index = index;
-      }
-      split_starts.push_back(last_index);
-      split_sizes.push_back(input_dims[axis] - last_index);
-    }
+    // No Relay "mode": derive each output's extent along `axis` from the 
per-output shapes.
+    auto output_shapes = 
params->node.GetAttr<ffi::Array<ffi::Array<int64_t>>>("shape");
 
     std::vector<int> start(input_dims.size(), 0);
     std::vector<int> size(input_dims.begin(), input_dims.end());
     std::vector<int> strides(input_dims.size(), 1);
-    for (size_t i = 0; i < split_sizes.size(); ++i) {
-      start[axis] = split_starts[i];
-      size[axis] = split_sizes[i];
+    int offset = 0;
+    for (size_t i = 0; i < output_shapes.size(); ++i) {
+      start[axis] = offset;
+      size[axis] = static_cast<int>(output_shapes[i][axis]);
       auto slice_layer = params->network->addSlice(*input, 
VectorToTrtDims(start),
                                                    VectorToTrtDims(size), 
VectorToTrtDims(strides));
       params->outputs.push_back(slice_layer->getOutput(0));
+      offset += size[axis];
     }
   }
 };
@@ -1106,17 +1103,14 @@ class LayoutTransformOpConverter : public 
TensorRTOpConverter {
 
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
-    auto src = params->node.GetAttr<ffi::String>("src_layout");
-    auto dst = params->node.GetAttr<ffi::String>("dst_layout");
+    // The codegen emits a pure-permutation IndexMap as "arg_axes"; a missing 
key => unsupported
+    // map.
+    TVM_FFI_ICHECK(params->node.HasAttr("arg_axes"))
+        << "TensorRT layout_transform supports only pure-permutation index 
maps";
+    auto axes = params->node.GetAttr<ffi::Array<int64_t>>("arg_axes");
     std::vector<int> order;
-    if (src == "NCHW" && dst == "NHWC") {
-      order = {0, 2, 3, 1};
-    } else if (src == "NHWC" && dst == "NCHW") {
-      order = {0, 3, 1, 2};
-    } else if (src == "NDHWC" && dst == "NCDHW") {
-      order = {0, 4, 1, 2, 3};
-    } else if (src == "NCDHW" && dst == "NDHWC") {
-      order = {0, 2, 3, 4, 1};
+    for (size_t i = 0; i < axes.size(); ++i) {
+      order.push_back(static_cast<int>(axes[i]));
     }
     params->outputs.push_back(Transpose(params, input, order));
   }
@@ -1131,7 +1125,10 @@ class ReshapeOpConverter : public TensorRTOpConverter {
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
     auto input_dims = TrtDimsToVector(input->getDimensions());
-    auto newshape = params->node.GetAttr<ffi::Array<int64_t>>("newshape");
+    // Relax reshape's shape is a Shape arg (serialized as arg_shape); a 
missing key => non-static.
+    TVM_FFI_ICHECK(params->node.HasAttr("arg_shape"))
+        << "TensorRT reshape supports only a fully static target shape";
+    auto newshape = params->node.GetAttr<ffi::Array<int64_t>>("arg_shape");
     std::vector<int> new_shape;
     int start_index = TRT_HAS_IMPLICIT_BATCH(params) ? 1 : 0;
     if (static_cast<int>(newshape[0]) == -1) start_index = 0;
@@ -1179,17 +1176,14 @@ class ReduceOpConverter : public TensorRTOpConverter {
     TVM_FFI_ICHECK(it != op_map.end()) << "Unsupported reduce type " << 
op_name;
 
     auto input = params->inputs.at(0).tensor;
-    
TVM_FFI_ICHECK_EQ(static_cast<int>(params->node.GetAttr<int64_t>("exclude")), 
false);
+    // No Relay "exclude"; axis is materialized to a concrete list by the 
codegen (None -> all
+    // axes).
     bool keepdims = 
static_cast<int>(params->node.GetAttr<int64_t>("keepdims"));
+    const int input_rank = input->getDimensions().nbDims;
     auto axes = params->node.GetAttr<ffi::Array<int64_t>>("axis");
-    // TODO(trevmorr): Support reduce to scalar.
-    TVM_FFI_ICHECK_GT(axes.size(), 0);
     uint32_t reduce_axes = 0;
-
     for (size_t i = 0; i < axes.size(); ++i) {
-      const int axis =
-          ConvertAxis(params, static_cast<int>(axes[i]), 
input->getDimensions().nbDims);
-      reduce_axes |= 1 << axis;
+      reduce_axes |= 1 << ConvertAxis(params, static_cast<int>(axes[i]), 
input_rank);
     }
     auto reduce_layer = params->network->addReduce(*input, it->second, 
reduce_axes, keepdims);
     params->outputs.push_back(reduce_layer->getOutput(0));
@@ -1206,20 +1200,35 @@ class StridedSliceOpConverter : public 
TensorRTOpConverter {
   void Convert(TensorRTOpConverterParams* params) const {
     auto input = params->inputs.at(0).tensor;
     auto input_dims = TrtDimsToVector(input->getDimensions());
-    auto attr_start = params->node.GetAttr<ffi::Array<int64_t>>("start");
-    auto attr_size = params->node.GetAttr<ffi::Array<int64_t>>("size");
-    auto attr_strides = params->node.GetAttr<ffi::Array<int64_t>>("strides");
-    std::vector<int> start, size, strides;
-    std::transform(attr_start.begin(), attr_start.end(), 
std::back_inserter(start),
-                   [](int64_t v) { return static_cast<int>(v); });
-    std::transform(attr_size.begin(), attr_size.end(), 
std::back_inserter(size),
-                   [](int64_t v) { return static_cast<int>(v); });
-    std::transform(attr_strides.begin(), attr_strides.end(), 
std::back_inserter(strides),
-                   [](int64_t v) { return static_cast<int>(v); });
-    if (TRT_HAS_IMPLICIT_BATCH(params)) {
-      start.erase(start.begin());
-      size.erase(size.begin());
-      strides.erase(strides.begin());
+    const int rank = static_cast<int>(input_dims.size());
+    // axes/begin/end/strides are tuple args (serialized by the codegen); only 
listed axes are
+    // sliced.
+    auto axes = params->node.GetAttr<ffi::Array<int64_t>>("arg_axes");
+    auto begin = params->node.GetAttr<ffi::Array<int64_t>>("arg_begin");
+    auto end = params->node.GetAttr<ffi::Array<int64_t>>("arg_end");
+    std::vector<int64_t> stride_values;
+    if (params->node.HasAttr("arg_strides")) {
+      auto attr_strides = 
params->node.GetAttr<ffi::Array<int64_t>>("arg_strides");
+      stride_values.assign(attr_strides.begin(), attr_strides.end());
+    }
+
+    std::vector<int> start(rank, 0);
+    std::vector<int> size(input_dims.begin(), input_dims.end());
+    std::vector<int> strides(rank, 1);
+    for (size_t i = 0; i < axes.size(); ++i) {
+      const int axis = ConvertAxis(params, static_cast<int>(axes[i]), rank);
+      const int dim = input_dims[axis];
+      const int stride = stride_values.empty() ? 1 : 
static_cast<int>(stride_values[i]);
+      TVM_FFI_ICHECK_GT(stride, 0) << "TensorRT strided_slice supports only 
positive strides";
+      int b = static_cast<int>(begin[i]);
+      int e = static_cast<int>(end[i]);
+      if (b < 0) b += dim;
+      if (e < 0) e += dim;
+      b = std::max(0, std::min(b, dim));
+      e = std::max(0, std::min(e, dim));
+      start[axis] = b;
+      strides[axis] = stride;
+      size[axis] = e > b ? (e - b + stride - 1) / stride : 0;
     }
     auto slice_layer = params->network->addSlice(*input, 
VectorToTrtDims(start),
                                                  VectorToTrtDims(size), 
VectorToTrtDims(strides));
@@ -1267,14 +1276,10 @@ class BatchMatmulOpConverter : public 
TensorRTOpConverter {
   ~BatchMatmulOpConverter() = default;
 
   void Convert(TensorRTOpConverterParams* params) const {
-    auto transa = 
static_cast<int>(params->node.GetAttr<int64_t>("transpose_a"));
-    auto transb = 
static_cast<int>(params->node.GetAttr<int64_t>("transpose_b"));
-    nvinfer1::MatrixOperation trt_transa =
-        transa ? nvinfer1::MatrixOperation::kTRANSPOSE : 
nvinfer1::MatrixOperation::kNONE;
-    nvinfer1::MatrixOperation trt_transb =
-        transb ? nvinfer1::MatrixOperation::kTRANSPOSE : 
nvinfer1::MatrixOperation::kNONE;
+    // Relax matmul has no transpose flags; multiply both operands as-is.
     nvinfer1::IMatrixMultiplyLayer* matmul_layer = 
params->network->addMatrixMultiply(
-        *params->inputs.at(0).tensor, trt_transa, 
*params->inputs.at(1).tensor, trt_transb);
+        *params->inputs.at(0).tensor, nvinfer1::MatrixOperation::kNONE,
+        *params->inputs.at(1).tensor, nvinfer1::MatrixOperation::kNONE);
     TVM_FFI_ICHECK(matmul_layer != nullptr);
     params->outputs.push_back(matmul_layer->getOutput(0));
   }
diff --git a/src/runtime/extra/contrib/tensorrt/tensorrt_runtime.cc 
b/src/runtime/extra/contrib/tensorrt/tensorrt_runtime.cc
index 932c52b394..b421e096d7 100644
--- a/src/runtime/extra/contrib/tensorrt/tensorrt_runtime.cc
+++ b/src/runtime/extra/contrib/tensorrt/tensorrt_runtime.cc
@@ -337,7 +337,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
 
   void BuildEngineFromJson(int batch_size) {
     const bool use_fp16 = support::GetEnv("TVM_TENSORRT_USE_FP16", false) || 
use_fp16_;
-    TensorRTBuilder builder(&logger_, data_entry_, max_workspace_size_, 
use_fp16,
+    TensorRTBuilder builder(&GetTensorRTLogger(), data_entry_, 
max_workspace_size_, use_fp16,
                             calibrator_.get());
     for (size_t i = 0; i < input_nodes_.size(); ++i) {
       auto nid = input_nodes_[i];
@@ -386,7 +386,7 @@ class TensorRTRuntime : public JSONRuntimeBase {
     LoadBinaryFromFile(path, &serialized_engine);
     // Deserialize engine. TensorRT 10 dropped the trailing IPluginFactory* 
argument and the runtime
     // must outlive the engine, so it is owned by the cached 
TensorRTEngineAndContext.
-    nvinfer1::IRuntime* runtime = nvinfer1::createInferRuntime(logger_);
+    nvinfer1::IRuntime* runtime = 
nvinfer1::createInferRuntime(GetTensorRTLogger());
     TensorRTEngineAndContext engine_and_context;
     engine_and_context.runtime = runtime;
     engine_and_context.engine =
@@ -522,9 +522,6 @@ class TensorRTRuntime : public JSONRuntimeBase {
    * used by all engines. */
   std::unordered_map<std::string, Tensor> device_buffers_;
 
-  /*! \brief TensorRT logger. */
-  TensorRTLogger logger_;
-
 #else   // TVM_GRAPH_EXECUTOR_TENSORRT
   void Run() override {
     TVM_FFI_THROW(InternalError) << "TensorRT runtime is not enabled. "
diff --git a/tests/python/relax/test_codegen_tensorrt.py 
b/tests/python/relax/test_codegen_tensorrt.py
index 5f90f826dd..14d3394a48 100644
--- a/tests/python/relax/test_codegen_tensorrt.py
+++ b/tests/python/relax/test_codegen_tensorrt.py
@@ -122,14 +122,20 @@ def _offload_and_compare(mod, params_np, patterns, 
data_np, rtol=1e-2, atol=1e-2
     otherwise collapse repeated ops.
     """
     ref = build_and_run(mod, [data_np, *params_np.values()], "llvm", 
legalize=True)
-    offloaded = tvm.transform.Sequential(
+    partitioned = tvm.transform.Sequential(
         [
             relax.transform.BindParams("main", params_np),
             relax.transform.FuseOpsByPattern(patterns),
             relax.transform.MergeCompositeFunctions(),
-            relax.transform.RunCodegen(),
         ]
     )(mod)
+    # Guard against a silent false pass: if no pattern matched, nothing is 
offloaded and the
+    # comparison would trivially succeed via the TVM fallback without 
exercising the converter.
+    assert any(
+        isinstance(fn, relax.Function) and fn.attrs is not None and "Codegen" 
in fn.attrs
+        for fn in partitioned.functions.values()
+    ), "expected the op under test to be offloaded to TensorRT, but nothing 
was partitioned"
+    offloaded = relax.transform.RunCodegen()(partitioned)
     out = build_and_run(offloaded, [data_np], "cuda")
     tvm.testing.assert_allclose(out, ref, rtol=rtol, atol=atol)
 
@@ -316,5 +322,279 @@ def test_tensorrt_int8_calibration(monkeypatch):
     tvm.testing.assert_allclose(out, ref, rtol=0.2, atol=0.1 * 
float(np.abs(ref).max()))
 
 
+def test_tensorrt_matmul():
+    # Regression test: Relax matmul has no transpose_a/transpose_b attrs 
(Relay's batch_matmul did).
+    @tvm.script.ir_module
+    class Matmul:
+        @R.function
+        def main(data: R.Tensor((4, 8), "float32"), weight: R.Tensor((8, 16), 
"float32")):
+            with R.dataflow():
+                out = relax.op.matmul(data, weight)
+                R.output(out)
+            return out
+
+    data = np.random.randn(4, 8).astype("float32")
+    weight = np.random.randn(8, 16).astype("float32")
+    patterns = [("tensorrt.nn.batch_matmul", is_op("relax.matmul")(wildcard(), 
wildcard()))]
+    _offload_and_compare(Matmul, {"weight": weight}, patterns, data)
+
+
+def test_tensorrt_sum():
+    # Regression test: Relax reduce ops (StatisticalAttrs) have no "exclude" 
attr.
+    @tvm.script.ir_module
+    class Sum:
+        @R.function
+        def main(data: R.Tensor((2, 3, 4), "float32")):
+            with R.dataflow():
+                out = relax.op.sum(data, axis=[1], keepdims=True)
+                R.output(out)
+            return out
+
+    data = np.random.randn(2, 3, 4).astype("float32")
+    patterns = [("tensorrt.sum", is_op("relax.sum")(wildcard()))]
+    _offload_and_compare(Sum, {}, patterns, data)
+
+
+def test_tensorrt_expand_dims():
+    # Regression test: Relax expand_dims carries an `axis` list, not Relay's 
axis + num_newaxis.
+    @tvm.script.ir_module
+    class ExpandDims:
+        @R.function
+        def main(data: R.Tensor((2, 4), "float32")):
+            with R.dataflow():
+                out = relax.op.expand_dims(data, axis=[1, 3])
+                R.output(out)
+            return out
+
+    data = np.random.randn(2, 4).astype("float32")
+    patterns = [("tensorrt.expand_dims", 
is_op("relax.expand_dims")(wildcard()))]
+    _offload_and_compare(ExpandDims, {}, patterns, data)
+
+
+def test_tensorrt_layer_norm():
+    # Regression test: Relax layer_norm normalizes over an `axes` list (Relay 
used `axis`).
+    @tvm.script.ir_module
+    class LayerNorm:
+        @R.function
+        def main(
+            data: R.Tensor((2, 4, 8), "float32"),
+            gamma: R.Tensor((8,), "float32"),
+            beta: R.Tensor((8,), "float32"),
+        ):
+            with R.dataflow():
+                out = relax.op.nn.layer_norm(data, gamma, beta, axes=[-1])
+                R.output(out)
+            return out
+
+    data = np.random.randn(2, 4, 8).astype("float32")
+    gamma = np.random.randn(8).astype("float32")
+    beta = np.random.randn(8).astype("float32")
+    patterns = [
+        ("tensorrt.nn.layer_norm", is_op("relax.nn.layer_norm")(wildcard(), 
wildcard(), wildcard()))
+    ]
+    _offload_and_compare(LayerNorm, {"gamma": gamma, "beta": beta}, patterns, 
data)
+
+
+def test_tensorrt_clip():
+    # Regression test: Relax clip passes min/max as PrimValue arguments (Relay 
used a_min/a_max
+    # attributes); the codegen serializes them under the op's argument names.
+    @tvm.script.ir_module
+    class Clip:
+        @R.function
+        def main(data: R.Tensor((2, 8, 16, 16), "float32")):
+            with R.dataflow():
+                out = relax.op.clip(data, 0.0, 6.0)
+                R.output(out)
+            return out
+
+    data = (np.random.randn(2, 8, 16, 16) * 4).astype("float32")
+    patterns = [("tensorrt.clip", is_op("relax.clip")(wildcard(), wildcard(), 
wildcard()))]
+    _offload_and_compare(Clip, {}, patterns, data)
+
+
+def test_tensorrt_reshape():
+    # Regression test: Relax reshape takes the target shape as a Shape 
argument (Relay used a
+    # "newshape" attribute); the codegen serializes it under the op's argument 
name.
+    @tvm.script.ir_module
+    class Reshape:
+        @R.function
+        def main(data: R.Tensor((2, 8, 4, 4), "float32")):
+            with R.dataflow():
+                out = relax.op.reshape(data, (2, 8, 16))
+                R.output(out)
+            return out
+
+    data = np.random.randn(2, 8, 4, 4).astype("float32")
+    patterns = [("tensorrt.reshape", is_op("relax.reshape")(wildcard(), 
wildcard()))]
+    _offload_and_compare(Reshape, {}, patterns, data)
+
+
+def test_tensorrt_strided_slice():
+    # Regression test: Relax strided_slice passes axes/begin/end/strides as 
tuple arguments (Relay
+    # used start/size/strides attributes); the codegen serializes them 
positionally.
+    @tvm.script.ir_module
+    class StridedSlice:
+        @R.function
+        def main(data: R.Tensor((4, 8, 16), "float32")):
+            with R.dataflow():
+                out = relax.op.strided_slice(
+                    data, axes=[1, 2], begin=[2, 0], end=[6, 8], strides=[2, 1]
+                )
+                R.output(out)
+            return out
+
+    data = np.random.randn(4, 8, 16).astype("float32")
+    patterns = [
+        (
+            "tensorrt.strided_slice",
+            is_op("relax.strided_slice")(
+                wildcard(), wildcard(), wildcard(), wildcard(), wildcard()
+            ),
+        )
+    ]
+    _offload_and_compare(StridedSlice, {}, patterns, data)
+
+
+def test_tensorrt_split():
+    # Regression test: Relax split has no Relay-style "mode"; it is 
multi-output. The converter
+    # derives per-output extents from the codegen-recorded output shapes.
+    @tvm.script.ir_module
+    class Split:
+        @R.function
+        def main(data: R.Tensor((4, 8, 16), "float32")):
+            with R.dataflow():
+                parts = relax.op.split(data, 2, axis=1)
+                out = relax.op.add(parts[0], parts[1])
+                R.output(out)
+            return out
+
+    data = np.random.randn(4, 8, 16).astype("float32")
+    # Offload the add too so both split outputs are consumed inside TensorRT 
(and nothing is left
+    # for the VM to legalize).
+    patterns = [
+        ("tensorrt.split", is_op("relax.split")(wildcard())),
+        ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
+    ]
+    _offload_and_compare(Split, {}, patterns, data)
+
+
+def test_tensorrt_layout_transform():
+    # Regression test: Relax layout_transform uses an IndexMap (Relay used 
src_layout/dst_layout
+    # strings); the codegen translates a pure-permutation index map into a 
transpose. Built with the
+    # BlockBuilder because the index_map lambda cannot be expressed in 
TVMScript.
+    bb = relax.BlockBuilder()
+    data = relax.Var("data", relax.TensorStructInfo((1, 4, 8, 8), "float32"))
+    with bb.function("main", [data]):
+        with bb.dataflow():
+            out = bb.emit(
+                relax.op.layout_transform(data, index_map=lambda n, c, h, w: 
(n, h, w, c))
+            )
+            gv = bb.emit_output(out)
+        bb.emit_func_output(gv)
+    LayoutTransform = bb.finalize()
+
+    data_np = np.random.randn(1, 4, 8, 8).astype("float32")
+    patterns = [("tensorrt.layout_transform", 
is_op("relax.layout_transform")(wildcard()))]
+    _offload_and_compare(LayoutTransform, {}, patterns, data_np)
+
+
+def test_tensorrt_sum_all_axes():
+    # Edge case: Relax sum with no axis (StatisticalAttrs.axis = None) reduces 
over all axes.
+    @tvm.script.ir_module
+    class SumAll:
+        @R.function
+        def main(data: R.Tensor((2, 3, 4), "float32")):
+            with R.dataflow():
+                out = relax.op.sum(data, keepdims=True)
+                R.output(out)
+            return out
+
+    data = np.random.randn(2, 3, 4).astype("float32")
+    patterns = [("tensorrt.sum", is_op("relax.sum")(wildcard()))]
+    _offload_and_compare(SumAll, {}, patterns, data)
+
+
+def test_tensorrt_layer_norm_multi_axis():
+    # Edge case: layer_norm normalizing over more than one axis.
+    @tvm.script.ir_module
+    class LayerNorm2:
+        @R.function
+        def main(
+            data: R.Tensor((2, 3, 4, 5), "float32"),
+            gamma: R.Tensor((4, 5), "float32"),
+            beta: R.Tensor((4, 5), "float32"),
+        ):
+            with R.dataflow():
+                out = relax.op.nn.layer_norm(data, gamma, beta, axes=[-2, -1])
+                R.output(out)
+            return out
+
+    data = np.random.randn(2, 3, 4, 5).astype("float32")
+    gamma = np.random.randn(4, 5).astype("float32")
+    beta = np.random.randn(4, 5).astype("float32")
+    patterns = [
+        ("tensorrt.nn.layer_norm", is_op("relax.nn.layer_norm")(wildcard(), 
wildcard(), wildcard()))
+    ]
+    _offload_and_compare(LayerNorm2, {"gamma": gamma, "beta": beta}, patterns, 
data)
+
+
+def test_tensorrt_matmul_batched():
+    # Edge case: batched (3-D) matmul exercises TensorRT's leading-dim 
broadcasting.
+    @tvm.script.ir_module
+    class BatchMatmul:
+        @R.function
+        def main(data: R.Tensor((2, 4, 8), "float32"), weight: R.Tensor((2, 8, 
16), "float32")):
+            with R.dataflow():
+                out = relax.op.matmul(data, weight)
+                R.output(out)
+            return out
+
+    data = np.random.randn(2, 4, 8).astype("float32")
+    weight = np.random.randn(2, 8, 16).astype("float32")
+    patterns = [("tensorrt.nn.batch_matmul", is_op("relax.matmul")(wildcard(), 
wildcard()))]
+    _offload_and_compare(BatchMatmul, {"weight": weight}, patterns, data)
+
+
+def test_tensorrt_strided_slice_no_strides():
+    # Edge case: strided_slice without an explicit strides argument (defaults 
to 1).
+    @tvm.script.ir_module
+    class StridedSliceNoStride:
+        @R.function
+        def main(data: R.Tensor((4, 8, 16), "float32")):
+            with R.dataflow():
+                out = relax.op.strided_slice(data, axes=[1], begin=[2], 
end=[6])
+                R.output(out)
+            return out
+
+    data = np.random.randn(4, 8, 16).astype("float32")
+    patterns = [
+        (
+            "tensorrt.strided_slice",
+            is_op("relax.strided_slice")(wildcard(), wildcard(), wildcard(), 
wildcard()),
+        )
+    ]
+    _offload_and_compare(StridedSliceNoStride, {}, patterns, data)
+
+
+def test_tensorrt_split_indices():
+    # Edge case: split by an explicit index list (the other 
indices_or_sections form).
+    @tvm.script.ir_module
+    class SplitIdx:
+        @R.function
+        def main(data: R.Tensor((4, 8, 16), "float32")):
+            with R.dataflow():
+                parts = relax.op.split(data, [4], axis=1)
+                out = relax.op.add(parts[0], parts[1])
+                R.output(out)
+            return out
+
+    data = np.random.randn(4, 8, 16).astype("float32")
+    patterns = [
+        ("tensorrt.split", is_op("relax.split")(wildcard())),
+        ("tensorrt.add", is_op("relax.add")(wildcard(), wildcard())),
+    ]
+    _offload_and_compare(SplitIdx, {}, patterns, data)
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to