zxybazh commented on a change in pull request #10366:
URL: https://github.com/apache/tvm/pull/10366#discussion_r826203835
##########
File path: src/meta_schedule/utils.h
##########
@@ -50,6 +51,201 @@
#include "../tir/schedule/utils.h"
namespace tvm {
+namespace tir {
+
+inline double CountFlop(const IRModule& mod) {
+ struct TResult {
+ using TTable = std::unordered_map<int32_t, double>;
+
+ TResult() = default;
+
+ explicit TResult(const tvm::DataType& dtype) { Add(dtype); }
+
+ void Add(const tvm::DataType& dtype) { data_[DataType2Int(dtype)] += 1; }
+
+ TResult operator+=(const TResult& rhs) {
+ for (const auto& kv : rhs.data_) {
+ data_[kv.first] += kv.second;
+ }
+ return *this;
+ }
+
+ TResult operator*=(int64_t rhs) {
+ for (auto& kv : data_) {
+ kv.second *= rhs;
+ }
+ return *this;
+ }
+
+ TResult MaxWith(const TResult& rhs) {
+ for (const auto& kv : rhs.data_) {
+ double& v = data_[kv.first];
+ if (v < kv.second) {
+ v = kv.second;
+ }
+ }
+ return *this;
+ }
+
+ struct DType {
+ uint8_t code : 8;
+ uint8_t bits : 8;
+ uint16_t lanes : 16;
+ };
+ static_assert(sizeof(DType) == 4, "Incorrect size of DType");
+
+ static String Int2Str(int32_t dtype) {
+ union {
+ DType dst;
+ int32_t src;
+ } converter;
+ converter.src = dtype;
+ static std::string type_code_tab[] = {"int", "uint", "float", "handle",
"bfloat"};
+ std::ostringstream os;
+ os << type_code_tab[converter.dst.code];
+ os << static_cast<int>(converter.dst.bits);
+ if (converter.dst.lanes != 1) {
+ os << "x" << static_cast<int>(converter.dst.lanes);
+ }
+ return os.str();
+ }
+
+ static int32_t DataType2Int(const tvm::DataType& dtype) {
+ union {
+ DType src;
+ int32_t dst;
+ } converter;
+ converter.src.code = dtype.code();
+ converter.src.bits = dtype.bits();
+ converter.src.lanes = dtype.lanes();
+ return converter.dst;
+ }
+
+ TTable data_;
+ };
+
+ class FlopCounter : public ExprFunctor<TResult(const PrimExpr& n)>,
+ public StmtFunctor<TResult(const Stmt& n)> {
+ public:
+ ~FlopCounter() {}
+
+ TResult VisitExpr(const PrimExpr& expr) override { return
ExprFunctor::VisitExpr(expr); }
+ TResult VisitStmt(const Stmt& stmt) override { return
StmtFunctor::VisitStmt(stmt); }
+
+ TResult VisitStmt_(const IfThenElseNode* branch) override {
+ TResult cond = VisitExpr(branch->condition);
+ cond +=
VisitStmt(branch->then_case).MaxWith(VisitStmt(branch->else_case));
+ return cond;
+ }
+
+ TResult VisitStmt_(const BufferStoreNode* store) override {
+ TResult result = VisitExpr(store->value);
+ for (const PrimExpr& e : store->indices) {
+ result += VisitExpr(e);
+ }
+ return result;
+ }
+
+ TResult VisitStmt_(const SeqStmtNode* seq) override {
+ TResult result;
+ for (const Stmt& stmt : seq->seq) {
+ result += VisitStmt(stmt);
+ }
+ return result;
+ }
+
+ TResult VisitStmt_(const BlockRealizeNode* block) override {
+ return VisitStmt(block->block->body);
+ }
+
+ TResult VisitStmt_(const BlockNode* block) override {
+ TResult result;
+ if (block->init.defined()) {
+ result += VisitStmt(block->init.value());
+ }
+ result += VisitStmt(block->body);
+ return result;
+ }
+
+ TResult VisitStmt_(const ForNode* loop) override {
+ TResult result = VisitStmt(loop->body);
+ const auto* int_imm = loop->extent.as<IntImmNode>();
+ ICHECK(int_imm) << "TypeError: Expect the extent of a loop to be IntImm,
but gets: "
+ << loop->extent->GetTypeKey();
+ result *= int_imm->value;
+ return result;
+ }
+
+#define TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(Node) \
+ TResult VisitExpr_(const Node* op) final { \
+ TResult result(op->dtype); \
+ result += VisitExpr(op->a); \
+ result += VisitExpr(op->b); \
+ return result; \
+ }
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AddNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(SubNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MulNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(DivNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(ModNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorDivNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(FloorModNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MinNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(MaxNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(EQNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(NENode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LTNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(LENode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GTNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(GENode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(AndNode);
+ TVM_META_SCHEDULE_FLOP_COUNTER_BINARY(OrNode);
+#undef TVM_META_SCHEDULE_FLOP_COUNTER_BINARY
+ TResult VisitExpr_(const CastNode* op) override { return
VisitExpr(op->value); }
+ TResult VisitExpr_(const VarNode* op) override { return TResult(); }
+ TResult VisitExpr_(const SizeVarNode* op) override { return TResult(); }
+ TResult VisitExpr_(const BufferLoadNode* op) override { return TResult(); }
+ TResult VisitExpr_(const IntImmNode* op) override { return TResult(); }
+ TResult VisitExpr_(const FloatImmNode* op) override { return TResult(); }
+ TResult VisitExpr_(const NotNode* op) override {
+ TResult result(op->dtype);
+ result += VisitExpr(op->a);
+ return result;
+ }
+ TResult VisitExpr_(const SelectNode* op) override {
+ TResult cond = VisitExpr(op->condition);
+ cond += VisitExpr(op->true_value).MaxWith(VisitExpr(op->false_value));
+ return cond;
+ }
+ TResult VisitExpr_(const CallNode* op) override {
+ TResult ret;
+ for (const auto& x : op->args) {
+ ret += VisitExpr(x);
+ }
+ return ret;
+ }
+ };
+ FlopCounter counter;
+ TResult result;
+ for (const auto& kv : mod->functions) {
+ const BaseFunc& base_func = kv.second;
+ if (const auto* prim_func = base_func.as<PrimFuncNode>()) {
+ result += counter.VisitStmt(prim_func->body);
+ }
+ }
+ double cnt = 0.0;
+ int i32 = TResult::DataType2Int(tvm::DataType::Int(32));
+ int i64 = TResult::DataType2Int(tvm::DataType::Int(64));
+ int u1 = TResult::DataType2Int(tvm::DataType::UInt(1));
+ for (const auto& kv : result.data_) {
+ if (kv.first != i32 && kv.first != i64 && kv.first != u1) {
Review comment:
CC @junrushao1994
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]