echuraev commented on code in PR #13781:
URL: https://github.com/apache/tvm/pull/13781#discussion_r1070857226


##########
python/tvm/topi/adreno/reduction.py:
##########
@@ -25,45 +25,102 @@
 
 
 def _schedule_reduce_adreno(op, sch, is_idx_reduce=False):
-    if is_idx_reduce:
-        real_output = op.output(0)
+    sch_output = sch.outputs[0].output(0)
+    use_rfactor = False
+    if not is_idx_reduce:
+        rdomain = 1
+        whole_rop_output = op.output(0)
+        for i in range(len(sch[whole_rop_output].op.reduce_axis)):
+            rdomain = rdomain * 
sch[whole_rop_output].op.reduce_axis[i].dom.extent

Review Comment:
   ```suggestion
           for axis in sch[whole_rop_output].op.reduce_axis:
               rdomain = rdomain * axis.dom.extent
   ```



##########
python/tvm/topi/adreno/reduction.py:
##########
@@ -25,45 +25,102 @@
 
 
 def _schedule_reduce_adreno(op, sch, is_idx_reduce=False):
-    if is_idx_reduce:
-        real_output = op.output(0)
+    sch_output = sch.outputs[0].output(0)
+    use_rfactor = False
+    if not is_idx_reduce:
+        rdomain = 1
+        whole_rop_output = op.output(0)
+        for i in range(len(sch[whole_rop_output].op.reduce_axis)):
+            rdomain = rdomain * 
sch[whole_rop_output].op.reduce_axis[i].dom.extent
+        if rdomain > 50:
+            use_rfactor = True
+            # shared goves better perf, but works only for rfactor flow
+            scope = "shared"
+        else:
+            # in case of direct scheduling, shared is failed to be compiled
+            scope = "local"
+        if op in sch.outputs:
+            whole_rop_output = sch.cache_write(sch_output, scope)
+        else:
+            # no change for whole_rop_output def, but need to set proper scope
+            sch[whole_rop_output].set_scope(scope)
+    else:
         temp_idx_input = op.input_tensors[0].op.output(0)
         temp_val_input = op.input_tensors[0].op.output(1)
-    else:
-        real_output = op.output(0)
-    shape = get_const_tuple(real_output.shape)
+        sch[temp_idx_input].set_scope("local")
+        sch[temp_val_input].set_scope("local")
+
+    shape = get_const_tuple(sch_output.shape)
     latest4 = shape[-1] == 4
     div4 = numpy.prod(shape) % 4 == 0
 
     # Fuse and split the axis
     if latest4:
-        fused_outer = sch[real_output].fuse(
-            *[sch[real_output].op.axis[i] for i in 
range(len(sch[real_output].op.axis) - 1)]
+        fused_outer = sch[sch_output].fuse(
+            *[sch[sch_output].op.axis[i] for i in 
range(len(sch[sch_output].op.axis) - 1)]
         )
     else:
-        fused_outer = sch[real_output].fuse(
-            *[sch[real_output].op.axis[i] for i in 
range(len(sch[real_output].op.axis))]
+        fused_outer = sch[sch_output].fuse(
+            *[sch[sch_output].op.axis[i] for i in 
range(len(sch[sch_output].op.axis))]
         )
 
     ftc = numpy.prod(shape)
     a = fused_outer
-    if latest4:
-        sch[real_output].vectorize(sch[real_output].op.axis[-1])
-    elif div4 and not is_idx_reduce:
-        a, b = sch[real_output].split(fused_outer, factor=4)
-        sch[real_output].vectorize(b)
-        ftc = ftc / 4
 
-    num_thread = get_div(ftc, 128)
+    if not is_idx_reduce:
+        if use_rfactor:
+            # below values were selected empirically assuming that we should 
have some work in each
+            # thread (currently from 25-49) and number of threads not 
exceeding some threshold that
+            # was selected as 256 from performance point of view after 
experiments on Adreno 660
+            max_threads = rdomain.value // 25 if rdomain > 25 else 1
+            max_threads = 256 if max_threads > 256 else max_threads
+            num_thread = get_div(rdomain, max_threads)
 
-    bx, outer_in = sch[real_output].split(a, factor=num_thread)
+            fused_reduce = 
sch[whole_rop_output].fuse(*sch[whole_rop_output].op.reduce_axis)
+            thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
+            _, ki = sch[whole_rop_output].split(fused_reduce, 
factor=num_thread)
+            data_out_rf = sch.rfactor(whole_rop_output, ki)
+            sch[data_out_rf].compute_at(
+                sch[whole_rop_output], sch[whole_rop_output].op.reduce_axis[0]
+            )
+            
sch[whole_rop_output].bind(sch[whole_rop_output].op.reduce_axis[0], thread_y)
 
-    sch[real_output].bind(bx, te.thread_axis("blockIdx.x"))
-    sch[real_output].bind(outer_in, te.thread_axis("threadIdx.y"))
-    if is_idx_reduce:
-        sch[temp_idx_input].compute_at(sch[real_output], outer_in)
-        sch[temp_val_input].compute_at(sch[real_output], outer_in)
+    if div4:
+        if latest4:
+            b = sch[sch_output].op.axis[-1]
+        else:
+            a, b = sch[sch_output].split(fused_outer, factor=4)
+        sch[sch_output].vectorize(b)
+        if not use_rfactor:
+            if is_idx_reduce:
+                sch[temp_idx_input].compute_at(sch[sch_output], b)
+                sch[temp_val_input].compute_at(sch[sch_output], b)
+            else:
+                sch[whole_rop_output].compute_at(sch[sch_output], b)
+
+    if not use_rfactor:
+        num_thread = get_div(ftc, 128)
+        bx, outer_in = sch[sch_output].split(a, factor=num_thread)

Review Comment:
   Sorry, probably I missed something. Can we guarantee that `use_rfactor == 
False && div4 == True && latest4 == False`? Because if we cannot do that than 
what is the value of `a` in this case?



##########
python/tvm/topi/adreno/reduction.py:
##########
@@ -25,45 +25,102 @@
 
 
 def _schedule_reduce_adreno(op, sch, is_idx_reduce=False):
-    if is_idx_reduce:
-        real_output = op.output(0)
+    sch_output = sch.outputs[0].output(0)
+    use_rfactor = False
+    if not is_idx_reduce:
+        rdomain = 1
+        whole_rop_output = op.output(0)
+        for i in range(len(sch[whole_rop_output].op.reduce_axis)):
+            rdomain = rdomain * 
sch[whole_rop_output].op.reduce_axis[i].dom.extent
+        if rdomain > 50:
+            use_rfactor = True
+            # shared goves better perf, but works only for rfactor flow
+            scope = "shared"
+        else:
+            # in case of direct scheduling, shared is failed to be compiled
+            scope = "local"
+        if op in sch.outputs:
+            whole_rop_output = sch.cache_write(sch_output, scope)
+        else:
+            # no change for whole_rop_output def, but need to set proper scope
+            sch[whole_rop_output].set_scope(scope)
+    else:
         temp_idx_input = op.input_tensors[0].op.output(0)
         temp_val_input = op.input_tensors[0].op.output(1)
-    else:
-        real_output = op.output(0)
-    shape = get_const_tuple(real_output.shape)
+        sch[temp_idx_input].set_scope("local")
+        sch[temp_val_input].set_scope("local")
+
+    shape = get_const_tuple(sch_output.shape)
     latest4 = shape[-1] == 4
     div4 = numpy.prod(shape) % 4 == 0
 
     # Fuse and split the axis
     if latest4:
-        fused_outer = sch[real_output].fuse(
-            *[sch[real_output].op.axis[i] for i in 
range(len(sch[real_output].op.axis) - 1)]
+        fused_outer = sch[sch_output].fuse(
+            *[sch[sch_output].op.axis[i] for i in 
range(len(sch[sch_output].op.axis) - 1)]
         )
     else:
-        fused_outer = sch[real_output].fuse(
-            *[sch[real_output].op.axis[i] for i in 
range(len(sch[real_output].op.axis))]
+        fused_outer = sch[sch_output].fuse(
+            *[sch[sch_output].op.axis[i] for i in 
range(len(sch[sch_output].op.axis))]
         )
 
     ftc = numpy.prod(shape)
     a = fused_outer
-    if latest4:
-        sch[real_output].vectorize(sch[real_output].op.axis[-1])
-    elif div4 and not is_idx_reduce:
-        a, b = sch[real_output].split(fused_outer, factor=4)
-        sch[real_output].vectorize(b)
-        ftc = ftc / 4
 
-    num_thread = get_div(ftc, 128)
+    if not is_idx_reduce:
+        if use_rfactor:
+            # below values were selected empirically assuming that we should 
have some work in each
+            # thread (currently from 25-49) and number of threads not 
exceeding some threshold that
+            # was selected as 256 from performance point of view after 
experiments on Adreno 660
+            max_threads = rdomain.value // 25 if rdomain > 25 else 1
+            max_threads = 256 if max_threads > 256 else max_threads
+            num_thread = get_div(rdomain, max_threads)
 
-    bx, outer_in = sch[real_output].split(a, factor=num_thread)
+            fused_reduce = 
sch[whole_rop_output].fuse(*sch[whole_rop_output].op.reduce_axis)
+            thread_y = te.thread_axis((0, num_thread), "threadIdx.y")
+            _, ki = sch[whole_rop_output].split(fused_reduce, 
factor=num_thread)
+            data_out_rf = sch.rfactor(whole_rop_output, ki)
+            sch[data_out_rf].compute_at(
+                sch[whole_rop_output], sch[whole_rop_output].op.reduce_axis[0]
+            )
+            
sch[whole_rop_output].bind(sch[whole_rop_output].op.reduce_axis[0], thread_y)
 
-    sch[real_output].bind(bx, te.thread_axis("blockIdx.x"))
-    sch[real_output].bind(outer_in, te.thread_axis("threadIdx.y"))
-    if is_idx_reduce:
-        sch[temp_idx_input].compute_at(sch[real_output], outer_in)
-        sch[temp_val_input].compute_at(sch[real_output], outer_in)
+    if div4:
+        if latest4:
+            b = sch[sch_output].op.axis[-1]
+        else:
+            a, b = sch[sch_output].split(fused_outer, factor=4)
+        sch[sch_output].vectorize(b)
+        if not use_rfactor:
+            if is_idx_reduce:
+                sch[temp_idx_input].compute_at(sch[sch_output], b)
+                sch[temp_val_input].compute_at(sch[sch_output], b)
+            else:
+                sch[whole_rop_output].compute_at(sch[sch_output], b)
+
+    if not use_rfactor:
+        num_thread = get_div(ftc, 128)
+        bx, outer_in = sch[sch_output].split(a, factor=num_thread)
+        sch[sch_output].bind(bx, te.thread_axis("blockIdx.x"))
+        sch[sch_output].bind(outer_in, te.thread_axis("threadIdx.x"))
+
+        if not div4:
+            if is_idx_reduce:
+                sch[temp_idx_input].compute_at(sch[sch_output], outer_in)
+                sch[temp_val_input].compute_at(sch[sch_output], outer_in)
+            else:
+                sch[whole_rop_output].compute_at(sch[sch_output], outer_in)
+    else:
+        sch[sch_output].bind(a, te.thread_axis("blockIdx.x"))
+        if not div4 or use_rfactor:

Review Comment:
   if `div4 == False` then where `a` will be initialized?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to