From: Philip Herron <[email protected]>

This handles the awkward case of the try from stuff where everything needs
to constrain based on the return type of the block its part of.

gcc/rust/ChangeLog:

        * typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::visit): push 
expected type
        * typecheck/rust-hir-type-check-stmt.cc (TypeCheckStmt::visit): apply 
expected ty on tail
        * typecheck/rust-hir-type-check.h: helpers
        * typecheck/rust-typecheck-context.cc 
(TypeCheckContext::push_expected_type): likewise
        (TypeCheckContext::pop_expected_type): likewise
        (TypeCheckContext::peek_expected_type): likewise
        * typecheck/rust-tyty-call.cc (TypeCheckCallExpr::visit): check expected

Signed-off-by: Philip Herron <[email protected]>
---
This change was merged into the gccrs repository and is posted here for
upstream visibility and potential drive-by review, as requested by GCC
release managers.
Each commit email contains a link to its details on github from where you can
find the Pull-Request and associated discussions.


Commit on github: 
https://github.com/Rust-GCC/gccrs/commit/a24d5b4ac10b6d327d3cc4ad61bff50431093036

The commit has NOT been mentioned in any issue.

The commit has been mentioned in the following pull-request(s):
 - https://github.com/Rust-GCC/gccrs/pull/4594

 .../typecheck/rust-hir-type-check-expr.cc     | 35 ++++++++++++++++---
 .../typecheck/rust-hir-type-check-stmt.cc     | 33 ++++++++++++-----
 gcc/rust/typecheck/rust-hir-type-check.h      |  5 +++
 gcc/rust/typecheck/rust-typecheck-context.cc  | 21 +++++++++++
 gcc/rust/typecheck/rust-tyty-call.cc          | 29 ++++++++++++++-
 5 files changed, 109 insertions(+), 14 deletions(-)

diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.cc 
b/gcc/rust/typecheck/rust-hir-type-check-expr.cc
index 07ac5c4ea..f3d2c3a27 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-expr.cc
+++ b/gcc/rust/typecheck/rust-hir-type-check-expr.cc
@@ -200,9 +200,23 @@ TypeCheckExpr::visit (HIR::ReturnExpr &expr)
                            ? expr.get_expr ().get_locus ()
                            : expr.get_locus ();
 
-  TyTy::BaseType *expr_ty = expr.has_return_expr ()
-                             ? TypeCheckExpr::Resolve (expr.get_expr ())
-                             : TyTy::TupleType::get_unit_type ();
+  // Push expected type so the resolver of the return expression
+  // inference before checking its arguments which is needed
+  // for things like:
+  //
+  //    return Try::from_error(...)
+  //
+  // Where Self has to bind from the fn return type before the param
+  // projection can be normalized.
+  TyTy::BaseType *expr_ty;
+  if (expr.has_return_expr ())
+    {
+      context->push_expected_type (fn_return_tyty);
+      expr_ty = TypeCheckExpr::Resolve (expr.get_expr ());
+      context->pop_expected_type ();
+    }
+  else
+    expr_ty = TyTy::TupleType::get_unit_type ();
 
   coercion_site (expr.get_mappings ().get_hirid (),
                 TyTy::TyWithLocation (fn_return_tyty),
@@ -624,6 +638,10 @@ TypeCheckExpr::visit (HIR::BlockExpr &expr)
     context->push_new_loop_context (expr.get_mappings ().get_hirid (),
                                    expr.get_locus ());
 
+  // Forward the caller's expected type to the block's tail expression only
+  TyTy::BaseType *outer_expected = context->peek_expected_type ();
+  context->push_expected_type (nullptr);
+
   for (auto &s : expr.get_statements ())
     {
       if (!s->is_item ())
@@ -641,6 +659,7 @@ TypeCheckExpr::visit (HIR::BlockExpr &expr)
       if (resolved == nullptr)
        {
          rust_error_at (s->get_locus (), "failure to resolve type");
+         context->pop_expected_type ();
          return;
        }
 
@@ -653,8 +672,14 @@ TypeCheckExpr::visit (HIR::BlockExpr &expr)
        }
     }
 
+  context->pop_expected_type ();
+
   if (expr.has_expr ())
-    infered = TypeCheckExpr::Resolve (expr.get_final_expr ())->clone ();
+    {
+      context->push_expected_type (outer_expected);
+      infered = TypeCheckExpr::Resolve (expr.get_final_expr ())->clone ();
+      context->pop_expected_type ();
+    }
   else if (expr.is_tail_reachable ())
     infered = TyTy::TupleType::get_unit_type ();
   else if (expr.has_label ())
@@ -1869,7 +1894,7 @@ TypeCheckExpr::visit (HIR::ClosureExpr &expr)
   TyTy::TyVar result_type
     = expr.has_return_type ()
        ? TyTy::TyVar (
-         TypeCheckType::Resolve (expr.get_return_type ())->get_ref ())
+           TypeCheckType::Resolve (expr.get_return_type ())->get_ref ())
        : TyTy::TyVar::get_implicit_infer_var (expr.get_locus ());
 
   // resolve the block
diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.cc 
b/gcc/rust/typecheck/rust-hir-type-check-stmt.cc
index e3d92a53c..5e7d8b4e7 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-stmt.cc
+++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.cc
@@ -82,12 +82,37 @@ TypeCheckStmt::visit (HIR::LetStmt &stmt)
   infered = TyTy::TupleType::get_unit_type ();
 
   auto &stmt_pattern = stmt.get_pattern ();
+
+  // Resolve the type annotation before the init expression so the normal
+  // coercion site below can check the init against the declared type.
+  TyTy::BaseType *specified_ty = nullptr;
+  location_t specified_ty_locus = UNKNOWN_LOCATION;
+  if (stmt.has_type ())
+    {
+      specified_ty = TypeCheckType::Resolve (stmt.get_type ());
+      specified_ty_locus = stmt.get_type ().get_locus ();
+    }
+
   TyTy::BaseType *init_expr_ty = nullptr;
   location_t init_expr_locus = UNKNOWN_LOCATION;
   if (stmt.has_init_expr ())
     {
       init_expr_locus = stmt.get_init_expr ().get_locus ();
+
+      // Try blocks have a block whose tail expression is:
+      //
+      //     Try::from_ok(tail)
+      //
+      // We can forward the annotated let type to block initializers so
+      // the block can pass it to its tail expression
+      bool push_expected = specified_ty != nullptr
+                          && stmt.get_init_expr ().get_expression_type ()
+                               == HIR::Expr::ExprType::Block;
+      if (push_expected)
+       context->push_expected_type (specified_ty);
       init_expr_ty = TypeCheckExpr::Resolve (stmt.get_init_expr ());
+      if (push_expected)
+       context->pop_expected_type ();
       if (init_expr_ty->get_kind () == TyTy::TypeKind::ERROR)
        return;
 
@@ -95,14 +120,6 @@ TypeCheckStmt::visit (HIR::LetStmt &stmt)
        stmt_pattern.get_mappings ().get_hirid ());
     }
 
-  TyTy::BaseType *specified_ty = nullptr;
-  location_t specified_ty_locus;
-  if (stmt.has_type ())
-    {
-      specified_ty = TypeCheckType::Resolve (stmt.get_type ());
-      specified_ty_locus = stmt.get_type ().get_locus ();
-    }
-
   // let x:i32 = 123;
   if (specified_ty != nullptr && init_expr_ty != nullptr)
     {
diff --git a/gcc/rust/typecheck/rust-hir-type-check.h 
b/gcc/rust/typecheck/rust-hir-type-check.h
index e356f05c0..38dbfa62c 100644
--- a/gcc/rust/typecheck/rust-hir-type-check.h
+++ b/gcc/rust/typecheck/rust-hir-type-check.h
@@ -220,6 +220,10 @@ public:
                         TyTy::BaseType *return_type);
   void pop_return_type ();
 
+  void push_expected_type (TyTy::BaseType *expected);
+  void pop_expected_type ();
+  TyTy::BaseType *peek_expected_type () const;
+
   StackedContexts<TypeCheckBlockContextItem> &block_context ();
 
   void iterate (std::function<bool (HirId, TyTy::BaseType *)> cb);
@@ -319,6 +323,7 @@ private:
   std::vector<std::unique_ptr<TyTy::BaseType>> builtins;
   std::vector<std::pair<TypeCheckContextItem, TyTy::BaseType *>>
     return_type_stack;
+  std::vector<TyTy::BaseType *> expected_type_stack;
   std::vector<TyTy::BaseType *> loop_type_stack;
   StackedContexts<TypeCheckBlockContextItem> block_stack;
   std::map<DefId, TraitReference> trait_context;
diff --git a/gcc/rust/typecheck/rust-typecheck-context.cc 
b/gcc/rust/typecheck/rust-typecheck-context.cc
index f2c186fb0..fa7d5efe5 100644
--- a/gcc/rust/typecheck/rust-typecheck-context.cc
+++ b/gcc/rust/typecheck/rust-typecheck-context.cc
@@ -171,6 +171,27 @@ TypeCheckContext::peek_context ()
   return return_type_stack.back ().first;
 }
 
+void
+TypeCheckContext::push_expected_type (TyTy::BaseType *expected)
+{
+  expected_type_stack.push_back (expected);
+}
+
+void
+TypeCheckContext::pop_expected_type ()
+{
+  rust_assert (!expected_type_stack.empty ());
+  expected_type_stack.pop_back ();
+}
+
+TyTy::BaseType *
+TypeCheckContext::peek_expected_type () const
+{
+  if (expected_type_stack.empty ())
+    return nullptr;
+  return expected_type_stack.back ();
+}
+
 StackedContexts<TypeCheckBlockContextItem> &
 TypeCheckContext::block_context ()
 {
diff --git a/gcc/rust/typecheck/rust-tyty-call.cc 
b/gcc/rust/typecheck/rust-tyty-call.cc
index ad03f034a..0b8d424d6 100644
--- a/gcc/rust/typecheck/rust-tyty-call.cc
+++ b/gcc/rust/typecheck/rust-tyty-call.cc
@@ -18,6 +18,7 @@
 
 #include "rust-tyty-call.h"
 #include "rust-hir-type-check-expr.h"
+#include "rust-hir-type-check.h"
 #include "rust-type-util.h"
 
 namespace Rust {
@@ -135,11 +136,38 @@ TypeCheckCallExpr::visit (FnType &type)
        }
     }
 
+  // if the surrounding context has pushed an expected type, try unifying it
+  // with the fn's return type before checking arguments. This lets the callee
+  // result constrain inference variables that may appear in parameter
+  // projections.
+  auto *ctx = Resolver::TypeCheckContext::get ();
+  TyTy::BaseType *expected = ctx->peek_expected_type ();
+  const TyTy::BaseType *return_infer
+    = type.get_return_type ()->contains_infer ();
+  if (expected != nullptr && return_infer != nullptr)
+    {
+      Resolver::unify_site_and (call.get_mappings ().get_hirid (),
+                               TyWithLocation (expected),
+                               TyWithLocation (type.get_return_type ()),
+                               call.get_locus (), false /*emit_errors*/,
+                               true /*commit_if_ok*/,
+                               true /*implicit_infer_vars*/, true /*cleanup*/);
+    }
+
   size_t i = 0;
   for (auto &argument : call.get_arguments ())
     {
       location_t arg_locus = argument->get_locus ();
+
+      TyTy::BaseType *param_ty = nullptr;
+      if (i < type.num_params ())
+       param_ty = type.param_at (i).get_type ();
+
+      if (param_ty != nullptr)
+       ctx->push_expected_type (param_ty);
       auto argument_expr_tyty = Resolver::TypeCheckExpr::Resolve (*argument);
+      if (param_ty != nullptr)
+       ctx->pop_expected_type ();
       if (argument_expr_tyty->is<TyTy::ErrorType> ())
        return;
 
@@ -147,7 +175,6 @@ TypeCheckCallExpr::visit (FnType &type)
       if (i < type.num_params ())
        {
          auto &fnparam = type.param_at (i);
-         BaseType *param_ty = fnparam.get_type ();
          location_t param_locus
            = fnparam.has_pattern ()
                ? fnparam.get_pattern ().get_locus ()
-- 
2.54.0

Reply via email to