[WHLSL] Hook up common texture functions
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLNameResolver.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLNameResolver.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLCallExpression.h"
32 #include "WHLSLDoWhileLoop.h"
33 #include "WHLSLDotExpression.h"
34 #include "WHLSLEnumerationDefinition.h"
35 #include "WHLSLEnumerationMemberLiteral.h"
36 #include "WHLSLForLoop.h"
37 #include "WHLSLFunctionDefinition.h"
38 #include "WHLSLIfStatement.h"
39 #include "WHLSLNameContext.h"
40 #include "WHLSLProgram.h"
41 #include "WHLSLPropertyAccessExpression.h"
42 #include "WHLSLResolveOverloadImpl.h"
43 #include "WHLSLReturn.h"
44 #include "WHLSLScopedSetAdder.h"
45 #include "WHLSLTypeReference.h"
46 #include "WHLSLVariableDeclaration.h"
47 #include "WHLSLVariableReference.h"
48 #include "WHLSLWhileLoop.h"
49
50 namespace WebCore {
51
52 namespace WHLSL {
53
54 NameResolver::NameResolver(NameContext& nameContext)
55     : m_nameContext(nameContext)
56 {
57 }
58
59 NameResolver::NameResolver(NameResolver& parentResolver, NameContext& nameContext)
60     : m_nameContext(nameContext)
61     , m_parentNameResolver(&parentResolver)
62 {
63     m_isResolvingCalls = parentResolver.m_isResolvingCalls;
64     setCurrentFunctionDefinition(parentResolver.m_currentFunction);
65 }
66
67 NameResolver::~NameResolver()
68 {
69     if (error() && m_parentNameResolver)
70         m_parentNameResolver->setError();
71 }
72
73 void NameResolver::visit(AST::TypeReference& typeReference)
74 {
75     if (m_isResolvingCalls)
76         return;
77
78     ScopedSetAdder<AST::TypeReference*> adder(m_typeReferences, &typeReference);
79     if (!adder.isNewEntry()) {
80         setError();
81         return;
82     }
83
84     Visitor::visit(typeReference);
85     if (typeReference.maybeResolvedType()) // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198161 Shouldn't we know by now whether the type has been resolved or not?
86         return;
87
88     auto* candidates = m_nameContext.getTypes(typeReference.name());
89     if (candidates == nullptr) {
90         setError();
91         return;
92     }
93     for (auto& candidate : *candidates)
94         Visitor::visit(candidate);
95     if (auto result = resolveTypeOverloadImpl(*candidates, typeReference.typeArguments()))
96         typeReference.setResolvedType(*result);
97     else {
98         setError();
99         return;
100     }
101 }
102
103 void NameResolver::visit(AST::FunctionDefinition& functionDefinition)
104 {
105     NameContext newNameContext(&m_nameContext);
106     NameResolver newNameResolver(*this, newNameContext);
107     checkErrorAndVisit(functionDefinition.type());
108     for (auto& parameter : functionDefinition.parameters())
109         newNameResolver.checkErrorAndVisit(parameter);
110     newNameResolver.checkErrorAndVisit(functionDefinition.block());
111 }
112
113 void NameResolver::visit(AST::Block& block)
114 {
115     NameContext nameContext(&m_nameContext);
116     NameResolver newNameResolver(*this, nameContext);
117     newNameResolver.Visitor::visit(block);
118 }
119
120 void NameResolver::visit(AST::IfStatement& ifStatement)
121 {
122     checkErrorAndVisit(ifStatement.conditional());
123     NameContext nameContext(&m_nameContext);
124     NameResolver newNameResolver(*this, nameContext);
125     newNameResolver.checkErrorAndVisit(ifStatement.body());
126     if (newNameResolver.error())
127         setError();
128     else if (ifStatement.elseBody()) {
129         NameContext nameContext(&m_nameContext);
130         NameResolver newNameResolver(*this, nameContext);
131         newNameResolver.checkErrorAndVisit(*ifStatement.elseBody());
132     }
133 }
134
135 void NameResolver::visit(AST::WhileLoop& whileLoop)
136 {
137     checkErrorAndVisit(whileLoop.conditional());
138     NameContext nameContext(&m_nameContext);
139     NameResolver newNameResolver(*this, nameContext);
140     newNameResolver.checkErrorAndVisit(whileLoop.body());
141 }
142
143 void NameResolver::visit(AST::DoWhileLoop& whileLoop)
144 {
145     NameContext nameContext(&m_nameContext);
146     NameResolver newNameResolver(*this, nameContext);
147     newNameResolver.checkErrorAndVisit(whileLoop.body());
148     checkErrorAndVisit(whileLoop.conditional());
149 }
150
151 void NameResolver::visit(AST::ForLoop& forLoop)
152 {
153     NameContext nameContext(&m_nameContext);
154     NameResolver newNameResolver(*this, nameContext);
155     newNameResolver.Visitor::visit(forLoop);
156 }
157
158 void NameResolver::visit(AST::VariableDeclaration& variableDeclaration)
159 {
160     if (!m_nameContext.add(variableDeclaration)) {
161         setError();
162         return;
163     }
164     Visitor::visit(variableDeclaration);
165 }
166
167 void NameResolver::visit(AST::VariableReference& variableReference)
168 {
169     if (variableReference.variable())
170         return;
171
172     if (auto* variable = m_nameContext.getVariable(variableReference.name()))
173         variableReference.setVariable(*variable);
174     else {
175         setError();
176         return;
177     }
178 }
179
180 void NameResolver::visit(AST::Return& returnStatement)
181 {
182     ASSERT(m_currentFunction);
183     returnStatement.setFunction(m_currentFunction);
184     Visitor::visit(returnStatement);
185 }
186
187 void NameResolver::visit(AST::PropertyAccessExpression& propertyAccessExpression)
188 {
189     if (m_isResolvingCalls) {
190         if (auto* getterFunctions = m_nameContext.getFunctions(propertyAccessExpression.getterFunctionName()))
191             propertyAccessExpression.setPossibleGetterOverloads(*getterFunctions);
192         if (auto* setterFunctions = m_nameContext.getFunctions(propertyAccessExpression.setterFunctionName()))
193             propertyAccessExpression.setPossibleSetterOverloads(*setterFunctions);
194         if (auto* anderFunctions = m_nameContext.getFunctions(propertyAccessExpression.anderFunctionName()))
195             propertyAccessExpression.setPossibleAnderOverloads(*anderFunctions);
196     }
197     Visitor::visit(propertyAccessExpression);
198 }
199
200 void NameResolver::visit(AST::DotExpression& dotExpression)
201 {
202     if (is<AST::VariableReference>(dotExpression.base())) {
203         auto baseName = downcast<AST::VariableReference>(dotExpression.base()).name();
204         if (auto enumerationTypes = m_nameContext.getTypes(baseName)) {
205             ASSERT(enumerationTypes->size() == 1);
206             AST::NamedType& type = (*enumerationTypes)[0];
207             if (is<AST::EnumerationDefinition>(type)) {
208                 AST::EnumerationDefinition& enumerationDefinition = downcast<AST::EnumerationDefinition>(type);
209                 auto memberName = dotExpression.fieldName();
210                 if (auto* member = enumerationDefinition.memberByName(memberName)) {
211                     Lexer::Token origin = dotExpression.origin();
212                     auto enumerationMemberLiteral = AST::EnumerationMemberLiteral::wrap(WTFMove(origin), WTFMove(baseName), WTFMove(memberName), enumerationDefinition, *member);
213                     AST::replaceWith<AST::EnumerationMemberLiteral>(dotExpression, WTFMove(enumerationMemberLiteral));
214                     return;
215                 }
216                 setError();
217                 return;
218             }
219         }
220     }
221
222     Visitor::visit(dotExpression);
223 }
224
225 void NameResolver::visit(AST::CallExpression& callExpression)
226 {
227     if (m_isResolvingCalls) {
228         if (!callExpression.hasOverloads()) {
229             if (auto* functions = m_nameContext.getFunctions(callExpression.name()))
230                 callExpression.setOverloads(*functions);
231             else {
232                 if (auto* types = m_nameContext.getTypes(callExpression.name())) {
233                     if (types->size() == 1) {
234                         if (auto* functions = m_nameContext.getFunctions("operator cast"_str)) {
235                             callExpression.setCastData((*types)[0].get());
236                             callExpression.setOverloads(*functions);
237                         }
238                     }
239                 }
240             }
241         }
242         if (!callExpression.hasOverloads()) {
243             setError();
244             return;
245         }
246     }
247     Visitor::visit(callExpression);
248 }
249
250 void NameResolver::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
251 {
252     if (enumerationMemberLiteral.enumerationMember())
253         return;
254
255     if (auto enumerationTypes = m_nameContext.getTypes(enumerationMemberLiteral.left())) {
256         ASSERT(enumerationTypes->size() == 1);
257         AST::NamedType& type = (*enumerationTypes)[0];
258         if (is<AST::EnumerationDefinition>(type)) {
259             AST::EnumerationDefinition& enumerationDefinition = downcast<AST::EnumerationDefinition>(type);
260             if (auto* member = enumerationDefinition.memberByName(enumerationMemberLiteral.right())) {
261                 enumerationMemberLiteral.setEnumerationMember(enumerationDefinition, *member);
262                 return;
263             }
264         }
265     }
266     
267     setError();
268 }
269
270 void NameResolver::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
271 {
272     NameContext newNameContext(&m_nameContext);
273     NameResolver newNameResolver(newNameContext);
274     newNameResolver.Visitor::visit(nativeFunctionDeclaration);
275 }
276
277 // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198167 Make sure all the names have been resolved.
278
279 bool resolveNamesInTypes(Program& program, NameResolver& nameResolver)
280 {
281     for (auto& typeDefinition : program.typeDefinitions()) {
282         nameResolver.checkErrorAndVisit(typeDefinition);
283         if (nameResolver.error())
284             return false;
285     }
286     for (auto& structureDefinition : program.structureDefinitions()) {
287         nameResolver.checkErrorAndVisit(structureDefinition);
288         if (nameResolver.error())
289             return false;
290     }
291     for (auto& enumerationDefinition : program.enumerationDefinitions()) {
292         nameResolver.checkErrorAndVisit(enumerationDefinition);
293         if (nameResolver.error())
294             return false;
295     }
296     for (auto& nativeTypeDeclaration : program.nativeTypeDeclarations()) {
297         nameResolver.checkErrorAndVisit(nativeTypeDeclaration);
298         if (nameResolver.error())
299             return false;
300     }
301     return true;
302 }
303
304 bool resolveTypeNamesInFunctions(Program& program, NameResolver& nameResolver)
305 {
306     for (auto& functionDefinition : program.functionDefinitions()) {
307         nameResolver.setCurrentFunctionDefinition(&functionDefinition);
308         nameResolver.checkErrorAndVisit(functionDefinition);
309         if (nameResolver.error())
310             return false;
311     }
312     nameResolver.setCurrentFunctionDefinition(nullptr);
313     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
314         nameResolver.checkErrorAndVisit(nativeFunctionDeclaration);
315         if (nameResolver.error())
316             return false;
317     }
318     return true;
319 }
320
321 bool resolveCallsInFunctions(Program& program, NameResolver& nameResolver)
322 {
323     nameResolver.setIsResolvingCalls(true);
324     for (auto& functionDefinition : program.functionDefinitions()) {
325         nameResolver.setCurrentFunctionDefinition(&functionDefinition);
326         nameResolver.checkErrorAndVisit(functionDefinition);
327         if (nameResolver.error())
328             return false;
329     }
330     nameResolver.setCurrentFunctionDefinition(nullptr);
331     nameResolver.setIsResolvingCalls(false);
332     return true;
333 }
334
335 } // namespace WHLSL
336
337 } // namespace WebCore
338
339 #endif // ENABLE(WEBGPU)