This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c481950807 [Relax][Refactor] Phase out FewShotTuning (#18864)
c481950807 is described below

commit c481950807791f0d3c9e005381f76555b0ceb5aa
Author: Tianqi Chen <[email protected]>
AuthorDate: Mon Mar 2 12:49:42 2026 -0500

    [Relax][Refactor] Phase out FewShotTuning (#18864)
    
    ## Summary
    
    - Remove `FewShotTuning` pass from Relax transform (C++ implementation,
    Python bindings, and test file)
    - The pass is unused in the current codebase and can be safely removed
    
    ## Files Changed
    
    - `include/tvm/relax/transform.h` — Remove declaration
    - `python/tvm/relax/transform/__init__.py` — Remove from imports
    - `python/tvm/relax/transform/transform.py` — Remove Python function
    - `src/relax/transform/few_shot_tuning.cc` — Delete (C++ implementation)
    - `tests/python/relax/test_transform_few_shot_tuning.py` — Delete (test
    file)
---
 include/tvm/relax/transform.h                      |  12 -
 python/tvm/relax/transform/__init__.py             |   1 -
 python/tvm/relax/transform/transform.py            |  24 --
 src/relax/transform/few_shot_tuning.cc             | 188 ----------
 tests/lint/check_asf_header.py                     |  22 +-
 .../python/relax/test_transform_few_shot_tuning.py | 392 ---------------------
 6 files changed, 18 insertions(+), 621 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 0e660292c4..9ffeb05f8f 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -673,18 +673,6 @@ ToMixedPrecision(const DataType& out_dtype,
  */
 TVM_DLL Pass RewriteCUDAGraph();
 
-/*!
- * \brief The pass is designed for few shot tuning for static shape PrimFuncs. 
It examines all the
- *  blocks within the PrimFunc and conducts loop fusion, splitting, and other 
transformations based
- *  on MetaSchedule schedule rules but directly samples from the search space 
instead of using the
- *  tuning algorithm. User can specify the number of valid counts to try and 
whether to use runner
- *  for benchmarking.
- * \param valid_count The number of valid counts to try.
- * \param benchmark Whether to use runner for benchmarking.
- * \return The Pass.
- */
-TVM_DLL Pass FewShotTuning(int valid_count, bool benchmark);
-
 /*!
  * \brief This pass updates the var_buffer mapping of PrimFunctions from the 
call_tir info.
  * Primarily used to update the VDevice information if any changes occured 
from the caller.
diff --git a/python/tvm/relax/transform/__init__.py 
b/python/tvm/relax/transform/__init__.py
index 5bf79dc7c8..c3188adf50 100644
--- a/python/tvm/relax/transform/__init__.py
+++ b/python/tvm/relax/transform/__init__.py
@@ -41,7 +41,6 @@ from .transform import (
     EliminateCommonSubexpr,
     ExpandMatmulOfSum,
     ExpandTupleArguments,
-    FewShotTuning,
     FoldConstant,
     FunctionPass,
     FuseOps,
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index 8911423c65..e70392e88d 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -1264,30 +1264,6 @@ def MetaScheduleTuneIRMod(
     )  # type: ignore
 
 
-def FewShotTuning(
-    valid_count: int = 1,
-    benchmark: bool = False,
-) -> tvm.ir.transform.Pass:
-    """The pass is designed for few shot tuning for static shape PrimFuncs. It 
examines all the
-    blocks within the PrimFunc and conducts loop fusion, splitting, and other 
transformations based
-    on MetaSchedule schedule rules but directly samples from the search space 
instead of using the
-    tuning algorithm. User can specify the number of valid counts to try and 
whether to use runner
-    for benchmarking.
-
-    Parameters
-    ----------
-    valid_count: int
-        The number of valid counts to try.
-    benchmark: bool
-        Whether to use runner for benchmarking.
-
-    Returns
-    -------
-    ret: tvm.ir.transform.Pass
-    """
-    return _ffi_api.FewShotTuning(valid_count, benchmark)  # type: ignore
-
-
 def DecomposeOpsForInference(func_name: str | None = None) -> 
tvm.ir.transform.Pass:
     """Decompose composite operators that are composed by other operators 
during inference.
     For example, the result of batch norm (a triple) will be simplified. 
Attention, tensor_to_shape,
diff --git a/src/relax/transform/few_shot_tuning.cc 
b/src/relax/transform/few_shot_tuning.cc
deleted file mode 100644
index a88b92e8e4..0000000000
--- a/src/relax/transform/few_shot_tuning.cc
+++ /dev/null
@@ -1,188 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/relax/transform.h>
-
-#include "../../s_tir/meta_schedule/utils.h"
-
-namespace tvm {
-namespace relax {
-namespace transform {
-
-tir::PrimFunc FewShotTunePrimFunc(const tir::PrimFunc& prim_func, const 
Target& target,
-                                  int64_t valid_count, bool benchmark) {
-  // fetch a local builder
-  static const auto f_get_local_builder =
-      
tvm::ffi::Function::GetGlobalRequired("s_tir.meta_schedule.builder.get_local_builder");
-  s_tir::meta_schedule::Builder builder =
-      f_get_local_builder().cast<s_tir::meta_schedule::Builder>();
-  TVM_FFI_CHECK(builder.defined(), ValueError) << "The local builder is not 
defined!";
-  // fetch a local runner
-  s_tir::meta_schedule::Runner runner{ffi::UnsafeInit()};
-  if (benchmark) {
-    static const auto f_get_local_runner =
-        
tvm::ffi::Function::GetGlobalRequired("s_tir.meta_schedule.runner.get_local_runner");
-    runner = f_get_local_runner().cast<s_tir::meta_schedule::Runner>();
-    TVM_FFI_CHECK(runner.defined(), ValueError) << "The local runner is not 
defined!";
-  }
-  // create an IRModule
-  IRModule mod = IRModule(ffi::Map<GlobalVar, BaseFunc>(
-      {{GlobalVar("main"), WithAttr(prim_func, tvm::attr::kGlobalSymbol, 
ffi::String("main"))}}));
-  // fetch the number of physical cores
-  static const auto f_cpu_count =
-      tvm::ffi::Function::GetGlobalRequired("s_tir.meta_schedule.cpu_count");
-  int num_threads = f_cpu_count(false).cast<int>();
-  // store the results
-  ffi::Array<IRModule> results;
-  std::vector<double> costs;
-  // create a TuneContext
-  s_tir::meta_schedule::TuneContext task = s_tir::meta_schedule::TuneContext(
-      /*mod=*/mod,
-      /*target=*/target,
-      /*space_generator=*/
-      
s_tir::meta_schedule::SpaceGenerator::PostOrderApply(/*f_block_filter=*/nullptr,
-                                                           
/*sch_rules=*/std::nullopt,
-                                                           
/*postprocs=*/std::nullopt,
-                                                           
/*mutator_probs=*/std::nullopt),
-      
/*search_strategy=*/s_tir::meta_schedule::SearchStrategy::ReplayTrace(/*max_fail_count=*/100),
-      /*task_name=*/std::nullopt,
-      /*num_threads=*/num_threads,  // use all available local threads
-      /*rand_state=*/-1,            // -1 means use random seed
-      /*logger=*/nullptr);
-  task->Initialize();
-  task->search_strategy.value()->PreTuning(
-      /*max_trials=*/valid_count, /*num_trials_per_iter=*/valid_count,
-      
/*design_spaces=*/task->space_generator.value()->GenerateDesignSpace(mod),
-      /*database=*/std::nullopt,
-      /*cost_model=*/std::nullopt);
-  int fail_count = 0, max_fail_count = 100;
-  while (valid_count > 0 && fail_count < max_fail_count) {
-    ffi::Optional<ffi::Array<s_tir::meta_schedule::MeasureCandidate>> 
candidates =
-        task->search_strategy.value()->GenerateMeasureCandidates();
-    if (!candidates.defined()) break;
-    ffi::Array<s_tir::meta_schedule::BuilderInput> builder_inputs;
-    for (const s_tir::meta_schedule::MeasureCandidate& candidate : 
candidates.value()) {
-      builder_inputs.push_back(s_tir::meta_schedule::BuilderInput(
-          /*mod=*/candidate->sch->mod(),
-          /*target=*/target));
-    }
-    ffi::Array<s_tir::meta_schedule::BuilderResult> builder_results =
-        builder->Build(builder_inputs);
-    TVM_FFI_ICHECK_EQ(builder_results.size(), candidates.value().size());
-    int idx = 0;
-    bool no_valid = true;  // whether there is no valid schedule in this 
iteration
-    for (const s_tir::meta_schedule::BuilderResult& builder_result : 
builder_results) {
-      if (!builder_result->error_msg.has_value()) {
-        results.push_back(candidates.value()[idx]->sch->mod());
-        valid_count--;
-        no_valid = false;
-      }
-      idx++;
-    }
-    fail_count += no_valid;  // increase fail_count if there is no valid 
schedule
-    if (benchmark) {
-      ffi::Array<s_tir::meta_schedule::RunnerInput> runner_inputs;
-      int idx = 0;
-      for (const s_tir::meta_schedule::BuilderResult& builder_result : 
builder_results) {
-        if (!builder_result->error_msg.has_value()) {
-          runner_inputs.push_back(s_tir::meta_schedule::RunnerInput(
-              /*artifact_path=*/builder_result->artifact_path.value(),
-              /*device_type=*/target->kind->name,
-              /*args_info=*/candidates.value()[idx]->args_info));
-        }
-        idx++;
-      }
-      ffi::Array<s_tir::meta_schedule::RunnerFuture> runner_futures = 
runner->Run(runner_inputs);
-      for (const s_tir::meta_schedule::RunnerFuture& runner_future : 
runner_futures) {
-        s_tir::meta_schedule::RunnerResult runner_result = 
runner_future->Result();
-        if (runner_result->error_msg.has_value()) {
-          costs.push_back(1e10);
-        } else {
-          double sum = 0;
-          for (const FloatImm& cost : runner_result->run_secs.value()) {
-            sum += cost->value;
-          }
-          costs.push_back(sum / runner_result->run_secs.value().size());
-        }
-      }
-      TVM_FFI_ICHECK_EQ(costs.size(), results.size());
-    }
-  }
-  if (results.size() == 0) {
-    LOG(WARNING) << "No valid schedule found";
-    return prim_func;
-  }
-  if (fail_count >= max_fail_count) {
-    LOG(WARNING) << "Reached the maximum number of failed trials";
-  }
-  int best_idx = 0;
-  if (benchmark) {
-    for (size_t i = 1; i < costs.size(); ++i) {
-      if (costs[i] < costs[best_idx]) {
-        best_idx = i;
-      }
-    }
-  } else {
-    best_idx = results.size() - 1;
-  }
-  return WithAttr(Downcast<tir::PrimFunc>(results[best_idx]->Lookup("main")),
-                  tvm::tir::attr::kIsScheduled, Bool(true));
-}
-
-Pass FewShotTuning(int valid_count, bool benchmark) {
-  auto pass_func =  //
-      [=](IRModule m, PassContext pc) {
-        // input check
-        TVM_FFI_ICHECK(valid_count > 0) << "Valid_count must be positive.";
-        TVM_FFI_ICHECK(valid_count > 1 || !benchmark)
-            << "Benchmarking requires at least two valid trials.";
-        // get the target from context.
-        tvm::Target target = tvm::Target::Current();
-        TVM_FFI_ICHECK(target.defined()) << "Target is not set in current 
context";
-        // generate the few shot tuned prim funcs.
-        ffi::Map<GlobalVar, BaseFunc> result;
-        for (const auto& [gv, func] : m->functions) {
-          if (func->IsInstance<tir::PrimFuncNode>() &&
-              !func->HasNonzeroAttr(tir::attr::kIsScheduled)) {
-            result.Set(gv,
-                       
FewShotTunePrimFunc(ffi::GetRef<tir::PrimFunc>(func.as<tir::PrimFuncNode>()),
-                                           target, valid_count, benchmark));
-          } else {
-            result.Set(gv, func);
-          }
-        }
-        return IRModule(result,         // functions
-                        m->source_map,  // map
-                        m->attrs);      // attrs);
-      };
-  return CreateModulePass(/*pass_function=*/pass_func,    //
-                          /*opt_level=*/0,                //
-                          /*pass_name=*/"FewShotTuning",  //
-                          /*required=*/{});
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
-  namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("relax.transform.FewShotTuning", FewShotTuning);
-}
-
-}  // namespace transform
-}  // namespace relax
-}  // namespace tvm
diff --git a/tests/lint/check_asf_header.py b/tests/lint/check_asf_header.py
index f5dcf22fcd..f0bfdc6a87 100644
--- a/tests/lint/check_asf_header.py
+++ b/tests/lint/check_asf_header.py
@@ -216,17 +216,31 @@ def should_skip_file(filepath: str) -> bool:
 
 
 def get_git_files() -> list[str] | None:
-    """Get list of files tracked by git."""
+    """Get list of files tracked by git (excluding files staged for 
deletion)."""
     try:
         result = subprocess.run(
             ["git", "ls-files"], check=False, capture_output=True, text=True, 
cwd=Path.cwd()
         )
-        if result.returncode == 0:
-            return [line.strip() for line in result.stdout.split("\n") if 
line.strip()]
-        else:
+        if result.returncode != 0:
             print("Error: Could not get git files. Make sure you're in a git 
repository.")
             print("Git command failed:", result.stderr.strip())
             return None
+        all_files = {line.strip() for line in result.stdout.split("\n") if 
line.strip()}
+        # Exclude files staged for deletion so the header check does not
+        # report errors for files that are intentionally being removed.
+        deleted_result = subprocess.run(
+            ["git", "ls-files", "--deleted"],
+            check=False,
+            capture_output=True,
+            text=True,
+            cwd=Path.cwd(),
+        )
+        if deleted_result.returncode == 0:
+            deleted = {line.strip() for line in 
deleted_result.stdout.split("\n") if line.strip()}
+            all_files -= deleted
+        elif deleted_result.stderr:
+            print(f"Warning: 'git ls-files --deleted' failed: 
{deleted_result.stderr.strip()}")
+        return sorted(all_files)
     except FileNotFoundError:
         print("Error: Git not found. This tool requires git to be installed.")
         return None
diff --git a/tests/python/relax/test_transform_few_shot_tuning.py 
b/tests/python/relax/test_transform_few_shot_tuning.py
deleted file mode 100644
index 6c8ee37290..0000000000
--- a/tests/python/relax/test_transform_few_shot_tuning.py
+++ /dev/null
@@ -1,392 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-# pylint: disable=invalid-name,,missing-function-docstring
-# ruff: noqa: E501, F403
-
-import numpy as np
-import pytest
-
-import tvm
-import tvm.testing
-from tvm.relax.transform import FewShotTuning
-from tvm.s_tir.meta_schedule.arg_info import ArgInfo
-from tvm.s_tir.meta_schedule.testing.tune_utils import generate_input_data
-from tvm.s_tir.tensor_intrin.cuda import *  # pylint: 
disable=wildcard-import,unused-wildcard-import
-from tvm.s_tir.tensor_intrin.x86 import *  # pylint: 
disable=wildcard-import,unused-wildcard-import
-from tvm.script import tir as T
-
-
-# pylint: disable=no-self-argument,missing-class-docstring,line-too-long
-# fmt: off
[email protected]_module
-class MatMul:
-    @T.prim_func
-    def matmul(
-        A: T.Buffer((32, 32), "float16"),
-        B: T.Buffer((32, 32), "float16"),
-        C: T.Buffer((32, 32), "float16"),
-    ):
-        T.func_attr({"tir.noalias": True})
-        # with T.sblock("root"):
-        for i, j, k in T.grid(32, 32, 32):
-            with T.sblock("C"):
-                v_i, v_j, v_k = T.axis.remap("SSR", [i, j, k])
-                T.reads(A[v_i, v_k], B[v_k, v_j])
-                T.writes(C[v_i, v_j])
-                with T.init():
-                    C[v_i, v_j] = T.float16(0)
-                C[v_i, v_j] = C[v_i, v_j] + A[v_i, v_k] * B[v_k, v_j]
-
[email protected]_module
-class Softmax:
-    @T.prim_func
-    def softmax(rxplaceholder: T.Buffer((T.int64(8), T.int64(3456), 
T.int64(3456)), "float32"), T_softmax_norm: T.Buffer((T.int64(8), 
T.int64(3456), T.int64(3456)), "float32")):
-        T.func_attr({"op_pattern": 4, "tir.noalias": True})
-        # with T.sblock("root"):
-        T_softmax_maxelem = T.alloc_buffer((T.int64(8), T.int64(3456)), 
"float32")
-        T_softmax_exp = T.alloc_buffer((T.int64(8), T.int64(3456), 
T.int64(3456)), "float32")
-        T_softmax_expsum = T.alloc_buffer((T.int64(8), T.int64(3456)), 
"float32")
-        for i0, i1, k in T.grid(T.int64(8), T.int64(3456), T.int64(3456)):
-            with T.sblock("T_softmax_maxelem"):
-                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
-                T.reads(rxplaceholder[v_i0, v_i1, v_k])
-                T.writes(T_softmax_maxelem[v_i0, v_i1])
-                with T.init():
-                    T_softmax_maxelem[v_i0, v_i1] = T.float16(-65504)
-                T_softmax_maxelem[v_i0, v_i1] = T.max(T_softmax_maxelem[v_i0, 
v_i1], rxplaceholder[v_i0, v_i1, v_k])
-        for i0, i1, i2 in T.grid(T.int64(8), T.int64(3456), T.int64(3456)):
-            with T.sblock("T_softmax_exp"):
-                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
-                T.reads(rxplaceholder[v_i0, v_i1, v_i2], 
T_softmax_maxelem[v_i0, v_i1])
-                T.writes(T_softmax_exp[v_i0, v_i1, v_i2])
-                T_softmax_exp[v_i0, v_i1, v_i2] = T.exp(rxplaceholder[v_i0, 
v_i1, v_i2] - T_softmax_maxelem[v_i0, v_i1])
-        for i0, i1, k in T.grid(T.int64(8), T.int64(3456), T.int64(3456)):
-            with T.sblock("T_softmax_expsum"):
-                v_i0, v_i1, v_k = T.axis.remap("SSR", [i0, i1, k])
-                T.reads(T_softmax_exp[v_i0, v_i1, v_k])
-                T.writes(T_softmax_expsum[v_i0, v_i1])
-                with T.init():
-                    T_softmax_expsum[v_i0, v_i1] = T.float16(0)
-                T_softmax_expsum[v_i0, v_i1] = T_softmax_expsum[v_i0, v_i1] + 
T_softmax_exp[v_i0, v_i1, v_k]
-        for i0, i1, i2 in T.grid(T.int64(8), T.int64(3456), T.int64(3456)):
-            with T.sblock("T_softmax_norm"):
-                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
-                T.reads(T_softmax_exp[v_i0, v_i1, v_i2], 
T_softmax_expsum[v_i0, v_i1])
-                T.writes(T_softmax_norm[v_i0, v_i1, v_i2])
-                T.sblock_attr({"axis": 2})
-                T_softmax_norm[v_i0, v_i1, v_i2] = T_softmax_exp[v_i0, v_i1, 
v_i2] / T_softmax_expsum[v_i0, v_i1]
-
[email protected]_module
-class Fused_Variance_Cast1:
-    @T.prim_func
-    def main(lv3: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), 
"float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), 
"float16")):
-        T.func_attr({"tir.noalias": True})
-        # with T.sblock("root"):
-        rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32), 
T.int64(1)))
-        T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
-        T_subtract = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(34560)))
-        T_multiply = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(34560)))
-        T_multiply_red = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
-        T_divide_1 = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
-        for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1), 
T.int64(34560)):
-            with T.sblock("rxplaceholder_red"):
-                v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, 
ax2, k2])
-                T.reads(lv3[v_ax0, v_ax1, v_k2])
-                T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2])
-                with T.init():
-                    rxplaceholder_red[v_ax0, v_ax1, v_ax2] = T.float32(0)
-                rxplaceholder_red[v_ax0, v_ax1, v_ax2] = 
rxplaceholder_red[v_ax0, v_ax1, v_ax2] + lv3[v_ax0, v_ax1, v_k2]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
-            with T.sblock("T_divide"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2])
-                T.writes(T_divide[v_ax0, v_ax1, v_ax2])
-                T_divide[v_ax0, v_ax1, v_ax2] = rxplaceholder_red[v_ax0, 
v_ax1, v_ax2] * T.float32(2.8935185185185186e-05)
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(34560)):
-            with T.sblock("T_subtract"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(lv3[v_ax0, v_ax1, v_ax2], T_divide[v_ax0, v_ax1, 
T.int64(0)])
-                T.writes(T_subtract[v_ax0, v_ax1, v_ax2])
-                T_subtract[v_ax0, v_ax1, v_ax2] = lv3[v_ax0, v_ax1, v_ax2] - 
T_divide[v_ax0, v_ax1, T.int64(0)]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(34560)):
-            with T.sblock("T_multiply"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_subtract[v_ax0, v_ax1, v_ax2])
-                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
-                T_multiply[v_ax0, v_ax1, v_ax2] = T_subtract[v_ax0, v_ax1, 
v_ax2] * T_subtract[v_ax0, v_ax1, v_ax2]
-        for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1), 
T.int64(34560)):
-            with T.sblock("T_multiply_red"):
-                v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, 
ax2, k2])
-                T.reads(T_multiply[v_ax0, v_ax1, v_k2])
-                T.writes(T_multiply_red[v_ax0, v_ax1, v_ax2])
-                with T.init():
-                    T_multiply_red[v_ax0, v_ax1, v_ax2] = T.float32(0)
-                T_multiply_red[v_ax0, v_ax1, v_ax2] = T_multiply_red[v_ax0, 
v_ax1, v_ax2] + T_multiply[v_ax0, v_ax1, v_k2]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
-            with T.sblock("T_divide_1"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_multiply_red[v_ax0, v_ax1, v_ax2])
-                T.writes(T_divide_1[v_ax0, v_ax1, v_ax2])
-                T_divide_1[v_ax0, v_ax1, v_ax2] = T_multiply_red[v_ax0, v_ax1, 
v_ax2] * T.float32(2.8935185185185186e-05)
-        for i0, i1, i2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
-            with T.sblock("compute"):
-                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
-                T.reads(T_divide_1[v_i0, v_i1, v_i2])
-                T.writes(compute[v_i0, v_i1, v_i2])
-                compute[v_i0, v_i1, v_i2] = T.Cast("float16", T_divide_1[v_i0, 
v_i1, v_i2])
-
[email protected]_module
-class Fuse_Mean_Cast1:
-    @T.prim_func
-    def main(lv: T.Buffer((T.int64(1), T.int64(32), T.int64(34560)), 
"float32"), compute: T.Buffer((T.int64(1), T.int64(32), T.int64(1)), 
"float16")):
-        T.func_attr({"tir.noalias": True})
-        # with T.sblock("root"):
-        rxplaceholder_red = T.alloc_buffer((T.int64(1), T.int64(32), 
T.int64(1)))
-        T_divide = T.alloc_buffer((T.int64(1), T.int64(32), T.int64(1)))
-        for ax0, ax1, ax2, k2 in T.grid(T.int64(1), T.int64(32), T.int64(1), 
T.int64(34560)):
-            with T.sblock("rxplaceholder_red"):
-                v_ax0, v_ax1, v_ax2, v_k2 = T.axis.remap("SSSR", [ax0, ax1, 
ax2, k2])
-                T.reads(lv[v_ax0, v_ax1, v_k2])
-                T.writes(rxplaceholder_red[v_ax0, v_ax1, v_ax2])
-                with T.init():
-                    rxplaceholder_red[v_ax0, v_ax1, v_ax2] = T.float32(0)
-                rxplaceholder_red[v_ax0, v_ax1, v_ax2] = 
rxplaceholder_red[v_ax0, v_ax1, v_ax2] + lv[v_ax0, v_ax1, v_k2]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
-            with T.sblock("T_divide"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(rxplaceholder_red[v_ax0, v_ax1, v_ax2])
-                T.writes(T_divide[v_ax0, v_ax1, v_ax2])
-                T_divide[v_ax0, v_ax1, v_ax2] = rxplaceholder_red[v_ax0, 
v_ax1, v_ax2] * T.float32(2.8935185185185186e-05)
-        for i0, i1, i2 in T.grid(T.int64(1), T.int64(32), T.int64(1)):
-            with T.sblock("compute"):
-                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
-                T.reads(T_divide[v_i0, v_i1, v_i2])
-                T.writes(compute[v_i0, v_i1, v_i2])
-                compute[v_i0, v_i1, v_i2] = T.Cast("float16", T_divide[v_i0, 
v_i1, v_i2])
-
[email protected]_module
-class Module:
-    @T.prim_func
-    def main(lv26: T.Buffer((T.int64(1), T.int64(3456), T.int64(2560)), 
"float16"), T_multiply: T.Buffer((T.int64(1), T.int64(3456), T.int64(1280)), 
"float16")):
-        T.func_attr({"tir.noalias": True})
-        # with T.sblock("root"):
-        T_strided_slice_with_axes = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        T_divide = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), 
"float16")
-        T_multiply_1 = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        T_multiply_2 = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        compute = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)))
-        compute_1 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)))
-        compute_2 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), 
"float16")
-        T_multiply_3 = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        T_add = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), 
"float16")
-        T_multiply_4 = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        T_multiply_5 = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        T_divide_1 = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        T_add_1 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), 
"float16")
-        T_add_2 = T.alloc_buffer((T.int64(1), T.int64(3456), T.int64(1280)), 
"float16")
-        T_multiply_6 = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        T_strided_slice_with_axes_1 = T.alloc_buffer((T.int64(1), 
T.int64(3456), T.int64(1280)), "float16")
-        T_multiply_7 = T.alloc_buffer((T.int64(1), T.int64(3456), 
T.int64(1280)), "float16")
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_strided_slice_with_axes"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(lv26[v_ax0, v_ax1, v_ax2 + T.int64(1280)])
-                T.writes(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2])
-                T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] = lv26[v_ax0, 
v_ax1, v_ax2 + T.int64(1280)]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_divide"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2])
-                T.writes(T_divide[v_ax0, v_ax1, v_ax2])
-                T_divide[v_ax0, v_ax1, v_ax2] = 
T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] * T.float16(0.70718232044198892)
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_multiply"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_divide[v_ax0, v_ax1, v_ax2])
-                T.writes(T_multiply_1[v_ax0, v_ax1, v_ax2])
-                T_multiply_1[v_ax0, v_ax1, v_ax2] = T_divide[v_ax0, v_ax1, 
v_ax2] * T.float16(1.4140625)
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_multiply_1"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2])
-                T.writes(T_multiply_2[v_ax0, v_ax1, v_ax2])
-                T_multiply_2[v_ax0, v_ax1, v_ax2] = T_multiply_1[v_ax0, v_ax1, 
v_ax2] * T.float16(0.70710678118654757)
-        for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("compute"):
-                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
-                T.reads(T_multiply_2[v_i0, v_i1, v_i2])
-                T.writes(compute[v_i0, v_i1, v_i2])
-                compute[v_i0, v_i1, v_i2] = T.Cast("float32", 
T_multiply_2[v_i0, v_i1, v_i2])
-        for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("compute_1"):
-                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
-                T.reads(compute[v_i0, v_i1, v_i2])
-                T.writes(compute_1[v_i0, v_i1, v_i2])
-                compute_1[v_i0, v_i1, v_i2] = T.erf(compute[v_i0, v_i1, v_i2])
-        for i0, i1, i2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("compute_2"):
-                v_i0, v_i1, v_i2 = T.axis.remap("SSS", [i0, i1, i2])
-                T.reads(compute_1[v_i0, v_i1, v_i2])
-                T.writes(compute_2[v_i0, v_i1, v_i2])
-                compute_2[v_i0, v_i1, v_i2] = T.Cast("float16", 
compute_1[v_i0, v_i1, v_i2])
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_multiply_1_1"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(compute_2[v_ax0, v_ax1, v_ax2])
-                T.writes(T_multiply_3[v_ax0, v_ax1, v_ax2])
-                T_multiply_3[v_ax0, v_ax1, v_ax2] = compute_2[v_ax0, v_ax1, 
v_ax2] * T.float16(0.5)
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_add"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_multiply_3[v_ax0, v_ax1, v_ax2])
-                T.writes(T_add[v_ax0, v_ax1, v_ax2])
-                T_add[v_ax0, v_ax1, v_ax2] = T.float16(0.5) + 
T_multiply_3[v_ax0, v_ax1, v_ax2]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_multiply_2"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_multiply_1[v_ax0, v_ax1, v_ax2], T_add[v_ax0, v_ax1, 
v_ax2])
-                T.writes(T_multiply_4[v_ax0, v_ax1, v_ax2])
-                T_multiply_4[v_ax0, v_ax1, v_ax2] = T_multiply_1[v_ax0, v_ax1, 
v_ax2] * T_add[v_ax0, v_ax1, v_ax2]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_multiply_3"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_multiply_4[v_ax0, v_ax1, v_ax2])
-                T.writes(T_multiply_5[v_ax0, v_ax1, v_ax2])
-                T_multiply_5[v_ax0, v_ax1, v_ax2] = T_multiply_4[v_ax0, v_ax1, 
v_ax2] * T.float16(1.4140625)
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_divide_1"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_multiply_5[v_ax0, v_ax1, v_ax2], T_divide[v_ax0, 
v_ax1, v_ax2])
-                T.writes(T_divide_1[v_ax0, v_ax1, v_ax2])
-                T_divide_1[v_ax0, v_ax1, v_ax2] = T_multiply_5[v_ax0, v_ax1, 
v_ax2] / T_divide[v_ax0, v_ax1, v_ax2]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_add_1"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_divide_1[v_ax0, v_ax1, v_ax2])
-                T.writes(T_add_1[v_ax0, v_ax1, v_ax2])
-                T_add_1[v_ax0, v_ax1, v_ax2] = T_divide_1[v_ax0, v_ax1, v_ax2] 
+ T.float16(-1)
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_add_2"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_add_1[v_ax0, v_ax1, v_ax2])
-                T.writes(T_add_2[v_ax0, v_ax1, v_ax2])
-                T_add_2[v_ax0, v_ax1, v_ax2] = T_add_1[v_ax0, v_ax1, v_ax2] + 
T.float16(1)
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_multiply_4"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2], 
T_add_2[v_ax0, v_ax1, v_ax2])
-                T.writes(T_multiply_6[v_ax0, v_ax1, v_ax2])
-                T_multiply_6[v_ax0, v_ax1, v_ax2] = 
T_strided_slice_with_axes[v_ax0, v_ax1, v_ax2] * T_add_2[v_ax0, v_ax1, v_ax2]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_strided_slice_with_axes_1"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(lv26[v_ax0, v_ax1, v_ax2])
-                T.writes(T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2])
-                T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2] = lv26[v_ax0, 
v_ax1, v_ax2]
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_multiply_5"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_multiply_6[v_ax0, v_ax1, v_ax2])
-                T.writes(T_multiply_7[v_ax0, v_ax1, v_ax2])
-                T_multiply_7[v_ax0, v_ax1, v_ax2] = T_multiply_6[v_ax0, v_ax1, 
v_ax2] * T.float16(0.5)
-        for ax0, ax1, ax2 in T.grid(T.int64(1), T.int64(3456), T.int64(1280)):
-            with T.sblock("T_multiply_6"):
-                v_ax0, v_ax1, v_ax2 = T.axis.remap("SSS", [ax0, ax1, ax2])
-                T.reads(T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2], 
T_multiply_7[v_ax0, v_ax1, v_ax2])
-                T.writes(T_multiply[v_ax0, v_ax1, v_ax2])
-                T_multiply[v_ax0, v_ax1, v_ax2] = 
T_strided_slice_with_axes_1[v_ax0, v_ax1, v_ax2] * T_multiply_7[v_ax0, v_ax1, 
v_ax2]
-# fmt: on
-# pylint: enable=no-self-argument,missing-class-docstring,line-too-long
-
-
-def _target() -> tvm.target.Target:
-    return tvm.target.Target({"kind": "llvm", "num-cores": 4})
-    # for local testing only
-    # return tvm.target.Target("nvidia/geforce-rtx-3070")
-
-
-def _acc() -> float:
-    return 1e-2 if _target().kind.name == "cuda" else 1e-7
-
-
-def _get_single_prim_func(mod: tvm.ir.IRModule) -> tvm.tir.PrimFunc:
-    funcs = [func for func in mod.functions.values()]
-    assert len(funcs) == 1, "Only one function is supported."
-    return funcs[0]
-
-
-def _get_input_output_info(func: tvm.tir.PrimFunc) -> tuple[list[np.ndarray], 
tuple, str]:
-    args = ArgInfo.from_prim_func(func)
-    inputs = [generate_input_data(x.shape, x.dtype) for x in args[:-1]]
-    output_shape = args[-1].shape
-    output_dtype = args[-1].dtype
-    return inputs, output_shape, output_dtype
-
-
-def _expected_results(
-    mod: tvm.ir.IRModule, inputs: list[np.ndarray], output_shape: tuple, 
output_dtype: str
-) -> np.ndarray:
-    func = _get_single_prim_func(mod)
-    func = func.with_attr("global_symbol", "main")
-    rt_mod = tvm.compile(func, target="llvm")
-    data = [
-        tvm.runtime.tensor(x)
-        for x in [
-            *inputs,
-            np.zeros(output_shape, dtype=output_dtype),
-        ]
-    ]
-    rt_mod(*data)
-    return data[-1].numpy()
-
-
-def _actual_results(
-    actual: tvm.ir.IRModule, inputs: list[np.ndarray], output_shape: tuple, 
output_dtype: str
-):
-    target = _target()
-    actual_rt_mod = tvm.compile(actual, target=target)
-    actual_data = [
-        tvm.runtime.tensor(x, device=tvm.cuda() if target.kind.name == "cuda" 
else tvm.cpu())
-        for x in [
-            *inputs,
-            np.zeros(output_shape, dtype=output_dtype),
-        ]
-    ]
-    actual_rt_mod(*actual_data)
-    return actual_data[-1].numpy()
-
-
-def _assert_allclose(mod: tvm.ir.IRModule, actual: tvm.ir.IRModule) -> None:
-    inputs, output_shape, output_dtype = 
_get_input_output_info(_get_single_prim_func(mod))
-    expected_output = _expected_results(mod, inputs, output_shape, 
output_dtype)
-    actual_output = _actual_results(actual, inputs, output_shape, output_dtype)
-    tvm.testing.assert_allclose(expected_output, actual_output, rtol=1e-3, 
atol=1e-3)
-
-
-# Fused_Variance_Cast1 not added due to 
https://github.com/apache/tvm/issues/14791
[email protected]("mod", [Softmax, MatMul, Fuse_Mean_Cast1, Module])
[email protected]("benchmark", [False, True])
-def test_funcs(mod: tvm.ir.IRModule, benchmark: bool) -> None:
-    valid_count = 10 if benchmark else 1
-    with _target(), tvm.transform.PassContext(opt_level=3):
-        actual = FewShotTuning(valid_count=valid_count)(mod)
-    assert _get_single_prim_func(actual).attrs["tir.is_scheduled"], "Schedule 
is not applied."
-    _assert_allclose(mod, actual)
-
-
-if __name__ == "__main__":
-    tvm.testing.main()


Reply via email to