This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 395e91ff54 [MetaSchedule] Extract workload embedding (#11975)
395e91ff54 is described below
commit 395e91ff54543864a90240d18c8efd8c277c758b
Author: Kathryn (Jinqi) Chen <[email protected]>
AuthorDate: Thu Jun 30 19:36:13 2022 -0700
[MetaSchedule] Extract workload embedding (#11975)
This PR enables extracting the embeddings of the workload in a tuning
context, which further strengthens the feature extracting process. Workload
embeddings are extracted based on names of each block in the IR module. If
`extract_workload` is enabled, the extracted feature vectors will have length
164 + 8 = 172.
---
include/tvm/meta_schedule/feature_extractor.h | 4 +-
.../feature_extractor/per_store_feature.py | 6 ++
.../feature_extractor/per_store_feature.cc | 79 +++++++++++++++++++++-
3 files changed, 85 insertions(+), 4 deletions(-)
diff --git a/include/tvm/meta_schedule/feature_extractor.h
b/include/tvm/meta_schedule/feature_extractor.h
index 02e9f26b2a..4165e5efe0 100644
--- a/include/tvm/meta_schedule/feature_extractor.h
+++ b/include/tvm/meta_schedule/feature_extractor.h
@@ -101,11 +101,13 @@ class FeatureExtractor : public runtime::ObjectRef {
* \param arith_intensity_curve_num_samples The number of samples used in
the arithmetic intensity
* curve.
* \param cache_line_bytes The number of bytes in a cache line.
+ * \param extract_workload Whether to extract features in the workload in
tuning context or not.
* \return The feature extractor created.
*/
TVM_DLL static FeatureExtractor PerStoreFeature(int buffers_per_store = 5,
int
arith_intensity_curve_num_samples = 10,
- int cache_line_bytes = 64);
+ int cache_line_bytes = 64,
+ bool extract_workload =
false);
/*!
* \brief Create a feature extractor with customized methods on the
python-side.
* \param f_extract_from The packed function of `ExtractFrom`.
diff --git a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
index 306934d5f9..078a4af0e3 100644
--- a/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
+++ b/python/tvm/meta_schedule/feature_extractor/per_store_feature.py
@@ -35,6 +35,8 @@ class PerStoreFeature(FeatureExtractor):
The number of samples used in the arithmetic intensity curve.
cache_line_bytes : int
The number of bytes in a cache line.
+ extract_workload : bool
+ Whether to extract features in the workload in tuning context or not.
"""
buffers_per_store: int
@@ -43,6 +45,8 @@ class PerStoreFeature(FeatureExtractor):
"""The number of samples used in the arithmetic intensity curve."""
cache_line_bytes: int
"""The number of bytes in a cache line."""
+ extract_workload: bool
+ """Whether to extract features in the workload in tuning context or not."""
feature_vector_length: int
"""Length of the feature vector."""
@@ -51,10 +55,12 @@ class PerStoreFeature(FeatureExtractor):
buffers_per_store: int = 5,
arith_intensity_curve_num_samples: int = 10,
cache_line_bytes: int = 64,
+ extract_workload: bool = False,
):
self.__init_handle_by_constructor__(
_ffi_api.FeatureExtractorPerStoreFeature, # type: ignore #
pylint: disable=no-member
buffers_per_store,
arith_intensity_curve_num_samples,
cache_line_bytes,
+ extract_workload,
)
diff --git a/src/meta_schedule/feature_extractor/per_store_feature.cc
b/src/meta_schedule/feature_extractor/per_store_feature.cc
index 93f6767b11..c29e5d61f0 100644
--- a/src/meta_schedule/feature_extractor/per_store_feature.cc
+++ b/src/meta_schedule/feature_extractor/per_store_feature.cc
@@ -21,6 +21,7 @@
#include <cmath>
#include <memory>
#include <numeric>
+#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
@@ -1169,6 +1170,64 @@ struct Feature {
} // namespace group5
+namespace group6 {
+
+/*! \brief The auxiliary feature extractor for workloads */
+class WorkloadEmbeddingExtractor : private StmtVisitor {
+ public:
+ static std::vector<double> Extract(const IRModule& mod) {
+ WorkloadEmbeddingExtractor self;
+ for (const auto& kv : mod->functions) {
+ if (const PrimFuncNode* func = kv.second.as<PrimFuncNode>()) {
+ self(func->body);
+ }
+ }
+ return self.embedding;
+ }
+
+ private:
+ void VisitStmt_(const BlockNode* block) final {
+ StmtVisitor::VisitStmt_(block);
+ std::string name = block->name_hint;
+ std::for_each(name.begin(), name.end(), [](char& c) { c = ::tolower(c); });
+ if (name.find("softmax") != std::string::npos) {
+ embedding[0] = 1.0;
+ } else if ((name.find("max") != std::string::npos) || (name.find("min") !=
std::string::npos)) {
+ embedding[1] = 1.0;
+ } else if (name.find("add") != std::string::npos) {
+ embedding[2] = 1.0;
+ } else if (name.find("batch_matmul") != std::string::npos) {
+ embedding[3] = 1.0;
+ } else if (name.find("matmul") != std::string::npos) {
+ embedding[4] = 1.0;
+ } else if (name.find("depthwiseconv2d") != std::string::npos) {
+ embedding[5] = 1.0;
+ } else if (name.find("conv2d_winograd") != std::string::npos) {
+ embedding[6] = 1.0;
+ } else if (name.find("conv2d") != std::string::npos) {
+ embedding[7] = 1.0;
+ }
+ }
+
+ std::vector<double> embedding = std::vector<double>(8, 0.0);
+};
+
+/*! \brief Group 6 feature */
+struct Feature {
+ explicit Feature(const IRModule& mod) {
+ this->feature = WorkloadEmbeddingExtractor::Extract(mod);
+ }
+
+ void Export(std::vector<double>* v) const {
+ v->insert(v->end(), std::begin(feature), std::end(feature));
+ }
+
+ std::vector<double> feature; // The workload embedding
+ static constexpr int64_t kCount = 8;
+};
+
+} // namespace group6
+
/*! \brief The feature extracted */
struct Feature {
const BufferNode* buffer = nullptr;
@@ -1178,6 +1237,7 @@ struct Feature {
std::unique_ptr<group3::Feature> group3 = nullptr;
std::unique_ptr<group4::Feature> group4 = nullptr;
std::unique_ptr<group5::Feature> group5 = nullptr;
+ std::shared_ptr<group6::Feature> group6 = nullptr;
bool operator<(const Feature& other) const { return buffer_order <
other.buffer_order; }
};
@@ -1283,6 +1343,7 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
int buffers_per_store;
int arith_intensity_curve_num_samples;
int cache_line_bytes;
+ bool extract_workload;
int feature_vector_length;
void VisitAttrs(tvm::AttrVisitor* v) {
@@ -1308,7 +1369,6 @@ class PerStoreFeatureNode : public FeatureExtractorNode {
feature.group3->Export(&result);
feature.group4->Export(&result, feature.group5->outer_prod);
feature.group5->Export(&result);
- ICHECK_EQ(static_cast<int>(result.size()), feature_vector_length);
}
}
@@ -1317,10 +1377,19 @@ class PerStoreFeatureNode : public FeatureExtractorNode
{
bool is_gpu = tune_context->target.value()->kind->name == "cuda";
std::vector<runtime::NDArray> results;
results.resize(candidates.size());
- auto f = [this, is_gpu, &candidates, &results](int, int task_id) -> void {
+ std::unique_ptr<tir::group6::Feature> feature_group6 = nullptr;
+ if (extract_workload) {
+ feature_group6 =
std::make_unique<tir::group6::Feature>(tune_context->mod.value());
+ }
+ auto f = [this, is_gpu, &feature_group6, &candidates, &results](int, int
task_id) -> void {
const auto& candidate = candidates[task_id];
std::vector<std::vector<double>> features;
ExtractSingle(DeepCopyIRModule(candidate->sch->mod()), is_gpu,
&features);
+ if (extract_workload) {
+ for (auto& feature : features) {
+ feature_group6->Export(&feature);
+ }
+ }
results[task_id] = tir::utils::AsNDArray(features);
};
support::parallel_for_dynamic(0, candidates.size(),
tune_context->num_threads, f);
@@ -1333,16 +1402,20 @@ class PerStoreFeatureNode : public FeatureExtractorNode
{
FeatureExtractor FeatureExtractor::PerStoreFeature(int buffers_per_store,
int
arith_intensity_curve_num_samples,
- int cache_line_bytes) {
+ int cache_line_bytes, bool
extract_workload) {
ObjectPtr<PerStoreFeatureNode> n = make_object<PerStoreFeatureNode>();
n->buffers_per_store = buffers_per_store;
n->arith_intensity_curve_num_samples = arith_intensity_curve_num_samples;
n->cache_line_bytes = cache_line_bytes;
+ n->extract_workload = extract_workload;
n->feature_vector_length = tir::group1::Feature::kCount +
//
tir::group2::Feature::SubFeature::kCount *
buffers_per_store + //
arith_intensity_curve_num_samples +
//
tir::group4::Feature::kCount +
//
tir::group5::Feature::kCount;
+ if (extract_workload) {
+ n->feature_vector_length += tir::group6::Feature::kCount;
+ }
return FeatureExtractor(n);
}