Mangled WHLSL names don't need to allocate Strings
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / Metal / WHLSLFunctionWriter.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLFunctionWriter.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "NotImplemented.h"
32 #include "WHLSLAST.h"
33 #include "WHLSLEntryPointScaffolding.h"
34 #include "WHLSLInferTypes.h"
35 #include "WHLSLNativeFunctionWriter.h"
36 #include "WHLSLProgram.h"
37 #include "WHLSLTypeNamer.h"
38 #include "WHLSLVisitor.h"
39 #include <wtf/HashMap.h>
40 #include <wtf/HashSet.h>
41 #include <wtf/SetForScope.h>
42 #include <wtf/text/StringBuilder.h>
43
44 namespace WebCore {
45
46 namespace WHLSL {
47
48 namespace Metal {
49
50 class FunctionDeclarationWriter : public Visitor {
51 public:
52     FunctionDeclarationWriter(TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping)
53         : m_typeNamer(typeNamer)
54         , m_functionMapping(functionMapping)
55     {
56     }
57
58     virtual ~FunctionDeclarationWriter() = default;
59
60     String toString() { return m_stringBuilder.toString(); }
61
62     void visit(AST::FunctionDeclaration&) override;
63
64 private:
65     TypeNamer& m_typeNamer;
66     HashMap<AST::FunctionDeclaration*, MangledFunctionName>& m_functionMapping;
67     StringBuilder m_stringBuilder;
68 };
69
70 void FunctionDeclarationWriter::visit(AST::FunctionDeclaration& functionDeclaration)
71 {
72     if (functionDeclaration.entryPointType())
73         return;
74
75     auto iterator = m_functionMapping.find(&functionDeclaration);
76     ASSERT(iterator != m_functionMapping.end());
77     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(functionDeclaration.type()), ' ', iterator->value, '(');
78     for (size_t i = 0; i < functionDeclaration.parameters().size(); ++i) {
79         if (i)
80             m_stringBuilder.append(", ");
81         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*functionDeclaration.parameters()[i]->type()));
82     }
83     m_stringBuilder.append(");\n");
84 }
85
86 class FunctionDefinitionWriter : public Visitor {
87 public:
88     FunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping, Layout& layout)
89         : m_intrinsics(intrinsics)
90         , m_typeNamer(typeNamer)
91         , m_functionMapping(functionMapping)
92         , m_layout(layout)
93     {
94     }
95
96     virtual ~FunctionDefinitionWriter() = default;
97
98     String toString() { return m_stringBuilder.toString(); }
99
100     void visit(AST::NativeFunctionDeclaration&) override;
101     void visit(AST::FunctionDefinition&) override;
102
103 protected:
104     virtual std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) = 0;
105
106     void visit(AST::FunctionDeclaration&) override;
107     void visit(AST::Statement&) override;
108     void visit(AST::Block&) override;
109     void visit(AST::Break&) override;
110     void visit(AST::Continue&) override;
111     void visit(AST::DoWhileLoop&) override;
112     void visit(AST::EffectfulExpressionStatement&) override;
113     void visit(AST::Fallthrough&) override;
114     void visit(AST::ForLoop&) override;
115     void visit(AST::IfStatement&) override;
116     void visit(AST::Return&) override;
117     void visit(AST::SwitchStatement&) override;
118     void visit(AST::SwitchCase&) override;
119     void visit(AST::VariableDeclarationsStatement&) override;
120     void visit(AST::WhileLoop&) override;
121     void visit(AST::IntegerLiteral&) override;
122     void visit(AST::UnsignedIntegerLiteral&) override;
123     void visit(AST::FloatLiteral&) override;
124     void visit(AST::NullLiteral&) override;
125     void visit(AST::BooleanLiteral&) override;
126     void visit(AST::EnumerationMemberLiteral&) override;
127     void visit(AST::Expression&) override;
128     void visit(AST::DotExpression&) override;
129     void visit(AST::GlobalVariableReference&) override;
130     void visit(AST::IndexExpression&) override;
131     void visit(AST::PropertyAccessExpression&) override;
132     void visit(AST::VariableDeclaration&) override;
133     void visit(AST::AssignmentExpression&) override;
134     void visit(AST::CallExpression&) override;
135     void visit(AST::CommaExpression&) override;
136     void visit(AST::DereferenceExpression&) override;
137     void visit(AST::LogicalExpression&) override;
138     void visit(AST::LogicalNotExpression&) override;
139     void visit(AST::MakeArrayReferenceExpression&) override;
140     void visit(AST::MakePointerExpression&) override;
141     void visit(AST::ReadModifyWriteExpression&) override;
142     void visit(AST::TernaryExpression&) override;
143     void visit(AST::VariableReference&) override;
144
145     enum class LoopConditionLocation {
146         BeforeBody,
147         AfterBody
148     };
149     void emitLoop(LoopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body);
150
151     String constantExpressionString(AST::ConstantExpression&);
152
153     MangledVariableName generateNextVariableName() { return { m_variableCount++ }; }
154
155     enum class Nullability : uint8_t {
156         NotNull,
157         CanBeNull
158     };
159
160     struct StackItem {
161         MangledVariableName value;
162         Optional<MangledVariableName> leftValue;
163         Nullability valueNullability;
164         Nullability leftValueNullability;
165     };
166
167     struct StackValue {
168         MangledVariableName value;
169         Nullability nullability;
170     };
171
172     // This is the important data flow step where we can take the nullability of an lvalue
173     // and transfer it into the nullability of an rvalue. This is conveyed in MakePointerExpression
174     // and DereferenceExpression. MakePointerExpression will try to produce rvalues which are
175     // non-null, and DereferenceExpression will take a non-null rvalue and try to produce
176     // a non-null lvalue.
177     void appendRightValueWithNullability(AST::Expression&, MangledVariableName value, Nullability nullability)
178     {
179         m_stack.append({ WTFMove(value), WTF::nullopt, nullability, Nullability::CanBeNull });
180     }
181
182     void appendRightValue(AST::Expression& expression, MangledVariableName value)
183     {
184         appendRightValueWithNullability(expression, WTFMove(value), Nullability::CanBeNull);
185     }
186
187     void appendLeftValue(AST::Expression& expression, MangledVariableName value, MangledVariableName leftValue, Nullability nullability)
188     {
189         ASSERT_UNUSED(expression, expression.typeAnnotation().leftAddressSpace());
190         m_stack.append({ WTFMove(value), WTFMove(leftValue), Nullability::CanBeNull, nullability });
191     }
192
193     MangledVariableName takeLastValue()
194     {
195         return m_stack.takeLast().value;
196     }
197
198     StackValue takeLastValueAndNullability()
199     {
200         auto last = m_stack.takeLast();
201         return { last.value, last.valueNullability };
202     }
203
204     StackValue takeLastLeftValue()
205     {
206         ASSERT(m_stack.last().leftValue);
207         auto last = m_stack.takeLast();
208         return { *last.leftValue, last.leftValueNullability };
209     }
210
211     enum class BreakContext {
212         Loop,
213         Switch
214     };
215
216     Optional<BreakContext> m_currentBreakContext;
217
218     Intrinsics& m_intrinsics;
219     TypeNamer& m_typeNamer;
220     HashMap<AST::FunctionDeclaration*, MangledFunctionName>& m_functionMapping;
221     HashMap<AST::VariableDeclaration*, MangledVariableName> m_variableMapping;
222     StringBuilder m_stringBuilder;
223
224     Vector<StackItem> m_stack;
225     std::unique_ptr<EntryPointScaffolding> m_entryPointScaffolding;
226     Layout& m_layout;
227     unsigned m_variableCount { 0 };
228     Optional<MangledVariableName> m_breakOutOfCurrentLoopEarlyVariable;
229 };
230
231 void FunctionDefinitionWriter::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
232 {
233     auto iterator = m_functionMapping.find(&nativeFunctionDeclaration);
234     ASSERT(iterator != m_functionMapping.end());
235     m_stringBuilder.append(writeNativeFunction(nativeFunctionDeclaration, iterator->value, m_intrinsics, m_typeNamer));
236 }
237
238 void FunctionDefinitionWriter::visit(AST::FunctionDefinition& functionDefinition)
239 {
240     auto iterator = m_functionMapping.find(&functionDefinition);
241     ASSERT(iterator != m_functionMapping.end());
242     if (functionDefinition.entryPointType()) {
243         auto entryPointScaffolding = createEntryPointScaffolding(functionDefinition);
244         if (!entryPointScaffolding)
245             return;
246         m_entryPointScaffolding = WTFMove(entryPointScaffolding);
247         m_stringBuilder.flexibleAppend(
248             m_entryPointScaffolding->helperTypes(), '\n',
249             m_entryPointScaffolding->signature(iterator->value), " {\n",
250             m_entryPointScaffolding->unpack()
251         );
252         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
253             auto addResult = m_variableMapping.add(&functionDefinition.parameters()[i], m_entryPointScaffolding->parameterVariables()[i]);
254             ASSERT_UNUSED(addResult, addResult.isNewEntry);
255         }
256         checkErrorAndVisit(functionDefinition.block());
257         ASSERT(m_stack.isEmpty());
258         m_stringBuilder.append("}\n");
259         m_entryPointScaffolding = nullptr;
260     } else {
261         ASSERT(m_entryPointScaffolding == nullptr);
262         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(functionDefinition.type()), ' ', iterator->value, '(');
263         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
264             auto& parameter = functionDefinition.parameters()[i];
265             if (i)
266                 m_stringBuilder.append(", ");
267             auto parameterName = generateNextVariableName();
268             auto addResult = m_variableMapping.add(&parameter, parameterName);
269             ASSERT_UNUSED(addResult, addResult.isNewEntry);
270             m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*parameter->type()), ' ', parameterName);
271         }
272         m_stringBuilder.append(") {\n");
273         checkErrorAndVisit(functionDefinition.block());
274         ASSERT(m_stack.isEmpty());
275         m_stringBuilder.append("}\n");
276     }
277 }
278
279 void FunctionDefinitionWriter::visit(AST::FunctionDeclaration&)
280 {
281     ASSERT_NOT_REACHED();
282 }
283
284 void FunctionDefinitionWriter::visit(AST::Statement& statement)
285 {
286     Visitor::visit(statement);
287 }
288
289 void FunctionDefinitionWriter::visit(AST::Block& block)
290 {
291     m_stringBuilder.append("{\n");
292     for (auto& statement : block.statements())
293         checkErrorAndVisit(statement);
294     m_stringBuilder.append("}\n");
295 }
296
297 void FunctionDefinitionWriter::visit(AST::Break&)
298 {
299     ASSERT(m_currentBreakContext);
300     switch (*m_currentBreakContext) {
301     case BreakContext::Switch:
302         m_stringBuilder.append("break;\n");
303         break;
304     case BreakContext::Loop:
305         ASSERT(m_breakOutOfCurrentLoopEarlyVariable);
306         m_stringBuilder.flexibleAppend(
307             *m_breakOutOfCurrentLoopEarlyVariable, " = true;\n"
308             "break;\n"
309         );
310         break;
311     }
312 }
313
314 void FunctionDefinitionWriter::visit(AST::Continue&)
315 {
316     ASSERT(m_breakOutOfCurrentLoopEarlyVariable);
317     m_stringBuilder.append("break;\n");
318 }
319
320 void FunctionDefinitionWriter::visit(AST::EffectfulExpressionStatement& effectfulExpressionStatement)
321 {
322     checkErrorAndVisit(effectfulExpressionStatement.effectfulExpression());
323     takeLastValue(); // The statement is already effectful, so we don't need to do anything with the result.
324 }
325
326 void FunctionDefinitionWriter::visit(AST::Fallthrough&)
327 {
328     m_stringBuilder.append("[[clang::fallthrough]];\n"); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195808 Make sure this is okay. Alternatively, we could do nothing and just return here instead.
329 }
330
331 void FunctionDefinitionWriter::emitLoop(LoopConditionLocation loopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body)
332 {
333     SetForScope<Optional<MangledVariableName>> loopVariableScope(m_breakOutOfCurrentLoopEarlyVariable, generateNextVariableName());
334
335     m_stringBuilder.flexibleAppend(
336         "bool ", *m_breakOutOfCurrentLoopEarlyVariable, " = false;\n",
337         "while (true) {\n"
338     );
339
340     if (loopConditionLocation == LoopConditionLocation::BeforeBody && conditionExpression) {
341         checkErrorAndVisit(*conditionExpression);
342         m_stringBuilder.flexibleAppend("if (!", takeLastValue(), ") break;\n");
343     }
344
345     m_stringBuilder.append("do {\n");
346     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Loop);
347     checkErrorAndVisit(body);
348     m_stringBuilder.flexibleAppend(
349         "} while(false); \n"
350         "if (", *m_breakOutOfCurrentLoopEarlyVariable, ") break;\n"
351     );
352
353     if (increment) {
354         checkErrorAndVisit(*increment);
355         // Expression results get pushed to m_stack. We don't use the result
356         // of increment, so we dispense of that now.
357         takeLastValue();
358     }
359
360     if (loopConditionLocation == LoopConditionLocation::AfterBody && conditionExpression) {
361         checkErrorAndVisit(*conditionExpression);
362         m_stringBuilder.flexibleAppend("if (!", takeLastValue(), ") break;\n");
363     }
364
365     m_stringBuilder.append("} \n");
366 }
367
368 void FunctionDefinitionWriter::visit(AST::DoWhileLoop& doWhileLoop)
369 {
370     emitLoop(LoopConditionLocation::AfterBody, &doWhileLoop.conditional(), nullptr, doWhileLoop.body());
371 }
372
373 void FunctionDefinitionWriter::visit(AST::WhileLoop& whileLoop)
374 {
375     emitLoop(LoopConditionLocation::BeforeBody, &whileLoop.conditional(), nullptr, whileLoop.body());
376 }
377
378 void FunctionDefinitionWriter::visit(AST::ForLoop& forLoop)
379 {
380     m_stringBuilder.append("{\n");
381
382     WTF::visit(WTF::makeVisitor([&](AST::Statement& statement) {
383         checkErrorAndVisit(statement);
384     }, [&](UniqueRef<AST::Expression>& expression) {
385         checkErrorAndVisit(expression);
386         takeLastValue(); // We don't need to do anything with the result.
387     }), forLoop.initialization());
388
389     emitLoop(LoopConditionLocation::BeforeBody, forLoop.condition(), forLoop.increment(), forLoop.body());
390     m_stringBuilder.append("}\n");
391 }
392
393 void FunctionDefinitionWriter::visit(AST::IfStatement& ifStatement)
394 {
395     checkErrorAndVisit(ifStatement.conditional());
396     m_stringBuilder.flexibleAppend("if (", takeLastValue(), ") {\n");
397     checkErrorAndVisit(ifStatement.body());
398     if (ifStatement.elseBody()) {
399         m_stringBuilder.append("} else {\n");
400         checkErrorAndVisit(*ifStatement.elseBody());
401     }
402     m_stringBuilder.append("}\n");
403 }
404
405 void FunctionDefinitionWriter::visit(AST::Return& returnStatement)
406 {
407     if (returnStatement.value()) {
408         checkErrorAndVisit(*returnStatement.value());
409         if (m_entryPointScaffolding) {
410             auto variableName = generateNextVariableName();
411             m_stringBuilder.flexibleAppend(
412                 m_entryPointScaffolding->pack(takeLastValue(), variableName),
413                 "return ", variableName, ";\n"
414             );
415         } else
416             m_stringBuilder.flexibleAppend("return ", takeLastValue(), ";\n");
417     } else
418         m_stringBuilder.append("return;\n");
419 }
420
421 void FunctionDefinitionWriter::visit(AST::SwitchStatement& switchStatement)
422 {
423     checkErrorAndVisit(switchStatement.value());
424
425     m_stringBuilder.flexibleAppend("switch (", takeLastValue(), ") {");
426     for (auto& switchCase : switchStatement.switchCases())
427         checkErrorAndVisit(switchCase);
428     m_stringBuilder.append("}\n");
429 }
430
431 void FunctionDefinitionWriter::visit(AST::SwitchCase& switchCase)
432 {
433     if (switchCase.value())
434         m_stringBuilder.flexibleAppend("case ", constantExpressionString(*switchCase.value()), ":\n");
435     else
436         m_stringBuilder.append("default:\n");
437     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Switch);
438     checkErrorAndVisit(switchCase.block());
439     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195812 Figure out whether we need to break or fallthrough.
440 }
441
442 void FunctionDefinitionWriter::visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement)
443 {
444     Visitor::visit(variableDeclarationsStatement);
445 }
446
447 void FunctionDefinitionWriter::visit(AST::IntegerLiteral& integerLiteral)
448 {
449     auto variableName = generateNextVariableName();
450     auto mangledTypeName = m_typeNamer.mangledNameForType(integerLiteral.resolvedType());
451     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", integerLiteral.value(), ");\n");
452     appendRightValue(integerLiteral, variableName);
453 }
454
455 void FunctionDefinitionWriter::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
456 {
457     auto variableName = generateNextVariableName();
458     auto mangledTypeName = m_typeNamer.mangledNameForType(unsignedIntegerLiteral.resolvedType());
459     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", unsignedIntegerLiteral.value(), ");\n");
460     appendRightValue(unsignedIntegerLiteral, variableName);
461 }
462
463 void FunctionDefinitionWriter::visit(AST::FloatLiteral& floatLiteral)
464 {
465     auto variableName = generateNextVariableName();
466     auto mangledTypeName = m_typeNamer.mangledNameForType(floatLiteral.resolvedType());
467     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", floatLiteral.value(), ");\n");
468     appendRightValue(floatLiteral, variableName);
469 }
470
471 void FunctionDefinitionWriter::visit(AST::NullLiteral& nullLiteral)
472 {
473     auto& unifyNode = nullLiteral.resolvedType().unifyNode();
474     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
475     bool isArrayReferenceType = is<AST::ArrayReferenceType>(unnamedType);
476
477     auto variableName = generateNextVariableName();
478     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(nullLiteral.resolvedType()), ' ', variableName, " = ");
479     if (isArrayReferenceType)
480         m_stringBuilder.append("{ nullptr, 0 };\n");
481     else
482         m_stringBuilder.append("nullptr;\n");
483     appendRightValue(nullLiteral, variableName);
484 }
485
486 void FunctionDefinitionWriter::visit(AST::BooleanLiteral& booleanLiteral)
487 {
488     auto variableName = generateNextVariableName();
489     auto mangledTypeName = m_typeNamer.mangledNameForType(booleanLiteral.resolvedType());
490     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", booleanLiteral.value() ? "true" : "false", ");\n");
491     appendRightValue(booleanLiteral, variableName);
492 }
493
494 void FunctionDefinitionWriter::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
495 {
496     ASSERT(enumerationMemberLiteral.enumerationDefinition());
497     ASSERT(enumerationMemberLiteral.enumerationDefinition());
498     auto variableName = generateNextVariableName();
499     auto mangledTypeName = m_typeNamer.mangledNameForType(enumerationMemberLiteral.resolvedType());
500     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = ", mangledTypeName, "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()), ";\n");
501     appendRightValue(enumerationMemberLiteral, variableName);
502 }
503
504 void FunctionDefinitionWriter::visit(AST::Expression& expression)
505 {
506     Visitor::visit(expression);
507 }
508
509 void FunctionDefinitionWriter::visit(AST::DotExpression& dotExpression)
510 {
511     // This should be lowered already.
512     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
513     notImplemented();
514     appendRightValue(dotExpression, generateNextVariableName());
515 }
516
517 void FunctionDefinitionWriter::visit(AST::GlobalVariableReference& globalVariableReference)
518 {
519     auto valueName = generateNextVariableName();
520     auto pointerName = generateNextVariableName();
521     auto mangledTypeName = m_typeNamer.mangledNameForType(globalVariableReference.resolvedType());
522     checkErrorAndVisit(globalVariableReference.base());
523     m_stringBuilder.flexibleAppend(
524         "thread ", mangledTypeName, "* ", pointerName, " = &", takeLastValue(), "->", m_typeNamer.mangledNameForStructureElement(globalVariableReference.structField()), ";\n",
525         mangledTypeName, ' ', valueName, " = ", "*", pointerName, ";\n"
526     );
527     appendLeftValue(globalVariableReference, valueName, pointerName, Nullability::NotNull);
528 }
529
530 void FunctionDefinitionWriter::visit(AST::IndexExpression& indexExpression)
531 {
532     // This should be lowered already.
533     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
534     notImplemented();
535     appendRightValue(indexExpression, generateNextVariableName());
536 }
537
538 void FunctionDefinitionWriter::visit(AST::PropertyAccessExpression& propertyAccessExpression)
539 {
540     // This should be lowered already.
541     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
542     notImplemented();
543     appendRightValue(propertyAccessExpression, generateNextVariableName());
544 }
545
546 void FunctionDefinitionWriter::visit(AST::VariableDeclaration& variableDeclaration)
547 {
548     ASSERT(variableDeclaration.type());
549     auto variableName = generateNextVariableName();
550     auto addResult = m_variableMapping.add(&variableDeclaration, variableName);
551     ASSERT_UNUSED(addResult, addResult.isNewEntry);
552     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198160 Implement qualifiers.
553     if (variableDeclaration.initializer()) {
554         checkErrorAndVisit(*variableDeclaration.initializer());
555         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = ", takeLastValue(), ";\n");
556     } else
557         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, ";\n");
558 }
559
560 void FunctionDefinitionWriter::visit(AST::AssignmentExpression& assignmentExpression)
561 {
562     checkErrorAndVisit(assignmentExpression.left());
563     auto [pointerName, nullability] = takeLastLeftValue();
564     checkErrorAndVisit(assignmentExpression.right());
565     auto [rightName, rightNullability] = takeLastValueAndNullability();
566     if (nullability == Nullability::CanBeNull)
567         m_stringBuilder.flexibleAppend("if (", pointerName, ") *", pointerName, " = ", rightName, ";\n");
568     else
569         m_stringBuilder.flexibleAppend("*", pointerName, " = ", rightName, ";\n");
570     appendRightValueWithNullability(assignmentExpression, rightName, rightNullability);
571 }
572
573 void FunctionDefinitionWriter::visit(AST::CallExpression& callExpression)
574 {
575     Vector<MangledVariableName> argumentNames;
576     for (auto& argument : callExpression.arguments()) {
577         checkErrorAndVisit(argument);
578         argumentNames.append(takeLastValue());
579     }
580     auto iterator = m_functionMapping.find(&callExpression.function());
581     ASSERT(iterator != m_functionMapping.end());
582     auto variableName = generateNextVariableName();
583     if (!matches(callExpression.resolvedType(), m_intrinsics.voidType()))
584         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(callExpression.resolvedType()), ' ', variableName, " = ");
585     m_stringBuilder.flexibleAppend(iterator->value, '(');
586     for (size_t i = 0; i < argumentNames.size(); ++i) {
587         if (i)
588             m_stringBuilder.append(", ");
589         m_stringBuilder.flexibleAppend(argumentNames[i]);
590     }
591     m_stringBuilder.append(");\n");
592     appendRightValue(callExpression, variableName);
593 }
594
595 void FunctionDefinitionWriter::visit(AST::CommaExpression& commaExpression)
596 {
597     Optional<MangledVariableName> result;
598     for (auto& expression : commaExpression.list()) {
599         checkErrorAndVisit(expression);
600         result = takeLastValue();
601     }
602     ASSERT(result);
603     appendRightValue(commaExpression, *result);
604 }
605
606 void FunctionDefinitionWriter::visit(AST::DereferenceExpression& dereferenceExpression)
607 {
608     checkErrorAndVisit(dereferenceExpression.pointer());
609     auto [inputPointer, nullability] = takeLastValueAndNullability();
610     auto resultValue = generateNextVariableName();
611     auto resultPointer = generateNextVariableName();
612     m_stringBuilder.flexibleAppend(
613         m_typeNamer.mangledNameForType(dereferenceExpression.pointer().resolvedType()), ' ', resultPointer, " = ", inputPointer, ";\n",
614         m_typeNamer.mangledNameForType(dereferenceExpression.resolvedType()), ' ', resultValue, ";\n");
615     if (nullability == Nullability::CanBeNull) {
616         m_stringBuilder.flexibleAppend(
617             "if (", resultPointer, ") ", resultValue, " = *", inputPointer, ";\n",
618             "else ", resultValue, " = { };\n");
619     } else
620         m_stringBuilder.flexibleAppend(resultValue, " = *", inputPointer, ";\n");
621     appendLeftValue(dereferenceExpression, resultValue, resultPointer, nullability);
622 }
623
624 void FunctionDefinitionWriter::visit(AST::LogicalExpression& logicalExpression)
625 {
626     checkErrorAndVisit(logicalExpression.left());
627     auto left = takeLastValue();
628     checkErrorAndVisit(logicalExpression.right());
629     auto right = takeLastValue();
630     auto variableName = generateNextVariableName();
631     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(logicalExpression.resolvedType()), ' ', variableName, " = ", left);
632     switch (logicalExpression.type()) {
633     case AST::LogicalExpression::Type::And:
634         m_stringBuilder.append(" && ");
635         break;
636     default:
637         ASSERT(logicalExpression.type() == AST::LogicalExpression::Type::Or);
638         m_stringBuilder.append(" || ");
639         break;
640     }
641     m_stringBuilder.flexibleAppend(right, ";\n");
642     appendRightValue(logicalExpression, variableName);
643 }
644
645 void FunctionDefinitionWriter::visit(AST::LogicalNotExpression& logicalNotExpression)
646 {
647     checkErrorAndVisit(logicalNotExpression.operand());
648     auto operand = takeLastValue();
649     auto variableName = generateNextVariableName();
650     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(logicalNotExpression.resolvedType()), ' ', variableName, " = !", operand, ";\n");
651     appendRightValue(logicalNotExpression, variableName);
652 }
653
654 void FunctionDefinitionWriter::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
655 {
656     checkErrorAndVisit(makeArrayReferenceExpression.leftValue());
657     // FIXME: This needs to be made to work. It probably should be using the last leftValue too.
658     // https://bugs.webkit.org/show_bug.cgi?id=198838
659     auto variableName = generateNextVariableName();
660
661     auto mangledTypeName = m_typeNamer.mangledNameForType(makeArrayReferenceExpression.resolvedType());
662     if (is<AST::PointerType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
663         auto ptrValue = takeLastValue();
664         m_stringBuilder.flexibleAppend(
665             mangledTypeName, ' ', variableName, ";\n",
666             "if (", ptrValue, ") ", variableName, " = { ", ptrValue, ", 1};\n",
667             "else ", variableName, " = { nullptr, 0 };\n"
668         );
669     } else if (is<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
670         auto lValue = takeLastLeftValue().value;
671         auto& arrayType = downcast<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType());
672         m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = { ", lValue, "->data(), ", arrayType.numElements(), " };\n");
673     } else {
674         auto lValue = takeLastLeftValue().value;
675         m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = { ", lValue, ", 1 };\n");
676     }
677     appendRightValue(makeArrayReferenceExpression, variableName);
678 }
679
680 void FunctionDefinitionWriter::visit(AST::MakePointerExpression& makePointerExpression)
681 {
682     checkErrorAndVisit(makePointerExpression.leftValue());
683     auto [pointer, nullability] = takeLastLeftValue();
684     auto variableName = generateNextVariableName();
685     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(makePointerExpression.resolvedType()), ' ', variableName, " = ", pointer, ";\n");
686     appendRightValueWithNullability(makePointerExpression, variableName, nullability);
687 }
688
689 void FunctionDefinitionWriter::visit(AST::ReadModifyWriteExpression&)
690 {
691     // This should be lowered already.
692     ASSERT_NOT_REACHED();
693 }
694
695 void FunctionDefinitionWriter::visit(AST::TernaryExpression& ternaryExpression)
696 {
697     checkErrorAndVisit(ternaryExpression.predicate());
698     auto check = takeLastValue();
699
700     auto variableName = generateNextVariableName();
701     m_stringBuilder.flexibleAppend(
702         m_typeNamer.mangledNameForType(ternaryExpression.resolvedType()), ' ', variableName, ";\n"
703         "if (", check, ") {\n"
704     );
705     checkErrorAndVisit(ternaryExpression.bodyExpression());
706     m_stringBuilder.flexibleAppend(
707         variableName, " = ", takeLastValue(), ";\n"
708         "} else {\n"
709     );
710     checkErrorAndVisit(ternaryExpression.elseExpression());
711     m_stringBuilder.flexibleAppend(
712         variableName, " = ", takeLastValue(), ";\n"
713         "}\n"
714     );
715     appendRightValue(ternaryExpression, variableName);
716 }
717
718 void FunctionDefinitionWriter::visit(AST::VariableReference& variableReference)
719 {
720     ASSERT(variableReference.variable());
721     auto iterator = m_variableMapping.find(variableReference.variable());
722     ASSERT(iterator != m_variableMapping.end());
723     auto pointerName = generateNextVariableName();
724     m_stringBuilder.flexibleAppend("thread ", m_typeNamer.mangledNameForType(variableReference.resolvedType()), "* ", pointerName, " = &", iterator->value, ";\n");
725     appendLeftValue(variableReference, iterator->value, pointerName, Nullability::NotNull);
726 }
727
728 String FunctionDefinitionWriter::constantExpressionString(AST::ConstantExpression& constantExpression)
729 {
730     return constantExpression.visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> String {
731         return makeString("", integerLiteral.value());
732     }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> String {
733         return makeString("", unsignedIntegerLiteral.value());
734     }, [&](AST::FloatLiteral& floatLiteral) -> String {
735         return makeString("", floatLiteral.value());
736     }, [&](AST::NullLiteral&) -> String {
737         return "nullptr"_str;
738     }, [&](AST::BooleanLiteral& booleanLiteral) -> String {
739         return booleanLiteral.value() ? "true"_str : "false"_str;
740     }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> String {
741         ASSERT(enumerationMemberLiteral.enumerationDefinition());
742         ASSERT(enumerationMemberLiteral.enumerationDefinition());
743         return makeString(m_typeNamer.mangledNameForType(*enumerationMemberLiteral.enumerationDefinition()), "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()));
744     }));
745 }
746
747 class RenderFunctionDefinitionWriter : public FunctionDefinitionWriter {
748 public:
749     RenderFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
750         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
751         , m_matchedSemantics(WTFMove(matchedSemantics))
752     {
753     }
754
755 private:
756     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
757
758     MatchedRenderSemantics m_matchedSemantics;
759 };
760
761 std::unique_ptr<EntryPointScaffolding> RenderFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
762 {
763     auto generateNextVariableName = [this]() -> MangledVariableName {
764         return this->generateNextVariableName();
765     };
766     if (&functionDefinition == m_matchedSemantics.vertexShader)
767         return std::make_unique<VertexEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.vertexShaderEntryPointItems, m_matchedSemantics.vertexShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedVertexAttributes);
768     if (&functionDefinition == m_matchedSemantics.fragmentShader)
769         return std::make_unique<FragmentEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.fragmentShaderEntryPointItems, m_matchedSemantics.fragmentShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedColorAttachments);
770     return nullptr;
771 }
772
773 class ComputeFunctionDefinitionWriter : public FunctionDefinitionWriter {
774 public:
775     ComputeFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, MangledFunctionName>& functionMapping, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
776         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
777         , m_matchedSemantics(WTFMove(matchedSemantics))
778     {
779     }
780
781 private:
782     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
783
784     MatchedComputeSemantics m_matchedSemantics;
785 };
786
787 std::unique_ptr<EntryPointScaffolding> ComputeFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
788 {
789     auto generateNextVariableName = [this]() -> MangledVariableName {
790         return this->generateNextVariableName();
791     };
792     if (&functionDefinition == m_matchedSemantics.shader)
793         return std::make_unique<ComputeEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.entryPointItems, m_matchedSemantics.resourceMap, m_layout, WTFMove(generateNextVariableName));
794     return nullptr;
795 }
796
797 struct SharedMetalFunctionsResult {
798     HashMap<AST::FunctionDeclaration*, MangledFunctionName> functionMapping;
799     String metalFunctions;
800 };
801 static SharedMetalFunctionsResult sharedMetalFunctions(Program& program, TypeNamer& typeNamer, const HashSet<AST::FunctionDeclaration*>& reachableFunctions)
802 {
803     StringBuilder stringBuilder;
804
805     unsigned numFunctions = 0;
806     HashMap<AST::FunctionDeclaration*, MangledFunctionName> functionMapping;
807     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
808         auto addResult = functionMapping.add(&nativeFunctionDeclaration, MangledFunctionName { numFunctions++ });
809         ASSERT_UNUSED(addResult, addResult.isNewEntry);
810     }
811     for (auto& functionDefinition : program.functionDefinitions()) {
812         auto addResult = functionMapping.add(&functionDefinition, MangledFunctionName { numFunctions++ });
813         ASSERT_UNUSED(addResult, addResult.isNewEntry);
814     }
815
816     {
817         FunctionDeclarationWriter functionDeclarationWriter(typeNamer, functionMapping);
818         for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
819             if (reachableFunctions.contains(&nativeFunctionDeclaration))
820                 functionDeclarationWriter.visit(nativeFunctionDeclaration);
821         }
822         for (auto& functionDefinition : program.functionDefinitions()) {
823             if (!functionDefinition->entryPointType() && reachableFunctions.contains(&functionDefinition))
824                 functionDeclarationWriter.visit(functionDefinition);
825         }
826         stringBuilder.append(functionDeclarationWriter.toString());
827     }
828
829     stringBuilder.append('\n');
830     return { WTFMove(functionMapping), stringBuilder.toString() };
831 }
832
833 class ReachableFunctionsGatherer : public Visitor {
834 public:
835     void visit(AST::FunctionDeclaration& functionDeclaration) override
836     {
837         Visitor::visit(functionDeclaration);
838         m_reachableFunctions.add(&functionDeclaration);
839     }
840
841     void visit(AST::CallExpression& callExpression) override
842     {
843         Visitor::visit(callExpression);
844         if (is<AST::FunctionDefinition>(callExpression.function()))
845             checkErrorAndVisit(downcast<AST::FunctionDefinition>(callExpression.function()));
846         else
847             checkErrorAndVisit(downcast<AST::NativeFunctionDeclaration>(callExpression.function()));
848     }
849
850     HashSet<AST::FunctionDeclaration*> takeReachableFunctions() { return WTFMove(m_reachableFunctions); }
851
852 private:
853     HashSet<AST::FunctionDeclaration*> m_reachableFunctions;
854 };
855
856 RenderMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
857 {
858     auto& vertexShaderEntryPoint = *matchedSemantics.vertexShader;
859     auto& fragmentShaderEntryPoint = *matchedSemantics.fragmentShader;
860
861     ReachableFunctionsGatherer reachableFunctionsGatherer;
862     reachableFunctionsGatherer.Visitor::visit(vertexShaderEntryPoint);
863     reachableFunctionsGatherer.Visitor::visit(fragmentShaderEntryPoint);
864     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
865
866     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer, reachableFunctions);
867
868     StringBuilder stringBuilder;
869     stringBuilder.append(sharedMetalFunctions.metalFunctions);
870
871     RenderFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
872     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
873         if (reachableFunctions.contains(&nativeFunctionDeclaration))
874             functionDefinitionWriter.visit(nativeFunctionDeclaration);
875     }
876     for (auto& functionDefinition : program.functionDefinitions()) {
877         if (reachableFunctions.contains(&functionDefinition))
878             functionDefinitionWriter.visit(functionDefinition);
879     }
880     stringBuilder.append(functionDefinitionWriter.toString());
881
882     RenderMetalFunctions result;
883     result.metalSource = stringBuilder.toString();
884     result.mangledVertexEntryPointName = sharedMetalFunctions.functionMapping.get(&vertexShaderEntryPoint);
885     result.mangledFragmentEntryPointName = sharedMetalFunctions.functionMapping.get(&fragmentShaderEntryPoint);
886     return result;
887 }
888
889 ComputeMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
890 {
891     auto& entryPoint = *matchedSemantics.shader;
892
893     ReachableFunctionsGatherer reachableFunctionsGatherer;
894     reachableFunctionsGatherer.Visitor::visit(entryPoint);
895     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
896
897     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer, reachableFunctions);
898
899     StringBuilder stringBuilder;
900     stringBuilder.append(sharedMetalFunctions.metalFunctions);
901
902     ComputeFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
903     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
904         if (reachableFunctions.contains(&nativeFunctionDeclaration))
905             functionDefinitionWriter.visit(nativeFunctionDeclaration);
906     }
907     for (auto& functionDefinition : program.functionDefinitions()) {
908         if (reachableFunctions.contains(&functionDefinition))
909             functionDefinitionWriter.visit(functionDefinition);
910     }
911     stringBuilder.append(functionDefinitionWriter.toString());
912
913     ComputeMetalFunctions result;
914     result.metalSource = stringBuilder.toString();
915     result.mangledEntryPointName = sharedMetalFunctions.functionMapping.get(&entryPoint);
916     return result;
917 }
918
919 } // namespace Metal
920
921 } // namespace WHLSL
922
923 } // namespace WebCore
924
925 #endif