https://gcc.gnu.org/g:9076a8f688886de933d5f953855a3431d3e0a922

commit r16-2943-g9076a8f688886de933d5f953855a3431d3e0a922
Author: Philip Herron <herron.phi...@googlemail.com>
Date:   Fri Jul 18 15:46:59 2025 +0100

    gccrs: Add initial support for deffered operator overload resolution
    
    In the test case:
    
      fn test (len: usize) -> u64 {
         let mut i = 0;
         let mut out = 0;
         if i + 3 < len {
            out = 123;
         }
         out
      }
    
    The issue is to determine the correct type of 'i', out is simple because it 
hits a
    coercion site in the resturn position for u64. But 'i + 3', 'i' is an 
integer infer
    variable and the same for the literal '3'. So when it comes to resolving 
the type for
    the Add expression we hit the resolve the operator overload code and 
because of this:
    
      macro_rules! add_impl {
          ($($t:ty)*) => ($(
              impl Add for $t {
                  type Output = $t;
    
                  #[inline]
                  #[rustc_inherit_overflow_checks]
                  fn add(self, other: $t) -> $t { self + other }
              }
          )*)
      }
    
      add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 }
    
    This means the resolution for 'i + 3' is ambigious because it could be any 
of these Add
    implementations. But because we unify against the '< len' where len is 
defined as usize
    later in the resolution we determine 'i' is actually a usize. Which means 
if we defer the
    resolution of this operator overload in the ambigious case we can simply 
resolve it at the
    end.
    
    Fixes Rust-GCC#3916
    
    gcc/rust/ChangeLog:
    
            * hir/tree/rust-hir-expr.cc (OperatorExprMeta::OperatorExprMeta): 
track the rhs
            * hir/tree/rust-hir-expr.h: likewise
            * hir/tree/rust-hir-path.h: get rid of old comments
            * typecheck/rust-hir-trait-reference.cc 
(TraitReference::get_trait_substs): return
            references instead of copy
            * typecheck/rust-hir-trait-reference.h: update header
            * typecheck/rust-hir-type-check-expr.cc 
(TypeCheckExpr::ResolveOpOverload): write ambigious
            operator overloads to a table and try to resolve it at the end
            * typecheck/rust-hir-type-check-expr.h: new static helper
            * typecheck/rust-hir-type-check.h (struct DeferredOpOverload): new 
model to defer resolution
            * typecheck/rust-typecheck-context.cc 
(TypeCheckContext::lookup_operator_overload): new
            (TypeCheckContext::compute_ambigious_op_overload): likewise
            (TypeCheckContext::compute_inference_variables): likewise
    
    gcc/testsuite/ChangeLog:
    
            * rust/compile/issue-3916.rs: New test.
    
    Signed-off-by: Philip Herron <herron.phi...@googlemail.com>

Diff:
---
 gcc/rust/hir/tree/rust-hir-expr.cc             | 10 ++--
 gcc/rust/hir/tree/rust-hir-expr.h              | 32 +++++++++++-
 gcc/rust/hir/tree/rust-hir-path.h              | 12 +++--
 gcc/rust/typecheck/rust-hir-trait-reference.cc |  8 ++-
 gcc/rust/typecheck/rust-hir-trait-reference.h  |  4 +-
 gcc/rust/typecheck/rust-hir-type-check-expr.cc | 71 +++++++++++++++++++++++---
 gcc/rust/typecheck/rust-hir-type-check-expr.h  |  5 ++
 gcc/rust/typecheck/rust-hir-type-check.h       | 45 ++++++++++++++++
 gcc/rust/typecheck/rust-typecheck-context.cc   | 63 ++++++++++++++++++++++-
 gcc/testsuite/rust/compile/issue-3916.rs       | 36 +++++++++++++
 10 files changed, 267 insertions(+), 19 deletions(-)

diff --git a/gcc/rust/hir/tree/rust-hir-expr.cc 
b/gcc/rust/hir/tree/rust-hir-expr.cc
index 93dcec2c8d79..038bfc77f94a 100644
--- a/gcc/rust/hir/tree/rust-hir-expr.cc
+++ b/gcc/rust/hir/tree/rust-hir-expr.cc
@@ -17,6 +17,7 @@
 // <http://www.gnu.org/licenses/>.
 
 #include "rust-hir-expr.h"
+#include "rust-hir-map.h"
 #include "rust-operators.h"
 #include "rust-hir-stmt.h"
 
@@ -1321,37 +1322,40 @@ AsyncBlockExpr::operator= (AsyncBlockExpr const &other)
 OperatorExprMeta::OperatorExprMeta (HIR::CompoundAssignmentExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
-    locus (expr.get_locus ())
+    rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus 
())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::ArithmeticOrLogicalExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
-    locus (expr.get_locus ())
+    rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus 
())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::NegationExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
+    rvalue_mappings (Analysis::NodeMapping::get_error ()),
     locus (expr.get_locus ())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::DereferenceExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
+    rvalue_mappings (Analysis::NodeMapping::get_error ()),
     locus (expr.get_locus ())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_array_expr ().get_mappings ()),
+    rvalue_mappings (expr.get_index_expr ().get_mappings ()),
     locus (expr.get_locus ())
 {}
 
 OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr)
   : node_mappings (expr.get_mappings ()),
     lvalue_mappings (expr.get_expr ().get_mappings ()),
-    locus (expr.get_locus ())
+    rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus 
())
 {}
 
 InlineAsmOperand::In::In (
diff --git a/gcc/rust/hir/tree/rust-hir-expr.h 
b/gcc/rust/hir/tree/rust-hir-expr.h
index fcb4744fef4c..028455b98702 100644
--- a/gcc/rust/hir/tree/rust-hir-expr.h
+++ b/gcc/rust/hir/tree/rust-hir-expr.h
@@ -27,6 +27,7 @@
 #include "rust-hir-attrs.h"
 #include "rust-expr.h"
 #include "rust-hir-map.h"
+#include "rust-mapping-common.h"
 
 namespace Rust {
 namespace HIR {
@@ -2892,6 +2893,22 @@ public:
 
   OperatorExprMeta (HIR::ComparisonExpr &expr);
 
+  OperatorExprMeta (const OperatorExprMeta &other)
+    : node_mappings (other.node_mappings),
+      lvalue_mappings (other.lvalue_mappings),
+      rvalue_mappings (other.rvalue_mappings), locus (other.locus)
+  {}
+
+  OperatorExprMeta &operator= (const OperatorExprMeta &other)
+  {
+    node_mappings = other.node_mappings;
+    lvalue_mappings = other.lvalue_mappings;
+    rvalue_mappings = other.rvalue_mappings;
+    locus = other.locus;
+
+    return *this;
+  }
+
   const Analysis::NodeMapping &get_mappings () const { return node_mappings; }
 
   const Analysis::NodeMapping &get_lvalue_mappings () const
@@ -2899,11 +2916,22 @@ public:
     return lvalue_mappings;
   }
 
+  const Analysis::NodeMapping &get_rvalue_mappings () const
+  {
+    return rvalue_mappings;
+  }
+
+  bool has_rvalue_mappings () const
+  {
+    return rvalue_mappings.get_hirid () != UNKNOWN_HIRID;
+  }
+
   location_t get_locus () const { return locus; }
 
 private:
-  const Analysis::NodeMapping node_mappings;
-  const Analysis::NodeMapping lvalue_mappings;
+  Analysis::NodeMapping node_mappings;
+  Analysis::NodeMapping lvalue_mappings;
+  Analysis::NodeMapping rvalue_mappings;
   location_t locus;
 };
 
diff --git a/gcc/rust/hir/tree/rust-hir-path.h 
b/gcc/rust/hir/tree/rust-hir-path.h
index 3ce2662c8024..5f88c6827bb1 100644
--- a/gcc/rust/hir/tree/rust-hir-path.h
+++ b/gcc/rust/hir/tree/rust-hir-path.h
@@ -41,11 +41,15 @@ public:
     : segment_name (std::move (segment_name))
   {}
 
-  /* TODO: insert check in constructor for this? Or is this a semantic error
-   * best handled then? */
+  PathIdentSegment (const PathIdentSegment &other)
+    : segment_name (other.segment_name)
+  {}
 
-  /* TODO: does this require visitor? pretty sure this isn't polymorphic, but
-   * not entirely sure */
+  PathIdentSegment &operator= (PathIdentSegment const &other)
+  {
+    segment_name = other.segment_name;
+    return *this;
+  }
 
   // Creates an error PathIdentSegment.
   static PathIdentSegment create_error () { return PathIdentSegment (""); }
diff --git a/gcc/rust/typecheck/rust-hir-trait-reference.cc 
b/gcc/rust/typecheck/rust-hir-trait-reference.cc
index 88e270d510d2..74856f098fa0 100644
--- a/gcc/rust/typecheck/rust-hir-trait-reference.cc
+++ b/gcc/rust/typecheck/rust-hir-trait-reference.cc
@@ -432,7 +432,13 @@ TraitReference::trait_has_generics () const
   return !trait_substs.empty ();
 }
 
-std::vector<TyTy::SubstitutionParamMapping>
+std::vector<TyTy::SubstitutionParamMapping> &
+TraitReference::get_trait_substs ()
+{
+  return trait_substs;
+}
+
+const std::vector<TyTy::SubstitutionParamMapping> &
 TraitReference::get_trait_substs () const
 {
   return trait_substs;
diff --git a/gcc/rust/typecheck/rust-hir-trait-reference.h 
b/gcc/rust/typecheck/rust-hir-trait-reference.h
index 8b1ac7daf7f1..473513ea75ff 100644
--- a/gcc/rust/typecheck/rust-hir-trait-reference.h
+++ b/gcc/rust/typecheck/rust-hir-trait-reference.h
@@ -224,7 +224,9 @@ public:
 
   bool trait_has_generics () const;
 
-  std::vector<TyTy::SubstitutionParamMapping> get_trait_substs () const;
+  std::vector<TyTy::SubstitutionParamMapping> &get_trait_substs ();
+
+  const std::vector<TyTy::SubstitutionParamMapping> &get_trait_substs () const;
 
   bool satisfies_bound (const TraitReference &reference) const;
 
diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.cc 
b/gcc/rust/typecheck/rust-hir-type-check-expr.cc
index c1404561f4d0..5db0e5690c97 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-expr.cc
+++ b/gcc/rust/typecheck/rust-hir-type-check-expr.cc
@@ -17,6 +17,7 @@
 // <http://www.gnu.org/licenses/>.
 
 #include "optional.h"
+#include "rust-common.h"
 #include "rust-hir-expr.h"
 #include "rust-system.h"
 #include "rust-tyty-call.h"
@@ -59,6 +60,19 @@ TypeCheckExpr::Resolve (HIR::Expr &expr)
   return resolver.infered;
 }
 
+TyTy::BaseType *
+TypeCheckExpr::ResolveOpOverload (LangItem::Kind lang_item_type,
+                                 HIR::OperatorExprMeta expr,
+                                 TyTy::BaseType *lhs, TyTy::BaseType *rhs,
+                                 HIR::PathIdentSegment specified_segment)
+{
+  TypeCheckExpr resolver;
+
+  resolver.resolve_operator_overload (lang_item_type, expr, lhs, rhs,
+                                     specified_segment);
+  return resolver.infered;
+}
+
 void
 TypeCheckExpr::visit (HIR::TupleIndexExpr &expr)
 {
@@ -1885,7 +1899,16 @@ TypeCheckExpr::resolve_operator_overload (
   // probe for the lang-item
   if (!lang_item_defined)
     return false;
+
   DefId &respective_lang_item_id = lang_item_defined.value ();
+  auto def_lookup = mappings.lookup_defid (respective_lang_item_id);
+  rust_assert (def_lookup.has_value ());
+
+  HIR::Item *def_item = def_lookup.value ();
+  rust_assert (def_item->get_item_kind () == HIR::Item::ItemKind::Trait);
+  HIR::Trait &trait = *static_cast<HIR::Trait *> (def_item);
+  TraitReference *defid_trait_reference = TraitResolver::Resolve (trait);
+  rust_assert (!defid_trait_reference->is_error ());
 
   // we might be in a static or const context and unknown is fine
   TypeCheckContextItem current_context = TypeCheckContextItem::get_error ();
@@ -1929,15 +1952,49 @@ TypeCheckExpr::resolve_operator_overload (
 
   if (selected_candidates.size () > 1)
     {
-      // mutliple candidates
-      rich_location r (line_table, expr.get_locus ());
-      for (auto &c : resolved_candidates)
-       r.add_range (c.candidate.locus);
+      auto infer
+       = TyTy::TyVar::get_implicit_infer_var (expr.get_locus ()).get_tyty ();
+      auto trait_subst = defid_trait_reference->get_trait_substs ();
+      rust_assert (trait_subst.size () > 0);
 
-      rust_error_at (
-       r, "multiple candidates found for possible operator overload");
+      TyTy::TypeBoundPredicate pred (respective_lang_item_id, trait_subst,
+                                    BoundPolarity::RegularBound,
+                                    expr.get_locus ());
 
-      return false;
+      std::vector<TyTy::SubstitutionArg> mappings;
+      auto &self_param_mapping = trait_subst[0];
+      mappings.push_back (TyTy::SubstitutionArg (&self_param_mapping, lhs));
+
+      if (rhs != nullptr)
+       {
+         rust_assert (trait_subst.size () == 2);
+         auto &rhs_param_mapping = trait_subst[1];
+         mappings.push_back (TyTy::SubstitutionArg (&rhs_param_mapping, lhs));
+       }
+
+      std::map<std::string, TyTy::BaseType *> binding_args;
+      binding_args["Output"] = infer;
+
+      TyTy::SubstitutionArgumentMappings arg_mappings (mappings, binding_args,
+                                                      TyTy::RegionParamList (
+                                                        trait_subst.size ()),
+                                                      expr.get_locus ());
+      pred.apply_argument_mappings (arg_mappings, false);
+
+      infer->inherit_bounds ({pred});
+      DeferredOpOverload defer (expr.get_mappings ().get_hirid (),
+                               lang_item_type, specified_segment, pred, expr);
+      context->insert_deferred_operator_overload (std::move (defer));
+
+      if (rhs != nullptr)
+       lhs = unify_site (expr.get_mappings ().get_hirid (),
+                         TyTy::TyWithLocation (lhs),
+                         TyTy::TyWithLocation (rhs), expr.get_locus ());
+
+      infered = unify_site (expr.get_mappings ().get_hirid (),
+                           TyTy::TyWithLocation (lhs),
+                           TyTy::TyWithLocation (infer), expr.get_locus ());
+      return true;
     }
 
   // Get the adjusted self
diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h 
b/gcc/rust/typecheck/rust-hir-type-check-expr.h
index 531197436853..48f28c700795 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-expr.h
+++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h
@@ -31,6 +31,11 @@ class TypeCheckExpr : private TypeCheckBase, private 
HIR::HIRExpressionVisitor
 public:
   static TyTy::BaseType *Resolve (HIR::Expr &expr);
 
+  static TyTy::BaseType *
+  ResolveOpOverload (LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr,
+                    TyTy::BaseType *lhs, TyTy::BaseType *rhs,
+                    HIR::PathIdentSegment specified_segment);
+
   void visit (HIR::TupleIndexExpr &expr) override;
   void visit (HIR::TupleExpr &expr) override;
   void visit (HIR::ReturnExpr &expr) override;
diff --git a/gcc/rust/typecheck/rust-hir-type-check.h 
b/gcc/rust/typecheck/rust-hir-type-check.h
index 356c55803ed6..80e403448359 100644
--- a/gcc/rust/typecheck/rust-hir-type-check.h
+++ b/gcc/rust/typecheck/rust-hir-type-check.h
@@ -20,6 +20,7 @@
 #define RUST_HIR_TYPE_CHECK
 
 #include "rust-hir-map.h"
+#include "rust-mapping-common.h"
 #include "rust-tyty.h"
 #include "rust-hir-trait-reference.h"
 #include "rust-stacked-contexts.h"
@@ -157,6 +158,39 @@ public:
   WARN_UNUSED_RESULT Lifetime next () { return Lifetime (interner_index++); }
 };
 
+struct DeferredOpOverload
+{
+  HirId expr_id;
+  LangItem::Kind lang_item_type;
+  HIR::PathIdentSegment specified_segment;
+  TyTy::TypeBoundPredicate predicate;
+  HIR::OperatorExprMeta op;
+
+  DeferredOpOverload (HirId expr_id, LangItem::Kind lang_item_type,
+                     HIR::PathIdentSegment specified_segment,
+                     TyTy::TypeBoundPredicate &predicate,
+                     HIR::OperatorExprMeta op)
+    : expr_id (expr_id), lang_item_type (lang_item_type),
+      specified_segment (specified_segment), predicate (predicate), op (op)
+  {}
+
+  DeferredOpOverload (const struct DeferredOpOverload &other)
+    : expr_id (other.expr_id), lang_item_type (other.lang_item_type),
+      specified_segment (other.specified_segment), predicate (other.predicate),
+      op (other.op)
+  {}
+
+  DeferredOpOverload &operator= (struct DeferredOpOverload const &other)
+  {
+    expr_id = other.expr_id;
+    lang_item_type = other.lang_item_type;
+    specified_segment = other.specified_segment;
+    op = other.op;
+
+    return *this;
+  }
+};
+
 class TypeCheckContext
 {
 public:
@@ -237,6 +271,13 @@ public:
   void insert_operator_overload (HirId id, TyTy::FnType *call_site);
   bool lookup_operator_overload (HirId id, TyTy::FnType **call);
 
+  void insert_deferred_operator_overload (DeferredOpOverload deferred);
+  bool lookup_deferred_operator_overload (HirId id,
+                                         DeferredOpOverload *deferred);
+
+  void iterate_deferred_operator_overloads (
+    std::function<bool (HirId, DeferredOpOverload &)> cb);
+
   void insert_unconstrained_check_marker (HirId id, bool status);
   bool have_checked_for_unconstrained (HirId id, bool *result);
 
@@ -271,6 +312,7 @@ private:
   TypeCheckContext ();
 
   bool compute_infer_var (HirId id, TyTy::BaseType *ty, bool emit_error);
+  bool compute_ambigious_op_overload (HirId id, DeferredOpOverload &op);
 
   std::map<NodeId, HirId> node_id_refs;
   std::map<HirId, TyTy::BaseType *> resolved;
@@ -308,6 +350,9 @@ private:
   std::set<HirId> querys_in_progress;
   std::set<DefId> trait_queries_in_progress;
 
+  // deferred operator overload
+  std::map<HirId, DeferredOpOverload> deferred_operator_overloads;
+
   // variance analysis
   TyTy::VarianceAnalysis::CrateCtx variance_analysis_ctx;
 
diff --git a/gcc/rust/typecheck/rust-typecheck-context.cc 
b/gcc/rust/typecheck/rust-typecheck-context.cc
index 7b3584823e44..83b17612d5e3 100644
--- a/gcc/rust/typecheck/rust-typecheck-context.cc
+++ b/gcc/rust/typecheck/rust-typecheck-context.cc
@@ -18,6 +18,7 @@
 
 #include "rust-hir-type-check.h"
 #include "rust-type-util.h"
+#include "rust-hir-type-check-expr.h"
 
 namespace Rust {
 namespace Resolver {
@@ -408,6 +409,38 @@ TypeCheckContext::lookup_operator_overload (HirId id, 
TyTy::FnType **call)
   return true;
 }
 
+void
+TypeCheckContext::insert_deferred_operator_overload (
+  DeferredOpOverload deferred)
+{
+  HirId expr_id = deferred.expr_id;
+  deferred_operator_overloads.emplace (std::make_pair (expr_id, deferred));
+}
+
+bool
+TypeCheckContext::lookup_deferred_operator_overload (
+  HirId id, DeferredOpOverload *deferred)
+{
+  auto it = deferred_operator_overloads.find (id);
+  if (it == deferred_operator_overloads.end ())
+    return false;
+
+  *deferred = it->second;
+  return true;
+}
+
+void
+TypeCheckContext::iterate_deferred_operator_overloads (
+  std::function<bool (HirId, DeferredOpOverload &)> cb)
+{
+  for (auto it = deferred_operator_overloads.begin ();
+       it != deferred_operator_overloads.end (); it++)
+    {
+      if (!cb (it->first, it->second))
+       return;
+    }
+}
+
 void
 TypeCheckContext::insert_unconstrained_check_marker (HirId id, bool status)
 {
@@ -574,10 +607,38 @@ TypeCheckContext::regions_from_generic_args (const 
HIR::GenericArgs &args) const
   return regions;
 }
 
+bool
+TypeCheckContext::compute_ambigious_op_overload (HirId id,
+                                                DeferredOpOverload &op)
+{
+  rust_debug ("attempting resolution of op overload: %s",
+             op.predicate.as_string ().c_str ());
+
+  TyTy::BaseType *lhs = nullptr;
+  bool ok = lookup_type (op.op.get_lvalue_mappings ().get_hirid (), &lhs);
+  rust_assert (ok);
+
+  TyTy::BaseType *rhs = nullptr;
+  if (op.op.has_rvalue_mappings ())
+    {
+      bool ok = lookup_type (op.op.get_rvalue_mappings ().get_hirid (), &rhs);
+      rust_assert (ok);
+    }
+
+  TypeCheckExpr::ResolveOpOverload (op.lang_item_type, op.op, lhs, rhs,
+                                   op.specified_segment);
+
+  return true;
+}
+
 void
 TypeCheckContext::compute_inference_variables (bool emit_error)
 {
-  // default inference variables if possible
+  iterate_deferred_operator_overloads (
+    [&] (HirId id, DeferredOpOverload &op) mutable -> bool {
+      return compute_ambigious_op_overload (id, op);
+    });
+
   iterate ([&] (HirId id, TyTy::BaseType *ty) mutable -> bool {
     return compute_infer_var (id, ty, emit_error);
   });
diff --git a/gcc/testsuite/rust/compile/issue-3916.rs 
b/gcc/testsuite/rust/compile/issue-3916.rs
new file mode 100644
index 000000000000..59b522b4ed5c
--- /dev/null
+++ b/gcc/testsuite/rust/compile/issue-3916.rs
@@ -0,0 +1,36 @@
+#![feature(rustc_attrs)]
+
+#[lang = "sized"]
+trait Sized {}
+
+#[lang = "add"]
+trait Add<Rhs = Self> {
+    type Output;
+
+    fn add(self, rhs: Rhs) -> Self::Output;
+}
+
+macro_rules! add_impl {
+    ($($t:ty)*) => ($(
+        impl Add for $t {
+            type Output = $t;
+
+            #[inline]
+            #[rustc_inherit_overflow_checks]
+            fn add(self, other: $t) -> $t { self + other }
+        }
+    )*)
+}
+
+add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 }
+
+pub fn test(len: usize) -> u64 {
+    let mut i = 0;
+    let mut out = 0;
+    if i + 3 < len {
+        out = 123;
+    } else {
+        out = 456;
+    }
+    out
+}

Reply via email to