[WTF] Add makeUnique<T>, which ensures T is fast-allocated, makeUnique / makeUniqueWi...
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLChecker.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLChecker.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLArrayReferenceType.h"
32 #include "WHLSLArrayType.h"
33 #include "WHLSLAssignmentExpression.h"
34 #include "WHLSLCallExpression.h"
35 #include "WHLSLCommaExpression.h"
36 #include "WHLSLDereferenceExpression.h"
37 #include "WHLSLDoWhileLoop.h"
38 #include "WHLSLDotExpression.h"
39 #include "WHLSLEntryPointType.h"
40 #include "WHLSLForLoop.h"
41 #include "WHLSLGatherEntryPointItems.h"
42 #include "WHLSLIfStatement.h"
43 #include "WHLSLIndexExpression.h"
44 #include "WHLSLInferTypes.h"
45 #include "WHLSLLogicalExpression.h"
46 #include "WHLSLLogicalNotExpression.h"
47 #include "WHLSLMakeArrayReferenceExpression.h"
48 #include "WHLSLMakePointerExpression.h"
49 #include "WHLSLNameContext.h"
50 #include "WHLSLPointerType.h"
51 #include "WHLSLProgram.h"
52 #include "WHLSLReadModifyWriteExpression.h"
53 #include "WHLSLResolvableType.h"
54 #include "WHLSLResolveOverloadImpl.h"
55 #include "WHLSLResolvingType.h"
56 #include "WHLSLReturn.h"
57 #include "WHLSLSwitchStatement.h"
58 #include "WHLSLTernaryExpression.h"
59 #include "WHLSLVisitor.h"
60 #include "WHLSLWhileLoop.h"
61 #include <wtf/HashMap.h>
62 #include <wtf/HashSet.h>
63 #include <wtf/Ref.h>
64 #include <wtf/Vector.h>
65 #include <wtf/text/WTFString.h>
66
67 namespace WebCore {
68
69 namespace WHLSL {
70
71 class PODChecker : public Visitor {
72 public:
73     PODChecker() = default;
74
75     virtual ~PODChecker() = default;
76
77     void visit(AST::EnumerationDefinition& enumerationDefinition) override
78     {
79         Visitor::visit(enumerationDefinition);
80     }
81
82     void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override
83     {
84         if (!nativeTypeDeclaration.isNumber()
85             && !nativeTypeDeclaration.isVector()
86             && !nativeTypeDeclaration.isMatrix())
87             setError(Error("Use of native type is not a POD in entrypoint semantic.", nativeTypeDeclaration.codeLocation()));
88     }
89
90     void visit(AST::StructureDefinition& structureDefinition) override
91     {
92         Visitor::visit(structureDefinition);
93     }
94
95     void visit(AST::TypeDefinition& typeDefinition) override
96     {
97         Visitor::visit(typeDefinition);
98     }
99
100     void visit(AST::ArrayType& arrayType) override
101     {
102         Visitor::visit(arrayType);
103     }
104
105     void visit(AST::PointerType& pointerType) override
106     {
107         setError(Error("Illegal use of pointer in entrypoint semantic.", pointerType.codeLocation()));
108     }
109
110     void visit(AST::ArrayReferenceType& arrayReferenceType) override
111     {
112         setError(Error("Illegal use of array reference in entrypoint semantic.", arrayReferenceType.codeLocation()));
113     }
114
115     void visit(AST::TypeReference& typeReference) override
116     {
117         checkErrorAndVisit(typeReference.resolvedType());
118     }
119 };
120
121 class FunctionKey {
122 public:
123     FunctionKey() = default;
124     FunctionKey(WTF::HashTableDeletedValueType)
125     {
126         m_castReturnType = bitwise_cast<AST::NamedType*>(static_cast<uintptr_t>(1));
127     }
128
129     FunctionKey(String name, Vector<std::reference_wrapper<AST::UnnamedType>> types, AST::NamedType* castReturnType = nullptr)
130         : m_name(WTFMove(name))
131         , m_types(WTFMove(types))
132         , m_castReturnType(castReturnType)
133     { }
134
135     bool isEmptyValue() const { return m_name.isNull(); }
136     bool isHashTableDeletedValue() const { return m_castReturnType == bitwise_cast<AST::NamedType*>(static_cast<uintptr_t>(1)); }
137
138     unsigned hash() const
139     {
140         unsigned hash = IntHash<size_t>::hash(m_types.size());
141         hash ^= m_name.hash();
142         for (size_t i = 0; i < m_types.size(); ++i)
143             hash ^= m_types[i].get().hash() + i;
144
145         if (m_castReturnType)
146             hash ^= WTF::PtrHash<AST::Type*>::hash(&m_castReturnType->unifyNode());
147
148         return hash;
149     }
150
151     bool operator==(const FunctionKey& other) const
152     {
153         if (m_types.size() != other.m_types.size())
154             return false;
155
156         if (m_name != other.m_name)
157             return false;
158
159         for (size_t i = 0; i < m_types.size(); ++i) {
160             if (!matches(m_types[i].get(), other.m_types[i].get()))
161                 return false;
162         }
163
164         if (!!m_castReturnType != !!other.m_castReturnType)
165             return false;
166
167         if (!m_castReturnType)
168             return true;
169
170         if (&m_castReturnType->unifyNode() == &other.m_castReturnType->unifyNode())
171             return true;
172
173         return false;
174     }
175
176     struct Hash {
177         static unsigned hash(const FunctionKey& key)
178         {
179             return key.hash();
180         }
181
182         static bool equal(const FunctionKey& a, const FunctionKey& b)
183         {
184             return a == b;
185         }
186
187         static const bool safeToCompareToEmptyOrDeleted = false;
188         static const bool emptyValueIsZero = false;
189     };
190
191     struct Traits : public WTF::SimpleClassHashTraits<FunctionKey> {
192         static const bool hasIsEmptyValueFunction = true;
193         static bool isEmptyValue(const FunctionKey& key) { return key.isEmptyValue(); }
194     };
195
196 private:
197     String m_name;
198     Vector<std::reference_wrapper<AST::UnnamedType>> m_types;
199     AST::NamedType* m_castReturnType;
200 };
201
202 class AndOverloadTypeKey {
203 public:
204     AndOverloadTypeKey() = default;
205     AndOverloadTypeKey(WTF::HashTableDeletedValueType)
206     {
207         m_type = bitwise_cast<AST::UnnamedType*>(static_cast<uintptr_t>(1));
208     }
209
210     AndOverloadTypeKey(AST::UnnamedType& type, AST::AddressSpace addressSpace)
211         : m_type(&type)
212         , m_addressSpace(addressSpace)
213     { }
214
215     bool isEmptyValue() const { return !m_type; }
216     bool isHashTableDeletedValue() const { return m_type == bitwise_cast<AST::UnnamedType*>(static_cast<uintptr_t>(1)); }
217
218     unsigned hash() const
219     {
220         return IntHash<uint8_t>::hash(static_cast<uint8_t>(m_addressSpace)) ^ m_type->hash();
221     }
222
223     bool operator==(const AndOverloadTypeKey& other) const
224     {
225         return m_addressSpace == other.m_addressSpace
226             && *m_type == *other.m_type;
227     }
228
229     struct Hash {
230         static unsigned hash(const AndOverloadTypeKey& key)
231         {
232             return key.hash();
233         }
234
235         static bool equal(const AndOverloadTypeKey& a, const AndOverloadTypeKey& b)
236         {
237             return a == b;
238         }
239
240         static const bool safeToCompareToEmptyOrDeleted = false;
241     };
242
243     struct Traits : public WTF::SimpleClassHashTraits<AndOverloadTypeKey> {
244         static const bool hasIsEmptyValueFunction = true;
245         static bool isEmptyValue(const AndOverloadTypeKey& key) { return key.isEmptyValue(); }
246     };
247
248 private:
249     AST::UnnamedType* m_type { nullptr };
250     AST::AddressSpace m_addressSpace;
251 };
252
253 static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(CodeLocation location, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics)
254 {
255     const bool isOperator = true;
256     auto returnType = AST::PointerType::create(location, firstArgument.addressSpace(), firstArgument.elementType());
257     AST::VariableDeclarations parameters;
258     parameters.append(makeUniqueRef<AST::VariableDeclaration>(location, AST::Qualifiers(), &firstArgument, String(), nullptr, nullptr));
259     parameters.append(makeUniqueRef<AST::VariableDeclaration>(location, AST::Qualifiers(), AST::TypeReference::wrap(location, intrinsics.uintType()), String(), nullptr, nullptr));
260     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(location, AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator&[]", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator, ParsingMode::StandardLibrary));
261 }
262
263 static AST::NativeFunctionDeclaration resolveWithOperatorLength(CodeLocation location, AST::UnnamedType& firstArgument, const Intrinsics& intrinsics)
264 {
265     const bool isOperator = true;
266     auto returnType = AST::TypeReference::wrap(location, intrinsics.uintType());
267     AST::VariableDeclarations parameters;
268     parameters.append(makeUniqueRef<AST::VariableDeclaration>(location, AST::Qualifiers(), &firstArgument, String(), nullptr, nullptr));
269     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(location, AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator.length", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator, ParsingMode::StandardLibrary));
270 }
271
272 static AST::NativeFunctionDeclaration resolveWithReferenceComparator(CodeLocation location, ResolvingType& firstArgument, ResolvingType& secondArgument, const Intrinsics& intrinsics)
273 {
274     const bool isOperator = true;
275     auto returnType = AST::TypeReference::wrap(location, intrinsics.boolType());
276     auto argumentType = firstArgument.visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& unnamedType) -> Ref<AST::UnnamedType> {
277         return unnamedType.copyRef();
278     }, [&](RefPtr<ResolvableTypeReference>&) -> Ref<AST::UnnamedType> {
279         return secondArgument.visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& unnamedType) -> Ref<AST::UnnamedType> {
280             return unnamedType.copyRef();
281         }, [&](RefPtr<ResolvableTypeReference>&) -> Ref<AST::UnnamedType> {
282             // We encountered "null == null".
283             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198162 This can probably be generalized, using the "preferred type" infrastructure used by generic literals
284             ASSERT_NOT_REACHED();
285             return AST::TypeReference::wrap(location, intrinsics.intType());
286         }));
287     }));
288     AST::VariableDeclarations parameters;
289     parameters.append(makeUniqueRef<AST::VariableDeclaration>(location, AST::Qualifiers(), argumentType.copyRef(), String(), nullptr, nullptr));
290     parameters.append(makeUniqueRef<AST::VariableDeclaration>(location, AST::Qualifiers(), WTFMove(argumentType), String(), nullptr, nullptr));
291     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(location, AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==", String::ConstructFromLiteral), WTFMove(parameters), nullptr, isOperator, ParsingMode::StandardLibrary));
292 }
293
294 enum class Acceptability {
295     Yes,
296     Maybe,
297     No
298 };
299
300 static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(const String& name, CodeLocation location, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics)
301 {
302     if (name == "operator&[]" && types.size() == 2) {
303         auto* firstArgumentArrayRef = types[0].get().visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* {
304             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
305                 return &downcast<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType));
306             return nullptr;
307         }, [](RefPtr<ResolvableTypeReference>&) -> AST::ArrayReferenceType* {
308             return nullptr;
309         }));
310         bool secondArgumentIsUint = types[1].get().visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> bool {
311             return matches(unnamedType, intrinsics.uintType());
312         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
313             return resolvableTypeReference->resolvableType().canResolve(intrinsics.uintType());
314         }));
315         if (firstArgumentArrayRef && secondArgumentIsUint)
316             return resolveWithOperatorAnderIndexer(location, *firstArgumentArrayRef, intrinsics);
317     } else if (name == "operator.length" && types.size() == 1) {
318         auto* firstArgumentReference = types[0].get().visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
319             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) || is<AST::ArrayType>(static_cast<AST::UnnamedType&>(unnamedType)))
320                 return unnamedType.ptr();
321             return nullptr;
322         }, [](RefPtr<ResolvableTypeReference>&) -> AST::UnnamedType* {
323             return nullptr;
324         }));
325         if (firstArgumentReference)
326             return resolveWithOperatorLength(location, *firstArgumentReference, intrinsics);
327     } else if (name == "operator==" && types.size() == 2) {
328         auto acceptability = [](ResolvingType& resolvingType) -> Acceptability {
329             return resolvingType.visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& unnamedType) -> Acceptability {
330                 auto& unifyNode = unnamedType->unifyNode();
331                 return is<AST::UnnamedType>(unifyNode) && is<AST::ReferenceType>(downcast<AST::UnnamedType>(unifyNode)) ? Acceptability::Yes : Acceptability::No;
332             }, [](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Acceptability {
333                 return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType()) ? Acceptability::Maybe : Acceptability::No;
334             }));
335         };
336         auto leftAcceptability = acceptability(types[0].get());
337         auto rightAcceptability = acceptability(types[1].get());
338         bool success = false;
339         if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) {
340             auto& unnamedType1 = *types[0].get().getUnnamedType();
341             auto& unnamedType2 = *types[1].get().getUnnamedType();
342             success = matches(unnamedType1, unnamedType2);
343         } else if ((leftAcceptability == Acceptability::Maybe && rightAcceptability == Acceptability::Yes)
344             || (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Maybe))
345             success = true;
346         if (success)
347             return resolveWithReferenceComparator(location, types[0].get(), types[1].get(), intrinsics);
348     }
349     return WTF::nullopt;
350 }
351
352 static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
353 {
354     {
355         auto checkDuplicateSemantics = [&](const Vector<EntryPointItem>& items) -> bool {
356             for (size_t i = 0; i < items.size(); ++i) {
357                 for (size_t j = i + 1; j < items.size(); ++j) {
358                     if (items[i].semantic == items[j].semantic)
359                         return false;
360                 }
361             }
362             return true;
363         };
364         if (!checkDuplicateSemantics(inputItems))
365             return false;
366         if (!checkDuplicateSemantics(outputItems))
367             return false;
368     }
369
370     {
371         auto checkSemanticTypes = [&](const Vector<EntryPointItem>& items) -> bool {
372             for (auto& item : items) {
373                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
374                     return semantic.isAcceptableType(*item.unnamedType, intrinsics);
375                 }), *item.semantic);
376                 if (!acceptable)
377                     return false;
378             }
379             return true;
380         };
381         if (!checkSemanticTypes(inputItems))
382             return false;
383         if (!checkSemanticTypes(outputItems))
384             return false;
385     }
386
387     {
388         auto checkSemanticForShaderType = [&](const Vector<EntryPointItem>& items, AST::BaseSemantic::ShaderItemDirection direction) -> bool {
389             for (auto& item : items) {
390                 auto acceptable = WTF::visit(WTF::makeVisitor([&](const AST::BaseSemantic& semantic) -> bool {
391                     return semantic.isAcceptableForShaderItemDirection(direction, entryPointType);
392                 }), *item.semantic);
393                 if (!acceptable)
394                     return false;
395             }
396             return true;
397         };
398         if (!checkSemanticForShaderType(inputItems, AST::BaseSemantic::ShaderItemDirection::Input))
399             return false;
400         if (!checkSemanticForShaderType(outputItems, AST::BaseSemantic::ShaderItemDirection::Output))
401             return false;
402     }
403
404     {
405         auto checkPODData = [&](const Vector<EntryPointItem>& items) -> bool {
406             for (auto& item : items) {
407                 PODChecker podChecker;
408                 if (is<AST::PointerType>(item.unnamedType))
409                     podChecker.checkErrorAndVisit(downcast<AST::PointerType>(*item.unnamedType).elementType());
410                 else if (is<AST::ArrayReferenceType>(item.unnamedType))
411                     podChecker.checkErrorAndVisit(downcast<AST::ArrayReferenceType>(*item.unnamedType).elementType());
412                 else if (is<AST::ArrayType>(item.unnamedType))
413                     podChecker.checkErrorAndVisit(downcast<AST::ArrayType>(*item.unnamedType).type());
414                 else
415                     continue;
416                 if (podChecker.hasError())
417                     return false;
418             }
419             return true;
420         };
421         if (!checkPODData(inputItems))
422             return false;
423         if (!checkPODData(outputItems))
424             return false;
425     }
426
427     return true;
428 }
429
430 static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinition, NameContext& nameContext)
431 {
432     enum class CheckKind {
433         Index,
434         Dot
435     };
436
437     auto checkGetter = [&](CheckKind kind) -> bool {
438         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
439         if (functionDefinition.parameters().size() != numExpectedParameters)
440             return false;
441         auto& firstParameterUnifyNode = functionDefinition.parameters()[0]->type()->unifyNode();
442         if (is<AST::UnnamedType>(firstParameterUnifyNode)) {
443             auto& unnamedType = downcast<AST::UnnamedType>(firstParameterUnifyNode);
444             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
445                 return false;
446         }
447         if (kind == CheckKind::Index) {
448             auto& secondParameterUnifyNode = functionDefinition.parameters()[1]->type()->unifyNode();
449             if (!is<AST::NamedType>(secondParameterUnifyNode))
450                 return false;
451             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
452             if (!is<AST::NativeTypeDeclaration>(namedType))
453                 return false;
454             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
455             if (!nativeTypeDeclaration.isInt())
456                 return false;
457         }
458         return true;
459     };
460
461     auto checkSetter = [&](CheckKind kind) -> bool {
462         size_t numExpectedParameters = kind == CheckKind::Index ? 3 : 2;
463         if (functionDefinition.parameters().size() != numExpectedParameters)
464             return false;
465         auto& firstArgumentUnifyNode = functionDefinition.parameters()[0]->type()->unifyNode();
466         if (is<AST::UnnamedType>(firstArgumentUnifyNode)) {
467             auto& unnamedType = downcast<AST::UnnamedType>(firstArgumentUnifyNode);
468             if (is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType) || is<AST::ArrayType>(unnamedType))
469                 return false;
470         }
471         if (kind == CheckKind::Index) {
472             auto& secondParameterUnifyNode = functionDefinition.parameters()[1]->type()->unifyNode();
473             if (!is<AST::NamedType>(secondParameterUnifyNode))
474                 return false;
475             auto& namedType = downcast<AST::NamedType>(secondParameterUnifyNode);
476             if (!is<AST::NativeTypeDeclaration>(namedType))
477                 return false;
478             auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
479             if (!nativeTypeDeclaration.isInt())
480                 return false;
481         }
482         if (!matches(functionDefinition.type(), *functionDefinition.parameters()[0]->type()))
483             return false;
484         auto& valueType = *functionDefinition.parameters()[numExpectedParameters - 1]->type();
485         auto getterName = functionDefinition.name().substring(0, functionDefinition.name().length() - 1);
486         auto* getterFuncs = nameContext.getFunctions(getterName);
487         if (!getterFuncs)
488             return false;
489         Vector<ResolvingType> argumentTypes;
490         Vector<std::reference_wrapper<ResolvingType>> argumentTypeReferences;
491         for (size_t i = 0; i < numExpectedParameters - 1; ++i)
492             argumentTypes.append(*functionDefinition.parameters()[i]->type());
493         for (auto& argumentType : argumentTypes)
494             argumentTypeReferences.append(argumentType);
495         auto* overload = resolveFunctionOverload(*getterFuncs, argumentTypeReferences);
496         if (!overload)
497             return false;
498         auto& resultType = overload->type();
499         return matches(resultType, valueType);
500     };
501
502     auto checkAnder = [&](CheckKind kind) -> bool {
503         size_t numExpectedParameters = kind == CheckKind::Index ? 2 : 1;
504         if (functionDefinition.parameters().size() != numExpectedParameters)
505             return false;
506         {
507             auto& unifyNode = functionDefinition.type().unifyNode();
508             if (!is<AST::UnnamedType>(unifyNode))
509                 return false;
510             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
511             if (!is<AST::PointerType>(unnamedType))
512                 return false;
513         }
514         {
515             auto& unifyNode = functionDefinition.parameters()[0]->type()->unifyNode();
516             if (!is<AST::UnnamedType>(unifyNode))
517                 return false;
518             auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
519             return is<AST::PointerType>(unnamedType) || is<AST::ArrayReferenceType>(unnamedType);
520         }
521     };
522
523     if (!functionDefinition.isOperator())
524         return true;
525     if (functionDefinition.isCast())
526         return true;
527     if (functionDefinition.name() == "operator++" || functionDefinition.name() == "operator--") {
528         return functionDefinition.parameters().size() == 1
529             && matches(*functionDefinition.parameters()[0]->type(), functionDefinition.type());
530     }
531     if (functionDefinition.name() == "operator+" || functionDefinition.name() == "operator-")
532         return functionDefinition.parameters().size() == 1 || functionDefinition.parameters().size() == 2;
533     if (functionDefinition.name() == "operator*"
534         || functionDefinition.name() == "operator/"
535         || functionDefinition.name() == "operator%"
536         || functionDefinition.name() == "operator&"
537         || functionDefinition.name() == "operator|"
538         || functionDefinition.name() == "operator^"
539         || functionDefinition.name() == "operator<<"
540         || functionDefinition.name() == "operator>>")
541         return functionDefinition.parameters().size() == 2;
542     if (functionDefinition.name() == "operator~")
543         return functionDefinition.parameters().size() == 1;
544     if (functionDefinition.name() == "operator[]")
545         return checkGetter(CheckKind::Index);
546     if (functionDefinition.name() == "operator[]=")
547         return checkSetter(CheckKind::Index);
548     if (functionDefinition.name() == "operator&[]")
549         return checkAnder(CheckKind::Index);
550     if (functionDefinition.name().startsWith("operator.")) {
551         if (functionDefinition.name().endsWith("="))
552             return checkSetter(CheckKind::Dot);
553         return checkGetter(CheckKind::Dot);
554     }
555     if (functionDefinition.name().startsWith("operator&."))
556         return checkAnder(CheckKind::Dot);
557     return false;
558 }
559
560 class Checker : public Visitor {
561 public:
562     Checker(const Intrinsics& intrinsics, Program& program)
563         : m_intrinsics(intrinsics)
564         , m_program(program)
565     {
566         auto addFunction = [&] (AST::FunctionDeclaration& function) {
567             AST::NamedType* castReturnType = nullptr;
568             if (function.isCast() && is<AST::NamedType>(function.type().unifyNode()))
569                 castReturnType = &downcast<AST::NamedType>(function.type().unifyNode());
570
571             Vector<std::reference_wrapper<AST::UnnamedType>> types;
572             types.reserveInitialCapacity(function.parameters().size());
573
574             for (auto& param : function.parameters())
575                 types.uncheckedAppend(normalizedTypeForFunctionKey(*param->type()));
576
577             auto addResult = m_functions.add(FunctionKey { function.name(), WTFMove(types), castReturnType }, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>());
578             addResult.iterator->value.append(function);
579         };
580
581         for (auto& function : m_program.functionDefinitions())
582             addFunction(function.get());
583         for (auto& function : m_program.nativeFunctionDeclarations())
584             addFunction(function.get());
585     }
586
587     virtual ~Checker() = default;
588
589     void visit(Program&) override;
590
591     Expected<void, Error> assignTypes();
592
593 private:
594     bool checkShaderType(const AST::FunctionDefinition&);
595     bool isBoolType(ResolvingType&);
596     struct RecurseInfo {
597         ResolvingType& resolvingType;
598         const AST::TypeAnnotation typeAnnotation;
599     };
600     Optional<RecurseInfo> recurseAndGetInfo(AST::Expression&, bool requiresLeftValue = false);
601     Optional<RecurseInfo> getInfo(AST::Expression&, bool requiresLeftValue = false);
602     RefPtr<AST::UnnamedType> recurseAndWrapBaseType(AST::PropertyAccessExpression&);
603     bool recurseAndRequireBoolType(AST::Expression&);
604     void assignConcreteType(AST::Expression&, Ref<AST::UnnamedType>, AST::TypeAnnotation);
605     void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>, AST::TypeAnnotation);
606     void forwardType(AST::Expression&, ResolvingType&, AST::TypeAnnotation);
607
608     void visit(AST::FunctionDefinition&) override;
609     void visit(AST::EnumerationDefinition&) override;
610     void visit(AST::TypeReference&) override;
611     void visit(AST::VariableDeclaration&) override;
612     void visit(AST::AssignmentExpression&) override;
613     void visit(AST::ReadModifyWriteExpression&) override;
614     void visit(AST::DereferenceExpression&) override;
615     void visit(AST::MakePointerExpression&) override;
616     void visit(AST::MakeArrayReferenceExpression&) override;
617     void visit(AST::DotExpression&) override;
618     void visit(AST::IndexExpression&) override;
619     void visit(AST::VariableReference&) override;
620     void visit(AST::Return&) override;
621     void visit(AST::PointerType&) override;
622     void visit(AST::ArrayReferenceType&) override;
623     void visit(AST::IntegerLiteral&) override;
624     void visit(AST::UnsignedIntegerLiteral&) override;
625     void visit(AST::FloatLiteral&) override;
626     void visit(AST::NullLiteral&) override;
627     void visit(AST::BooleanLiteral&) override;
628     void visit(AST::EnumerationMemberLiteral&) override;
629     void visit(AST::LogicalNotExpression&) override;
630     void visit(AST::LogicalExpression&) override;
631     void visit(AST::IfStatement&) override;
632     void visit(AST::WhileLoop&) override;
633     void visit(AST::DoWhileLoop&) override;
634     void visit(AST::ForLoop&) override;
635     void visit(AST::SwitchStatement&) override;
636     void visit(AST::CommaExpression&) override;
637     void visit(AST::TernaryExpression&) override;
638     void visit(AST::CallExpression&) override;
639
640     void finishVisiting(AST::PropertyAccessExpression&, ResolvingType* additionalArgumentType = nullptr);
641
642     AST::FunctionDeclaration* resolveFunction(Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation, AST::NamedType* castReturnType = nullptr);
643
644     RefPtr<AST::UnnamedType> argumentTypeForAndOverload(AST::UnnamedType& baseType, AST::AddressSpace);
645
646     AST::UnnamedType& wrappedFloatType()
647     {
648         if (!m_wrappedFloatType)
649             m_wrappedFloatType = AST::TypeReference::wrap({ }, m_intrinsics.floatType());
650         return *m_wrappedFloatType;
651     }
652
653     AST::UnnamedType& genericPointerType()
654     {
655         if (!m_genericPointerType)
656             m_genericPointerType = AST::PointerType::create({ }, AST::AddressSpace::Thread, AST::TypeReference::wrap({ }, m_intrinsics.floatType()));
657         return *m_genericPointerType;
658     }
659
660     AST::UnnamedType& normalizedTypeForFunctionKey(AST::UnnamedType& type)
661     {
662         auto* unifyNode = &type.unifyNode();
663         if (unifyNode == &m_intrinsics.uintType() || unifyNode == &m_intrinsics.intType())
664             return wrappedFloatType();
665
666         if (is<AST::ReferenceType>(type))
667             return genericPointerType();
668
669         return type;
670     }
671
672     RefPtr<AST::TypeReference> m_wrappedFloatType;
673     RefPtr<AST::UnnamedType> m_genericPointerType;
674     HashMap<AST::Expression*, std::unique_ptr<ResolvingType>> m_typeMap;
675     HashSet<String> m_vertexEntryPoints;
676     HashSet<String> m_fragmentEntryPoints;
677     HashSet<String> m_computeEntryPoints;
678     const Intrinsics& m_intrinsics;
679     Program& m_program;
680     AST::FunctionDefinition* m_currentFunction { nullptr };
681     HashMap<FunctionKey, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>, FunctionKey::Hash, FunctionKey::Traits> m_functions;
682     HashMap<AndOverloadTypeKey, RefPtr<AST::UnnamedType>, AndOverloadTypeKey::Hash, AndOverloadTypeKey::Traits> m_andOverloadTypeMap;
683 };
684
685 void Checker::visit(Program& program)
686 {
687     // These visiting functions might add new global statements, so don't use foreach syntax.
688     for (size_t i = 0; i < program.typeDefinitions().size(); ++i)
689         checkErrorAndVisit(program.typeDefinitions()[i]);
690     for (size_t i = 0; i < program.structureDefinitions().size(); ++i)
691         checkErrorAndVisit(program.structureDefinitions()[i]);
692     for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i)
693         checkErrorAndVisit(program.enumerationDefinitions()[i]);
694     for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i)
695         checkErrorAndVisit(program.nativeTypeDeclarations()[i]);
696
697     for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
698         checkErrorAndVisit(program.functionDefinitions()[i]);
699     for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i)
700         checkErrorAndVisit(program.nativeFunctionDeclarations()[i]);
701 }
702
703 Expected<void, Error> Checker::assignTypes()
704 {
705     for (auto& keyValuePair : m_typeMap) {
706         auto success = keyValuePair.value->visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> bool {
707             keyValuePair.key->setType(unnamedType.copyRef());
708             return true;
709         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
710             if (!resolvableTypeReference->resolvableType().maybeResolvedType()) {
711                 if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType())))
712                     return false;
713             }
714             keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType());
715             return true;
716         }));
717         if (!success)
718             return makeUnexpected(Error("Could not resolve the type of a constant."));
719     }
720
721     return { };
722 }
723
724 bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition)
725 {
726     switch (*functionDefinition.entryPointType()) {
727     case AST::EntryPointType::Vertex:
728         return static_cast<bool>(m_vertexEntryPoints.add(functionDefinition.name()));
729     case AST::EntryPointType::Fragment:
730         return static_cast<bool>(m_fragmentEntryPoints.add(functionDefinition.name()));
731     case AST::EntryPointType::Compute:
732         return static_cast<bool>(m_computeEntryPoints.add(functionDefinition.name()));
733     }
734 }
735
736 void Checker::visit(AST::FunctionDefinition& functionDefinition)
737 {
738     m_currentFunction = &functionDefinition;
739     if (functionDefinition.entryPointType()) {
740         if (!checkShaderType(functionDefinition)) {
741             setError(Error("Duplicate entrypoint function.", functionDefinition.codeLocation()));
742             return;
743         }
744         auto entryPointItems = gatherEntryPointItems(m_intrinsics, functionDefinition);
745         if (!entryPointItems) {
746             setError(entryPointItems.error());
747             return;
748         }
749         if (!checkSemantics(entryPointItems->inputs, entryPointItems->outputs, functionDefinition.entryPointType(), m_intrinsics)) {
750             setError(Error("Bad semantics for entrypoint.", functionDefinition.codeLocation()));
751             return;
752         }
753     }
754     if (!checkOperatorOverload(functionDefinition, m_program.nameContext())) {
755         setError(Error("Operator does not match expected signature.", functionDefinition.codeLocation()));
756         return;
757     }
758
759     Visitor::visit(functionDefinition);
760 }
761
762 static RefPtr<AST::UnnamedType> matchAndCommit(ResolvingType& left, ResolvingType& right)
763 {
764     return left.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& left) -> RefPtr<AST::UnnamedType> {
765         return right.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& right) -> RefPtr<AST::UnnamedType> {
766             if (matches(left, right))
767                 return left.copyRef();
768             return nullptr;
769         }, [&](RefPtr<ResolvableTypeReference>& right) -> RefPtr<AST::UnnamedType> {
770             return matchAndCommit(left, right->resolvableType());
771         }));
772     }, [&](RefPtr<ResolvableTypeReference>& left) -> RefPtr<AST::UnnamedType> {
773         return right.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& right) -> RefPtr<AST::UnnamedType> {
774             return matchAndCommit(right, left->resolvableType());
775         }, [&](RefPtr<ResolvableTypeReference>& right) -> RefPtr<AST::UnnamedType> {
776             return matchAndCommit(left->resolvableType(), right->resolvableType());
777         }));
778     }));
779 }
780
781 static RefPtr<AST::UnnamedType> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType)
782 {
783     return resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& resolvingType) -> RefPtr<AST::UnnamedType> {
784         if (matches(unnamedType, resolvingType))
785             return &unnamedType;
786         return nullptr;
787     }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> RefPtr<AST::UnnamedType> {
788         return matchAndCommit(unnamedType, resolvingType->resolvableType());
789     }));
790 }
791
792 static RefPtr<AST::UnnamedType> commit(ResolvingType& resolvingType)
793 {
794     return resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> RefPtr<AST::UnnamedType> {
795         return unnamedType.copyRef();
796     }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> RefPtr<AST::UnnamedType> {
797         if (!resolvableTypeReference->resolvableType().maybeResolvedType())
798             return commit(resolvableTypeReference->resolvableType());
799         return &resolvableTypeReference->resolvableType().resolvedType();
800     }));
801 }
802
803 AST::FunctionDeclaration* Checker::resolveFunction(Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation location, AST::NamedType* castReturnType)
804 {
805     Vector<std::reference_wrapper<AST::UnnamedType>> unnamedTypes;
806     unnamedTypes.reserveInitialCapacity(types.size());
807
808     for (auto resolvingType : types) {
809         AST::UnnamedType* type = resolvingType.get().visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
810             return unnamedType.ptr();
811         }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> AST::UnnamedType* {
812             if (resolvableTypeReference->resolvableType().maybeResolvedType())
813                 return &resolvableTypeReference->resolvableType().resolvedType();
814
815             if (resolvableTypeReference->resolvableType().isFloatLiteralType()
816                 || resolvableTypeReference->resolvableType().isIntegerLiteralType()
817                 || resolvableTypeReference->resolvableType().isUnsignedIntegerLiteralType())
818                 return &wrappedFloatType();
819
820             if (resolvableTypeReference->resolvableType().isNullLiteralType())
821                 return &genericPointerType();
822
823             return commit(resolvableTypeReference->resolvableType()).get();
824         }));
825
826         if (!type) {
827             setError(Error("Could not resolve the type of a constant."));
828             return nullptr;
829         }
830
831         unnamedTypes.uncheckedAppend(normalizedTypeForFunctionKey(*type));
832     }
833
834     {
835         auto iter = m_functions.find(FunctionKey { name, WTFMove(unnamedTypes), castReturnType });
836         if (iter != m_functions.end()) {
837             if (AST::FunctionDeclaration* function = resolveFunctionOverload(iter->value, types, castReturnType))
838                 return function;
839         }
840     }
841
842     if (auto newFunction = resolveByInstantiation(name, location, types, m_intrinsics)) {
843         m_program.append(WTFMove(*newFunction));
844         return &m_program.nativeFunctionDeclarations().last();
845     }
846
847     return nullptr;
848 }
849
850 void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
851 {
852     bool isSigned;
853     auto* baseType = ([&]() -> AST::NativeTypeDeclaration* {
854         checkErrorAndVisit(enumerationDefinition.type());
855         auto& baseType = enumerationDefinition.type().unifyNode();
856         if (!is<AST::NamedType>(baseType))
857             return nullptr;
858         auto& namedType = downcast<AST::NamedType>(baseType);
859         if (!is<AST::NativeTypeDeclaration>(namedType))
860             return nullptr;
861         auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
862         if (!nativeTypeDeclaration.isInt())
863             return nullptr;
864         isSigned = nativeTypeDeclaration.isSigned();
865         return &nativeTypeDeclaration;
866     })();
867     if (!baseType) {
868         setError(Error("Invalid base type for enum.", enumerationDefinition.codeLocation()));
869         return;
870     }
871
872     auto enumerationMembers = enumerationDefinition.enumerationMembers();
873
874     for (auto& member : enumerationMembers) {
875         int64_t value = member.get().value();
876         if (isSigned) {
877             if (static_cast<int64_t>(static_cast<int32_t>(value)) != value) {
878                 setError(Error("Invalid enumeration value.", member.get().codeLocation()));
879                 return;
880             }
881         } else {
882             if (static_cast<int64_t>(static_cast<uint32_t>(value)) != value) {
883                 setError(Error("Invalid enumeration value.", member.get().codeLocation()));
884                 return;
885             }
886         }
887     }
888
889     for (size_t i = 0; i < enumerationMembers.size(); ++i) {
890         auto value = enumerationMembers[i].get().value();
891         for (size_t j = i + 1; j < enumerationMembers.size(); ++j) {
892             auto otherValue = enumerationMembers[j].get().value();
893             if (value == otherValue) {
894                 setError(Error("Cannot declare duplicate enumeration values.", enumerationMembers[j].get().codeLocation()));
895                 return;
896             }
897         }
898     }
899
900     bool foundZero = false;
901     for (auto& member : enumerationMembers) {
902         if (!member.get().value()) {
903             foundZero = true;
904             break;
905         }
906     }
907     if (!foundZero) {
908         setError(Error("enum definition must contain a zero value.", enumerationDefinition.codeLocation()));
909         return;
910     }
911 }
912
913 void Checker::visit(AST::TypeReference& typeReference)
914 {
915     ASSERT(typeReference.maybeResolvedType());
916
917     for (auto& typeArgument : typeReference.typeArguments())
918         checkErrorAndVisit(typeArgument);
919 }
920
921 auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
922 {
923     Visitor::visit(expression);
924     if (hasError())
925         return WTF::nullopt;
926     return getInfo(expression, requiresLeftValue);
927 }
928
929 auto Checker::getInfo(AST::Expression& expression, bool requiresLeftValue) -> Optional<RecurseInfo>
930 {
931     auto typeIterator = m_typeMap.find(&expression);
932     ASSERT(typeIterator != m_typeMap.end());
933
934     const auto& typeAnnotation = expression.typeAnnotation();
935     if (requiresLeftValue && typeAnnotation.isRightValue()) {
936         setError(Error("Unexpected rvalue.", expression.codeLocation()));
937         return WTF::nullopt;
938     }
939     return {{ *typeIterator->value, typeAnnotation }};
940 }
941
942 void Checker::visit(AST::VariableDeclaration& variableDeclaration)
943 {
944     // ReadModifyWriteExpressions are the only place where anonymous variables exist,
945     // and that doesn't recurse on the anonymous variables, so we can assume the variable has a type.
946     checkErrorAndVisit(*variableDeclaration.type());
947     if (matches(*variableDeclaration.type(), m_intrinsics.voidType())) {
948         setError(Error("Variables can't have void type.", variableDeclaration.codeLocation()));
949         return;
950     }
951     if (variableDeclaration.initializer()) {
952         auto& lhsType = *variableDeclaration.type();
953         auto initializerInfo = recurseAndGetInfo(*variableDeclaration.initializer());
954         if (!initializerInfo)
955             return;
956         if (!matchAndCommit(initializerInfo->resolvingType, lhsType)) {
957             setError(Error("Declared variable type does not match its initializer's type.", variableDeclaration.codeLocation()));
958             return;
959         }
960     }
961 }
962
963 void Checker::assignConcreteType(AST::Expression& expression, Ref<AST::UnnamedType> unnamedType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
964 {
965     auto addResult = m_typeMap.add(&expression, makeUnique<ResolvingType>(WTFMove(unnamedType)));
966     ASSERT_UNUSED(addResult, addResult.isNewEntry);
967     expression.setTypeAnnotation(WTFMove(typeAnnotation));
968 }
969
970 void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference> resolvableTypeReference, AST::TypeAnnotation typeAnnotation = AST::RightValue())
971 {
972     auto addResult = m_typeMap.add(&expression, makeUnique<ResolvingType>(WTFMove(resolvableTypeReference)));
973     ASSERT_UNUSED(addResult, addResult.isNewEntry);
974     expression.setTypeAnnotation(WTFMove(typeAnnotation));
975 }
976
977 void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, AST::TypeAnnotation typeAnnotation = AST::RightValue())
978 {
979     resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& result) {
980         auto addResult = m_typeMap.add(&expression, makeUnique<ResolvingType>(result.copyRef()));
981         ASSERT_UNUSED(addResult, addResult.isNewEntry);
982     }, [&](RefPtr<ResolvableTypeReference>& result) {
983         auto addResult = m_typeMap.add(&expression, makeUnique<ResolvingType>(result.copyRef()));
984         ASSERT_UNUSED(addResult, addResult.isNewEntry);
985     }));
986     expression.setTypeAnnotation(WTFMove(typeAnnotation));
987 }
988
989 void Checker::visit(AST::AssignmentExpression& assignmentExpression)
990 {
991     auto leftInfo = recurseAndGetInfo(assignmentExpression.left(), true);
992     if (!leftInfo)
993         return;
994
995     auto rightInfo = recurseAndGetInfo(assignmentExpression.right());
996     if (!rightInfo)
997         return;
998
999     auto resultType = matchAndCommit(leftInfo->resolvingType, rightInfo->resolvingType);
1000     if (!resultType) {
1001         setError(Error("Left hand side of assignment does not match the type of the right hand side.", assignmentExpression.codeLocation()));
1002         return;
1003     }
1004
1005     assignConcreteType(assignmentExpression, *resultType);
1006 }
1007
1008 void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
1009 {
1010     auto leftValueInfo = recurseAndGetInfo(readModifyWriteExpression.leftValue(), true);
1011     if (!leftValueInfo)
1012         return;
1013
1014     readModifyWriteExpression.oldValue().setType(*leftValueInfo->resolvingType.getUnnamedType());
1015
1016     auto newValueInfo = recurseAndGetInfo(readModifyWriteExpression.newValueExpression());
1017     if (!newValueInfo)
1018         return;
1019
1020     if (RefPtr<AST::UnnamedType> matchedType = matchAndCommit(leftValueInfo->resolvingType, newValueInfo->resolvingType))
1021         readModifyWriteExpression.newValue().setType(*matchedType);
1022     else {
1023         setError(Error("Base of the read-modify-write expression does not match the type of the new value.", readModifyWriteExpression.codeLocation()));
1024         return;
1025     }
1026
1027     auto resultInfo = recurseAndGetInfo(readModifyWriteExpression.resultExpression());
1028     if (!resultInfo)
1029         return;
1030
1031     forwardType(readModifyWriteExpression, resultInfo->resolvingType);
1032 }
1033
1034 static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType)
1035 {
1036     return resolvingType.visit(WTF::makeVisitor([](Ref<AST::UnnamedType>& type) -> AST::UnnamedType* {
1037         return type.ptr();
1038     }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* {
1039         // FIXME: If the type isn't committed, should we just commit() it now?
1040         return type->resolvableType().maybeResolvedType();
1041     }));
1042 }
1043
1044 void Checker::visit(AST::DereferenceExpression& dereferenceExpression)
1045 {
1046     auto pointerInfo = recurseAndGetInfo(dereferenceExpression.pointer());
1047     if (!pointerInfo)
1048         return;
1049
1050     auto* unnamedType = getUnnamedType(pointerInfo->resolvingType);
1051
1052     auto* pointerType = ([&](AST::UnnamedType* unnamedType) -> AST::PointerType* {
1053         if (!unnamedType)
1054             return nullptr;
1055         auto& unifyNode = unnamedType->unifyNode();
1056         if (!is<AST::UnnamedType>(unifyNode))
1057             return nullptr;
1058         auto& unnamedUnifyType = downcast<AST::UnnamedType>(unifyNode);
1059         if (!is<AST::PointerType>(unnamedUnifyType))
1060             return nullptr;
1061         return &downcast<AST::PointerType>(unnamedUnifyType);
1062     })(unnamedType);
1063     if (!pointerType) {
1064         setError(Error("Cannot dereference a non-pointer type.", dereferenceExpression.codeLocation()));
1065         return;
1066     }
1067
1068     assignConcreteType(dereferenceExpression, pointerType->elementType(), AST::LeftValue { pointerType->addressSpace() });
1069 }
1070
1071 void Checker::visit(AST::MakePointerExpression& makePointerExpression)
1072 {
1073     auto leftValueInfo = recurseAndGetInfo(makePointerExpression.leftValue(), true);
1074     if (!leftValueInfo)
1075         return;
1076
1077     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
1078     if (!leftAddressSpace) {
1079         setError(Error("Cannot take the address of a non lvalue.", makePointerExpression.codeLocation()));
1080         return;
1081     }
1082
1083     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
1084     if (!leftValueType) {
1085         setError(Error("Cannot take the address of a value without a type.", makePointerExpression.codeLocation()));
1086         return;
1087     }
1088
1089     assignConcreteType(makePointerExpression, AST::PointerType::create(makePointerExpression.codeLocation(), *leftAddressSpace, *leftValueType));
1090 }
1091
1092 void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
1093 {
1094     auto leftValueInfo = recurseAndGetInfo(makeArrayReferenceExpression.leftValue());
1095     if (!leftValueInfo)
1096         return;
1097
1098     auto* leftValueType = getUnnamedType(leftValueInfo->resolvingType);
1099     if (!leftValueType) {
1100         setError(Error("Cannot make an array reference of a value without a type.", makeArrayReferenceExpression.codeLocation()));
1101         return;
1102     }
1103
1104     auto& unifyNode = leftValueType->unifyNode();
1105     if (is<AST::UnnamedType>(unifyNode)) {
1106         auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
1107         if (is<AST::PointerType>(unnamedType)) {
1108             auto& pointerType = downcast<AST::PointerType>(unnamedType);
1109             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the fact that we're not targetting the item; we're targetting the item's inner element.
1110             assignConcreteType(makeArrayReferenceExpression, AST::ArrayReferenceType::create(makeArrayReferenceExpression.codeLocation(), pointerType.addressSpace(), pointerType.elementType()));
1111             return;
1112         }
1113
1114         auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
1115         if (!leftAddressSpace) {
1116             setError(Error("Cannot make an array reference from a non-left-value.", makeArrayReferenceExpression.codeLocation()));
1117             return;
1118         }
1119
1120         if (is<AST::ArrayType>(unnamedType)) {
1121             auto& arrayType = downcast<AST::ArrayType>(unnamedType);
1122             // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198163 Save the number of elements.
1123             assignConcreteType(makeArrayReferenceExpression, AST::ArrayReferenceType::create(makeArrayReferenceExpression.codeLocation(), *leftAddressSpace, arrayType.type()));
1124             return;
1125         }
1126     }
1127
1128     auto leftAddressSpace = leftValueInfo->typeAnnotation.leftAddressSpace();
1129     if (!leftAddressSpace) {
1130         setError(Error("Cannot make an array reference from a non-left-value.", makeArrayReferenceExpression.codeLocation()));
1131         return;
1132     }
1133
1134     assignConcreteType(makeArrayReferenceExpression, AST::ArrayReferenceType::create(makeArrayReferenceExpression.codeLocation(), *leftAddressSpace, *leftValueType));
1135 }
1136
1137 RefPtr<AST::UnnamedType> Checker::argumentTypeForAndOverload(AST::UnnamedType& baseType, AST::AddressSpace addressSpace)
1138 {
1139     AndOverloadTypeKey key { baseType, addressSpace };
1140     {
1141         auto iter = m_andOverloadTypeMap.find(key);
1142         if (iter != m_andOverloadTypeMap.end())
1143             return iter->value;
1144     }
1145
1146     auto createArgumentType = [&] () -> RefPtr<AST::UnnamedType> {
1147         auto& unifyNode = baseType.unifyNode();
1148         if (is<AST::NamedType>(unifyNode)) {
1149             auto& namedType = downcast<AST::NamedType>(unifyNode);
1150             return { AST::PointerType::create(namedType.codeLocation(), addressSpace, AST::TypeReference::wrap(namedType.codeLocation(), namedType)) };
1151         }
1152
1153         auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
1154
1155         if (is<AST::ArrayReferenceType>(unnamedType))
1156             return &unnamedType;
1157
1158         if (is<AST::ArrayType>(unnamedType))
1159             return { AST::ArrayReferenceType::create(unnamedType.codeLocation(), addressSpace, downcast<AST::ArrayType>(unnamedType).type()) };
1160
1161         if (is<AST::PointerType>(unnamedType))
1162             return nullptr;
1163
1164         return { AST::PointerType::create(unnamedType.codeLocation(), addressSpace, unnamedType) };
1165     };
1166
1167     auto result = createArgumentType();
1168     m_andOverloadTypeMap.add(key, result);
1169     return result;
1170 }
1171
1172 void Checker::finishVisiting(AST::PropertyAccessExpression& propertyAccessExpression, ResolvingType* additionalArgumentType)
1173 {
1174     auto baseInfo = recurseAndGetInfo(propertyAccessExpression.base());
1175     if (!baseInfo)
1176         return;
1177     auto baseUnnamedType = commit(baseInfo->resolvingType);
1178     if (!baseUnnamedType) {
1179         setError(Error("Cannot resolve the type of the base of a property access expression.", propertyAccessExpression.codeLocation()));
1180         return;
1181     }
1182
1183     AST::FunctionDeclaration* getterFunction = nullptr;
1184     RefPtr<AST::UnnamedType> getterReturnType = nullptr;
1185     {
1186         Vector<std::reference_wrapper<ResolvingType>> getterArgumentTypes { baseInfo->resolvingType };
1187         if (additionalArgumentType)
1188             getterArgumentTypes.append(*additionalArgumentType);
1189         auto getterName = propertyAccessExpression.getterFunctionName();
1190         getterFunction = resolveFunction(getterArgumentTypes, getterName, propertyAccessExpression.codeLocation());
1191         if (hasError())
1192             return;
1193         if (getterFunction)
1194             getterReturnType = &getterFunction->type();
1195     }
1196
1197     AST::FunctionDeclaration* anderFunction = nullptr;
1198     RefPtr<AST::UnnamedType> anderReturnType = nullptr;
1199     auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace();
1200     if (leftAddressSpace) {
1201         if (auto argumentTypeForAndOverload = this->argumentTypeForAndOverload(*baseUnnamedType, *leftAddressSpace)) {
1202             ResolvingType argumentType = { Ref<AST::UnnamedType>(*argumentTypeForAndOverload) };
1203             Vector<std::reference_wrapper<ResolvingType>> anderArgumentTypes { argumentType };
1204             if (additionalArgumentType)
1205                 anderArgumentTypes.append(*additionalArgumentType);
1206             auto anderName = propertyAccessExpression.anderFunctionName();
1207             anderFunction = resolveFunction(anderArgumentTypes, anderName, propertyAccessExpression.codeLocation());
1208             if (hasError())
1209                 return;
1210             if (anderFunction)
1211                 anderReturnType = &downcast<AST::PointerType>(anderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1212         }
1213     }
1214
1215     AST::FunctionDeclaration* threadAnderFunction = nullptr;
1216     RefPtr<AST::UnnamedType> threadAnderReturnType = nullptr;
1217     if (auto argumentTypeForAndOverload = this->argumentTypeForAndOverload(*baseUnnamedType, AST::AddressSpace::Thread)) {
1218         ResolvingType argumentType = { Ref<AST::UnnamedType>(AST::PointerType::create(propertyAccessExpression.codeLocation(), AST::AddressSpace::Thread, *baseUnnamedType)) };
1219         Vector<std::reference_wrapper<ResolvingType>> threadAnderArgumentTypes { argumentType };
1220         if (additionalArgumentType)
1221             threadAnderArgumentTypes.append(*additionalArgumentType);
1222         auto anderName = propertyAccessExpression.anderFunctionName();
1223         threadAnderFunction = resolveFunction(threadAnderArgumentTypes, anderName, propertyAccessExpression.codeLocation());
1224         if (hasError())
1225             return;
1226         if (threadAnderFunction)
1227             threadAnderReturnType = &downcast<AST::PointerType>(threadAnderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
1228     }
1229
1230     if (leftAddressSpace && !anderFunction && !getterFunction) {
1231         setError(Error("Property access instruction must either have an ander or a getter.", propertyAccessExpression.codeLocation()));
1232         return;
1233     }
1234
1235     if (!leftAddressSpace && !threadAnderFunction && !getterFunction) {
1236         setError(Error("Property access instruction must either have a thread ander or a getter.", propertyAccessExpression.codeLocation()));
1237         return;
1238     }
1239
1240     if (threadAnderFunction && getterFunction) {
1241         setError(Error("Cannot have both a thread ander and a getter.", propertyAccessExpression.codeLocation()));
1242         return;
1243     }
1244
1245     if (anderFunction && threadAnderFunction && !matches(*anderReturnType, *threadAnderReturnType)) {
1246         setError(Error("Return type of ander must match the return type of the thread ander.", propertyAccessExpression.codeLocation()));
1247         return;
1248     }
1249
1250     if (getterFunction && anderFunction && !matches(*getterReturnType, *anderReturnType)) {
1251         setError(Error("Return type of ander must match the return type of the getter.", propertyAccessExpression.codeLocation()));
1252         return;
1253     }
1254
1255     if (getterFunction && threadAnderFunction && !matches(*getterReturnType, *threadAnderReturnType)) {
1256         setError(Error("Return type of the thread ander must match the return type of the getter.", propertyAccessExpression.codeLocation()));
1257         return;
1258     }
1259
1260     Ref<AST::UnnamedType> fieldType = getterReturnType ? *getterReturnType : anderReturnType ? *anderReturnType : *threadAnderReturnType;
1261
1262     AST::FunctionDeclaration* setterFunction = nullptr;
1263     AST::UnnamedType* setterReturnType = nullptr;
1264     {
1265         ResolvingType fieldResolvingType(fieldType.copyRef());
1266         Vector<std::reference_wrapper<ResolvingType>> setterArgumentTypes { baseInfo->resolvingType };
1267         if (additionalArgumentType)
1268             setterArgumentTypes.append(*additionalArgumentType);
1269         setterArgumentTypes.append(fieldResolvingType);
1270         auto setterName = propertyAccessExpression.setterFunctionName();
1271         setterFunction = resolveFunction(setterArgumentTypes, setterName, propertyAccessExpression.codeLocation());
1272         if (hasError())
1273             return;
1274         if (setterFunction)
1275             setterReturnType = &setterFunction->type();
1276     }
1277
1278     if (setterFunction && !getterFunction) {
1279         setError(Error("Cannot define a setter function without a corresponding getter.", propertyAccessExpression.codeLocation()));
1280         return;
1281     }
1282
1283     propertyAccessExpression.setGetterFunction(getterFunction);
1284     propertyAccessExpression.setAnderFunction(anderFunction);
1285     propertyAccessExpression.setThreadAnderFunction(threadAnderFunction);
1286     propertyAccessExpression.setSetterFunction(setterFunction);
1287
1288     AST::TypeAnnotation typeAnnotation = AST::RightValue();
1289     if (auto leftAddressSpace = baseInfo->typeAnnotation.leftAddressSpace()) {
1290         if (anderFunction)
1291             typeAnnotation = AST::LeftValue { downcast<AST::ReferenceType>(anderFunction->type()).addressSpace() };
1292         else if (setterFunction)
1293             typeAnnotation = AST::AbstractLeftValue();
1294     } else if (!baseInfo->typeAnnotation.isRightValue() && (setterFunction || threadAnderFunction))
1295         typeAnnotation = AST::AbstractLeftValue();
1296     assignConcreteType(propertyAccessExpression, WTFMove(fieldType), WTFMove(typeAnnotation));
1297 }
1298
1299 void Checker::visit(AST::DotExpression& dotExpression)
1300 {
1301     finishVisiting(dotExpression);
1302 }
1303
1304 void Checker::visit(AST::IndexExpression& indexExpression)
1305 {
1306     auto baseInfo = recurseAndGetInfo(indexExpression.indexExpression());
1307     if (!baseInfo)
1308         return;
1309     finishVisiting(indexExpression, &baseInfo->resolvingType);
1310 }
1311
1312 void Checker::visit(AST::VariableReference& variableReference)
1313 {
1314     ASSERT(variableReference.variable());
1315     ASSERT(variableReference.variable()->type());
1316     
1317     assignConcreteType(variableReference, *variableReference.variable()->type(), AST::LeftValue { AST::AddressSpace::Thread });
1318 }
1319
1320 void Checker::visit(AST::Return& returnStatement)
1321 {
1322     if (returnStatement.value()) {
1323         auto valueInfo = recurseAndGetInfo(*returnStatement.value());
1324         if (!valueInfo)
1325             return;
1326         if (!matchAndCommit(valueInfo->resolvingType, m_currentFunction->type()))
1327             setError(Error("Type of the return value must match the return type of the function.", returnStatement.codeLocation()));
1328         return;
1329     }
1330
1331     if (!matches(m_currentFunction->type(), m_intrinsics.voidType()))
1332         setError(Error("Cannot return a value from a void function.", returnStatement.codeLocation()));
1333 }
1334
1335 void Checker::visit(AST::PointerType&)
1336 {
1337     // Following pointer types can cause infinite loops because of data structures
1338     // like linked lists.
1339     // FIXME: Make sure this function should be empty
1340 }
1341
1342 void Checker::visit(AST::ArrayReferenceType&)
1343 {
1344     // Following array reference types can cause infinite loops because of data
1345     // structures like linked lists.
1346     // FIXME: Make sure this function should be empty
1347 }
1348
1349 void Checker::visit(AST::IntegerLiteral& integerLiteral)
1350 {
1351     assignType(integerLiteral, adoptRef(*new ResolvableTypeReference(integerLiteral.type())));
1352 }
1353
1354 void Checker::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
1355 {
1356     assignType(unsignedIntegerLiteral, adoptRef(*new ResolvableTypeReference(unsignedIntegerLiteral.type())));
1357 }
1358
1359 void Checker::visit(AST::FloatLiteral& floatLiteral)
1360 {
1361     assignType(floatLiteral, adoptRef(*new ResolvableTypeReference(floatLiteral.type())));
1362 }
1363
1364 void Checker::visit(AST::NullLiteral& nullLiteral)
1365 {
1366     assignType(nullLiteral, adoptRef(*new ResolvableTypeReference(nullLiteral.type())));
1367 }
1368
1369 void Checker::visit(AST::BooleanLiteral& booleanLiteral)
1370 {
1371     assignConcreteType(booleanLiteral, AST::TypeReference::wrap(booleanLiteral.codeLocation(), m_intrinsics.boolType()));
1372 }
1373
1374 void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
1375 {
1376     ASSERT(enumerationMemberLiteral.enumerationDefinition());
1377     auto& enumerationDefinition = *enumerationMemberLiteral.enumerationDefinition();
1378     assignConcreteType(enumerationMemberLiteral, AST::TypeReference::wrap(enumerationMemberLiteral.codeLocation(), enumerationDefinition));
1379 }
1380
1381 bool Checker::isBoolType(ResolvingType& resolvingType)
1382 {
1383     return resolvingType.visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& left) -> bool {
1384         return matches(left, m_intrinsics.boolType());
1385     }, [&](RefPtr<ResolvableTypeReference>& left) -> bool {
1386         return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType()));
1387     }));
1388 }
1389
1390 bool Checker::recurseAndRequireBoolType(AST::Expression& expression)
1391 {
1392     auto expressionInfo = recurseAndGetInfo(expression);
1393     if (!expressionInfo)
1394         return false;
1395     if (!isBoolType(expressionInfo->resolvingType)) {
1396         setError(Error("Expected bool type from expression.", expression.codeLocation()));
1397         return false;
1398     }
1399     return true;
1400 }
1401
1402 void Checker::visit(AST::LogicalNotExpression& logicalNotExpression)
1403 {
1404     if (!recurseAndRequireBoolType(logicalNotExpression.operand()))
1405         return;
1406     assignConcreteType(logicalNotExpression, AST::TypeReference::wrap(logicalNotExpression.codeLocation(), m_intrinsics.boolType()));
1407 }
1408
1409 void Checker::visit(AST::LogicalExpression& logicalExpression)
1410 {
1411     if (!recurseAndRequireBoolType(logicalExpression.left()))
1412         return;
1413     if (!recurseAndRequireBoolType(logicalExpression.right()))
1414         return;
1415     assignConcreteType(logicalExpression, AST::TypeReference::wrap(logicalExpression.codeLocation(), m_intrinsics.boolType()));
1416 }
1417
1418 void Checker::visit(AST::IfStatement& ifStatement)
1419 {
1420     if (!recurseAndRequireBoolType(ifStatement.conditional()))
1421         return;
1422     checkErrorAndVisit(ifStatement.body());
1423     if (ifStatement.elseBody())
1424         checkErrorAndVisit(*ifStatement.elseBody());
1425 }
1426
1427 void Checker::visit(AST::WhileLoop& whileLoop)
1428 {
1429     if (!recurseAndRequireBoolType(whileLoop.conditional()))
1430         return;
1431     checkErrorAndVisit(whileLoop.body());
1432 }
1433
1434 void Checker::visit(AST::DoWhileLoop& doWhileLoop)
1435 {
1436     checkErrorAndVisit(doWhileLoop.body());
1437     recurseAndRequireBoolType(doWhileLoop.conditional());
1438 }
1439
1440 void Checker::visit(AST::ForLoop& forLoop)
1441 {
1442     checkErrorAndVisit(forLoop.initialization());
1443     if (hasError())
1444         return;
1445     if (forLoop.condition()) {
1446         if (!recurseAndRequireBoolType(*forLoop.condition()))
1447             return;
1448     }
1449     if (forLoop.increment())
1450         checkErrorAndVisit(*forLoop.increment());
1451     checkErrorAndVisit(forLoop.body());
1452 }
1453
1454 void Checker::visit(AST::SwitchStatement& switchStatement)
1455 {
1456     auto* valueType = ([&]() -> AST::NamedType* {
1457         auto valueInfo = recurseAndGetInfo(switchStatement.value());
1458         if (!valueInfo)
1459             return nullptr;
1460         auto* valueType = getUnnamedType(valueInfo->resolvingType);
1461         if (!valueType)
1462             return nullptr;
1463         auto& valueUnifyNode = valueType->unifyNode();
1464         if (!is<AST::NamedType>(valueUnifyNode))
1465             return nullptr;
1466         auto& valueNamedUnifyNode = downcast<AST::NamedType>(valueUnifyNode);
1467         if (!(is<AST::NativeTypeDeclaration>(valueNamedUnifyNode) && downcast<AST::NativeTypeDeclaration>(valueNamedUnifyNode).isInt())
1468             && !is<AST::EnumerationDefinition>(valueNamedUnifyNode))
1469             return nullptr;
1470         return &valueNamedUnifyNode;
1471     })();
1472     if (!valueType) {
1473         setError(Error("Invalid type for the expression condition of the switch statement.", switchStatement.codeLocation()));
1474         return;
1475     }
1476
1477     bool hasDefault = false;
1478     for (auto& switchCase : switchStatement.switchCases()) {
1479         checkErrorAndVisit(switchCase.block());
1480         if (!switchCase.value()) {
1481             hasDefault = true;
1482             continue;
1483         }
1484         auto success = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> bool {
1485             return static_cast<bool>(matchAndCommit(*valueType, integerLiteral.type()));
1486         }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> bool {
1487             return static_cast<bool>(matchAndCommit(*valueType, unsignedIntegerLiteral.type()));
1488         }, [&](AST::FloatLiteral& floatLiteral) -> bool {
1489             return static_cast<bool>(matchAndCommit(*valueType, floatLiteral.type()));
1490         }, [&](AST::NullLiteral& nullLiteral) -> bool {
1491             return static_cast<bool>(matchAndCommit(*valueType, nullLiteral.type()));
1492         }, [&](AST::BooleanLiteral&) -> bool {
1493             return matches(*valueType, m_intrinsics.boolType());
1494         }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> bool {
1495             ASSERT(enumerationMemberLiteral.enumerationDefinition());
1496             return matches(*valueType, *enumerationMemberLiteral.enumerationDefinition());
1497         }));
1498         if (!success) {
1499             setError(Error("Invalid type for switch case.", switchCase.codeLocation()));
1500             return;
1501         }
1502     }
1503
1504     for (size_t i = 0; i < switchStatement.switchCases().size(); ++i) {
1505         auto& firstCase = switchStatement.switchCases()[i];
1506         for (size_t j = i + 1; j < switchStatement.switchCases().size(); ++j) {
1507             auto& secondCase = switchStatement.switchCases()[j];
1508             
1509             if (static_cast<bool>(firstCase.value()) != static_cast<bool>(secondCase.value()))
1510                 continue;
1511
1512             if (!static_cast<bool>(firstCase.value())) {
1513                 setError(Error("Cannot define multiple default cases in switch statement.", secondCase.codeLocation()));
1514                 return;
1515             }
1516
1517             auto success = firstCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& firstIntegerLiteral) -> bool {
1518                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1519                     return firstIntegerLiteral.value() != secondIntegerLiteral.value();
1520                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1521                     return static_cast<int64_t>(firstIntegerLiteral.value()) != static_cast<int64_t>(secondUnsignedIntegerLiteral.value());
1522                 }, [](auto&) -> bool {
1523                     return true;
1524                 }));
1525             }, [&](AST::UnsignedIntegerLiteral& firstUnsignedIntegerLiteral) -> bool {
1526                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& secondIntegerLiteral) -> bool {
1527                     return static_cast<int64_t>(firstUnsignedIntegerLiteral.value()) != static_cast<int64_t>(secondIntegerLiteral.value());
1528                 }, [&](AST::UnsignedIntegerLiteral& secondUnsignedIntegerLiteral) -> bool {
1529                     return firstUnsignedIntegerLiteral.value() != secondUnsignedIntegerLiteral.value();
1530                 }, [](auto&) -> bool {
1531                     return true;
1532                 }));
1533             }, [&](AST::EnumerationMemberLiteral& firstEnumerationMemberLiteral) -> bool {
1534                 return secondCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& secondEnumerationMemberLiteral) -> bool {
1535                     ASSERT(firstEnumerationMemberLiteral.enumerationMember());
1536                     ASSERT(secondEnumerationMemberLiteral.enumerationMember());
1537                     return firstEnumerationMemberLiteral.enumerationMember() != secondEnumerationMemberLiteral.enumerationMember();
1538                 }, [](auto&) -> bool {
1539                     return true;
1540                 }));
1541             }, [](auto&) -> bool {
1542                 return true;
1543             }));
1544             if (!success) {
1545                 setError(Error("Cannot define duplicate case statements in a switch.", secondCase.codeLocation()));
1546                 return;
1547             }
1548         }
1549     }
1550
1551     if (!hasDefault) {
1552         if (is<AST::NativeTypeDeclaration>(*valueType)) {
1553             HashSet<int64_t> values;
1554             bool zeroValueExists;
1555             for (auto& switchCase : switchStatement.switchCases()) {
1556                 auto value = switchCase.value()->visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> int64_t {
1557                     return integerLiteral.valueForSelectedType();
1558                 }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> int64_t {
1559                     return unsignedIntegerLiteral.valueForSelectedType();
1560                 }, [](auto&) -> int64_t {
1561                     ASSERT_NOT_REACHED();
1562                     return 0;
1563                 }));
1564                 if (!value)
1565                     zeroValueExists = true;
1566                 else
1567                     values.add(value);
1568             }
1569             bool success = true;
1570             downcast<AST::NativeTypeDeclaration>(*valueType).iterateAllValues([&](int64_t value) -> bool {
1571                 if (!value) {
1572                     if (!zeroValueExists) {
1573                         success = false;
1574                         return true;
1575                     }
1576                     return false;
1577                 }
1578                 if (!values.contains(value)) {
1579                     success = false;
1580                     return true;
1581                 }
1582                 return false;
1583             });
1584             if (!success) {
1585                 setError(Error("Switch cases must be exhaustive or you must define a default case.", switchStatement.codeLocation()));
1586                 return;
1587             }
1588         } else {
1589             HashSet<AST::EnumerationMember*> values;
1590             for (auto& switchCase : switchStatement.switchCases()) {
1591                 switchCase.value()->visit(WTF::makeVisitor([&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
1592                     ASSERT(enumerationMemberLiteral.enumerationMember());
1593                     values.add(enumerationMemberLiteral.enumerationMember());
1594                 }, [](auto&) {
1595                     ASSERT_NOT_REACHED();
1596                 }));
1597             }
1598             for (auto& enumerationMember : downcast<AST::EnumerationDefinition>(*valueType).enumerationMembers()) {
1599                 if (!values.contains(&enumerationMember.get())) {
1600                     setError(Error("Switch cases over an enum must be exhaustive or you must define a default case.", switchStatement.codeLocation()));
1601                     return;
1602                 }
1603             }
1604         }
1605     }
1606 }
1607
1608 void Checker::visit(AST::CommaExpression& commaExpression)
1609 {
1610     ASSERT(commaExpression.list().size() > 0);
1611     Visitor::visit(commaExpression);
1612     if (hasError())
1613         return;
1614     auto lastInfo = getInfo(commaExpression.list().last());
1615     forwardType(commaExpression, lastInfo->resolvingType);
1616 }
1617
1618 void Checker::visit(AST::TernaryExpression& ternaryExpression)
1619 {
1620     auto predicateInfo = recurseAndRequireBoolType(ternaryExpression.predicate());
1621     if (!predicateInfo)
1622         return;
1623
1624     auto bodyInfo = recurseAndGetInfo(ternaryExpression.bodyExpression());
1625     auto elseInfo = recurseAndGetInfo(ternaryExpression.elseExpression());
1626     
1627     auto resultType = matchAndCommit(bodyInfo->resolvingType, elseInfo->resolvingType);
1628     if (!resultType) {
1629         setError(Error("lhs and rhs of a ternary expression must match.", ternaryExpression.codeLocation()));
1630         return;
1631     }
1632
1633     assignConcreteType(ternaryExpression, *resultType);
1634 }
1635
1636 void Checker::visit(AST::CallExpression& callExpression)
1637 {
1638     Vector<std::reference_wrapper<ResolvingType>> types;
1639     types.reserveInitialCapacity(callExpression.arguments().size());
1640     for (auto& argument : callExpression.arguments()) {
1641         auto argumentInfo = recurseAndGetInfo(argument);
1642         if (!argumentInfo)
1643             return;
1644         types.uncheckedAppend(argumentInfo->resolvingType);
1645     }
1646     // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
1647     // We don't want to recurse to the same node twice.
1648
1649     auto* function = resolveFunction(types, callExpression.name(), callExpression.codeLocation());
1650     if (hasError())
1651         return;
1652
1653     if (!function) {
1654         NameContext& nameContext = m_program.nameContext();
1655         if (auto* castTypes = nameContext.getTypes(callExpression.name())) {
1656             if (castTypes->size() == 1) {
1657                 AST::NamedType& castType = (*castTypes)[0].get();
1658                 function = resolveFunction(types, "operator cast"_str, callExpression.codeLocation(), &castType);
1659                 if (hasError())
1660                     return;
1661                 if (function)
1662                     callExpression.setCastData(castType);
1663             }
1664         }
1665     }
1666
1667     if (!function) {
1668         // FIXME: Add better error messages for why we can't resolve to one of the overrides.
1669         // https://bugs.webkit.org/show_bug.cgi?id=200133
1670         setError(Error("Cannot resolve function call to a concrete callee. Make sure you are using compatible types.", callExpression.codeLocation()));
1671         return;
1672     }
1673
1674     for (size_t i = 0; i < function->parameters().size(); ++i) {
1675         if (!matchAndCommit(types[i].get(), *function->parameters()[i]->type())) {
1676             setError(Error(makeString("Invalid type for parameter number ", i + 1, " in function call."), callExpression.codeLocation()));
1677             return;
1678         }
1679     }
1680
1681     callExpression.setFunction(*function);
1682
1683     assignConcreteType(callExpression, function->type());
1684 }
1685
1686 Expected<void, Error> check(Program& program)
1687 {
1688     Checker checker(program.intrinsics(), program);
1689     checker.checkErrorAndVisit(program);
1690     if (checker.hasError())
1691         return checker.result();
1692     return checker.assignTypes();
1693 }
1694
1695 } // namespace WHLSL
1696
1697 } // namespace WebCore
1698
1699 #endif // ENABLE(WEBGPU)