junrushao commented on code in PR #12520:
URL: https://github.com/apache/tvm/pull/12520#discussion_r956484627


##########
src/relay/backend/te_compiler_cache.cc:
##########
@@ -359,32 +350,43 @@ class ScheduleBuilder : public ExprVisitor {
           schedule = Downcast<te::Schedule>(obj);
         }
       }
-      if (meta_schedule_ctx_) {
+      if (database_) {
+        using tvm::meta_schedule::TuningRecord;
+        using tvm::tir::IndexMap;
+        using tvm::tir::Instruction;
+        using tvm::tir::InstructionKind;
+        using tvm::tir::PrimFunc;
+        using tvm::tir::Schedule;
+        backend::FTECompilerTIRConverter tir_converter = 
backend::GetTIRConverter();
         Array<te::Tensor> te_args = Concat(fn_inputs, tensor_outs);
         Array<runtime::NDArray> constants;
         for (auto [const_node, te_tensor] : 
lower_te_compute.constant_tensors_) {
           te_args.push_back(te_tensor);
           constants.push_back(const_node->data);
         }
-
-        if (Optional<tir::PrimFunc> tir_func =
-                meta_schedule_ctx_.value()->te_filter_func(te_args, 
constants)) {
-          IRModule relay_mod({{prim_fn_var, relay_func}});
-          IRModule tir_mod({{prim_fn_var, tir_func.value()}});
-          if (Optional<IRModule> opt_scheduled_mod = 
meta_schedule_ctx_.value()->Query(
-                  /*task_name=*/prim_fn_var->name_hint,     //
-                  /*mod=*/relay_mod,                        //
-                  /*target=*/target_,                       //
-                  /*dispatched=*/Array<IRModule>{tir_mod},  //
-                  /*f_take_tuning_record=*/ExtractTransformLayout)) {
-            IRModule scheduled_mod =
-                
tir::transform::RemoveWeightLayoutRewriteBlock()(opt_scheduled_mod.value());
-            ICHECK_EQ(scheduled_mod->functions.count(prim_fn_var), 1);
-            prim_func = 
Downcast<tir::PrimFunc>(scheduled_mod->functions[prim_fn_var]);
+        if (Optional<PrimFunc> f = tir_converter(te_args, constants)) {
+          if (Optional<TuningRecord> opt_record = 
database_.value()->QueryTuningRecord(
+                  /*mod=*/backend::PrimFuncToIRModule(f.value()),
+                  /*target=*/target_)) {
+            static InstructionKind kind_transform_layout = 
InstructionKind::Get("TransformLayout");
+            TuningRecord record = opt_record.value();
+            for (const Instruction& inst : record->trace->insts) {
+              if (inst->kind.same_as(kind_transform_layout)) {
+                ICHECK_EQ(inst->attrs.size(), 3);
+                
MetaScheduleLayoutRewriter::LayoutQueuePush(Downcast<IndexMap>(inst->attrs[2]));

Review Comment:
   Oh this is an extremely hacky version of weight layout rewriting on CPU



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to