d894b920e66746b867427f74dd9adbd3ad7dff06
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLChecker.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLChecker.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLArrayReferenceType.h"
32 #include "WHLSLArrayType.h"
33 #include "WHLSLAssignmentExpression.h"
34 #include "WHLSLCallExpression.h"
35 #include "WHLSLCommaExpression.h"
36 #include "WHLSLDereferenceExpression.h"
37 #include "WHLSLDoWhileLoop.h"
38 #include "WHLSLDotExpression.h"
39 #include "WHLSLEntryPointType.h"
40 #include "WHLSLForLoop.h"
41 #include "WHLSLGatherEntryPointItems.h"
42 #include "WHLSLIfStatement.h"
43 #include "WHLSLIndexExpression.h"
44 #include "WHLSLInferTypes.h"
45 #include "WHLSLLogicalExpression.h"
46 #include "WHLSLLogicalNotExpression.h"
47 #include "WHLSLMakeArrayReferenceExpression.h"
48 #include "WHLSLMakePointerExpression.h"
49 #include "WHLSLPointerType.h"
50 #include "WHLSLProgram.h"
51 #include "WHLSLReadModifyWriteExpression.h"
52 #include "WHLSLResolvableType.h"
53 #include "WHLSLResolveOverloadImpl.h"
54 #include "WHLSLResolvingType.h"
55 #include "WHLSLReturn.h"
56 #include "WHLSLSwitchStatement.h"
57 #include "WHLSLTernaryExpression.h"
58 #include "WHLSLVisitor.h"
59 #include "WHLSLWhileLoop.h"
60 #include <wtf/HashMap.h>
61 #include <wtf/HashSet.h>
62 #include <wtf/Ref.h>
63 #include <wtf/Vector.h>
64 #include <wtf/text/WTFString.h>
65
66 namespace WebCore {
67
68 namespace WHLSL {
69
70 class PODChecker : public Visitor {
71 public:
72     PODChecker() = default;
73
74     virtual ~PODChecker() = default;
75
76     void visit(AST::EnumerationDefinition& enumerationDefinition) override
77     {
78         Visitor::visit(enumerationDefinition);
79     }
80
81     void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override
82     {
83         if (!nativeTypeDeclaration.isNumber()
84             && !nativeTypeDeclaration.isVector()
85             && !nativeTypeDeclaration.isMatrix())
86             setError();
87     }
88
89     void visit(AST::StructureDefinition& structureDefinition) override
90     {
91         Visitor::visit(structureDefinition);
92     }
93
94     void visit(AST::TypeDefinition& typeDefinition) override
95     {
96         Visitor::visit(typeDefinition);
97     }
98
99     void visit(AST::ArrayType& arrayType) override
100     {
101         Visitor::visit(arrayType);
102     }
103
104     void visit(AST::PointerType&) override
105     {
106         setError();
107     }
108
109     void visit(AST::ArrayReferenceType&) override
110     {
111         setError();
112     }
113
114     void visit(AST::TypeReference& typeReference) override
115     {
116         checkErrorAndVisit(typeReference.resolvedType());
117     }
118 };
119
120 static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(Lexer::Token origin, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics)
121 {
122     const bool isOperator = true;
123     auto returnType = makeUniqueRef<AST::PointerType>(Lexer::Token(origin), firstArgument.addressSpace(), firstArgument.elementType().clone());
124     AST::VariableDeclarations parameters;
125     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), firstArgument.clone(), String(), WTF::nullopt, WTF::nullopt));
126     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), UniqueRef<AST::UnnamedType>(AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.uintType())), String(), WTF::nullopt, WTF::nullopt));
127     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator&[]", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
128 }
129
130 static AST::NativeFunctionDeclaration resolveWithOperatorLength(Lexer::Token origin, AST::UnnamedType& firstArgument, const Intrinsics& intrinsics)
131 {
132     const bool isOperator = true;
133     auto returnType = AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.uintType());
134     AST::VariableDeclarations parameters;
135     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), firstArgument.clone(), String(), WTF::nullopt, WTF::nullopt));
136     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator.length", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
137 }
138
139 static AST::NativeFunctionDeclaration resolveWithReferenceComparator(Lexer::Token origin, ResolvingType& firstArgument, ResolvingType& secondArgument, const Intrinsics& intrinsics)
140 {
141     const bool isOperator = true;
142     auto returnType = AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.boolType());
143     auto argumentType = firstArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
144         return unnamedType->clone();
145     }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
146         return secondArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
147             return unnamedType->clone();
148         }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
149             // We encountered "null == null".
150             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198162 This can probably be generalized, using the "preferred type" infrastructure used by generic literals
151             ASSERT_NOT_REACHED();
152             return AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.intType());
153         }));
154     }));
155     AST::VariableDeclarations parameters;
156     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), argumentType->clone(), String(), WTF::nullopt, WTF::nullopt));
157     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), UniqueRef<AST::UnnamedType>(WTFMove(argumentType)), String(), WTF::nullopt, WTF::nullopt));
158     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
159 }
160
161 enum class Acceptability {
162     Yes,
163     Maybe,
164     No
165 };
166
167 static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(const String& name, Lexer::Token origin, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics)
168 {
169     if (name == "operator&[]" && types.size() == 2) {
170         auto* firstArgumentArrayRef = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* {
171             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
172                 return &downcast<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType));
173             return nullptr;
174         }, [](RefPtr<ResolvableTypeReference>&) -> AST::ArrayReferenceType* {
175             return nullptr;
176         }));
177         bool secondArgumentIsUint = types[1].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
178             return matches(unnamedType, intrinsics.uintType());
179         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
180             return resolvableTypeReference->resolvableType().canResolve(intrinsics.uintType());
181         }));
182         if (firstArgumentArrayRef && secondArgumentIsUint)
183             return resolveWithOperatorAnderIndexer(origin, *firstArgumentArrayRef, intrinsics);
184     } else if (name == "operator.length" && types.size() == 1) {
185         auto* firstArgumentReference = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
186             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
187                 return &unnamedType;
188             return nullptr;
189         }, [](RefPtr<ResolvableTypeReference>&) -> AST::UnnamedType* {
190             return nullptr;
191         }));
192         if (firstArgumentReference)
193             return resolveWithOperatorLength(origin, *firstArgumentReference, intrinsics);
194     } else if (name == "operator==" && types.size() == 2) {
195         auto acceptability = [](ResolvingType& resolvingType) -> Acceptability {
196             return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> Acceptability {
197                 return is<AST::ReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) ? Acceptability::Yes : Acceptability::No;
198             }, [](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Acceptability {
199                 return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType()) ? Acceptability::Maybe : Acceptability::No;
200             }));
201         };
202         auto leftAcceptability = acceptability(types[0].get());
203         auto rightAcceptability = acceptability(types[1].get());
204         bool success = false;
205         if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) {
206             auto& unnamedType1 = *types[0].get().getUnnamedType();
207             auto& unnamedType2 = *types[1].get().getUnnamedType();
208             success = matches(unnamedType1, unnamedType2);
209         } else if ((leftAcceptability == Acceptability::Maybe && rightAcceptability == Acceptability::Yes)
210             || (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Maybe))
211             success = true;
212         if (success)
213             return resolveWithReferenceComparator(origin, types[0].get(), types[1].get(), intrinsics);
214     }
215     return WTF::nullopt;
216 }
217
218 static AST::FunctionDeclaration* resolveFunction(Program& program, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>& possibleOverloads, Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, Lexer::Token origin, const Intrinsics& intrinsics, AST::NamedType* castReturnType = nullptr)
219 {
220     if (AST::FunctionDeclaration* function = resolveFunctionOverload(possibleOverloads, types, castReturnType))
221         return function;
222
223     if (auto newFunction = resolveByInstantiation(name, origin, types, intrinsics)) {
224         program.append(WTFMove(*newFunction));
225         return &program.nativeFunctionDeclarations().last();
226     }
227
228     return nullptr;
229 }
230
231 static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
232 {
233     {
234         auto checkDuplicateSemantics = [&](const Vector<EntryPointItem>& items) -> bool {
235             for (size_t i = 0; i < items.size(); ++i) {
236                 for (size_t j = i + 1; j < items.size(); ++j) {
237                     if (items[i].semantic == items[j].semantic)
238                         return false;
239                 }
240             }
241             return true;
242         };
243         if (!checkDuplicateSemantics(inputItems))
244             return false;
245         if (!checkDuplicateSemantics(outputItems))
246             return false;
247     }
248
249     {
250         auto checkSemanticTypes = [&](const Vector<EntryPointItem>& items) -> bool {
251             for (auto& item : items) {
252                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
253                     return semantic.isAcceptableType(*item.unnamedType, intrinsics);
254                 }), *item.semantic);
255                 if (!acceptable)
256                     return false;
257             }
258             return true;
259         };
260         if (!checkSemanticTypes(inputItems))
261             return false;
262         if (!checkSemanticTypes(outputItems))
263             return false;
264     }
265
266     {
267         auto checkSemanticForShaderType = [&](const Vector<EntryPointItem>& items, AST::BaseSemantic::ShaderItemDirection direction) -> bool {
268             for (auto& item : items) {
269                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
270                     return semantic.isAcceptableForShaderItemDirection(direction, entryPointType);
271                 }), *item.semantic);
272                 if (!acceptable)
273                     return false;
274             }
275             return true;
276         };
277         if (!checkSemanticForShaderType(inputItems, AST::BaseSemantic::ShaderItemDirection::Input))
278             return false;
279         if (!checkSemanticForShaderType(outputItems, AST::BaseSemantic::ShaderItemDirection::Output))
280             return false;
281     }
282
283     {
284         auto checkPODData = [&](const Vector<EntryPointItem>& items) -> bool {
285             for (auto& item : items) {
286                 PODChecker podChecker;
287                 if (is<AST::PointerType>(item.unnamedType))
288                     podChecker.checkErrorAndVisit(downcast<AST::PointerType>(*item.unnamedType).elementType());
289                 else if (is<AST::ArrayReferenceType>(item.unnamedType))
290                     podChecker.checkErrorAndVisit(downcast<AST::ArrayReferenceType>(*item.unnamedType).elementType());
291                 else if (is<AST::ArrayType>(item.unnamedType))
292                     podChecker.checkErrorAndVisit(downcast<AST::ArrayType>(*item.unnamedType).type());
293                 else
294                     continue;
295                 if (podChecker.error())
296                     return false;
297             }
298             return true;
299         };
300         if (!checkPODData(inputItems))
301             return false;
302         if (!checkPODData(outputItems))
303             return false;
304     }
305
306     return true;
307 }
308
309 static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinition, const Intrinsics& intrinsics, NameContext& nameContext)
310 {
311     enum class CheckKind {
312         Index,
313         Dot
314     };
315
316     auto checkGetter = [&](CheckKind kind) -> bool {
317         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
318         if (functionDefinition.parameters().size() != numExpectedParameters)
319             return false;
320         auto& firstParameterUnifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
321         if (is<AST::UnnamedType>(firstParameterUnifyNode)) {
322             auto& unnamedType = downcast<AST::UnnamedType>(firstParameterUnifyNode);
323             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
324                 return false;
325         }
326         if (kind == CheckKind::Index) {
327             auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1]->type())->unifyNode();
328             if (!is<AST::NamedType>(secondParameterUnifyNode))
329                 return false;
330             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
331             if (!is<AST::NativeTypeDeclaration>(namedType))
332                 return false;
333             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
334             if (!nativeTypeDeclaration.isInt())
335                 return false;
336         }
337         return true;
338     };
339
340     auto checkSetter = [&](CheckKind kind) -> bool {
341         size_t numExpectedParameters = kind == CheckKind::Index ? 3 : 2;
342         if (functionDefinition.parameters().size() != numExpectedParameters)
343             return false;
344         auto& firstArgumentUnifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
345         if (is<AST::UnnamedType>(firstArgumentUnifyNode)) {
346             auto& unnamedType = downcast<AST::UnnamedType>(firstArgumentUnifyNode);
347             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
348                 return false;
349         }
350         if (kind == CheckKind::Index) {
351             auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1]->type())->unifyNode();
352             if (!is<AST::NamedType>(secondParameterUnifyNode))
353                 return false;
354             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
355             if (!is<AST::NativeTypeDeclaration>(namedType))
356                 return false;
357             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
358             if (!nativeTypeDeclaration.isInt())
359                 return false;
360         }
361         if (!matches(functionDefinition.type(), *functionDefinition.parameters()[0]->type()))
362             return false;
363         auto& valueType = *functionDefinition.parameters()[numExpectedParameters - 1]->type();
364         auto getterName = functionDefinition.name().substring(0, functionDefinition.name().length() - 1);
365         auto* getterFuncs = nameContext.getFunctions(getterName);
366         if (!getterFuncs)
367             return false;
368         Vector<ResolvingType> argumentTypes;
369         Vector<std::reference_wrapper<ResolvingType>> argumentTypeReferences;
370         for (size_t i = 0; i < numExpectedParameters - 1; ++i)
371             argumentTypes.append((*functionDefinition.parameters()[i]->type())->clone());
372         for (auto& argumentType : argumentTypes)
373             argumentTypeReferences.append(argumentType);
374         auto* overload = resolveFunctionOverload(*getterFuncs, argumentTypeReferences);
375         if (!overload)
376             return false;
377         auto& resultType = overload->type();
378         return matches(resultType, valueType);
379     };
380
381     auto checkAnder = [&](CheckKind kind) -> bool {
382         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
383         if (functionDefinition.parameters().size() != numExpectedParameters)
384             return false;
385         {
386             auto& unifyNode = functionDefinition.type().unifyNode();
387             if (!is<AST::UnnamedType>(unifyNode))
388                 return false;
389             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
390             if (!is<AST::PointerType>(unnamedType))
391                 return false;
392         }
393         {
394             auto& unifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
395             if (!is<AST::UnnamedType>(unifyNode))
396                 return false;
397             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
398             return is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType);
399         }
400     };
401
402     if (!functionDefinition.isOperator())
403         return true;
404     if (functionDefinition.isCast())
405         return true;
406     if (functionDefinition.name() == "operator++" || functionDefinition.name() == "operator--") {
407         return functionDefinition.parameters().size() == 1
408             && matches(*functionDefinition.parameters()[0]->type(), functionDefinition.type());
409     }
410     if (functionDefinition.name() == "operator+" || functionDefinition.name() == "operator-")
411         return functionDefinition.parameters().size() == 1 || functionDefinition.parameters().size() == 2;
412     if (functionDefinition.name() == "operator*"
413         || functionDefinition.name() == "operator/"
414         || functionDefinition.name() == "operator%"
415         || functionDefinition.name() == "operator&"
416         || functionDefinition.name() == "operator|"
417         || functionDefinition.name() == "operator^"
418         || functionDefinition.name() == "operator<<"
419         || functionDefinition.name() == "opreator>>")
420         return functionDefinition.parameters().size() == 2;
421     if (functionDefinition.name() == "operator~")
422         return functionDefinition.parameters().size() == 1;
423     if (functionDefinition.name() == "operator=="
424         || functionDefinition.name() == "operator<"
425         || functionDefinition.name() == "operator<="
426         || functionDefinition.name() == "operator>"
427         || functionDefinition.name() == "operator>=") {
428         return functionDefinition.parameters().size() == 2
429             && matches(functionDefinition.type(), intrinsics.boolType());
430     }
431     if (functionDefinition.name() == "operator[]")
432         return checkGetter(CheckKind::Index);
433     if (functionDefinition.name() == "operator[]=")
434         return checkSetter(CheckKind::Index);
435     if (functionDefinition.name() == "operator&[]")
436         return checkAnder(CheckKind::Index);
437     if (functionDefinition.name().startsWith("operator.")) {
438         if (functionDefinition.name().endsWith("="))
439             return checkSetter(CheckKind::Dot);
440         return checkGetter(CheckKind::Dot);
441     }
442     if (functionDefinition.name().startsWith("operator&."))
443         return checkAnder(CheckKind::Dot);
444     return false;
445 }
446
447 class Checker : public Visitor {
448 public:
449     Checker(const Intrinsics& intrinsics, Program& program)
450         : m_intrinsics(intrinsics)
451         , m_program(program)
452     {
453     }
454
455     ~Checker() = default;
456
457     void visit(Program&) override;
458
459     bool assignTypes();
460
461 private:
462     bool checkShaderType(const AST::FunctionDefinition&);
463     bool isBoolType(ResolvingType&);
464     struct RecurseInfo {
465         ResolvingType& resolvingType;
466         AST::TypeAnnotation& typeAnnotation;
467     };
468     Optional<RecurseInfo> recurseAndGetInfo(AST::Expression&, bool requiresLeftValue = false);
469     Optional<RecurseInfo> getInfo(AST::Expression&, bool requiresLeftValue = false);
470     Optional<UniqueRef<AST::UnnamedType>> recurseAndWrapBaseType(AST::PropertyAccessExpression&);
471     bool recurseAndRequireBoolType(AST::Expression&);
472     void assignType(AST::Expression&, UniqueRef<AST::UnnamedType>&&, AST::TypeAnnotation);
473     void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>&&, AST::TypeAnnotation);
474     void forwardType(AST::Expression&, ResolvingType&, AST::TypeAnnotation);
475
476     void visit(AST::FunctionDefinition&) override;
477     void visit(AST::EnumerationDefinition&) override;
478     void visit(AST::TypeReference&) override;
479     void visit(AST::VariableDeclaration&) override;
480     void visit(AST::AssignmentExpression&) override;
481     void visit(AST::ReadModifyWriteExpression&) override;
482     void visit(AST::DereferenceExpression&) override;
483     void visit(AST::MakePointerExpression&) override;
484     void visit(AST::MakeArrayReferenceExpression&) override;
485     void visit(AST::DotExpression&) override;
486     void visit(AST::IndexExpression&) override;
487     void visit(AST::VariableReference&) override;
488     void visit(AST::Return&) override;
489     void visit(AST::PointerType&) override;
490     void visit(AST::ArrayReferenceType&) override;
491     void visit(AST::IntegerLiteral&) override;
492     void visit(AST::UnsignedIntegerLiteral&) override;
493     void visit(AST::FloatLiteral&) override;
494     void visit(AST::NullLiteral&) override;
495     void visit(AST::BooleanLiteral&) override;
496     void visit(AST::EnumerationMemberLiteral&) override;
497     void visit(AST::LogicalNotExpression&) override;
498     void visit(AST::LogicalExpression&) override;
499     void visit(AST::IfStatement&) override;
500     void visit(AST::WhileLoop&) override;
501     void visit(AST::DoWhileLoop&) override;
502     void visit(AST::ForLoop&) override;
503     void visit(AST::SwitchStatement&) override;
504     void visit(AST::CommaExpression&) override;
505     void visit(AST::TernaryExpression&) override;
506     void visit(AST::CallExpression&) override;
507
508     void finishVisiting(AST::PropertyAccessExpression&, ResolvingType* additionalArgumentType = nullptr);
509
510     HashMap<AST::Expression*, std::unique_ptr<ResolvingType>> m_typeMap;
511     HashMap<AST::Expression*, AST::TypeAnnotation> m_typeAnnotations;
512     HashSet<String> m_vertexEntryPoints;
513     HashSet<String> m_fragmentEntryPoints;
514     HashSet<String> m_computeEntryPoints;
515     const Intrinsics& m_intrinsics;
516     Program& m_program;
517 };
518
519 void Checker::visit(Program& program)
520 {
521     // These visiting functions might add new global statements, so don't use foreach syntax.
522     for (size_t i = 0; i < program.typeDefinitions().size(); ++i)
523         checkErrorAndVisit(program.typeDefinitions()[i]);
524     for (size_t i = 0; i < program.structureDefinitions().size(); ++i)
525         checkErrorAndVisit(program.structureDefinitions()[i]);
526     for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i)
527         checkErrorAndVisit(program.enumerationDefinitions()[i]);
528     for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i)
529         checkErrorAndVisit(program.nativeTypeDeclarations()[i]);
530
531     for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
532         checkErrorAndVisit(program.functionDefinitions()[i]);
533     for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i)
534         checkErrorAndVisit(program.nativeFunctionDeclarations()[i]);
535 }
536
537 bool Checker::assignTypes()
538 {
539     for (auto& keyValuePair : m_typeMap) {
540         auto success = keyValuePair.value->visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
541             keyValuePair.key->setType(unnamedType->clone());
542             return true;
543         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
544             if (!resolvableTypeReference->resolvableType().maybeResolvedType()) {
545                 if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType())))
546                     return false;
547             }
548             keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType().clone());
549             return true;
550         }));
551         if (!success)
552             return false;
553     }
554
555     for (auto& keyValuePair : m_typeAnnotations)
556         keyValuePair.key->setTypeAnnotation(WTFMove(keyValuePair.value));
557     return true;
558 }
559
560 bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition)
561 {
562     switch (*functionDefinition.entryPointType()) {
563     case AST::EntryPointType::Vertex:
564         return static_cast<bool>(m_vertexEntryPoints.add(functionDefinition.name()));
565     case AST::EntryPointType::Fragment:
566         return static_cast<bool>(m_fragmentEntryPoints.add(functionDefinition.name()));
567     case AST::EntryPointType::Compute:
568         return static_cast<bool>(m_computeEntryPoints.add(functionDefinition.name()));
569     }
570 }
571
572 void Checker::visit(AST::FunctionDefinition& functionDefinition)
573 {
574     if (functionDefinition.entryPointType()) {
575         if (!checkShaderType(functionDefinition)) {
576             setError();
577             return;
578         }
579         auto entryPointItems = gatherEntryPointItems(m_intrinsics, functionDefinition);
580         if (!entryPointItems) {
581             setError();
582             return;
583         }
584         if (!checkSemantics(entryPointItems->inputs, entryPointItems->outputs, functionDefinition.entryPointType(), m_intrinsics)) {
585             setError();
586             return;
587         }
588     }
589     if (!checkOperatorOverload(functionDefinition, m_intrinsics, m_program.nameContext())) {
590         setError();
591         return;
592     }
593
594     Visitor::visit(functionDefinition);
595 }
596
597 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& left, ResolvingType& right)
598 {
599     return left.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
600         return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
601             if (matches(left, right))
602                 return left->clone();
603             return WTF::nullopt;
604         }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
605             return matchAndCommit(left, right->resolvableType());
606         }));
607     }, [&](RefPtr<ResolvableTypeReference>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
608         return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
609             return matchAndCommit(right, left->resolvableType());
610         }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
611             return matchAndCommit(left->resolvableType(), right->resolvableType());
612         }));
613     }));
614 }
615
616 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType)
617 {
618     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
619         if (matches(unnamedType, resolvingType))
620             return unnamedType.clone();
621         return WTF::nullopt;
622     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
623         return matchAndCommit(unnamedType, resolvingType->resolvableType());
624     }));
625 }
626
627 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::NamedType& namedType)
628 {
629     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
630         if (matches(resolvingType, namedType))
631             return resolvingType->clone();
632         return WTF::nullopt;
633     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
634         return matchAndCommit(namedType, resolvingType->resolvableType());
635     }));
636 }
637
638 static Optional<UniqueRef<AST::UnnamedType>> commit(ResolvingType& resolvingType)
639 {
640     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> Optional<UniqueRef<AST::UnnamedType>> {
641         return unnamedType->clone();
642     }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Optional<UniqueRef<AST::UnnamedType>> {
643         if (!resolvableTypeReference->resolvableType().maybeResolvedType())
644             return commit(resolvableTypeReference->resolvableType());
645         return resolvableTypeReference->resolvableType().resolvedType().clone();
646     }));
647 }
648
649 void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
650 {
651     auto* baseType = ([&]() -> AST::NativeTypeDeclaration* {
652         checkErrorAndVisit(enumerationDefinition.type());
653         auto& baseType = enumerationDefinition.type().unifyNode();
654         if (!is<AST::NamedType>(baseType))
655             return nullptr;
656         auto& namedType = downcast<AST::NamedType>(baseType);
657         if (!is<AST::NativeTypeDeclaration>(namedType))
658             return nullptr;
659         auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
660         if (!nativeTypeDeclaration.isInt())
661             return nullptr;
662         return &nativeTypeDeclaration;
663     })();
664     if (!baseType) {
665         setError();
666         return;
667     }
668
669     auto enumerationMembers = enumerationDefinition.enumerationMembers();
670
671     auto matchAndCommitMember = [&](AST::EnumerationMember& member) -> bool {
672         return member.value()->visit(WTF::makeVisitor([&](AST::Expression& value) -> bool {
673             auto valueInfo = recurseAndGetInfo(value);
674             if (!valueInfo)
675                 return false;
676             return static_cast<bool>(matchAndCommit(valueInfo->resolvingType, *baseType));
677         }));
678     };
679
680     for (auto& member : enumerationMembers) {
681         if (!member.get().value())
682             continue;
683
684         if (!matchAndCommitMember(member)) {
685             setError();
686             return;
687         }
688     }
689
690     int64_t nextValue = 0;
691     for (auto& member : enumerationMembers) {
692         if (member.get().value()) {
693             auto value = member.get().value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
694                 return integerLiteral.valueForSelectedType();
695             }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
696                 return unsignedIntegerLiteral.valueForSelectedType();
697             }, [&](auto&) -> int64_t {
698                 ASSERT_NOT_REACHED();
699                 return 0;
700             }));
701             nextValue = baseType->successor()(value);
702         } else {
703             if (nextValue > std::numeric_limits<int>::max()) {
704                 ASSERT(nextValue <= std::numeric_limits<unsigned>::max());
705                 member.get().setValue(AST::ConstantExpression(AST::UnsignedIntegerLiteral(Lexer::Token(member.get().origin()), static_cast<unsigned>(nextValue))));
706             }
707             ASSERT(nextValue >= std::numeric_limits<int>::min());
708             member.get().setValue(AST::ConstantExpression(AST::IntegerLiteral(Lexer::Token(member.get().origin()), static_cast<int>(nextValue))));
709
710             if (!matchAndCommitMember(member)) {
711                 setError();
712                 return;
713             }
714
715             nextValue = baseType->successor()(nextValue);
716         }
717     }
718
719     auto getValue = [&](AST::EnumerationMember& member) -> int64_t {
720         ASSERT(member.value());
721         auto value = member.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
722             return integerLiteral.value();
723         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
724             return unsignedIntegerLiteral.value();
725         }, [&](auto&) -> int64_t {
726             ASSERT_NOT_REACHED();
727             return 0;
728         }));
729         return value;
730     };
731
732     for (size_t i = 0; i < enumerationMembers.size(); ++i) {
733         auto value = getValue(enumerationMembers[i].get());
734         for (size_t j = i + 1; j < enumerationMembers.size(); ++j) {
735             auto otherValue = getValue(enumerationMembers[j].get());
736             if (value == otherValue) {
737                 setError();
738                 return;
739             }
740         }
741     }
742
743     bool foundZero = false;
744     for (auto& member : enumerationMembers) {
745         if (!getValue(member.get())) {
746             foundZero = true;
747             break;
748         }
749     }
750     if (!foundZero) {
751         setError();
752         return;
753     }
754 }
755
756 void Checker::visit(AST::TypeReference& typeReference)
757 {
758     ASSERT(typeReference.maybeResolvedType());
759
760     for (auto& typeArgument : typeReference.typeArguments())
761         checkErrorAndVisit(typeArgument);
762 }
763
764 auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
765 {
766     Visitor::visit(expression);
767     if (error())
768         return WTF::nullopt;
769     return getInfo(expression, requiresLeftValue);
770 }
771
772 auto Checker::getInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
773 {
774     auto typeIterator = m_typeMap.find(&expression);
775     ASSERT(typeIterator != m_typeMap.end());
776
777     auto typeAnnotationIterator = m_typeAnnotations.find(&expression);
778     ASSERT(typeAnnotationIterator != m_typeAnnotations.end());
779     if (requiresLeftValue && typeAnnotationIterator->value.isRightValue()) {
780         setError();
781         return WTF::nullopt;
782     }
783     return {{ *typeIterator->value, typeAnnotationIterator->value }};
784 }
785
786 void Checker::visit(AST::VariableDeclaration& variableDeclaration)
787 {
788     // ReadModifyWriteExpressions are the only place where anonymous variables exist,
789     // and that doesn't recurse on the anonymous variables, so we can assume the variable has a type.
790     checkErrorAndVisit(*variableDeclaration.type());
791     if (variableDeclaration.initializer()) {
792         auto& lhsType = *variableDeclaration.type();
793         auto initializerInfo = recurseAndGetInfo(*variableDeclaration.initializer());
794         if (!initializerInfo)
795             return;
796         if (!matchAndCommit(initializerInfo->resolvingType, lhsType)) {
797             setError();
798             return;
799         }
800     }
801 }
802
803 void Checker::assignType(AST::Expression& expression, UniqueRef<AST::UnnamedType>&& unnamedType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
804 {
805     auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(WTFMove(unnamedType)));
806     ASSERT_UNUSED(addResult, addResult.isNewEntry);
807     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
808     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
809 }
810
811 void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference>&& resolvableTypeReference, AST::TypeAnnotation typeAnnotation = AST::RightValue())
812 {
813     auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(WTFMove(resolvableTypeReference)));
814     ASSERT_UNUSED(addResult, addResult.isNewEntry);
815     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
816     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
817 }
818
819 void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
820 {
821     resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& result) {
822         auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(result->clone()));
823         ASSERT_UNUSED(addResult, addResult.isNewEntry);
824     }, [&](RefPtr<ResolvableTypeReference>& result) {
825         auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(result.copyRef()));
826         ASSERT_UNUSED(addResult, addResult.isNewEntry);
827     }));
828     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
829     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
830 }
831
832 void Checker::visit(AST::AssignmentExpression& assignmentExpression)
833 {
834     auto leftInfo = recurseAndGetInfo(assignmentExpression.left(), true);
835     if (!leftInfo)
836         return;
837
838     if (leftInfo->typeAnnotation.isRightValue()) {
839         setError();
840         return;
841     }
842
843     auto rightInfo = recurseAndGetInfo(assignmentExpression.right());
844     if (!rightInfo)
845         return;
846
847     auto resultType = matchAndCommit(leftInfo->resolvingType, rightInfo->resolvingType);
848     if (!resultType) {
849         setError();
850         return;
851     }
852
853     assignType(assignmentExpression, WTFMove(*resultType));
854 }
855
856 void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
857 {
858     auto leftValueInfo = recurseAndGetInfo(readModifyWriteExpression.leftValue(), true);
859     if (!leftValueInfo)
860         return;
861
862     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198166 Figure out what to do with the ReadModifyWriteExpression's AnonymousVariables.
863
864     auto newValueInfo = recurseAndGetInfo(readModifyWriteExpression.newValueExpression());
865     if (!newValueInfo)
866         return;
867
868     if (!matchAndCommit(leftValueInfo->resolvingType, newValueInfo->resolvingType)) {
869         setError();
870         return;
871     }
872
873     auto resultInfo = recurseAndGetInfo(readModifyWriteExpression.resultExpression());
874     if (!resultInfo)
875         return;
876
877     forwardType(readModifyWriteExpression, resultInfo->resolvingType);
878 }
879
880 static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType)
881 {
882     return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& type) -> AST::UnnamedType* {
883         return &type;
884     }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* {
885         // FIXME: If the type isn't committed, should we just commit() it now?
886         return type->resolvableType().maybeResolvedType();
887     }));
888 }
889
890 void Checker::visit(AST::DereferenceExpression& dereferenceExpression)
891 {
892     auto pointerInfo = recurseAndGetInfo(dereferenceExpression.pointer());
893     if (!pointerInfo)
894         return;
895
896     auto* unnamedType = getUnnamedType(pointerInfo->resolvingType);
897
898     auto* pointerType = ([&](AST::UnnamedType* unnamedType) -> AST::PointerType* {
899         if (!unnamedType)
900             return nullptr;
901         auto& unifyNode = unnamedType->unifyNode();
902         if (!is<AST::UnnamedType>(unifyNode))
903             return nullptr;
904         auto& unnamedUnifyType = downcast<AST::UnnamedType>(unifyNode);
905         if (!is<AST::PointerType>(unnamedUnifyType))
906             return nullptr;
907         return &downcast<AST::PointerType>(unnamedUnifyType);
908     })(unnamedType);
909     if (!pointerType) {
910         setError();
911         return;
912     }
913
914     assignType(dereferenceExpression, pointerType->elementType().clone(), AST::LeftValue { pointerType->addressSpace() });
915 }
916
917 void Checker::visit(AST::MakePointerExpression& makePointerExpression)
918 {
919     auto leftValueInfo = recurseAndGetInfo(makePointerExpression.leftValue(), true);
920     if (!leftValueInfo)
921         return;
922
923     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
924     if (!leftAddressSpace) {
925         setError();
926         return;
927     }
928
929     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
930     if (!leftValueType) {
931         setError();
932         return;
933     }
934
935     assignType(makePointerExpression, makeUniqueRef<AST::PointerType>(Lexer::Token(makePointerExpression.origin()), *leftAddressSpace, leftValueType->clone()));
936 }
937
938 void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
939 {
940     auto leftValueInfo = recurseAndGetInfo(makeArrayReferenceExpression.leftValue());
941     if (!leftValueInfo)
942         return;
943
944     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
945     if (!leftValueType) {
946         setError();
947         return;
948     }
949
950     auto& unifyNode = leftValueType->unifyNode();
951     if (is<AST::UnnamedType>(unifyNode)) {
952         auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
953         if (is<AST::PointerType>(unnamedType)) {
954             auto& pointerType = downcast<AST::PointerType>(unnamedType);
955             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the fact that we're not targetting the item; we're targetting the item's inner element.
956             assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), pointerType.addressSpace(), pointerType.elementType().clone()));
957             return;
958         }
959
960         auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
961         if (!leftAddressSpace) {
962             setError();
963             return;
964         }
965
966         if (is<AST::ArrayType>(unnamedType)) {
967             auto& arrayType = downcast<AST::ArrayType>(unnamedType);
968             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the number of elements.
969             assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *leftAddressSpace, arrayType.type().clone()));
970             return;
971         }
972     }
973
974     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
975     if (!leftAddressSpace) {
976         setError();
977         return;
978     }
979
980     assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *leftAddressSpace, leftValueType->clone()));
981 }
982
983 static Optional<UniqueRef<AST::UnnamedType>> argumentTypeForAndOverload(AST::UnnamedType& baseType, AST::AddressSpace addressSpace)
984 {
985     auto& unifyNode = baseType.unifyNode();
986     if (is<AST::NamedType>(unifyNode)) {
987         auto& namedType = downcast<AST::NamedType>(unifyNode);
988         return { makeUniqueRef<AST::PointerType>(Lexer::Token(namedType.origin()), addressSpace, AST::TypeReference::wrap(Lexer::Token(namedType.origin()), namedType)) };
989     }
990
991     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
992
993     if (is<AST::ArrayReferenceType>(unnamedType))
994         return unnamedType.clone();
995
996     if (is<AST::ArrayType>(unnamedType))
997         return { makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(unnamedType.origin()), addressSpace, downcast<AST::ArrayType>(unnamedType).type().clone()) };
998
999     if (is<AST::PointerType>(unnamedType))
1000         return WTF::nullopt;
1001
1002     return { makeUniqueRef<AST::PointerType>(Lexer::Token(unnamedType.origin()), addressSpace, unnamedType.clone()) };
1003 }
1004
1005 void Checker::finishVisiting(AST::PropertyAccessExpression& propertyAccessExpression, ResolvingType* additionalArgumentType)
1006 {
1007     auto baseInfo = recurseAndGetInfo(propertyAccessExpression.base());
1008     if (!baseInfo)
1009         return;
1010     auto baseUnnamedType = commit(baseInfo->resolvingType);
1011     if (!baseUnnamedType)
1012         return;
1013
1014     AST::FunctionDeclaration* getterFunction = nullptr;
1015     AST::UnnamedType* getterReturnType = nullptr;
1016     {
1017         Vector<std::reference_wrapper<ResolvingType>> getterArgumentTypes { baseInfo->resolvingType };
1018         if (additionalArgumentType)
1019             getterArgumentTypes.append(*additionalArgumentType);
1020         if ((getterFunction = resolveFunction(m_program, propertyAccessExpression.possibleGetterOverloads(), getterArgumentTypes, propertyAccessExpression.getterFunctionName(), propertyAccessExpression.origin(), m_intrinsics)))
1021             getterReturnType = &getterFunction->type();
1022     }
1023
1024     AST::FunctionDeclaration* anderFunction = nullptr;
1025     AST::UnnamedType* anderReturnType = nullptr;
1026     auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace();
1027     if (leftAddressSpace) {
1028         if (auto argumentTypeForAndOverload = WHLSL::argumentTypeForAndOverload(*baseUnnamedType, *leftAddressSpace)) {
1029             ResolvingType argumentType = { WTFMove(*argumentTypeForAndOverload) };
1030             Vector<std::reference_wrapper<ResolvingType>> anderArgumentTypes { argumentType };
1031             if (additionalArgumentType)
1032                 anderArgumentTypes.append(*additionalArgumentType);
1033             if ((anderFunction = resolveFunction(m_program, propertyAccessExpression.possibleAnderOverloads(), anderArgumentTypes, propertyAccessExpression.anderFunctionName(), propertyAccessExpression.origin(), m_intrinsics)))
1034                 anderReturnType = &downcast<AST::PointerType>(anderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1035         }
1036     }
1037
1038     AST::FunctionDeclaration* threadAnderFunction = nullptr;
1039     AST::UnnamedType* threadAnderReturnType = nullptr;
1040     if (auto argumentTypeForAndOverload = WHLSL::argumentTypeForAndOverload(*baseUnnamedType, AST::AddressSpace::Thread)) {
1041         ResolvingType argumentType = { makeUniqueRef<AST::PointerType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, baseUnnamedType->get().clone()) };
1042         Vector<std::reference_wrapper<ResolvingType>> threadAnderArgumentTypes { argumentType };
1043         if (additionalArgumentType)
1044             threadAnderArgumentTypes.append(*additionalArgumentType);
1045         if ((threadAnderFunction = resolveFunction(m_program, propertyAccessExpression.possibleAnderOverloads(), threadAnderArgumentTypes, propertyAccessExpression.anderFunctionName(), propertyAccessExpression.origin(), m_intrinsics)))
1046             threadAnderReturnType = &downcast<AST::PointerType>(threadAnderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1047     }
1048
1049     if (leftAddressSpace && !anderFunction && !getterFunction) {
1050         setError();
1051         return;
1052     }
1053
1054     if (!leftAddressSpace && !threadAnderFunction && !getterFunction) {
1055         setError();
1056         return;
1057     }
1058
1059     if (threadAnderFunction && getterFunction) {
1060         setError();
1061         return;
1062     }
1063
1064     if (anderFunction && threadAnderFunction && !matches(*anderReturnType, *threadAnderReturnType)) {
1065         setError();
1066         return;
1067     }
1068
1069     if (getterFunction && anderFunction && !matches(*getterReturnType, *anderReturnType)) {
1070         setError();
1071         return;
1072     }
1073
1074     if (getterFunction && threadAnderFunction && !matches(*getterReturnType, *threadAnderReturnType)) {
1075         setError();
1076         return;
1077     }
1078
1079     AST::UnnamedType* fieldType = getterReturnType ? getterReturnType : anderReturnType ? anderReturnType : threadAnderReturnType;
1080
1081     AST::FunctionDeclaration* setterFunction = nullptr;
1082     AST::UnnamedType* setterReturnType = nullptr;
1083     {
1084         ResolvingType fieldResolvingType(fieldType->clone());
1085         Vector<std::reference_wrapper<ResolvingType>> setterArgumentTypes { baseInfo->resolvingType };
1086         if (additionalArgumentType)
1087             setterArgumentTypes.append(*additionalArgumentType);
1088         setterArgumentTypes.append(fieldResolvingType);
1089         setterFunction = resolveFunction(m_program, propertyAccessExpression.possibleSetterOverloads(), setterArgumentTypes, propertyAccessExpression.setterFunctionName(), propertyAccessExpression.origin(), m_intrinsics);
1090         if (setterFunction)
1091             setterReturnType = &setterFunction->type();
1092     }
1093
1094     if (setterFunction && !getterFunction) {
1095         setError();
1096         return;
1097     }
1098
1099     propertyAccessExpression.setGetterFunction(getterFunction);
1100     propertyAccessExpression.setAnderFunction(anderFunction);
1101     propertyAccessExpression.setThreadAnderFunction(threadAnderFunction);
1102     propertyAccessExpression.setSetterFunction(setterFunction);
1103
1104     AST::TypeAnnotation typeAnnotation = AST::RightValue();
1105     if (auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace()) {
1106         if (anderFunction)
1107             typeAnnotation = AST::LeftValue { *leftAddressSpace };
1108         else if (setterFunction)
1109             typeAnnotation = AST::AbstractLeftValue();
1110     } else if (!baseInfo->typeAnnotation.isRightValue() && (setterFunction || threadAnderFunction))
1111         typeAnnotation = AST::AbstractLeftValue();
1112     assignType(propertyAccessExpression, fieldType->clone(), WTFMove(typeAnnotation));
1113 }
1114
1115 void Checker::visit(AST::DotExpression& dotExpression)
1116 {
1117     finishVisiting(dotExpression);
1118 }
1119
1120 void Checker::visit(AST::IndexExpression& indexExpression)
1121 {
1122     auto baseInfo = recurseAndGetInfo(indexExpression.indexExpression());
1123     if (!baseInfo)
1124         return;
1125     finishVisiting(indexExpression, &baseInfo->resolvingType);
1126 }
1127
1128 void Checker::visit(AST::VariableReference& variableReference)
1129 {
1130     ASSERT(variableReference.variable());
1131     ASSERT(variableReference.variable()->type());
1132     
1133     AST::TypeAnnotation typeAnnotation = AST::RightValue();
1134     if (!variableReference.variable()->isAnonymous()) // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198166 This doesn't seem right.
1135         typeAnnotation = AST::LeftValue { AST::AddressSpace::Thread };
1136     assignType(variableReference, variableReference.variable()->type()->clone(), WTFMove(typeAnnotation));
1137 }
1138
1139 void Checker::visit(AST::Return& returnStatement)
1140 {
1141     ASSERT(returnStatement.function());
1142     if (returnStatement.value()) {
1143         auto valueInfo = recurseAndGetInfo(*returnStatement.value());
1144         if (!valueInfo)
1145             return;
1146         if (!matchAndCommit(valueInfo->resolvingType, returnStatement.function()->type()))
1147             setError();
1148         return;
1149     }
1150
1151     if (!matches(returnStatement.function()->type(), m_intrinsics.voidType()))
1152         setError();
1153 }
1154
1155 void Checker::visit(AST::PointerType&)
1156 {
1157     // Following pointer types can cause infinite loops because of data structures
1158     // like linked lists.
1159     // FIXME: Make sure this function should be empty
1160 }
1161
1162 void Checker::visit(AST::ArrayReferenceType&)
1163 {
1164     // Following array reference types can cause infinite loops because of data
1165     // structures like linked lists.
1166     // FIXME: Make sure this function should be empty
1167 }
1168
1169 void Checker::visit(AST::IntegerLiteral& integerLiteral)
1170 {
1171     assignType(integerLiteral, adoptRef(*new ResolvableTypeReference(integerLiteral.type())));
1172 }
1173
1174 void Checker::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
1175 {
1176     assignType(unsignedIntegerLiteral, adoptRef(*new ResolvableTypeReference(unsignedIntegerLiteral.type())));
1177 }
1178
1179 void Checker::visit(AST::FloatLiteral& floatLiteral)
1180 {
1181     assignType(floatLiteral, adoptRef(*new ResolvableTypeReference(floatLiteral.type())));
1182 }
1183
1184 void Checker::visit(AST::NullLiteral& nullLiteral)
1185 {
1186     assignType(nullLiteral, adoptRef(*new ResolvableTypeReference(nullLiteral.type())));
1187 }
1188
1189 void Checker::visit(AST::BooleanLiteral& booleanLiteral)
1190 {
1191     assignType(booleanLiteral, AST::TypeReference::wrap(Lexer::Token(booleanLiteral.origin()), m_intrinsics.boolType()));
1192 }
1193
1194 void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
1195 {
1196     ASSERT(enumerationMemberLiteral.enumerationDefinition());
1197     auto& enumerationDefinition = *enumerationMemberLiteral.enumerationDefinition();
1198     assignType(enumerationMemberLiteral, AST::TypeReference::wrap(Lexer::Token(enumerationMemberLiteral.origin()), enumerationDefinition));
1199 }
1200
1201 bool Checker::isBoolType(ResolvingType& resolvingType)
1202 {
1203     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> bool {
1204         return matches(left, m_intrinsics.boolType());
1205     }, [&](RefPtr<ResolvableTypeReference>& left) -> bool {
1206         return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType()));
1207     }));
1208 }
1209
1210 bool Checker::recurseAndRequireBoolType(AST::Expression& expression)
1211 {
1212     auto expressionInfo = recurseAndGetInfo(expression);
1213     if (!expressionInfo)
1214         return false;
1215     if (!isBoolType(expressionInfo->resolvingType)) {
1216         setError();
1217         return false;
1218     }
1219     return true;
1220 }
1221
1222 void Checker::visit(AST::LogicalNotExpression& logicalNotExpression)
1223 {
1224     if (!recurseAndRequireBoolType(logicalNotExpression.operand()))
1225         return;
1226     assignType(logicalNotExpression, AST::TypeReference::wrap(Lexer::Token(logicalNotExpression.origin()), m_intrinsics.boolType()));
1227 }
1228
1229 void Checker::visit(AST::LogicalExpression& logicalExpression)
1230 {
1231     if (!recurseAndRequireBoolType(logicalExpression.left()))
1232         return;
1233     if (!recurseAndRequireBoolType(logicalExpression.right()))
1234         return;
1235     assignType(logicalExpression, AST::TypeReference::wrap(Lexer::Token(logicalExpression.origin()), m_intrinsics.boolType()));
1236 }
1237
1238 void Checker::visit(AST::IfStatement& ifStatement)
1239 {
1240     if (!recurseAndRequireBoolType(ifStatement.conditional()))
1241         return;
1242     checkErrorAndVisit(ifStatement.body());
1243     if (ifStatement.elseBody())
1244         checkErrorAndVisit(*ifStatement.elseBody());
1245 }
1246
1247 void Checker::visit(AST::WhileLoop& whileLoop)
1248 {
1249     if (!recurseAndRequireBoolType(whileLoop.conditional()))
1250         return;
1251     checkErrorAndVisit(whileLoop.body());
1252 }
1253
1254 void Checker::visit(AST::DoWhileLoop& doWhileLoop)
1255 {
1256     checkErrorAndVisit(doWhileLoop.body());
1257     recurseAndRequireBoolType(doWhileLoop.conditional());
1258 }
1259
1260 void Checker::visit(AST::ForLoop& forLoop)
1261 {
1262     WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::Statement>& statement) {
1263         checkErrorAndVisit(statement);
1264     }, [&](UniqueRef<AST::Expression>& expression) {
1265         checkErrorAndVisit(expression);
1266     }), forLoop.initialization());
1267     if (error())
1268         return;
1269     if (forLoop.condition()) {
1270         if (!recurseAndRequireBoolType(*forLoop.condition()))
1271             return;
1272     }
1273     if (forLoop.increment())
1274         checkErrorAndVisit(*forLoop.increment());
1275     checkErrorAndVisit(forLoop.body());
1276 }
1277
1278 void Checker::visit(AST::SwitchStatement& switchStatement)
1279 {
1280     auto* valueType = ([&]() -> AST::NamedType* {
1281         auto valueInfo = recurseAndGetInfo(switchStatement.value());
1282         if (!valueInfo)
1283             return nullptr;
1284         auto* valueType = getUnnamedType(valueInfo->resolvingType);
1285         if (!valueType)
1286             return nullptr;
1287         auto& valueUnifyNode = valueType->unifyNode();
1288         if (!is<AST::NamedType>(valueUnifyNode))
1289             return nullptr;
1290         auto& valueNamedUnifyNode = downcast<AST::NamedType>(valueUnifyNode);
1291         if (!(is<AST::NativeTypeDeclaration>(valueNamedUnifyNode) && downcast<AST::NativeTypeDeclaration>(valueNamedUnifyNode).isInt())
1292             && !is<AST::EnumerationDefinition>(valueNamedUnifyNode))
1293             return nullptr;
1294         return &valueNamedUnifyNode;
1295     })();
1296     if (!valueType) {
1297         setError();
1298         return;
1299     }
1300
1301     bool hasDefault = false;
1302     for (auto& switchCase : switchStatement.switchCases()) {
1303         checkErrorAndVisit(switchCase.block());
1304         if (!switchCase.value()) {
1305             hasDefault = true;
1306             continue;
1307         }
1308         auto success = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> bool {
1309             return static_cast<bool>(matchAndCommit(*valueType, integerLiteral.type()));
1310         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> bool {
1311             return static_cast<bool>(matchAndCommit(*valueType, unsignedIntegerLiteral.type()));
1312         }, [&](AST::FloatLiteral& floatLiteral) -> bool {
1313             return static_cast<bool>(matchAndCommit(*valueType, floatLiteral.type()));
1314         }, [&](AST::NullLiteral& nullLiteral) -> bool {
1315             return static_cast<bool>(matchAndCommit(*valueType, nullLiteral.type()));
1316         }, [&](AST::BooleanLiteral&) -> bool {
1317             return matches(*valueType, m_intrinsics.boolType());
1318         }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> bool {
1319             ASSERT(enumerationMemberLiteral.enumerationDefinition());
1320             return matches(*valueType, *enumerationMemberLiteral.enumerationDefinition());
1321         }));
1322         if (!success) {
1323             setError();
1324             return;
1325         }
1326     }
1327
1328     for (size_t i = 0; i < switchStatement.switchCases().size(); ++i) {
1329         auto& firstCase = switchStatement.switchCases()[i];
1330         for (size_t j = i + 1; j < switchStatement.switchCases().size(); ++j) {
1331             auto& secondCase = switchStatement.switchCases()[j];
1332             
1333             if (static_cast<bool>(firstCase.value()) != static_cast<bool>(secondCase.value()))
1334                 continue;
1335
1336             if (!static_cast<bool>(firstCase.value())) {
1337                 setError();
1338                 return;
1339             }
1340
1341             auto success = firstCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& firstIntegerLiteral) -> bool {
1342                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1343                     return firstIntegerLiteral.value() != secondIntegerLiteral.value();
1344                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1345                     return static_cast<int64_t>(firstIntegerLiteral.value()) != static_cast<int64_t>(secondUnsignedIntegerLiteral.value());
1346                 }, [](auto&) -> bool {
1347                     return true;
1348                 }));
1349             }, [&](AST::UnsignedIntegerLiteral& firstUnsignedIntegerLiteral) -> bool {
1350                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1351                     return static_cast<int64_t>(firstUnsignedIntegerLiteral.value()) != static_cast<int64_t>(secondIntegerLiteral.value());
1352                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1353                     return firstUnsignedIntegerLiteral.value() != secondUnsignedIntegerLiteral.value();
1354                 }, [](auto&) -> bool {
1355                     return true;
1356                 }));
1357             }, [&](AST::EnumerationMemberLiteral& firstEnumerationMemberLiteral) -> bool {
1358                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& secondEnumerationMemberLiteral) -> bool {
1359                     ASSERT(firstEnumerationMemberLiteral.enumerationMember());
1360                     ASSERT(secondEnumerationMemberLiteral.enumerationMember());
1361                     return firstEnumerationMemberLiteral.enumerationMember() != secondEnumerationMemberLiteral.enumerationMember();
1362                 }, [](auto&) -> bool {
1363                     return true;
1364                 }));
1365             }, [](auto&) -> bool {
1366                 return true;
1367             }));
1368             if (!success) {
1369                 setError();
1370                 return;
1371             }
1372         }
1373     }
1374
1375     if (!hasDefault) {
1376         if (is<AST::NativeTypeDeclaration>(*valueType)) {
1377             HashSet<int64_t> values;
1378             bool zeroValueExists;
1379             for (auto& switchCase : switchStatement.switchCases()) {
1380                 auto value = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
1381                     return integerLiteral.valueForSelectedType();
1382                 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
1383                     return unsignedIntegerLiteral.valueForSelectedType();
1384                 }, [](auto&) -> int64_t {
1385                     ASSERT_NOT_REACHED();
1386                     return 0;
1387                 }));
1388                 if (!value)
1389                     zeroValueExists = true;
1390                 else
1391                     values.add(value);
1392             }
1393             bool success = true;
1394             downcast<AST::NativeTypeDeclaration>(*valueType).iterateAllValues([&](int64_t value) -> bool {
1395                 if (!value) {
1396                     if (!zeroValueExists) {
1397                         success = false;
1398                         return true;
1399                     }
1400                     return false;
1401                 }
1402                 if (!values.contains(value)) {
1403                     success = false;
1404                     return true;
1405                 }
1406                 return false;
1407             });
1408             if (!success) {
1409                 setError();
1410                 return;
1411             }
1412         } else {
1413             HashSet<AST::EnumerationMember*> values;
1414             for (auto& switchCase : switchStatement.switchCases()) {
1415                 switchCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
1416                     ASSERT(enumerationMemberLiteral.enumerationMember());
1417                     values.add(enumerationMemberLiteral.enumerationMember());
1418                 }, [](auto&) {
1419                     ASSERT_NOT_REACHED();
1420                 }));
1421             }
1422             for (auto& enumerationMember : downcast<AST::EnumerationDefinition>(*valueType).enumerationMembers()) {
1423                 if (!values.contains(&enumerationMember.get())) {
1424                     setError();
1425                     return;
1426                 }
1427             }
1428         }
1429     }
1430 }
1431
1432 void Checker::visit(AST::CommaExpression& commaExpression)
1433 {
1434     ASSERT(commaExpression.list().size() > 0);
1435     Visitor::visit(commaExpression);
1436     if (error())
1437         return;
1438     auto lastInfo = getInfo(commaExpression.list().last());
1439     forwardType(commaExpression, lastInfo->resolvingType);
1440 }
1441
1442 void Checker::visit(AST::TernaryExpression& ternaryExpression)
1443 {
1444     auto predicateInfo = recurseAndRequireBoolType(ternaryExpression.predicate());
1445     if (!predicateInfo)
1446         return;
1447
1448     auto bodyInfo = recurseAndGetInfo(ternaryExpression.bodyExpression());
1449     auto elseInfo = recurseAndGetInfo(ternaryExpression.elseExpression());
1450     
1451     auto resultType = matchAndCommit(bodyInfo->resolvingType, elseInfo->resolvingType);
1452     if (!resultType) {
1453         setError();
1454         return;
1455     }
1456
1457     assignType(ternaryExpression, WTFMove(*resultType));
1458 }
1459
1460 void Checker::visit(AST::CallExpression& callExpression)
1461 {
1462     Vector<std::reference_wrapper<ResolvingType>> types;
1463     types.reserveInitialCapacity(callExpression.arguments().size());
1464     for (auto& argument : callExpression.arguments()) {
1465         auto argumentInfo = recurseAndGetInfo(argument);
1466         if (!argumentInfo)
1467             return;
1468         types.uncheckedAppend(argumentInfo->resolvingType);
1469     }
1470     // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
1471     // We don't want to recurse to the same node twice.
1472
1473     ASSERT(callExpression.hasOverloads());
1474     auto* function = resolveFunction(m_program, *callExpression.overloads(), types, callExpression.name(), callExpression.origin(), m_intrinsics, callExpression.castReturnType());
1475     if (!function) {
1476         setError();
1477         return;
1478     }
1479
1480     for (size_t i = 0; i < function->parameters().size(); ++i) {
1481         if (!matchAndCommit(types[i].get(), *function->parameters()[i]->type())) {
1482             setError();
1483             return;
1484         }
1485     }
1486
1487     callExpression.setFunction(*function);
1488
1489     assignType(callExpression, function->type().clone());
1490 }
1491
1492 bool check(Program& program)
1493 {
1494     Checker checker(program.intrinsics(), program);
1495     checker.checkErrorAndVisit(program);
1496     if (checker.error())
1497         return false;
1498     return checker.assignTypes();
1499 }
1500
1501 } // namespace WHLSL
1502
1503 } // namespace WebCore
1504
1505 #endif // ENABLE(WEBGPU)