StringBuilder::append(makeString(...)) is inefficient
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / Metal / WHLSLFunctionWriter.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLFunctionWriter.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "NotImplemented.h"
32 #include "WHLSLAST.h"
33 #include "WHLSLEntryPointScaffolding.h"
34 #include "WHLSLInferTypes.h"
35 #include "WHLSLNativeFunctionWriter.h"
36 #include "WHLSLProgram.h"
37 #include "WHLSLTypeNamer.h"
38 #include "WHLSLVisitor.h"
39 #include <wtf/HashMap.h>
40 #include <wtf/HashSet.h>
41 #include <wtf/SetForScope.h>
42 #include <wtf/text/StringBuilder.h>
43
44 namespace WebCore {
45
46 namespace WHLSL {
47
48 namespace Metal {
49
50 class FunctionDeclarationWriter : public Visitor {
51 public:
52     FunctionDeclarationWriter(TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping)
53         : m_typeNamer(typeNamer)
54         , m_functionMapping(functionMapping)
55     {
56     }
57
58     virtual ~FunctionDeclarationWriter() = default;
59
60     String toString() { return m_stringBuilder.toString(); }
61
62     void visit(AST::FunctionDeclaration&) override;
63
64 private:
65     TypeNamer& m_typeNamer;
66     HashMap<AST::FunctionDeclaration*, String>& m_functionMapping;
67     StringBuilder m_stringBuilder;
68 };
69
70 void FunctionDeclarationWriter::visit(AST::FunctionDeclaration& functionDeclaration)
71 {
72     if (functionDeclaration.entryPointType())
73         return;
74
75     auto iterator = m_functionMapping.find(&functionDeclaration);
76     ASSERT(iterator != m_functionMapping.end());
77     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(functionDeclaration.type()), ' ', iterator->value, '(');
78     for (size_t i = 0; i < functionDeclaration.parameters().size(); ++i) {
79         if (i)
80             m_stringBuilder.append(", ");
81         m_stringBuilder.append(m_typeNamer.mangledNameForType(*functionDeclaration.parameters()[i]->type()));
82     }
83     m_stringBuilder.append(");\n");
84 }
85
86 class FunctionDefinitionWriter : public Visitor {
87 public:
88     FunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, Layout& layout)
89         : m_intrinsics(intrinsics)
90         , m_typeNamer(typeNamer)
91         , m_functionMapping(functionMapping)
92         , m_layout(layout)
93     {
94         m_stringBuilder.flexibleAppend(
95             "template <typename T>\n"
96             "inline void ", memsetZeroFunctionName, "(thread T& value)\n"
97             "{\n"
98             "    thread char* ptr = static_cast<thread char*>(static_cast<thread void*>(&value));\n"
99             "    for (size_t i = 0; i < sizeof(T); ++i)\n"
100             "        ptr[i] = 0;\n"
101             "}\n");
102     }
103
104     static constexpr const char* memsetZeroFunctionName = "memsetZero";
105
106     virtual ~FunctionDefinitionWriter() = default;
107
108     String toString() { return m_stringBuilder.toString(); }
109
110     void visit(AST::NativeFunctionDeclaration&) override;
111     void visit(AST::FunctionDefinition&) override;
112
113 protected:
114     virtual std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) = 0;
115
116     void visit(AST::FunctionDeclaration&) override;
117     void visit(AST::Statement&) override;
118     void visit(AST::Block&) override;
119     void visit(AST::Break&) override;
120     void visit(AST::Continue&) override;
121     void visit(AST::DoWhileLoop&) override;
122     void visit(AST::EffectfulExpressionStatement&) override;
123     void visit(AST::Fallthrough&) override;
124     void visit(AST::ForLoop&) override;
125     void visit(AST::IfStatement&) override;
126     void visit(AST::Return&) override;
127     void visit(AST::SwitchStatement&) override;
128     void visit(AST::SwitchCase&) override;
129     void visit(AST::VariableDeclarationsStatement&) override;
130     void visit(AST::WhileLoop&) override;
131     void visit(AST::IntegerLiteral&) override;
132     void visit(AST::UnsignedIntegerLiteral&) override;
133     void visit(AST::FloatLiteral&) override;
134     void visit(AST::NullLiteral&) override;
135     void visit(AST::BooleanLiteral&) override;
136     void visit(AST::EnumerationMemberLiteral&) override;
137     void visit(AST::Expression&) override;
138     void visit(AST::DotExpression&) override;
139     void visit(AST::GlobalVariableReference&) override;
140     void visit(AST::IndexExpression&) override;
141     void visit(AST::PropertyAccessExpression&) override;
142     void visit(AST::VariableDeclaration&) override;
143     void visit(AST::AssignmentExpression&) override;
144     void visit(AST::CallExpression&) override;
145     void visit(AST::CommaExpression&) override;
146     void visit(AST::DereferenceExpression&) override;
147     void visit(AST::LogicalExpression&) override;
148     void visit(AST::LogicalNotExpression&) override;
149     void visit(AST::MakeArrayReferenceExpression&) override;
150     void visit(AST::MakePointerExpression&) override;
151     void visit(AST::ReadModifyWriteExpression&) override;
152     void visit(AST::TernaryExpression&) override;
153     void visit(AST::VariableReference&) override;
154
155     enum class LoopConditionLocation {
156         BeforeBody,
157         AfterBody
158     };
159     void emitLoop(LoopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body);
160
161     String constantExpressionString(AST::ConstantExpression&);
162
163     String generateNextVariableName()
164     {
165         return makeString("variable", m_variableCount++);
166     }
167
168     struct StackItem {
169         String value;
170         String leftValue;
171     };
172
173     void appendRightValue(AST::Expression&, String value)
174     {
175         m_stack.append({ WTFMove(value), String() });
176     }
177
178     void appendLeftValue(AST::Expression& expression, String value, String leftValue)
179     {
180         ASSERT_UNUSED(expression, expression.typeAnnotation().leftAddressSpace());
181         m_stack.append({ WTFMove(value), WTFMove(leftValue) });
182     }
183
184     String takeLastValue()
185     {
186         ASSERT(m_stack.last().value);
187         return m_stack.takeLast().value;
188     }
189
190     String takeLastLeftValue()
191     {
192         ASSERT(m_stack.last().leftValue);
193         return m_stack.takeLast().leftValue;
194     }
195
196     enum class BreakContext {
197         Loop,
198         Switch
199     };
200
201     Optional<BreakContext> m_currentBreakContext;
202
203     Intrinsics& m_intrinsics;
204     TypeNamer& m_typeNamer;
205     HashMap<AST::FunctionDeclaration*, String>& m_functionMapping;
206     HashMap<AST::VariableDeclaration*, String> m_variableMapping;
207     StringBuilder m_stringBuilder;
208
209     Vector<StackItem> m_stack;
210     std::unique_ptr<EntryPointScaffolding> m_entryPointScaffolding;
211     Layout& m_layout;
212     unsigned m_variableCount { 0 };
213     String m_breakOutOfCurrentLoopEarlyVariable;
214 };
215
216 void FunctionDefinitionWriter::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
217 {
218     auto iterator = m_functionMapping.find(&nativeFunctionDeclaration);
219     ASSERT(iterator != m_functionMapping.end());
220     m_stringBuilder.append(writeNativeFunction(nativeFunctionDeclaration, iterator->value, m_intrinsics, m_typeNamer, memsetZeroFunctionName));
221 }
222
223 void FunctionDefinitionWriter::visit(AST::FunctionDefinition& functionDefinition)
224 {
225     auto iterator = m_functionMapping.find(&functionDefinition);
226     ASSERT(iterator != m_functionMapping.end());
227     if (functionDefinition.entryPointType()) {
228         auto entryPointScaffolding = createEntryPointScaffolding(functionDefinition);
229         if (!entryPointScaffolding)
230             return;
231         m_entryPointScaffolding = WTFMove(entryPointScaffolding);
232         m_stringBuilder.flexibleAppend(
233             m_entryPointScaffolding->helperTypes(), '\n',
234             m_entryPointScaffolding->signature(iterator->value), " {\n",
235             m_entryPointScaffolding->unpack()
236         );
237         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
238             auto addResult = m_variableMapping.add(&functionDefinition.parameters()[i], m_entryPointScaffolding->parameterVariables()[i]);
239             ASSERT_UNUSED(addResult, addResult.isNewEntry);
240         }
241         checkErrorAndVisit(functionDefinition.block());
242         ASSERT(m_stack.isEmpty());
243         m_stringBuilder.append("}\n");
244         m_entryPointScaffolding = nullptr;
245     } else {
246         ASSERT(m_entryPointScaffolding == nullptr);
247         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(functionDefinition.type()), ' ', iterator->value, '(');
248         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
249             auto& parameter = functionDefinition.parameters()[i];
250             if (i)
251                 m_stringBuilder.append(", ");
252             auto parameterName = generateNextVariableName();
253             auto addResult = m_variableMapping.add(&parameter, parameterName);
254             ASSERT_UNUSED(addResult, addResult.isNewEntry);
255             m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*parameter->type()), ' ', parameterName);
256         }
257         m_stringBuilder.append(") {\n");
258         checkErrorAndVisit(functionDefinition.block());
259         ASSERT(m_stack.isEmpty());
260         m_stringBuilder.append("}\n");
261     }
262 }
263
264 void FunctionDefinitionWriter::visit(AST::FunctionDeclaration&)
265 {
266     ASSERT_NOT_REACHED();
267 }
268
269 void FunctionDefinitionWriter::visit(AST::Statement& statement)
270 {
271     Visitor::visit(statement);
272 }
273
274 void FunctionDefinitionWriter::visit(AST::Block& block)
275 {
276     m_stringBuilder.append("{\n");
277     for (auto& statement : block.statements())
278         checkErrorAndVisit(statement);
279     m_stringBuilder.append("}\n");
280 }
281
282 void FunctionDefinitionWriter::visit(AST::Break&)
283 {
284     ASSERT(m_currentBreakContext);
285     switch (*m_currentBreakContext) {
286     case BreakContext::Switch:
287         m_stringBuilder.append("break;\n");
288         break;
289     case BreakContext::Loop:
290         ASSERT(m_breakOutOfCurrentLoopEarlyVariable.length());
291         m_stringBuilder.flexibleAppend(
292             m_breakOutOfCurrentLoopEarlyVariable, " = true;\n"
293             "break;\n"
294         );
295         break;
296     }
297 }
298
299 void FunctionDefinitionWriter::visit(AST::Continue&)
300 {
301     ASSERT(m_breakOutOfCurrentLoopEarlyVariable.length());
302     m_stringBuilder.append("break;\n");
303 }
304
305 void FunctionDefinitionWriter::visit(AST::EffectfulExpressionStatement& effectfulExpressionStatement)
306 {
307     checkErrorAndVisit(effectfulExpressionStatement.effectfulExpression());
308     takeLastValue(); // The statement is already effectful, so we don't need to do anything with the result.
309 }
310
311 void FunctionDefinitionWriter::visit(AST::Fallthrough&)
312 {
313     m_stringBuilder.append("[[clang::fallthrough]];\n"); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195808 Make sure this is okay. Alternatively, we could do nothing and just return here instead.
314 }
315
316 void FunctionDefinitionWriter::emitLoop(LoopConditionLocation loopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body)
317 {
318     SetForScope<String> loopVariableScope(m_breakOutOfCurrentLoopEarlyVariable, generateNextVariableName());
319
320     m_stringBuilder.flexibleAppend(
321         "bool ", m_breakOutOfCurrentLoopEarlyVariable, " = false;\n",
322         "while (true) {\n"
323     );
324
325     if (loopConditionLocation == LoopConditionLocation::BeforeBody && conditionExpression) {
326         checkErrorAndVisit(*conditionExpression);
327         m_stringBuilder.flexibleAppend("if (!", takeLastValue(), ") break;\n");
328     }
329
330     m_stringBuilder.append("do {\n");
331     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Loop);
332     checkErrorAndVisit(body);
333     m_stringBuilder.flexibleAppend(
334         "} while(false); \n"
335         "if (", m_breakOutOfCurrentLoopEarlyVariable, ") break;\n"
336     );
337
338     if (increment) {
339         checkErrorAndVisit(*increment);
340         // Expression results get pushed to m_stack. We don't use the result
341         // of increment, so we dispense of that now.
342         takeLastValue();
343     }
344
345     if (loopConditionLocation == LoopConditionLocation::AfterBody && conditionExpression) {
346         checkErrorAndVisit(*conditionExpression);
347         m_stringBuilder.flexibleAppend("if (!", takeLastValue(), ") break;\n");
348     }
349
350     m_stringBuilder.append("} \n");
351 }
352
353 void FunctionDefinitionWriter::visit(AST::DoWhileLoop& doWhileLoop)
354 {
355     emitLoop(LoopConditionLocation::AfterBody, &doWhileLoop.conditional(), nullptr, doWhileLoop.body());
356 }
357
358 void FunctionDefinitionWriter::visit(AST::WhileLoop& whileLoop)
359 {
360     emitLoop(LoopConditionLocation::BeforeBody, &whileLoop.conditional(), nullptr, whileLoop.body());
361 }
362
363 void FunctionDefinitionWriter::visit(AST::ForLoop& forLoop)
364 {
365     m_stringBuilder.append("{\n");
366
367     WTF::visit(WTF::makeVisitor([&](AST::Statement& statement) {
368         checkErrorAndVisit(statement);
369     }, [&](UniqueRef<AST::Expression>& expression) {
370         checkErrorAndVisit(expression);
371         takeLastValue(); // We don't need to do anything with the result.
372     }), forLoop.initialization());
373
374     emitLoop(LoopConditionLocation::BeforeBody, forLoop.condition(), forLoop.increment(), forLoop.body());
375     m_stringBuilder.append("}\n");
376 }
377
378 void FunctionDefinitionWriter::visit(AST::IfStatement& ifStatement)
379 {
380     checkErrorAndVisit(ifStatement.conditional());
381     m_stringBuilder.flexibleAppend("if (", takeLastValue(), ") {\n");
382     checkErrorAndVisit(ifStatement.body());
383     if (ifStatement.elseBody()) {
384         m_stringBuilder.append("} else {\n");
385         checkErrorAndVisit(*ifStatement.elseBody());
386     }
387     m_stringBuilder.append("}\n");
388 }
389
390 void FunctionDefinitionWriter::visit(AST::Return& returnStatement)
391 {
392     if (returnStatement.value()) {
393         checkErrorAndVisit(*returnStatement.value());
394         if (m_entryPointScaffolding) {
395             auto variableName = generateNextVariableName();
396             m_stringBuilder.flexibleAppend(
397                 m_entryPointScaffolding->pack(takeLastValue(), variableName),
398                 "return ", variableName, ";\n"
399             );
400         } else
401             m_stringBuilder.flexibleAppend("return ", takeLastValue(), ";\n");
402     } else
403         m_stringBuilder.append("return;\n");
404 }
405
406 void FunctionDefinitionWriter::visit(AST::SwitchStatement& switchStatement)
407 {
408     checkErrorAndVisit(switchStatement.value());
409
410     m_stringBuilder.flexibleAppend("switch (", takeLastValue(), ") {");
411     for (auto& switchCase : switchStatement.switchCases())
412         checkErrorAndVisit(switchCase);
413     m_stringBuilder.append("}\n");
414 }
415
416 void FunctionDefinitionWriter::visit(AST::SwitchCase& switchCase)
417 {
418     if (switchCase.value())
419         m_stringBuilder.flexibleAppend("case ", constantExpressionString(*switchCase.value()), ":\n");
420     else
421         m_stringBuilder.append("default:\n");
422     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Switch);
423     checkErrorAndVisit(switchCase.block());
424     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195812 Figure out whether we need to break or fallthrough.
425 }
426
427 void FunctionDefinitionWriter::visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement)
428 {
429     Visitor::visit(variableDeclarationsStatement);
430 }
431
432 void FunctionDefinitionWriter::visit(AST::IntegerLiteral& integerLiteral)
433 {
434     auto variableName = generateNextVariableName();
435     auto mangledTypeName = m_typeNamer.mangledNameForType(integerLiteral.resolvedType());
436     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", integerLiteral.value(), ");\n");
437     appendRightValue(integerLiteral, variableName);
438 }
439
440 void FunctionDefinitionWriter::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
441 {
442     auto variableName = generateNextVariableName();
443     auto mangledTypeName = m_typeNamer.mangledNameForType(unsignedIntegerLiteral.resolvedType());
444     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", unsignedIntegerLiteral.value(), ");\n");
445     appendRightValue(unsignedIntegerLiteral, variableName);
446 }
447
448 void FunctionDefinitionWriter::visit(AST::FloatLiteral& floatLiteral)
449 {
450     auto variableName = generateNextVariableName();
451     auto mangledTypeName = m_typeNamer.mangledNameForType(floatLiteral.resolvedType());
452     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", floatLiteral.value(), ");\n");
453     appendRightValue(floatLiteral, variableName);
454 }
455
456 void FunctionDefinitionWriter::visit(AST::NullLiteral& nullLiteral)
457 {
458     auto& unifyNode = nullLiteral.resolvedType().unifyNode();
459     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
460     bool isArrayReferenceType = is<AST::ArrayReferenceType>(unnamedType);
461
462     auto variableName = generateNextVariableName();
463     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(nullLiteral.resolvedType()), ' ', variableName, " = ");
464     if (isArrayReferenceType)
465         m_stringBuilder.append("{ nullptr, 0 };\n");
466     else
467         m_stringBuilder.append("nullptr;\n");
468     appendRightValue(nullLiteral, variableName);
469 }
470
471 void FunctionDefinitionWriter::visit(AST::BooleanLiteral& booleanLiteral)
472 {
473     auto variableName = generateNextVariableName();
474     auto mangledTypeName = m_typeNamer.mangledNameForType(booleanLiteral.resolvedType());
475     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", booleanLiteral.value() ? "true" : "false", ");\n");
476     appendRightValue(booleanLiteral, variableName);
477 }
478
479 void FunctionDefinitionWriter::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
480 {
481     ASSERT(enumerationMemberLiteral.enumerationDefinition());
482     ASSERT(enumerationMemberLiteral.enumerationDefinition());
483     auto variableName = generateNextVariableName();
484     auto mangledTypeName = m_typeNamer.mangledNameForType(enumerationMemberLiteral.resolvedType());
485     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = ", mangledTypeName, "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()), ";\n");
486     appendRightValue(enumerationMemberLiteral, variableName);
487 }
488
489 void FunctionDefinitionWriter::visit(AST::Expression& expression)
490 {
491     Visitor::visit(expression);
492 }
493
494 void FunctionDefinitionWriter::visit(AST::DotExpression& dotExpression)
495 {
496     // This should be lowered already.
497     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
498     notImplemented();
499     appendRightValue(dotExpression, "dummy");
500 }
501
502 void FunctionDefinitionWriter::visit(AST::GlobalVariableReference& globalVariableReference)
503 {
504     auto valueName = generateNextVariableName();
505     auto pointerName = generateNextVariableName();
506     auto mangledTypeName = m_typeNamer.mangledNameForType(globalVariableReference.resolvedType());
507     checkErrorAndVisit(globalVariableReference.base());
508     m_stringBuilder.flexibleAppend(
509         "thread ", mangledTypeName, "* ", pointerName, " = &", takeLastValue(), "->", m_typeNamer.mangledNameForStructureElement(globalVariableReference.structField()), ";\n",
510         mangledTypeName, ' ', valueName, " = ", "*", pointerName, ";\n"
511     );
512     appendLeftValue(globalVariableReference, valueName, pointerName);
513 }
514
515 void FunctionDefinitionWriter::visit(AST::IndexExpression& indexExpression)
516 {
517     // This should be lowered already.
518     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
519     notImplemented();
520     appendRightValue(indexExpression, "dummy");
521 }
522
523 void FunctionDefinitionWriter::visit(AST::PropertyAccessExpression& propertyAccessExpression)
524 {
525     // This should be lowered already.
526     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
527     notImplemented();
528     appendRightValue(propertyAccessExpression, "dummy");
529 }
530
531 void FunctionDefinitionWriter::visit(AST::VariableDeclaration& variableDeclaration)
532 {
533     ASSERT(variableDeclaration.type());
534     auto variableName = generateNextVariableName();
535     auto addResult = m_variableMapping.add(&variableDeclaration, variableName);
536     ASSERT_UNUSED(addResult, addResult.isNewEntry);
537     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198160 Implement qualifiers.
538     if (variableDeclaration.initializer()) {
539         checkErrorAndVisit(*variableDeclaration.initializer());
540         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = ", takeLastValue(), ";\n");
541     } else
542         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, ";\n");
543 }
544
545 void FunctionDefinitionWriter::visit(AST::AssignmentExpression& assignmentExpression)
546 {
547     checkErrorAndVisit(assignmentExpression.left());
548     auto pointerName = takeLastLeftValue();
549     checkErrorAndVisit(assignmentExpression.right());
550     auto rightName = takeLastValue();
551     m_stringBuilder.flexibleAppend("if (", pointerName, ") *", pointerName, " = ", rightName, ";\n");
552     appendRightValue(assignmentExpression, rightName);
553 }
554
555 void FunctionDefinitionWriter::visit(AST::CallExpression& callExpression)
556 {
557     Vector<String> argumentNames;
558     for (auto& argument : callExpression.arguments()) {
559         checkErrorAndVisit(argument);
560         argumentNames.append(takeLastValue());
561     }
562     auto iterator = m_functionMapping.find(&callExpression.function());
563     ASSERT(iterator != m_functionMapping.end());
564     auto variableName = generateNextVariableName();
565     if (!matches(callExpression.resolvedType(), m_intrinsics.voidType()))
566         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(callExpression.resolvedType()), ' ', variableName, " = ");
567     m_stringBuilder.flexibleAppend(iterator->value, '(');
568     for (size_t i = 0; i < argumentNames.size(); ++i) {
569         if (i)
570             m_stringBuilder.append(", ");
571         m_stringBuilder.append(argumentNames[i]);
572     }
573     m_stringBuilder.append(");\n");
574     appendRightValue(callExpression, variableName);
575 }
576
577 void FunctionDefinitionWriter::visit(AST::CommaExpression& commaExpression)
578 {
579     String result;
580     for (auto& expression : commaExpression.list()) {
581         checkErrorAndVisit(expression);
582         result = takeLastValue();
583     }
584     appendRightValue(commaExpression, result);
585 }
586
587 void FunctionDefinitionWriter::visit(AST::DereferenceExpression& dereferenceExpression)
588 {
589     checkErrorAndVisit(dereferenceExpression.pointer());
590     auto right = takeLastValue();
591     auto variableName = generateNextVariableName();
592     auto pointerName = generateNextVariableName();
593     m_stringBuilder.flexibleAppend(
594         m_typeNamer.mangledNameForType(dereferenceExpression.pointer().resolvedType()), ' ', pointerName, " = ", right, ";\n",
595         m_typeNamer.mangledNameForType(dereferenceExpression.resolvedType()), ' ', variableName, ";\n",
596         "if (", pointerName, ") ", variableName, " = *", right, ";\n",
597         "else ", memsetZeroFunctionName, '(', variableName, ");\n"
598     );
599     appendLeftValue(dereferenceExpression, variableName, pointerName);
600 }
601
602 void FunctionDefinitionWriter::visit(AST::LogicalExpression& logicalExpression)
603 {
604     checkErrorAndVisit(logicalExpression.left());
605     auto left = takeLastValue();
606     checkErrorAndVisit(logicalExpression.right());
607     auto right = takeLastValue();
608     auto variableName = generateNextVariableName();
609     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(logicalExpression.resolvedType()), ' ', variableName, " = ", left);
610     switch (logicalExpression.type()) {
611     case AST::LogicalExpression::Type::And:
612         m_stringBuilder.append(" && ");
613         break;
614     default:
615         ASSERT(logicalExpression.type() == AST::LogicalExpression::Type::Or);
616         m_stringBuilder.append(" || ");
617         break;
618     }
619     m_stringBuilder.flexibleAppend(right, ";\n");
620     appendRightValue(logicalExpression, variableName);
621 }
622
623 void FunctionDefinitionWriter::visit(AST::LogicalNotExpression& logicalNotExpression)
624 {
625     checkErrorAndVisit(logicalNotExpression.operand());
626     auto operand = takeLastValue();
627     auto variableName = generateNextVariableName();
628     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(logicalNotExpression.resolvedType()), ' ', variableName, " = !", operand, ";\n");
629     appendRightValue(logicalNotExpression, variableName);
630 }
631
632 void FunctionDefinitionWriter::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
633 {
634     checkErrorAndVisit(makeArrayReferenceExpression.leftValue());
635     // FIXME: This needs to be made to work. It probably should be using the last leftValue too.
636     // https://bugs.webkit.org/show_bug.cgi?id=198838
637     auto variableName = generateNextVariableName();
638
639     auto mangledTypeName = m_typeNamer.mangledNameForType(makeArrayReferenceExpression.resolvedType());
640     if (is<AST::PointerType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
641         auto ptrValue = takeLastValue();
642         m_stringBuilder.flexibleAppend(
643             mangledTypeName, ' ', variableName, ";\n",
644             "if (", ptrValue, ") ", variableName, " = { ", ptrValue, ", 1};\n",
645             "else ", variableName, " = { nullptr, 0 };\n"
646         );
647     } else if (is<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
648         auto lValue = takeLastLeftValue();
649         auto& arrayType = downcast<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType());
650         m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = { ", lValue, "->data(), ", arrayType.numElements(), " };\n");
651     } else {
652         auto lValue = takeLastLeftValue();
653         m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = { ", lValue, ", 1 };\n");
654     }
655     appendRightValue(makeArrayReferenceExpression, variableName);
656 }
657
658 void FunctionDefinitionWriter::visit(AST::MakePointerExpression& makePointerExpression)
659 {
660     checkErrorAndVisit(makePointerExpression.leftValue());
661     auto pointer = takeLastLeftValue();
662     auto variableName = generateNextVariableName();
663     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(makePointerExpression.resolvedType()), ' ', variableName, " = ", pointer, ";\n");
664     appendRightValue(makePointerExpression, variableName);
665 }
666
667 void FunctionDefinitionWriter::visit(AST::ReadModifyWriteExpression&)
668 {
669     // This should be lowered already.
670     ASSERT_NOT_REACHED();
671 }
672
673 void FunctionDefinitionWriter::visit(AST::TernaryExpression& ternaryExpression)
674 {
675     checkErrorAndVisit(ternaryExpression.predicate());
676     auto check = takeLastValue();
677
678     auto variableName = generateNextVariableName();
679     m_stringBuilder.flexibleAppend(
680         m_typeNamer.mangledNameForType(ternaryExpression.resolvedType()), ' ', variableName, ";\n"
681         "if (", check, ") {\n"
682     );
683     checkErrorAndVisit(ternaryExpression.bodyExpression());
684     m_stringBuilder.flexibleAppend(
685         variableName, " = ", takeLastValue(), ";\n"
686         "} else {\n"
687     );
688     checkErrorAndVisit(ternaryExpression.elseExpression());
689     m_stringBuilder.flexibleAppend(
690         variableName, " = ", takeLastValue(), ";\n"
691         "}\n"
692     );
693     appendRightValue(ternaryExpression, variableName);
694 }
695
696 void FunctionDefinitionWriter::visit(AST::VariableReference& variableReference)
697 {
698     ASSERT(variableReference.variable());
699     auto iterator = m_variableMapping.find(variableReference.variable());
700     ASSERT(iterator != m_variableMapping.end());
701     auto pointerName = generateNextVariableName();
702     m_stringBuilder.flexibleAppend("thread ", m_typeNamer.mangledNameForType(variableReference.resolvedType()), "* ", pointerName, " = &", iterator->value, ";\n");
703     appendLeftValue(variableReference, iterator->value, pointerName);
704 }
705
706 String FunctionDefinitionWriter::constantExpressionString(AST::ConstantExpression& constantExpression)
707 {
708     return constantExpression.visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> String {
709         return makeString("", integerLiteral.value());
710     }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> String {
711         return makeString("", unsignedIntegerLiteral.value());
712     }, [&](AST::FloatLiteral& floatLiteral) -> String {
713         return makeString("", floatLiteral.value());
714     }, [&](AST::NullLiteral&) -> String {
715         return "nullptr"_str;
716     }, [&](AST::BooleanLiteral& booleanLiteral) -> String {
717         return booleanLiteral.value() ? "true"_str : "false"_str;
718     }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> String {
719         ASSERT(enumerationMemberLiteral.enumerationDefinition());
720         ASSERT(enumerationMemberLiteral.enumerationDefinition());
721         return makeString(m_typeNamer.mangledNameForType(*enumerationMemberLiteral.enumerationDefinition()), "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()));
722     }));
723 }
724
725 class RenderFunctionDefinitionWriter : public FunctionDefinitionWriter {
726 public:
727     RenderFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
728         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
729         , m_matchedSemantics(WTFMove(matchedSemantics))
730     {
731     }
732
733 private:
734     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
735
736     MatchedRenderSemantics m_matchedSemantics;
737 };
738
739 std::unique_ptr<EntryPointScaffolding> RenderFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
740 {
741     auto generateNextVariableName = [this]() -> String {
742         return this->generateNextVariableName();
743     };
744     if (&functionDefinition == m_matchedSemantics.vertexShader)
745         return std::make_unique<VertexEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.vertexShaderEntryPointItems, m_matchedSemantics.vertexShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedVertexAttributes);
746     if (&functionDefinition == m_matchedSemantics.fragmentShader)
747         return std::make_unique<FragmentEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.fragmentShaderEntryPointItems, m_matchedSemantics.fragmentShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedColorAttachments);
748     return nullptr;
749 }
750
751 class ComputeFunctionDefinitionWriter : public FunctionDefinitionWriter {
752 public:
753     ComputeFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
754         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
755         , m_matchedSemantics(WTFMove(matchedSemantics))
756     {
757     }
758
759 private:
760     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
761
762     MatchedComputeSemantics m_matchedSemantics;
763 };
764
765 std::unique_ptr<EntryPointScaffolding> ComputeFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
766 {
767     auto generateNextVariableName = [this]() -> String {
768         return this->generateNextVariableName();
769     };
770     if (&functionDefinition == m_matchedSemantics.shader)
771         return std::make_unique<ComputeEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.entryPointItems, m_matchedSemantics.resourceMap, m_layout, WTFMove(generateNextVariableName));
772     return nullptr;
773 }
774
775 struct SharedMetalFunctionsResult {
776     HashMap<AST::FunctionDeclaration*, String> functionMapping;
777     String metalFunctions;
778 };
779 static SharedMetalFunctionsResult sharedMetalFunctions(Program& program, TypeNamer& typeNamer, const HashSet<AST::FunctionDeclaration*>& reachableFunctions)
780 {
781     StringBuilder stringBuilder;
782
783     unsigned numFunctions = 0;
784     HashMap<AST::FunctionDeclaration*, String> functionMapping;
785     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
786         auto addResult = functionMapping.add(&nativeFunctionDeclaration, makeString("function", numFunctions++));
787         ASSERT_UNUSED(addResult, addResult.isNewEntry);
788     }
789     for (auto& functionDefinition : program.functionDefinitions()) {
790         auto addResult = functionMapping.add(&functionDefinition, makeString("function", numFunctions++));
791         ASSERT_UNUSED(addResult, addResult.isNewEntry);
792     }
793
794     {
795         FunctionDeclarationWriter functionDeclarationWriter(typeNamer, functionMapping);
796         for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
797             if (reachableFunctions.contains(&nativeFunctionDeclaration))
798                 functionDeclarationWriter.visit(nativeFunctionDeclaration);
799         }
800         for (auto& functionDefinition : program.functionDefinitions()) {
801             if (!functionDefinition->entryPointType() && reachableFunctions.contains(&functionDefinition))
802                 functionDeclarationWriter.visit(functionDefinition);
803         }
804         stringBuilder.append(functionDeclarationWriter.toString());
805     }
806
807     stringBuilder.append('\n');
808     return { WTFMove(functionMapping), stringBuilder.toString() };
809 }
810
811 class ReachableFunctionsGatherer : public Visitor {
812 public:
813     void visit(AST::FunctionDeclaration& functionDeclaration) override
814     {
815         Visitor::visit(functionDeclaration);
816         m_reachableFunctions.add(&functionDeclaration);
817     }
818
819     void visit(AST::CallExpression& callExpression) override
820     {
821         Visitor::visit(callExpression);
822         if (is<AST::FunctionDefinition>(callExpression.function()))
823             checkErrorAndVisit(downcast<AST::FunctionDefinition>(callExpression.function()));
824         else
825             checkErrorAndVisit(downcast<AST::NativeFunctionDeclaration>(callExpression.function()));
826     }
827
828     HashSet<AST::FunctionDeclaration*> takeReachableFunctions() { return WTFMove(m_reachableFunctions); }
829
830 private:
831     HashSet<AST::FunctionDeclaration*> m_reachableFunctions;
832 };
833
834 RenderMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
835 {
836     auto& vertexShaderEntryPoint = *matchedSemantics.vertexShader;
837     auto& fragmentShaderEntryPoint = *matchedSemantics.fragmentShader;
838
839     ReachableFunctionsGatherer reachableFunctionsGatherer;
840     reachableFunctionsGatherer.Visitor::visit(vertexShaderEntryPoint);
841     reachableFunctionsGatherer.Visitor::visit(fragmentShaderEntryPoint);
842     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
843
844     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer, reachableFunctions);
845
846     StringBuilder stringBuilder;
847     stringBuilder.append(sharedMetalFunctions.metalFunctions);
848
849     RenderFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
850     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
851         if (reachableFunctions.contains(&nativeFunctionDeclaration))
852             functionDefinitionWriter.visit(nativeFunctionDeclaration);
853     }
854     for (auto& functionDefinition : program.functionDefinitions()) {
855         if (reachableFunctions.contains(&functionDefinition))
856             functionDefinitionWriter.visit(functionDefinition);
857     }
858     stringBuilder.append(functionDefinitionWriter.toString());
859
860     RenderMetalFunctions result;
861     result.metalSource = stringBuilder.toString();
862     result.mangledVertexEntryPointName = sharedMetalFunctions.functionMapping.get(&vertexShaderEntryPoint);
863     result.mangledFragmentEntryPointName = sharedMetalFunctions.functionMapping.get(&fragmentShaderEntryPoint);
864     return result;
865 }
866
867 ComputeMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
868 {
869     auto& entryPoint = *matchedSemantics.shader;
870
871     ReachableFunctionsGatherer reachableFunctionsGatherer;
872     reachableFunctionsGatherer.Visitor::visit(entryPoint);
873     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
874
875     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer, reachableFunctions);
876
877     StringBuilder stringBuilder;
878     stringBuilder.append(sharedMetalFunctions.metalFunctions);
879
880     ComputeFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
881     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
882         if (reachableFunctions.contains(&nativeFunctionDeclaration))
883             functionDefinitionWriter.visit(nativeFunctionDeclaration);
884     }
885     for (auto& functionDefinition : program.functionDefinitions()) {
886         if (reachableFunctions.contains(&functionDefinition))
887             functionDefinitionWriter.visit(functionDefinition);
888     }
889     stringBuilder.append(functionDefinitionWriter.toString());
890
891     ComputeMetalFunctions result;
892     result.metalSource = stringBuilder.toString();
893     result.mangledEntryPointName = sharedMetalFunctions.functionMapping.get(&entryPoint);
894     return result;
895 }
896
897 } // namespace Metal
898
899 } // namespace WHLSL
900
901 } // namespace WebCore
902
903 #endif