Hi,
    The attached patch attempts to generate commutative variants for
a given expression.

Example:
For the AST: (PLUS_EXPR (PLUS_EXPR @0 @1) @2),

the commutative variants are:
(PLUS_EXPR (PLUS_EXPR @0 @1 ) @2 )
(PLUS_EXPR (PLUS_EXPR @1 @0 ) @2 )
(PLUS_EXPR @2 (PLUS_EXPR @0 @1 ) )
(PLUS_EXPR @2 (PLUS_EXPR @1 @0 ) )


* Basic Idea:
Consider expression e with two operands o0, and o1,
and expr-code denoting expression's code (plus/mult, etc.)

Commutative variants are stored in vector (vec<operand *>).

vec<operand *>
commutative (e)
{
  if (e is not commutative)
    return [e];  // vector with only one expression

  v1 = commutative (o0);
  v2 = commutative (o1);
  ret = []

  for i = 0 ... v1.length ()
    for j = 0 ... v2.length ()
      {
        ne = new expr with <expr-code> and operands: v1[i], v2[j];
        append ne to ret;
      }

  for i = 0 ... v2.length ()
    for j = 0 ... v1.length ()
      {
        ne = new expr with <expr-code> and operand: v2[i], v1[j];
        append ne to ret
      }

  return ret;
}

Example:
(plus (plus @0 @1) (plus @2 @3))
generates following commutative variants:

(PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @0 @3 ) )
(PLUS_EXPR (PLUS_EXPR @0 @1 ) (PLUS_EXPR @3 @0 ) )
(PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @0 @3 ) )
(PLUS_EXPR (PLUS_EXPR @1 @0 ) (PLUS_EXPR @3 @0 ) )
(PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @0 @1 ) )
(PLUS_EXPR (PLUS_EXPR @0 @3 ) (PLUS_EXPR @1 @0 ) )
(PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @0 @1 ) )
(PLUS_EXPR (PLUS_EXPR @3 @0 ) (PLUS_EXPR @1 @0 ) )


* Decide which operators are commutative.
Currently I assume all PLUS_EXPR and MULT_EXPR are true.
Maybe we should add syntax to mark a particular operator as commutative ?

* Cloning AST nodes
While creating another AST that represents one of
the commutative variants, should we clone the AST nodes,
so that all commutative variants have distinct AST nodes ?
That's not done currently, and AST nodes are shared amongst
different commutative expressions, and we end up with a DAG,
for a set of commutative expressions.

Thanks and Regards,
Prathamesh
Index: gcc/genmatch.c
===================================================================
--- gcc/genmatch.c	(revision 211732)
+++ gcc/genmatch.c	(working copy)
@@ -293,14 +293,14 @@ e_operation::e_operation (const char *id
 
 struct simplify {
   simplify (const char *name_,
-	    struct operand *match_, source_location match_location_,
+	    vec<operand *> matchers_, source_location match_location_,
 	    struct operand *ifexpr_, source_location ifexpr_location_,
 	    struct operand *result_, source_location result_location_)
-      : name (name_), match (match_), match_location (match_location_),
+      : name (name_), matchers (matchers_), match_location (match_location_),
       ifexpr (ifexpr_), ifexpr_location (ifexpr_location_),
       result (result_), result_location (result_location_) {}
   const char *name;
-  struct operand *match;
+  vec<operand *> matchers;  // vector to hold commutative expressions
   source_location match_location;
   struct operand *ifexpr;
   source_location ifexpr_location;
@@ -308,7 +308,108 @@ struct simplify {
   source_location result_location;
 };
 
+void
+print_operand (operand *o, FILE *f = stderr) 
+{
+  if (o->type == operand::OP_CAPTURE)
+    fprintf (f, "@%s", (static_cast<capture *> (o))->where);
+
+  else if (o->type == operand::OP_PREDICATE)
+    fprintf (f, "%s", (static_cast<predicate *> (o))->ident);
+
+  else if (o->type == operand::OP_C_EXPR)
+    fprintf (f, "c_expr");
+
+  else if (o->type == operand::OP_EXPR)
+    {
+      expr *e = static_cast<expr *> (o);
+      fprintf (f, "(%s ", e->operation->op->id);
+
+      for (unsigned i = 0; i < e->ops.length (); ++i)
+	{
+	  print_operand (e->ops[i], f);
+	  putc (' ', f);
+	}
+
+      putc (')', f);
+    }
+
+  else
+    gcc_unreachable ();
+}
+
+void
+print_matches (struct simplify *s, FILE *f = stderr)
+{
+  if (s->matchers.length () == 1)
+    return;
+
+  fprintf (f, "for expression: ");
+  print_operand (s->matchers[0], f);  // s->matchers[0] is equivalent to original expression
+  putc ('\n', f);
+
+  fprintf (f, "commutative expressions:\n");
+  for (unsigned i = 0; i < s->matchers.length (); ++i)
+    {
+      print_operand (s->matchers[i], f);
+      putc ('\n', f);
+    }
+}
+
+bool
+is_commutative (operand *op)
+{
+  if (op->type != operand::OP_EXPR)
+    return false;
+
+  expr *e = static_cast<expr *> (op);
+  operator_id *op_id = static_cast <operator_id *> (e->operation->op);
+  enum tree_code code = op_id->code;
 
+  if (code == PLUS_EXPR || code == MULT_EXPR)
+    return true;
+
+  return false;
+}
+
+vec<operand *>
+commutate (operand *op)
+{
+  vec<operand *> ret = vNULL;
+
+  if (!is_commutative (op))
+    {
+      ret.safe_push (op);  // FIXME: should we clone op ? ret.safe_push (op->clone()) 
+      return ret;
+    }
+
+  expr *e = static_cast<expr *> (op);
+
+  vec<operand *> v1 = commutate (e->ops[0]);
+  vec<operand *> v2 = commutate (e->ops[1]);
+
+  unsigned i, j;
+
+  for (i = 0; i < v1.length (); ++i)
+    for (j = 0; j < v2.length (); ++j)
+      {
+	expr *ne = new expr (e->operation);  // FIXME: e->operation should be cloned ?
+	ne->append_op (v1[i]);
+	ne->append_op (v2[j]);
+	ret.safe_push (ne);
+      }
+
+  for (i = 0; i < v2.length (); ++i)
+    for (j = 0; j < v1.length (); ++j)
+      {
+	expr *ne = new expr (e->operation);
+	ne->append_op (v2[i]);
+	ne->append_op (v1[j]);
+	ret.safe_push (ne);
+      }
+
+  return ret;
+}
 
 /* Code gen off the AST.  */
 
@@ -574,11 +675,15 @@ write_nary_simplifiers (FILE *f, vec<sim
     {
       simplify *s = simplifiers[i];
       /* ???  This means we can't capture the outermost expression.  */
-      if (s->match->type != operand::OP_EXPR)
+      for (unsigned i = 0; i < s->matchers.length (); ++i)
+	{
+	  operand *match = s->matchers[i];
+	  if (match->type != operand::OP_EXPR)
 	continue;
-      expr *e = static_cast <expr *> (s->match);
+	  expr *e = static_cast <expr *> (match);
       if (e->ops.length () != n)
 	continue;
+
       char fail_label[16];
       snprintf (fail_label, 16, "fail%d", label_cnt++);
       output_line_directive (f, s->match_location);
@@ -627,6 +735,7 @@ write_nary_simplifiers (FILE *f, vec<sim
       fprintf (f, "    }\n");
       fprintf (f, "%s:\n", fail_label);
     }
+    }
   fprintf (f, "  return false;\n");
   fprintf (f, "}\n");
 }
@@ -971,7 +1080,7 @@ parse_match_and_simplify (cpp_reader *r,
       ifexpr = parse_c_expr (r, CPP_OPEN_PAREN);
     }
   token = peek (r);
-  return new simplify (id, match, match_location,
+  return new simplify (id, commutate (match), match_location,
 		       ifexpr, ifexpr_location, parse_op (r), token->src_loc);
 }
 
@@ -1043,6 +1152,9 @@ main(int argc, char **argv)
     }
   while (1);
 
+  for (unsigned i = 0; i < simplifiers.length (); ++i)
+    print_matches (simplifiers[i]);
+
   write_gimple (stdout, simplifiers);
 
   cpp_finish (r, NULL);

Reply via email to