543f8c0be7009bc990a35be677198c617bfcd0ac
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / Metal / WHLSLFunctionWriter.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLFunctionWriter.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "NotImplemented.h"
32 #include "WHLSLArrayReferenceType.h"
33 #include "WHLSLArrayType.h"
34 #include "WHLSLAssignmentExpression.h"
35 #include "WHLSLBooleanLiteral.h"
36 #include "WHLSLBuiltInSemantic.h"
37 #include "WHLSLCallExpression.h"
38 #include "WHLSLCommaExpression.h"
39 #include "WHLSLDereferenceExpression.h"
40 #include "WHLSLDoWhileLoop.h"
41 #include "WHLSLEffectfulExpressionStatement.h"
42 #include "WHLSLEntryPointScaffolding.h"
43 #include "WHLSLEntryPointType.h"
44 #include "WHLSLFloatLiteral.h"
45 #include "WHLSLForLoop.h"
46 #include "WHLSLFunctionDeclaration.h"
47 #include "WHLSLFunctionDefinition.h"
48 #include "WHLSLIfStatement.h"
49 #include "WHLSLIntegerLiteral.h"
50 #include "WHLSLLogicalExpression.h"
51 #include "WHLSLLogicalNotExpression.h"
52 #include "WHLSLMakeArrayReferenceExpression.h"
53 #include "WHLSLMakePointerExpression.h"
54 #include "WHLSLNativeFunctionDeclaration.h"
55 #include "WHLSLNativeFunctionWriter.h"
56 #include "WHLSLNativeTypeDeclaration.h"
57 #include "WHLSLPointerType.h"
58 #include "WHLSLProgram.h"
59 #include "WHLSLReturn.h"
60 #include "WHLSLSwitchCase.h"
61 #include "WHLSLSwitchStatement.h"
62 #include "WHLSLTernaryExpression.h"
63 #include "WHLSLTypeNamer.h"
64 #include "WHLSLUnsignedIntegerLiteral.h"
65 #include "WHLSLVariableDeclaration.h"
66 #include "WHLSLVariableDeclarationsStatement.h"
67 #include "WHLSLVariableReference.h"
68 #include "WHLSLVisitor.h"
69 #include "WHLSLWhileLoop.h"
70 #include <wtf/HashMap.h>
71 #include <wtf/text/StringBuilder.h>
72
73 namespace WebCore {
74
75 namespace WHLSL {
76
77 namespace Metal {
78
79 class FunctionDeclarationWriter : public Visitor {
80 public:
81     FunctionDeclarationWriter(TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping)
82         : m_typeNamer(typeNamer)
83         , m_functionMapping(functionMapping)
84     {
85     }
86
87     virtual ~FunctionDeclarationWriter() = default;
88
89     String toString() { return m_stringBuilder.toString(); }
90
91     void visit(AST::FunctionDeclaration&) override;
92
93 private:
94     TypeNamer& m_typeNamer;
95     HashMap<AST::FunctionDeclaration*, String>& m_functionMapping;
96     StringBuilder m_stringBuilder;
97 };
98
99 void FunctionDeclarationWriter::visit(AST::FunctionDeclaration& functionDeclaration)
100 {
101     if (functionDeclaration.entryPointType())
102         return;
103
104     auto iterator = m_functionMapping.find(&functionDeclaration);
105     ASSERT(iterator != m_functionMapping.end());
106     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(functionDeclaration.type()), ' ', iterator->value, '('));
107     for (size_t i = 0; i < functionDeclaration.parameters().size(); ++i) {
108         if (i)
109             m_stringBuilder.append(", ");
110         m_stringBuilder.append(m_typeNamer.mangledNameForType(*functionDeclaration.parameters()[i]->type()));
111     }
112     m_stringBuilder.append(");\n");
113 }
114
115 class FunctionDefinitionWriter : public Visitor {
116 public:
117     FunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, Layout& layout)
118         : m_intrinsics(intrinsics)
119         , m_typeNamer(typeNamer)
120         , m_functionMapping(functionMapping)
121         , m_layout(layout)
122     {
123     }
124
125     virtual ~FunctionDefinitionWriter() = default;
126
127     String toString() { return m_stringBuilder.toString(); }
128
129     void visit(AST::NativeFunctionDeclaration&) override;
130     void visit(AST::FunctionDefinition&) override;
131
132 protected:
133     virtual std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) = 0;
134
135     void visit(AST::FunctionDeclaration&) override;
136     void visit(AST::Statement&) override;
137     void visit(AST::Block&) override;
138     void visit(AST::Break&) override;
139     void visit(AST::Continue&) override;
140     void visit(AST::DoWhileLoop&) override;
141     void visit(AST::EffectfulExpressionStatement&) override;
142     void visit(AST::Fallthrough&) override;
143     void visit(AST::ForLoop&) override;
144     void visit(AST::IfStatement&) override;
145     void visit(AST::Return&) override;
146     void visit(AST::SwitchStatement&) override;
147     void visit(AST::SwitchCase&) override;
148     void visit(AST::Trap&) override;
149     void visit(AST::VariableDeclarationsStatement&) override;
150     void visit(AST::WhileLoop&) override;
151     void visit(AST::IntegerLiteral&) override;
152     void visit(AST::UnsignedIntegerLiteral&) override;
153     void visit(AST::FloatLiteral&) override;
154     void visit(AST::NullLiteral&) override;
155     void visit(AST::BooleanLiteral&) override;
156     void visit(AST::EnumerationMemberLiteral&) override;
157     void visit(AST::Expression&) override;
158     void visit(AST::DotExpression&) override;
159     void visit(AST::IndexExpression&) override;
160     void visit(AST::PropertyAccessExpression&) override;
161     void visit(AST::VariableDeclaration&) override;
162     void visit(AST::AssignmentExpression&) override;
163     void visit(AST::CallExpression&) override;
164     void visit(AST::CommaExpression&) override;
165     void visit(AST::DereferenceExpression&) override;
166     void visit(AST::LogicalExpression&) override;
167     void visit(AST::LogicalNotExpression&) override;
168     void visit(AST::MakeArrayReferenceExpression&) override;
169     void visit(AST::MakePointerExpression&) override;
170     void visit(AST::ReadModifyWriteExpression&) override;
171     void visit(AST::TernaryExpression&) override;
172     void visit(AST::VariableReference&) override;
173
174     String constantExpressionString(AST::ConstantExpression&);
175
176     String generateNextVariableName()
177     {
178         return makeString("variable", m_variableCount++);
179     }
180
181     Intrinsics& m_intrinsics;
182     TypeNamer& m_typeNamer;
183     HashMap<AST::FunctionDeclaration*, String>& m_functionMapping;
184     HashMap<AST::VariableDeclaration*, String> m_variableMapping;
185     StringBuilder m_stringBuilder;
186     Vector<String> m_stack;
187     std::unique_ptr<EntryPointScaffolding> m_entryPointScaffolding;
188     Layout& m_layout;
189     unsigned m_variableCount { 0 };
190 };
191
192 void FunctionDefinitionWriter::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
193 {
194     auto iterator = m_functionMapping.find(&nativeFunctionDeclaration);
195     ASSERT(iterator != m_functionMapping.end());
196     m_stringBuilder.append(writeNativeFunction(nativeFunctionDeclaration, iterator->value, m_intrinsics, m_typeNamer));
197 }
198
199 void FunctionDefinitionWriter::visit(AST::FunctionDefinition& functionDefinition)
200 {
201     auto iterator = m_functionMapping.find(&functionDefinition);
202     ASSERT(iterator != m_functionMapping.end());
203     if (functionDefinition.entryPointType()) {
204         auto entryPointScaffolding = createEntryPointScaffolding(functionDefinition);
205         if (!entryPointScaffolding)
206             return;
207         m_entryPointScaffolding = WTFMove(entryPointScaffolding);
208         m_stringBuilder.append(m_entryPointScaffolding->helperTypes());
209         m_stringBuilder.append('\n');
210         m_stringBuilder.append(makeString(m_entryPointScaffolding->signature(iterator->value), " {\n"));
211         m_stringBuilder.append(m_entryPointScaffolding->unpack());
212         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
213             auto addResult = m_variableMapping.add(&functionDefinition.parameters()[i], m_entryPointScaffolding->parameterVariables()[i]);
214             ASSERT_UNUSED(addResult, addResult.isNewEntry);
215         }
216         checkErrorAndVisit(functionDefinition.block());
217         ASSERT(m_stack.isEmpty());
218         m_stringBuilder.append("}\n");
219         m_entryPointScaffolding = nullptr;
220     } else {
221         ASSERT(m_entryPointScaffolding == nullptr);
222         m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(functionDefinition.type()), ' ', iterator->value, '('));
223         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
224             auto& parameter = functionDefinition.parameters()[i];
225             if (i)
226                 m_stringBuilder.append(", ");
227             auto parameterName = generateNextVariableName();
228             auto addResult = m_variableMapping.add(&parameter, parameterName);
229             ASSERT_UNUSED(addResult, addResult.isNewEntry);
230             m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*parameter->type()), ' ', parameterName));
231         }
232         m_stringBuilder.append(") {\n");
233         checkErrorAndVisit(functionDefinition.block());
234         ASSERT(m_stack.isEmpty());
235         m_stringBuilder.append("}\n");
236     }
237 }
238
239 void FunctionDefinitionWriter::visit(AST::FunctionDeclaration&)
240 {
241     ASSERT_NOT_REACHED();
242 }
243
244 void FunctionDefinitionWriter::visit(AST::Statement& statement)
245 {
246     Visitor::visit(statement);
247 }
248
249 void FunctionDefinitionWriter::visit(AST::Block& block)
250 {
251     m_stringBuilder.append("{\n");
252     for (auto& statement : block.statements())
253         checkErrorAndVisit(statement);
254     m_stringBuilder.append("}\n");
255 }
256
257 void FunctionDefinitionWriter::visit(AST::Break&)
258 {
259     m_stringBuilder.append("break;\n");
260 }
261
262 void FunctionDefinitionWriter::visit(AST::Continue&)
263 {
264     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195808 Figure out which loop we're in, and run the increment code
265     notImplemented();
266 }
267
268 void FunctionDefinitionWriter::visit(AST::DoWhileLoop& doWhileLoop)
269 {
270     m_stringBuilder.append("do {\n");
271     checkErrorAndVisit(doWhileLoop.body());
272     checkErrorAndVisit(doWhileLoop.conditional());
273     m_stringBuilder.append(makeString("if (!", m_stack.takeLast(), ") break;\n"));
274     m_stringBuilder.append(makeString("} while(true);\n"));
275 }
276
277 void FunctionDefinitionWriter::visit(AST::EffectfulExpressionStatement& effectfulExpressionStatement)
278 {
279     checkErrorAndVisit(effectfulExpressionStatement.effectfulExpression());
280     m_stack.takeLast(); // The statement is already effectful, so we don't need to do anything with the result.
281 }
282
283 void FunctionDefinitionWriter::visit(AST::Fallthrough&)
284 {
285     m_stringBuilder.append("[[clang::fallthrough]];\n"); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195808 Make sure this is okay. Alternatively, we could do nothing and just return here instead.
286 }
287
288 void FunctionDefinitionWriter::visit(AST::ForLoop& forLoop)
289 {
290     WTF::visit(WTF::makeVisitor([&](AST::VariableDeclarationsStatement& variableDeclarationsStatement) {
291         checkErrorAndVisit(variableDeclarationsStatement);
292     }, [&](UniqueRef<AST::Expression>& expression) {
293         checkErrorAndVisit(expression);
294         m_stack.takeLast(); // We don't need to do anything with the result.
295     }), forLoop.initialization());
296
297     m_stringBuilder.append("for ( ; ; ) {\n");
298     if (forLoop.condition()) {
299         checkErrorAndVisit(*forLoop.condition());
300         m_stringBuilder.append(makeString("if (!", m_stack.takeLast(), ") break;\n"));
301     }
302     checkErrorAndVisit(forLoop.body());
303     if (forLoop.increment()) {
304         checkErrorAndVisit(*forLoop.increment());
305         m_stack.takeLast();
306     }
307     m_stringBuilder.append("}\n");
308 }
309
310 void FunctionDefinitionWriter::visit(AST::IfStatement& ifStatement)
311 {
312     checkErrorAndVisit(ifStatement.conditional());
313     m_stringBuilder.append(makeString("if (", m_stack.takeLast(), ") {\n"));
314     checkErrorAndVisit(ifStatement.body());
315     if (ifStatement.elseBody()) {
316         m_stringBuilder.append("} else {\n");
317         checkErrorAndVisit(*ifStatement.elseBody());
318     }
319     m_stringBuilder.append("}\n");
320 }
321
322 void FunctionDefinitionWriter::visit(AST::Return& returnStatement)
323 {
324     if (returnStatement.value()) {
325         checkErrorAndVisit(*returnStatement.value());
326         if (m_entryPointScaffolding) {
327             auto variableName = generateNextVariableName();
328             m_stringBuilder.append(m_entryPointScaffolding->pack(m_stack.takeLast(), variableName));
329             m_stringBuilder.append(makeString("return ", variableName, ";\n"));
330         } else
331             m_stringBuilder.append(makeString("return ", m_stack.takeLast(), ";\n"));
332     } else
333         m_stringBuilder.append("return;\n");
334 }
335
336 void FunctionDefinitionWriter::visit(AST::SwitchStatement& switchStatement)
337 {
338     checkErrorAndVisit(switchStatement.value());
339
340     m_stringBuilder.append(makeString("switch (", m_stack.takeLast(), ") {"));
341     for (auto& switchCase : switchStatement.switchCases())
342         checkErrorAndVisit(switchCase);
343     m_stringBuilder.append("}\n");
344 }
345
346 void FunctionDefinitionWriter::visit(AST::SwitchCase& switchCase)
347 {
348     if (switchCase.value())
349         m_stringBuilder.append(makeString("case ", constantExpressionString(*switchCase.value()), ":\n"));
350     else
351         m_stringBuilder.append("default:\n");
352     checkErrorAndVisit(switchCase.block());
353     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195812 Figure out whether we need to break or fallthrough.
354     notImplemented();
355 }
356
357 void FunctionDefinitionWriter::visit(AST::Trap&)
358 {
359     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195811 Implement this
360     notImplemented();
361 }
362
363 void FunctionDefinitionWriter::visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement)
364 {
365     Visitor::visit(variableDeclarationsStatement);
366 }
367
368 void FunctionDefinitionWriter::visit(AST::WhileLoop& whileLoop)
369 {
370     m_stringBuilder.append(makeString("while (true) {\n"));
371     checkErrorAndVisit(whileLoop.conditional());
372     m_stringBuilder.append(makeString("if (!", m_stack.takeLast(), ") break;\n"));
373     checkErrorAndVisit(whileLoop.body());
374     m_stringBuilder.append("}\n");
375 }
376
377 void FunctionDefinitionWriter::visit(AST::IntegerLiteral& integerLiteral)
378 {
379     auto variableName = generateNextVariableName();
380     auto mangledTypeName = m_typeNamer.mangledNameForType(integerLiteral.resolvedType());
381     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", integerLiteral.value(), ");\n"));
382     m_stack.append(variableName);
383 }
384
385 void FunctionDefinitionWriter::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
386 {
387     auto variableName = generateNextVariableName();
388     auto mangledTypeName = m_typeNamer.mangledNameForType(unsignedIntegerLiteral.resolvedType());
389     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", unsignedIntegerLiteral.value(), ");\n"));
390     m_stack.append(variableName);
391 }
392
393 void FunctionDefinitionWriter::visit(AST::FloatLiteral& floatLiteral)
394 {
395     auto variableName = generateNextVariableName();
396     auto mangledTypeName = m_typeNamer.mangledNameForType(floatLiteral.resolvedType());
397     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", floatLiteral.value(), ");\n"));
398     m_stack.append(variableName);
399 }
400
401 void FunctionDefinitionWriter::visit(AST::NullLiteral& nullLiteral)
402 {
403     auto& unifyNode = nullLiteral.resolvedType().unifyNode();
404     ASSERT(is<AST::UnnamedType>(unifyNode));
405     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
406     bool isArrayReferenceType = is<AST::ArrayReferenceType>(unnamedType);
407
408     auto variableName = generateNextVariableName();
409     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(nullLiteral.resolvedType()), ' ', variableName, " = "));
410     if (isArrayReferenceType)
411         m_stringBuilder.append("{ nullptr, 0 }");
412     else
413         m_stringBuilder.append("nullptr");
414     m_stringBuilder.append(";\n");
415     m_stack.append(variableName);
416 }
417
418 void FunctionDefinitionWriter::visit(AST::BooleanLiteral& booleanLiteral)
419 {
420     auto variableName = generateNextVariableName();
421     auto mangledTypeName = m_typeNamer.mangledNameForType(booleanLiteral.resolvedType());
422     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", booleanLiteral.value() ? "true" : "false", ");\n"));
423     m_stack.append(variableName);
424 }
425
426 void FunctionDefinitionWriter::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
427 {
428     ASSERT(enumerationMemberLiteral.enumerationDefinition());
429     ASSERT(enumerationMemberLiteral.enumerationDefinition());
430     auto variableName = generateNextVariableName();
431     auto mangledTypeName = m_typeNamer.mangledNameForType(enumerationMemberLiteral.resolvedType());
432     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = ", mangledTypeName, '.', m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()), ";\n"));
433     m_stack.append(variableName);
434 }
435
436 void FunctionDefinitionWriter::visit(AST::Expression& expression)
437 {
438     Visitor::visit(expression);
439 }
440
441 void FunctionDefinitionWriter::visit(AST::DotExpression&)
442 {
443     // This should be lowered already.
444     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
445     notImplemented();
446     m_stack.append("dummy");
447 }
448
449 void FunctionDefinitionWriter::visit(AST::IndexExpression&)
450 {
451     // This should be lowered already.
452     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
453     notImplemented();
454     m_stack.append("dummy");
455 }
456
457 void FunctionDefinitionWriter::visit(AST::PropertyAccessExpression&)
458 {
459     // This should be lowered already.
460     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
461     notImplemented();
462     m_stack.append("dummy");
463 }
464
465 void FunctionDefinitionWriter::visit(AST::VariableDeclaration& variableDeclaration)
466 {
467     ASSERT(variableDeclaration.type());
468     auto variableName = generateNextVariableName();
469     auto addResult = m_variableMapping.add(&variableDeclaration, variableName);
470     ASSERT_UNUSED(addResult, addResult.isNewEntry);
471     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198160 Implement qualifiers.
472     if (variableDeclaration.initializer()) {
473         checkErrorAndVisit(*variableDeclaration.initializer());
474         m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = ", m_stack.takeLast(), ";\n"));
475     } else {
476         // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195771 Zero-fill the variable.
477         m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, ";\n"));
478     }
479 }
480
481 void FunctionDefinitionWriter::visit(AST::AssignmentExpression& assignmentExpression)
482 {
483     if (is<AST::DereferenceExpression>(assignmentExpression.left())) {
484         checkErrorAndVisit(downcast<AST::DereferenceExpression>(assignmentExpression.left()).pointer());
485         auto leftName = m_stack.takeLast();
486         checkErrorAndVisit(assignmentExpression.right());
487         auto rightName = m_stack.takeLast();
488         m_stringBuilder.append(makeString('*', leftName, " = ", rightName, ";\n"));
489         m_stack.append(rightName);
490         return;
491     }
492     checkErrorAndVisit(assignmentExpression.left());
493     auto leftName = m_stack.takeLast();
494     checkErrorAndVisit(assignmentExpression.right());
495     auto rightName = m_stack.takeLast();
496     m_stringBuilder.append(makeString(leftName, " = ", rightName, ";\n"));
497     m_stack.append(rightName);
498 }
499
500 void FunctionDefinitionWriter::visit(AST::CallExpression& callExpression)
501 {
502     Vector<String> argumentNames;
503     for (auto& argument : callExpression.arguments()) {
504         checkErrorAndVisit(argument);
505         argumentNames.append(m_stack.takeLast());
506     }
507     ASSERT(callExpression.function());
508     auto iterator = m_functionMapping.find(callExpression.function());
509     ASSERT(iterator != m_functionMapping.end());
510     auto variableName = generateNextVariableName();
511     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(callExpression.resolvedType()), ' ', variableName, " = ", iterator->value, '('));
512     for (size_t i = 0; i < argumentNames.size(); ++i) {
513         if (i)
514             m_stringBuilder.append(", ");
515         m_stringBuilder.append(argumentNames[i]);
516     }
517     m_stringBuilder.append(");\n");
518     m_stack.append(variableName);
519 }
520
521 void FunctionDefinitionWriter::visit(AST::CommaExpression& commaExpression)
522 {
523     String result;
524     for (auto& expression : commaExpression.list()) {
525         checkErrorAndVisit(expression);
526         result = m_stack.takeLast();
527     }
528     m_stack.append(result);
529 }
530
531 void FunctionDefinitionWriter::visit(AST::DereferenceExpression& dereferenceExpression)
532 {
533     checkErrorAndVisit(dereferenceExpression.pointer());
534     auto right = m_stack.takeLast();
535     auto variableName = generateNextVariableName();
536     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(dereferenceExpression.resolvedType()), ' ', variableName, " = *", right, ";\n"));
537     m_stack.append(variableName);
538 }
539
540 void FunctionDefinitionWriter::visit(AST::LogicalExpression& logicalExpression)
541 {
542     checkErrorAndVisit(logicalExpression.left());
543     auto left = m_stack.takeLast();
544     checkErrorAndVisit(logicalExpression.right());
545     auto right = m_stack.takeLast();
546     auto variableName = generateNextVariableName();
547     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(logicalExpression.resolvedType()), ' ', variableName, " = ", left));
548     switch (logicalExpression.type()) {
549     case AST::LogicalExpression::Type::And:
550         m_stringBuilder.append(" && ");
551         break;
552     default:
553         ASSERT(logicalExpression.type() == AST::LogicalExpression::Type::Or);
554         m_stringBuilder.append(" || ");
555         break;
556     }
557     m_stringBuilder.append(makeString(right, ";\n"));
558     m_stack.append(variableName);
559 }
560
561 void FunctionDefinitionWriter::visit(AST::LogicalNotExpression& logicalNotExpression)
562 {
563     checkErrorAndVisit(logicalNotExpression.operand());
564     auto operand = m_stack.takeLast();
565     auto variableName = generateNextVariableName();
566     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(logicalNotExpression.resolvedType()), ' ', variableName, " = !", operand, ";\n"));
567     m_stack.append(variableName);
568 }
569
570 void FunctionDefinitionWriter::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
571 {
572     checkErrorAndVisit(makeArrayReferenceExpression.leftValue());
573     auto lValue = m_stack.takeLast();
574     auto variableName = generateNextVariableName();
575     auto mangledTypeName = m_typeNamer.mangledNameForType(makeArrayReferenceExpression.resolvedType());
576     if (is<AST::PointerType>(makeArrayReferenceExpression.resolvedType()))
577         m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = { ", lValue, ", 1 };\n"));
578     else if (is<AST::ArrayType>(makeArrayReferenceExpression.resolvedType())) {
579         auto& arrayType = downcast<AST::ArrayType>(makeArrayReferenceExpression.resolvedType());
580         m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = { &(", lValue, "[0]), ", arrayType.numElements(), " };\n"));
581     } else
582         m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = { &", lValue, ", 1 };\n"));
583     m_stack.append(variableName);
584 }
585
586 void FunctionDefinitionWriter::visit(AST::MakePointerExpression& makePointerExpression)
587 {
588     checkErrorAndVisit(makePointerExpression.leftValue());
589     auto lValue = m_stack.takeLast();
590     auto variableName = generateNextVariableName();
591     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(makePointerExpression.resolvedType()), ' ', variableName, " = &", lValue, ";\n"));
592     m_stack.append(variableName);
593 }
594
595 void FunctionDefinitionWriter::visit(AST::ReadModifyWriteExpression&)
596 {
597     // This should be lowered already.
598     ASSERT_NOT_REACHED();
599 }
600
601 void FunctionDefinitionWriter::visit(AST::TernaryExpression& ternaryExpression)
602 {
603     checkErrorAndVisit(ternaryExpression.predicate());
604     auto check = m_stack.takeLast();
605
606     auto variableName = generateNextVariableName();
607     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(ternaryExpression.resolvedType()), ' ', variableName, ";\n"));
608
609     m_stringBuilder.append(makeString("if (", check, ") {\n"));
610     checkErrorAndVisit(ternaryExpression.bodyExpression());
611     m_stringBuilder.append(makeString(variableName, " = ", m_stack.takeLast(), ";\n"));
612     m_stringBuilder.append("} else {\n");
613     checkErrorAndVisit(ternaryExpression.elseExpression());
614     m_stringBuilder.append(makeString(variableName, " = ", m_stack.takeLast(), ";\n"));
615     m_stringBuilder.append("}\n");
616     m_stack.append(variableName);
617 }
618
619 void FunctionDefinitionWriter::visit(AST::VariableReference& variableReference)
620 {
621     ASSERT(variableReference.variable());
622     auto iterator = m_variableMapping.find(variableReference.variable());
623     ASSERT(iterator != m_variableMapping.end());
624     m_stack.append(iterator->value);
625 }
626
627 String FunctionDefinitionWriter::constantExpressionString(AST::ConstantExpression& constantExpression)
628 {
629     String result;
630     constantExpression.visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) {
631         result = makeString("", integerLiteral.value());
632     }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) {
633         result = makeString("", unsignedIntegerLiteral.value());
634     }, [&](AST::FloatLiteral& floatLiteral) {
635         result = makeString("", floatLiteral.value());
636     }, [&](AST::NullLiteral&) {
637         result = "nullptr"_str;
638     }, [&](AST::BooleanLiteral& booleanLiteral) {
639         result = booleanLiteral.value() ? "true"_str : "false"_str;
640     }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
641         ASSERT(enumerationMemberLiteral.enumerationDefinition());
642         ASSERT(enumerationMemberLiteral.enumerationDefinition());
643         result = makeString(m_typeNamer.mangledNameForType(*enumerationMemberLiteral.enumerationDefinition()), '.', m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()));
644     }));
645     return result;
646 }
647
648 class RenderFunctionDefinitionWriter : public FunctionDefinitionWriter {
649 public:
650     RenderFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
651         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
652         , m_matchedSemantics(WTFMove(matchedSemantics))
653     {
654     }
655
656 private:
657     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
658
659     MatchedRenderSemantics m_matchedSemantics;
660 };
661
662 std::unique_ptr<EntryPointScaffolding> RenderFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
663 {
664     auto generateNextVariableName = [this]() -> String {
665         return this->generateNextVariableName();
666     };
667     if (&functionDefinition == m_matchedSemantics.vertexShader)
668         return std::make_unique<VertexEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.vertexShaderEntryPointItems, m_matchedSemantics.vertexShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedVertexAttributes);
669     if (&functionDefinition == m_matchedSemantics.fragmentShader)
670         return std::make_unique<FragmentEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.fragmentShaderEntryPointItems, m_matchedSemantics.fragmentShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedColorAttachments);
671     return nullptr;
672 }
673
674 class ComputeFunctionDefinitionWriter : public FunctionDefinitionWriter {
675 public:
676     ComputeFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
677         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
678         , m_matchedSemantics(WTFMove(matchedSemantics))
679     {
680     }
681
682 private:
683     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
684
685     MatchedComputeSemantics m_matchedSemantics;
686 };
687
688 std::unique_ptr<EntryPointScaffolding> ComputeFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
689 {
690     auto generateNextVariableName = [this]() -> String {
691         return this->generateNextVariableName();
692     };
693     if (&functionDefinition == m_matchedSemantics.shader)
694         return std::make_unique<ComputeEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.entryPointItems, m_matchedSemantics.resourceMap, m_layout, WTFMove(generateNextVariableName));
695     return nullptr;
696 }
697
698 struct SharedMetalFunctionsResult {
699     HashMap<AST::FunctionDeclaration*, String> functionMapping;
700     String metalFunctions;
701 };
702 static SharedMetalFunctionsResult sharedMetalFunctions(Program& program, TypeNamer& typeNamer)
703 {
704     StringBuilder stringBuilder;
705
706     unsigned numFunctions = 0;
707     HashMap<AST::FunctionDeclaration*, String> functionMapping;
708     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
709         auto addResult = functionMapping.add(&nativeFunctionDeclaration, makeString("function", numFunctions++));
710         ASSERT_UNUSED(addResult, addResult.isNewEntry);
711     }
712     for (auto& functionDefinition : program.functionDefinitions()) {
713         auto addResult = functionMapping.add(&functionDefinition, makeString("function", numFunctions++));
714         ASSERT_UNUSED(addResult, addResult.isNewEntry);
715     }
716
717     {
718         FunctionDeclarationWriter functionDeclarationWriter(typeNamer, functionMapping);
719         for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations())
720             functionDeclarationWriter.visit(nativeFunctionDeclaration);
721         for (auto& functionDefinition : program.functionDefinitions()) {
722             if (!functionDefinition->entryPointType())
723                 functionDeclarationWriter.visit(functionDefinition);
724         }
725         stringBuilder.append(functionDeclarationWriter.toString());
726     }
727
728     stringBuilder.append('\n');
729     return { WTFMove(functionMapping), stringBuilder.toString() };
730 }
731
732 RenderMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
733 {
734     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer);
735
736     StringBuilder stringBuilder;
737     stringBuilder.append(sharedMetalFunctions.metalFunctions);
738
739     auto* vertexShaderEntryPoint = matchedSemantics.vertexShader;
740     auto* fragmentShaderEntryPoint = matchedSemantics.fragmentShader;
741
742     RenderFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
743     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations())
744         functionDefinitionWriter.visit(nativeFunctionDeclaration);
745     for (auto& functionDefinition : program.functionDefinitions())
746         functionDefinitionWriter.visit(functionDefinition);
747     stringBuilder.append(functionDefinitionWriter.toString());
748
749     RenderMetalFunctions result;
750     result.metalSource = stringBuilder.toString();
751     result.mangledVertexEntryPointName = sharedMetalFunctions.functionMapping.get(vertexShaderEntryPoint);
752     result.mangledFragmentEntryPointName = sharedMetalFunctions.functionMapping.get(fragmentShaderEntryPoint);
753     return result;
754 }
755
756 ComputeMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
757 {
758     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer);
759
760     StringBuilder stringBuilder;
761     stringBuilder.append(sharedMetalFunctions.metalFunctions);
762
763     auto* entryPoint = matchedSemantics.shader;
764
765     ComputeFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
766     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations())
767         functionDefinitionWriter.visit(nativeFunctionDeclaration);
768     for (auto& functionDefinition : program.functionDefinitions())
769         functionDefinitionWriter.visit(functionDefinition);
770     stringBuilder.append(functionDefinitionWriter.toString());
771
772     ComputeMetalFunctions result;
773     result.metalSource = stringBuilder.toString();
774     result.mangledEntryPointName = sharedMetalFunctions.functionMapping.get(entryPoint);
775     return result;
776 }
777
778 } // namespace Metal
779
780 } // namespace WHLSL
781
782 } // namespace WebCore
783
784 #endif