Lunderberg commented on PR #77:
URL: https://github.com/apache/tvm-rfcs/pull/77#issuecomment-1163620046

   Writing out some of my thoughts, to see if there's a way to express the 
constraints while only using existing TIR features.  The main goals would be as 
follows.
   
   1. Allow simplification of expressions based on the values present in the 
padding.
   2. Allow local simplifications to take advantage of non-local constraints, 
without requiring a full end-to-end analysis.
   3. Specify the non-local constraints in some deducible manner that doesn't 
impose a runtime performance penalty.
      
   Next, working through various options for how the constraints could be 
stored. In the examples below, sketching out how these would apply to the 
element-wise operation which starts as below.
   
   ```python
   @T.prim_func
   def func(A: T.Buffer[(14), "int32"], B: T.Buffer[14, "int32"]):
       for i in T.serial(14):
           B[i] = 2 * A[i]
   ```
   
   1. Apply layout transforms on local caches.  Here, the full lifetime of a 
buffer is known.  All TIR optimization are done prior to hoisting the cache and 
layout transformation into the graph level.
      
      - For read caches, pad value is whatever gets conditionally written to 
the padding while generating it.  In example below, `AC` could be recognized as 
being padded.
        
        ```python
        @T.prim_func
        def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
            AC = T.alloc_buffer([4, 4], "int32")
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    AC[io, ii] = A[4 * io + ii]
                else:
                    AC[io, ii] = 0
        
            for i in T.serial(14):
                B[i] = 2 * AC[i // 4, i % 4]
        ```
        
      - For write caches, pad value is whatever is in the padding after the 
last write to the cache.  In example below, `BC` could be recognized as being 
padded.
   
        ```python
        @T.prim_func
        def func(A: T.Buffer[14, "int32"], B: T.Buffer[14, "int32"]):
            BC = T.alloc_buffer([4, 4], "int32")
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    BC[io, ii] = 2 * A[4*io + ii]
                else:
                    BC[io, ii] = 0
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    B[i] = BC[io, ii]
        ```
   
      - Downside, either of the `else` statements could be eliminated as a 
no-op, since they don't contribute to the output `B` value. After that 
elimination, there wouldn't be any way to reconstruct the pad value.
        
   2. When hoisting an allocation+transformation, write the pad value to the 
buffer at the start of function from which it was hoisted. This way, the pad 
value can still be used in local reasoning.
      
      - No change needed in producers, since they would already write the pad 
value to the buffer.
      
      - For consumers, would be represented as writing `pad_value` into the 
padding at the start of the function.
      
        ```python
        @T.prim_func
        def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
            for io, ii in T.grid(4, 4):
                if 4 * io + ii >= 14:
                    AC[io, ii] = 0
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    B[4 * io + ii] = 2 * AC[io, ii]
        ```
        
      - Downside, repeated unnecessary effort at the beginning of each 
consumer.  Avoiding it with this representation would require knowing that the 
producer had written `pad_value` already, which is exactly the information 
we're trying to avoid.
        
   3. When hoisting an allocation+transformation, write the pad value to the 
buffer at the start of function from which it was hoisted, and write 
`T.undef()` at the end.  This way, the pad value can still be used in local 
reasoning, and no-op removal can remove the repeated writing when lowering.
      
      - No change needed in producers, since they would already write the pad 
value to the buffer.
        
      - For consumers, would be like option 2, but with an additional write of 
`T.undef()` at the end of the function.  When lowering, the write of 
`T.undef()` would allow the first write to be removed as a no-op because it is 
overwritten.  The `T.undef()` can then be removed as described in the RFC.
      
        ```python
        @T.prim_func
        def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
            for io, ii in T.grid(4, 4):
                if 4 * io + ii >= 14:
                    AC[io, ii] = 0
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    B[4 * io + ii] = 2 * AC[io, ii]
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii >= 14:
                    AC[io, ii] = T.undef()
        ```
        
      - Downside, no way to distinguish between "can assume the pad value is 
zero" and "can overwrite the pad value at will".  The writing of `T.undef()` 
would allow any writes to the padding to be inserted as a no-op.
        
      - Downside, wouldn't actually simplify out in cases where the pad value 
is used.  The first in a pair of repeated writes to the same location can only 
be removed if there are no reads between the writes.  After using the pad value 
to eliminate `if 4 * io + ii < 14` from the compute, the dummy loop that writes 
the padding could no longer be removed.
        
   4. Use `AssertStmt` in a loop to declare known information about the buffers.
   
      - No change needed in producers, since the pad value is already written 
out.
        
      - For consumers, would have an initial loop that asserts the pad value is 
correct.
   
        ```python
        @T.prim_func
        def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
            for io, ii in T.grid(4, 4):
                if 4 * io + ii >= 14:
                    assert AC[io, ii] == 0, "padding"
        
            for io, ii in T.grid(4, 4):
                if 4 * io + ii < 14:
                    B[4 * io + ii] = 2 * AC[io, ii]
        ```
        
      - Downside, assert statements have target-dependent handling.  In 
`CodeGenLLVM` and `CodeGenSPIRV`, they are treated as no-ops.  In `CodeGenCPU` 
and `CodeGenC`, they generate asserts.  In `CodeGenCUDA`, they aren't handled 
at all and would error out.
        
        Could work around this with a lowering pass, but identifying these 
conditions would require having a special string in the message, and packing 
structured data into strings makes me wary.
        
   5. Use `AssertStmt` with implicitly-defined variables to declare known 
information about the buffers.
      
      ```python
      @T.prim_func
      def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
          a = T.var("int32")
          b = T.var("int32")
          assert (
              AC[a, b] == 0 or (4 * a + b < 14) or (a < 0) or (a >= 4) or (b < 
0) or (b >= 4)
          ), "padding"
      
          for io, ii in T.grid(4, 4):
              if 4 * io + ii < 14:
                  B[4 * io + ii] = 2 * AC[io, ii]
      ```
      
      - Can apply to clamped texture memory, since the variables in the 
assertion isn't restricted to the bounds.
        
      - Would need to recognize specific pattern of `BufferLoad` being used to 
define variables used in constraint.
        
      - The implicitly-defined variables can be written in current TIR, but  
variables would ensure that this isn't something that ever makes it into 
generated code at runtime.
      
      - Downside, implicitly-defined variables are something of a red flag.
   
   6. Store constraints in the function attributes, either as a dictionary or 
as a structured object.
      
      ```python
      @T.prim_func
      def func(AC: T.Buffer[(4, 4), "int32"], B: T.Buffer[14, "int32"]):
          T.func_attr(
              "buffer_constraints",
              [
                  {
                      "buffer": AC,
                      "predicate": lambda io, ii: 4 * io + ii < 14,
                      "pad_value": lambda io, ii: 0,
                  },
              ],
          )
      
          for io, ii in T.grid(4, 4):
              if 4 * io + ii < 14:
                  B[4 * io + ii] = 2 * AC[io, ii]
      ```
      
      - Downside, requires transformations that change a buffer to be aware 
that other structures will also need to be replaced.
        
      - Downside, requires simplifications to either be passed the entire 
`PrimFunc`, or to be explicitly passed the `"buffer_constraints"` list.
        
      - Downside, would break expectations of `IRMutatorWithAnalyzer`. The 
current entry point of any `Stmt` or `Expr` would need to have additional 
information of the `"buffer_constraints"`.
        
   
   7. Store constraints in the `Buffer` object, either as a dictionary or as a 
structured object.
      
      ```python
      @T.prim_func
      def func(ac: T.handle, B: T.Buffer[14, "int32"]):
          AC = T.match_buffer(
              shape=(4, 4),
              dtype="int32",
              constraints=[T.BufferConstraints(predicate=lambda io, ii: 4 * io 
+ ii < 14, pad_value=0)],
          )
      
          for io, ii in T.grid(4, 4):
              if 4 * io + ii < 14:
                  B[4 * io + ii] = 2 * AC[io, ii]
      ```
      
      - Downside, introduces additional data structure in TIR.


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to