From: Arthur Cohen <arthur.co...@embecosm.com>

gcc/rust/ChangeLog:

        * expand/rust-derive-ord.cc (DeriveOrd::make_cmp_arms): New function.
        (is_last): Likewise.
        (recursive_match): Likewise.
        (DeriveOrd::recursive_match): Likewise.
        (DeriveOrd::visit_struct): Add proper implementation.
        (DeriveOrd::visit_union): Likewise.
        * expand/rust-derive-ord.h: Declare these new functions.
---
 gcc/rust/expand/rust-derive-ord.cc | 90 +++++++++++++++++++++++++++---
 gcc/rust/expand/rust-derive-ord.h  | 37 +++++++++++-
 2 files changed, 116 insertions(+), 11 deletions(-)

diff --git a/gcc/rust/expand/rust-derive-ord.cc 
b/gcc/rust/expand/rust-derive-ord.cc
index 7eaaa474d1b..2403e9c2a33 100644
--- a/gcc/rust/expand/rust-derive-ord.cc
+++ b/gcc/rust/expand/rust-derive-ord.cc
@@ -92,15 +92,85 @@ DeriveOrd::cmp_fn (std::unique_ptr<BlockExpr> &&block, 
Identifier type_name)
                            builder.reference_type (ptrify (
                              builder.type_path (type_name.as_string ())))));
 
-  auto function_name = ordering == Ordering::Partial ? "partial_cmp" : "cmp";
+  auto function_name = fn (ordering);
 
   return builder.function (function_name, std::move (params),
                           ptrify (return_type), std::move (block));
 }
+
+std::pair<MatchArm, MatchArm>
+DeriveOrd::make_cmp_arms ()
+{
+  // All comparison results other than Ordering::Equal
+  auto non_equal = builder.identifier_pattern (DeriveOrd::not_equal);
+
+  std::unique_ptr<Pattern> equal = ptrify (
+    builder.path_in_expression ({"core", "cmp", "Ordering", "Equal"}, true));
+
+  // We need to wrap the pattern in Option::Some if we are doing total ordering
+  if (ordering == Ordering::Total)
+    {
+      auto pattern_items = std::unique_ptr<TupleStructItems> (
+       new TupleStructItemsNoRange (vec (std::move (equal))));
+
+      equal
+       = std::make_unique<TupleStructPattern> (builder.path_in_expression (
+                                                 LangItem::Kind::OPTION_SOME),
+                                               std::move (pattern_items));
+    }
+
+  return {builder.match_arm (std::move (equal)),
+         builder.match_arm (std::move (non_equal))};
+}
+
+template <typename T>
+inline bool
+is_last (T &elt, std::vector<T> &vec)
+{
+  rust_assert (!vec.empty ());
+
+  return &elt == &vec.back ();
+}
+
 std::unique_ptr<Expr>
-recursive_match ()
+DeriveOrd::recursive_match (std::vector<SelfOther> &&members)
 {
-  return nullptr;
+  std::unique_ptr<Expr> final_expr = nullptr;
+
+  for (auto it = members.rbegin (); it != members.rend (); it++)
+    {
+      auto &member = *it;
+
+      auto cmp_fn_path = builder.path_in_expression (
+       {"core", "cmp", trait (ordering), fn (ordering)}, true);
+
+      auto cmp_call = builder.call (ptrify (cmp_fn_path),
+                                   vec (std::move (member.self_expr),
+                                        std::move (member.other_expr)));
+
+      // For the last member (so the first iterator), we just create a call
+      // expression
+      if (it == members.rbegin ())
+       {
+         final_expr = std::move (cmp_call);
+         continue;
+       }
+
+      // If we aren't dealing with the last member, then we need to wrap all of
+      // that in a big match expression and keep going
+      auto match_arms = make_cmp_arms ();
+
+      auto match_cases
+       = {builder.match_case (std::move (match_arms.first),
+                              std::move (final_expr)),
+          builder.match_case (std::move (match_arms.second),
+                              builder.identifier (DeriveOrd::not_equal))};
+
+      final_expr
+       = builder.match (std::move (cmp_call), std::move (match_cases));
+    }
+
+  return final_expr;
 }
 
 // we need to do a recursive match expression for all of the fields used in a
@@ -128,10 +198,12 @@ recursive_match ()
 void
 DeriveOrd::visit_struct (StructStruct &item)
 {
-  // FIXME: Put cmp_fn call inside cmp_impl, pass a block to cmp_impl instead -
-  // this avoids repeating the same parameter twice (the type name)
-  expanded = cmp_impl (builder.block (), item.get_identifier (),
-                      item.get_generic_params ());
+  auto fields = SelfOther::fields (builder, item.get_fields ());
+
+  auto match_expr = recursive_match (std::move (fields));
+
+  expanded = cmp_impl (builder.block (std::move (match_expr)),
+                      item.get_identifier (), item.get_generic_params ());
 }
 
 // same as structs, but for each field index instead of each field name -
@@ -162,10 +234,10 @@ DeriveOrd::visit_enum (Enum &item)
 void
 DeriveOrd::visit_union (Union &item)
 {
-  auto trait_name = ordering == Ordering::Total ? "Ord" : "PartialOrd";
+  auto trait_name = trait (ordering);
 
   rust_error_at (item.get_locus (), "derive(%s) cannot be used on unions",
-                trait_name);
+                trait_name.c_str ());
 }
 
 } // namespace AST
diff --git a/gcc/rust/expand/rust-derive-ord.h 
b/gcc/rust/expand/rust-derive-ord.h
index fae13261e7c..047ebfb0c01 100644
--- a/gcc/rust/expand/rust-derive-ord.h
+++ b/gcc/rust/expand/rust-derive-ord.h
@@ -20,6 +20,7 @@
 #define RUST_DERIVE_ORD_H
 
 #include "rust-ast.h"
+#include "rust-derive-cmp-common.h"
 #include "rust-derive.h"
 
 namespace Rust {
@@ -42,6 +43,22 @@ public:
     Partial
   };
 
+  std::string fn (Ordering ordering)
+  {
+    if (ordering == Ordering::Total)
+      return "cmp";
+    else
+      return "partial_cmp";
+  }
+
+  std::string trait (Ordering ordering)
+  {
+    if (ordering == Ordering::Total)
+      return "Ord";
+    else
+      return "PartialOrd";
+  }
+
   DeriveOrd (Ordering ordering, location_t loc);
 
   std::unique_ptr<Item> go (Item &item);
@@ -51,10 +68,26 @@ private:
 
   Ordering ordering;
 
+  /* Identifier patterns for the non-equal match arms */
+  constexpr static const char *not_equal = "non_eq";
+
   /**
    * Create the recursive matching structure used when implementing the
-   * comparison function on multiple sub items (fields, tuple indexes...) */
-  std::unique_ptr<Expr> recursive_match ();
+   * comparison function on multiple sub items (fields, tuple indexes...)
+   */
+  std::unique_ptr<Expr> recursive_match (std::vector<SelfOther> &&members);
+
+  /**
+   * Make the match arms for one inner match in a comparison function block.
+   * This returns the "equal" match arm and the "rest" match arm, so something
+   * like `Ordering::Equal` and `non_eq` in the following match expression:
+   *
+   * match cmp(...) {
+   *     Ordering::Equal => match cmp(...) { ... }
+   *     non_eq => non_eq,
+   * }
+   */
+  std::pair<MatchArm, MatchArm> make_cmp_arms ();
 
   std::unique_ptr<Item>
   cmp_impl (std::unique_ptr<BlockExpr> &&fn_block, Identifier type_name,
-- 
2.49.0

Reply via email to