masahi commented on a change in pull request #10455:
URL: https://github.com/apache/tvm/pull/10455#discussion_r822188815
##########
File path: src/auto_scheduler/feature.cc
##########
@@ -659,38 +697,86 @@ class PerStoreFeatureExtractor : public StmtExprVisitor {
}
}
+ void VisitExpr_(const BufferLoadNode* node) final {
+ // Store buffer shape/dtype. It may already be stored.
+ buffer_shapes[node->buffer->data] = node->buffer->shape;
+ buffer_dtypes[node->buffer->data] = node->buffer->dtype;
+ StmtExprVisitor::VisitExpr_(node);
+ }
+
void VisitStmt_(const BufferStoreNode* node) final {
+ // Store buffer shape/dtype. It may already be stored.
+ buffer_shapes[node->buffer->data] = node->buffer->shape;
+ buffer_dtypes[node->buffer->data] = node->buffer->dtype;
+
MathOpCounter math_op_counter;
math_op_counter(node->value);
std::vector<float> mem_bytes_list;
std::vector<float> compute_ops_list;
double cur_compute_ops;
// Group 1: Computation related features
- ExtractComputationFeature(node, math_op_counter);
+ ExtractComputationFeature(node->buffer->data, node->indices,
math_op_counter);
// Group 2: Buffer access related features (per buffer)
- ExtractBufferAccessFeature(node, math_op_counter, &cur_compute_ops,
&compute_ops_list,
- &mem_bytes_list);
+ ExtractBufferAccessFeature(node->buffer->data, node->indices, node->value,
math_op_counter,
+ &cur_compute_ops, &compute_ops_list,
&mem_bytes_list);
// Group 3: Arithmetic intensity related features
- ExtractArithmeticIntensityFeature(node, cur_compute_ops, compute_ops_list,
mem_bytes_list);
+ ExtractArithmeticIntensityFeature(node->buffer->data, cur_compute_ops,
compute_ops_list,
+ mem_bytes_list);
// Group 4: Allocation related features
- ExtractOuterScopeFeature(node);
+ ExtractOuterScopeFeature(node->buffer->data);
}
void VisitStmt_(const BufferRealizeNode* node) final {
+ // Store buffer shape/dtype. It may already be stored.
+ buffer_shapes[node->buffer->data] = node->buffer->shape;
+ buffer_dtypes[node->buffer->data] = node->buffer->dtype;
StmtExprVisitor::VisitStmt_(node);
// Group 5: Outer scope related features
ExtractAllocationFeature(node);
}
+ void VisitStmt_(const AllocateNode* node) final {
+ buffer_dtypes[node->buffer_var] = node->dtype;
+ buffer_shapes[node->buffer_var] = node->extents;
+ StmtExprVisitor::VisitStmt_(node);
+
+ // Group 5: Outer scope related features
+ ExtractAllocationFeature(node);
+ }
+
+ void VisitStmt_(const StoreNode* node) final {
+ MathOpCounter math_op_counter;
Review comment:
Do we hit this function now? Note that Load / Store node are deprecated
after eric PR cc @Lunderberg
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]