From: Philip Herron <[email protected]>
This handles the awkward case of the try from stuff where everything needs
to constrain based on the return type of the block its part of.
gcc/rust/ChangeLog:
* typecheck/rust-hir-type-check-expr.cc (TypeCheckExpr::visit): push
expected type
* typecheck/rust-hir-type-check-stmt.cc (TypeCheckStmt::visit): apply
expected ty on tail
* typecheck/rust-hir-type-check.h: helpers
* typecheck/rust-typecheck-context.cc
(TypeCheckContext::push_expected_type): likewise
(TypeCheckContext::pop_expected_type): likewise
(TypeCheckContext::peek_expected_type): likewise
* typecheck/rust-tyty-call.cc (TypeCheckCallExpr::visit): check expected
Signed-off-by: Philip Herron <[email protected]>
---
This change was merged into the gccrs repository and is posted here for
upstream visibility and potential drive-by review, as requested by GCC
release managers.
Each commit email contains a link to its details on github from where you can
find the Pull-Request and associated discussions.
Commit on github:
https://github.com/Rust-GCC/gccrs/commit/a24d5b4ac10b6d327d3cc4ad61bff50431093036
The commit has NOT been mentioned in any issue.
The commit has been mentioned in the following pull-request(s):
- https://github.com/Rust-GCC/gccrs/pull/4594
.../typecheck/rust-hir-type-check-expr.cc | 35 ++++++++++++++++---
.../typecheck/rust-hir-type-check-stmt.cc | 33 ++++++++++++-----
gcc/rust/typecheck/rust-hir-type-check.h | 5 +++
gcc/rust/typecheck/rust-typecheck-context.cc | 21 +++++++++++
gcc/rust/typecheck/rust-tyty-call.cc | 29 ++++++++++++++-
5 files changed, 109 insertions(+), 14 deletions(-)
diff --git a/gcc/rust/typecheck/rust-hir-type-check-expr.cc
b/gcc/rust/typecheck/rust-hir-type-check-expr.cc
index 07ac5c4ea..f3d2c3a27 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-expr.cc
+++ b/gcc/rust/typecheck/rust-hir-type-check-expr.cc
@@ -200,9 +200,23 @@ TypeCheckExpr::visit (HIR::ReturnExpr &expr)
? expr.get_expr ().get_locus ()
: expr.get_locus ();
- TyTy::BaseType *expr_ty = expr.has_return_expr ()
- ? TypeCheckExpr::Resolve (expr.get_expr ())
- : TyTy::TupleType::get_unit_type ();
+ // Push expected type so the resolver of the return expression
+ // inference before checking its arguments which is needed
+ // for things like:
+ //
+ // return Try::from_error(...)
+ //
+ // Where Self has to bind from the fn return type before the param
+ // projection can be normalized.
+ TyTy::BaseType *expr_ty;
+ if (expr.has_return_expr ())
+ {
+ context->push_expected_type (fn_return_tyty);
+ expr_ty = TypeCheckExpr::Resolve (expr.get_expr ());
+ context->pop_expected_type ();
+ }
+ else
+ expr_ty = TyTy::TupleType::get_unit_type ();
coercion_site (expr.get_mappings ().get_hirid (),
TyTy::TyWithLocation (fn_return_tyty),
@@ -624,6 +638,10 @@ TypeCheckExpr::visit (HIR::BlockExpr &expr)
context->push_new_loop_context (expr.get_mappings ().get_hirid (),
expr.get_locus ());
+ // Forward the caller's expected type to the block's tail expression only
+ TyTy::BaseType *outer_expected = context->peek_expected_type ();
+ context->push_expected_type (nullptr);
+
for (auto &s : expr.get_statements ())
{
if (!s->is_item ())
@@ -641,6 +659,7 @@ TypeCheckExpr::visit (HIR::BlockExpr &expr)
if (resolved == nullptr)
{
rust_error_at (s->get_locus (), "failure to resolve type");
+ context->pop_expected_type ();
return;
}
@@ -653,8 +672,14 @@ TypeCheckExpr::visit (HIR::BlockExpr &expr)
}
}
+ context->pop_expected_type ();
+
if (expr.has_expr ())
- infered = TypeCheckExpr::Resolve (expr.get_final_expr ())->clone ();
+ {
+ context->push_expected_type (outer_expected);
+ infered = TypeCheckExpr::Resolve (expr.get_final_expr ())->clone ();
+ context->pop_expected_type ();
+ }
else if (expr.is_tail_reachable ())
infered = TyTy::TupleType::get_unit_type ();
else if (expr.has_label ())
@@ -1869,7 +1894,7 @@ TypeCheckExpr::visit (HIR::ClosureExpr &expr)
TyTy::TyVar result_type
= expr.has_return_type ()
? TyTy::TyVar (
- TypeCheckType::Resolve (expr.get_return_type ())->get_ref ())
+ TypeCheckType::Resolve (expr.get_return_type ())->get_ref ())
: TyTy::TyVar::get_implicit_infer_var (expr.get_locus ());
// resolve the block
diff --git a/gcc/rust/typecheck/rust-hir-type-check-stmt.cc
b/gcc/rust/typecheck/rust-hir-type-check-stmt.cc
index e3d92a53c..5e7d8b4e7 100644
--- a/gcc/rust/typecheck/rust-hir-type-check-stmt.cc
+++ b/gcc/rust/typecheck/rust-hir-type-check-stmt.cc
@@ -82,12 +82,37 @@ TypeCheckStmt::visit (HIR::LetStmt &stmt)
infered = TyTy::TupleType::get_unit_type ();
auto &stmt_pattern = stmt.get_pattern ();
+
+ // Resolve the type annotation before the init expression so the normal
+ // coercion site below can check the init against the declared type.
+ TyTy::BaseType *specified_ty = nullptr;
+ location_t specified_ty_locus = UNKNOWN_LOCATION;
+ if (stmt.has_type ())
+ {
+ specified_ty = TypeCheckType::Resolve (stmt.get_type ());
+ specified_ty_locus = stmt.get_type ().get_locus ();
+ }
+
TyTy::BaseType *init_expr_ty = nullptr;
location_t init_expr_locus = UNKNOWN_LOCATION;
if (stmt.has_init_expr ())
{
init_expr_locus = stmt.get_init_expr ().get_locus ();
+
+ // Try blocks have a block whose tail expression is:
+ //
+ // Try::from_ok(tail)
+ //
+ // We can forward the annotated let type to block initializers so
+ // the block can pass it to its tail expression
+ bool push_expected = specified_ty != nullptr
+ && stmt.get_init_expr ().get_expression_type ()
+ == HIR::Expr::ExprType::Block;
+ if (push_expected)
+ context->push_expected_type (specified_ty);
init_expr_ty = TypeCheckExpr::Resolve (stmt.get_init_expr ());
+ if (push_expected)
+ context->pop_expected_type ();
if (init_expr_ty->get_kind () == TyTy::TypeKind::ERROR)
return;
@@ -95,14 +120,6 @@ TypeCheckStmt::visit (HIR::LetStmt &stmt)
stmt_pattern.get_mappings ().get_hirid ());
}
- TyTy::BaseType *specified_ty = nullptr;
- location_t specified_ty_locus;
- if (stmt.has_type ())
- {
- specified_ty = TypeCheckType::Resolve (stmt.get_type ());
- specified_ty_locus = stmt.get_type ().get_locus ();
- }
-
// let x:i32 = 123;
if (specified_ty != nullptr && init_expr_ty != nullptr)
{
diff --git a/gcc/rust/typecheck/rust-hir-type-check.h
b/gcc/rust/typecheck/rust-hir-type-check.h
index e356f05c0..38dbfa62c 100644
--- a/gcc/rust/typecheck/rust-hir-type-check.h
+++ b/gcc/rust/typecheck/rust-hir-type-check.h
@@ -220,6 +220,10 @@ public:
TyTy::BaseType *return_type);
void pop_return_type ();
+ void push_expected_type (TyTy::BaseType *expected);
+ void pop_expected_type ();
+ TyTy::BaseType *peek_expected_type () const;
+
StackedContexts<TypeCheckBlockContextItem> &block_context ();
void iterate (std::function<bool (HirId, TyTy::BaseType *)> cb);
@@ -319,6 +323,7 @@ private:
std::vector<std::unique_ptr<TyTy::BaseType>> builtins;
std::vector<std::pair<TypeCheckContextItem, TyTy::BaseType *>>
return_type_stack;
+ std::vector<TyTy::BaseType *> expected_type_stack;
std::vector<TyTy::BaseType *> loop_type_stack;
StackedContexts<TypeCheckBlockContextItem> block_stack;
std::map<DefId, TraitReference> trait_context;
diff --git a/gcc/rust/typecheck/rust-typecheck-context.cc
b/gcc/rust/typecheck/rust-typecheck-context.cc
index f2c186fb0..fa7d5efe5 100644
--- a/gcc/rust/typecheck/rust-typecheck-context.cc
+++ b/gcc/rust/typecheck/rust-typecheck-context.cc
@@ -171,6 +171,27 @@ TypeCheckContext::peek_context ()
return return_type_stack.back ().first;
}
+void
+TypeCheckContext::push_expected_type (TyTy::BaseType *expected)
+{
+ expected_type_stack.push_back (expected);
+}
+
+void
+TypeCheckContext::pop_expected_type ()
+{
+ rust_assert (!expected_type_stack.empty ());
+ expected_type_stack.pop_back ();
+}
+
+TyTy::BaseType *
+TypeCheckContext::peek_expected_type () const
+{
+ if (expected_type_stack.empty ())
+ return nullptr;
+ return expected_type_stack.back ();
+}
+
StackedContexts<TypeCheckBlockContextItem> &
TypeCheckContext::block_context ()
{
diff --git a/gcc/rust/typecheck/rust-tyty-call.cc
b/gcc/rust/typecheck/rust-tyty-call.cc
index ad03f034a..0b8d424d6 100644
--- a/gcc/rust/typecheck/rust-tyty-call.cc
+++ b/gcc/rust/typecheck/rust-tyty-call.cc
@@ -18,6 +18,7 @@
#include "rust-tyty-call.h"
#include "rust-hir-type-check-expr.h"
+#include "rust-hir-type-check.h"
#include "rust-type-util.h"
namespace Rust {
@@ -135,11 +136,38 @@ TypeCheckCallExpr::visit (FnType &type)
}
}
+ // if the surrounding context has pushed an expected type, try unifying it
+ // with the fn's return type before checking arguments. This lets the callee
+ // result constrain inference variables that may appear in parameter
+ // projections.
+ auto *ctx = Resolver::TypeCheckContext::get ();
+ TyTy::BaseType *expected = ctx->peek_expected_type ();
+ const TyTy::BaseType *return_infer
+ = type.get_return_type ()->contains_infer ();
+ if (expected != nullptr && return_infer != nullptr)
+ {
+ Resolver::unify_site_and (call.get_mappings ().get_hirid (),
+ TyWithLocation (expected),
+ TyWithLocation (type.get_return_type ()),
+ call.get_locus (), false /*emit_errors*/,
+ true /*commit_if_ok*/,
+ true /*implicit_infer_vars*/, true /*cleanup*/);
+ }
+
size_t i = 0;
for (auto &argument : call.get_arguments ())
{
location_t arg_locus = argument->get_locus ();
+
+ TyTy::BaseType *param_ty = nullptr;
+ if (i < type.num_params ())
+ param_ty = type.param_at (i).get_type ();
+
+ if (param_ty != nullptr)
+ ctx->push_expected_type (param_ty);
auto argument_expr_tyty = Resolver::TypeCheckExpr::Resolve (*argument);
+ if (param_ty != nullptr)
+ ctx->pop_expected_type ();
if (argument_expr_tyty->is<TyTy::ErrorType> ())
return;
@@ -147,7 +175,6 @@ TypeCheckCallExpr::visit (FnType &type)
if (i < type.num_params ())
{
auto &fnparam = type.param_at (i);
- BaseType *param_ty = fnparam.get_type ();
location_t param_locus
= fnparam.has_pattern ()
? fnparam.get_pattern ().get_locus ()
--
2.54.0