wpan11nv commented on a change in pull request #5428:
URL: https://github.com/apache/incubator-tvm/pull/5428#discussion_r416251080



##########
File path: src/target/source/codegen_cuda.cc
##########
@@ -274,9 +274,21 @@ void CodeGenCUDA::PrintVecElemLoad(
   static const char access[] = {'x', 'y', 'z', 'w'};
   CHECK(i >= 0 && i < (t.is_float16() ? 8 : 4));
   if ((t.is_int()) && t.bits() == 8) {
-    os << "((char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {
+      os << vec;
+    } else if (t.lanes() == 2) {
+      os << vec << "." << access[i % 2];
+    } else {
+      os << "((char)(" << vec << " >> " << i * 8 << "))";
+    }
   } else if ((t.is_uint()) && t.bits() == 8) {
-    os << "((unsigned char)(" << vec << " >> " << i * 8 << "))";
+    if (t.lanes() == 1) {

Review comment:
       I meant you have two "if (t.lanes() == 1)  return vec". You could just 
move this logic out, as it is true for all inputs. 
   
   The other part is nice-to-have if you can naturally support char2/3/4./8. 
that will  help simplify the codegen logic. 
   
   e.g. char2/3/4 can be stored as uint32_t. The only difference is only lower 
k bytes are useful. 
   
    




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


Reply via email to