This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dda158ca10 [Fix][Relax] Support ND batched matmul chains in 
AdjustMatmulOrder pass (#19650)
dda158ca10 is described below

commit dda158ca1056cd5b8a0b5ff5b550faade0daf6f7
Author: ConvolutedDog <[email protected]>
AuthorDate: Wed Jun 3 01:39:44 2026 +0800

    [Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass 
(#19650)
    
    Fix a crash (https://github.com/apache/tvm/issues/19576) when
    AdjustMatmulOrder encounters mixed-dimension matmul chains common in
    transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass
    previously assumed all operands in a chained rewrite were 2D and
    asserted shape_c.size() == 2, failing on 3D intermediate results.
    
    Changes:
    - Replace full 2D transpose with permute_last_two_dims for permuted
    matmul patterns, swapping only the last two axes for ND tensors.
    - Remove hard ndim==2 checks in the permuted rewrite path.
    - Account for batch prefixes when comparing naive matmul FLOPs, so
    reorder decisions reflect batched vs. weight-only inner matmuls.
    - Skip reorder when neither evaluation order is provably cheaper.
    - Add regression tests for symbolic/concrete batched LoRA shapes.
    - Add a numerics test covering a minimal attention block with ND
    permute_dims.
---
 src/relax/op/op_common.cc                          |  48 ++-
 src/relax/op/op_common.h                           |  30 ++
 src/relax/transform/adjust_matmul_order.cc         | 132 +++++--
 .../relax/test_transform_adjust_matmul_order.py    | 408 +++++++++++++++++++--
 4 files changed, 552 insertions(+), 66 deletions(-)

diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index 61485b0911..a019b87f3a 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -22,6 +22,7 @@
 #include <tvm/ffi/cast.h>
 
 #include <algorithm>
+#include <sstream>
 
 namespace tvm {
 namespace relax {
@@ -108,10 +109,10 @@ ffi::Array<TensorStructInfo> 
GetTensorStructInfoFromTuple(const Call& call, cons
   return tensor_sinfo;
 }
 
-ffi::Optional<ffi::Array<PrimExpr>> InferBinaryBroadcastShape(
-    const Call& call, const BlockBuilder& ctx, const ffi::Array<PrimExpr>& 
x1_shape,
-    const ffi::Array<PrimExpr>& x2_shape) {
-  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* 
analyzer,
+                                                          const 
ffi::Array<PrimExpr>& x1_shape,
+                                                          const 
ffi::Array<PrimExpr>& x2_shape) {
+  BinaryBroadcastShapeInferResult result;
   int x1_ndim = x1_shape.size();
   int x2_ndim = x2_shape.size();
   int max_ndim = std::max(x1_ndim, x2_ndim);
@@ -132,20 +133,45 @@ ffi::Optional<ffi::Array<PrimExpr>> 
InferBinaryBroadcastShape(
     } else if (analyzer->CanProveEqual(dim0, dim1)) {
       output_shape.push_back(dim0);
     } else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) {
-      ctx->ReportFatal(Diagnostic::Error(call)
-                       << "In " << call->op << ", the first input shape at dim 
" << x1_ndim - i
-                       << " is " << dim0 << " and the second input shape at 
dim " << x2_ndim - i
-                       << " is " << dim1 << ", which are not broadcastable.");
+      result.status = BinaryBroadcastShapeInferResult::Status::kConflict;
+      result.message = [&]() {
+        std::ostringstream os;
+        os << "the first input shape at dim " << x1_ndim - i << " is " << dim0
+           << " and the second input shape at dim " << x2_ndim - i << " is " 
<< dim1
+           << ", which are not broadcastable.";
+        return ffi::String(os.str());
+      }();
+      return result;
     } else {
-      // Use simple fallback when shape mismatch.
-      return std::nullopt;
+      result.status = BinaryBroadcastShapeInferResult::Status::kUnknown;
+      return result;
     }
   }
   auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape;
   for (; i <= max_ndim; ++i) {
     output_shape.push_back(longer_shape[max_ndim - i]);
   }
-  return ffi::Array<PrimExpr>(output_shape.rbegin(), output_shape.rend());
+  result.status = BinaryBroadcastShapeInferResult::Status::kSuccess;
+  result.shape = ffi::Array<PrimExpr>(output_shape.rbegin(), 
output_shape.rend());
+  return result;
+}
+
+ffi::Optional<ffi::Array<PrimExpr>> InferBinaryBroadcastShape(
+    const Call& call, const BlockBuilder& ctx, const ffi::Array<PrimExpr>& 
x1_shape,
+    const ffi::Array<PrimExpr>& x2_shape) {
+  auto infer_result = InferBinaryBroadcastShape(ctx->GetAnalyzer(), x1_shape, 
x2_shape);
+  if (infer_result.status == 
BinaryBroadcastShapeInferResult::Status::kConflict) {
+    TVM_FFI_ICHECK(infer_result.message.has_value());
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "In " << call->op << ", " << 
infer_result.message.value());
+  } else if (infer_result.status == 
BinaryBroadcastShapeInferResult::Status::kSuccess) {
+    TVM_FFI_ICHECK(infer_result.shape.has_value());
+    return infer_result.shape.value();
+  } else {
+    // Unknown status, use simple fallback when shape mismatch.
+    return std::nullopt;
+  }
+  TVM_FFI_UNREACHABLE();
 }
 
 std::vector<int> NormalizeAxes(const Call& call, const BlockBuilder& ctx, int 
ndim,
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 774eccfd58..6f7de974cb 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -387,6 +387,36 @@ inline ffi::Optional<VDevice> 
InferBinaryArithOpOutVDevice(const Call& call,
   return lhs_vdevice;
 }
 
+/*! \brief Result of binary broadcast shape inference without diagnostic 
context. */
+struct BinaryBroadcastShapeInferResult {
+  enum class Status {
+    /*! \brief Broadcast output shape is known. */
+    kSuccess,
+    /*! \brief Shapes may be broadcastable but cannot be proved symbolically. 
*/
+    kUnknown,
+    /*! \brief Concrete shapes are not broadcastable. */
+    kConflict,
+  };
+
+  /*! \brief Inference status. */
+  Status status = Status::kUnknown;
+  /*! \brief Broadcasted shape if status is kSuccess. */
+  ffi::Optional<ffi::Array<PrimExpr>> shape;
+  /*! \brief Human-readable conflict description if status is kConflict. */
+  ffi::Optional<ffi::String> message;
+};
+
+/*!
+ * \brief Infer the output shape for binary broadcast operators.
+ * \param analyzer The arithmetic analyzer used to prove shape equality.
+ * \param x1_shape The shape of the first operand.
+ * \param x2_shape The shape of the second operand.
+ * \return Inference status and broadcasted shape, or a conflict message.
+ */
+BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer* 
analyzer,
+                                                          const 
ffi::Array<PrimExpr>& x1_shape,
+                                                          const 
ffi::Array<PrimExpr>& x2_shape);
+
 /*!
  * \brief Infer the output shape for binary broadcast operators.
  * \param call The context Call to the operator.
diff --git a/src/relax/transform/adjust_matmul_order.cc 
b/src/relax/transform/adjust_matmul_order.cc
index 9ea47aa648..012c8ce5b7 100644
--- a/src/relax/transform/adjust_matmul_order.cc
+++ b/src/relax/transform/adjust_matmul_order.cc
@@ -34,6 +34,7 @@
 #include <unordered_set>
 #include <vector>
 
+#include "../op/op_common.h"
 #include "../op/tensor/linear_algebra.h"
 #include "../op/tensor/manipulate.h"
 
@@ -41,6 +42,27 @@ namespace tvm {
 namespace relax {
 
 namespace {
+
+ffi::Array<PrimExpr> GetBatchPrefix(const ffi::Array<PrimExpr>& shape) {
+  if (shape.size() <= 2) return {};
+  return {shape.begin(), shape.end() - 2};
+}
+
+PrimExpr ProductDims(const ffi::Array<PrimExpr>& dims) {
+  PrimExpr product = IntImm(DataType::Int(64), 1);
+  for (const auto& dim : dims) product = product * dim;
+  return product;
+}
+
+ffi::Optional<ffi::Array<PrimExpr>> InferBatchedMatmulBroadcastPrefix(
+    arith::Analyzer* analyzer, const ffi::Array<PrimExpr>& x1, const 
ffi::Array<PrimExpr>& x2) {
+  auto infer_result = InferBinaryBroadcastShape(analyzer, x1, x2);
+  if (infer_result.status == 
BinaryBroadcastShapeInferResult::Status::kSuccess) {
+    return infer_result.shape;
+  }
+  return std::nullopt;
+}
+
 std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern, 
Expr>)>> CreatePatterns(
     const Function& func) {
   auto compile_time_arr = ComputableAtCompileTime(func);
@@ -141,20 +163,46 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, 
ffi::Map<DFPattern, Expr>)>>
     auto shape_b = opt_shape_b.value();
     auto shape_c = opt_shape_c.value();
 
+    auto permute_last_two_dims = [&](Expr expr) -> Expr {
+      auto opt_shape = get_shape(expr);
+      if (!opt_shape) return expr;
+
+      size_t ndim = opt_shape.value().size();
+      TVM_FFI_ICHECK_GE(ndim, 2);
+
+      ffi::Optional<ffi::Array<int64_t>> axes;
+
+      if (ndim == 2) {
+        // Pass none axes to permute_dims for simple transpose of 2D tensors.
+        axes = std::nullopt;
+      } else {
+        ffi::Array<int64_t> axes_array;
+        for (size_t i = 0; i < ndim; ++i) axes_array.push_back(i);
+        axes_array.Set(ndim - 1, ndim - 2);
+        axes_array.Set(ndim - 2, ndim - 1);
+        axes = ffi::Optional<ffi::Array<int64_t>>(axes_array);
+      }
+      return permute_dims(std::move(expr), axes);
+    };
+
+    auto transpose_shape_last_two_dims = [&](ffi::Array<PrimExpr>& shape) {
+      PrimExpr last_dim_shape = shape[shape.size() - 1];
+      shape.Set(shape.size() - 1, shape[shape.size() - 2]);
+      shape.Set(shape.size() - 2, last_dim_shape);
+    };
+
     if (matches.count(pat_permuted_matmul_on_lhs)) {
-      expr_a = permute_dims(expr_a, std::nullopt);
-      expr_b = permute_dims(expr_b, std::nullopt);
-      TVM_FFI_ICHECK_EQ(shape_a.size(), 2);
-      TVM_FFI_ICHECK_EQ(shape_b.size(), 2);
-      shape_a = {shape_a[1], shape_a[0]};
-      shape_b = {shape_b[1], shape_b[0]};
+      if (shape_a.size() < 2 || shape_b.size() < 2) return expr;
+      expr_a = permute_last_two_dims(expr_a);
+      expr_b = permute_last_two_dims(expr_b);
+      transpose_shape_last_two_dims(shape_a);
+      transpose_shape_last_two_dims(shape_b);
     } else if (matches.count(pat_permuted_matmul_on_rhs)) {
-      expr_b = permute_dims(expr_b, std::nullopt);
-      expr_c = permute_dims(expr_c, std::nullopt);
-      TVM_FFI_ICHECK_EQ(shape_b.size(), 2);
-      TVM_FFI_ICHECK_EQ(shape_c.size(), 2);
-      shape_b = {shape_b[1], shape_b[0]};
-      shape_c = {shape_c[1], shape_c[0]};
+      if (shape_b.size() < 2 || shape_c.size() < 2) return expr;
+      expr_b = permute_last_two_dims(expr_b);
+      expr_c = permute_last_two_dims(expr_c);
+      transpose_shape_last_two_dims(shape_b);
+      transpose_shape_last_two_dims(shape_c);
     }
 
     // If two of the three are compile-time, group those two values
@@ -166,13 +214,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, 
ffi::Map<DFPattern, Expr>)>>
     }
 
     // Otherwise, select the order that reduces the total number of
-    // operations required, assuming a naive matmul.
-
-    // Matmul on LHS: ([N,R]*[R,M]) * [M,batch]
-    // Matmul on RHS: [N,R] * ([R,M]*[M,batch])
-    //
-    // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)`
-    // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch`
+    // operations required, assuming a naive matmul (see below).
 
     if (shape_a.size() == 1) {
       shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]};
@@ -192,21 +234,54 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, 
ffi::Map<DFPattern, Expr>)>>
       shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)};
     }
 
-    auto size_N = shape_a[shape_a.size() - 2];
-    auto size_R = shape_a[shape_a.size() - 1];
-    auto size_M = shape_c[shape_c.size() - 2];
-    auto size_B = shape_c[shape_c.size() - 1];
-
-    auto ops_with_lhs_first = (size_R + size_B) * size_N * size_M;
-    auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B;
+    PrimExpr size_N = shape_a[shape_a.size() - 2];  // row of A
+    PrimExpr size_R = shape_a[shape_a.size() - 1];  // col of A and row of B
+    PrimExpr size_M = shape_c[shape_c.size() - 2];  // row of C and col of B
+    PrimExpr size_B = shape_c[shape_c.size() - 1];  // col of C
 
     arith::Analyzer analyzer;
+    auto prefix_a = GetBatchPrefix(shape_a);
+    auto prefix_b = GetBatchPrefix(shape_b);
+    auto prefix_c = GetBatchPrefix(shape_c);
+
+    auto opt_prefix_ab = InferBatchedMatmulBroadcastPrefix(&analyzer, 
prefix_a, prefix_b);
+    if (!opt_prefix_ab) return expr;
+    auto opt_prefix_bc = InferBatchedMatmulBroadcastPrefix(&analyzer, 
prefix_b, prefix_c);
+    if (!opt_prefix_bc) return expr;
+    auto opt_prefix_outer_lhs =
+        InferBatchedMatmulBroadcastPrefix(&analyzer, opt_prefix_ab.value(), 
prefix_c);
+    if (!opt_prefix_outer_lhs) return expr;
+    auto opt_prefix_outer_rhs =
+        InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a, 
opt_prefix_bc.value());
+    if (!opt_prefix_outer_rhs) return expr;
+
+    PrimExpr batch_ab = ProductDims(opt_prefix_ab.value());
+    PrimExpr batch_bc = ProductDims(opt_prefix_bc.value());
+    PrimExpr batch_outer_lhs = ProductDims(opt_prefix_outer_lhs.value());
+    PrimExpr batch_outer_rhs = ProductDims(opt_prefix_outer_rhs.value());
+
+    // Compare naive matmul FLOPs for two evaluation orders of
+    //   matmul(A, matmul(B, C))  vs  matmul(matmul(A, B), C)
+    //
+    // Matrix dims (last two axes): A [N, R], B [R, M], C [M, B_last]
+    // Each matmul uses the broadcasted batch prefix of its operands.
+    //
+    // LHS first — matmul(matmul(A, B), C):
+    //   batch_ab * N * R * M + batch_outer_lhs * N * M * B_last
+    PrimExpr ops_with_lhs_first =
+        batch_ab * size_N * size_R * size_M + batch_outer_lhs * size_N * 
size_M * size_B;
+    // RHS first — matmul(A, matmul(B, C)):
+    //   batch_bc * R * M * B_last + batch_outer_rhs * N * R * B_last
+    PrimExpr ops_with_rhs_first =
+        batch_bc * size_R * size_M * size_B + batch_outer_rhs * size_N * 
size_R * size_B;
+
     
analyzer.rewrite_simplify.SetEnabledExtensions(static_cast<arith::RewriteSimplifier::Extension>(
         analyzer.rewrite_simplify.GetEnabledExtensions() |
         arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum));
     With<arith::ConstraintContext> func_attr_constraint(&analyzer, 
symbolic_var_constraints);
     With<arith::ConstraintContext> analyzer_constraint(
-        &analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0);
+        &analyzer, batch_ab > 0 && batch_bc > 0 && batch_outer_lhs > 0 && 
batch_outer_rhs > 0 &&
+                       size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0);
 
     if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) {
       return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c, 
DataType::Void());
@@ -214,8 +289,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, 
ffi::Map<DFPattern, Expr>)>>
       return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()), 
DataType::Void());
     }
 
-    // If we cannot determine which order is best, keep the existing
-    // order.
+    // If we cannot determine which order is best, keep the existing order.
     return expr;
   };
 
diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py 
b/tests/python/relax/test_transform_adjust_matmul_order.py
index a086f3abdb..9600c97bda 100644
--- a/tests/python/relax/test_transform_adjust_matmul_order.py
+++ b/tests/python/relax/test_transform_adjust_matmul_order.py
@@ -17,8 +17,11 @@
 
 import inspect
 
+import numpy as np
 import pytest
+import torch
 
+import tvm
 import tvm.testing
 from tvm import relax
 from tvm.script import ir as I
@@ -39,7 +42,13 @@ class Base:
 
 
 class TestLHS(Base):
-    """Prefer (x*A)*B instead of x*(A*B)"""
+    """Prefer (x*A)*B instead of x*(A*B)
+
+    LHS first - (x*A)*B:
+        ops = 1*16*2 + 1*2*32 = 96
+    RHS first - x*(A*B):
+        ops = 16*2*32 + 1*16*32 = 1536
+    """
 
     @I.ir_module
     class Before:
@@ -67,7 +76,13 @@ class TestLHS(Base):
 
 
 class TestRHS(Base):
-    """Prefer A*(B*x) instead of (A*B)*x"""
+    """Prefer A*(B*x) instead of (A*B)*x
+
+    LHS first - (A*B)*x:
+        ops = 32*2*16 + 32*16*1 = 1536
+    RHS first - A*(B*x):
+        ops = 2*16*1 + 32*2*1 = 96
+    """
 
     @I.ir_module
     class Before:
@@ -163,6 +178,13 @@ class TestLHSDynamic(Base):
 
     This case appears when evaluating LoRA-tuned models with a dynamic
     rank.
+
+    LHS first - (x*A)*B:
+        ops = 1*16*lora_r + 1*lora_r*32 = 48*lora_r
+    RHS first - x*(A*B):
+        ops = 16*lora_r*32 + 1*16*32 = 512*lora_r + 512
+
+    48*lora_r can be proved to be less than 512*lora_r + 512, so the LHS first 
is preferred.
     """
 
     @I.ir_module
@@ -192,7 +214,15 @@ class TestLHSDynamic(Base):
 
 
 class TestRHSDynamic(Base):
-    """Prefer A*(B*x) instead of (A*B)*x"""
+    """Prefer A*(B*x) instead of (A*B)*x
+
+    LHS first - (A*B)*x:
+        ops = 32*lora_r*16 + 32*16*1 = 512*lora_r + 512
+    RHS first - A*(B*x):
+        ops = lora_r*16*1 + 32*lora_r*1 = 48*lora_r
+
+    48*lora_r can be proved to be less than 512*lora_r + 512, so the RHS first 
is preferred.
+    """
 
     @I.ir_module
     class Before:
@@ -234,8 +264,27 @@ class TestIdempotentRHSDynamic(Base):
     Expected = TestRHSDynamic.Expected
 
 
-class TestLHSDynamicWithBatch(Base):
-    """Prefer (x*A)*B instead of x*(A*B)"""
+class TestDynamicWithBatchSymbolic1(Base):
+    """When both batch_size and lora_r are symbolic and it cannot be proven 
which
+    is cheaper, LHS or RHS, maintain the existing order.
+
+    `Before` computes `x * (A * B)` with
+    `x: [batch_size, 1, 16]`, `A: [16, lora_r]`, `B: [lora_r, 32]`.
+
+    RHS first - x * (A * B):
+        16*lora_r*32 + batch_size*1*16*32 = 512*(lora_r + batch_size)
+
+    LHS first - (x * A) * B:
+        batch_size*1*16*lora_r + batch_size*1*lora_r*32 = 48*batch_size*lora_r
+
+    When `batch_size` and `lora_r` are known at compile-time:
+        - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r + 
batch_size),
+          the LHS first is preferred.
+        - satisfy the inequality 512*(lora_r + batch_size) < 
48*batch_size*lora_r,
+          the RHS first is preferred.
+
+    Without bounds on `batch_size` and `lora_r`, neither side is provably 
cheaper.
+    """
 
     @I.ir_module
     class Before:
@@ -250,6 +299,31 @@ class TestLHSDynamicWithBatch(Base):
             out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight)
             return out
 
+    Expected = Before
+
+
+class TestDynamicWithBatchConcrete1LHSFirst(Base):
+    """With concrete shapes, LHS first is provably cheaper.
+
+    batch_size=4, lora_r=16:
+        LHS first: 48*4*16 = 3072
+        RHS first: 512*(16 + 4) = 10240
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 1, 16]),
+            A: R.Tensor([16, "lora_r"]),
+            B: R.Tensor(["lora_r", 32]),
+        ) -> R.Tensor(["batch_size", 1, 32]):
+            batch_size = T.int64(4)
+            lora_r = T.int64(16)  # noqa: F841
+            weight: R.Tensor([16, 32]) = R.matmul(A, B)
+            out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight)
+            return out
+
     @I.ir_module
     class Expected:
         @R.function
@@ -258,15 +332,71 @@ class TestLHSDynamicWithBatch(Base):
             A: R.Tensor([16, "lora_r"]),
             B: R.Tensor(["lora_r", 32]),
         ) -> R.Tensor(["batch_size", 1, 32]):
-            lora_r = T.int64()
-            batch_size = T.int64()
-            x: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A)
-            x: R.Tensor([batch_size, 1, 32]) = R.matmul(x, B)
-            return x
+            batch_size = T.int64(4)
+            lora_r = T.int64(16)
+            weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A)
+            out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B)
+            return out
 
 
-class TestRHSDynamicWithBatch(Base):
-    """Prefer A*(B*x) instead of (A*B)*x"""
+class TestDynamicWithBatchConcrete1RHSFirst(Base):
+    """With concrete shapes, RHS first is provably cheaper.
+
+    batch_size=64, lora_r=16:
+        LHS first: 48*64*16 = 49152
+        RHS first: 512*(16 + 64) = 40960
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 1, 16]),
+            A: R.Tensor([16, "lora_r"]),
+            B: R.Tensor(["lora_r", 32]),
+        ) -> R.Tensor(["batch_size", 1, 32]):
+            batch_size = T.int64(64)
+            lora_r = T.int64(16)
+            weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A)
+            out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 1, 16]),
+            A: R.Tensor([16, "lora_r"]),
+            B: R.Tensor(["lora_r", 32]),
+        ) -> R.Tensor(["batch_size", 1, 32]):
+            batch_size = T.int64(64)
+            lora_r = T.int64(16)  # noqa: F841
+            weight: R.Tensor([16, 32]) = R.matmul(A, B)
+            out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight)
+            return out
+
+
+class TestDynamicWithBatchSymbolic2(Base):
+    """When both batch_size and lora_r are symbolic and it cannot be proven 
which
+    is cheaper, LHS or RHS, maintain the existing order.
+
+    `Before` computes `(A * B) * x` with
+    `A: [32, lora_r]`, `B: [lora_r, 16]`, `x: [batch_size, 16, 1]`.
+
+    LHS first - (A * B) * x:
+        32*lora_r*16 + batch_size*32*16*1 = 512*(lora_r + batch_size)
+
+    RHS first - A * (B * x):
+        batch_size*lora_r*16*1 + batch_size*32*lora_r*1 = 48*batch_size*lora_r
+
+    When `batch_size` and `lora_r` are known at compile-time:
+        - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r + 
batch_size),
+          the RHS first is preferred.
+        - satisfy the inequality 512*(lora_r + batch_size) < 
48*batch_size*lora_r,
+          the LHS first is preferred.
+
+    Without bounds on `batch_size` and `lora_r`, neither side is provably 
cheaper.
+    """
 
     @I.ir_module
     class Before:
@@ -281,6 +411,31 @@ class TestRHSDynamicWithBatch(Base):
             out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x)
             return out
 
+    Expected = Before
+
+
+class TestDynamicWithBatchConcrete2RHSFirst(Base):
+    """With concrete shapes, RHS first is provably cheaper.
+
+    batch_size=4, lora_r=16:
+        RHS first: 48*4*16 = 3072
+        LHS first: 512*(16 + 4) = 10240
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 16, 1]),
+            A: R.Tensor([32, "lora_r"]),
+            B: R.Tensor(["lora_r", 16]),
+        ) -> R.Tensor(["batch_size", 32, 1]):
+            batch_size = T.int64(4)
+            lora_r = T.int64(16)  # noqa: F841
+            weight: R.Tensor([32, 16]) = R.matmul(A, B)
+            out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x)
+            return out
+
     @I.ir_module
     class Expected:
         @R.function
@@ -289,11 +444,48 @@ class TestRHSDynamicWithBatch(Base):
             A: R.Tensor([32, "lora_r"]),
             B: R.Tensor(["lora_r", 16]),
         ) -> R.Tensor(["batch_size", 32, 1]):
-            lora_r = T.int64()
-            batch_size = T.int64()
-            x: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x)
-            x: R.Tensor([batch_size, 32, 1]) = R.matmul(A, x)
-            return x
+            batch_size = T.int64(4)
+            lora_r = T.int64(16)
+            weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x)
+            out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight)
+            return out
+
+
+class TestDynamicWithBatchConcrete2LHSFirst(Base):
+    """With concrete shapes, LHS first is provably cheaper.
+
+    batch_size=64, lora_r=16:
+        RHS first: 48*64*16 = 49152
+        LHS first: 512*(16 + 64) = 40960
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 16, 1]),
+            A: R.Tensor([32, "lora_r"]),
+            B: R.Tensor(["lora_r", 16]),
+        ) -> R.Tensor(["batch_size", 32, 1]):
+            batch_size = T.int64(64)
+            lora_r = T.int64(16)
+            weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x)
+            out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight)
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor(["batch_size", 16, 1]),
+            A: R.Tensor([32, "lora_r"]),
+            B: R.Tensor(["lora_r", 16]),
+        ) -> R.Tensor(["batch_size", 32, 1]):
+            batch_size = T.int64(64)
+            lora_r = T.int64(16)  # noqa: F841
+            weight: R.Tensor([32, 16]) = R.matmul(A, B)
+            out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x)
+            return out
 
 
 class TestNoOpForFullyDynamicOnLHS(Base):
@@ -353,6 +545,11 @@ class TestRHSPermuteDims(Base):
     """Prefer (x*A)*B instead of x*(A*B)
 
     Like `TestRHS`, but the weights on the RHS are transposed.
+
+    Before: x * (BT * AT)
+        ops = 16*2*32 + 1*16*32 = 1536
+    After: (x * BT) * AT
+        ops = 1*16*2 + 1*2*32 = 96
     """
 
     @I.ir_module
@@ -388,6 +585,13 @@ class TestRHSPermuteDimsDynamic(Base):
 
     Like `TestRHSPermuteDims`, but the weights on the RHS have a
     dynamic shape.
+
+    Before: x * (BT * AT)
+        ops = 16*lora_r*32 + 1*16*32 = 512*lora_r + 512
+    After: (x * BT) * AT
+        ops = 1*16*lora_r + 1*lora_r*32 = 48*lora_r
+
+    48*lora_r can be proved to be less than 512*lora_r + 512, so the After is 
preferred.
     """
 
     @I.ir_module
@@ -433,15 +637,15 @@ class TestRHSPermuteDimsWithDynamicBatch(Base):
         ops_left_to_right = (batch_size + lora_r)*4096*4096
         ops_right_to_left = (4096 + 4096)*batch_size*lora_r
 
-    Without an upper bound on `lora_r`, we cannot prove which of these
-    is the preferred execution order.  With the upper bound, TVM can
-    determine the preferred order using the following arithmethic
-    reasoning.
+    Without an upper bound on batch_size and`lora_r`, we cannot prove which
+    of these is the preferred execution order.
 
-        (batch_size + lora_r)*4096*4096 < (4096 + 4096)*batch_size*lora_r
-        (batch_size + lora_r)*2048 < batch_size*lora_r
-        1/batch_size + 1/lora_r < 1/2048
+    With the upper bound, TVM can determine the preferred order using
+    the following arithmetic reasoning.
 
+        (batch_size + lora_r)*4096*4096 > (4096 + 4096)*batch_size*lora_r
+        (batch_size + lora_r)*2048 > batch_size*lora_r
+        1/batch_size + 1/lora_r > 1/2048
     """
 
     @I.ir_module
@@ -452,7 +656,12 @@ class TestRHSPermuteDimsWithDynamicBatch(Base):
             A: R.Tensor([4096, "lora_r"]),
             B: R.Tensor(["lora_r", 4096]),
         ) -> R.Tensor(["batch_size", 4096]):
-            R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}})
+            R.func_attr(
+                {
+                    "tir_var_upper_bound": {"lora_r": 2048, "batch_size": 
2048},
+                }
+            )
+            lora_r = T.int64()  # noqa: F841
             batch_size = T.int64()
             linear_weight: R.Tensor([4096, 4096]) = R.matmul(A, B)
             matmul_weight: R.Tensor([4096, 4096]) = 
R.permute_dims(linear_weight)
@@ -467,7 +676,11 @@ class TestRHSPermuteDimsWithDynamicBatch(Base):
             A: R.Tensor([4096, "lora_r"]),
             B: R.Tensor(["lora_r", 4096]),
         ) -> R.Tensor(["batch_size", 4096]):
-            R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}})
+            R.func_attr(
+                {
+                    "tir_var_upper_bound": {"lora_r": 2048, "batch_size": 
2048},
+                }
+            )
             lora_r = T.int64()
             batch_size = T.int64()
             B_transpose = R.permute_dims(B)
@@ -482,6 +695,11 @@ class TestRHSPermuteDimsDynamicWithSquareMatrix(Base):
 
     Like `TestRHSPermuteDims`, but the weights on the RHS have a
     dynamic shape.
+
+    Before: x * (BT * AT)
+        ops = 32*lora_r*32 + 1*32*32 = 1024*lora_r + 1024
+    After: (x * BT) * AT
+        ops = 1*32*lora_r + 1*lora_r*32 = 64*lora_r
     """
 
     @I.ir_module
@@ -513,5 +731,143 @@ class TestRHSPermuteDimsDynamicWithSquareMatrix(Base):
             return x
 
 
+class TestBatchedBroadcastPreferLHSFirst(Base):
+    """Use broadcasted batch prefix per matmul, not independent prefix 
products.
+
+    Example with broadcast batch axes: A:[2,1,1], B:[2,1,2], C:[2,2,3].
+
+    LHS first: (A * B) * C
+        ops = 2*1*1*2 + 2*1*2*3 = 16
+    RHS first: A * (B * C)
+        ops = 2*1*2*3 + 2*1*1*3 = 18
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            A: R.Tensor([2, 1, 1]),
+            B: R.Tensor([2, 1, 2]),
+            C: R.Tensor([2, 2, 3]),
+        ) -> R.Tensor([2, 1, 3]):
+            out: R.Tensor([2, 1, 3]) = R.matmul(A, R.matmul(B, C))
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            A: R.Tensor([2, 1, 1]),
+            B: R.Tensor([2, 1, 2]),
+            C: R.Tensor([2, 2, 3]),
+        ) -> R.Tensor([2, 1, 3]):
+            temp: R.Tensor([2, 1, 2]) = R.matmul(A, B)
+            out: R.Tensor([2, 1, 3]) = R.matmul(temp, C)
+            return out
+
+
+class TestBatchedSharedPrefixPreferLHSFirst(Base):
+    """All operands share a nontrivial batch prefix [2, 3].
+
+    Shapes: A:[2,3,4,5], B:[2,3,5,6], C:[2,3,6,7]
+
+    LHS first:
+        ops = 6*4*5*6 + 6*4*6*7 = 1728
+    RHS first:
+        ops = 6*5*6*7 + 6*4*5*7 = 2100
+    """
+
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            A: R.Tensor([2, 3, 4, 5]),
+            B: R.Tensor([2, 3, 5, 6]),
+            C: R.Tensor([2, 3, 6, 7]),
+        ) -> R.Tensor([2, 3, 4, 7]):
+            out: R.Tensor([2, 3, 4, 7]) = R.matmul(A, R.matmul(B, C))
+            return out
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            A: R.Tensor([2, 3, 4, 5]),
+            B: R.Tensor([2, 3, 5, 6]),
+            C: R.Tensor([2, 3, 6, 7]),
+        ) -> R.Tensor([2, 3, 4, 7]):
+            temp: R.Tensor([2, 3, 4, 6]) = R.matmul(A, B)
+            out: R.Tensor([2, 3, 4, 7]) = R.matmul(temp, C)
+            return out
+
+
+class TestAdjustMatmulOrderAttentionBlock:
+    """AdjustMatmulOrder preserves numerics on a batched attention block.
+
+    Covers ND `permute_dims` (swap last two axes) inside `matmul(q, kt)`,
+    regression for issue #19576.
+    """
+
+    def _build_attention_module(self, batch, seq, dim):
+        """Minimal batched attention block exercising ND permute_dims + 
matmul."""
+        bb = relax.BlockBuilder()
+        x = relax.Var("x", relax.TensorStructInfo((batch, seq, dim), 
"float32"))
+        wq = relax.Var("wq", relax.TensorStructInfo((dim, dim), "float32"))
+        wk = relax.Var("wk", relax.TensorStructInfo((dim, dim), "float32"))
+        wv = relax.Var("wv", relax.TensorStructInfo((dim, dim), "float32"))
+        wo = relax.Var("wo", relax.TensorStructInfo((dim, dim), "float32"))
+        with bb.function("main", [x, wq, wk, wv, wo]):
+            with bb.dataflow():
+                q = bb.emit(relax.op.matmul(x, wq))
+                k = bb.emit(relax.op.matmul(x, wk))
+                v = bb.emit(relax.op.matmul(x, wv))
+                kt = bb.emit(relax.op.permute_dims(k, axes=[0, 2, 1]))
+                scores = bb.emit(relax.op.matmul(q, kt))
+                scale = bb.emit(relax.const(1.0 / np.sqrt(dim), "float32"))
+                scores = bb.emit(relax.op.multiply(scores, scale))
+                attn = bb.emit(relax.op.nn.softmax(scores, axis=-1))
+                out = bb.emit(relax.op.matmul(attn, v))
+                proj = bb.emit_output(relax.op.matmul(out, wo))
+            bb.emit_func_output(proj)
+        return bb.finalize()
+
+    def _run_relax_main(self, mod, inputs):
+        exe = relax.build(mod, target="llvm")
+        vm = relax.VirtualMachine(exe, device=tvm.cpu())
+        args = [tvm.runtime.tensor(arr, device=tvm.cpu()) for arr in inputs]
+        return vm["main"](*args).numpy()
+
+    def _torch_attention_ref(self, x_np, w_np, dim):
+        x = torch.from_numpy(x_np)
+        w = torch.from_numpy(w_np)
+        with torch.no_grad():
+            q = torch.matmul(x, w)
+            k = torch.matmul(x, w)
+            v = torch.matmul(x, w)
+            scores = torch.matmul(q, k.transpose(-2, -1))
+            scores = scores * (1.0 / np.sqrt(dim))
+            attn = torch.nn.functional.softmax(scores, dim=-1)
+            out = torch.matmul(attn, v)
+            out = torch.matmul(out, w)
+        return out.detach().numpy()
+
+    @pytest.mark.parametrize("batch,seq,dim", [(2, 16, 64)])
+    def test_attention_block_numerics(self, batch, seq, dim):
+        mod = self._build_attention_module(batch, seq, dim)
+        mod_opt = relax.transform.AdjustMatmulOrder()(mod)
+
+        x_np = np.random.randn(batch, seq, dim).astype("float32")
+        w_np = np.random.randn(dim, dim).astype("float32")
+        inputs = [x_np, w_np, w_np, w_np, w_np]
+
+        ref = self._torch_attention_ref(x_np, w_np, dim)
+        out_before = self._run_relax_main(mod, inputs)
+        out_after = self._run_relax_main(mod_opt, inputs)
+
+        tvm.testing.assert_allclose(out_before, ref, rtol=1e-3, atol=1e-3)
+        tvm.testing.assert_allclose(out_after, ref, rtol=1e-3, atol=1e-3)
+        tvm.testing.assert_allclose(out_before, out_after, rtol=1e-5, 
atol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to