9277f84b841dd87d0ebf503889db590c40a22d08
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLVisitor.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLVisitor.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLAST.h"
32
33 namespace WebCore {
34
35 namespace WHLSL {
36
37 void Visitor::visit(Program& program)
38 {
39     // These visiting functions might add new global statements, so don't use foreach syntax.
40     for (size_t i = 0; i < program.typeDefinitions().size(); ++i)
41         checkErrorAndVisit(program.typeDefinitions()[i]);
42     for (size_t i = 0; i < program.structureDefinitions().size(); ++i)
43         checkErrorAndVisit(program.structureDefinitions()[i]);
44     for (size_t i = 0; i < program.enumerationDefinitions().size(); ++i)
45         checkErrorAndVisit(program.enumerationDefinitions()[i]);
46     for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
47         checkErrorAndVisit(program.functionDefinitions()[i]);
48     for (size_t i = 0; i < program.nativeFunctionDeclarations().size(); ++i)
49         checkErrorAndVisit(program.nativeFunctionDeclarations()[i]);
50     for (size_t i = 0; i < program.nativeTypeDeclarations().size(); ++i)
51         checkErrorAndVisit(program.nativeTypeDeclarations()[i]);
52 }
53
54 void Visitor::visit(AST::UnnamedType& unnamedType)
55 {
56     if (is<AST::TypeReference>(unnamedType))
57         checkErrorAndVisit(downcast<AST::TypeReference>(unnamedType));
58     else if (is<AST::PointerType>(unnamedType))
59         checkErrorAndVisit(downcast<AST::PointerType>(unnamedType));
60     else if (is<AST::ArrayReferenceType>(unnamedType))
61         checkErrorAndVisit(downcast<AST::ArrayReferenceType>(unnamedType));
62     else {
63         ASSERT(is<AST::ArrayType>(unnamedType));
64         checkErrorAndVisit(downcast<AST::ArrayType>(unnamedType));
65     }
66 }
67
68 void Visitor::visit(AST::NamedType& namedType)
69 {
70     if (is<AST::TypeDefinition>(namedType))
71         checkErrorAndVisit(downcast<AST::TypeDefinition>(namedType));
72     else if (is<AST::StructureDefinition>(namedType))
73         checkErrorAndVisit(downcast<AST::StructureDefinition>(namedType));
74     else if (is<AST::EnumerationDefinition>(namedType))
75         checkErrorAndVisit(downcast<AST::EnumerationDefinition>(namedType));
76     else {
77         ASSERT(is<AST::NativeTypeDeclaration>(namedType));
78         checkErrorAndVisit(downcast<AST::NativeTypeDeclaration>(namedType));
79     }
80 }
81
82 void Visitor::visit(AST::TypeDefinition& typeDefinition)
83 {
84     checkErrorAndVisit(typeDefinition.type());
85 }
86
87 void Visitor::visit(AST::StructureDefinition& structureDefinition)
88 {
89     for (auto& structureElement : structureDefinition.structureElements())
90         checkErrorAndVisit(structureElement);
91 }
92
93 void Visitor::visit(AST::EnumerationDefinition& enumerationDefinition)
94 {
95     checkErrorAndVisit(enumerationDefinition.type());
96     for (auto& enumerationMember : enumerationDefinition.enumerationMembers())
97         checkErrorAndVisit(enumerationMember);
98 }
99
100 void Visitor::visit(AST::FunctionDefinition& functionDefinition)
101 {
102     checkErrorAndVisit(static_cast<AST::FunctionDeclaration&>(functionDefinition));
103     checkErrorAndVisit(functionDefinition.block());
104 }
105
106 void Visitor::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
107 {
108     checkErrorAndVisit(static_cast<AST::FunctionDeclaration&>(nativeFunctionDeclaration));
109 }
110
111 void Visitor::visit(AST::NativeTypeDeclaration& nativeTypeDeclaration)
112 {
113     for (auto& typeArgument : nativeTypeDeclaration.typeArguments())
114         checkErrorAndVisit(typeArgument);
115 }
116
117 void Visitor::visit(AST::TypeReference& typeReference)
118 {
119     for (auto& typeArgument : typeReference.typeArguments())
120         checkErrorAndVisit(typeArgument);
121     if (typeReference.maybeResolvedType() && is<AST::TypeDefinition>(typeReference.resolvedType())) {
122         auto& typeDefinition = downcast<AST::TypeDefinition>(typeReference.resolvedType());
123         checkErrorAndVisit(typeDefinition.type());
124     }
125 }
126
127 void Visitor::visit(AST::PointerType& pointerType)
128 {
129     checkErrorAndVisit(static_cast<AST::ReferenceType&>(pointerType));
130 }
131
132 void Visitor::visit(AST::ArrayReferenceType& arrayReferenceType)
133 {
134     checkErrorAndVisit(static_cast<AST::ReferenceType&>(arrayReferenceType));
135 }
136
137 void Visitor::visit(AST::ArrayType& arrayType)
138 {
139     checkErrorAndVisit(arrayType.type());
140 }
141
142 void Visitor::visit(AST::StructureElement& structureElement)
143 {
144     checkErrorAndVisit(structureElement.type());
145     if (structureElement.semantic())
146         checkErrorAndVisit(*structureElement.semantic());
147 }
148
149 void Visitor::visit(AST::EnumerationMember& enumerationMember)
150 {
151     if (enumerationMember.value())
152         checkErrorAndVisit(*enumerationMember.value());
153 }
154
155 void Visitor::visit(AST::FunctionDeclaration& functionDeclaration)
156 {
157     checkErrorAndVisit(functionDeclaration.attributeBlock());
158     checkErrorAndVisit(functionDeclaration.type());
159     for (auto& parameter : functionDeclaration.parameters())
160         checkErrorAndVisit(parameter);
161     if (functionDeclaration.semantic())
162         checkErrorAndVisit(*functionDeclaration.semantic());
163 }
164
165 void Visitor::visit(AST::TypeArgument& typeArgument)
166 {
167     WTF::visit(WTF::makeVisitor([&](AST::ConstantExpression& constantExpression) {
168         checkErrorAndVisit(constantExpression);
169     }, [&](UniqueRef<AST::TypeReference>& typeReference) {
170         checkErrorAndVisit(typeReference);
171     }), typeArgument);
172 }
173
174 void Visitor::visit(AST::ReferenceType& referenceType)
175 {
176     checkErrorAndVisit(referenceType.elementType());
177 }
178
179 void Visitor::visit(AST::Semantic& semantic)
180 {
181     WTF::visit(WTF::makeVisitor([&](AST::BuiltInSemantic& builtInSemantic) {
182         checkErrorAndVisit(builtInSemantic);
183     }, [&](AST::ResourceSemantic& resourceSemantic) {
184         checkErrorAndVisit(resourceSemantic);
185     }, [&](AST::SpecializationConstantSemantic& specializationConstantSemantic) {
186         checkErrorAndVisit(specializationConstantSemantic);
187     }, [&](AST::StageInOutSemantic& stageInOutSemantic) {
188         checkErrorAndVisit(stageInOutSemantic);
189     }), semantic);
190 }
191
192 void Visitor::visit(AST::ConstantExpression& constantExpression)
193 {
194     constantExpression.visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) {
195         checkErrorAndVisit(integerLiteral);
196     }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) {
197         checkErrorAndVisit(unsignedIntegerLiteral);
198     }, [&](AST::FloatLiteral& floatLiteral) {
199         checkErrorAndVisit(floatLiteral);
200     }, [&](AST::NullLiteral& nullLiteral) {
201         checkErrorAndVisit(nullLiteral);
202     }, [&](AST::BooleanLiteral& booleanLiteral) {
203         checkErrorAndVisit(booleanLiteral);
204     }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
205         checkErrorAndVisit(enumerationMemberLiteral);
206     }));
207 }
208
209 void Visitor::visit(AST::AttributeBlock& attributeBlock)
210 {
211     for (auto& functionAttribute : attributeBlock)
212         checkErrorAndVisit(functionAttribute);
213 }
214
215 void Visitor::visit(AST::BuiltInSemantic&)
216 {
217 }
218
219 void Visitor::visit(AST::ResourceSemantic&)
220 {
221 }
222
223 void Visitor::visit(AST::SpecializationConstantSemantic&)
224 {
225 }
226
227 void Visitor::visit(AST::StageInOutSemantic&)
228 {
229 }
230
231 void Visitor::visit(AST::IntegerLiteral& integerLiteral)
232 {
233     checkErrorAndVisit(integerLiteral.type());
234 }
235
236 void Visitor::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
237 {
238     checkErrorAndVisit(unsignedIntegerLiteral.type());
239 }
240
241 void Visitor::visit(AST::FloatLiteral& floatLiteral)
242 {
243     checkErrorAndVisit(floatLiteral.type());
244 }
245
246 void Visitor::visit(AST::NullLiteral& nullLiteral)
247 {
248     checkErrorAndVisit(nullLiteral.type());
249 }
250
251 void Visitor::visit(AST::BooleanLiteral&)
252 {
253 }
254
255 void Visitor::visit(AST::IntegerLiteralType& integerLiteralType)
256 {
257     if (integerLiteralType.maybeResolvedType())
258         checkErrorAndVisit(integerLiteralType.resolvedType());
259     checkErrorAndVisit(integerLiteralType.preferredType());
260 }
261
262 void Visitor::visit(AST::UnsignedIntegerLiteralType& unsignedIntegerLiteralType)
263 {
264     if (unsignedIntegerLiteralType.maybeResolvedType())
265         checkErrorAndVisit(unsignedIntegerLiteralType.resolvedType());
266     checkErrorAndVisit(unsignedIntegerLiteralType.preferredType());
267 }
268
269 void Visitor::visit(AST::FloatLiteralType& floatLiteralType)
270 {
271     if (floatLiteralType.maybeResolvedType())
272         checkErrorAndVisit(floatLiteralType.resolvedType());
273     checkErrorAndVisit(floatLiteralType.preferredType());
274 }
275
276 void Visitor::visit(AST::NullLiteralType& nullLiteralType)
277 {
278     if (nullLiteralType.maybeResolvedType())
279         checkErrorAndVisit(nullLiteralType.resolvedType());
280 }
281
282 void Visitor::visit(AST::EnumerationMemberLiteral&)
283 {
284 }
285
286 void Visitor::visit(AST::FunctionAttribute& functionAttribute)
287 {
288     WTF::visit(WTF::makeVisitor([&](AST::NumThreadsFunctionAttribute& numThreadsFunctionAttribute) {
289         checkErrorAndVisit(numThreadsFunctionAttribute);
290     }), functionAttribute);
291 }
292
293 void Visitor::visit(AST::NumThreadsFunctionAttribute&)
294 {
295 }
296
297 void Visitor::visit(AST::Block& block)
298 {
299     for (auto& statement : block.statements())
300         checkErrorAndVisit(statement);
301 }
302
303 void Visitor::visit(AST::Statement& statement)
304 {
305     if (is<AST::Block>(statement))
306         checkErrorAndVisit(downcast<AST::Block>(statement));
307     else if (is<AST::Break>(statement))
308         checkErrorAndVisit(downcast<AST::Break>(statement));
309     else if (is<AST::Continue>(statement))
310         checkErrorAndVisit(downcast<AST::Continue>(statement));
311     else if (is<AST::DoWhileLoop>(statement))
312         checkErrorAndVisit(downcast<AST::DoWhileLoop>(statement));
313     else if (is<AST::EffectfulExpressionStatement>(statement))
314         checkErrorAndVisit(downcast<AST::EffectfulExpressionStatement>(statement));
315     else if (is<AST::Fallthrough>(statement))
316         checkErrorAndVisit(downcast<AST::Fallthrough>(statement));
317     else if (is<AST::ForLoop>(statement))
318         checkErrorAndVisit(downcast<AST::ForLoop>(statement));
319     else if (is<AST::IfStatement>(statement))
320         checkErrorAndVisit(downcast<AST::IfStatement>(statement));
321     else if (is<AST::Return>(statement))
322         checkErrorAndVisit(downcast<AST::Return>(statement));
323     else if (is<AST::SwitchCase>(statement))
324         checkErrorAndVisit(downcast<AST::SwitchCase>(statement));
325     else if (is<AST::SwitchStatement>(statement))
326         checkErrorAndVisit(downcast<AST::SwitchStatement>(statement));
327     else if (is<AST::Trap>(statement))
328         checkErrorAndVisit(downcast<AST::Trap>(statement));
329     else if (is<AST::VariableDeclarationsStatement>(statement))
330         checkErrorAndVisit(downcast<AST::VariableDeclarationsStatement>(statement));
331     else {
332         ASSERT(is<AST::WhileLoop>(statement));
333         checkErrorAndVisit(downcast<AST::WhileLoop>(statement));
334     }
335 }
336
337 void Visitor::visit(AST::Break&)
338 {
339 }
340
341 void Visitor::visit(AST::Continue&)
342 {
343 }
344
345 void Visitor::visit(AST::DoWhileLoop& doWhileLoop)
346 {
347     checkErrorAndVisit(doWhileLoop.body());
348     checkErrorAndVisit(doWhileLoop.conditional());
349 }
350
351 void Visitor::visit(AST::Expression& expression)
352 {
353     if (is<AST::AssignmentExpression>(expression))
354         checkErrorAndVisit(downcast<AST::AssignmentExpression>(expression));
355     else if (is<AST::BooleanLiteral>(expression))
356         checkErrorAndVisit(downcast<AST::BooleanLiteral>(expression));
357     else if (is<AST::CallExpression>(expression))
358         checkErrorAndVisit(downcast<AST::CallExpression>(expression));
359     else if (is<AST::CommaExpression>(expression))
360         checkErrorAndVisit(downcast<AST::CommaExpression>(expression));
361     else if (is<AST::DereferenceExpression>(expression))
362         checkErrorAndVisit(downcast<AST::DereferenceExpression>(expression));
363     else if (is<AST::FloatLiteral>(expression))
364         checkErrorAndVisit(downcast<AST::FloatLiteral>(expression));
365     else if (is<AST::IntegerLiteral>(expression))
366         checkErrorAndVisit(downcast<AST::IntegerLiteral>(expression));
367     else if (is<AST::LogicalExpression>(expression))
368         checkErrorAndVisit(downcast<AST::LogicalExpression>(expression));
369     else if (is<AST::LogicalNotExpression>(expression))
370         checkErrorAndVisit(downcast<AST::LogicalNotExpression>(expression));
371     else if (is<AST::MakeArrayReferenceExpression>(expression))
372         checkErrorAndVisit(downcast<AST::MakeArrayReferenceExpression>(expression));
373     else if (is<AST::MakePointerExpression>(expression))
374         checkErrorAndVisit(downcast<AST::MakePointerExpression>(expression));
375     else if (is<AST::NullLiteral>(expression))
376         checkErrorAndVisit(downcast<AST::NullLiteral>(expression));
377     else if (is<AST::DotExpression>(expression))
378         checkErrorAndVisit(downcast<AST::DotExpression>(expression));
379     else if (is<AST::IndexExpression>(expression))
380         checkErrorAndVisit(downcast<AST::IndexExpression>(expression));
381     else if (is<AST::ReadModifyWriteExpression>(expression))
382         checkErrorAndVisit(downcast<AST::ReadModifyWriteExpression>(expression));
383     else if (is<AST::TernaryExpression>(expression))
384         checkErrorAndVisit(downcast<AST::TernaryExpression>(expression));
385     else if (is<AST::UnsignedIntegerLiteral>(expression))
386         checkErrorAndVisit(downcast<AST::UnsignedIntegerLiteral>(expression));
387     else if (is<AST::EnumerationMemberLiteral>(expression))
388         checkErrorAndVisit(downcast<AST::EnumerationMemberLiteral>(expression));
389     else {
390         ASSERT(is<AST::VariableReference>(expression));
391         checkErrorAndVisit(downcast<AST::VariableReference>(expression));
392     }
393 }
394
395 void Visitor::visit(AST::DotExpression& dotExpression)
396 {
397     checkErrorAndVisit(static_cast<AST::PropertyAccessExpression&>(dotExpression));
398 }
399
400 void Visitor::visit(AST::IndexExpression& indexExpression)
401 {
402     checkErrorAndVisit(indexExpression.indexExpression());
403     checkErrorAndVisit(static_cast<AST::PropertyAccessExpression&>(indexExpression));
404 }
405
406 void Visitor::visit(AST::PropertyAccessExpression& expression)
407 {
408     checkErrorAndVisit(expression.base());
409 }
410
411 void Visitor::visit(AST::EffectfulExpressionStatement& effectfulExpressionStatement)
412 {
413     checkErrorAndVisit(effectfulExpressionStatement.effectfulExpression());
414 }
415
416 void Visitor::visit(AST::Fallthrough&)
417 {
418 }
419
420 void Visitor::visit(AST::ForLoop& forLoop)
421 {
422     WTF::visit(WTF::makeVisitor([&](AST::VariableDeclarationsStatement& variableDeclarationsStatement) {
423         checkErrorAndVisit(variableDeclarationsStatement);
424     }, [&](UniqueRef<AST::Expression>& expression) {
425         checkErrorAndVisit(expression);
426     }), forLoop.initialization());
427     if (forLoop.condition())
428         checkErrorAndVisit(*forLoop.condition());
429     if (forLoop.increment())
430         checkErrorAndVisit(*forLoop.increment());
431     checkErrorAndVisit(forLoop.body());
432 }
433
434 void Visitor::visit(AST::IfStatement& ifStatement)
435 {
436     checkErrorAndVisit(ifStatement.conditional());
437     checkErrorAndVisit(ifStatement.body());
438     if (ifStatement.elseBody())
439         checkErrorAndVisit(*ifStatement.elseBody());
440 }
441
442 void Visitor::visit(AST::Return& returnStatement)
443 {
444     if (returnStatement.value())
445         checkErrorAndVisit(*returnStatement.value());
446 }
447
448 void Visitor::visit(AST::SwitchCase& switchCase)
449 {
450     if (switchCase.value())
451         checkErrorAndVisit(*switchCase.value());
452     checkErrorAndVisit(switchCase.block());
453 }
454
455 void Visitor::visit(AST::SwitchStatement& switchStatement)
456 {
457     checkErrorAndVisit(switchStatement.value());
458     for (auto& switchCase : switchStatement.switchCases())
459         checkErrorAndVisit(switchCase);
460 }
461
462 void Visitor::visit(AST::Trap&)
463 {
464 }
465
466 void Visitor::visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement)
467 {
468     for (auto& variableDeclaration : variableDeclarationsStatement.variableDeclarations())
469         checkErrorAndVisit(variableDeclaration.get());
470 }
471
472 void Visitor::visit(AST::WhileLoop& whileLoop)
473 {
474     checkErrorAndVisit(whileLoop.conditional());
475     checkErrorAndVisit(whileLoop.body());
476 }
477
478 void Visitor::visit(AST::VariableDeclaration& variableDeclaration)
479 {
480     if (variableDeclaration.type())
481         checkErrorAndVisit(*variableDeclaration.type());
482     if (variableDeclaration.semantic())
483         checkErrorAndVisit(*variableDeclaration.semantic());
484     if (variableDeclaration.initializer())
485         checkErrorAndVisit(*variableDeclaration.initializer());
486 }
487
488 void Visitor::visit(AST::AssignmentExpression& assignmentExpression)
489 {
490     checkErrorAndVisit(assignmentExpression.left());
491     checkErrorAndVisit(assignmentExpression.right());
492 }
493
494 void Visitor::visit(AST::CallExpression& callExpression)
495 {
496     for (auto& argument : callExpression.arguments())
497         checkErrorAndVisit(argument);
498     if (callExpression.castReturnType())
499         checkErrorAndVisit(*callExpression.castReturnType());
500 }
501
502 void Visitor::visit(AST::CommaExpression& commaExpression)
503 {
504     for (auto& expression : commaExpression.list())
505         checkErrorAndVisit(expression);
506 }
507
508 void Visitor::visit(AST::DereferenceExpression& dereferenceExpression)
509 {
510     checkErrorAndVisit(dereferenceExpression.pointer());
511 }
512
513 void Visitor::visit(AST::LogicalExpression& logicalExpression)
514 {
515     checkErrorAndVisit(logicalExpression.left());
516     checkErrorAndVisit(logicalExpression.right());
517 }
518
519 void Visitor::visit(AST::LogicalNotExpression& logicalNotExpression)
520 {
521     checkErrorAndVisit(logicalNotExpression.operand());
522 }
523
524 void Visitor::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
525 {
526     checkErrorAndVisit(makeArrayReferenceExpression.leftValue());
527 }
528
529 void Visitor::visit(AST::MakePointerExpression& makePointerExpression)
530 {
531     checkErrorAndVisit(makePointerExpression.leftValue());
532 }
533
534 void Visitor::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
535 {
536     checkErrorAndVisit(readModifyWriteExpression.leftValue());
537     checkErrorAndVisit(readModifyWriteExpression.oldValue());
538     checkErrorAndVisit(readModifyWriteExpression.newValue());
539     checkErrorAndVisit(readModifyWriteExpression.newValueExpression());
540     checkErrorAndVisit(readModifyWriteExpression.resultExpression());
541 }
542
543 void Visitor::visit(AST::TernaryExpression& ternaryExpression)
544 {
545     checkErrorAndVisit(ternaryExpression.predicate());
546     checkErrorAndVisit(ternaryExpression.bodyExpression());
547     checkErrorAndVisit(ternaryExpression.elseExpression());
548 }
549
550 void Visitor::visit(AST::VariableReference&)
551 {
552 }
553
554 } // namespace WHLSL
555
556 } // namespace WebCore
557
558 #endif // ENABLE(WEBGPU)