yubing updated this revision to Diff 327362.
yubing edited the summary of this revision.
yubing added a comment.

address comments above


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D93594/new/

https://reviews.llvm.org/D93594

Files:
  llvm/include/llvm/CodeGen/Passes.h
  llvm/lib/Target/X86/CMakeLists.txt
  llvm/lib/Target/X86/X86.h
  llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
  llvm/lib/Target/X86/X86LowerAMXType.cpp
  llvm/lib/Target/X86/X86TargetMachine.cpp
  llvm/test/CodeGen/X86/AMX/amx-low-intrinsics-no-amx-bitcast.ll
  llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll
  llvm/test/CodeGen/X86/AMX/amx-type.ll
  llvm/test/CodeGen/X86/O0-pipeline.ll
  llvm/test/CodeGen/X86/opt-pipeline.ll
  llvm/tools/opt/opt.cpp

Index: llvm/tools/opt/opt.cpp
===================================================================
--- llvm/tools/opt/opt.cpp
+++ llvm/tools/opt/opt.cpp
@@ -497,7 +497,7 @@
       "expand-reductions",    "indirectbr-expand",
       "generic-to-nvvm",      "expandmemcmp",
       "loop-reduce",          "lower-amx-type",
-      "polyhedral-info"};
+      "lower-amx-intrinsics", "polyhedral-info"};
   for (const auto &P : PassNamePrefix)
     if (Pass.startswith(P))
       return true;
Index: llvm/test/CodeGen/X86/opt-pipeline.ll
===================================================================
--- llvm/test/CodeGen/X86/opt-pipeline.ll
+++ llvm/test/CodeGen/X86/opt-pipeline.ll
@@ -24,11 +24,12 @@
 ; CHECK-NEXT:     Pre-ISel Intrinsic Lowering
 ; CHECK-NEXT:     FunctionPass Manager
 ; CHECK-NEXT:       Expand Atomic instructions
+; CHECK-NEXT:       Dominator Tree Construction
+; CHECK-NEXT:       Natural Loop Information
+; CHECK-NEXT:       Lower AMX intrinsics
 ; CHECK-NEXT:       Lower AMX type for load/store
 ; CHECK-NEXT:       Module Verifier
-; CHECK-NEXT:       Dominator Tree Construction
 ; CHECK-NEXT:       Basic Alias Analysis (stateless AA impl)
-; CHECK-NEXT:       Natural Loop Information
 ; CHECK-NEXT:       Canonicalize natural loops
 ; CHECK-NEXT:       Scalar Evolution Analysis
 ; CHECK-NEXT:       Loop Pass Manager
Index: llvm/test/CodeGen/X86/O0-pipeline.ll
===================================================================
--- llvm/test/CodeGen/X86/O0-pipeline.ll
+++ llvm/test/CodeGen/X86/O0-pipeline.ll
@@ -18,6 +18,9 @@
 ; CHECK-NEXT:     Pre-ISel Intrinsic Lowering
 ; CHECK-NEXT:     FunctionPass Manager
 ; CHECK-NEXT:       Expand Atomic instructions
+; CHECK-NEXT:       Dominator Tree Construction
+; CHECK-NEXT:       Natural Loop Information
+; CHECK-NEXT:       Lower AMX intrinsics
 ; CHECK-NEXT:       Lower AMX type for load/store
 ; CHECK-NEXT:       Module Verifier
 ; CHECK-NEXT:       Lower Garbage Collection Instructions
Index: llvm/test/CodeGen/X86/AMX/amx-type.ll
===================================================================
--- llvm/test/CodeGen/X86/AMX/amx-type.ll
+++ llvm/test/CodeGen/X86/AMX/amx-type.ll
@@ -1,5 +1,5 @@
 ; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
-; RUN: opt -lower-amx-type %s -S | FileCheck %s
+; RUN: opt --codegen-opt-level=2 -mtriple=x86_64 -lower-amx-type %s -S | FileCheck %s
 
 %struct.__tile_str = type { i16, i16, <256 x i32> }
 
Index: llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-low-intrinsics.ll
@@ -0,0 +1,237 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=x86_64 -lower-amx-intrinsics %s -S | FileCheck %s
+
+define dso_local void @test_amx_load_non_O0(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr) {
+; CHECK-LABEL: @test_amx_load_non_O0(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[STRIDE:%.*]], 2
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.rows.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP11:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tileload.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.cols.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW]], [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TMP11]], [[TILELOAD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tileload.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast i8* [[PTR:%.*]] to i32*
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, i32* [[TMP6]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = add i16 [[TMP8]], [[TILELOAD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP10:%.*]] = load i32, i32* [[TMP7]], align 4
+; CHECK-NEXT:    [[TMP11]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP10]], i16 [[TMP9]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tileload.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND]], label [[TILELOAD_SCALARIZE_COLS_HEADER]], label [[TILELOAD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tileload.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND]], label [[TILELOAD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32> [[TMP11]] to x86_amx
+; CHECK-NEXT:    store <256 x i32> [[TMP11]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_load(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_load(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[STRIDE:%.*]], 2
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.rows.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP11:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tileload.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.cols.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW]], [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TMP11]], [[TILELOAD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tileload.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast i8* [[PTR:%.*]] to i32*
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, i32* [[TMP6]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = add i16 [[TMP8]], [[TILELOAD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP10:%.*]] = load i32, i32* [[TMP7]], align 4
+; CHECK-NEXT:    [[TMP11]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP10]], i16 [[TMP9]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tileload.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND]], label [[TILELOAD_SCALARIZE_COLS_HEADER]], label [[TILELOAD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tileload.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND]], label [[TILELOAD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32> [[TMP11]] to x86_amx
+; CHECK-NEXT:    store <256 x i32> [[TMP11]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tileloadd64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_dp(i16 signext %row, i16 signext %col, i16 signext %k, <256 x i32> %c, <256 x i32> %a, <256 x i32> %b, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_dp(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[A_AMX:%.*]] = bitcast <256 x i32> [[A:%.*]] to x86_amx
+; CHECK-NEXT:    [[B_AMX:%.*]] = bitcast <256 x i32> [[B:%.*]] to x86_amx
+; CHECK-NEXT:    [[C_AMX:%.*]] = bitcast <256 x i32> [[C:%.*]] to x86_amx
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i16 [[K:%.*]], 2
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.rows.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILEDPBSSD_SCALARIZE_ROWS_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_PHI_ROW:%.*]] = phi <256 x i32> [ [[C]], [[ENTRY]] ], [ [[TMP18:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    [[VEC_D_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP20:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.cols.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TILEDPBSSD_SCALARIZE_COLS_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_ROW]], [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TMP18]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    [[VEC_D_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_ROW]], [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TMP20]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    [[TMP2:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP3:%.*]] = add i16 [[TMP2]], [[TILEDPBSSD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.cols.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.inner.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_SCALARIZE_COLS_BODY]] ], [ [[TILEDPBSSD_SCALARIZE_INNER_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_INNER_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_COL]], [[TILEDPBSSD_SCALARIZE_COLS_BODY]] ], [ [[TMP18]], [[TILEDPBSSD_SCALARIZE_INNER_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.inner.body:
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP5:%.*]] = add i16 [[TMP4]], [[TILEDPBSSD_SCALARIZE_INNER_IV]]
+; CHECK-NEXT:    [[TMP6:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_INNER_IV]], 16
+; CHECK-NEXT:    [[TMP7:%.*]] = add i16 [[TMP6]], [[TILEDPBSSD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP8:%.*]] = extractelement <256 x i32> [[VEC_C_INNER_PHI]], i16 [[TMP3]]
+; CHECK-NEXT:    [[TMP9:%.*]] = extractelement <256 x i32> [[A]], i16 [[TMP5]]
+; CHECK-NEXT:    [[TMP10:%.*]] = bitcast i32 [[TMP9]] to <4 x i8>
+; CHECK-NEXT:    [[TMP11:%.*]] = extractelement <256 x i32> [[B]], i16 [[TMP7]]
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast i32 [[TMP11]] to <4 x i8>
+; CHECK-NEXT:    [[TMP13:%.*]] = sext <4 x i8> [[TMP12]] to <4 x i32>
+; CHECK-NEXT:    [[TMP14:%.*]] = sext <4 x i8> [[TMP10]] to <4 x i32>
+; CHECK-NEXT:    [[TMP15:%.*]] = mul <4 x i32> [[TMP14]], [[TMP13]]
+; CHECK-NEXT:    [[TMP16:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP15]])
+; CHECK-NEXT:    [[TMP17:%.*]] = add i32 [[TMP8]], [[TMP16]]
+; CHECK-NEXT:    [[TMP18]] = insertelement <256 x i32> [[VEC_C_INNER_PHI]], i32 [[TMP17]], i16 [[TMP3]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_LATCH]]
+; CHECK:       tiledpbssd.scalarize.inner.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_INNER_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_INNER_STEP]], [[TMP1]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_INNER_COND]], label [[TILEDPBSSD_SCALARIZE_INNER_HEADER]], label [[TILEDPBSSD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tiledpbssd.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    [[TMP19:%.*]] = extractelement <256 x i32> [[TMP18]], i16 [[TMP3]]
+; CHECK-NEXT:    [[TMP20]] = insertelement <256 x i32> [[VEC_D_PHI_COL]], i32 [[TMP19]], i16 [[TMP3]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_COLS_COND]], label [[TILEDPBSSD_SCALARIZE_COLS_HEADER]], label [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tiledpbssd.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_ROWS_COND]], label [[TILEDPBSSD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    [[TMP21:%.*]] = bitcast <256 x i32> [[TMP20]] to x86_amx
+; CHECK-NEXT:    store <256 x i32> [[TMP20]], <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %a.amx = bitcast <256 x i32> %a to x86_amx
+  %b.amx = bitcast <256 x i32> %b to x86_amx
+  %c.amx = bitcast <256 x i32> %c to x86_amx
+  %acc = call x86_amx @llvm.x86.tdpbssd.internal(i16 %row, i16 %col, i16 %k, x86_amx %c.amx, x86_amx %a.amx, x86_amx %b.amx)
+  %vec = bitcast x86_amx %acc to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+define dso_local void @test_amx_store(i16 signext %row, i16 signext %col, i8 *%ptr, i64 %stride, <256 x i32>* %vptr, <256 x i32> %vec) #0 {
+; CHECK-LABEL: @test_amx_store(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[AMX:%.*]] = bitcast <256 x i32> [[VEC:%.*]] to x86_amx
+; CHECK-NEXT:    [[TMP0:%.*]] = lshr i16 [[COL:%.*]], 2
+; CHECK-NEXT:    [[TMP1:%.*]] = lshr i64 [[STRIDE:%.*]], 2
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tilestore.scalarize.rows.header:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILESTORE_SCALARIZE_ROWS_STEP:%.*]], [[TILESTORE_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tilestore.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tilestore.scalarize.cols.header:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILESTORE_SCALARIZE_ROWS_BODY]] ], [ [[TILESTORE_SCALARIZE_COLS_STEP:%.*]], [[TILESTORE_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tilestore.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILESTORE_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = zext i16 [[TILESTORE_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP4:%.*]] = mul i64 [[TMP2]], [[TMP1]]
+; CHECK-NEXT:    [[TMP5:%.*]] = add i64 [[TMP4]], [[TMP3]]
+; CHECK-NEXT:    [[TMP6:%.*]] = bitcast i8* [[PTR:%.*]] to i32*
+; CHECK-NEXT:    [[TMP7:%.*]] = getelementptr i32, i32* [[TMP6]], i64 [[TMP5]]
+; CHECK-NEXT:    [[TMP8:%.*]] = mul i16 [[TILESTORE_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP9:%.*]] = add i16 [[TMP8]], [[TILESTORE_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP10:%.*]] = extractelement <256 x i32> [[VEC]], i16 [[TMP9]]
+; CHECK-NEXT:    store i32 [[TMP10]], i32* [[TMP7]], align 4
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_LATCH]]
+; CHECK:       tilestore.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_STEP]] = add i16 [[TILESTORE_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILESTORE_SCALARIZE_COLS_STEP]], [[TMP0]]
+; CHECK-NEXT:    br i1 [[TILESTORE_SCALARIZE_COLS_COND]], label [[TILESTORE_SCALARIZE_COLS_HEADER]], label [[TILESTORE_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tilestore.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_STEP]] = add i16 [[TILESTORE_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILESTORE_SCALARIZE_ROWS_STEP]], [[ROW:%.*]]
+; CHECK-NEXT:    br i1 [[TILESTORE_SCALARIZE_ROWS_COND]], label [[TILESTORE_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = bitcast <256 x i32> %vec to x86_amx
+  call void @llvm.x86.tilestored64.internal(i16 %row, i16 %col, i8* %ptr, i64 %stride, x86_amx %amx)
+  ret void
+}
+
+define dso_local void @test_amx_zero(i16 signext %row, i16 signext %col, <256 x i32>* %vptr) #0 {
+; CHECK-LABEL: @test_amx_zero(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    store <256 x i32> zeroinitializer, <256 x i32>* [[VPTR:%.*]], align 64
+; CHECK-NEXT:    ret void
+;
+entry:
+  %amx = call x86_amx @llvm.x86.tilezero.internal(i16 %row, i16 %col)
+  %vec = bitcast x86_amx %amx to <256 x i32>
+  store <256 x i32> %vec, <256 x i32>* %vptr, align 64
+  ret void
+}
+
+declare x86_amx @llvm.x86.tilezero.internal(i16, i16)
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+
+attributes #0 = { noinline nounwind optnone }
Index: llvm/test/CodeGen/X86/AMX/amx-low-intrinsics-no-amx-bitcast.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-low-intrinsics-no-amx-bitcast.ll
@@ -0,0 +1,211 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -mtriple=x86_64 -lower-amx-intrinsics %s -S | FileCheck %s
+
+define dso_local void @test_no_bitcast(i32* %A_mem, i32* %B_mem, i32* %C_mem) local_unnamed_addr #0 {
+; CHECK-LABEL: @test_no_bitcast(
+; CHECK-NEXT:  entry:
+; CHECK-NEXT:    [[TMP0:%.*]] = bitcast i32* [[C_MEM:%.*]] to i8*
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.rows.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[ENTRY:%.*]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[ENTRY]] ], [ [[TMP10:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tileload.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tileload.scalarize.cols.header:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW]], [[TILELOAD_SCALARIZE_ROWS_BODY]] ], [ [[TMP10]], [[TILELOAD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tileload.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP1:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP2:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP3:%.*]] = mul i64 [[TMP1]], 4
+; CHECK-NEXT:    [[TMP4:%.*]] = add i64 [[TMP3]], [[TMP2]]
+; CHECK-NEXT:    [[TMP5:%.*]] = bitcast i8* [[TMP0]] to i32*
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr i32, i32* [[TMP5]], i64 [[TMP4]]
+; CHECK-NEXT:    [[TMP7:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP8:%.*]] = add i16 [[TMP7]], [[TILELOAD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP9:%.*]] = load i32, i32* [[TMP6]], align 4
+; CHECK-NEXT:    [[TMP10]] = insertelement <256 x i32> [[VEC_PHI]], i32 [[TMP9]], i16 [[TMP8]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tileload.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND]], label [[TILELOAD_SCALARIZE_COLS_HEADER]], label [[TILELOAD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tileload.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND]], label [[TILELOAD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE:%.*]]
+; CHECK:       continue:
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <256 x i32> [[TMP10]] to x86_amx
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast i32* [[A_MEM:%.*]] to i8*
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER2:%.*]]
+; CHECK:       tileload.scalarize.rows.header2:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV5:%.*]] = phi i16 [ 0, [[CONTINUE]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP6:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH4:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW14:%.*]] = phi <256 x i32> [ zeroinitializer, [[CONTINUE]] ], [ [[TMP22:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH4]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY3:%.*]]
+; CHECK:       tileload.scalarize.rows.body3:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER8:%.*]]
+; CHECK:       tileload.scalarize.cols.header8:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV11:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY3]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP12:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH10:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI15:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW14]], [[TILELOAD_SCALARIZE_ROWS_BODY3]] ], [ [[TMP22]], [[TILELOAD_SCALARIZE_COLS_LATCH10]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY9:%.*]]
+; CHECK:       tileload.scalarize.cols.body9:
+; CHECK-NEXT:    [[TMP13:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV5]] to i64
+; CHECK-NEXT:    [[TMP14:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV11]] to i64
+; CHECK-NEXT:    [[TMP15:%.*]] = mul i64 [[TMP13]], 4
+; CHECK-NEXT:    [[TMP16:%.*]] = add i64 [[TMP15]], [[TMP14]]
+; CHECK-NEXT:    [[TMP17:%.*]] = bitcast i8* [[TMP12]] to i32*
+; CHECK-NEXT:    [[TMP18:%.*]] = getelementptr i32, i32* [[TMP17]], i64 [[TMP16]]
+; CHECK-NEXT:    [[TMP19:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV5]], 16
+; CHECK-NEXT:    [[TMP20:%.*]] = add i16 [[TMP19]], [[TILELOAD_SCALARIZE_COLS_IV11]]
+; CHECK-NEXT:    [[TMP21:%.*]] = load i32, i32* [[TMP18]], align 4
+; CHECK-NEXT:    [[TMP22]] = insertelement <256 x i32> [[VEC_PHI15]], i32 [[TMP21]], i16 [[TMP20]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH10]]
+; CHECK:       tileload.scalarize.cols.latch10:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP12]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV11]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND13:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP12]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND13]], label [[TILELOAD_SCALARIZE_COLS_HEADER8]], label [[TILELOAD_SCALARIZE_ROWS_LATCH4]]
+; CHECK:       tileload.scalarize.rows.latch4:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP6]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV5]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND7:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP6]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND7]], label [[TILELOAD_SCALARIZE_ROWS_HEADER2]], label [[CONTINUE1:%.*]]
+; CHECK:       continue1:
+; CHECK-NEXT:    [[TMP23:%.*]] = bitcast <256 x i32> [[TMP22]] to x86_amx
+; CHECK-NEXT:    [[TMP24:%.*]] = bitcast i32* [[B_MEM:%.*]] to i8*
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_HEADER17:%.*]]
+; CHECK:       tileload.scalarize.rows.header17:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_IV20:%.*]] = phi i16 [ 0, [[CONTINUE1]] ], [ [[TILELOAD_SCALARIZE_ROWS_STEP21:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH19:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI_ROW29:%.*]] = phi <256 x i32> [ zeroinitializer, [[CONTINUE1]] ], [ [[TMP34:%.*]], [[TILELOAD_SCALARIZE_ROWS_LATCH19]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_ROWS_BODY18:%.*]]
+; CHECK:       tileload.scalarize.rows.body18:
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_HEADER23:%.*]]
+; CHECK:       tileload.scalarize.cols.header23:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_IV26:%.*]] = phi i16 [ 0, [[TILELOAD_SCALARIZE_ROWS_BODY18]] ], [ [[TILELOAD_SCALARIZE_COLS_STEP27:%.*]], [[TILELOAD_SCALARIZE_COLS_LATCH25:%.*]] ]
+; CHECK-NEXT:    [[VEC_PHI30:%.*]] = phi <256 x i32> [ [[VEC_PHI_ROW29]], [[TILELOAD_SCALARIZE_ROWS_BODY18]] ], [ [[TMP34]], [[TILELOAD_SCALARIZE_COLS_LATCH25]] ]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_BODY24:%.*]]
+; CHECK:       tileload.scalarize.cols.body24:
+; CHECK-NEXT:    [[TMP25:%.*]] = zext i16 [[TILELOAD_SCALARIZE_ROWS_IV20]] to i64
+; CHECK-NEXT:    [[TMP26:%.*]] = zext i16 [[TILELOAD_SCALARIZE_COLS_IV26]] to i64
+; CHECK-NEXT:    [[TMP27:%.*]] = mul i64 [[TMP25]], 4
+; CHECK-NEXT:    [[TMP28:%.*]] = add i64 [[TMP27]], [[TMP26]]
+; CHECK-NEXT:    [[TMP29:%.*]] = bitcast i8* [[TMP24]] to i32*
+; CHECK-NEXT:    [[TMP30:%.*]] = getelementptr i32, i32* [[TMP29]], i64 [[TMP28]]
+; CHECK-NEXT:    [[TMP31:%.*]] = mul i16 [[TILELOAD_SCALARIZE_ROWS_IV20]], 16
+; CHECK-NEXT:    [[TMP32:%.*]] = add i16 [[TMP31]], [[TILELOAD_SCALARIZE_COLS_IV26]]
+; CHECK-NEXT:    [[TMP33:%.*]] = load i32, i32* [[TMP30]], align 4
+; CHECK-NEXT:    [[TMP34]] = insertelement <256 x i32> [[VEC_PHI30]], i32 [[TMP33]], i16 [[TMP32]]
+; CHECK-NEXT:    br label [[TILELOAD_SCALARIZE_COLS_LATCH25]]
+; CHECK:       tileload.scalarize.cols.latch25:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_STEP27]] = add i16 [[TILELOAD_SCALARIZE_COLS_IV26]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_COLS_COND28:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_COLS_STEP27]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_COLS_COND28]], label [[TILELOAD_SCALARIZE_COLS_HEADER23]], label [[TILELOAD_SCALARIZE_ROWS_LATCH19]]
+; CHECK:       tileload.scalarize.rows.latch19:
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_STEP21]] = add i16 [[TILELOAD_SCALARIZE_ROWS_IV20]], 1
+; CHECK-NEXT:    [[TILELOAD_SCALARIZE_ROWS_COND22:%.*]] = icmp ne i16 [[TILELOAD_SCALARIZE_ROWS_STEP21]], 4
+; CHECK-NEXT:    br i1 [[TILELOAD_SCALARIZE_ROWS_COND22]], label [[TILELOAD_SCALARIZE_ROWS_HEADER17]], label [[CONTINUE16:%.*]]
+; CHECK:       continue16:
+; CHECK-NEXT:    [[TMP35:%.*]] = bitcast <256 x i32> [[TMP34]] to x86_amx
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.rows.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[CONTINUE16]] ], [ [[TILEDPBSSD_SCALARIZE_ROWS_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_PHI_ROW:%.*]] = phi <256 x i32> [ [[TMP10]], [[CONTINUE16]] ], [ [[TMP52:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    [[VEC_D_PHI_ROW:%.*]] = phi <256 x i32> [ zeroinitializer, [[CONTINUE16]] ], [ [[TMP54:%.*]], [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.cols.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TILEDPBSSD_SCALARIZE_COLS_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_ROW]], [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TMP52]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    [[VEC_D_PHI_COL:%.*]] = phi <256 x i32> [ [[VEC_D_PHI_ROW]], [[TILEDPBSSD_SCALARIZE_ROWS_BODY]] ], [ [[TMP54]], [[TILEDPBSSD_SCALARIZE_COLS_LATCH]] ]
+; CHECK-NEXT:    [[TMP36:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP37:%.*]] = add i16 [[TMP36]], [[TILEDPBSSD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.cols.body:
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_HEADER:%.*]]
+; CHECK:       tiledpbssd.scalarize.inner.header:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_IV:%.*]] = phi i16 [ 0, [[TILEDPBSSD_SCALARIZE_COLS_BODY]] ], [ [[TILEDPBSSD_SCALARIZE_INNER_STEP:%.*]], [[TILEDPBSSD_SCALARIZE_INNER_LATCH:%.*]] ]
+; CHECK-NEXT:    [[VEC_C_INNER_PHI:%.*]] = phi <256 x i32> [ [[VEC_C_PHI_COL]], [[TILEDPBSSD_SCALARIZE_COLS_BODY]] ], [ [[TMP52]], [[TILEDPBSSD_SCALARIZE_INNER_LATCH]] ]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_BODY:%.*]]
+; CHECK:       tiledpbssd.scalarize.inner.body:
+; CHECK-NEXT:    [[TMP38:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP39:%.*]] = add i16 [[TMP38]], [[TILEDPBSSD_SCALARIZE_INNER_IV]]
+; CHECK-NEXT:    [[TMP40:%.*]] = mul i16 [[TILEDPBSSD_SCALARIZE_INNER_IV]], 16
+; CHECK-NEXT:    [[TMP41:%.*]] = add i16 [[TMP40]], [[TILEDPBSSD_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP42:%.*]] = extractelement <256 x i32> [[VEC_C_INNER_PHI]], i16 [[TMP37]]
+; CHECK-NEXT:    [[TMP43:%.*]] = extractelement <256 x i32> [[TMP22]], i16 [[TMP39]]
+; CHECK-NEXT:    [[TMP44:%.*]] = bitcast i32 [[TMP43]] to <4 x i8>
+; CHECK-NEXT:    [[TMP45:%.*]] = extractelement <256 x i32> [[TMP34]], i16 [[TMP41]]
+; CHECK-NEXT:    [[TMP46:%.*]] = bitcast i32 [[TMP45]] to <4 x i8>
+; CHECK-NEXT:    [[TMP47:%.*]] = sext <4 x i8> [[TMP46]] to <4 x i32>
+; CHECK-NEXT:    [[TMP48:%.*]] = sext <4 x i8> [[TMP44]] to <4 x i32>
+; CHECK-NEXT:    [[TMP49:%.*]] = mul <4 x i32> [[TMP48]], [[TMP47]]
+; CHECK-NEXT:    [[TMP50:%.*]] = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> [[TMP49]])
+; CHECK-NEXT:    [[TMP51:%.*]] = add i32 [[TMP42]], [[TMP50]]
+; CHECK-NEXT:    [[TMP52]] = insertelement <256 x i32> [[VEC_C_INNER_PHI]], i32 [[TMP51]], i16 [[TMP37]]
+; CHECK-NEXT:    br label [[TILEDPBSSD_SCALARIZE_INNER_LATCH]]
+; CHECK:       tiledpbssd.scalarize.inner.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_INNER_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_INNER_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_INNER_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_INNER_COND]], label [[TILEDPBSSD_SCALARIZE_INNER_HEADER]], label [[TILEDPBSSD_SCALARIZE_COLS_LATCH]]
+; CHECK:       tiledpbssd.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_COLS_STEP]], 4
+; CHECK-NEXT:    [[TMP53:%.*]] = extractelement <256 x i32> [[TMP52]], i16 [[TMP37]]
+; CHECK-NEXT:    [[TMP54]] = insertelement <256 x i32> [[VEC_D_PHI_COL]], i32 [[TMP53]], i16 [[TMP37]]
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_COLS_COND]], label [[TILEDPBSSD_SCALARIZE_COLS_HEADER]], label [[TILEDPBSSD_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tiledpbssd.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_STEP]] = add i16 [[TILEDPBSSD_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILEDPBSSD_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILEDPBSSD_SCALARIZE_ROWS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILEDPBSSD_SCALARIZE_ROWS_COND]], label [[TILEDPBSSD_SCALARIZE_ROWS_HEADER]], label [[CONTINUE31:%.*]]
+; CHECK:       continue31:
+; CHECK-NEXT:    [[TMP55:%.*]] = bitcast <256 x i32> [[TMP54]] to x86_amx
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_ROWS_HEADER:%.*]]
+; CHECK:       tilestore.scalarize.rows.header:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_IV:%.*]] = phi i16 [ 0, [[CONTINUE31]] ], [ [[TILESTORE_SCALARIZE_ROWS_STEP:%.*]], [[TILESTORE_SCALARIZE_ROWS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_ROWS_BODY:%.*]]
+; CHECK:       tilestore.scalarize.rows.body:
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_HEADER:%.*]]
+; CHECK:       tilestore.scalarize.cols.header:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_IV:%.*]] = phi i16 [ 0, [[TILESTORE_SCALARIZE_ROWS_BODY]] ], [ [[TILESTORE_SCALARIZE_COLS_STEP:%.*]], [[TILESTORE_SCALARIZE_COLS_LATCH:%.*]] ]
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_BODY:%.*]]
+; CHECK:       tilestore.scalarize.cols.body:
+; CHECK-NEXT:    [[TMP56:%.*]] = zext i16 [[TILESTORE_SCALARIZE_ROWS_IV]] to i64
+; CHECK-NEXT:    [[TMP57:%.*]] = zext i16 [[TILESTORE_SCALARIZE_COLS_IV]] to i64
+; CHECK-NEXT:    [[TMP58:%.*]] = mul i64 [[TMP56]], 4
+; CHECK-NEXT:    [[TMP59:%.*]] = add i64 [[TMP58]], [[TMP57]]
+; CHECK-NEXT:    [[TMP60:%.*]] = bitcast i8* [[TMP0]] to i32*
+; CHECK-NEXT:    [[TMP61:%.*]] = getelementptr i32, i32* [[TMP60]], i64 [[TMP59]]
+; CHECK-NEXT:    [[TMP62:%.*]] = mul i16 [[TILESTORE_SCALARIZE_ROWS_IV]], 16
+; CHECK-NEXT:    [[TMP63:%.*]] = add i16 [[TMP62]], [[TILESTORE_SCALARIZE_COLS_IV]]
+; CHECK-NEXT:    [[TMP64:%.*]] = extractelement <256 x i32> [[TMP54]], i16 [[TMP63]]
+; CHECK-NEXT:    store i32 [[TMP64]], i32* [[TMP61]], align 4
+; CHECK-NEXT:    br label [[TILESTORE_SCALARIZE_COLS_LATCH]]
+; CHECK:       tilestore.scalarize.cols.latch:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_STEP]] = add i16 [[TILESTORE_SCALARIZE_COLS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_COLS_COND:%.*]] = icmp ne i16 [[TILESTORE_SCALARIZE_COLS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILESTORE_SCALARIZE_COLS_COND]], label [[TILESTORE_SCALARIZE_COLS_HEADER]], label [[TILESTORE_SCALARIZE_ROWS_LATCH]]
+; CHECK:       tilestore.scalarize.rows.latch:
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_STEP]] = add i16 [[TILESTORE_SCALARIZE_ROWS_IV]], 1
+; CHECK-NEXT:    [[TILESTORE_SCALARIZE_ROWS_COND:%.*]] = icmp ne i16 [[TILESTORE_SCALARIZE_ROWS_STEP]], 4
+; CHECK-NEXT:    br i1 [[TILESTORE_SCALARIZE_ROWS_COND]], label [[TILESTORE_SCALARIZE_ROWS_HEADER]], label [[CONTINUE32:%.*]]
+; CHECK:       continue32:
+; CHECK-NEXT:    ret void
+;
+entry:
+  %0 = bitcast i32* %C_mem to i8*
+  %1 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 4, i16 16, i8* %0, i64 16)
+  %2 = bitcast i32* %A_mem to i8*
+  %3 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 4, i16 16, i8* %2, i64 16)
+  %4 = bitcast i32* %B_mem to i8*
+  %5 = tail call x86_amx @llvm.x86.tileloadd64.internal(i16 4, i16 16, i8* %4, i64 16)
+  %6 = tail call x86_amx @llvm.x86.tdpbssd.internal(i16 4, i16 16, i16 16, x86_amx %1, x86_amx %3, x86_amx %5)
+  tail call void @llvm.x86.tilestored64.internal(i16 4, i16 16, i8* %0, i64 16, x86_amx %6)
+  ret void
+}
+
+declare x86_amx @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64)
+declare x86_amx @llvm.x86.tdpbssd.internal(i16, i16, i16, x86_amx, x86_amx, x86_amx)
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, x86_amx)
+
+attributes #0 = { noinline nounwind optnone }
Index: llvm/lib/Target/X86/X86TargetMachine.cpp
===================================================================
--- llvm/lib/Target/X86/X86TargetMachine.cpp
+++ llvm/lib/Target/X86/X86TargetMachine.cpp
@@ -62,6 +62,7 @@
   RegisterTargetMachine<X86TargetMachine> Y(getTheX86_64Target());
 
   PassRegistry &PR = *PassRegistry::getPassRegistry();
+  initializeX86LowerAMXIntrinsicsLegacyPassPass(PR);
   initializeX86LowerAMXTypeLegacyPassPass(PR);
   initializeGlobalISel(PR);
   initializeWinEHStatePassPass(PR);
@@ -410,6 +411,10 @@
 
 void X86PassConfig::addIRPasses() {
   addPass(createAtomicExpandPass());
+
+  // We add both pass anyway and when these two passes run, we skip the pass
+  // based on the option level and option attribute.
+  addPass(createX86LowerAMXIntrinsicsPass());
   addPass(createX86LowerAMXTypePass());
 
   TargetPassConfig::addIRPasses();
Index: llvm/lib/Target/X86/X86LowerAMXType.cpp
===================================================================
--- llvm/lib/Target/X86/X86LowerAMXType.cpp
+++ llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -23,6 +23,7 @@
 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
 #include "llvm/Analysis/TargetTransformInfo.h"
 #include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
 #include "llvm/CodeGen/ValueTypes.h"
 #include "llvm/IR/DataLayout.h"
 #include "llvm/IR/Function.h"
@@ -33,6 +34,7 @@
 #include "llvm/IR/PatternMatch.h"
 #include "llvm/InitializePasses.h"
 #include "llvm/Pass.h"
+#include "llvm/Target/TargetMachine.h"
 
 using namespace llvm;
 using namespace PatternMatch;
@@ -327,6 +329,10 @@
   }
 
   bool runOnFunction(Function &F) override {
+    TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+    if (F.hasFnAttribute(Attribute::OptimizeNone) ||
+        TM->getOptLevel() == CodeGenOpt::None)
+      return false;
     X86LowerAMXType LAT(F);
     bool C = LAT.visit();
     return C;
@@ -334,6 +340,7 @@
 
   void getAnalysisUsage(AnalysisUsage &AU) const override {
     AU.setPreservesCFG();
+    AU.addRequired<TargetPassConfig>();
   }
 };
 
@@ -343,6 +350,7 @@
 char X86LowerAMXTypeLegacyPass::ID = 0;
 INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
                       false)
+INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
 INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, PassName, false,
                     false)
 
Index: llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
===================================================================
--- /dev/null
+++ llvm/lib/Target/X86/X86LowerAMXIntrinsics.cpp
@@ -0,0 +1,539 @@
+//===-- X86LowerAMXIntrinsics.cpp -X86 Scalarize AMX Intrinsics------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+/// \file Pass to transform amx intrinsics to scalar operation.
+/// This pass is only enabled with -O0. With -O0, the def of shape to amx
+/// intrinsics is near the amx intrinsics code. We are not able to find a
+/// point which post-dominate all the shape and dominate all amx intrinsics.
+/// To decouple the dependency of the shape, we transform amx intrinsics
+/// to scalar operation, so that compiling doesn't fail. In long term, we
+/// should improve fast register allocation to allocate amx register.
+//===----------------------------------------------------------------------===//
+//
+#include "X86.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/ADT/PostOrderIterator.h"
+#include "llvm/Analysis/DomTreeUpdater.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetPassConfig.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/IR/PatternMatch.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+#include "llvm/Target/TargetMachine.h"
+#include "llvm/Transforms/Utils/BasicBlockUtils.h"
+#include "llvm/Transforms/Utils/LoopUtils.h"
+
+using namespace llvm;
+using namespace PatternMatch;
+
+#define DEBUG_TYPE "lower-amx-intrinsics"
+
+static bool isV256I32Ty(Type *Ty) {
+  if (auto *FVT = dyn_cast<FixedVectorType>(Ty)) {
+    return FVT->getNumElements() == 256 &&
+           FVT->getElementType()->isIntegerTy(32);
+  }
+  return false;
+}
+
+static BasicBlock *createLoop(BasicBlock *Preheader, BasicBlock *Exit,
+                              Value *Bound, Value *Step, StringRef Name,
+                              IRBuilderBase &B, DomTreeUpdater &DTU, Loop *L,
+                              LoopInfo &LI) {
+  LLVMContext &Ctx = Preheader->getContext();
+  BasicBlock *Header =
+      BasicBlock::Create(Ctx, Name + ".header", Preheader->getParent(), Exit);
+  BasicBlock *Body =
+      BasicBlock::Create(Ctx, Name + ".body", Header->getParent(), Exit);
+  BasicBlock *Latch =
+      BasicBlock::Create(Ctx, Name + ".latch", Header->getParent(), Exit);
+
+  Type *I16Ty = Type::getInt16Ty(Ctx);
+  BranchInst::Create(Body, Header);
+  BranchInst::Create(Latch, Body);
+  PHINode *IV =
+      PHINode::Create(I16Ty, 2, Name + ".iv", Header->getTerminator());
+  IV->addIncoming(ConstantInt::get(I16Ty, 0), Preheader);
+
+  B.SetInsertPoint(Latch);
+  Value *Inc = B.CreateAdd(IV, Step, Name + ".step");
+  Value *Cond = B.CreateICmpNE(Inc, Bound, Name + ".cond");
+  BranchInst::Create(Header, Exit, Cond, Latch);
+  IV->addIncoming(Inc, Latch);
+
+  BranchInst *PreheaderBr = cast<BranchInst>(Preheader->getTerminator());
+  BasicBlock *Tmp = PreheaderBr->getSuccessor(0);
+  PreheaderBr->setSuccessor(0, Header);
+  DTU.applyUpdatesPermissive({
+      {DominatorTree::Delete, Preheader, Tmp},
+      {DominatorTree::Insert, Header, Body},
+      {DominatorTree::Insert, Body, Latch},
+      {DominatorTree::Insert, Latch, Header},
+      {DominatorTree::Insert, Latch, Exit},
+      {DominatorTree::Insert, Preheader, Header},
+  });
+
+  L->addBasicBlockToLoop(Header, LI);
+  L->addBasicBlockToLoop(Body, LI);
+  L->addBasicBlockToLoop(Latch, LI);
+  return Body;
+}
+
+template <bool isTileLoad>
+static Value *createTileLoadStoreLoops(BasicBlock *Start, BasicBlock *End,
+                                       IRBuilderBase &B, DomTreeUpdater &DTU,
+                                       LoopInfo &LI, Value *Row, Value *Col,
+                                       Value *Ptr, Value *Stride, Value *Tile) {
+  std::string IntrinName = isTileLoad ? "tileload" : "tilestore";
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *ColLoop = LI.AllocateLoop();
+  RowLoop->addChildLoop(ColLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(RowLoop);
+  else
+    LI.addTopLevelLoop(RowLoop);
+
+  BasicBlock *RowBody =
+      createLoop(Start, End, Row, B.getInt16(1), IntrinName + ".scalarize.rows",
+                 B, DTU, RowLoop, LI);
+  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+  uint16_t ColStep = 1;
+  BasicBlock *ColBody =
+      createLoop(RowBody, RowLatch, Col, B.getInt16(ColStep),
+                 IntrinName + ".scalarize.cols", B, DTU, ColLoop, LI);
+
+  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
+  BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor();
+  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+  Value *CurrentRow = &*RowLoopHeader->begin();
+  Value *CurrentCol = &*ColumnLoopHeader->begin();
+  FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
+  Type *EltTy = V256I32Ty->getElementType();
+
+  // Common part for tileload and tilestore
+  // *.scalarize.cols.body:
+  // Calculate %idxmem and %idxvec
+  B.SetInsertPoint(ColBody->getTerminator());
+  Value *CurrentRowZExt = B.CreateZExt(CurrentRow, Stride->getType());
+  Value *CurrentColZExt = B.CreateZExt(CurrentCol, Stride->getType());
+  Value *Offset =
+      B.CreateAdd(B.CreateMul(CurrentRowZExt, Stride), CurrentColZExt);
+  unsigned AS = cast<PointerType>(Ptr->getType())->getAddressSpace();
+  Value *EltBasePtr = B.CreatePointerCast(Ptr, PointerType::get(EltTy, AS));
+  Value *EltPtr = B.CreateGEP(EltTy, EltBasePtr, Offset);
+  Value *Idx = B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+  if (isTileLoad) {
+    // tileload.scalarize.rows.header:
+    // %vec.phi.row = phi <256 x i32> [ zeroinitializer, %entry ], [ %ResVec,
+    // %tileload.scalarize.rows.latch ]
+    B.SetInsertPoint(RowLoopHeader->getTerminator());
+    Value *VecZero = Constant::getNullValue(V256I32Ty);
+    PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.phi.row");
+    VecCPhiRowLoop->addIncoming(VecZero, Start);
+
+    // tileload.scalarize.cols.header:
+    // %vec.phi = phi <256 x i32> [ %vec.phi.row, %tileload.scalarize.rows.body
+    // ], [ %ResVec, %tileload.scalarize.cols.latch ]
+    B.SetInsertPoint(ColumnLoopHeader->getTerminator());
+    PHINode *VecPhi = B.CreatePHI(V256I32Ty, 2, "vec.phi");
+    VecPhi->addIncoming(VecCPhiRowLoop, RowBody);
+
+    // tileload.scalarize.cols.body:
+    // Calculate %idxmem and %idxvec
+    // %eltptr = getelementptr i32, i32* %base, i64 %idxmem
+    // %elt = load i32, i32* %ptr
+    // %ResVec = insertelement <256 x i32> %vec.phi, i32 %elt, i16 %idxvec
+    B.SetInsertPoint(ColBody->getTerminator());
+    Value *Elt = B.CreateLoad(EltTy, EltPtr);
+    Value *ResVec = B.CreateInsertElement(VecPhi, Elt, Idx);
+    VecPhi->addIncoming(ResVec, ColLoopLatch);
+    VecCPhiRowLoop->addIncoming(ResVec, RowLatch);
+
+    return ResVec;
+  } else {
+    auto *BitCast = cast<BitCastInst>(Tile);
+    Value *Vec = BitCast->getOperand(0);
+    assert(isV256I32Ty(Vec->getType()) && "bitcast from non-v256i32 to x86amx");
+    // tilestore.scalarize.cols.body:
+    // %mul = mul i16 %row.iv, i16 16
+    // %idx = add i16 %mul, i16 %col.iv
+    // %vec = extractelement <16 x i32> %vec, i16 %idx
+    // store i32 %vec, i32* %ptr
+    B.SetInsertPoint(ColBody->getTerminator());
+    Value *Elt = B.CreateExtractElement(Vec, Idx);
+
+    B.CreateStore(Elt, EltPtr);
+    return nullptr;
+  }
+}
+
+static Value *createTileDPBSSDLoops(BasicBlock *Start, BasicBlock *End,
+                                    IRBuilderBase &B, DomTreeUpdater &DTU,
+                                    LoopInfo &LI, Value *Row, Value *Col,
+                                    Value *K, Value *Acc, Value *LHS,
+                                    Value *RHS) {
+  Loop *RowLoop = LI.AllocateLoop();
+  Loop *ColLoop = LI.AllocateLoop();
+  Loop *InnerLoop = LI.AllocateLoop();
+  ColLoop->addChildLoop(InnerLoop);
+  RowLoop->addChildLoop(ColLoop);
+  if (Loop *ParentL = LI.getLoopFor(Start))
+    ParentL->addChildLoop(RowLoop);
+  else
+    LI.addTopLevelLoop(RowLoop);
+
+  BasicBlock *RowBody =
+      createLoop(Start, End, Row, B.getInt16(1), "tiledpbssd.scalarize.rows", B,
+                 DTU, RowLoop, LI);
+  BasicBlock *RowLatch = RowBody->getSingleSuccessor();
+
+  BasicBlock *ColBody =
+      createLoop(RowBody, RowLatch, Col, B.getInt16(1),
+                 "tiledpbssd.scalarize.cols", B, DTU, ColLoop, LI);
+  BasicBlock *ColLoopLatch = ColBody->getSingleSuccessor();
+
+  B.SetInsertPoint(ColBody->getTerminator());
+  BasicBlock *InnerBody =
+      createLoop(ColBody, ColLoopLatch, K, B.getInt16(1),
+                 "tiledpbssd.scalarize.inner", B, DTU, InnerLoop, LI);
+
+  BasicBlock *ColumnLoopHeader = ColBody->getSinglePredecessor();
+  BasicBlock *RowLoopHeader = RowBody->getSinglePredecessor();
+  BasicBlock *InnerLoopHeader = InnerBody->getSinglePredecessor();
+  BasicBlock *InnerLoopLatch = InnerBody->getSingleSuccessor();
+  Value *CurrentRow = &*RowLoopHeader->begin();
+  Value *CurrentCol = &*ColumnLoopHeader->begin();
+  Value *CurrentInner = &*InnerLoopHeader->begin();
+
+  FixedVectorType *V256I32Ty = FixedVectorType::get(B.getInt32Ty(), 256);
+  auto *BitCastAcc = cast<BitCastInst>(Acc);
+  Value *VecC = BitCastAcc->getOperand(0);
+  assert(isV256I32Ty(VecC->getType()) && "bitcast from non-v256i32 to x86amx");
+  // TODO else create BitCast from x86amx to v256i32.
+  // Store x86amx to memory, and reload from memory
+  // to vector. However with -O0, it doesn't happen.
+  auto *BitCastLHS = cast<BitCastInst>(LHS);
+  Value *VecA = BitCastLHS->getOperand(0);
+  assert(isV256I32Ty(VecA->getType()) && "bitcast from non-v256i32 to x86amx");
+  auto *BitCastRHS = cast<BitCastInst>(RHS);
+  Value *VecB = BitCastRHS->getOperand(0);
+  assert(isV256I32Ty(VecB->getType()) && "bitcast from non-v256i32 to x86amx");
+
+  // tiledpbssd.scalarize.rows.header:
+  // %vec.c.phi.row = phi <256 x i32> [ %VecC, %continue ], [ %NewVecC,
+  // %tiledpbssd.scalarize.rows.latch ]
+
+  // %vec.d.phi.row = phi <256 x i32> [ zeroinitializer, %continue ], [
+  // %NewVecD, %tiledpbssd.scalarize.rows.latch ]
+  B.SetInsertPoint(RowLoopHeader->getTerminator());
+  PHINode *VecCPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.row");
+  VecCPhiRowLoop->addIncoming(VecC, Start);
+  Value *VecZero = Constant::getNullValue(V256I32Ty);
+  PHINode *VecDPhiRowLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.row");
+  VecDPhiRowLoop->addIncoming(VecZero, Start);
+
+  // tiledpbssd.scalarize.cols.header:
+  // %vec.c.phi.col = phi <256 x i32> [ %vec.c.phi.row,
+  // %tiledpbssd.scalarize.rows.body ], [ %NewVecC,
+  // %tiledpbssd.scalarize.cols.latch ]
+
+  // %vec.d.phi.col = phi <256 x i32> [
+  // %vec.d.phi.row, %tiledpbssd.scalarize.rows.body ], [ %NewVecD,
+  // %tiledpbssd.scalarize.cols.latch ]
+
+  // calculate idxc.
+  B.SetInsertPoint(ColumnLoopHeader->getTerminator());
+  PHINode *VecCPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.c.phi.col");
+  VecCPhiColLoop->addIncoming(VecCPhiRowLoop, RowBody);
+  PHINode *VecDPhiColLoop = B.CreatePHI(V256I32Ty, 2, "vec.d.phi.col");
+  VecDPhiColLoop->addIncoming(VecDPhiRowLoop, RowBody);
+  Value *IdxC =
+      B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentCol);
+
+  // tiledpbssd.scalarize.inner.header:
+  // %vec.c.inner.phi = phi <256 x i32> [ %vec.c.phi.col,
+  // %tiledpbssd.scalarize.cols.body ], [ %NewVecC,
+  // %tiledpbssd.scalarize.inner.latch ]
+
+  B.SetInsertPoint(InnerLoopHeader->getTerminator());
+  PHINode *VecCPhi = B.CreatePHI(V256I32Ty, 2, "vec.c.inner.phi");
+  VecCPhi->addIncoming(VecCPhiColLoop, ColBody);
+
+  // tiledpbssd.scalarize.inner.body:
+  // calculate idxa, idxb
+  // %eltc = extractelement <256 x i32> %vec.c.inner.phi, i16 %idxc
+  // %elta = extractelement <256 x i32> %veca, i16 %idxa
+  // %eltav4i8 = bitcast i32 %elta to <4 x i8>
+  // %eltb = extractelement <256 x i32> %vecb, i16 %idxb
+  // %eltbv4i8 = bitcast i32 %eltb to <4 x i8>
+  // %eltav4i32 = sext <4 x i8> %eltav4i8 to <4 x i32>
+  // %eltbv4i32 = sext <4 x i8> %eltbv4i8 to <4 x i32>
+  // %mulab = mul <4 x i32> %eltbv4i32, %eltav4i32
+  // %acc = call i32 @llvm.vector.reduce.add.v4i32(<4 x i32> %131)
+  // %neweltc = add i32 %elt, %acc
+  // %NewVecC = insertelement <256 x i32> %vec.c.inner.phi, i32 %neweltc,
+  // i16 %idxc
+
+  B.SetInsertPoint(InnerBody->getTerminator());
+  Value *IdxA =
+      B.CreateAdd(B.CreateMul(CurrentRow, B.getInt16(16)), CurrentInner);
+  Value *IdxB =
+      B.CreateAdd(B.CreateMul(CurrentInner, B.getInt16(16)), CurrentCol);
+
+  FixedVectorType *V4I8Ty = FixedVectorType::get(B.getInt8Ty(), 4);
+  FixedVectorType *V4I32Ty = FixedVectorType::get(B.getInt32Ty(), 4);
+  Value *EltC = B.CreateExtractElement(VecCPhi, IdxC);
+  Value *EltA = B.CreateExtractElement(VecA, IdxA);
+  Value *SubVecA = B.CreateBitCast(EltA, V4I8Ty);
+  Value *EltB = B.CreateExtractElement(VecB, IdxB);
+  Value *SubVecB = B.CreateBitCast(EltB, V4I8Ty);
+  Value *SubVecR = B.CreateAddReduce(B.CreateMul(
+      B.CreateSExt(SubVecA, V4I32Ty), B.CreateSExt(SubVecB, V4I32Ty)));
+  Value *ResElt = B.CreateAdd(EltC, SubVecR);
+  Value *NewVecC = B.CreateInsertElement(VecCPhi, ResElt, IdxC);
+
+  // tiledpbssd.scalarize.cols.latch:
+  // %NewEltC = extractelement <256 x i32> %vec.c.phi.col, i16 %idxc
+  // %NewVecD = insertelement <256 x i32> %vec.d.phi.col, i32 %NewEltC,
+  // i16 %idxc
+  B.SetInsertPoint(ColLoopLatch->getTerminator());
+  Value *NewEltC = B.CreateExtractElement(NewVecC, IdxC);
+  Value *NewVecD = B.CreateInsertElement(VecDPhiColLoop, NewEltC, IdxC);
+
+  VecCPhi->addIncoming(NewVecC, InnerLoopLatch);
+  VecCPhiRowLoop->addIncoming(NewVecC, RowLatch);
+  VecCPhiColLoop->addIncoming(NewVecC, ColLoopLatch);
+  VecDPhiRowLoop->addIncoming(NewVecD, RowLatch);
+  VecDPhiColLoop->addIncoming(NewVecD, ColLoopLatch);
+
+  return NewVecD;
+}
+
+namespace {
+class X86LowerAMXIntrinsics {
+  Function &Func;
+
+public:
+  X86LowerAMXIntrinsics(Function &F, DominatorTree *DT, LoopInfo *LI)
+      : Func(F), DT(DT), LI(LI) {}
+  bool visit();
+
+private:
+  DominatorTree *DT;
+  LoopInfo *LI;
+  template <bool isTileLoad>
+  bool lowerTileLoadStore(Instruction *TileLoadStore);
+  bool lowerTileDPBSSD(Instruction *TileDPBSSD);
+  bool lowerTileZero(Instruction *TileZero);
+};
+
+bool X86LowerAMXIntrinsics::lowerTileDPBSSD(Instruction *TileDPBSSD) {
+  Value *M, *N, *K, *C, *A, *B;
+  match(TileDPBSSD, m_Intrinsic<Intrinsic::x86_tdpbssd_internal>(
+                        m_Value(M), m_Value(N), m_Value(K), m_Value(C),
+                        m_Value(A), m_Value(B)));
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Instruction *InsertI = TileDPBSSD;
+  IRBuilder<> PreBuilder(TileDPBSSD);
+  PreBuilder.SetInsertPoint(TileDPBSSD);
+  // We visit the loop with (m, n/4, k/4):
+  // %n_dword = lshr i16 %n, 2
+  // %k_dword = lshr i16 %k, 2
+  Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
+  Value *KDWord = PreBuilder.CreateLShr(K, PreBuilder.getInt16(2));
+  BasicBlock *Start = InsertI->getParent();
+  BasicBlock *End =
+      SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+  IRBuilder<> Builder(TileDPBSSD);
+  Value *ResVec = createTileDPBSSDLoops(Start, End, Builder, DTU, *LI, M,
+                                        NDWord, KDWord, C, A, B);
+  // we cannot assume there always be bitcast after tiledpbssd. So we need to
+  // insert one bitcast as required
+  Builder.SetInsertPoint(End->getFirstNonPHI());
+  Value *ResAMX =
+      Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
+  // Delete tiledpbssd intrinsic and do some clean-up.
+  for (auto UI = TileDPBSSD->use_begin(), UE = TileDPBSSD->use_end();
+       UI != UE;) {
+    Instruction *I = cast<Instruction>((UI++)->getUser());
+    Value *Vec;
+    if (match(I, m_BitCast(m_Value(Vec)))) {
+      I->replaceAllUsesWith(ResVec);
+      I->eraseFromParent();
+    }
+  }
+  TileDPBSSD->replaceAllUsesWith(ResAMX);
+  TileDPBSSD->eraseFromParent();
+  return true;
+}
+
+template <bool isTileLoad>
+bool X86LowerAMXIntrinsics::lowerTileLoadStore(Instruction *TileLoadStore) {
+  Value *M, *N, *Ptr, *Stride, *Tile;
+  if (isTileLoad)
+    match(TileLoadStore,
+          m_Intrinsic<Intrinsic::x86_tileloadd64_internal>(
+              m_Value(M), m_Value(N), m_Value(Ptr), m_Value(Stride)));
+  else
+    match(TileLoadStore, m_Intrinsic<Intrinsic::x86_tilestored64_internal>(
+                             m_Value(M), m_Value(N), m_Value(Ptr),
+                             m_Value(Stride), m_Value(Tile)));
+
+  DomTreeUpdater DTU(DT, DomTreeUpdater::UpdateStrategy::Lazy);
+  Instruction *InsertI = TileLoadStore;
+  IRBuilder<> PreBuilder(TileLoadStore);
+  PreBuilder.SetInsertPoint(TileLoadStore);
+  Value *NDWord = PreBuilder.CreateLShr(N, PreBuilder.getInt16(2));
+  Value *StrideDWord = PreBuilder.CreateLShr(Stride, PreBuilder.getInt64(2));
+  BasicBlock *Start = InsertI->getParent();
+  BasicBlock *End =
+      SplitBlock(InsertI->getParent(), InsertI, DT, LI, nullptr, "continue");
+  IRBuilder<> Builder(TileLoadStore);
+  Value *ResVec = createTileLoadStoreLoops<isTileLoad>(
+      Start, End, Builder, DTU, *LI, M, NDWord, Ptr, StrideDWord,
+      isTileLoad ? nullptr : Tile);
+  if (isTileLoad) {
+    // we cannot assume there always be bitcast after tileload. So we need to
+    // insert one bitcast as required
+    Builder.SetInsertPoint(End->getFirstNonPHI());
+    Value *ResAMX =
+        Builder.CreateBitCast(ResVec, Type::getX86_AMXTy(Builder.getContext()));
+    // Delete tileloadd6 intrinsic and do some clean-up
+    for (auto UI = TileLoadStore->use_begin(), UE = TileLoadStore->use_end();
+         UI != UE;) {
+      Instruction *I = cast<Instruction>((UI++)->getUser());
+      Value *Vec;
+      if (match(I, m_BitCast(m_Value(Vec)))) {
+        I->replaceAllUsesWith(ResVec);
+        I->eraseFromParent();
+      }
+    }
+    TileLoadStore->replaceAllUsesWith(ResAMX);
+  }
+  TileLoadStore->eraseFromParent();
+  return true;
+}
+
+bool X86LowerAMXIntrinsics::lowerTileZero(Instruction *TileZero) {
+  IRBuilder<> Builder(TileZero);
+  FixedVectorType *V256I32Ty = FixedVectorType::get(Builder.getInt32Ty(), 256);
+  Value *VecZero = Constant::getNullValue(V256I32Ty);
+  for (auto UI = TileZero->use_begin(), UE = TileZero->use_end(); UI != UE;) {
+    Instruction *I = cast<Instruction>((UI++)->getUser());
+    Value *Vec;
+    if (match(I, m_BitCast(m_Value(Vec)))) {
+      I->replaceAllUsesWith(VecZero);
+      I->eraseFromParent();
+    }
+  }
+  TileZero->eraseFromParent();
+  return true;
+}
+
+bool X86LowerAMXIntrinsics::visit() {
+  bool C = false;
+  SmallVector<IntrinsicInst *, 8> WorkList;
+  for (BasicBlock *BB : depth_first(&Func)) {
+    for (BasicBlock::iterator II = BB->begin(), IE = BB->end(); II != IE;) {
+      if (auto *Inst = dyn_cast<IntrinsicInst>(&*II++)) {
+        switch (Inst->getIntrinsicID()) {
+        case Intrinsic::x86_tdpbssd_internal:
+        case Intrinsic::x86_tileloadd64_internal:
+        case Intrinsic::x86_tilestored64_internal:
+        case Intrinsic::x86_tilezero_internal:
+          WorkList.push_back(Inst);
+          break;
+        default:
+          break;
+        }
+      }
+    }
+  }
+
+  for (auto *Inst : WorkList) {
+    switch (Inst->getIntrinsicID()) {
+    case Intrinsic::x86_tdpbssd_internal:
+      C = lowerTileDPBSSD(Inst) || C;
+      break;
+    case Intrinsic::x86_tileloadd64_internal:
+      C = lowerTileLoadStore<true>(Inst) || C;
+      break;
+    case Intrinsic::x86_tilestored64_internal:
+      C = lowerTileLoadStore<false>(Inst) || C;
+      break;
+    case Intrinsic::x86_tilezero_internal:
+      C = lowerTileZero(Inst) || C;
+      break;
+    default:
+      llvm_unreachable("invalid amx intrinsics!");
+    }
+  }
+
+  return C;
+}
+} // anonymous namespace
+
+namespace {
+
+class X86LowerAMXIntrinsicsLegacyPass : public FunctionPass {
+public:
+  static char ID;
+
+  X86LowerAMXIntrinsicsLegacyPass() : FunctionPass(ID) {
+    initializeX86LowerAMXIntrinsicsLegacyPassPass(
+        *PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    TargetMachine *TM = &getAnalysis<TargetPassConfig>().getTM<TargetMachine>();
+    if (!F.hasFnAttribute(Attribute::OptimizeNone) &&
+        TM->getOptLevel() != CodeGenOpt::None)
+      return false;
+
+    auto &DT = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
+    auto &LI = getAnalysis<LoopInfoWrapperPass>().getLoopInfo();
+
+    X86LowerAMXIntrinsics LAT(F, &DT, &LI);
+    return LAT.visit();
+  }
+  StringRef getPassName() const override { return "Lower AMX intrinsics"; }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.addRequired<DominatorTreeWrapperPass>();
+    AU.addPreserved<DominatorTreeWrapperPass>();
+    AU.addRequired<LoopInfoWrapperPass>();
+    AU.addPreserved<LoopInfoWrapperPass>();
+    AU.addRequired<TargetPassConfig>();
+  }
+};
+
+} // anonymous namespace
+
+static const char PassName[] = "Lower AMX intrinsics";
+char X86LowerAMXIntrinsicsLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
+                      false, false)
+INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
+INITIALIZE_PASS_DEPENDENCY(TargetPassConfig)
+INITIALIZE_PASS_END(X86LowerAMXIntrinsicsLegacyPass, DEBUG_TYPE, PassName,
+                    false, false)
+
+FunctionPass *llvm::createX86LowerAMXIntrinsicsPass() {
+  return new X86LowerAMXIntrinsicsLegacyPass();
+}
Index: llvm/lib/Target/X86/X86.h
===================================================================
--- llvm/lib/Target/X86/X86.h
+++ llvm/lib/Target/X86/X86.h
@@ -169,6 +169,7 @@
 void initializeX86PreTileConfigPass(PassRegistry &);
 void initializeX86TileConfigPass(PassRegistry &);
 void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &);
+void initializeX86LowerAMXIntrinsicsLegacyPassPass(PassRegistry &);
 
 namespace X86AS {
 enum : unsigned {
Index: llvm/lib/Target/X86/CMakeLists.txt
===================================================================
--- llvm/lib/Target/X86/CMakeLists.txt
+++ llvm/lib/Target/X86/CMakeLists.txt
@@ -33,6 +33,7 @@
   X86DomainReassignment.cpp
   X86DiscriminateMemOps.cpp
   X86LowerAMXType.cpp
+  X86LowerAMXIntrinsics.cpp
   X86TileConfig.cpp
   X86PreTileConfig.cpp
   X86ExpandPseudo.cpp
Index: llvm/include/llvm/CodeGen/Passes.h
===================================================================
--- llvm/include/llvm/CodeGen/Passes.h
+++ llvm/include/llvm/CodeGen/Passes.h
@@ -489,9 +489,13 @@
   /// caller saved registers with stack slots.
   extern char &FixupStatepointCallerSavedID;
 
-  /// The pass transform load/store <256 x i32> to AMX load/store intrinsics
+  /// The pass transforms load/store <256 x i32> to AMX load/store intrinsics
   /// or split the data to two <128 x i32>.
   FunctionPass *createX86LowerAMXTypePass();
+
+  /// The pass transform amx intrinsics to scalar operation if the function has
+  /// optnone attribute or it is O0.
+  FunctionPass *createX86LowerAMXIntrinsicsPass();
 } // End llvm namespace
 
 #endif
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to