TODO: move lowering code somewhere else. We do the same thing as from_tgsi for a few ops and we could move that down a bit so the input IR doesn't have to deal with a few things, like slct and min/max with 64bit dest types.
TODO: move DEFAULT_HANDLER into its own function TODO: check if some code duplication can be eliminated through templates Signed-off-by: Karol Herbst <[email protected]> --- .../drivers/nouveau/codegen/nv50_ir_from_nir.cpp | 524 ++++++++++++++++++++- 1 file changed, 523 insertions(+), 1 deletion(-) diff --git a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp index fe11280537..d2b2236c17 100644 --- a/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp +++ b/src/gallium/drivers/nouveau/codegen/nv50_ir_from_nir.cpp @@ -64,6 +64,7 @@ public: Value* getSrc(nir_src *, uint8_t); Value* getSrc(nir_ssa_def *, uint8_t); + bool visit(nir_alu_instr *); bool visit(nir_block *); bool visit(nir_cf_node *); bool visit(nir_function *); @@ -86,6 +87,10 @@ public: std::vector<DataType> getSTypes(nir_alu_instr*); DataType getSType(nir_src&, bool isFloat, bool isSigned); + operation getOperation(nir_op); + operation preOperationNeeded(nir_op); + int getSubOp(nir_op); + CondCode getCondCode(nir_op); private: nir_shader *nir; @@ -95,6 +100,7 @@ private: unsigned int curLoopDepth; BasicBlock *exit; + Value *zero; union { struct { @@ -106,7 +112,10 @@ private: Converter::Converter(Program *prog, nir_shader *nir, nv50_ir_prog_info *info) : ConverterCommon(prog, info), nir(nir), - curLoopDepth(0) {} + curLoopDepth(0) +{ + zero = mkImm((uint32_t)0); +} BasicBlock * Converter::convert(nir_block *block) @@ -224,6 +233,157 @@ Converter::getSType(nir_src &src, bool isFloat, bool isSigned) return typeOfSize(bitSize / 8, isFloat, isSigned); } +#define CASE_OP(ni, no) \ + case nir_op_f ## ni : \ + case nir_op_i ## ni : \ + return OP_ ## no +#define CASE_OP3(ni, no) \ + case nir_op_f ## ni : \ + case nir_op_i ## ni : \ + case nir_op_u ## ni : \ + return OP_ ## no +#define CASE_OPIU(ni, no) \ + case nir_op_i ## ni : \ + case nir_op_u ## ni : \ + return OP_ ## no + +operation +Converter::getOperation(nir_op op) +{ + switch (op) { + // basic ops with float and int variants + CASE_OP(abs, ABS); + CASE_OP(add, ADD); + CASE_OP(and, AND); + CASE_OP3(div, DIV); + CASE_OPIU(find_msb, BFIND); + CASE_OP3(max, MAX); + CASE_OP3(min, MIN); + CASE_OP3(mod, MOD); + CASE_OP(mul, MUL); + CASE_OPIU(mul_high, MUL); + CASE_OP(neg, NEG); + CASE_OP(not, NOT); + CASE_OP(or, OR); + CASE_OP(eq, SET); + CASE_OP3(ge, SET); + CASE_OP3(lt, SET); + CASE_OP(ne, SET); + CASE_OPIU(shr, SHR); + CASE_OP(sub, SUB); + CASE_OP(xor, XOR); + case nir_op_fceil: + return OP_CEIL; + case nir_op_fcos: + return OP_COS; + case nir_op_f2f32: + case nir_op_f2f64: + case nir_op_f2i32: + case nir_op_f2i64: + case nir_op_f2u32: + case nir_op_f2u64: + case nir_op_i2f32: + case nir_op_i2f64: + case nir_op_u2f32: + case nir_op_u2f64: + return OP_CVT; + case nir_op_fddx: + return OP_DFDX; + case nir_op_fddy: + return OP_DFDY; + case nir_op_fexp2: + return OP_EX2; + case nir_op_ffloor: + return OP_FLOOR; + case nir_op_ffma: + return OP_FMA; + case nir_op_flog2: + return OP_LG2; + case nir_op_frcp: + return OP_RCP; + case nir_op_frsq: + return OP_RSQ; + case nir_op_fsat: + return OP_SAT; + case nir_op_ishl: + return OP_SHL; + case nir_op_fsin: + return OP_SIN; + case nir_op_fsqrt: + return OP_SQRT; + case nir_op_ftrunc: + return OP_TRUNC; + default: + ERROR("couldn't get operation for op %s\n", nir_op_infos[op].name); + assert(false); + return OP_NOP; + } +} +#undef CASE_OP +#undef CASE_OP3 +#undef CASE_OPIU + +operation +Converter::preOperationNeeded(nir_op op) +{ + switch (op) { + case nir_op_fcos: + case nir_op_fsin: + return OP_PRESIN; + default: + return OP_NOP; + } +} + +#define CASE_OPIU(ni, no) \ + case nir_op_i ## ni : \ + case nir_op_u ## ni : \ + return NV50_IR_SUBOP_ ## no +int +Converter::getSubOp(nir_op op) +{ + switch (op) { + CASE_OPIU(mul_high, MUL_HIGH); + default: + return 0; + } +} +#undef CASE_OPIU + +#define CASE_OP(ni) \ + case nir_op_f ## ni : \ + case nir_op_i ## ni +#define CASE_OP3(ni) \ + case nir_op_f ## ni : \ + case nir_op_i ## ni : \ + case nir_op_u ## ni +CondCode +Converter::getCondCode(nir_op op) +{ + switch (op) { + CASE_OP(eq): + return CC_EQ; + CASE_OP3(ge): + return CC_GE; + CASE_OP3(lt): + return CC_LT; + CASE_OP(ne): + return CC_NEU; + default: + ERROR("couldn't get CondCode for op %s\n", nir_op_infos[op].name); + assert(false); + return CC_FL; + } +} +#undef CASE_OP +#undef CASE_OP3 + +Converter::LValues& +Converter::convert(nir_alu_dest *dest) +{ + return convert(&dest->dest); +} + Converter::LValues& Converter::convert(nir_dest *dest) { @@ -486,6 +646,10 @@ bool Converter::visit(nir_instr *insn) { switch (insn->type) { + case nir_instr_type_alu: + if (!visit(nir_instr_as_alu(insn))) + return false; + break; case nir_instr_type_intrinsic: if (!visit(nir_instr_as_intrinsic(insn))) return false; @@ -559,6 +723,364 @@ Converter::visit(nir_intrinsic_instr *insn) return true; } +#define CASE_OP(ni) \ + case nir_op_f ## ni : \ + case nir_op_i ## ni +#define CASE_OP3(ni) \ + case nir_op_f ## ni : \ + case nir_op_i ## ni : \ + case nir_op_u ## ni +#define CASE_OPIU(ni) \ + case nir_op_i ## ni : \ + case nir_op_u ## ni +#define DEFAULT_CHECKS \ + if (insn->dest.dest.ssa.num_components > 1) { \ + ERROR("nir_alu_instr only supported with 1 component!\n"); \ + return false; \ + } \ + if (insn->dest.write_mask != 1) { \ + ERROR("nir_alu_instr only with write_mask of 1 supported!\n"); \ + return false; \ + } +#define DEFAULT_HANDLER \ + do { \ + LValues &newDefs = convert(&insn->dest); \ + operation preOp = preOperationNeeded(op); \ + if (preOp != OP_NOP) { \ + assert(info.num_inputs < 2); \ + Instruction *i0 = mkOp(preOp, dType, newDefs[0]); \ + Instruction *i1 = mkOp(getOperation(op), dType, newDefs[0]); \ + if (info.num_inputs) { \ + i0->setSrc(0, getSrc(&insn->src[0])); \ + i1->setSrc(0, newDefs[0]); \ + } \ + i1->subOp = getSubOp(op); \ + } else { \ + Instruction *i = mkOp(getOperation(op), dType, newDefs[0]); \ + for (auto s = 0u; s < info.num_inputs; ++s) { \ + i->setSrc(s, getSrc(&insn->src[s])); \ + } \ + i->subOp = getSubOp(op); \ + } \ + } while (false) + +bool +Converter::visit(nir_alu_instr *insn) +{ + // some helper variables + const nir_op op = insn->op; + const nir_op_info &info = nir_op_infos[op]; + DataType dType = getDType(insn); + const std::vector<DataType> sTypes = getSTypes(insn); + // save last instruction + Instruction *oldPos = this->bb->getExit(); + + switch (op) { + CASE_OP(abs): + CASE_OP(add): + CASE_OP(and): + case nir_op_fceil: + case nir_op_fcos: + case nir_op_fddx: + case nir_op_fddy: + CASE_OP3(div): + case nir_op_fexp2: + case nir_op_ffloor: + case nir_op_ffma: + case nir_op_flog2: + CASE_OP3(mod): + CASE_OP(mul): + CASE_OPIU(mul_high): + CASE_OP(neg): + CASE_OP(not): + CASE_OP(or): + case nir_op_frcp: + case nir_op_frsq: + case nir_op_fsat: + CASE_OPIU(shr): + case nir_op_fsin: + case nir_op_fsqrt: + CASE_OP(sub): + case nir_op_ftrunc: + case nir_op_ishl: + CASE_OP(xor): { + DEFAULT_CHECKS; + DEFAULT_HANDLER; + break; + } + CASE_OPIU(find_msb): { + DEFAULT_CHECKS; + dType = sTypes[0]; + DEFAULT_HANDLER; + break; + } + CASE_OP3(max): + CASE_OP3(min): { + DEFAULT_CHECKS; + if (dType == TYPE_U64 || dType == TYPE_S64) { + operation op = getOperation(insn->op); + LValues &newDefs = convert(&insn->dest); + DataType sdType = typeOfSize(4, false, isSignedIntType(dType)); + Value *flag = getSSA(1, FILE_FLAGS); + + Value *split0[2]; + Value *split1[2]; + Value *merge[2]; + + merge[0] = getScratch(); + merge[1] = getScratch(); + + mkSplit(split0, 4, getSrc(&insn->src[0])); + mkSplit(split1, 4, getSrc(&insn->src[1])); + + Instruction *hi = mkOp2(op, sdType, merge[1], split0[1], split1[1]); + hi->subOp = NV50_IR_SUBOP_MINMAX_HIGH; + hi->setFlagsDef(1, flag); + + Instruction *low = mkOp2(op, sdType, merge[0], split0[0], split1[0]); + low->subOp = NV50_IR_SUBOP_MINMAX_LOW; + low->setFlagsSrc(2, flag); + + mkOp2(OP_MERGE, dType, newDefs[0], merge[0], merge[1]); + } else + DEFAULT_HANDLER; + break; + } + case nir_op_fround_even: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + mkCvt(OP_CVT, dType, newDefs[0], dType, getSrc(&insn->src[0]))->rnd = ROUND_NI; + break; + } + // convert instructions + CASE_OP3(2f32): + CASE_OP3(2f64): + case nir_op_f2i32: + case nir_op_f2i64: + case nir_op_f2u32: + case nir_op_f2u64: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + Instruction *i = mkOp1(getOperation(op), dType, newDefs[0], getSrc(&insn->src[0])); + if (op == nir_op_f2i32 || op == nir_op_f2i64 || op == nir_op_f2u32 || op == nir_op_f2u64) + i->rnd = ROUND_Z; + i->sType = sTypes[0]; + break; + } + // compare instructions + CASE_OP(eq): + CASE_OP3(ge): + CASE_OP3(lt): + CASE_OP(ne): { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + Instruction *i = mkCmp(getOperation(op), + getCondCode(op), + dType, + newDefs[0], + dType, + getSrc(&insn->src[0]), + getSrc(&insn->src[1])); + if (info.num_inputs == 3) + i->setSrc(2, getSrc(&insn->src[2])); + i->sType = sTypes[0]; + break; + } + /* those are weird ALU ops and need special handling, because + * 1. they are always componend based + * 2. they basically just merge multiple values into one data type + */ + CASE_OP(mov): + case nir_op_vec2: + case nir_op_vec3: + case nir_op_vec4: { + LValues &newDefs = convert(&insn->dest); + for (LValues::size_type c = 0u; c < newDefs.size(); ++c) { + mkMov(newDefs[c], getSrc(&insn->src[c]), dType); + } + break; + } + // (un)pack + case nir_op_pack_64_2x32: { + LValues &newDefs = convert(&insn->dest); + Instruction *merge = mkOp(OP_MERGE, dType, newDefs[0]); + merge->setSrc(0, getSrc(&insn->src[0], 0)); + merge->setSrc(1, getSrc(&insn->src[0], 1)); + break; + } + case nir_op_unpack_64_2x32: { + LValues &newDefs = convert(&insn->dest); + mkOp1(OP_SPLIT, dType, newDefs[0], getSrc(&insn->src[0]))->setDef(1, newDefs[1]); + break; + } + // special instructions + CASE_OP(sign): { + DEFAULT_CHECKS; + DataType iType; + if (::isFloatType(dType)) + iType = TYPE_F32; + else + iType = TYPE_S32; + + LValues &newDefs = convert(&insn->dest); + LValue *val0 = getScratch(); + LValue *val1 = getScratch(); + mkCmp(OP_SET, CC_GT, iType, val0, dType, getSrc(&insn->src[0]), zero); + mkCmp(OP_SET, CC_LT, iType, val1, dType, getSrc(&insn->src[0]), zero); + + if (dType == TYPE_F64) { + mkOp2(OP_SUB, iType, val0, val0, val1); + mkCvt(OP_CVT, TYPE_F64, newDefs[0], iType, val0); + } else if (dType == TYPE_S64 || dType == TYPE_U64) { + mkOp2(OP_SUB, iType, val0, val1, val0); + mkOp2(OP_SHR, iType, val1, val0, loadImm(nullptr, 31)); + mkOp2(OP_MERGE, dType, newDefs[0], val0, val1); + } else if (::isFloatType(dType)) + mkOp2(OP_SUB, iType, newDefs[0], val0, val1); + else + mkOp2(OP_SUB, iType, newDefs[0], val1, val0); + break; + } + case nir_op_fcsel: + case nir_op_bcsel: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + if (typeSizeof(dType) > 4) { + Value *split0[2]; + Value *split1[2]; + Value *merge[2]; + merge[0] = getScratch(); + merge[1] = getScratch(); + mkSplit(split0, 4, getSrc(&insn->src[1])); + mkSplit(split1, 4, getSrc(&insn->src[2])); + mkCmp(OP_SLCT, CC_NE, typeOfSize(4, ::isFloatType(dType)), merge[0], sTypes[0], split0[0], split1[0], getSrc(&insn->src[0])); + mkCmp(OP_SLCT, CC_NE, typeOfSize(4, ::isFloatType(dType)), merge[1], sTypes[0], split0[1], split1[1], getSrc(&insn->src[0])); + mkOp2(OP_MERGE, dType, newDefs[0], merge[0], merge[1]); + } else + mkCmp(OP_SLCT, CC_NE, dType, newDefs[0], sTypes[0], getSrc(&insn->src[1]), getSrc(&insn->src[2]), getSrc(&insn->src[0])); + break; + } + CASE_OPIU(bfe): + CASE_OPIU(bitfield_extract): { + DEFAULT_CHECKS; + Value *tmp = getScratch(); + LValues &newDefs = convert(&insn->dest); + mkOp3(OP_INSBF, dType, tmp, getSrc(&insn->src[2]), loadImm(getScratch(), 0x808), getSrc(&insn->src[1])); + mkOp2(OP_EXTBF, dType, newDefs[0], getSrc(&insn->src[0]), tmp); + break; + } + case nir_op_bfm: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 0x808), getSrc(&insn->src[1])); + break; + } + case nir_op_bfi: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + mkOp3(OP_INSBF, dType, newDefs[0], getSrc(&insn->src[1]), getSrc(&insn->src[0]), getSrc(&insn->src[2])); + break; + } + case nir_op_bit_count: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + mkOp2(OP_POPCNT, dType, newDefs[0], getSrc(&insn->src[0]), getSrc(&insn->src[0])); + break; + } + case nir_op_bitfield_reverse: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + mkOp2(OP_EXTBF, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV; + break; + } + case nir_op_find_lsb: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + Value *tmp = getScratch(); + mkOp2(OP_EXTBF, TYPE_U32, tmp, getSrc(&insn->src[0]), mkImm(0x2000))->subOp = NV50_IR_SUBOP_EXTBF_REV; + mkOp1(OP_BFIND, TYPE_U32, newDefs[0], tmp)->subOp = NV50_IR_SUBOP_BFIND_SAMT; + break; + } + // boolean conversions + case nir_op_b2f: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + mkOp2(OP_AND, TYPE_U32, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 1.0f)); + break; + } + CASE_OP(2b): { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + Value *src1; + if (typeSizeof(sTypes[0]) == 8) { + src1 = loadImm(getScratch(8), 0.0); + } else { + src1 = zero; + } + mkCmp(OP_SET, CC_NEU, TYPE_U32, newDefs[0], sTypes[0], getSrc(&insn->src[0]), src1); + break; + } + case nir_op_b2i: { + DEFAULT_CHECKS; + LValues &newDefs = convert(&insn->dest); + LValue *def; + if (typeSizeof(dType) == 8) + def = getScratch(); + else + def = newDefs[0]; + + // bools are always 32bit values + mkOp2(OP_AND, TYPE_U32, def, getSrc(&insn->src[0]), loadImm(getScratch(), 1)); + if (typeSizeof(dType) == 8) + mkOp2(OP_MERGE, TYPE_S64, newDefs[0], def, loadImm(getScratch(), 0)); + + break; + } + case nir_op_i2i32: + case nir_op_u2u32: { + DEFAULT_CHECKS; + Value *src[2]; + LValues &newDefs = convert(&insn->dest); + mkSplit(src, 4, getSrc(&insn->src[0])); + mkMov(newDefs[0], src[0]); + break; + } + case nir_op_i2i64: { + LValues &newDefs = convert(&insn->dest); + Value *dst0 = getSrc(&insn->src[0]); + Value *dst1 = getScratch(); + mkOp2(OP_SHR, TYPE_S32, dst1, dst0, loadImm(NULL, 31)); + mkOp2(OP_MERGE, TYPE_S64, newDefs[0], dst0, dst1); + break; + } + case nir_op_u2u64: { + LValues &newDefs = convert(&insn->dest); + mkOp2(OP_MERGE, TYPE_U64, newDefs[0], getSrc(&insn->src[0]), loadImm(getScratch(), 0)); + break; + } + default: + ERROR("unknown nir_op %s\n", info.name); + return false; + } + + if (!oldPos) { + oldPos = this->bb->getExit(); + oldPos->precise = insn->exact; + } + + while (oldPos->next) { + oldPos = oldPos->next; + oldPos->precise = insn->exact; + } + oldPos->saturate = insn->dest.saturate; + + return true; +} +#undef CASE_OP +#undef CASE_OP3 +#undef CASE_OPIU +#undef DEFAULT_CHECKS + bool Converter::run() { -- 2.14.3 _______________________________________________ mesa-dev mailing list [email protected] https://lists.freedesktop.org/mailman/listinfo/mesa-dev
