From: Arthur Cohen <arthur.co...@embecosm.com> gcc/rust/ChangeLog:
* expand/rust-derive-ord.cc (DeriveOrd::make_cmp_arms): New function. (is_last): Likewise. (recursive_match): Likewise. (DeriveOrd::recursive_match): Likewise. (DeriveOrd::visit_struct): Add proper implementation. (DeriveOrd::visit_union): Likewise. * expand/rust-derive-ord.h: Declare these new functions. --- gcc/rust/expand/rust-derive-ord.cc | 90 +++++++++++++++++++++++++++--- gcc/rust/expand/rust-derive-ord.h | 37 +++++++++++- 2 files changed, 116 insertions(+), 11 deletions(-) diff --git a/gcc/rust/expand/rust-derive-ord.cc b/gcc/rust/expand/rust-derive-ord.cc index 7eaaa474d1b..2403e9c2a33 100644 --- a/gcc/rust/expand/rust-derive-ord.cc +++ b/gcc/rust/expand/rust-derive-ord.cc @@ -92,15 +92,85 @@ DeriveOrd::cmp_fn (std::unique_ptr<BlockExpr> &&block, Identifier type_name) builder.reference_type (ptrify ( builder.type_path (type_name.as_string ()))))); - auto function_name = ordering == Ordering::Partial ? "partial_cmp" : "cmp"; + auto function_name = fn (ordering); return builder.function (function_name, std::move (params), ptrify (return_type), std::move (block)); } + +std::pair<MatchArm, MatchArm> +DeriveOrd::make_cmp_arms () +{ + // All comparison results other than Ordering::Equal + auto non_equal = builder.identifier_pattern (DeriveOrd::not_equal); + + std::unique_ptr<Pattern> equal = ptrify ( + builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"}, true)); + + // We need to wrap the pattern in Option::Some if we are doing total ordering + if (ordering == Ordering::Total) + { + auto pattern_items = std::unique_ptr<TupleStructItems> ( + new TupleStructItemsNoRange (vec (std::move (equal)))); + + equal + = std::make_unique<TupleStructPattern> (builder.path_in_expression ( + LangItem::Kind::OPTION_SOME), + std::move (pattern_items)); + } + + return {builder.match_arm (std::move (equal)), + builder.match_arm (std::move (non_equal))}; +} + +template <typename T> +inline bool +is_last (T &elt, std::vector<T> &vec) +{ + rust_assert (!vec.empty ()); + + return &elt == &vec.back (); +} + std::unique_ptr<Expr> -recursive_match () +DeriveOrd::recursive_match (std::vector<SelfOther> &&members) { - return nullptr; + std::unique_ptr<Expr> final_expr = nullptr; + + for (auto it = members.rbegin (); it != members.rend (); it++) + { + auto &member = *it; + + auto cmp_fn_path = builder.path_in_expression ( + {"core", "cmp", trait (ordering), fn (ordering)}, true); + + auto cmp_call = builder.call (ptrify (cmp_fn_path), + vec (std::move (member.self_expr), + std::move (member.other_expr))); + + // For the last member (so the first iterator), we just create a call + // expression + if (it == members.rbegin ()) + { + final_expr = std::move (cmp_call); + continue; + } + + // If we aren't dealing with the last member, then we need to wrap all of + // that in a big match expression and keep going + auto match_arms = make_cmp_arms (); + + auto match_cases + = {builder.match_case (std::move (match_arms.first), + std::move (final_expr)), + builder.match_case (std::move (match_arms.second), + builder.identifier (DeriveOrd::not_equal))}; + + final_expr + = builder.match (std::move (cmp_call), std::move (match_cases)); + } + + return final_expr; } // we need to do a recursive match expression for all of the fields used in a @@ -128,10 +198,12 @@ recursive_match () void DeriveOrd::visit_struct (StructStruct &item) { - // FIXME: Put cmp_fn call inside cmp_impl, pass a block to cmp_impl instead - - // this avoids repeating the same parameter twice (the type name) - expanded = cmp_impl (builder.block (), item.get_identifier (), - item.get_generic_params ()); + auto fields = SelfOther::fields (builder, item.get_fields ()); + + auto match_expr = recursive_match (std::move (fields)); + + expanded = cmp_impl (builder.block (std::move (match_expr)), + item.get_identifier (), item.get_generic_params ()); } // same as structs, but for each field index instead of each field name - @@ -162,10 +234,10 @@ DeriveOrd::visit_enum (Enum &item) void DeriveOrd::visit_union (Union &item) { - auto trait_name = ordering == Ordering::Total ? "Ord" : "PartialOrd"; + auto trait_name = trait (ordering); rust_error_at (item.get_locus (), "derive(%s) cannot be used on unions", - trait_name); + trait_name.c_str ()); } } // namespace AST diff --git a/gcc/rust/expand/rust-derive-ord.h b/gcc/rust/expand/rust-derive-ord.h index fae13261e7c..047ebfb0c01 100644 --- a/gcc/rust/expand/rust-derive-ord.h +++ b/gcc/rust/expand/rust-derive-ord.h @@ -20,6 +20,7 @@ #define RUST_DERIVE_ORD_H #include "rust-ast.h" +#include "rust-derive-cmp-common.h" #include "rust-derive.h" namespace Rust { @@ -42,6 +43,22 @@ public: Partial }; + std::string fn (Ordering ordering) + { + if (ordering == Ordering::Total) + return "cmp"; + else + return "partial_cmp"; + } + + std::string trait (Ordering ordering) + { + if (ordering == Ordering::Total) + return "Ord"; + else + return "PartialOrd"; + } + DeriveOrd (Ordering ordering, location_t loc); std::unique_ptr<Item> go (Item &item); @@ -51,10 +68,26 @@ private: Ordering ordering; + /* Identifier patterns for the non-equal match arms */ + constexpr static const char *not_equal = "non_eq"; + /** * Create the recursive matching structure used when implementing the - * comparison function on multiple sub items (fields, tuple indexes...) */ - std::unique_ptr<Expr> recursive_match (); + * comparison function on multiple sub items (fields, tuple indexes...) + */ + std::unique_ptr<Expr> recursive_match (std::vector<SelfOther> &&members); + + /** + * Make the match arms for one inner match in a comparison function block. + * This returns the "equal" match arm and the "rest" match arm, so something + * like `Ordering::Equal` and `non_eq` in the following match expression: + * + * match cmp(...) { + * Ordering::Equal => match cmp(...) { ... } + * non_eq => non_eq, + * } + */ + std::pair<MatchArm, MatchArm> make_cmp_arms (); std::unique_ptr<Item> cmp_impl (std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name, -- 2.49.0