Rename StringBuilder::flexibleAppend(...) to StringBuilder::append(...)
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / Metal / WHLSLFunctionWriter.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLFunctionWriter.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "NotImplemented.h"
32 #include "WHLSLAST.h"
33 #include "WHLSLEntryPointScaffolding.h"
34 #include "WHLSLInferTypes.h"
35 #include "WHLSLNativeFunctionWriter.h"
36 #include "WHLSLProgram.h"
37 #include "WHLSLTypeNamer.h"
38 #include "WHLSLVisitor.h"
39 #include <wtf/HashMap.h>
40 #include <wtf/HashSet.h>
41 #include <wtf/SetForScope.h>
42 #include <wtf/text/StringBuilder.h>
43
44 namespace WebCore {
45
46 namespace WHLSL {
47
48 namespace Metal {
49
50 static void declareFunction(StringBuilder& stringBuilder, AST::FunctionDeclaration& functionDeclaration, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping)
51 {
52     if (functionDeclaration.entryPointType())
53         return;
54
55     auto iterator = functionMapping.find(&functionDeclaration);
56     ASSERT(iterator != functionMapping.end());
57     stringBuilder.append(typeNamer.mangledNameForType(functionDeclaration.type()), ' ', iterator->value, '(');
58     for (size_t i = 0; i < functionDeclaration.parameters().size(); ++i) {
59         if (i)
60             stringBuilder.append(", ");
61         stringBuilder.append(typeNamer.mangledNameForType(*functionDeclaration.parameters()[i]->type()));
62     }
63     stringBuilder.append(");\n");
64 }
65
66 class FunctionDefinitionWriter : public Visitor {
67 public:
68     FunctionDefinitionWriter(StringBuilder& stringBuilder, Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping, Layout& layout)
69         : m_stringBuilder(stringBuilder)
70         , m_intrinsics(intrinsics)
71         , m_typeNamer(typeNamer)
72         , m_functionMapping(functionMapping)
73         , m_layout(layout)
74     {
75     }
76
77     virtual ~FunctionDefinitionWriter() = default;
78
79     void visit(AST::NativeFunctionDeclaration&) override;
80     void visit(AST::FunctionDefinition&) override;
81
82 protected:
83     virtual std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) = 0;
84
85     void visit(AST::FunctionDeclaration&) override;
86     void visit(AST::Statement&) override;
87     void visit(AST::Block&) override;
88     void visit(AST::Break&) override;
89     void visit(AST::Continue&) override;
90     void visit(AST::DoWhileLoop&) override;
91     void visit(AST::EffectfulExpressionStatement&) override;
92     void visit(AST::Fallthrough&) override;
93     void visit(AST::ForLoop&) override;
94     void visit(AST::IfStatement&) override;
95     void visit(AST::Return&) override;
96     void visit(AST::SwitchStatement&) override;
97     void visit(AST::SwitchCase&) override;
98     void visit(AST::VariableDeclarationsStatement&) override;
99     void visit(AST::WhileLoop&) override;
100     void visit(AST::IntegerLiteral&) override;
101     void visit(AST::UnsignedIntegerLiteral&) override;
102     void visit(AST::FloatLiteral&) override;
103     void visit(AST::NullLiteral&) override;
104     void visit(AST::BooleanLiteral&) override;
105     void visit(AST::EnumerationMemberLiteral&) override;
106     void visit(AST::Expression&) override;
107     void visit(AST::DotExpression&) override;
108     void visit(AST::GlobalVariableReference&) override;
109     void visit(AST::IndexExpression&) override;
110     void visit(AST::PropertyAccessExpression&) override;
111     void visit(AST::VariableDeclaration&) override;
112     void visit(AST::AssignmentExpression&) override;
113     void visit(AST::CallExpression&) override;
114     void visit(AST::CommaExpression&) override;
115     void visit(AST::DereferenceExpression&) override;
116     void visit(AST::LogicalExpression&) override;
117     void visit(AST::LogicalNotExpression&) override;
118     void visit(AST::MakeArrayReferenceExpression&) override;
119     void visit(AST::MakePointerExpression&) override;
120     void visit(AST::ReadModifyWriteExpression&) override;
121     void visit(AST::TernaryExpression&) override;
122     void visit(AST::VariableReference&) override;
123
124     enum class LoopConditionLocation {
125         BeforeBody,
126         AfterBody
127     };
128     void emitLoop(LoopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body);
129
130     void emitConstantExpressionString(AST::ConstantExpression&);
131
132     MangledVariableName generateNextVariableName() { return { m_variableCount++ }; }
133
134     enum class Nullability : uint8_t {
135         NotNull,
136         CanBeNull
137     };
138
139     struct StackItem {
140         MangledVariableName value;
141         Optional<MangledVariableName> leftValue;
142         Nullability valueNullability;
143         Nullability leftValueNullability;
144     };
145
146     struct StackValue {
147         MangledVariableName value;
148         Nullability nullability;
149     };
150
151     // This is the important data flow step where we can take the nullability of an lvalue
152     // and transfer it into the nullability of an rvalue. This is conveyed in MakePointerExpression
153     // and DereferenceExpression. MakePointerExpression will try to produce rvalues which are
154     // non-null, and DereferenceExpression will take a non-null rvalue and try to produce
155     // a non-null lvalue.
156     void appendRightValueWithNullability(AST::Expression&, MangledVariableName value, Nullability nullability)
157     {
158         m_stack.append({ WTFMove(value), WTF::nullopt, nullability, Nullability::CanBeNull });
159     }
160
161     void appendRightValue(AST::Expression& expression, MangledVariableName value)
162     {
163         appendRightValueWithNullability(expression, WTFMove(value), Nullability::CanBeNull);
164     }
165
166     void appendLeftValue(AST::Expression& expression, MangledVariableName value, MangledVariableName leftValue, Nullability nullability)
167     {
168         ASSERT_UNUSED(expression, expression.typeAnnotation().leftAddressSpace());
169         m_stack.append({ WTFMove(value), WTFMove(leftValue), Nullability::CanBeNull, nullability });
170     }
171
172     MangledVariableName takeLastValue()
173     {
174         return m_stack.takeLast().value;
175     }
176
177     StackValue takeLastValueAndNullability()
178     {
179         auto last = m_stack.takeLast();
180         return { last.value, last.valueNullability };
181     }
182
183     StackValue takeLastLeftValue()
184     {
185         ASSERT(m_stack.last().leftValue);
186         auto last = m_stack.takeLast();
187         return { *last.leftValue, last.leftValueNullability };
188     }
189
190     enum class BreakContext {
191         Loop,
192         Switch
193     };
194
195     Optional<BreakContext> m_currentBreakContext;
196
197     StringBuilder& m_stringBuilder;
198     Intrinsics& m_intrinsics;
199     TypeNamer& m_typeNamer;
200     HashMap<AST::FunctionDeclaration*, MangledFunctionName>& m_functionMapping;
201     HashMap<AST::VariableDeclaration*, MangledVariableName> m_variableMapping;
202
203     Vector<StackItem> m_stack;
204     std::unique_ptr<EntryPointScaffolding> m_entryPointScaffolding;
205     Layout& m_layout;
206     unsigned m_variableCount { 0 };
207     Optional<MangledVariableName> m_breakOutOfCurrentLoopEarlyVariable;
208 };
209
210 void FunctionDefinitionWriter::visit(AST::NativeFunctionDeclaration&)
211 {
212     // We inline native function calls.
213 }
214
215 void FunctionDefinitionWriter::visit(AST::FunctionDefinition& functionDefinition)
216 {
217     auto iterator = m_functionMapping.find(&functionDefinition);
218     ASSERT(iterator != m_functionMapping.end());
219     if (functionDefinition.entryPointType()) {
220         auto entryPointScaffolding = createEntryPointScaffolding(functionDefinition);
221         if (!entryPointScaffolding)
222             return;
223         m_entryPointScaffolding = WTFMove(entryPointScaffolding);
224         m_entryPointScaffolding->emitHelperTypes(m_stringBuilder);
225         m_stringBuilder.append('\n');
226         m_entryPointScaffolding->emitSignature(m_stringBuilder, iterator->value);
227         m_stringBuilder.append(" {\n");
228         m_entryPointScaffolding->emitUnpack(m_stringBuilder);
229     
230         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
231             auto addResult = m_variableMapping.add(&functionDefinition.parameters()[i], m_entryPointScaffolding->parameterVariables()[i]);
232             ASSERT_UNUSED(addResult, addResult.isNewEntry);
233         }
234         checkErrorAndVisit(functionDefinition.block());
235         ASSERT(m_stack.isEmpty());
236         m_stringBuilder.append("}\n");
237         m_entryPointScaffolding = nullptr;
238     } else {
239         ASSERT(m_entryPointScaffolding == nullptr);
240         m_stringBuilder.append(m_typeNamer.mangledNameForType(functionDefinition.type()), ' ', iterator->value, '(');
241         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
242             auto& parameter = functionDefinition.parameters()[i];
243             if (i)
244                 m_stringBuilder.append(", ");
245             auto parameterName = generateNextVariableName();
246             auto addResult = m_variableMapping.add(&parameter, parameterName);
247             ASSERT_UNUSED(addResult, addResult.isNewEntry);
248             m_stringBuilder.append(m_typeNamer.mangledNameForType(*parameter->type()), ' ', parameterName);
249         }
250         m_stringBuilder.append(")\n");
251         checkErrorAndVisit(functionDefinition.block());
252         ASSERT(m_stack.isEmpty());
253         m_stringBuilder.append('\n');
254     }
255 }
256
257 void FunctionDefinitionWriter::visit(AST::FunctionDeclaration&)
258 {
259     ASSERT_NOT_REACHED();
260 }
261
262 void FunctionDefinitionWriter::visit(AST::Statement& statement)
263 {
264     Visitor::visit(statement);
265 }
266
267 void FunctionDefinitionWriter::visit(AST::Block& block)
268 {
269     m_stringBuilder.append("{\n");
270     for (auto& statement : block.statements())
271         checkErrorAndVisit(statement);
272     m_stringBuilder.append("}\n");
273 }
274
275 void FunctionDefinitionWriter::visit(AST::Break&)
276 {
277     ASSERT(m_currentBreakContext);
278     switch (*m_currentBreakContext) {
279     case BreakContext::Switch:
280         m_stringBuilder.append("break;\n");
281         break;
282     case BreakContext::Loop:
283         ASSERT(m_breakOutOfCurrentLoopEarlyVariable);
284         m_stringBuilder.append(
285             *m_breakOutOfCurrentLoopEarlyVariable, " = true;\n"
286             "break;\n"
287         );
288         break;
289     }
290 }
291
292 void FunctionDefinitionWriter::visit(AST::Continue&)
293 {
294     ASSERT(m_breakOutOfCurrentLoopEarlyVariable);
295     m_stringBuilder.append("break;\n");
296 }
297
298 void FunctionDefinitionWriter::visit(AST::EffectfulExpressionStatement& effectfulExpressionStatement)
299 {
300     checkErrorAndVisit(effectfulExpressionStatement.effectfulExpression());
301     takeLastValue(); // The statement is already effectful, so we don't need to do anything with the result.
302 }
303
304 void FunctionDefinitionWriter::visit(AST::Fallthrough&)
305 {
306     m_stringBuilder.append("[[clang::fallthrough]];\n"); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195808 Make sure this is okay. Alternatively, we could do nothing and just return here instead.
307 }
308
309 void FunctionDefinitionWriter::emitLoop(LoopConditionLocation loopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body)
310 {
311     SetForScope<Optional<MangledVariableName>> loopVariableScope(m_breakOutOfCurrentLoopEarlyVariable, generateNextVariableName());
312
313     m_stringBuilder.append(
314         "bool ", *m_breakOutOfCurrentLoopEarlyVariable, " = false;\n",
315         "while (true) {\n"
316     );
317
318     if (loopConditionLocation == LoopConditionLocation::BeforeBody && conditionExpression) {
319         checkErrorAndVisit(*conditionExpression);
320         m_stringBuilder.append("if (!", takeLastValue(), ") break;\n");
321     }
322
323     m_stringBuilder.append("do {\n");
324     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Loop);
325     checkErrorAndVisit(body);
326     m_stringBuilder.append(
327         "} while(false); \n"
328         "if (", *m_breakOutOfCurrentLoopEarlyVariable, ") break;\n"
329     );
330
331     if (increment) {
332         checkErrorAndVisit(*increment);
333         // Expression results get pushed to m_stack. We don't use the result
334         // of increment, so we dispense of that now.
335         takeLastValue();
336     }
337
338     if (loopConditionLocation == LoopConditionLocation::AfterBody && conditionExpression) {
339         checkErrorAndVisit(*conditionExpression);
340         m_stringBuilder.append("if (!", takeLastValue(), ") break;\n");
341     }
342
343     m_stringBuilder.append("} \n");
344 }
345
346 void FunctionDefinitionWriter::visit(AST::DoWhileLoop& doWhileLoop)
347 {
348     emitLoop(LoopConditionLocation::AfterBody, &doWhileLoop.conditional(), nullptr, doWhileLoop.body());
349 }
350
351 void FunctionDefinitionWriter::visit(AST::WhileLoop& whileLoop)
352 {
353     emitLoop(LoopConditionLocation::BeforeBody, &whileLoop.conditional(), nullptr, whileLoop.body());
354 }
355
356 void FunctionDefinitionWriter::visit(AST::ForLoop& forLoop)
357 {
358     m_stringBuilder.append("{\n");
359     checkErrorAndVisit(forLoop.initialization());
360     emitLoop(LoopConditionLocation::BeforeBody, forLoop.condition(), forLoop.increment(), forLoop.body());
361     m_stringBuilder.append("}\n");
362 }
363
364 void FunctionDefinitionWriter::visit(AST::IfStatement& ifStatement)
365 {
366     checkErrorAndVisit(ifStatement.conditional());
367     m_stringBuilder.append("if (", takeLastValue(), ") {\n");
368     checkErrorAndVisit(ifStatement.body());
369     if (ifStatement.elseBody()) {
370         m_stringBuilder.append("} else {\n");
371         checkErrorAndVisit(*ifStatement.elseBody());
372     }
373     m_stringBuilder.append("}\n");
374 }
375
376 void FunctionDefinitionWriter::visit(AST::Return& returnStatement)
377 {
378     if (returnStatement.value()) {
379         checkErrorAndVisit(*returnStatement.value());
380         if (m_entryPointScaffolding) {
381             auto variableName = generateNextVariableName();
382             m_entryPointScaffolding->emitPack(m_stringBuilder, takeLastValue(), variableName);
383             m_stringBuilder.append("return ", variableName, ";\n");
384         } else
385             m_stringBuilder.append("return ", takeLastValue(), ";\n");
386     } else
387         m_stringBuilder.append("return;\n");
388 }
389
390 void FunctionDefinitionWriter::visit(AST::SwitchStatement& switchStatement)
391 {
392     checkErrorAndVisit(switchStatement.value());
393
394     m_stringBuilder.append("switch (", takeLastValue(), ") {");
395     for (auto& switchCase : switchStatement.switchCases())
396         checkErrorAndVisit(switchCase);
397     m_stringBuilder.append("}\n");
398 }
399
400 void FunctionDefinitionWriter::visit(AST::SwitchCase& switchCase)
401 {
402     if (switchCase.value()) {
403         m_stringBuilder.append("case ");
404         emitConstantExpressionString(*switchCase.value());
405         m_stringBuilder.append(":\n");
406     } else
407         m_stringBuilder.append("default:\n");
408     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Switch);
409     checkErrorAndVisit(switchCase.block());
410     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195812 Figure out whether we need to break or fallthrough.
411 }
412
413 void FunctionDefinitionWriter::visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement)
414 {
415     Visitor::visit(variableDeclarationsStatement);
416 }
417
418 void FunctionDefinitionWriter::visit(AST::IntegerLiteral& integerLiteral)
419 {
420     auto variableName = generateNextVariableName();
421     auto mangledTypeName = m_typeNamer.mangledNameForType(integerLiteral.resolvedType());
422     m_stringBuilder.append(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", integerLiteral.value(), ");\n");
423     appendRightValue(integerLiteral, variableName);
424 }
425
426 void FunctionDefinitionWriter::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
427 {
428     auto variableName = generateNextVariableName();
429     auto mangledTypeName = m_typeNamer.mangledNameForType(unsignedIntegerLiteral.resolvedType());
430     m_stringBuilder.append(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", unsignedIntegerLiteral.value(), ");\n");
431     appendRightValue(unsignedIntegerLiteral, variableName);
432 }
433
434 void FunctionDefinitionWriter::visit(AST::FloatLiteral& floatLiteral)
435 {
436     auto variableName = generateNextVariableName();
437     auto mangledTypeName = m_typeNamer.mangledNameForType(floatLiteral.resolvedType());
438     m_stringBuilder.append(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", floatLiteral.value(), ");\n");
439     appendRightValue(floatLiteral, variableName);
440 }
441
442 void FunctionDefinitionWriter::visit(AST::NullLiteral& nullLiteral)
443 {
444     auto& unifyNode = nullLiteral.resolvedType().unifyNode();
445     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
446     bool isArrayReferenceType = is<AST::ArrayReferenceType>(unnamedType);
447
448     auto variableName = generateNextVariableName();
449     m_stringBuilder.append(m_typeNamer.mangledNameForType(nullLiteral.resolvedType()), ' ', variableName, " = ");
450     if (isArrayReferenceType)
451         m_stringBuilder.append("{ nullptr, 0 };\n");
452     else
453         m_stringBuilder.append("nullptr;\n");
454     appendRightValue(nullLiteral, variableName);
455 }
456
457 void FunctionDefinitionWriter::visit(AST::BooleanLiteral& booleanLiteral)
458 {
459     auto variableName = generateNextVariableName();
460     auto mangledTypeName = m_typeNamer.mangledNameForType(booleanLiteral.resolvedType());
461     m_stringBuilder.append(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", booleanLiteral.value() ? "true" : "false", ");\n");
462     appendRightValue(booleanLiteral, variableName);
463 }
464
465 void FunctionDefinitionWriter::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
466 {
467     ASSERT(enumerationMemberLiteral.enumerationDefinition());
468     ASSERT(enumerationMemberLiteral.enumerationDefinition());
469     auto variableName = generateNextVariableName();
470     auto mangledTypeName = m_typeNamer.mangledNameForType(enumerationMemberLiteral.resolvedType());
471     m_stringBuilder.append(mangledTypeName, ' ', variableName, " = ", mangledTypeName, "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()), ";\n");
472     appendRightValue(enumerationMemberLiteral, variableName);
473 }
474
475 void FunctionDefinitionWriter::visit(AST::Expression& expression)
476 {
477     Visitor::visit(expression);
478 }
479
480 void FunctionDefinitionWriter::visit(AST::DotExpression& dotExpression)
481 {
482     // This should be lowered already.
483     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
484     notImplemented();
485     appendRightValue(dotExpression, generateNextVariableName());
486 }
487
488 void FunctionDefinitionWriter::visit(AST::GlobalVariableReference& globalVariableReference)
489 {
490     auto valueName = generateNextVariableName();
491     auto pointerName = generateNextVariableName();
492     auto mangledTypeName = m_typeNamer.mangledNameForType(globalVariableReference.resolvedType());
493     checkErrorAndVisit(globalVariableReference.base());
494     m_stringBuilder.append(
495         "thread ", mangledTypeName, "* ", pointerName, " = &", takeLastValue(), "->", m_typeNamer.mangledNameForStructureElement(globalVariableReference.structField()), ";\n",
496         mangledTypeName, ' ', valueName, " = ", "*", pointerName, ";\n"
497     );
498     appendLeftValue(globalVariableReference, valueName, pointerName, Nullability::NotNull);
499 }
500
501 void FunctionDefinitionWriter::visit(AST::IndexExpression& indexExpression)
502 {
503     // This should be lowered already.
504     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
505     notImplemented();
506     appendRightValue(indexExpression, generateNextVariableName());
507 }
508
509 void FunctionDefinitionWriter::visit(AST::PropertyAccessExpression& propertyAccessExpression)
510 {
511     // This should be lowered already.
512     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
513     notImplemented();
514     appendRightValue(propertyAccessExpression, generateNextVariableName());
515 }
516
517 void FunctionDefinitionWriter::visit(AST::VariableDeclaration& variableDeclaration)
518 {
519     ASSERT(variableDeclaration.type());
520     auto variableName = generateNextVariableName();
521     auto addResult = m_variableMapping.add(&variableDeclaration, variableName);
522     ASSERT_UNUSED(addResult, addResult.isNewEntry);
523     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198160 Implement qualifiers.
524     if (variableDeclaration.initializer()) {
525         checkErrorAndVisit(*variableDeclaration.initializer());
526         m_stringBuilder.append(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = ", takeLastValue(), ";\n");
527     } else
528         m_stringBuilder.append(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = { };\n");
529 }
530
531 void FunctionDefinitionWriter::visit(AST::AssignmentExpression& assignmentExpression)
532 {
533     checkErrorAndVisit(assignmentExpression.left());
534     auto [pointerName, nullability] = takeLastLeftValue();
535     checkErrorAndVisit(assignmentExpression.right());
536     auto [rightName, rightNullability] = takeLastValueAndNullability();
537     if (nullability == Nullability::CanBeNull)
538         m_stringBuilder.append("if (", pointerName, ") *", pointerName, " = ", rightName, ";\n");
539     else
540         m_stringBuilder.append("*", pointerName, " = ", rightName, ";\n");
541     appendRightValueWithNullability(assignmentExpression, rightName, rightNullability);
542 }
543
544 void FunctionDefinitionWriter::visit(AST::CallExpression& callExpression)
545 {
546     Vector<MangledVariableName> argumentNames;
547     for (auto& argument : callExpression.arguments()) {
548         checkErrorAndVisit(argument);
549         argumentNames.append(takeLastValue());
550     }
551
552     bool isVoid = matches(callExpression.resolvedType(), m_intrinsics.voidType());
553     MangledVariableName returnName;
554     if (!isVoid) {
555         returnName = generateNextVariableName();
556         m_stringBuilder.append(m_typeNamer.mangledNameForType(callExpression.resolvedType()), ' ', returnName, ";\n");
557     }
558
559     if (is<AST::NativeFunctionDeclaration>(callExpression.function()))
560         inlineNativeFunction(m_stringBuilder, downcast<AST::NativeFunctionDeclaration>(callExpression.function()), returnName, argumentNames, m_intrinsics, m_typeNamer);
561     else {
562         auto iterator = m_functionMapping.find(&callExpression.function());
563         ASSERT(iterator != m_functionMapping.end());
564         if (!isVoid)
565             m_stringBuilder.append(returnName, " = ");
566         m_stringBuilder.append(iterator->value, '(');
567         for (size_t i = 0; i < argumentNames.size(); ++i) {
568             if (i)
569                 m_stringBuilder.append(", ");
570             m_stringBuilder.append(argumentNames[i]);
571         }
572         m_stringBuilder.append(");\n");
573     }
574
575     appendRightValue(callExpression, returnName);
576 }
577
578 void FunctionDefinitionWriter::visit(AST::CommaExpression& commaExpression)
579 {
580     Optional<MangledVariableName> result;
581     for (auto& expression : commaExpression.list()) {
582         checkErrorAndVisit(expression);
583         result = takeLastValue();
584     }
585     ASSERT(result);
586     appendRightValue(commaExpression, *result);
587 }
588
589 void FunctionDefinitionWriter::visit(AST::DereferenceExpression& dereferenceExpression)
590 {
591     checkErrorAndVisit(dereferenceExpression.pointer());
592     auto [inputPointer, nullability] = takeLastValueAndNullability();
593     auto resultValue = generateNextVariableName();
594     auto resultPointer = generateNextVariableName();
595     m_stringBuilder.append(
596         m_typeNamer.mangledNameForType(dereferenceExpression.pointer().resolvedType()), ' ', resultPointer, " = ", inputPointer, ";\n",
597         m_typeNamer.mangledNameForType(dereferenceExpression.resolvedType()), ' ', resultValue, ";\n");
598     if (nullability == Nullability::CanBeNull) {
599         m_stringBuilder.append(
600             "if (", resultPointer, ") ", resultValue, " = *", inputPointer, ";\n",
601             "else ", resultValue, " = { };\n");
602     } else
603         m_stringBuilder.append(resultValue, " = *", inputPointer, ";\n");
604     appendLeftValue(dereferenceExpression, resultValue, resultPointer, nullability);
605 }
606
607 void FunctionDefinitionWriter::visit(AST::LogicalExpression& logicalExpression)
608 {
609     checkErrorAndVisit(logicalExpression.left());
610     auto left = takeLastValue();
611     checkErrorAndVisit(logicalExpression.right());
612     auto right = takeLastValue();
613     auto variableName = generateNextVariableName();
614     m_stringBuilder.append(m_typeNamer.mangledNameForType(logicalExpression.resolvedType()), ' ', variableName, " = ", left);
615     switch (logicalExpression.type()) {
616     case AST::LogicalExpression::Type::And:
617         m_stringBuilder.append(" && ");
618         break;
619     default:
620         ASSERT(logicalExpression.type() == AST::LogicalExpression::Type::Or);
621         m_stringBuilder.append(" || ");
622         break;
623     }
624     m_stringBuilder.append(right, ";\n");
625     appendRightValue(logicalExpression, variableName);
626 }
627
628 void FunctionDefinitionWriter::visit(AST::LogicalNotExpression& logicalNotExpression)
629 {
630     checkErrorAndVisit(logicalNotExpression.operand());
631     auto operand = takeLastValue();
632     auto variableName = generateNextVariableName();
633     m_stringBuilder.append(m_typeNamer.mangledNameForType(logicalNotExpression.resolvedType()), ' ', variableName, " = !", operand, ";\n");
634     appendRightValue(logicalNotExpression, variableName);
635 }
636
637 void FunctionDefinitionWriter::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
638 {
639     checkErrorAndVisit(makeArrayReferenceExpression.leftValue());
640     // FIXME: This needs to be made to work. It probably should be using the last leftValue too.
641     // https://bugs.webkit.org/show_bug.cgi?id=198838
642     auto variableName = generateNextVariableName();
643
644     auto mangledTypeName = m_typeNamer.mangledNameForType(makeArrayReferenceExpression.resolvedType());
645     if (is<AST::PointerType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
646         auto ptrValue = takeLastValue();
647         m_stringBuilder.append(
648             mangledTypeName, ' ', variableName, ";\n",
649             "if (", ptrValue, ") ", variableName, " = { ", ptrValue, ", 1};\n",
650             "else ", variableName, " = { nullptr, 0 };\n"
651         );
652     } else if (is<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
653         auto lValue = takeLastLeftValue().value;
654         auto& arrayType = downcast<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType());
655         m_stringBuilder.append(mangledTypeName, ' ', variableName, " = { ", lValue, "->data(), ", arrayType.numElements(), " };\n");
656     } else {
657         auto lValue = takeLastLeftValue().value;
658         m_stringBuilder.append(mangledTypeName, ' ', variableName, " = { ", lValue, ", 1 };\n");
659     }
660     appendRightValue(makeArrayReferenceExpression, variableName);
661 }
662
663 void FunctionDefinitionWriter::visit(AST::MakePointerExpression& makePointerExpression)
664 {
665     checkErrorAndVisit(makePointerExpression.leftValue());
666     auto [pointer, nullability] = takeLastLeftValue();
667     auto variableName = generateNextVariableName();
668     m_stringBuilder.append(m_typeNamer.mangledNameForType(makePointerExpression.resolvedType()), ' ', variableName, " = ", pointer, ";\n");
669     appendRightValueWithNullability(makePointerExpression, variableName, nullability);
670 }
671
672 void FunctionDefinitionWriter::visit(AST::ReadModifyWriteExpression&)
673 {
674     // This should be lowered already.
675     ASSERT_NOT_REACHED();
676 }
677
678 void FunctionDefinitionWriter::visit(AST::TernaryExpression& ternaryExpression)
679 {
680     checkErrorAndVisit(ternaryExpression.predicate());
681     auto check = takeLastValue();
682     checkErrorAndVisit(ternaryExpression.bodyExpression());
683     auto body = takeLastValue();
684     checkErrorAndVisit(ternaryExpression.elseExpression());
685     auto elseBody = takeLastValue();
686
687     auto variableName = generateNextVariableName();
688     m_stringBuilder.append(m_typeNamer.mangledNameForType(ternaryExpression.resolvedType()), ' ', variableName, " = ", check, " ? ", body, " : ", elseBody, ";\n");
689     appendRightValue(ternaryExpression, variableName);
690 }
691
692 void FunctionDefinitionWriter::visit(AST::VariableReference& variableReference)
693 {
694     ASSERT(variableReference.variable());
695     auto iterator = m_variableMapping.find(variableReference.variable());
696     ASSERT(iterator != m_variableMapping.end());
697     auto pointerName = generateNextVariableName();
698     m_stringBuilder.append("thread ", m_typeNamer.mangledNameForType(variableReference.resolvedType()), "* ", pointerName, " = &", iterator->value, ";\n");
699     appendLeftValue(variableReference, iterator->value, pointerName, Nullability::NotNull);
700 }
701
702 void FunctionDefinitionWriter::emitConstantExpressionString(AST::ConstantExpression& constantExpression)
703 {
704     constantExpression.visit(WTF::makeVisitor(
705         [&](AST::IntegerLiteral& integerLiteral) {
706             m_stringBuilder.append(integerLiteral.value());
707         },
708         [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) {
709             m_stringBuilder.append(unsignedIntegerLiteral.value());
710         },
711         [&](AST::FloatLiteral& floatLiteral) {
712             m_stringBuilder.append(floatLiteral.value());
713         },
714         [&](AST::NullLiteral&) {
715             m_stringBuilder.append("nullptr");
716         },
717         [&](AST::BooleanLiteral& booleanLiteral) {
718             if (booleanLiteral.value())
719                 m_stringBuilder.append("true");
720             else
721                 m_stringBuilder.append("false");
722         },
723         [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) {
724             ASSERT(enumerationMemberLiteral.enumerationDefinition());
725             ASSERT(enumerationMemberLiteral.enumerationDefinition());
726             m_stringBuilder.append(m_typeNamer.mangledNameForType(*enumerationMemberLiteral.enumerationDefinition()), "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()));
727         }
728     ));
729 }
730
731 class RenderFunctionDefinitionWriter final : public FunctionDefinitionWriter {
732 public:
733     RenderFunctionDefinitionWriter(StringBuilder& stringBuilder, Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
734         : FunctionDefinitionWriter(stringBuilder, intrinsics, typeNamer, functionMapping, layout)
735         , m_matchedSemantics(WTFMove(matchedSemantics))
736     {
737     }
738
739 private:
740     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
741
742     MatchedRenderSemantics m_matchedSemantics;
743 };
744
745 std::unique_ptr<EntryPointScaffolding> RenderFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
746 {
747     auto generateNextVariableName = [this]() -> MangledVariableName {
748         return this->generateNextVariableName();
749     };
750     if (&functionDefinition == m_matchedSemantics.vertexShader)
751         return std::make_unique<VertexEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.vertexShaderEntryPointItems, m_matchedSemantics.vertexShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedVertexAttributes);
752     if (&functionDefinition == m_matchedSemantics.fragmentShader)
753         return std::make_unique<FragmentEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.fragmentShaderEntryPointItems, m_matchedSemantics.fragmentShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedColorAttachments);
754     return nullptr;
755 }
756
757 class ComputeFunctionDefinitionWriter final : public FunctionDefinitionWriter {
758 public:
759     ComputeFunctionDefinitionWriter(StringBuilder& stringBuilder, Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
760         : FunctionDefinitionWriter(stringBuilder, intrinsics, typeNamer, functionMapping, layout)
761         , m_matchedSemantics(WTFMove(matchedSemantics))
762     {
763     }
764
765 private:
766     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
767
768     MatchedComputeSemantics m_matchedSemantics;
769 };
770
771 std::unique_ptr<EntryPointScaffolding> ComputeFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
772 {
773     auto generateNextVariableName = [this]() -> MangledVariableName {
774         return this->generateNextVariableName();
775     };
776     if (&functionDefinition == m_matchedSemantics.shader)
777         return std::make_unique<ComputeEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.entryPointItems, m_matchedSemantics.resourceMap, m_layout, WTFMove(generateNextVariableName));
778     return nullptr;
779 }
780
781 static HashMap<AST::FunctionDeclaration*, MangledFunctionName> generateMetalFunctionsMapping(Program& program)
782 {
783     unsigned numFunctions = 0;
784     HashMap<AST::FunctionDeclaration*, MangledFunctionName> functionMapping;
785     for (auto& functionDefinition : program.functionDefinitions()) {
786         auto addResult = functionMapping.add(&functionDefinition, MangledFunctionName { numFunctions++ });
787         ASSERT_UNUSED(addResult, addResult.isNewEntry);
788     }
789
790     return functionMapping;
791 }
792
793 static void emitSharedMetalFunctions(StringBuilder& stringBuilder, Program& program, TypeNamer& typeNamer, const HashSet<AST::FunctionDeclaration*>& reachableFunctions, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping)
794 {
795     for (auto& functionDefinition : program.functionDefinitions()) {
796         if (!functionDefinition->entryPointType() && reachableFunctions.contains(&functionDefinition))
797             declareFunction(stringBuilder, functionDefinition, typeNamer, functionMapping);
798     }
799
800     stringBuilder.append('\n');
801 }
802
803 class ReachableFunctionsGatherer final : public Visitor {
804 public:
805     void visit(AST::FunctionDeclaration& functionDeclaration) override
806     {
807         auto result = m_reachableFunctions.add(&functionDeclaration);
808         if (result.isNewEntry)
809             Visitor::visit(functionDeclaration);
810     }
811
812     void visit(AST::CallExpression& callExpression) override
813     {
814         Visitor::visit(callExpression);
815         if (is<AST::FunctionDefinition>(callExpression.function()))
816             checkErrorAndVisit(downcast<AST::FunctionDefinition>(callExpression.function()));
817         else
818             checkErrorAndVisit(downcast<AST::NativeFunctionDeclaration>(callExpression.function()));
819     }
820
821     HashSet<AST::FunctionDeclaration*> takeReachableFunctions() { return WTFMove(m_reachableFunctions); }
822
823 private:
824     HashSet<AST::FunctionDeclaration*> m_reachableFunctions;
825 };
826
827 RenderMetalFunctionEntryPoints emitMetalFunctions(StringBuilder& stringBuilder, Program& program, TypeNamer& typeNamer, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
828 {
829     auto& vertexShaderEntryPoint = *matchedSemantics.vertexShader;
830     auto& fragmentShaderEntryPoint = *matchedSemantics.fragmentShader;
831
832     ReachableFunctionsGatherer reachableFunctionsGatherer;
833     reachableFunctionsGatherer.Visitor::visit(vertexShaderEntryPoint);
834     reachableFunctionsGatherer.Visitor::visit(fragmentShaderEntryPoint);
835     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
836
837     auto functionMapping = generateMetalFunctionsMapping(program);
838     
839     emitSharedMetalFunctions(stringBuilder, program, typeNamer, reachableFunctions, functionMapping);
840
841     RenderFunctionDefinitionWriter functionDefinitionWriter(stringBuilder, program.intrinsics(), typeNamer, functionMapping, WTFMove(matchedSemantics), layout);
842     for (auto& functionDefinition : program.functionDefinitions()) {
843         if (reachableFunctions.contains(&functionDefinition))
844             functionDefinitionWriter.visit(functionDefinition);
845     }
846
847     return { functionMapping.get(&vertexShaderEntryPoint), functionMapping.get(&fragmentShaderEntryPoint) };
848 }
849
850 ComputeMetalFunctionEntryPoints emitMetalFunctions(StringBuilder& stringBuilder, Program& program, TypeNamer& typeNamer, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
851 {
852     auto& entryPoint = *matchedSemantics.shader;
853
854     ReachableFunctionsGatherer reachableFunctionsGatherer;
855     reachableFunctionsGatherer.Visitor::visit(entryPoint);
856     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
857
858     auto functionMapping = generateMetalFunctionsMapping(program);
859     emitSharedMetalFunctions(stringBuilder, program, typeNamer, reachableFunctions, functionMapping);
860
861     ComputeFunctionDefinitionWriter functionDefinitionWriter(stringBuilder, program.intrinsics(), typeNamer, functionMapping, WTFMove(matchedSemantics), layout);
862     for (auto& functionDefinition : program.functionDefinitions()) {
863         if (reachableFunctions.contains(&functionDefinition))
864             functionDefinitionWriter.visit(functionDefinition);
865     }
866
867     return { functionMapping.get(&entryPoint) };
868 }
869
870 } // namespace Metal
871
872 } // namespace WHLSL
873
874 } // namespace WebCore
875
876 #endif