This is an automated email from the ASF dual-hosted git repository.
lukhut pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new a6157a6369 [SVE] Change the dtype of Ramp and Broadcast lanes to
PrimExpr (#16523)
a6157a6369 is described below
commit a6157a6369c184b6fa5f66654feb685e58726737
Author: Elen Kalda <[email protected]>
AuthorDate: Tue Feb 20 09:13:30 2024 +0000
[SVE] Change the dtype of Ramp and Broadcast lanes to PrimExpr (#16523)
This change will allow us to express scalable vectors through Ramp and
Broadcast nodes, e.g.
```
vec = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale())
```
We will use negative values for `runtime::DataType` the encode the scalable
lane values, e.g.
the above example would result in `lanes` = -4. That's because the lanes in
`runtime::DataType`
are tied to DLPack standard which uses `uint16_t` for the lanes. The
conversion happens in the
node definitions and `runtime::DataType`, so the `int` and `uint16_t`
values should never be
exposed to the API user, especially after the string support has been added.
Also include the TVMScript support for scalable Ramp and Broadcasts.
Note that this patch doesn't include lowering to the appropriate LLVM
vectors, support for
data type string representation or `LoopVectorizer` support. All of these
will be part of
future patches.
Co-authored-by: Luke Hutton <[email protected]>
Co-authored-by: Neil Hickey <[email protected]>
---
include/tvm/runtime/data_type.h | 43 ++++++-
include/tvm/tir/expr.h | 8 +-
python/tvm/ir/base.py | 7 +-
python/tvm/ir/json_compact.py | 57 +++++++--
python/tvm/script/ir_builder/tir/ir.py | 2 +-
python/tvm/tir/expr.py | 14 +--
src/arith/const_fold.h | 3 +-
src/arith/int_set.cc | 27 ++--
src/arith/pattern_match.h | 15 ++-
src/arith/rewrite_simplify.cc | 138 ++++++++++++---------
src/arith/scalable_expression.cc | 54 ++++++++
src/arith/scalable_expression.h | 52 ++++++++
src/ir/expr.cc | 4 +-
src/relay/printer/tir_text_printer.cc | 4 +-
src/relay/printer/tvmscript_printer.cc | 4 +-
src/script/ir_builder/tir/ir.cc | 39 +++++-
src/script/printer/tir/expr.cc | 4 +-
src/target/llvm/codegen_arm.cc | 2 +-
src/target/llvm/codegen_hexagon.cc | 2 +-
src/target/llvm/codegen_llvm.cc | 16 ++-
src/target/llvm/codegen_nvptx.cc | 2 +-
src/target/source/codegen_c.cc | 5 +-
src/target/source/codegen_c_host.cc | 3 +-
src/target/source/codegen_cuda.cc | 28 +++--
src/target/source/codegen_metal.cc | 3 +-
src/target/source/codegen_opencl.cc | 8 +-
src/target/source/codegen_webgpu.cc | 3 +-
src/target/spirv/codegen_spirv.cc | 6 +-
src/tir/ir/expr.cc | 68 +++++++---
src/tir/ir/expr_functor.cc | 10 +-
src/tir/ir/stmt.cc | 40 +++++-
src/tir/ir/tir_visitor_with_path.cc | 2 +
src/tir/op/op.cc | 61 +++++++--
src/tir/schedule/analysis/reducer.cc | 32 +++--
src/tir/transforms/bound_checker.cc | 15 ++-
src/tir/transforms/lower_thread_allreduce.cc | 4 +-
src/tir/transforms/renormalize_split_pattern.cc | 2 +-
src/tir/transforms/storage_rewrite.cc | 23 ++--
src/tir/transforms/vectorize_loop.cc | 35 ++++--
tests/cpp/pattern_match_test.cc | 12 +-
tests/cpp/tir_scalable_datatype.cc | 125 +++++++++++++++++++
tests/python/arith/test_arith_intset.py | 8 ++
tests/python/arith/test_arith_rewrite_simplify.py | 36 ++++++
tests/python/relay/test_json_compact.py | 71 +++++++++++
tests/python/tir-base/test_tir_nodes.py | 76 ++++++++++++
.../tvmscript/test_tvmscript_ir_builder_tir.py | 18 +++
.../python/tvmscript/test_tvmscript_printer_tir.py | 27 ++--
tests/python/tvmscript/test_tvmscript_roundtrip.py | 9 ++
48 files changed, 989 insertions(+), 238 deletions(-)
diff --git a/include/tvm/runtime/data_type.h b/include/tvm/runtime/data_type.h
index fcd35f1e2a..5efa5f3b90 100644
--- a/include/tvm/runtime/data_type.h
+++ b/include/tvm/runtime/data_type.h
@@ -71,11 +71,15 @@ class DataType {
* \param code The type code.
* \param bits The number of bits in the type.
* \param lanes The number of lanes.
+ * \param is_scalable Whether the data type is scalable.
*/
- DataType(int code, int bits, int lanes) {
+ DataType(int code, int bits, int lanes, bool is_scalable = false) {
data_.code = static_cast<uint8_t>(code);
data_.bits = static_cast<uint8_t>(bits);
- data_.lanes = static_cast<uint16_t>(lanes);
+ if (is_scalable) {
+ ICHECK(lanes > 1) << "Invalid value for vscale factor" << lanes;
+ }
+ data_.lanes = is_scalable ? static_cast<uint16_t>(-lanes) :
static_cast<uint16_t>(lanes);
if (code == kBFloat) {
ICHECK_EQ(bits, 16);
}
@@ -90,7 +94,21 @@ class DataType {
/*! \return number of bytes to store each scalar. */
int bytes() const { return (bits() + 7) / 8; }
/*! \return number of lanes in the data. */
- int lanes() const { return static_cast<int>(data_.lanes); }
+ int lanes() const {
+ int lanes_as_int = static_cast<int16_t>(data_.lanes);
+ if (lanes_as_int < 0) {
+ LOG(FATAL) << "Can't fetch the lanes of a scalable vector at a compile
time.";
+ }
+ return lanes_as_int;
+ }
+ /*! \return the integer multiplier of vscale in a scalable vector. */
+ int vscale_factor() const {
+ int lanes_as_int = static_cast<int16_t>(data_.lanes);
+ if (lanes_as_int >= -1) {
+ LOG(FATAL) << "A fixed length vector doesn't have a vscale factor.";
+ }
+ return -lanes_as_int;
+ }
/*! \return whether type is a scalar type. */
bool is_scalar() const { return lanes() == 1; }
/*! \return whether type is a scalar type. */
@@ -114,9 +132,16 @@ class DataType {
/*! \return whether type is a handle type. */
bool is_handle() const { return code() == DataType::kHandle && !is_void(); }
/*! \return whether type is a vector type. */
- bool is_vector() const { return lanes() > 1; }
+ bool is_scalable_or_fixed_length_vector() const {
+ int encoded_lanes = static_cast<int16_t>(data_.lanes);
+ return (encoded_lanes < -1) || (1 < encoded_lanes);
+ }
+ /*! \return Whether the type is a fixed length vector. */
+ bool is_fixed_length_vector() const { return
static_cast<int16_t>(data_.lanes) > 1; }
+ /*! \return Whether the type is a scalable vector. */
+ bool is_scalable_vector() const { return static_cast<int16_t>(data_.lanes) <
-1; }
/*! \return whether type is a bool vector type. */
- bool is_vector_bool() const { return is_vector() && bits() == 1; }
+ bool is_vector_bool() const { return is_scalable_or_fixed_length_vector() &&
bits() == 1; }
/*! \return whether type is a Void type. */
bool is_void() const { return code() == DataType::kHandle && bits() == 0 &&
lanes() == 0; }
/*!
@@ -125,6 +150,14 @@ class DataType {
* \return the result type.
*/
DataType with_lanes(int lanes) const { return DataType(data_.code,
data_.bits, lanes); }
+ /*!
+ * \brief Create a new scalable vector data type by changing the vscale
multiplier to a specified
+ * value. We'll use the data_.lanes field for this value. \param
vscale_factor The vscale
+ * multiplier. \return A copy of the old DataType with the number of
scalable lanes.
+ */
+ DataType with_scalable_vscale_factor(int vscale_factor) const {
+ return DataType(data_.code, data_.bits, -vscale_factor);
+ }
/*!
* \brief Create a new data type by change bits to a specified value.
* \param bits The target number of bits.
diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h
index 4e29eddadd..39b32f5633 100644
--- a/include/tvm/tir/expr.h
+++ b/include/tvm/tir/expr.h
@@ -746,7 +746,7 @@ class RampNode : public PrimExprNode {
/*! \brief The stride of each step. */
PrimExpr stride;
/*! \brief Total number of lanes. */
- int lanes;
+ PrimExpr lanes;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
@@ -778,7 +778,7 @@ class RampNode : public PrimExprNode {
*/
class Ramp : public PrimExpr {
public:
- TVM_DLL Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span = Span());
+ TVM_DLL Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span =
Span());
TVM_DEFINE_OBJECT_REF_METHODS(Ramp, PrimExpr, RampNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(RampNode);
};
@@ -789,7 +789,7 @@ class BroadcastNode : public PrimExprNode {
/*! \brief The base value. */
PrimExpr value;
/*! \brief The number of lanes. */
- int lanes;
+ PrimExpr lanes;
void VisitAttrs(AttrVisitor* v) {
v->Visit("dtype", &dtype);
@@ -818,7 +818,7 @@ class BroadcastNode : public PrimExprNode {
*/
class Broadcast : public PrimExpr {
public:
- TVM_DLL Broadcast(PrimExpr value, int lanes, Span span = Span());
+ TVM_DLL Broadcast(PrimExpr value, PrimExpr lanes, Span span = Span());
TVM_DEFINE_OBJECT_REF_METHODS(Broadcast, PrimExpr, BroadcastNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(BroadcastNode);
};
diff --git a/python/tvm/ir/base.py b/python/tvm/ir/base.py
index 21a5ed6576..535b97e62d 100644
--- a/python/tvm/ir/base.py
+++ b/python/tvm/ir/base.py
@@ -126,11 +126,8 @@ def load_json(json_str) -> Object:
The loaded tvm node.
"""
- try:
- return _ffi_node_api.LoadJSON(json_str)
- except tvm.error.TVMError:
- json_str = json_compact.upgrade_json(json_str)
- return _ffi_node_api.LoadJSON(json_str)
+ json_str = json_compact.upgrade_json(json_str)
+ return _ffi_node_api.LoadJSON(json_str)
def save_json(node) -> str:
diff --git a/python/tvm/ir/json_compact.py b/python/tvm/ir/json_compact.py
index 224932b00c..cb6e031667 100644
--- a/python/tvm/ir/json_compact.py
+++ b/python/tvm/ir/json_compact.py
@@ -57,6 +57,32 @@ def create_updater(node_map, from_ver, to_ver):
return _updater
+def create_updater_15_to_16():
+ """
+ Create an update to upgrade json from v0.15 to v0.16
+
+ Returns
+ -------
+ fupdater : function
+ The updater function
+ """
+
+ def _update_lanes_obj(item, nodes):
+ lanes = item["attrs"]["lanes"]
+ new_idx = len(nodes)
+ item["attrs"]["lanes"] = str(new_idx)
+ lanes_node = {
+ "type_key": "IntImm",
+ "attrs": {"dtype": "int32", "span": "0", "value": lanes},
+ }
+ nodes.append(lanes_node)
+ return item
+
+ node_map = {"tir.Ramp": _update_lanes_obj, "tir.Broadcast":
_update_lanes_obj}
+
+ return create_updater(node_map, "0.15", "0.16")
+
+
def create_updater_13_to_14():
"""Create an update to upgrade json from v0.13 to v0.14 for TVM Unity"""
@@ -266,16 +292,29 @@ def upgrade_json(json_str):
The updated version.
"""
data = json.loads(json_str)
- from_version = data["attrs"]["tvm_version"]
- if from_version.startswith("0.6"):
- data =
create_updater_08_to_09()(create_updater_07_to_08()(create_updater_06_to_07()(data)))
- elif from_version.startswith("0.7"):
- data = create_updater_08_to_09()(create_updater_07_to_08()(data))
- elif from_version.startswith("0.8"):
+ def _from_version(data):
+ return data["attrs"]["tvm_version"]
+
+ if _from_version(data).startswith("0.6"):
+ data = create_updater_06_to_07()(data)
+ if _from_version(data).startswith("0.7"):
+ data = create_updater_07_to_08()(data)
+ if _from_version(data).startswith("0.8"):
data = create_updater_08_to_09()(data)
- elif from_version.startswith("0.13"):
+ if _from_version(data).startswith("0.9"):
+ data = create_updater({}, "0.9", "0.10")(data)
+ if _from_version(data).startswith("0.10"):
+ data = create_updater({}, "0.10", "0.11")(data)
+ if _from_version(data).startswith("0.11"):
+ data = create_updater({}, "0.11", "0.12")(data)
+ if _from_version(data).startswith("0.12"):
+ data = create_updater({}, "0.12", "0.13")(data)
+ if _from_version(data).startswith("0.13"):
data = create_updater_13_to_14()(data)
- else:
- raise ValueError(f"Cannot update from version {from_version}")
+ if _from_version(data).startswith("0.14"):
+ data = create_updater({}, "0.14", "0.15")(data)
+ if _from_version(data).startswith("0.15"):
+ data = create_updater_15_to_16()(data)
+
return json.dumps(data, indent=2)
diff --git a/python/tvm/script/ir_builder/tir/ir.py
b/python/tvm/script/ir_builder/tir/ir.py
index 8a93537f77..a5c09cf1a3 100644
--- a/python/tvm/script/ir_builder/tir/ir.py
+++ b/python/tvm/script/ir_builder/tir/ir.py
@@ -1289,7 +1289,7 @@ def buffer_store(
if lanes == 1:
expr_indices.append(index.start)
else:
- expr_indices.append(ramp(index.start, step, int(lanes)))
+ expr_indices.append(ramp(index.start, step, lanes))
else:
expr_indices.append(index)
if isinstance(value, bool) and buffer.dtype == "bool":
diff --git a/python/tvm/tir/expr.py b/python/tvm/tir/expr.py
index fad9fca083..fca501874d 100644
--- a/python/tvm/tir/expr.py
+++ b/python/tvm/tir/expr.py
@@ -1146,10 +1146,10 @@ class Ramp(PrimExprWithOp):
base : PrimExpr
The base expression.
- stride : ramp stride
+ stride : PrimExpr
The stride of the ramp.
- lanes : int
+ lanes : PrimExpr
The lanes of the expression.
span : Optional[Span]
@@ -1158,10 +1158,10 @@ class Ramp(PrimExprWithOp):
base: PrimExpr
stride: PrimExpr
- lanes: int
+ lanes: PrimExpr
def __init__(
- self, base: PrimExpr, stride: PrimExpr, lanes: int, span:
Optional[Span] = None
+ self, base: PrimExpr, stride: PrimExpr, lanes: PrimExpr, span:
Optional[Span] = None
) -> None:
self.__init_handle_by_constructor__(
_ffi_api.Ramp, base, stride, lanes, span # type: ignore
@@ -1177,7 +1177,7 @@ class Broadcast(PrimExprWithOp):
value : PrimExpr
The value of the expression.
- lanes : int
+ lanes : PrimExpr
The lanes of the expression.
span : Optional[Span]
@@ -1185,9 +1185,9 @@ class Broadcast(PrimExprWithOp):
"""
value: PrimExpr
- lanes: int
+ lanes: PrimExpr
- def __init__(self, value: PrimExpr, lanes: int, span: Optional[Span] =
None) -> None:
+ def __init__(self, value: PrimExpr, lanes: PrimExpr, span: Optional[Span]
= None) -> None:
self.__init_handle_by_constructor__(_ffi_api.Broadcast, value, lanes,
span) # type: ignore
diff --git a/src/arith/const_fold.h b/src/arith/const_fold.h
index e66991d70e..65ac749d45 100644
--- a/src/arith/const_fold.h
+++ b/src/arith/const_fold.h
@@ -72,7 +72,8 @@ inline Optional<PrimExpr> TryConstFold(PrimExpr a);
* \return the checked result.
*/
inline bool IsIndexType(const DataType& type) {
- return type.is_int() && type.lanes() == 1 && (type.bits() == 32 ||
type.bits() == 64);
+ return type.is_int() && !type.is_scalable_or_fixed_length_vector() &&
+ (type.bits() == 32 || type.bits() == 64);
}
/*! \brief Helper to get const folding result repr in int64. */
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 625488430b..579870e5f5 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -466,14 +466,23 @@ class IntervalSetEvaluator : public
ExprFunctor<IntervalSet(const PrimExpr&)> {
if (stride.Match(op->stride)) {
DataType t = op->base.dtype();
int64_t vstride = stride.Eval()->value;
- if (vstride > 0) {
- return Combine<Add>(analyzer_, base,
- IntervalSet(make_zero(t), make_const(t, vstride *
(op->lanes - 1))),
- op->dtype);
- } else {
- return Combine<Add>(analyzer_, base,
- IntervalSet(make_const(t, vstride * (op->lanes -
1)), make_zero(t)),
- op->dtype);
+ if (op->lanes->IsInstance<IntImmNode>()) {
+ int lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
+ if (vstride > 0) {
+ return Combine<Add>(analyzer_, base,
+ IntervalSet(make_zero(t), make_const(t, vstride
* (lanes - 1))),
+ op->dtype);
+ } else {
+ return Combine<Add>(analyzer_, base,
+ IntervalSet(make_const(t, vstride * (lanes -
1)), make_zero(t)),
+ op->dtype);
+ }
+ } else { /* Scalable vector */
+ if (vstride > 0) {
+ return Combine<Add>(analyzer_, base, IntervalSet(make_zero(t),
pos_inf()), op->dtype);
+ } else {
+ return Combine<Add>(analyzer_, base, IntervalSet(neg_inf(),
make_zero(t)), op->dtype);
+ }
}
}
DLOG(WARNING) << "cannot evaluate set on expression " <<
GetRef<PrimExpr>(op);
@@ -957,7 +966,7 @@ IntSet EvalSet(PrimExpr e, const Map<Var, IntSet>& dom_map)
{
IntSet IntSet::Vector(PrimExpr x) {
// short cut: simply get single point
- if (x.dtype().lanes() == 1) {
+ if (!x.dtype().is_scalable_or_fixed_length_vector()) {
return IntSet::SinglePoint(x);
} else {
// vector case.
diff --git a/src/arith/pattern_match.h b/src/arith/pattern_match.h
index d057a840e8..98cf61990d 100644
--- a/src/arith/pattern_match.h
+++ b/src/arith/pattern_match.h
@@ -628,10 +628,11 @@ inline PRampExpr<TBase, TStride, TLanes> ramp(const
Pattern<TBase>& base,
}
template <typename TBase>
-inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>> ramp(const
Pattern<TBase>& base,
- int
stride, int lanes) {
- return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>(
- base.derived(), PConstWithTypeLike<TBase>(base.derived(), stride),
PConst<int>(lanes));
+inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConstWithTypeLike<TBase>>
ramp(
+ const Pattern<TBase>& base, int stride, int lanes) {
+ return PRampExpr<TBase, PConstWithTypeLike<TBase>,
PConstWithTypeLike<TBase>>(
+ base.derived(), PConstWithTypeLike<TBase>(base.derived(), stride),
+ PConstWithTypeLike<TBase>(base.derived(), lanes));
}
/*!
@@ -835,6 +836,12 @@ inline PCallExpr<PIfThenElseOp, TCond, TA, TB>
if_then_else(const Pattern<TCond>
false_value.derived());
}
+// vscale
+struct PVscaleOp {
+ static PrimExpr Eval() { return tir::Call(DataType::Int(32), GetOp(), {}); }
+ static const Op& GetOp() { return tir::builtin::vscale(); }
+};
+
template <typename... TPattern>
class PMatchesOneOf {
public:
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index d5f946fca0..0eaaff5ba8 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -37,6 +37,7 @@
#include "const_fold.h"
#include "constraint_extract.h"
#include "pattern_match.h"
+#include "scalable_expression.h"
namespace tvm {
namespace arith {
@@ -247,9 +248,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const AddNode*
op) {
// Pattern var match FloatImm
PVar<FloatImm> c4;
// Pattern var for lanes in broadcast and ramp
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// Vector rules
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) + ramp(b2, s2, lanes), ramp(b1 + b2,
s1 + s2, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) + broadcast(x, lanes), ramp(b1 + x,
s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) + ramp(b1, s1, lanes), ramp(x + b1,
s1, lanes));
@@ -396,9 +397,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const SubNode*
op) {
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// Vector rules
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - ramp(b2, s2, lanes), ramp(b1 - b2,
s1 - s2, lanes));
TVM_TRY_REWRITE(ramp(b1, s1, lanes) - broadcast(x, lanes), ramp(b1 - x,
s1, lanes));
TVM_TRY_REWRITE(broadcast(x, lanes) - ramp(b1, s1, lanes), ramp(x - b1, 0
- s1, lanes));
@@ -580,9 +581,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const MulNode*
op) {
// Pattern var match FloatImm
PVar<FloatImm> c3;
// Pattern var for lanes in broadcast and ramp
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// Vector rules
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(broadcast(x, lanes) * broadcast(y, lanes), broadcast(x *
y, lanes));
TVM_TRY_REWRITE(matches_one_of(ramp(b1, s1, lanes) * broadcast(x, lanes),
broadcast(x, lanes) * ramp(b1, s1, lanes)),
@@ -617,7 +618,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode*
op) {
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// x / 2.0 = x * 0.5
if (const FloatImmNode* ptr = op->b.as<FloatImmNode>()) {
@@ -627,7 +628,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const DivNode*
op) {
}
// Vector rules
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
// NOTE: use div as the pattern also works for float.
TVM_TRY_REWRITE(div(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(div(x, y), lanes));
// ramp / bcast
@@ -639,10 +640,11 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
DivNode* op) {
return ramp(div(b1, c2), div(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
- if (CanProveGreaterEqual(b1.Eval(), 0)) {
+ if (CanProveGreaterEqual(b1.Eval(), 0) &&
!arith::ExtractVscaleFactor(lanes.Eval())) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
int64_t ramp_min = bmod->base / c2val;
- int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
+ auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
+ int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val;
if (bmod->coeff % c2val == 0 && ramp_min == ramp_max) {
return broadcast(div(b1, c2), lanes).Eval();
}
@@ -777,10 +779,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
ModNode* op) {
// Pattern var match IntImm
PVar<IntImm> c1, c2;
// Pattern var for lanes in broadcast and ramp
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// Vector rules
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(truncmod(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(truncmod(x, y), lanes));
@@ -795,12 +797,21 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
ModNode* op) {
// If all possible indices in ramp are the same.
if (CanProveGreaterEqual(b1.Eval(), 0)) {
ModularSet bmod = analyzer_->modular_set(b1.Eval());
- int64_t ramp_min = bmod->base / c2val;
- int64_t ramp_max = (bmod->base + (lanes.Eval() - 1) * c1val) / c2val;
- if (bmod->coeff % c2val == 0) {
- if (ramp_min == ramp_max) {
- return ramp(truncmod(bmod->base, c2), c1, lanes).Eval();
- } else {
+ if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+ auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
+ int64_t ramp_min = bmod->base / c2val;
+ int64_t ramp_max = (bmod->base + (lanes_int - 1) * c1val) / c2val;
+ if (bmod->coeff % c2val == 0) {
+ if (ramp_min == ramp_max) {
+ return ramp(truncmod(bmod->base, c2), c1, lanes).Eval();
+ } else {
+ return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes),
broadcast(c2, lanes))
+ .Eval();
+ }
+ }
+ } else { /* Special case for scalable vectors */
+ ModularSet bmod = analyzer_->modular_set(b1.Eval());
+ if (bmod->coeff % c2val == 0) {
return truncmod(ramp(truncmod(bmod->base, c2), c1, lanes),
broadcast(c2, lanes)).Eval();
}
}
@@ -857,10 +868,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorDivNode* op) {
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// Vector rules
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(floordiv(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(floordiv(x, y), lanes));
// ramp // bcast
@@ -872,17 +883,20 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorDivNode* op) {
return ramp(floordiv(b1, c2), floordiv(c1, c2), lanes).Eval();
}
// If all possible indices in ramp are the same.
- ModularSet bmod = analyzer_->modular_set(b1.Eval());
- int64_t ramp_min = floordiv(bmod->base, c2val);
- int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val,
c2val);
- if (ramp_min == ramp_max) {
- // If b1 can devide c2
- if (bmod->coeff % c2val == 0) {
- return broadcast(floordiv(b1, c2), lanes).Eval();
- }
- // If all indices can be guaranteed to settle inside a coeff range
- if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) *
c1val < bmod->coeff) {
- return broadcast(floordiv(b1, c2), lanes).Eval();
+ if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+ ModularSet bmod = analyzer_->modular_set(b1.Eval());
+ int64_t ramp_min = floordiv(bmod->base, c2val);
+ auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
+ int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val,
c2val);
+ if (ramp_min == ramp_max) {
+ // If b1 can divide c2
+ if (bmod->coeff % c2val == 0) {
+ return broadcast(floordiv(b1, c2), lanes).Eval();
+ }
+ // If all indices can be guaranteed to settle inside a coeff range
+ if (c2val % bmod->coeff == 0 && bmod->base + (lanes_int - 1) * c1val
< bmod->coeff) {
+ return broadcast(floordiv(b1, c2), lanes).Eval();
+ }
}
}
}
@@ -993,10 +1007,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
// Pattern var match IntImm
PVar<IntImm> c1, c2;
// Pattern var for lanes in broadcast and ramp
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// Vector rules
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(floormod(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(floormod(x, y), lanes));
@@ -1010,21 +1024,29 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
}
// If all possible indices in ramp are the same.
ModularSet bmod = analyzer_->modular_set(b1.Eval());
- int64_t ramp_min = floordiv(bmod->base, c2val);
- int64_t ramp_max = floordiv(bmod->base + (lanes.Eval() - 1) * c1val,
c2val);
- if (ramp_min == ramp_max) {
- // If b1 can devide c2
+ if (!arith::ExtractVscaleFactor(lanes.Eval())) {
+ int64_t ramp_min = floordiv(bmod->base, c2val);
+ auto lanes_int = lanes.Eval().as<IntImmNode>()->value;
+ int64_t ramp_max = floordiv(bmod->base + (lanes_int - 1) * c1val,
c2val);
+ if (ramp_min == ramp_max) {
+ // If b1 can divide c2
+ if (bmod->coeff % c2val == 0) {
+ return ramp(floormod(bmod->base, c2), c1, lanes).Eval();
+ }
+ // If all indices can be guaranteed to settle inside a coeff range
+ if (c2val % bmod->coeff == 0 && bmod->base + (lanes_int - 1) * c1val
< bmod->coeff) {
+ return ramp(floormod(b1, c2), c1, lanes).Eval();
+ }
+ }
+ // If b1 can divide c2
if (bmod->coeff % c2val == 0) {
- return ramp(floormod(bmod->base, c2), c1, lanes).Eval();
+ return floormod(ramp(floormod(bmod->base, c2), c1, lanes),
broadcast(c2, lanes)).Eval();
}
- // If all indices can be guaranteed to settle inside a coeff range
- if (c2val % bmod->coeff == 0 && bmod->base + (lanes.Eval() - 1) *
c1val < bmod->coeff) {
- return ramp(floormod(b1, c2), c1, lanes).Eval();
+ } else { /* scalable vectors */
+ if (bmod->coeff % c2val == 0) {
+ return floormod(ramp(floormod(bmod->base, c2), c1, lanes),
broadcast(c2, lanes)).Eval();
}
}
- if (bmod->coeff % c2val == 0) {
- return floormod(ramp(floormod(bmod->base, c2), c1, lanes),
broadcast(c2, lanes)).Eval();
- }
}
}
@@ -1093,10 +1115,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
MinNode* op) {
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// vector rule
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(min(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(min(x, y), lanes));
TVM_TRY_REWRITE(min(min(x, broadcast(y, lanes)), broadcast(z, lanes)),
min(x, broadcast(min(y, z), lanes)));
@@ -1267,10 +1289,10 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
MaxNode* op) {
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// vector rule
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(max(broadcast(x, lanes), broadcast(y, lanes)),
broadcast(max(x, y), lanes));
TVM_TRY_REWRITE(max(max(x, broadcast(y, lanes)), broadcast(z, lanes)),
max(x, broadcast(max(y, z), lanes)));
@@ -1475,10 +1497,10 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(EQ
ret) {
PVar<PrimExpr> x, y;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// vector rule
- if (ret->dtype.lanes() != 1) {
+ if (ret->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(broadcast(x, lanes) == broadcast(y, lanes), broadcast(x ==
y, lanes));
}
@@ -1603,10 +1625,10 @@ PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(LT
ret) {
PVar<PrimExpr> x, y, z, s1, s2;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// vector rule
- if (ret->dtype.lanes() != 1) {
+ if (ret->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(broadcast(x, lanes) < broadcast(y, lanes), broadcast(x <
y, lanes));
TVM_TRY_REWRITE(ramp(x, s1, lanes) < ramp(y, s1, lanes), broadcast(x < y,
lanes));
}
@@ -1761,8 +1783,8 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
NotNode* op) {
PrimExpr RewriteSimplifier::Impl::ApplyRewriteRules(Not ret) {
// Pattern var to match any expression
PVar<PrimExpr> x, y;
- PVar<int> lanes;
- if (ret->dtype.lanes() != 1) {
+ PVar<PrimExpr> lanes;
+ if (ret->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(!broadcast(x, lanes), broadcast(!x, lanes));
}
@@ -1836,9 +1858,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
AndNode* op) {
PVar<PrimExpr> x, y, z;
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(broadcast(x, lanes) && broadcast(y, lanes), broadcast(x &&
y, lanes));
}
@@ -1984,9 +2006,9 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
OrNode* op) {
PVar<PrimExpr> x, y, z;
// Pattern var match IntImm
PVar<IntImm> c1, c2;
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
- if (op->dtype.lanes() != 1) {
+ if (op->dtype.is_scalable_or_fixed_length_vector()) {
TVM_TRY_REWRITE(broadcast(x, lanes) || broadcast(y, lanes), broadcast(x ||
y, lanes));
}
diff --git a/src/arith/scalable_expression.cc b/src/arith/scalable_expression.cc
new file mode 100644
index 0000000000..85fd149e04
--- /dev/null
+++ b/src/arith/scalable_expression.cc
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/scalable_expression.cc
+ * \brief Analyze scalable expressions.
+ */
+
+#include "scalable_expression.h"
+
+#include <tvm/tir/expr.h>
+#include <tvm/tir/op.h>
+
+#include "./pattern_match.h"
+
+namespace tvm {
+namespace arith {
+
+bool IsVScaleCall(const PrimExpr& expr) {
+ if (auto call = expr.as<tir::CallNode>()) {
+ return call->op.same_as(tir::builtin::vscale());
+ }
+ return false;
+}
+
+std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes) {
+ PVar<IntImm> multiplier;
+ PCallExpr<PVscaleOp> vscale;
+
+ if (PMatchesOneOf(multiplier * vscale, vscale * multiplier).Match(lanes)) {
+ return multiplier.Eval()->value;
+ } else {
+ return std::nullopt;
+ }
+}
+
+} // namespace arith
+} // namespace tvm
diff --git a/src/arith/scalable_expression.h b/src/arith/scalable_expression.h
new file mode 100644
index 0000000000..3c7fb0bb26
--- /dev/null
+++ b/src/arith/scalable_expression.h
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/arith/scalable_expression.h
+ * \brief Analyze scalable expressions.
+ */
+
+#ifndef TVM_ARITH_SCALABLE_EXPRESSION_H_
+#define TVM_ARITH_SCALABLE_EXPRESSION_H_
+
+#include <tvm/ir/expr.h>
+
+#include <optional>
+
+namespace tvm {
+namespace arith {
+
+/*!
+ * \brief Check if an expr is a call to the vscale intrinsic.
+ * \param expr The expr to check
+ * \return True if the expr is a call to the vscale intrinsic, false if not.
+ */
+bool IsVScaleCall(const PrimExpr& expr);
+
+/*!
+ * \brief Returns the vscale multiplier as a nullable type
+ * \param lanes The scalable lanes as a PrimExpr
+ * \return vscale multiplier as std::optional<int>
+ */
+std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);
+
+} // namespace arith
+} // namespace tvm
+
+#endif // TVM_ARITH_SCALABLE_EXPRESSION_H_
diff --git a/src/ir/expr.cc b/src/ir/expr.cc
index fdd8c2cd8b..596805f74b 100644
--- a/src/ir/expr.cc
+++ b/src/ir/expr.cc
@@ -53,8 +53,8 @@ PrimExpr PrimExpr::FromObject_(ObjectRef ref) {
for (const Range& r : buffer_region->region) {
if (tvm::tir::is_one(r->extent)) {
indices.push_back(r->min);
- } else if (const auto* extent = r->extent.as<IntImmNode>()) {
- indices.push_back(tir::Ramp(r->min,
tvm::tir::make_const(r->min->dtype, 1), extent->value));
+ } else if (r->extent.as<IntImmNode>()) {
+ indices.push_back(tir::Ramp(r->min,
tvm::tir::make_const(r->min->dtype, 1), r->extent));
} else {
LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: " << ref;
}
diff --git a/src/relay/printer/tir_text_printer.cc
b/src/relay/printer/tir_text_printer.cc
index e9a9ee2313..c34788be91 100644
--- a/src/relay/printer/tir_text_printer.cc
+++ b/src/relay/printer/tir_text_printer.cc
@@ -381,13 +381,13 @@ Doc TIRTextPrinter::VisitExpr_(const ProducerLoadNode*
op) {
Doc TIRTextPrinter::VisitExpr_(const RampNode* op) {
Doc doc;
- doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " <<
op->lanes << ")";
+ doc << "ramp(" << Print(op->base) << ", " << Print(op->stride) << ", " <<
Print(op->lanes) << ")";
return doc;
}
Doc TIRTextPrinter::VisitExpr_(const BroadcastNode* op) {
Doc doc;
- doc << "broadcast(" << Print(op->value) << ", " << op->lanes << ")";
+ doc << "broadcast(" << Print(op->value) << ", " << Print(op->lanes) << ")";
return doc;
}
diff --git a/src/relay/printer/tvmscript_printer.cc
b/src/relay/printer/tvmscript_printer.cc
index b0085b8242..1126e633d6 100644
--- a/src/relay/printer/tvmscript_printer.cc
+++ b/src/relay/printer/tvmscript_printer.cc
@@ -912,14 +912,14 @@ Doc TVMScriptPrinter::VisitExpr_(const RampNode* op,
ExprPrecedence* out_precede
*out_precedence = ExprPrecedence::kIdentity;
Doc doc;
doc << tir_prefix_ << ".ramp(" << Print(op->base) << ", " <<
Print(op->stride) << ", "
- << op->lanes << ")";
+ << Print(op->lanes) << ")";
return doc;
}
Doc TVMScriptPrinter::VisitExpr_(const BroadcastNode* op, ExprPrecedence*
out_precedence) {
*out_precedence = ExprPrecedence::kIdentity;
Doc doc;
- doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " << op->lanes
<< ")";
+ doc << tir_prefix_ << ".broadcast(" << Print(op->value) << ", " <<
Print(op->lanes) << ")";
return doc;
}
diff --git a/src/script/ir_builder/tir/ir.cc b/src/script/ir_builder/tir/ir.cc
index d6554fc371..cf73ffa0ee 100644
--- a/src/script/ir_builder/tir/ir.cc
+++ b/src/script/ir_builder/tir/ir.cc
@@ -518,11 +518,44 @@ Var EnvThread(String thread_tag) {
void BufferStore(Buffer buffer, PrimExpr value, Array<PrimExpr> indices) {
runtime::DataType buffer_dtype = buffer->dtype;
- int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
- runtime::DataType lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() *
index_lanes);
+ bool is_index_scalable = indices.empty() ? false :
indices.back().dtype().is_scalable_vector();
+ bool is_buffer_dtype_scalable = buffer_dtype.is_scalable_vector();
+
+ ICHECK(!(is_index_scalable && is_buffer_dtype_scalable))
+ << "Index dtype and buffer dtype can't both be scalable.";
+
+ int index_lanes;
+ if (indices.empty()) {
+ index_lanes = 1;
+ } else if (is_index_scalable) {
+ index_lanes = indices.back().dtype().vscale_factor();
+ } else {
+ index_lanes = indices.back().dtype().lanes();
+ }
+
+ int buffer_lanes = is_buffer_dtype_scalable ? buffer_dtype.vscale_factor() :
buffer_dtype.lanes();
+
+ runtime::DataType lhs_dtype;
+ if (is_buffer_dtype_scalable || is_index_scalable) {
+ lhs_dtype = buffer_dtype.with_scalable_vscale_factor(buffer_lanes *
index_lanes);
+ } else {
+ lhs_dtype = buffer_dtype.with_lanes(buffer_dtype.lanes() * index_lanes);
+ }
+
runtime::DataType rhs_dtype = value->dtype;
+
if (lhs_dtype != rhs_dtype) {
- if (lhs_dtype.lanes() != rhs_dtype.lanes()) {
+ ICHECK(lhs_dtype.is_scalable_vector() == rhs_dtype.is_scalable_vector())
+ << "Can't mix scalable and fixed length vectors in a statement";
+
+ bool lanes_match = false;
+ if (lhs_dtype.is_scalable_vector()) {
+ lanes_match = lhs_dtype.vscale_factor() == rhs_dtype.vscale_factor();
+ } else {
+ lanes_match = lhs_dtype.lanes() == rhs_dtype.lanes();
+ }
+
+ if (!lanes_match) {
LOG(FATAL) << "TypeError: Incompatible types in BufferStore"
<< ": LHS is `" << lhs_dtype << "`, RHS is `" << rhs_dtype
<< "`, indexing lanes: " << index_lanes;
diff --git a/src/script/printer/tir/expr.cc b/src/script/printer/tir/expr.cc
index e25b074401..8268e6b35e 100644
--- a/src/script/printer/tir/expr.cc
+++ b/src/script/printer/tir/expr.cc
@@ -137,7 +137,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return TIR(d, "Ramp")->Call({
d->AsDoc<ExprDoc>(ramp->base, ramp_p->Attr("base")),
d->AsDoc<ExprDoc>(ramp->stride, ramp_p->Attr("stride")),
- LiteralDoc::Int(ramp->lanes, ramp_p->Attr("lanes")),
+ d->AsDoc<ExprDoc>(ramp->lanes, ramp_p->Attr("lanes")),
});
});
@@ -146,7 +146,7 @@ TVM_STATIC_IR_FUNCTOR(IRDocsifier, vtable)
return TIR(d, "Broadcast")
->Call({
d->AsDoc<ExprDoc>(bc->value, bc_p->Attr("value")),
- LiteralDoc::Int(bc->lanes, bc_p->Attr("lanes")),
+ d->AsDoc<ExprDoc>(bc->lanes, bc_p->Attr("lanes")),
});
});
diff --git a/src/target/llvm/codegen_arm.cc b/src/target/llvm/codegen_arm.cc
index 15d1699b3b..ba9641bb06 100644
--- a/src/target/llvm/codegen_arm.cc
+++ b/src/target/llvm/codegen_arm.cc
@@ -72,7 +72,7 @@ PrimExpr CodeGenARM::ARMPopcount(const CallNode* call) {
// Fallback to default llvm lowering rule if input type not a full vector or
half vector length
int total_size = call->dtype.bits() * call->dtype.lanes();
- if (!call->dtype.is_vector() || call->dtype.bits() == 8 ||
+ if (!call->dtype.is_fixed_length_vector() || call->dtype.bits() == 8 ||
(total_size != 128 && total_size != 64)) {
Array<PrimExpr> vcnt_args;
vcnt_args.push_back(IntImm(DataType::UInt(32), ctpop_id));
diff --git a/src/target/llvm/codegen_hexagon.cc
b/src/target/llvm/codegen_hexagon.cc
index 8b884669c3..6ef5e064c0 100644
--- a/src/target/llvm/codegen_hexagon.cc
+++ b/src/target/llvm/codegen_hexagon.cc
@@ -393,7 +393,7 @@ llvm::Value* CodeGenHexagon::Intrinsic(llvm::Intrinsic::ID
IntID,
llvm::Value* CodeGenHexagon::VectorLookupLoad(Buffer buffer, DataType
buffer_type,
Array<PrimExpr> indices) {
PrimExpr index = indices[0];
- if (!index.dtype().is_vector()) {
+ if (!index.dtype().is_fixed_length_vector()) {
return nullptr;
}
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index a689183d8f..60c102ceaa 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -641,13 +641,13 @@ void CodeGenLLVM::AddAliasInfo(llvm::Instruction* inst,
const VarNode* buffer_va
int64_t base = 0, width = 0;
arith::PVar<IntImm> pbase, pstride;
- arith::PVar<int> planes;
+ arith::PVar<IntImm> planes;
// create meta-data for alias analysis
// Use a group of binary tree ranges of memory banks.
int64_t xwith = 0;
if (arith::ramp(pbase, pstride, planes).Match(index)) {
base = pbase.Eval()->value;
- xwith = planes.Eval() * pstride.Eval()->value;
+ xwith = planes.Eval()->value * pstride.Eval()->value;
} else if (auto* ptr = index.as<tir::IntImmNode>()) {
base = ptr->value;
xwith = 1;
@@ -1730,6 +1730,8 @@ void CodeGenLLVM::BufferAccessHelper(
llvm::Value* last_index_value;
int subelement_i = i;
if (const RampNode* ramp = last_index.as<RampNode>()) {
+ // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
+ ICHECK(!last_index.dtype().is_scalable_vector());
PrimExpr offset = ramp->base + (ramp->stride * i);
last_index_value = MakeValue(offset);
} else if (last_index.dtype().lanes() > 1) {
@@ -1827,7 +1829,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const CallNode* op)
{
llvm::Value* CodeGenLLVM::VisitExpr_(const RampNode* op) {
llvm::Value* vec = llvm::UndefValue::get(DTypeToLLVMType(op->dtype));
- for (int i = 0; i < op->lanes; ++i) {
+ // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
+ ICHECK(!op->dtype.is_scalable_vector());
+ int lanes = op->dtype.lanes();
+ for (int i = 0; i < lanes; ++i) {
vec = builder_->CreateInsertElement(
vec, MakeValue(op->base + op->stride * make_const(op->stride.dtype(),
i)), ConstInt32(i));
}
@@ -1859,7 +1864,10 @@ llvm::Value* CodeGenLLVM::VisitExpr_(const ShuffleNode*
op) {
}
llvm::Value* CodeGenLLVM::VisitExpr_(const BroadcastNode* op) {
- return CreateBroadcast(MakeValue(op->value), op->lanes);
+ // TODO(ekalda): P4 in https://github.com/apache/tvm/issues/16455
+ ICHECK(!op->dtype.is_scalable_vector());
+ int lanes = op->dtype.lanes();
+ return CreateBroadcast(MakeValue(op->value), lanes);
}
void CodeGenLLVM::VisitStmt_(const BufferStoreNode* op) {
diff --git a/src/target/llvm/codegen_nvptx.cc b/src/target/llvm/codegen_nvptx.cc
index 0b13fce678..b500958a8a 100644
--- a/src/target/llvm/codegen_nvptx.cc
+++ b/src/target/llvm/codegen_nvptx.cc
@@ -223,7 +223,7 @@ class CodeGenNVPTX : public CodeGenLLVM {
// corresponding nvvm intrinsic. Return true if the match is successful.
static bool GetWarpShuffleIntrinsic(const CallNode* op, llvm::Intrinsic::ID*
id) {
// Only 32 bit data type is supported.
- if (op->dtype.is_vector() || op->dtype.bits() != 32) {
+ if (op->dtype.is_fixed_length_vector() || op->dtype.bits() != 32) {
return false;
}
diff --git a/src/target/source/codegen_c.cc b/src/target/source/codegen_c.cc
index dd83fbdcbd..abb62f2faf 100644
--- a/src/target/source/codegen_c.cc
+++ b/src/target/source/codegen_c.cc
@@ -888,11 +888,12 @@ void CodeGenC::VisitExpr_(const RampNode* op,
std::ostream& os) { // NOLINT(*)
// NOTE: C have comma expression so cannot use (int2)(v0, v1)
// instead should use int2(v0, v1)
PrintType(op->dtype, os);
+ int lanes = op->dtype.lanes();
os << "(";
- for (int i = 0; i < op->lanes; i++) {
+ for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
- if (i != op->lanes - 1) os << ", ";
+ if (i != lanes - 1) os << ", ";
}
os << ")";
}
diff --git a/src/target/source/codegen_c_host.cc
b/src/target/source/codegen_c_host.cc
index d16d749e9b..b22d32d6c5 100644
--- a/src/target/source/codegen_c_host.cc
+++ b/src/target/source/codegen_c_host.cc
@@ -208,10 +208,11 @@ void CodeGenCHost::PrintType(DataType t, std::ostream&
os) { // NOLINT(*)
void CodeGenCHost::VisitExpr_(const BroadcastNode* op, std::ostream& os) { //
NOLINT(*)
std::string v = PrintExpr(op->value);
+ int lanes = op->dtype.lanes();
os << "((";
PrintType(op->dtype, os);
os << ")(";
- for (int i = 0; i < op->lanes; ++i) {
+ for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
diff --git a/src/target/source/codegen_cuda.cc
b/src/target/source/codegen_cuda.cc
index efed5c02f1..15905b0304 100644
--- a/src/target/source/codegen_cuda.cc
+++ b/src/target/source/codegen_cuda.cc
@@ -672,7 +672,7 @@ void CodeGenCUDA::VisitExpr_(const CastNode* op,
std::ostream& os) {
void CodeGenCUDA::PrintCallExtern(Type ret_type, String global_symbol, const
Array<PrimExpr>& args,
bool skip_first_arg, std::ostream& os) { //
NOLINT(*)
DataType ret_dtype = GetRuntimeDataType(ret_type);
- if (ret_dtype.is_vector()) {
+ if (ret_dtype.is_fixed_length_vector()) {
//
// Emit an unsupported vector call
//
@@ -1162,19 +1162,21 @@ void CodeGenCUDA::VisitStmt_(const EvaluateNode* op) {
}
void CodeGenCUDA::VisitExpr_(const RampNode* op, std::ostream& os) {
- CHECK_LE(op->lanes, 4) << "ValueError: Ramp of more than 4 lanes is not
allowed.";
+ int lanes = op->dtype.lanes();
+ CHECK_LE(lanes, 4) << "ValueError: Ramp of more than 4 lanes is not
allowed.";
PrintVecConstructor(op->dtype, os);
os << "(";
- for (int i = 0; i < op->lanes; i++) {
+ for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
- if (i != op->lanes - 1) os << ", ";
+ if (i != lanes - 1) os << ", ";
}
os << ")";
}
void CodeGenCUDA::VisitExpr_(const BroadcastNode* op, std::ostream& os) { //
NOLINT(*)
- if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
op->lanes == 4) {
+ int lanes = op->dtype.lanes();
+ if ((op->dtype.is_int() || op->dtype.is_uint()) && op->dtype.bits() == 8 &&
lanes == 4) {
// make_int8x4
const int64_t* p = as_const_int(op->value);
ICHECK(p);
@@ -1192,7 +1194,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { // NO
std::string v = PrintExpr(op->value);
PrintVecConstructor(op->dtype, os);
os << '(';
- for (int i = 0; i < op->lanes / 2; ++i) {
+ for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_half2(" << v << ", " << v << ")";
}
@@ -1204,7 +1206,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { // NO
std::string v = PrintExpr(op->value);
PrintVecConstructor(op->dtype, os);
os << '(';
- for (int i = 0; i < op->lanes / 2; ++i) {
+ for (int i = 0; i < lanes / 2; ++i) {
if (i != 0) os << ", ";
os << "__pack_nv_bfloat162(" << v << ", " << v << ")";
}
@@ -1218,7 +1220,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { // NO
ICHECK(p);
int64_t v = *p & 0xF;
- if (op->lanes == 4) {
+ if (lanes == 4) {
v = (v << 12) | (v << 8) | (v << 4) | v;
if (op->dtype.is_uint()) {
os << "(uint16_t)" << v;
@@ -1227,16 +1229,16 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { // NO
}
} else {
v = (v << 28) | (v << 24) | (v << 20) | (v << 16) | (v << 12) | (v << 8)
| (v << 4) | v;
- if (op->lanes == 8) {
+ if (lanes == 8) {
if (op->dtype.is_uint()) {
os << "(uint)" << v;
} else {
os << "(int)" << v;
}
- } else if (op->lanes == 16 || op->lanes == 32) {
+ } else if (lanes == 16 || lanes == 32) {
PrintVecConstructor(op->dtype, os);
os << '(';
- for (int i = 0; i < op->lanes / 8; ++i) {
+ for (int i = 0; i < lanes / 8; ++i) {
if (i != 0) os << ", ";
if (op->dtype.is_uint()) {
os << "(uint)" << v;
@@ -1258,7 +1260,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { // NO
std::string v = PrintExpr(op->value);
PrintVecConstructor(op->dtype, os);
os << '(';
- for (int i = 0; i < op->lanes; ++i) {
+ for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
@@ -1267,7 +1269,7 @@ void CodeGenCUDA::VisitExpr_(const BroadcastNode* op,
std::ostream& os) { // NO
void CodeGenCUDA::VisitExpr_(const SelectNode* op, std::ostream& os) {
// Non-vector cases.
- if (!op->dtype.is_vector()) {
+ if (!op->dtype.is_fixed_length_vector()) {
CodeGenC::VisitExpr_(op, os);
return;
}
diff --git a/src/target/source/codegen_metal.cc
b/src/target/source/codegen_metal.cc
index 86d5956dec..e729af417c 100644
--- a/src/target/source/codegen_metal.cc
+++ b/src/target/source/codegen_metal.cc
@@ -308,9 +308,10 @@ void CodeGenMetal::VisitExpr_(const SelectNode* op,
std::ostream& os) { // NOLI
void CodeGenMetal::VisitExpr_(const BroadcastNode* op, std::ostream& os) { //
NOLINT(*)
std::string v = PrintExpr(op->value);
+ int lanes = op->dtype.lanes();
PrintType(op->dtype, os);
os << "(";
- for (int i = 0; i < op->lanes; ++i) {
+ for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
diff --git a/src/target/source/codegen_opencl.cc
b/src/target/source/codegen_opencl.cc
index da6a4de619..f17a452d5c 100644
--- a/src/target/source/codegen_opencl.cc
+++ b/src/target/source/codegen_opencl.cc
@@ -472,10 +472,11 @@ void CodeGenOpenCL::VisitExpr_(const CallNode* op,
std::ostream& os) {
void CodeGenOpenCL::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
// NOLINT(*)
std::string v = PrintExpr(op->value);
+ int lanes = op->dtype.lanes();
os << "((";
PrintType(op->dtype, os);
os << ")(";
- for (int i = 0; i < op->lanes; ++i) {
+ for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
@@ -486,10 +487,11 @@ void CodeGenOpenCL::VisitExpr_(const RampNode* op,
std::ostream& os) { // NOLIN
os << "((";
PrintType(op->dtype, os);
os << ")(";
- for (int i = 0; i < op->lanes; i++) {
+ int lanes = op->dtype.lanes();
+ for (int i = 0; i < lanes; i++) {
os << "(" << PrintExpr(op->base) << ")"
<< "+(" << PrintExpr(op->stride) << "*" << i << ")";
- if (i != op->lanes - 1) os << ", ";
+ if (i != lanes - 1) os << ", ";
}
os << "))";
}
diff --git a/src/target/source/codegen_webgpu.cc
b/src/target/source/codegen_webgpu.cc
index 5ede16d2f4..a9a23fb999 100644
--- a/src/target/source/codegen_webgpu.cc
+++ b/src/target/source/codegen_webgpu.cc
@@ -344,9 +344,10 @@ void CodeGenWebGPU::PrintSSAAssign(const std::string&
target, const std::string&
void CodeGenWebGPU::VisitExpr_(const BroadcastNode* op, std::ostream& os) {
// NOLINT(*)
std::string v = PrintExpr(op->value);
+ int lanes = op->dtype.lanes();
PrintType(op->dtype, os);
os << "(";
- for (int i = 0; i < op->lanes; ++i) {
+ for (int i = 0; i < lanes; ++i) {
if (i != 0) os << ", ";
os << v;
}
diff --git a/src/target/spirv/codegen_spirv.cc
b/src/target/spirv/codegen_spirv.cc
index aca504b94b..ddbc22d88a 100644
--- a/src/target/spirv/codegen_spirv.cc
+++ b/src/target/spirv/codegen_spirv.cc
@@ -519,7 +519,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const CallNode* op) {
spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) {
std::vector<spirv::Value> values;
spirv::Value base = MakeValue(op->base);
- for (int i = 0; i < op->lanes; ++i) {
+ int lanes = op->dtype.lanes();
+ for (int i = 0; i < lanes; ++i) {
spirv::Value v = base;
if (i != 0) {
spirv::Value offset = MakeValue(make_const(op->stride.dtype(), i) *
op->stride);
@@ -533,7 +534,8 @@ spirv::Value CodeGenSPIRV::VisitExpr_(const RampNode* op) {
spirv::Value CodeGenSPIRV::VisitExpr_(const BroadcastNode* op) {
std::vector<spirv::Value> values;
spirv::Value v = MakeValue(op->value);
- for (int i = 0; i < op->lanes; i++) {
+ int lanes = op->dtype.lanes();
+ for (int i = 0; i < lanes; i++) {
values.push_back(v);
}
return builder_->Concat(values);
diff --git a/src/tir/ir/expr.cc b/src/tir/ir/expr.cc
index 41500051fa..1b611d4534 100644
--- a/src/tir/ir/expr.cc
+++ b/src/tir/ir/expr.cc
@@ -21,10 +21,14 @@
* \file expr.cc
*/
#include <tvm/runtime/registry.h>
+#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
+#include <optional>
+
+#include "../../arith/scalable_expression.h"
#include "../../support/str_escape.h"
#include "buffer_common.h"
@@ -427,18 +431,28 @@ TVM_REGISTER_GLOBAL("tir.Select")
TVM_REGISTER_NODE_TYPE(SelectNode);
// Ramp
-Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes, Span span) {
+Ramp::Ramp(PrimExpr base, PrimExpr stride, PrimExpr lanes, Span span) {
ICHECK(base.defined());
ICHECK(stride.defined());
ICHECK(base.dtype().is_scalar());
ICHECK(stride.dtype().is_scalar());
- ICHECK_GT(lanes, 1);
if (stride.dtype() != base.dtype()) {
stride = cast(base.dtype(), stride);
}
ObjectPtr<RampNode> node = make_object<RampNode>();
- node->dtype = base.dtype().with_lanes(lanes);
+ auto* lanes_as_int = lanes.as<IntImmNode>();
+ if (lanes_as_int) {
+ int lanes = static_cast<int>(lanes_as_int->value);
+ ICHECK_GT(lanes, 1);
+ node->dtype = base.dtype().with_lanes(lanes);
+ } else { /* scalable vector */
+ std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
+ ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes;
+
+ node->dtype =
base.dtype().with_scalable_vscale_factor(vscale_factor.value());
+ lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}),
vscale_factor.value());
+ }
node->base = base;
node->stride = stride;
node->lanes = lanes;
@@ -447,27 +461,37 @@ Ramp::Ramp(PrimExpr base, PrimExpr stride, int lanes,
Span span) {
}
TVM_REGISTER_GLOBAL("tir.Ramp")
- .set_body_typed([](PrimExpr base, PrimExpr stride, int lanes, Span span) {
+ .set_body_typed([](PrimExpr base, PrimExpr stride, PrimExpr lanes, Span
span) {
return Ramp(base, stride, lanes, span);
});
TVM_REGISTER_NODE_TYPE(RampNode);
// Broadcast
-Broadcast::Broadcast(PrimExpr value, int lanes, Span span) {
+Broadcast::Broadcast(PrimExpr value, PrimExpr lanes, Span span) {
ICHECK(value.defined());
ICHECK(value.dtype().is_scalar());
- ICHECK_GT(lanes, 1);
ObjectPtr<BroadcastNode> node = make_object<BroadcastNode>();
- node->dtype = value.dtype().with_lanes(lanes);
+ auto* lanes_int = lanes.as<IntImmNode>();
+ if (lanes_int) {
+ int lanes = static_cast<int>(lanes_int->value);
+ ICHECK_GT(lanes, 1);
+ node->dtype = value.dtype().with_lanes(lanes);
+ } else { /* scalable vector */
+ std::optional<int> vscale_factor = arith::ExtractVscaleFactor(lanes);
+ ICHECK(vscale_factor) << "Invalid expression for scalable lanes " << lanes;
+
+ node->dtype =
value.dtype().with_scalable_vscale_factor(vscale_factor.value());
+ lanes = Mul(Call(DataType::Int(32), tir::builtin::vscale(), {}),
vscale_factor.value());
+ }
node->value = std::move(value);
node->lanes = lanes;
node->span = std::move(span);
data_ = node;
}
-TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value, int
lanes, Span span) {
+TVM_REGISTER_GLOBAL("tir.Broadcast").set_body_typed([](PrimExpr value,
PrimExpr lanes, Span span) {
return Broadcast(value, lanes, span);
});
@@ -525,8 +549,8 @@ TVM_REGISTER_GLOBAL("tir.Call")
for (Range r : br->region) {
if (is_one(r->extent)) {
indices.push_back(r->min);
- } else if (const auto* extent = r->extent.as<IntImmNode>()) {
- indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype,
1), extent->value));
+ } else if (r->extent.as<IntImmNode>()) {
+ indices.push_back(tir::Ramp(r->min, make_const(r->min->dtype,
1), r->extent));
} else {
LOG(FATAL) << "ValueError: Cannot convert to BufferLoad: "
<< GetRef<BufferRegion>(br);
@@ -714,10 +738,26 @@ void BufferLoadNode::LegalizeDType() {
<< "Only the last index of a buffer access may be a vector type.";
}
- int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
- int buffer_lanes = buffer->dtype.lanes();
-
- this->dtype = buffer->dtype.with_lanes(index_lanes * buffer_lanes);
+ if (indices.empty()) {
+ this->dtype = buffer->dtype;
+ } else {
+ auto index_dtype = indices.back().dtype();
+ bool is_buffer_dtype_scalable = buffer->dtype.is_scalable_vector();
+ bool is_index_scalable = index_dtype.is_scalable_vector();
+
+ ICHECK(!(is_index_scalable && is_buffer_dtype_scalable))
+ << "Index dtype and buffer dtype can't both be scalable.";
+
+ if (is_index_scalable) {
+ this->dtype =
buffer->dtype.with_scalable_vscale_factor(index_dtype.vscale_factor() *
+
buffer->dtype.lanes());
+ } else if (is_buffer_dtype_scalable) {
+ this->dtype =
buffer->dtype.with_scalable_vscale_factor(buffer->dtype.vscale_factor() *
+
index_dtype.lanes());
+ } else {
+ this->dtype = buffer->dtype.with_lanes(index_dtype.lanes() *
buffer->dtype.lanes());
+ }
+ }
}
BufferLoad::BufferLoad(Buffer buffer, Array<PrimExpr> indices, Span span) {
diff --git a/src/tir/ir/expr_functor.cc b/src/tir/ir/expr_functor.cc
index 8a93d9dd82..089a1d31e7 100644
--- a/src/tir/ir/expr_functor.cc
+++ b/src/tir/ir/expr_functor.cc
@@ -258,19 +258,21 @@ PrimExpr ExprMutator::VisitExpr_(const SelectNode* op) {
PrimExpr ExprMutator::VisitExpr_(const RampNode* op) {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
- if (base.same_as(op->base) && stride.same_as(op->stride)) {
+ PrimExpr lanes = this->VisitExpr(op->lanes);
+ if (base.same_as(op->base) && stride.same_as(op->stride) &&
lanes.same_as(op->lanes)) {
return GetRef<PrimExpr>(op);
} else {
- return Ramp(base, stride, op->lanes);
+ return Ramp(base, stride, lanes);
}
}
PrimExpr ExprMutator::VisitExpr_(const BroadcastNode* op) {
PrimExpr value = this->VisitExpr(op->value);
- if (value.same_as(op->value)) {
+ PrimExpr lanes = this->VisitExpr(op->lanes);
+ if (value.same_as(op->value) && lanes.same_as(op->lanes)) {
return GetRef<PrimExpr>(op);
} else {
- return Broadcast(value, op->lanes);
+ return Broadcast(value, lanes);
}
}
diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc
index 1d1e674a9d..4774471afc 100644
--- a/src/tir/ir/stmt.cc
+++ b/src/tir/ir/stmt.cc
@@ -469,17 +469,47 @@ BufferStore::BufferStore(Buffer buffer, PrimExpr value,
Array<PrimExpr> indices,
<< "Only the last index of a buffer access may be a vector type.";
}
- int index_lanes = indices.size() ? indices.back().dtype().lanes() : 1;
- int buffer_lanes = buffer->dtype.lanes();
+ bool is_index_scalable = indices.empty() ? false :
indices.back().dtype().is_scalable_vector();
+ bool is_buffer_dtype_scalable = buffer->dtype.is_scalable_vector();
+ bool is_value_dtype_scalable = value.dtype().is_scalable_vector();
- ICHECK_EQ(index_lanes * buffer_lanes, value.dtype().lanes())
- << "Cannot store value with " << value.dtype().lanes() << ", expected
value with "
+ ICHECK(!(is_index_scalable && is_buffer_dtype_scalable))
+ << "Index dtype and buffer dtype can't both be scalable.";
+
+ if (is_index_scalable || is_buffer_dtype_scalable) {
+ ICHECK(is_value_dtype_scalable) << "Can't store non-scalable data into
scalable buffer";
+ }
+
+ int index_lanes;
+ if (indices.empty()) {
+ index_lanes = 1;
+ } else if (is_index_scalable) {
+ index_lanes = indices.back().dtype().vscale_factor();
+ } else {
+ index_lanes = indices.back().dtype().lanes();
+ }
+
+ int buffer_lanes =
+ is_buffer_dtype_scalable ? buffer->dtype.vscale_factor() :
buffer->dtype.lanes();
+ int value_dtype_lanes =
+ is_value_dtype_scalable ? value.dtype().vscale_factor() :
value.dtype().lanes();
+
+ ICHECK_EQ(index_lanes * buffer_lanes, value_dtype_lanes)
+ << "Cannot store value with " << value_dtype_lanes << ", expected value
with "
<< index_lanes * buffer_lanes << " (" << index_lanes << " index lanes *
" << buffer_lanes
<< " buffer element lanes)";
- if (buffer->dtype.with_lanes(buffer_lanes * index_lanes) != value.dtype()) {
+
+ runtime::DataType buffer_dtype;
+ if (is_index_scalable || is_buffer_dtype_scalable) {
+ buffer_dtype = buffer->dtype.with_scalable_vscale_factor(buffer_lanes *
index_lanes);
+ } else {
+ buffer_dtype = buffer->dtype.with_lanes(buffer_lanes * index_lanes);
+ }
+ if (buffer_dtype != value.dtype()) {
LOG(FATAL) << "TypeError: dtype mismatch on BufferStore: " //
<< "buffer's dtype is `" << buffer->dtype //
<< "`, the lanes of indexing are: `" << index_lanes //
+ << "`, the scalability is: `" <<
buffer_dtype.is_scalable_vector()
<< "`, but RHS's dtype is `" << value.dtype() << "`";
}
diff --git a/src/tir/ir/tir_visitor_with_path.cc
b/src/tir/ir/tir_visitor_with_path.cc
index e0996cd72f..a80f2300e2 100644
--- a/src/tir/ir/tir_visitor_with_path.cc
+++ b/src/tir/ir/tir_visitor_with_path.cc
@@ -429,6 +429,7 @@ void TIRVisitorWithPath::VisitExpr_(const SelectNode* op,
ObjectPath path) {
void TIRVisitorWithPath::VisitExpr_(const RampNode* op, ObjectPath path) {
Visit(op->base, path->Attr("base"));
Visit(op->stride, path->Attr("stride"));
+ Visit(op->lanes, path->Attr("lanes"));
}
void TIRVisitorWithPath::VisitExpr_(const ShuffleNode* op, ObjectPath path) {
@@ -438,6 +439,7 @@ void TIRVisitorWithPath::VisitExpr_(const ShuffleNode* op,
ObjectPath path) {
void TIRVisitorWithPath::VisitExpr_(const BroadcastNode* op, ObjectPath path) {
Visit(op->value, path->Attr("value"));
+ Visit(op->lanes, path->Attr("lanes"));
}
} // namespace tir
diff --git a/src/tir/op/op.cc b/src/tir/op/op.cc
index 9f35f73a62..b329d25b54 100644
--- a/src/tir/op/op.cc
+++ b/src/tir/op/op.cc
@@ -32,6 +32,7 @@
#include <cmath>
// Centralized header for constant folders.
#include "../../arith/const_fold.h"
+#include "../../arith/scalable_expression.h"
#include "../../target/datatype/registry.h"
namespace tvm {
@@ -122,20 +123,45 @@ PrimExpr q_multiply_shift(PrimExpr x, PrimExpr y,
PrimExpr q, PrimExpr s, Span s
{x, y, q, s}, span);
}
+void BroadcastToMatchLanes(PrimExpr& op_a, PrimExpr& op_b) { // NOLINT(*)
+ DataType dtype_a = op_a.dtype();
+ DataType dtype_b = op_b.dtype();
+
+ if (!dtype_a.is_scalable_or_fixed_length_vector() &&
+ dtype_b.is_scalable_or_fixed_length_vector()) {
+ if (dtype_b.is_scalable_vector()) {
+ op_a = tir::Broadcast(
+ op_a, tir::Mul(dtype_b.vscale_factor(), Call(DataType::Int(32),
builtin::vscale(), {})));
+ } else {
+ op_a = tir::Broadcast(op_a, dtype_b.lanes());
+ }
+ }
+}
+
// The public function with a quick checking path.
void BinaryOpMatchTypes(PrimExpr& lhs, PrimExpr& rhs, Span span) { //
NOLINT(*)
CHECK(lhs.defined()) << "ValueError: `lhs` is null in the binary operator";
CHECK(rhs.defined()) << "ValueError: `rhs` is null in the binary operator";
if (lhs.dtype() == rhs.dtype()) return;
+
+ BroadcastToMatchLanes(lhs, rhs);
+ BroadcastToMatchLanes(rhs, lhs);
+
DataType ltype = lhs.dtype();
DataType rtype = rhs.dtype();
- if (ltype.lanes() == 1 && rtype.lanes() != 1) {
- lhs = tir::Broadcast(lhs, rtype.lanes());
- } else if (rtype.lanes() == 1 && ltype.lanes() != 1) {
- rhs = tir::Broadcast(rhs, ltype.lanes());
+
+ ICHECK(ltype.is_scalable_vector() == rtype.is_scalable_vector())
+ << "Can't match scalable and fixed length vectors";
+
+ bool lanes_match = false;
+
+ if (ltype.is_scalable_vector()) {
+ lanes_match = ltype.vscale_factor() == rtype.vscale_factor();
} else {
- ICHECK(ltype.lanes() == rtype.lanes()) << "Cannot match type " << ltype <<
" vs " << rtype;
+ lanes_match = ltype.lanes() == rtype.lanes();
}
+
+ ICHECK(lanes_match) << "Cannot match type " << ltype << " vs " << rtype;
if (lhs.dtype() == rhs.dtype()) return;
ltype = lhs.dtype();
@@ -326,7 +352,7 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span span)
{
return tir::Cast(t, value, span);
} else {
DataType vtype = t.element_of();
- if (value.dtype().lanes() == 1) {
+ if (!value.dtype().is_scalable_or_fixed_length_vector()) {
// manually unroll cast
if (value.dtype() != vtype) {
if (const IntImmNode* op = value.as<IntImmNode>()) {
@@ -337,11 +363,25 @@ PrimExpr cast(const DataType& t, PrimExpr value, Span
span) {
value = tir::Cast(vtype, value, span);
}
}
- return tir::Broadcast(value, t.lanes(), span);
- } else {
- ICHECK(value.dtype().lanes() == t.lanes());
+ if (t.is_scalable_vector()) {
+ return tir::Broadcast(
+ value, tir::Mul(t.vscale_factor(), Call(DataType::Int(32),
builtin::vscale(), {})),
+ span);
+ } else {
+ return tir::Broadcast(value, t.lanes(), span);
+ }
+ } else { /* value is a vector */
+ ICHECK(value.dtype().is_scalable_vector() == t.is_scalable_vector());
+
+ bool lanes_match = false;
+ if (value.dtype().is_scalable_vector()) {
+ lanes_match = value.dtype().vscale_factor() == t.vscale_factor();
+ } else {
+ lanes_match = value.dtype().lanes() == t.lanes();
+ }
+ ICHECK(lanes_match);
if (const auto* broadcast = value.as<tir::BroadcastNode>()) {
- return tir::Broadcast(cast(vtype, broadcast->value, span), t.lanes(),
span);
+ return tir::Broadcast(cast(vtype, broadcast->value, span),
broadcast->lanes, span);
} else if (const auto* ramp = value.as<tir::RampNode>()) {
if (t.is_int() || t.is_uint()) {
// only cast to index data type can be folded to ramp
@@ -534,6 +574,7 @@ PrimExpr operator==(PrimExpr a, PrimExpr b) { return
equal(a, b); }
PrimExpr equal(PrimExpr a, PrimExpr b, Span span) {
BinaryOpMatchTypes(a, b, span);
if (auto ret = arith::TryConstFold<tir::EQ>(a, b)) return ret.value();
+ if (arith::IsVScaleCall(a) && arith::IsVScaleCall(b)) return true;
return tir::EQ(a, b, span);
}
diff --git a/src/tir/schedule/analysis/reducer.cc
b/src/tir/schedule/analysis/reducer.cc
index d8d1e8fc25..5f8af84186 100644
--- a/src/tir/schedule/analysis/reducer.cc
+++ b/src/tir/schedule/analysis/reducer.cc
@@ -178,16 +178,14 @@ class PatternMatcher : public ExprVisitor {
if (ptr == nullptr) {
match_success_ = false;
} else {
- if (op->lanes != ptr->lanes) {
- match_success_ = false;
- } else {
- PrimExpr tmp = expr_to_match_;
- expr_to_match_ = ptr->base;
- VisitExpr(op->base);
- expr_to_match_ = ptr->stride;
- VisitExpr(op->stride);
- std::swap(expr_to_match_, tmp);
- }
+ PrimExpr tmp = expr_to_match_;
+ expr_to_match_ = ptr->base;
+ VisitExpr(op->base);
+ expr_to_match_ = ptr->stride;
+ VisitExpr(op->stride);
+ expr_to_match_ = ptr->lanes;
+ VisitExpr(op->lanes);
+ std::swap(expr_to_match_, tmp);
}
}
@@ -196,14 +194,12 @@ class PatternMatcher : public ExprVisitor {
if (ptr == nullptr) {
match_success_ = false;
} else {
- if (op->lanes != ptr->lanes) {
- match_success_ = false;
- } else {
- PrimExpr tmp = expr_to_match_;
- expr_to_match_ = ptr->value;
- VisitExpr(op->value);
- std::swap(expr_to_match_, tmp);
- }
+ PrimExpr tmp = expr_to_match_;
+ expr_to_match_ = ptr->value;
+ VisitExpr(op->value);
+ expr_to_match_ = ptr->lanes;
+ VisitExpr(op->lanes);
+ std::swap(expr_to_match_, tmp);
}
}
diff --git a/src/tir/transforms/bound_checker.cc
b/src/tir/transforms/bound_checker.cc
index f5aa6773e6..358f864d3a 100644
--- a/src/tir/transforms/bound_checker.cc
+++ b/src/tir/transforms/bound_checker.cc
@@ -34,6 +34,8 @@
#include <utility>
#include <vector>
+#include "../../arith/unwrap_vector_expr.h"
+
namespace tvm {
namespace tir {
@@ -156,7 +158,12 @@ class BoundChecker : public StmtExprMutator {
if (!IsValidScalar(ramp_index->stride)) {
return false;
}
- if (ramp_index->lanes <= 0) {
+ bool lanes_int = ramp_index->lanes->IsInstance<IntImmNode>();
+ if (!lanes_int) {
+ return false;
+ }
+ int lanes =
static_cast<int>(Downcast<IntImm>(ramp_index->lanes)->value);
+ if (lanes <= 0) {
return false;
}
}
@@ -192,11 +199,7 @@ class BoundChecker : public StmtExprMutator {
PrimExpr upper_bound = shape[i];
if (const RampNode* ramp_index = index.as<RampNode>()) {
- // In case index is base + stride * i.
- // Non inclusive range.
- index = Add(ramp_index->base,
- Mul(ramp_index->stride,
- make_const(ramp_index->stride.dtype(),
ramp_index->lanes - 1)));
+ index = arith::UnwrapVectorExpr(GetRef<Ramp>(ramp_index),
ramp_index->lanes);
}
// Try to simplify index and bound.
diff --git a/src/tir/transforms/lower_thread_allreduce.cc
b/src/tir/transforms/lower_thread_allreduce.cc
index 7094d6adaf..37d8f67580 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -730,7 +730,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
(std::any_of(types.begin(), types.end(), [](DataType ty) {
- if (ty.is_vector()) return ty.bits() * ty.lanes() != 32;
+ if (ty.is_fixed_length_vector()) return ty.bits() * ty.lanes() != 32;
return ty.bits() != 32;
}))) {
return false;
@@ -740,7 +740,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator
{
// {u}int, {u}long, {u}long long, float, double, half/half2
if (std::any_of(types.begin(), types.end(), [](DataType ty) {
if (ty.is_float16()) return ty.lanes() > 2;
- if (ty.is_vector()) return true;
+ if (ty.is_fixed_length_vector()) return true;
return ty.bytes() < 4 || ty.bytes() > 8;
})) {
return false;
diff --git a/src/tir/transforms/renormalize_split_pattern.cc
b/src/tir/transforms/renormalize_split_pattern.cc
index eb596beb18..beb5997d49 100644
--- a/src/tir/transforms/renormalize_split_pattern.cc
+++ b/src/tir/transforms/renormalize_split_pattern.cc
@@ -63,7 +63,7 @@ class SplitPatternReNormalizer : public IRMutatorWithAnalyzer
{
// Pattern var match IntImm
PVar<IntImm> c1, c2, c3;
// Pattern var for lanes in broadcast and ramp
- PVar<int> lanes;
+ PVar<PrimExpr> lanes;
// floordiv(floormod(x, c1 * c2), c2) = floormod(floordiv(x, c2), c1)
TRY_RECURSIVE_REWRITE_IF(floordiv(floormod(x, c3), c2),
diff --git a/src/tir/transforms/storage_rewrite.cc
b/src/tir/transforms/storage_rewrite.cc
index dd27397f36..e40f683e21 100644
--- a/src/tir/transforms/storage_rewrite.cc
+++ b/src/tir/transforms/storage_rewrite.cc
@@ -1327,9 +1327,12 @@ class VectorTypeAccessChecker : public StmtExprVisitor {
if (indices.size()) {
const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>();
if (ramp_index && is_one(ramp_index->stride)) {
- arith::ModularSet me = analyzer_.modular_set(ramp_index->base);
- if ((me->coeff % ramp_index->lanes == 0) && (me->base %
ramp_index->lanes == 0)) {
- lanes_used = ramp_index->lanes;
+ if (ramp_index->lanes->IsInstance<IntImmNode>()) {
+ int lanes =
static_cast<int>(Downcast<IntImm>(ramp_index->lanes)->value);
+ arith::ModularSet me = analyzer_.modular_set(ramp_index->base);
+ if ((me->coeff % lanes == 0) && (me->base % lanes == 0)) {
+ lanes_used = lanes;
+ }
}
}
}
@@ -1462,13 +1465,13 @@ class VectorTypeRewriter : public StmtExprMutator {
Array<PrimExpr> indices = node->indices;
const PrimExpr& last_dim_index = indices[indices.size() - 1];
- if (const RampNode* ramp_index = last_dim_index.as<RampNode>();
- ramp_index && is_one(ramp_index->stride)) {
- PrimExpr new_index =
- ramp_index->base / make_const(ramp_index->base.dtype(),
ramp_index->lanes);
- if (ramp_index->lanes != info.factor()) {
- ICHECK(info.factor() && ramp_index->lanes % info.factor() == 0);
- int new_lanes = ramp_index->lanes / info.factor();
+ const RampNode* ramp_index = indices[indices.size() - 1].as<RampNode>();
+ if (ramp_index && is_one(ramp_index->stride) &&
ramp_index->lanes->IsInstance<IntImmNode>()) {
+ int lanes = static_cast<int>(Downcast<IntImm>(ramp_index->lanes)->value);
+ PrimExpr new_index = ramp_index->base /
make_const(ramp_index->base.dtype(), lanes);
+ if (lanes != info.factor()) {
+ ICHECK(info.factor() && lanes % info.factor() == 0);
+ int new_lanes = lanes / info.factor();
new_index = Ramp(new_index * new_lanes, ramp_index->stride, new_lanes,
ramp_index->span);
}
indices.Set(indices.size() - 1, new_index);
diff --git a/src/tir/transforms/vectorize_loop.cc
b/src/tir/transforms/vectorize_loop.cc
index b80a71aa31..fe589bede6 100644
--- a/src/tir/transforms/vectorize_loop.cc
+++ b/src/tir/transforms/vectorize_loop.cc
@@ -38,10 +38,13 @@
namespace tvm {
namespace tir {
+// TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
inline PrimExpr BroadcastTo(PrimExpr e, int lanes) {
if (e.dtype().lanes() == lanes) return e;
if (const BroadcastNode* op = e.as<BroadcastNode>()) {
- if (lanes % op->lanes == 0) {
+ ICHECK(!e.dtype().is_scalable_vector());
+ int broadcast_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
+ if (lanes % broadcast_lanes == 0) {
return Broadcast(op->value, lanes);
}
}
@@ -180,15 +183,18 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
if (a.same_as(op->a) && b.same_as(op->b)) {
return GetRef<PrimExpr>(op);
} else {
+ // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
int lanes = std::max(a.dtype().lanes(), b.dtype().lanes());
if (lanes != 1) {
const RampNode* b_ramp = b.as<RampNode>();
const RampNode* a_ramp = a.as<RampNode>();
if (a_ramp && b.dtype().lanes() == 1 && analyzer_.CanProve(b > 0)) {
- return Ramp(a_ramp->base * b, a_ramp->stride * b, a_ramp->lanes);
+ int lanes = static_cast<int>(Downcast<IntImm>(a_ramp->lanes)->value);
+ return Ramp(a_ramp->base * b, a_ramp->stride * b, lanes);
}
if (b_ramp && a.dtype().lanes() == 1 && analyzer_.CanProve(a > 0)) {
- return Ramp(b_ramp->base * a, b_ramp->stride * a, b_ramp->lanes);
+ int lanes = static_cast<int>(Downcast<IntImm>(b_ramp->lanes)->value);
+ return Ramp(b_ramp->base * a, b_ramp->stride * a, lanes);
}
}
return Mul(BroadcastTo(a, lanes), BroadcastTo(b, lanes));
@@ -222,10 +228,13 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
PrimExpr VisitExpr_(const RampNode* op) final {
PrimExpr base = this->VisitExpr(op->base);
PrimExpr stride = this->VisitExpr(op->stride);
+ // TODO(ekalda): P5 in https://github.com/apache/tvm/issues/16455
+ int op_lanes = static_cast<int>(Downcast<IntImm>(op->lanes)->value);
if (base.dtype().lanes() > 1 && stride.dtype().lanes() == 1) {
const RampNode* base_ramp = base.as<RampNode>();
- if (analyzer_.CanProve(base_ramp->stride == stride *
make_const(stride.dtype(), op->lanes))) {
- return Ramp(base_ramp->base, stride, op->lanes * base_ramp->lanes);
+ int base_ramp_lanes =
static_cast<int>(Downcast<IntImm>(base_ramp->lanes)->value);
+ if (analyzer_.CanProve(base_ramp->stride == stride *
make_const(stride.dtype(), op_lanes))) {
+ return Ramp(base_ramp->base, stride, op_lanes * base_ramp_lanes);
}
}
int lanes = std::max(base.dtype().lanes(), stride.dtype().lanes());
@@ -295,7 +304,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
// IfThenElse expr
PrimExpr MutateIfThenElseExpr_(const CallNode* op) {
PrimExpr cond = this->VisitExpr(op->args[0]);
- if (cond.dtype().is_vector()) {
+ if (cond.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
@@ -337,7 +346,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
Array<PrimExpr> new_args;
for (auto arg : op->args) {
auto new_arg = this->VisitExpr(arg);
- if (new_arg.dtype().is_vector()) {
+ if (new_arg.dtype().is_scalable_or_fixed_length_vector()) {
need_scalarize_ = true;
return GetRef<PrimExpr>(op);
}
@@ -449,9 +458,9 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
LOG(WARNING) << "Detect vectorize inside vectorized loop, ignoring...";
}
ICHECK(is_zero(op->min));
- ICHECK(!op->extent.dtype().is_vector());
+ ICHECK(!op->extent.dtype().is_scalable_or_fixed_length_vector());
PrimExpr extent = this->VisitExpr(op->extent);
- if (extent.dtype().is_vector()) {
+ if (extent.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
}
Stmt body = this->VisitStmt(op->body);
@@ -464,9 +473,9 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
}
// IfThenElse
Stmt VisitStmt_(const IfThenElseNode* op) final {
- ICHECK(!op->condition.dtype().is_vector());
+ ICHECK(!op->condition.dtype().is_scalable_or_fixed_length_vector());
PrimExpr condition = this->VisitExpr(op->condition);
- if (condition.dtype().is_vector()) {
+ if (condition.dtype().is_scalable_or_fixed_length_vector()) {
return Scalarize(GetRef<Stmt>(op));
}
Stmt then_case = this->VisitStmt(op->then_case);
@@ -509,7 +518,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
Stmt VisitStmt_(const AllocateNode* op) final {
// Mutate the condition
PrimExpr condition = this->VisitExpr(op->condition);
- if (condition.dtype().is_vector()) {
+ if (condition.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " <<
op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
@@ -518,7 +527,7 @@ class Vectorizer : public StmtMutator, public
ExprFunctor<PrimExpr(const PrimExp
Array<PrimExpr> extents;
for (const auto& extent : op->extents) {
PrimExpr new_ext = this->VisitExpr(extent);
- if (new_ext.dtype().is_vector()) {
+ if (new_ext.dtype().is_scalable_or_fixed_length_vector()) {
LOG(WARNING) << "Cannot handle vector extent in alloc of " <<
op->buffer_var->name_hint;
return Scalarize(GetRef<Stmt>(op));
}
diff --git a/tests/cpp/pattern_match_test.cc b/tests/cpp/pattern_match_test.cc
index 2e386c48b7..70f806e241 100644
--- a/tests/cpp/pattern_match_test.cc
+++ b/tests/cpp/pattern_match_test.cc
@@ -27,9 +27,11 @@ TEST(Pattern, Basic) {
using namespace tvm::tir;
using namespace tvm::arith;
tvm::tir::Var x("x"), y("y"), z("z");
+ PrimExpr scalable_lanes = Mul(Call(DataType::Int(32), builtin::vscale(),
{}), 4);
arith::PVar<PrimExpr> px, py, pz;
arith::PVar<DataType> pt;
- arith::PVar<int> planes;
+ arith::PVar<PrimExpr> planes;
+ arith::PCallExpr<PVscaleOp> vscale;
// arithmetics
auto r = 1 + (y + 1);
@@ -110,14 +112,18 @@ TEST(Pattern, Basic) {
// ramp pattern
{
ICHECK(ramp(px, PConst<PrimExpr>(1), planes).Match(tir::Ramp(x, 1, 10)));
- ICHECK(planes.Eval() == 10);
+ ICHECK(planes.Eval().as<IntImmNode>()->value == 10);
+ ICHECK(ramp(px, PConst<PrimExpr>(1), planes).Match(tir::Ramp(x, 1,
scalable_lanes)));
+ ICHECK((vscale * PConst<PrimExpr>(4)).Match(planes.Eval()));
ICHECK(!ramp(px, PConst<PrimExpr>(1), planes).Match(tir::Ramp(x, 2, 10)));
}
// broadcast pattern
{
ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, 10)));
- ICHECK(planes.Eval() == 10);
+ ICHECK(planes.Eval().as<IntImmNode>()->value == 10);
ICHECK(broadcast(px * py, planes).Match(tir::Broadcast(x * 10, 10)));
+ ICHECK(broadcast(px, planes).Match(tir::Broadcast(x, scalable_lanes)));
+ ICHECK((vscale * PConst<PrimExpr>(4)).Match(planes.Eval()));
}
}
diff --git a/tests/cpp/tir_scalable_datatype.cc
b/tests/cpp/tir_scalable_datatype.cc
new file mode 100644
index 0000000000..daa4dfe729
--- /dev/null
+++ b/tests/cpp/tir_scalable_datatype.cc
@@ -0,0 +1,125 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <gmock/gmock.h>
+#include <gtest/gtest.h>
+#include <llvm/IR/Intrinsics.h>
+#include <tvm/runtime/data_type.h>
+#include <tvm/tir/builtin.h>
+#include <tvm/tir/expr.h>
+
+using ::testing::HasSubstr;
+
+// ---------
+// Data Type
+// ---------
+TEST(TIR, TestCreateScalableType) {
+ tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
+ ASSERT_EQ(scalable_type.code(), kDLInt);
+ ASSERT_EQ(scalable_type.bits(), 32);
+ ASSERT_EQ(scalable_type.vscale_factor(), 4);
+ ASSERT_TRUE(scalable_type.is_scalable_vector());
+ ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
+}
+
+TEST(TIR, TestScalableWithBits) {
+ tvm::DataType scalable_type = tvm::DataType(kDLInt, 1, 8, true);
+ scalable_type = scalable_type.with_bits(32);
+ ASSERT_EQ(scalable_type.bits(), 32);
+ ASSERT_TRUE(scalable_type.is_scalable_vector());
+ ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
+}
+
+TEST(TIR, TestScalableWithVscaleFactor) {
+ tvm::DataType type = tvm::DataType(kDLInt, 32, 1);
+ tvm::DataType scalable_type = type.with_scalable_vscale_factor(4);
+ ASSERT_EQ(scalable_type.vscale_factor(), 4);
+ ASSERT_TRUE(scalable_type.is_scalable_vector());
+ ASSERT_TRUE(scalable_type.is_scalable_or_fixed_length_vector());
+}
+
+TEST(TIR, TestAssignScalableDataType) {
+ tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 2, true);
+ tvm::DataType scalable_type_copy = scalable_type;
+ ASSERT_TRUE(scalable_type_copy.is_scalable_vector());
+ ASSERT_TRUE(scalable_type_copy.is_scalable_or_fixed_length_vector());
+}
+
+TEST(TIR, TestScalableDataTypeAndNonScalableDataTypeInequality) {
+ ASSERT_FALSE(tvm::DataType(kDLInt, 32, 4, true) == tvm::DataType(kDLInt, 32,
4));
+}
+
+TEST(TIR, TestGetScalableVectorBytesError) {
+ tvm::DataType scalable_type = tvm::DataType(kDLInt, 32, 4, true);
+ EXPECT_THROW(
+ {
+ try {
+ tvm::runtime::GetVectorBytes(scalable_type);
+ } catch (const tvm::InternalError& e) {
+ EXPECT_THAT(e.what(),
+ HasSubstr("Can't fetch the lanes of a scalable vector at
a compile time"));
+ throw;
+ }
+ },
+ tvm::InternalError);
+}
+
+TEST(TIR, TestScalableDataTypeInvalidLanesError) {
+ EXPECT_THROW(
+ {
+ try {
+ tvm::DataType(kDLFloat, 62, 1, true);
+ } catch (const tvm::InternalError& e) {
+ EXPECT_THAT(e.what(), HasSubstr("Invalid value for vscale factor"));
+ throw;
+ }
+ },
+ tvm::InternalError);
+}
+
+TEST(TIR, TestScalableDataTypeInvalidVscaleFactorAccess) {
+ tvm::DataType fixed_length_type = tvm::DataType(kDLFloat, 32, 4);
+ ASSERT_TRUE(fixed_length_type.is_fixed_length_vector());
+ ASSERT_TRUE(fixed_length_type.is_scalable_or_fixed_length_vector());
+ EXPECT_THROW(
+ {
+ try {
+ fixed_length_type.vscale_factor();
+ } catch (const tvm::InternalError& e) {
+ EXPECT_THAT(e.what(), HasSubstr("A fixed length vector doesn't have
a vscale factor"));
+ throw;
+ }
+ },
+ tvm::InternalError);
+}
+
+TEST(TIR, TestScalableDataTypeInvalidLanesAccess) {
+ tvm::DataType scalable_type = tvm::DataType(kDLFloat, 32, 4, true);
+ EXPECT_THROW(
+ {
+ try {
+ scalable_type.lanes();
+ } catch (const tvm::InternalError& e) {
+ EXPECT_THAT(e.what(),
+ HasSubstr("Can't fetch the lanes of a scalable vector at
a compile time"));
+ throw;
+ }
+ },
+ tvm::InternalError);
+}
diff --git a/tests/python/arith/test_arith_intset.py
b/tests/python/arith/test_arith_intset.py
index 5b99115148..18865a73df 100644
--- a/tests/python/arith/test_arith_intset.py
+++ b/tests/python/arith/test_arith_intset.py
@@ -54,6 +54,14 @@ def test_vector():
assert s.max_value.value == base + stride * (lanes - 1)
+def test_scalable_vector():
+ base = 5
+ s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, 2, tvm.tir.vscale() * 4))
+
+ assert s.min_value.value == base
+ assert s.max_value.same_as(tvm.arith.int_set.pos_inf())
+
+
def test_add_sub():
ck = IntSetChecker()
x, y = te.var("x"), te.var("y")
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index 5b06275422..6433dc2dec 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -80,8 +80,15 @@ class TestVector(BaseCompare):
TestCase(tvm.tir.Ramp(x, 1, 4) + tvm.tir.Ramp(y, 2, 4), tvm.tir.Ramp(x
+ y, 3, 4)),
TestCase(tvm.tir.Ramp(x, 1, 2) + y, tvm.tir.Ramp(x + y, 1, 2)),
TestCase(y + tvm.tir.Ramp(x, 1, 2), tvm.tir.Ramp(y + x, 1, 2)),
+ TestCase(
+ tvm.tir.Ramp(x, 1, tir.vscale() * 4) + tvm.tir.Ramp(y, 2,
tir.vscale() * 4),
+ tvm.tir.Ramp(x + y, 3, tir.vscale() * 4),
+ ),
TestCase(y.astype("int32x2") + x.astype("int32x2"), (y +
x).astype("int32x2")),
TestCase(tvm.tir.Broadcast(0, 4) + y, tvm.tir.Broadcast(y, 4)),
+ TestCase(
+ tvm.tir.Broadcast(0, tir.vscale() * 8) + y, tvm.tir.Broadcast(y,
tir.vscale() * 8)
+ ),
TestCase(
tvm.tir.Ramp(x, 1, 4).astype("float32x4") + tvm.tir.Broadcast(0.0,
4),
tvm.tir.Ramp(x, 1, 4).astype("float32x4"),
@@ -101,21 +108,38 @@ class TestVector(BaseCompare):
# trunc div
TestCase(tdiv(y.astype("int32x2"), x.astype("int32x2")), tdiv(y,
x).astype("int32x2")),
TestCase(tdiv(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(tdiv(x, 2), 2,
4)),
+ TestCase(
+ tdiv(tvm.tir.Ramp(x, 4, tir.vscale() * 5), 2),
+ tvm.tir.Ramp(tdiv(x, 2), 2, tir.vscale() * 5),
+ ),
TestCase(tdiv(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), x.astype("int32x4"),
x >= 0),
TestCase(tdiv(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), tdiv(tvm.tir.Ramp(x
* 8 + 15, 1, 4), 8)),
# trunc mod
TestCase(tmod(y.astype("int32x2"), x.astype("int32x2")), tmod(y,
x).astype("int32x2")),
TestCase(tmod(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(tmod(x, 2),
4)),
TestCase(tmod(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1,
4), x >= 0),
+ TestCase(
+ tmod(tvm.tir.Ramp(x * 8 + 1, 1, tir.vscale() * 4), 8),
+ tmod(tvm.tir.Ramp(1, 1, tir.vscale() * 4), 8),
+ x >= 0,
+ ),
TestCase(tmod(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), tmod(tvm.tir.Ramp(1,
15, 4), 8), x >= 0),
# floor div
TestCase(fld(y.astype("int32x2"), x.astype("int32x2")), fld(y,
x).astype("int32x2")),
TestCase(fld(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Ramp(fld(x, 2), 2, 4)),
+ TestCase(
+ fld(tvm.tir.Ramp(x, 4, tir.vscale() * 4), 2),
+ tvm.tir.Ramp(fld(x, 2), 2, tir.vscale() * 4),
+ ),
TestCase(fld(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), (x).astype("int32x4")),
TestCase(fld(tvm.tir.Ramp(x * 8 + 15, 1, 4), 8), fld(tvm.tir.Ramp(x *
8 + 15, 1, 4), 8)),
TestCase(
fld(tvm.tir.Ramp(x, 8, 5), tvm.tir.Broadcast(4, 5)),
tvm.tir.Ramp(fld(x, 4), 2, 5)
),
+ TestCase(
+ fld(tvm.tir.Ramp(x, 8, tir.vscale() * 4), tvm.tir.Broadcast(4,
tir.vscale() * 4)),
+ tvm.tir.Ramp(fld(x, 4), 2, tir.vscale() * 4),
+ ),
TestCase(
fld(tvm.tir.Ramp(flm(x * 4, 256), 1, 4), tvm.tir.Broadcast(8, 4)),
tvm.tir.Broadcast(fld(flm(x * 4, 256), 8), 4),
@@ -127,6 +151,10 @@ class TestVector(BaseCompare):
TestCase(
fld(tvm.tir.Ramp(x * 8, 1, 4), tvm.tir.Broadcast(4, 4)),
tvm.tir.Broadcast(x * 2, 4)
),
+ TestCase(
+ fld(tvm.tir.Ramp(x * 8, 1, tir.vscale() * 4), tvm.tir.Broadcast(4,
tir.vscale() * 4)),
+ fld(tvm.tir.Ramp(x * 8, 1, tir.vscale() * 4), tvm.tir.Broadcast(4,
tir.vscale() * 4)),
+ ),
TestCase(
fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
fld(tvm.tir.Ramp(x * 8, 3, 4), tvm.tir.Broadcast(4, 4)),
@@ -158,7 +186,15 @@ class TestVector(BaseCompare):
# floor mod
TestCase(flm(y.astype("int32x2"), x.astype("int32x2")), flm(y,
x).astype("int32x2")),
TestCase(flm(tvm.tir.Ramp(x, 4, 4), 2), tvm.tir.Broadcast(flm(x, 2),
4)),
+ TestCase(
+ flm(tvm.tir.Ramp(x, 4, tir.vscale() * 8), 2),
+ tvm.tir.Broadcast(flm(x, 2), tir.vscale() * 8),
+ ),
TestCase(flm(tvm.tir.Ramp(x * 8 + 1, 1, 4), 8), tvm.tir.Ramp(1, 1, 4)),
+ TestCase(
+ flm(tvm.tir.Ramp(x * 8 + 1, 1, tir.vscale() * 4), 8),
+ flm(tvm.tir.Ramp(1, 1, tir.vscale() * 4), 8),
+ ),
TestCase(flm(tvm.tir.Ramp(x * 8 + 1, 15, 4), 8), flm(tvm.tir.Ramp(1,
15, 4), 8)),
TestCase(
flm(tvm.tir.Ramp(x, 8, 4), tvm.tir.Broadcast(4, 4)),
tvm.tir.Broadcast(flm(x, 4), 4)
diff --git a/tests/python/relay/test_json_compact.py
b/tests/python/relay/test_json_compact.py
index 148587e3c9..d4fa17bf8f 100644
--- a/tests/python/relay/test_json_compact.py
+++ b/tests/python/relay/test_json_compact.py
@@ -277,5 +277,76 @@ def test_virtual_device():
assert not func.virtual_device_
+def test_v0_16_ramp_broadcast_lanes():
+ json_graph_v0_15 = {
+ "root": 1,
+ "nodes": [
+ {"type_key": ""},
+ {
+ "type_key": "tir.BufferStore",
+ "attrs": {"buffer": "2", "indices": "16", "span": "0",
"value": "14"},
+ },
+ {
+ "type_key": "tir.Buffer",
+ "attrs": {
+ "axis_separators": "11",
+ "buffer_type": "1",
+ "data": "3",
+ "data_alignment": "64",
+ "dtype": "int32",
+ "elem_offset": "12",
+ "name": "13",
+ "offset_factor": "1",
+ "shape": "8",
+ "span": "0",
+ "strides": "10",
+ },
+ },
+ {
+ "type_key": "tir.Var",
+ "attrs": {"dtype": "handle", "name": "4", "span": "0",
"type_annotation": "5"},
+ },
+ {"type_key": "runtime.String", "repr_str": "buffer"},
+ {"type_key": "PointerType", "attrs": {"element_type": "6",
"storage_scope": "7"}},
+ {"type_key": "PrimType", "attrs": {"dtype": "int32"}},
+ {"type_key": "runtime.String"},
+ {"type_key": "Array", "data": [9]},
+ {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0",
"value": "50"}},
+ {"type_key": "Array"},
+ {"type_key": "Array"},
+ {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0",
"value": "0"}},
+ {"type_key": "runtime.String", "repr_str": "buffer"},
+ {
+ "type_key": "tir.Broadcast",
+ "attrs": {"dtype": "int32x12", "lanes": "12", "span": "0",
"value": "15"},
+ },
+ {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0",
"value": "3"}},
+ {"type_key": "Array", "data": [17]},
+ {
+ "type_key": "tir.Ramp",
+ "attrs": {
+ "base": "18",
+ "dtype": "int32x12",
+ "lanes": "12",
+ "span": "0",
+ "stride": "19",
+ },
+ },
+ {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0",
"value": "11"}},
+ {"type_key": "IntImm", "attrs": {"dtype": "int32", "span": "0",
"value": "1"}},
+ ],
+ "b64ndarrays": [],
+ "attrs": {"tvm_version": "0.15.dev0"},
+ }
+ graph = tvm.ir.load_json(json.dumps(json_graph_v0_15))
+
+ # Ramp
+ assert graph.indices[0].base == 11
+ assert graph.indices[0].lanes == 12
+ # Broadcast
+ assert graph.value.value == 3
+ assert graph.value.lanes == 12
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/tir-base/test_tir_nodes.py
b/tests/python/tir-base/test_tir_nodes.py
index 49816778f1..5b55c432b0 100644
--- a/tests/python/tir-base/test_tir_nodes.py
+++ b/tests/python/tir-base/test_tir_nodes.py
@@ -401,5 +401,81 @@ def test_intimm_cond():
assert x == 1
+def _create_ramp(lanes):
+ return tvm.tir.Ramp(0, 1, lanes)
+
+
+def _create_broadcast(lanes):
+ return tvm.tir.Broadcast(0, lanes)
+
+
[email protected]("lanes", [(11 * tvm.tir.vscale()), (tvm.tir.vscale()
* 11)])
[email protected]("node_func", [_create_ramp, _create_broadcast])
+def test_scalable_vec(lanes, node_func):
+ def _check_dtype(node):
+ assert node.lanes.a.equal(tvm.tir.vscale())
+ assert node.lanes.b == 11
+
+ _check_dtype(node_func(lanes))
+
+
[email protected](
+ "lanes", [(tvm.tir.vscale()), (tvm.tir.vscale() + 3), (tvm.tir.vscale() *
2 + 5)]
+)
[email protected]("node_func", [_create_ramp, _create_broadcast])
+def test_scalable_vec_error(lanes, node_func):
+
+ with pytest.raises(tvm.error.TVMError):
+ node_func(lanes)
+
+
+def test_broadcast_to_scalable_vec():
+ vec = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale()) + 3
+ broadcast = vec.b
+
+ assert isinstance(broadcast, tvm.tir.expr.Broadcast)
+ assert broadcast.value == 3
+ assert broadcast.lanes.a.equal(tvm.tir.vscale())
+ assert broadcast.lanes.b == 4
+
+
[email protected](
+ reason="Support for scalable data type string will be added in P3 of
https://github.com/apache/tvm/issues/16455"
+)
+def test_buffer_load_scalable_vec():
+ buf = tvm.tir.decl_buffer((24,), "float32")
+ index = tvm.tir.expr.Ramp(1, 1, 8 * tvm.tir.vscale())
+ load = tvm.tir.BufferLoad(buf, [index])
+
+ assert isinstance(load, tvm.tir.BufferLoad)
+ assert load.dtype == "float32x8xvscale"
+
+
[email protected](
+ reason="Support for scalable data type string will be added in P3 of
https://github.com/apache/tvm/issues/16455"
+)
+def test_buffer_store_scalable_vec():
+ b = tvm.tir.decl_buffer((24,), "int32")
+ value = tvm.tir.expr.Broadcast(1, 4 * tvm.tir.vscale())
+ index = tvm.tir.expr.Ramp(0, 1, 4 * tvm.tir.vscale())
+ store = tvm.tir.BufferStore(b, value, [index])
+
+ assert isinstance(store, tvm.tir.BufferStore)
+ assert store.value.dtype == "int32x4xvscale"
+
+
[email protected](
+ reason="Support for scalable data type string will be added in P3 of
https://github.com/apache/tvm/issues/16455"
+)
+def test_scalable_vec_cast():
+ b = tvm.tir.decl_buffer((24,), "float32")
+ value = tvm.tir.expr.Broadcast(1, 12 *
tvm.tir.vscale()).astype("float32x12xvscale")
+ index = tvm.tir.expr.Ramp(0, 1, 12 * tvm.tir.vscale())
+
+ store = tvm.tir.BufferStore(b, value, [index])
+
+ assert isinstance(store.value.value, tvm.tir.expr.FloatImm)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
index 5362dae303..c20784b4bf 100644
--- a/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_ir_builder_tir.py
@@ -450,6 +450,24 @@ def test_ir_builder_tir_buffer_store():
assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+def test_ir_builder_tir_buffer_store_scalable_vec():
+ buffer_a = T.Buffer((30,), "float32")
+ value = T.broadcast(0.11, 4 * tvm.tir.vscale())
+ index = T.ramp(0, 1, 4 * tvm.tir.vscale())
+
+ with IRBuilder() as ib:
+ T.buffer_store(buffer_a, value, [index])
+
+ # the buffer store generated by IRBuilder
+ ir_actual = ib.get()
+
+ # the expected buffer store
+ ir_expected = tir.BufferStore(buffer_a, value, [index])
+
+ # Check if the generated ir is expected
+ assert_structural_equal(ir_actual, ir_expected, map_free_vars=True)
+
+
def test_ir_builder_tir_prefetch():
with IRBuilder() as ib:
buffer_a = T.Buffer((128, 128), "float32")
diff --git a/tests/python/tvmscript/test_tvmscript_printer_tir.py
b/tests/python/tvmscript/test_tvmscript_printer_tir.py
index 18f4e153bf..4c862e75a6 100644
--- a/tests/python/tvmscript/test_tvmscript_printer_tir.py
+++ b/tests/python/tvmscript/test_tvmscript_printer_tir.py
@@ -17,6 +17,7 @@
# pylint: disable=missing-docstring
import re
+import pytest
import tvm.testing
from tvm import ir, tir
@@ -622,25 +623,35 @@ def test_select():
)
-def test_ramp():
[email protected](
+ "lanes, scripted_lanes", [(32, "32"), (tvm.tir.vscale() * 8, "T.vscale() *
8")]
+)
+def test_ramp(lanes, scripted_lanes):
a = tir.Var("a", "int32")
- obj = tir.Ramp(a, 1, 32)
+ obj = tir.Ramp(a, 1, lanes)
_assert_print(
obj,
"""
a = T.int32()
-T.Ramp(a, 1, 32)
-""",
+T.Ramp(a, 1, {})
+""".format(
+ scripted_lanes
+ ),
)
-def test_broadcast():
- obj = tir.Broadcast(0, 4)
[email protected](
+ "lanes, scripted_lanes", [(4, "4"), (tvm.tir.vscale() * 4, "T.vscale() *
4")]
+)
+def test_broadcast(lanes, scripted_lanes):
+ obj = tir.Broadcast(0, lanes)
_assert_print(
obj,
"""
-T.Broadcast(0, 4)
-""",
+T.Broadcast(0, {})
+""".format(
+ scripted_lanes
+ ),
)
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index 66eef5ad81..c0947f93af 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -3339,6 +3339,15 @@ def ramp_int64():
return func
+def scalable_vectors():
+ @T.prim_func
+ def func(a: T.handle):
+ A = T.match_buffer(a, (200,), "float32")
+ A[T.Ramp(11, 2, 4 * tir.vscale())] = T.Broadcast(125, 4 * tir.vscale())
+
+ return func
+
+
def let_expression():
@T.prim_func
def func():