Hi, I have attempted to add syntax for symbol to denote multiple operators.
I tried it with few bogus patterns and it appears to work.... hopefully -:) eg: (bogus pattern): (for op in plus minus (match_and_simplify (op @0 @1) (op @0 @0))) generates following patterns: (plus @0 @1) -> (plus @0 @0) // simplify_0 (plus @0 @1) -> (mult @0 @0) // simplify_1 (mult @0 @1) -> (plus @0 @0) // simplify_2 (mult @0 @1) -> (mult @0 @0) // simplify_3 root (0xab6b10), 0, 2 |--(PLUS_EXPR) (0xab6b30), 1, 1 |----true (0xab6ba0), 2, 1 |------true (0xab6c10), 3, 2 |--------simplify_0 { 0xab6ba0, 0xab6c10, (nil), (nil), } (0xab6c80), 4, 0 |--------simplify_1 { 0xab6ba0, 0xab6c10, (nil), (nil), } (0xab6d40), 4, 0 |--(MULT_EXPR) (0xab6d00), 1, 1 |----true (0xab6d90), 2, 1 |------true (0xab6e00), 3, 2 |--------simplify_2 { 0xab6d90, 0xab6e00, (nil), (nil), } (0xab6e70), 4, 0 |--------simplify_3 { 0xab6d90, 0xab6e00, (nil), (nil), } (0xab6f30), 4, 0 * Changes to rest of the code: a) commutating patterns was interfering with this, because parse_match_and_simplify, immediately commutated match operand. Symbol should be replaced by operators before commutating. This required adjusting simplify (back to operand *match), and commutating is done in new function lower_commutative. Ideally this should be done during insertion in decision tree ? b) adjustments required to e_operation constructor, so it doesn't call fatal, when it does not find id to be in the hash table. * Caveats a) new e_operation constructor taking id_base * argument. Not sure if that's required. b) e_operation::user_id denotes user-defined identifier (<opname>), a rather apologetic name ... c) Similar to commutate(), replace_user_id() does not clone AST's. So we have multiple AST's sharing same nodes. * add multiple symbols ? should we have (for <opname> in operator-list1, <opname2> in operator-list2 (match_and_simplify ...)) or have nested for ? (for <opname> in operator-list1 (for <opname2> in operator-list2 (match_and_simplify ....))) * we don't detect functions with wrong arguments for example, we dont give error on: (built_in_sqrt @0 @1) I guess that's because we don't have an easy way to figure out number of arguments a function expects ? (is there a built-in equivalent of tree_code_length[] ?) * genmatch.c (e_operation::e_operation): New constructor. (e_operation::user_id): New member. (e_operation::get_op): New member function. (simplify::matchers): Remove. (simplify::match): New member. (lower_commutative): New function. (check_operator): Likewise. (replace_user_id): Likewise. (decision_tree::insert): Adjust to changes in simplify. (eat_ident): New function. (parse_expr): Call to check_operator. (parse_for): New function. (main): Add calls to parse_for, lower_commutative. Thanks and Regards, Prathamesh
Index: gcc/genmatch.c =================================================================== --- gcc/genmatch.c (revision 212366) +++ gcc/genmatch.c (working copy) @@ -135,6 +135,7 @@ id_base::hash (const value_type *op) { return op->hashval; } + inline int id_base::equal (const value_type *op1, const compare_type *op2) @@ -169,6 +170,7 @@ struct fn_id : public id_base enum built_in_function fn; }; + static void add_operator (enum tree_code code, const char *id, const char *tcc, unsigned nargs) @@ -218,8 +220,12 @@ struct predicate : public operand struct e_operation { e_operation (const char *id, bool is_commutative_ = false); + e_operation (id_base *op, bool is_commutative_ = false); id_base *op; bool is_commutative; + const char *user_id; + + static id_base *get_op (const char *id); }; @@ -255,14 +261,15 @@ struct capture : public operand }; -e_operation::e_operation (const char *id, bool is_commutative_) +id_base * +e_operation::get_op (const char *id) { id_base tem (id_base::CODE, id); - is_commutative = is_commutative_; + id_base *op; op = operators.find_with_hash (&tem, tem.hashval); if (op) - return; + return op; /* Try all-uppercase. */ char *id2 = xstrdup (id); @@ -273,7 +280,7 @@ e_operation::e_operation (const char *id if (op) { free (id2); - return; + return op; } /* Try _EXPR appended. */ @@ -284,22 +291,41 @@ e_operation::e_operation (const char *id if (op) { free (id2); - return; + return op; } - fatal ("expected operator, got %s", id); + return 0; +} + +e_operation::e_operation (id_base *op_, bool is_commutative_) +{ + gcc_assert (op_); + op = op_; + is_commutative = is_commutative_; +} + +e_operation::e_operation (const char *id, bool is_commutative_) +{ + is_commutative = is_commutative_; + user_id = 0; + + id_base *op_ = e_operation::get_op (id); + if (op_) + op = op_; + else + user_id = id; } struct simplify { simplify (const char *name_, - vec<operand *> matchers_, source_location match_location_, + operand *match_, source_location match_location_, struct operand *ifexpr_, source_location ifexpr_location_, struct operand *result_, source_location result_location_) - : name (name_), matchers (matchers_), match_location (match_location_), + : name (name_), match (match_), match_location (match_location_), ifexpr (ifexpr_), ifexpr_location (ifexpr_location_), result (result_), result_location (result_location_) {} const char *name; - vec<operand *> matchers; // vector to hold commutative expressions + operand *match; source_location match_location; struct operand *ifexpr; source_location ifexpr_location; @@ -452,19 +478,9 @@ print_operand (operand *o, FILE *f = std void print_matches (struct simplify *s, FILE *f = stderr) { - if (s->matchers.length () == 1) - return; - fprintf (f, "for expression: "); - print_operand (s->matchers[0], f); // s->matchers[0] is equivalent to original expression + print_operand (s->match, f); putc ('\n', f); - - fprintf (f, "commutative expressions:\n"); - for (unsigned i = 0; i < s->matchers.length (); ++i) - { - print_operand (s->matchers[i], f); - putc ('\n', f); - } } void @@ -552,6 +568,110 @@ commutate (operand *op) return ret; } +void +lower_commutative (simplify *s, vec<simplify *>& simplifiers) +{ + vec<operand *> matchers = commutate (s->match); + for (unsigned i = 0; i < matchers.length (); ++i) + { + simplify *ns = new simplify (s->name, matchers[i], s->match_location, + s->ifexpr, s->ifexpr_location, + s->result, s->result_location); + simplifiers.safe_push (ns); + } +} + +void +check_operator (id_base *op, unsigned n_ops, const cpp_token *token = 0) +{ + if (!op) + return; + + if (op->kind != id_base::CODE) + return; + + operator_id *opr = static_cast<operator_id *> (op); + if (opr->get_required_nargs () == n_ops) + return; + + if (token) + fatal_at (token, "%s expects %u operands, got %u operands", opr->id, opr->get_required_nargs (), n_ops); + else + fatal ("%s expects %u operands, got %u operands", opr->id, opr->get_required_nargs (), n_ops); +} + + +vec<operand *> +replace_user_id (operand *o, const char *user_id, vec<id_base *>& ids) +{ + vec<operand *> ret = vNULL; + + if (o->type == operand::OP_CAPTURE) + { + capture *c = static_cast<capture *> (o); + if (!c->what) + { + ret.safe_push (o); + return ret; + } + expr *e = static_cast<expr *> (c->what); + vec<operand *> v = replace_user_id (c->what, user_id, ids); + for (unsigned i = 0; i < v.length (); ++i) + { + capture *nc = new capture (c->where, v[i]); + ret.safe_push (nc); + } + return ret; + } + + if (o->type != operand::OP_EXPR) + { + ret.safe_push (o); + return ret; + } + + expr *e = static_cast <expr *> (o); + + vec< vec<operand *> > ops_vector = vNULL; + + for (unsigned i = 0; i < e->ops.length (); ++i) + { + vec<operand *> r = replace_user_id (e->ops[i], user_id, ids); + ops_vector.safe_push (r); + } + + vec < vec<operand *> > result = vNULL; + cartesian_product (ops_vector, result, e->ops.length ()); + + if (e->operation->user_id == 0) + { + for (unsigned i = 0; i < result.length (); ++i) + { + expr *ne = new expr (e->operation); + for (unsigned j = 0; j < result[i].length (); ++j) + ne->append_op (result[i][j]); + ret.safe_push (ne); + } + return ret; + } + + for (unsigned i = 0; i < ids.length (); ++i) + { + struct e_operation *e_op = new e_operation (ids[i], e->operation->is_commutative); + check_operator (ids[i], e->ops.length ()); + + for (unsigned j = 0; j < result.length (); ++j) + { + expr *ne = new expr (e_op); + for (unsigned k = 0; k < result[j].length (); ++k) + ne->append_op (result[j][k]); + ret.safe_push (ne); + } + } + + return ret; +} + /* Code gen off the AST. */ void @@ -828,17 +948,14 @@ decision_tree::insert (struct simplify * { dt_operand *indexes[dt_simplify::capture_max]; - for (unsigned i = 0; i < s->matchers.length (); ++i) - { - if (s->matchers[i]->type != operand::OP_EXPR) - continue; + if (s->match->type != operand::OP_EXPR) + return; - for (unsigned j = 0; j < dt_simplify::capture_max; ++j) - indexes[j] = 0; + for (unsigned j = 0; j < dt_simplify::capture_max; ++j) + indexes[j] = 0; - dt_node *p = decision_tree::insert_operand (root, s->matchers[i], indexes); - p->append_simplify (s, pattern_no, indexes); - } + dt_node *p = decision_tree::insert_operand (root, s->match, indexes); + p->append_simplify (s, pattern_no, indexes); } void @@ -1707,6 +1824,16 @@ get_ident (cpp_reader *r) return (const char *)CPP_HASHNODE (token->val.node.node)->ident.str; } +static void +eat_ident (cpp_reader *r, const char *s) +{ + const cpp_token *token = expect (r, CPP_NAME); + const char *t = (const char *) CPP_HASHNODE (token->val.node.node)->ident.str; + + if (strcmp (s, t)) + fatal_at (token, "expected %s got %s\n", s, t); +} + /* Read the next token from R and assert it is of type CPP_NUMBER and return its value. */ @@ -1735,6 +1862,7 @@ parse_capture (cpp_reader *r, operand *o return new capture (get_number (r), op); } + /* Parse expr = (operation[capture] op...) */ static struct operand * @@ -1774,13 +1902,7 @@ parse_expr (cpp_reader *r) const cpp_token *token = peek (r); if (token->type == CPP_CLOSE_PAREN) { - if (e->operation->op->kind == id_base::CODE) - { - operator_id *opr = static_cast <operator_id *> (e->operation->op); - if (e->ops.length () != opr->get_required_nargs ()) - fatal_at (token, "got %d operands instead of the required %d", - e->ops.length (), opr->get_required_nargs ()); - } + check_operator (e->operation->op, e->ops.length (), token); if (is_commutative) { if (e->ops.length () == 2) @@ -1912,11 +2034,50 @@ parse_match_and_simplify (cpp_reader *r, ifexpr = parse_c_expr (r, CPP_OPEN_PAREN); } token = peek (r); - return new simplify (id, commutate (match), match_location, + return new simplify (id, match, match_location, ifexpr, ifexpr_location, parse_op (r), token->src_loc); } +void +parse_for (cpp_reader *r, source_location match_location, vec<simplify *>& simplifiers) +{ + const char *user_id = get_ident (r); + eat_ident (r, "in"); + + vec<id_base *> ids = vNULL; + while (1) + { + const cpp_token *token = peek (r); + if (token->type != CPP_NAME) + break; + const char *id = get_ident (r); + id_base *op = e_operation::get_op (id); + if (!op) + fatal_at (token, "expect operator got %s", id); + + ids.safe_push (op); + } + + eat_token (r, CPP_OPEN_PAREN); + eat_ident (r, "match_and_simplify"); + + simplify *s = parse_match_and_simplify (r, match_location); + eat_token (r, CPP_CLOSE_PAREN); + + vec<operand *> matchers = replace_user_id (s->match, user_id, ids); + vec<operand *> transforms = replace_user_id (s->result, user_id, ids); + + for (unsigned i = 0; i < matchers.length (); ++i) + for (unsigned j = 0; j < transforms.length (); ++j) + { + simplify *ns = new simplify (s->name, matchers[i], s->match_location, + s->ifexpr, s->ifexpr_location, + transforms[j], s->result_location); + simplifiers.safe_push (ns); + } +} + static size_t round_alloc_size (size_t s) { @@ -1986,30 +2147,36 @@ main(int argc, char **argv) const char *id = get_ident (r); if (strcmp (id, "match_and_simplify") == 0) simplifiers.safe_push (parse_match_and_simplify (r, token->src_loc)); + else if (strcmp (id, "for") == 0) + parse_for (r, token->src_loc, simplifiers); else - fatal_at (token, "expected 'match_and_simplify'"); + fatal_at (token, "expected 'match_and_simplify' or 'for'"); eat_token (r, CPP_CLOSE_PAREN); } while (1); + vec<simplify *> out_simplifiers = vNULL; for (unsigned i = 0; i < simplifiers.length (); ++i) - print_matches (simplifiers[i]); + lower_commutative (simplifiers[i], out_simplifiers); + + for (unsigned i = 0; i < out_simplifiers.length (); ++i) + print_matches (out_simplifiers[i]); decision_tree dt; - for (unsigned i = 0; i < simplifiers.length (); ++i) - dt.insert (simplifiers[i], i); + for (unsigned i = 0; i < out_simplifiers.length (); ++i) + dt.insert (out_simplifiers[i], i); dt.print (stderr); if (gimple) { - write_header (stdout, simplifiers, "gimple-match-head.c"); + write_header (stdout, out_simplifiers, "gimple-match-head.c"); dt.gen_gimple (stdout); } else { - write_header (stdout, simplifiers, "generic-match-head.c"); + write_header (stdout, out_simplifiers, "generic-match-head.c"); dt.gen_generic (stdout); }