lhutton1 commented on code in PR #16782:
URL: https://github.com/apache/tvm/pull/16782#discussion_r1539232691


##########
src/tir/ir/expr.cc:
##########
@@ -196,7 +196,9 @@ TVM_REGISTER_NODE_TYPE(StringImmNode);
 // Cast
 Cast::Cast(DataType t, PrimExpr value, Span span) {
   ICHECK(value.defined());
-  ICHECK_EQ(t.lanes(), value.dtype().lanes());
+  ICHECK_EQ(t.get_lanes_or_vscale_factor(), 
value.dtype().get_lanes_or_vscale_factor());
+  ICHECK((t.is_scalable_vector() == value.dtype().is_scalable_vector()) ||
+         (!t.is_scalable_vector() && !value.dtype().is_scalable_vector()));

Review Comment:
   I think `a == b` already implies `!a && !b`, so the expression could be 
simplified to just `t.is_scalable_vector() == 
value.dtype().is_scalable_vector()`



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -37,19 +37,36 @@
 namespace tvm {
 namespace tir {
 
-// TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
-  if (e.dtype().lanes() == lanes) return e;
+inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
+  if (is_scalable) {
+    return Mul(Call(DataType::Int(32), builtin::vscale(), {}), 
lanes_or_vscale_factor);
+  } else {
+    return lanes_or_vscale_factor;
+  }
+}
+
+inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
+  // Check if e is already in the expected form
+  if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
+      e.dtype().is_scalable_vector() == is_scalable)
+    return e;
+
   if (const BroadcastNode* op = e.as<BroadcastNode>()) {
-    ICHECK(!e.dtype().is_scalable_vector());
-    int broadcast_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
-    if (lanes % broadcast_lanes == 0) {
-      return Broadcast(op->value, lanes);
+    ICHECK(op->dtype.is_scalable_vector() == is_scalable)
+        << "Can't broadcast between scalable and fixed length vectors.";
+    int e_lanes = is_scalable ? op->dtype.vscale_factor() : op->dtype.lanes();

Review Comment:
   nit: `get_lanes_or_vscale_factor()`



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -433,20 +488,27 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     PrimExpr value = this->VisitExpr(op->value);
 
     if (!indices.same_as(op->indices) || !value.same_as(op->value)) {
+      ICHECK(!op->buffer->dtype.is_scalable_vector())
+          << "Vectorizing over scalable buffer elements is not supported in 
vectorizer.";
       // How many lanes of indexing are present in the index and
-      // buffer element type, excluding the last index.  T
+      // buffer element type, excluding the last index.
       int other_index_lanes = op->buffer->dtype.lanes();
       for (size_t i = 0; i < indices.size() - 1; i++) {
         other_index_lanes *= indices[i].dtype().lanes();
+        // Only allow the last index to be scalable
+        ICHECK(!indices[i].dtype().is_scalable_vector()) << "Only the last 
index can be scalable.";
       }
 
       // The total number of lanes of indexing, including the last index.
-      int index_lanes = other_index_lanes * indices[indices.size() - 
1].dtype().lanes();
+      int lanes_in_last_index = indices[indices.size() - 
1].dtype().get_lanes_or_vscale_factor();
+      int index_lanes = other_index_lanes * lanes_in_last_index;
 
       // The total number of lanes in this store operation.  Either
       // the index or the value will be broadcast out to this number
       // of lanes, depending on which has more lanes.
-      int total_lanes = std::max(index_lanes, value.dtype().lanes());
+      int value_dtype_lanes = value.dtype().get_lanes_or_vscale_factor();
+      bool is_last_index_scalable = indices[indices.size() - 
1].dtype().is_scalable_vector();

Review Comment:
   nit: might be nicer to replace uses of `indices[indices.size() - 1].dtype()` 
with a `last_index_dtype` variable



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -635,19 +701,22 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (a.same_as(op->a) && b.same_as(op->b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
+      int a_lanes = a.dtype().get_lanes_or_vscale_factor();
+      int b_lanes = b.dtype().get_lanes_or_vscale_factor();
+      int lanes = std::max(a_lanes, b_lanes);
       if (lanes != 1) {
         const RampNode* b_ramp = b.as<RampNode>();
         const RampNode* a_ramp = a.as<RampNode>();
-        if (a.dtype().lanes() == 1 && b_ramp) {
+        if (!a.dtype().is_scalable_or_fixed_length_vector() && b_ramp) {

Review Comment:
   `is_scalar`?



##########
tests/python/tir-transform/test_tir_transform_vectorize.py:
##########
@@ -64,28 +61,86 @@ def test_vectorize_vector():
     assert isinstance(stmt.body.value, tvm.tir.Broadcast)
 
 
-def test_vectorize_with_if():
-    n = te.var("n")
-    x = te.var("x")
-    ib = tvm.tir.ir_builder.create()
-    A = ib.pointer("float32", name="A")
-    with ib.for_range(0, 4, kind="vectorize") as i:
-        with ib.if_scope(x < n):
-            A[i] = A[i] + 1
-        with ib.else_scope():
-            with ib.if_scope(i < n):
-                A[i] = 2.0
-    stmt = ib.get()
+def test_vectorize_vector_scalable_error():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32")):
+            for j in T.vectorized(T.vscale() * 4):
+                A[j * 4 : j * 4 + 4] = T.Broadcast(T.float32(1), 4)
 
-    mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([A, n, x], stmt))
-    stmt = tvm.tir.transform.VectorizeLoop()(mod)["main"].body
+    with pytest.raises(tvm.error.InternalError):
+        tvm.tir.transform.VectorizeLoop()(Module)
+
+
+def test_vectorize_vector_scalable_error2():
+    @I.ir_module
+    class Module:
+        @T.prim_func
+        def main(A: T.Buffer((25,), "float32xvscalex4")):
+            for j in T.vectorized(4):
+                A[j] = T.Broadcast(T.float32(1), T.vscale() * 4)
+
+    with pytest.raises(tvm.error.InternalError):

Review Comment:
   nit: for these negative tests, is it possible to check the error message as 
well? It can sometimes help reason about why the test is expected to fail



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -182,21 +199,29 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (a.same_as(op->a) && b.same_as(op->b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-      int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
-      if (lanes != 1) {
+      bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
+      bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
+      if (is_vec_a && is_vec_b) {
+        // Let's not multiply scalable and fixed length vectors
+        ICHECK(a.dtype().is_scalable_vector() == 
b.dtype().is_scalable_vector());
+      }
+      if (is_vec_a || is_vec_b) {
         const RampNode* b_ramp = b.as<RampNode>();
         const RampNode* a_ramp = a.as<RampNode>();
-        if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
-          int lanes = static_cast<int>(Downcast<IntImm>(a_ramp->lanes)->value);
+        if (a_ramp && !b_ramp && analyzer_.CanProve(b > 0)) {

Review Comment:
   Would it be better to use `b.dtype().is_scalar()` here? (just because it's 
closer to the original statement... I'm unsure if `!b_ramp` relaxes this check)



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -182,21 +199,29 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (a.same_as(op->a) && b.same_as(op->b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-      int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
-      if (lanes != 1) {
+      bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
+      bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
+      if (is_vec_a && is_vec_b) {
+        // Let's not multiply scalable and fixed length vectors
+        ICHECK(a.dtype().is_scalable_vector() == 
b.dtype().is_scalable_vector());
+      }
+      if (is_vec_a || is_vec_b) {
         const RampNode* b_ramp = b.as<RampNode>();
         const RampNode* a_ramp = a.as<RampNode>();
-        if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
-          int lanes = static_cast<int>(Downcast<IntImm>(a_ramp->lanes)->value);
+        if (a_ramp && !b_ramp && analyzer_.CanProve(b > 0)) {
+          PrimExpr lanes = a_ramp->lanes;
           return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
         }
-        if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
-          int lanes = static_cast<int>(Downcast<IntImm>(b_ramp->lanes)->value);
+        if (b_ramp && !a_ramp && analyzer_.CanProve(a > 0)) {

Review Comment:
   same as above



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -635,19 +701,22 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (a.same_as(op->a) && b.same_as(op->b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
+      int a_lanes = a.dtype().get_lanes_or_vscale_factor();
+      int b_lanes = b.dtype().get_lanes_or_vscale_factor();
+      int lanes = std::max(a_lanes, b_lanes);
       if (lanes != 1) {
         const RampNode* b_ramp = b.as<RampNode>();
         const RampNode* a_ramp = a.as<RampNode>();
-        if (a.dtype().lanes() == 1 && b_ramp) {
+        if (!a.dtype().is_scalable_or_fixed_length_vector() && b_ramp) {
           return Ramp(fcompute(a, b_ramp->base),
                       fcompute(make_zero(b_ramp->stride.dtype()), 
b_ramp->stride), b_ramp->lanes);
         }
-        if (b.dtype().lanes() == 1 && a_ramp) {
+        if (!b.dtype().is_scalable_or_fixed_length_vector() && a_ramp) {

Review Comment:
   same as above



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -182,21 +199,29 @@ class Vectorizer : public StmtMutator, public 
ExprFunctor<PrimExpr(const PrimExp
     if (a.same_as(op->a) && b.same_as(op->b)) {
       return GetRef<PrimExpr>(op);
     } else {
-      // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-      int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
-      if (lanes != 1) {
+      bool is_vec_a = a.dtype().is_scalable_or_fixed_length_vector();
+      bool is_vec_b = b.dtype().is_scalable_or_fixed_length_vector();
+      if (is_vec_a && is_vec_b) {
+        // Let's not multiply scalable and fixed length vectors
+        ICHECK(a.dtype().is_scalable_vector() == 
b.dtype().is_scalable_vector());

Review Comment:
   Worth adding a message to this?



##########
src/tir/transforms/vectorize_loop.cc:
##########
@@ -37,19 +37,36 @@
 namespace tvm {
 namespace tir {
 
-// TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
-inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
-  if (e.dtype().lanes() == lanes) return e;
+inline PrimExpr CreateNewLanes(bool is_scalable, int lanes_or_vscale_factor) {
+  if (is_scalable) {
+    return Mul(Call(DataType::Int(32), builtin::vscale(), {}), 
lanes_or_vscale_factor);
+  } else {
+    return lanes_or_vscale_factor;
+  }
+}
+
+inline PrimExpr BroadcastTo(PrimExpr e, int lanes, bool is_scalable) {
+  // Check if e is already in the expected form
+  if (e.dtype().get_lanes_or_vscale_factor() == lanes &&
+      e.dtype().is_scalable_vector() == is_scalable)

Review Comment:
   This logic for fixed width vectors seems to have changed here, is that 
intended?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to