adstraw commented on a change in pull request #9287:
URL: https://github.com/apache/tvm/pull/9287#discussion_r732307896



##########
File path: tests/python/contrib/test_hexagon/README.md
##########
@@ -118,173 +128,220 @@ primfn(input_handle: handle, kernel_handle: handle, 
output_handle: handle) -> ()
             for (wi.c: int32, 0, 8) {
               for (ki.c: int32, 0, 32) {
                 for (rc.inner: int32, 0, 32) {
-                  output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + 
(wi.c*32)) + ki.c)] = 
+                  output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + 
ki.c)] = 
                   (
-                    (float32*)output.cache[(((((wo.c*4096) + (ko.c*2048)) + 
(hi.c*256)) + (wi.c*32)) + ki.c)] + 
+                    (float32*)output.cache[((((wo.c*2048) + (hi.c*256)) + 
(wi.c*32)) + ki.c)] + 
                     (
                       (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) 
+ (hi.c*256)) + (wi.c*32)) + rc.inner)] *
-                      (float32*)kernel_pointer[(((((ko.c*2048) + 
(rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + 
floormod(rc.inner, 4))]
+                      (float32*)filter.cache[((((rc.outer*1024) + 
(floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))]
                     )
                   )
                 }
               }
             }
           }
-        } // end rc.outer
-      } // end ko.c
-    } // end wo.c
+        }
+      } // end wo.c
 
-    // cache write
-    for (wo: int32, 0, 8) {
-      for (ko: int32, 0, 2) {
+      // cache write
+      for (wo: int32, 0, 8) {
         for (hi: int32, 0, 8) {
           for (wi: int32, 0, 8) {
             for (ki: int32, 0, 32) {
-              output_pointer[((((((ho.outer*32768) + (wo*4096)) + (ko*2048)) + 
(hi*256)) + (wi*32)) + ki)] = 
-                (float32*)output.cache[(((((wo*4096) + (ko*2048)) + (hi*256)) 
+ (wi*32)) + ki)]
+              output_pointer[((((((ho.outer*65536) + (wo*8192)) + 
(ko.outer*2048)) + (hi*256)) + (wi*32)) + ki)] = 
+                (float32*)output.cache[((((wo*2048) + (hi*256)) + (wi*32)) + 
ki)]
             }
           }
         }
       }
-    }
-  }
+    } // end ho.outer
+  } // end ko.outer
 }
 ```
 
-# Split on Height - "Full Output Slice"
+# Split on Channel Out and Height - "Full Output Slice"
 
-Adds a new parameter `h_split` which creates a loop split on the height `h` 
dimension.  The cache reads and writes are moved to the outer of the two loops 
created by that split - the loop over `ho.outer`.  This increases cache usage 
by a factor equivalent to `h_split`.  The compute is still "full width" and 
"full depth" in the channel-out dimension and now over multiple slices in the 
height `h` dimension.  
+Adds new parameters `k_split` and `h_split` which creates a loop split on the 
outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops 
for each split.  The cache reads and writes are computed at `ho.outer` which 
means that cache allocation grow in proportion to `k_split` and `h_split` 
factors.
 
-The key changes in TIR versus the baseline are ...
+The key changes in TIR versus the above are...
 
 1) Increased cache allocations:
 
 ```
+  // input cache grows by factor of h_split = 2
   allocate(input.cache: Pointer(global float32), float32, [65536]), 
storage_scope = global;
+
+  // filter cache grows by factor of k_split = 2
+  allocate(filter.cache: Pointer(global float32), float32, [4096]), 
storage_scope = global;
+
+  // output cache grows by factor of h_split * k_split = 4
   allocate(output.cache: Pointer(global float32), float32, [65536]), 
storage_scope = global;
 ```
 
-2) The loop split on the `h` dimension:
+2) Outer loop splits using k_split and h_split factors
 
 ```
-  for (ho.outer: int32, 0, 4) {
-    for (ho.inner: int32, 0, 2) {
+  // ko.outer = outer loop split on ko using k_split factor
+  for (ko.outer: int32, 0, 2) {
+    // ho.outer = outer loop split on ho using h_split factor
+    for (ho.outer: int32, 0, 4) {
+```
+
+3) Inner loop splits in both cache read / write and compute schedules.  This 
is taken from the compute schedule e.g.
+```
+      for (ko.c.inner: int32, 0, 2) {
+        for (ho.c.inner: int32, 0, 2) {
 ```
 
 ## Command
 
-pytest -sv 
"tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-1-64-64-64-llvm]"
+pytest -sv 
"tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-2-1-64-64-128-llvm]"
 
 ## Parameters
 
 | Parameter | Value       |
 | --------- | ----------- |
 | Batch     | 1           |
-| Kernel    | 1x1         |
+| Filter    | 1x1         |
 | Spatial   | 64x64       |
 | Input Ch  | 64          |
-| Output Ch | 64          |
+| Output Ch | 128         |
 | Stride    | 1           |
 | Padding   | 0           |
 | Layout    | NHWC8h8w32c |
+| k_split   | 2           |
 | h_split   | 2           |
 
 ## Assumptions
 
-Same as baseline
+* n/a - With the loop splits on `ko` and `ho` the compute schedule is now over 
`ko.inner` `ho.inner` `wo` etc. This should fit the pattern matching for 
microkernels.
 
 ## To Do
 
-Same as baseline
+* n/a
 
 ## Annotated TIR
 
 ```
-primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> 
()
+primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> 
()
   attr = {"from_legacy_te_schedule": True, "global_symbol": 
"default_function", "tir.noalias": True, "target": meta[Target][0]}
-  buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, 
[1, 8, 8, 2, 8, 8, 32], []),
-             kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, 
[2, 2, 1, 1, 8, 32, 4], []),
-             input_buffer: Buffer(input_pointer: Pointer(float32), float32, 
[1, 64, 64, 64], [])}
-  buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, 
output_handle: output_buffer} {
-  
-  // increased cache usage due to h_split parameter
+  buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, 
[1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c
+             filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, 
[4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i
+             input_buffer: Buffer(input_pointer: Pointer(float32), float32, 
[1, 64, 64, 64], [])} // NHWC (pending RFC)
+  buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, 
output_handle: output_buffer} {
+
+  // input cache grows by factor of h_split = 2
   allocate(input.cache: Pointer(global float32), float32, [65536]), 
storage_scope = global;
+
+  // filter cache grows by factor of k_split = 2
+  allocate(filter.cache: Pointer(global float32), float32, [4096]), 
storage_scope = global;
+
+  // output cache grows by factor of h_split * k_split = 4
   allocate(output.cache: Pointer(global float32), float32, [65536]), 
storage_scope = global;
+  
+  // ko.outer = outer loop split on ko using k_split factor
+  for (ko.outer: int32, 0, 2) {
+    // ho.outer = outer loop split on ho using h_split factor
+    for (ho.outer: int32, 0, 4) {
+
+      // input cache read
+      // NHWC -> NHWC8h8w32c (pending RFC)
+      for (ho.inner: int32, 0, 2) {
+        for (wo: int32, 0, 8) {
+          for (co: int32, 0, 2) {
+            for (hi: int32, 0, 8) {
+              for (wi: int32, 0, 8) {
+                for (ci: int32, 0, 32) {
+                  input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) 
+ (hi*256)) + (wi*32)) + ci)] = 
+                    (float32*)input_pointer[(((((((ho.outer*65536) + 
(ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)]
+                }
+              }
+            }
+          }
+        }
+      } // end ho.inner
 
-  // loop split ho.outer vs. ho.inner based on h_split parameter
-  for (ho.outer: int32, 0, 4) {
-    for (ho.inner: int32, 0, 2) {
-      for (wo: int32, 0, 8) {
+      // filter cache read
+      for (ko.inner: int32, 0, 2) {
         for (co: int32, 0, 2) {
-          for (hi: int32, 0, 8) {
-            for (wi: int32, 0, 8) {
-              for (ci: int32, 0, 32) {
-                input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + 
(hi*256)) + (wi*32)) + ci)] = 
-                  (float32*)input_pointer[(((((((ho.outer*65536) + 
(ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)]
+          for (ci8: int32, 0, 8) {
+            for (ki: int32, 0, 32) {
+              for (ci4: int32, 0, 4) {
+                filter.cache[(((((ko.inner*2048) + (co*1024)) + (ci8*128)) + 
(ki*4)) + ci4)] = 
+                  (float32*)filter_pointer[((((((ko.outer*4096) + 
(ko.inner*2048)) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)]
               }
             }
           }
         }
-      }
-    }
-    for (ho.c.inner: int32, 0, 2) {
-      for (wo.c: int32, 0, 8) {
-        for (ko.c: int32, 0, 2) {
-          for (hi.c.init: int32, 0, 8) {
-            for (wi.c.init: int32, 0, 8) {
-              for (ki.c.init: int32, 0, 32) {
-                output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + 
(ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32
+      } // end ko.inner
+
+      // compute
+      for (ko.c.inner: int32, 0, 2) {
+        for (ho.c.inner: int32, 0, 2) {
+          for (wo.c: int32, 0, 8) {
+
+            // init output cache
+            for (hi.c.init: int32, 0, 8) {
+              for (wi.c.init: int32, 0, 8) {
+                for (ki.c.init: int32, 0, 32) {
+                  output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + 
(ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32
+                }
               }
             }
-          }
-          for (rc.outer: int32, 0, 2) {
-            for (hi.c: int32, 0, 8) {
-              for (wi.c: int32, 0, 8) {
-                for (ki.c: int32, 0, 32) {
-                  for (rc.inner: int32, 0, 32) {
-                    output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + 
(ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = 
-                    (
-                      (float32*)output.cache[((((((ho.c.inner*32768) + 
(wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + 
+
+            // convolution
+            for (rc.outer: int32, 0, 2) {
+              for (hi.c: int32, 0, 8) {
+                for (wi.c: int32, 0, 8) {
+                  for (ki.c: int32, 0, 32) {
+                    for (rc.inner: int32, 0, 32) {
+                      output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + 
(ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = 
                       (
-                        (float32*)input.cache[((((((ho.c.inner*32768) + 
(wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] *
-                        (float32*)kernel_pointer[(((((ko.c*2048) + 
(rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + 
floormod(rc.inner, 4))]
+                        (float32*)output.cache[((((((ho.c.inner*32768) + 
(wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + 
+                        (
+                          (float32*)input.cache[((((((ho.c.inner*32768) + 
(wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] *
+                          (float32*)filter.cache[(((((ko.c.inner*2048) + 
(rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + 
floormod(rc.inner, 4))]
+                        )
                       )
-                    )
+                    }
                   }
                 }
               }
             }
-          }
-        }
-      }
-    }
-    for (ho.inner: int32, 0, 2) {
-      for (wo: int32, 0, 8) {
-        for (ko: int32, 0, 2) {
-          for (hi: int32, 0, 8) {
-            for (wi: int32, 0, 8) {
-              for (ki: int32, 0, 32) {
-                output_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + 
(wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = 
-                  (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + 
(ko*2048)) + (hi*256)) + (wi*32)) + ki)]
+          } // end wo.c
+        } // end ho.c.inner
+      } // end ko.c.inner
+
+      // cache write
+      for (ko.inner: int32, 0, 2) {
+        for (ho.inner: int32, 0, 2) {
+          for (wo: int32, 0, 8) {
+            for (hi: int32, 0, 8) {
+              for (wi: int32, 0, 8) {
+                for (ki: int32, 0, 32) {
+                  output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) 
+ (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + 
ki)] = 
+                    (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) 
+ (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)]
+                }
               }
             }
           }
-        }
-      }
-    }
-  }
+        } // end ho.inner
+      } // end ko.inner
+    } // end ho.outer
+  } // end ko.outer
 }
 ```
 
 # 3x3 conv2d (no padding)
 
-Change from a 1x1 kernel to a 3x3 kernel.  The implication of this change is 
that `h_split + 1` rather than just `h_split` "full width" slices of the input 
are required to compute the output.  This is due to the fact that the 3x3 
kernel will "fall off the bottom" of the input and thus the vertically adjacent 
"full width" slice must be prefetched into the input cache.
+Change from a 1x1 filter to a 3x3 filter.  The implication of this change is 
that `h_split + 1` rather than just `h_split` "full width" slices of the input 
are required to compute the output.  This is due to the fact that the 3x3 
filter will "fall off the bottom" of the input and thus the vertically adjacent 
"full width" slice must be prefetched into the input cache.
 
 The key changes in TIR versus the above are...
 
 1) Increased input cache size to hold the vertically adjacent slice
 
 ```
+  // input cache grows to hold vertically adjacent slice

Review comment:
       A full-width full-channel-in-depth slice.  The explanation for this is 
above.  Line 337.  You need `h_split + 1` vertical slices to calculate 
`h_split` output slices. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to