Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions include/swift/AST/ASTBridging.h
Original file line number Diff line number Diff line change
Expand Up @@ -2404,13 +2404,14 @@ BridgedFallthroughStmt_createParsed(swift::SourceLoc loc,
BridgedDeclContext cDC);

SWIFT_NAME("BridgedForEachStmt.createParsed(_:labelInfo:forLoc:tryLoc:awaitLoc:"
"unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:)")
"unsafeLoc:pattern:inLoc:sequence:whereLoc:whereExpr:body:declContext:)")
BridgedForEachStmt BridgedForEachStmt_createParsed(
BridgedASTContext cContext, BridgedLabeledStmtInfo cLabelInfo,
swift::SourceLoc forLoc, swift::SourceLoc tryLoc, swift::SourceLoc awaitLoc,
swift::SourceLoc unsafeLoc, BridgedPattern cPat, swift::SourceLoc inLoc,
BridgedExpr cSequence, swift::SourceLoc whereLoc,
BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody);
BridgedNullableExpr cWhereExpr, BridgedBraceStmt cBody,
BridgedDeclContext cDeclContext);

SWIFT_NAME("BridgedGuardStmt.createParsed(_:guardLoc:conds:body:)")
BridgedGuardStmt BridgedGuardStmt_createParsed(BridgedASTContext cContext,
Expand Down
21 changes: 21 additions & 0 deletions include/swift/AST/Expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -6724,6 +6724,27 @@ class MacroExpansionExpr final : public Expr,
}
};

/// OpaqueExpr - created to serve as an indirection to a ForEachStmt's sequence
/// expr and where clause to avoid visiting it twice in the ASTWalker after
/// having desugared the loop. This will only be processed in SILGen to emit
/// the underlying expression.
class OpaqueExpr final : public Expr {
Expr *OriginalExpr;

public:
OpaqueExpr(Expr* originalExpr)
: Expr(ExprKind::Opaque, /*implicit*/ true, originalExpr->getType()),
OriginalExpr(originalExpr) {}

Expr *getOriginalExpr() const { return OriginalExpr; }
SourceLoc getStartLoc() const { return OriginalExpr->getStartLoc(); }
SourceLoc getEndLoc() const { return OriginalExpr->getEndLoc(); }

static bool classof(const Expr *E) {
return E->getKind() == ExprKind::Opaque;
}
};

inline bool Expr::isInfixOperator() const {
return isa<BinaryExpr>(this) || isa<TernaryExpr>(this) ||
isa<AssignExpr>(this) || isa<ExplicitCastExpr>(this);
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/ExprNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,7 @@ EXPR(Tap, Expr)
UNCHECKED_EXPR(TypeJoin, Expr)
EXPR(MacroExpansion, Expr)
EXPR(TypeValue, Expr)
EXPR(Opaque, Expr)
// Don't forget to update the LAST_EXPR below when adding a new Expr here.
LAST_EXPR(TypeValue)

Expand Down
48 changes: 35 additions & 13 deletions include/swift/AST/Stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1003,21 +1003,21 @@ class ForEachStmt : public LabeledStmt {
SourceLoc WhereLoc;
Expr *WhereExpr = nullptr;
BraceStmt *Body;
DeclContext* DC = nullptr;

// Set by Sema:
ProtocolConformanceRef sequenceConformance = ProtocolConformanceRef();
Type sequenceType;
PatternBindingDecl *iteratorVar = nullptr;
Expr *nextCall = nullptr;
OpaqueValueExpr *elementExpr = nullptr;
BraceStmt *desugaredStmt = nullptr;
Expr *convertElementExpr = nullptr;

public:
ForEachStmt(LabeledStmtInfo LabelInfo, SourceLoc ForLoc, SourceLoc TryLoc,
SourceLoc AwaitLoc, SourceLoc UnsafeLoc, Pattern *Pat,
SourceLoc InLoc, Expr *Sequence,
SourceLoc WhereLoc, Expr *WhereExpr, BraceStmt *Body,
std::optional<bool> implicit = std::nullopt)
DeclContext* DC, std::optional<bool> implicit = std::nullopt)
: LabeledStmt(StmtKind::ForEach, getDefaultImplicitFlag(implicit, ForLoc),
LabelInfo),
ForLoc(ForLoc), TryLoc(TryLoc), AwaitLoc(AwaitLoc), UnsafeLoc(UnsafeLoc),
Expand All @@ -1026,15 +1026,9 @@ class ForEachStmt : public LabeledStmt {
setPattern(Pat);
}

void setIteratorVar(PatternBindingDecl *var) { iteratorVar = var; }
PatternBindingDecl *getIteratorVar() const { return iteratorVar; }

void setNextCall(Expr *next) { nextCall = next; }
Expr *getNextCall() const { return nextCall; }

void setElementExpr(OpaqueValueExpr *expr) { elementExpr = expr; }
OpaqueValueExpr *getElementExpr() const { return elementExpr; }

void setConvertElementExpr(Expr *expr) { convertElementExpr = expr; }
Expr *getConvertElementExpr() const { return convertElementExpr; }

Expand Down Expand Up @@ -1076,20 +1070,23 @@ class ForEachStmt : public LabeledStmt {
Expr *getParsedSequence() const { return Sequence; }
void setParsedSequence(Expr *S) { Sequence = S; }

/// Type-checked version of the sequence or nullptr if this statement
/// yet to be type-checked.
Expr *getTypeCheckedSequence() const;

/// getBody - Retrieve the body of the loop.
BraceStmt *getBody() const { return Body; }
void setBody(BraceStmt *B) { Body = B; }

SourceLoc getStartLoc() const { return getLabelLocOrKeywordLoc(ForLoc); }
SourceLoc getEndLoc() const { return Body->getEndLoc(); }

DeclContext *getDeclContext() const { return DC; }
void setDeclContext(DeclContext *newDC) { DC = newDC; }

static bool classof(const Stmt *S) {
return S->getKind() == StmtKind::ForEach;
}

BraceStmt* desugar();
BraceStmt* getDesugaredStmt() const { return desugaredStmt; }
void setDesugaredStmt(BraceStmt* newStmt) { desugaredStmt = newStmt; }
};

/// A pattern and an optional guard expression used in a 'case' statement.
Expand Down Expand Up @@ -1541,6 +1538,31 @@ class DoCatchStmt final
}
};

/// OpaqueStmt - created to serve as an indirection to a ForEachStmt's body
/// to avoid visiting it twice in the ASTWalker after having desugared the loop.
/// This ensures we only visit the body once, and this OpaqueStmt will only be
/// visited to emit the underlying statement in SILGen.
class OpaqueStmt final : public Stmt {
SourceLoc StartLoc;
SourceLoc EndLoc;
BraceStmt *Body; // FIXME: should I just use Stmt * so that this is more versatile?
// If not, should the class be renamed to be more specific?
public:
OpaqueStmt(BraceStmt* body, SourceLoc startLoc, SourceLoc endLoc)
: Stmt(StmtKind::Opaque, true /*always implicit*/),
StartLoc(startLoc), EndLoc(endLoc), Body(body) {}

SourceLoc getLoc() const { return StartLoc; }
SourceLoc getStartLoc() const { return StartLoc; }
SourceLoc getEndLoc() const { return EndLoc; }

BraceStmt* getUnderlyingStmt() { return Body; }

static bool classof(const Stmt *S) {
return S->getKind() == StmtKind::Opaque;
}
};

/// BreakStmt - The "break" and "break label" statement.
class BreakStmt : public Stmt {
SourceLoc Loc;
Expand Down
1 change: 1 addition & 0 deletions include/swift/AST/StmtNodes.def
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ ABSTRACT_STMT(Labeled, Stmt)
LABELED_STMT(ForEach, LabeledStmt)
LABELED_STMT(Switch, LabeledStmt)
STMT_RANGE(Labeled, If, Switch)
STMT(Opaque, Stmt)
STMT(Case, Stmt)
STMT(Break, Stmt)
STMT(Continue, Stmt)
Expand Down
19 changes: 19 additions & 0 deletions include/swift/AST/TypeCheckRequests.h
Original file line number Diff line number Diff line change
Expand Up @@ -5591,6 +5591,25 @@ class IsCustomAvailabilityDomainPermanentlyEnabled
}
};

class DesugarForEachStmtRequest
: public SimpleRequest<DesugarForEachStmtRequest,
BraceStmt *(ForEachStmt*),
RequestFlags::SeparatelyCached> {
public:
using SimpleRequest::SimpleRequest;

private:
friend SimpleRequest;

// Evaluation.
BraceStmt *evaluate(Evaluator &evaluator, ForEachStmt *FES) const;

public:
bool isCached() const { return true; }
std::optional<BraceStmt*> getCachedResult() const;
void cacheResult(BraceStmt *stmt) const;
};

#define SWIFT_TYPEID_ZONE TypeChecker
#define SWIFT_TYPEID_HEADER "swift/AST/TypeCheckerTypeIDZone.def"
#include "swift/Basic/DefineTypeIDZone.h"
Expand Down
4 changes: 4 additions & 0 deletions include/swift/AST/TypeCheckerTypeIDZone.def
Original file line number Diff line number Diff line change
Expand Up @@ -674,3 +674,7 @@ SWIFT_REQUEST(TypeChecker, IsCustomAvailabilityDomainPermanentlyEnabled,
SWIFT_REQUEST(TypeChecker, EmitPerformanceHints,
evaluator::SideEffect(SourceFile *),
Cached, NoLocationInfo)

SWIFT_REQUEST(TypeChecker, DesugarForEachStmtRequest,
Stmt*(const ForEachStmt*),
Cached, NoLocationInfo)
2 changes: 2 additions & 0 deletions include/swift/Sema/ConstraintLocator.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ enum ContextualTypePurpose : uint8_t {

CTP_ExprPattern, ///< `~=` operator application associated with expression
/// pattern.

CTP_ForEachElement, ///< Element expression associated with `for-in` loop.
};

namespace constraints {
Expand Down
7 changes: 1 addition & 6 deletions include/swift/Sema/SyntacticElementTarget.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,6 @@ struct SequenceIterationInfo {

/// The type of the pattern that matches the elements.
Type initType;

/// Implicit `$iterator = <sequence>.makeIterator()`
PatternBindingDecl *makeIteratorVar;

/// Implicit `$iterator.next()` call.
Expr *nextCall;
};

/// Describes information about a for-in loop over a pack that needs to be
Expand Down Expand Up @@ -605,6 +599,7 @@ class SyntacticElementTarget {
case CTP_Initialization:
case CTP_ForEachSequence:
case CTP_ExprPattern:
case CTP_ForEachElement:
break;
default:
assert(false && "Unexpected contextual type purpose");
Expand Down
21 changes: 12 additions & 9 deletions lib/AST/ASTDumper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3252,6 +3252,10 @@ class PrintStmt : public StmtVisitor<PrintStmt, void, Label>,
printFlag(S->TrailingSemiLoc.isValid(), "trailing_semi");
}

void visitOpaqueStmt(OpaqueStmt *S, Label label){
visitBraceStmt(S->getUnderlyingStmt(), label);
}

void visitBraceStmt(BraceStmt *S, Label label) {
printCommon(S, "brace_stmt", label);
printList(S->getElements(), [&](auto &Elt, Label label) {
Expand Down Expand Up @@ -3332,20 +3336,15 @@ class PrintStmt : public StmtVisitor<PrintStmt, void, Label>,
printRec(S->getWhere(), Label::always("where"));
}
printRec(S->getParsedSequence(), Label::optional("parsed_sequence"));
if (S->getIteratorVar()) {
printRec(S->getIteratorVar(), Label::optional("iterator_var"));
}
if (S->getNextCall()) {
printRec(S->getNextCall(), Label::optional("next_call"));
}
if (S->getConvertElementExpr()) {
printRec(S->getConvertElementExpr(),
Label::optional("convert_element_expr"));
}
if (S->getElementExpr()) {
printRec(S->getElementExpr(), Label::optional("element_expr"));
}

printRec(S->getBody(), Label::optional("body"));

printRec(S->getDesugaredStmt(), Label::optional("desugared_loop"));

printFoot();
}
void visitBreakStmt(BreakStmt *S, Label label) {
Expand Down Expand Up @@ -4237,6 +4236,10 @@ class PrintExpr : public ExprVisitor<PrintExpr, void, Label>,
printFoot();
}

void visitOpaqueExpr(OpaqueExpr *E, Label label){
visit(E->getOriginalExpr(), label);
}

void visitPropertyWrapperValuePlaceholderExpr(
PropertyWrapperValuePlaceholderExpr *E, Label label) {
printCommon(E, "property_wrapper_value_placeholder_expr", label);
Expand Down
12 changes: 11 additions & 1 deletion lib/AST/ASTPrinter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5633,6 +5633,16 @@ void PrintAST::visitTypeValueExpr(TypeValueExpr *expr) {
expr->getType()->print(Printer, Options);
}

void PrintAST::visitOpaqueExpr(OpaqueExpr *expr) {
// FIXME: unsure about this, maybe do nothing?
visit(expr->getOriginalExpr());
}

void PrintAST::visitOpaqueStmt(OpaqueStmt *stmt) {
// FIXME: unsure about this, maybe do nothing?
printBraceStmt(stmt->getUnderlyingStmt());
}

void PrintAST::visitBraceStmt(BraceStmt *stmt) {
printBraceStmt(stmt);
}
Expand Down Expand Up @@ -5810,7 +5820,7 @@ void PrintAST::visitForEachStmt(ForEachStmt *stmt) {
printPattern(stmt->getPattern());
Printer << " " << tok::kw_in << " ";
// FIXME: print container
if (auto *seq = stmt->getTypeCheckedSequence()) {
if (auto *seq = stmt->getParsedSequence()) {
// Look through the call to '.makeIterator()'

if (auto *CE = dyn_cast<CallExpr>(seq)) {
Expand Down
1 change: 1 addition & 0 deletions lib/AST/ASTScopeCreation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ class NodeAdder
VISIT_AND_IGNORE(ContinueStmt)
VISIT_AND_IGNORE(FallthroughStmt)
VISIT_AND_IGNORE(FailStmt)
VISIT_AND_IGNORE(OpaqueStmt)

#undef VISIT_AND_IGNORE

Expand Down
11 changes: 0 additions & 11 deletions lib/AST/ASTVerifier.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,11 +802,6 @@ class Verifier : public ASTWalker {
ForEachPatternSequences.insert(expansion);
}

if (!S->getElementExpr())
return true;

assert(!OpaqueValues.count(S->getElementExpr()));
OpaqueValues[S->getElementExpr()] = 0;
return true;
}

Expand All @@ -819,12 +814,6 @@ class Verifier : public ASTWalker {
// Clean up for real.
cleanup(expansion);
}

if (!S->getElementExpr())
return;

assert(OpaqueValues.count(S->getElementExpr()));
OpaqueValues.erase(S->getElementExpr());
}

bool shouldVerify(InterpolatedStringLiteralExpr *expr) {
Expand Down
33 changes: 15 additions & 18 deletions lib/AST/ASTWalker.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,8 @@ class Traversal : public ASTVisitor<Traversal, Expr*, Stmt*,

Expr *visitOpaqueValueExpr(OpaqueValueExpr *E) { return E; }

Expr *visitOpaqueExpr(OpaqueExpr *E) { return E; }

Expr *visitPropertyWrapperValuePlaceholderExpr(
PropertyWrapperValuePlaceholderExpr *E) {
if (E->getOpaqueValuePlaceholder()) {
Expand Down Expand Up @@ -1896,6 +1898,11 @@ Stmt *Traversal::visitPoundAssertStmt(PoundAssertStmt *S) {
return S;
}

Stmt* Traversal::visitOpaqueStmt(OpaqueStmt* OS){
// We do not want to visit it.
return OS;
}

Stmt *Traversal::visitBraceStmt(BraceStmt *BS) {
for (auto &Elem : BS->getElements()) {
if (auto *SubExpr = Elem.dyn_cast<Expr*>()) {
Expand Down Expand Up @@ -2066,28 +2073,11 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) {
return nullptr;
}

// The iterator decl is built directly on top of the sequence
// expression, so don't visit both.
//
// If for-in is already type-checked, the type-checked version
// of the sequence is going to be visited as part of `iteratorVar`.
if (auto IteratorVar = S->getIteratorVar()) {
if (doIt(IteratorVar))
return nullptr;

if (auto NextCall = S->getNextCall()) {
if ((NextCall = doIt(NextCall)))
S->setNextCall(NextCall);
else
return nullptr;
}
} else {
if (Expr *Sequence = S->getParsedSequence()) {
if (Expr *Sequence = S->getParsedSequence()) {
if ((Sequence = doIt(Sequence)))
S->setParsedSequence(Sequence);
else
return nullptr;
}
}

if (Expr *Where = S->getWhere()) {
Expand All @@ -2111,6 +2101,13 @@ Stmt *Traversal::visitForEachStmt(ForEachStmt *S) {
return nullptr;
}

if (Stmt *Desugared = S->getDesugaredStmt()) {
if ((Desugared = doIt(Desugared)))
S->setDesugaredStmt(cast<BraceStmt>(Desugared));
else
return nullptr;
}

return S;
}

Expand Down
Loading