From: Junyan He <[email protected]> We use slm to store the value which will be broadcasted to the whole work group.
Signed-off-by: Junyan He <[email protected]> --- backend/src/backend/gen_insn_selection.cpp | 92 ++++++++++++++++++++++++++++ 1 file changed, 92 insertions(+) diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp index c240261..f8f1d29 100644 --- a/backend/src/backend/gen_insn_selection.cpp +++ b/backend/src/backend/gen_insn_selection.cpp @@ -4107,6 +4107,97 @@ namespace gbe DECL_CTOR(AtomicInstruction, 1, 1); }; + /*! WorkGroup instruction pattern */ + DECL_PATTERN(WorkGroupInstruction) + { + INLINE bool emitWGBroadcast(Selection::Opaque &sel, const ir::WorkGroupInstruction &insn) const { + /* 1. BARRIER Ensure all the threads have set the correct value for the var which will be broadcasted. + 2. CMP IDs Compare the local IDs with the specified ones in the function call. + 3. STORE Use flag to control the store of the var. Only the specified item will execute the store. + 4. BARRIER Ensure the specified value has been stored. + 5. LOAD Load the stored value to all the dst value, the dst of all the items will have same value, + so broadcasted. */ + using namespace ir; + const Type type = insn.getType(); + const GenRegister src = sel.selReg(insn.getSrc(0), type); + const GenRegister dst = sel.selReg(insn.getDst(0), type); + const uint32_t srcNum = insn.getSrcNum(); + const uint32_t simdWidth = sel.ctx.getSimdWidth(); + const uint32_t slmAddr = insn.getSlmAddr(); + GenRegister addr = sel.selReg(sel.reg(FAMILY_DWORD), ir::TYPE_U32); + + /* Then we insert a barrier to make sure all the var we are interested in + have been assigned the final value. */ + sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)), sel.selReg(sel.reg(FAMILY_DWORD)), syncLocalBarrier); + + GBE_ASSERT(srcNum >= 2); + GenRegister coords[3]; + for (uint32_t i = 1; i < srcNum; i++) { + coords[i - 1] = sel.selReg(insn.getSrc(i), TYPE_U32); + } + + sel.push(); { + sel.curr.predicate = GEN_PREDICATE_NONE; + sel.curr.noMask = 1; + sel.MOV(addr, GenRegister::immud(slmAddr)); + } sel.pop(); + + sel.push(); { + sel.curr.flag = 0; + sel.curr.subFlag = 1; + sel.curr.predicate = GEN_PREDICATE_NONE; + sel.curr.noMask = 1; + GenRegister lid0, lid1, lid2; + uint32_t dim = srcNum - 1; + if (simdWidth == 16) { + lid0 = GenRegister::ud16grf(ir::ocl::lid0); + lid1 = GenRegister::ud16grf(ir::ocl::lid1); + lid2 = GenRegister::ud16grf(ir::ocl::lid2); + } else { + lid0 = GenRegister::ud8grf(ir::ocl::lid0); + lid1 = GenRegister::ud8grf(ir::ocl::lid1); + lid2 = GenRegister::ud8grf(ir::ocl::lid2); + } + + sel.CMP(GEN_CONDITIONAL_EQ, coords[0], lid0, GenRegister::retype(GenRegister::null(), GEN_TYPE_UD)); + sel.curr.predicate = GEN_PREDICATE_NORMAL; + if (dim >= 2) + sel.CMP(GEN_CONDITIONAL_EQ, coords[1], lid1, GenRegister::retype(GenRegister::null(), GEN_TYPE_UD)); + if (dim >= 3) + sel.CMP(GEN_CONDITIONAL_EQ, coords[2], lid2, GenRegister::retype(GenRegister::null(), GEN_TYPE_UD)); + + if (typeSize(src.type) == 4) { + GenRegister _addr = GenRegister::retype(addr, GEN_TYPE_F); + GenRegister _src = GenRegister::retype(src, GEN_TYPE_F); + sel.UNTYPED_WRITE(_addr, &_src, 1, 0xfe); + } + } sel.pop(); + + /* Make sure the slm var have the valid value now */ + sel.BARRIER(GenRegister::ud8grf(sel.reg(FAMILY_DWORD)), sel.selReg(sel.reg(FAMILY_DWORD)), syncLocalBarrier); + + if (typeSize(src.type) == 4) { + sel.UNTYPED_READ(addr, &dst, 1, 0xfe); + } + + return true; + } + + INLINE bool emitOne(Selection::Opaque &sel, const ir::WorkGroupInstruction &insn, bool &markChildren) const + { + using namespace ir; + const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode(); + + if (workGroupOp == WORKGROUP_OP_BROADCAST) { + return emitWGBroadcast(sel, insn); + } else { + GBE_ASSERT(0); + } + return true; + } + DECL_CTOR(WorkGroupInstruction, 1, 1); + }; + /*! Select instruction pattern */ class SelectInstructionPattern : public SelectionPattern { @@ -4789,6 +4880,7 @@ namespace gbe this->insert<GetImageInfoInstructionPattern>(); this->insert<ReadARFInstructionPattern>(); this->insert<RegionInstructionPattern>(); + this->insert<WorkGroupInstructionPattern>(); // Sort all the patterns with the number of instructions they output for (uint32_t op = 0; op < ir::OP_INVALID; ++op) -- 1.7.9.5 _______________________________________________ Beignet mailing list [email protected] http://lists.freedesktop.org/mailman/listinfo/beignet
