jroesch commented on a change in pull request #5144:
URL: https://github.com/apache/incubator-tvm/pull/5144#discussion_r424823515



##########
File path: python/tvm/relay/transform/memory_plan.py
##########
@@ -0,0 +1,353 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: 
disable=no-else-return,invalid-name,len-as-condition,too-many-nested-blocks
+"""
+A pass for manifesting explicit memory allocations.
+"""
+from typing import Optional, Dict, List, Tuple
+from collections import defaultdict
+import attr
+
+from ..expr_functor import ExprMutator
+from .. import op, expr
+from ..function import Function
+from ... import register_func, ir, cpu
+from ..._ffi.runtime_ctypes import TVMContext
+from ... import IRModule
+from .. import transform
+from . import function_pass
+
+
+def is_primitive(call):
+    return (
+        hasattr(call, "op")
+        and hasattr(call.op, "attrs")
+        and hasattr(call.op.attrs, "Primitive")
+        and int(call.op.attrs.Primitive) == 1
+    )
+
+
[email protected](auto_attribs=True)
+class Region:
+    """
+    Represents a control-free allocation region.
+
+    The below pass groups sets of allocations into regions,
+    then replaces the region with a single allocation.
+    """
+    var: expr.Var
+    size: expr.Expr
+    alignment: Optional[expr.Expr]
+    dtype: Optional[str]
+    ctx: TVMContext
+    offsets: Dict[expr.Var, Tuple[expr.Expr, expr.Expr]]
+
+    @staticmethod
+    def empty(region_no):
+        zero = expr.const(0, dtype="int64")
+        assert len(zero.data.shape) == 0
+        region_var = expr.var(f"region{region_no}")
+        return Region(region_var, zero, None, None, None, {})
+
+    def grow(
+            self, old_storage: expr.Var,
+            size: expr.Expr, alignment: expr.Expr,
+            ctx: TVMContext,
+            dtype: str) -> None:
+        """Grow the region by a given allocation as well as track the old 
storage
+           for later rewriting the program to use the allocated region.
+        """
+        if self.dtype:
+            assert self.dtype == dtype, "must have matching dtypes in a region"
+        else:
+            self.dtype = dtype
+
+        if self.alignment:
+            assert ir.structural_equal(
+                self.alignment, alignment
+            ), "must have matching alignments in a region"
+        else:
+            self.alignment = alignment
+
+        if self.ctx:
+            assert (self.ctx.device_type == ctx.device_type and
+                    self.ctx.device_id == ctx.device_id), "must have matching 
context"
+        else:
+            assert ctx
+            self.ctx = ctx
+
+        new_size = (size + self.alignment - expr.const(1, "int64")) \
+            / self.alignment * self.alignment
+
+        # Record the offset at which we allocate the storage.
+        offset_var: expr.RelayExpr = expr.var(f"offset{len(self.offsets)}")
+        self.offsets[old_storage] = (offset_var, self.size)
+
+        self.size = self.size + new_size
+
+    def offset_for(self, alloc: expr.Expr) -> expr.Expr:
+        return self.offsets.get(alloc, [None])[0]
+
+    def to_expr(self, body: expr.Expr) -> expr.Expr:
+        """
+        Generate the prelude code for a region, wrapping the body in it.
+
+        The prelude contains the single allocation for a region, and
+        all offset computations.
+        """
+
+        if self.ctx is None:
+            self.ctx = cpu(0)
+
+        # Generate bindings for each and every size computation
+        # we must do this to maintain ANF.
+        bindings: List[Tuple[expr.Expr, expr.Expr]] = []
+
+        # First compute the total size.
+        total_size = expr.var(f"total_size{hash(body)}")
+        bindings.append((total_size, self.size))
+
+        # Allocate the entire region with a single call.
+        alloc = op.memory.alloc_storage(total_size, self.alignment, self.ctx, 
self.dtype)
+        bindings.append((self.var, alloc))
+
+        # Generate variables which contain all of the offset math.
+        # Ensure we constant evaluate away all the math here.
+        #
+        # In theory we can support dynamic offsets but this
+        # requires another round of memory planning and
+        # potentially colaescing.
+        for alloc in self.offsets:
+            (var, offset) = self.offsets[alloc]
+            bindings.append((var, offset))
+
+        body = mk_let(bindings, body)
+        return body
+
+
+def iterative_let(let, each_binding, kont):
+    bindings = []
+    while isinstance(let, expr.Let):
+        lhs = let.var
+        rhs = let.value
+        bindings.append(each_binding(lhs, rhs))
+        let = let.body
+
+    return kont(bindings, let)
+
+
+
+def mk_let(bindings, body):
+    for var, value in reversed(bindings):
+        assert var
+        assert value
+        assert body
+        body = expr.Let(var, value, body)
+
+    return body
+
+def const_eval(mod, exp):
+    mod = IRModule.from_expr(exp, type_defs=mod.type_definitions)
+    mod = transform.FoldConstant()(mod)
+    return mod["main"]
+
+class StorageCoalesce(ExprMutator):
+    """
+    A pass for coalescing allocations into region/arena allocations.
+
+    After this pass each allocation comes from the same backing storage,
+    but will never overlap even in time, i.e. the allocations are just
+    packed into a contiguous block of memory.
+
+    A secondary part of memory planning will perform liveness analysis to
+    overlap these in time, i.e when an early tensor dies we will attempt
+    to reuse its slot.
+    """
+
+    def __init__(self):
+        super().__init__()
+        self.regions = []
+
+    def enter_scope(self) -> None:
+        region_no = len(self.regions)
+        self.regions.append(defaultdict(lambda: Region.empty(region_no)))
+
+    def exit_scope(self, body: expr.Expr) -> expr.Expr:
+        """When leaving a scope build a region allocation for the scope."""
+        dtype_region = self.regions.pop()
+        for _, region in reversed(list(dtype_region.items())):
+            if len(region.offsets) != 0:
+                body = region.to_expr(body)
+
+        return body
+
+    def current_region(self, dtype) -> Region:
+        current_scope = self.regions[-1]
+        return current_scope[dtype]
+
+    def new_region_and_offset(self, old_storage):
+        for dtype_region in reversed(self.regions):
+            for dtype in dtype_region:
+                region = dtype_region[dtype]
+                offset = region.offset_for(old_storage)
+                if offset:
+                    return region, offset
+
+        raise Exception("could not find offset in any valid region")
+
+    def visit_function(self, fn):
+        """Transform the function body to use region allocation scheme."""
+        func = fn
+        if getattr(func.attrs, "Primitive", 0) == 1:
+            return super().visit_function(func)
+        else:
+            self.enter_scope()
+            body = self.visit(func.body)
+            body = self.exit_scope(body)
+            return Function(
+                func.params,
+                body,
+                func.ret_type,
+                func.type_params,
+                func.attrs,
+            )
+
+    def visit_if(self, ite):
+        self.enter_scope()
+        true_branch = self.visit(ite.true_branch)
+        true_branch = self.exit_scope(true_branch)
+
+        self.enter_scope()
+        false_branch = self.visit(ite.false_branch)
+        false_branch = self.exit_scope(false_branch)
+
+        return expr.If(ite.cond, true_branch, false_branch)
+
+
+    def mk_let(self, dynamic_regions):
+        """Let bind the dynamic regions"""
+        def _mk_let(bindings, body):
+            for var, value in reversed(bindings):
+                assert var
+                assert value
+                assert body
+                body = expr.Let(var, value, body)
+                if var in dynamic_regions:
+                    body = self.exit_scope(body)
+
+            return body
+
+        return _mk_let
+
+    def visit_let(self, let):
+        dynamic_regions = []
+        def _each_binding(lhs, rhs):
+            if isinstance(rhs, expr.Call) and rhs.op == op.op.get(
+                    "memory.alloc_storage"
+            ):
+                return self.process_alloc_storage(dynamic_regions, lhs, rhs)
+            elif isinstance(rhs, expr.Call) and rhs.op == op.op.get(
+                    "memory.alloc_tensor"
+            ):
+                return self.process_alloc_tensor(lhs, rhs)
+            else:
+                return lhs, rhs
+
+        result = iterative_let(let, _each_binding, 
self.mk_let(dynamic_regions))
+        assert result
+        return result
+
+    def process_alloc_storage(self, dynamic_regions, lhs, call):
+        """Process alloc_storage"""
+        size, alignment = call.args
+        dtype = call.attrs.dtype
+        ctx = TVMContext(call.attrs.device_type, call.attrs.device_id)
+
+        if not isinstance(size, expr.Constant):
+            self.enter_scope()
+            dynamic_regions.append(lhs)
+
+        region = self.current_region(dtype)
+        region.grow(lhs, size, alignment, ctx, dtype)
+        return lhs, region.var
+
+    def process_alloc_tensor(self, lhs, call):
+        """Process alloc tensor. Region and offset are computed"""
+        storage, old_offset, shape = call.args
+        region, offset = self.new_region_and_offset(storage)
+
+        assert (
+            old_offset.data.asnumpy().item() == 0
+        ), "no offsets should yet be allocated"
+        return (
+            lhs,
+            expr.Call(call.op, [region.var, offset, shape], call.attrs),
+        )
+
+class LiftConst(ExprMutator):
+    """A internal pass to lift constants to the top level of function."""
+    def __init__(self):
+        self.i = 0
+        self.constants = []
+        self.top_level = True
+        super().__init__()
+
+    def visit_constant(self, const):
+        var = expr.var(f"const{self.i}")
+        self.i += 1
+        self.constants.append((var, const))
+        return var
+
+    def visit_function(self, fn):
+        if int(getattr(fn.attrs, "Primitive", 0)) == 1:
+            return fn
+
+        outer_constant = self.constants
+        self.constants = []
+        body = mk_let(self.constants, self.visit(fn.body))

Review comment:
       oh this is relying a very subtle side-effect, the visit on the right is 
populating the variable again. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to