masahi commented on code in PR #13710:
URL: https://github.com/apache/tvm/pull/13710#discussion_r1063081290
##########
src/tir/ir/data_type_rewriter.cc:
##########
@@ -397,6 +417,9 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const
BufferStoreNode* op) {
Buffer new_buffer = GetRemappedBuffer(op->buffer);
auto value = this->VisitExpr(op->value);
+ if (new_buffer->dtype != value->dtype && value->dtype.lanes() == 1) {
+ value = cast(new_buffer->dtype, value);
Review Comment:
`new_buffer->dtype != value->dtype` is true if `new_buffer->dtype ==
float32` and `value->dtype == float32x16`, and the indices is a vector.
`value->dtype.lanes() == 1` condition is added to prevent adding a cast in such
cases.
For argmax, this cast can result in a redundant chain of casts of the form
`int32(select(cond, int64(tru), int64(false)))`, but hopefully LLVM or other
backend can clean them up.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]