[WHLSL] The checker needs to resolve types for the anonymous variables in ReadModifyW...
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLChecker.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLChecker.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLArrayReferenceType.h"
32 #include "WHLSLArrayType.h"
33 #include "WHLSLAssignmentExpression.h"
34 #include "WHLSLCallExpression.h"
35 #include "WHLSLCommaExpression.h"
36 #include "WHLSLDereferenceExpression.h"
37 #include "WHLSLDoWhileLoop.h"
38 #include "WHLSLDotExpression.h"
39 #include "WHLSLEntryPointType.h"
40 #include "WHLSLForLoop.h"
41 #include "WHLSLGatherEntryPointItems.h"
42 #include "WHLSLIfStatement.h"
43 #include "WHLSLIndexExpression.h"
44 #include "WHLSLInferTypes.h"
45 #include "WHLSLLogicalExpression.h"
46 #include "WHLSLLogicalNotExpression.h"
47 #include "WHLSLMakeArrayReferenceExpression.h"
48 #include "WHLSLMakePointerExpression.h"
49 #include "WHLSLPointerType.h"
50 #include "WHLSLProgram.h"
51 #include "WHLSLReadModifyWriteExpression.h"
52 #include "WHLSLResolvableType.h"
53 #include "WHLSLResolveOverloadImpl.h"
54 #include "WHLSLResolvingType.h"
55 #include "WHLSLReturn.h"
56 #include "WHLSLSwitchStatement.h"
57 #include "WHLSLTernaryExpression.h"
58 #include "WHLSLVisitor.h"
59 #include "WHLSLWhileLoop.h"
60 #include <wtf/HashMap.h>
61 #include <wtf/HashSet.h>
62 #include <wtf/Ref.h>
63 #include <wtf/Vector.h>
64 #include <wtf/text/WTFString.h>
65
66 namespace WebCore {
67
68 namespace WHLSL {
69
70 class PODChecker : public Visitor {
71 public:
72     PODChecker() = default;
73
74     virtual ~PODChecker() = default;
75
76     void visit(AST::EnumerationDefinition& enumerationDefinition) override
77     {
78         Visitor::visit(enumerationDefinition);
79     }
80
81     void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override
82     {
83         if (!nativeTypeDeclaration.isNumber()
84             && !nativeTypeDeclaration.isVector()
85             && !nativeTypeDeclaration.isMatrix())
86             setError();
87     }
88
89     void visit(AST::StructureDefinition& structureDefinition) override
90     {
91         Visitor::visit(structureDefinition);
92     }
93
94     void visit(AST::TypeDefinition& typeDefinition) override
95     {
96         Visitor::visit(typeDefinition);
97     }
98
99     void visit(AST::ArrayType& arrayType) override
100     {
101         Visitor::visit(arrayType);
102     }
103
104     void visit(AST::PointerType&) override
105     {
106         setError();
107     }
108
109     void visit(AST::ArrayReferenceType&) override
110     {
111         setError();
112     }
113
114     void visit(AST::TypeReference& typeReference) override
115     {
116         checkErrorAndVisit(typeReference.resolvedType());
117     }
118 };
119
120 static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(Lexer::Token origin, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics)
121 {
122     const bool isOperator = true;
123     auto returnType = makeUniqueRef<AST::PointerType>(Lexer::Token(origin), firstArgument.addressSpace(), firstArgument.elementType().clone());
124     AST::VariableDeclarations parameters;
125     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), firstArgument.clone(), String(), WTF::nullopt, WTF::nullopt));
126     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), UniqueRef<AST::UnnamedType>(AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.uintType())), String(), WTF::nullopt, WTF::nullopt));
127     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator&[]", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
128 }
129
130 static AST::NativeFunctionDeclaration resolveWithOperatorLength(Lexer::Token origin, AST::UnnamedType& firstArgument, const Intrinsics& intrinsics)
131 {
132     const bool isOperator = true;
133     auto returnType = AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.uintType());
134     AST::VariableDeclarations parameters;
135     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), firstArgument.clone(), String(), WTF::nullopt, WTF::nullopt));
136     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator.length", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
137 }
138
139 static AST::NativeFunctionDeclaration resolveWithReferenceComparator(Lexer::Token origin, ResolvingType& firstArgument, ResolvingType& secondArgument, const Intrinsics& intrinsics)
140 {
141     const bool isOperator = true;
142     auto returnType = AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.boolType());
143     auto argumentType = firstArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
144         return unnamedType->clone();
145     }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
146         return secondArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
147             return unnamedType->clone();
148         }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
149             // We encountered "null == null".
150             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198162 This can probably be generalized, using the "preferred type" infrastructure used by generic literals
151             ASSERT_NOT_REACHED();
152             return AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.intType());
153         }));
154     }));
155     AST::VariableDeclarations parameters;
156     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), argumentType->clone(), String(), WTF::nullopt, WTF::nullopt));
157     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), UniqueRef<AST::UnnamedType>(WTFMove(argumentType)), String(), WTF::nullopt, WTF::nullopt));
158     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
159 }
160
161 enum class Acceptability {
162     Yes,
163     Maybe,
164     No
165 };
166
167 static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(const String& name, Lexer::Token origin, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics)
168 {
169     if (name == "operator&[]" && types.size() == 2) {
170         auto* firstArgumentArrayRef = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* {
171             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
172                 return &downcast<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType));
173             return nullptr;
174         }, [](RefPtr<ResolvableTypeReference>&) -> AST::ArrayReferenceType* {
175             return nullptr;
176         }));
177         bool secondArgumentIsUint = types[1].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
178             return matches(unnamedType, intrinsics.uintType());
179         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
180             return resolvableTypeReference->resolvableType().canResolve(intrinsics.uintType());
181         }));
182         if (firstArgumentArrayRef && secondArgumentIsUint)
183             return resolveWithOperatorAnderIndexer(origin, *firstArgumentArrayRef, intrinsics);
184     } else if (name == "operator.length" && types.size() == 1) {
185         auto* firstArgumentReference = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
186             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
187                 return &unnamedType;
188             return nullptr;
189         }, [](RefPtr<ResolvableTypeReference>&) -> AST::UnnamedType* {
190             return nullptr;
191         }));
192         if (firstArgumentReference)
193             return resolveWithOperatorLength(origin, *firstArgumentReference, intrinsics);
194     } else if (name == "operator==" && types.size() == 2) {
195         auto acceptability = [](ResolvingType& resolvingType) -> Acceptability {
196             return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> Acceptability {
197                 return is<AST::ReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) ? Acceptability::Yes : Acceptability::No;
198             }, [](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Acceptability {
199                 return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType()) ? Acceptability::Maybe : Acceptability::No;
200             }));
201         };
202         auto leftAcceptability = acceptability(types[0].get());
203         auto rightAcceptability = acceptability(types[1].get());
204         bool success = false;
205         if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) {
206             auto& unnamedType1 = *types[0].get().getUnnamedType();
207             auto& unnamedType2 = *types[1].get().getUnnamedType();
208             success = matches(unnamedType1, unnamedType2);
209         } else if ((leftAcceptability == Acceptability::Maybe && rightAcceptability == Acceptability::Yes)
210             || (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Maybe))
211             success = true;
212         if (success)
213             return resolveWithReferenceComparator(origin, types[0].get(), types[1].get(), intrinsics);
214     }
215     return WTF::nullopt;
216 }
217
218 static AST::FunctionDeclaration* resolveFunction(Program& program, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>& possibleOverloads, Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, Lexer::Token origin, const Intrinsics& intrinsics, AST::NamedType* castReturnType = nullptr)
219 {
220     if (AST::FunctionDeclaration* function = resolveFunctionOverload(possibleOverloads, types, castReturnType))
221         return function;
222
223     if (auto newFunction = resolveByInstantiation(name, origin, types, intrinsics)) {
224         program.append(WTFMove(*newFunction));
225         return &program.nativeFunctionDeclarations().last();
226     }
227
228     return nullptr;
229 }
230
231 static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
232 {
233     {
234         auto checkDuplicateSemantics = [&](const Vector<EntryPointItem>& items) -> bool {
235             for (size_t i = 0; i < items.size(); ++i) {
236                 for (size_t j = i + 1; j < items.size(); ++j) {
237                     if (items[i].semantic == items[j].semantic)
238                         return false;
239                 }
240             }
241             return true;
242         };
243         if (!checkDuplicateSemantics(inputItems))
244             return false;
245         if (!checkDuplicateSemantics(outputItems))
246             return false;
247     }
248
249     {
250         auto checkSemanticTypes = [&](const Vector<EntryPointItem>& items) -> bool {
251             for (auto& item : items) {
252                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
253                     return semantic.isAcceptableType(*item.unnamedType, intrinsics);
254                 }), *item.semantic);
255                 if (!acceptable)
256                     return false;
257             }
258             return true;
259         };
260         if (!checkSemanticTypes(inputItems))
261             return false;
262         if (!checkSemanticTypes(outputItems))
263             return false;
264     }
265
266     {
267         auto checkSemanticForShaderType = [&](const Vector<EntryPointItem>& items, AST::BaseSemantic::ShaderItemDirection direction) -> bool {
268             for (auto& item : items) {
269                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
270                     return semantic.isAcceptableForShaderItemDirection(direction, entryPointType);
271                 }), *item.semantic);
272                 if (!acceptable)
273                     return false;
274             }
275             return true;
276         };
277         if (!checkSemanticForShaderType(inputItems, AST::BaseSemantic::ShaderItemDirection::Input))
278             return false;
279         if (!checkSemanticForShaderType(outputItems, AST::BaseSemantic::ShaderItemDirection::Output))
280             return false;
281     }
282
283     {
284         auto checkPODData = [&](const Vector<EntryPointItem>& items) -> bool {
285             for (auto& item : items) {
286                 PODChecker podChecker;
287                 if (is<AST::PointerType>(item.unnamedType))
288                     podChecker.checkErrorAndVisit(downcast<AST::PointerType>(*item.unnamedType).elementType());
289                 else if (is<AST::ArrayReferenceType>(item.unnamedType))
290                     podChecker.checkErrorAndVisit(downcast<AST::ArrayReferenceType>(*item.unnamedType).elementType());
291                 else if (is<AST::ArrayType>(item.unnamedType))
292                     podChecker.checkErrorAndVisit(downcast<AST::ArrayType>(*item.unnamedType).type());
293                 else
294                     continue;
295                 if (podChecker.error())
296                     return false;
297             }
298             return true;
299         };
300         if (!checkPODData(inputItems))
301             return false;
302         if (!checkPODData(outputItems))
303             return false;
304     }
305
306     return true;
307 }
308
309 static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinition, const Intrinsics& intrinsics, NameContext& nameContext)
310 {
311     enum class CheckKind {
312         Index,
313         Dot
314     };
315
316     auto checkGetter = [&](CheckKind kind) -> bool {
317         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
318         if (functionDefinition.parameters().size() != numExpectedParameters)
319             return false;
320         auto& firstParameterUnifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
321         if (is<AST::UnnamedType>(firstParameterUnifyNode)) {
322             auto& unnamedType = downcast<AST::UnnamedType>(firstParameterUnifyNode);
323             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
324                 return false;
325         }
326         if (kind == CheckKind::Index) {
327             auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1]->type())->unifyNode();
328             if (!is<AST::NamedType>(secondParameterUnifyNode))
329                 return false;
330             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
331             if (!is<AST::NativeTypeDeclaration>(namedType))
332                 return false;
333             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
334             if (!nativeTypeDeclaration.isInt())
335                 return false;
336         }
337         return true;
338     };
339
340     auto checkSetter = [&](CheckKind kind) -> bool {
341         size_t numExpectedParameters = kind == CheckKind::Index ? 3 : 2;
342         if (functionDefinition.parameters().size() != numExpectedParameters)
343             return false;
344         auto& firstArgumentUnifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
345         if (is<AST::UnnamedType>(firstArgumentUnifyNode)) {
346             auto& unnamedType = downcast<AST::UnnamedType>(firstArgumentUnifyNode);
347             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
348                 return false;
349         }
350         if (kind == CheckKind::Index) {
351             auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1]->type())->unifyNode();
352             if (!is<AST::NamedType>(secondParameterUnifyNode))
353                 return false;
354             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
355             if (!is<AST::NativeTypeDeclaration>(namedType))
356                 return false;
357             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
358             if (!nativeTypeDeclaration.isInt())
359                 return false;
360         }
361         if (!matches(functionDefinition.type(), *functionDefinition.parameters()[0]->type()))
362             return false;
363         auto& valueType = *functionDefinition.parameters()[numExpectedParameters - 1]->type();
364         auto getterName = functionDefinition.name().substring(0, functionDefinition.name().length() - 1);
365         auto* getterFuncs = nameContext.getFunctions(getterName);
366         if (!getterFuncs)
367             return false;
368         Vector<ResolvingType> argumentTypes;
369         Vector<std::reference_wrapper<ResolvingType>> argumentTypeReferences;
370         for (size_t i = 0; i < numExpectedParameters - 1; ++i)
371             argumentTypes.append((*functionDefinition.parameters()[i]->type())->clone());
372         for (auto& argumentType : argumentTypes)
373             argumentTypeReferences.append(argumentType);
374         auto* overload = resolveFunctionOverload(*getterFuncs, argumentTypeReferences);
375         if (!overload)
376             return false;
377         auto& resultType = overload->type();
378         return matches(resultType, valueType);
379     };
380
381     auto checkAnder = [&](CheckKind kind) -> bool {
382         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
383         if (functionDefinition.parameters().size() != numExpectedParameters)
384             return false;
385         {
386             auto& unifyNode = functionDefinition.type().unifyNode();
387             if (!is<AST::UnnamedType>(unifyNode))
388                 return false;
389             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
390             if (!is<AST::PointerType>(unnamedType))
391                 return false;
392         }
393         {
394             auto& unifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
395             if (!is<AST::UnnamedType>(unifyNode))
396                 return false;
397             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
398             return is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType);
399         }
400     };
401
402     if (!functionDefinition.isOperator())
403         return true;
404     if (functionDefinition.isCast())
405         return true;
406     if (functionDefinition.name() == "operator++" || functionDefinition.name() == "operator--") {
407         return functionDefinition.parameters().size() == 1
408             && matches(*functionDefinition.parameters()[0]->type(), functionDefinition.type());
409     }
410     if (functionDefinition.name() == "operator+" || functionDefinition.name() == "operator-")
411         return functionDefinition.parameters().size() == 1 || functionDefinition.parameters().size() == 2;
412     if (functionDefinition.name() == "operator*"
413         || functionDefinition.name() == "operator/"
414         || functionDefinition.name() == "operator%"
415         || functionDefinition.name() == "operator&"
416         || functionDefinition.name() == "operator|"
417         || functionDefinition.name() == "operator^"
418         || functionDefinition.name() == "operator<<"
419         || functionDefinition.name() == "opreator>>")
420         return functionDefinition.parameters().size() == 2;
421     if (functionDefinition.name() == "operator~")
422         return functionDefinition.parameters().size() == 1;
423     if (functionDefinition.name() == "operator=="
424         || functionDefinition.name() == "operator<"
425         || functionDefinition.name() == "operator<="
426         || functionDefinition.name() == "operator>"
427         || functionDefinition.name() == "operator>=") {
428         return functionDefinition.parameters().size() == 2
429             && matches(functionDefinition.type(), intrinsics.boolType());
430     }
431     if (functionDefinition.name() == "operator[]")
432         return checkGetter(CheckKind::Index);
433     if (functionDefinition.name() == "operator[]=")
434         return checkSetter(CheckKind::Index);
435     if (functionDefinition.name() == "operator&[]")
436         return checkAnder(CheckKind::Index);
437     if (functionDefinition.name().startsWith("operator.")) {
438         if (functionDefinition.name().endsWith("="))
439             return checkSetter(CheckKind::Dot);
440         return checkGetter(CheckKind::Dot);
441     }
442     if (functionDefinition.name().startsWith("operator&."))
443         return checkAnder(CheckKind::Dot);
444     return false;
445 }
446
447 class Checker : public Visitor {
448 public:
449     Checker(const Intrinsics& intrinsics, Program& program)
450         : m_intrinsics(intrinsics)
451         , m_program(program)
452     {
453     }
454
455     ~Checker() = default;
456
457     void visit(Program&) override;
458
459     bool assignTypes();
460
461 private:
462     bool checkShaderType(const AST::FunctionDefinition&);
463     bool isBoolType(ResolvingType&);
464     struct RecurseInfo {
465         ResolvingType& resolvingType;
466         const AST::TypeAnnotation typeAnnotation;
467     };
468     Optional<RecurseInfo> recurseAndGetInfo(AST::Expression&, bool requiresLeftValue = false);
469     Optional<RecurseInfo> getInfo(AST::Expression&, bool requiresLeftValue = false);
470     Optional<UniqueRef<AST::UnnamedType>> recurseAndWrapBaseType(AST::PropertyAccessExpression&);
471     bool recurseAndRequireBoolType(AST::Expression&);
472     void assignType(AST::Expression&, UniqueRef<AST::UnnamedType>&&, AST::TypeAnnotation);
473     void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>&&, AST::TypeAnnotation);
474     void forwardType(AST::Expression&, ResolvingType&, AST::TypeAnnotation);
475
476     void visit(AST::FunctionDefinition&) override;
477     void visit(AST::EnumerationDefinition&) override;
478     void visit(AST::TypeReference&) override;
479     void visit(AST::VariableDeclaration&) override;
480     void visit(AST::AssignmentExpression&) override;
481     void visit(AST::ReadModifyWriteExpression&) override;
482     void visit(AST::DereferenceExpression&) override;
483     void visit(AST::MakePointerExpression&) override;
484     void visit(AST::MakeArrayReferenceExpression&) override;
485     void visit(AST::DotExpression&) override;
486     void visit(AST::IndexExpression&) override;
487     void visit(AST::VariableReference&) override;
488     void visit(AST::Return&) override;
489     void visit(AST::PointerType&) override;
490     void visit(AST::ArrayReferenceType&) override;
491     void visit(AST::IntegerLiteral&) override;
492     void visit(AST::UnsignedIntegerLiteral&) override;
493     void visit(AST::FloatLiteral&) override;
494     void visit(AST::NullLiteral&) override;
495     void visit(AST::BooleanLiteral&) override;
496     void visit(AST::EnumerationMemberLiteral&) override;
497     void visit(AST::LogicalNotExpression&) override;
498     void visit(AST::LogicalExpression&) override;
499     void visit(AST::IfStatement&) override;
500     void visit(AST::WhileLoop&) override;
501     void visit(AST::DoWhileLoop&) override;
502     void visit(AST::ForLoop&) override;
503     void visit(AST::SwitchStatement&) override;
504     void visit(AST::CommaExpression&) override;
505     void visit(AST::TernaryExpression&) override;
506     void visit(AST::CallExpression&) override;
507
508     void finishVisiting(AST::PropertyAccessExpression&, ResolvingType* additionalArgumentType = nullptr);
509
510     HashMap<AST::Expression*, std::unique_ptr<ResolvingType>> m_typeMap;
511     HashMap<AST::Expression*, AST::TypeAnnotation> m_typeAnnotations;
512     HashSet<String> m_vertexEntryPoints;
513     HashSet<String> m_fragmentEntryPoints;
514     HashSet<String> m_computeEntryPoints;
515     const Intrinsics& m_intrinsics;
516     Program& m_program;
517 };
518
519 void Checker::visit(Program& program)
520 {
521     // These visiting functions might add new global statements, so don't use foreach syntax.
522     for (size_t i = 0; i < program.typeDefinitions().size(); ++i)
523         checkErrorAndVisit(program.typeDefinitions()[i]);
524     for (size_t i = 0; i < program.structureDefinitions().size(); ++i)
525         checkErrorAndVisit(program.structureDefinitions()[i]);
526     for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i)
527         checkErrorAndVisit(program.enumerationDefinitions()[i]);
528     for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i)
529         checkErrorAndVisit(program.nativeTypeDeclarations()[i]);
530
531     for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
532         checkErrorAndVisit(program.functionDefinitions()[i]);
533     for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i)
534         checkErrorAndVisit(program.nativeFunctionDeclarations()[i]);
535 }
536
537 bool Checker::assignTypes()
538 {
539     for (auto& keyValuePair : m_typeMap) {
540         auto success = keyValuePair.value->visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
541             keyValuePair.key->setType(unnamedType->clone());
542             return true;
543         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
544             if (!resolvableTypeReference->resolvableType().maybeResolvedType()) {
545                 if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType())))
546                     return false;
547             }
548             keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType().clone());
549             return true;
550         }));
551         if (!success)
552             return false;
553     }
554
555     for (auto& keyValuePair : m_typeAnnotations)
556         keyValuePair.key->setTypeAnnotation(WTFMove(keyValuePair.value));
557     return true;
558 }
559
560 bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition)
561 {
562     switch (*functionDefinition.entryPointType()) {
563     case AST::EntryPointType::Vertex:
564         return static_cast<bool>(m_vertexEntryPoints.add(functionDefinition.name()));
565     case AST::EntryPointType::Fragment:
566         return static_cast<bool>(m_fragmentEntryPoints.add(functionDefinition.name()));
567     case AST::EntryPointType::Compute:
568         return static_cast<bool>(m_computeEntryPoints.add(functionDefinition.name()));
569     }
570 }
571
572 void Checker::visit(AST::FunctionDefinition& functionDefinition)
573 {
574     if (functionDefinition.entryPointType()) {
575         if (!checkShaderType(functionDefinition)) {
576             setError();
577             return;
578         }
579         auto entryPointItems = gatherEntryPointItems(m_intrinsics, functionDefinition);
580         if (!entryPointItems) {
581             setError();
582             return;
583         }
584         if (!checkSemantics(entryPointItems->inputs, entryPointItems->outputs, functionDefinition.entryPointType(), m_intrinsics)) {
585             setError();
586             return;
587         }
588     }
589     if (!checkOperatorOverload(functionDefinition, m_intrinsics, m_program.nameContext())) {
590         setError();
591         return;
592     }
593
594     Visitor::visit(functionDefinition);
595 }
596
597 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& left, ResolvingType& right)
598 {
599     return left.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
600         return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
601             if (matches(left, right))
602                 return left->clone();
603             return WTF::nullopt;
604         }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
605             return matchAndCommit(left, right->resolvableType());
606         }));
607     }, [&](RefPtr<ResolvableTypeReference>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
608         return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
609             return matchAndCommit(right, left->resolvableType());
610         }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
611             return matchAndCommit(left->resolvableType(), right->resolvableType());
612         }));
613     }));
614 }
615
616 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType)
617 {
618     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
619         if (matches(unnamedType, resolvingType))
620             return unnamedType.clone();
621         return WTF::nullopt;
622     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
623         return matchAndCommit(unnamedType, resolvingType->resolvableType());
624     }));
625 }
626
627 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::NamedType& namedType)
628 {
629     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
630         if (matches(resolvingType, namedType))
631             return resolvingType->clone();
632         return WTF::nullopt;
633     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
634         return matchAndCommit(namedType, resolvingType->resolvableType());
635     }));
636 }
637
638 static Optional<UniqueRef<AST::UnnamedType>> commit(ResolvingType& resolvingType)
639 {
640     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> Optional<UniqueRef<AST::UnnamedType>> {
641         return unnamedType->clone();
642     }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Optional<UniqueRef<AST::UnnamedType>> {
643         if (!resolvableTypeReference->resolvableType().maybeResolvedType())
644             return commit(resolvableTypeReference->resolvableType());
645         return resolvableTypeReference->resolvableType().resolvedType().clone();
646     }));
647 }
648
649 void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
650 {
651     auto* baseType = ([&]() -> AST::NativeTypeDeclaration* {
652         checkErrorAndVisit(enumerationDefinition.type());
653         auto& baseType = enumerationDefinition.type().unifyNode();
654         if (!is<AST::NamedType>(baseType))
655             return nullptr;
656         auto& namedType = downcast<AST::NamedType>(baseType);
657         if (!is<AST::NativeTypeDeclaration>(namedType))
658             return nullptr;
659         auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
660         if (!nativeTypeDeclaration.isInt())
661             return nullptr;
662         return &nativeTypeDeclaration;
663     })();
664     if (!baseType) {
665         setError();
666         return;
667     }
668
669     auto enumerationMembers = enumerationDefinition.enumerationMembers();
670
671     auto matchAndCommitMember = [&](AST::EnumerationMember& member) -> bool {
672         return member.value()->visit(WTF::makeVisitor([&](AST::Expression& value) -> bool {
673             auto valueInfo = recurseAndGetInfo(value);
674             if (!valueInfo)
675                 return false;
676             return static_cast<bool>(matchAndCommit(valueInfo->resolvingType, *baseType));
677         }));
678     };
679
680     for (auto& member : enumerationMembers) {
681         if (!member.get().value())
682             continue;
683
684         if (!matchAndCommitMember(member)) {
685             setError();
686             return;
687         }
688     }
689
690     int64_t nextValue = 0;
691     for (auto& member : enumerationMembers) {
692         if (member.get().value()) {
693             auto value = member.get().value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
694                 return integerLiteral.valueForSelectedType();
695             }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
696                 return unsignedIntegerLiteral.valueForSelectedType();
697             }, [&](auto&) -> int64_t {
698                 ASSERT_NOT_REACHED();
699                 return 0;
700             }));
701             nextValue = baseType->successor()(value);
702         } else {
703             if (nextValue > std::numeric_limits<int>::max()) {
704                 ASSERT(nextValue <= std::numeric_limits<unsigned>::max());
705                 member.get().setValue(AST::ConstantExpression(AST::UnsignedIntegerLiteral(Lexer::Token(member.get().origin()), static_cast<unsigned>(nextValue))));
706             }
707             ASSERT(nextValue >= std::numeric_limits<int>::min());
708             member.get().setValue(AST::ConstantExpression(AST::IntegerLiteral(Lexer::Token(member.get().origin()), static_cast<int>(nextValue))));
709
710             if (!matchAndCommitMember(member)) {
711                 setError();
712                 return;
713             }
714
715             nextValue = baseType->successor()(nextValue);
716         }
717     }
718
719     auto getValue = [&](AST::EnumerationMember& member) -> int64_t {
720         ASSERT(member.value());
721         auto value = member.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
722             return integerLiteral.value();
723         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
724             return unsignedIntegerLiteral.value();
725         }, [&](auto&) -> int64_t {
726             ASSERT_NOT_REACHED();
727             return 0;
728         }));
729         return value;
730     };
731
732     for (size_t i = 0; i < enumerationMembers.size(); ++i) {
733         auto value = getValue(enumerationMembers[i].get());
734         for (size_t j = i + 1; j < enumerationMembers.size(); ++j) {
735             auto otherValue = getValue(enumerationMembers[j].get());
736             if (value == otherValue) {
737                 setError();
738                 return;
739             }
740         }
741     }
742
743     bool foundZero = false;
744     for (auto& member : enumerationMembers) {
745         if (!getValue(member.get())) {
746             foundZero = true;
747             break;
748         }
749     }
750     if (!foundZero) {
751         setError();
752         return;
753     }
754 }
755
756 void Checker::visit(AST::TypeReference& typeReference)
757 {
758     ASSERT(typeReference.maybeResolvedType());
759
760     for (auto& typeArgument : typeReference.typeArguments())
761         checkErrorAndVisit(typeArgument);
762 }
763
764 auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
765 {
766     Visitor::visit(expression);
767     if (error())
768         return WTF::nullopt;
769     return getInfo(expression, requiresLeftValue);
770 }
771
772 auto Checker::getInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
773 {
774     auto typeIterator = m_typeMap.find(&expression);
775     ASSERT(typeIterator != m_typeMap.end());
776
777     auto typeAnnotationIterator = m_typeAnnotations.find(&expression);
778     ASSERT(typeAnnotationIterator != m_typeAnnotations.end());
779     if (requiresLeftValue && typeAnnotationIterator->value.isRightValue()) {
780         setError();
781         return WTF::nullopt;
782     }
783     return {{ *typeIterator->value, typeAnnotationIterator->value }};
784 }
785
786 void Checker::visit(AST::VariableDeclaration& variableDeclaration)
787 {
788     // ReadModifyWriteExpressions are the only place where anonymous variables exist,
789     // and that doesn't recurse on the anonymous variables, so we can assume the variable has a type.
790     checkErrorAndVisit(*variableDeclaration.type());
791     if (variableDeclaration.initializer()) {
792         auto& lhsType = *variableDeclaration.type();
793         auto initializerInfo = recurseAndGetInfo(*variableDeclaration.initializer());
794         if (!initializerInfo)
795             return;
796         if (!matchAndCommit(initializerInfo->resolvingType, lhsType)) {
797             setError();
798             return;
799         }
800     }
801 }
802
803 void Checker::assignType(AST::Expression& expression, UniqueRef<AST::UnnamedType>&& unnamedType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
804 {
805     auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(WTFMove(unnamedType)));
806     ASSERT_UNUSED(addResult, addResult.isNewEntry);
807     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
808     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
809 }
810
811 void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference>&& resolvableTypeReference, AST::TypeAnnotation typeAnnotation = AST::RightValue())
812 {
813     auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(WTFMove(resolvableTypeReference)));
814     ASSERT_UNUSED(addResult, addResult.isNewEntry);
815     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
816     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
817 }
818
819 void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
820 {
821     resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& result) {
822         auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(result->clone()));
823         ASSERT_UNUSED(addResult, addResult.isNewEntry);
824     }, [&](RefPtr<ResolvableTypeReference>& result) {
825         auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(result.copyRef()));
826         ASSERT_UNUSED(addResult, addResult.isNewEntry);
827     }));
828     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
829     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
830 }
831
832 void Checker::visit(AST::AssignmentExpression& assignmentExpression)
833 {
834     auto leftInfo = recurseAndGetInfo(assignmentExpression.left(), true);
835     if (!leftInfo)
836         return;
837
838     if (leftInfo->typeAnnotation.isRightValue()) {
839         setError();
840         return;
841     }
842
843     auto rightInfo = recurseAndGetInfo(assignmentExpression.right());
844     if (!rightInfo)
845         return;
846
847     auto resultType = matchAndCommit(leftInfo->resolvingType, rightInfo->resolvingType);
848     if (!resultType) {
849         setError();
850         return;
851     }
852
853     assignType(assignmentExpression, WTFMove(*resultType));
854 }
855
856 void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
857 {
858     auto leftValueInfo = recurseAndGetInfo(readModifyWriteExpression.leftValue(), true);
859     if (!leftValueInfo)
860         return;
861
862     readModifyWriteExpression.oldValue().setType(leftValueInfo->resolvingType.getUnnamedType()->clone());
863
864     auto newValueInfo = recurseAndGetInfo(readModifyWriteExpression.newValueExpression());
865     if (!newValueInfo)
866         return;
867
868     if (Optional<UniqueRef<AST::UnnamedType>> matchedType = matchAndCommit(leftValueInfo->resolvingType, newValueInfo->resolvingType))
869         readModifyWriteExpression.newValue().setType(WTFMove(matchedType.value()));
870     else {
871         setError();
872         return;
873     }
874
875     auto resultInfo = recurseAndGetInfo(readModifyWriteExpression.resultExpression());
876     if (!resultInfo)
877         return;
878
879     forwardType(readModifyWriteExpression, resultInfo->resolvingType);
880 }
881
882 static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType)
883 {
884     return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& type) -> AST::UnnamedType* {
885         return &type;
886     }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* {
887         // FIXME: If the type isn't committed, should we just commit() it now?
888         return type->resolvableType().maybeResolvedType();
889     }));
890 }
891
892 void Checker::visit(AST::DereferenceExpression& dereferenceExpression)
893 {
894     auto pointerInfo = recurseAndGetInfo(dereferenceExpression.pointer());
895     if (!pointerInfo)
896         return;
897
898     auto* unnamedType = getUnnamedType(pointerInfo->resolvingType);
899
900     auto* pointerType = ([&](AST::UnnamedType* unnamedType) -> AST::PointerType* {
901         if (!unnamedType)
902             return nullptr;
903         auto& unifyNode = unnamedType->unifyNode();
904         if (!is<AST::UnnamedType>(unifyNode))
905             return nullptr;
906         auto& unnamedUnifyType = downcast<AST::UnnamedType>(unifyNode);
907         if (!is<AST::PointerType>(unnamedUnifyType))
908             return nullptr;
909         return &downcast<AST::PointerType>(unnamedUnifyType);
910     })(unnamedType);
911     if (!pointerType) {
912         setError();
913         return;
914     }
915
916     assignType(dereferenceExpression, pointerType->elementType().clone(), AST::LeftValue { pointerType->addressSpace() });
917 }
918
919 void Checker::visit(AST::MakePointerExpression& makePointerExpression)
920 {
921     auto leftValueInfo = recurseAndGetInfo(makePointerExpression.leftValue(), true);
922     if (!leftValueInfo)
923         return;
924
925     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
926     if (!leftAddressSpace) {
927         setError();
928         return;
929     }
930
931     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
932     if (!leftValueType) {
933         setError();
934         return;
935     }
936
937     assignType(makePointerExpression, makeUniqueRef<AST::PointerType>(Lexer::Token(makePointerExpression.origin()), *leftAddressSpace, leftValueType->clone()));
938 }
939
940 void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
941 {
942     auto leftValueInfo = recurseAndGetInfo(makeArrayReferenceExpression.leftValue());
943     if (!leftValueInfo)
944         return;
945
946     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
947     if (!leftValueType) {
948         setError();
949         return;
950     }
951
952     auto& unifyNode = leftValueType->unifyNode();
953     if (is<AST::UnnamedType>(unifyNode)) {
954         auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
955         if (is<AST::PointerType>(unnamedType)) {
956             auto& pointerType = downcast<AST::PointerType>(unnamedType);
957             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the fact that we're not targetting the item; we're targetting the item's inner element.
958             assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), pointerType.addressSpace(), pointerType.elementType().clone()));
959             return;
960         }
961
962         auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
963         if (!leftAddressSpace) {
964             setError();
965             return;
966         }
967
968         if (is<AST::ArrayType>(unnamedType)) {
969             auto& arrayType = downcast<AST::ArrayType>(unnamedType);
970             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the number of elements.
971             assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *leftAddressSpace, arrayType.type().clone()));
972             return;
973         }
974     }
975
976     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
977     if (!leftAddressSpace) {
978         setError();
979         return;
980     }
981
982     assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *leftAddressSpace, leftValueType->clone()));
983 }
984
985 static Optional<UniqueRef<AST::UnnamedType>> argumentTypeForAndOverload(AST::UnnamedType& baseType, AST::AddressSpace addressSpace)
986 {
987     auto& unifyNode = baseType.unifyNode();
988     if (is<AST::NamedType>(unifyNode)) {
989         auto& namedType = downcast<AST::NamedType>(unifyNode);
990         return { makeUniqueRef<AST::PointerType>(Lexer::Token(namedType.origin()), addressSpace, AST::TypeReference::wrap(Lexer::Token(namedType.origin()), namedType)) };
991     }
992
993     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
994
995     if (is<AST::ArrayReferenceType>(unnamedType))
996         return unnamedType.clone();
997
998     if (is<AST::ArrayType>(unnamedType))
999         return { makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(unnamedType.origin()), addressSpace, downcast<AST::ArrayType>(unnamedType).type().clone()) };
1000
1001     if (is<AST::PointerType>(unnamedType))
1002         return WTF::nullopt;
1003
1004     return { makeUniqueRef<AST::PointerType>(Lexer::Token(unnamedType.origin()), addressSpace, unnamedType.clone()) };
1005 }
1006
1007 void Checker::finishVisiting(AST::PropertyAccessExpression& propertyAccessExpression, ResolvingType* additionalArgumentType)
1008 {
1009     auto baseInfo = recurseAndGetInfo(propertyAccessExpression.base());
1010     if (!baseInfo)
1011         return;
1012     auto baseUnnamedType = commit(baseInfo->resolvingType);
1013     if (!baseUnnamedType)
1014         return;
1015
1016     AST::FunctionDeclaration* getterFunction = nullptr;
1017     AST::UnnamedType* getterReturnType = nullptr;
1018     {
1019         Vector<std::reference_wrapper<ResolvingType>> getterArgumentTypes { baseInfo->resolvingType };
1020         if (additionalArgumentType)
1021             getterArgumentTypes.append(*additionalArgumentType);
1022         if ((getterFunction = resolveFunction(m_program, propertyAccessExpression.possibleGetterOverloads(), getterArgumentTypes, propertyAccessExpression.getterFunctionName(), propertyAccessExpression.origin(), m_intrinsics)))
1023             getterReturnType = &getterFunction->type();
1024     }
1025
1026     AST::FunctionDeclaration* anderFunction = nullptr;
1027     AST::UnnamedType* anderReturnType = nullptr;
1028     auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace();
1029     if (leftAddressSpace) {
1030         if (auto argumentTypeForAndOverload = WHLSL::argumentTypeForAndOverload(*baseUnnamedType, *leftAddressSpace)) {
1031             ResolvingType argumentType = { WTFMove(*argumentTypeForAndOverload) };
1032             Vector<std::reference_wrapper<ResolvingType>> anderArgumentTypes { argumentType };
1033             if (additionalArgumentType)
1034                 anderArgumentTypes.append(*additionalArgumentType);
1035             if ((anderFunction = resolveFunction(m_program, propertyAccessExpression.possibleAnderOverloads(), anderArgumentTypes, propertyAccessExpression.anderFunctionName(), propertyAccessExpression.origin(), m_intrinsics)))
1036                 anderReturnType = &downcast<AST::PointerType>(anderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1037         }
1038     }
1039
1040     AST::FunctionDeclaration* threadAnderFunction = nullptr;
1041     AST::UnnamedType* threadAnderReturnType = nullptr;
1042     if (auto argumentTypeForAndOverload = WHLSL::argumentTypeForAndOverload(*baseUnnamedType, AST::AddressSpace::Thread)) {
1043         ResolvingType argumentType = { makeUniqueRef<AST::PointerType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, baseUnnamedType->get().clone()) };
1044         Vector<std::reference_wrapper<ResolvingType>> threadAnderArgumentTypes { argumentType };
1045         if (additionalArgumentType)
1046             threadAnderArgumentTypes.append(*additionalArgumentType);
1047         if ((threadAnderFunction = resolveFunction(m_program, propertyAccessExpression.possibleAnderOverloads(), threadAnderArgumentTypes, propertyAccessExpression.anderFunctionName(), propertyAccessExpression.origin(), m_intrinsics)))
1048             threadAnderReturnType = &downcast<AST::PointerType>(threadAnderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1049     }
1050
1051     if (leftAddressSpace && !anderFunction && !getterFunction) {
1052         setError();
1053         return;
1054     }
1055
1056     if (!leftAddressSpace && !threadAnderFunction && !getterFunction) {
1057         setError();
1058         return;
1059     }
1060
1061     if (threadAnderFunction && getterFunction) {
1062         setError();
1063         return;
1064     }
1065
1066     if (anderFunction && threadAnderFunction && !matches(*anderReturnType, *threadAnderReturnType)) {
1067         setError();
1068         return;
1069     }
1070
1071     if (getterFunction && anderFunction && !matches(*getterReturnType, *anderReturnType)) {
1072         setError();
1073         return;
1074     }
1075
1076     if (getterFunction && threadAnderFunction && !matches(*getterReturnType, *threadAnderReturnType)) {
1077         setError();
1078         return;
1079     }
1080
1081     AST::UnnamedType* fieldType = getterReturnType ? getterReturnType : anderReturnType ? anderReturnType : threadAnderReturnType;
1082
1083     AST::FunctionDeclaration* setterFunction = nullptr;
1084     AST::UnnamedType* setterReturnType = nullptr;
1085     {
1086         ResolvingType fieldResolvingType(fieldType->clone());
1087         Vector<std::reference_wrapper<ResolvingType>> setterArgumentTypes { baseInfo->resolvingType };
1088         if (additionalArgumentType)
1089             setterArgumentTypes.append(*additionalArgumentType);
1090         setterArgumentTypes.append(fieldResolvingType);
1091         setterFunction = resolveFunction(m_program, propertyAccessExpression.possibleSetterOverloads(), setterArgumentTypes, propertyAccessExpression.setterFunctionName(), propertyAccessExpression.origin(), m_intrinsics);
1092         if (setterFunction)
1093             setterReturnType = &setterFunction->type();
1094     }
1095
1096     if (setterFunction && !getterFunction) {
1097         setError();
1098         return;
1099     }
1100
1101     propertyAccessExpression.setGetterFunction(getterFunction);
1102     propertyAccessExpression.setAnderFunction(anderFunction);
1103     propertyAccessExpression.setThreadAnderFunction(threadAnderFunction);
1104     propertyAccessExpression.setSetterFunction(setterFunction);
1105
1106     AST::TypeAnnotation typeAnnotation = AST::RightValue();
1107     if (auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace()) {
1108         if (anderFunction)
1109             typeAnnotation = AST::LeftValue { *leftAddressSpace };
1110         else if (setterFunction)
1111             typeAnnotation = AST::AbstractLeftValue();
1112     } else if (!baseInfo->typeAnnotation.isRightValue() && (setterFunction || threadAnderFunction))
1113         typeAnnotation = AST::AbstractLeftValue();
1114     assignType(propertyAccessExpression, fieldType->clone(), WTFMove(typeAnnotation));
1115 }
1116
1117 void Checker::visit(AST::DotExpression& dotExpression)
1118 {
1119     finishVisiting(dotExpression);
1120 }
1121
1122 void Checker::visit(AST::IndexExpression& indexExpression)
1123 {
1124     auto baseInfo = recurseAndGetInfo(indexExpression.indexExpression());
1125     if (!baseInfo)
1126         return;
1127     finishVisiting(indexExpression, &baseInfo->resolvingType);
1128 }
1129
1130 void Checker::visit(AST::VariableReference& variableReference)
1131 {
1132     ASSERT(variableReference.variable());
1133     ASSERT(variableReference.variable()->type());
1134     
1135     AST::TypeAnnotation typeAnnotation = AST::RightValue();
1136     typeAnnotation = AST::LeftValue { AST::AddressSpace::Thread };
1137     assignType(variableReference, variableReference.variable()->type()->clone(), WTFMove(typeAnnotation));
1138 }
1139
1140 void Checker::visit(AST::Return& returnStatement)
1141 {
1142     ASSERT(returnStatement.function());
1143     if (returnStatement.value()) {
1144         auto valueInfo = recurseAndGetInfo(*returnStatement.value());
1145         if (!valueInfo)
1146             return;
1147         if (!matchAndCommit(valueInfo->resolvingType, returnStatement.function()->type()))
1148             setError();
1149         return;
1150     }
1151
1152     if (!matches(returnStatement.function()->type(), m_intrinsics.voidType()))
1153         setError();
1154 }
1155
1156 void Checker::visit(AST::PointerType&)
1157 {
1158     // Following pointer types can cause infinite loops because of data structures
1159     // like linked lists.
1160     // FIXME: Make sure this function should be empty
1161 }
1162
1163 void Checker::visit(AST::ArrayReferenceType&)
1164 {
1165     // Following array reference types can cause infinite loops because of data
1166     // structures like linked lists.
1167     // FIXME: Make sure this function should be empty
1168 }
1169
1170 void Checker::visit(AST::IntegerLiteral& integerLiteral)
1171 {
1172     assignType(integerLiteral, adoptRef(*new ResolvableTypeReference(integerLiteral.type())));
1173 }
1174
1175 void Checker::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
1176 {
1177     assignType(unsignedIntegerLiteral, adoptRef(*new ResolvableTypeReference(unsignedIntegerLiteral.type())));
1178 }
1179
1180 void Checker::visit(AST::FloatLiteral& floatLiteral)
1181 {
1182     assignType(floatLiteral, adoptRef(*new ResolvableTypeReference(floatLiteral.type())));
1183 }
1184
1185 void Checker::visit(AST::NullLiteral& nullLiteral)
1186 {
1187     assignType(nullLiteral, adoptRef(*new ResolvableTypeReference(nullLiteral.type())));
1188 }
1189
1190 void Checker::visit(AST::BooleanLiteral& booleanLiteral)
1191 {
1192     assignType(booleanLiteral, AST::TypeReference::wrap(Lexer::Token(booleanLiteral.origin()), m_intrinsics.boolType()));
1193 }
1194
1195 void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
1196 {
1197     ASSERT(enumerationMemberLiteral.enumerationDefinition());
1198     auto& enumerationDefinition = *enumerationMemberLiteral.enumerationDefinition();
1199     assignType(enumerationMemberLiteral, AST::TypeReference::wrap(Lexer::Token(enumerationMemberLiteral.origin()), enumerationDefinition));
1200 }
1201
1202 bool Checker::isBoolType(ResolvingType& resolvingType)
1203 {
1204     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> bool {
1205         return matches(left, m_intrinsics.boolType());
1206     }, [&](RefPtr<ResolvableTypeReference>& left) -> bool {
1207         return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType()));
1208     }));
1209 }
1210
1211 bool Checker::recurseAndRequireBoolType(AST::Expression& expression)
1212 {
1213     auto expressionInfo = recurseAndGetInfo(expression);
1214     if (!expressionInfo)
1215         return false;
1216     if (!isBoolType(expressionInfo->resolvingType)) {
1217         setError();
1218         return false;
1219     }
1220     return true;
1221 }
1222
1223 void Checker::visit(AST::LogicalNotExpression& logicalNotExpression)
1224 {
1225     if (!recurseAndRequireBoolType(logicalNotExpression.operand()))
1226         return;
1227     assignType(logicalNotExpression, AST::TypeReference::wrap(Lexer::Token(logicalNotExpression.origin()), m_intrinsics.boolType()));
1228 }
1229
1230 void Checker::visit(AST::LogicalExpression& logicalExpression)
1231 {
1232     if (!recurseAndRequireBoolType(logicalExpression.left()))
1233         return;
1234     if (!recurseAndRequireBoolType(logicalExpression.right()))
1235         return;
1236     assignType(logicalExpression, AST::TypeReference::wrap(Lexer::Token(logicalExpression.origin()), m_intrinsics.boolType()));
1237 }
1238
1239 void Checker::visit(AST::IfStatement& ifStatement)
1240 {
1241     if (!recurseAndRequireBoolType(ifStatement.conditional()))
1242         return;
1243     checkErrorAndVisit(ifStatement.body());
1244     if (ifStatement.elseBody())
1245         checkErrorAndVisit(*ifStatement.elseBody());
1246 }
1247
1248 void Checker::visit(AST::WhileLoop& whileLoop)
1249 {
1250     if (!recurseAndRequireBoolType(whileLoop.conditional()))
1251         return;
1252     checkErrorAndVisit(whileLoop.body());
1253 }
1254
1255 void Checker::visit(AST::DoWhileLoop& doWhileLoop)
1256 {
1257     checkErrorAndVisit(doWhileLoop.body());
1258     recurseAndRequireBoolType(doWhileLoop.conditional());
1259 }
1260
1261 void Checker::visit(AST::ForLoop& forLoop)
1262 {
1263     WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::Statement>& statement) {
1264         checkErrorAndVisit(statement);
1265     }, [&](UniqueRef<AST::Expression>& expression) {
1266         checkErrorAndVisit(expression);
1267     }), forLoop.initialization());
1268     if (error())
1269         return;
1270     if (forLoop.condition()) {
1271         if (!recurseAndRequireBoolType(*forLoop.condition()))
1272             return;
1273     }
1274     if (forLoop.increment())
1275         checkErrorAndVisit(*forLoop.increment());
1276     checkErrorAndVisit(forLoop.body());
1277 }
1278
1279 void Checker::visit(AST::SwitchStatement& switchStatement)
1280 {
1281     auto* valueType = ([&]() -> AST::NamedType* {
1282         auto valueInfo = recurseAndGetInfo(switchStatement.value());
1283         if (!valueInfo)
1284             return nullptr;
1285         auto* valueType = getUnnamedType(valueInfo->resolvingType);
1286         if (!valueType)
1287             return nullptr;
1288         auto& valueUnifyNode = valueType->unifyNode();
1289         if (!is<AST::NamedType>(valueUnifyNode))
1290             return nullptr;
1291         auto& valueNamedUnifyNode = downcast<AST::NamedType>(valueUnifyNode);
1292         if (!(is<AST::NativeTypeDeclaration>(valueNamedUnifyNode) && downcast<AST::NativeTypeDeclaration>(valueNamedUnifyNode).isInt())
1293             && !is<AST::EnumerationDefinition>(valueNamedUnifyNode))
1294             return nullptr;
1295         return &valueNamedUnifyNode;
1296     })();
1297     if (!valueType) {
1298         setError();
1299         return;
1300     }
1301
1302     bool hasDefault = false;
1303     for (auto& switchCase : switchStatement.switchCases()) {
1304         checkErrorAndVisit(switchCase.block());
1305         if (!switchCase.value()) {
1306             hasDefault = true;
1307             continue;
1308         }
1309         auto success = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> bool {
1310             return static_cast<bool>(matchAndCommit(*valueType, integerLiteral.type()));
1311         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> bool {
1312             return static_cast<bool>(matchAndCommit(*valueType, unsignedIntegerLiteral.type()));
1313         }, [&](AST::FloatLiteral& floatLiteral) -> bool {
1314             return static_cast<bool>(matchAndCommit(*valueType, floatLiteral.type()));
1315         }, [&](AST::NullLiteral& nullLiteral) -> bool {
1316             return static_cast<bool>(matchAndCommit(*valueType, nullLiteral.type()));
1317         }, [&](AST::BooleanLiteral&) -> bool {
1318             return matches(*valueType, m_intrinsics.boolType());
1319         }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> bool {
1320             ASSERT(enumerationMemberLiteral.enumerationDefinition());
1321             return matches(*valueType, *enumerationMemberLiteral.enumerationDefinition());
1322         }));
1323         if (!success) {
1324             setError();
1325             return;
1326         }
1327     }
1328
1329     for (size_t i = 0; i < switchStatement.switchCases().size(); ++i) {
1330         auto& firstCase = switchStatement.switchCases()[i];
1331         for (size_t j = i + 1; j < switchStatement.switchCases().size(); ++j) {
1332             auto& secondCase = switchStatement.switchCases()[j];
1333             
1334             if (static_cast<bool>(firstCase.value()) != static_cast<bool>(secondCase.value()))
1335                 continue;
1336
1337             if (!static_cast<bool>(firstCase.value())) {
1338                 setError();
1339                 return;
1340             }
1341
1342             auto success = firstCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& firstIntegerLiteral) -> bool {
1343                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1344                     return firstIntegerLiteral.value() != secondIntegerLiteral.value();
1345                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1346                     return static_cast<int64_t>(firstIntegerLiteral.value()) != static_cast<int64_t>(secondUnsignedIntegerLiteral.value());
1347                 }, [](auto&) -> bool {
1348                     return true;
1349                 }));
1350             }, [&](AST::UnsignedIntegerLiteral& firstUnsignedIntegerLiteral) -> bool {
1351                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1352                     return static_cast<int64_t>(firstUnsignedIntegerLiteral.value()) != static_cast<int64_t>(secondIntegerLiteral.value());
1353                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1354                     return firstUnsignedIntegerLiteral.value() != secondUnsignedIntegerLiteral.value();
1355                 }, [](auto&) -> bool {
1356                     return true;
1357                 }));
1358             }, [&](AST::EnumerationMemberLiteral& firstEnumerationMemberLiteral) -> bool {
1359                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& secondEnumerationMemberLiteral) -> bool {
1360                     ASSERT(firstEnumerationMemberLiteral.enumerationMember());
1361                     ASSERT(secondEnumerationMemberLiteral.enumerationMember());
1362                     return firstEnumerationMemberLiteral.enumerationMember() != secondEnumerationMemberLiteral.enumerationMember();
1363                 }, [](auto&) -> bool {
1364                     return true;
1365                 }));
1366             }, [](auto&) -> bool {
1367                 return true;
1368             }));
1369             if (!success) {
1370                 setError();
1371                 return;
1372             }
1373         }
1374     }
1375
1376     if (!hasDefault) {
1377         if (is<AST::NativeTypeDeclaration>(*valueType)) {
1378             HashSet<int64_t> values;
1379             bool zeroValueExists;
1380             for (auto& switchCase : switchStatement.switchCases()) {
1381                 auto value = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
1382                     return integerLiteral.valueForSelectedType();
1383                 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
1384                     return unsignedIntegerLiteral.valueForSelectedType();
1385                 }, [](auto&) -> int64_t {
1386                     ASSERT_NOT_REACHED();
1387                     return 0;
1388                 }));
1389                 if (!value)
1390                     zeroValueExists = true;
1391                 else
1392                     values.add(value);
1393             }
1394             bool success = true;
1395             downcast<AST::NativeTypeDeclaration>(*valueType).iterateAllValues([&](int64_t value) -> bool {
1396                 if (!value) {
1397                     if (!zeroValueExists) {
1398                         success = false;
1399                         return true;
1400                     }
1401                     return false;
1402                 }
1403                 if (!values.contains(value)) {
1404                     success = false;
1405                     return true;
1406                 }
1407                 return false;
1408             });
1409             if (!success) {
1410                 setError();
1411                 return;
1412             }
1413         } else {
1414             HashSet<AST::EnumerationMember*> values;
1415             for (auto& switchCase : switchStatement.switchCases()) {
1416                 switchCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
1417                     ASSERT(enumerationMemberLiteral.enumerationMember());
1418                     values.add(enumerationMemberLiteral.enumerationMember());
1419                 }, [](auto&) {
1420                     ASSERT_NOT_REACHED();
1421                 }));
1422             }
1423             for (auto& enumerationMember : downcast<AST::EnumerationDefinition>(*valueType).enumerationMembers()) {
1424                 if (!values.contains(&enumerationMember.get())) {
1425                     setError();
1426                     return;
1427                 }
1428             }
1429         }
1430     }
1431 }
1432
1433 void Checker::visit(AST::CommaExpression& commaExpression)
1434 {
1435     ASSERT(commaExpression.list().size() > 0);
1436     Visitor::visit(commaExpression);
1437     if (error())
1438         return;
1439     auto lastInfo = getInfo(commaExpression.list().last());
1440     forwardType(commaExpression, lastInfo->resolvingType);
1441 }
1442
1443 void Checker::visit(AST::TernaryExpression& ternaryExpression)
1444 {
1445     auto predicateInfo = recurseAndRequireBoolType(ternaryExpression.predicate());
1446     if (!predicateInfo)
1447         return;
1448
1449     auto bodyInfo = recurseAndGetInfo(ternaryExpression.bodyExpression());
1450     auto elseInfo = recurseAndGetInfo(ternaryExpression.elseExpression());
1451     
1452     auto resultType = matchAndCommit(bodyInfo->resolvingType, elseInfo->resolvingType);
1453     if (!resultType) {
1454         setError();
1455         return;
1456     }
1457
1458     assignType(ternaryExpression, WTFMove(*resultType));
1459 }
1460
1461 void Checker::visit(AST::CallExpression& callExpression)
1462 {
1463     Vector<std::reference_wrapper<ResolvingType>> types;
1464     types.reserveInitialCapacity(callExpression.arguments().size());
1465     for (auto& argument : callExpression.arguments()) {
1466         auto argumentInfo = recurseAndGetInfo(argument);
1467         if (!argumentInfo)
1468             return;
1469         types.uncheckedAppend(argumentInfo->resolvingType);
1470     }
1471     // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
1472     // We don't want to recurse to the same node twice.
1473
1474     ASSERT(callExpression.hasOverloads());
1475     auto* function = resolveFunction(m_program, *callExpression.overloads(), types, callExpression.name(), callExpression.origin(), m_intrinsics, callExpression.castReturnType());
1476     if (!function) {
1477         setError();
1478         return;
1479     }
1480
1481     for (size_t i = 0; i < function->parameters().size(); ++i) {
1482         if (!matchAndCommit(types[i].get(), *function->parameters()[i]->type())) {
1483             setError();
1484             return;
1485         }
1486     }
1487
1488     callExpression.setFunction(*function);
1489
1490     assignType(callExpression, function->type().clone());
1491 }
1492
1493 bool check(Program& program)
1494 {
1495     Checker checker(program.intrinsics(), program);
1496     checker.checkErrorAndVisit(program);
1497     if (checker.error())
1498         return false;
1499     return checker.assignTypes();
1500 }
1501
1502 } // namespace WHLSL
1503
1504 } // namespace WebCore
1505
1506 #endif // ENABLE(WEBGPU)