Unsized conversion opcodes require special handling in opt_algebraic because they fallow different bit size rules from regular opcodes. In particular, we now have a new case where we have an opcode with multiple variable-size inputs and outputs but no common size. --- src/compiler/nir/nir_algebraic.py | 68 ++++++++++++++++++++++++------- src/compiler/nir/nir_search.c | 19 +++++---- 2 files changed, 65 insertions(+), 22 deletions(-)
diff --git a/src/compiler/nir/nir_algebraic.py b/src/compiler/nir/nir_algebraic.py index f9ee637830c..837ef114349 100644 --- a/src/compiler/nir/nir_algebraic.py +++ b/src/compiler/nir/nir_algebraic.py @@ -406,6 +406,9 @@ class BitSizeValidator(object): 'Source {0} of nir_op_{1} must be a {2}-bit value but ' \ 'the only possible matched values are {3}-bit: {4}' \ .format(i, val.opcode, src_type_bits, src_bits, str(val)) + elif nir_op.is_unsized_conversion: + # Nothing to do here + pass else: assert val.common_size == 0 or src_bits == val.common_size, \ 'Expression cannot have both {0}-bit and {1}-bit ' \ @@ -420,6 +423,10 @@ class BitSizeValidator(object): 'result was requested' \ .format(val.opcode, dst_type_bits, val.bit_size) return dst_type_bits + elif nir_op.is_unsized_conversion: + # No validation to do here. If we have a non-zero bit size, + # that's the bit size of the result of the expression; return it. + return val.bit_size else: if val.common_size != 0: assert val.bit_size == 0 or val.bit_size == val.common_size, \ @@ -432,12 +439,16 @@ class BitSizeValidator(object): def _propagate_bit_class_down(self, val, bit_class): if isinstance(val, Constant): - assert val.bit_size == 0 or val.bit_size == bit_class, \ + assert bit_class == 0 or val.bit_size == 0 or \ + val.bit_size == bit_class, \ 'Constant is {0}-bit but a {1}-bit value is required: {2}' \ .format(val.bit_size, bit_class, str(val)) elif isinstance(val, Variable): - self._set_var_bit_class(val, bit_class) + if bit_class == 0: + self._set_var_bit_class(val, self._new_class()) + else: + self._set_var_bit_class(val, bit_class) elif isinstance(val, Expression): nir_op = opcodes[val.opcode] @@ -448,14 +459,19 @@ class BitSizeValidator(object): 'expression wants a {2} value' \ .format(val.opcode, dst_type_bits, self._bit_class_to_str(bit_class)) + elif nir_op.is_unsized_conversion: + # Nothing to do here + pass else: - assert val.common_size == 0 or val.common_size == bit_class, \ + assert bit_class == 0 or val.common_size == 0 or \ + val.common_size == bit_class, \ 'Variable-width expression produces a {0}-bit result ' \ 'based on the source widths but the parent expression ' \ 'wants a {1} value: {2}' \ .format(val.common_size, self._bit_class_to_str(bit_class), str(val)) - val.common_size = bit_class + if val.common_size == 0: + val.common_size = bit_class if val.common_size: common_class = val.common_size @@ -468,6 +484,8 @@ class BitSizeValidator(object): src_type_bits = type_bits(nir_op.input_types[i]) if src_type_bits != 0: self._propagate_bit_class_down(val.sources[i], src_type_bits) + elif nir_op.is_unsized_conversion: + self._propagate_bit_class_down(val.sources[i], 0) else: self._propagate_bit_class_down(val.sources[i], common_class) @@ -507,6 +525,9 @@ class BitSizeValidator(object): 'the constructed value would be {}: {}' \ .format(i, val.opcode, src_type_bits, self._bit_class_to_str(src_class), str(val)) + elif nir_op.is_unsized_conversion: + # Nothing to do here + pass else: assert val.common_class == 0 or src_class == val.common_class, \ 'Source {} of nir_op_{} must be a {} value based ' \ @@ -524,6 +545,8 @@ class BitSizeValidator(object): 'expression explicitly requests a {}-bit value' \ .format(val.opcode, dst_type_bits, val.bit_size) return dst_type_bits + elif nir_op.is_unsized_conversion: + return 0 else: if val.common_class != 0: assert val.bit_size == 0 or val.bit_size == val.common_class, \ @@ -538,21 +561,21 @@ class BitSizeValidator(object): return val.common_class def _validate_bit_class_down(self, val, bit_class): - # At this point, everything *must* have a bit class. Otherwise, we have - # a value we don't know how to define. - assert bit_class != 0, \ - 'Value cannot be constructed because no bit-size is implied '\ - '{}'.format(str(val)) - if isinstance(val, Constant): - assert val.bit_size == 0 or val.bit_size == bit_class, \ + assert bit_class != 0 or val.bit_size != 0, \ + 'Constant value {} cannot be constructed because ' \ + 'nothing provides or implies a bit size'.format(val) + + assert val.bit_size == 0 or bit_class == 0 or \ + val.bit_size == bit_class, \ 'Constant value {} explicitly requests being {}-bit but ' \ 'must be {} thanks to its consumer' \ .format(str(val), val.bit_size, self._bit_class_to_str(bit_class)) elif isinstance(val, Variable): - assert val.bit_size == 0 or val.bit_size == bit_class, \ + assert val.bit_size == 0 or bit_class == 0 or \ + val.bit_size == bit_class, \ 'Variable {} explicitly only matches {}-bit values but ' \ 'must be {} thanks to its consumer' \ .format(str(val), val.bit_size, @@ -562,24 +585,39 @@ class BitSizeValidator(object): nir_op = opcodes[val.opcode] dst_type_bits = type_bits(nir_op.output_type) if dst_type_bits != 0: - assert bit_class == dst_type_bits, \ + assert bit_class == 0 or bit_class == dst_type_bits, \ 'Result of nir_op_{} must be a {}-bit value but the ' \ 'consumer requires a {} value: {}' \ .format(val.opcode, dst_type_bits, self._bit_class_to_str(bit_class), str(val)) + elif nir_op.is_unsized_conversion: + # Nothing to do here + pass else: - assert val.common_class == 0 or val.common_class == bit_class, \ + assert bit_class != 0 or val.common_class != 0, \ + 'Expression cannot be constructed because nothing ' \ + 'provides or implies a result bit size: {}' \ + .format(str(val)) + + assert bit_class == 0 or val.common_class == 0 or \ + val.common_class == bit_class, \ 'Result of nir_op_{} must be a {} value but based on ' \ 'the sources but the consumer requires a {} value: {}' \ .format(val.opcode, self._bit_class_to_str(val.common_class), self._bit_class_to_str(bit_class), str(val)) - val.common_class = bit_class + if val.common_class == 0: + val.common_class = bit_class for i in range(nir_op.num_inputs): src_type_bits = type_bits(nir_op.input_types[i]) if src_type_bits != 0: self._validate_bit_class_down(val.sources[i], src_type_bits) + elif nir_op.is_unsized_conversion: + # We can't imply a bit size and nothing is coming up the chain + # so we just pass it it's own bit size. If it's zero, it will + # trigger the assert at the top of this function. + self._validate_bit_class_down(val.sources[i], 0) else: self._validate_bit_class_down(val.sources[i], val.common_class) diff --git a/src/compiler/nir/nir_search.c b/src/compiler/nir/nir_search.c index 0270302fd3d..838031e700d 100644 --- a/src/compiler/nir/nir_search.c +++ b/src/compiler/nir/nir_search.c @@ -297,6 +297,7 @@ typedef struct bitsize_tree { unsigned common_size; bool is_src_sized[4]; bool is_dest_sized; + bool is_unsized_conversion; unsigned dest_size; unsigned src_size[4]; @@ -323,6 +324,7 @@ build_bitsize_tree(void *mem_ctx, struct match_state *state, tree->is_dest_sized = !!nir_alu_type_get_type_size(info.output_type); if (tree->is_dest_sized) tree->dest_size = nir_alu_type_get_type_size(info.output_type); + tree->is_unsized_conversion = info.is_unsized_conversion; break; } @@ -360,6 +362,9 @@ bitsize_tree_filter_up(bitsize_tree *tree) if (tree->is_src_sized[i]) { assert(src_size == tree->src_size[i]); + } else if (tree->is_unsized_conversion) { + assert(src_size); + tree->src_size[i] = src_size; } else if (tree->common_size != 0) { assert(src_size == tree->common_size); tree->src_size[i] = src_size; @@ -369,11 +374,11 @@ bitsize_tree_filter_up(bitsize_tree *tree) } } - if (tree->num_srcs && tree->common_size) { - if (tree->dest_size == 0) + if (tree->common_size) { + if (!tree->is_dest_sized && !tree->is_unsized_conversion) { + assert(tree->dest_size == 0 || tree->dest_size == tree->common_size); tree->dest_size = tree->common_size; - else if (!tree->is_dest_sized) - assert(tree->dest_size == tree->common_size); + } for (unsigned i = 0; i < tree->num_srcs; i++) { if (!tree->src_size[i]) @@ -388,13 +393,13 @@ static void bitsize_tree_filter_down(bitsize_tree *tree, unsigned size) { if (tree->dest_size) - assert(tree->dest_size == size); + assert(size == 0 || tree->dest_size == size); else tree->dest_size = size; - if (!tree->is_dest_sized) { + if (!tree->is_dest_sized && !tree->is_unsized_conversion) { if (tree->common_size) - assert(tree->common_size == size); + assert(size == 0 || tree->common_size == size); else tree->common_size = size; } -- 2.19.1 _______________________________________________ mesa-dev mailing list mesa-dev@lists.freedesktop.org https://lists.freedesktop.org/mailman/listinfo/mesa-dev