LuoYuanke updated this revision to Diff 301172.
LuoYuanke added a comment.
Herald added a subscriber: dexonsmith.

Rebase and Model tile configure register.
Configure register is the user of all the amx instruction.
Add X86PreTileConfig pass to insert psuedo ldtilecfg and
build the data depenedcy between config regiter and amx
instructions.


Repository:
  rG LLVM Github Monorepo

CHANGES SINCE LAST ACTION
  https://reviews.llvm.org/D87981/new/

https://reviews.llvm.org/D87981

Files:
  clang/include/clang/Basic/BuiltinsX86_64.def
  clang/lib/Headers/amxintrin.h
  clang/test/CodeGen/AMX/amx_api.c
  llvm/include/llvm/CodeGen/LiveIntervalUnion.h
  llvm/include/llvm/CodeGen/LiveRegMatrix.h
  llvm/include/llvm/CodeGen/Passes.h
  llvm/include/llvm/CodeGen/TileShapeInfo.h
  llvm/include/llvm/CodeGen/VirtRegMap.h
  llvm/include/llvm/IR/Intrinsics.td
  llvm/include/llvm/IR/IntrinsicsX86.td
  llvm/lib/CodeGen/InlineSpiller.cpp
  llvm/lib/CodeGen/LiveIntervalUnion.cpp
  llvm/lib/CodeGen/LiveRegMatrix.cpp
  llvm/lib/CodeGen/VirtRegMap.cpp
  llvm/lib/IR/Function.cpp
  llvm/lib/Target/X86/CMakeLists.txt
  llvm/lib/Target/X86/X86.h
  llvm/lib/Target/X86/X86ExpandPseudo.cpp
  llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
  llvm/lib/Target/X86/X86ISelLowering.cpp
  llvm/lib/Target/X86/X86InstrAMX.td
  llvm/lib/Target/X86/X86InstrInfo.cpp
  llvm/lib/Target/X86/X86LowerAMXType.cpp
  llvm/lib/Target/X86/X86PreTileConfig.cpp
  llvm/lib/Target/X86/X86RegisterInfo.cpp
  llvm/lib/Target/X86/X86RegisterInfo.h
  llvm/lib/Target/X86/X86RegisterInfo.td
  llvm/lib/Target/X86/X86Subtarget.h
  llvm/lib/Target/X86/X86TargetMachine.cpp
  llvm/lib/Target/X86/X86TileConfig.cpp
  llvm/test/CodeGen/X86/AMX/amx-across-func.ll
  llvm/test/CodeGen/X86/AMX/amx-config.ll
  llvm/test/CodeGen/X86/AMX/amx-spill.ll
  llvm/test/CodeGen/X86/AMX/amx-type.ll
  llvm/test/CodeGen/X86/O0-pipeline.ll
  llvm/test/CodeGen/X86/ipra-reg-usage.ll
  llvm/test/CodeGen/X86/opt-pipeline.ll
  llvm/test/CodeGen/X86/statepoint-fixup-invoke.mir
  llvm/test/CodeGen/X86/statepoint-fixup-shared-ehpad.mir
  llvm/utils/TableGen/IntrinsicEmitter.cpp

Index: llvm/utils/TableGen/IntrinsicEmitter.cpp
===================================================================
--- llvm/utils/TableGen/IntrinsicEmitter.cpp
+++ llvm/utils/TableGen/IntrinsicEmitter.cpp
@@ -197,25 +197,25 @@
 enum IIT_Info {
   // Common values should be encoded with 0-15.
   IIT_Done = 0,
-  IIT_I1   = 1,
-  IIT_I8   = 2,
-  IIT_I16  = 3,
-  IIT_I32  = 4,
-  IIT_I64  = 5,
-  IIT_F16  = 6,
-  IIT_F32  = 7,
-  IIT_F64  = 8,
-  IIT_V2   = 9,
-  IIT_V4   = 10,
-  IIT_V8   = 11,
-  IIT_V16  = 12,
-  IIT_V32  = 13,
-  IIT_PTR  = 14,
-  IIT_ARG  = 15,
+  IIT_I1 = 1,
+  IIT_I8 = 2,
+  IIT_I16 = 3,
+  IIT_I32 = 4,
+  IIT_I64 = 5,
+  IIT_F16 = 6,
+  IIT_F32 = 7,
+  IIT_F64 = 8,
+  IIT_V2 = 9,
+  IIT_V4 = 10,
+  IIT_V8 = 11,
+  IIT_V16 = 12,
+  IIT_V32 = 13,
+  IIT_PTR = 14,
+  IIT_ARG = 15,
 
   // Values from 16+ are only encodable with the inefficient encoding.
-  IIT_V64  = 16,
-  IIT_MMX  = 17,
+  IIT_V64 = 16,
+  IIT_MMX = 17,
   IIT_TOKEN = 18,
   IIT_METADATA = 19,
   IIT_EMPTYSTRUCT = 20,
@@ -226,7 +226,7 @@
   IIT_EXTEND_ARG = 25,
   IIT_TRUNC_ARG = 26,
   IIT_ANYPTR = 27,
-  IIT_V1   = 28,
+  IIT_V1 = 28,
   IIT_VARARG = 29,
   IIT_HALF_VEC_ARG = 30,
   IIT_SAME_VEC_WIDTH_ARG = 31,
Index: llvm/test/CodeGen/X86/statepoint-fixup-shared-ehpad.mir
===================================================================
--- llvm/test/CodeGen/X86/statepoint-fixup-shared-ehpad.mir
+++ llvm/test/CodeGen/X86/statepoint-fixup-shared-ehpad.mir
@@ -108,7 +108,7 @@
   ; CHECK:   ADJCALLSTACKDOWN64 0, 0, 0, implicit-def dead $rsp, implicit-def dead $eflags, implicit-def dead $ssp, implicit $rsp, implicit $ssp
   ; CHECK:   MOV64mr [[STACK0:%stack.[0-9]+]], 1, $noreg, 0, $noreg, killed $rbx :: (store 8 into [[STACK0]])
   ; CHECK:   MOV64mr [[STACK1:%stack.[0-9]+]], 1, $noreg, 0, $noreg, killed $r14 :: (store 8 into [[STACK1]])
-  ; CHECK:   STATEPOINT 0, 0, 0, @foo, 2, 0, 2, 0, 2, 0, 2, 2, 1, 8, [[STACK0]], 0, 1, 8, [[STACK1]], 0, 2, 0, 2, 2, 0, 0, 1, 1, csr_64, implicit-def $rsp, implicit-def $ssp :: (load store 8 on [[STACK0]]), (load store 8 on [[STACK1]])
+  ; CHECK:   STATEPOINT 0, 0, 0, @foo, 2, 0, 2, 0, 2, 0, 2, 2, 1, 8, [[STACK0]], 0, 1, 8, [[STACK1]], 0, 2, 0, 2, 2, 0, 0, 1, 1, csr_64, implicit-def $rsp, implicit-def $ssp :: (load store 8 on [[STACK1]]), (load store 8 on [[STACK0]])
   ; CHECK-DAG:   $rbx = MOV64rm [[STACK0]], 1, $noreg, 0, $noreg :: (load 8 from [[STACK0]])
   ; CHECK-DAG:   $r14 = MOV64rm [[STACK1]], 1, $noreg, 0, $noreg :: (load 8 from [[STACK1]])
   ; CHECK:   ADJCALLSTACKUP64 0, 0, implicit-def dead $rsp, implicit-def dead $eflags, implicit-def dead $ssp, implicit $rsp, implicit $ssp
@@ -121,7 +121,7 @@
   ; CHECK:   ADJCALLSTACKDOWN64 0, 0, 0, implicit-def dead $rsp, implicit-def dead $eflags, implicit-def dead $ssp, implicit $rsp, implicit $ssp
   ; CHECK-DAG:   MOV64mr [[STACK0]], 1, $noreg, 0, $noreg, killed $rbx :: (store 8 into [[STACK0]])
   ; CHECK-DAG:   MOV64mr [[STACK1]], 1, $noreg, 0, $noreg, killed $r14 :: (store 8 into [[STACK1]])
-  ; CHECK:   STATEPOINT 0, 0, 0, @bar, 2, 0, 2, 0, 2, 0, 2, 2, 1, 8, %stack.0, 0, 1, 8, [[STACK1]], 0, 2, 0, 2, 2, 0, 0, 1, 1, csr_64, implicit-def $rsp, implicit-def $ssp :: (load store 8 on [[STACK0]]), (load store 8 on [[STACK1]])
+  ; CHECK:   STATEPOINT 0, 0, 0, @bar, 2, 0, 2, 0, 2, 0, 2, 2, 1, 8, %stack.0, 0, 1, 8, [[STACK1]], 0, 2, 0, 2, 2, 0, 0, 1, 1, csr_64, implicit-def $rsp, implicit-def $ssp :: (load store 8 on [[STACK1]]), (load store 8 on [[STACK0]])
   ; CHECK-DAG:   $rbx = MOV64rm [[STACK0]], 1, $noreg, 0, $noreg :: (load 8 from [[STACK0]])
   ; CHECK-DAG:   $r14 = MOV64rm [[STACK1]], 1, $noreg, 0, $noreg :: (load 8 from [[STACK1]])
   ; CHECK:   ADJCALLSTACKUP64 0, 0, implicit-def dead $rsp, implicit-def dead $eflags, implicit-def dead $ssp, implicit $rsp, implicit $ssp
Index: llvm/test/CodeGen/X86/statepoint-fixup-invoke.mir
===================================================================
--- llvm/test/CodeGen/X86/statepoint-fixup-invoke.mir
+++ llvm/test/CodeGen/X86/statepoint-fixup-invoke.mir
@@ -91,7 +91,7 @@
   ; CHECK-DAG:   MOV64mr %stack.1, 1, $noreg, 0, $noreg, $rdi :: (store 8 into %stack.1)
   ; CHECK:   EH_LABEL <mcsymbol .Ltmp0>
   ; CHECK:   ADJCALLSTACKDOWN64 0, 0, 0, implicit-def dead $rsp, implicit-def dead $eflags, implicit-def dead $ssp, implicit $rsp, implicit $ssp
-  ; CHECK:   STATEPOINT 0, 0, 1, @some_call, $rdi, 2, 0, 2, 0, 2, 5, 2, 0, 2, -1, 2, 0, 2, 0, 2, 0, 2, 2, 1, 8, %stack.0, 0, 1, 8, %stack.1, 0, 2, 0, 2, 2, 0, 0, 1, 1, csr_64, implicit-def $rsp, implicit-def $ssp :: (load store 8 on %stack.1), (load store 8 on %stack.0)
+  ; CHECK:   STATEPOINT 0, 0, 1, @some_call, $rdi, 2, 0, 2, 0, 2, 5, 2, 0, 2, -1, 2, 0, 2, 0, 2, 0, 2, 2, 1, 8, %stack.0, 0, 1, 8, %stack.1, 0, 2, 0, 2, 2, 0, 0, 1, 1, csr_64, implicit-def $rsp, implicit-def $ssp :: (load store 8 on %stack.0), (load store 8 on %stack.1)
   ; CHECK-DAG:   $r14 = MOV64rm %stack.0, 1, $noreg, 0, $noreg :: (load 8 from %stack.0)
   ; CHECK-DAG:   $rbx = MOV64rm %stack.1, 1, $noreg, 0, $noreg :: (load 8 from %stack.1)
   ; CHECK:   ADJCALLSTACKUP64 0, 0, implicit-def dead $rsp, implicit-def dead $eflags, implicit-def dead $ssp, implicit $rsp, implicit $ssp
Index: llvm/test/CodeGen/X86/opt-pipeline.ll
===================================================================
--- llvm/test/CodeGen/X86/opt-pipeline.ll
+++ llvm/test/CodeGen/X86/opt-pipeline.ll
@@ -24,6 +24,7 @@
 ; CHECK-NEXT:     Pre-ISel Intrinsic Lowering
 ; CHECK-NEXT:     FunctionPass Manager
 ; CHECK-NEXT:       Expand Atomic instructions
+; CHECK-NEXT:       Lower AMX type for load/store
 ; CHECK-NEXT:       Module Verifier
 ; CHECK-NEXT:       Dominator Tree Construction
 ; CHECK-NEXT:       Basic Alias Analysis (stateless AA impl)
@@ -118,11 +119,12 @@
 ; CHECK-NEXT:       MachineDominator Tree Construction
 ; CHECK-NEXT:       X86 EFLAGS copy lowering
 ; CHECK-NEXT:       X86 WinAlloca Expander
+; CHECK-NEXT:       MachineDominator Tree Construction
+; CHECK-NEXT:       Tile Register Pre-configure
 ; CHECK-NEXT:       Detect Dead Lanes
 ; CHECK-NEXT:       Process Implicit Definitions
 ; CHECK-NEXT:       Remove unreachable machine basic blocks
 ; CHECK-NEXT:       Live Variable Analysis
-; CHECK-NEXT:       MachineDominator Tree Construction
 ; CHECK-NEXT:       Machine Natural Loop Construction
 ; CHECK-NEXT:       Eliminate PHI nodes for register allocation
 ; CHECK-NEXT:       Two-Address instruction pass
@@ -141,6 +143,7 @@
 ; CHECK-NEXT:       Lazy Machine Block Frequency Analysis
 ; CHECK-NEXT:       Machine Optimization Remark Emitter
 ; CHECK-NEXT:       Greedy Register Allocator
+; CHECK-NEXT:       Tile Register Configure
 ; CHECK-NEXT:       Virtual Register Rewriter
 ; CHECK-NEXT:       Stack Slot Coloring
 ; CHECK-NEXT:       Machine Copy Propagation Pass
Index: llvm/test/CodeGen/X86/ipra-reg-usage.ll
===================================================================
--- llvm/test/CodeGen/X86/ipra-reg-usage.ll
+++ llvm/test/CodeGen/X86/ipra-reg-usage.ll
@@ -3,7 +3,7 @@
 target triple = "x86_64-unknown-unknown"
 declare void @bar1()
 define preserve_allcc void @foo()#0 {
-; CHECK: foo Clobbered Registers: $cs $df $ds $eflags $eip $eiz $es $fpcw $fpsw $fs $gs $hip $ip $mxcsr $rip $riz $ss $ssp $bnd0 $bnd1 $bnd2 $bnd3 $cr0 $cr1 $cr2 $cr3 $cr4 $cr5 $cr6 $cr7 $cr8 $cr9 $cr10 $cr11 $cr12 $cr13 $cr14 $cr15 $dr0 $dr1 $dr2 $dr3 $dr4 $dr5 $dr6 $dr7 $dr8 $dr9 $dr10 $dr11 $dr12 $dr13 $dr14 $dr15 $fp0 $fp1 $fp2 $fp3 $fp4 $fp5 $fp6 $fp7 $k0 $k1 $k2 $k3 $k4 $k5 $k6 $k7 $mm0 $mm1 $mm2 $mm3 $mm4 $mm5 $mm6 $mm7 $r11 $st0 $st1 $st2 $st3 $st4 $st5 $st6 $st7 $tmm0 $tmm1 $tmm2 $tmm3 $tmm4 $tmm5 $tmm6 $tmm7 $xmm16 $xmm17 $xmm18 $xmm19 $xmm20 $xmm21 $xmm22 $xmm23 $xmm24 $xmm25 $xmm26 $xmm27 $xmm28 $xmm29 $xmm30 $xmm31 $ymm0 $ymm1 $ymm2 $ymm3 $ymm4 $ymm5 $ymm6 $ymm7 $ymm8 $ymm9 $ymm10 $ymm11 $ymm12 $ymm13 $ymm14 $ymm15 $ymm16 $ymm17 $ymm18 $ymm19 $ymm20 $ymm21 $ymm22 $ymm23 $ymm24 $ymm25 $ymm26 $ymm27 $ymm28 $ymm29 $ymm30 $ymm31 $zmm0 $zmm1 $zmm2 $zmm3 $zmm4 $zmm5 $zmm6 $zmm7 $zmm8 $zmm9 $zmm10 $zmm11 $zmm12 $zmm13 $zmm14 $zmm15 $zmm16 $zmm17 $zmm18 $zmm19 $zmm20 $zmm21 $zmm22 $zmm23 $zmm24 $zmm25 $zmm26 $zmm27 $zmm28 $zmm29 $zmm30 $zmm31 $r11b $r11bh $r11d $r11w $r11wh
+; CHECK: foo Clobbered Registers: $cs $df $ds $eflags $eip $eiz $es $fpcw $fpsw $fs $gs $hip $ip $mxcsr $rip $riz $ss $ssp $tmmcfg $bnd0 $bnd1 $bnd2 $bnd3 $cr0 $cr1 $cr2 $cr3 $cr4 $cr5 $cr6 $cr7 $cr8 $cr9 $cr10 $cr11 $cr12 $cr13 $cr14 $cr15 $dr0 $dr1 $dr2 $dr3 $dr4 $dr5 $dr6 $dr7 $dr8 $dr9 $dr10 $dr11 $dr12 $dr13 $dr14 $dr15 $fp0 $fp1 $fp2 $fp3 $fp4 $fp5 $fp6 $fp7 $k0 $k1 $k2 $k3 $k4 $k5 $k6 $k7 $mm0 $mm1 $mm2 $mm3 $mm4 $mm5 $mm6 $mm7 $r11 $st0 $st1 $st2 $st3 $st4 $st5 $st6 $st7 $tmm0 $tmm1 $tmm2 $tmm3 $tmm4 $tmm5 $tmm6 $tmm7 $xmm16 $xmm17 $xmm18 $xmm19 $xmm20 $xmm21 $xmm22 $xmm23 $xmm24 $xmm25 $xmm26 $xmm27 $xmm28 $xmm29 $xmm30 $xmm31 $ymm0 $ymm1 $ymm2 $ymm3 $ymm4 $ymm5 $ymm6 $ymm7 $ymm8 $ymm9 $ymm10 $ymm11 $ymm12 $ymm13 $ymm14 $ymm15 $ymm16 $ymm17 $ymm18 $ymm19 $ymm20 $ymm21 $ymm22 $ymm23 $ymm24 $ymm25 $ymm26 $ymm27 $ymm28 $ymm29 $ymm30 $ymm31 $zmm0 $zmm1 $zmm2 $zmm3 $zmm4 $zmm5 $zmm6 $zmm7 $zmm8 $zmm9 $zmm10 $zmm11 $zmm12 $zmm13 $zmm14 $zmm15 $zmm16 $zmm17 $zmm18 $zmm19 $zmm20 $zmm21 $zmm22 $zmm23 $zmm24 $zmm25 $zmm26 $zmm27 $zmm28 $zmm29 $zmm30 $zmm31 $r11b $r11bh $r11d $r11w $r11wh $k0_k1 $k2_k3 $k4_k5 $k6_k7
   call void @bar1()
   call void @bar2()
   ret void
Index: llvm/test/CodeGen/X86/O0-pipeline.ll
===================================================================
--- llvm/test/CodeGen/X86/O0-pipeline.ll
+++ llvm/test/CodeGen/X86/O0-pipeline.ll
@@ -18,6 +18,7 @@
 ; CHECK-NEXT:     Pre-ISel Intrinsic Lowering
 ; CHECK-NEXT:     FunctionPass Manager
 ; CHECK-NEXT:       Expand Atomic instructions
+; CHECK-NEXT:       Lower AMX type for load/store
 ; CHECK-NEXT:       Module Verifier
 ; CHECK-NEXT:       Lower Garbage Collection Instructions
 ; CHECK-NEXT:       Shadow Stack GC Lowering
Index: llvm/test/CodeGen/X86/AMX/amx-type.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-type.ll
@@ -0,0 +1,143 @@
+; NOTE: Assertions have been autogenerated by utils/update_test_checks.py
+; RUN: opt -lower-amx-type %s -S | FileCheck %s
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+%struct.__tile_str = type { i16, i16, <256 x i32> }
+
+@buf = dso_local global [1024 x i8] zeroinitializer, align 16
+@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16
+
+define dso_local void @test_load(i8* %in, i8* %out) local_unnamed_addr #2 {
+; CHECK-LABEL: @test_load(
+; CHECK-NEXT:    [[TMP1:%.*]] = bitcast i8* [[IN:%.*]] to <256 x i32>*
+; CHECK-NEXT:    [[TMP2:%.*]] = bitcast i8* [[OUT:%.*]] to <256 x i32>*
+; CHECK-NEXT:    [[TMP3:%.*]] = bitcast <256 x i32>* [[TMP1]] to <128 x i32>*
+; CHECK-NEXT:    [[TMP4:%.*]] = load <128 x i32>, <128 x i32>* [[TMP3]], align 64
+; CHECK-NEXT:    [[TMP5:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP3]], i32 1
+; CHECK-NEXT:    [[TMP6:%.*]] = load <128 x i32>, <128 x i32>* [[TMP5]], align 64
+; CHECK-NEXT:    [[TMP7:%.*]] = bitcast <256 x i32>* [[TMP2]] to <128 x i32>*
+; CHECK-NEXT:    store <128 x i32> [[TMP4]], <128 x i32>* [[TMP7]], align 64
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr <128 x i32>, <128 x i32>* [[TMP7]], i32 1
+; CHECK-NEXT:    store <128 x i32> [[TMP6]], <128 x i32>* [[TMP8]], align 64
+; CHECK-NEXT:    ret void
+;
+  %1 = bitcast i8* %in to <256 x i32>*
+  %2 = bitcast i8* %out to <256 x i32>*
+  %3 = load <256 x i32>, <256 x i32>* %1, align 64, !tbaa !8
+  store <256 x i32> %3, <256 x i32>* %2, align 64, !tbaa !8
+  ret void
+}
+
+define dso_local void @__tile_loadd(%struct.__tile_str* nocapture %0, i8* %1, i64 %2) local_unnamed_addr #0 {
+; CHECK-LABEL: @__tile_loadd(
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2:!tbaa !.*]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7:!tbaa !.*]]
+; CHECK-NEXT:    [[TMP8:%.*]] = shl i64 [[TMP2:%.*]], 32
+; CHECK-NEXT:    [[TMP9:%.*]] = ashr exact i64 [[TMP8]], 32
+; CHECK-NEXT:    [[TMP10:%.*]] = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP1:%.*]], i64 [[TMP9]]) [[ATTR3:#.*]]
+; CHECK-NEXT:    [[TMP11:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP12:%.*]] = bitcast <256 x i32>* [[TMP11]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP12]], i64 64, <256 x i32> [[TMP10]])
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64, !tbaa !2
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2, !tbaa !7
+  %8 = shl i64 %2, 32
+  %9 = ashr exact i64 %8, 32
+  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %5, i16 %7, i8* %1, i64 %9) #3
+  %11 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+  store <256 x i32> %10, <256 x i32>* %11, align 64, !tbaa !8
+  ret void
+}
+
+define dso_local void @__tile_dpbsud(%struct.__tile_str* nocapture %0, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #0 {
+; CHECK-LABEL: @__tile_dpbsud(
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP1:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP9:%.*]] = load i16, i16* [[TMP8]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP10:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP0:%.*]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP11:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
+; CHECK-NEXT:    [[TMP12:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP11]], i64 64)
+; CHECK-NEXT:    [[TMP13:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP1]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP14:%.*]] = bitcast <256 x i32>* [[TMP13]] to i8*
+; CHECK-NEXT:    [[TMP15:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP9]], i8* [[TMP14]], i64 64)
+; CHECK-NEXT:    [[TMP16:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP17:%.*]] = bitcast <256 x i32>* [[TMP16]] to i8*
+; CHECK-NEXT:    [[TMP18:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP9]], i16 [[TMP7]], i8* [[TMP17]], i64 64)
+; CHECK-NEXT:    [[TMP19:%.*]] = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 [[TMP5]], i16 [[TMP7]], i16 [[TMP9]], <256 x i32> [[TMP12]], <256 x i32> [[TMP15]], <256 x i32> [[TMP18]]) [[ATTR3]]
+; CHECK-NEXT:    [[TMP20:%.*]] = bitcast <256 x i32>* [[TMP10]] to i8*
+; CHECK-NEXT:    call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP20]], i64 64, <256 x i32> [[TMP19]])
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64, !tbaa !2
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2, !tbaa !7
+  %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 1
+  %9 = load i16, i16* %8, align 2, !tbaa !7
+  %10 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %0, i64 0, i32 2
+  %11 = load <256 x i32>, <256 x i32>* %10, align 64, !tbaa !8
+  %12 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %1, i64 0, i32 2
+  %13 = load <256 x i32>, <256 x i32>* %12, align 64, !tbaa !8
+  %14 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+  %15 = load <256 x i32>, <256 x i32>* %14, align 64, !tbaa !8
+  %16 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %5, i16 %7, i16 %9, <256 x i32> %11, <256 x i32> %13, <256 x i32> %15) #3
+  store <256 x i32> %16, <256 x i32>* %10, align 64, !tbaa !8
+  ret void
+}
+
+define dso_local void @__tile_stored(i8* %0, i64 %1, %struct.__tile_str* nocapture readonly byval(%struct.__tile_str) align 64 %2) local_unnamed_addr #1 {
+; CHECK-LABEL: @__tile_stored(
+; CHECK-NEXT:    [[TMP4:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR:%.*]], %struct.__tile_str* [[TMP2:%.*]], i64 0, i32 0
+; CHECK-NEXT:    [[TMP5:%.*]] = load i16, i16* [[TMP4]], align 64, [[TBAA2]]
+; CHECK-NEXT:    [[TMP6:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 1
+; CHECK-NEXT:    [[TMP7:%.*]] = load i16, i16* [[TMP6]], align 2, [[TBAA7]]
+; CHECK-NEXT:    [[TMP8:%.*]] = getelementptr inbounds [[STRUCT___TILE_STR]], %struct.__tile_str* [[TMP2]], i64 0, i32 2
+; CHECK-NEXT:    [[TMP9:%.*]] = bitcast <256 x i32>* [[TMP8]] to i8*
+; CHECK-NEXT:    [[TMP10:%.*]] = call <256 x i32> @llvm.x86.tileloadd64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP9]], i64 64)
+; CHECK-NEXT:    [[TMP11:%.*]] = shl i64 [[TMP1:%.*]], 32
+; CHECK-NEXT:    [[TMP12:%.*]] = ashr exact i64 [[TMP11]], 32
+; CHECK-NEXT:    tail call void @llvm.x86.tilestored64.internal(i16 [[TMP5]], i16 [[TMP7]], i8* [[TMP0:%.*]], i64 [[TMP12]], <256 x i32> [[TMP10]]) [[ATTR3]]
+; CHECK-NEXT:    ret void
+;
+  %4 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 0
+  %5 = load i16, i16* %4, align 64, !tbaa !2
+  %6 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 1
+  %7 = load i16, i16* %6, align 2, !tbaa !7
+  %8 = getelementptr inbounds %struct.__tile_str, %struct.__tile_str* %2, i64 0, i32 2
+  %9 = load <256 x i32>, <256 x i32>* %8, align 64, !tbaa !8
+  %10 = shl i64 %1, 32
+  %11 = ashr exact i64 %10, 32
+  tail call void @llvm.x86.tilestored64.internal(i16 %5, i16 %7, i8* %0, i64 %11, <256 x i32> %9) #3
+  ret void
+}
+
+declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+
+attributes #0 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #1 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #2 = { alwaysinline nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="0" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #3 = { nounwind }
+
+!llvm.module.flags = !{!0}
+!llvm.ident = !{!1}
+
+!0 = !{i32 1, !"wchar_size", i32 4}
+!1 = !{!"clang version 12.0.0 (ssh://git-amr-1.devtools.intel.com:29418/dpd_icl-llvm_project_worldread f3c78a3f053379a2511e00e9ce2c13383ea3f835)"}
+!2 = !{!3, !4, i64 0}
+!3 = !{!"__tile_str", !4, i64 0, !4, i64 2, !5, i64 1024}
+!4 = !{!"short", !5, i64 0}
+!5 = !{!"omnipotent char", !6, i64 0}
+!6 = !{!"Simple C/C++ TBAA"}
+!7 = !{!3, !4, i64 2}
+!8 = !{!5, !5, i64 0}
Index: llvm/test/CodeGen/X86/AMX/amx-spill.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-spill.ll
@@ -0,0 +1,110 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -mattr=+avx512f -verify-machineinstrs | FileCheck %s
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+@buf = dso_local global [1024 x i8] zeroinitializer, align 16
+@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16
+
+define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) local_unnamed_addr #2 {
+; CHECK-LABEL: test_api:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    subq $2936, %rsp # imm = 0xB78
+; CHECK-NEXT:    .cfi_def_cfa_offset 2944
+; CHECK-NEXT:    vpxord %zmm0, %zmm0, %zmm0
+; CHECK-NEXT:    vmovdqu64 %zmm0, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %dl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %dl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %dl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %dl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %si, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %sil, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    ldtilecfg {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movl $buf, %r8d
+; CHECK-NEXT:    movl $32, %eax
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm1
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm1
+; CHECK-NEXT:    movabsq $64, %rcx
+; CHECK-NEXT:    tilestored %tmm1, 896(%rsp,%rcx) # 1024-byte Folded Spill
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm3
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm4
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm2
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm5
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm0
+; CHECK-NEXT:    testl %edi, %edi
+; CHECK-NEXT:    je .LBB0_2
+; CHECK-NEXT:  # %bb.1:
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm6
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm7
+; CHECK-NEXT:    tileloadd (%r8,%rax), %tmm1
+; CHECK-NEXT:    jmp .LBB0_3
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    movl $buf2, %ecx
+; CHECK-NEXT:    tileloadd (%rcx,%rax), %tmm6
+; CHECK-NEXT:    tileloadd (%rcx,%rax), %tmm7
+; CHECK-NEXT:    tileloadd (%rcx,%rax), %tmm1
+; CHECK-NEXT:  .LBB0_3:
+; CHECK-NEXT:    tdpbssd %tmm7, %tmm6, %tmm1
+; CHECK-NEXT:    movabsq $64, %rax
+; CHECK-NEXT:    tileloadd 896(%rsp,%rax), %tmm7 # 1024-byte Folded Reload
+; CHECK-NEXT:    tdpbssd %tmm7, %tmm1, %tmm3
+; CHECK-NEXT:    tdpbssd %tmm4, %tmm3, %tmm2
+; CHECK-NEXT:    tdpbssd %tmm5, %tmm2, %tmm0
+; CHECK-NEXT:    movl $buf, %eax
+; CHECK-NEXT:    movl $32, %ecx
+; CHECK-NEXT:    tilestored %tmm0, (%rax,%rcx)
+; CHECK-NEXT:    addq $2936, %rsp # imm = 0xB78
+; CHECK-NEXT:    .cfi_def_cfa_offset 8
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+  %4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %6 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %7 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %11 = icmp eq i32 %0, 0
+  br i1 %11, label %16, label %12
+
+12:                                               ; preds = %3
+  %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %15 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  br label %20
+
+16:                                               ; preds = %3
+  %17 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %18 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %19 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %1, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  br label %20
+
+20:                                               ; preds = %16, %12
+  %21 = phi <256 x i32> [ %17, %16 ], [ %13, %12 ]
+  %22 = phi <256 x i32> [ %18, %16 ], [ %14, %12 ]
+  %23 = phi <256 x i32> [ %19, %16 ], [ %15, %12 ]
+  %24 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %1, <256 x i32> %23, <256 x i32> %21, <256 x i32> %22) #3
+  %25 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %6, <256 x i32> %24, <256 x i32> %5) #3
+  %26 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %1, i16 %2, i16 %2, <256 x i32> %8, <256 x i32> %25, <256 x i32> %7) #3
+  %27 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %2, i16 %2, i16 %2, <256 x i32> %10, <256 x i32> %26, <256 x i32> %9) #3
+  tail call void @llvm.x86.tilestored64.internal(i16 %2, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %27) #3
+  ret void
+}
+
+declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+
+attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #3 = { nounwind }
Index: llvm/test/CodeGen/X86/AMX/amx-config.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-config.ll
@@ -0,0 +1,75 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -verify-machineinstrs | FileCheck %s
+
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+@buf = dso_local global [1024 x i8] zeroinitializer, align 16
+@buf2 = dso_local global [1024 x i8] zeroinitializer, align 16
+
+; Function Attrs: nounwind uwtable
+define dso_local void @test_api(i32 %0, i16 signext %1, i16 signext %2) local_unnamed_addr #2 {
+; CHECK-LABEL: test_api:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    movsbl %sil, %eax
+; CHECK-NEXT:    vpxord %zmm0, %zmm0, %zmm0
+; CHECK-NEXT:    vmovdqu64 %zmm0, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %al, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %si, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %al, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %al, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %dx, -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    ldtilecfg -{{[0-9]+}}(%rsp)
+; CHECK-NEXT:    testl %edi, %edi
+; CHECK-NEXT:    je .LBB0_2
+; CHECK-NEXT:  # %bb.1:
+; CHECK-NEXT:    movl $buf, %ecx
+; CHECK-NEXT:    jmp .LBB0_3
+; CHECK-NEXT:  .LBB0_2:
+; CHECK-NEXT:    movl $buf2, %ecx
+; CHECK-NEXT:  .LBB0_3:
+; CHECK-NEXT:    movl $32, %edi
+; CHECK-NEXT:    tileloadd (%rcx,%rdi), %tmm0
+; CHECK-NEXT:    tileloadd (%rcx,%rdi), %tmm2
+; CHECK-NEXT:    tileloadd (%rcx,%rdi), %tmm1
+; CHECK-NEXT:    tdpbssd %tmm2, %tmm0, %tmm1
+; CHECK-NEXT:    movl $buf, %ecx
+; CHECK-NEXT:    movl $32, %esi
+; CHECK-NEXT:    tilestored %tmm1, (%rcx,%rsi)
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    retq
+  %4 = icmp eq i32 %0, 0
+  %5 = shl i16 %1, 8
+  %6 = ashr exact i16 %5, 8
+  br i1 %4, label %11, label %7
+
+7:                                                ; preds = %3
+  %8 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %9 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  %10 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32) #3
+  br label %15
+
+11:                                               ; preds = %3
+  %12 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %1, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %13 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  %14 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf2, i64 0, i64 0), i64 32) #3
+  br label %15
+
+15:                                               ; preds = %11, %7
+  %16 = phi <256 x i32> [ %12, %11 ], [ %8, %7 ]
+  %17 = phi <256 x i32> [ %13, %11 ], [ %9, %7 ]
+  %18 = phi <256 x i32> [ %14, %11 ], [ %10, %7 ]
+  %19 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %6, i16 %2, i16 %1, <256 x i32> %18, <256 x i32> %16, <256 x i32> %17) #3
+  tail call void @llvm.x86.tilestored64.internal(i16 %6, i16 %2, i8* getelementptr inbounds ([1024 x i8], [1024 x i8]* @buf, i64 0, i64 0), i64 32, <256 x i32> %19) #3
+  ret void
+}
+
+declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #3
+
+declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #3
+
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #3
+
+attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+avx,+avx2,+avx512f,+cx8,+f16c,+fma,+fxsr,+mmx,+popcnt,+sse,+sse2,+sse3,+sse4.1,+sse4.2,+ssse3,+x87,+xsave" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #3 = { nounwind }
Index: llvm/test/CodeGen/X86/AMX/amx-across-func.ll
===================================================================
--- /dev/null
+++ llvm/test/CodeGen/X86/AMX/amx-across-func.ll
@@ -0,0 +1,89 @@
+; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py
+; RUN: llc < %s -mtriple=x86_64-unknown-unknown -mattr=+amx-int8 -mattr=+avx512f -verify-machineinstrs | FileCheck %s
+target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128"
+target triple = "x86_64-unknown-linux-gnu"
+
+%struct.__tile_str = type <{ i16, i16, [60 x i8], <256 x i32> }>
+
+@buf = dso_local global [3072 x i8] zeroinitializer, align 16
+
+define dso_local void @test_api(i16 signext %0, i16 signext %1) local_unnamed_addr #2 {
+; CHECK-LABEL: test_api:
+; CHECK:       # %bb.0:
+; CHECK-NEXT:    pushq %rbp
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    pushq %r15
+; CHECK-NEXT:    .cfi_def_cfa_offset 24
+; CHECK-NEXT:    pushq %r14
+; CHECK-NEXT:    .cfi_def_cfa_offset 32
+; CHECK-NEXT:    pushq %rbx
+; CHECK-NEXT:    .cfi_def_cfa_offset 40
+; CHECK-NEXT:    subq $4056, %rsp # imm = 0xFD8
+; CHECK-NEXT:    .cfi_def_cfa_offset 4096
+; CHECK-NEXT:    .cfi_offset %rbx, -40
+; CHECK-NEXT:    .cfi_offset %r14, -32
+; CHECK-NEXT:    .cfi_offset %r15, -24
+; CHECK-NEXT:    .cfi_offset %rbp, -16
+; CHECK-NEXT:    movl %esi, %ebx
+; CHECK-NEXT:    movl %edi, %ebp
+; CHECK-NEXT:    vpxord %zmm0, %zmm0, %zmm0
+; CHECK-NEXT:    vmovdqu64 %zmm0, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %bpl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %bx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb %bpl, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw $8, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movb $8, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    movw %bx, {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    ldtilecfg {{[0-9]+}}(%rsp)
+; CHECK-NEXT:    sttilecfg {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Folded Spill
+; CHECK-NEXT:    movl $buf, %eax
+; CHECK-NEXT:    movl $32, %r14d
+; CHECK-NEXT:    movw $8, %r15w
+; CHECK-NEXT:    tileloadd (%rax,%r14), %tmm1
+; CHECK-NEXT:    movabsq $64, %rax
+; CHECK-NEXT:    tilestored %tmm1, 2048(%rsp,%rax) # 1024-byte Folded Spill
+; CHECK-NEXT:    movl $buf+1024, %eax
+; CHECK-NEXT:    tileloadd (%rax,%r14), %tmm2
+; CHECK-NEXT:    movabsq $64, %rax
+; CHECK-NEXT:    tilestored %tmm2, 1024(%rsp,%rax) # 1024-byte Folded Spill
+; CHECK-NEXT:    xorl %eax, %eax
+; CHECK-NEXT:    vzeroupper
+; CHECK-NEXT:    callq foo
+; CHECK-NEXT:    movl $buf+2048, %eax
+; CHECK-NEXT:    ldtilecfg {{[-0-9]+}}(%r{{[sb]}}p) # 64-byte Folded Reload
+; CHECK-NEXT:    tileloadd (%rax,%r14), %tmm0
+; CHECK-NEXT:    movabsq $64, %rcx
+; CHECK-NEXT:    tileloadd 2048(%rsp,%rcx), %tmm1 # 1024-byte Folded Reload
+; CHECK-NEXT:    movabsq $64, %rcx
+; CHECK-NEXT:    tileloadd 1024(%rsp,%rcx), %tmm2 # 1024-byte Folded Reload
+; CHECK-NEXT:    tdpbssd %tmm2, %tmm1, %tmm0
+; CHECK-NEXT:    tilestored %tmm0, (%rax,%r14)
+; CHECK-NEXT:    addq $4056, %rsp # imm = 0xFD8
+; CHECK-NEXT:    .cfi_def_cfa_offset 40
+; CHECK-NEXT:    popq %rbx
+; CHECK-NEXT:    .cfi_def_cfa_offset 32
+; CHECK-NEXT:    popq %r14
+; CHECK-NEXT:    .cfi_def_cfa_offset 24
+; CHECK-NEXT:    popq %r15
+; CHECK-NEXT:    .cfi_def_cfa_offset 16
+; CHECK-NEXT:    popq %rbp
+; CHECK-NEXT:    .cfi_def_cfa_offset 8
+; CHECK-NEXT:    retq
+  %3 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %0, i16 8, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 0), i64 32) #4
+  %4 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 8, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 1024), i64 32) #4
+  tail call void (...) @foo() #4
+  %5 = tail call <256 x i32> @llvm.x86.tileloadd64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32) #4
+  %6 = tail call <256 x i32> @llvm.x86.tdpbssd.internal(i16 %0, i16 %1, i16 8, <256 x i32> %5, <256 x i32> %3, <256 x i32> %4) #4
+  tail call void @llvm.x86.tilestored64.internal(i16 %0, i16 %1, i8* getelementptr inbounds ([3072 x i8], [3072 x i8]* @buf, i64 0, i64 2048), i64 32, <256 x i32> %6) #4
+  ret void
+}
+
+declare dso_local void @foo(...) local_unnamed_addr #3
+
+declare <256 x i32> @llvm.x86.tileloadd64.internal(i16, i16, i8*, i64) #4
+declare <256 x i32> @llvm.x86.tdpbssd.internal(i16, i16, i16, <256 x i32>, <256 x i32>, <256 x i32>) #4
+declare void @llvm.x86.tilestored64.internal(i16, i16, i8*, i64, <256 x i32>) #4
+
+attributes #2 = { nounwind uwtable "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "min-legal-vector-width"="8192" "no-infs-fp-math"="false" "no-jump-tables"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #3 = { "correctly-rounded-divide-sqrt-fp-math"="false" "disable-tail-calls"="false" "frame-pointer"="none" "less-precise-fpmad"="false" "no-infs-fp-math"="false" "no-nans-fp-math"="false" "no-signed-zeros-fp-math"="false" "no-trapping-math"="true" "stack-protector-buffer-size"="8" "target-cpu"="x86-64" "target-features"="+amx-int8,+amx-tile,+cx8,+fxsr,+mmx,+sse,+sse2,+x87" "tune-cpu"="generic" "unsafe-fp-math"="false" "use-soft-float"="false" }
+attributes #4 = { nounwind }
Index: llvm/lib/Target/X86/X86TileConfig.cpp
===================================================================
--- /dev/null
+++ llvm/lib/Target/X86/X86TileConfig.cpp
@@ -0,0 +1,222 @@
+//===-- X86TileConfig.cpp - Tile Register Configure----------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+
+#include "X86.h"
+#include "X86InstrBuilder.h"
+#include "X86MachineFunctionInfo.h"
+#include "X86RegisterInfo.h"
+#include "X86Subtarget.h"
+#include "llvm/CodeGen/LiveIntervals.h"
+#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachineFrameInfo.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/CodeGen/TileShapeInfo.h"
+#include "llvm/CodeGen/VirtRegMap.h"
+#include "llvm/InitializePasses.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "tile-config"
+
+namespace {
+
+class X86TileConfig : public MachineFunctionPass {
+  // context
+  MachineFunction *MF = nullptr;
+  const X86Subtarget *ST = nullptr;
+  const TargetRegisterInfo *TRI;
+  const TargetInstrInfo *TII;
+  MachineDominatorTree *DomTree = nullptr;
+  MachineRegisterInfo *MRI = nullptr;
+  VirtRegMap *VRM = nullptr;
+  LiveIntervals *LIS = nullptr;
+
+  MachineInstr *getTileConfigPoint();
+  void tileConfig();
+
+public:
+  X86TileConfig() : MachineFunctionPass(ID) {}
+
+  /// Return the pass name.
+  StringRef getPassName() const override { return "Tile Register Configure"; }
+
+  /// X86TileConfig analysis usage.
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+  /// Perform register allocation.
+  bool runOnMachineFunction(MachineFunction &mf) override;
+
+  MachineFunctionProperties getRequiredProperties() const override {
+    return MachineFunctionProperties().set(
+        MachineFunctionProperties::Property::NoPHIs);
+  }
+
+  static char ID;
+};
+
+} // end anonymous namespace
+
+char X86TileConfig::ID = 0;
+
+INITIALIZE_PASS_BEGIN(X86TileConfig, "tileconfig", "Tile Register Configure",
+                      false, false)
+INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
+INITIALIZE_PASS_DEPENDENCY(VirtRegMap)
+INITIALIZE_PASS_END(X86TileConfig, "tileconfig", "Tile Register Configure",
+                    false, false)
+
+void X86TileConfig::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.addRequired<MachineDominatorTree>();
+  AU.addRequired<LiveIntervals>();
+  AU.addPreserved<SlotIndexes>();
+  AU.addRequired<VirtRegMap>();
+  AU.setPreservesAll();
+  MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+static unsigned getTilePhysRegIndex(Register PhysReg) {
+  assert((PhysReg >= X86::TMM0 && X86::TMM0 <= X86::TMM7) &&
+         "Tile register number is invalid");
+  return (PhysReg - X86::TMM0);
+}
+
+static MachineInstr *
+storeRegToStackSlot(MachineBasicBlock &MBB, MachineBasicBlock::iterator MI,
+                    Register SrcReg, unsigned BitSize, int FrameIdx, int Offset,
+                    const TargetInstrInfo *TII, const TargetRegisterClass *RC,
+                    const TargetRegisterInfo *TRI) {
+
+  unsigned SubIdx = (BitSize == 8) ? X86::sub_8bit : X86::sub_16bit;
+  unsigned Opc = (BitSize == 8) ? X86::MOV8mr : X86::MOV16mr;
+  MachineInstr *NewMI =
+      addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)), FrameIdx,
+                        Offset)
+          .addReg(SrcReg);
+  MachineOperand &MO = NewMI->getOperand(5);
+  if (BitSize < TRI->getRegSizeInBits(*RC))
+    MO.setSubReg(SubIdx);
+  return NewMI;
+}
+
+static MachineInstr *storeImmToStackSlot(MachineBasicBlock &MBB,
+                                         MachineBasicBlock::iterator MI,
+                                         int64_t Imm, unsigned BitSize,
+                                         int FrameIdx, int Offset,
+                                         const TargetInstrInfo *TII) {
+  unsigned Opc = (BitSize == 8) ? X86::MOV8mi : X86::MOV16mi;
+  return addFrameReference(BuildMI(MBB, MI, DebugLoc(), TII->get(Opc)),
+                           FrameIdx, Offset)
+      .addImm(Imm);
+}
+
+MachineInstr *X86TileConfig::getTileConfigPoint() {
+  for (MachineBasicBlock &MBB : *MF) {
+
+    // Traverse the basic block.
+    for (MachineInstr &MI : MBB)
+      // Refer X86PreTileConfig.cpp.
+      // We only support one tile config for now.
+      if (MI.getOpcode() == X86::PLDTILECFG)
+        return &MI;
+  }
+
+  return nullptr;
+}
+
+void X86TileConfig::tileConfig() {
+  MachineInstr *MI = getTileConfigPoint();
+  if (!MI)
+    return;
+  MachineBasicBlock *MBB = MI->getParent();
+  int SS = MI->getOperand(1).getIndex();
+  BitVector PhysRegs(TRI->getNumRegs());
+
+  // Insert ldtilecfg to the MBB
+  for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
+    unsigned VirtReg = Register::index2VirtReg(i);
+    if (MRI->reg_nodbg_empty(VirtReg))
+      continue;
+    const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
+    if (RC.getID() != X86::TILERegClassID)
+      continue;
+    Register PhysReg = VRM->getPhys(VirtReg);
+    if (PhysRegs.test(PhysReg))
+      continue;
+    PhysRegs.set(PhysReg);
+    ShapeT Shape = VRM->getShape(VirtReg);
+    Register RowReg = Shape.getRow()->getReg();
+    Register ColReg = Shape.getCol()->getReg();
+
+    unsigned Index = getTilePhysRegIndex(PhysReg);
+    int RowOffset = 48 + Index;
+    int ColOffset = 16 + Index * 2;
+
+    unsigned BitSize = 8;
+    for (const auto &Pair : {std::make_pair(RowReg, RowOffset),
+                             std::make_pair(ColReg, ColOffset)}) {
+      int64_t Imm;
+      int ImmCount = 0;
+      // All def must be the same value, otherwise it is invalid MIs.
+      // Immediate is prefered.
+      for (const MachineOperand &MO : MRI->def_operands(Pair.first)) {
+        const auto *Inst = MO.getParent();
+        if (Inst->isMoveImmediate()) {
+          ImmCount++;
+          Imm = Inst->getOperand(1).getImm();
+          break;
+        }
+      }
+      auto StoreConfig = [&](int Offset) {
+        MachineInstr *NewMI = nullptr;
+        if (ImmCount)
+          NewMI = storeImmToStackSlot(*MBB, *MI, Imm, BitSize, SS, Offset, TII);
+        else {
+          const TargetRegisterClass *RC = MRI->getRegClass(Pair.first);
+          NewMI = storeRegToStackSlot(*MBB, *MI, Pair.first, BitSize, SS,
+                                      Offset, TII, RC, TRI);
+        }
+        SlotIndex SIdx = LIS->InsertMachineInstrInMaps(*NewMI);
+        if (!ImmCount) {
+          // Extend the live interval.
+          SmallVector<SlotIndex, 8> EndPoints = {SIdx.getRegSlot()};
+          LiveInterval &Int = LIS->getInterval(Pair.first);
+          LIS->extendToIndices(Int, EndPoints);
+        }
+      };
+      StoreConfig(Pair.second);
+      BitSize += 8;
+    }
+  }
+}
+
+bool X86TileConfig::runOnMachineFunction(MachineFunction &mf) {
+  LLVM_DEBUG(dbgs() << "********** TILE REGISTER CONFIGURE**********\n"
+                    << "********** Function: " << mf.getName() << '\n');
+  MF = &mf;
+  MRI = &mf.getRegInfo();
+  ST = &mf.getSubtarget<X86Subtarget>();
+  TRI = ST->getRegisterInfo();
+  TII = mf.getSubtarget().getInstrInfo();
+  DomTree = &getAnalysis<MachineDominatorTree>();
+  VRM = &getAnalysis<VirtRegMap>();
+  LIS = &getAnalysis<LiveIntervals>();
+
+  if (VRM->isShapeMapEmpty())
+    return false;
+
+  tileConfig();
+  return true;
+}
+
+FunctionPass *llvm::createX86TileConfigPass() { return new X86TileConfig(); }
Index: llvm/lib/Target/X86/X86TargetMachine.cpp
===================================================================
--- llvm/lib/Target/X86/X86TargetMachine.cpp
+++ llvm/lib/Target/X86/X86TargetMachine.cpp
@@ -62,6 +62,7 @@
   RegisterTargetMachine<X86TargetMachine> Y(getTheX86_64Target());
 
   PassRegistry &PR = *PassRegistry::getPassRegistry();
+  initializeX86LowerAMXTypeLegacyPassPass(PR);
   initializeGlobalISel(PR);
   initializeWinEHStatePassPass(PR);
   initializeFixupBWInstPassPass(PR);
@@ -71,6 +72,7 @@
   initializeX86FixupSetCCPassPass(PR);
   initializeX86CallFrameOptimizationPass(PR);
   initializeX86CmovConverterPassPass(PR);
+  initializeX86TileConfigPass(PR);
   initializeX86ExpandPseudoPass(PR);
   initializeX86ExecutionDomainFixPass(PR);
   initializeX86DomainReassignmentPass(PR);
@@ -378,6 +380,7 @@
   void addPreEmitPass() override;
   void addPreEmitPass2() override;
   void addPreSched2() override;
+  bool addPreRewrite() override;
 
   std::unique_ptr<CSEConfigBase> getCSEConfig() const override;
 };
@@ -406,6 +409,7 @@
 
 void X86PassConfig::addIRPasses() {
   addPass(createAtomicExpandPass());
+  addPass(createX86LowerAMXTypePass());
 
   TargetPassConfig::addIRPasses();
 
@@ -491,7 +495,12 @@
   addPass(createX86SpeculativeLoadHardeningPass());
   addPass(createX86FlagsCopyLoweringPass());
   addPass(createX86WinAllocaExpander());
+
+  if (getOptLevel() != CodeGenOpt::None) {
+    addPass(createX86PreTileConfigPass());
+  }
 }
+
 void X86PassConfig::addMachineSSAOptimization() {
   addPass(createX86DomainReassignmentPass());
   TargetPassConfig::addMachineSSAOptimization();
@@ -564,6 +573,11 @@
   addPass(createX86LoadValueInjectionRetHardeningPass());
 }
 
+bool X86PassConfig::addPreRewrite() {
+  addPass(createX86TileConfigPass());
+  return true;
+}
+
 std::unique_ptr<CSEConfigBase> X86PassConfig::getCSEConfig() const {
   return getStandardCSEConfigForOpt(TM->getOptLevel());
 }
Index: llvm/lib/Target/X86/X86Subtarget.h
===================================================================
--- llvm/lib/Target/X86/X86Subtarget.h
+++ llvm/lib/Target/X86/X86Subtarget.h
@@ -469,6 +469,8 @@
   /// entry to the function and which must be maintained by every function.
   Align stackAlignment = Align(4);
 
+  Align tileConfigAlignment = Align(4);
+
   /// Max. memset / memcpy size that is turned into rep/movs, rep/stos ops.
   ///
   // FIXME: this is a known good value for Yonah. How about others?
@@ -552,6 +554,9 @@
     return &getInstrInfo()->getRegisterInfo();
   }
 
+  unsigned getTileConfigSize() const { return 64; }
+  Align getTileConfigAlignment() const { return tileConfigAlignment; }
+
   /// Returns the minimum alignment known to hold of the
   /// stack frame on entry to the function and which must be maintained by every
   /// function for this subtarget.
Index: llvm/lib/Target/X86/X86RegisterInfo.td
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.td
+++ llvm/lib/Target/X86/X86RegisterInfo.td
@@ -265,6 +265,9 @@
   }
 }
 
+// Tile config registers.
+def TMMCFG: X86Reg<"tmmcfg", 0>;
+
 // Tile "registers".
 def TMM0:  X86Reg<"tmm0",   0>;
 def TMM1:  X86Reg<"tmm1",   1>;
@@ -633,6 +636,11 @@
 def BNDR : RegisterClass<"X86", [v2i64], 128, (sequence "BND%u", 0, 3)>;
 
 // Tiles
-let isAllocatable = 0 in
-def TILE : RegisterClass<"X86", [untyped], 0,
+let CopyCost = -1 in // Don't allow copy tile register
+def TILE : RegisterClass<"X86", [v256i32], 8192,
                          (sequence "TMM%u", 0, 7)> {let Size = 8192;}
+def TILECFG : RegisterClass<"X86", [v512i1], 512, (add TMMCFG)> {
+  let CopyCost = -1;  // Don't allow copying of tile config registers.
+  let isAllocatable = 1;
+  let Size = 512;
+}
Index: llvm/lib/Target/X86/X86RegisterInfo.h
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.h
+++ llvm/lib/Target/X86/X86RegisterInfo.h
@@ -144,6 +144,11 @@
   Register getFramePtr() const { return FramePtr; }
   // FIXME: Move to FrameInfok
   unsigned getSlotSize() const { return SlotSize; }
+
+  bool getRegAllocationHints(Register VirtReg, ArrayRef<MCPhysReg> Order,
+                             SmallVectorImpl<MCPhysReg> &Hints,
+                             const MachineFunction &MF, const VirtRegMap *VRM,
+                             const LiveRegMatrix *Matrix) const override;
 };
 
 } // End llvm namespace
Index: llvm/lib/Target/X86/X86RegisterInfo.cpp
===================================================================
--- llvm/lib/Target/X86/X86RegisterInfo.cpp
+++ llvm/lib/Target/X86/X86RegisterInfo.cpp
@@ -19,6 +19,7 @@
 #include "llvm/ADT/BitVector.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallSet.h"
+#include "llvm/CodeGen/LiveRegMatrix.h"
 #include "llvm/CodeGen/MachineFrameInfo.h"
 #include "llvm/CodeGen/MachineFunction.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
@@ -855,3 +856,77 @@
     StackReg = getX86SubSuperRegister(StackReg, 32);
   return StackReg;
 }
+
+static ShapeT getTileShape(Register VirtReg, VirtRegMap *VRM,
+                           const MachineRegisterInfo *MRI) {
+  if (VRM->hasShape(VirtReg))
+    return VRM->getShape(VirtReg);
+
+  const MachineOperand &Def = *MRI->def_begin(VirtReg);
+  MachineInstr *MI = const_cast<MachineInstr *>(Def.getParent());
+  unsigned OpCode = MI->getOpcode();
+  switch (OpCode) {
+  default:
+    llvm_unreachable("Unexpected machine instruction on tile register!");
+    break;
+  // We only collect the tile shape that is defined.
+  case X86::PTILELOADDV:
+  case X86::PTDPBSSDV:
+    MachineOperand &MO1 = MI->getOperand(1);
+    MachineOperand &MO2 = MI->getOperand(2);
+    ShapeT Shape(&MO1, &MO2, MRI);
+    VRM->assignVirt2Shape(VirtReg, Shape);
+    return Shape;
+  }
+}
+
+bool X86RegisterInfo::getRegAllocationHints(Register VirtReg,
+                                            ArrayRef<MCPhysReg> Order,
+                                            SmallVectorImpl<MCPhysReg> &Hints,
+                                            const MachineFunction &MF,
+                                            const VirtRegMap *VRM,
+                                            const LiveRegMatrix *Matrix) const {
+  const MachineRegisterInfo *MRI = &MF.getRegInfo();
+  const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
+  bool BaseImplRetVal = TargetRegisterInfo::getRegAllocationHints(
+      VirtReg, Order, Hints, MF, VRM, Matrix);
+
+  if (RC.getID() != X86::TILERegClassID)
+    return BaseImplRetVal;
+
+  ShapeT VirtShape = getTileShape(VirtReg, const_cast<VirtRegMap *>(VRM), MRI);
+  auto AddHint = [&](MCPhysReg PhysReg) {
+    Register VReg = Matrix->getOneVReg(PhysReg);
+    if (VReg == MCRegister::NoRegister) { // Not allocated yet
+      Hints.push_back(PhysReg);
+      return;
+    }
+    ShapeT PhysShape = getTileShape(VReg, const_cast<VirtRegMap *>(VRM), MRI);
+    if (PhysShape == VirtShape)
+      Hints.push_back(PhysReg);
+  };
+
+  SmallSet<unsigned, 4> CopyHints;
+  CopyHints.insert(Hints.begin(), Hints.end());
+  Hints.clear();
+  for (auto Hint : CopyHints) {
+    if (RC.contains(Hint) && !MRI->isReserved(Hint))
+      AddHint(Hint);
+  }
+  for (MCPhysReg PhysReg : Order) {
+    if (!MRI->isReserved(PhysReg))
+      AddHint(PhysReg);
+  }
+
+#define DEBUG_TYPE "tile-hint"
+  LLVM_DEBUG({
+    dbgs() << "Hints for virtual register " << format_hex(VirtReg, 8) << "\n";
+    for (auto Hint : Hints) {
+      dbgs() << "tmm" << Hint << ",";
+    }
+    dbgs() << "\n";
+  });
+#undef DEBUG_TYPE
+
+  return true;
+}
Index: llvm/lib/Target/X86/X86PreTileConfig.cpp
===================================================================
--- /dev/null
+++ llvm/lib/Target/X86/X86PreTileConfig.cpp
@@ -0,0 +1,225 @@
+//===-- X86PreTileConfig.cpp - Tile Register Configure---------------------===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+
+#include "X86.h"
+#include "X86InstrBuilder.h"
+#include "X86RegisterInfo.h"
+#include "X86Subtarget.h"
+#include "llvm/CodeGen/MachineDominators.h"
+#include "llvm/CodeGen/MachineFunctionPass.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/TargetInstrInfo.h"
+#include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/CodeGen/TileShapeInfo.h"
+#include "llvm/InitializePasses.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "tile-pre-config"
+
+namespace {
+
+class X86PreTileConfig : public MachineFunctionPass {
+  // context
+  MachineFunction *MF = nullptr;
+  const X86Subtarget *ST = nullptr;
+  const TargetRegisterInfo *TRI;
+  const TargetInstrInfo *TII;
+  MachineDominatorTree *DomTree = nullptr;
+  MachineRegisterInfo *MRI = nullptr;
+
+  MachineInstr *getTileConfigPoint();
+
+public:
+  X86PreTileConfig() : MachineFunctionPass(ID) {}
+
+  /// Return the pass name.
+  StringRef getPassName() const override {
+    return "Tile Register Pre-configure";
+  }
+
+  /// X86PreTileConfig analysis usage.
+  void getAnalysisUsage(AnalysisUsage &AU) const override;
+
+  /// Perform register allocation.
+  bool runOnMachineFunction(MachineFunction &mf) override;
+
+  static char ID;
+};
+
+} // end anonymous namespace
+
+char X86PreTileConfig::ID = 0;
+
+INITIALIZE_PASS_BEGIN(X86PreTileConfig, "tilepreconfig",
+                      "Tile Register Configure", false, false)
+INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
+INITIALIZE_PASS_END(X86PreTileConfig, "tilepreconfig",
+                    "Tile Register Configure", false, false)
+
+void X86PreTileConfig::getAnalysisUsage(AnalysisUsage &AU) const {
+  AU.setPreservesAll();
+  AU.addRequired<MachineDominatorTree>();
+  MachineFunctionPass::getAnalysisUsage(AU);
+}
+
+static Register buildConfigMI(MachineBasicBlock::iterator MI, int FrameIdx,
+                              const TargetInstrInfo *TII,
+                              MachineRegisterInfo *MRI) {
+  auto *MBB = MI->getParent();
+
+  // Zero stack slot. Assume we have AVX512
+  Register Zmm = MRI->createVirtualRegister(&X86::VR512RegClass);
+  BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VPXORDZrr), Zmm)
+      .addReg(Zmm, RegState::Undef)
+      .addReg(Zmm, RegState::Undef);
+  addFrameReference(BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::VMOVUPSZmr)),
+                    FrameIdx)
+      .addReg(Zmm);
+
+  // build psuedo ldtilecfg
+  Register VReg = MRI->createVirtualRegister(&X86::TILECFGRegClass);
+
+  addFrameReference(
+      BuildMI(*MBB, MI, DebugLoc(), TII->get(X86::PLDTILECFG), VReg), FrameIdx);
+
+  return VReg;
+}
+
+static ShapeT getShape(const MachineInstr &MI, MachineRegisterInfo *MRI) {
+  unsigned Opcode = MI.getOpcode();
+  switch (Opcode) {
+  default:
+    llvm_unreachable("Unexpected machine instruction on tile");
+  case X86::PTILELOADDV:
+  case X86::PTDPBSSDV:
+    MachineOperand &MO1 = const_cast<MachineOperand &>(MI.getOperand(1));
+    MachineOperand &MO2 = const_cast<MachineOperand &>(MI.getOperand(2));
+    ShapeT Shape(&MO1, &MO2, MRI);
+    return Shape;
+  }
+}
+
+MachineInstr *X86PreTileConfig::getTileConfigPoint() {
+  DenseMap<Register, ShapeT> PhysShapeInfo;
+  MachineBasicBlock *MBB = nullptr;
+  DenseSet<const MachineInstr *> MIs;
+  for (unsigned i = 0, e = MRI->getNumVirtRegs(); i != e; ++i) {
+    unsigned VirtReg = Register::index2VirtReg(i);
+    if (MRI->reg_nodbg_empty(VirtReg))
+      continue;
+    const TargetRegisterClass &RC = *MRI->getRegClass(VirtReg);
+    if (RC.getID() != X86::TILERegClassID)
+      continue;
+
+    // Find the common dominator for all MI that define tile register.
+    for (const MachineOperand &MO : MRI->def_operands(VirtReg)) {
+      if (MO.isUndef())
+        continue;
+      const auto *MI = MO.getParent();
+      // PHI or IMPLICIT_DEF instructiion.
+      // There must be a input tile before PHI instruction.
+      if (MI->isTransient())
+        continue;
+      if (!MBB)
+        MBB = const_cast<MachineBasicBlock *>(MI->getParent());
+      MBB = DomTree->findNearestCommonDominator(
+          MBB, const_cast<MachineBasicBlock *>(MI->getParent()));
+
+      // Collect the instructions that define shape.
+      ShapeT Shape = getShape(*MI, MRI);
+      std::array<MachineOperand *, 2> ShapeMOs = {Shape.getRow(),
+                                                  Shape.getCol()};
+      for (auto *ShapeMO : ShapeMOs) {
+        Register ShapeReg = ShapeMO->getReg();
+        for (const MachineOperand &MO : MRI->def_operands(ShapeReg)) {
+          const auto *ShapeMI = MO.getParent();
+          MIs.insert(ShapeMI);
+        }
+      }
+    }
+  }
+  if (!MBB)
+    return nullptr;
+  // Shape def should dominate tile config MBB.
+  // TODO: Improve for shape that is immediate.
+  for (const auto *MI : MIs) {
+    const MachineBasicBlock *ShapeMBB = MI->getParent();
+    if (DomTree->dominates(ShapeMBB, MBB))
+      continue;
+    if (MI->isMoveImmediate())
+      continue;
+    report_fatal_error("Failed to config tile register, "
+                       "please define the shape earlier");
+  }
+
+  // ldtilecfg should be inserted after the MI that define the shape.
+  MachineBasicBlock::reverse_instr_iterator I, E;
+  for (I = MBB->instr_rbegin(), E = MBB->instr_rend(); I != E; ++I) {
+    auto *MI = &*I;
+    if (MIs.count(MI) && (!MI->isMoveImmediate()))
+      break;
+  }
+  MachineBasicBlock::iterator MII;
+  if (I == E)
+    MII = MBB->getFirstNonPHI();
+  else {
+    MII = MachineBasicBlock::iterator(&*I);
+    MII++;
+  }
+  return &*MII;
+}
+
+static void addTileCFGUse(MachineFunction &mf, Register cfg) {
+  for (MachineBasicBlock &MBB : mf) {
+
+    // Traverse the basic block.
+    for (MachineInstr &MI : MBB) {
+      unsigned Opcode = MI.getOpcode();
+      switch (Opcode) {
+      default:
+        break;
+      case X86::PTILELOADDV:
+      case X86::PTILESTOREDV:
+      case X86::PTDPBSSDV:
+        unsigned NumOperands = MI.getNumOperands();
+        MI.RemoveOperand(NumOperands - 1);
+        MI.addOperand(mf, MachineOperand::CreateReg(cfg, false));
+        break;
+      }
+    }
+  }
+}
+
+bool X86PreTileConfig::runOnMachineFunction(MachineFunction &mf) {
+  LLVM_DEBUG(dbgs() << "********** TILE REGISTER CONFIGURE**********\n"
+                    << "********** Function: " << mf.getName() << '\n');
+  MF = &mf;
+  MRI = &mf.getRegInfo();
+  ST = &mf.getSubtarget<X86Subtarget>();
+  TRI = ST->getRegisterInfo();
+  TII = mf.getSubtarget().getInstrInfo();
+  DomTree = &getAnalysis<MachineDominatorTree>();
+
+  MachineInstr *MI = getTileConfigPoint();
+  if (!MI)
+    return false;
+  unsigned Size = ST->getTileConfigSize();
+  Align Alignment = ST->getTileConfigAlignment();
+  int SS = mf.getFrameInfo().CreateStackObject(Size, Alignment, false);
+  Register CFG = buildConfigMI(MI, SS, TII, MRI);
+  addTileCFGUse(mf, CFG);
+  return true;
+}
+
+FunctionPass *llvm::createX86PreTileConfigPass() {
+  return new X86PreTileConfig();
+}
Index: llvm/lib/Target/X86/X86LowerAMXType.cpp
===================================================================
--- /dev/null
+++ llvm/lib/Target/X86/X86LowerAMXType.cpp
@@ -0,0 +1,277 @@
+#include "X86.h"
+#include "llvm/ADT/DenseSet.h"
+#include "llvm/Analysis/OptimizationRemarkEmitter.h"
+#include "llvm/Analysis/TargetTransformInfo.h"
+#include "llvm/CodeGen/Passes.h"
+#include "llvm/CodeGen/ValueTypes.h"
+#include "llvm/IR/DataLayout.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Instructions.h"
+#include "llvm/IR/IntrinsicInst.h"
+#include "llvm/IR/IntrinsicsX86.h"
+#include "llvm/InitializePasses.h"
+#include "llvm/Pass.h"
+
+using namespace llvm;
+
+#define DEBUG_TYPE "lower-amx-type"
+
+namespace {
+class X86LowerAMXType {
+  Function &Func;
+  const DataLayout &DL;
+  DenseSet<Instruction *> LDSet;
+  DenseSet<Instruction *> STSet;
+  DenseMap<Value *, std::pair<LoadInst *, LoadInst *>> LoadMap;
+
+public:
+  X86LowerAMXType(Function &F) : Func(F), DL(F.getParent()->getDataLayout()) {}
+  bool Visit();
+  bool VisitLD();
+  bool VisitST();
+  void SplitST(Instruction *Inst);
+  void SplitLD(Instruction *Inst);
+};
+
+// Split v256i32 load/store to 2 v128i32, so that ISel can
+// lower it to proper vector size.
+void X86LowerAMXType::SplitST(Instruction *Inst) {
+  StoreInst *ST = dyn_cast<StoreInst>(Inst);
+  IRBuilder<> Builder(ST);
+  LLVMContext &Ctx = Builder.getContext();
+  Type *Ty = ST->getValueOperand()->getType();
+  EVT VT = EVT::getEVT(Ty);
+  EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
+  Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
+
+  LoadInst *Lo, *Hi;
+  std::tie(Lo, Hi) = LoadMap[ST->getValueOperand()];
+  Value *Ptr = ST->getPointerOperand();
+  PointerType *HalfPtrTy = HalfTy->getPointerTo(ST->getPointerAddressSpace());
+  Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
+  // The HW require the alignment for AMX tile is 64, but front-end generate
+  // code for the vector alignment which is the vector size.
+  uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8;
+  Align Alignment = std::min(Lo->getAlign(), Align(HalfTySize));
+  Builder.CreateAlignedStore(Lo, HalfPtr, Alignment, ST->isVolatile());
+
+  HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
+  Builder.CreateAlignedStore(Hi, HalfPtr, Alignment, ST->isVolatile());
+}
+
+bool X86LowerAMXType::VisitST() {
+  if (STSet.empty())
+    return false;
+  for (auto *Inst : STSet) {
+    Value *Row, *Col;
+    const IntrinsicInst *II = dyn_cast<IntrinsicInst>(Inst->getOperand(0));
+    if (!II)
+      Row = Col = nullptr;
+    else {
+      switch (II->getIntrinsicID()) {
+      default:
+        Row = Col = nullptr;
+        break;
+      case Intrinsic::x86_tileloadd64_internal:
+      case Intrinsic::x86_tdpbssd_internal: {
+        Row = II->getArgOperand(0);
+        Col = II->getArgOperand(1);
+        break;
+      }
+      }
+    }
+    if (!Row) {
+      SplitST(Inst);
+      continue;
+    }
+    IRBuilder<> Builder(Inst);
+    LLVMContext &Ctx = Builder.getContext();
+    // Use the maximun column as stride. It must be the same with load stride.
+    Value *Stride = Builder.getInt64(64);
+    Value *I8Ptr =
+        Builder.CreateBitCast(Inst->getOperand(1), Type::getInt8PtrTy(Ctx));
+    std::array<Value *, 5> Args = {Row, Col, I8Ptr, Stride,
+                                   Inst->getOperand(0)};
+
+    Builder.CreateIntrinsic(Intrinsic::x86_tilestored64_internal, None, Args);
+  }
+  return true;
+}
+
+void X86LowerAMXType::SplitLD(Instruction *Inst) {
+  LoadInst *LD = dyn_cast<LoadInst>(Inst);
+  IRBuilder<> Builder(LD);
+  LLVMContext &Ctx = Builder.getContext();
+  Type *Ty = LD->getType();
+  EVT VT = EVT::getEVT(Ty);
+  EVT HalfVT = VT.getHalfNumVectorElementsVT(Ctx);
+  Type *HalfTy = HalfVT.getTypeForEVT(Ctx);
+
+  Value *Ptr = LD->getPointerOperand();
+  PointerType *HalfPtrTy = HalfTy->getPointerTo(LD->getPointerAddressSpace());
+  Value *HalfPtr = Builder.CreateBitCast(Ptr, HalfPtrTy);
+  // The HW require the alignment for AMX tile is 64, but front-end generate
+  // code for the vector alignment which is the vector size.
+  uint64_t HalfTySize = HalfTy->getPrimitiveSizeInBits().getFixedSize() / 8;
+  Align Alignment = std::min(LD->getAlign(), Align(HalfTySize));
+  auto *Lo =
+      Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile());
+
+  HalfPtr = Builder.CreateGEP(HalfTy, HalfPtr, Builder.getInt32(1));
+  auto *Hi =
+      Builder.CreateAlignedLoad(HalfTy, HalfPtr, Alignment, LD->isVolatile());
+
+  LoadMap[Inst] = std::make_pair(Lo, Hi);
+}
+
+bool X86LowerAMXType::VisitLD() {
+  if (LDSet.empty())
+    return false;
+  for (auto &Inst : LDSet) {
+    int Count = 0;
+    Value *NewInst = nullptr;
+    // The user should be all AMX intrinsics or all LLVM instruction.
+    // Don't support it is used by both AMX intrinsics and LLVM instructions.
+    for (auto I = Inst->use_begin(), E = Inst->use_end(); I != E;) {
+      Use &U = *I++;
+      const IntrinsicInst *II = dyn_cast<IntrinsicInst>(U.getUser());
+      if (!II) {
+        Count++;
+        continue;
+      }
+      if (NewInst)
+        continue;
+      Value *Row, *Col;
+      switch (II->getIntrinsicID()) {
+      default:
+        report_fatal_error("Non-AMX intrinsic use tile type.");
+        break;
+      case Intrinsic::x86_tdpbssd_internal: {
+        unsigned OpNo = U.getOperandNo();
+        switch (OpNo) {
+        case 3:
+          Row = II->getArgOperand(0);
+          Col = II->getArgOperand(1);
+          break;
+        case 4:
+          Row = II->getArgOperand(0);
+          Col = II->getArgOperand(2);
+          break;
+        case 5:
+          Row = II->getArgOperand(2);
+          Col = II->getArgOperand(1);
+          break;
+        }
+        break;
+      }
+      case Intrinsic::x86_tilestored64_internal: {
+        Row = II->getArgOperand(0);
+        Col = II->getArgOperand(1);
+        break;
+      }
+      }
+      assert(Count == 0 && "Can NOT mix amx intrinsic and LLVM instruction");
+      // FIXME: The shape def should be ahead of load.
+      IRBuilder<> Builder(Inst);
+      LLVMContext &Ctx = Builder.getContext();
+      // Use the maximun column as stride.
+      Value *Stride = Builder.getInt64(64);
+      Value *I8Ptr =
+          Builder.CreateBitCast(Inst->getOperand(0), Type::getInt8PtrTy(Ctx));
+      std::array<Value *, 4> Args = {Row, Col, I8Ptr, Stride};
+
+      NewInst = Builder.CreateIntrinsic(Intrinsic::x86_tileloadd64_internal,
+                                        None, Args);
+
+      Inst->replaceAllUsesWith(NewInst);
+    }
+    if (!NewInst)
+      SplitLD(Inst);
+  }
+  return true;
+}
+
+bool X86LowerAMXType::Visit() {
+  bool C;
+  auto IsAMXType = [](FixedVectorType *VTy) {
+    if (!VTy)
+      return false;
+    if (!VTy->getScalarType()->isIntegerTy(32))
+      return false;
+    if (VTy->getNumElements() != 256)
+      return false;
+
+    return true;
+  };
+
+  for (BasicBlock &BB : Func) {
+    for (Instruction &Inst : BB) {
+      LoadInst *LD = dyn_cast<LoadInst>(&Inst);
+      // Check load instruction.
+      // %3 = load <256 x i32>, <256 x i32>* %1, align 64
+      if (LD) {
+        FixedVectorType *VTy = dyn_cast<FixedVectorType>(Inst.getType());
+        if (!IsAMXType(VTy))
+          continue;
+        LDSet.insert(&Inst);
+        continue;
+      }
+      // Check store instruction.
+      // store <256 x i32> %3, <256 x i32>* %2, align 64
+      StoreInst *ST = dyn_cast<StoreInst>(&Inst);
+      if (!ST)
+        continue;
+      FixedVectorType *VTy =
+          dyn_cast<FixedVectorType>(ST->getOperand(0)->getType());
+      if (!IsAMXType(VTy))
+        continue;
+      STSet.insert(&Inst);
+    }
+  }
+
+  C = VisitLD() | VisitST();
+  for (auto *Inst : STSet)
+    Inst->eraseFromParent();
+  for (auto *Inst : LDSet)
+    Inst->eraseFromParent();
+  return C;
+}
+} // anonymous namespace
+
+namespace {
+
+class X86LowerAMXTypeLegacyPass : public FunctionPass {
+public:
+  static char ID;
+
+  X86LowerAMXTypeLegacyPass() : FunctionPass(ID) {
+    initializeX86LowerAMXTypeLegacyPassPass(*PassRegistry::getPassRegistry());
+  }
+
+  bool runOnFunction(Function &F) override {
+    X86LowerAMXType LAT(F);
+    bool C = LAT.Visit();
+    return C;
+  }
+
+  void getAnalysisUsage(AnalysisUsage &AU) const override {
+    AU.setPreservesCFG();
+  }
+
+private:
+  Function *F;
+};
+
+} // anonymous namespace
+
+static const char pass_name[] = "Lower AMX type for load/store";
+char X86LowerAMXTypeLegacyPass::ID = 0;
+INITIALIZE_PASS_BEGIN(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, pass_name, false,
+                      false)
+INITIALIZE_PASS_END(X86LowerAMXTypeLegacyPass, DEBUG_TYPE, pass_name, false,
+                    false)
+
+FunctionPass *llvm::createX86LowerAMXTypePass() {
+  return new X86LowerAMXTypeLegacyPass();
+}
Index: llvm/lib/Target/X86/X86InstrInfo.cpp
===================================================================
--- llvm/lib/Target/X86/X86InstrInfo.cpp
+++ llvm/lib/Target/X86/X86InstrInfo.cpp
@@ -3797,13 +3797,31 @@
   const MachineFunction &MF = *MBB.getParent();
   assert(MF.getFrameInfo().getObjectSize(FrameIdx) >= TRI->getSpillSize(*RC) &&
          "Stack slot too small for store");
-  unsigned Alignment = std::max<uint32_t>(TRI->getSpillSize(*RC), 16);
-  bool isAligned =
-      (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) ||
-      RI.canRealignStack(MF);
-  unsigned Opc = getStoreRegOpcode(SrcReg, RC, isAligned, Subtarget);
-  addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx)
-    .addReg(SrcReg, getKillRegState(isKill));
+  if (RC->getID() == X86::TILERegClassID) {
+    unsigned Opc = X86::TILESTORED;
+    // tilestored %tmm, (%sp, %idx)
+    MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();
+    Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass);
+    MachineInstr *NewMI =
+        BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri), VirtReg).addImm(64);
+    NewMI = addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx)
+                .addReg(SrcReg, getKillRegState(isKill));
+    MachineOperand &MO = NewMI->getOperand(2);
+    MO.setReg(VirtReg);
+    MO.setIsKill(true);
+  } else if (RC->getID() == X86::TILECFGRegClassID) {
+    unsigned Opc = X86::PSTTILECFG;
+    addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx)
+        .addReg(SrcReg, getKillRegState(isKill));
+  } else {
+    unsigned Alignment = std::max<uint32_t>(TRI->getSpillSize(*RC), 16);
+    bool isAligned =
+        (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) ||
+        RI.canRealignStack(MF);
+    unsigned Opc = getStoreRegOpcode(SrcReg, RC, isAligned, Subtarget);
+    addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc)), FrameIdx)
+        .addReg(SrcReg, getKillRegState(isKill));
+  }
 }
 
 void X86InstrInfo::loadRegFromStackSlot(MachineBasicBlock &MBB,
@@ -3811,13 +3829,32 @@
                                         Register DestReg, int FrameIdx,
                                         const TargetRegisterClass *RC,
                                         const TargetRegisterInfo *TRI) const {
-  const MachineFunction &MF = *MBB.getParent();
-  unsigned Alignment = std::max<uint32_t>(TRI->getSpillSize(*RC), 16);
-  bool isAligned =
-      (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) ||
-      RI.canRealignStack(MF);
-  unsigned Opc = getLoadRegOpcode(DestReg, RC, isAligned, Subtarget);
-  addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg), FrameIdx);
+  if (RC->getID() == X86::TILERegClassID) {
+    unsigned Opc = X86::TILELOADD;
+    // tileloadd (%sp, %idx), %tmm
+    MachineRegisterInfo &RegInfo = MBB.getParent()->getRegInfo();
+    Register VirtReg = RegInfo.createVirtualRegister(&X86::GR64_NOSPRegClass);
+    MachineInstr *NewMI =
+        BuildMI(MBB, MI, DebugLoc(), get(X86::MOV64ri), VirtReg).addImm(64);
+    NewMI = addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg),
+                              FrameIdx);
+    MachineOperand &MO = NewMI->getOperand(3);
+    MO.setReg(VirtReg);
+    MO.setIsKill(true);
+  } else if (RC->getID() == X86::TILECFGRegClassID) {
+    unsigned Opc = X86::PLDTILECFG;
+    addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg),
+                      FrameIdx);
+  } else {
+    const MachineFunction &MF = *MBB.getParent();
+    unsigned Alignment = std::max<uint32_t>(TRI->getSpillSize(*RC), 16);
+    bool isAligned =
+        (Subtarget.getFrameLowering()->getStackAlign() >= Alignment) ||
+        RI.canRealignStack(MF);
+    unsigned Opc = getLoadRegOpcode(DestReg, RC, isAligned, Subtarget);
+    addFrameReference(BuildMI(MBB, MI, DebugLoc(), get(Opc), DestReg),
+                      FrameIdx);
+  }
 }
 
 bool X86InstrInfo::analyzeCompare(const MachineInstr &MI, Register &SrcReg,
Index: llvm/lib/Target/X86/X86InstrAMX.td
===================================================================
--- llvm/lib/Target/X86/X86InstrAMX.td
+++ llvm/lib/Target/X86/X86InstrAMX.td
@@ -16,17 +16,21 @@
 
 let Predicates = [HasAMXTILE, In64BitMode] in {
   let SchedRW = [WriteSystem] in {
-    let Defs = [TMM0,TMM1,TMM2,TMM3,TMM4,TMM5,TMM6,TMM7] in
+    let hasSideEffects = 1,
+        Defs = [TMM0,TMM1,TMM2,TMM3,TMM4,TMM5,TMM6,TMM7] in
     def LDTILECFG : I <0x49, MRM0m, (outs), (ins opaquemem:$src),
                        "ldtilecfg\t$src",
                        [(int_x86_ldtilecfg addr:$src)]>, VEX, T8PS;
+    let hasSideEffects = 1 in
     def STTILECFG : I <0x49, MRM0m, (outs), (ins opaquemem:$src),
                        "sttilecfg\t$src",
                        [(int_x86_sttilecfg addr:$src)]>, VEX, T8PD;
+    let mayLoad = 1 in
     def TILELOADD : I<0x4b, MRMSrcMemFSIB, (outs TILE:$dst),
                       (ins sibmem:$src),
                       "tileloadd\t{$src, $dst|$dst, $src}", []>,
                       VEX, T8XD;
+    let mayLoad = 1 in
     def TILELOADDT1 : I<0x4b, MRMSrcMemFSIB, (outs TILE:$dst),
                         (ins sibmem:$src),
                         "tileloaddt1\t{$src, $dst|$dst, $src}", []>,
@@ -34,6 +38,7 @@
     let Defs = [TMM0,TMM1,TMM2,TMM3,TMM4,TMM5,TMM6,TMM7] in
     def TILERELEASE : I<0x49, MRM_C0, (outs), (ins),
                         "tilerelease", [(int_x86_tilerelease)]>, VEX, T8PS;
+    let mayStore = 1 in
     def TILESTORED : I<0x4b, MRMDestMemFSIB, (outs),
                        (ins sibmem:$dst, TILE:$src),
                        "tilestored\t{$src, $dst|$dst, $src}", []>,
@@ -42,6 +47,22 @@
                      "tilezero\t$dst", []>,
                      VEX, T8XD;
 
+    // Pseduo instruction for RA.
+    let hasSideEffects = 1, mayLoad = 1,
+        Defs = [TMM0,TMM1,TMM2,TMM3,TMM4,TMM5,TMM6,TMM7] in
+    def PLDTILECFG : PseudoI <(outs TILECFG:$cfg), (ins opaquemem:$src), []>;
+
+    let hasSideEffects = 1, mayStore = 1 in
+    def PSTTILECFG : PseudoI<(outs), (ins opaquemem:$dst, TILECFG:$cfg), []>;
+
+    def PTILELOADDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
+                                                      GR16:$src2,
+                                                      opaquemem:$src3,
+                                                      TILECFG:$cfg), []>;
+    def PTILESTOREDV : PseudoI<(outs), (ins GR16:$src1,
+                                            GR16:$src2, opaquemem:$src3,
+                                            TILE:$src4, TILECFG:$cfg), []>;
+
     let usesCustomInserter = 1 in {
       // Pseudo instructions, using immediates instead of tile registers.
       // To be translated to the actual instructions in X86ISelLowering.cpp
@@ -76,6 +97,12 @@
                       VEX_4V, T8PS;
     }
 
+    // Pseduo instruction for RA.
+    let Constraints = "$src4 = $dst" in
+    def PTDPBSSDV : PseudoI<(outs TILE: $dst), (ins GR16:$src1,
+                            GR16:$src2, GR16:$src3, TILE:$src4,
+                            TILE:$src5, TILE:$src6, TILECFG:$cfg), []>;
+
     let usesCustomInserter = 1 in {
       // Pseudo instructions, using immediates instead of tile registers.
       // To be translated to the actual instructions in X86ISelLowering.cpp
Index: llvm/lib/Target/X86/X86ISelLowering.cpp
===================================================================
--- llvm/lib/Target/X86/X86ISelLowering.cpp
+++ llvm/lib/Target/X86/X86ISelLowering.cpp
@@ -1885,6 +1885,10 @@
     setOperationAction(ISD::TRUNCATE, MVT::v16i64, Custom);
   }
 
+  if (Subtarget.hasAMXTILE()) {
+    addRegisterClass(MVT::v256i32, &X86::TILERegClass);
+  }
+
   // We want to custom lower some of our intrinsics.
   setOperationAction(ISD::INTRINSIC_WO_CHAIN, MVT::Other, Custom);
   setOperationAction(ISD::INTRINSIC_W_CHAIN, MVT::Other, Custom);
@@ -5324,6 +5328,12 @@
   // width.
   if (MemVT.getSizeInBits() > Subtarget.getPreferVectorWidth())
     return false;
+
+  // Don't merge to x86 amx tile, as we only map MVT::v256i32
+  // to x86 amx tile on amx intrinsics.
+  if (MemVT == MVT::v256i32)
+    return false;
+
   return true;
 }
 
Index: llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
===================================================================
--- llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
+++ llvm/lib/Target/X86/X86ISelDAGToDAG.cpp
@@ -4572,6 +4572,49 @@
       ReplaceNode(Node, Res);
       return;
     }
+    case Intrinsic::x86_tileloadd64_internal: {
+      if (!Subtarget->hasAMXTILE())
+        break;
+      unsigned Opc = X86::PTILELOADDV;
+      // _tile_loadd_internal(row, col, buf, STRIDE)
+      SDValue Base = Node->getOperand(4);
+      SDValue Scale = getI8Imm(1, dl);
+      SDValue Index = Node->getOperand(5);
+      SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32);
+      SDValue Segment = CurDAG->getRegister(0, MVT::i16);
+      SDValue CFG = CurDAG->getRegister(0, MVT::v512i1);
+      SDValue Chain = Node->getOperand(0);
+      MachineSDNode *CNode;
+      SDValue Ops[] = {Node->getOperand(2),
+                       Node->getOperand(3),
+                       Base,
+                       Scale,
+                       Index,
+                       Disp,
+                       Segment,
+                       CFG,
+                       Chain};
+      CNode = CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops);
+      ReplaceNode(Node, CNode);
+      return;
+    }
+    case Intrinsic::x86_tdpbssd_internal: {
+      if (!Subtarget->hasAMXTILE())
+        break;
+      unsigned Opc = X86::PTDPBSSDV;
+      SDValue CFG = CurDAG->getRegister(0, MVT::v512i1);
+      SDValue Ops[] = {Node->getOperand(2),
+                       Node->getOperand(3),
+                       Node->getOperand(4),
+                       Node->getOperand(5),
+                       Node->getOperand(6),
+                       Node->getOperand(7),
+                       CFG};
+      MachineSDNode *CNode =
+          CurDAG->getMachineNode(Opc, dl, {MVT::v256i32, MVT::Other}, Ops);
+      ReplaceNode(Node, CNode);
+      return;
+    }
     }
     break;
   }
@@ -4629,6 +4672,31 @@
 
       break;
     }
+    case Intrinsic::x86_tilestored64_internal: {
+      unsigned Opc = X86::PTILESTOREDV;
+      // _tile_stored_internal(row, col, buf, STRIDE, c)
+      SDValue Base = Node->getOperand(4);
+      SDValue Scale = getI8Imm(1, dl);
+      SDValue Index = Node->getOperand(5);
+      SDValue Disp = CurDAG->getTargetConstant(0, dl, MVT::i32);
+      SDValue Segment = CurDAG->getRegister(0, MVT::i16);
+      SDValue CFG = CurDAG->getRegister(0, MVT::v512i1);
+      SDValue Chain = Node->getOperand(0);
+      MachineSDNode *CNode;
+      SDValue Ops[] = {Node->getOperand(2),
+                       Node->getOperand(3),
+                       Base,
+                       Scale,
+                       Index,
+                       Disp,
+                       Segment,
+                       Node->getOperand(6),
+                       CFG,
+                       Chain};
+      CNode = CurDAG->getMachineNode(Opc, dl, MVT::Other, Ops);
+      ReplaceNode(Node, CNode);
+      return;
+    }
     case Intrinsic::x86_tileloadd64:
     case Intrinsic::x86_tileloaddt164:
     case Intrinsic::x86_tilestored64: {
Index: llvm/lib/Target/X86/X86ExpandPseudo.cpp
===================================================================
--- llvm/lib/Target/X86/X86ExpandPseudo.cpp
+++ llvm/lib/Target/X86/X86ExpandPseudo.cpp
@@ -461,6 +461,39 @@
   case TargetOpcode::ICALL_BRANCH_FUNNEL:
     ExpandICallBranchFunnel(&MBB, MBBI);
     return true;
+  case X86::PLDTILECFG: {
+    MI.RemoveOperand(0);
+    MI.setDesc(TII->get(X86::LDTILECFG));
+    return true;
+  }
+  case X86::PSTTILECFG: {
+    MI.RemoveOperand(MI.getNumOperands() - 1); // Remove $tmmcfg
+    MI.setDesc(TII->get(X86::STTILECFG));
+    return true;
+  }
+  case X86::PTILELOADDV: {
+    MI.RemoveOperand(8); // Remove $tmmcfg
+    for (unsigned i = 2; i > 0; --i)
+      MI.RemoveOperand(i);
+    MI.setDesc(TII->get(X86::TILELOADD));
+    return true;
+  }
+  case X86::PTDPBSSDV: {
+    MI.RemoveOperand(7); // Remove $tmmcfg
+    MI.untieRegOperand(4);
+    for (unsigned i = 3; i > 0; --i)
+      MI.RemoveOperand(i);
+    MI.setDesc(TII->get(X86::TDPBSSD));
+    MI.tieOperands(0, 1);
+    return true;
+  }
+  case X86::PTILESTOREDV: {
+    MI.RemoveOperand(8); // Remove $tmmcfg
+    for (int i = 1; i >= 0; --i)
+      MI.RemoveOperand(i);
+    MI.setDesc(TII->get(X86::TILESTORED));
+    return true;
+  }
   }
   llvm_unreachable("Previous switch has a fallthrough?");
 }
Index: llvm/lib/Target/X86/X86.h
===================================================================
--- llvm/lib/Target/X86/X86.h
+++ llvm/lib/Target/X86/X86.h
@@ -76,6 +76,10 @@
 /// Return a pass that expands WinAlloca pseudo-instructions.
 FunctionPass *createX86WinAllocaExpander();
 
+FunctionPass *createX86TileConfigPass();
+
+FunctionPass *createX86PreTileConfigPass();
+
 /// Return a pass that inserts int3 at the end of the function if it ends with a
 /// CALL instruction. The pass does the same for each funclet as well. This
 /// ensures that the open interval of function start and end PCs contains all
@@ -162,6 +166,9 @@
 void initializeX86PartialReductionPass(PassRegistry &);
 void initializeX86SpeculativeLoadHardeningPassPass(PassRegistry &);
 void initializeX86SpeculativeExecutionSideEffectSuppressionPass(PassRegistry &);
+void initializeX86PreTileConfigPass(PassRegistry &);
+void initializeX86TileConfigPass(PassRegistry &);
+void initializeX86LowerAMXTypeLegacyPassPass(PassRegistry &);
 
 namespace X86AS {
 enum : unsigned {
Index: llvm/lib/Target/X86/CMakeLists.txt
===================================================================
--- llvm/lib/Target/X86/CMakeLists.txt
+++ llvm/lib/Target/X86/CMakeLists.txt
@@ -30,6 +30,9 @@
   X86CmovConversion.cpp
   X86DomainReassignment.cpp
   X86DiscriminateMemOps.cpp
+  X86LowerAMXType.cpp
+  X86TileConfig.cpp
+  X86PreTileConfig.cpp
   X86ExpandPseudo.cpp
   X86FastISel.cpp
   X86FixupBWInsts.cpp
Index: llvm/lib/IR/Function.cpp
===================================================================
--- llvm/lib/IR/Function.cpp
+++ llvm/lib/IR/Function.cpp
@@ -784,25 +784,25 @@
 enum IIT_Info {
   // Common values should be encoded with 0-15.
   IIT_Done = 0,
-  IIT_I1   = 1,
-  IIT_I8   = 2,
-  IIT_I16  = 3,
-  IIT_I32  = 4,
-  IIT_I64  = 5,
-  IIT_F16  = 6,
-  IIT_F32  = 7,
-  IIT_F64  = 8,
-  IIT_V2   = 9,
-  IIT_V4   = 10,
-  IIT_V8   = 11,
-  IIT_V16  = 12,
-  IIT_V32  = 13,
-  IIT_PTR  = 14,
-  IIT_ARG  = 15,
+  IIT_I1 = 1,
+  IIT_I8 = 2,
+  IIT_I16 = 3,
+  IIT_I32 = 4,
+  IIT_I64 = 5,
+  IIT_F16 = 6,
+  IIT_F32 = 7,
+  IIT_F64 = 8,
+  IIT_V2 = 9,
+  IIT_V4 = 10,
+  IIT_V8 = 11,
+  IIT_V16 = 12,
+  IIT_V32 = 13,
+  IIT_PTR = 14,
+  IIT_ARG = 15,
 
   // Values from 16+ are only encodable with the inefficient encoding.
-  IIT_V64  = 16,
-  IIT_MMX  = 17,
+  IIT_V64 = 16,
+  IIT_MMX = 17,
   IIT_TOKEN = 18,
   IIT_METADATA = 19,
   IIT_EMPTYSTRUCT = 20,
@@ -813,7 +813,7 @@
   IIT_EXTEND_ARG = 25,
   IIT_TRUNC_ARG = 26,
   IIT_ANYPTR = 27,
-  IIT_V1   = 28,
+  IIT_V1 = 28,
   IIT_VARARG = 29,
   IIT_HALF_VEC_ARG = 30,
   IIT_SAME_VEC_WIDTH_ARG = 31,
Index: llvm/lib/CodeGen/VirtRegMap.cpp
===================================================================
--- llvm/lib/CodeGen/VirtRegMap.cpp
+++ llvm/lib/CodeGen/VirtRegMap.cpp
@@ -68,6 +68,7 @@
   Virt2PhysMap.clear();
   Virt2StackSlotMap.clear();
   Virt2SplitMap.clear();
+  Virt2ShapeMap.clear();
 
   grow();
   return false;
Index: llvm/lib/CodeGen/LiveRegMatrix.cpp
===================================================================
--- llvm/lib/CodeGen/LiveRegMatrix.cpp
+++ llvm/lib/CodeGen/LiveRegMatrix.cpp
@@ -54,6 +54,7 @@
 
 bool LiveRegMatrix::runOnMachineFunction(MachineFunction &MF) {
   TRI = MF.getSubtarget().getRegisterInfo();
+  MRI = &MF.getRegInfo();
   LIS = &getAnalysis<LiveIntervals>();
   VRM = &getAnalysis<VirtRegMap>();
 
@@ -221,3 +222,13 @@
   }
   return false;
 }
+
+Register LiveRegMatrix::getOneVReg(unsigned PhysReg) const {
+  LiveInterval *VRegInterval = nullptr;
+  for (MCRegUnitIterator Unit(PhysReg, TRI); Unit.isValid(); ++Unit) {
+    if ((VRegInterval = Matrix[*Unit].getOneVReg()))
+      return VRegInterval->reg();
+  }
+
+  return MCRegister::NoRegister;
+}
Index: llvm/lib/CodeGen/LiveIntervalUnion.cpp
===================================================================
--- llvm/lib/CodeGen/LiveIntervalUnion.cpp
+++ llvm/lib/CodeGen/LiveIntervalUnion.cpp
@@ -99,6 +99,16 @@
 }
 #endif //!NDEBUG
 
+LiveInterval *LiveIntervalUnion::getOneVReg() const {
+  if (empty())
+    return nullptr;
+  for (LiveSegments::const_iterator SI = Segments.begin(); SI.valid(); ++SI) {
+    // return the first valid live interval
+    return SI.value();
+  }
+  return nullptr;
+}
+
 // Scan the vector of interfering virtual registers in this union. Assume it's
 // quite small.
 bool LiveIntervalUnion::Query::isSeenInterference(LiveInterval *VirtReg) const {
Index: llvm/lib/CodeGen/InlineSpiller.cpp
===================================================================
--- llvm/lib/CodeGen/InlineSpiller.cpp
+++ llvm/lib/CodeGen/InlineSpiller.cpp
@@ -1556,6 +1556,8 @@
     VRM.assignVirt2Phys(New, VRM.getPhys(Old));
   else if (VRM.getStackSlot(Old) != VirtRegMap::NO_STACK_SLOT)
     VRM.assignVirt2StackSlot(New, VRM.getStackSlot(Old));
+  else if (VRM.hasShape(Old))
+    VRM.assignVirt2Shape(New, VRM.getShape(Old));
   else
     llvm_unreachable("VReg should be assigned either physreg or stackslot");
 }
Index: llvm/include/llvm/IR/IntrinsicsX86.td
===================================================================
--- llvm/include/llvm/IR/IntrinsicsX86.td
+++ llvm/include/llvm/IR/IntrinsicsX86.td
@@ -5056,4 +5056,24 @@
               Intrinsic<[llvm_i8_ty], [], []>;
   def int_x86_senduipi : GCCBuiltin<"__builtin_ia32_senduipi">,
               Intrinsic<[], [llvm_i64_ty], []>;
+// AMX - internal intrinsics
+  def int_x86_tileloadd64_internal :
+              GCCBuiltin<"__builtin_ia32_tileloadd64_internal">,
+              Intrinsic<[llvm_v256i32_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty, llvm_i64_ty],
+                        []>;
+  def int_x86_tilezero_internal :
+              GCCBuiltin<"__builtin_ia32_tilezero_internal">,
+              Intrinsic<[llvm_v256i32_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_v256i32_ty], []>;
+  def int_x86_tdpbssd_internal :
+              GCCBuiltin<"__builtin_ia32_tdpbssd_internal">,
+              Intrinsic<[llvm_v256i32_ty],
+                        [llvm_i16_ty, llvm_i16_ty, llvm_i16_ty,
+                         llvm_v256i32_ty, llvm_v256i32_ty,
+                         llvm_v256i32_ty], []>;
+  def int_x86_tilestored64_internal :
+              GCCBuiltin<"__builtin_ia32_tilestored64_internal">,
+              Intrinsic<[], [llvm_i16_ty, llvm_i16_ty, llvm_ptr_ty,
+                             llvm_i64_ty, llvm_v256i32_ty], []>;
 }
Index: llvm/include/llvm/IR/Intrinsics.td
===================================================================
--- llvm/include/llvm/IR/Intrinsics.td
+++ llvm/include/llvm/IR/Intrinsics.td
@@ -292,6 +292,7 @@
 def llvm_v16i32_ty     : LLVMType<v16i32>;   // 16 x i32
 def llvm_v32i32_ty     : LLVMType<v32i32>;   // 32 x i32
 def llvm_v64i32_ty     : LLVMType<v64i32>;   // 64 x i32
+def llvm_v256i32_ty    : LLVMType<v256i32>;  //256 x i32
 
 def llvm_v1i64_ty      : LLVMType<v1i64>;    //  1 x i64
 def llvm_v2i64_ty      : LLVMType<v2i64>;    //  2 x i64
Index: llvm/include/llvm/CodeGen/VirtRegMap.h
===================================================================
--- llvm/include/llvm/CodeGen/VirtRegMap.h
+++ llvm/include/llvm/CodeGen/VirtRegMap.h
@@ -19,6 +19,7 @@
 #include "llvm/ADT/IndexedMap.h"
 #include "llvm/CodeGen/MachineFunctionPass.h"
 #include "llvm/CodeGen/TargetRegisterInfo.h"
+#include "llvm/CodeGen/TileShapeInfo.h"
 #include "llvm/Pass.h"
 #include <cassert>
 
@@ -60,6 +61,10 @@
     /// mapping.
     IndexedMap<unsigned, VirtReg2IndexFunctor> Virt2SplitMap;
 
+    /// Virt2ShapeMap - For X86 AMX register whose register is bound shape
+    /// information.
+    DenseMap<unsigned, ShapeT> Virt2ShapeMap;
+
     /// createSpillSlot - Allocate a spill slot for RC from MFI.
     unsigned createSpillSlot(const TargetRegisterClass *RC);
 
@@ -107,6 +112,21 @@
     /// the specified physical register
     void assignVirt2Phys(Register virtReg, MCPhysReg physReg);
 
+    bool isShapeMapEmpty() const { return Virt2ShapeMap.empty(); }
+
+    bool hasShape(Register virtReg) const {
+      return getShape(virtReg).isValid();
+    }
+
+    ShapeT getShape(Register virtReg) const {
+      assert(virtReg.isVirtual());
+      return Virt2ShapeMap.lookup(virtReg);
+    }
+
+    void assignVirt2Shape(Register virtReg, ShapeT shape) {
+      Virt2ShapeMap[virtReg.id()] = shape;
+    }
+
     /// clears the specified virtual register's, physical
     /// register mapping
     void clearVirt(Register virtReg) {
@@ -133,6 +153,9 @@
     /// records virtReg is a split live interval from SReg.
     void setIsSplitFromReg(Register virtReg, unsigned SReg) {
       Virt2SplitMap[virtReg.id()] = SReg;
+      if (hasShape(SReg)) {
+        Virt2ShapeMap[virtReg.id()] = getShape(SReg);
+      }
     }
 
     /// returns the live interval virtReg is split from.
Index: llvm/include/llvm/CodeGen/TileShapeInfo.h
===================================================================
--- /dev/null
+++ llvm/include/llvm/CodeGen/TileShapeInfo.h
@@ -0,0 +1,103 @@
+//===- llvm/CodeGen/TileShapeInfo.h - ---------------------------*- C++ -*-===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// This file defines TileShapeInfo for AMX.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef LLVM_CODEGEN_TILESHAPEINFO_H
+#define LLVM_CODEGEN_TILESHAPEINFO_H
+
+#include "llvm/ADT/DenseMapInfo.h"
+#include "llvm/CodeGen/MachineInstr.h"
+#include "llvm/CodeGen/MachineOperand.h"
+#include "llvm/CodeGen/MachineRegisterInfo.h"
+#include "llvm/CodeGen/Register.h"
+#include <utility>
+
+using namespace llvm;
+
+namespace llvm {
+
+class ShapeT {
+public:
+  ShapeT(MachineOperand *Row, MachineOperand *Col,
+         const MachineRegisterInfo *MRI = nullptr)
+      : Row(Row), Col(Col) {
+    if (MRI)
+      deduceImm(MRI);
+  }
+  ShapeT()
+      : Row(nullptr), Col(nullptr), RowImm(InvalidImmShape),
+        ColImm(InvalidImmShape) {}
+  bool operator==(const ShapeT &Shape) {
+    MachineOperand *R = Shape.Row;
+    MachineOperand *C = Shape.Col;
+    if (!R || !C)
+      return false;
+    if (!Row || !Col)
+      return false;
+    if (Row->getReg() == R->getReg() && Col->getReg() == C->getReg())
+      return true;
+    if ((RowImm != InvalidImmShape) && (Shape.getRowImm() != InvalidImmShape) &&
+        (ColImm != InvalidImmShape) && (Shape.getColImm() != InvalidImmShape)) {
+      return RowImm == Shape.getRowImm() && ColImm == Shape.getColImm();
+    }
+    return false;
+  }
+
+  bool operator!=(const ShapeT &Shape) { return !(*this == Shape); }
+
+  ShapeT &operator=(const ShapeT &RHS) {
+    Row = RHS.Row;
+    Col = RHS.Col;
+    RowImm = RHS.RowImm;
+    ColImm = RHS.ColImm;
+    return *this;
+  }
+
+  MachineOperand *getRow() const { return Row; }
+
+  MachineOperand *getCol() const { return Col; }
+
+  int64_t getRowImm() const { return RowImm; }
+
+  int64_t getColImm() const { return ColImm; }
+
+  bool isValid() { return (Row != nullptr) && (Col != nullptr); }
+
+  void deduceImm(const MachineRegisterInfo *MRI) {
+    // All def must be the same value, otherwise it is invalid MIs.
+    // Find the immediate.
+    // TODO copy propagation.
+    auto GetImm = [&](Register Reg) {
+      int64_t Imm = InvalidImmShape;
+      for (const MachineOperand &DefMO : MRI->def_operands(Reg)) {
+        const auto *MI = DefMO.getParent();
+        if (MI->isMoveImmediate()) {
+          Imm = MI->getOperand(1).getImm();
+          break;
+        }
+      }
+      return Imm;
+    };
+    RowImm = GetImm(Row->getReg());
+    ColImm = GetImm(Col->getReg());
+  }
+
+private:
+  static constexpr int64_t InvalidImmShape = -1;
+  MachineOperand *Row;
+  MachineOperand *Col;
+  int64_t RowImm;
+  int64_t ColImm;
+};
+
+} // namespace llvm
+
+#endif
Index: llvm/include/llvm/CodeGen/Passes.h
===================================================================
--- llvm/include/llvm/CodeGen/Passes.h
+++ llvm/include/llvm/CodeGen/Passes.h
@@ -490,6 +490,8 @@
   /// The pass fixups statepoint machine instruction to replace usage of
   /// caller saved registers with stack slots.
   extern char &FixupStatepointCallerSavedID;
+
+  FunctionPass *createX86LowerAMXTypePass();
 } // End llvm namespace
 
 #endif
Index: llvm/include/llvm/CodeGen/LiveRegMatrix.h
===================================================================
--- llvm/include/llvm/CodeGen/LiveRegMatrix.h
+++ llvm/include/llvm/CodeGen/LiveRegMatrix.h
@@ -41,6 +41,7 @@
   const TargetRegisterInfo *TRI;
   LiveIntervals *LIS;
   VirtRegMap *VRM;
+  MachineRegisterInfo *MRI;
 
   // UserTag changes whenever virtual registers have been modified.
   unsigned UserTag = 0;
@@ -153,6 +154,8 @@
   /// Directly access the live interval unions per regunit.
   /// This returns an array indexed by the regunit number.
   LiveIntervalUnion *getLiveUnions() { return &Matrix[0]; }
+
+  Register getOneVReg(unsigned PhysReg) const;
 };
 
 } // end namespace llvm
Index: llvm/include/llvm/CodeGen/LiveIntervalUnion.h
===================================================================
--- llvm/include/llvm/CodeGen/LiveIntervalUnion.h
+++ llvm/include/llvm/CodeGen/LiveIntervalUnion.h
@@ -104,6 +104,9 @@
   void verify(LiveVirtRegBitSet& VisitedVRegs);
 #endif
 
+  // Get any virtual register that is assign to this physical unit
+  LiveInterval *getOneVReg() const;
+
   /// Query interferences between a single live virtual register and a live
   /// interval union.
   class Query {
Index: clang/test/CodeGen/AMX/amx_api.c
===================================================================
--- /dev/null
+++ clang/test/CodeGen/AMX/amx_api.c
@@ -0,0 +1,31 @@
+// RUN: %clang_cc1 %s -ffreestanding -triple=x86_64-unknown-unknown  -target-feature +avx512f  -target-feature +amx-int8  \
+// RUN: -target-feature +amx-bf16 -emit-llvm -o - -Werror -pedantic | FileCheck %s --check-prefixes=CHECK
+
+#include <immintrin.h>
+
+char buf[1024];
+#define STRIDE 32
+
+char buf2[1024];
+
+void test_api(int cond, short row, short col) {
+  //CHECK-LABEL: @test_api
+  //CHECK: call <256 x i32> @llvm.x86.tileloadd64.internal
+  //CHECK: call <256 x i32> @llvm.x86.tdpbssd.internal
+  //CHECK: call void @llvm.x86.tilestored64.internal
+  __tile a = {row, 8};
+  __tile b = {8, col};
+  __tile c = {row, col};
+
+  if (cond) {
+    __tile_loadd(&a, buf, STRIDE);
+    __tile_loadd(&b, buf, STRIDE);
+    __tile_loadd(&c, buf, STRIDE);
+  } else {
+    __tile_loadd(&a, buf2, STRIDE);
+    __tile_loadd(&b, buf2, STRIDE);
+    __tile_loadd(&c, buf2, STRIDE);
+  }
+  __tile_dpbsud(&c, a, b);
+  __tile_stored(buf, STRIDE, c);
+}
Index: clang/lib/Headers/amxintrin.h
===================================================================
--- clang/lib/Headers/amxintrin.h
+++ clang/lib/Headers/amxintrin.h
@@ -66,6 +66,8 @@
   __builtin_ia32_tilerelease();
 }
 
+#undef __DEFAULT_FN_ATTRS
+
 /// Load tile rows from memory specifieid by "base" address and "stride" into
 /// destination tile "dst" using the tile configuration previously configured
 /// via "_tile_loadconfig".
@@ -219,6 +221,56 @@
 #define _tile_dpbf16ps(dst, src0, src1) \
   __builtin_ia32_tdpbf16ps((dst), (src0), (src1))
 
+#define __DEFAULT_FN_ATTRS                                                     \
+  __attribute__((__always_inline__, __nodebug__, __target__("amx-int8")))
+
+/// This is new intrinsic interface
+typedef int _tile_data __attribute__((__vector_size__(1024), __aligned__(64)));
+static __inline__ _tile_data __DEFAULT_FN_ATTRS
+_tile_loadd_internal(short m, short n, const void *base, int stride) {
+  return __builtin_ia32_tileloadd64_internal(m, n, base,
+                                             (__SIZE_TYPE__)(stride));
+}
+
+static __inline__ _tile_data __DEFAULT_FN_ATTRS
+_tile_zero_internal(short m, short n, _tile_data tile) {
+  return __builtin_ia32_tilezero_internal(m, n, tile);
+}
+
+static __inline__ _tile_data __DEFAULT_FN_ATTRS
+_tile_dpbssd_internal(short m, short n, short k, _tile_data dst,
+                      _tile_data src1, _tile_data src2) {
+  return __builtin_ia32_tdpbssd_internal(m, n, k, dst, src1, src2);
+}
+
+static __inline__ void __DEFAULT_FN_ATTRS _tile_stored_internal(
+    short m, short n, void *base, int stride, _tile_data tile) {
+  return __builtin_ia32_tilestored64_internal(m, n, base,
+                                              (__SIZE_TYPE__)(stride), tile);
+}
+
+typedef struct __tile_str {
+  const short row;
+  const short col;
+  _tile_data tile;
+} __tile;
+
+__DEFAULT_FN_ATTRS
+void __tile_loadd(__tile *dst, const void *base, long stride) {
+  dst->tile = _tile_loadd_internal(dst->row, dst->col, base, stride);
+}
+
+__DEFAULT_FN_ATTRS
+void __tile_dpbsud(__tile *dst, __tile src1, __tile src2) {
+  dst->tile = _tile_dpbssd_internal(src1.row, src2.col, src1.col, dst->tile,
+                                    src1.tile, src2.tile);
+}
+
+__DEFAULT_FN_ATTRS
+void __tile_stored(void *base, long stride, __tile src) {
+  _tile_stored_internal(src.row, src.col, base, stride, src.tile);
+}
+
 #undef __DEFAULT_FN_ATTRS
 
 #endif /* __x86_64__ */
Index: clang/include/clang/Basic/BuiltinsX86_64.def
===================================================================
--- clang/include/clang/Basic/BuiltinsX86_64.def
+++ clang/include/clang/Basic/BuiltinsX86_64.def
@@ -100,6 +100,11 @@
 TARGET_BUILTIN(__builtin_ia32_testui, "Uc", "n", "uintr")
 TARGET_BUILTIN(__builtin_ia32_senduipi, "vUWi", "n", "uintr")
 
+// AMX internal builtin
+TARGET_BUILTIN(__builtin_ia32_tileloadd64_internal, "V256iUsUsvC*z", "n", "amx-tile")
+TARGET_BUILTIN(__builtin_ia32_tilezero_internal, "V256iUsUsV256i", "n", "amx-tile")
+TARGET_BUILTIN(__builtin_ia32_tdpbssd_internal, "V256iUsUsUsV256iV256iV256i", "n", "amx-int8")
+TARGET_BUILTIN(__builtin_ia32_tilestored64_internal, "vUsUsv*zV256i", "n", "amx-tile")
 // AMX
 TARGET_BUILTIN(__builtin_ia32_tile_loadconfig, "vvC*", "n", "amx-tile")
 TARGET_BUILTIN(__builtin_ia32_tile_storeconfig, "vvC*", "n", "amx-tile")
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to