[WHLSL] Code that accesses an undefined variable crashes
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLNameResolver.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLNameResolver.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLCallExpression.h"
32 #include "WHLSLDoWhileLoop.h"
33 #include "WHLSLDotExpression.h"
34 #include "WHLSLEnumerationDefinition.h"
35 #include "WHLSLEnumerationMemberLiteral.h"
36 #include "WHLSLForLoop.h"
37 #include "WHLSLFunctionDefinition.h"
38 #include "WHLSLIfStatement.h"
39 #include "WHLSLNameContext.h"
40 #include "WHLSLProgram.h"
41 #include "WHLSLPropertyAccessExpression.h"
42 #include "WHLSLResolveOverloadImpl.h"
43 #include "WHLSLReturn.h"
44 #include "WHLSLScopedSetAdder.h"
45 #include "WHLSLTypeReference.h"
46 #include "WHLSLVariableDeclaration.h"
47 #include "WHLSLVariableReference.h"
48 #include "WHLSLWhileLoop.h"
49
50 namespace WebCore {
51
52 namespace WHLSL {
53
54 NameResolver::NameResolver(NameContext& nameContext)
55     : m_nameContext(nameContext)
56 {
57 }
58
59 NameResolver::NameResolver(NameResolver& parentResolver, NameContext& nameContext)
60     : m_nameContext(nameContext)
61     , m_parentNameResolver(&parentResolver)
62 {
63     m_isResolvingCalls = parentResolver.m_isResolvingCalls;
64     setCurrentFunctionDefinition(parentResolver.m_currentFunction);
65 }
66
67 NameResolver::~NameResolver()
68 {
69     if (error() && m_parentNameResolver)
70         m_parentNameResolver->setError();
71 }
72
73 void NameResolver::visit(AST::TypeReference& typeReference)
74 {
75     if (m_isResolvingCalls)
76         return;
77
78     ScopedSetAdder<AST::TypeReference*> adder(m_typeReferences, &typeReference);
79     if (!adder.isNewEntry()) {
80         setError();
81         return;
82     }
83
84     Visitor::visit(typeReference);
85     if (typeReference.maybeResolvedType()) // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198161 Shouldn't we know by now whether the type has been resolved or not?
86         return;
87
88     auto* candidates = m_nameContext.getTypes(typeReference.name());
89     if (candidates == nullptr) {
90         setError();
91         return;
92     }
93     for (auto& candidate : *candidates)
94         Visitor::visit(candidate);
95     if (auto result = resolveTypeOverloadImpl(*candidates, typeReference.typeArguments()))
96         typeReference.setResolvedType(*result);
97     else {
98         setError();
99         return;
100     }
101 }
102
103 void NameResolver::visit(AST::FunctionDefinition& functionDefinition)
104 {
105     NameContext newNameContext(&m_nameContext);
106     NameResolver newNameResolver(*this, newNameContext);
107     checkErrorAndVisit(functionDefinition.type());
108     for (auto& parameter : functionDefinition.parameters())
109         newNameResolver.checkErrorAndVisit(parameter);
110     newNameResolver.checkErrorAndVisit(functionDefinition.block());
111 }
112
113 void NameResolver::visit(AST::Block& block)
114 {
115     NameContext nameContext(&m_nameContext);
116     NameResolver newNameResolver(*this, nameContext);
117     newNameResolver.Visitor::visit(block);
118 }
119
120 void NameResolver::visit(AST::IfStatement& ifStatement)
121 {
122     checkErrorAndVisit(ifStatement.conditional());
123     if (error())
124         return;
125
126     {
127         NameContext nameContext(&m_nameContext);
128         NameResolver newNameResolver(*this, nameContext);
129         newNameResolver.checkErrorAndVisit(ifStatement.body());
130     }
131     if (error())
132         return;
133
134     if (ifStatement.elseBody()) {
135         NameContext nameContext(&m_nameContext);
136         NameResolver newNameResolver(*this, nameContext);
137         newNameResolver.checkErrorAndVisit(*ifStatement.elseBody());
138     }
139 }
140
141 void NameResolver::visit(AST::WhileLoop& whileLoop)
142 {
143     checkErrorAndVisit(whileLoop.conditional());
144     if (error())
145         return;
146
147     NameContext nameContext(&m_nameContext);
148     NameResolver newNameResolver(*this, nameContext);
149     newNameResolver.checkErrorAndVisit(whileLoop.body());
150 }
151
152 void NameResolver::visit(AST::DoWhileLoop& whileLoop)
153 {
154     {
155         NameContext nameContext(&m_nameContext);
156         NameResolver newNameResolver(*this, nameContext);
157         newNameResolver.checkErrorAndVisit(whileLoop.body());
158     }
159
160     checkErrorAndVisit(whileLoop.conditional());
161 }
162
163 void NameResolver::visit(AST::ForLoop& forLoop)
164 {
165     NameContext nameContext(&m_nameContext);
166     NameResolver newNameResolver(*this, nameContext);
167     newNameResolver.Visitor::visit(forLoop);
168 }
169
170 void NameResolver::visit(AST::VariableDeclaration& variableDeclaration)
171 {
172     if (!m_nameContext.add(variableDeclaration)) {
173         setError();
174         return;
175     }
176     Visitor::visit(variableDeclaration);
177 }
178
179 void NameResolver::visit(AST::VariableReference& variableReference)
180 {
181     if (variableReference.variable())
182         return;
183
184     if (auto* variable = m_nameContext.getVariable(variableReference.name()))
185         variableReference.setVariable(*variable);
186     else {
187         setError();
188         return;
189     }
190 }
191
192 void NameResolver::visit(AST::Return& returnStatement)
193 {
194     ASSERT(m_currentFunction);
195     returnStatement.setFunction(m_currentFunction);
196     Visitor::visit(returnStatement);
197 }
198
199 void NameResolver::visit(AST::PropertyAccessExpression& propertyAccessExpression)
200 {
201     if (m_isResolvingCalls) {
202         if (auto* getterFunctions = m_nameContext.getFunctions(propertyAccessExpression.getterFunctionName()))
203             propertyAccessExpression.setPossibleGetterOverloads(*getterFunctions);
204         if (auto* setterFunctions = m_nameContext.getFunctions(propertyAccessExpression.setterFunctionName()))
205             propertyAccessExpression.setPossibleSetterOverloads(*setterFunctions);
206         if (auto* anderFunctions = m_nameContext.getFunctions(propertyAccessExpression.anderFunctionName()))
207             propertyAccessExpression.setPossibleAnderOverloads(*anderFunctions);
208     }
209     Visitor::visit(propertyAccessExpression);
210 }
211
212 void NameResolver::visit(AST::DotExpression& dotExpression)
213 {
214     if (is<AST::VariableReference>(dotExpression.base())) {
215         auto baseName = downcast<AST::VariableReference>(dotExpression.base()).name();
216         if (auto enumerationTypes = m_nameContext.getTypes(baseName)) {
217             ASSERT(enumerationTypes->size() == 1);
218             AST::NamedType& type = (*enumerationTypes)[0];
219             if (is<AST::EnumerationDefinition>(type)) {
220                 AST::EnumerationDefinition& enumerationDefinition = downcast<AST::EnumerationDefinition>(type);
221                 auto memberName = dotExpression.fieldName();
222                 if (auto* member = enumerationDefinition.memberByName(memberName)) {
223                     Lexer::Token origin = dotExpression.origin();
224                     auto enumerationMemberLiteral = AST::EnumerationMemberLiteral::wrap(WTFMove(origin), WTFMove(baseName), WTFMove(memberName), enumerationDefinition, *member);
225                     AST::replaceWith<AST::EnumerationMemberLiteral>(dotExpression, WTFMove(enumerationMemberLiteral));
226                     return;
227                 }
228                 setError();
229                 return;
230             }
231         }
232     }
233
234     Visitor::visit(dotExpression);
235 }
236
237 void NameResolver::visit(AST::CallExpression& callExpression)
238 {
239     if (m_isResolvingCalls) {
240         if (!callExpression.hasOverloads()) {
241             if (auto* functions = m_nameContext.getFunctions(callExpression.name()))
242                 callExpression.setOverloads(*functions);
243             else {
244                 if (auto* types = m_nameContext.getTypes(callExpression.name())) {
245                     if (types->size() == 1) {
246                         if (auto* functions = m_nameContext.getFunctions("operator cast"_str)) {
247                             callExpression.setCastData((*types)[0].get());
248                             callExpression.setOverloads(*functions);
249                         }
250                     }
251                 }
252             }
253         }
254         if (!callExpression.hasOverloads()) {
255             setError();
256             return;
257         }
258     }
259     Visitor::visit(callExpression);
260 }
261
262 void NameResolver::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
263 {
264     if (enumerationMemberLiteral.enumerationMember())
265         return;
266
267     if (auto enumerationTypes = m_nameContext.getTypes(enumerationMemberLiteral.left())) {
268         ASSERT(enumerationTypes->size() == 1);
269         AST::NamedType& type = (*enumerationTypes)[0];
270         if (is<AST::EnumerationDefinition>(type)) {
271             AST::EnumerationDefinition& enumerationDefinition = downcast<AST::EnumerationDefinition>(type);
272             if (auto* member = enumerationDefinition.memberByName(enumerationMemberLiteral.right())) {
273                 enumerationMemberLiteral.setEnumerationMember(enumerationDefinition, *member);
274                 return;
275             }
276         }
277     }
278     
279     setError();
280 }
281
282 void NameResolver::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
283 {
284     NameContext newNameContext(&m_nameContext);
285     NameResolver newNameResolver(newNameContext);
286     newNameResolver.Visitor::visit(nativeFunctionDeclaration);
287 }
288
289 // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198167 Make sure all the names have been resolved.
290
291 bool resolveNamesInTypes(Program& program, NameResolver& nameResolver)
292 {
293     for (auto& typeDefinition : program.typeDefinitions()) {
294         nameResolver.checkErrorAndVisit(typeDefinition);
295         if (nameResolver.error())
296             return false;
297     }
298     for (auto& structureDefinition : program.structureDefinitions()) {
299         nameResolver.checkErrorAndVisit(structureDefinition);
300         if (nameResolver.error())
301             return false;
302     }
303     for (auto& enumerationDefinition : program.enumerationDefinitions()) {
304         nameResolver.checkErrorAndVisit(enumerationDefinition);
305         if (nameResolver.error())
306             return false;
307     }
308     for (auto& nativeTypeDeclaration : program.nativeTypeDeclarations()) {
309         nameResolver.checkErrorAndVisit(nativeTypeDeclaration);
310         if (nameResolver.error())
311             return false;
312     }
313     return true;
314 }
315
316 bool resolveTypeNamesInFunctions(Program& program, NameResolver& nameResolver)
317 {
318     for (auto& functionDefinition : program.functionDefinitions()) {
319         nameResolver.setCurrentFunctionDefinition(&functionDefinition);
320         nameResolver.checkErrorAndVisit(functionDefinition);
321         if (nameResolver.error())
322             return false;
323     }
324     nameResolver.setCurrentFunctionDefinition(nullptr);
325     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
326         nameResolver.checkErrorAndVisit(nativeFunctionDeclaration);
327         if (nameResolver.error())
328             return false;
329     }
330     return true;
331 }
332
333 bool resolveCallsInFunctions(Program& program, NameResolver& nameResolver)
334 {
335     nameResolver.setIsResolvingCalls(true);
336     for (auto& functionDefinition : program.functionDefinitions()) {
337         nameResolver.setCurrentFunctionDefinition(&functionDefinition);
338         nameResolver.checkErrorAndVisit(functionDefinition);
339         if (nameResolver.error())
340             return false;
341     }
342     nameResolver.setCurrentFunctionDefinition(nullptr);
343     nameResolver.setIsResolvingCalls(false);
344     return true;
345 }
346
347 } // namespace WHLSL
348
349 } // namespace WebCore
350
351 #endif // ENABLE(WEBGPU)