ekalda commented on a change in pull request #10344:
URL: https://github.com/apache/tvm/pull/10344#discussion_r817652817



##########
File path: python/tvm/relay/backend/contrib/ethosu/tir_to_cs_translator.py
##########
@@ -398,11 +398,11 @@ def assign_addresses(buffer_info, npu_ops, 
scratch_region_map):
 
     def replace_npu_fm_with_address(npu_fm):
         assert isinstance(npu_fm.tiles.addresses[0], tvm.tir.Load)
-        # We currently does not support tiles
-        # Change this when tiles are needed
-        # (i.e. when using rolling buffers)
-        assert npu_fm.tiles.addresses[1:] == [0, 0, 0]
-        npu_fm.tiles.addresses[1:] = [0, 0, 0]
+        for i in range(1, 4):
+            address = npu_fm.tiles.addresses[i]
+            if isinstance(address, tvm.tir.expr.Load):
+                address = address.index
+            npu_fm.tiles.addresses[i] = int(address)

Review comment:
       Ok cool... I'm sill a bit confused what is going on in that change, 
first it converts `IntImm` into `int` in that block using the addresses already 
is `npu_fm.tiles.addresses`, but then in the end of that function it overwrites 
the middle two addresses to with `address`. What's the reason for that 
overwriting in the end of the function? Maybe a comment would help there (can 
be done in a follow up though). 

##########
File path: python/tvm/relay/backend/contrib/ethosu/tir/transform.py
##########
@@ -21,19 +21,16 @@
 from .utils import get_base_address, get_op_attrs
 
 
-def get_copy_params(stmt, producers, consumers):
+def get_copy_params(stmt, producers_consumers):

Review comment:
       Ah yes, that's right...

##########
File path: python/tvm/relay/backend/contrib/ethosu/tir/dma.py
##########
@@ -287,31 +321,69 @@ def get_ifm_params(pointer, producers):
         The serializable padding.
 
     """
-    pad = producers[pointer]
+    pad = producers_consumers.get_producer(pointer, stmt)
     serial_padding, input_pointer, _ = get_pad_params(pad)
-    upscale = producers[input_pointer]
+    upscale = producers_consumers.get_producer(input_pointer, pad)
     input_pointer, _ = get_upscale_params(upscale)
-    convert_to_nhwc = producers[input_pointer]
+    convert_to_nhwc = producers_consumers.get_producer(input_pointer, upscale)
     in_channels, input_pointer, _ = get_convert_to_nhwc_params(convert_to_nhwc)
-    read = producers[input_pointer]
+    read = producers_consumers.get_producer(input_pointer, convert_to_nhwc)
     serial_ifm, _, _ = get_read_params(read)
     serial_ifm.channels = in_channels
+
+    floor_mod_stmt = None
+    for_stmt = None
+
+    def _get_buffer_var(stmt):
+        nonlocal for_stmt
+        nonlocal floor_mod_stmt
+        if isinstance(stmt, tvm.tir.For):
+            for_stmt = stmt
+        if isinstance(stmt, tvm.tir.FloorMod):
+            floor_mod_stmt = stmt
+
+    tvm.tir.stmt_functor.post_order_visit(stmt, _get_buffer_var)
+
+    if floor_mod_stmt is not None:
+        layout = get_op_attrs(read)[0]["layout"]
+        channels = serial_ifm.channels
+        if for_stmt.body.loop_var == floor_mod_stmt.a.a.a:

Review comment:
       Ok cool, that makes sense! :) 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to