ekalda commented on code in PR #16862:
URL: https://github.com/apache/tvm/pull/16862#discussion_r1560717574


##########
tests/python/arith/test_arith_simplify.py:
##########
@@ -53,6 +56,33 @@ def test_simplify_symbolic_comparison():
     assert ana.can_prove((n + 31) // 32 * 32 >= i0 * 32 + i1, 
PS.SYMBOLIC_BOUND)
 
 
+def test_simplify_vscale_comparison_with_sve_target():
+    ana = tvm.arith.Analyzer()
+    vs = tvm.tir.vscale()
+
+    with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
+        assert ana.can_prove(vs * 32 < vs * 64)

Review Comment:
   Nit: Maybe these test cases can be broken up into different tests, to help 
with debugging if one of the cases start to fail. Also, maybe it's worth having 
one case with modulo as well.



##########
src/arith/scalable_expression.h:
##########
@@ -25,27 +25,51 @@
 #ifndef TVM_ARITH_SCALABLE_EXPRESSION_H_
 #define TVM_ARITH_SCALABLE_EXPRESSION_H_
 
+#include <tvm/arith/analyzer.h>
 #include <tvm/ir/expr.h>
 
 #include <optional>
+#include <vector>
 
 namespace tvm {
 namespace arith {
 
+/*! \brief A list of known vscale values to try for an AArch64 SVE target. */
+static const std::vector<unsigned int> kAArch64VScaleValues = {1, 2,  3,  4,  
5,  6,  7,  8,
+                                                               9, 10, 11, 12, 
13, 14, 15, 16};
+
 /*!
  * \brief Check if an expr is a call to the vscale intrinsic.
  * \param expr The expr to check
  * \return True if the expr is a call to the vscale intrinsic, false if not.
  */
 bool IsVScaleCall(const PrimExpr& expr);
 
+/*!
+ * \brief Substitute a vscale intrinsic call with a known scalar value.
+ * \param expr The expr to apply substitutions to.
+ * \param vscale_value The scalar value to replace vscale with.
+ * \return A rewritten expression with vscale values replaced with a scalar 
value.
+ */
+PrimExpr SubstituteVScaleWithKnownValue(const PrimExpr& expr, unsigned int 
vscale_value);
+
 /*!
  * \brief Returns the vscale multiplier as a nullable type
  * \param lanes The scalable lanes as a PrimExpr
  * \return vscale multiplier as std::optional<int>
  */
 std::optional<int> ExtractVscaleFactor(const PrimExpr& lanes);
 
+/*!
+ * \brief Check if the expression can be proven when evaluating it on all 
possible values
+           of vscale.
+ * \param analyzer An analyzer instance.
+ * \param expr The expression to try to prove.
+ * \return Whether or not the expression can be proven with this technique.

Review Comment:
   Nit: Missing documentation for vscale_values



##########
src/arith/scalable_expression.cc:
##########
@@ -50,5 +66,20 @@ std::optional<int> ExtractVscaleFactor(const PrimExpr& 
lanes) {
   }
 }
 
+bool CanProveVscaleExpressionFromKnownValues(arith::Analyzer* analyzer, const 
PrimExpr& expr,
+                                             const std::vector<unsigned int>& 
vscale_values) {
+  bool can_prove_expr = true;

Review Comment:
   Out of interest, what happens if the `expr` is not inequality? Should we 
`ICHECK` for it? 



##########
tests/python/tir-schedule/test_tir_schedule_split_fuse.py:
##########
@@ -653,5 +654,138 @@ def test_split_int64_factors():
     assert_structural_equal_ignore_global_symbol(elementwise_symbolic_split, 
sch.mod["main"])
 
 
[email protected]("num_elements", [128, 115])
+def test_sve_scalable_split_predicated(num_elements):
+    """
+    By default, splitting with by vscale factors over a fixed-length loop will
+    result in loop-level predication being inserted. This is because, at
+    compile-time, we don't know if vscale is a multiple of the extent of the
+    loop to be split.
+    """
+
+    @T.prim_func
+    def before(a: T.handle):
+        A = T.match_buffer(a, (num_elements,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        for i in T.serial(num_elements):
+            with T.block("A"):
+                v_i = T.axis.remap("S", [i])
+                A[v_i] = 1.0
+
+    @T.prim_func
+    def after(a: T.handle):
+        A = T.match_buffer(a, (num_elements,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        for i_0, i_1 in T.grid(
+            (T.vscale() * 4 + (num_elements - 1)) // (T.vscale() * 4), 
T.vscale() * 4
+        ):
+            with T.block("A"):
+                v_i = T.axis.spatial(num_elements, i_0 * (T.vscale() * 4) + 
i_1)
+                T.where(i_0 * (T.vscale() * 4) + i_1 < num_elements)
+                A[v_i] = 1.0
+
+    with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
+        sch = tvm.tir.Schedule(before)
+        (a,) = sch.get_loops("A")
+        sch.split(a, factors=[T.ceildiv(num_elements, 4 * T.vscale()), 4 * 
T.vscale()])
+
+    tvm.ir.assert_structural_equal(sch.mod["main"], after)
+
+
+def test_sve_scalable_split_assume_exact_multiple():
+    """
+    If the schedule writer knows the extent of the loop to be split will always
+    be a multiple of vscale, they may use `disable_predication=True` to ensure
+    a predicate is not created. This can be used to ensure predication is not
+    inserted where current analysis is not powerful enough to recognise this.
+    """
+
+    @T.prim_func
+    def before(a: T.handle):
+        A = T.match_buffer(a, (128,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        for i in T.serial(128):
+            with T.block("A"):
+                v_i = T.axis.remap("S", [i])
+                A[v_i] = 1.0
+
+    @T.prim_func
+    def after(a: T.handle):
+        A = T.match_buffer(a, (128,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        for i_0, i_1 in T.grid((T.vscale() * 4 + (128 - 1)) // (T.vscale() * 
4), T.vscale() * 4):
+            with T.block("A"):
+                v_i = T.axis.spatial(128, i_0 * (T.vscale() * 4) + i_1)
+                A[v_i] = 1.0
+
+    with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
+        sch = tvm.tir.Schedule(before)
+        (a,) = sch.get_loops("A")
+        sch.split(
+            a,
+            factors=[T.ceildiv(128, 4 * T.vscale()), 4 * T.vscale()],
+            disable_predication=True,
+        )
+
+    tvm.ir.assert_structural_equal(sch.mod["main"], after)
+
+
+def test_sve_split_over_scalable_loop():
+    @T.prim_func
+    def before(a: T.handle):
+        A = T.match_buffer(a, (128,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        for i in T.serial(4 * T.vscale()):
+            with T.block("A"):
+                v_i = T.axis.remap("S", [i])
+                A[v_i] = 1.0
+
+    @T.prim_func
+    def after(a: T.handle):
+        A = T.match_buffer(a, (128,), "float32")
+        T.func_attr({"global_symbol": "my_module", "tir.noalias": True})
+        for i_0, i_1 in T.grid(T.vscale() * 2, T.vscale() * 2):
+            with T.block("A"):
+                v_i = T.axis.spatial(T.vscale() * 4, i_0 * (T.vscale() * 2) + 
i_1)
+                T.where(i_0 * (T.vscale() * 2) + i_1 < T.vscale() * 4)
+                A[v_i] = 1.0
+
+    with tvm.target.Target("llvm -mtriple=aarch64-linux-gnu -mattr=+sve"):
+        sch = tvm.tir.Schedule(before)
+        (a,) = sch.get_loops("A")
+        sch.split(
+            a,
+            factors=[2 * T.vscale(), 2 * T.vscale()],
+        )
+
+    tvm.ir.assert_structural_equal(sch.mod["main"], after)
+
+
+def test_default_scalable_split(capfd):

Review Comment:
   Nit:
   
   Does the splitting in this test fail because we didn't provide the target 
information? Maybe it can be made clearer in the test name or a comment. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to