This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new d1e1ac4  [REFACTOR][PY] Establish tvm.arith (#4904)
d1e1ac4 is described below

commit d1e1ac49b37210334e543f6c4cd8813cbe80e26d
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Feb 18 08:14:12 2020 -0800

    [REFACTOR][PY] Establish tvm.arith (#4904)
---
 .../tvm/arith/__init__.py                          | 25 ++-----
 .../tvm/arith/_ffi_api.py                          | 22 +-----
 python/tvm/{arith.py => arith/analyzer.py}         | 72 ++++++++-----------
 .../tvm/arith/bound.py                             | 38 +++++-----
 python/tvm/arith/int_set.py                        | 80 ++++++++++++++++++++++
 python/tvm/arith/pattern.py                        | 60 ++++++++++++++++
 src/api/api_arith.cc                               | 18 +++--
 src/arith/int_set.cc                               |  2 +-
 tests/python/unittest/test_arith_deduce_bound.py   | 56 +++++++--------
 .../unittest/test_arith_detect_clip_bound.py       |  6 +-
 .../unittest/test_arith_detect_linear_equation.py  | 24 +++----
 tests/python/unittest/test_arith_domain_touched.py | 11 ++-
 tests/python/unittest/test_arith_intset.py         |  6 +-
 vta/python/vta/ir_pass.py                          | 12 ++--
 14 files changed, 268 insertions(+), 164 deletions(-)

diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py 
b/python/tvm/arith/__init__.py
similarity index 51%
copy from tests/python/unittest/test_arith_detect_clip_bound.py
copy to python/tvm/arith/__init__.py
index 3301c24..40e977e 100644
--- a/tests/python/unittest/test_arith_detect_clip_bound.py
+++ b/python/tvm/arith/__init__.py
@@ -14,24 +14,9 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import tvm
+"""Integer bound analysis, simplification and pattern detection."""
 
-def test_basic():
-    a = tvm.var("a")
-    b = tvm.var("b")
-    c = tvm.var("c")
-    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
-                                          a - 1 > 0), [a])
-    assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
-    assert m[0].value == 2
-    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
-                                          a - 1 > 0), [a, b])
-    assert len(m) == 0
-    m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20,
-                                          b - 1 > 0), [a, b])
-    assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
-    assert tvm.ir_pass.Simplify(m[2] - 2).value == 0
-
-
-if __name__ == "__main__":
-    test_basic()
+from .int_set import IntSet, IntervalSet
+from .analyzer import ModularSet, ConstIntBound, Analyzer
+from .bound import deduce_bound
+from .pattern import detect_linear_equation, detect_clip_bound
diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py 
b/python/tvm/arith/_ffi_api.py
similarity index 52%
copy from tests/python/unittest/test_arith_detect_clip_bound.py
copy to python/tvm/arith/_ffi_api.py
index 3301c24..c551e56 100644
--- a/tests/python/unittest/test_arith_detect_clip_bound.py
+++ b/python/tvm/arith/_ffi_api.py
@@ -14,24 +14,8 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import tvm
+"""FFI APIs for tvm.arith"""
+import tvm._ffi
 
-def test_basic():
-    a = tvm.var("a")
-    b = tvm.var("b")
-    c = tvm.var("c")
-    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
-                                          a - 1 > 0), [a])
-    assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
-    assert m[0].value == 2
-    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
-                                          a - 1 > 0), [a, b])
-    assert len(m) == 0
-    m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20,
-                                          b - 1 > 0), [a, b])
-    assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
-    assert tvm.ir_pass.Simplify(m[2] - 2).value == 0
 
-
-if __name__ == "__main__":
-    test_basic()
+tvm._ffi._init_api("arith", __name__)
diff --git a/python/tvm/arith.py b/python/tvm/arith/analyzer.py
similarity index 83%
rename from python/tvm/arith.py
rename to python/tvm/arith/analyzer.py
index b67e99c..382a7e0 100644
--- a/python/tvm/arith.py
+++ b/python/tvm/arith/analyzer.py
@@ -17,34 +17,7 @@
 """Arithmetic data structure and utility"""
 import tvm._ffi
 from tvm.runtime import Object
-
-
-class IntSet(Object):
-    """Represent a set of integer in one dimension."""
-    def is_nothing(self):
-        """Whether the set represent nothing"""
-        return _IntSetIsNothing(self)
-
-    def is_everything(self):
-        """Whether the set represent everything"""
-        return _IntSetIsEverything(self)
-
-
-@tvm._ffi.register_object("arith.IntervalSet")
-class IntervalSet(IntSet):
-    """Represent set of continuous interval [min_value, max_value]
-
-    Parameters
-    ----------
-    min_value : Expr
-        The minimum value in the interval.
-
-    max_value : Expr
-        The maximum value in the interval.
-    """
-    def __init__(self, min_value, max_value):
-        self.__init_handle_by_constructor__(
-            _make_IntervalSet, min_value, max_value)
+from . import _ffi_api
 
 
 @tvm._ffi.register_object("arith.ModularSet")
@@ -52,7 +25,7 @@ class ModularSet(Object):
     """Represent range of (coeff * x + base) for x in Z """
     def __init__(self, coeff, base):
         self.__init_handle_by_constructor__(
-            _make_ModularSet, coeff, base)
+            _ffi_api.ModularSet, coeff, base)
 
 
 @tvm._ffi.register_object("arith.ConstIntBound")
@@ -72,7 +45,7 @@ class ConstIntBound(Object):
 
     def __init__(self, min_value, max_value):
         self.__init_handle_by_constructor__(
-            _make_ConstIntBound, min_value, max_value)
+            _ffi_api.ConstIntBound, min_value, max_value)
 
 
 class ConstraintScope:
@@ -105,11 +78,12 @@ class Analyzer:
     be used to perform various symbolic integer analysis.
     """
     def __init__(self):
-        _mod = _CreateAnalyzer()
+        _mod = _ffi_api.CreateAnalyzer()
         self._const_int_bound = _mod("const_int_bound")
         self._const_int_bound_update = _mod("const_int_bound_update")
         self._bind = _mod("bind")
         self._modular_set = _mod("modular_set")
+        self._simplify = _mod("Simplify")
         self._rewrite_simplify = _mod("rewrite_simplify")
         self._canonical_simplify = _mod("canonical_simplify")
         self._int_set = _mod("int_set")
@@ -120,7 +94,7 @@ class Analyzer:
 
         Parameters
         ----------
-        expr : tvm.Expr
+        expr : PrimExpr
             The expression.
 
         Returns
@@ -135,7 +109,7 @@ class Analyzer:
 
         Parameters
         ----------
-        expr : tvm.Expr
+        expr : PrimExpr
             The expression.
 
         Returns
@@ -145,12 +119,27 @@ class Analyzer:
         """
         return self._modular_set(expr)
 
+    def simplify(self, expr):
+        """Simplify expression via both rewrite and canonicalization.
+
+        Parameters
+        ----------
+        expr : PrimExpr
+            The expression.
+
+        Returns
+        -------
+        result : Expr
+            The result.
+        """
+        return self._simplify(expr)
+
     def rewrite_simplify(self, expr):
         """Simplify expression via rewriting rules.
 
         Parameters
         ----------
-        expr : tvm.Expr
+        expr : PrimExpr
             The expression.
 
         Returns
@@ -165,7 +154,7 @@ class Analyzer:
 
         Parameters
         ----------
-        expr : tvm.Expr
+        expr : PrimExpr
             The expression.
 
         Returns
@@ -180,7 +169,7 @@ class Analyzer:
 
         Parameters
         ----------
-        expr : tvm.Expr
+        expr : PrimExpr
             The expression.
 
         dom_map : Dict[Var, tvm.arith.IntSet]
@@ -198,10 +187,10 @@ class Analyzer:
 
         Parameters
         ----------
-        var : tvm.Var
+        var : tvm.tir.Var
             The variable.
 
-        expr : tvm.Expr
+        expr : PrimExpr
             The expression.
         """
         return self._bind(var, expr)
@@ -211,7 +200,7 @@ class Analyzer:
 
         Parameters
         ----------
-        constraint : tvm.Expr
+        constraint : PrimExpr
             The constraint expression.
 
         returns
@@ -240,7 +229,7 @@ class Analyzer:
 
         Parameters
         ----------
-        var : tvm.Var
+        var : tvm.tir.Var
             The variable.
 
         info : tvm.Object
@@ -254,6 +243,3 @@ class Analyzer:
         else:
             raise TypeError(
                 "Do not know how to handle type {}".format(type(info)))
-
-
-tvm._ffi._init_api("tvm.arith")
diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py 
b/python/tvm/arith/bound.py
similarity index 52%
copy from tests/python/unittest/test_arith_detect_clip_bound.py
copy to python/tvm/arith/bound.py
index 3301c24..6f4b220 100644
--- a/tests/python/unittest/test_arith_detect_clip_bound.py
+++ b/python/tvm/arith/bound.py
@@ -14,24 +14,26 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import tvm
+"""Bound deduction."""
+from . import _ffi_api
 
-def test_basic():
-    a = tvm.var("a")
-    b = tvm.var("b")
-    c = tvm.var("c")
-    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
-                                          a - 1 > 0), [a])
-    assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
-    assert m[0].value == 2
-    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
-                                          a - 1 > 0), [a, b])
-    assert len(m) == 0
-    m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20,
-                                          b - 1 > 0), [a, b])
-    assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
-    assert tvm.ir_pass.Simplify(m[2] - 2).value == 0
 
+def deduce_bound(var, cond, hint_map, relax_map):
+    """Deduce the bound of the target variable in the cond.
 
-if __name__ == "__main__":
-    test_basic()
+    Parameters
+    ----------
+    var : Var
+        The target variable to be deduced.
+
+    cond : PrimExpr
+        The condition
+
+    hint_map : Map[Var, IntSet]
+        Domain of variables used to help deduction.
+
+    relax_map : Map[Var, IntSet]
+        The fomain of the variables to be relaxed
+        using the provided domain.
+    """
+    return _ffi_api.DeduceBound(var, cond, hint_map, relax_map)
diff --git a/python/tvm/arith/int_set.py b/python/tvm/arith/int_set.py
new file mode 100644
index 0000000..838e8e5
--- /dev/null
+++ b/python/tvm/arith/int_set.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Integer set."""
+import tvm._ffi
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+class IntSet(Object):
+    """Represent a set of integer in one dimension."""
+    def is_nothing(self):
+        """Whether the set represent nothing"""
+        return _ffi_api.IntSetIsNothing(self)
+
+    def is_everything(self):
+        """Whether the set represent everything"""
+        return _ffi_api.IntSetIsEverything(self)
+
+    @staticmethod
+    def vector(vec):
+        """Construct an integer set that covers the vector expr
+
+        Parameters
+        ----------
+        vec : PrimExpr
+            The vector expression.
+
+        Returns
+        -------
+        rset : IntSet
+            The result set.
+        """
+        return _ffi_api.intset_vector(vec)
+
+    @staticmethod
+    def single_point(point):
+        """Construct a point set.
+
+        Parameters
+        ----------
+        point : PrimExpr
+            The vector expression.
+
+        Returns
+        -------
+        rset : IntSet
+            The result set.
+        """
+        return _ffi_api.intset_single_point(point)
+
+
+@tvm._ffi.register_object("arith.IntervalSet")
+class IntervalSet(IntSet):
+    """Represent set of continuous interval [min_value, max_value]
+
+    Parameters
+    ----------
+    min_value : PrimExpr
+        The minimum value in the interval.
+
+    max_value : PrimExpr
+        The maximum value in the interval.
+    """
+    def __init__(self, min_value, max_value):
+        self.__init_handle_by_constructor__(
+            _ffi_api.IntervalSet, min_value, max_value)
diff --git a/python/tvm/arith/pattern.py b/python/tvm/arith/pattern.py
new file mode 100644
index 0000000..2281088
--- /dev/null
+++ b/python/tvm/arith/pattern.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Detect common patterns."""
+from . import _ffi_api
+
+
+def detect_linear_equation(expr, var_list):
+    """Match `expr = sum_{i=0}^{n-1} var[i] * coeff[i] + coeff[n]`
+
+    Where coeff[i] and base are invariant of var[j] for all i and j.
+
+    Parameters
+    ----------
+    expr : PrimExpr
+        The expression to be matched.
+
+    var_list : List[tvm.tir.Var]
+        A list of variables.
+
+    Returns
+    -------
+    coeff : List[PrimExpr]
+        A list of co-efficients if the match is successful.
+        An empty list if the match failed.
+    """
+    return _ffi_api.DetectLinearEquation(expr, var_list)
+
+
+def detect_clip_bound(expr, var_list):
+    """ Detect if expression corresponds to clip bound of the vars
+
+    Parameters
+    ----------
+    expr : PrimExpr
+        The expression to be matched.
+
+    var_list : List[tvm.tir.Var]
+        A list of variables.
+
+    Returns
+    -------
+    coeff : List[PrimExpr]
+        `concat([min_value[i], max_value[i]] for i, v in enumerate(var_list))`
+        An empty list if the match failed.
+    """
+    return _ffi_api.DetectClipBound(expr, var_list)
diff --git a/src/api/api_arith.cc b/src/api/api_arith.cc
index f996bdb..3942f6e 100644
--- a/src/api/api_arith.cc
+++ b/src/api/api_arith.cc
@@ -64,33 +64,33 @@ TVM_REGISTER_GLOBAL("arith.DeduceBound")
 TVM_REGISTER_GLOBAL("arith.DomainTouched")
 .set_body_typed(DomainTouched);
 
-TVM_REGISTER_GLOBAL("arith._IntervalSetGetMin")
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMin")
 .set_body_method(&IntSet::min);
 
-TVM_REGISTER_GLOBAL("arith._IntervalSetGetMax")
+TVM_REGISTER_GLOBAL("arith.IntervalSetGetMax")
 .set_body_method(&IntSet::max);
 
-TVM_REGISTER_GLOBAL("arith._IntSetIsNothing")
+TVM_REGISTER_GLOBAL("arith.IntSetIsNothing")
 .set_body_method(&IntSet::is_nothing);
 
-TVM_REGISTER_GLOBAL("arith._IntSetIsEverything")
+TVM_REGISTER_GLOBAL("arith.IntSetIsEverything")
 .set_body_method(&IntSet::is_everything);
 
 ConstIntBound MakeConstIntBound(int64_t min_value, int64_t max_value) {
   return ConstIntBound(min_value, max_value);
 }
 
-TVM_REGISTER_GLOBAL("arith._make_ConstIntBound")
+TVM_REGISTER_GLOBAL("arith.ConstIntBound")
 .set_body_typed(MakeConstIntBound);
 
 ModularSet MakeModularSet(int64_t coeff, int64_t base) {
   return ModularSet(coeff, base);
 }
 
-TVM_REGISTER_GLOBAL("arith._make_ModularSet")
+TVM_REGISTER_GLOBAL("arith.ModularSet")
 .set_body_typed(MakeModularSet);
 
-TVM_REGISTER_GLOBAL("arith._CreateAnalyzer")
+TVM_REGISTER_GLOBAL("arith.CreateAnalyzer")
 .set_body([](TVMArgs args, TVMRetValue* ret) {
     using runtime::PackedFunc;
     using runtime::TypedPackedFunc;
@@ -108,6 +108,10 @@ TVM_REGISTER_GLOBAL("arith._CreateAnalyzer")
         return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
             self->const_int_bound.Update(args[0], args[1], args[2]);
         });
+      } else if (name == "Simplify") {
+        return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
+            *ret = self->Simplify(args[0]);
+        });
       } else if (name == "rewrite_simplify") {
         return PackedFunc([self](TVMArgs args, TVMRetValue *ret) {
             *ret = self->rewrite_simplify(args[0]);
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index 728cca1..adb3879 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -54,7 +54,7 @@ IntervalSet MakeIntervalSet(PrimExpr min_value, PrimExpr 
max_value) {
   return IntervalSet(min_value, max_value);
 }
 
-TVM_REGISTER_GLOBAL("arith._make_IntervalSet")
+TVM_REGISTER_GLOBAL("arith.IntervalSet")
 .set_body_typed(MakeIntervalSet);
 
 
diff --git a/tests/python/unittest/test_arith_deduce_bound.py 
b/tests/python/unittest/test_arith_deduce_bound.py
index 787dfe8..5e08635 100644
--- a/tests/python/unittest/test_arith_deduce_bound.py
+++ b/tests/python/unittest/test_arith_deduce_bound.py
@@ -38,90 +38,90 @@ def test_deduce():
     fdiv = tvm.floordiv
 
     e0 = (-b)*a+c-d
-    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
+    res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
     ans0 = fdiv(d - c, b*-1)
     assert_expr_equal(res0.max_value, ans0)
 
     # expression containing variable a is on rhs
-    res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
+    res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
     assert_expr_equal(res0.max_value, ans0)
 
     e0 = d*a+c-d
-    res0 = tvm.arith.DeduceBound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
+    res0 = tvm.arith.deduce_bound(a, e0>=0, {b: b_s, c: c_s, d: d_s}, {})
     ans0 = fdiv(d-c, d)
     assert_expr_equal(res0.max_value, ans0)
 
     # expression containing variable a is on rhs
-    res0 = tvm.arith.DeduceBound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
+    res0 = tvm.arith.deduce_bound(a, zero <= e0, {b: b_s, c: c_s, d: d_s}, {})
     assert_expr_equal(res0.max_value, ans0)
 
 
     e1 = (a*4+b < c)
-    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
+    res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
     ans1 = fdiv(c-1-b, 4)
     assert_expr_equal(res1.max_value, ans1)
 
 
     # expression containing variable a is on rhs
     e1 = (c > a*4+b)
-    res1 = tvm.arith.DeduceBound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
+    res1 = tvm.arith.deduce_bound(a, e1, {b: b_s, c: c_s, d: d_s}, {})
     assert_expr_equal(res1.max_value, ans1)
 
 
     e2 = (tvm.max(5, a * 4) < 0)
-    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
+    res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
     assert str(res2.max_value) == "neg_inf"
     assert str(res2.min_value) == "pos_inf"
 
     # expression containing variable a is on rhs
     e2 = (zero < tvm.max(5, a * 4))
-    res2 = tvm.arith.DeduceBound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
+    res2 = tvm.arith.deduce_bound(a, e2, {b: b_s, c: c_s, d: d_s}, {})
     assert str(res2.max_value) == "neg_inf"
     assert str(res2.min_value) == "pos_inf"
 
     e3 = (-b)+a*c-d
-    res3 = tvm.arith.DeduceBound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, 
d: d_s})
+    res3 = tvm.arith.deduce_bound(a, e3>=0, {b: b_s, c: c_s, d: d_s}, {b: b_s, 
d: d_s})
     ans3 = fdiv(2,c)+1
     assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
 
-    res3 = tvm.arith.DeduceBound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: 
b_s, d: d_s})
+    res3 = tvm.arith.deduce_bound(a, zero <= e3, {b: b_s, c: c_s, d: d_s}, {b: 
b_s, d: d_s})
     assert str(tvm.ir_pass.Simplify(res3.min_value)) == str(ans3)
 
     # tests for `EQ` op
-    res4 = tvm.arith.DeduceBound(a, a == b, {}, {})
+    res4 = tvm.arith.deduce_bound(a, a == b, {}, {})
     assert_expr_equal(res4.max_value, b)
     assert_expr_equal(res4.min_value, b)
 
     # Unsatisfiable `EQ`, variable as one of the Operand
-    res5 = tvm.arith.DeduceBound(a, (a == b), {b: b_s}, {b: b_s})
+    res5 = tvm.arith.deduce_bound(a, (a == b), {b: b_s}, {b: b_s})
     assert str(res5.max_value) == "neg_inf"
     assert str(res5.min_value) == "pos_inf"
 
     # variable `a` on the RHS side
-    res6 = tvm.arith.DeduceBound(a, 10 == a, {}, {})
+    res6 = tvm.arith.deduce_bound(a, 10 == a, {}, {})
     assert_expr_equal(res6.max_value, 10)
     assert_expr_equal(res6.min_value, 10)
 
     # Add, Sub in `EQ`
     e4 = ((a - c) == (b + d))
     ans4 = (b + d + c)
-    res7 = tvm.arith.DeduceBound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
+    res7 = tvm.arith.deduce_bound(a, e4, {b: b_s, c: c_s, d: d_s}, {})
     assert_expr_equal(res7.max_value, ans4)
     assert_expr_equal(res7.min_value, ans4)
 
     # Satisfiable Mul in `EQ` with negative sign
-    res8 = tvm.arith.DeduceBound(a, (5 * a == -10), {}, {})
+    res8 = tvm.arith.deduce_bound(a, (5 * a == -10), {}, {})
     assert_expr_equal(res8.max_value, -2)
     assert_expr_equal(res8.min_value, -2)
 
     # Unsatisfiable Mul in `EQ`
     e5 = (4 * a == b)
-    res9 = tvm.arith.DeduceBound(a, e5, {b: b_s}, {})
+    res9 = tvm.arith.deduce_bound(a, e5, {b: b_s}, {})
     assert str(res9.max_value) == "neg_inf"
     assert str(res9.min_value) == "pos_inf"
 
     # Unsatisfiable Mul in `EQ`
-    res10 = tvm.arith.DeduceBound(a, (b * a == b), {b: b_s}, {})    # 
simplifier is not able to prove that (b % b == 0)
+    res10 = tvm.arith.deduce_bound(a, (b * a == b), {b: b_s}, {})    # 
simplifier is not able to prove that (b % b == 0)
     assert str(res10.max_value) == "neg_inf"
     assert str(res10.min_value) == "pos_inf"
 
@@ -137,15 +137,15 @@ def test_check():
     d_s = tvm.arith.IntervalSet(-3, -1)
 
     # no compare operator
-    res1 = tvm.arith.DeduceBound(a, a+b, {b: b_s}, {})
+    res1 = tvm.arith.deduce_bound(a, a+b, {b: b_s}, {})
     assert res1.is_nothing()
 
     # multiple compare operators
-    res2 = tvm.arith.DeduceBound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: 
c_s}, {})
+    res2 = tvm.arith.deduce_bound(a, (a+b>3).astype(c.dtype)>c , {b: b_s, c: 
c_s}, {})
     assert res2.is_nothing()
 
     # multiple target variable
-    res2 = tvm.arith.DeduceBound(a, a*2-a>b, {b: b_s}, {})
+    res2 = tvm.arith.deduce_bound(a, a*2-a>b, {b: b_s}, {})
     assert res2.is_nothing()
 
 def test_deduce_basic():
@@ -155,21 +155,21 @@ def test_deduce_basic():
         b_s = tvm.arith.IntervalSet(a1, a2)
         e0 = b + a*coff + 3
 
-        res1 = tvm.arith.DeduceBound(a, e0<17, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, e0<17, {b: b_s}, {b: b_s})
         [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else 
[res1.min_value, b_s.min_value]
         assert (tvm.ir_pass.Simplify((x * coff + 3 + y) < 17)).value == 1
 
         # expression containing variable a is on rhs
-        res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32") < e0, {b: b_s}, 
{b: b_s})
+        res1 = tvm.arith.deduce_bound(a, tvm.const(17, "int32") < e0, {b: 
b_s}, {b: b_s})
         [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else 
[res1.min_value, b_s.min_value]
         assert (tvm.ir_pass.Simplify((x * coff + 3 + y) > 17)).value == 1
 
         # expression containing variable a is on rhs
-        res1 = tvm.arith.DeduceBound(a, tvm.const(17, "int32")>= e0, {b: b_s}, 
{b: b_s})
+        res1 = tvm.arith.deduce_bound(a, tvm.const(17, "int32")>= e0, {b: 
b_s}, {b: b_s})
         [x, y] = [res1.max_value, b_s.max_value] if coff > 0 else 
[res1.min_value, b_s.min_value]
         assert (tvm.ir_pass.Simplify((x * coff + 3 + y) <= 17)).value == 1
 
-        res1 = tvm.arith.DeduceBound(a, e0>=17, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, e0>=17, {b: b_s}, {b: b_s})
         [x, y] = [res1.max_value, b_s.max_value] if coff < 0 else 
[res1.min_value, b_s.min_value]
         assert (tvm.ir_pass.Simplify((x * coff + 3 + y) >= 17)).value == 1
 
@@ -187,21 +187,21 @@ def test_deduce_complex():
         b_s = tvm.arith.IntervalSet(a1, a2)
         e0 = (b*3 + a* coff) * 4
 
-        res1 = tvm.arith.DeduceBound(a, e0<63, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, e0<63, {b: b_s}, {b: b_s})
         [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else 
[res1.min_value, b_s.min_value]
         assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) < 63)).value == 1
 
         # expression containing variable a is on rhs
-        res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32")>= e0, {b: b_s}, 
{b: b_s})
+        res1 = tvm.arith.deduce_bound(a, tvm.const(63, "int32")>= e0, {b: 
b_s}, {b: b_s})
         [t, x] = [res1.max_value, b_s.max_value] if coff > 0 else 
[res1.min_value, b_s.min_value]
         assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) <= 63)).value == 1
 
-        res1 = tvm.arith.DeduceBound(a, e0>63, {b: b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, e0>63, {b: b_s}, {b: b_s})
         [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else 
[res1.min_value, b_s.min_value]
         assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) > 63)).value == 1
 
         # expression containing variable a is on rhs
-        res1 = tvm.arith.DeduceBound(a, tvm.const(63, "int32") <= e0, {b: 
b_s}, {b: b_s})
+        res1 = tvm.arith.deduce_bound(a, tvm.const(63, "int32") <= e0, {b: 
b_s}, {b: b_s})
         [t, x] = [res1.max_value, b_s.max_value] if coff < 0 else 
[res1.min_value, b_s.min_value]
         assert (tvm.ir_pass.Simplify(((x*3 + t* coff) * 4) >= 63)).value == 1
 
diff --git a/tests/python/unittest/test_arith_detect_clip_bound.py 
b/tests/python/unittest/test_arith_detect_clip_bound.py
index 3301c24..44ae24c 100644
--- a/tests/python/unittest/test_arith_detect_clip_bound.py
+++ b/tests/python/unittest/test_arith_detect_clip_bound.py
@@ -20,14 +20,14 @@ def test_basic():
     a = tvm.var("a")
     b = tvm.var("b")
     c = tvm.var("c")
-    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
+    m = tvm.arith.detect_clip_bound(tvm.all(a * 1 < b * 6,
                                           a - 1 > 0), [a])
     assert tvm.ir_pass.Simplify(m[1] - (b * 6 - 1)).value == 0
     assert m[0].value == 2
-    m = tvm.arith.DetectClipBound(tvm.all(a * 1 < b * 6,
+    m = tvm.arith.detect_clip_bound(tvm.all(a * 1 < b * 6,
                                           a - 1 > 0), [a, b])
     assert len(m) == 0
-    m = tvm.arith.DetectClipBound(tvm.all(a + 10 * c <= 20,
+    m = tvm.arith.detect_clip_bound(tvm.all(a + 10 * c <= 20,
                                           b - 1 > 0), [a, b])
     assert tvm.ir_pass.Simplify(m[1] - (20 - 10 * c)).value == 0
     assert tvm.ir_pass.Simplify(m[2] - 2).value == 0
diff --git a/tests/python/unittest/test_arith_detect_linear_equation.py 
b/tests/python/unittest/test_arith_detect_linear_equation.py
index cacb624..3b10302 100644
--- a/tests/python/unittest/test_arith_detect_linear_equation.py
+++ b/tests/python/unittest/test_arith_detect_linear_equation.py
@@ -19,50 +19,50 @@ import tvm
 def test_basic():
     a = tvm.var("a")
     b = tvm.var("b")
-    m = tvm.arith.DetectLinearEquation(a * 4 + b * 6 + 7, [a])
+    m = tvm.arith.detect_linear_equation(a * 4 + b * 6 + 7, [a])
     assert m[0].value == 4
     assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7)).value == 0
 
-    m = tvm.arith.DetectLinearEquation(a * 4 * (a+1) + b * 6 + 7, [a])
+    m = tvm.arith.detect_linear_equation(a * 4 * (a+1) + b * 6 + 7, [a])
     assert len(m) == 0
 
-    m = tvm.arith.DetectLinearEquation(a * 4  + (a+1) + b * 6 + 7, [a])
+    m = tvm.arith.detect_linear_equation(a * 4  + (a+1) + b * 6 + 7, [a])
     assert m[0].value == 5
     assert tvm.ir_pass.Simplify(m[1] - (b * 6 + 7 + 1)).value == 0
 
-    m = tvm.arith.DetectLinearEquation(a * b + 7, [a])
+    m = tvm.arith.detect_linear_equation(a * b + 7, [a])
     assert m[0] == b
 
-    m = tvm.arith.DetectLinearEquation(b * 7, [a])
+    m = tvm.arith.detect_linear_equation(b * 7, [a])
     assert m[0].value == 0
 
-    m = tvm.arith.DetectLinearEquation(b * 7, [])
+    m = tvm.arith.detect_linear_equation(b * 7, [])
     assert len(m) == 1
     assert tvm.ir_pass.Simplify(m[0] - b * 7).value == 0
 
 def test_multivariate():
     v = [tvm.var("v%d" % i) for i in range(4)]
     b = tvm.var("b")
-    m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
+    m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8, v)
     assert(tvm.ir_pass.Equal(tvm.ir_pass.Simplify(m[0]), b + 5))
     assert(m[1].value == 8)
 
-    m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * 
v[2], v)
+    m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * 
v[2], v)
     assert(len(m) == 0)
 
-    m = tvm.arith.DetectLinearEquation(v[0] * (b + 4) + v[0] + v[1] * 8 * v[1] 
+ v[3], v)
+    m = tvm.arith.detect_linear_equation(v[0] * (b + 4) + v[0] + v[1] * 8 * 
v[1] + v[3], v)
     assert(len(m) == 0)
 
-    m = tvm.arith.DetectLinearEquation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 2, 
v)
+    m = tvm.arith.detect_linear_equation(((v[0] * b + v[1]) * 8 + v[2] + 1) * 
2, v)
     assert(m[1].value == 16)
     assert(m[2].value == 2)
     assert(m[len(m)-1].value == 2)
 
-    m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [v[2]])
+    m = tvm.arith.detect_linear_equation((v[0] - v[1]), [v[2]])
     assert(m[0].value == 0)
     assert(tvm.ir_pass.Simplify(m[1] - (v[0] - v[1])).value == 0)
 
-    m = tvm.arith.DetectLinearEquation((v[0] - v[1]), [])
+    m = tvm.arith.detect_linear_equation((v[0] - v[1]), [])
     assert(len(m) == 1)
     assert(tvm.ir_pass.Simplify(m[0] - (v[0] - v[1])).value == 0)
 
diff --git a/tests/python/unittest/test_arith_domain_touched.py 
b/tests/python/unittest/test_arith_domain_touched.py
index 3e45d4e..7876fb6 100644
--- a/tests/python/unittest/test_arith_domain_touched.py
+++ b/tests/python/unittest/test_arith_domain_touched.py
@@ -35,19 +35,19 @@ def test_domain_touched():
                 )
             )
     )
-    a_domain_r = tvm.arith.DomainTouched(ir, a, True, False)
+    a_domain_r = tvm.arith._ffi_api.DomainTouched(ir, a, True, False)
     assert a_domain_r[0].min.value == -1
     assert a_domain_r[0].extent.value == 100
     assert a_domain_r[1].min.value == -1
     assert a_domain_r[1].extent.name == 'm'
 
-    a_domain_w = tvm.arith.DomainTouched(ir, a, False, True)
+    a_domain_w = tvm.arith._ffi_api.DomainTouched(ir, a, False, True)
     assert a_domain_w[0].min.value == 0
     assert a_domain_w[0].extent.value == 100
     assert a_domain_w[1].min.value == 0
     assert a_domain_w[1].extent.name == 'm'
 
-    a_domain_rw= tvm.arith.DomainTouched(ir, a, True, True)
+    a_domain_rw= tvm.arith._ffi_api.DomainTouched(ir, a, True, True)
     assert a_domain_rw[0].min.value == -1
     assert a_domain_rw[0].extent.value == 101
     assert a_domain_rw[1].min.value == -1
@@ -55,17 +55,16 @@ def test_domain_touched():
     assert a_domain_rw[1].extent.a.name == 'm'
     assert a_domain_rw[1].extent.b.value == 1
 
-    b_domain_r = tvm.arith.DomainTouched(ir, b, True, False)
+    b_domain_r = tvm.arith._ffi_api.DomainTouched(ir, b, True, False)
     assert b_domain_r
     assert b_domain_r[0].min.value == -1
     assert b_domain_r[0].extent.value == 100
     assert b_domain_r[1].min.value == 1
     assert b_domain_r[1].extent.name == 'm'
 
-    b_domain_w = tvm.arith.DomainTouched(ir, b, False, True)
+    b_domain_w = tvm.arith._ffi_api.DomainTouched(ir, b, False, True)
     assert isinstance(b_domain_w, tvm.container.Array)
     assert len(b_domain_w) == 0
 
 if __name__ == "__main__":
     test_domain_touched()
-
diff --git a/tests/python/unittest/test_arith_intset.py 
b/tests/python/unittest/test_arith_intset.py
index d83d33d..dad2fa7 100644
--- a/tests/python/unittest/test_arith_intset.py
+++ b/tests/python/unittest/test_arith_intset.py
@@ -36,12 +36,16 @@ def test_basic():
     assert s.min_value.value == 2
     assert s.max_value.value == 3
 
+    s = tvm.arith.IntSet.single_point(2)
+    assert s.min_value.value == 2
+    assert s.max_value.value == 2
+
 
 def test_vector():
     base = 10
     stride = 3
     lanes = 2
-    s = tvm.arith.intset_vector(tvm.tir.Ramp(base, stride, lanes))
+    s = tvm.arith.IntSet.vector(tvm.tir.Ramp(base, stride, lanes))
     assert s.min_value.value == base
     assert s.max_value.value == base + stride * lanes - 1
 
diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py
index 8b8a2f0..36d8e41 100644
--- a/vta/python/vta/ir_pass.py
+++ b/vta/python/vta/ir_pass.py
@@ -76,7 +76,7 @@ def fold_uop_loop(stmt_in):
                 args = []
                 args += op.args[:base_args]
                 for i in range(3):
-                    m = tvm.arith.DetectLinearEquation(
+                    m = tvm.arith.detect_linear_equation(
                         op.args[i + base_args], [loop_var])
                     if not m:
                         fail[0] = True
@@ -867,25 +867,25 @@ def inject_alu_intrin(stmt_in):
                         type(loop_body.value), str(loop_body.value), 
str(stmt)))
 
             # Derive array index coefficients
-            dst_coeff = tvm.arith.DetectLinearEquation(dst_idx, indices)
+            dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
             # Check if lhs/rhs is immediate
             use_imm = False
             imm_val = None
             if isinstance(rhs, tvm.tir.IntImm):
                 assert lhs.buffer_var.same_as(dst_var)
-                src_coeff = tvm.arith.DetectLinearEquation(lhs.index, indices)
+                src_coeff = tvm.arith.detect_linear_equation(lhs.index, 
indices)
                 use_imm = True
                 imm_val = rhs
             if isinstance(lhs, tvm.tir.IntImm):
                 assert rhs.buffer_var.same_as(dst_var)
-                src_coeff = tvm.arith.DetectLinearEquation(rhs.index, indices)
+                src_coeff = tvm.arith.detect_linear_equation(rhs.index, 
indices)
                 use_imm = True
                 imm_val = lhs
             if imm_val is None:
                 imm_val = 0
                 assert lhs.buffer_var.same_as(dst_var) and 
rhs.buffer_var.same_as(dst_var)
-                src_lhs_coeff = tvm.arith.DetectLinearEquation(lhs.index, 
indices)
-                src_rhs_coeff = tvm.arith.DetectLinearEquation(rhs.index, 
indices)
+                src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, 
indices)
+                src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, 
indices)
                 # Determine which side has the same coefficients
                 lhs_equal = True
                 rhs_equal = True

Reply via email to