7ee400057c6a79ba7b77f7199eb405cdfc00008c
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLNameResolver.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLNameResolver.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLCallExpression.h"
32 #include "WHLSLDoWhileLoop.h"
33 #include "WHLSLDotExpression.h"
34 #include "WHLSLEnumerationDefinition.h"
35 #include "WHLSLEnumerationMemberLiteral.h"
36 #include "WHLSLForLoop.h"
37 #include "WHLSLFunctionDefinition.h"
38 #include "WHLSLIfStatement.h"
39 #include "WHLSLNameContext.h"
40 #include "WHLSLProgram.h"
41 #include "WHLSLPropertyAccessExpression.h"
42 #include "WHLSLResolveOverloadImpl.h"
43 #include "WHLSLReturn.h"
44 #include "WHLSLScopedSetAdder.h"
45 #include "WHLSLTypeReference.h"
46 #include "WHLSLVariableDeclaration.h"
47 #include "WHLSLVariableReference.h"
48 #include "WHLSLWhileLoop.h"
49
50 namespace WebCore {
51
52 namespace WHLSL {
53
54 NameResolver::NameResolver(NameContext& nameContext)
55     : m_nameContext(nameContext)
56 {
57 }
58
59 NameResolver::NameResolver(NameResolver& parentResolver, NameContext& nameContext)
60     : m_nameContext(nameContext)
61 {
62     m_isResolvingCalls = parentResolver.m_isResolvingCalls;
63     setCurrentFunctionDefinition(parentResolver.m_currentFunction);
64 }
65
66 void NameResolver::visit(AST::TypeReference& typeReference)
67 {
68     if (m_isResolvingCalls)
69         return;
70
71     ScopedSetAdder<AST::TypeReference*> adder(m_typeReferences, &typeReference);
72     if (!adder.isNewEntry()) {
73         setError();
74         return;
75     }
76
77     Visitor::visit(typeReference);
78     if (typeReference.maybeResolvedType()) // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198161 Shouldn't we know by now whether the type has been resolved or not?
79         return;
80
81     auto* candidates = m_nameContext.getTypes(typeReference.name());
82     if (candidates == nullptr) {
83         setError();
84         return;
85     }
86     for (auto& candidate : *candidates)
87         Visitor::visit(candidate);
88     if (auto result = resolveTypeOverloadImpl(*candidates, typeReference.typeArguments()))
89         typeReference.setResolvedType(*result);
90     else {
91         setError();
92         return;
93     }
94 }
95
96 void NameResolver::visit(AST::FunctionDefinition& functionDefinition)
97 {
98     NameContext newNameContext(&m_nameContext);
99     NameResolver newNameResolver(*this, newNameContext);
100     checkErrorAndVisit(functionDefinition.type());
101     for (auto& parameter : functionDefinition.parameters())
102         newNameResolver.checkErrorAndVisit(parameter);
103     newNameResolver.checkErrorAndVisit(functionDefinition.block());
104 }
105
106 void NameResolver::visit(AST::Block& block)
107 {
108     NameContext nameContext(&m_nameContext);
109     NameResolver newNameResolver(*this, nameContext);
110     newNameResolver.Visitor::visit(block);
111 }
112
113 void NameResolver::visit(AST::IfStatement& ifStatement)
114 {
115     checkErrorAndVisit(ifStatement.conditional());
116     NameContext nameContext(&m_nameContext);
117     NameResolver newNameResolver(*this, nameContext);
118     newNameResolver.checkErrorAndVisit(ifStatement.body());
119     if (ifStatement.elseBody()) {
120         NameContext nameContext(&m_nameContext);
121         NameResolver newNameResolver(*this, nameContext);
122         newNameResolver.checkErrorAndVisit(*ifStatement.elseBody());
123     }
124 }
125
126 void NameResolver::visit(AST::WhileLoop& whileLoop)
127 {
128     checkErrorAndVisit(whileLoop.conditional());
129     NameContext nameContext(&m_nameContext);
130     NameResolver newNameResolver(*this, nameContext);
131     newNameResolver.checkErrorAndVisit(whileLoop.body());
132 }
133
134 void NameResolver::visit(AST::DoWhileLoop& whileLoop)
135 {
136     NameContext nameContext(&m_nameContext);
137     NameResolver newNameResolver(*this, nameContext);
138     newNameResolver.checkErrorAndVisit(whileLoop.body());
139     checkErrorAndVisit(whileLoop.conditional());
140 }
141
142 void NameResolver::visit(AST::ForLoop& forLoop)
143 {
144     NameContext nameContext(&m_nameContext);
145     NameResolver newNameResolver(*this, nameContext);
146     newNameResolver.Visitor::visit(forLoop);
147 }
148
149 void NameResolver::visit(AST::VariableDeclaration& variableDeclaration)
150 {
151     if (!m_nameContext.add(variableDeclaration)) {
152         setError();
153         return;
154     }
155     Visitor::visit(variableDeclaration);
156 }
157
158 void NameResolver::visit(AST::VariableReference& variableReference)
159 {
160     if (variableReference.variable())
161         return;
162
163     if (auto* variable = m_nameContext.getVariable(variableReference.name()))
164         variableReference.setVariable(*variable);
165     else {
166         setError();
167         return;
168     }
169 }
170
171 void NameResolver::visit(AST::Return& returnStatement)
172 {
173     ASSERT(m_currentFunction);
174     returnStatement.setFunction(m_currentFunction);
175     Visitor::visit(returnStatement);
176 }
177
178 void NameResolver::visit(AST::PropertyAccessExpression& propertyAccessExpression)
179 {
180     if (m_isResolvingCalls) {
181         if (auto* getterFunctions = m_nameContext.getFunctions(propertyAccessExpression.getterFunctionName()))
182             propertyAccessExpression.setPossibleGetterOverloads(*getterFunctions);
183         if (auto* setterFunctions = m_nameContext.getFunctions(propertyAccessExpression.setterFunctionName()))
184             propertyAccessExpression.setPossibleSetterOverloads(*setterFunctions);
185         if (auto* anderFunctions = m_nameContext.getFunctions(propertyAccessExpression.anderFunctionName()))
186             propertyAccessExpression.setPossibleAnderOverloads(*anderFunctions);
187     }
188     Visitor::visit(propertyAccessExpression);
189 }
190
191 void NameResolver::visit(AST::DotExpression& dotExpression)
192 {
193     if (is<AST::VariableReference>(dotExpression.base())) {
194         auto baseName = downcast<AST::VariableReference>(dotExpression.base()).name();
195         if (auto enumerationTypes = m_nameContext.getTypes(baseName)) {
196             ASSERT(enumerationTypes->size() == 1);
197             AST::NamedType& type = (*enumerationTypes)[0];
198             if (is<AST::EnumerationDefinition>(type)) {
199                 AST::EnumerationDefinition& enumerationDefinition = downcast<AST::EnumerationDefinition>(type);
200                 auto memberName = dotExpression.fieldName();
201                 if (auto* member = enumerationDefinition.memberByName(memberName)) {
202                     Lexer::Token origin = dotExpression.origin();
203                     auto enumerationMemberLiteral = AST::EnumerationMemberLiteral::wrap(WTFMove(origin), WTFMove(baseName), WTFMove(memberName), enumerationDefinition, *member);
204                     AST::replaceWith<AST::EnumerationMemberLiteral>(dotExpression, WTFMove(enumerationMemberLiteral));
205                     return;
206                 }
207                 setError();
208                 return;
209             }
210         }
211     }
212
213     Visitor::visit(dotExpression);
214 }
215
216 void NameResolver::visit(AST::CallExpression& callExpression)
217 {
218     if (m_isResolvingCalls) {
219         if (!callExpression.hasOverloads()) {
220             if (auto* functions = m_nameContext.getFunctions(callExpression.name()))
221                 callExpression.setOverloads(*functions);
222             else {
223                 if (auto* types = m_nameContext.getTypes(callExpression.name())) {
224                     if (types->size() == 1) {
225                         if (auto* functions = m_nameContext.getFunctions("operator cast"_str)) {
226                             callExpression.setCastData((*types)[0].get());
227                             callExpression.setOverloads(*functions);
228                         }
229                     }
230                 }
231             }
232         }
233         if (!callExpression.hasOverloads()) {
234             setError();
235             return;
236         }
237     }
238     Visitor::visit(callExpression);
239 }
240
241 void NameResolver::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
242 {
243     if (enumerationMemberLiteral.enumerationMember())
244         return;
245
246     if (auto enumerationTypes = m_nameContext.getTypes(enumerationMemberLiteral.left())) {
247         ASSERT(enumerationTypes->size() == 1);
248         AST::NamedType& type = (*enumerationTypes)[0];
249         if (is<AST::EnumerationDefinition>(type)) {
250             AST::EnumerationDefinition& enumerationDefinition = downcast<AST::EnumerationDefinition>(type);
251             if (auto* member = enumerationDefinition.memberByName(enumerationMemberLiteral.right())) {
252                 enumerationMemberLiteral.setEnumerationMember(enumerationDefinition, *member);
253                 return;
254             }
255         }
256     }
257     
258     setError();
259 }
260
261 // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198167 Make sure all the names have been resolved.
262
263 bool resolveNamesInTypes(Program& program, NameResolver& nameResolver)
264 {
265     for (auto& typeDefinition : program.typeDefinitions()) {
266         nameResolver.checkErrorAndVisit(typeDefinition);
267         if (nameResolver.error())
268             return false;
269     }
270     for (auto& structureDefinition : program.structureDefinitions()) {
271         nameResolver.checkErrorAndVisit(structureDefinition);
272         if (nameResolver.error())
273             return false;
274     }
275     for (auto& enumerationDefinition : program.enumerationDefinitions()) {
276         nameResolver.checkErrorAndVisit(enumerationDefinition);
277         if (nameResolver.error())
278             return false;
279     }
280     for (auto& nativeTypeDeclaration : program.nativeTypeDeclarations()) {
281         nameResolver.checkErrorAndVisit(nativeTypeDeclaration);
282         if (nameResolver.error())
283             return false;
284     }
285     return true;
286 }
287
288 bool resolveTypeNamesInFunctions(Program& program, NameResolver& nameResolver)
289 {
290     for (auto& functionDefinition : program.functionDefinitions()) {
291         nameResolver.setCurrentFunctionDefinition(&functionDefinition);
292         nameResolver.checkErrorAndVisit(functionDefinition);
293         if (nameResolver.error())
294             return false;
295     }
296     nameResolver.setCurrentFunctionDefinition(nullptr);
297     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
298         nameResolver.checkErrorAndVisit(nativeFunctionDeclaration);
299         if (nameResolver.error())
300             return false;
301     }
302     return true;
303 }
304
305 bool resolveCallsInFunctions(Program& program, NameResolver& nameResolver)
306 {
307     nameResolver.setIsResolvingCalls(true);
308     for (auto& functionDefinition : program.functionDefinitions()) {
309         nameResolver.setCurrentFunctionDefinition(&functionDefinition);
310         nameResolver.checkErrorAndVisit(functionDefinition);
311         if (nameResolver.error())
312             return false;
313     }
314     nameResolver.setCurrentFunctionDefinition(nullptr);
315     nameResolver.setIsResolvingCalls(false);
316     return true;
317 }
318
319 } // namespace WHLSL
320
321 } // namespace WebCore
322
323 #endif // ENABLE(WEBGPU)