https://gcc.gnu.org/g:9076a8f688886de933d5f953855a3431d3e0a922
commit r16-2943-g9076a8f688886de933d5f953855a3431d3e0a922 Author: Philip Herron <herron.phi...@googlemail.com> Date: Fri Jul 18 15:46:59 2025 +0100 gccrs: Add initial support for deffered operator overload resolution In the test case: fn test (len: usize) -> u64 { let mut i = 0; let mut out = 0; if i + 3 < len { out = 123; } out } The issue is to determine the correct type of 'i', out is simple because it hits a coercion site in the resturn position for u64. But 'i + 3', 'i' is an integer infer variable and the same for the literal '3'. So when it comes to resolving the type for the Add expression we hit the resolve the operator overload code and because of this: macro_rules! add_impl { ($($t:ty)*) => ($( impl Add for $t { type Output = $t; #[inline] #[rustc_inherit_overflow_checks] fn add(self, other: $t) -> $t { self + other } } )*) } add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 } This means the resolution for 'i + 3' is ambigious because it could be any of these Add implementations. But because we unify against the '< len' where len is defined as usize later in the resolution we determine 'i' is actually a usize. Which means if we defer the resolution of this operator overload in the ambigious case we can simply resolve it at the end. Fixes Rust-GCC#3916 gcc/rust/ChangeLog: * hir/tree/rust-hir-expr.cc (OperatorExprMeta::OperatorExprMeta): track the rhs * hir/tree/rust-hir-expr.h: likewise * hir/tree/rust-hir-path.h: get rid of old comments * typecheck/rust-hir-trait-reference.cc (TraitReference::get_trait_substs): return references instead of copy * typecheck/rust-hir-trait-reference.h: update header * typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::ResolveOpOverload): write ambigious operator overloads to a table and try to resolve it at the end * typecheck/rust-hir-type-check-expr.h: new static helper * typecheck/rust-hir-type-check.h (struct DeferredOpOverload): new model to defer resolution * typecheck/rust-typecheck-context.cc (TypeCheckContext::lookup_operator_overload): new (TypeCheckContext::compute_ambigious_op_overload): likewise (TypeCheckContext::compute_inference_variables): likewise gcc/testsuite/ChangeLog: * rust/compile/issue-3916.rs: New test. Signed-off-by: Philip Herron <herron.phi...@googlemail.com> Diff: --- gcc/rust/hir/tree/rust-hir-expr.cc | 10 ++-- gcc/rust/hir/tree/rust-hir-expr.h | 32 +++++++++++- gcc/rust/hir/tree/rust-hir-path.h | 12 +++-- gcc/rust/typecheck/rust-hir-trait-reference.cc | 8 ++- gcc/rust/typecheck/rust-hir-trait-reference.h | 4 +- gcc/rust/typecheck/rust-hir-type-check-expr.cc | 71 +++++++++++++++++++++++--- gcc/rust/typecheck/rust-hir-type-check-expr.h | 5 ++ gcc/rust/typecheck/rust-hir-type-check.h | 45 ++++++++++++++++ gcc/rust/typecheck/rust-typecheck-context.cc | 63 ++++++++++++++++++++++- gcc/testsuite/rust/compile/issue-3916.rs | 36 +++++++++++++ 10 files changed, 267 insertions(+), 19 deletions(-) diff --git a/gcc/rust/hir/tree/rust-hir-expr.cc b/gcc/rust/hir/tree/rust-hir-expr.cc index 93dcec2c8d79..038bfc77f94a 100644 --- a/gcc/rust/hir/tree/rust-hir-expr.cc +++ b/gcc/rust/hir/tree/rust-hir-expr.cc @@ -17,6 +17,7 @@ // <http://www.gnu.org/licenses/>. #include "rust-hir-expr.h" +#include "rust-hir-map.h" #include "rust-operators.h" #include "rust-hir-stmt.h" @@ -1321,37 +1322,40 @@ AsyncBlockExpr::operator= (AsyncBlockExpr const &other) OperatorExprMeta::OperatorExprMeta (HIR::CompoundAssignmentExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), - locus (expr.get_locus ()) + rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::ArithmeticOrLogicalExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), - locus (expr.get_locus ()) + rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::NegationExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), + rvalue_mappings (Analysis::NodeMapping::get_error ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::DereferenceExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), + rvalue_mappings (Analysis::NodeMapping::get_error ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::ArrayIndexExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_array_expr ().get_mappings ()), + rvalue_mappings (expr.get_index_expr ().get_mappings ()), locus (expr.get_locus ()) {} OperatorExprMeta::OperatorExprMeta (HIR::ComparisonExpr &expr) : node_mappings (expr.get_mappings ()), lvalue_mappings (expr.get_expr ().get_mappings ()), - locus (expr.get_locus ()) + rvalue_mappings (expr.get_rhs ().get_mappings ()), locus (expr.get_locus ()) {} InlineAsmOperand::In::In ( diff --git a/gcc/rust/hir/tree/rust-hir-expr.h b/gcc/rust/hir/tree/rust-hir-expr.h index fcb4744fef4c..028455b98702 100644 --- a/gcc/rust/hir/tree/rust-hir-expr.h +++ b/gcc/rust/hir/tree/rust-hir-expr.h @@ -27,6 +27,7 @@ #include "rust-hir-attrs.h" #include "rust-expr.h" #include "rust-hir-map.h" +#include "rust-mapping-common.h" namespace Rust { namespace HIR { @@ -2892,6 +2893,22 @@ public: OperatorExprMeta (HIR::ComparisonExpr &expr); + OperatorExprMeta (const OperatorExprMeta &other) + : node_mappings (other.node_mappings), + lvalue_mappings (other.lvalue_mappings), + rvalue_mappings (other.rvalue_mappings), locus (other.locus) + {} + + OperatorExprMeta &operator= (const OperatorExprMeta &other) + { + node_mappings = other.node_mappings; + lvalue_mappings = other.lvalue_mappings; + rvalue_mappings = other.rvalue_mappings; + locus = other.locus; + + return *this; + } + const Analysis::NodeMapping &get_mappings () const { return node_mappings; } const Analysis::NodeMapping &get_lvalue_mappings () const @@ -2899,11 +2916,22 @@ public: return lvalue_mappings; } + const Analysis::NodeMapping &get_rvalue_mappings () const + { + return rvalue_mappings; + } + + bool has_rvalue_mappings () const + { + return rvalue_mappings.get_hirid () != UNKNOWN_HIRID; + } + location_t get_locus () const { return locus; } private: - const Analysis::NodeMapping node_mappings; - const Analysis::NodeMapping lvalue_mappings; + Analysis::NodeMapping node_mappings; + Analysis::NodeMapping lvalue_mappings; + Analysis::NodeMapping rvalue_mappings; location_t locus; }; diff --git a/gcc/rust/hir/tree/rust-hir-path.h b/gcc/rust/hir/tree/rust-hir-path.h index 3ce2662c8024..5f88c6827bb1 100644 --- a/gcc/rust/hir/tree/rust-hir-path.h +++ b/gcc/rust/hir/tree/rust-hir-path.h @@ -41,11 +41,15 @@ public: : segment_name (std::move (segment_name)) {} - /* TODO: insert check in constructor for this? Or is this a semantic error - * best handled then? */ + PathIdentSegment (const PathIdentSegment &other) + : segment_name (other.segment_name) + {} - /* TODO: does this require visitor? pretty sure this isn't polymorphic, but - * not entirely sure */ + PathIdentSegment &operator= (PathIdentSegment const &other) + { + segment_name = other.segment_name; + return *this; + } // Creates an error PathIdentSegment. static PathIdentSegment create_error () { return PathIdentSegment (""); } diff --git a/gcc/rust/typecheck/rust-hir-trait-reference.cc b/gcc/rust/typecheck/rust-hir-trait-reference.cc index 88e270d510d2..74856f098fa0 100644 --- a/gcc/rust/typecheck/rust-hir-trait-reference.cc +++ b/gcc/rust/typecheck/rust-hir-trait-reference.cc @@ -432,7 +432,13 @@ TraitReference::trait_has_generics () const return !trait_substs.empty (); } -std::vector<TyTy::SubstitutionParamMapping> +std::vector<TyTy::SubstitutionParamMapping> & +TraitReference::get_trait_substs () +{ + return trait_substs; +} + +const std::vector<TyTy::SubstitutionParamMapping> & TraitReference::get_trait_substs () const { return trait_substs; diff --git a/gcc/rust/typecheck/rust-hir-trait-reference.h b/gcc/rust/typecheck/rust-hir-trait-reference.h index 8b1ac7daf7f1..473513ea75ff 100644 --- a/gcc/rust/typecheck/rust-hir-trait-reference.h +++ b/gcc/rust/typecheck/rust-hir-trait-reference.h @@ -224,7 +224,9 @@ public: bool trait_has_generics () const; - std::vector<TyTy::SubstitutionParamMapping> get_trait_substs () const; + std::vector<TyTy::SubstitutionParamMapping> &get_trait_substs (); + + const std::vector<TyTy::SubstitutionParamMapping> &get_trait_substs () const; bool satisfies_bound (const TraitReference &reference) const; diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.cc b/gcc/rust/typecheck/rust-hir-type-check-expr.cc index c1404561f4d0..5db0e5690c97 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.cc +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.cc @@ -17,6 +17,7 @@ // <http://www.gnu.org/licenses/>. #include "optional.h" +#include "rust-common.h" #include "rust-hir-expr.h" #include "rust-system.h" #include "rust-tyty-call.h" @@ -59,6 +60,19 @@ TypeCheckExpr::Resolve (HIR::Expr &expr) return resolver.infered; } +TyTy::BaseType * +TypeCheckExpr::ResolveOpOverload (LangItem::Kind lang_item_type, + HIR::OperatorExprMeta expr, + TyTy::BaseType *lhs, TyTy::BaseType *rhs, + HIR::PathIdentSegment specified_segment) +{ + TypeCheckExpr resolver; + + resolver.resolve_operator_overload (lang_item_type, expr, lhs, rhs, + specified_segment); + return resolver.infered; +} + void TypeCheckExpr::visit (HIR::TupleIndexExpr &expr) { @@ -1885,7 +1899,16 @@ TypeCheckExpr::resolve_operator_overload ( // probe for the lang-item if (!lang_item_defined) return false; + DefId &respective_lang_item_id = lang_item_defined.value (); + auto def_lookup = mappings.lookup_defid (respective_lang_item_id); + rust_assert (def_lookup.has_value ()); + + HIR::Item *def_item = def_lookup.value (); + rust_assert (def_item->get_item_kind () == HIR::Item::ItemKind::Trait); + HIR::Trait &trait = *static_cast<HIR::Trait *> (def_item); + TraitReference *defid_trait_reference = TraitResolver::Resolve (trait); + rust_assert (!defid_trait_reference->is_error ()); // we might be in a static or const context and unknown is fine TypeCheckContextItem current_context = TypeCheckContextItem::get_error (); @@ -1929,15 +1952,49 @@ TypeCheckExpr::resolve_operator_overload ( if (selected_candidates.size () > 1) { - // mutliple candidates - rich_location r (line_table, expr.get_locus ()); - for (auto &c : resolved_candidates) - r.add_range (c.candidate.locus); + auto infer + = TyTy::TyVar::get_implicit_infer_var (expr.get_locus ()).get_tyty (); + auto trait_subst = defid_trait_reference->get_trait_substs (); + rust_assert (trait_subst.size () > 0); - rust_error_at ( - r, "multiple candidates found for possible operator overload"); + TyTy::TypeBoundPredicate pred (respective_lang_item_id, trait_subst, + BoundPolarity::RegularBound, + expr.get_locus ()); - return false; + std::vector<TyTy::SubstitutionArg> mappings; + auto &self_param_mapping = trait_subst[0]; + mappings.push_back (TyTy::SubstitutionArg (&self_param_mapping, lhs)); + + if (rhs != nullptr) + { + rust_assert (trait_subst.size () == 2); + auto &rhs_param_mapping = trait_subst[1]; + mappings.push_back (TyTy::SubstitutionArg (&rhs_param_mapping, lhs)); + } + + std::map<std::string, TyTy::BaseType *> binding_args; + binding_args["Output"] = infer; + + TyTy::SubstitutionArgumentMappings arg_mappings (mappings, binding_args, + TyTy::RegionParamList ( + trait_subst.size ()), + expr.get_locus ()); + pred.apply_argument_mappings (arg_mappings, false); + + infer->inherit_bounds ({pred}); + DeferredOpOverload defer (expr.get_mappings ().get_hirid (), + lang_item_type, specified_segment, pred, expr); + context->insert_deferred_operator_overload (std::move (defer)); + + if (rhs != nullptr) + lhs = unify_site (expr.get_mappings ().get_hirid (), + TyTy::TyWithLocation (lhs), + TyTy::TyWithLocation (rhs), expr.get_locus ()); + + infered = unify_site (expr.get_mappings ().get_hirid (), + TyTy::TyWithLocation (lhs), + TyTy::TyWithLocation (infer), expr.get_locus ()); + return true; } // Get the adjusted self diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.h b/gcc/rust/typecheck/rust-hir-type-check-expr.h index 531197436853..48f28c700795 100644 --- a/gcc/rust/typecheck/rust-hir-type-check-expr.h +++ b/gcc/rust/typecheck/rust-hir-type-check-expr.h @@ -31,6 +31,11 @@ class TypeCheckExpr : private TypeCheckBase, private HIR::HIRExpressionVisitor public: static TyTy::BaseType *Resolve (HIR::Expr &expr); + static TyTy::BaseType * + ResolveOpOverload (LangItem::Kind lang_item_type, HIR::OperatorExprMeta expr, + TyTy::BaseType *lhs, TyTy::BaseType *rhs, + HIR::PathIdentSegment specified_segment); + void visit (HIR::TupleIndexExpr &expr) override; void visit (HIR::TupleExpr &expr) override; void visit (HIR::ReturnExpr &expr) override; diff --git a/gcc/rust/typecheck/rust-hir-type-check.h b/gcc/rust/typecheck/rust-hir-type-check.h index 356c55803ed6..80e403448359 100644 --- a/gcc/rust/typecheck/rust-hir-type-check.h +++ b/gcc/rust/typecheck/rust-hir-type-check.h @@ -20,6 +20,7 @@ #define RUST_HIR_TYPE_CHECK #include "rust-hir-map.h" +#include "rust-mapping-common.h" #include "rust-tyty.h" #include "rust-hir-trait-reference.h" #include "rust-stacked-contexts.h" @@ -157,6 +158,39 @@ public: WARN_UNUSED_RESULT Lifetime next () { return Lifetime (interner_index++); } }; +struct DeferredOpOverload +{ + HirId expr_id; + LangItem::Kind lang_item_type; + HIR::PathIdentSegment specified_segment; + TyTy::TypeBoundPredicate predicate; + HIR::OperatorExprMeta op; + + DeferredOpOverload (HirId expr_id, LangItem::Kind lang_item_type, + HIR::PathIdentSegment specified_segment, + TyTy::TypeBoundPredicate &predicate, + HIR::OperatorExprMeta op) + : expr_id (expr_id), lang_item_type (lang_item_type), + specified_segment (specified_segment), predicate (predicate), op (op) + {} + + DeferredOpOverload (const struct DeferredOpOverload &other) + : expr_id (other.expr_id), lang_item_type (other.lang_item_type), + specified_segment (other.specified_segment), predicate (other.predicate), + op (other.op) + {} + + DeferredOpOverload &operator= (struct DeferredOpOverload const &other) + { + expr_id = other.expr_id; + lang_item_type = other.lang_item_type; + specified_segment = other.specified_segment; + op = other.op; + + return *this; + } +}; + class TypeCheckContext { public: @@ -237,6 +271,13 @@ public: void insert_operator_overload (HirId id, TyTy::FnType *call_site); bool lookup_operator_overload (HirId id, TyTy::FnType **call); + void insert_deferred_operator_overload (DeferredOpOverload deferred); + bool lookup_deferred_operator_overload (HirId id, + DeferredOpOverload *deferred); + + void iterate_deferred_operator_overloads ( + std::function<bool (HirId, DeferredOpOverload &)> cb); + void insert_unconstrained_check_marker (HirId id, bool status); bool have_checked_for_unconstrained (HirId id, bool *result); @@ -271,6 +312,7 @@ private: TypeCheckContext (); bool compute_infer_var (HirId id, TyTy::BaseType *ty, bool emit_error); + bool compute_ambigious_op_overload (HirId id, DeferredOpOverload &op); std::map<NodeId, HirId> node_id_refs; std::map<HirId, TyTy::BaseType *> resolved; @@ -308,6 +350,9 @@ private: std::set<HirId> querys_in_progress; std::set<DefId> trait_queries_in_progress; + // deferred operator overload + std::map<HirId, DeferredOpOverload> deferred_operator_overloads; + // variance analysis TyTy::VarianceAnalysis::CrateCtx variance_analysis_ctx; diff --git a/gcc/rust/typecheck/rust-typecheck-context.cc b/gcc/rust/typecheck/rust-typecheck-context.cc index 7b3584823e44..83b17612d5e3 100644 --- a/gcc/rust/typecheck/rust-typecheck-context.cc +++ b/gcc/rust/typecheck/rust-typecheck-context.cc @@ -18,6 +18,7 @@ #include "rust-hir-type-check.h" #include "rust-type-util.h" +#include "rust-hir-type-check-expr.h" namespace Rust { namespace Resolver { @@ -408,6 +409,38 @@ TypeCheckContext::lookup_operator_overload (HirId id, TyTy::FnType **call) return true; } +void +TypeCheckContext::insert_deferred_operator_overload ( + DeferredOpOverload deferred) +{ + HirId expr_id = deferred.expr_id; + deferred_operator_overloads.emplace (std::make_pair (expr_id, deferred)); +} + +bool +TypeCheckContext::lookup_deferred_operator_overload ( + HirId id, DeferredOpOverload *deferred) +{ + auto it = deferred_operator_overloads.find (id); + if (it == deferred_operator_overloads.end ()) + return false; + + *deferred = it->second; + return true; +} + +void +TypeCheckContext::iterate_deferred_operator_overloads ( + std::function<bool (HirId, DeferredOpOverload &)> cb) +{ + for (auto it = deferred_operator_overloads.begin (); + it != deferred_operator_overloads.end (); it++) + { + if (!cb (it->first, it->second)) + return; + } +} + void TypeCheckContext::insert_unconstrained_check_marker (HirId id, bool status) { @@ -574,10 +607,38 @@ TypeCheckContext::regions_from_generic_args (const HIR::GenericArgs &args) const return regions; } +bool +TypeCheckContext::compute_ambigious_op_overload (HirId id, + DeferredOpOverload &op) +{ + rust_debug ("attempting resolution of op overload: %s", + op.predicate.as_string ().c_str ()); + + TyTy::BaseType *lhs = nullptr; + bool ok = lookup_type (op.op.get_lvalue_mappings ().get_hirid (), &lhs); + rust_assert (ok); + + TyTy::BaseType *rhs = nullptr; + if (op.op.has_rvalue_mappings ()) + { + bool ok = lookup_type (op.op.get_rvalue_mappings ().get_hirid (), &rhs); + rust_assert (ok); + } + + TypeCheckExpr::ResolveOpOverload (op.lang_item_type, op.op, lhs, rhs, + op.specified_segment); + + return true; +} + void TypeCheckContext::compute_inference_variables (bool emit_error) { - // default inference variables if possible + iterate_deferred_operator_overloads ( + [&] (HirId id, DeferredOpOverload &op) mutable -> bool { + return compute_ambigious_op_overload (id, op); + }); + iterate ([&] (HirId id, TyTy::BaseType *ty) mutable -> bool { return compute_infer_var (id, ty, emit_error); }); diff --git a/gcc/testsuite/rust/compile/issue-3916.rs b/gcc/testsuite/rust/compile/issue-3916.rs new file mode 100644 index 000000000000..59b522b4ed5c --- /dev/null +++ b/gcc/testsuite/rust/compile/issue-3916.rs @@ -0,0 +1,36 @@ +#![feature(rustc_attrs)] + +#[lang = "sized"] +trait Sized {} + +#[lang = "add"] +trait Add<Rhs = Self> { + type Output; + + fn add(self, rhs: Rhs) -> Self::Output; +} + +macro_rules! add_impl { + ($($t:ty)*) => ($( + impl Add for $t { + type Output = $t; + + #[inline] + #[rustc_inherit_overflow_checks] + fn add(self, other: $t) -> $t { self + other } + } + )*) +} + +add_impl! { usize u8 u16 u32 u64 u128 isize i8 i16 i32 i64 i128 f32 f64 } + +pub fn test(len: usize) -> u64 { + let mut i = 0; + let mut out = 0; + if i + 3 < len { + out = 123; + } else { + out = 456; + } + out +}