From: Pan Xiuli <xiuli....@intel.com> Add sub_group_reduce/exclusive/inclusive_max/min/add builtin functions. They share the in thread algorithm of work group functions.
Signed-off-by: Pan Xiuli <xiuli....@intel.com> --- backend/src/backend/gen8_context.cpp | 23 ++++ backend/src/backend/gen8_context.hpp | 1 + backend/src/backend/gen_context.cpp | 23 ++++ backend/src/backend/gen_context.hpp | 1 + .../src/backend/gen_insn_gen7_schedule_info.hxx | 1 + backend/src/backend/gen_insn_selection.cpp | 116 +++++++++++++++++ backend/src/backend/gen_insn_selection.hxx | 1 + backend/src/ir/instruction.cpp | 144 +++++++++++++++++++++ backend/src/ir/instruction.hpp | 11 ++ backend/src/ir/instruction.hxx | 1 + backend/src/libocl/tmpl/ocl_simd.tmpl.cl | 98 ++++++++++++++ backend/src/libocl/tmpl/ocl_simd.tmpl.h | 95 ++++++++++++++ backend/src/llvm/llvm_gen_backend.cpp | 74 +++++++++++ backend/src/llvm/llvm_gen_ocl_function.hxx | 15 +++ 14 files changed, 604 insertions(+) diff --git a/backend/src/backend/gen8_context.cpp b/backend/src/backend/gen8_context.cpp index 477b22b..7ddb95a 100644 --- a/backend/src/backend/gen8_context.cpp +++ b/backend/src/backend/gen8_context.cpp @@ -1845,4 +1845,27 @@ namespace gbe } } + void Gen8Context::emitSubGroupOpInstruction(const SelectionInstruction &insn){ + const GenRegister dst = ra->genReg(insn.dst(0)); + const GenRegister tmp = GenRegister::retype(ra->genReg(insn.dst(1)), dst.type); + const GenRegister theVal = GenRegister::retype(ra->genReg(insn.src(0)), dst.type); + GenRegister threadData = ra->genReg(insn.src(1)); + + uint32_t wg_op = insn.extra.workgroupOp; + uint32_t simd = p->curr.execWidth; + + /* masked elements should be properly set to init value */ + p->push(); { + p->curr.noMask = 1; + wgOpInitValue(p, tmp, wg_op); + p->curr.noMask = 0; + p->MOV(tmp, theVal); + p->curr.noMask = 1; + p->MOV(theVal, tmp); + } p->pop(); + + /* do some calculation within each thread */ + wgOpPerformThread(dst, theVal, threadData, tmp, simd, wg_op, p); + } + } diff --git a/backend/src/backend/gen8_context.hpp b/backend/src/backend/gen8_context.hpp index 771e20b..ec1358c 100644 --- a/backend/src/backend/gen8_context.hpp +++ b/backend/src/backend/gen8_context.hpp @@ -77,6 +77,7 @@ namespace gbe virtual void emitF64DIVInstruction(const SelectionInstruction &insn); virtual void emitWorkGroupOpInstruction(const SelectionInstruction &insn); + virtual void emitSubGroupOpInstruction(const SelectionInstruction &insn); static GenRegister unpacked_ud(GenRegister reg, uint32_t offset = 0); diff --git a/backend/src/backend/gen_context.cpp b/backend/src/backend/gen_context.cpp index 4e24816..4d0a3f3 100644 --- a/backend/src/backend/gen_context.cpp +++ b/backend/src/backend/gen_context.cpp @@ -3374,6 +3374,29 @@ namespace gbe } } + void GenContext::emitSubGroupOpInstruction(const SelectionInstruction &insn){ + const GenRegister dst = ra->genReg(insn.dst(0)); + const GenRegister tmp = GenRegister::retype(ra->genReg(insn.dst(1)), dst.type); + const GenRegister theVal = GenRegister::retype(ra->genReg(insn.src(0)), dst.type); + GenRegister threadData = ra->genReg(insn.src(1)); + + uint32_t wg_op = insn.extra.workgroupOp; + uint32_t simd = p->curr.execWidth; + + /* masked elements should be properly set to init value */ + p->push(); { + p->curr.noMask = 1; + wgOpInitValue(p, tmp, wg_op); + p->curr.noMask = 0; + p->MOV(tmp, theVal); + p->curr.noMask = 1; + p->MOV(theVal, tmp); + } p->pop(); + + /* do some calculation within each thread */ + wgOpPerformThread(dst, theVal, threadData, tmp, simd, wg_op, p); + } + void GenContext::emitPrintfLongInstruction(GenRegister& addr, GenRegister& data, GenRegister& src, uint32_t bti) { p->MOV(GenRegister::retype(data, GEN_TYPE_UD), src.bottom_half()); diff --git a/backend/src/backend/gen_context.hpp b/backend/src/backend/gen_context.hpp index ebc55e6..4c43ccb 100644 --- a/backend/src/backend/gen_context.hpp +++ b/backend/src/backend/gen_context.hpp @@ -181,6 +181,7 @@ namespace gbe void emitCalcTimestampInstruction(const SelectionInstruction &insn); void emitStoreProfilingInstruction(const SelectionInstruction &insn); virtual void emitWorkGroupOpInstruction(const SelectionInstruction &insn); + virtual void emitSubGroupOpInstruction(const SelectionInstruction &insn); void emitPrintfInstruction(const SelectionInstruction &insn); void scratchWrite(const GenRegister header, uint32_t offset, uint32_t reg_num, uint32_t reg_type, uint32_t channel_mode); void scratchRead(const GenRegister dst, const GenRegister header, uint32_t offset, uint32_t reg_num, uint32_t reg_type, uint32_t channel_mode); diff --git a/backend/src/backend/gen_insn_gen7_schedule_info.hxx b/backend/src/backend/gen_insn_gen7_schedule_info.hxx index 112df32..cb5c4f1 100644 --- a/backend/src/backend/gen_insn_gen7_schedule_info.hxx +++ b/backend/src/backend/gen_insn_gen7_schedule_info.hxx @@ -48,4 +48,5 @@ DECL_GEN7_SCHEDULE(F64DIV, 20, 40, 20) DECL_GEN7_SCHEDULE(CalcTimestamp, 80, 1, 1) DECL_GEN7_SCHEDULE(StoreProfiling, 80, 1, 1) DECL_GEN7_SCHEDULE(WorkGroupOp, 80, 1, 1) +DECL_GEN7_SCHEDULE(SubGroupOp, 80, 1, 1) DECL_GEN7_SCHEDULE(Printf, 80, 1, 1) diff --git a/backend/src/backend/gen_insn_selection.cpp b/backend/src/backend/gen_insn_selection.cpp index 09f459a..855c39d 100644 --- a/backend/src/backend/gen_insn_selection.cpp +++ b/backend/src/backend/gen_insn_selection.cpp @@ -694,6 +694,9 @@ namespace gbe GenRegister tmpData2, GenRegister slmOff, vector<GenRegister> msg, uint32_t msgSizeReq, GenRegister localBarrier); + /*! Sub Group Operations */ + void SUBGROUP_OP(uint32_t wg_op, Reg dst, GenRegister src, + GenRegister tmpData1, GenRegister tmpData2); /* common functions for both binary instruction and sel_cmp and compare instruction. It will handle the IMM or normal register assignment, and will try to avoid LOADI as much as possible. */ @@ -1995,6 +1998,23 @@ namespace gbe insn->src(5) = localBarrier; } + void Selection::Opaque::SUBGROUP_OP(uint32_t wg_op, + Reg dst, + GenRegister src, + GenRegister tmpData1, + GenRegister tmpData2) + { + SelectionInstruction *insn = this->appendInsn(SEL_OP_SUBGROUP_OP, 2, 2); + + insn->extra.workgroupOp = wg_op; + + insn->dst(0) = dst; + insn->dst(1) = tmpData1; + + insn->src(0) = src; + insn->src(1) = tmpData2; + } + // Boiler plate to initialize the selection library at c++ pre-main static SelectionLibrary *selLib = NULL; static void destroySelectionLibrary(void) { GBE_DELETE(selLib); } @@ -6399,6 +6419,101 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp DECL_CTOR(WorkGroupInstruction, 1, 1); }; + /*! SubGroup instruction pattern */ + class SubGroupInstructionPattern : public SelectionPattern + { + public: + SubGroupInstructionPattern(void) : SelectionPattern(1,1) { + for (uint32_t op = 0; op < ir::OP_INVALID; ++op) + if (ir::isOpcodeFrom<ir::SubGroupInstruction>(ir::Opcode(op)) == true) + this->opcodes.push_back(ir::Opcode(op)); + } + + /* SUBGROUP OP: ALL, ANY, REDUCE, SCAN INCLUSIVE, SCAN EXCLUSIVE + * Shared algorithm with workgroup inthread */ + INLINE bool emitSGReduce(Selection::Opaque &sel, const ir::SubGroupInstruction &insn) const + { + using namespace ir; + + GBE_ASSERT(insn.getSrcNum() == 1); + + const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode(); + const Type type = insn.getType(); + GenRegister dst = sel.selReg(insn.getDst(0), type); + GenRegister src = sel.selReg(insn.getSrc(0), type); + GenRegister tmpData1 = GenRegister::retype(sel.selReg(sel.reg(FAMILY_QWORD)), type); + GenRegister tmpData2 = GenRegister::retype(sel.selReg(sel.reg(FAMILY_QWORD)), type); + + /* Perform workgroup op */ + sel.SUBGROUP_OP(workGroupOp, dst, src, tmpData1, tmpData2); + + return true; + } + + /* SUBROUP OP: BROADCAST + * Shared algorithm with simd shuffle */ + INLINE bool emitSGBroadcast(Selection::Opaque &sel, const ir::SubGroupInstruction &insn, SelectionDAG &dag) const + { + using namespace ir; + + GBE_ASSERT(insn.getSrcNum() == 2); + + const Type type = insn.getType(); + const GenRegister src0 = sel.selReg(insn.getSrc(0), type); + const GenRegister dst = sel.selReg(insn.getDst(0), type); + GenRegister src1; + + SelectionDAG *dag0 = dag.child[0]; + SelectionDAG *dag1 = dag.child[1]; + if (dag1 != NULL && dag1->insn.getOpcode() == OP_LOADI && canGetRegisterFromImmediate(dag1->insn)) { + const auto &childInsn = cast<LoadImmInstruction>(dag1->insn); + src1 = getRegisterFromImmediate(childInsn.getImmediate(), TYPE_U32); + if (dag0) dag0->isRoot = 1; + } else { + markAllChildren(dag); + src1 = sel.selReg(insn.getSrc(1), TYPE_U32); + } + + sel.push(); { + if (src1.file == GEN_IMMEDIATE_VALUE) { + uint32_t offset = src1.value.ud % sel.curr.execWidth; + GenRegister reg = GenRegister::subphysicaloffset(src0, offset); + reg.vstride = GEN_VERTICAL_STRIDE_0; + reg.hstride = GEN_HORIZONTAL_STRIDE_0; + reg.width = GEN_WIDTH_1; + sel.MOV(dst, reg); + } else { + GenRegister shiftL = sel.selReg(sel.reg(FAMILY_DWORD), TYPE_U32); + sel.SHL(shiftL, src1, GenRegister::immud(0x2)); + sel.SIMD_SHUFFLE(dst, src0, shiftL); + } + } sel.pop(); + + return true; + } + + INLINE bool emit(Selection::Opaque &sel, SelectionDAG &dag) const + { + using namespace ir; + const ir::SubGroupInstruction &insn = cast<SubGroupInstruction>(dag.insn); + const WorkGroupOps workGroupOp = insn.getWorkGroupOpcode(); + + if (workGroupOp == WORKGROUP_OP_BROADCAST){ + return emitSGBroadcast(sel, insn, dag); + } + else if (workGroupOp >= WORKGROUP_OP_ANY && workGroupOp <= WORKGROUP_OP_EXCLUSIVE_MAX){ + if(emitSGReduce(sel, insn)) + markAllChildren(dag); + else + return false; + } + else + GBE_ASSERT(0); + + return true; + } + }; + /*! Sort patterns */ INLINE bool cmp(const SelectionPattern *p0, const SelectionPattern *p1) { if (p0->insnNum != p1->insnNum) @@ -6436,6 +6551,7 @@ extern bool OCL_DEBUGINFO; // first defined by calling BVAR in program.cpp this->insert<CalcTimestampInstructionPattern>(); this->insert<StoreProfilingInstructionPattern>(); this->insert<WorkGroupInstructionPattern>(); + this->insert<SubGroupInstructionPattern>(); this->insert<NullaryInstructionPattern>(); this->insert<WaitInstructionPattern>(); this->insert<PrintfInstructionPattern>(); diff --git a/backend/src/backend/gen_insn_selection.hxx b/backend/src/backend/gen_insn_selection.hxx index 4352490..0e11f9f 100644 --- a/backend/src/backend/gen_insn_selection.hxx +++ b/backend/src/backend/gen_insn_selection.hxx @@ -94,4 +94,5 @@ DECL_SELECTION_IR(F64DIV, F64DIVInstruction) DECL_SELECTION_IR(CALC_TIMESTAMP, CalcTimestampInstruction) DECL_SELECTION_IR(STORE_PROFILING, StoreProfilingInstruction) DECL_SELECTION_IR(WORKGROUP_OP, WorkGroupOpInstruction) +DECL_SELECTION_IR(SUBGROUP_OP, SubGroupOpInstruction) DECL_SELECTION_IR(PRINTF, PrintfInstruction) diff --git a/backend/src/ir/instruction.cpp b/backend/src/ir/instruction.cpp index d9051ab..47606b2 100644 --- a/backend/src/ir/instruction.cpp +++ b/backend/src/ir/instruction.cpp @@ -994,6 +994,33 @@ namespace ir { Register dst[1]; }; + class ALIGNED_INSTRUCTION SubGroupInstruction : + public BasePolicy, + public TupleSrcPolicy<SubGroupInstruction>, + public NDstPolicy<SubGroupInstruction, 1> + { + public: + INLINE SubGroupInstruction(WorkGroupOps opcode, Register dst, + Tuple srcTuple, uint8_t srcNum, Type type) { + this->opcode = OP_SUBGROUP; + this->workGroupOp = opcode; + this->type = type; + this->dst[0] = dst; + this->src = srcTuple; + this->srcNum = srcNum; + } + INLINE Type getType(void) const { return this->type; } + INLINE bool wellFormed(const Function &fn, std::string &whyNot) const; + INLINE void out(std::ostream &out, const Function &fn) const; + INLINE WorkGroupOps getWorkGroupOpcode(void) const { return this->workGroupOp; } + + WorkGroupOps workGroupOp:5; + uint32_t srcNum:3; //!< Source Number + Type type; //!< Type of the instruction + Tuple src; + Register dst[1]; + }; + class ALIGNED_INSTRUCTION PrintfInstruction : public BasePolicy, public TupleSrcPolicy<PrintfInstruction>, @@ -1505,6 +1532,52 @@ namespace ir { return true; } + INLINE bool SubGroupInstruction::wellFormed(const Function &fn, std::string &whyNot) const { + const RegisterFamily family = getFamily(this->type); + + if (UNLIKELY(checkSpecialRegForWrite(dst[0], fn, whyNot) == false)) + return false; + if (UNLIKELY(checkRegisterData(family, dst[0], fn, whyNot) == false)) + return false; + + switch (this->workGroupOp) { + case WORKGROUP_OP_ANY: + case WORKGROUP_OP_ALL: + case WORKGROUP_OP_REDUCE_ADD: + case WORKGROUP_OP_REDUCE_MIN: + case WORKGROUP_OP_REDUCE_MAX: + case WORKGROUP_OP_INCLUSIVE_ADD: + case WORKGROUP_OP_INCLUSIVE_MIN: + case WORKGROUP_OP_INCLUSIVE_MAX: + case WORKGROUP_OP_EXCLUSIVE_ADD: + case WORKGROUP_OP_EXCLUSIVE_MIN: + case WORKGROUP_OP_EXCLUSIVE_MAX: + if (this->srcNum != 1) { + whyNot = "Wrong number of source."; + return false; + } + break; + case WORKGROUP_OP_BROADCAST: + if (this->srcNum != 2) { + whyNot = "Wrong number of source."; + return false; + } else { + const RegisterFamily fam = fn.getPointerFamily(); + for (uint32_t srcID = 1; srcID < this->srcNum; ++srcID) { + const Register regID = fn.getRegister(src, srcID); + if (UNLIKELY(checkRegisterData(fam, regID, fn, whyNot) == false)) + return false; + } + } + break; + default: + whyNot = "No such sub group function."; + return false; + } + + return true; + } + INLINE bool PrintfInstruction::wellFormed(const Function &fn, std::string &whyNot) const { return true; } @@ -1739,6 +1812,67 @@ namespace ir { out << "TheadID Map at SLM: " << this->slmAddr; } + INLINE void SubGroupInstruction::out(std::ostream &out, const Function &fn) const { + this->outOpcode(out); + + switch (this->workGroupOp) { + case WORKGROUP_OP_ANY: + out << "_" << "ANY"; + break; + case WORKGROUP_OP_ALL: + out << "_" << "ALL"; + break; + case WORKGROUP_OP_REDUCE_ADD: + out << "_" << "REDUCE_ADD"; + break; + case WORKGROUP_OP_REDUCE_MIN: + out << "_" << "REDUCE_MIN"; + break; + case WORKGROUP_OP_REDUCE_MAX: + out << "_" << "REDUCE_MAX"; + break; + case WORKGROUP_OP_INCLUSIVE_ADD: + out << "_" << "INCLUSIVE_ADD"; + break; + case WORKGROUP_OP_INCLUSIVE_MIN: + out << "_" << "INCLUSIVE_MIN"; + break; + case WORKGROUP_OP_INCLUSIVE_MAX: + out << "_" << "INCLUSIVE_MAX"; + break; + case WORKGROUP_OP_EXCLUSIVE_ADD: + out << "_" << "EXCLUSIVE_ADD"; + break; + case WORKGROUP_OP_EXCLUSIVE_MIN: + out << "_" << "EXCLUSIVE_MIN"; + break; + case WORKGROUP_OP_EXCLUSIVE_MAX: + out << "_" << "EXCLUSIVE_MAX"; + break; + case WORKGROUP_OP_BROADCAST: + out << "_" << "BROADCAST"; + break; + default: + GBE_ASSERT(0); + } + + out << " %" << this->getDst(fn, 0); + out << " %" << this->getSrc(fn, 0); + + if (this->workGroupOp == WORKGROUP_OP_BROADCAST) { + do { + int localN = srcNum - 1; + GBE_ASSERT(localN); + out << " Local ID:"; + out << " %" << this->getSrc(fn, 1); + localN--; + if (!localN) + break; + } while(0); + } + + } + INLINE void PrintfInstruction::out(std::ostream &out, const Function &fn) const { this->outOpcode(out); } @@ -1903,6 +2037,10 @@ START_INTROSPECTION(WorkGroupInstruction) #include "ir/instruction.hxx" END_INTROSPECTION(WorkGroupInstruction) +START_INTROSPECTION(SubGroupInstruction) +#include "ir/instruction.hxx" +END_INTROSPECTION(SubGroupInstruction) + START_INTROSPECTION(PrintfInstruction) #include "ir/instruction.hxx" END_INTROSPECTION(PrintfInstruction) @@ -2117,6 +2255,8 @@ DECL_MEM_FN(StoreProfilingInstruction, uint32_t, getBTI(void), getBTI()) DECL_MEM_FN(WorkGroupInstruction, Type, getType(void), getType()) DECL_MEM_FN(WorkGroupInstruction, WorkGroupOps, getWorkGroupOpcode(void), getWorkGroupOpcode()) DECL_MEM_FN(WorkGroupInstruction, uint32_t, getSlmAddr(void), getSlmAddr()) +DECL_MEM_FN(SubGroupInstruction, Type, getType(void), getType()) +DECL_MEM_FN(SubGroupInstruction, WorkGroupOps, getWorkGroupOpcode(void), getWorkGroupOpcode()) DECL_MEM_FN(PrintfInstruction, uint32_t, getNum(void), getNum()) DECL_MEM_FN(PrintfInstruction, uint32_t, getBti(void), getBti()) DECL_MEM_FN(PrintfInstruction, Type, getType(const Function& fn, uint32_t ID), getType(fn, ID)) @@ -2418,6 +2558,10 @@ DECL_MEM_FN(MemInstruction, void, setBtiReg(Register reg), setBtiReg(reg)) return internal::WorkGroupInstruction(opcode, slmAddr, dst, srcTuple, srcNum, type).convert(); } + Instruction SUBGROUP(WorkGroupOps opcode, Register dst, Tuple srcTuple, uint8_t srcNum, Type type) { + return internal::SubGroupInstruction(opcode, dst, srcTuple, srcNum, type).convert(); + } + Instruction PRINTF(Register dst, Tuple srcTuple, Tuple typeTuple, uint8_t srcNum, uint8_t bti, uint16_t num) { return internal::PrintfInstruction(dst, srcTuple, typeTuple, srcNum, bti, num).convert(); } diff --git a/backend/src/ir/instruction.hpp b/backend/src/ir/instruction.hpp index bbdef91..a605f45 100644 --- a/backend/src/ir/instruction.hpp +++ b/backend/src/ir/instruction.hpp @@ -611,6 +611,15 @@ namespace ir { uint32_t getSlmAddr(void) const; }; + /*! Related to Sub Group. */ + class SubGroupInstruction : public Instruction { + public: + /*! Return true if the given instruction is an instance of this class */ + static bool isClassOf(const Instruction &insn); + Type getType(void) const; + WorkGroupOps getWorkGroupOpcode(void) const; + }; + /*! Printf instruction. */ class PrintfInstruction : public Instruction { public: @@ -850,6 +859,8 @@ namespace ir { /*! work group */ Instruction WORKGROUP(WorkGroupOps opcode, uint32_t slmAddr, Register dst, Tuple srcTuple, uint8_t srcNum, Type type); + /*! sub group */ + Instruction SUBGROUP(WorkGroupOps opcode, Register dst, Tuple srcTuple, uint8_t srcNum, Type type); /*! printf */ Instruction PRINTF(Register dst, Tuple srcTuple, Tuple typeTuple, uint8_t srcNum, uint8_t bti, uint16_t num); } /* namespace ir */ diff --git a/backend/src/ir/instruction.hxx b/backend/src/ir/instruction.hxx index 651ed64..57e13eb 100644 --- a/backend/src/ir/instruction.hxx +++ b/backend/src/ir/instruction.hxx @@ -112,4 +112,5 @@ DECL_INSN(CALC_TIMESTAMP, CalcTimestampInstruction) DECL_INSN(STORE_PROFILING, StoreProfilingInstruction) DECL_INSN(WAIT, WaitInstruction) DECL_INSN(WORKGROUP, WorkGroupInstruction) +DECL_INSN(SUBGROUP, SubGroupInstruction) DECL_INSN(PRINTF, PrintfInstruction) diff --git a/backend/src/libocl/tmpl/ocl_simd.tmpl.cl b/backend/src/libocl/tmpl/ocl_simd.tmpl.cl index c2e22c1..a25dcef 100644 --- a/backend/src/libocl/tmpl/ocl_simd.tmpl.cl +++ b/backend/src/libocl/tmpl/ocl_simd.tmpl.cl @@ -35,3 +35,101 @@ uint get_sub_group_size(void) else return get_max_sub_group_size(); } + +/* broadcast */ +#define BROADCAST_IMPL(GEN_TYPE) \ + OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id); \ + OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id) { \ + return __gen_ocl_sub_group_broadcast(a, local_id); \ + } \ + OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y); \ + OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y) { \ + return __gen_ocl_sub_group_broadcast(a, local_id_x, local_id_y); \ + } \ + OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y, size_t local_id_z); \ + OVERLOADABLE GEN_TYPE sub_group_broadcast(GEN_TYPE a, size_t local_id_x, size_t local_id_y, size_t local_id_z) { \ + return __gen_ocl_sub_group_broadcast(a, local_id_x, local_id_y, local_id_z); \ + } + +BROADCAST_IMPL(int) +BROADCAST_IMPL(uint) +BROADCAST_IMPL(long) +BROADCAST_IMPL(ulong) +BROADCAST_IMPL(float) +BROADCAST_IMPL(double) +#undef BROADCAST_IMPL + + +#define RANGE_OP(RANGE, OP, GEN_TYPE, SIGN) \ + OVERLOADABLE GEN_TYPE __gen_ocl_sub_group_##RANGE##_##OP(bool sign, GEN_TYPE x); \ + OVERLOADABLE GEN_TYPE sub_group_##RANGE##_##OP(GEN_TYPE x) { \ + return __gen_ocl_sub_group_##RANGE##_##OP(SIGN, x); \ + } + +/* reduce add */ +RANGE_OP(reduce, add, int, true) +RANGE_OP(reduce, add, uint, false) +RANGE_OP(reduce, add, long, true) +RANGE_OP(reduce, add, ulong, false) +RANGE_OP(reduce, add, float, true) +RANGE_OP(reduce, add, double, true) +/* reduce min */ +RANGE_OP(reduce, min, int, true) +RANGE_OP(reduce, min, uint, false) +RANGE_OP(reduce, min, long, true) +RANGE_OP(reduce, min, ulong, false) +RANGE_OP(reduce, min, float, true) +RANGE_OP(reduce, min, double, true) +/* reduce max */ +RANGE_OP(reduce, max, int, true) +RANGE_OP(reduce, max, uint, false) +RANGE_OP(reduce, max, long, true) +RANGE_OP(reduce, max, ulong, false) +RANGE_OP(reduce, max, float, true) +RANGE_OP(reduce, max, double, true) + +/* scan_inclusive add */ +RANGE_OP(scan_inclusive, add, int, true) +RANGE_OP(scan_inclusive, add, uint, false) +RANGE_OP(scan_inclusive, add, long, true) +RANGE_OP(scan_inclusive, add, ulong, false) +RANGE_OP(scan_inclusive, add, float, true) +RANGE_OP(scan_inclusive, add, double, true) +/* scan_inclusive min */ +RANGE_OP(scan_inclusive, min, int, true) +RANGE_OP(scan_inclusive, min, uint, false) +RANGE_OP(scan_inclusive, min, long, true) +RANGE_OP(scan_inclusive, min, ulong, false) +RANGE_OP(scan_inclusive, min, float, true) +RANGE_OP(scan_inclusive, min, double, true) +/* scan_inclusive max */ +RANGE_OP(scan_inclusive, max, int, true) +RANGE_OP(scan_inclusive, max, uint, false) +RANGE_OP(scan_inclusive, max, long, true) +RANGE_OP(scan_inclusive, max, ulong, false) +RANGE_OP(scan_inclusive, max, float, true) +RANGE_OP(scan_inclusive, max, double, true) + +/* scan_exclusive add */ +RANGE_OP(scan_exclusive, add, int, true) +RANGE_OP(scan_exclusive, add, uint, false) +RANGE_OP(scan_exclusive, add, long, true) +RANGE_OP(scan_exclusive, add, ulong, false) +RANGE_OP(scan_exclusive, add, float, true) +RANGE_OP(scan_exclusive, add, double, true) +/* scan_exclusive min */ +RANGE_OP(scan_exclusive, min, int, true) +RANGE_OP(scan_exclusive, min, uint, false) +RANGE_OP(scan_exclusive, min, long, true) +RANGE_OP(scan_exclusive, min, ulong, false) +RANGE_OP(scan_exclusive, min, float, true) +RANGE_OP(scan_exclusive, min, double, true) +/* scan_exclusive max */ +RANGE_OP(scan_exclusive, max, int, true) +RANGE_OP(scan_exclusive, max, uint, false) +RANGE_OP(scan_exclusive, max, long, true) +RANGE_OP(scan_exclusive, max, ulong, false) +RANGE_OP(scan_exclusive, max, float, true) +RANGE_OP(scan_exclusive, max, double, true) + +#undef RANGE_OP diff --git a/backend/src/libocl/tmpl/ocl_simd.tmpl.h b/backend/src/libocl/tmpl/ocl_simd.tmpl.h index 96337cd..355ee30 100644 --- a/backend/src/libocl/tmpl/ocl_simd.tmpl.h +++ b/backend/src/libocl/tmpl/ocl_simd.tmpl.h @@ -34,6 +34,101 @@ uint get_num_sub_groups(void); uint get_sub_group_id(void); uint get_sub_group_local_id(void); +/* broadcast */ +OVERLOADABLE int sub_group_broadcast(int a, size_t local_id); +OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id); +OVERLOADABLE long sub_group_broadcast(long a, size_t local_id); +OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id); +OVERLOADABLE float sub_group_broadcast(float a, size_t local_id); +OVERLOADABLE double sub_group_broadcast(double a, size_t local_id); + +OVERLOADABLE int sub_group_broadcast(int a, size_t local_id_x, size_t local_id_y); +OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id_x, size_t local_id_y); +OVERLOADABLE long sub_group_broadcast(long a, size_t local_id_x, size_t local_id_y); +OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id_x, size_t local_id_y); +OVERLOADABLE float sub_group_broadcast(float a, size_t local_id_x, size_t local_id_y); +OVERLOADABLE double sub_group_broadcast(double a, size_t local_id_x, size_t local_id_y); + +OVERLOADABLE int sub_group_broadcast(int a, size_t local_id_x, size_t local_id_y, size_t local_id_z); +OVERLOADABLE uint sub_group_broadcast(uint a, size_t local_id_x, size_t local_id_y, size_t local_id_z); +OVERLOADABLE long sub_group_broadcast(long a, size_t local_id_x, size_t local_id_y, size_t local_id_z); +OVERLOADABLE ulong sub_group_broadcast(ulong a, size_t local_id_x, size_t local_id_y, size_t local_id_z); +OVERLOADABLE float sub_group_broadcast(float a, size_t local_id_x, size_t local_id_y, size_t local_id_z); +OVERLOADABLE double sub_group_broadcast(double a, size_t local_id_x, size_t local_id_y, size_t local_id_z); + +/* reduce add */ +OVERLOADABLE int sub_group_reduce_add(int x); +OVERLOADABLE uint sub_group_reduce_add(uint x); +OVERLOADABLE long sub_group_reduce_add(long x); +OVERLOADABLE ulong sub_group_reduce_add(ulong x); +OVERLOADABLE float sub_group_reduce_add(float x); +OVERLOADABLE double sub_group_reduce_add(double x); + +/* reduce min */ +OVERLOADABLE int sub_group_reduce_min(int x); +OVERLOADABLE uint sub_group_reduce_min(uint x); +OVERLOADABLE long sub_group_reduce_min(long x); +OVERLOADABLE ulong sub_group_reduce_min(ulong x); +OVERLOADABLE float sub_group_reduce_min(float x); +OVERLOADABLE double sub_group_reduce_min(double x); + +/* reduce max */ +OVERLOADABLE int sub_group_reduce_max(int x); +OVERLOADABLE uint sub_group_reduce_max(uint x); +OVERLOADABLE long sub_group_reduce_max(long x); +OVERLOADABLE ulong sub_group_reduce_max(ulong x); +OVERLOADABLE float sub_group_reduce_max(float x); +OVERLOADABLE double sub_group_reduce_max(double x); + +/* scan_inclusive add */ +OVERLOADABLE int sub_group_scan_inclusive_add(int x); +OVERLOADABLE uint sub_group_scan_inclusive_add(uint x); +OVERLOADABLE long sub_group_scan_inclusive_add(long x); +OVERLOADABLE ulong sub_group_scan_inclusive_add(ulong x); +OVERLOADABLE float sub_group_scan_inclusive_add(float x); +OVERLOADABLE double sub_group_scan_inclusive_add(double x); + +/* scan_inclusive min */ +OVERLOADABLE int sub_group_scan_inclusive_min(int x); +OVERLOADABLE uint sub_group_scan_inclusive_min(uint x); +OVERLOADABLE long sub_group_scan_inclusive_min(long x); +OVERLOADABLE ulong sub_group_scan_inclusive_min(ulong x); +OVERLOADABLE float sub_group_scan_inclusive_min(float x); +OVERLOADABLE double sub_group_scan_inclusive_min(double x); + +/* scan_inclusive max */ +OVERLOADABLE int sub_group_scan_inclusive_max(int x); +OVERLOADABLE uint sub_group_scan_inclusive_max(uint x); +OVERLOADABLE long sub_group_scan_inclusive_max(long x); +OVERLOADABLE ulong sub_group_scan_inclusive_max(ulong x); +OVERLOADABLE float sub_group_scan_inclusive_max(float x); +OVERLOADABLE double sub_group_scan_inclusive_max(double x); + +/* scan_exclusive add */ +OVERLOADABLE int sub_group_scan_exclusive_add(int x); +OVERLOADABLE uint sub_group_scan_exclusive_add(uint x); +OVERLOADABLE long sub_group_scan_exclusive_add(long x); +OVERLOADABLE ulong sub_group_scan_exclusive_add(ulong x); +OVERLOADABLE float sub_group_scan_exclusive_add(float x); +OVERLOADABLE double sub_group_scan_exclusive_add(double x); + +/* scan_exclusive min */ +OVERLOADABLE int sub_group_scan_exclusive_min(int x); +OVERLOADABLE uint sub_group_scan_exclusive_min(uint x); +OVERLOADABLE long sub_group_scan_exclusive_min(long x); +OVERLOADABLE ulong sub_group_scan_exclusive_min(ulong x); +OVERLOADABLE float sub_group_scan_exclusive_min(float x); +OVERLOADABLE double sub_group_scan_exclusive_min(double x); + +/* scan_exclusive max */ +OVERLOADABLE int sub_group_scan_exclusive_max(int x); +OVERLOADABLE uint sub_group_scan_exclusive_max(uint x); +OVERLOADABLE long sub_group_scan_exclusive_max(long x); +OVERLOADABLE ulong sub_group_scan_exclusive_max(ulong x); +OVERLOADABLE float sub_group_scan_exclusive_max(float x); +OVERLOADABLE double sub_group_scan_exclusive_max(double x); + +/* shuffle */ OVERLOADABLE float intel_sub_group_shuffle(float x, uint c); OVERLOADABLE int intel_sub_group_shuffle(int x, uint c); OVERLOADABLE uint intel_sub_group_shuffle(uint x, uint c); diff --git a/backend/src/llvm/llvm_gen_backend.cpp b/backend/src/llvm/llvm_gen_backend.cpp index f5228d2..a091d7c 100644 --- a/backend/src/llvm/llvm_gen_backend.cpp +++ b/backend/src/llvm/llvm_gen_backend.cpp @@ -695,6 +695,8 @@ namespace gbe void emitAtomicInst(CallInst &I, CallSite &CS, ir::AtomicOps opcode); // Emit workgroup instructions void emitWorkGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode); + // Emit subgroup instructions + void emitSubGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode); uint8_t appendSampler(CallSite::arg_iterator AI); uint8_t getImageID(CallInst &I); @@ -3729,6 +3731,16 @@ namespace gbe case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_ADD: case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MAX: case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MIN: + case GEN_OCL_SUB_GROUP_BROADCAST: + case GEN_OCL_SUB_GROUP_REDUCE_ADD: + case GEN_OCL_SUB_GROUP_REDUCE_MAX: + case GEN_OCL_SUB_GROUP_REDUCE_MIN: + case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_ADD: + case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MAX: + case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MIN: + case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_ADD: + case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MAX: + case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MIN: case GEN_OCL_LRP: this->newRegister(&I); break; @@ -3898,6 +3910,48 @@ namespace gbe GBE_ASSERT(AI == AE); } + void GenWriter::emitSubGroupInst(CallInst &I, CallSite &CS, ir::WorkGroupOps opcode) { + CallSite::arg_iterator AI = CS.arg_begin(); + CallSite::arg_iterator AE = CS.arg_end(); + GBE_ASSERT(AI != AE); + + if (opcode == ir::WORKGROUP_OP_ALL || opcode == ir::WORKGROUP_OP_ANY) { + GBE_ASSERT(getType(ctx, (*AI)->getType()) == ir::TYPE_S32); + ir::Register src[3]; + src[0] = this->getRegister(*(AI++)); + const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], 1); + ctx.SUBGROUP(opcode, getRegister(&I), srcTuple, 1, ir::TYPE_S32); + } else if (opcode == ir::WORKGROUP_OP_BROADCAST) { + int argNum = CS.arg_size(); + std::vector<ir::Register> src(argNum); + for (int i = 0; i < argNum; i++) { + src[i] = this->getRegister(*(AI++)); + } + const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], argNum); + ctx.SUBGROUP(ir::WORKGROUP_OP_BROADCAST, getRegister(&I), srcTuple, argNum, + getType(ctx, (*CS.arg_begin())->getType())); + } else { + ConstantInt *sign = dyn_cast<ConstantInt>(AI); + GBE_ASSERT(sign); + bool isSign = sign->getZExtValue(); + AI++; + ir::Type ty; + if (isSign) { + ty = getType(ctx, (*AI)->getType()); + + } else { + ty = getUnsignedType(ctx, (*AI)->getType()); + } + + ir::Register src[3]; + src[0] = this->getRegister(*(AI++)); + const ir::Tuple srcTuple = ctx.arrayTuple(&src[0], 1); + ctx.SUBGROUP(opcode, getRegister(&I), srcTuple, 1, ty); + } + + GBE_ASSERT(AI == AE); + } + /* append a new sampler. should be called before any reference to * a sampler_t value. */ uint8_t GenWriter::appendSampler(CallSite::arg_iterator AI) { @@ -4690,6 +4744,26 @@ namespace gbe this->emitWorkGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MAX); break; case GEN_OCL_WORK_GROUP_SCAN_INCLUSIVE_MIN: this->emitWorkGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MIN); break; + case GEN_OCL_SUB_GROUP_BROADCAST: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_BROADCAST); break; + case GEN_OCL_SUB_GROUP_REDUCE_ADD: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_ADD); break; + case GEN_OCL_SUB_GROUP_REDUCE_MAX: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_MAX); break; + case GEN_OCL_SUB_GROUP_REDUCE_MIN: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_REDUCE_MIN); break; + case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_ADD: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_ADD); break; + case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MAX: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_MAX); break; + case GEN_OCL_SUB_GROUP_SCAN_EXCLUSIVE_MIN: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_EXCLUSIVE_MIN); break; + case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_ADD: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_ADD); break; + case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MAX: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MAX); break; + case GEN_OCL_SUB_GROUP_SCAN_INCLUSIVE_MIN: + this->emitSubGroupInst(I, CS, ir::WORKGROUP_OP_INCLUSIVE_MIN); break; case GEN_OCL_LRP: { const ir::Register dst = this->getRegister(&I); diff --git a/backend/src/llvm/llvm_gen_ocl_function.hxx b/backend/src/llvm/llvm_gen_ocl_function.hxx index cff4d61..213ead0 100644 --- a/backend/src/llvm/llvm_gen_ocl_function.hxx +++ b/backend/src/llvm/llvm_gen_ocl_function.hxx @@ -202,5 +202,20 @@ DECL_LLVM_GEN_FUNCTION(WORK_GROUP_SCAN_INCLUSIVE_MIN, __gen_ocl_work_group_scan_ DECL_LLVM_GEN_FUNCTION(WORK_GROUP_ALL, __gen_ocl_work_group_all) DECL_LLVM_GEN_FUNCTION(WORK_GROUP_ANY, __gen_ocl_work_group_any) +// work group function +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_BROADCAST, __gen_ocl_sub_group_broadcast) + +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_ADD, __gen_ocl_sub_group_reduce_add) +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_MAX, __gen_ocl_sub_group_reduce_max) +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_REDUCE_MIN, __gen_ocl_sub_group_reduce_min) + +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_ADD, __gen_ocl_sub_group_scan_exclusive_add) +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_MAX, __gen_ocl_sub_group_scan_exclusive_max) +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_EXCLUSIVE_MIN, __gen_ocl_sub_group_scan_exclusive_min) + +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_ADD, __gen_ocl_sub_group_scan_inclusive_add) +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_MAX, __gen_ocl_sub_group_scan_inclusive_max) +DECL_LLVM_GEN_FUNCTION(SUB_GROUP_SCAN_INCLUSIVE_MIN, __gen_ocl_sub_group_scan_inclusive_min) + // common function DECL_LLVM_GEN_FUNCTION(LRP, __gen_ocl_lrp) -- 2.7.4 _______________________________________________ Beignet mailing list Beignet@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/beignet