Hi All,
This is a respin of this patch using the new approach.
Thanks,
Tamar
gcc/ChangeLog:
* doc/md.texi: Document optabs.
* internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
* optabs.def (cmul_optab, cmul_conj_optab): New,
* tree-vect-slp-patterns.c (vect_build_perm_groups,
(vect_can_combine_node_p, vect_slp_make_combine_linear,
vect_match_call_complex_mla, vect_slp_matches_complex_mul,
class complex_mul_pattern, complex_mul_pattern::matches,
complex_mul_pattern::validate_p,
complex_operations_pattern::matches): Add complex_mul_pattern.
> -----Original Message-----
> From: Gcc-patches <[email protected]> On Behalf Of Tamar
> Christina
> Sent: Friday, September 25, 2020 3:29 PM
> To: [email protected]
> Cc: nd <[email protected]>; [email protected]; [email protected]
> Subject: [PATCH v2 7/16]middle-end: Add Complex Multiplication and
> Multiplication with Conjucate detection
>
> Hi All,
>
> This patch adds pattern detections for the following operation:
>
> Complex multiplication and Conjucate Complex multiplication of the second
> parameter.
>
> c = a * b and c = a * conj (b)
>
> For the conjucate cases it supports under fast-math that the operands that
> is
> being conjucated be flipped by flipping the arguments to the optab. This
> allows it to support c = conj (a) * b and c += conj (a) * b.
>
> where a, b and c are complex numbers.
>
> and provides a shared class for anything needing to recognize complex MLA
> patterns.
>
> Bootstrapped Regtested on aarch64-none-linux-gnu and no issues.
>
> Ok for master?
>
> Thanks,
> Tamar
>
> gcc/ChangeLog:
>
> * doc/md.texi: Document optabs.
> * internal-fn.def (COMPLEX_MUL, COMPLEX_MUL_CONJ): New.
> * optabs.def (cmul_optab, cmul_conj_optab): New,
> * tree-vect-slp-patterns.c (class ComplexMLAPattern,
> class ComplexMulPattern): New.
> (slp_patterns): Add ComplexMulPattern.
>
> --
diff --git a/gcc/doc/md.texi b/gcc/doc/md.texi
index 71e226505b2619d10982b59a4ebbed73a70f29be..ddaf1abaccbd44dae11ea902ec38b474aacfb8e1 100644
--- a/gcc/doc/md.texi
+++ b/gcc/doc/md.texi
@@ -6143,6 +6143,28 @@ rotations @var{m} of 90 or 270.
This pattern is not allowed to @code{FAIL}.
+@cindex @code{cmul@var{m}4} instruction pattern
+@item @samp{cmul@var{m}4}
+Perform a vector floating point multiplication of complex numbers in operand 0
+and operand 1.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
+@cindex @code{cmul_conj@var{m}4} instruction pattern
+@item @samp{cmul_conj@var{m}4}
+Perform a vector floating point multiplication of complex numbers in operand 0
+and the conjucate of operand 1.
+
+The instruction must perform the operation on data loaded contiguously into the
+vectors.
+The operation is only supported for vector modes @var{m}.
+
+This pattern is not allowed to @code{FAIL}.
+
@cindex @code{ffs@var{m}2} instruction pattern
@item @samp{ffs@var{m}2}
Store into operand 0 one plus the index of the least significant 1-bit
diff --git a/gcc/internal-fn.def b/gcc/internal-fn.def
index 33c54be1e158ddea25c4cd6b1148df8cf4a509b5..cb41643f5e332518a0271bb8e1af4883c8bd6880 100644
--- a/gcc/internal-fn.def
+++ b/gcc/internal-fn.def
@@ -279,6 +279,8 @@ DEF_INTERNAL_FLT_FLOATN_FN (FMAX, ECF_CONST, fmax, binary)
DEF_INTERNAL_OPTAB_FN (XORSIGN, ECF_CONST, xorsign, binary)
DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT90, ECF_CONST, cadd90, binary)
DEF_INTERNAL_OPTAB_FN (COMPLEX_ADD_ROT270, ECF_CONST, cadd270, binary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL, ECF_CONST, cmul, binary)
+DEF_INTERNAL_OPTAB_FN (COMPLEX_MUL_CONJ, ECF_CONST, cmul_conj, binary)
/* FP scales. */
diff --git a/gcc/optabs.def b/gcc/optabs.def
index 2bb0bf857977035bf562a77f5f6848e80edf936d..9c267d422478d0011f288b1f5f62daabe3989ba7 100644
--- a/gcc/optabs.def
+++ b/gcc/optabs.def
@@ -292,6 +292,8 @@ OPTAB_D (copysign_optab, "copysign$F$a3")
OPTAB_D (xorsign_optab, "xorsign$F$a3")
OPTAB_D (cadd90_optab, "cadd90$a3")
OPTAB_D (cadd270_optab, "cadd270$a3")
+OPTAB_D (cmul_optab, "cmul$a3")
+OPTAB_D (cmul_conj_optab, "cmul_conj$a3")
OPTAB_D (cos_optab, "cos$a2")
OPTAB_D (cosh_optab, "cosh$a2")
OPTAB_D (exp10_optab, "exp10$a2")
diff --git a/gcc/tree-vect-slp-patterns.c b/gcc/tree-vect-slp-patterns.c
index 0732cf0a6d93be8590b84c39dff82940b280e46b..2edb0117f9cbbfc40e9ed3a96120a3c88f84a68e 100644
--- a/gcc/tree-vect-slp-patterns.c
+++ b/gcc/tree-vect-slp-patterns.c
@@ -196,6 +196,65 @@ linear_loads_p (slp_tree root, bool *linear)
return loads;
}
+/* Builds a permutation group from the operands in OPS and stores it in BLOCKS.
+ The group describes how to combine the operators to get a valid linear node.
+
+ This is used when combining multiple children from a two_operators node into
+ one using a lane permute to select the appropriate lane. As an example the
+ permute { [0 0] [1 4] [2 2] [3 3] [1 4] [5 5] } says the nodes which occur
+ twice in a group, e.g [0 0] only needs itself to possibly be made linear
+ whereas [1 4] means to combine the nodes 1 and 4. */
+
+static void
+vect_build_perm_groups (map_t *blocks, vec<slp_tree> ops)
+{
+ slp_tree op;
+ unsigned i;
+ bool is_linear = false;
+ unsigned min_eq = -1, max_eq = 0;
+ unsigned min_idx = 0, max_idx = 0;
+ FOR_EACH_VEC_ELT (ops, i, op)
+ {
+ load_permutation_t perms = linear_loads_p (op, &is_linear);
+ unsigned x, imin = -1, imax = 0;
+ for (x = 0; x < perms.length () && !is_linear; x++)
+ {
+ imin = MIN (imin, perms[x]);
+ imax = MAX (imax, perms[x]);
+ }
+
+ if (imin != imax || perms.length () == 0 || is_linear)
+ blocks[i] = {i, i};
+ else
+ {
+ if (imin <= min_eq)
+ {
+ min_eq = imin;
+ min_idx = i;
+ }
+
+ if (imin >= max_eq)
+ {
+ max_eq = imin;
+ max_idx = i;
+ }
+ }
+ }
+
+ /* Now fill in the gap. */
+ blocks[min_idx] = {min_idx, max_idx};
+ blocks[max_idx] = {min_idx, max_idx};
+
+ if (dump_enabled_p ())
+ {
+ dump_printf_loc (MSG_NOTE, vect_location, "pattern group: { ");
+ for (i = 0; i < ops.length (); i++)
+ dump_printf (MSG_NOTE,"[%d %d] ", blocks[i].a, blocks[i].b);
+ dump_printf (MSG_NOTE,"}\n");
+ }
+
+}
+
/* This function attempts to make a node rooted in NODE linear. If the node
if already linear than the node itself is returned in RESULT.
@@ -265,6 +324,85 @@ vect_slp_make_linear (slp_tree parent, slp_tree node, slp_tree *result)
return is_linear;
}
+/* Helper utility to check to see if the permutation PERM is one that can be
+ used in a node combination operation. This is defined as the permute not
+ having all the elements being the same. e.g [0 0]. */
+
+static inline bool
+vect_can_combine_node_p (load_permutation_t perm, bool is_linear)
+{
+ if (is_linear)
+ return false;
+
+ unsigned i, x;
+ FOR_EACH_VEC_ELT (perm, i, x)
+ if (perm[0] != x)
+ return false;
+
+ return true;
+}
+
+/* This function combines the nodes in MAP together to make a new node using a
+ lane permute. The nodes to combine are stored in ENTRIES and the resulting
+ node is returned in RESULT.
+
+ If the nodes are already linear then this function fails and returns FALSE.
+ Otherwise it returns the new node and TRUE. */
+
+static bool
+vect_slp_make_combine_linear (slp_tree parent, vec<slp_tree> entries, map_t map,
+ slp_tree *result)
+{
+ if (map.a == map.b)
+ return false;
+
+ slp_tree node_a = entries[map.a];
+ slp_tree node_b = entries[map.b];
+
+ bool is_a_linear = false;
+ bool is_b_linear = false;
+
+ load_permutation_t load_perm_a = linear_loads_p (node_a, &is_a_linear);
+ if (!vect_can_combine_node_p (load_perm_a, is_a_linear))
+ return false;
+ load_permutation_t load_perm_b = linear_loads_p (node_b, &is_b_linear);
+ if (!vect_can_combine_node_p (load_perm_b, is_b_linear))
+ return false;
+
+ /* Now we need to figure which node is first. */
+ auto_vec<slp_tree> nodes;
+ nodes.create (2);
+ vec<std::pair<unsigned, unsigned> > perm;
+ perm.create (2);
+ if (load_perm_a[0] < load_perm_b[0])
+ {
+ perm.quick_push (std::make_pair (0, 0));
+ perm.quick_push (std::make_pair (1, 0));
+ }
+ else
+ {
+ perm.quick_push (std::make_pair (1, 0));
+ perm.quick_push (std::make_pair (0, 0));
+ }
+
+ nodes.quick_push (node_a);
+ nodes.quick_push (node_b);
+ /* Already connected to a, just need b. */
+ SLP_TREE_REF_COUNT (node_a)++;
+ SLP_TREE_REF_COUNT (node_b)++;
+
+ slp_tree vnode = vect_create_new_slp_node (vNULL, 1);
+ SLP_TREE_CODE (vnode) = VEC_PERM_EXPR;
+ SLP_TREE_LANE_PERMUTATION (vnode) = perm;
+ SLP_TREE_VECTYPE (vnode) = SLP_TREE_VECTYPE (parent);
+ SLP_TREE_CHILDREN (vnode).safe_splice (nodes);
+ SLP_TREE_REF_COUNT (vnode) = 1;
+ SLP_TREE_LANES (vnode) = SLP_TREE_LANES (parent);
+ SLP_TREE_REPRESENTATIVE (vnode) = SLP_TREE_REPRESENTATIVE (parent);
+ *result = vnode;
+ return true;
+}
+
/*******************************************************************************
* Simple vector pattern matcher
******************************************************************************/
@@ -727,6 +865,313 @@ complex_add_pattern::matches ()
return matches (op, this->m_ops);
}
+/*******************************************************************************
+ * complex_mul_pattern
+ ******************************************************************************/
+
+/* Helper function of that looks for a match in the CHILDth child of NODE. The
+ child used is stored in RES.
+
+ If the match is successful then ARGS will contain the operands matched
+ and the complex_operation_t type is returned. If match is not successful
+ then CMPLX_NONE is returned and ARGS is left unmodified. */
+
+static complex_operation_t
+vect_match_call_complex_mla (slp_tree node, unsigned child,
+ vec<slp_tree> *args = NULL, slp_tree *res = NULL)
+{
+ gcc_assert (child < SLP_TREE_CHILDREN (node).length ());
+
+ slp_tree data = SLP_TREE_CHILDREN (node)[child];
+
+ if (res)
+ *res = data;
+
+ return vect_detect_pair_op (data, false, args);
+}
+
+/* This helper attemps to find a complex MUL pattern rooted in ROOT. If the
+ match succeeds then the pattern type is set in IFN and the operands are
+ returned in OPS.
+
+ This function matches both a normal complex multiply and complex conjucate
+ multiply. Additionally it also matches the MUL part in a FMS and FMA
+ sequence. However due to the additional TWO_OPERATORS node that an FMS
+ has the location of the negate node that denotes a conjucate changes.
+
+ In order to differentiate when and where we should check for a conjucate
+ the value MULTIPLY is set when this should only match a normal complex
+ multiply operation and INVERSE is set when we're matching a sequence for an
+ FMS where the negate node is on the other side.
+
+ Note that this function also deals with that the canonicalization of the
+ sequence is off if there is a type cast in between. This is likely a mid-end
+ bug but for now we deal with it here. */
+static bool
+vect_slp_matches_complex_mul (complex_operation_t op, slp_tree root,
+ internal_fn *ifn, vec<slp_tree> *ops,
+ bool multiply, bool inverse = false)
+{
+ *ifn = IFN_LAST;
+
+ if (op != MINUS_PLUS)
+ return false;
+
+ /* Now operand1+3 must lead to another expression. */
+ auto_vec<slp_tree> args0;
+ complex_operation_t op2 = vect_match_call_complex_mla (root, 0, &args0);
+
+ if (op2 != MULT_MULT)
+ return false;
+
+ /* Now operand2+4 must lead to another expression. */
+ auto_vec<slp_tree> args1;
+ complex_operation_t op3 = vect_match_call_complex_mla (root, 1, &args1);
+
+ if (op3 != MULT_MULT)
+ return false;
+
+ vec<slp_tree> args2 = SLP_TREE_CHILDREN (args1[inverse ? 0 : 1]);
+ slp_tree neg_node = NULL;
+ bool first_neg = false, second_neg = false;
+
+ /* Now operand2+4 may lead to another expression. */
+ if ((first_neg = vect_match_expression_p (args2[0], NEGATE_EXPR)))
+ neg_node = SLP_TREE_CHILDREN (args2[0])[0];
+ else if ((second_neg = vect_match_expression_p (args2[1], NEGATE_EXPR)))
+ neg_node = SLP_TREE_CHILDREN (args2[1])[0];
+
+ if (first_neg && multiply)
+ return false;
+
+ /* Check if the neg node is a dup, otherwise not a pattern we want. */
+ bool is_dup = false;
+ bool same_operand = true;
+ stmt_vec_info elem;
+ unsigned i;
+ load_permutation_t perm = linear_loads_p (neg_node, &is_dup);
+ for (i = 0; i < perm.length (); i++)
+ if (perm[i] != perm[0])
+ return false;
+
+
+ /* Check if the conjucate is on the second first and flip the order so we
+ get it in the right place. We can't check the DR of the new child since
+ we may not be a load. We can't recurse all the way down because we
+ may find a child with multiple children or external. So instead just
+ check the operands to the multiply which tell us where the conjucate
+ was and how it's interpeting the permute. */
+ vec<stmt_vec_info> stmts = SLP_TREE_SCALAR_STMTS (args0[0]);
+ tree first_op = gimple_op (STMT_VINFO_STMT (stmts[0]), 1);
+ FOR_EACH_VEC_ELT (stmts, i, elem)
+ if (first_op != gimple_op (STMT_VINFO_STMT (elem), 1))
+ {
+ same_operand = false;
+ break;
+ }
+
+ /* Reject operations that we don't have an optab for. */
+ if (first_neg && !multiply && same_operand && !inverse)
+ return false;
+
+ bool is_neg = first_neg || second_neg;
+
+ if (!is_neg)
+ {
+ /* Indicates a rotation in the complex number, not a pattern we are
+ looking for.. */
+ vec<slp_tree> params = SLP_TREE_CHILDREN (args0[0]);
+ if (vect_match_expression_p (params[0], NEGATE_EXPR)
+ || vect_match_expression_p (params[1], NEGATE_EXPR))
+ return false;
+ *ifn = IFN_COMPLEX_MUL;
+ ops->safe_splice (params);
+ ops->safe_push (SLP_TREE_CHILDREN (args1[1])[0]);
+ ops->safe_push (SLP_TREE_CHILDREN (args1[1])[1]);
+ }
+ else if (is_neg)
+ {
+ *ifn = IFN_COMPLEX_MUL_CONJ;
+ vec<slp_tree> params = SLP_TREE_CHILDREN (args0[inverse ? 1 : 0]);
+ slp_tree value = second_neg ? args2[0] : args2[1];
+ /* Check if the conjucate is on the first or second parameter. */
+ if (same_operand)
+ {
+ ops->safe_push (params[0]);
+ ops->safe_push (params[1]);
+ ops->safe_push (neg_node);
+ ops->safe_push (value);
+ }
+ else
+ {
+ ops->safe_push (params[1]);
+ ops->safe_push (params[0]);
+ ops->safe_push (neg_node);
+ ops->safe_push (value);
+ }
+
+ /* The two_operators with an FMS reverse the nodes so we have to swap them
+ back to make a sensible operation. */
+ if (inverse)
+ std::swap ((*ops)[2], (*ops)[3]);
+ }
+
+ return *ifn != IFN_LAST;
+}
+
+static bool
+vect_slp_matches_complex_mul (slp_tree root, internal_fn *ifn,
+ vec<slp_tree> *ops, bool multiply,
+ bool inverse = false)
+{
+ return vect_slp_matches_complex_mul (vect_detect_pair_op (root), root, ifn,
+ ops, multiply, inverse);
+}
+
+class complex_mul_pattern : public complex_pattern
+{
+ protected:
+ /* Allocate enough space for FMA as well. */
+ map_t m_blocks[6] = {};
+ bool m_inplace = false;
+ auto_vec<slp_tree> workset;
+ complex_mul_pattern (slp_tree *node, vec_info *vinfo)
+ : complex_pattern (node, vinfo)
+ {
+ this->m_arity = 2;
+ this->m_num_args = 2;
+ }
+
+ public:
+ static vect_pattern* create (slp_tree *node, vec_info *vinfo)
+ {
+ return new complex_mul_pattern (node, vinfo);
+ }
+
+ const char* get_name ()
+ {
+ return "Complex Multiplication";
+ }
+
+ bool validate_p ();
+ bool matches ();
+ bool matches (complex_operation_t op, vec<slp_tree> ops);
+};
+
+
+/* Pattern matcher for trying to match complex multiply pattern in SLP tree
+ If the operation matches then IFN is set to the operation it matched
+ and the arguments to the two replacement statements are put in M_OPS.
+
+ If no match is found then IFN is set to IFN_LAST and M_OPS is unchanged.
+
+ This function matches the patterns shaped as:
+
+ double ax = (b[i+1] * a[i]);
+ double bx = (a[i+1] * b[i]);
+
+ c[i] = c[i] - ax;
+ c[i+1] = c[i+1] + bx;
+
+ If a match occurred then TRUE is returned, else FALSE. */
+
+bool
+complex_mul_pattern::matches (complex_operation_t op, vec<slp_tree> /* ops */)
+{
+ bool res
+ = vect_slp_matches_complex_mul (op, *this->m_node, &this->m_ifn,
+ &this->m_ops, true);
+ if (res)
+ {
+ vect_build_perm_groups (&this->m_blocks[0], this->m_ops);
+ this->workset.safe_splice (SLP_TREE_CHILDREN (*this->m_node));
+ save_match ();
+ }
+ return res;
+}
+
+bool
+complex_mul_pattern::matches ()
+{
+ complex_operation_t op
+ = vect_detect_pair_op (*this->m_node);
+ return matches (op, this->m_ops);
+}
+
+
+/* Validates to see if the Complex MUL that we have matched is valid. This is
+ done through a combination of making nodes linear and combining nodes. */
+
+bool
+complex_mul_pattern::validate_p ()
+{
+ if (!this->m_match)
+ return false;
+
+ slp_tree node;
+ unsigned ix;
+ hash_set<slp_tree> cache;
+ FOR_EACH_VEC_ELT (this->workset, ix, node)
+ {
+ auto_vec<slp_tree> nodes;
+ nodes.create (this->m_num_args);
+ slp_tree tmp = NULL;
+
+ unsigned i;
+ for (i = 0; i < this->m_num_args; i++)
+ {
+ unsigned index = (ix * this->m_num_args) + i;
+ map_t map = this->m_blocks[index];
+ slp_tree vnode = NULL;
+ bool needs_linear = map.a == map.b;
+ tmp = this->m_ops[index];
+ cache.add (tmp);
+ if (needs_linear && vect_slp_make_linear (node, tmp, &vnode))
+ nodes.quick_push (vnode);
+ else if (!needs_linear
+ && vect_slp_make_combine_linear (node, this->m_ops, map,
+ &vnode))
+ nodes.quick_push (vnode);
+ else
+ {
+ if (dump_enabled_p ())
+ dump_printf_loc (MSG_MISSED_OPTIMIZATION, vect_location,
+ "stmts could not be made %s %p\n",
+ needs_linear ? "linear" : "linear/combined",
+ tmp);
+ nodes.release();
+ return false;
+ }
+
+ vect_mark_stmts_as_in_pattern (&cache, node);
+ }
+
+ if (m_inplace)
+ {
+ SLP_TREE_CHILDREN (*this->m_node).truncate (0);
+ SLP_TREE_CHILDREN (*this->m_node).safe_splice (nodes);
+ }
+ else
+ {
+ slp_tree new_node
+ = vect_create_new_slp_node (SLP_TREE_SCALAR_STMTS (node),
+ SLP_TREE_CHILDREN (node).length ());
+ SLP_TREE_VECTYPE (new_node) = SLP_TREE_VECTYPE (node);
+ SLP_TREE_LANE_PERMUTATION (new_node)
+ = SLP_TREE_LANE_PERMUTATION (node);
+ SLP_TREE_CODE (new_node) = SLP_TREE_CODE (node);
+ SLP_TREE_REF_COUNT (new_node) = SLP_TREE_REF_COUNT (node);
+ SLP_TREE_REPRESENTATIVE (new_node) = SLP_TREE_REPRESENTATIVE (node);
+ SLP_TREE_CHILDREN (new_node).safe_splice (nodes);
+ SLP_TREE_LANES (new_node) = SLP_TREE_LANES (node);
+
+ SLP_TREE_CHILDREN (*this->m_node)[ix] = new_node;
+ }
+ }
+
+ return true;
+}
+
/*******************************************************************************
* complex_operations_pattern class
******************************************************************************/
@@ -776,6 +1221,12 @@ complex_operations_pattern::matches ()
return false;
/* Check which pattern this may be. Match longest pattern first. */
+ this->m_patt = complex_mul_pattern::create (this->m_node, this->m_vinfo);
+ if (this->m_patt->matches (op, this->m_ops))
+ return true;
+
+ delete this->m_patt;
+
this->m_patt = complex_add_pattern::create (this->m_node, this->m_vinfo);
if (this->m_patt->matches (op, this->m_ops))
return true;