This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new dda158ca10 [Fix][Relax] Support ND batched matmul chains in
AdjustMatmulOrder pass (#19650)
dda158ca10 is described below
commit dda158ca1056cd5b8a0b5ff5b550faade0daf6f7
Author: ConvolutedDog <[email protected]>
AuthorDate: Wed Jun 3 01:39:44 2026 +0800
[Fix][Relax] Support ND batched matmul chains in AdjustMatmulOrder pass
(#19650)
Fix a crash (https://github.com/apache/tvm/issues/19576) when
AdjustMatmulOrder encounters mixed-dimension matmul chains common in
transformer models (e.g. matmul(attn_output[B,S,D], W_o[D,D])). The pass
previously assumed all operands in a chained rewrite were 2D and
asserted shape_c.size() == 2, failing on 3D intermediate results.
Changes:
- Replace full 2D transpose with permute_last_two_dims for permuted
matmul patterns, swapping only the last two axes for ND tensors.
- Remove hard ndim==2 checks in the permuted rewrite path.
- Account for batch prefixes when comparing naive matmul FLOPs, so
reorder decisions reflect batched vs. weight-only inner matmuls.
- Skip reorder when neither evaluation order is provably cheaper.
- Add regression tests for symbolic/concrete batched LoRA shapes.
- Add a numerics test covering a minimal attention block with ND
permute_dims.
---
src/relax/op/op_common.cc | 48 ++-
src/relax/op/op_common.h | 30 ++
src/relax/transform/adjust_matmul_order.cc | 132 +++++--
.../relax/test_transform_adjust_matmul_order.py | 408 +++++++++++++++++++--
4 files changed, 552 insertions(+), 66 deletions(-)
diff --git a/src/relax/op/op_common.cc b/src/relax/op/op_common.cc
index 61485b0911..a019b87f3a 100644
--- a/src/relax/op/op_common.cc
+++ b/src/relax/op/op_common.cc
@@ -22,6 +22,7 @@
#include <tvm/ffi/cast.h>
#include <algorithm>
+#include <sstream>
namespace tvm {
namespace relax {
@@ -108,10 +109,10 @@ ffi::Array<TensorStructInfo>
GetTensorStructInfoFromTuple(const Call& call, cons
return tensor_sinfo;
}
-ffi::Optional<ffi::Array<PrimExpr>> InferBinaryBroadcastShape(
- const Call& call, const BlockBuilder& ctx, const ffi::Array<PrimExpr>&
x1_shape,
- const ffi::Array<PrimExpr>& x2_shape) {
- arith::Analyzer* analyzer = ctx->GetAnalyzer();
+BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer*
analyzer,
+ const
ffi::Array<PrimExpr>& x1_shape,
+ const
ffi::Array<PrimExpr>& x2_shape) {
+ BinaryBroadcastShapeInferResult result;
int x1_ndim = x1_shape.size();
int x2_ndim = x2_shape.size();
int max_ndim = std::max(x1_ndim, x2_ndim);
@@ -132,20 +133,45 @@ ffi::Optional<ffi::Array<PrimExpr>>
InferBinaryBroadcastShape(
} else if (analyzer->CanProveEqual(dim0, dim1)) {
output_shape.push_back(dim0);
} else if (int_dim0 && int_dim1 && int_dim0->value != int_dim1->value) {
- ctx->ReportFatal(Diagnostic::Error(call)
- << "In " << call->op << ", the first input shape at dim
" << x1_ndim - i
- << " is " << dim0 << " and the second input shape at
dim " << x2_ndim - i
- << " is " << dim1 << ", which are not broadcastable.");
+ result.status = BinaryBroadcastShapeInferResult::Status::kConflict;
+ result.message = [&]() {
+ std::ostringstream os;
+ os << "the first input shape at dim " << x1_ndim - i << " is " << dim0
+ << " and the second input shape at dim " << x2_ndim - i << " is "
<< dim1
+ << ", which are not broadcastable.";
+ return ffi::String(os.str());
+ }();
+ return result;
} else {
- // Use simple fallback when shape mismatch.
- return std::nullopt;
+ result.status = BinaryBroadcastShapeInferResult::Status::kUnknown;
+ return result;
}
}
auto& longer_shape = (x1_ndim > x2_ndim) ? x1_shape : x2_shape;
for (; i <= max_ndim; ++i) {
output_shape.push_back(longer_shape[max_ndim - i]);
}
- return ffi::Array<PrimExpr>(output_shape.rbegin(), output_shape.rend());
+ result.status = BinaryBroadcastShapeInferResult::Status::kSuccess;
+ result.shape = ffi::Array<PrimExpr>(output_shape.rbegin(),
output_shape.rend());
+ return result;
+}
+
+ffi::Optional<ffi::Array<PrimExpr>> InferBinaryBroadcastShape(
+ const Call& call, const BlockBuilder& ctx, const ffi::Array<PrimExpr>&
x1_shape,
+ const ffi::Array<PrimExpr>& x2_shape) {
+ auto infer_result = InferBinaryBroadcastShape(ctx->GetAnalyzer(), x1_shape,
x2_shape);
+ if (infer_result.status ==
BinaryBroadcastShapeInferResult::Status::kConflict) {
+ TVM_FFI_ICHECK(infer_result.message.has_value());
+ ctx->ReportFatal(Diagnostic::Error(call)
+ << "In " << call->op << ", " <<
infer_result.message.value());
+ } else if (infer_result.status ==
BinaryBroadcastShapeInferResult::Status::kSuccess) {
+ TVM_FFI_ICHECK(infer_result.shape.has_value());
+ return infer_result.shape.value();
+ } else {
+ // Unknown status, use simple fallback when shape mismatch.
+ return std::nullopt;
+ }
+ TVM_FFI_UNREACHABLE();
}
std::vector<int> NormalizeAxes(const Call& call, const BlockBuilder& ctx, int
ndim,
diff --git a/src/relax/op/op_common.h b/src/relax/op/op_common.h
index 774eccfd58..6f7de974cb 100644
--- a/src/relax/op/op_common.h
+++ b/src/relax/op/op_common.h
@@ -387,6 +387,36 @@ inline ffi::Optional<VDevice>
InferBinaryArithOpOutVDevice(const Call& call,
return lhs_vdevice;
}
+/*! \brief Result of binary broadcast shape inference without diagnostic
context. */
+struct BinaryBroadcastShapeInferResult {
+ enum class Status {
+ /*! \brief Broadcast output shape is known. */
+ kSuccess,
+ /*! \brief Shapes may be broadcastable but cannot be proved symbolically.
*/
+ kUnknown,
+ /*! \brief Concrete shapes are not broadcastable. */
+ kConflict,
+ };
+
+ /*! \brief Inference status. */
+ Status status = Status::kUnknown;
+ /*! \brief Broadcasted shape if status is kSuccess. */
+ ffi::Optional<ffi::Array<PrimExpr>> shape;
+ /*! \brief Human-readable conflict description if status is kConflict. */
+ ffi::Optional<ffi::String> message;
+};
+
+/*!
+ * \brief Infer the output shape for binary broadcast operators.
+ * \param analyzer The arithmetic analyzer used to prove shape equality.
+ * \param x1_shape The shape of the first operand.
+ * \param x2_shape The shape of the second operand.
+ * \return Inference status and broadcasted shape, or a conflict message.
+ */
+BinaryBroadcastShapeInferResult InferBinaryBroadcastShape(arith::Analyzer*
analyzer,
+ const
ffi::Array<PrimExpr>& x1_shape,
+ const
ffi::Array<PrimExpr>& x2_shape);
+
/*!
* \brief Infer the output shape for binary broadcast operators.
* \param call The context Call to the operator.
diff --git a/src/relax/transform/adjust_matmul_order.cc
b/src/relax/transform/adjust_matmul_order.cc
index 9ea47aa648..012c8ce5b7 100644
--- a/src/relax/transform/adjust_matmul_order.cc
+++ b/src/relax/transform/adjust_matmul_order.cc
@@ -34,6 +34,7 @@
#include <unordered_set>
#include <vector>
+#include "../op/op_common.h"
#include "../op/tensor/linear_algebra.h"
#include "../op/tensor/manipulate.h"
@@ -41,6 +42,27 @@ namespace tvm {
namespace relax {
namespace {
+
+ffi::Array<PrimExpr> GetBatchPrefix(const ffi::Array<PrimExpr>& shape) {
+ if (shape.size() <= 2) return {};
+ return {shape.begin(), shape.end() - 2};
+}
+
+PrimExpr ProductDims(const ffi::Array<PrimExpr>& dims) {
+ PrimExpr product = IntImm(DataType::Int(64), 1);
+ for (const auto& dim : dims) product = product * dim;
+ return product;
+}
+
+ffi::Optional<ffi::Array<PrimExpr>> InferBatchedMatmulBroadcastPrefix(
+ arith::Analyzer* analyzer, const ffi::Array<PrimExpr>& x1, const
ffi::Array<PrimExpr>& x2) {
+ auto infer_result = InferBinaryBroadcastShape(analyzer, x1, x2);
+ if (infer_result.status ==
BinaryBroadcastShapeInferResult::Status::kSuccess) {
+ return infer_result.shape;
+ }
+ return std::nullopt;
+}
+
std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr, ffi::Map<DFPattern,
Expr>)>> CreatePatterns(
const Function& func) {
auto compile_time_arr = ComputableAtCompileTime(func);
@@ -141,20 +163,46 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr,
ffi::Map<DFPattern, Expr>)>>
auto shape_b = opt_shape_b.value();
auto shape_c = opt_shape_c.value();
+ auto permute_last_two_dims = [&](Expr expr) -> Expr {
+ auto opt_shape = get_shape(expr);
+ if (!opt_shape) return expr;
+
+ size_t ndim = opt_shape.value().size();
+ TVM_FFI_ICHECK_GE(ndim, 2);
+
+ ffi::Optional<ffi::Array<int64_t>> axes;
+
+ if (ndim == 2) {
+ // Pass none axes to permute_dims for simple transpose of 2D tensors.
+ axes = std::nullopt;
+ } else {
+ ffi::Array<int64_t> axes_array;
+ for (size_t i = 0; i < ndim; ++i) axes_array.push_back(i);
+ axes_array.Set(ndim - 1, ndim - 2);
+ axes_array.Set(ndim - 2, ndim - 1);
+ axes = ffi::Optional<ffi::Array<int64_t>>(axes_array);
+ }
+ return permute_dims(std::move(expr), axes);
+ };
+
+ auto transpose_shape_last_two_dims = [&](ffi::Array<PrimExpr>& shape) {
+ PrimExpr last_dim_shape = shape[shape.size() - 1];
+ shape.Set(shape.size() - 1, shape[shape.size() - 2]);
+ shape.Set(shape.size() - 2, last_dim_shape);
+ };
+
if (matches.count(pat_permuted_matmul_on_lhs)) {
- expr_a = permute_dims(expr_a, std::nullopt);
- expr_b = permute_dims(expr_b, std::nullopt);
- TVM_FFI_ICHECK_EQ(shape_a.size(), 2);
- TVM_FFI_ICHECK_EQ(shape_b.size(), 2);
- shape_a = {shape_a[1], shape_a[0]};
- shape_b = {shape_b[1], shape_b[0]};
+ if (shape_a.size() < 2 || shape_b.size() < 2) return expr;
+ expr_a = permute_last_two_dims(expr_a);
+ expr_b = permute_last_two_dims(expr_b);
+ transpose_shape_last_two_dims(shape_a);
+ transpose_shape_last_two_dims(shape_b);
} else if (matches.count(pat_permuted_matmul_on_rhs)) {
- expr_b = permute_dims(expr_b, std::nullopt);
- expr_c = permute_dims(expr_c, std::nullopt);
- TVM_FFI_ICHECK_EQ(shape_b.size(), 2);
- TVM_FFI_ICHECK_EQ(shape_c.size(), 2);
- shape_b = {shape_b[1], shape_b[0]};
- shape_c = {shape_c[1], shape_c[0]};
+ if (shape_b.size() < 2 || shape_c.size() < 2) return expr;
+ expr_b = permute_last_two_dims(expr_b);
+ expr_c = permute_last_two_dims(expr_c);
+ transpose_shape_last_two_dims(shape_b);
+ transpose_shape_last_two_dims(shape_c);
}
// If two of the three are compile-time, group those two values
@@ -166,13 +214,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr,
ffi::Map<DFPattern, Expr>)>>
}
// Otherwise, select the order that reduces the total number of
- // operations required, assuming a naive matmul.
-
- // Matmul on LHS: ([N,R]*[R,M]) * [M,batch]
- // Matmul on RHS: [N,R] * ([R,M]*[M,batch])
- //
- // LHS first: `N*R*M + N*M*batch = N*M*(R+batch)`
- // RHS first: `N*R*batch + R*M*batch = (N+M)*R*batch`
+ // operations required, assuming a naive matmul (see below).
if (shape_a.size() == 1) {
shape_a = {IntImm(shape_a[0].dtype(), 1), shape_a[0]};
@@ -192,21 +234,54 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr,
ffi::Map<DFPattern, Expr>)>>
shape_c = {shape_c[0], IntImm(shape_c[0].dtype(), 1)};
}
- auto size_N = shape_a[shape_a.size() - 2];
- auto size_R = shape_a[shape_a.size() - 1];
- auto size_M = shape_c[shape_c.size() - 2];
- auto size_B = shape_c[shape_c.size() - 1];
-
- auto ops_with_lhs_first = (size_R + size_B) * size_N * size_M;
- auto ops_with_rhs_first = (size_M + size_N) * size_R * size_B;
+ PrimExpr size_N = shape_a[shape_a.size() - 2]; // row of A
+ PrimExpr size_R = shape_a[shape_a.size() - 1]; // col of A and row of B
+ PrimExpr size_M = shape_c[shape_c.size() - 2]; // row of C and col of B
+ PrimExpr size_B = shape_c[shape_c.size() - 1]; // col of C
arith::Analyzer analyzer;
+ auto prefix_a = GetBatchPrefix(shape_a);
+ auto prefix_b = GetBatchPrefix(shape_b);
+ auto prefix_c = GetBatchPrefix(shape_c);
+
+ auto opt_prefix_ab = InferBatchedMatmulBroadcastPrefix(&analyzer,
prefix_a, prefix_b);
+ if (!opt_prefix_ab) return expr;
+ auto opt_prefix_bc = InferBatchedMatmulBroadcastPrefix(&analyzer,
prefix_b, prefix_c);
+ if (!opt_prefix_bc) return expr;
+ auto opt_prefix_outer_lhs =
+ InferBatchedMatmulBroadcastPrefix(&analyzer, opt_prefix_ab.value(),
prefix_c);
+ if (!opt_prefix_outer_lhs) return expr;
+ auto opt_prefix_outer_rhs =
+ InferBatchedMatmulBroadcastPrefix(&analyzer, prefix_a,
opt_prefix_bc.value());
+ if (!opt_prefix_outer_rhs) return expr;
+
+ PrimExpr batch_ab = ProductDims(opt_prefix_ab.value());
+ PrimExpr batch_bc = ProductDims(opt_prefix_bc.value());
+ PrimExpr batch_outer_lhs = ProductDims(opt_prefix_outer_lhs.value());
+ PrimExpr batch_outer_rhs = ProductDims(opt_prefix_outer_rhs.value());
+
+ // Compare naive matmul FLOPs for two evaluation orders of
+ // matmul(A, matmul(B, C)) vs matmul(matmul(A, B), C)
+ //
+ // Matrix dims (last two axes): A [N, R], B [R, M], C [M, B_last]
+ // Each matmul uses the broadcasted batch prefix of its operands.
+ //
+ // LHS first — matmul(matmul(A, B), C):
+ // batch_ab * N * R * M + batch_outer_lhs * N * M * B_last
+ PrimExpr ops_with_lhs_first =
+ batch_ab * size_N * size_R * size_M + batch_outer_lhs * size_N *
size_M * size_B;
+ // RHS first — matmul(A, matmul(B, C)):
+ // batch_bc * R * M * B_last + batch_outer_rhs * N * R * B_last
+ PrimExpr ops_with_rhs_first =
+ batch_bc * size_R * size_M * size_B + batch_outer_rhs * size_N *
size_R * size_B;
+
analyzer.rewrite_simplify.SetEnabledExtensions(static_cast<arith::RewriteSimplifier::Extension>(
analyzer.rewrite_simplify.GetEnabledExtensions() |
arith::RewriteSimplifier::Extension::kComparisonOfProductAndSum));
With<arith::ConstraintContext> func_attr_constraint(&analyzer,
symbolic_var_constraints);
With<arith::ConstraintContext> analyzer_constraint(
- &analyzer, size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0);
+ &analyzer, batch_ab > 0 && batch_bc > 0 && batch_outer_lhs > 0 &&
batch_outer_rhs > 0 &&
+ size_N > 0 && size_R > 0 && size_M > 0 && size_B > 0);
if (analyzer.CanProve(ops_with_lhs_first < ops_with_rhs_first)) {
return matmul(matmul(expr_a, expr_b, DataType::Void()), expr_c,
DataType::Void());
@@ -214,8 +289,7 @@ std::tuple<DFPattern, ffi::TypedFunction<Expr(Expr,
ffi::Map<DFPattern, Expr>)>>
return matmul(expr_a, matmul(expr_b, expr_c, DataType::Void()),
DataType::Void());
}
- // If we cannot determine which order is best, keep the existing
- // order.
+ // If we cannot determine which order is best, keep the existing order.
return expr;
};
diff --git a/tests/python/relax/test_transform_adjust_matmul_order.py
b/tests/python/relax/test_transform_adjust_matmul_order.py
index a086f3abdb..9600c97bda 100644
--- a/tests/python/relax/test_transform_adjust_matmul_order.py
+++ b/tests/python/relax/test_transform_adjust_matmul_order.py
@@ -17,8 +17,11 @@
import inspect
+import numpy as np
import pytest
+import torch
+import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I
@@ -39,7 +42,13 @@ class Base:
class TestLHS(Base):
- """Prefer (x*A)*B instead of x*(A*B)"""
+ """Prefer (x*A)*B instead of x*(A*B)
+
+ LHS first - (x*A)*B:
+ ops = 1*16*2 + 1*2*32 = 96
+ RHS first - x*(A*B):
+ ops = 16*2*32 + 1*16*32 = 1536
+ """
@I.ir_module
class Before:
@@ -67,7 +76,13 @@ class TestLHS(Base):
class TestRHS(Base):
- """Prefer A*(B*x) instead of (A*B)*x"""
+ """Prefer A*(B*x) instead of (A*B)*x
+
+ LHS first - (A*B)*x:
+ ops = 32*2*16 + 32*16*1 = 1536
+ RHS first - A*(B*x):
+ ops = 2*16*1 + 32*2*1 = 96
+ """
@I.ir_module
class Before:
@@ -163,6 +178,13 @@ class TestLHSDynamic(Base):
This case appears when evaluating LoRA-tuned models with a dynamic
rank.
+
+ LHS first - (x*A)*B:
+ ops = 1*16*lora_r + 1*lora_r*32 = 48*lora_r
+ RHS first - x*(A*B):
+ ops = 16*lora_r*32 + 1*16*32 = 512*lora_r + 512
+
+ 48*lora_r can be proved to be less than 512*lora_r + 512, so the LHS first
is preferred.
"""
@I.ir_module
@@ -192,7 +214,15 @@ class TestLHSDynamic(Base):
class TestRHSDynamic(Base):
- """Prefer A*(B*x) instead of (A*B)*x"""
+ """Prefer A*(B*x) instead of (A*B)*x
+
+ LHS first - (A*B)*x:
+ ops = 32*lora_r*16 + 32*16*1 = 512*lora_r + 512
+ RHS first - A*(B*x):
+ ops = lora_r*16*1 + 32*lora_r*1 = 48*lora_r
+
+ 48*lora_r can be proved to be less than 512*lora_r + 512, so the RHS first
is preferred.
+ """
@I.ir_module
class Before:
@@ -234,8 +264,27 @@ class TestIdempotentRHSDynamic(Base):
Expected = TestRHSDynamic.Expected
-class TestLHSDynamicWithBatch(Base):
- """Prefer (x*A)*B instead of x*(A*B)"""
+class TestDynamicWithBatchSymbolic1(Base):
+ """When both batch_size and lora_r are symbolic and it cannot be proven
which
+ is cheaper, LHS or RHS, maintain the existing order.
+
+ `Before` computes `x * (A * B)` with
+ `x: [batch_size, 1, 16]`, `A: [16, lora_r]`, `B: [lora_r, 32]`.
+
+ RHS first - x * (A * B):
+ 16*lora_r*32 + batch_size*1*16*32 = 512*(lora_r + batch_size)
+
+ LHS first - (x * A) * B:
+ batch_size*1*16*lora_r + batch_size*1*lora_r*32 = 48*batch_size*lora_r
+
+ When `batch_size` and `lora_r` are known at compile-time:
+ - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r +
batch_size),
+ the LHS first is preferred.
+ - satisfy the inequality 512*(lora_r + batch_size) <
48*batch_size*lora_r,
+ the RHS first is preferred.
+
+ Without bounds on `batch_size` and `lora_r`, neither side is provably
cheaper.
+ """
@I.ir_module
class Before:
@@ -250,6 +299,31 @@ class TestLHSDynamicWithBatch(Base):
out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight)
return out
+ Expected = Before
+
+
+class TestDynamicWithBatchConcrete1LHSFirst(Base):
+ """With concrete shapes, LHS first is provably cheaper.
+
+ batch_size=4, lora_r=16:
+ LHS first: 48*4*16 = 3072
+ RHS first: 512*(16 + 4) = 10240
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 1, 16]),
+ A: R.Tensor([16, "lora_r"]),
+ B: R.Tensor(["lora_r", 32]),
+ ) -> R.Tensor(["batch_size", 1, 32]):
+ batch_size = T.int64(4)
+ lora_r = T.int64(16) # noqa: F841
+ weight: R.Tensor([16, 32]) = R.matmul(A, B)
+ out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight)
+ return out
+
@I.ir_module
class Expected:
@R.function
@@ -258,15 +332,71 @@ class TestLHSDynamicWithBatch(Base):
A: R.Tensor([16, "lora_r"]),
B: R.Tensor(["lora_r", 32]),
) -> R.Tensor(["batch_size", 1, 32]):
- lora_r = T.int64()
- batch_size = T.int64()
- x: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A)
- x: R.Tensor([batch_size, 1, 32]) = R.matmul(x, B)
- return x
+ batch_size = T.int64(4)
+ lora_r = T.int64(16)
+ weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A)
+ out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B)
+ return out
-class TestRHSDynamicWithBatch(Base):
- """Prefer A*(B*x) instead of (A*B)*x"""
+class TestDynamicWithBatchConcrete1RHSFirst(Base):
+ """With concrete shapes, RHS first is provably cheaper.
+
+ batch_size=64, lora_r=16:
+ LHS first: 48*64*16 = 49152
+ RHS first: 512*(16 + 64) = 40960
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 1, 16]),
+ A: R.Tensor([16, "lora_r"]),
+ B: R.Tensor(["lora_r", 32]),
+ ) -> R.Tensor(["batch_size", 1, 32]):
+ batch_size = T.int64(64)
+ lora_r = T.int64(16)
+ weight: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A)
+ out: R.Tensor([batch_size, 1, 32]) = R.matmul(weight, B)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 1, 16]),
+ A: R.Tensor([16, "lora_r"]),
+ B: R.Tensor(["lora_r", 32]),
+ ) -> R.Tensor(["batch_size", 1, 32]):
+ batch_size = T.int64(64)
+ lora_r = T.int64(16) # noqa: F841
+ weight: R.Tensor([16, 32]) = R.matmul(A, B)
+ out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight)
+ return out
+
+
+class TestDynamicWithBatchSymbolic2(Base):
+ """When both batch_size and lora_r are symbolic and it cannot be proven
which
+ is cheaper, LHS or RHS, maintain the existing order.
+
+ `Before` computes `(A * B) * x` with
+ `A: [32, lora_r]`, `B: [lora_r, 16]`, `x: [batch_size, 16, 1]`.
+
+ LHS first - (A * B) * x:
+ 32*lora_r*16 + batch_size*32*16*1 = 512*(lora_r + batch_size)
+
+ RHS first - A * (B * x):
+ batch_size*lora_r*16*1 + batch_size*32*lora_r*1 = 48*batch_size*lora_r
+
+ When `batch_size` and `lora_r` are known at compile-time:
+ - satisfy the inequality 48*batch_size*lora_r < 512*(lora_r +
batch_size),
+ the RHS first is preferred.
+ - satisfy the inequality 512*(lora_r + batch_size) <
48*batch_size*lora_r,
+ the LHS first is preferred.
+
+ Without bounds on `batch_size` and `lora_r`, neither side is provably
cheaper.
+ """
@I.ir_module
class Before:
@@ -281,6 +411,31 @@ class TestRHSDynamicWithBatch(Base):
out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x)
return out
+ Expected = Before
+
+
+class TestDynamicWithBatchConcrete2RHSFirst(Base):
+ """With concrete shapes, RHS first is provably cheaper.
+
+ batch_size=4, lora_r=16:
+ RHS first: 48*4*16 = 3072
+ LHS first: 512*(16 + 4) = 10240
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 16, 1]),
+ A: R.Tensor([32, "lora_r"]),
+ B: R.Tensor(["lora_r", 16]),
+ ) -> R.Tensor(["batch_size", 32, 1]):
+ batch_size = T.int64(4)
+ lora_r = T.int64(16) # noqa: F841
+ weight: R.Tensor([32, 16]) = R.matmul(A, B)
+ out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x)
+ return out
+
@I.ir_module
class Expected:
@R.function
@@ -289,11 +444,48 @@ class TestRHSDynamicWithBatch(Base):
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 16]),
) -> R.Tensor(["batch_size", 32, 1]):
- lora_r = T.int64()
- batch_size = T.int64()
- x: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x)
- x: R.Tensor([batch_size, 32, 1]) = R.matmul(A, x)
- return x
+ batch_size = T.int64(4)
+ lora_r = T.int64(16)
+ weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x)
+ out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight)
+ return out
+
+
+class TestDynamicWithBatchConcrete2LHSFirst(Base):
+ """With concrete shapes, LHS first is provably cheaper.
+
+ batch_size=64, lora_r=16:
+ RHS first: 48*64*16 = 49152
+ LHS first: 512*(16 + 64) = 40960
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 16, 1]),
+ A: R.Tensor([32, "lora_r"]),
+ B: R.Tensor(["lora_r", 16]),
+ ) -> R.Tensor(["batch_size", 32, 1]):
+ batch_size = T.int64(64)
+ lora_r = T.int64(16)
+ weight: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x)
+ out: R.Tensor([batch_size, 32, 1]) = R.matmul(A, weight)
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(["batch_size", 16, 1]),
+ A: R.Tensor([32, "lora_r"]),
+ B: R.Tensor(["lora_r", 16]),
+ ) -> R.Tensor(["batch_size", 32, 1]):
+ batch_size = T.int64(64)
+ lora_r = T.int64(16) # noqa: F841
+ weight: R.Tensor([32, 16]) = R.matmul(A, B)
+ out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x)
+ return out
class TestNoOpForFullyDynamicOnLHS(Base):
@@ -353,6 +545,11 @@ class TestRHSPermuteDims(Base):
"""Prefer (x*A)*B instead of x*(A*B)
Like `TestRHS`, but the weights on the RHS are transposed.
+
+ Before: x * (BT * AT)
+ ops = 16*2*32 + 1*16*32 = 1536
+ After: (x * BT) * AT
+ ops = 1*16*2 + 1*2*32 = 96
"""
@I.ir_module
@@ -388,6 +585,13 @@ class TestRHSPermuteDimsDynamic(Base):
Like `TestRHSPermuteDims`, but the weights on the RHS have a
dynamic shape.
+
+ Before: x * (BT * AT)
+ ops = 16*lora_r*32 + 1*16*32 = 512*lora_r + 512
+ After: (x * BT) * AT
+ ops = 1*16*lora_r + 1*lora_r*32 = 48*lora_r
+
+ 48*lora_r can be proved to be less than 512*lora_r + 512, so the After is
preferred.
"""
@I.ir_module
@@ -433,15 +637,15 @@ class TestRHSPermuteDimsWithDynamicBatch(Base):
ops_left_to_right = (batch_size + lora_r)*4096*4096
ops_right_to_left = (4096 + 4096)*batch_size*lora_r
- Without an upper bound on `lora_r`, we cannot prove which of these
- is the preferred execution order. With the upper bound, TVM can
- determine the preferred order using the following arithmethic
- reasoning.
+ Without an upper bound on batch_size and`lora_r`, we cannot prove which
+ of these is the preferred execution order.
- (batch_size + lora_r)*4096*4096 < (4096 + 4096)*batch_size*lora_r
- (batch_size + lora_r)*2048 < batch_size*lora_r
- 1/batch_size + 1/lora_r < 1/2048
+ With the upper bound, TVM can determine the preferred order using
+ the following arithmetic reasoning.
+ (batch_size + lora_r)*4096*4096 > (4096 + 4096)*batch_size*lora_r
+ (batch_size + lora_r)*2048 > batch_size*lora_r
+ 1/batch_size + 1/lora_r > 1/2048
"""
@I.ir_module
@@ -452,7 +656,12 @@ class TestRHSPermuteDimsWithDynamicBatch(Base):
A: R.Tensor([4096, "lora_r"]),
B: R.Tensor(["lora_r", 4096]),
) -> R.Tensor(["batch_size", 4096]):
- R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}})
+ R.func_attr(
+ {
+ "tir_var_upper_bound": {"lora_r": 2048, "batch_size":
2048},
+ }
+ )
+ lora_r = T.int64() # noqa: F841
batch_size = T.int64()
linear_weight: R.Tensor([4096, 4096]) = R.matmul(A, B)
matmul_weight: R.Tensor([4096, 4096]) =
R.permute_dims(linear_weight)
@@ -467,7 +676,11 @@ class TestRHSPermuteDimsWithDynamicBatch(Base):
A: R.Tensor([4096, "lora_r"]),
B: R.Tensor(["lora_r", 4096]),
) -> R.Tensor(["batch_size", 4096]):
- R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}})
+ R.func_attr(
+ {
+ "tir_var_upper_bound": {"lora_r": 2048, "batch_size":
2048},
+ }
+ )
lora_r = T.int64()
batch_size = T.int64()
B_transpose = R.permute_dims(B)
@@ -482,6 +695,11 @@ class TestRHSPermuteDimsDynamicWithSquareMatrix(Base):
Like `TestRHSPermuteDims`, but the weights on the RHS have a
dynamic shape.
+
+ Before: x * (BT * AT)
+ ops = 32*lora_r*32 + 1*32*32 = 1024*lora_r + 1024
+ After: (x * BT) * AT
+ ops = 1*32*lora_r + 1*lora_r*32 = 64*lora_r
"""
@I.ir_module
@@ -513,5 +731,143 @@ class TestRHSPermuteDimsDynamicWithSquareMatrix(Base):
return x
+class TestBatchedBroadcastPreferLHSFirst(Base):
+ """Use broadcasted batch prefix per matmul, not independent prefix
products.
+
+ Example with broadcast batch axes: A:[2,1,1], B:[2,1,2], C:[2,2,3].
+
+ LHS first: (A * B) * C
+ ops = 2*1*1*2 + 2*1*2*3 = 16
+ RHS first: A * (B * C)
+ ops = 2*1*2*3 + 2*1*1*3 = 18
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ A: R.Tensor([2, 1, 1]),
+ B: R.Tensor([2, 1, 2]),
+ C: R.Tensor([2, 2, 3]),
+ ) -> R.Tensor([2, 1, 3]):
+ out: R.Tensor([2, 1, 3]) = R.matmul(A, R.matmul(B, C))
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ A: R.Tensor([2, 1, 1]),
+ B: R.Tensor([2, 1, 2]),
+ C: R.Tensor([2, 2, 3]),
+ ) -> R.Tensor([2, 1, 3]):
+ temp: R.Tensor([2, 1, 2]) = R.matmul(A, B)
+ out: R.Tensor([2, 1, 3]) = R.matmul(temp, C)
+ return out
+
+
+class TestBatchedSharedPrefixPreferLHSFirst(Base):
+ """All operands share a nontrivial batch prefix [2, 3].
+
+ Shapes: A:[2,3,4,5], B:[2,3,5,6], C:[2,3,6,7]
+
+ LHS first:
+ ops = 6*4*5*6 + 6*4*6*7 = 1728
+ RHS first:
+ ops = 6*5*6*7 + 6*4*5*7 = 2100
+ """
+
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ A: R.Tensor([2, 3, 4, 5]),
+ B: R.Tensor([2, 3, 5, 6]),
+ C: R.Tensor([2, 3, 6, 7]),
+ ) -> R.Tensor([2, 3, 4, 7]):
+ out: R.Tensor([2, 3, 4, 7]) = R.matmul(A, R.matmul(B, C))
+ return out
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ A: R.Tensor([2, 3, 4, 5]),
+ B: R.Tensor([2, 3, 5, 6]),
+ C: R.Tensor([2, 3, 6, 7]),
+ ) -> R.Tensor([2, 3, 4, 7]):
+ temp: R.Tensor([2, 3, 4, 6]) = R.matmul(A, B)
+ out: R.Tensor([2, 3, 4, 7]) = R.matmul(temp, C)
+ return out
+
+
+class TestAdjustMatmulOrderAttentionBlock:
+ """AdjustMatmulOrder preserves numerics on a batched attention block.
+
+ Covers ND `permute_dims` (swap last two axes) inside `matmul(q, kt)`,
+ regression for issue #19576.
+ """
+
+ def _build_attention_module(self, batch, seq, dim):
+ """Minimal batched attention block exercising ND permute_dims +
matmul."""
+ bb = relax.BlockBuilder()
+ x = relax.Var("x", relax.TensorStructInfo((batch, seq, dim),
"float32"))
+ wq = relax.Var("wq", relax.TensorStructInfo((dim, dim), "float32"))
+ wk = relax.Var("wk", relax.TensorStructInfo((dim, dim), "float32"))
+ wv = relax.Var("wv", relax.TensorStructInfo((dim, dim), "float32"))
+ wo = relax.Var("wo", relax.TensorStructInfo((dim, dim), "float32"))
+ with bb.function("main", [x, wq, wk, wv, wo]):
+ with bb.dataflow():
+ q = bb.emit(relax.op.matmul(x, wq))
+ k = bb.emit(relax.op.matmul(x, wk))
+ v = bb.emit(relax.op.matmul(x, wv))
+ kt = bb.emit(relax.op.permute_dims(k, axes=[0, 2, 1]))
+ scores = bb.emit(relax.op.matmul(q, kt))
+ scale = bb.emit(relax.const(1.0 / np.sqrt(dim), "float32"))
+ scores = bb.emit(relax.op.multiply(scores, scale))
+ attn = bb.emit(relax.op.nn.softmax(scores, axis=-1))
+ out = bb.emit(relax.op.matmul(attn, v))
+ proj = bb.emit_output(relax.op.matmul(out, wo))
+ bb.emit_func_output(proj)
+ return bb.finalize()
+
+ def _run_relax_main(self, mod, inputs):
+ exe = relax.build(mod, target="llvm")
+ vm = relax.VirtualMachine(exe, device=tvm.cpu())
+ args = [tvm.runtime.tensor(arr, device=tvm.cpu()) for arr in inputs]
+ return vm["main"](*args).numpy()
+
+ def _torch_attention_ref(self, x_np, w_np, dim):
+ x = torch.from_numpy(x_np)
+ w = torch.from_numpy(w_np)
+ with torch.no_grad():
+ q = torch.matmul(x, w)
+ k = torch.matmul(x, w)
+ v = torch.matmul(x, w)
+ scores = torch.matmul(q, k.transpose(-2, -1))
+ scores = scores * (1.0 / np.sqrt(dim))
+ attn = torch.nn.functional.softmax(scores, dim=-1)
+ out = torch.matmul(attn, v)
+ out = torch.matmul(out, w)
+ return out.detach().numpy()
+
+ @pytest.mark.parametrize("batch,seq,dim", [(2, 16, 64)])
+ def test_attention_block_numerics(self, batch, seq, dim):
+ mod = self._build_attention_module(batch, seq, dim)
+ mod_opt = relax.transform.AdjustMatmulOrder()(mod)
+
+ x_np = np.random.randn(batch, seq, dim).astype("float32")
+ w_np = np.random.randn(dim, dim).astype("float32")
+ inputs = [x_np, w_np, w_np, w_np, w_np]
+
+ ref = self._torch_attention_ref(x_np, w_np, dim)
+ out_before = self._run_relax_main(mod, inputs)
+ out_after = self._run_relax_main(mod_opt, inputs)
+
+ tvm.testing.assert_allclose(out_before, ref, rtol=1e-3, atol=1e-3)
+ tvm.testing.assert_allclose(out_after, ref, rtol=1e-3, atol=1e-3)
+ tvm.testing.assert_allclose(out_before, out_after, rtol=1e-5,
atol=1e-5)
+
+
if __name__ == "__main__":
tvm.testing.main()