This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 395e91ff54 [MetaSchedule] Extract workload embedding (#11975)
395e91ff54 is described below

commit 395e91ff54543864a90240d18c8efd8c277c758b
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Jun 30 19:36:13 2022 -0700

    [MetaSchedule] Extract workload embedding (#11975)
    
    This PR enables extracting the embeddings of the workload in a tuning 
context, which further strengthens the feature extracting process. Workload 
embeddings are extracted based on names of each block in the IR module. If 
`extract_workload` is enabled, the extracted feature vectors will have length 
164 + 8 = 172.
---
 include/tvm/meta_schedule/feature_extractor.h      |  4 +-
 .../feature_extractor/per_store_feature.py         |  6 ++
 .../feature_extractor/per_store_feature.cc         | 79 +++++++++++++++++++++-
 3 files changed, 85 insertions(+), 4 deletions(-)

diff --git a/include/tvm/meta_schedule/feature_extractor.h 
b/include/tvm/meta_schedule/feature_extractor.h
index 02e9f26b2a..4165e5efe0 100644
--- a/include/tvm/meta_schedule/feature_extractor.h
+++ b/include/tvm/meta_schedule/feature_extractor.h
@@ -101,11 +101,13 @@ class FeatureExtractor : public runtime::ObjectRef {
    * \param arith_intensity_curve_num_samples The number of samples used in 
the arithmetic intensity
    * curve.
    * \param cache_line_bytes The number of bytes in a cache line.
+   * \param extract_workload Whether to extract features in the workload in 
tuning context or not.
    * \return The feature extractor created.
    */
   TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5,
                                                   int 
arith_intensity_curve_num_samples = 10,
-                                                  int cache_line_bytes = 64);
+                                                  int cache_line_bytes = 64,
+                                                  bool extract_workload = 
false);
   /*!
    * \brief Create a feature extractor with customized methods on the 
python-side.
    * \param f_extract_from The packed function of `ExtractFrom`.
diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py 
b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
index 306934d5f9..078a4af0e3 100644
--- a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
+++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
@@ -35,6 +35,8 @@ class PerStoreFeature(FeatureExtractor):
         The number of samples used in the arithmetic intensity curve.
     cache_line_bytes : int
         The number of bytes in a cache line.
+    extract_workload : bool
+        Whether to extract features in the workload in tuning context or not.
     """
 
     buffers_per_store: int
@@ -43,6 +45,8 @@ class PerStoreFeature(FeatureExtractor):
     """The number of samples used in the arithmetic intensity curve."""
     cache_line_bytes: int
     """The number of bytes in a cache line."""
+    extract_workload: bool
+    """Whether to extract features in the workload in tuning context or not."""
     feature_vector_length: int
     """Length of the feature vector."""
 
@@ -51,10 +55,12 @@ class PerStoreFeature(FeatureExtractor):
         buffers_per_store: int = 5,
         arith_intensity_curve_num_samples: int = 10,
         cache_line_bytes: int = 64,
+        extract_workload: bool = False,
     ):
         self.__init_handle_by_constructor__(
             _ffi_api.FeatureExtractorPerStoreFeature,  # type: ignore # 
pylint: disable=no-member
             buffers_per_store,
             arith_intensity_curve_num_samples,
             cache_line_bytes,
+            extract_workload,
         )
diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc 
b/src/meta_schedule/feature_extractor/per_store_feature.cc
index 93f6767b11..c29e5d61f0 100644
--- a/src/meta_schedule/feature_extractor/per_store_feature.cc
+++ b/src/meta_schedule/feature_extractor/per_store_feature.cc
@@ -21,6 +21,7 @@
 #include <cmath>
 #include <memory>
 #include <numeric>
+#include <string>
 #include <unordered_map>
 #include <unordered_set>
 #include <vector>
@@ -1169,6 +1170,64 @@ struct Feature {
 
 }  // namespace group5
 
+namespace group6 {
+
+/*! \brief The auxiliary feature extractor for workloads */
+class WorkloadEmbeddingExtractor : private StmtVisitor {
+ public:
+  static std::vector<double> Extract(const IRModule& mod) {
+    WorkloadEmbeddingExtractor self;
+    for (const auto& kv : mod->functions) {
+      if (const PrimFuncNode* func = kv.second.as<PrimFuncNode>()) {
+        self(func->body);
+      }
+    }
+    return self.embedding;
+  }
+
+ private:
+  void VisitStmt_(const BlockNode* block) final {
+    StmtVisitor::VisitStmt_(block);
+    std::string name = block->name_hint;
+    std::for_each(name.begin(), name.end(), [](char& c) { c = ::tolower(c); });
+    if (name.find("softmax") != std::string::npos) {
+      embedding[0] = 1.0;
+    } else if ((name.find("max") != std::string::npos) || (name.find("min") != 
std::string::npos)) {
+      embedding[1] = 1.0;
+    } else if (name.find("add") != std::string::npos) {
+      embedding[2] = 1.0;
+    } else if (name.find("batch_matmul") != std::string::npos) {
+      embedding[3] = 1.0;
+    } else if (name.find("matmul") != std::string::npos) {
+      embedding[4] = 1.0;
+    } else if (name.find("depthwiseconv2d") != std::string::npos) {
+      embedding[5] = 1.0;
+    } else if (name.find("conv2d_winograd") != std::string::npos) {
+      embedding[6] = 1.0;
+    } else if (name.find("conv2d") != std::string::npos) {
+      embedding[7] = 1.0;
+    }
+  }
+
+  std::vector<double> embedding = std::vector<double>(8, 0.0);
+};
+
+/*! \brief Group 6 feature */
+struct Feature {
+  explicit Feature(const IRModule& mod) {
+    this->feature = WorkloadEmbeddingExtractor::Extract(mod);
+  }
+
+  void Export(std::vector<double>* v) const {
+    v->insert(v->end(), std::begin(feature), std::end(feature));
+  }
+
+  std::vector<double> feature;  // The workload embedding
+  static constexpr int64_t kCount = 8;
+};
+
+}  // namespace group6
+
 /*! \brief The feature extracted */
 struct Feature {
   const BufferNode* buffer = nullptr;
@@ -1178,6 +1237,7 @@ struct Feature {
   std::unique_ptr<group3::Feature> group3 = nullptr;
   std::unique_ptr<group4::Feature> group4 = nullptr;
   std::unique_ptr<group5::Feature> group5 = nullptr;
+  std::shared_ptr<group6::Feature> group6 = nullptr;
 
   bool operator<(const Feature& other) const { return buffer_order < 
other.buffer_order; }
 };
@@ -1283,6 +1343,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
   int buffers_per_store;
   int arith_intensity_curve_num_samples;
   int cache_line_bytes;
+  bool extract_workload;
   int feature_vector_length;
 
   void VisitAttrs(tvm::AttrVisitor* v) {
@@ -1308,7 +1369,6 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
       feature.group3->Export(&result);
       feature.group4->Export(&result, feature.group5->outer_prod);
       feature.group5->Export(&result);
-      ICHECK_EQ(static_cast<int>(result.size()), feature_vector_length);
     }
   }
 
@@ -1317,10 +1377,19 @@ class PerStoreFeatureNode : public FeatureExtractorNode 
{
     bool is_gpu = tune_context->target.value()->kind->name == "cuda";
     std::vector<runtime::NDArray> results;
     results.resize(candidates.size());
-    auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void {
+    std::unique_ptr<tir::group6::Feature> feature_group6 = nullptr;
+    if (extract_workload) {
+      feature_group6 = 
std::make_unique<tir::group6::Feature>(tune_context->mod.value());
+    }
+    auto f = [this, is_gpu, &feature_group6, &candidates, &results](int, int 
task_id) -> void {
       const auto& candidate = candidates[task_id];
       std::vector<std::vector<double>> features;
       ExtractSingle(DeepCopyIRModule(candidate->sch->mod()), is_gpu, 
&features);
+      if (extract_workload) {
+        for (auto& feature : features) {
+          feature_group6->Export(&feature);
+        }
+      }
       results[task_id] = tir::utils::AsNDArray(features);
     };
     support::parallel_for_dynamic(0, candidates.size(), 
tune_context->num_threads, f);
@@ -1333,16 +1402,20 @@ class PerStoreFeatureNode : public FeatureExtractorNode 
{
 
 FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store,
                                                    int 
arith_intensity_curve_num_samples,
-                                                   int cache_line_bytes) {
+                                                   int cache_line_bytes, bool 
extract_workload) {
   ObjectPtr<PerStoreFeatureNode> n = make_object<PerStoreFeatureNode>();
   n->buffers_per_store = buffers_per_store;
   n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples;
   n->cache_line_bytes = cache_line_bytes;
+  n->extract_workload = extract_workload;
   n->feature_vector_length = tir::group1::Feature::kCount +                    
              //
                              tir::group2::Feature::SubFeature::kCount * 
buffers_per_store +  //
                              arith_intensity_curve_num_samples +               
              //
                              tir::group4::Feature::kCount +                    
              //
                              tir::group5::Feature::kCount;
+  if (extract_workload) {
+    n->feature_vector_length += tir::group6::Feature::kCount;
+  }
   return FeatureExtractor(n);
 }
 

Reply via email to