cc63884e5e0e7de794b9d3104e10f90b22753dfc
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLChecker.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLChecker.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLArrayReferenceType.h"
32 #include "WHLSLArrayType.h"
33 #include "WHLSLAssignmentExpression.h"
34 #include "WHLSLCallExpression.h"
35 #include "WHLSLCommaExpression.h"
36 #include "WHLSLDereferenceExpression.h"
37 #include "WHLSLDoWhileLoop.h"
38 #include "WHLSLDotExpression.h"
39 #include "WHLSLEntryPointType.h"
40 #include "WHLSLForLoop.h"
41 #include "WHLSLGatherEntryPointItems.h"
42 #include "WHLSLIfStatement.h"
43 #include "WHLSLIndexExpression.h"
44 #include "WHLSLInferTypes.h"
45 #include "WHLSLLogicalExpression.h"
46 #include "WHLSLLogicalNotExpression.h"
47 #include "WHLSLMakeArrayReferenceExpression.h"
48 #include "WHLSLMakePointerExpression.h"
49 #include "WHLSLNameContext.h"
50 #include "WHLSLPointerType.h"
51 #include "WHLSLProgram.h"
52 #include "WHLSLReadModifyWriteExpression.h"
53 #include "WHLSLResolvableType.h"
54 #include "WHLSLResolveOverloadImpl.h"
55 #include "WHLSLResolvingType.h"
56 #include "WHLSLReturn.h"
57 #include "WHLSLSwitchStatement.h"
58 #include "WHLSLTernaryExpression.h"
59 #include "WHLSLVisitor.h"
60 #include "WHLSLWhileLoop.h"
61 #include <wtf/HashMap.h>
62 #include <wtf/HashSet.h>
63 #include <wtf/Ref.h>
64 #include <wtf/Vector.h>
65 #include <wtf/text/WTFString.h>
66
67 namespace WebCore {
68
69 namespace WHLSL {
70
71 class PODChecker : public Visitor {
72 public:
73     PODChecker() = default;
74
75     virtual ~PODChecker() = default;
76
77     void visit(AST::EnumerationDefinition& enumerationDefinition) override
78     {
79         Visitor::visit(enumerationDefinition);
80     }
81
82     void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override
83     {
84         if (!nativeTypeDeclaration.isNumber()
85             && !nativeTypeDeclaration.isVector()
86             && !nativeTypeDeclaration.isMatrix())
87             setError();
88     }
89
90     void visit(AST::StructureDefinition& structureDefinition) override
91     {
92         Visitor::visit(structureDefinition);
93     }
94
95     void visit(AST::TypeDefinition& typeDefinition) override
96     {
97         Visitor::visit(typeDefinition);
98     }
99
100     void visit(AST::ArrayType& arrayType) override
101     {
102         Visitor::visit(arrayType);
103     }
104
105     void visit(AST::PointerType&) override
106     {
107         setError();
108     }
109
110     void visit(AST::ArrayReferenceType&) override
111     {
112         setError();
113     }
114
115     void visit(AST::TypeReference& typeReference) override
116     {
117         checkErrorAndVisit(typeReference.resolvedType());
118     }
119 };
120
121 static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(Lexer::Token origin, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics)
122 {
123     const bool isOperator = true;
124     auto returnType = makeUniqueRef<AST::PointerType>(Lexer::Token(origin), firstArgument.addressSpace(), firstArgument.elementType().clone());
125     AST::VariableDeclarations parameters;
126     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), firstArgument.clone(), String(), nullptr, nullptr));
127     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), UniqueRef<AST::UnnamedType>(AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.uintType())), String(), nullptr, nullptr));
128     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator&[]", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator));
129 }
130
131 static AST::NativeFunctionDeclaration resolveWithOperatorLength(Lexer::Token origin, AST::UnnamedType& firstArgument, const Intrinsics& intrinsics)
132 {
133     const bool isOperator = true;
134     auto returnType = AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.uintType());
135     AST::VariableDeclarations parameters;
136     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), firstArgument.clone(), String(), nullptr, nullptr));
137     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator.length", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator));
138 }
139
140 static AST::NativeFunctionDeclaration resolveWithReferenceComparator(Lexer::Token origin, ResolvingType& firstArgument, ResolvingType& secondArgument, const Intrinsics& intrinsics)
141 {
142     const bool isOperator = true;
143     auto returnType = AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.boolType());
144     auto argumentType = firstArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
145         return unnamedType->clone();
146     }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
147         return secondArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
148             return unnamedType->clone();
149         }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
150             // We encountered "null == null".
151             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198162 This can probably be generalized, using the "preferred type" infrastructure used by generic literals
152             ASSERT_NOT_REACHED();
153             return AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.intType());
154         }));
155     }));
156     AST::VariableDeclarations parameters;
157     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), argumentType->clone(), String(), nullptr, nullptr));
158     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), UniqueRef<AST::UnnamedType>(WTFMove(argumentType)), String(), nullptr, nullptr));
159     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator));
160 }
161
162 enum class Acceptability {
163     Yes,
164     Maybe,
165     No
166 };
167
168 static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(const String& name, Lexer::Token origin, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics)
169 {
170     if (name == "operator&[]" && types.size() == 2) {
171         auto* firstArgumentArrayRef = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* {
172             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
173                 return &downcast<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType));
174             return nullptr;
175         }, [](RefPtr<ResolvableTypeReference>&) -> AST::ArrayReferenceType* {
176             return nullptr;
177         }));
178         bool secondArgumentIsUint = types[1].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
179             return matches(unnamedType, intrinsics.uintType());
180         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
181             return resolvableTypeReference->resolvableType().canResolve(intrinsics.uintType());
182         }));
183         if (firstArgumentArrayRef && secondArgumentIsUint)
184             return resolveWithOperatorAnderIndexer(origin, *firstArgumentArrayRef, intrinsics);
185     } else if (name == "operator.length" && types.size() == 1) {
186         auto* firstArgumentReference = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
187             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) || is<AST::ArrayType>(static_cast<AST::UnnamedType&>(unnamedType)))
188                 return &unnamedType;
189             return nullptr;
190         }, [](RefPtr<ResolvableTypeReference>&) -> AST::UnnamedType* {
191             return nullptr;
192         }));
193         if (firstArgumentReference)
194             return resolveWithOperatorLength(origin, *firstArgumentReference, intrinsics);
195     } else if (name == "operator==" && types.size() == 2) {
196         auto acceptability = [](ResolvingType& resolvingType) -> Acceptability {
197             return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> Acceptability {
198                 return is<AST::ReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) ? Acceptability::Yes : Acceptability::No;
199             }, [](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Acceptability {
200                 return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType()) ? Acceptability::Maybe : Acceptability::No;
201             }));
202         };
203         auto leftAcceptability = acceptability(types[0].get());
204         auto rightAcceptability = acceptability(types[1].get());
205         bool success = false;
206         if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) {
207             auto& unnamedType1 = *types[0].get().getUnnamedType();
208             auto& unnamedType2 = *types[1].get().getUnnamedType();
209             success = matches(unnamedType1, unnamedType2);
210         } else if ((leftAcceptability == Acceptability::Maybe && rightAcceptability == Acceptability::Yes)
211             || (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Maybe))
212             success = true;
213         if (success)
214             return resolveWithReferenceComparator(origin, types[0].get(), types[1].get(), intrinsics);
215     }
216     return WTF::nullopt;
217 }
218
219 static AST::FunctionDeclaration* resolveFunction(Program& program, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>* possibleOverloads, Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, Lexer::Token origin, const Intrinsics& intrinsics, AST::NamedType* castReturnType = nullptr)
220 {
221     if (possibleOverloads) {
222         if (AST::FunctionDeclaration* function = resolveFunctionOverload(*possibleOverloads, types, castReturnType))
223             return function;
224     }
225
226     if (auto newFunction = resolveByInstantiation(name, origin, types, intrinsics)) {
227         program.append(WTFMove(*newFunction));
228         return &program.nativeFunctionDeclarations().last();
229     }
230
231     return nullptr;
232 }
233
234 static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
235 {
236     {
237         auto checkDuplicateSemantics = [&](const Vector<EntryPointItem>& items) -> bool {
238             for (size_t i = 0; i < items.size(); ++i) {
239                 for (size_t j = i + 1; j < items.size(); ++j) {
240                     if (items[i].semantic == items[j].semantic)
241                         return false;
242                 }
243             }
244             return true;
245         };
246         if (!checkDuplicateSemantics(inputItems))
247             return false;
248         if (!checkDuplicateSemantics(outputItems))
249             return false;
250     }
251
252     {
253         auto checkSemanticTypes = [&](const Vector<EntryPointItem>& items) -> bool {
254             for (auto& item : items) {
255                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
256                     return semantic.isAcceptableType(*item.unnamedType, intrinsics);
257                 }), *item.semantic);
258                 if (!acceptable)
259                     return false;
260             }
261             return true;
262         };
263         if (!checkSemanticTypes(inputItems))
264             return false;
265         if (!checkSemanticTypes(outputItems))
266             return false;
267     }
268
269     {
270         auto checkSemanticForShaderType = [&](const Vector<EntryPointItem>& items, AST::BaseSemantic::ShaderItemDirection direction) -> bool {
271             for (auto& item : items) {
272                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
273                     return semantic.isAcceptableForShaderItemDirection(direction, entryPointType);
274                 }), *item.semantic);
275                 if (!acceptable)
276                     return false;
277             }
278             return true;
279         };
280         if (!checkSemanticForShaderType(inputItems, AST::BaseSemantic::ShaderItemDirection::Input))
281             return false;
282         if (!checkSemanticForShaderType(outputItems, AST::BaseSemantic::ShaderItemDirection::Output))
283             return false;
284     }
285
286     {
287         auto checkPODData = [&](const Vector<EntryPointItem>& items) -> bool {
288             for (auto& item : items) {
289                 PODChecker podChecker;
290                 if (is<AST::PointerType>(item.unnamedType))
291                     podChecker.checkErrorAndVisit(downcast<AST::PointerType>(*item.unnamedType).elementType());
292                 else if (is<AST::ArrayReferenceType>(item.unnamedType))
293                     podChecker.checkErrorAndVisit(downcast<AST::ArrayReferenceType>(*item.unnamedType).elementType());
294                 else if (is<AST::ArrayType>(item.unnamedType))
295                     podChecker.checkErrorAndVisit(downcast<AST::ArrayType>(*item.unnamedType).type());
296                 else
297                     continue;
298                 if (podChecker.error())
299                     return false;
300             }
301             return true;
302         };
303         if (!checkPODData(inputItems))
304             return false;
305         if (!checkPODData(outputItems))
306             return false;
307     }
308
309     return true;
310 }
311
312 static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinition, const Intrinsics& intrinsics, NameContext& nameContext)
313 {
314     enum class CheckKind {
315         Index,
316         Dot
317     };
318
319     auto checkGetter = [&](CheckKind kind) -> bool {
320         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
321         if (functionDefinition.parameters().size() != numExpectedParameters)
322             return false;
323         auto& firstParameterUnifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
324         if (is<AST::UnnamedType>(firstParameterUnifyNode)) {
325             auto& unnamedType = downcast<AST::UnnamedType>(firstParameterUnifyNode);
326             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
327                 return false;
328         }
329         if (kind == CheckKind::Index) {
330             auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1]->type())->unifyNode();
331             if (!is<AST::NamedType>(secondParameterUnifyNode))
332                 return false;
333             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
334             if (!is<AST::NativeTypeDeclaration>(namedType))
335                 return false;
336             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
337             if (!nativeTypeDeclaration.isInt())
338                 return false;
339         }
340         return true;
341     };
342
343     auto checkSetter = [&](CheckKind kind) -> bool {
344         size_t numExpectedParameters = kind == CheckKind::Index ? 3 : 2;
345         if (functionDefinition.parameters().size() != numExpectedParameters)
346             return false;
347         auto& firstArgumentUnifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
348         if (is<AST::UnnamedType>(firstArgumentUnifyNode)) {
349             auto& unnamedType = downcast<AST::UnnamedType>(firstArgumentUnifyNode);
350             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
351                 return false;
352         }
353         if (kind == CheckKind::Index) {
354             auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1]->type())->unifyNode();
355             if (!is<AST::NamedType>(secondParameterUnifyNode))
356                 return false;
357             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
358             if (!is<AST::NativeTypeDeclaration>(namedType))
359                 return false;
360             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
361             if (!nativeTypeDeclaration.isInt())
362                 return false;
363         }
364         if (!matches(functionDefinition.type(), *functionDefinition.parameters()[0]->type()))
365             return false;
366         auto& valueType = *functionDefinition.parameters()[numExpectedParameters - 1]->type();
367         auto getterName = functionDefinition.name().substring(0, functionDefinition.name().length() - 1);
368         auto* getterFuncs = nameContext.getFunctions(getterName);
369         if (!getterFuncs)
370             return false;
371         Vector<ResolvingType> argumentTypes;
372         Vector<std::reference_wrapper<ResolvingType>> argumentTypeReferences;
373         for (size_t i = 0; i < numExpectedParameters - 1; ++i)
374             argumentTypes.append((*functionDefinition.parameters()[i]->type())->clone());
375         for (auto& argumentType : argumentTypes)
376             argumentTypeReferences.append(argumentType);
377         auto* overload = resolveFunctionOverload(*getterFuncs, argumentTypeReferences);
378         if (!overload)
379             return false;
380         auto& resultType = overload->type();
381         return matches(resultType, valueType);
382     };
383
384     auto checkAnder = [&](CheckKind kind) -> bool {
385         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
386         if (functionDefinition.parameters().size() != numExpectedParameters)
387             return false;
388         {
389             auto& unifyNode = functionDefinition.type().unifyNode();
390             if (!is<AST::UnnamedType>(unifyNode))
391                 return false;
392             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
393             if (!is<AST::PointerType>(unnamedType))
394                 return false;
395         }
396         {
397             auto& unifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
398             if (!is<AST::UnnamedType>(unifyNode))
399                 return false;
400             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
401             return is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType);
402         }
403     };
404
405     if (!functionDefinition.isOperator())
406         return true;
407     if (functionDefinition.isCast())
408         return true;
409     if (functionDefinition.name() == "operator++" || functionDefinition.name() == "operator--") {
410         return functionDefinition.parameters().size() == 1
411             && matches(*functionDefinition.parameters()[0]->type(), functionDefinition.type());
412     }
413     if (functionDefinition.name() == "operator+" || functionDefinition.name() == "operator-")
414         return functionDefinition.parameters().size() == 1 || functionDefinition.parameters().size() == 2;
415     if (functionDefinition.name() == "operator*"
416         || functionDefinition.name() == "operator/"
417         || functionDefinition.name() == "operator%"
418         || functionDefinition.name() == "operator&"
419         || functionDefinition.name() == "operator|"
420         || functionDefinition.name() == "operator^"
421         || functionDefinition.name() == "operator<<"
422         || functionDefinition.name() == "opreator>>")
423         return functionDefinition.parameters().size() == 2;
424     if (functionDefinition.name() == "operator~")
425         return functionDefinition.parameters().size() == 1;
426     if (functionDefinition.name() == "operator=="
427         || functionDefinition.name() == "operator<"
428         || functionDefinition.name() == "operator<="
429         || functionDefinition.name() == "operator>"
430         || functionDefinition.name() == "operator>=") {
431         return functionDefinition.parameters().size() == 2
432             && matches(functionDefinition.type(), intrinsics.boolType());
433     }
434     if (functionDefinition.name() == "operator[]")
435         return checkGetter(CheckKind::Index);
436     if (functionDefinition.name() == "operator[]=")
437         return checkSetter(CheckKind::Index);
438     if (functionDefinition.name() == "operator&[]")
439         return checkAnder(CheckKind::Index);
440     if (functionDefinition.name().startsWith("operator.")) {
441         if (functionDefinition.name().endsWith("="))
442             return checkSetter(CheckKind::Dot);
443         return checkGetter(CheckKind::Dot);
444     }
445     if (functionDefinition.name().startsWith("operator&."))
446         return checkAnder(CheckKind::Dot);
447     return false;
448 }
449
450 class Checker : public Visitor {
451 public:
452     Checker(const Intrinsics& intrinsics, Program& program)
453         : m_intrinsics(intrinsics)
454         , m_program(program)
455     {
456     }
457
458     virtual ~Checker() = default;
459
460     void visit(Program&) override;
461
462     bool assignTypes();
463
464 private:
465     bool checkShaderType(const AST::FunctionDefinition&);
466     bool isBoolType(ResolvingType&);
467     struct RecurseInfo {
468         ResolvingType& resolvingType;
469         const AST::TypeAnnotation typeAnnotation;
470     };
471     Optional<RecurseInfo> recurseAndGetInfo(AST::Expression&, bool requiresLeftValue = false);
472     Optional<RecurseInfo> getInfo(AST::Expression&, bool requiresLeftValue = false);
473     Optional<UniqueRef<AST::UnnamedType>> recurseAndWrapBaseType(AST::PropertyAccessExpression&);
474     bool recurseAndRequireBoolType(AST::Expression&);
475     void assignType(AST::Expression&, UniqueRef<AST::UnnamedType>&&, AST::TypeAnnotation);
476     void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>&&, AST::TypeAnnotation);
477     void forwardType(AST::Expression&, ResolvingType&, AST::TypeAnnotation);
478
479     void visit(AST::FunctionDefinition&) override;
480     void visit(AST::EnumerationDefinition&) override;
481     void visit(AST::TypeReference&) override;
482     void visit(AST::VariableDeclaration&) override;
483     void visit(AST::AssignmentExpression&) override;
484     void visit(AST::ReadModifyWriteExpression&) override;
485     void visit(AST::DereferenceExpression&) override;
486     void visit(AST::MakePointerExpression&) override;
487     void visit(AST::MakeArrayReferenceExpression&) override;
488     void visit(AST::DotExpression&) override;
489     void visit(AST::IndexExpression&) override;
490     void visit(AST::VariableReference&) override;
491     void visit(AST::Return&) override;
492     void visit(AST::PointerType&) override;
493     void visit(AST::ArrayReferenceType&) override;
494     void visit(AST::IntegerLiteral&) override;
495     void visit(AST::UnsignedIntegerLiteral&) override;
496     void visit(AST::FloatLiteral&) override;
497     void visit(AST::NullLiteral&) override;
498     void visit(AST::BooleanLiteral&) override;
499     void visit(AST::EnumerationMemberLiteral&) override;
500     void visit(AST::LogicalNotExpression&) override;
501     void visit(AST::LogicalExpression&) override;
502     void visit(AST::IfStatement&) override;
503     void visit(AST::WhileLoop&) override;
504     void visit(AST::DoWhileLoop&) override;
505     void visit(AST::ForLoop&) override;
506     void visit(AST::SwitchStatement&) override;
507     void visit(AST::CommaExpression&) override;
508     void visit(AST::TernaryExpression&) override;
509     void visit(AST::CallExpression&) override;
510
511     void finishVisiting(AST::PropertyAccessExpression&, ResolvingType* additionalArgumentType = nullptr);
512
513     HashMap<AST::Expression*, std::unique_ptr<ResolvingType>> m_typeMap;
514     HashMap<AST::Expression*, AST::TypeAnnotation> m_typeAnnotations;
515     HashSet<String> m_vertexEntryPoints;
516     HashSet<String> m_fragmentEntryPoints;
517     HashSet<String> m_computeEntryPoints;
518     const Intrinsics& m_intrinsics;
519     Program& m_program;
520 };
521
522 void Checker::visit(Program& program)
523 {
524     // These visiting functions might add new global statements, so don't use foreach syntax.
525     for (size_t i = 0; i < program.typeDefinitions().size(); ++i)
526         checkErrorAndVisit(program.typeDefinitions()[i]);
527     for (size_t i = 0; i < program.structureDefinitions().size(); ++i)
528         checkErrorAndVisit(program.structureDefinitions()[i]);
529     for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i)
530         checkErrorAndVisit(program.enumerationDefinitions()[i]);
531     for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i)
532         checkErrorAndVisit(program.nativeTypeDeclarations()[i]);
533
534     for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
535         checkErrorAndVisit(program.functionDefinitions()[i]);
536     for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i)
537         checkErrorAndVisit(program.nativeFunctionDeclarations()[i]);
538 }
539
540 bool Checker::assignTypes()
541 {
542     for (auto& keyValuePair : m_typeMap) {
543         auto success = keyValuePair.value->visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
544             keyValuePair.key->setType(unnamedType->clone());
545             return true;
546         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
547             if (!resolvableTypeReference->resolvableType().maybeResolvedType()) {
548                 if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType())))
549                     return false;
550             }
551             keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType().clone());
552             return true;
553         }));
554         if (!success)
555             return false;
556     }
557
558     for (auto& keyValuePair : m_typeAnnotations)
559         keyValuePair.key->setTypeAnnotation(WTFMove(keyValuePair.value));
560     return true;
561 }
562
563 bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition)
564 {
565     switch (*functionDefinition.entryPointType()) {
566     case AST::EntryPointType::Vertex:
567         return static_cast<bool>(m_vertexEntryPoints.add(functionDefinition.name()));
568     case AST::EntryPointType::Fragment:
569         return static_cast<bool>(m_fragmentEntryPoints.add(functionDefinition.name()));
570     case AST::EntryPointType::Compute:
571         return static_cast<bool>(m_computeEntryPoints.add(functionDefinition.name()));
572     }
573 }
574
575 void Checker::visit(AST::FunctionDefinition& functionDefinition)
576 {
577     if (functionDefinition.entryPointType()) {
578         if (!checkShaderType(functionDefinition)) {
579             setError();
580             return;
581         }
582         auto entryPointItems = gatherEntryPointItems(m_intrinsics, functionDefinition);
583         if (!entryPointItems) {
584             setError();
585             return;
586         }
587         if (!checkSemantics(entryPointItems->inputs, entryPointItems->outputs, functionDefinition.entryPointType(), m_intrinsics)) {
588             setError();
589             return;
590         }
591     }
592     if (!checkOperatorOverload(functionDefinition, m_intrinsics, m_program.nameContext())) {
593         setError();
594         return;
595     }
596
597     Visitor::visit(functionDefinition);
598 }
599
600 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& left, ResolvingType& right)
601 {
602     return left.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
603         return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
604             if (matches(left, right))
605                 return left->clone();
606             return WTF::nullopt;
607         }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
608             return matchAndCommit(left, right->resolvableType());
609         }));
610     }, [&](RefPtr<ResolvableTypeReference>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
611         return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
612             return matchAndCommit(right, left->resolvableType());
613         }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
614             return matchAndCommit(left->resolvableType(), right->resolvableType());
615         }));
616     }));
617 }
618
619 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType)
620 {
621     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
622         if (matches(unnamedType, resolvingType))
623             return unnamedType.clone();
624         return WTF::nullopt;
625     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
626         return matchAndCommit(unnamedType, resolvingType->resolvableType());
627     }));
628 }
629
630 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::NamedType& namedType)
631 {
632     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
633         if (matches(resolvingType, namedType))
634             return resolvingType->clone();
635         return WTF::nullopt;
636     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
637         return matchAndCommit(namedType, resolvingType->resolvableType());
638     }));
639 }
640
641 static Optional<UniqueRef<AST::UnnamedType>> commit(ResolvingType& resolvingType)
642 {
643     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> Optional<UniqueRef<AST::UnnamedType>> {
644         return unnamedType->clone();
645     }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Optional<UniqueRef<AST::UnnamedType>> {
646         if (!resolvableTypeReference->resolvableType().maybeResolvedType())
647             return commit(resolvableTypeReference->resolvableType());
648         return resolvableTypeReference->resolvableType().resolvedType().clone();
649     }));
650 }
651
652 void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
653 {
654     auto* baseType = ([&]() -> AST::NativeTypeDeclaration* {
655         checkErrorAndVisit(enumerationDefinition.type());
656         auto& baseType = enumerationDefinition.type().unifyNode();
657         if (!is<AST::NamedType>(baseType))
658             return nullptr;
659         auto& namedType = downcast<AST::NamedType>(baseType);
660         if (!is<AST::NativeTypeDeclaration>(namedType))
661             return nullptr;
662         auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
663         if (!nativeTypeDeclaration.isInt())
664             return nullptr;
665         return &nativeTypeDeclaration;
666     })();
667     if (!baseType) {
668         setError();
669         return;
670     }
671
672     auto enumerationMembers = enumerationDefinition.enumerationMembers();
673
674     auto matchAndCommitMember = [&](AST::EnumerationMember& member) -> bool {
675         return member.value()->visit(WTF::makeVisitor([&](AST::Expression& value) -> bool {
676             auto valueInfo = recurseAndGetInfo(value);
677             if (!valueInfo)
678                 return false;
679             return static_cast<bool>(matchAndCommit(valueInfo->resolvingType, *baseType));
680         }));
681     };
682
683     for (auto& member : enumerationMembers) {
684         if (!member.get().value())
685             continue;
686
687         if (!matchAndCommitMember(member)) {
688             setError();
689             return;
690         }
691     }
692
693     int64_t nextValue = 0;
694     for (auto& member : enumerationMembers) {
695         if (member.get().value()) {
696             auto value = member.get().value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
697                 return integerLiteral.valueForSelectedType();
698             }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
699                 return unsignedIntegerLiteral.valueForSelectedType();
700             }, [&](auto&) -> int64_t {
701                 ASSERT_NOT_REACHED();
702                 return 0;
703             }));
704             nextValue = baseType->successor()(value);
705         } else {
706             if (nextValue > std::numeric_limits<int>::max()) {
707                 ASSERT(nextValue <= std::numeric_limits<unsigned>::max());
708                 member.get().setValue(AST::ConstantExpression(AST::UnsignedIntegerLiteral(Lexer::Token(member.get().origin()), static_cast<unsigned>(nextValue))));
709             }
710             ASSERT(nextValue >= std::numeric_limits<int>::min());
711             member.get().setValue(AST::ConstantExpression(AST::IntegerLiteral(Lexer::Token(member.get().origin()), static_cast<int>(nextValue))));
712
713             if (!matchAndCommitMember(member)) {
714                 setError();
715                 return;
716             }
717
718             nextValue = baseType->successor()(nextValue);
719         }
720     }
721
722     auto getValue = [&](AST::EnumerationMember& member) -> int64_t {
723         ASSERT(member.value());
724         auto value = member.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
725             return integerLiteral.value();
726         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
727             return unsignedIntegerLiteral.value();
728         }, [&](auto&) -> int64_t {
729             ASSERT_NOT_REACHED();
730             return 0;
731         }));
732         return value;
733     };
734
735     for (size_t i = 0; i < enumerationMembers.size(); ++i) {
736         auto value = getValue(enumerationMembers[i].get());
737         for (size_t j = i + 1; j < enumerationMembers.size(); ++j) {
738             auto otherValue = getValue(enumerationMembers[j].get());
739             if (value == otherValue) {
740                 setError();
741                 return;
742             }
743         }
744     }
745
746     bool foundZero = false;
747     for (auto& member : enumerationMembers) {
748         if (!getValue(member.get())) {
749             foundZero = true;
750             break;
751         }
752     }
753     if (!foundZero) {
754         setError();
755         return;
756     }
757 }
758
759 void Checker::visit(AST::TypeReference& typeReference)
760 {
761     ASSERT(typeReference.maybeResolvedType());
762
763     for (auto& typeArgument : typeReference.typeArguments())
764         checkErrorAndVisit(typeArgument);
765 }
766
767 auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
768 {
769     Visitor::visit(expression);
770     if (error())
771         return WTF::nullopt;
772     return getInfo(expression, requiresLeftValue);
773 }
774
775 auto Checker::getInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
776 {
777     auto typeIterator = m_typeMap.find(&expression);
778     ASSERT(typeIterator != m_typeMap.end());
779
780     auto typeAnnotationIterator = m_typeAnnotations.find(&expression);
781     ASSERT(typeAnnotationIterator != m_typeAnnotations.end());
782     if (requiresLeftValue && typeAnnotationIterator->value.isRightValue()) {
783         setError();
784         return WTF::nullopt;
785     }
786     return {{ *typeIterator->value, typeAnnotationIterator->value }};
787 }
788
789 void Checker::visit(AST::VariableDeclaration& variableDeclaration)
790 {
791     // ReadModifyWriteExpressions are the only place where anonymous variables exist,
792     // and that doesn't recurse on the anonymous variables, so we can assume the variable has a type.
793     checkErrorAndVisit(*variableDeclaration.type());
794     if (variableDeclaration.initializer()) {
795         auto& lhsType = *variableDeclaration.type();
796         auto initializerInfo = recurseAndGetInfo(*variableDeclaration.initializer());
797         if (!initializerInfo)
798             return;
799         if (!matchAndCommit(initializerInfo->resolvingType, lhsType)) {
800             setError();
801             return;
802         }
803     }
804 }
805
806 void Checker::assignType(AST::Expression& expression, UniqueRef<AST::UnnamedType>&& unnamedType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
807 {
808     auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(WTFMove(unnamedType)));
809     ASSERT_UNUSED(addResult, addResult.isNewEntry);
810     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
811     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
812 }
813
814 void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference>&& resolvableTypeReference, AST::TypeAnnotation typeAnnotation = AST::RightValue())
815 {
816     auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(WTFMove(resolvableTypeReference)));
817     ASSERT_UNUSED(addResult, addResult.isNewEntry);
818     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
819     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
820 }
821
822 void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
823 {
824     resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& result) {
825         auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(result->clone()));
826         ASSERT_UNUSED(addResult, addResult.isNewEntry);
827     }, [&](RefPtr<ResolvableTypeReference>& result) {
828         auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(result.copyRef()));
829         ASSERT_UNUSED(addResult, addResult.isNewEntry);
830     }));
831     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
832     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
833 }
834
835 void Checker::visit(AST::AssignmentExpression& assignmentExpression)
836 {
837     auto leftInfo = recurseAndGetInfo(assignmentExpression.left(), true);
838     if (!leftInfo)
839         return;
840
841     if (leftInfo->typeAnnotation.isRightValue()) {
842         setError();
843         return;
844     }
845
846     auto rightInfo = recurseAndGetInfo(assignmentExpression.right());
847     if (!rightInfo)
848         return;
849
850     auto resultType = matchAndCommit(leftInfo->resolvingType, rightInfo->resolvingType);
851     if (!resultType) {
852         setError();
853         return;
854     }
855
856     assignType(assignmentExpression, WTFMove(*resultType));
857 }
858
859 void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
860 {
861     auto leftValueInfo = recurseAndGetInfo(readModifyWriteExpression.leftValue(), true);
862     if (!leftValueInfo)
863         return;
864
865     readModifyWriteExpression.oldValue().setType(leftValueInfo->resolvingType.getUnnamedType()->clone());
866
867     auto newValueInfo = recurseAndGetInfo(readModifyWriteExpression.newValueExpression());
868     if (!newValueInfo)
869         return;
870
871     if (Optional<UniqueRef<AST::UnnamedType>> matchedType = matchAndCommit(leftValueInfo->resolvingType, newValueInfo->resolvingType))
872         readModifyWriteExpression.newValue().setType(WTFMove(matchedType.value()));
873     else {
874         setError();
875         return;
876     }
877
878     auto resultInfo = recurseAndGetInfo(readModifyWriteExpression.resultExpression());
879     if (!resultInfo)
880         return;
881
882     forwardType(readModifyWriteExpression, resultInfo->resolvingType);
883 }
884
885 static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType)
886 {
887     return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& type) -> AST::UnnamedType* {
888         return &type;
889     }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* {
890         // FIXME: If the type isn't committed, should we just commit() it now?
891         return type->resolvableType().maybeResolvedType();
892     }));
893 }
894
895 void Checker::visit(AST::DereferenceExpression& dereferenceExpression)
896 {
897     auto pointerInfo = recurseAndGetInfo(dereferenceExpression.pointer());
898     if (!pointerInfo)
899         return;
900
901     auto* unnamedType = getUnnamedType(pointerInfo->resolvingType);
902
903     auto* pointerType = ([&](AST::UnnamedType* unnamedType) -> AST::PointerType* {
904         if (!unnamedType)
905             return nullptr;
906         auto& unifyNode = unnamedType->unifyNode();
907         if (!is<AST::UnnamedType>(unifyNode))
908             return nullptr;
909         auto& unnamedUnifyType = downcast<AST::UnnamedType>(unifyNode);
910         if (!is<AST::PointerType>(unnamedUnifyType))
911             return nullptr;
912         return &downcast<AST::PointerType>(unnamedUnifyType);
913     })(unnamedType);
914     if (!pointerType) {
915         setError();
916         return;
917     }
918
919     assignType(dereferenceExpression, pointerType->elementType().clone(), AST::LeftValue { pointerType->addressSpace() });
920 }
921
922 void Checker::visit(AST::MakePointerExpression& makePointerExpression)
923 {
924     auto leftValueInfo = recurseAndGetInfo(makePointerExpression.leftValue(), true);
925     if (!leftValueInfo)
926         return;
927
928     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
929     if (!leftAddressSpace) {
930         setError();
931         return;
932     }
933
934     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
935     if (!leftValueType) {
936         setError();
937         return;
938     }
939
940     assignType(makePointerExpression, makeUniqueRef<AST::PointerType>(Lexer::Token(makePointerExpression.origin()), *leftAddressSpace, leftValueType->clone()));
941 }
942
943 void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
944 {
945     auto leftValueInfo = recurseAndGetInfo(makeArrayReferenceExpression.leftValue());
946     if (!leftValueInfo)
947         return;
948
949     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
950     if (!leftValueType) {
951         setError();
952         return;
953     }
954
955     auto& unifyNode = leftValueType->unifyNode();
956     if (is<AST::UnnamedType>(unifyNode)) {
957         auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
958         if (is<AST::PointerType>(unnamedType)) {
959             auto& pointerType = downcast<AST::PointerType>(unnamedType);
960             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the fact that we're not targetting the item; we're targetting the item's inner element.
961             assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), pointerType.addressSpace(), pointerType.elementType().clone()));
962             return;
963         }
964
965         auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
966         if (!leftAddressSpace) {
967             setError();
968             return;
969         }
970
971         if (is<AST::ArrayType>(unnamedType)) {
972             auto& arrayType = downcast<AST::ArrayType>(unnamedType);
973             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the number of elements.
974             assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *leftAddressSpace, arrayType.type().clone()));
975             return;
976         }
977     }
978
979     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
980     if (!leftAddressSpace) {
981         setError();
982         return;
983     }
984
985     assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *leftAddressSpace, leftValueType->clone()));
986 }
987
988 static Optional<UniqueRef<AST::UnnamedType>> argumentTypeForAndOverload(AST::UnnamedType& baseType, AST::AddressSpace addressSpace)
989 {
990     auto& unifyNode = baseType.unifyNode();
991     if (is<AST::NamedType>(unifyNode)) {
992         auto& namedType = downcast<AST::NamedType>(unifyNode);
993         return { makeUniqueRef<AST::PointerType>(Lexer::Token(namedType.origin()), addressSpace, AST::TypeReference::wrap(Lexer::Token(namedType.origin()), namedType)) };
994     }
995
996     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
997
998     if (is<AST::ArrayReferenceType>(unnamedType))
999         return unnamedType.clone();
1000
1001     if (is<AST::ArrayType>(unnamedType))
1002         return { makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(unnamedType.origin()), addressSpace, downcast<AST::ArrayType>(unnamedType).type().clone()) };
1003
1004     if (is<AST::PointerType>(unnamedType))
1005         return WTF::nullopt;
1006
1007     return { makeUniqueRef<AST::PointerType>(Lexer::Token(unnamedType.origin()), addressSpace, unnamedType.clone()) };
1008 }
1009
1010 void Checker::finishVisiting(AST::PropertyAccessExpression& propertyAccessExpression, ResolvingType* additionalArgumentType)
1011 {
1012     auto baseInfo = recurseAndGetInfo(propertyAccessExpression.base());
1013     if (!baseInfo)
1014         return;
1015     auto baseUnnamedType = commit(baseInfo->resolvingType);
1016     if (!baseUnnamedType)
1017         return;
1018
1019     AST::FunctionDeclaration* getterFunction = nullptr;
1020     AST::UnnamedType* getterReturnType = nullptr;
1021     {
1022         Vector<std::reference_wrapper<ResolvingType>> getterArgumentTypes { baseInfo->resolvingType };
1023         if (additionalArgumentType)
1024             getterArgumentTypes.append(*additionalArgumentType);
1025         auto getterName = propertyAccessExpression.getterFunctionName();
1026         auto* getterFunctions = m_program.nameContext().getFunctions(getterName);
1027         getterFunction = resolveFunction(m_program, getterFunctions, getterArgumentTypes, getterName, propertyAccessExpression.origin(), m_intrinsics);
1028         if (getterFunction)
1029             getterReturnType = &getterFunction->type();
1030     }
1031
1032     AST::FunctionDeclaration* anderFunction = nullptr;
1033     AST::UnnamedType* anderReturnType = nullptr;
1034     auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace();
1035     if (leftAddressSpace) {
1036         if (auto argumentTypeForAndOverload = WHLSL::argumentTypeForAndOverload(*baseUnnamedType, *leftAddressSpace)) {
1037             ResolvingType argumentType = { WTFMove(*argumentTypeForAndOverload) };
1038             Vector<std::reference_wrapper<ResolvingType>> anderArgumentTypes { argumentType };
1039             if (additionalArgumentType)
1040                 anderArgumentTypes.append(*additionalArgumentType);
1041             auto anderName = propertyAccessExpression.anderFunctionName();
1042             auto* anderFunctions = m_program.nameContext().getFunctions(anderName);
1043             anderFunction = resolveFunction(m_program, anderFunctions, anderArgumentTypes, anderName, propertyAccessExpression.origin(), m_intrinsics);
1044             if (anderFunction)
1045                 anderReturnType = &downcast<AST::PointerType>(anderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1046         }
1047     }
1048
1049     AST::FunctionDeclaration* threadAnderFunction = nullptr;
1050     AST::UnnamedType* threadAnderReturnType = nullptr;
1051     if (auto argumentTypeForAndOverload = WHLSL::argumentTypeForAndOverload(*baseUnnamedType, AST::AddressSpace::Thread)) {
1052         ResolvingType argumentType = { makeUniqueRef<AST::PointerType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, baseUnnamedType->get().clone()) };
1053         Vector<std::reference_wrapper<ResolvingType>> threadAnderArgumentTypes { argumentType };
1054         if (additionalArgumentType)
1055             threadAnderArgumentTypes.append(*additionalArgumentType);
1056         auto anderName = propertyAccessExpression.anderFunctionName();
1057         auto* anderFunctions = m_program.nameContext().getFunctions(anderName);
1058         threadAnderFunction = resolveFunction(m_program, anderFunctions, threadAnderArgumentTypes, anderName, propertyAccessExpression.origin(), m_intrinsics);
1059         if (threadAnderFunction)
1060             threadAnderReturnType = &downcast<AST::PointerType>(threadAnderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1061     }
1062
1063     if (leftAddressSpace && !anderFunction && !getterFunction) {
1064         setError();
1065         return;
1066     }
1067
1068     if (!leftAddressSpace && !threadAnderFunction && !getterFunction) {
1069         setError();
1070         return;
1071     }
1072
1073     if (threadAnderFunction && getterFunction) {
1074         setError();
1075         return;
1076     }
1077
1078     if (anderFunction && threadAnderFunction && !matches(*anderReturnType, *threadAnderReturnType)) {
1079         setError();
1080         return;
1081     }
1082
1083     if (getterFunction && anderFunction && !matches(*getterReturnType, *anderReturnType)) {
1084         setError();
1085         return;
1086     }
1087
1088     if (getterFunction && threadAnderFunction && !matches(*getterReturnType, *threadAnderReturnType)) {
1089         setError();
1090         return;
1091     }
1092
1093     AST::UnnamedType* fieldType = getterReturnType ? getterReturnType : anderReturnType ? anderReturnType : threadAnderReturnType;
1094
1095     AST::FunctionDeclaration* setterFunction = nullptr;
1096     AST::UnnamedType* setterReturnType = nullptr;
1097     {
1098         ResolvingType fieldResolvingType(fieldType->clone());
1099         Vector<std::reference_wrapper<ResolvingType>> setterArgumentTypes { baseInfo->resolvingType };
1100         if (additionalArgumentType)
1101             setterArgumentTypes.append(*additionalArgumentType);
1102         setterArgumentTypes.append(fieldResolvingType);
1103         auto setterName = propertyAccessExpression.setterFunctionName();
1104         auto* setterFunctions = m_program.nameContext().getFunctions(setterName);
1105         setterFunction = resolveFunction(m_program, setterFunctions, setterArgumentTypes, setterName, propertyAccessExpression.origin(), m_intrinsics);
1106         if (setterFunction)
1107             setterReturnType = &setterFunction->type();
1108     }
1109
1110     if (setterFunction && !getterFunction) {
1111         setError();
1112         return;
1113     }
1114
1115     propertyAccessExpression.setGetterFunction(getterFunction);
1116     propertyAccessExpression.setAnderFunction(anderFunction);
1117     propertyAccessExpression.setThreadAnderFunction(threadAnderFunction);
1118     propertyAccessExpression.setSetterFunction(setterFunction);
1119
1120     AST::TypeAnnotation typeAnnotation = AST::RightValue();
1121     if (auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace()) {
1122         if (anderFunction)
1123             typeAnnotation = AST::LeftValue { *leftAddressSpace };
1124         else if (setterFunction)
1125             typeAnnotation = AST::AbstractLeftValue();
1126     } else if (!baseInfo->typeAnnotation.isRightValue() && (setterFunction || threadAnderFunction))
1127         typeAnnotation = AST::AbstractLeftValue();
1128     assignType(propertyAccessExpression, fieldType->clone(), WTFMove(typeAnnotation));
1129 }
1130
1131 void Checker::visit(AST::DotExpression& dotExpression)
1132 {
1133     finishVisiting(dotExpression);
1134 }
1135
1136 void Checker::visit(AST::IndexExpression& indexExpression)
1137 {
1138     auto baseInfo = recurseAndGetInfo(indexExpression.indexExpression());
1139     if (!baseInfo)
1140         return;
1141     finishVisiting(indexExpression, &baseInfo->resolvingType);
1142 }
1143
1144 void Checker::visit(AST::VariableReference& variableReference)
1145 {
1146     ASSERT(variableReference.variable());
1147     ASSERT(variableReference.variable()->type());
1148     
1149     assignType(variableReference, variableReference.variable()->type()->clone(), AST::LeftValue { AST::AddressSpace::Thread });
1150 }
1151
1152 void Checker::visit(AST::Return& returnStatement)
1153 {
1154     ASSERT(returnStatement.function());
1155     if (returnStatement.value()) {
1156         auto valueInfo = recurseAndGetInfo(*returnStatement.value());
1157         if (!valueInfo)
1158             return;
1159         if (!matchAndCommit(valueInfo->resolvingType, returnStatement.function()->type()))
1160             setError();
1161         return;
1162     }
1163
1164     if (!matches(returnStatement.function()->type(), m_intrinsics.voidType()))
1165         setError();
1166 }
1167
1168 void Checker::visit(AST::PointerType&)
1169 {
1170     // Following pointer types can cause infinite loops because of data structures
1171     // like linked lists.
1172     // FIXME: Make sure this function should be empty
1173 }
1174
1175 void Checker::visit(AST::ArrayReferenceType&)
1176 {
1177     // Following array reference types can cause infinite loops because of data
1178     // structures like linked lists.
1179     // FIXME: Make sure this function should be empty
1180 }
1181
1182 void Checker::visit(AST::IntegerLiteral& integerLiteral)
1183 {
1184     assignType(integerLiteral, adoptRef(*new ResolvableTypeReference(integerLiteral.type())));
1185 }
1186
1187 void Checker::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
1188 {
1189     assignType(unsignedIntegerLiteral, adoptRef(*new ResolvableTypeReference(unsignedIntegerLiteral.type())));
1190 }
1191
1192 void Checker::visit(AST::FloatLiteral& floatLiteral)
1193 {
1194     assignType(floatLiteral, adoptRef(*new ResolvableTypeReference(floatLiteral.type())));
1195 }
1196
1197 void Checker::visit(AST::NullLiteral& nullLiteral)
1198 {
1199     assignType(nullLiteral, adoptRef(*new ResolvableTypeReference(nullLiteral.type())));
1200 }
1201
1202 void Checker::visit(AST::BooleanLiteral& booleanLiteral)
1203 {
1204     assignType(booleanLiteral, AST::TypeReference::wrap(Lexer::Token(booleanLiteral.origin()), m_intrinsics.boolType()));
1205 }
1206
1207 void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
1208 {
1209     ASSERT(enumerationMemberLiteral.enumerationDefinition());
1210     auto& enumerationDefinition = *enumerationMemberLiteral.enumerationDefinition();
1211     assignType(enumerationMemberLiteral, AST::TypeReference::wrap(Lexer::Token(enumerationMemberLiteral.origin()), enumerationDefinition));
1212 }
1213
1214 bool Checker::isBoolType(ResolvingType& resolvingType)
1215 {
1216     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> bool {
1217         return matches(left, m_intrinsics.boolType());
1218     }, [&](RefPtr<ResolvableTypeReference>& left) -> bool {
1219         return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType()));
1220     }));
1221 }
1222
1223 bool Checker::recurseAndRequireBoolType(AST::Expression& expression)
1224 {
1225     auto expressionInfo = recurseAndGetInfo(expression);
1226     if (!expressionInfo)
1227         return false;
1228     if (!isBoolType(expressionInfo->resolvingType)) {
1229         setError();
1230         return false;
1231     }
1232     return true;
1233 }
1234
1235 void Checker::visit(AST::LogicalNotExpression& logicalNotExpression)
1236 {
1237     if (!recurseAndRequireBoolType(logicalNotExpression.operand()))
1238         return;
1239     assignType(logicalNotExpression, AST::TypeReference::wrap(Lexer::Token(logicalNotExpression.origin()), m_intrinsics.boolType()));
1240 }
1241
1242 void Checker::visit(AST::LogicalExpression& logicalExpression)
1243 {
1244     if (!recurseAndRequireBoolType(logicalExpression.left()))
1245         return;
1246     if (!recurseAndRequireBoolType(logicalExpression.right()))
1247         return;
1248     assignType(logicalExpression, AST::TypeReference::wrap(Lexer::Token(logicalExpression.origin()), m_intrinsics.boolType()));
1249 }
1250
1251 void Checker::visit(AST::IfStatement& ifStatement)
1252 {
1253     if (!recurseAndRequireBoolType(ifStatement.conditional()))
1254         return;
1255     checkErrorAndVisit(ifStatement.body());
1256     if (ifStatement.elseBody())
1257         checkErrorAndVisit(*ifStatement.elseBody());
1258 }
1259
1260 void Checker::visit(AST::WhileLoop& whileLoop)
1261 {
1262     if (!recurseAndRequireBoolType(whileLoop.conditional()))
1263         return;
1264     checkErrorAndVisit(whileLoop.body());
1265 }
1266
1267 void Checker::visit(AST::DoWhileLoop& doWhileLoop)
1268 {
1269     checkErrorAndVisit(doWhileLoop.body());
1270     recurseAndRequireBoolType(doWhileLoop.conditional());
1271 }
1272
1273 void Checker::visit(AST::ForLoop& forLoop)
1274 {
1275     WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::Statement>& statement) {
1276         checkErrorAndVisit(statement);
1277     }, [&](UniqueRef<AST::Expression>& expression) {
1278         checkErrorAndVisit(expression);
1279     }), forLoop.initialization());
1280     if (error())
1281         return;
1282     if (forLoop.condition()) {
1283         if (!recurseAndRequireBoolType(*forLoop.condition()))
1284             return;
1285     }
1286     if (forLoop.increment())
1287         checkErrorAndVisit(*forLoop.increment());
1288     checkErrorAndVisit(forLoop.body());
1289 }
1290
1291 void Checker::visit(AST::SwitchStatement& switchStatement)
1292 {
1293     auto* valueType = ([&]() -> AST::NamedType* {
1294         auto valueInfo = recurseAndGetInfo(switchStatement.value());
1295         if (!valueInfo)
1296             return nullptr;
1297         auto* valueType = getUnnamedType(valueInfo->resolvingType);
1298         if (!valueType)
1299             return nullptr;
1300         auto& valueUnifyNode = valueType->unifyNode();
1301         if (!is<AST::NamedType>(valueUnifyNode))
1302             return nullptr;
1303         auto& valueNamedUnifyNode = downcast<AST::NamedType>(valueUnifyNode);
1304         if (!(is<AST::NativeTypeDeclaration>(valueNamedUnifyNode) && downcast<AST::NativeTypeDeclaration>(valueNamedUnifyNode).isInt())
1305             && !is<AST::EnumerationDefinition>(valueNamedUnifyNode))
1306             return nullptr;
1307         return &valueNamedUnifyNode;
1308     })();
1309     if (!valueType) {
1310         setError();
1311         return;
1312     }
1313
1314     bool hasDefault = false;
1315     for (auto& switchCase : switchStatement.switchCases()) {
1316         checkErrorAndVisit(switchCase.block());
1317         if (!switchCase.value()) {
1318             hasDefault = true;
1319             continue;
1320         }
1321         auto success = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> bool {
1322             return static_cast<bool>(matchAndCommit(*valueType, integerLiteral.type()));
1323         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> bool {
1324             return static_cast<bool>(matchAndCommit(*valueType, unsignedIntegerLiteral.type()));
1325         }, [&](AST::FloatLiteral& floatLiteral) -> bool {
1326             return static_cast<bool>(matchAndCommit(*valueType, floatLiteral.type()));
1327         }, [&](AST::NullLiteral& nullLiteral) -> bool {
1328             return static_cast<bool>(matchAndCommit(*valueType, nullLiteral.type()));
1329         }, [&](AST::BooleanLiteral&) -> bool {
1330             return matches(*valueType, m_intrinsics.boolType());
1331         }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> bool {
1332             ASSERT(enumerationMemberLiteral.enumerationDefinition());
1333             return matches(*valueType, *enumerationMemberLiteral.enumerationDefinition());
1334         }));
1335         if (!success) {
1336             setError();
1337             return;
1338         }
1339     }
1340
1341     for (size_t i = 0; i < switchStatement.switchCases().size(); ++i) {
1342         auto& firstCase = switchStatement.switchCases()[i];
1343         for (size_t j = i + 1; j < switchStatement.switchCases().size(); ++j) {
1344             auto& secondCase = switchStatement.switchCases()[j];
1345             
1346             if (static_cast<bool>(firstCase.value()) != static_cast<bool>(secondCase.value()))
1347                 continue;
1348
1349             if (!static_cast<bool>(firstCase.value())) {
1350                 setError();
1351                 return;
1352             }
1353
1354             auto success = firstCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& firstIntegerLiteral) -> bool {
1355                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1356                     return firstIntegerLiteral.value() != secondIntegerLiteral.value();
1357                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1358                     return static_cast<int64_t>(firstIntegerLiteral.value()) != static_cast<int64_t>(secondUnsignedIntegerLiteral.value());
1359                 }, [](auto&) -> bool {
1360                     return true;
1361                 }));
1362             }, [&](AST::UnsignedIntegerLiteral& firstUnsignedIntegerLiteral) -> bool {
1363                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1364                     return static_cast<int64_t>(firstUnsignedIntegerLiteral.value()) != static_cast<int64_t>(secondIntegerLiteral.value());
1365                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1366                     return firstUnsignedIntegerLiteral.value() != secondUnsignedIntegerLiteral.value();
1367                 }, [](auto&) -> bool {
1368                     return true;
1369                 }));
1370             }, [&](AST::EnumerationMemberLiteral& firstEnumerationMemberLiteral) -> bool {
1371                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& secondEnumerationMemberLiteral) -> bool {
1372                     ASSERT(firstEnumerationMemberLiteral.enumerationMember());
1373                     ASSERT(secondEnumerationMemberLiteral.enumerationMember());
1374                     return firstEnumerationMemberLiteral.enumerationMember() != secondEnumerationMemberLiteral.enumerationMember();
1375                 }, [](auto&) -> bool {
1376                     return true;
1377                 }));
1378             }, [](auto&) -> bool {
1379                 return true;
1380             }));
1381             if (!success) {
1382                 setError();
1383                 return;
1384             }
1385         }
1386     }
1387
1388     if (!hasDefault) {
1389         if (is<AST::NativeTypeDeclaration>(*valueType)) {
1390             HashSet<int64_t> values;
1391             bool zeroValueExists;
1392             for (auto& switchCase : switchStatement.switchCases()) {
1393                 auto value = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
1394                     return integerLiteral.valueForSelectedType();
1395                 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
1396                     return unsignedIntegerLiteral.valueForSelectedType();
1397                 }, [](auto&) -> int64_t {
1398                     ASSERT_NOT_REACHED();
1399                     return 0;
1400                 }));
1401                 if (!value)
1402                     zeroValueExists = true;
1403                 else
1404                     values.add(value);
1405             }
1406             bool success = true;
1407             downcast<AST::NativeTypeDeclaration>(*valueType).iterateAllValues([&](int64_t value) -> bool {
1408                 if (!value) {
1409                     if (!zeroValueExists) {
1410                         success = false;
1411                         return true;
1412                     }
1413                     return false;
1414                 }
1415                 if (!values.contains(value)) {
1416                     success = false;
1417                     return true;
1418                 }
1419                 return false;
1420             });
1421             if (!success) {
1422                 setError();
1423                 return;
1424             }
1425         } else {
1426             HashSet<AST::EnumerationMember*> values;
1427             for (auto& switchCase : switchStatement.switchCases()) {
1428                 switchCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
1429                     ASSERT(enumerationMemberLiteral.enumerationMember());
1430                     values.add(enumerationMemberLiteral.enumerationMember());
1431                 }, [](auto&) {
1432                     ASSERT_NOT_REACHED();
1433                 }));
1434             }
1435             for (auto& enumerationMember : downcast<AST::EnumerationDefinition>(*valueType).enumerationMembers()) {
1436                 if (!values.contains(&enumerationMember.get())) {
1437                     setError();
1438                     return;
1439                 }
1440             }
1441         }
1442     }
1443 }
1444
1445 void Checker::visit(AST::CommaExpression& commaExpression)
1446 {
1447     ASSERT(commaExpression.list().size() > 0);
1448     Visitor::visit(commaExpression);
1449     if (error())
1450         return;
1451     auto lastInfo = getInfo(commaExpression.list().last());
1452     forwardType(commaExpression, lastInfo->resolvingType);
1453 }
1454
1455 void Checker::visit(AST::TernaryExpression& ternaryExpression)
1456 {
1457     auto predicateInfo = recurseAndRequireBoolType(ternaryExpression.predicate());
1458     if (!predicateInfo)
1459         return;
1460
1461     auto bodyInfo = recurseAndGetInfo(ternaryExpression.bodyExpression());
1462     auto elseInfo = recurseAndGetInfo(ternaryExpression.elseExpression());
1463     
1464     auto resultType = matchAndCommit(bodyInfo->resolvingType, elseInfo->resolvingType);
1465     if (!resultType) {
1466         setError();
1467         return;
1468     }
1469
1470     assignType(ternaryExpression, WTFMove(*resultType));
1471 }
1472
1473 void Checker::visit(AST::CallExpression& callExpression)
1474 {
1475     Vector<std::reference_wrapper<ResolvingType>> types;
1476     types.reserveInitialCapacity(callExpression.arguments().size());
1477     for (auto& argument : callExpression.arguments()) {
1478         auto argumentInfo = recurseAndGetInfo(argument);
1479         if (!argumentInfo)
1480             return;
1481         types.uncheckedAppend(argumentInfo->resolvingType);
1482     }
1483     // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
1484     // We don't want to recurse to the same node twice.
1485
1486     NameContext& nameContext = m_program.nameContext();
1487     auto* functions = nameContext.getFunctions(callExpression.name());
1488     if (!functions) {
1489         if (auto* types = nameContext.getTypes(callExpression.name())) {
1490             if (types->size() == 1) {
1491                 if ((functions = nameContext.getFunctions("operator cast"_str)))
1492                     callExpression.setCastData((*types)[0].get());
1493             }
1494         }
1495     }
1496     if (!functions) {
1497         setError();
1498         return;
1499     }
1500
1501     auto* function = resolveFunction(m_program, functions, types, callExpression.name(), callExpression.origin(), m_intrinsics, callExpression.castReturnType());
1502     if (!function) {
1503         setError();
1504         return;
1505     }
1506
1507     for (size_t i = 0; i < function->parameters().size(); ++i) {
1508         if (!matchAndCommit(types[i].get(), *function->parameters()[i]->type())) {
1509             setError();
1510             return;
1511         }
1512     }
1513
1514     callExpression.setFunction(*function);
1515
1516     assignType(callExpression, function->type().clone());
1517 }
1518
1519 bool check(Program& program)
1520 {
1521     Checker checker(program.intrinsics(), program);
1522     checker.checkErrorAndVisit(program);
1523     if (checker.error())
1524         return false;
1525     return checker.assignTypes();
1526 }
1527
1528 } // namespace WHLSL
1529
1530 } // namespace WebCore
1531
1532 #endif // ENABLE(WEBGPU)