This is an automated email from the ASF dual-hosted git repository.

lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4d4f0508a2 [SVE] Support scalable vectors in LoopVectorizer (#16782)
4d4f0508a2 is described below

commit 4d4f0508a2fd903d95ae46472d830cde84e9ce9e
Author: Elen Kalda <[email protected]>
AuthorDate: Tue Apr 9 14:36:32 2024 +0100

    [SVE] Support scalable vectors in LoopVectorizer (#16782)
    
    This patch add support for turning loops marked for vectorizing into
    scalable vectors if the extent of the loop is a vscale dependent
    expression in a correct form.
    
    The testing for both scalable and fixed length vectors in
    test_tir_transform.py has been extended and most of the tests
    have been converted to TVMScript based testing against expected
    output.
    
    Co-authored-by: Luke Hutton <[email protected]>
    Co-authored-by: Neil Hickey <[email protected]>
---
 include/tvm/runtime/data_type.h                    |   4 +-
 include/tvm/tir/op.h                               |  11 +-
 src/tir/ir/expr.cc                                 |  13 +-
 src/tir/transforms/vectorize_loop.cc               | 187 +++++++----
 .../tir-transform/test_tir_transform_vectorize.py  | 361 ++++++++++++++++-----
 5 files changed, 428 insertions(+), 148 deletions(-)

diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index 8f3ae9b424..f7284ec690 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -111,7 +111,9 @@ class DataType {
     return -lanes_as_int;
   }
   /*! \return get vscale factor or lanes depending on scalability of the 
vector. */
-  int get_lanes_or_vscale_factor() { return is_scalable_vector() ? 
vscale_factor() : lanes(); }
+  int get_lanes_or_vscale_factor() const {
+    return is_scalable_vector() ? vscale_factor() : lanes();
+  }
   /*! \return whether type is a scalar type. */
   bool is_scalar() const { return !is_scalable_vector() && lanes() == 1; }
   /*! \return whether type is a scalar type. */
diff --git a/include/tvm/tir/op.h b/include/tvm/tir/op.h
index ce4a4d6a28..d06bb779d0 100644
--- a/include/tvm/tir/op.h
+++ b/include/tvm/tir/op.h
@@ -31,6 +31,7 @@
 #include <tvm/ir/expr.h>
 #include <tvm/ir/op.h>
 #include <tvm/ir/type.h>
+#include <tvm/tir/builtin.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/stmt.h>
 
@@ -959,10 +960,16 @@ inline PrimExpr MakeConstScalar(DataType t, bool value, 
Span span) {
 
 template <typename ValueType, typename>
 inline PrimExpr make_const(DataType t, ValueType value, Span span) {
-  if (t.lanes() == 1) {
+  if (t.is_scalar()) {
     return MakeConstScalar(t, value, span);
   } else {
-    return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), 
t.lanes(), span);
+    if (t.is_fixed_length_vector()) {
+      return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), 
t.lanes(), span);
+    } else {
+      PrimExpr lanes =
+          tir::Mul(tir::Call(DataType::Int(32), tir::builtin::vscale(), {}), 
t.vscale_factor());
+      return tir::Broadcast(MakeConstScalar(t.element_of(), value, span), 
lanes, span);
+    }
   }
 }
 
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 90dad72039..2cd2a698de 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -196,7 +196,8 @@ TVM_REGISTER_NODE_TYPE(StringImmNode);
 // Cast
 Cast::Cast(DataType t, PrimExpr value, Span span) {
   ICHECK(value.defined());
-  ICHECK_EQ(t.lanes(), value.dtype().lanes());
+  ICHECK_EQ(t.get_lanes_or_vscale_factor(), 
value.dtype().get_lanes_or_vscale_factor());
+  ICHECK(t.is_scalable_vector() == value.dtype().is_scalable_vector());
   ObjectPtr<CastNode> node = make_object<CastNode>();
   node->dtype = t;
   node->value = std::move(value);
@@ -354,7 +355,8 @@ And::And(PrimExpr a, PrimExpr b, Span span) {
   ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
 
   ObjectPtr<AndNode> node = make_object<AndNode>();
-  node->dtype = DataType::Bool(a.dtype().lanes());
+  node->dtype =
+      DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), 
a.dtype().is_scalable_vector());
   node->a = std::move(a);
   node->b = std::move(b);
   node->span = std::move(span);
@@ -376,7 +378,8 @@ Or::Or(PrimExpr a, PrimExpr b, Span span) {
   ICHECK(a.dtype() == b.dtype()) << "TypeError: mismatched types";
 
   ObjectPtr<OrNode> node = make_object<OrNode>();
-  node->dtype = DataType::Bool(a.dtype().lanes());
+  node->dtype =
+      DataType::Bool(a.dtype().get_lanes_or_vscale_factor(), 
a.dtype().is_scalable_vector());
   node->a = std::move(a);
   node->b = std::move(b);
   node->span = std::move(span);
@@ -412,7 +415,9 @@ Select::Select(PrimExpr condition, PrimExpr true_value, 
PrimExpr false_value, Sp
   ICHECK(true_value.defined()) << "ValueError: true_value is undefined";
   ICHECK(false_value.defined()) << "ValueError: true_value is undefined";
   ICHECK(condition.dtype().is_bool());
-  ICHECK(condition.dtype().lanes() == true_value.dtype().lanes() || 
condition.dtype().lanes() == 1);
+  ICHECK(condition.dtype().get_lanes_or_vscale_factor() ==
+             true_value.dtype().get_lanes_or_vscale_factor() ||
+         condition.dtype().is_scalar());
   ICHECK(false_value.dtype() == true_value.dtype())
       << "TypeError: mismatched types. "
       << "False type: " << false_value.dtype() << "; True type: " << 
true_value.dtype();
diff --git a/src/tir/transforms/vectorize_loop.cc 
b/src/tir/transforms/vectorize_loop.cc
index 57536422cf..a9cc497580 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -37,19 +37,36 @@
 namespace tvm {
 namespace tir {
 
-// TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
-  if (e.dtype().lanes() == lanes) return e;
+inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
+  if (is_scalable) {
+    return Mul(Call(DataType::Int(32), builtin::vscale(), {}), 
lanes_or_vscale_factor);
+  } else {
+    return lanes_or_vscale_factor;
+  }
+}
+
+inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
+  // Check if e is already in the expected form
+  if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
+      e.dtype().is_scalable_vector() == is_scalable)
+    return e;
+
   if (const BroadcastNode* op = e.as<BroadcastNode>()) {
-    ICHECK(!e.dtype().is_scalable_vector());
-    int broadcast_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
-    if (lanes % broadcast_lanes == 0) {
-      return Broadcast(op->value, lanes);
+    ICHECK(op->dtype.is_scalable_vector() == is_scalable)
+        << "Can't broadcast between scalable and fixed length vectors.";
+    int e_lanes = op->dtype.get_lanes_or_vscale_factor();
+
+    if (lanes % e_lanes == 0) {
+      return Broadcast(op->value, CreateNewLanes(is_scalable, lanes));
     }
   }
-  ICHECK_EQ(e.dtype().lanes(), 1) << "Cannot broadcast lane=" << 
e.dtype().lanes() << " to "
-                                  << lanes;
-  return Broadcast(e, lanes);
+
+  ICHECK(e.dtype().is_scalar()) << "Cannot broadcast lanes="
+                                << e.dtype().get_lanes_or_vscale_factor()
+                                << " is_scalable=" << 
e.dtype().is_scalable_vector() << " to "
+                                << lanes;
+
+  return Broadcast(e, CreateNewLanes(is_scalable, lanes));
 }
 
 // Rewrite vectorized allocation access
@@ -62,7 +79,7 @@ inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
 //
 class VecAllocAccess : public StmtExprMutator {
  public:
-  VecAllocAccess(const VarNode* buf, Var var, int var_lanes)
+  VecAllocAccess(const VarNode* buf, Var var, PrimExpr var_lanes)
       : buf_(buf), var_(var), var_lanes_(var_lanes) {}
 
   PrimExpr VisitExpr_(const BufferLoadNode* op) final {
@@ -138,7 +155,7 @@ class VecAllocAccess : public StmtExprMutator {
   // variable to be replaced
   Var var_;
   // the lanes.
-  int var_lanes_;
+  PrimExpr var_lanes_;
   // Analyzer for simplifications
   arith::Analyzer analyzer_;
 };
@@ -151,7 +168,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
   using ExprFunctor::VisitExpr;
   using StmtMutator::operator();
 
-  Vectorizer(Var var, int var_lanes) : var_(var), var_lanes_(var_lanes) {
+  Vectorizer(Var var, PrimExpr var_lanes) : var_(var), var_lanes_(var_lanes) {
     ramp_ = Ramp(IntImm(var->dtype, 0), IntImm(var->dtype, 1), var_lanes);
   }
 
@@ -182,21 +199,30 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (a.same_as(op->a) && b.same_as(op->b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-      int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
-      if (lanes != 1) {
+      bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
+      bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
+      if (is_vec_a && is_vec_b) {
+        // Let's not multiply scalable and fixed length vectors
+        ICHECK(a.dtype().is_scalable_vector() == 
b.dtype().is_scalable_vector())
+            << "Fixed length and scalable vectors can't be mixed in 
multiplication.";
+      }
+      if (is_vec_a || is_vec_b) {
         const RampNode* b_ramp = b.as<RampNode>();
         const RampNode* a_ramp = a.as<RampNode>();
-        if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
-          int lanes = static_cast<int>(Downcast<IntImm>(a_ramp->lanes)->value);
+        if (a_ramp && b.dtype().is_scalar() && analyzer_.CanProve(b > 0)) {
+          PrimExpr lanes = a_ramp->lanes;
           return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
         }
-        if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
-          int lanes = static_cast<int>(Downcast<IntImm>(b_ramp->lanes)->value);
+        if (b_ramp && a.dtype().is_scalar() && analyzer_.CanProve(a > 0)) {
+          PrimExpr lanes = b_ramp->lanes;
           return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes);
         }
+        int a_lanes = a.dtype().get_lanes_or_vscale_factor();
+        int b_lanes = b.dtype().get_lanes_or_vscale_factor();
+        int max_lanes = std::max(a_lanes, b_lanes);
+        bool is_scalable = a.dtype().is_scalable_vector() || 
b.dtype().is_scalable_vector();
+        return Mul(BroadcastTo(a, max_lanes, is_scalable), BroadcastTo(b, 
max_lanes, is_scalable));
       }
-      return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
     }
     return BinaryVec<Mul>(op);
   }
@@ -227,18 +253,24 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
   PrimExpr VisitExpr_(const RampNode* op) final {
     PrimExpr base = this->VisitExpr(op->base);
     PrimExpr stride = this->VisitExpr(op->stride);
-    // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-    int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
-    if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
+    ICHECK(!base.dtype().is_scalable_vector())
+        << "Creating scalable vectors from existing vectors is not supported.";
+    ICHECK(!stride.dtype().is_scalable_vector())
+        << "Ramp stride with scalable dtype is not supported";
+    if (base.dtype().is_fixed_length_vector() && stride.dtype().is_scalar()) {
+      ICHECK(op->lanes->IsInstance<IntImmNode>())
+          << "Vectorizing over existing scalable vectors is not supported.";
       const RampNode* base_ramp = base.as<RampNode>();
+      int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
       int base_ramp_lanes = 
static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
-      if (analyzer_.CanProve(base_ramp->stride == stride * 
make_const(stride.dtype(), op_lanes))) {
+      if (analyzer_.CanProve(base_ramp->stride ==
+                             stride * make_const(stride.dtype(), 
base_ramp_lanes))) {
         return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
       }
     }
     int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
-    base = BroadcastTo(base, lanes);
-    stride = BroadcastTo(stride, lanes);
+    base = BroadcastTo(base, lanes, false);
+    stride = BroadcastTo(stride, lanes, false);
     Array<PrimExpr> elems;
     for (int i = 0; i < lanes; ++i) {
       elems.push_back(
@@ -249,7 +281,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
 
   PrimExpr VisitExpr_(const BroadcastNode* op) final {
     PrimExpr value = this->VisitExpr(op->value);
-    if (value.dtype().lanes() != 1) {
+    if (value.dtype().is_scalable_or_fixed_length_vector()) {
       need_scalarize_ = true;
       return GetRef<PrimExpr>(op);
     }
@@ -267,16 +299,27 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (cond.same_as(op->condition) && t.same_as(op->true_value) && 
f.same_as(op->false_value)) {
       return GetRef<PrimExpr>(op);
     } else {
-      int lanes = std::max(std::max(cond.dtype().lanes(), t.dtype().lanes()), 
f.dtype().lanes());
-      return Select(cond, BroadcastTo(t, lanes), BroadcastTo(f, lanes));
+      int cond_lanes = cond.dtype().get_lanes_or_vscale_factor();
+      int t_lanes = t.dtype().get_lanes_or_vscale_factor();
+      int f_lanes = f.dtype().get_lanes_or_vscale_factor();
+      int lanes = std::max(std::max(cond_lanes, t_lanes), f_lanes);
+      bool is_scalable = cond.dtype().is_scalable_vector() || 
t.dtype().is_scalable_vector() ||
+                         f.dtype().is_scalable_vector();
+      return Select(BroadcastTo(cond, lanes, is_scalable), BroadcastTo(t, 
lanes, is_scalable),
+                    BroadcastTo(f, lanes, is_scalable));
     }
   }
+
   PrimExpr VisitExpr_(const CastNode* op) final {
     PrimExpr value = this->VisitExpr(op->value);
     if (value.same_as(op->value)) {
       return GetRef<PrimExpr>(op);
     } else {
-      return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
+      if (value.dtype().is_scalable_vector()) {
+        return 
Cast(op->dtype.with_scalable_vscale_factor(value.dtype().vscale_factor()), 
value);
+      } else {
+        return Cast(op->dtype.with_lanes(value.dtype().lanes()), value);
+      }
     }
   }
 
@@ -312,10 +355,17 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (cond.same_as(op->args[0]) && t.same_as(op->args[1]) && 
f.same_as(op->args[2])) {
       return GetRef<PrimExpr>(op);
     } else {
-      int lanes = std::max(t.dtype().lanes(), f.dtype().lanes());
-      t = BroadcastTo(t, lanes);
-      f = BroadcastTo(f, lanes);
-      return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
+      int t_lanes = t.dtype().get_lanes_or_vscale_factor();
+      int f_lanes = f.dtype().get_lanes_or_vscale_factor();
+      int lanes = std::max(t_lanes, f_lanes);
+      bool is_scalable = t.dtype().is_scalable_vector() || 
f.dtype().is_scalable_vector();
+      t = BroadcastTo(t, lanes, is_scalable);
+      f = BroadcastTo(f, lanes, is_scalable);
+      if (is_scalable) {
+        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{cond, t, f});
+      } else {
+        return Call(op->dtype.with_lanes(lanes), op->op, {cond, t, f});
+      }
     }
   }
   // Reinterpret expr
@@ -325,8 +375,12 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (value.same_as(op->args[0])) {
       return GetRef<PrimExpr>(op);
     } else {
-      int lanes = value.dtype().lanes();
-      return Call(op->dtype.with_lanes(lanes), op->op, {value});
+      int lanes = value.dtype().get_lanes_or_vscale_factor();
+      if (value.dtype().is_scalable_vector()) {
+        return Call(op->dtype.with_scalable_vscale_factor(lanes), op->op, 
{value});
+      } else {
+        return Call(op->dtype.with_lanes(lanes), op->op, {value});
+      }
     }
   }
   // Call
@@ -351,7 +405,8 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       return MutateReinterpretExpr_(op);
     }
     auto optional_op = op->op.as<Op>();
-    bool vectorizable = optional_op && 
op_vectorizable_.get(optional_op.value(), false);
+    bool vectorizable = optional_op && 
op_vectorizable_.get(optional_op.value(), false) &&
+                        !op->dtype.is_scalable_vector();
 
     if (!vectorizable) {
       // Cannot vectorize this op
@@ -409,7 +464,8 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
       ICHECK(deep_equal_(it->second, value))
           << "Let cannot bind the same var to two different values";
     }
-    if (value.dtype().lanes() != op->value.dtype().lanes()) {
+    if (value.dtype().get_lanes_or_vscale_factor() !=
+        op->value.dtype().get_lanes_or_vscale_factor()) {
       Var new_var(op->var->name_hint, value.dtype());
       let_binding_[op->var] = new_var;
       return Let(new_var, value, this->VisitExpr(op->body));
@@ -433,20 +489,28 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     PrimExpr value = this->VisitExpr(op->value);
 
     if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
+      ICHECK(!op->buffer->dtype.is_scalable_vector())
+          << "Vectorizing over scalable buffer elements is not supported in 
vectorizer.";
       // How many lanes of indexing are present in the index and
-      // buffer element type, excluding the last index.  T
+      // buffer element type, excluding the last index.
       int other_index_lanes = op->buffer->dtype.lanes();
       for (size_t i = 0; i < indices.size() - 1; i++) {
         other_index_lanes *= indices[i].dtype().lanes();
+        // Only allow the last index to be scalable
+        ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last 
index can be scalable.";
       }
 
       // The total number of lanes of indexing, including the last index.
-      int index_lanes = other_index_lanes * indices[indices.size() - 
1].dtype().lanes();
+      auto last_index_dtype = indices[indices.size() - 1].dtype();
+      int lanes_in_last_index = last_index_dtype.get_lanes_or_vscale_factor();
+      int index_lanes = other_index_lanes * lanes_in_last_index;
 
       // The total number of lanes in this store operation.  Either
       // the index or the value will be broadcast out to this number
       // of lanes, depending on which has more lanes.
-      int total_lanes = std::max(index_lanes, value.dtype().lanes());
+      int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor();
+      bool is_last_index_scalable = last_index_dtype.is_scalable_vector();
+      int total_lanes = std::max(index_lanes, value_dtype_lanes);
 
       ICHECK_EQ(total_lanes % other_index_lanes, 0)
           << "When storing to buffer " << op->buffer->name << ", cannot 
produce " << total_lanes
@@ -455,11 +519,12 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
 
       // Broadcast the last index such that the total number of index
       // lanes matches the desired number.
-      indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], 
last_index_lanes));
+      indices.Set(indices.size() - 1, BroadcastTo(indices[indices.size() - 1], 
last_index_lanes,
+                                                  is_last_index_scalable));
 
       auto writer = store.CopyOnWrite();
       writer->indices = indices;
-      writer->value = BroadcastTo(value, total_lanes);
+      writer->value = BroadcastTo(value, total_lanes, is_last_index_scalable);
     }
 
     return std::move(store);
@@ -512,7 +577,8 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     ICHECK(!let_binding_.count(op->var)) << "SSA violation, a single var is 
binded twice";
     let_binding_[op->var] = value;
 
-    if (value.dtype().lanes() != op->value.dtype().lanes()) {
+    if (value.dtype().get_lanes_or_vscale_factor() !=
+        op->value.dtype().get_lanes_or_vscale_factor()) {
       Var new_var(op->var->name_hint, value.dtype());
       let_binding_[op->var] = new_var;
       return LetStmt(new_var, value, this->VisitStmt(op->body));
@@ -566,8 +632,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
   Stmt Scalarize(Stmt stmt) {
     Var idx(var_->name_hint + ".s", var_->dtype);
     stmt = Substitute(stmt, {{var_, idx}});
-    return For(idx, IntImm(var_->dtype, 0), IntImm(var_->dtype, var_lanes_), 
ForKind::kSerial,
-               stmt);
+    return For(idx, IntImm(var_->dtype, 0), var_lanes_, ForKind::kSerial, 
stmt);
   }
   // ProducerStore
   Stmt VisitStmt_(const ProducerStoreNode* op) final {
@@ -582,7 +647,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
   // variable to be replaced
   Var var_;
   // the lanes.
-  int var_lanes_;
+  PrimExpr var_lanes_;
   // ramp representing the var.
   PrimExpr ramp_;
   // flag to mark requirment of scalarization.
@@ -609,7 +674,7 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
 
     for (size_t i = 0; i < arr.size(); ++i) {
       if (new_arr[i].dtype().lanes() != lanes) {
-        new_arr[i] = BroadcastTo(new_arr[i], lanes);
+        new_arr[i] = BroadcastTo(new_arr[i], lanes, false);
         changed = true;
       }
     }
@@ -624,8 +689,11 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (a.same_as(op->a) && b.same_as(op->b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
-      return TOp(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
+      int a_lanes = a.dtype().get_lanes_or_vscale_factor();
+      int b_lanes = b.dtype().get_lanes_or_vscale_factor();
+      int lanes = std::max(a_lanes, b_lanes);
+      bool is_scalable = a.dtype().is_scalable_vector() || 
b.dtype().is_scalable_vector();
+      return TOp(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, lanes, 
is_scalable));
     }
   }
   template <typename T, typename FCompute>
@@ -635,19 +703,22 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (a.same_as(op->a) && b.same_as(op->b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
+      int a_lanes = a.dtype().get_lanes_or_vscale_factor();
+      int b_lanes = b.dtype().get_lanes_or_vscale_factor();
+      int lanes = std::max(a_lanes, b_lanes);
       if (lanes != 1) {
         const RampNode* b_ramp = b.as<RampNode>();
         const RampNode* a_ramp = a.as<RampNode>();
-        if (a.dtype().lanes() == 1 && b_ramp) {
+        if (a.dtype().is_scalar() && b_ramp) {
           return Ramp(fcompute(a, b_ramp->base),
                       fcompute(make_zero(b_ramp->stride.dtype()), 
b_ramp->stride), b_ramp->lanes);
         }
-        if (b.dtype().lanes() == 1 && a_ramp) {
+        if (b.dtype().is_scalar() && a_ramp) {
           return Ramp(fcompute(a_ramp->base, b), a_ramp->stride, 
a_ramp->lanes);
         }
       }
-      return fcompute(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
+      bool is_scalable = a.dtype().is_scalable_vector() || 
b.dtype().is_scalable_vector();
+      return fcompute(BroadcastTo(a, lanes, is_scalable), BroadcastTo(b, 
lanes, is_scalable));
     }
   }
 };
@@ -657,11 +728,7 @@ class LoopVectorizer : public StmtMutator {
   Stmt VisitStmt_(const ForNode* op) final {
     if (op->kind == ForKind::kVectorized) {
       ICHECK(is_zero(op->min));
-      auto* extent_as_int = op->extent.as<IntImmNode>();
-      if (!extent_as_int || extent_as_int->value < 1) {
-        LOG(FATAL) << "Failed to vectorize loop with extent " << op->extent;
-      }
-      return Vectorizer(op->loop_var, 
static_cast<int>(extent_as_int->value))(op->body);
+      return Vectorizer(op->loop_var, op->extent)(op->body);
     } else {
       return StmtMutator::VisitStmt_(op);
     }
diff --git a/tests/python/tir-transform/test_tir_transform_vectorize.py 
b/tests/python/tir-transform/test_tir_transform_vectorize.py
index 7d0fac2423..dbca006b19 100644
--- a/tests/python/tir-transform/test_tir_transform_vectorize.py
+++ b/tests/python/tir-transform/test_tir_transform_vectorize.py
@@ -19,32 +19,29 @@ import tvm.testing
 from tvm import te
 from tvm.script import ir as I
 from tvm.script import tir as T
+import pytest
 
 
-def test_vectorize_loop():
-    dtype = "int64"
-    n = te.var("n")
-    ib = tvm.tir.ir_builder.create()
-    A = ib.pointer("float32", name="A")
-    with ib.for_range(0, n) as i:
-        with ib.for_range(0, 4, kind="vectorize") as j:
-            A[j] = tvm.tir.const(1, A.dtype)
-    stmt = ib.get()
-
-    assert isinstance(stmt.body, tvm.tir.For)
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_loop(extent):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((16,), "float32")):
+            for j in T.vectorized(0, extent):
+                A[j] = 1
 
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
-    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((16,), "float32")):
+            A[T.Ramp(0, 1, extent)] = T.Broadcast(1, extent)
 
-    assert isinstance(stmt, tvm.tir.For)
-    assert not isinstance(stmt.body, tvm.tir.For)
-    assert len(stmt.body.indices) == 1
-    assert isinstance(stmt.body.indices[0], tvm.tir.Ramp)
-    assert isinstance(stmt.body.value, tvm.tir.Broadcast)
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
 
 
 def test_vectorize_vector():
-    dtype = "int64"
     n = te.var("n")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32x4", name="A")
@@ -64,28 +61,90 @@ def test_vectorize_vector():
     assert isinstance(stmt.body.value, tvm.tir.Broadcast)
 
 
-def test_vectorize_with_if():
-    n = te.var("n")
-    x = te.var("x")
-    ib = tvm.tir.ir_builder.create()
-    A = ib.pointer("float32", name="A")
-    with ib.for_range(0, 4, kind="vectorize") as i:
-        with ib.if_scope(x < n):
-            A[i] = A[i] + 1
-        with ib.else_scope():
-            with ib.if_scope(i < n):
-                A[i] = 2.0
-    stmt = ib.get()
+def test_vectorize_vector_scalable_error():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32")):
+            for j in T.vectorized(T.vscale() * 4):
+                A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4)
+
+    error_msg = f"Creating scalable vectors from existing vectors is not 
supported."
+    with pytest.raises(tvm.error.InternalError, match=error_msg):
+        tvm.tir.transform.VectorizeLoop()(Module)
+
+
+def test_vectorize_vector_scalable_error2():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32xvscalex4")):
+            for j in T.vectorized(4):
+                A[j] = T.Broadcast(T.float32(1), T.vscale() * 4)
+
+    error_msg = f"Vectorizing over scalable buffer elements is not supported 
in vectorizer."
+    with pytest.raises(tvm.error.InternalError, match=error_msg):
+        tvm.tir.transform.VectorizeLoop()(Module)
 
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
-    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
 
-    assert isinstance(stmt, tvm.tir.IfThenElse)
-    assert len(stmt.then_case.indices) == 1
-    assert isinstance(stmt.then_case.indices[0], tvm.tir.Ramp)
-    assert isinstance(stmt.then_case.value, tvm.tir.Add)
-    assert stmt.then_case.value.dtype == "float32x4"
-    assert isinstance(stmt.else_case, tvm.tir.For)
+def test_vectorize_vector_scalable_error3():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32")):
+            for j in T.vectorized(4):
+                A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = 
T.Broadcast(
+                    T.float32(1), T.vscale() * 4
+                )
+
+    error_msg = f"Vectorizing over existing scalable vectors is not supported."
+    with pytest.raises(tvm.error.InternalError, match=error_msg):
+        tvm.tir.transform.VectorizeLoop()(Module)
+
+
+def test_vectorize_vector_scalable_error4():
+    @I.ir_module
+    class Module:
+        @T.prim_func(private=True)
+        def main(A: T.Buffer((25,), "float32")):
+            for j in T.vectorized(T.vscale() * 4):
+                A[j * T.vscale() * 4 : j * T.vscale() * 4 + T.vscale() * 4] = 
T.Broadcast(
+                    T.float32(1), T.vscale() * 4
+                )
+
+    error_msg = f"Creating scalable vectors from existing vectors is not 
supported."
+    with pytest.raises(tvm.error.InternalError, match=error_msg):
+        tvm.tir.transform.VectorizeLoop()(Module)
+
+
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_with_if(extent):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
+            for i in T.vectorized(extent):
+                if x < n:
+                    A[i] = A[i] + T.float32(1)
+                else:
+                    if i < n:
+                        A[i] = T.float32(2)
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), n: T.int32, x: T.int32):
+            if x < n:
+                A[T.Ramp(0, 1, extent)] = A[T.Ramp(0, 1, extent)] + 
T.Broadcast(
+                    T.float32(1), extent
+                )
+            else:
+                for i_s in range(extent):
+                    if i_s < n:
+                        A[i_s] = T.float32(2)
+
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
 
 
 def test_vectorize_with_if_cond_int64():
@@ -98,25 +157,33 @@ def test_vectorize_with_if_cond_int64():
     f = tvm.build(s, [A, B], "llvm")
 
 
-def test_vectorize_let():
-    v = tvm.tir.Var("v", "float32")
-    ib = tvm.tir.ir_builder.create()
-    A = ib.pointer("float32", name="A")
-    with ib.for_range(0, 4, kind="vectorize") as i:
-        ib.emit(lambda body: tvm.tir.LetStmt(v, A[i] + 1, body))
-        A[i] = v + 2
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_let(extent):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32")):
+            for i in T.vectorized(extent):
+                v = A[i] + T.float32(1)
+                A[i] = v + T.float32(2)
 
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A], ib.get()))
-    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
-    assert isinstance(stmt, tvm.tir.LetStmt)
-    assert stmt.value.dtype == "float32x4"
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32")):
+            v = A[T.Ramp(0, 1, extent)] + T.Broadcast(T.float32(1), extent)
+            A[T.Ramp(0, 1, extent)] = v + T.Broadcast(T.float32(2), extent)
+
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
 
 
-def test_vectorize_with_le_cond():
[email protected]("extent", (4, tvm.tir.vscale() * 4))
+def test_vectorize_with_le_cond(extent):
     n = te.var("n")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
-    with ib.for_range(0, 4, kind="vectorize") as i:
+    with ib.for_range(0, extent, kind="vectorize") as i:
         with ib.if_scope(i <= n):
             A[i] = A[i] + 1
     stmt = ib.get()
@@ -124,14 +191,16 @@ def test_vectorize_with_le_cond():
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
     stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
 
+    # Check that the loop was't vectorised
     assert isinstance(stmt, tvm.tir.For)
 
 
-def test_vectorize_with_ge_cond():
[email protected]("extent", (4, tvm.tir.vscale() * 4))
+def test_vectorize_with_ge_cond(extent):
     n = te.var("n")
     ib = tvm.tir.ir_builder.create()
     A = ib.pointer("float32", name="A")
-    with ib.for_range(0, 4, kind="vectorize") as i:
+    with ib.for_range(0, extent, kind="vectorize") as i:
         with ib.if_scope(i >= n):
             A[i] = A[i] + 1
     stmt = ib.get()
@@ -139,39 +208,51 @@ def test_vectorize_with_ge_cond():
     mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
     stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
 
+    # Check that the loop wasn't vectorised
     assert isinstance(stmt, tvm.tir.For)
 
 
-def test_vectorize_if_then_else():
-    n = te.var("n")
-    x = te.var("x")
-    ib = tvm.tir.ir_builder.create()
-    A = ib.pointer("float32", name="A")
-    with ib.for_range(0, 4, kind="vectorize") as i:
-        A[i] = tvm.tir.call_intrin("float32", "tir.if_then_else", i > 0, A[i] 
+ 1, A[i])
-    stmt = ib.get()
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_if_then_else_scalarize(extent):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32")):
+            for i in T.vectorized(extent):
+                A[i] = T.if_then_else(i > 0, A[i] + T.float32(1), A[i])
 
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
-    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32")):
+            for i_s in range(extent):
+                A[i_s] = T.if_then_else(i_s > 0, A[i_s] + T.float32(1), A[i_s])
 
-    assert isinstance(stmt, tvm.tir.For)
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
 
-    ib = tvm.tir.ir_builder.create()
-    A = ib.pointer("float32", name="A")
-    with ib.for_range(0, n) as k:
-        with ib.for_range(0, 4, kind="vectorize") as i:
-            A[k * 4 + i] = tvm.tir.call_intrin(
-                "float32", "tir.if_then_else", k > 0, A[k * 4 + i], 0
-            )
-    stmt = ib.get()
 
-    assert isinstance(stmt.body, tvm.tir.For)
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_if_then_else_vector(extent):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), n: T.int32):
+            for i in range(n):
+                for j in T.vectorized(extent):
+                    A[i * extent + j] = T.if_then_else(i > 0, A[i * extent + 
j], 0)
 
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n], stmt))
-    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), n: T.int32):
+            for i in range(n):
+                A[T.Ramp(i * extent, 1, extent)] = T.if_then_else(
+                    i > 0, A[T.Ramp(i * extent, 1, extent)], T.Broadcast(0, 
extent)
+                )
 
-    assert not isinstance(stmt.body, tvm.tir.For)
-    assert isinstance(stmt.body.value.args[2], tvm.tir.Broadcast)
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
 
 
 def test_vectorize_while_fail():
@@ -229,23 +310,141 @@ def test_vectorize_dtype_mismatch():
     tvm.lower(s, [A], "llvm", simple_mode=True)
 
 
-def test_vectorize_with_reinterpret():
[email protected](
+    "extent, vec_str", [(16, "float32x16"), (T.vscale() * 8, 
"float32xvscalex8")]
+)
+def test_vectorize_with_reinterpret(extent, vec_str):
     @I.ir_module
     class Before:
         @T.prim_func
         def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
-            for i in T.vectorized(0, 16):
+            for i in T.vectorized(0, extent):
                 B[i] = T.reinterpret("float32", A[i])
 
     @I.ir_module
     class After:
         @T.prim_func
         def main(A: T.Buffer((16,), "int32"), B: T.Buffer((16,), "float32")):
-            B[0:16] = T.reinterpret("float32x16", A[0:16])
+            B[T.Ramp(0, 1, extent)] = T.reinterpret(vec_str, A[T.Ramp(0, 1, 
extent)])
 
     mod = tvm.tir.transform.VectorizeLoop()(Before)
     tvm.ir.assert_structural_equal(mod, After)
 
 
[email protected]("extent", (4, T.vscale() * 4))
[email protected](
+    "op",
+    (
+        T.Mul,
+        T.Add,
+        T.Sub,
+        T.Div,
+        T.Mod,
+        T.FloorDiv,
+        T.FloorMod,
+        T.Min,
+        T.Max,
+        T.EQ,
+        T.LT,
+        T.LE,
+        T.GE,
+        T.GT,
+        T.NE,
+    ),
+)
+def test_vectorize_binary(op, extent):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+            for j in T.vectorized(extent):
+                A[j] = op(T.float32(3), B[j])
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+            A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.float32(3), extent), 
B[T.Ramp(0, 1, extent)])
+
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
+
+
[email protected]("extent", (4, T.vscale() * 4))
[email protected]("op", (T.And, T.Or))
+def test_vectorize_logical(op, extent):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
+            for j in T.vectorized(extent):
+                A[j] = op(T.bool(1), B[j])
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "bool"), B: T.Buffer((25,), "bool")):
+            A[T.Ramp(0, 1, extent)] = op(T.Broadcast(T.bool(1), extent), 
B[T.Ramp(0, 1, extent)])
+
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
+
+
[email protected]("extent", (4, T.vscale() * 4))
+def test_vectorize_select(extent):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+            for j in T.vectorized(extent):
+                A[j] = T.Select(T.bool(True), A[j], B[j])
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32"), B: T.Buffer((25,), "float32")):
+            A[T.Ramp(0, 1, extent)] = T.Select(
+                T.Broadcast(T.bool(True), extent),
+                A[T.Ramp(0, 1, extent)],
+                B[T.Ramp(0, 1, extent)],
+            )
+
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
+
+
[email protected]("extent, vec_str", [(4, "int32x4"), (T.vscale() * 4, 
"int32xvscalex4")])
+def test_vectorize_cast(extent, vec_str):
+    @I.ir_module
+    class Before:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
+            for j in T.vectorized(extent):
+                A[j] = T.Cast("int32", B[j])
+
+    @I.ir_module
+    class After:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "int32"), B: T.Buffer((25,), "float32")):
+            A[T.Ramp(0, 1, extent)] = T.Cast(vec_str, B[T.Ramp(0, 1, extent)])
+
+    mod = tvm.tir.transform.VectorizeLoop()(Before)
+    tvm.ir.assert_structural_equal(mod, After)
+
+
+def test_illegal_extent():
+    @I.ir_module(check_well_formed=False)
+    class Mod:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "int32")):
+            n = T.Var("n", dtype="int32")
+            for j in T.vectorized(n):
+                A[j] = 3
+
+    error_msg = f"Invalid expression for scalable lanes n"
+    with pytest.raises(tvm.error.InternalError, match=error_msg):
+        tvm.tir.transform.VectorizeLoop()(Mod)
+
+
 if __name__ == "__main__":
     tvm.testing.main()


Reply via email to