This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 4d4f0508a2 [SVE] Support scalable vectors in LoopVectorizer (#16782)
4d4f0508a2 is described below
commit 4d4f0508a2fd903d95ae46472d830cde84e9ce9e
Author: Elen Kalda <[email protected]>
AuthorDate: Tue Apr 9 14:36:32 2024 +0100
[SVE] Support scalable vectors in LoopVectorizer (#16782)
This patch add support for turning loops marked for vectorizing into
scalable vectors if the extent of the loop is a vscale dependent
expression in a correct form.
The testing for both scalable and fixed length vectors in
test_tir_transform.py has been extended and most of the tests
have been converted to TVMScript based testing against expected
output.
Co-authored-by: Luke Hutton <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>
---
include/tvm/runtime/data_type.h | 4 +-
include/tvm/tir/op.h | 11 +-
src/tir/ir/expr.cc | 13 +-
src/tir/transforms/vectorize_loop.cc | 187 +++++++----
.../tir-transform/test_tir_transform_vectorize.py | 361 ++++++++++++++++-----
5 files changed, 428 insertions(+), 148 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 8f3ae9b424..f7284ec690 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -111,7 +111,9 @@ class DataType {
return -lanes_as_int;
}
/*! \return get vscale factor or lanes depending on scalability of the
vector. */
- int get_lanes_or_vscale_factor() { return is_scalable_vector() ?
vscale_factor() : lanes(); }
+ int get_lanes_or_vscale_factor() const {
+ return is_scalable_vector() ? vscale_factor() : lanes();
+ }
/*! \return whether type is a scalar type. */
bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
/*! \return whether type is a scalar type. */
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index ce4a4d6a28..d06bb779d0 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -31,6 +31,7 @@
#include <tvm/ir/expr.h>
#include <tvm/ir/op.h>
#include <tvm/ir/type.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt.h>
@@ -959,10 +960,16 @@ inline PrimExpr MakeConstScalar(DataType t, bool value,
Span span) {
template <typename ValueType, typename>
inline PrimExpr make_const(DataType t, ValueType value, Span span) {
- if (t.lanes() == 1) {
+ if (t.is_scalar()) {
return MakeConstScalar(t, value, span);
} else {
- return tir::Broadcast(MakeConstScalar(t.element_of(), value, span),
t.lanes(), span);
+ if (t.is_fixed_length_vector()) {
+ return tir::Broadcast(MakeConstScalar(t.element_of(), value, span),
t.lanes(), span);
+ } else {
+ PrimExpr lanes =
+ tir::Mul(tir::Call(DataType::Int(32), tir::builtin::vscale(), {}),
t.vscale_factor());
+ return tir::Broadcast(MakeConstScalar(t.element_of(), value, span),
lanes, span);
+ }
}
}
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 90dad72039..2cd2a698de 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -196,7 +196,8 @@ TVM_REGISTER_NODE_TYPE(StringImmNode);
// Cast
Cast::Cast(DataType t, PrimExpr value, Span span) {
ICHECK(value.defined());
- ICHECK_EQ(t.lanes(), value.dtype().lanes());
+ ICHECK_EQ(t.get_lanes_or_vscale_factor(),
value.dtype().get_lanes_or_vscale_factor());
+ ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector());
ObjectPtr<CastNode> node = make_object<CastNode>();
node->dtype = t;
node->value = std::move(value);
@@ -354,7 +355,8 @@ And::And(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
ObjectPtr<AndNode> node = make_object<AndNode>();
- node->dtype = DataType::Bool(a.dtype().lanes());
+ node->dtype =
+ DataType::Bool(a.dtype().get_lanes_or_vscale_factor(),
a.dtype().is_scalable_vector());
node->a = std::move(a);
node->b = std::move(b);
node->span = std::move(span);
@@ -376,7 +378,8 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) {
ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
ObjectPtr<OrNode> node = make_object<OrNode>();
- node->dtype = DataType::Bool(a.dtype().lanes());
+ node->dtype =
+ DataType::Bool(a.dtype().get_lanes_or_vscale_factor(),
a.dtype().is_scalable_vector());
node->a = std::move(a);
node->b = std::move(b);
node->span = std::move(span);
@@ -412,7 +415,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value,
PrimExpr false_value, Sp
ICHECK(true_value.defined()) << "ValueError: true_value is undefined";
ICHECK(false_value.defined()) << "ValueError: true_value is undefined";
ICHECK(condition.dtype().is_bool());
- ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() ||
condition.dtype().lanes() == 1);
+ ICHECK(condition.dtype().get_lanes_or_vscale_factor() ==
+ true_value.dtype().get_lanes_or_vscale_factor() ||
+ condition.dtype().is_scalar());
ICHECK(false_value.dtype() == true_value.dtype())
<< "TypeError: mismatched types. "
<< "False type: " << false_value.dtype() << "; True type: " <<
true_value.dtype();
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index 57536422cf..a9cc497580 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -37,19 +37,36 @@
namespace tvm {
namespace tir {
-// TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
- if (e.dtype().lanes() == lanes) return e;
+inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
+ if (is_scalable) {
+ return Mul(Call(DataType::Int(32), builtin::vscale(), {}),
lanes_or_vscale_factor);
+ } else {
+ return lanes_or_vscale_factor;
+ }
+}
+
+inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
+ // Check if e is already in the expected form
+ if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
+ e.dtype().is_scalable_vector() == is_scalable)
+ return e;
+
if (const BroadcastNode* op = e.as<BroadcastNode>()) {
- ICHECK(!e.dtype().is_scalable_vector());
- int broadcast_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
- if (lanes % broadcast_lanes == 0) {
- return Broadcast(op->value, lanes);
+ ICHECK(op->dtype.is_scalable_vector() == is_scalable)
+ << "Can't broadcast between scalable and fixed length vectors.";
+ int e_lanes = op->dtype.get_lanes_or_vscale_factor();
+
+ if (lanes % e_lanes == 0) {
+ return Broadcast(op->value, CreateNewLanes(is_scalable, lanes));
}
}
- ICHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" <<
e.dtype().lanes() << " to "
- << lanes;
- return Broadcast(e, lanes);
+
+ ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes="
+ << e.dtype().get_lanes_or_vscale_factor()
+ << " is_scalable=" <<
e.dtype().is_scalable_vector() << " to "
+ << lanes;
+
+ return Broadcast(e, CreateNewLanes(is_scalable, lanes));
}
// Rewrite vectorized allocation access
@@ -62,7 +79,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
//
class VecAllocAccess : public StmtExprMutator {
public:
- VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
+ VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes)
: buf_(buf), var_(var), var_lanes_(var_lanes) {}
PrimExpr VisitExpr_(const BufferLoadNode* op) final {
@@ -138,7 +155,7 @@ class VecAllocAccess : public StmtExprMutator {
// variable to be replaced
Var var_;
// the lanes.
- int var_lanes_;
+ PrimExpr var_lanes_;
// Analyzer for simplifications
arith::Analyzer analyzer_;
};
@@ -151,7 +168,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
using ExprFunctor::VisitExpr;
using StmtMutator::operator();
- Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
+ Vectorizer(Var var, PrimExpr var_lanes) : var_(var), var_lanes_(var_lanes) {
ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
}
@@ -182,21 +199,30 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
- // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
- int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
- if (lanes != 1) {
+ bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
+ bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
+ if (is_vec_a && is_vec_b) {
+ // Let's not multiply scalable and fixed length vectors
+ ICHECK(a.dtype().is_scalable_vector() ==
b.dtype().is_scalable_vector())
+ << "Fixed length and scalable vectors can't be mixed in
multiplication.";
+ }
+ if (is_vec_a || is_vec_b) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
- if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
- int lanes = static_cast<int>(Downcast<IntImm>(a_ramp->lanes)->value);
+ if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) {
+ PrimExpr lanes = a_ramp->lanes;
return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
}
- if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
- int lanes = static_cast<int>(Downcast<IntImm>(b_ramp->lanes)->value);
+ if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) {
+ PrimExpr lanes = b_ramp->lanes;
return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes);
}
+ int a_lanes = a.dtype().get_lanes_or_vscale_factor();
+ int b_lanes = b.dtype().get_lanes_or_vscale_factor();
+ int max_lanes = std::max(a_lanes, b_lanes);
+ bool is_scalable = a.dtype().is_scalable_vector() ||
b.dtype().is_scalable_vector();
+ return Mul(BroadcastTo(a, max_lanes, is_scalable), BroadcastTo(b,
max_lanes, is_scalable));
}
- return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
}
return BinaryVec<Mul>(op);
}
@@ -227,18 +253,24 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
PrimExpr VisitExpr_(const RampNode* op) final {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
- // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
- int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
- if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
+ ICHECK(!base.dtype().is_scalable_vector())
+ << "Creating scalable vectors from existing vectors is not supported.";
+ ICHECK(!stride.dtype().is_scalable_vector())
+ << "Ramp stride with scalable dtype is not supported";
+ if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) {
+ ICHECK(op->lanes->IsInstance<IntImmNode>())
+ << "Vectorizing over existing scalable vectors is not supported.";
const RampNode* base_ramp = base.as<RampNode>();
+ int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
int base_ramp_lanes =
static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
- if (analyzer_.CanProve(base_ramp->stride == stride *
make_const(stride.dtype(), op_lanes))) {
+ if (analyzer_.CanProve(base_ramp->stride ==
+ stride * make_const(stride.dtype(),
base_ramp_lanes))) {
return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
}
}
int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
- base = BroadcastTo(base, lanes);
- stride = BroadcastTo(stride, lanes);
+ base = BroadcastTo(base, lanes, false);
+ stride = BroadcastTo(stride, lanes, false);
Array<PrimExpr> elems;
for (int i = 0; i < lanes; ++i) {
elems.push_back(
@@ -249,7 +281,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
PrimExpr VisitExpr_(const BroadcastNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
- if (value.dtype().lanes() != 1) {
+ if (value.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
@@ -267,16 +299,27 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (cond.same_as(op->condition) && t.same_as(op->true_value) &&
f.same_as(op->false_value)) {
return GetRef<PrimExpr>(op);
} else {
- int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()),
f.dtype().lanes());
- return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
+ int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
+ int t_lanes = t.dtype().get_lanes_or_vscale_factor();
+ int f_lanes = f.dtype().get_lanes_or_vscale_factor();
+ int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes);
+ bool is_scalable = cond.dtype().is_scalable_vector() ||
t.dtype().is_scalable_vector() ||
+ f.dtype().is_scalable_vector();
+ return Select(BroadcastTo(cond, lanes, is_scalable), BroadcastTo(t,
lanes, is_scalable),
+ BroadcastTo(f, lanes, is_scalable));
}
}
+
PrimExpr VisitExpr_(const CastNode* op) final {
PrimExpr value = this->VisitExpr(op->value);
if (value.same_as(op->value)) {
return GetRef<PrimExpr>(op);
} else {
- return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
+ if (value.dtype().is_scalable_vector()) {
+ return
Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()),
value);
+ } else {
+ return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
+ }
}
}
@@ -312,10 +355,17 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) &&
f.same_as(op->args[2])) {
return GetRef<PrimExpr>(op);
} else {
- int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
- t = BroadcastTo(t, lanes);
- f = BroadcastTo(f, lanes);
- return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
+ int t_lanes = t.dtype().get_lanes_or_vscale_factor();
+ int f_lanes = f.dtype().get_lanes_or_vscale_factor();
+ int lanes = std::max(t_lanes, f_lanes);
+ bool is_scalable = t.dtype().is_scalable_vector() ||
f.dtype().is_scalable_vector();
+ t = BroadcastTo(t, lanes, is_scalable);
+ f = BroadcastTo(f, lanes, is_scalable);
+ if (is_scalable) {
+ return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{cond, t, f});
+ } else {
+ return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
+ }
}
}
// Reinterpret expr
@@ -325,8 +375,12 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (value.same_as(op->args[0])) {
return GetRef<PrimExpr>(op);
} else {
- int lanes = value.dtype().lanes();
- return Call(op->dtype.with_lanes(lanes), op->op, {value});
+ int lanes = value.dtype().get_lanes_or_vscale_factor();
+ if (value.dtype().is_scalable_vector()) {
+ return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op,
{value});
+ } else {
+ return Call(op->dtype.with_lanes(lanes), op->op, {value});
+ }
}
}
// Call
@@ -351,7 +405,8 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
return MutateReinterpretExpr_(op);
}
auto optional_op = op->op.as<Op>();
- bool vectorizable = optional_op &&
op_vectorizable_.get(optional_op.value(), false);
+ bool vectorizable = optional_op &&
op_vectorizable_.get(optional_op.value(), false) &&
+ !op->dtype.is_scalable_vector();
if (!vectorizable) {
// Cannot vectorize this op
@@ -409,7 +464,8 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
ICHECK(deep_equal_(it->second, value))
<< "Let cannot bind the same var to two different values";
}
- if (value.dtype().lanes() != op->value.dtype().lanes()) {
+ if (value.dtype().get_lanes_or_vscale_factor() !=
+ op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
return Let(new_var, value, this->VisitExpr(op->body));
@@ -433,20 +489,28 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
PrimExpr value = this->VisitExpr(op->value);
if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
+ ICHECK(!op->buffer->dtype.is_scalable_vector())
+ << "Vectorizing over scalable buffer elements is not supported in
vectorizer.";
// How many lanes of indexing are present in the index and
- // buffer element type, excluding the last index. T
+ // buffer element type, excluding the last index.
int other_index_lanes = op->buffer->dtype.lanes();
for (size_t i = 0; i < indices.size() - 1; i++) {
other_index_lanes *= indices[i].dtype().lanes();
+ // Only allow the last index to be scalable
+ ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last
index can be scalable.";
}
// The total number of lanes of indexing, including the last index.
- int index_lanes = other_index_lanes * indices[indices.size() -
1].dtype().lanes();
+ auto last_index_dtype = indices[indices.size() - 1].dtype();
+ int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor();
+ int index_lanes = other_index_lanes * lanes_in_last_index;
// The total number of lanes in this store operation. Either
// the index or the value will be broadcast out to this number
// of lanes, depending on which has more lanes.
- int total_lanes = std::max(index_lanes, value.dtype().lanes());
+ int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor();
+ bool is_last_index_scalable = last_index_dtype.is_scalable_vector();
+ int total_lanes = std::max(index_lanes, value_dtype_lanes);
ICHECK_EQ(total_lanes % other_index_lanes, 0)
<< "When storing to buffer " << op->buffer->name << ", cannot
produce " << total_lanes
@@ -455,11 +519,12 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
// Broadcast the last index such that the total number of index
// lanes matches the desired number.
- indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1],
last_index_lanes));
+ indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1],
last_index_lanes,
+ is_last_index_scalable));
auto writer = store.CopyOnWrite();
writer->indices = indices;
- writer->value = BroadcastTo(value, total_lanes);
+ writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable);
}
return std::move(store);
@@ -512,7 +577,8 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is
binded twice";
let_binding_[op->var] = value;
- if (value.dtype().lanes() != op->value.dtype().lanes()) {
+ if (value.dtype().get_lanes_or_vscale_factor() !=
+ op->value.dtype().get_lanes_or_vscale_factor()) {
Var new_var(op->var->name_hint, value.dtype());
let_binding_[op->var] = new_var;
return LetStmt(new_var, value, this->VisitStmt(op->body));
@@ -566,8 +632,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
Stmt Scalarize(Stmt stmt) {
Var idx(var_->name_hint + ".s", var_->dtype);
stmt = Substitute(stmt, {{var_, idx}});
- return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_),
ForKind::kSerial,
- stmt);
+ return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial,
stmt);
}
// ProducerStore
Stmt VisitStmt_(const ProducerStoreNode* op) final {
@@ -582,7 +647,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
// variable to be replaced
Var var_;
// the lanes.
- int var_lanes_;
+ PrimExpr var_lanes_;
// ramp representing the var.
PrimExpr ramp_;
// flag to mark requirment of scalarization.
@@ -609,7 +674,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
for (size_t i = 0; i < arr.size(); ++i) {
if (new_arr[i].dtype().lanes() != lanes) {
- new_arr[i] = BroadcastTo(new_arr[i], lanes);
+ new_arr[i] = BroadcastTo(new_arr[i], lanes, false);
changed = true;
}
}
@@ -624,8 +689,11 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
- int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
- return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
+ int a_lanes = a.dtype().get_lanes_or_vscale_factor();
+ int b_lanes = b.dtype().get_lanes_or_vscale_factor();
+ int lanes = std::max(a_lanes, b_lanes);
+ bool is_scalable = a.dtype().is_scalable_vector() ||
b.dtype().is_scalable_vector();
+ return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes,
is_scalable));
}
}
template <typename T, typename FCompute>
@@ -635,19 +703,22 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
- int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
+ int a_lanes = a.dtype().get_lanes_or_vscale_factor();
+ int b_lanes = b.dtype().get_lanes_or_vscale_factor();
+ int lanes = std::max(a_lanes, b_lanes);
if (lanes != 1) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
- if (a.dtype().lanes() == 1 && b_ramp) {
+ if (a.dtype().is_scalar() && b_ramp) {
return Ramp(fcompute(a, b_ramp->base),
fcompute(make_zero(b_ramp->stride.dtype()),
b_ramp->stride), b_ramp->lanes);
}
- if (b.dtype().lanes() == 1 && a_ramp) {
+ if (b.dtype().is_scalar() && a_ramp) {
return Ramp(fcompute(a_ramp->base, b), a_ramp->stride,
a_ramp->lanes);
}
}
- return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
+ bool is_scalable = a.dtype().is_scalable_vector() ||
b.dtype().is_scalable_vector();
+ return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b,
lanes, is_scalable));
}
}
};
@@ -657,11 +728,7 @@ class LoopVectorizer : public StmtMutator {
Stmt VisitStmt_(const ForNode* op) final {
if (op->kind == ForKind::kVectorized) {
ICHECK(is_zero(op->min));
- auto* extent_as_int = op->extent.as<IntImmNode>();
- if (!extent_as_int || extent_as_int->value < 1) {
- LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
- }
- return Vectorizer(op->loop_var,
static_cast<int>(extent_as_int->value))(op->body);
+ return Vectorizer(op->loop_var, op->extent)(op->body);
} else {
return StmtMutator::VisitStmt_(op);
}
diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py
b/tests/python/tir-transform/test_tir_transform_vectorize.py
index 7d0fac2423..dbca006b19 100644
--- a/tests/python/tir-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tir-transform/test_tir_transform_vectorize.py
@@ -19,32 +19,29 @@ import tvm.testing
from tvm import te
from tvm.script import ir as I
from tvm.script import tir as T
+import pytest
-def test_vectorize_loop():
- dtype = "int64"
- n = te.var("n")
- ib = tvm.tir.ir_builder.create()
- A = ib.pointer("float32", name="A")
- with ib.for_range(0, n) as i:
- with ib.for_range(0, 4, kind="vectorize") as j:
- A[j] = tvm.tir.const(1, A.dtype)
- stmt = ib.get()
-
- assert isinstance(stmt.body, tvm.tir.For)
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_loop(extent):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((16,), "float32")):
+ for j in T.vectorized(0, extent):
+ A[j] = 1
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
- stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((16,), "float32")):
+ A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent)
- assert isinstance(stmt, tvm.tir.For)
- assert not isinstance(stmt.body, tvm.tir.For)
- assert len(stmt.body.indices) == 1
- assert isinstance(stmt.body.indices[0], tvm.tir.Ramp)
- assert isinstance(stmt.body.value, tvm.tir.Broadcast)
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
def test_vectorize_vector():
- dtype = "int64"
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32x4", name="A")
@@ -64,28 +61,90 @@ def test_vectorize_vector():
assert isinstance(stmt.body.value, tvm.tir.Broadcast)
-def test_vectorize_with_if():
- n = te.var("n")
- x = te.var("x")
- ib = tvm.tir.ir_builder.create()
- A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, kind="vectorize") as i:
- with ib.if_scope(x < n):
- A[i] = A[i] + 1
- with ib.else_scope():
- with ib.if_scope(i < n):
- A[i] = 2.0
- stmt = ib.get()
+def test_vectorize_vector_scalable_error():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32")):
+ for j in T.vectorized(T.vscale() * 4):
+ A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4)
+
+ error_msg = f"Creating scalable vectors from existing vectors is not
supported."
+ with pytest.raises(tvm.error.InternalError, match=error_msg):
+ tvm.tir.transform.VectorizeLoop()(Module)
+
+
+def test_vectorize_vector_scalable_error2():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32xvscalex4")):
+ for j in T.vectorized(4):
+ A[j] = T.Broadcast(T.float32(1), T.vscale() * 4)
+
+ error_msg = f"Vectorizing over scalable buffer elements is not supported
in vectorizer."
+ with pytest.raises(tvm.error.InternalError, match=error_msg):
+ tvm.tir.transform.VectorizeLoop()(Module)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
- stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
- assert isinstance(stmt, tvm.tir.IfThenElse)
- assert len(stmt.then_case.indices) == 1
- assert isinstance(stmt.then_case.indices[0], tvm.tir.Ramp)
- assert isinstance(stmt.then_case.value, tvm.tir.Add)
- assert stmt.then_case.value.dtype == "float32x4"
- assert isinstance(stmt.else_case, tvm.tir.For)
+def test_vectorize_vector_scalable_error3():
+ @I.ir_module
+ class Module:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32")):
+ for j in T.vectorized(4):
+ A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] =
T.Broadcast(
+ T.float32(1), T.vscale() * 4
+ )
+
+ error_msg = f"Vectorizing over existing scalable vectors is not supported."
+ with pytest.raises(tvm.error.InternalError, match=error_msg):
+ tvm.tir.transform.VectorizeLoop()(Module)
+
+
+def test_vectorize_vector_scalable_error4():
+ @I.ir_module
+ class Module:
+ @T.prim_func(private=True)
+ def main(A: T.Buffer((25,), "float32")):
+ for j in T.vectorized(T.vscale() * 4):
+ A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] =
T.Broadcast(
+ T.float32(1), T.vscale() * 4
+ )
+
+ error_msg = f"Creating scalable vectors from existing vectors is not
supported."
+ with pytest.raises(tvm.error.InternalError, match=error_msg):
+ tvm.tir.transform.VectorizeLoop()(Module)
+
+
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_with_if(extent):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
+ for i in T.vectorized(extent):
+ if x < n:
+ A[i] = A[i] + T.float32(1)
+ else:
+ if i < n:
+ A[i] = T.float32(2)
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
+ if x < n:
+ A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] +
T.Broadcast(
+ T.float32(1), extent
+ )
+ else:
+ for i_s in range(extent):
+ if i_s < n:
+ A[i_s] = T.float32(2)
+
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
def test_vectorize_with_if_cond_int64():
@@ -98,25 +157,33 @@ def test_vectorize_with_if_cond_int64():
f = tvm.build(s, [A, B], "llvm")
-def test_vectorize_let():
- v = tvm.tir.Var("v", "float32")
- ib = tvm.tir.ir_builder.create()
- A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, kind="vectorize") as i:
- ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body))
- A[i] = v + 2
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_let(extent):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32")):
+ for i in T.vectorized(extent):
+ v = A[i] + T.float32(1)
+ A[i] = v + T.float32(2)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], ib.get()))
- stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
- assert isinstance(stmt, tvm.tir.LetStmt)
- assert stmt.value.dtype == "float32x4"
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32")):
+ v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent)
+ A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent)
+
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
-def test_vectorize_with_le_cond():
[email protected]("extent", (4, tvm.tir.vscale() * 4))
+def test_vectorize_with_le_cond(extent):
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, kind="vectorize") as i:
+ with ib.for_range(0, extent, kind="vectorize") as i:
with ib.if_scope(i <= n):
A[i] = A[i] + 1
stmt = ib.get()
@@ -124,14 +191,16 @@ def test_vectorize_with_le_cond():
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+ # Check that the loop was't vectorised
assert isinstance(stmt, tvm.tir.For)
-def test_vectorize_with_ge_cond():
[email protected]("extent", (4, tvm.tir.vscale() * 4))
+def test_vectorize_with_ge_cond(extent):
n = te.var("n")
ib = tvm.tir.ir_builder.create()
A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, kind="vectorize") as i:
+ with ib.for_range(0, extent, kind="vectorize") as i:
with ib.if_scope(i >= n):
A[i] = A[i] + 1
stmt = ib.get()
@@ -139,39 +208,51 @@ def test_vectorize_with_ge_cond():
mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+ # Check that the loop wasn't vectorised
assert isinstance(stmt, tvm.tir.For)
-def test_vectorize_if_then_else():
- n = te.var("n")
- x = te.var("x")
- ib = tvm.tir.ir_builder.create()
- A = ib.pointer("float32", name="A")
- with ib.for_range(0, 4, kind="vectorize") as i:
- A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i]
+ 1, A[i])
- stmt = ib.get()
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_if_then_else_scalarize(extent):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32")):
+ for i in T.vectorized(extent):
+ A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i])
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
- stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32")):
+ for i_s in range(extent):
+ A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s])
- assert isinstance(stmt, tvm.tir.For)
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
- ib = tvm.tir.ir_builder.create()
- A = ib.pointer("float32", name="A")
- with ib.for_range(0, n) as k:
- with ib.for_range(0, 4, kind="vectorize") as i:
- A[k * 4 + i] = tvm.tir.call_intrin(
- "float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0
- )
- stmt = ib.get()
- assert isinstance(stmt.body, tvm.tir.For)
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_if_then_else_vector(extent):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), n: T.int32):
+ for i in range(n):
+ for j in T.vectorized(extent):
+ A[i * extent + j] = T.if_then_else(i > 0, A[i * extent +
j], 0)
- mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
- stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), n: T.int32):
+ for i in range(n):
+ A[T.Ramp(i * extent, 1, extent)] = T.if_then_else(
+ i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0,
extent)
+ )
- assert not isinstance(stmt.body, tvm.tir.For)
- assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast)
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
def test_vectorize_while_fail():
@@ -229,23 +310,141 @@ def test_vectorize_dtype_mismatch():
tvm.lower(s, [A], "llvm", simple_mode=True)
-def test_vectorize_with_reinterpret():
[email protected](
+ "extent, vec_str", [(16, "float32x16"), (T.vscale() * 8,
"float32xvscalex8")]
+)
+def test_vectorize_with_reinterpret(extent, vec_str):
@I.ir_module
class Before:
@T.prim_func
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
- for i in T.vectorized(0, 16):
+ for i in T.vectorized(0, extent):
B[i] = T.reinterpret("float32", A[i])
@I.ir_module
class After:
@T.prim_func
def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
- B[0:16] = T.reinterpret("float32x16", A[0:16])
+ B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1,
extent)])
mod = tvm.tir.transform.VectorizeLoop()(Before)
tvm.ir.assert_structural_equal(mod, After)
[email protected]("extent", (4, T.vscale() * 4))
[email protected](
+ "op",
+ (
+ T.Mul,
+ T.Add,
+ T.Sub,
+ T.Div,
+ T.Mod,
+ T.FloorDiv,
+ T.FloorMod,
+ T.Min,
+ T.Max,
+ T.EQ,
+ T.LT,
+ T.LE,
+ T.GE,
+ T.GT,
+ T.NE,
+ ),
+)
+def test_vectorize_binary(op, extent):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+ for j in T.vectorized(extent):
+ A[j] = op(T.float32(3), B[j])
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+ A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent),
B[T.Ramp(0, 1, extent)])
+
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
+
+
[email protected]("extent", (4, T.vscale() * 4))
[email protected]("op", (T.And, T.Or))
+def test_vectorize_logical(op, extent):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
+ for j in T.vectorized(extent):
+ A[j] = op(T.bool(1), B[j])
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
+ A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent),
B[T.Ramp(0, 1, extent)])
+
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
+
+
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_select(extent):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+ for j in T.vectorized(extent):
+ A[j] = T.Select(T.bool(True), A[j], B[j])
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+ A[T.Ramp(0, 1, extent)] = T.Select(
+ T.Broadcast(T.bool(True), extent),
+ A[T.Ramp(0, 1, extent)],
+ B[T.Ramp(0, 1, extent)],
+ )
+
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
+
+
[email protected]("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4,
"int32xvscalex4")])
+def test_vectorize_cast(extent, vec_str):
+ @I.ir_module
+ class Before:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
+ for j in T.vectorized(extent):
+ A[j] = T.Cast("int32", B[j])
+
+ @I.ir_module
+ class After:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
+ A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)])
+
+ mod = tvm.tir.transform.VectorizeLoop()(Before)
+ tvm.ir.assert_structural_equal(mod, After)
+
+
+def test_illegal_extent():
+ @I.ir_module(check_well_formed=False)
+ class Mod:
+ @T.prim_func
+ def main(A: T.Buffer((25,), "int32")):
+ n = T.Var("n", dtype="int32")
+ for j in T.vectorized(n):
+ A[j] = 3
+
+ error_msg = f"Invalid expression for scalable lanes n"
+ with pytest.raises(tvm.error.InternalError, match=error_msg):
+ tvm.tir.transform.VectorizeLoop()(Mod)
+
+
if __name__ == "__main__":
tvm.testing.main()