vinx13 commented on code in PR #14182:
URL: https://github.com/apache/tvm/pull/14182#discussion_r1123927489


##########
include/tvm/relax/transform.h:
##########
@@ -272,6 +273,12 @@ TVM_DLL Pass RemoveUnusedFunctions(Array<runtime::String> 
entry_functions);
 TVM_DLL Pass RunCodegen(Optional<Map<String, Map<String, ObjectRef>>> 
target_options,
                         Array<runtime::String> entry_functions);
 
+/*!
+ * \brief Create default schedule for PrimFuncs.
+ * \return The Pass.
+ */
+TVM_DLL Pass DefaultSchedule(tvm::Target target);

Review Comment:
   shall we get target using `Target::Current()` instead



##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which 
calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod, int64_t max_thread_per_block) 
{
+    ThreadBindMutator mutator(mod);
+
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({kv}));
+        tir::Schedule sch = tir::Schedule::Traced(mod, /*seed=*/-1, 
/*debug_mask=*/0,
+                                                  
tir::ScheduleErrorRenderLevel::kDetail);
+        Array<tir::BlockRV> blocks = 
meta_schedule::BlockCollector::Collect(sch);
+        for (const tir::BlockRV& block : blocks) {
+          // fetch the loops
+          Array<tir::LoopRV> loops = sch->GetLoops(block);
+          bool scheduled = false;
+          for (const tir::LoopRV& loop : loops) {
+            if (sch->Get(loop)->thread_binding.defined()) {
+              scheduled = true;
+              break;
+            }
+          }
+          // skip if already scheduled
+          if (scheduled) {
+            continue;
+          }
+          Array<tir::IterVar> iters = sch->Get(block)->iter_vars;
+          ICHECK_EQ(loops.size(), iters.size());
+          Array<tir::LoopRV> data_parallel_loops;
+          // only fuse data parallel loops
+          for (size_t i = 0; i < loops.size(); ++i) {
+            if (iters[i]->iter_type == tir::IterVarType::kDataPar) {
+              data_parallel_loops.push_back(loops[i]);
+            }
+          }
+          if (data_parallel_loops.size() == 0) {
+            continue;
+          }
+          // fuse all data parallel loops
+          tir::LoopRV fused = sch->Fuse(data_parallel_loops, 
/*preserve_unit_iters=*/false);
+          int64_t product = std::numeric_limits<int64_t>::max();
+          if (sch->Get(fused)->extent->IsInstance<tir::IntImmNode>()) {
+            product = sch->Get(fused)->extent.as<tir::IntImmNode>()->value;
+          }
+          static const int64_t max_threadblocks = 256;
+          // schedule the fused loop
+          if (product > max_thread_per_block * max_threadblocks) {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused,
+                /*factors=*/{NullOpt, Integer(max_threadblocks), 
Integer(max_thread_per_block)});
+            sch->Reorder(/*ordered_loop_rvs=*/{splits[1], splits[2], 
splits[0]});
+            sch->Bind(splits[1], "blockIdx.x");
+            sch->Bind(splits[2], "threadIdx.x");
+          } else {
+            Array<tir::LoopRV> splits = sch->Split(
+                fused, /*factors=*/{NullOpt, Integer(std::min(product, 
max_thread_per_block))});
+            sch->Bind(splits[0], "blockIdx.x");
+            sch->Bind(splits[1], "threadIdx.x");
+          }
+        }
+        mutator.builder_->AddFunction(sch->mod()->Lookup(gv->name_hint), 
gv->name_hint);
+      } else {
+        mutator.builder_->AddFunction(func, gv->name_hint);

Review Comment:
   not needed since the function is already in the module and it's not changed



##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which 
calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {
+ public:
+  static IRModule Transform(const IRModule& mod, int64_t max_thread_per_block) 
{
+    ThreadBindMutator mutator(mod);
+
+    for (const auto& kv : mod->functions) {
+      const GlobalVar& gv = kv.first;
+      const BaseFunc& func = kv.second;
+
+      if (func->IsInstance<tir::PrimFuncNode>()) {
+        IRModule mod = IRModule(Map<GlobalVar, BaseFunc>({kv}));

Review Comment:
   it is possible to skip creating a new module if we use `Schedule::WorkOn` to 
directly schedule the original module (after copy on write), see 
https://discuss.tvm.apache.org/t/manual-scheduling-of-call-tir-functions-in-relax/14446/7?u=vinx13



##########
src/relax/transform/default_schedule.cc:
##########
@@ -0,0 +1,143 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/tir/schedule/schedule.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include "../../meta_schedule/utils.h"
+#include "../../relay/analysis/graph_partitioner.h"
+#include "../../support/arena.h"
+#include "../../tir/ir/functor_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief The helper class to schedule functions and build a new module which 
calls the new TIR
+ * function.
+ */
+class ThreadBindMutator : public ExprMutator {

Review Comment:
   since we are not visiting the function bodies probably it's not necessary to 
use a `ExprMutator`, we can directly update the module using `IRModule::Add`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to