junrushao1994 commented on a change in pull request #6073:
URL: https://github.com/apache/incubator-tvm/pull/6073#discussion_r456206595



##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -61,5 +61,79 @@ def test_split_fuse_reorder():
     assert s1[C].iters[4].range.extent == 8
     assert s1[C].iters[5].range.extent == 2
 
+    s1.parallel(C, j1)
+    s1.unroll(C, j2)
+    s1.vectorize(C, j3)
+    s1.bind(C, i1, "blockIdx.x")
+    s1.bind(C, i2, "vthread")
+    s1.bind(C, i3, "threadIdx.y")
+
+
+def test_compute_at_root_inline():
+    dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 
2, 3))
+    s0 = dag.get_init_state()
+
+    # data, padding, kernel = 0, 1, 2
+    conv = s0.stage_ops[3]
+    # bias = 4
+    bias_add = s0.stage_ops[5]
+    # bn_scale = 6
+    bn_mul = s0.stage_ops[7]
+    # bn_offset = 8
+    bn_add = s0.stage_ops[9]
+    relu = s0.stage_ops[10]
+
+    s0.compute_inline(bn_add)
+    s0.compute_inline(bn_mul)
+    s0.compute_inline(bias_add)
+    s0.compute_at(conv, relu, s0[relu].iters[2])
+    print(s0)
+    assert str(s0) == \
+        "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \
+        "for i1 (0,3)\n" + \
+        "  for i2 (0,230)\n" + \
+        "    for i3 (0,230)\n" + \
+        "      pad_temp = ...\n" + \
+        "for i1 (0,64)\n" + \
+        "  for i2 (0,112)\n" + \
+        "    for nn (None)\n" + \
+        "      for ff (None)\n" + \
+        "        for yy (None)\n" + \
+        "          for xx (None)\n" + \
+        "            for rc (None)\n" + \
+        "              for ry (None)\n" + \
+        "                for rx (None)\n" + \
+        "                  compute = ...\n" + \
+        "    for i3 (0,112)\n" + \
+        "      compute = ...\n"

Review comment:
       Use """ instead?

##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -61,5 +61,79 @@ def test_split_fuse_reorder():
     assert s1[C].iters[4].range.extent == 8
     assert s1[C].iters[5].range.extent == 2
 
+    s1.parallel(C, j1)
+    s1.unroll(C, j2)
+    s1.vectorize(C, j3)
+    s1.bind(C, i1, "blockIdx.x")
+    s1.bind(C, i2, "vthread")
+    s1.bind(C, i3, "threadIdx.y")
+
+
+def test_compute_at_root_inline():
+    dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 
2, 3))
+    s0 = dag.get_init_state()
+
+    # data, padding, kernel = 0, 1, 2
+    conv = s0.stage_ops[3]
+    # bias = 4
+    bias_add = s0.stage_ops[5]
+    # bn_scale = 6
+    bn_mul = s0.stage_ops[7]
+    # bn_offset = 8
+    bn_add = s0.stage_ops[9]
+    relu = s0.stage_ops[10]
+
+    s0.compute_inline(bn_add)
+    s0.compute_inline(bn_mul)
+    s0.compute_inline(bias_add)
+    s0.compute_at(conv, relu, s0[relu].iters[2])
+    print(s0)
+    assert str(s0) == \
+        "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \
+        "for i1 (0,3)\n" + \
+        "  for i2 (0,230)\n" + \
+        "    for i3 (0,230)\n" + \
+        "      pad_temp = ...\n" + \
+        "for i1 (0,64)\n" + \
+        "  for i2 (0,112)\n" + \
+        "    for nn (None)\n" + \
+        "      for ff (None)\n" + \
+        "        for yy (None)\n" + \
+        "          for xx (None)\n" + \
+        "            for rc (None)\n" + \
+        "              for ry (None)\n" + \
+        "                for rx (None)\n" + \
+        "                  compute = ...\n" + \
+        "    for i3 (0,112)\n" + \
+        "      compute = ...\n"
+
+    s0.compute_root(conv)
+    s0.compute_root(bn_mul)
+    assert str(s0) == \
+        "Placeholder: Data, Kernel, Bias, Bn_scale, Bn_offset\n" + \
+        "for i1 (0,3)\n" + \
+        "  for i2 (0,230)\n" + \
+        "    for i3 (0,230)\n" + \
+        "      pad_temp = ...\n" + \
+        "for nn (None)\n" + \
+        "  for ff (None)\n" + \
+        "    for yy (None)\n" + \
+        "      for xx (None)\n" + \
+        "        for rc (None)\n" + \
+        "          for ry (None)\n" + \
+        "            for rx (None)\n" + \
+        "              compute = ...\n" + \
+        "for i (None)\n" + \
+        "  for j (None)\n" + \
+        "    for k (None)\n" + \
+        "      for l (None)\n" + \
+        "        Bn_mul = ...\n" + \
+        "for i1 (0,64)\n" + \
+        "  for i2 (0,112)\n" + \
+        "    for i3 (0,112)\n" + \
+        "      compute = ...\n"

Review comment:
       ditto

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -82,6 +82,83 @@ String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* 
stages,
   return ss.str();
 }
 
+/********** Compute At **********/
+ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int 
target_iter_id) {
+  auto node = make_object<ComputeAtStepNode>();
+  node->stage_id = stage_id;
+  node->target_stage_id = target_stage_id;
+  node->target_iter_id = target_iter_id;
+  data_ = std::move(node);
+}
+
+void ComputeAtStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                        StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const IterVar& target_axis = 
(*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id];
+  stage.compute_at((*stages)[target_stage_id], target_axis);
+
+  stages->Set(stage_id, std::move(stage));
+}
+
+String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                           StageToAxesMap* stage_to_axes) 
const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+  const auto& target_stage = (*stages)[target_stage_id];
+  ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << 
CleanName(target_stage->op->name)
+     << "], " << 
CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << 
")\n";
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
+
+/********** Compute Root **********/
+ComputeRootStep::ComputeRootStep(int stage_id) {
+  auto node = make_object<ComputeRootStepNode>();
+  node->stage_id = stage_id;
+  data_ = std::move(node);
+}
+
+void ComputeRootStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                          StageToAxesMap* stage_to_axes) const 
{
+  auto stage = (*stages)[stage_id];
+  stage.compute_root();
+
+  stages->Set(stage_id, std::move(stage));
+}
+
+String ComputeRootStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                             StageToAxesMap* stage_to_axes) 
const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+  ss << "s[" << CleanName(stage->op->name) << "].compute_root()\n";
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
+
+/********** Compute Inline **********/
+ComputeInlineStep::ComputeInlineStep(int stage_id) {
+  auto node = make_object<ComputeInlineStepNode>();
+  node->stage_id = stage_id;
+  data_ = std::move(node);
+}
+
+void ComputeInlineStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                            StageToAxesMap* stage_to_axes) 
const {
+  auto stage = (*stages)[stage_id];
+  stage.compute_inline();
+

Review comment:
       ditto

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -82,6 +82,83 @@ String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* 
stages,
   return ss.str();
 }
 
+/********** Compute At **********/
+ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int 
target_iter_id) {
+  auto node = make_object<ComputeAtStepNode>();
+  node->stage_id = stage_id;
+  node->target_stage_id = target_stage_id;
+  node->target_iter_id = target_iter_id;
+  data_ = std::move(node);
+}
+
+void ComputeAtStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                        StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const IterVar& target_axis = 
(*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id];
+  stage.compute_at((*stages)[target_stage_id], target_axis);
+
+  stages->Set(stage_id, std::move(stage));
+}
+
+String ComputeAtStepNode::PrintAsPythonAPI(Array<te::Stage>* stages,
+                                           StageToAxesMap* stage_to_axes) 
const {
+  std::stringstream ss;
+  const auto& stage = (*stages)[stage_id];
+  const auto& target_stage = (*stages)[target_stage_id];
+  ss << "s[" << CleanName(stage->op->name) << "].compute_at(s[" << 
CleanName(target_stage->op->name)
+     << "], " << 
CleanName((*stage_to_axes)[target_stage][target_iter_id]->var->name_hint) << 
")\n";
+  ApplyToSchedule(stages, stage_to_axes);
+  return ss.str();
+}
+
+/********** Compute Root **********/
+ComputeRootStep::ComputeRootStep(int stage_id) {
+  auto node = make_object<ComputeRootStepNode>();
+  node->stage_id = stage_id;
+  data_ = std::move(node);
+}
+
+void ComputeRootStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                          StageToAxesMap* stage_to_axes) const 
{
+  auto stage = (*stages)[stage_id];
+  stage.compute_root();
+

Review comment:
       nit: remove this blank line.

##########
File path: src/auto_scheduler/transform_step.cc
##########
@@ -82,6 +82,83 @@ String ReorderStepNode::PrintAsPythonAPI(Array<te::Stage>* 
stages,
   return ss.str();
 }
 
+/********** Compute At **********/
+ComputeAtStep::ComputeAtStep(int stage_id, int target_stage_id, int 
target_iter_id) {
+  auto node = make_object<ComputeAtStepNode>();
+  node->stage_id = stage_id;
+  node->target_stage_id = target_stage_id;
+  node->target_iter_id = target_iter_id;
+  data_ = std::move(node);
+}
+
+void ComputeAtStepNode::ApplyToSchedule(Array<te::Stage>* stages,
+                                        StageToAxesMap* stage_to_axes) const {
+  te::Stage stage = (*stages)[stage_id];
+  const IterVar& target_axis = 
(*stage_to_axes)[(*stages)[target_stage_id]][target_iter_id];
+  stage.compute_at((*stages)[target_stage_id], target_axis);

Review comment:
       nit: make it slightly more succinct
   ```suggestion
     const te::Stage& target_stage = (*stages)[target_stage_id];
     const IterVar& target_axis = 
(*stage_to_axes)[target_stage][target_iter_id];
     stage.compute_at(target_stage, target_axis);
   ```

##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -61,5 +61,79 @@ def test_split_fuse_reorder():
     assert s1[C].iters[4].range.extent == 8
     assert s1[C].iters[5].range.extent == 2
 
+    s1.parallel(C, j1)
+    s1.unroll(C, j2)
+    s1.vectorize(C, j3)
+    s1.bind(C, i1, "blockIdx.x")
+    s1.bind(C, i2, "vthread")
+    s1.bind(C, i3, "threadIdx.y")
+
+
+def test_compute_at_root_inline():
+    dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 
2, 3))

Review comment:
       ```suggestion
       dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(N=1, H=224, W=224, 
CI=3, CO=64, kernel_size=7, strides=2, padding=3))
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to