Hi,
The attached patch attempts to generate commutative variants for
a given expression.
Example:
For the AST: (PLUS_EXPR (PLUS_EXPR @0 @1) @2),
the commutative variants are:
(PLUS_EXPR (PLUS_EXPR @0 @1 ) @2 )
(PLUS_EXPR (PLUS_EXPR @1 @0 ) @2 )
(PLUS_EXPR @2 (PLUS_EXPR @0 @1 ) )
(PLUS_EXPR @2 (PLUS_EXPR @1 @0 ) )
* Basic Idea:
Consider expression e with two operands o0, and o1,
and expr-code denoting expression's code (plus/mult, etc.)
Commutative variants are stored in vector (vec<operand *>).
vec<operand *>
commutative (e)
{
if (e is not commutative)
return [e]; // vector with only one expression
v1 = commutative (o0);
v2 = commutative (o1);
ret = []
for i = 0 ... v1.length ()
for j = 0 ... v2.length ()
{
ne = new expr with <expr-code> and operands: v1[i], v2[j];
append ne to ret;
}
for i = 0 ... v2.length ()
for j = 0 ... v1.length ()
{
ne = new expr with <expr-code> and operand: v2[i], v1[j];
append ne to ret
}
return ret;
}
Example:
(plus (plus @0 @1) (plus @2 @3))
generates following commutative variants:
(PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @0 @3 ) )
(PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @3 @0 ) )
(PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @0 @3 ) )
(PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @3 @0 ) )
(PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @0 @1 ) )
(PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @1 @0 ) )
(PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @0 @1 ) )
(PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @1 @0 ) )
* Decide which operators are commutative.
Currently I assume all PLUS_EXPR and MULT_EXPR are true.
Maybe we should add syntax to mark a particular operator as commutative ?
* Cloning AST nodes
While creating another AST that represents one of
the commutative variants, should we clone the AST nodes,
so that all commutative variants have distinct AST nodes ?
That's not done currently, and AST nodes are shared amongst
different commutative expressions, and we end up with a DAG,
for a set of commutative expressions.
Thanks and Regards,
Prathamesh
Index: gcc/genmatch.c
===================================================================
--- gcc/genmatch.c (revision 211732)
+++ gcc/genmatch.c (working copy)
@@ -293,14 +293,14 @@ e_operation::e_operation (const char *id
struct simplify {
simplify (const char *name_,
- struct operand *match_, source_location match_location_,
+ vec<operand *> matchers_, source_location match_location_,
struct operand *ifexpr_, source_location ifexpr_location_,
struct operand *result_, source_location result_location_)
- : name (name_), match (match_), match_location (match_location_),
+ : name (name_), matchers (matchers_), match_location (match_location_),
ifexpr (ifexpr_), ifexpr_location (ifexpr_location_),
result (result_), result_location (result_location_) {}
const char *name;
- struct operand *match;
+ vec<operand *> matchers; // vector to hold commutative expressions
source_location match_location;
struct operand *ifexpr;
source_location ifexpr_location;
@@ -308,7 +308,108 @@ struct simplify {
source_location result_location;
};
+void
+print_operand (operand *o, FILE *f = stderr)
+{
+ if (o->type == operand::OP_CAPTURE)
+ fprintf (f, "@%s", (static_cast<capture *> (o))->where);
+
+ else if (o->type == operand::OP_PREDICATE)
+ fprintf (f, "%s", (static_cast<predicate *> (o))->ident);
+
+ else if (o->type == operand::OP_C_EXPR)
+ fprintf (f, "c_expr");
+
+ else if (o->type == operand::OP_EXPR)
+ {
+ expr *e = static_cast<expr *> (o);
+ fprintf (f, "(%s ", e->operation->op->id);
+
+ for (unsigned i = 0; i < e->ops.length (); ++i)
+ {
+ print_operand (e->ops[i], f);
+ putc (' ', f);
+ }
+
+ putc (')', f);
+ }
+
+ else
+ gcc_unreachable ();
+}
+
+void
+print_matches (struct simplify *s, FILE *f = stderr)
+{
+ if (s->matchers.length () == 1)
+ return;
+
+ fprintf (f, "for expression: ");
+ print_operand (s->matchers[0], f); // s->matchers[0] is equivalent to original expression
+ putc ('\n', f);
+
+ fprintf (f, "commutative expressions:\n");
+ for (unsigned i = 0; i < s->matchers.length (); ++i)
+ {
+ print_operand (s->matchers[i], f);
+ putc ('\n', f);
+ }
+}
+
+bool
+is_commutative (operand *op)
+{
+ if (op->type != operand::OP_EXPR)
+ return false;
+
+ expr *e = static_cast<expr *> (op);
+ operator_id *op_id = static_cast <operator_id *> (e->operation->op);
+ enum tree_code code = op_id->code;
+ if (code == PLUS_EXPR || code == MULT_EXPR)
+ return true;
+
+ return false;
+}
+
+vec<operand *>
+commutate (operand *op)
+{
+ vec<operand *> ret = vNULL;
+
+ if (!is_commutative (op))
+ {
+ ret.safe_push (op); // FIXME: should we clone op ? ret.safe_push (op->clone())
+ return ret;
+ }
+
+ expr *e = static_cast<expr *> (op);
+
+ vec<operand *> v1 = commutate (e->ops[0]);
+ vec<operand *> v2 = commutate (e->ops[1]);
+
+ unsigned i, j;
+
+ for (i = 0; i < v1.length (); ++i)
+ for (j = 0; j < v2.length (); ++j)
+ {
+ expr *ne = new expr (e->operation); // FIXME: e->operation should be cloned ?
+ ne->append_op (v1[i]);
+ ne->append_op (v2[j]);
+ ret.safe_push (ne);
+ }
+
+ for (i = 0; i < v2.length (); ++i)
+ for (j = 0; j < v1.length (); ++j)
+ {
+ expr *ne = new expr (e->operation);
+ ne->append_op (v2[i]);
+ ne->append_op (v1[j]);
+ ret.safe_push (ne);
+ }
+
+ return ret;
+}
/* Code gen off the AST. */
@@ -574,11 +675,15 @@ write_nary_simplifiers (FILE *f, vec<sim
{
simplify *s = simplifiers[i];
/* ??? This means we can't capture the outermost expression. */
- if (s->match->type != operand::OP_EXPR)
+ for (unsigned i = 0; i < s->matchers.length (); ++i)
+ {
+ operand *match = s->matchers[i];
+ if (match->type != operand::OP_EXPR)
continue;
- expr *e = static_cast <expr *> (s->match);
+ expr *e = static_cast <expr *> (match);
if (e->ops.length () != n)
continue;
+
char fail_label[16];
snprintf (fail_label, 16, "fail%d", label_cnt++);
output_line_directive (f, s->match_location);
@@ -627,6 +735,7 @@ write_nary_simplifiers (FILE *f, vec<sim
fprintf (f, " }\n");
fprintf (f, "%s:\n", fail_label);
}
+ }
fprintf (f, " return false;\n");
fprintf (f, "}\n");
}
@@ -971,7 +1080,7 @@ parse_match_and_simplify (cpp_reader *r,
ifexpr = parse_c_expr (r, CPP_OPEN_PAREN);
}
token = peek (r);
- return new simplify (id, match, match_location,
+ return new simplify (id, commutate (match), match_location,
ifexpr, ifexpr_location, parse_op (r), token->src_loc);
}
@@ -1043,6 +1152,9 @@ main(int argc, char **argv)
}
while (1);
+ for (unsigned i = 0; i < simplifiers.length (); ++i)
+ print_matches (simplifiers[i]);
+
write_gimple (stdout, simplifiers);
cpp_finish (r, NULL);