Hi, The attached patch attempts to generate commutative variants for a given expression.
Example: For the AST: (PLUS_EXPR (PLUS_EXPR @0 @1) @2), the commutative variants are: (PLUS_EXPR (PLUS_EXPR @0 @1 ) @2 ) (PLUS_EXPR (PLUS_EXPR @1 @0 ) @2 ) (PLUS_EXPR @2 (PLUS_EXPR @0 @1 ) ) (PLUS_EXPR @2 (PLUS_EXPR @1 @0 ) ) * Basic Idea: Consider expression e with two operands o0, and o1, and expr-code denoting expression's code (plus/mult, etc.) Commutative variants are stored in vector (vec<operand *>). vec<operand *> commutative (e) { if (e is not commutative) return [e]; // vector with only one expression v1 = commutative (o0); v2 = commutative (o1); ret = [] for i = 0 ... v1.length () for j = 0 ... v2.length () { ne = new expr with <expr-code> and operands: v1[i], v2[j]; append ne to ret; } for i = 0 ... v2.length () for j = 0 ... v1.length () { ne = new expr with <expr-code> and operand: v2[i], v1[j]; append ne to ret } return ret; } Example: (plus (plus @0 @1) (plus @2 @3)) generates following commutative variants: (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @0 @3 ) ) (PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @3 @0 ) ) (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @0 @3 ) ) (PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @3 @0 ) ) (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @0 @1 ) ) (PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @1 @0 ) ) (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @0 @1 ) ) (PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @1 @0 ) ) * Decide which operators are commutative. Currently I assume all PLUS_EXPR and MULT_EXPR are true. Maybe we should add syntax to mark a particular operator as commutative ? * Cloning AST nodes While creating another AST that represents one of the commutative variants, should we clone the AST nodes, so that all commutative variants have distinct AST nodes ? That's not done currently, and AST nodes are shared amongst different commutative expressions, and we end up with a DAG, for a set of commutative expressions. Thanks and Regards, Prathamesh
Index: gcc/genmatch.c =================================================================== --- gcc/genmatch.c (revision 211732) +++ gcc/genmatch.c (working copy) @@ -293,14 +293,14 @@ e_operation::e_operation (const char *id struct simplify { simplify (const char *name_, - struct operand *match_, source_location match_location_, + vec<operand *> matchers_, source_location match_location_, struct operand *ifexpr_, source_location ifexpr_location_, struct operand *result_, source_location result_location_) - : name (name_), match (match_), match_location (match_location_), + : name (name_), matchers (matchers_), match_location (match_location_), ifexpr (ifexpr_), ifexpr_location (ifexpr_location_), result (result_), result_location (result_location_) {} const char *name; - struct operand *match; + vec<operand *> matchers; // vector to hold commutative expressions source_location match_location; struct operand *ifexpr; source_location ifexpr_location; @@ -308,7 +308,108 @@ struct simplify { source_location result_location; }; +void +print_operand (operand *o, FILE *f = stderr) +{ + if (o->type == operand::OP_CAPTURE) + fprintf (f, "@%s", (static_cast<capture *> (o))->where); + + else if (o->type == operand::OP_PREDICATE) + fprintf (f, "%s", (static_cast<predicate *> (o))->ident); + + else if (o->type == operand::OP_C_EXPR) + fprintf (f, "c_expr"); + + else if (o->type == operand::OP_EXPR) + { + expr *e = static_cast<expr *> (o); + fprintf (f, "(%s ", e->operation->op->id); + + for (unsigned i = 0; i < e->ops.length (); ++i) + { + print_operand (e->ops[i], f); + putc (' ', f); + } + + putc (')', f); + } + + else + gcc_unreachable (); +} + +void +print_matches (struct simplify *s, FILE *f = stderr) +{ + if (s->matchers.length () == 1) + return; + + fprintf (f, "for expression: "); + print_operand (s->matchers[0], f); // s->matchers[0] is equivalent to original expression + putc ('\n', f); + + fprintf (f, "commutative expressions:\n"); + for (unsigned i = 0; i < s->matchers.length (); ++i) + { + print_operand (s->matchers[i], f); + putc ('\n', f); + } +} + +bool +is_commutative (operand *op) +{ + if (op->type != operand::OP_EXPR) + return false; + + expr *e = static_cast<expr *> (op); + operator_id *op_id = static_cast <operator_id *> (e->operation->op); + enum tree_code code = op_id->code; + if (code == PLUS_EXPR || code == MULT_EXPR) + return true; + + return false; +} + +vec<operand *> +commutate (operand *op) +{ + vec<operand *> ret = vNULL; + + if (!is_commutative (op)) + { + ret.safe_push (op); // FIXME: should we clone op ? ret.safe_push (op->clone()) + return ret; + } + + expr *e = static_cast<expr *> (op); + + vec<operand *> v1 = commutate (e->ops[0]); + vec<operand *> v2 = commutate (e->ops[1]); + + unsigned i, j; + + for (i = 0; i < v1.length (); ++i) + for (j = 0; j < v2.length (); ++j) + { + expr *ne = new expr (e->operation); // FIXME: e->operation should be cloned ? + ne->append_op (v1[i]); + ne->append_op (v2[j]); + ret.safe_push (ne); + } + + for (i = 0; i < v2.length (); ++i) + for (j = 0; j < v1.length (); ++j) + { + expr *ne = new expr (e->operation); + ne->append_op (v2[i]); + ne->append_op (v1[j]); + ret.safe_push (ne); + } + + return ret; +} /* Code gen off the AST. */ @@ -574,11 +675,15 @@ write_nary_simplifiers (FILE *f, vec<sim { simplify *s = simplifiers[i]; /* ??? This means we can't capture the outermost expression. */ - if (s->match->type != operand::OP_EXPR) + for (unsigned i = 0; i < s->matchers.length (); ++i) + { + operand *match = s->matchers[i]; + if (match->type != operand::OP_EXPR) continue; - expr *e = static_cast <expr *> (s->match); + expr *e = static_cast <expr *> (match); if (e->ops.length () != n) continue; + char fail_label[16]; snprintf (fail_label, 16, "fail%d", label_cnt++); output_line_directive (f, s->match_location); @@ -627,6 +735,7 @@ write_nary_simplifiers (FILE *f, vec<sim fprintf (f, " }\n"); fprintf (f, "%s:\n", fail_label); } + } fprintf (f, " return false;\n"); fprintf (f, "}\n"); } @@ -971,7 +1080,7 @@ parse_match_and_simplify (cpp_reader *r, ifexpr = parse_c_expr (r, CPP_OPEN_PAREN); } token = peek (r); - return new simplify (id, match, match_location, + return new simplify (id, commutate (match), match_location, ifexpr, ifexpr_location, parse_op (r), token->src_loc); } @@ -1043,6 +1152,9 @@ main(int argc, char **argv) } while (1); + for (unsigned i = 0; i < simplifiers.length (); ++i) + print_matches (simplifiers[i]); + write_gimple (stdout, simplifiers); cpp_finish (r, NULL);