[WHLSL] Standard library is too big to directly include in WebCore
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLChecker.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLChecker.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLArrayReferenceType.h"
32 #include "WHLSLArrayType.h"
33 #include "WHLSLAssignmentExpression.h"
34 #include "WHLSLCallExpression.h"
35 #include "WHLSLCommaExpression.h"
36 #include "WHLSLDereferenceExpression.h"
37 #include "WHLSLDoWhileLoop.h"
38 #include "WHLSLDotExpression.h"
39 #include "WHLSLEntryPointType.h"
40 #include "WHLSLForLoop.h"
41 #include "WHLSLGatherEntryPointItems.h"
42 #include "WHLSLIfStatement.h"
43 #include "WHLSLIndexExpression.h"
44 #include "WHLSLInferTypes.h"
45 #include "WHLSLLogicalExpression.h"
46 #include "WHLSLLogicalNotExpression.h"
47 #include "WHLSLMakeArrayReferenceExpression.h"
48 #include "WHLSLMakePointerExpression.h"
49 #include "WHLSLNameContext.h"
50 #include "WHLSLPointerType.h"
51 #include "WHLSLProgram.h"
52 #include "WHLSLReadModifyWriteExpression.h"
53 #include "WHLSLResolvableType.h"
54 #include "WHLSLResolveOverloadImpl.h"
55 #include "WHLSLResolvingType.h"
56 #include "WHLSLReturn.h"
57 #include "WHLSLSwitchStatement.h"
58 #include "WHLSLTernaryExpression.h"
59 #include "WHLSLVisitor.h"
60 #include "WHLSLWhileLoop.h"
61 #include <wtf/HashMap.h>
62 #include <wtf/HashSet.h>
63 #include <wtf/Ref.h>
64 #include <wtf/Vector.h>
65 #include <wtf/text/WTFString.h>
66
67 namespace WebCore {
68
69 namespace WHLSL {
70
71 class PODChecker : public Visitor {
72 public:
73     PODChecker() = default;
74
75     virtual ~PODChecker() = default;
76
77     void visit(AST::EnumerationDefinition& enumerationDefinition) override
78     {
79         Visitor::visit(enumerationDefinition);
80     }
81
82     void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override
83     {
84         if (!nativeTypeDeclaration.isNumber()
85             && !nativeTypeDeclaration.isVector()
86             && !nativeTypeDeclaration.isMatrix())
87             setError();
88     }
89
90     void visit(AST::StructureDefinition& structureDefinition) override
91     {
92         Visitor::visit(structureDefinition);
93     }
94
95     void visit(AST::TypeDefinition& typeDefinition) override
96     {
97         Visitor::visit(typeDefinition);
98     }
99
100     void visit(AST::ArrayType& arrayType) override
101     {
102         Visitor::visit(arrayType);
103     }
104
105     void visit(AST::PointerType&) override
106     {
107         setError();
108     }
109
110     void visit(AST::ArrayReferenceType&) override
111     {
112         setError();
113     }
114
115     void visit(AST::TypeReference& typeReference) override
116     {
117         checkErrorAndVisit(typeReference.resolvedType());
118     }
119 };
120
121 static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(Lexer::Token origin, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics)
122 {
123     const bool isOperator = true;
124     auto returnType = makeUniqueRef<AST::PointerType>(Lexer::Token(origin), firstArgument.addressSpace(), firstArgument.elementType().clone());
125     AST::VariableDeclarations parameters;
126     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), firstArgument.clone(), String(), nullptr, nullptr));
127     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), UniqueRef<AST::UnnamedType>(AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.uintType())), String(), nullptr, nullptr));
128     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator&[]", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator));
129 }
130
131 static AST::NativeFunctionDeclaration resolveWithOperatorLength(Lexer::Token origin, AST::UnnamedType& firstArgument, const Intrinsics& intrinsics)
132 {
133     const bool isOperator = true;
134     auto returnType = AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.uintType());
135     AST::VariableDeclarations parameters;
136     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), firstArgument.clone(), String(), nullptr, nullptr));
137     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator.length", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator));
138 }
139
140 static AST::NativeFunctionDeclaration resolveWithReferenceComparator(Lexer::Token origin, ResolvingType& firstArgument, ResolvingType& secondArgument, const Intrinsics& intrinsics)
141 {
142     const bool isOperator = true;
143     auto returnType = AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.boolType());
144     auto argumentType = firstArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
145         return unnamedType->clone();
146     }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
147         return secondArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
148             return unnamedType->clone();
149         }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
150             // We encountered "null == null".
151             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198162 This can probably be generalized, using the "preferred type" infrastructure used by generic literals
152             ASSERT_NOT_REACHED();
153             return AST::TypeReference::wrap(Lexer::Token(origin), intrinsics.intType());
154         }));
155     }));
156     AST::VariableDeclarations parameters;
157     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), argumentType->clone(), String(), nullptr, nullptr));
158     parameters.append(makeUniqueRef<AST::VariableDeclaration>(Lexer::Token(origin), AST::Qualifiers(), UniqueRef<AST::UnnamedType>(WTFMove(argumentType)), String(), nullptr, nullptr));
159     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(origin), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator));
160 }
161
162 enum class Acceptability {
163     Yes,
164     Maybe,
165     No
166 };
167
168 static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(const String& name, Lexer::Token origin, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics)
169 {
170     if (name == "operator&[]" && types.size() == 2) {
171         auto* firstArgumentArrayRef = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* {
172             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
173                 return &downcast<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType));
174             return nullptr;
175         }, [](RefPtr<ResolvableTypeReference>&) -> AST::ArrayReferenceType* {
176             return nullptr;
177         }));
178         bool secondArgumentIsUint = types[1].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
179             return matches(unnamedType, intrinsics.uintType());
180         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
181             return resolvableTypeReference->resolvableType().canResolve(intrinsics.uintType());
182         }));
183         if (firstArgumentArrayRef && secondArgumentIsUint)
184             return resolveWithOperatorAnderIndexer(origin, *firstArgumentArrayRef, intrinsics);
185     } else if (name == "operator.length" && types.size() == 1) {
186         auto* firstArgumentReference = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
187             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) || is<AST::ArrayType>(static_cast<AST::UnnamedType&>(unnamedType)))
188                 return &unnamedType;
189             return nullptr;
190         }, [](RefPtr<ResolvableTypeReference>&) -> AST::UnnamedType* {
191             return nullptr;
192         }));
193         if (firstArgumentReference)
194             return resolveWithOperatorLength(origin, *firstArgumentReference, intrinsics);
195     } else if (name == "operator==" && types.size() == 2) {
196         auto acceptability = [](ResolvingType& resolvingType) -> Acceptability {
197             return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> Acceptability {
198                 auto& unifyNode = unnamedType->unifyNode();
199                 return is<AST::UnnamedType>(unifyNode) && is<AST::ReferenceType>(downcast<AST::UnnamedType>(unifyNode)) ? Acceptability::Yes : Acceptability::No;
200             }, [](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Acceptability {
201                 return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType()) ? Acceptability::Maybe : Acceptability::No;
202             }));
203         };
204         auto leftAcceptability = acceptability(types[0].get());
205         auto rightAcceptability = acceptability(types[1].get());
206         bool success = false;
207         if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) {
208             auto& unnamedType1 = *types[0].get().getUnnamedType();
209             auto& unnamedType2 = *types[1].get().getUnnamedType();
210             success = matches(unnamedType1, unnamedType2);
211         } else if ((leftAcceptability == Acceptability::Maybe && rightAcceptability == Acceptability::Yes)
212             || (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Maybe))
213             success = true;
214         if (success)
215             return resolveWithReferenceComparator(origin, types[0].get(), types[1].get(), intrinsics);
216     }
217     return WTF::nullopt;
218 }
219
220 static AST::FunctionDeclaration* resolveFunction(Program& program, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>* possibleOverloads, Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, Lexer::Token origin, const Intrinsics& intrinsics, AST::NamedType* castReturnType = nullptr)
221 {
222     if (possibleOverloads) {
223         if (AST::FunctionDeclaration* function = resolveFunctionOverload(*possibleOverloads, types, castReturnType))
224             return function;
225     }
226
227     if (auto newFunction = resolveByInstantiation(name, origin, types, intrinsics)) {
228         program.append(WTFMove(*newFunction));
229         return &program.nativeFunctionDeclarations().last();
230     }
231
232     return nullptr;
233 }
234
235 static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
236 {
237     {
238         auto checkDuplicateSemantics = [&](const Vector<EntryPointItem>& items) -> bool {
239             for (size_t i = 0; i < items.size(); ++i) {
240                 for (size_t j = i + 1; j < items.size(); ++j) {
241                     if (items[i].semantic == items[j].semantic)
242                         return false;
243                 }
244             }
245             return true;
246         };
247         if (!checkDuplicateSemantics(inputItems))
248             return false;
249         if (!checkDuplicateSemantics(outputItems))
250             return false;
251     }
252
253     {
254         auto checkSemanticTypes = [&](const Vector<EntryPointItem>& items) -> bool {
255             for (auto& item : items) {
256                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
257                     return semantic.isAcceptableType(*item.unnamedType, intrinsics);
258                 }), *item.semantic);
259                 if (!acceptable)
260                     return false;
261             }
262             return true;
263         };
264         if (!checkSemanticTypes(inputItems))
265             return false;
266         if (!checkSemanticTypes(outputItems))
267             return false;
268     }
269
270     {
271         auto checkSemanticForShaderType = [&](const Vector<EntryPointItem>& items, AST::BaseSemantic::ShaderItemDirection direction) -> bool {
272             for (auto& item : items) {
273                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
274                     return semantic.isAcceptableForShaderItemDirection(direction, entryPointType);
275                 }), *item.semantic);
276                 if (!acceptable)
277                     return false;
278             }
279             return true;
280         };
281         if (!checkSemanticForShaderType(inputItems, AST::BaseSemantic::ShaderItemDirection::Input))
282             return false;
283         if (!checkSemanticForShaderType(outputItems, AST::BaseSemantic::ShaderItemDirection::Output))
284             return false;
285     }
286
287     {
288         auto checkPODData = [&](const Vector<EntryPointItem>& items) -> bool {
289             for (auto& item : items) {
290                 PODChecker podChecker;
291                 if (is<AST::PointerType>(item.unnamedType))
292                     podChecker.checkErrorAndVisit(downcast<AST::PointerType>(*item.unnamedType).elementType());
293                 else if (is<AST::ArrayReferenceType>(item.unnamedType))
294                     podChecker.checkErrorAndVisit(downcast<AST::ArrayReferenceType>(*item.unnamedType).elementType());
295                 else if (is<AST::ArrayType>(item.unnamedType))
296                     podChecker.checkErrorAndVisit(downcast<AST::ArrayType>(*item.unnamedType).type());
297                 else
298                     continue;
299                 if (podChecker.error())
300                     return false;
301             }
302             return true;
303         };
304         if (!checkPODData(inputItems))
305             return false;
306         if (!checkPODData(outputItems))
307             return false;
308     }
309
310     return true;
311 }
312
313 static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinition, const Intrinsics& intrinsics, NameContext& nameContext)
314 {
315     enum class CheckKind {
316         Index,
317         Dot
318     };
319
320     auto checkGetter = [&](CheckKind kind) -> bool {
321         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
322         if (functionDefinition.parameters().size() != numExpectedParameters)
323             return false;
324         auto& firstParameterUnifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
325         if (is<AST::UnnamedType>(firstParameterUnifyNode)) {
326             auto& unnamedType = downcast<AST::UnnamedType>(firstParameterUnifyNode);
327             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
328                 return false;
329         }
330         if (kind == CheckKind::Index) {
331             auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1]->type())->unifyNode();
332             if (!is<AST::NamedType>(secondParameterUnifyNode))
333                 return false;
334             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
335             if (!is<AST::NativeTypeDeclaration>(namedType))
336                 return false;
337             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
338             if (!nativeTypeDeclaration.isInt())
339                 return false;
340         }
341         return true;
342     };
343
344     auto checkSetter = [&](CheckKind kind) -> bool {
345         size_t numExpectedParameters = kind == CheckKind::Index ? 3 : 2;
346         if (functionDefinition.parameters().size() != numExpectedParameters)
347             return false;
348         auto& firstArgumentUnifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
349         if (is<AST::UnnamedType>(firstArgumentUnifyNode)) {
350             auto& unnamedType = downcast<AST::UnnamedType>(firstArgumentUnifyNode);
351             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
352                 return false;
353         }
354         if (kind == CheckKind::Index) {
355             auto& secondParameterUnifyNode = (*functionDefinition.parameters()[1]->type())->unifyNode();
356             if (!is<AST::NamedType>(secondParameterUnifyNode))
357                 return false;
358             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
359             if (!is<AST::NativeTypeDeclaration>(namedType))
360                 return false;
361             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
362             if (!nativeTypeDeclaration.isInt())
363                 return false;
364         }
365         if (!matches(functionDefinition.type(), *functionDefinition.parameters()[0]->type()))
366             return false;
367         auto& valueType = *functionDefinition.parameters()[numExpectedParameters - 1]->type();
368         auto getterName = functionDefinition.name().substring(0, functionDefinition.name().length() - 1);
369         auto* getterFuncs = nameContext.getFunctions(getterName);
370         if (!getterFuncs)
371             return false;
372         Vector<ResolvingType> argumentTypes;
373         Vector<std::reference_wrapper<ResolvingType>> argumentTypeReferences;
374         for (size_t i = 0; i < numExpectedParameters - 1; ++i)
375             argumentTypes.append((*functionDefinition.parameters()[i]->type())->clone());
376         for (auto& argumentType : argumentTypes)
377             argumentTypeReferences.append(argumentType);
378         auto* overload = resolveFunctionOverload(*getterFuncs, argumentTypeReferences);
379         if (!overload)
380             return false;
381         auto& resultType = overload->type();
382         return matches(resultType, valueType);
383     };
384
385     auto checkAnder = [&](CheckKind kind) -> bool {
386         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
387         if (functionDefinition.parameters().size() != numExpectedParameters)
388             return false;
389         {
390             auto& unifyNode = functionDefinition.type().unifyNode();
391             if (!is<AST::UnnamedType>(unifyNode))
392                 return false;
393             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
394             if (!is<AST::PointerType>(unnamedType))
395                 return false;
396         }
397         {
398             auto& unifyNode = (*functionDefinition.parameters()[0]->type())->unifyNode();
399             if (!is<AST::UnnamedType>(unifyNode))
400                 return false;
401             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
402             return is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType);
403         }
404     };
405
406     if (!functionDefinition.isOperator())
407         return true;
408     if (functionDefinition.isCast())
409         return true;
410     if (functionDefinition.name() == "operator++" || functionDefinition.name() == "operator--") {
411         return functionDefinition.parameters().size() == 1
412             && matches(*functionDefinition.parameters()[0]->type(), functionDefinition.type());
413     }
414     if (functionDefinition.name() == "operator+" || functionDefinition.name() == "operator-")
415         return functionDefinition.parameters().size() == 1 || functionDefinition.parameters().size() == 2;
416     if (functionDefinition.name() == "operator*"
417         || functionDefinition.name() == "operator/"
418         || functionDefinition.name() == "operator%"
419         || functionDefinition.name() == "operator&"
420         || functionDefinition.name() == "operator|"
421         || functionDefinition.name() == "operator^"
422         || functionDefinition.name() == "operator<<"
423         || functionDefinition.name() == "operator>>")
424         return functionDefinition.parameters().size() == 2;
425     if (functionDefinition.name() == "operator~")
426         return functionDefinition.parameters().size() == 1;
427     if (functionDefinition.name() == "operator=="
428         || functionDefinition.name() == "operator<"
429         || functionDefinition.name() == "operator<="
430         || functionDefinition.name() == "operator>"
431         || functionDefinition.name() == "operator>=") {
432         return functionDefinition.parameters().size() == 2
433             && matches(functionDefinition.type(), intrinsics.boolType());
434     }
435     if (functionDefinition.name() == "operator[]")
436         return checkGetter(CheckKind::Index);
437     if (functionDefinition.name() == "operator[]=")
438         return checkSetter(CheckKind::Index);
439     if (functionDefinition.name() == "operator&[]")
440         return checkAnder(CheckKind::Index);
441     if (functionDefinition.name().startsWith("operator.")) {
442         if (functionDefinition.name().endsWith("="))
443             return checkSetter(CheckKind::Dot);
444         return checkGetter(CheckKind::Dot);
445     }
446     if (functionDefinition.name().startsWith("operator&."))
447         return checkAnder(CheckKind::Dot);
448     return false;
449 }
450
451 class Checker : public Visitor {
452 public:
453     Checker(const Intrinsics& intrinsics, Program& program)
454         : m_intrinsics(intrinsics)
455         , m_program(program)
456     {
457     }
458
459     virtual ~Checker() = default;
460
461     void visit(Program&) override;
462
463     bool assignTypes();
464
465 private:
466     bool checkShaderType(const AST::FunctionDefinition&);
467     bool isBoolType(ResolvingType&);
468     struct RecurseInfo {
469         ResolvingType& resolvingType;
470         const AST::TypeAnnotation typeAnnotation;
471     };
472     Optional<RecurseInfo> recurseAndGetInfo(AST::Expression&, bool requiresLeftValue = false);
473     Optional<RecurseInfo> getInfo(AST::Expression&, bool requiresLeftValue = false);
474     Optional<UniqueRef<AST::UnnamedType>> recurseAndWrapBaseType(AST::PropertyAccessExpression&);
475     bool recurseAndRequireBoolType(AST::Expression&);
476     void assignType(AST::Expression&, UniqueRef<AST::UnnamedType>&&, AST::TypeAnnotation);
477     void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>&&, AST::TypeAnnotation);
478     void forwardType(AST::Expression&, ResolvingType&, AST::TypeAnnotation);
479
480     void visit(AST::FunctionDefinition&) override;
481     void visit(AST::EnumerationDefinition&) override;
482     void visit(AST::TypeReference&) override;
483     void visit(AST::VariableDeclaration&) override;
484     void visit(AST::AssignmentExpression&) override;
485     void visit(AST::ReadModifyWriteExpression&) override;
486     void visit(AST::DereferenceExpression&) override;
487     void visit(AST::MakePointerExpression&) override;
488     void visit(AST::MakeArrayReferenceExpression&) override;
489     void visit(AST::DotExpression&) override;
490     void visit(AST::IndexExpression&) override;
491     void visit(AST::VariableReference&) override;
492     void visit(AST::Return&) override;
493     void visit(AST::PointerType&) override;
494     void visit(AST::ArrayReferenceType&) override;
495     void visit(AST::IntegerLiteral&) override;
496     void visit(AST::UnsignedIntegerLiteral&) override;
497     void visit(AST::FloatLiteral&) override;
498     void visit(AST::NullLiteral&) override;
499     void visit(AST::BooleanLiteral&) override;
500     void visit(AST::EnumerationMemberLiteral&) override;
501     void visit(AST::LogicalNotExpression&) override;
502     void visit(AST::LogicalExpression&) override;
503     void visit(AST::IfStatement&) override;
504     void visit(AST::WhileLoop&) override;
505     void visit(AST::DoWhileLoop&) override;
506     void visit(AST::ForLoop&) override;
507     void visit(AST::SwitchStatement&) override;
508     void visit(AST::CommaExpression&) override;
509     void visit(AST::TernaryExpression&) override;
510     void visit(AST::CallExpression&) override;
511
512     void finishVisiting(AST::PropertyAccessExpression&, ResolvingType* additionalArgumentType = nullptr);
513
514     HashMap<AST::Expression*, std::unique_ptr<ResolvingType>> m_typeMap;
515     HashMap<AST::Expression*, AST::TypeAnnotation> m_typeAnnotations;
516     HashSet<String> m_vertexEntryPoints;
517     HashSet<String> m_fragmentEntryPoints;
518     HashSet<String> m_computeEntryPoints;
519     const Intrinsics& m_intrinsics;
520     Program& m_program;
521 };
522
523 void Checker::visit(Program& program)
524 {
525     // These visiting functions might add new global statements, so don't use foreach syntax.
526     for (size_t i = 0; i < program.typeDefinitions().size(); ++i)
527         checkErrorAndVisit(program.typeDefinitions()[i]);
528     for (size_t i = 0; i < program.structureDefinitions().size(); ++i)
529         checkErrorAndVisit(program.structureDefinitions()[i]);
530     for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i)
531         checkErrorAndVisit(program.enumerationDefinitions()[i]);
532     for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i)
533         checkErrorAndVisit(program.nativeTypeDeclarations()[i]);
534
535     for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
536         checkErrorAndVisit(program.functionDefinitions()[i]);
537     for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i)
538         checkErrorAndVisit(program.nativeFunctionDeclarations()[i]);
539 }
540
541 bool Checker::assignTypes()
542 {
543     for (auto& keyValuePair : m_typeMap) {
544         auto success = keyValuePair.value->visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
545             keyValuePair.key->setType(unnamedType->clone());
546             return true;
547         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
548             if (!resolvableTypeReference->resolvableType().maybeResolvedType()) {
549                 if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType())))
550                     return false;
551             }
552             keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType().clone());
553             return true;
554         }));
555         if (!success)
556             return false;
557     }
558
559     for (auto& keyValuePair : m_typeAnnotations)
560         keyValuePair.key->setTypeAnnotation(WTFMove(keyValuePair.value));
561     return true;
562 }
563
564 bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition)
565 {
566     switch (*functionDefinition.entryPointType()) {
567     case AST::EntryPointType::Vertex:
568         return static_cast<bool>(m_vertexEntryPoints.add(functionDefinition.name()));
569     case AST::EntryPointType::Fragment:
570         return static_cast<bool>(m_fragmentEntryPoints.add(functionDefinition.name()));
571     case AST::EntryPointType::Compute:
572         return static_cast<bool>(m_computeEntryPoints.add(functionDefinition.name()));
573     }
574 }
575
576 void Checker::visit(AST::FunctionDefinition& functionDefinition)
577 {
578     if (functionDefinition.entryPointType()) {
579         if (!checkShaderType(functionDefinition)) {
580             setError();
581             return;
582         }
583         auto entryPointItems = gatherEntryPointItems(m_intrinsics, functionDefinition);
584         if (!entryPointItems) {
585             setError();
586             return;
587         }
588         if (!checkSemantics(entryPointItems->inputs, entryPointItems->outputs, functionDefinition.entryPointType(), m_intrinsics)) {
589             setError();
590             return;
591         }
592     }
593     if (!checkOperatorOverload(functionDefinition, m_intrinsics, m_program.nameContext())) {
594         setError();
595         return;
596     }
597
598     Visitor::visit(functionDefinition);
599 }
600
601 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& left, ResolvingType& right)
602 {
603     return left.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
604         return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
605             if (matches(left, right))
606                 return left->clone();
607             return WTF::nullopt;
608         }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
609             return matchAndCommit(left, right->resolvableType());
610         }));
611     }, [&](RefPtr<ResolvableTypeReference>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
612         return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
613             return matchAndCommit(right, left->resolvableType());
614         }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
615             return matchAndCommit(left->resolvableType(), right->resolvableType());
616         }));
617     }));
618 }
619
620 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType)
621 {
622     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
623         if (matches(unnamedType, resolvingType))
624             return unnamedType.clone();
625         return WTF::nullopt;
626     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
627         return matchAndCommit(unnamedType, resolvingType->resolvableType());
628     }));
629 }
630
631 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::NamedType& namedType)
632 {
633     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
634         if (matches(resolvingType, namedType))
635             return resolvingType->clone();
636         return WTF::nullopt;
637     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
638         return matchAndCommit(namedType, resolvingType->resolvableType());
639     }));
640 }
641
642 static Optional<UniqueRef<AST::UnnamedType>> commit(ResolvingType& resolvingType)
643 {
644     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> Optional<UniqueRef<AST::UnnamedType>> {
645         return unnamedType->clone();
646     }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Optional<UniqueRef<AST::UnnamedType>> {
647         if (!resolvableTypeReference->resolvableType().maybeResolvedType())
648             return commit(resolvableTypeReference->resolvableType());
649         return resolvableTypeReference->resolvableType().resolvedType().clone();
650     }));
651 }
652
653 void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
654 {
655     auto* baseType = ([&]() -> AST::NativeTypeDeclaration* {
656         checkErrorAndVisit(enumerationDefinition.type());
657         auto& baseType = enumerationDefinition.type().unifyNode();
658         if (!is<AST::NamedType>(baseType))
659             return nullptr;
660         auto& namedType = downcast<AST::NamedType>(baseType);
661         if (!is<AST::NativeTypeDeclaration>(namedType))
662             return nullptr;
663         auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
664         if (!nativeTypeDeclaration.isInt())
665             return nullptr;
666         return &nativeTypeDeclaration;
667     })();
668     if (!baseType) {
669         setError();
670         return;
671     }
672
673     auto enumerationMembers = enumerationDefinition.enumerationMembers();
674
675     auto matchAndCommitMember = [&](AST::EnumerationMember& member) -> bool {
676         return member.value()->visit(WTF::makeVisitor([&](AST::Expression& value) -> bool {
677             auto valueInfo = recurseAndGetInfo(value);
678             if (!valueInfo)
679                 return false;
680             return static_cast<bool>(matchAndCommit(valueInfo->resolvingType, *baseType));
681         }));
682     };
683
684     for (auto& member : enumerationMembers) {
685         if (!member.get().value())
686             continue;
687
688         if (!matchAndCommitMember(member)) {
689             setError();
690             return;
691         }
692     }
693
694     int64_t nextValue = 0;
695     for (auto& member : enumerationMembers) {
696         if (member.get().value()) {
697             auto value = member.get().value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
698                 return integerLiteral.valueForSelectedType();
699             }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
700                 return unsignedIntegerLiteral.valueForSelectedType();
701             }, [&](auto&) -> int64_t {
702                 ASSERT_NOT_REACHED();
703                 return 0;
704             }));
705             nextValue = baseType->successor()(value);
706         } else {
707             if (nextValue > std::numeric_limits<int>::max()) {
708                 ASSERT(nextValue <= std::numeric_limits<unsigned>::max());
709                 member.get().setValue(AST::ConstantExpression(AST::UnsignedIntegerLiteral(Lexer::Token(member.get().origin()), static_cast<unsigned>(nextValue))));
710             }
711             ASSERT(nextValue >= std::numeric_limits<int>::min());
712             member.get().setValue(AST::ConstantExpression(AST::IntegerLiteral(Lexer::Token(member.get().origin()), static_cast<int>(nextValue))));
713
714             if (!matchAndCommitMember(member)) {
715                 setError();
716                 return;
717             }
718
719             nextValue = baseType->successor()(nextValue);
720         }
721     }
722
723     auto getValue = [&](AST::EnumerationMember& member) -> int64_t {
724         ASSERT(member.value());
725         auto value = member.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
726             return integerLiteral.value();
727         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
728             return unsignedIntegerLiteral.value();
729         }, [&](auto&) -> int64_t {
730             ASSERT_NOT_REACHED();
731             return 0;
732         }));
733         return value;
734     };
735
736     for (size_t i = 0; i < enumerationMembers.size(); ++i) {
737         auto value = getValue(enumerationMembers[i].get());
738         for (size_t j = i + 1; j < enumerationMembers.size(); ++j) {
739             auto otherValue = getValue(enumerationMembers[j].get());
740             if (value == otherValue) {
741                 setError();
742                 return;
743             }
744         }
745     }
746
747     bool foundZero = false;
748     for (auto& member : enumerationMembers) {
749         if (!getValue(member.get())) {
750             foundZero = true;
751             break;
752         }
753     }
754     if (!foundZero) {
755         setError();
756         return;
757     }
758 }
759
760 void Checker::visit(AST::TypeReference& typeReference)
761 {
762     ASSERT(typeReference.maybeResolvedType());
763
764     for (auto& typeArgument : typeReference.typeArguments())
765         checkErrorAndVisit(typeArgument);
766 }
767
768 auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
769 {
770     Visitor::visit(expression);
771     if (error())
772         return WTF::nullopt;
773     return getInfo(expression, requiresLeftValue);
774 }
775
776 auto Checker::getInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
777 {
778     auto typeIterator = m_typeMap.find(&expression);
779     ASSERT(typeIterator != m_typeMap.end());
780
781     auto typeAnnotationIterator = m_typeAnnotations.find(&expression);
782     ASSERT(typeAnnotationIterator != m_typeAnnotations.end());
783     if (requiresLeftValue && typeAnnotationIterator->value.isRightValue()) {
784         setError();
785         return WTF::nullopt;
786     }
787     return {{ *typeIterator->value, typeAnnotationIterator->value }};
788 }
789
790 void Checker::visit(AST::VariableDeclaration& variableDeclaration)
791 {
792     // ReadModifyWriteExpressions are the only place where anonymous variables exist,
793     // and that doesn't recurse on the anonymous variables, so we can assume the variable has a type.
794     checkErrorAndVisit(*variableDeclaration.type());
795     if (variableDeclaration.initializer()) {
796         auto& lhsType = *variableDeclaration.type();
797         auto initializerInfo = recurseAndGetInfo(*variableDeclaration.initializer());
798         if (!initializerInfo)
799             return;
800         if (!matchAndCommit(initializerInfo->resolvingType, lhsType)) {
801             setError();
802             return;
803         }
804     }
805 }
806
807 void Checker::assignType(AST::Expression& expression, UniqueRef<AST::UnnamedType>&& unnamedType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
808 {
809     auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(WTFMove(unnamedType)));
810     ASSERT_UNUSED(addResult, addResult.isNewEntry);
811     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
812     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
813 }
814
815 void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference>&& resolvableTypeReference, AST::TypeAnnotation typeAnnotation = AST::RightValue())
816 {
817     auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(WTFMove(resolvableTypeReference)));
818     ASSERT_UNUSED(addResult, addResult.isNewEntry);
819     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
820     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
821 }
822
823 void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
824 {
825     resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& result) {
826         auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(result->clone()));
827         ASSERT_UNUSED(addResult, addResult.isNewEntry);
828     }, [&](RefPtr<ResolvableTypeReference>& result) {
829         auto addResult = m_typeMap.add(&expression, std::make_unique<ResolvingType>(result.copyRef()));
830         ASSERT_UNUSED(addResult, addResult.isNewEntry);
831     }));
832     auto typeAnnotationAddResult = m_typeAnnotations.add(&expression, WTFMove(typeAnnotation));
833     ASSERT_UNUSED(typeAnnotationAddResult, typeAnnotationAddResult.isNewEntry);
834 }
835
836 void Checker::visit(AST::AssignmentExpression& assignmentExpression)
837 {
838     auto leftInfo = recurseAndGetInfo(assignmentExpression.left(), true);
839     if (!leftInfo)
840         return;
841
842     if (leftInfo->typeAnnotation.isRightValue()) {
843         setError();
844         return;
845     }
846
847     auto rightInfo = recurseAndGetInfo(assignmentExpression.right());
848     if (!rightInfo)
849         return;
850
851     auto resultType = matchAndCommit(leftInfo->resolvingType, rightInfo->resolvingType);
852     if (!resultType) {
853         setError();
854         return;
855     }
856
857     assignType(assignmentExpression, WTFMove(*resultType));
858 }
859
860 void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
861 {
862     auto leftValueInfo = recurseAndGetInfo(readModifyWriteExpression.leftValue(), true);
863     if (!leftValueInfo)
864         return;
865
866     readModifyWriteExpression.oldValue().setType(leftValueInfo->resolvingType.getUnnamedType()->clone());
867
868     auto newValueInfo = recurseAndGetInfo(readModifyWriteExpression.newValueExpression());
869     if (!newValueInfo)
870         return;
871
872     if (Optional<UniqueRef<AST::UnnamedType>> matchedType = matchAndCommit(leftValueInfo->resolvingType, newValueInfo->resolvingType))
873         readModifyWriteExpression.newValue().setType(WTFMove(matchedType.value()));
874     else {
875         setError();
876         return;
877     }
878
879     auto resultInfo = recurseAndGetInfo(readModifyWriteExpression.resultExpression());
880     if (!resultInfo)
881         return;
882
883     forwardType(readModifyWriteExpression, resultInfo->resolvingType);
884 }
885
886 static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType)
887 {
888     return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& type) -> AST::UnnamedType* {
889         return &type;
890     }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* {
891         // FIXME: If the type isn't committed, should we just commit() it now?
892         return type->resolvableType().maybeResolvedType();
893     }));
894 }
895
896 void Checker::visit(AST::DereferenceExpression& dereferenceExpression)
897 {
898     auto pointerInfo = recurseAndGetInfo(dereferenceExpression.pointer());
899     if (!pointerInfo)
900         return;
901
902     auto* unnamedType = getUnnamedType(pointerInfo->resolvingType);
903
904     auto* pointerType = ([&](AST::UnnamedType* unnamedType) -> AST::PointerType* {
905         if (!unnamedType)
906             return nullptr;
907         auto& unifyNode = unnamedType->unifyNode();
908         if (!is<AST::UnnamedType>(unifyNode))
909             return nullptr;
910         auto& unnamedUnifyType = downcast<AST::UnnamedType>(unifyNode);
911         if (!is<AST::PointerType>(unnamedUnifyType))
912             return nullptr;
913         return &downcast<AST::PointerType>(unnamedUnifyType);
914     })(unnamedType);
915     if (!pointerType) {
916         setError();
917         return;
918     }
919
920     assignType(dereferenceExpression, pointerType->elementType().clone(), AST::LeftValue { pointerType->addressSpace() });
921 }
922
923 void Checker::visit(AST::MakePointerExpression& makePointerExpression)
924 {
925     auto leftValueInfo = recurseAndGetInfo(makePointerExpression.leftValue(), true);
926     if (!leftValueInfo)
927         return;
928
929     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
930     if (!leftAddressSpace) {
931         setError();
932         return;
933     }
934
935     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
936     if (!leftValueType) {
937         setError();
938         return;
939     }
940
941     assignType(makePointerExpression, makeUniqueRef<AST::PointerType>(Lexer::Token(makePointerExpression.origin()), *leftAddressSpace, leftValueType->clone()));
942 }
943
944 void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
945 {
946     auto leftValueInfo = recurseAndGetInfo(makeArrayReferenceExpression.leftValue());
947     if (!leftValueInfo)
948         return;
949
950     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
951     if (!leftValueType) {
952         setError();
953         return;
954     }
955
956     auto& unifyNode = leftValueType->unifyNode();
957     if (is<AST::UnnamedType>(unifyNode)) {
958         auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
959         if (is<AST::PointerType>(unnamedType)) {
960             auto& pointerType = downcast<AST::PointerType>(unnamedType);
961             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the fact that we're not targetting the item; we're targetting the item's inner element.
962             assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), pointerType.addressSpace(), pointerType.elementType().clone()));
963             return;
964         }
965
966         auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
967         if (!leftAddressSpace) {
968             setError();
969             return;
970         }
971
972         if (is<AST::ArrayType>(unnamedType)) {
973             auto& arrayType = downcast<AST::ArrayType>(unnamedType);
974             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the number of elements.
975             assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *leftAddressSpace, arrayType.type().clone()));
976             return;
977         }
978     }
979
980     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
981     if (!leftAddressSpace) {
982         setError();
983         return;
984     }
985
986     assignType(makeArrayReferenceExpression, makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(makeArrayReferenceExpression.origin()), *leftAddressSpace, leftValueType->clone()));
987 }
988
989 static Optional<UniqueRef<AST::UnnamedType>> argumentTypeForAndOverload(AST::UnnamedType& baseType, AST::AddressSpace addressSpace)
990 {
991     auto& unifyNode = baseType.unifyNode();
992     if (is<AST::NamedType>(unifyNode)) {
993         auto& namedType = downcast<AST::NamedType>(unifyNode);
994         return { makeUniqueRef<AST::PointerType>(Lexer::Token(namedType.origin()), addressSpace, AST::TypeReference::wrap(Lexer::Token(namedType.origin()), namedType)) };
995     }
996
997     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
998
999     if (is<AST::ArrayReferenceType>(unnamedType))
1000         return unnamedType.clone();
1001
1002     if (is<AST::ArrayType>(unnamedType))
1003         return { makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(unnamedType.origin()), addressSpace, downcast<AST::ArrayType>(unnamedType).type().clone()) };
1004
1005     if (is<AST::PointerType>(unnamedType))
1006         return WTF::nullopt;
1007
1008     return { makeUniqueRef<AST::PointerType>(Lexer::Token(unnamedType.origin()), addressSpace, unnamedType.clone()) };
1009 }
1010
1011 void Checker::finishVisiting(AST::PropertyAccessExpression& propertyAccessExpression, ResolvingType* additionalArgumentType)
1012 {
1013     auto baseInfo = recurseAndGetInfo(propertyAccessExpression.base());
1014     if (!baseInfo)
1015         return;
1016     auto baseUnnamedType = commit(baseInfo->resolvingType);
1017     if (!baseUnnamedType)
1018         return;
1019
1020     AST::FunctionDeclaration* getterFunction = nullptr;
1021     AST::UnnamedType* getterReturnType = nullptr;
1022     {
1023         Vector<std::reference_wrapper<ResolvingType>> getterArgumentTypes { baseInfo->resolvingType };
1024         if (additionalArgumentType)
1025             getterArgumentTypes.append(*additionalArgumentType);
1026         auto getterName = propertyAccessExpression.getterFunctionName();
1027         auto* getterFunctions = m_program.nameContext().getFunctions(getterName);
1028         getterFunction = resolveFunction(m_program, getterFunctions, getterArgumentTypes, getterName, propertyAccessExpression.origin(), m_intrinsics);
1029         if (getterFunction)
1030             getterReturnType = &getterFunction->type();
1031     }
1032
1033     AST::FunctionDeclaration* anderFunction = nullptr;
1034     AST::UnnamedType* anderReturnType = nullptr;
1035     auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace();
1036     if (leftAddressSpace) {
1037         if (auto argumentTypeForAndOverload = WHLSL::argumentTypeForAndOverload(*baseUnnamedType, *leftAddressSpace)) {
1038             ResolvingType argumentType = { WTFMove(*argumentTypeForAndOverload) };
1039             Vector<std::reference_wrapper<ResolvingType>> anderArgumentTypes { argumentType };
1040             if (additionalArgumentType)
1041                 anderArgumentTypes.append(*additionalArgumentType);
1042             auto anderName = propertyAccessExpression.anderFunctionName();
1043             auto* anderFunctions = m_program.nameContext().getFunctions(anderName);
1044             anderFunction = resolveFunction(m_program, anderFunctions, anderArgumentTypes, anderName, propertyAccessExpression.origin(), m_intrinsics);
1045             if (anderFunction)
1046                 anderReturnType = &downcast<AST::PointerType>(anderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1047         }
1048     }
1049
1050     AST::FunctionDeclaration* threadAnderFunction = nullptr;
1051     AST::UnnamedType* threadAnderReturnType = nullptr;
1052     if (auto argumentTypeForAndOverload = WHLSL::argumentTypeForAndOverload(*baseUnnamedType, AST::AddressSpace::Thread)) {
1053         ResolvingType argumentType = { makeUniqueRef<AST::PointerType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, baseUnnamedType->get().clone()) };
1054         Vector<std::reference_wrapper<ResolvingType>> threadAnderArgumentTypes { argumentType };
1055         if (additionalArgumentType)
1056             threadAnderArgumentTypes.append(*additionalArgumentType);
1057         auto anderName = propertyAccessExpression.anderFunctionName();
1058         auto* anderFunctions = m_program.nameContext().getFunctions(anderName);
1059         threadAnderFunction = resolveFunction(m_program, anderFunctions, threadAnderArgumentTypes, anderName, propertyAccessExpression.origin(), m_intrinsics);
1060         if (threadAnderFunction)
1061             threadAnderReturnType = &downcast<AST::PointerType>(threadAnderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1062     }
1063
1064     if (leftAddressSpace && !anderFunction && !getterFunction) {
1065         setError();
1066         return;
1067     }
1068
1069     if (!leftAddressSpace && !threadAnderFunction && !getterFunction) {
1070         setError();
1071         return;
1072     }
1073
1074     if (threadAnderFunction && getterFunction) {
1075         setError();
1076         return;
1077     }
1078
1079     if (anderFunction && threadAnderFunction && !matches(*anderReturnType, *threadAnderReturnType)) {
1080         setError();
1081         return;
1082     }
1083
1084     if (getterFunction && anderFunction && !matches(*getterReturnType, *anderReturnType)) {
1085         setError();
1086         return;
1087     }
1088
1089     if (getterFunction && threadAnderFunction && !matches(*getterReturnType, *threadAnderReturnType)) {
1090         setError();
1091         return;
1092     }
1093
1094     AST::UnnamedType* fieldType = getterReturnType ? getterReturnType : anderReturnType ? anderReturnType : threadAnderReturnType;
1095
1096     AST::FunctionDeclaration* setterFunction = nullptr;
1097     AST::UnnamedType* setterReturnType = nullptr;
1098     {
1099         ResolvingType fieldResolvingType(fieldType->clone());
1100         Vector<std::reference_wrapper<ResolvingType>> setterArgumentTypes { baseInfo->resolvingType };
1101         if (additionalArgumentType)
1102             setterArgumentTypes.append(*additionalArgumentType);
1103         setterArgumentTypes.append(fieldResolvingType);
1104         auto setterName = propertyAccessExpression.setterFunctionName();
1105         auto* setterFunctions = m_program.nameContext().getFunctions(setterName);
1106         setterFunction = resolveFunction(m_program, setterFunctions, setterArgumentTypes, setterName, propertyAccessExpression.origin(), m_intrinsics);
1107         if (setterFunction)
1108             setterReturnType = &setterFunction->type();
1109     }
1110
1111     if (setterFunction && !getterFunction) {
1112         setError();
1113         return;
1114     }
1115
1116     propertyAccessExpression.setGetterFunction(getterFunction);
1117     propertyAccessExpression.setAnderFunction(anderFunction);
1118     propertyAccessExpression.setThreadAnderFunction(threadAnderFunction);
1119     propertyAccessExpression.setSetterFunction(setterFunction);
1120
1121     AST::TypeAnnotation typeAnnotation = AST::RightValue();
1122     if (auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace()) {
1123         if (anderFunction)
1124             typeAnnotation = AST::LeftValue { *leftAddressSpace };
1125         else if (setterFunction)
1126             typeAnnotation = AST::AbstractLeftValue();
1127     } else if (!baseInfo->typeAnnotation.isRightValue() && (setterFunction || threadAnderFunction))
1128         typeAnnotation = AST::AbstractLeftValue();
1129     assignType(propertyAccessExpression, fieldType->clone(), WTFMove(typeAnnotation));
1130 }
1131
1132 void Checker::visit(AST::DotExpression& dotExpression)
1133 {
1134     finishVisiting(dotExpression);
1135 }
1136
1137 void Checker::visit(AST::IndexExpression& indexExpression)
1138 {
1139     auto baseInfo = recurseAndGetInfo(indexExpression.indexExpression());
1140     if (!baseInfo)
1141         return;
1142     finishVisiting(indexExpression, &baseInfo->resolvingType);
1143 }
1144
1145 void Checker::visit(AST::VariableReference& variableReference)
1146 {
1147     ASSERT(variableReference.variable());
1148     ASSERT(variableReference.variable()->type());
1149     
1150     assignType(variableReference, variableReference.variable()->type()->clone(), AST::LeftValue { AST::AddressSpace::Thread });
1151 }
1152
1153 void Checker::visit(AST::Return& returnStatement)
1154 {
1155     if (returnStatement.value()) {
1156         auto valueInfo = recurseAndGetInfo(*returnStatement.value());
1157         if (!valueInfo)
1158             return;
1159         if (!matchAndCommit(valueInfo->resolvingType, returnStatement.function()->type()))
1160             setError();
1161         return;
1162     }
1163
1164     if (!matches(returnStatement.function()->type(), m_intrinsics.voidType()))
1165         setError();
1166 }
1167
1168 void Checker::visit(AST::PointerType&)
1169 {
1170     // Following pointer types can cause infinite loops because of data structures
1171     // like linked lists.
1172     // FIXME: Make sure this function should be empty
1173 }
1174
1175 void Checker::visit(AST::ArrayReferenceType&)
1176 {
1177     // Following array reference types can cause infinite loops because of data
1178     // structures like linked lists.
1179     // FIXME: Make sure this function should be empty
1180 }
1181
1182 void Checker::visit(AST::IntegerLiteral& integerLiteral)
1183 {
1184     assignType(integerLiteral, adoptRef(*new ResolvableTypeReference(integerLiteral.type())));
1185 }
1186
1187 void Checker::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
1188 {
1189     assignType(unsignedIntegerLiteral, adoptRef(*new ResolvableTypeReference(unsignedIntegerLiteral.type())));
1190 }
1191
1192 void Checker::visit(AST::FloatLiteral& floatLiteral)
1193 {
1194     assignType(floatLiteral, adoptRef(*new ResolvableTypeReference(floatLiteral.type())));
1195 }
1196
1197 void Checker::visit(AST::NullLiteral& nullLiteral)
1198 {
1199     assignType(nullLiteral, adoptRef(*new ResolvableTypeReference(nullLiteral.type())));
1200 }
1201
1202 void Checker::visit(AST::BooleanLiteral& booleanLiteral)
1203 {
1204     assignType(booleanLiteral, AST::TypeReference::wrap(Lexer::Token(booleanLiteral.origin()), m_intrinsics.boolType()));
1205 }
1206
1207 void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
1208 {
1209     ASSERT(enumerationMemberLiteral.enumerationDefinition());
1210     auto& enumerationDefinition = *enumerationMemberLiteral.enumerationDefinition();
1211     assignType(enumerationMemberLiteral, AST::TypeReference::wrap(Lexer::Token(enumerationMemberLiteral.origin()), enumerationDefinition));
1212 }
1213
1214 bool Checker::isBoolType(ResolvingType& resolvingType)
1215 {
1216     return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> bool {
1217         return matches(left, m_intrinsics.boolType());
1218     }, [&](RefPtr<ResolvableTypeReference>& left) -> bool {
1219         return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType()));
1220     }));
1221 }
1222
1223 bool Checker::recurseAndRequireBoolType(AST::Expression& expression)
1224 {
1225     auto expressionInfo = recurseAndGetInfo(expression);
1226     if (!expressionInfo)
1227         return false;
1228     if (!isBoolType(expressionInfo->resolvingType)) {
1229         setError();
1230         return false;
1231     }
1232     return true;
1233 }
1234
1235 void Checker::visit(AST::LogicalNotExpression& logicalNotExpression)
1236 {
1237     if (!recurseAndRequireBoolType(logicalNotExpression.operand()))
1238         return;
1239     assignType(logicalNotExpression, AST::TypeReference::wrap(Lexer::Token(logicalNotExpression.origin()), m_intrinsics.boolType()));
1240 }
1241
1242 void Checker::visit(AST::LogicalExpression& logicalExpression)
1243 {
1244     if (!recurseAndRequireBoolType(logicalExpression.left()))
1245         return;
1246     if (!recurseAndRequireBoolType(logicalExpression.right()))
1247         return;
1248     assignType(logicalExpression, AST::TypeReference::wrap(Lexer::Token(logicalExpression.origin()), m_intrinsics.boolType()));
1249 }
1250
1251 void Checker::visit(AST::IfStatement& ifStatement)
1252 {
1253     if (!recurseAndRequireBoolType(ifStatement.conditional()))
1254         return;
1255     checkErrorAndVisit(ifStatement.body());
1256     if (ifStatement.elseBody())
1257         checkErrorAndVisit(*ifStatement.elseBody());
1258 }
1259
1260 void Checker::visit(AST::WhileLoop& whileLoop)
1261 {
1262     if (!recurseAndRequireBoolType(whileLoop.conditional()))
1263         return;
1264     checkErrorAndVisit(whileLoop.body());
1265 }
1266
1267 void Checker::visit(AST::DoWhileLoop& doWhileLoop)
1268 {
1269     checkErrorAndVisit(doWhileLoop.body());
1270     recurseAndRequireBoolType(doWhileLoop.conditional());
1271 }
1272
1273 void Checker::visit(AST::ForLoop& forLoop)
1274 {
1275     WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::Statement>& statement) {
1276         checkErrorAndVisit(statement);
1277     }, [&](UniqueRef<AST::Expression>& expression) {
1278         checkErrorAndVisit(expression);
1279     }), forLoop.initialization());
1280     if (error())
1281         return;
1282     if (forLoop.condition()) {
1283         if (!recurseAndRequireBoolType(*forLoop.condition()))
1284             return;
1285     }
1286     if (forLoop.increment())
1287         checkErrorAndVisit(*forLoop.increment());
1288     checkErrorAndVisit(forLoop.body());
1289 }
1290
1291 void Checker::visit(AST::SwitchStatement& switchStatement)
1292 {
1293     auto* valueType = ([&]() -> AST::NamedType* {
1294         auto valueInfo = recurseAndGetInfo(switchStatement.value());
1295         if (!valueInfo)
1296             return nullptr;
1297         auto* valueType = getUnnamedType(valueInfo->resolvingType);
1298         if (!valueType)
1299             return nullptr;
1300         auto& valueUnifyNode = valueType->unifyNode();
1301         if (!is<AST::NamedType>(valueUnifyNode))
1302             return nullptr;
1303         auto& valueNamedUnifyNode = downcast<AST::NamedType>(valueUnifyNode);
1304         if (!(is<AST::NativeTypeDeclaration>(valueNamedUnifyNode) && downcast<AST::NativeTypeDeclaration>(valueNamedUnifyNode).isInt())
1305             && !is<AST::EnumerationDefinition>(valueNamedUnifyNode))
1306             return nullptr;
1307         return &valueNamedUnifyNode;
1308     })();
1309     if (!valueType) {
1310         setError();
1311         return;
1312     }
1313
1314     bool hasDefault = false;
1315     for (auto& switchCase : switchStatement.switchCases()) {
1316         checkErrorAndVisit(switchCase.block());
1317         if (!switchCase.value()) {
1318             hasDefault = true;
1319             continue;
1320         }
1321         auto success = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> bool {
1322             return static_cast<bool>(matchAndCommit(*valueType, integerLiteral.type()));
1323         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> bool {
1324             return static_cast<bool>(matchAndCommit(*valueType, unsignedIntegerLiteral.type()));
1325         }, [&](AST::FloatLiteral& floatLiteral) -> bool {
1326             return static_cast<bool>(matchAndCommit(*valueType, floatLiteral.type()));
1327         }, [&](AST::NullLiteral& nullLiteral) -> bool {
1328             return static_cast<bool>(matchAndCommit(*valueType, nullLiteral.type()));
1329         }, [&](AST::BooleanLiteral&) -> bool {
1330             return matches(*valueType, m_intrinsics.boolType());
1331         }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> bool {
1332             ASSERT(enumerationMemberLiteral.enumerationDefinition());
1333             return matches(*valueType, *enumerationMemberLiteral.enumerationDefinition());
1334         }));
1335         if (!success) {
1336             setError();
1337             return;
1338         }
1339     }
1340
1341     for (size_t i = 0; i < switchStatement.switchCases().size(); ++i) {
1342         auto& firstCase = switchStatement.switchCases()[i];
1343         for (size_t j = i + 1; j < switchStatement.switchCases().size(); ++j) {
1344             auto& secondCase = switchStatement.switchCases()[j];
1345             
1346             if (static_cast<bool>(firstCase.value()) != static_cast<bool>(secondCase.value()))
1347                 continue;
1348
1349             if (!static_cast<bool>(firstCase.value())) {
1350                 setError();
1351                 return;
1352             }
1353
1354             auto success = firstCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& firstIntegerLiteral) -> bool {
1355                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1356                     return firstIntegerLiteral.value() != secondIntegerLiteral.value();
1357                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1358                     return static_cast<int64_t>(firstIntegerLiteral.value()) != static_cast<int64_t>(secondUnsignedIntegerLiteral.value());
1359                 }, [](auto&) -> bool {
1360                     return true;
1361                 }));
1362             }, [&](AST::UnsignedIntegerLiteral& firstUnsignedIntegerLiteral) -> bool {
1363                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1364                     return static_cast<int64_t>(firstUnsignedIntegerLiteral.value()) != static_cast<int64_t>(secondIntegerLiteral.value());
1365                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1366                     return firstUnsignedIntegerLiteral.value() != secondUnsignedIntegerLiteral.value();
1367                 }, [](auto&) -> bool {
1368                     return true;
1369                 }));
1370             }, [&](AST::EnumerationMemberLiteral& firstEnumerationMemberLiteral) -> bool {
1371                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& secondEnumerationMemberLiteral) -> bool {
1372                     ASSERT(firstEnumerationMemberLiteral.enumerationMember());
1373                     ASSERT(secondEnumerationMemberLiteral.enumerationMember());
1374                     return firstEnumerationMemberLiteral.enumerationMember() != secondEnumerationMemberLiteral.enumerationMember();
1375                 }, [](auto&) -> bool {
1376                     return true;
1377                 }));
1378             }, [](auto&) -> bool {
1379                 return true;
1380             }));
1381             if (!success) {
1382                 setError();
1383                 return;
1384             }
1385         }
1386     }
1387
1388     if (!hasDefault) {
1389         if (is<AST::NativeTypeDeclaration>(*valueType)) {
1390             HashSet<int64_t> values;
1391             bool zeroValueExists;
1392             for (auto& switchCase : switchStatement.switchCases()) {
1393                 auto value = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
1394                     return integerLiteral.valueForSelectedType();
1395                 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
1396                     return unsignedIntegerLiteral.valueForSelectedType();
1397                 }, [](auto&) -> int64_t {
1398                     ASSERT_NOT_REACHED();
1399                     return 0;
1400                 }));
1401                 if (!value)
1402                     zeroValueExists = true;
1403                 else
1404                     values.add(value);
1405             }
1406             bool success = true;
1407             downcast<AST::NativeTypeDeclaration>(*valueType).iterateAllValues([&](int64_t value) -> bool {
1408                 if (!value) {
1409                     if (!zeroValueExists) {
1410                         success = false;
1411                         return true;
1412                     }
1413                     return false;
1414                 }
1415                 if (!values.contains(value)) {
1416                     success = false;
1417                     return true;
1418                 }
1419                 return false;
1420             });
1421             if (!success) {
1422                 setError();
1423                 return;
1424             }
1425         } else {
1426             HashSet<AST::EnumerationMember*> values;
1427             for (auto& switchCase : switchStatement.switchCases()) {
1428                 switchCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
1429                     ASSERT(enumerationMemberLiteral.enumerationMember());
1430                     values.add(enumerationMemberLiteral.enumerationMember());
1431                 }, [](auto&) {
1432                     ASSERT_NOT_REACHED();
1433                 }));
1434             }
1435             for (auto& enumerationMember : downcast<AST::EnumerationDefinition>(*valueType).enumerationMembers()) {
1436                 if (!values.contains(&enumerationMember.get())) {
1437                     setError();
1438                     return;
1439                 }
1440             }
1441         }
1442     }
1443 }
1444
1445 void Checker::visit(AST::CommaExpression& commaExpression)
1446 {
1447     ASSERT(commaExpression.list().size() > 0);
1448     Visitor::visit(commaExpression);
1449     if (error())
1450         return;
1451     auto lastInfo = getInfo(commaExpression.list().last());
1452     forwardType(commaExpression, lastInfo->resolvingType);
1453 }
1454
1455 void Checker::visit(AST::TernaryExpression& ternaryExpression)
1456 {
1457     auto predicateInfo = recurseAndRequireBoolType(ternaryExpression.predicate());
1458     if (!predicateInfo)
1459         return;
1460
1461     auto bodyInfo = recurseAndGetInfo(ternaryExpression.bodyExpression());
1462     auto elseInfo = recurseAndGetInfo(ternaryExpression.elseExpression());
1463     
1464     auto resultType = matchAndCommit(bodyInfo->resolvingType, elseInfo->resolvingType);
1465     if (!resultType) {
1466         setError();
1467         return;
1468     }
1469
1470     assignType(ternaryExpression, WTFMove(*resultType));
1471 }
1472
1473 void Checker::visit(AST::CallExpression& callExpression)
1474 {
1475     Vector<std::reference_wrapper<ResolvingType>> types;
1476     types.reserveInitialCapacity(callExpression.arguments().size());
1477     for (auto& argument : callExpression.arguments()) {
1478         auto argumentInfo = recurseAndGetInfo(argument);
1479         if (!argumentInfo)
1480             return;
1481         types.uncheckedAppend(argumentInfo->resolvingType);
1482     }
1483     // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
1484     // We don't want to recurse to the same node twice.
1485
1486     NameContext& nameContext = m_program.nameContext();
1487     auto* functions = nameContext.getFunctions(callExpression.name());
1488     if (!functions) {
1489         if (auto* types = nameContext.getTypes(callExpression.name())) {
1490             if (types->size() == 1) {
1491                 if ((functions = nameContext.getFunctions("operator cast"_str)))
1492                     callExpression.setCastData((*types)[0].get());
1493             }
1494         }
1495     }
1496     if (!functions) {
1497         setError();
1498         return;
1499     }
1500
1501     auto* function = resolveFunction(m_program, functions, types, callExpression.name(), callExpression.origin(), m_intrinsics, callExpression.castReturnType());
1502     if (!function) {
1503         setError();
1504         return;
1505     }
1506
1507     for (size_t i = 0; i < function->parameters().size(); ++i) {
1508         if (!matchAndCommit(types[i].get(), *function->parameters()[i]->type())) {
1509             setError();
1510             return;
1511         }
1512     }
1513
1514     callExpression.setFunction(*function);
1515
1516     assignType(callExpression, function->type().clone());
1517 }
1518
1519 bool check(Program& program)
1520 {
1521     Checker checker(program.intrinsics(), program);
1522     checker.checkErrorAndVisit(program);
1523     if (checker.error())
1524         return false;
1525     return checker.assignTypes();
1526 }
1527
1528 } // namespace WHLSL
1529
1530 } // namespace WebCore
1531
1532 #endif // ENABLE(WEBGPU)