mbaret commented on a change in pull request #7925:
URL: https://github.com/apache/tvm/pull/7925#discussion_r648128260



##########
File path: python/tvm/tir/transform/inject_rolling_buffer.py
##########
@@ -0,0 +1,238 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Inject rolling buffers through a TIR transformation."""
+# pylint: disable=invalid-name,unused-argument,inconsistent-return-statements
+from collections import defaultdict, namedtuple
+import math
+
+import tvm
+from tvm import arith
+
+
+def InjectRollingBuffer():
+    """Inject rolling buffer statements.
+
+    Rolling buffers are buffers where one of the dimensions has been made into
+    a circular buffer. Two optimizations are implemented in order to accomplish
+    this: sliding window and storage folding. In particular, the sliding window
+    optimization is applied to the entire buffer (to avoid recomputing 
elements)
+    and storage folding is then applied to just the rolling dimension.
+
+    Rolling buffers must be inside a loop with only part of the buffer used per
+    iteration. The outermost axis will be rolled over.
+
+    For more information, see the RFC:
+    
https://discuss.tvm.apache.org/t/rfc-introducing-a-rolling-buffer-scheduling-primitive/9836
+
+    Returns
+    -------
+    fpass : tvm.transform.Pass
+        The pass
+    """
+    buffer_to_attrs = defaultdict(list)
+    rolling_buffers = set()
+    rolling_buffer_to_info = dict()
+    iter_vars = list()
+    hoist_buffer_to_for = defaultdict(list)
+
+    RollingBufferInfo = namedtuple(
+        "RollingBufferInfo", ["rolling_axis", "rolling_extent", 
"axis_overlaps", "axis_iter_vars"]
+    )
+
+    def _pre_visit(stmt):
+        if isinstance(stmt, tvm.tir.For):
+            # Manage the stack of iter_vars
+            iter_vars.append(stmt)
+
+        elif isinstance(stmt, tvm.tir.AttrStmt):
+            if isinstance(stmt.node, tvm.tir.Buffer):
+                if stmt.attr_key == "rolling_buffer_scope" and 
stmt.value.value:
+                    # If the attribute is indicating that a buffer should be a 
rolling
+                    # buffer, then update the rolling_buffers set to include 
the bufffer
+                    rolling_buffers.add(stmt.node)
+                # Keep a dictionary associating attribute statements with the 
buffers
+                # they reference. We'll need this if the buffer gets hoisted 
and we
+                # need to hoist all of its attributes at the same time.
+                buffer_to_attrs[stmt.node].append(stmt)
+
+        elif isinstance(stmt, tvm.tir.BufferRealize):
+            if stmt.buffer in rolling_buffers:
+                # If a BufferRealize has been identified as needing to be made 
into
+                # a rolling buffer, begin the analysis...
+                bound_iter_vars = []
+                bound_overlaps = []
+                # We use the bound information of the BufferRealize to 
calculate
+                # how we can legally roll
+                for bound in stmt.bounds:
+                    divisor = 1
+                    # Handle the case of fractional strides
+                    # They take this form: floordiv(hh.outer, 2)
+                    # Strip the floordiv and keep track of the divisor
+                    if isinstance(bound.min, tvm.tir.FloorDiv):
+                        divisor = bound.min.b.value
+                        bound.min = bound.min.a
+                    # If the bound is an int, we can't roll over it
+                    if isinstance(bound.min, tvm.tir.IntImm):
+                        iter_var = None
+                        stride = 0
+                    # If the bound is just a Var, that implies the stride is 1
+                    elif isinstance(bound.min, tvm.tir.Var):
+                        iter_var = bound.min
+                        stride = 1
+                    # Otherwise, it's the iter var multiplied by the stride
+                    # If not we're in unknown behaviour, so assert
+                    else:
+                        assert isinstance(
+                            bound.min, tvm.tir.Mul
+                        ), "Rolling buffer injection failed: the buffer 
striding is unsupported"
+                        assert isinstance(
+                            bound.min.a, tvm.tir.Var
+                        ), "Rolling buffer injection failed: the buffer 
striding is unsupported"
+                        assert isinstance(
+                            bound.min.b, tvm.tir.IntImm
+                        ), "Rolling buffer injection failed: the buffer 
striding is unsupported"
+                        iter_var = bound.min.a
+                        stride = bound.min.b.value
+                    stride = math.ceil(stride / divisor)
+                    bound_iter_vars.append(iter_var)
+                    if iter_var is not None:
+                        bound_overlaps.append(bound.extent.value - stride)
+                    else:
+                        bound_overlaps.append(0)
+
+                # Pick the outermost iter_var that's mentioned in the bounds
+                # to be the rolling axis
+                roll_iter_var = None
+                roll_axis = -1
+                for loop in iter_vars:
+                    iter_var = loop.loop_var
+                    if iter_var in bound_iter_vars:

Review comment:
       It's because we don't necessarily iterate over a tensor in the same 
order as its bounds (e.g. we don't have to go axis 0, 1, 2...)




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to