junrushao1994 commented on a change in pull request #6107:
URL: https://github.com/apache/incubator-tvm/pull/6107#discussion_r459194418



##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -143,6 +143,282 @@ def test_compute_at_root_inline():
     assert s0[conv].iters[6].range.extent == 7
 
 
+def test_cache_read_write():
+    N, H, W, CO, CI, KH, KW, strides, padding = 4, 7, 7, 512, 512, 3, 3, (
+        1, 1), (1, 1)
+
+    data = te.placeholder((N, CI, H, W), name='Data')
+    kernel_data = te.placeholder((CO, CI, KH, KW), name='Kernel_data')
+    k0, k1 = te.compute(kernel_data.shape,
+                        lambda *i: (kernel_data(*i)+1, kernel_data(*i)/2),
+                        name='Kernel_split')
+    kernel = te.compute(kernel_data.shape,
+                        lambda *i: k0(*i) + k1(*i),
+                        name='Kernel')
+    conv = topi.nn.conv2d_nchw(data, kernel, strides, padding, dilation=1)
+    relu = topi.nn.relu(conv)
+    add = topi.add(data, relu)
+
+    dag = auto_scheduler.ComputeDAG([data, kernel_data, add])
+    s0 = dag.get_init_state()
+
+    pad_temp = s0.stage_ops[1]
+    kernel_split = s0.stage_ops[3]
+
+    # 0: init state
+    ori_its = s0[add].iters
+    its = s0.split(add, s0[add].iters[0], [2])
+    s0.reorder(add, [its[0], ori_its[1], its[1], ori_its[2], ori_its[3]])
+    s0.compute_inline(relu)
+
+    # 1: simple cache_write with compute_at
+    conv_global = s0.cache_write(conv, "global")
+    s0.compute_at(conv_global, conv, s0[conv].iters[3])
+
+    # 2: simple cache_read with compute_at
+    kernel_global = s0.cache_read(kernel, "global", [conv_global])
+    s0.compute_at(kernel_global, conv_global, s0[conv_global].iters[4])
+    """
+        Placeholder: Data, Kernel_data
+        for i0 (0,4)
+          for i1 (0,512)
+            for i2 (0,9)
+              for i3 (0,9)
+                pad_temp = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel_split = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel = ...
+        for nn (0,4)
+          for ff (0,512)
+            for yy (0,7)
+              for xx (0,7)
+                for nn_c (None)
+                  for ff_c (None)
+                    for yy_c (None)
+                      for xx_c (None)
+                        for rc (None)
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  Kernel.global = ...
+                          for ry (None)
+                            for rx (None)
+                              compute.global = ...
+                compute = ...
+        for ax0.0 (0,2)
+          for ax1 (0,512)
+            for ax0.1 (0,2)
+              for ax2 (0,7)
+                for ax3 (0,7)
+                  T_add = ...
+    """
+    s1 = dag.infer_bound_from_state(s0)
+    assert s1[conv].iters[0].range.extent == 4
+    assert s1[conv].iters[1].range.extent == 512
+    assert s1[conv].iters[2].range.extent == 7
+    assert s1[conv].iters[3].range.extent == 7
+    assert s1[kernel_global].iters[0].range.extent == 1
+    assert s1[kernel_global].iters[1].range.extent == 1
+    assert s1[kernel_global].iters[2].range.extent == 3
+    assert s1[kernel_global].iters[3].range.extent == 3
+    assert s1[conv_global].iters[0].range.extent == 1
+    assert s1[conv_global].iters[1].range.extent == 1
+    assert s1[conv_global].iters[2].range.extent == 1
+    assert s1[conv_global].iters[3].range.extent == 1
+    assert s1[conv_global].iters[4].range.extent == 512
+    assert s1[conv_global].iters[5].range.extent == 3
+    assert s1[conv_global].iters[6].range.extent == 3
+
+    # 3: two level cache_read with compute_at
+    #    preparing for GPU's shared memory & local memory
+    pad_temp_global = s0.cache_read(pad_temp, "global", [conv_global])
+    pad_temp_shared = s0.cache_read(pad_temp_global, "shared", [conv_global])
+    s0.compute_at(pad_temp_global, conv_global, s0[conv_global].iters[2])
+    s0.compute_at(pad_temp_shared, conv_global, s0[conv_global].iters[4])
+
+    # 4: cache_read with multi readers
+    #    This stage cannot be compute at to its consumer
+    s0.cache_read(data, "global", [pad_temp, add])
+    """
+        Placeholder: Data, Kernel_data
+        for ax0 (0,4)
+          for ax1 (0,512)
+            for ax2 (0,7)
+              for ax3 (0,7)
+                Data.global = ...
+        for i0 (0,4)
+          for i1 (0,512)
+            for i2 (0,9)
+              for i3 (0,9)
+                pad_temp = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel_split = ...
+        for i0 (0,512)
+          for i1 (0,512)
+            for i2 (0,3)
+              for i3 (0,3)
+                Kernel = ...
+        for nn (0,4)
+          for ff (0,512)
+            for yy (0,7)
+              for xx (0,7)
+                for nn_c (None)
+                  for ff_c (None)
+                    for yy_c (None)
+                      for ax0 (None)
+                        for ax1 (None)
+                          for ax2 (None)
+                            for ax3 (None)
+                              pad_temp.global = ...
+                      for xx_c (None)
+                        for rc (None)
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  Kernel.global = ...
+                          for ax0 (None)
+                            for ax1 (None)
+                              for ax2 (None)
+                                for ax3 (None)
+                                  pad_temp.global.shared = ...
+                          for ry (None)
+                            for rx (None)
+                              compute.global = ...
+                compute = ...
+        for ax0.0 (0,2)
+          for ax1 (0,512)
+            for ax0.1 (0,2)
+              for ax2 (0,7)
+                for ax3 (0,7)
+                  T_add = ...
+    """
+    s1 = dag.infer_bound_from_state(s0)
+    assert s1[conv].iters[0].range.extent == 4
+    assert s1[conv].iters[1].range.extent == 512
+    assert s1[conv].iters[2].range.extent == 7
+    assert s1[conv].iters[3].range.extent == 7
+    assert s1[kernel_global].iters[0].range.extent == 1
+    assert s1[kernel_global].iters[1].range.extent == 1
+    assert s1[kernel_global].iters[2].range.extent == 3
+    assert s1[kernel_global].iters[3].range.extent == 3
+    assert s1[conv_global].iters[0].range.extent == 1
+    assert s1[conv_global].iters[1].range.extent == 1
+    assert s1[conv_global].iters[2].range.extent == 1
+    assert s1[conv_global].iters[3].range.extent == 1
+    assert s1[conv_global].iters[4].range.extent == 512
+    assert s1[conv_global].iters[5].range.extent == 3
+    assert s1[conv_global].iters[6].range.extent == 3
+    assert s1[pad_temp_global].iters[0].range.extent == 1
+    assert s1[pad_temp_global].iters[1].range.extent == 512
+    assert s1[pad_temp_global].iters[2].range.extent == 3
+    assert s1[pad_temp_global].iters[3].range.extent == 3
+    assert s1[pad_temp_shared].iters[0].range.extent == 1
+    assert s1[pad_temp_shared].iters[1].range.extent == 1
+    assert s1[pad_temp_shared].iters[2].range.extent == 3
+    assert s1[pad_temp_shared].iters[3].range.extent == 3
+
+    # 5: cache_write with multi outputs
+    # TVM's cache_write actually has a bug with this case:
+    #
+    # After schedule.cache_write, TVM generate one new stage:
+    #   From: kernel_data -> kernel_split -> kernel
+    #   To:   kernel_data -> kernel_split_global -> kernel_split -> kernel
+    #
+    # But with topo sort analyse, we get:
+    #  //   kernel_data -> kernel_split_global -> kernel_split -> kernel
+    #         \                                                /
+    #          ----------------> kernel_split ---------------->
+    #
+    # Seems there's bug with the input/output tensor. Such multi outputs case
+    # should be unusual, so we make some hack on DoCacheWrite
+    # To be fixed in the future

Review comment:
       I think maybe we should explicitly add a "TODO" mark here?

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -923,5 +958,272 @@ String 
ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
   return ss.str();
 }
 
+/********** Primitives adding new stages **********/
+
+/*!
+ * \brief Common part for steps that add new stages(e.g. CacheReadStep, 
CacheWriteStep,
+ * RfactorStep). This will filter out all steps that can change the stages of 
ComputeDAG.
+ */
+Array<Step> GetStageModifiableSteps(Step current_step, const Array<Step>& 
transform_steps) {
+  Array<Step> ret_steps;
+  for (const Step& step : transform_steps) {
+    if (step->IsInstance<CacheWriteStepNode>() || 
step->IsInstance<CacheReadStepNode>()) {
+      ret_steps.push_back(step);
+    }
+    // TODO(jcf94): add rfactor support
+    if (step.same_as(current_step)) {

Review comment:
       Does it mean we stop by `current_step`? If so, shall we reflect this in 
the document or function name?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to