93f2bcd36867fa44ba5669063f680612ecde2b02
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / Metal / WHLSLFunctionWriter.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLFunctionWriter.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "NotImplemented.h"
32 #include "WHLSLAST.h"
33 #include "WHLSLEntryPointScaffolding.h"
34 #include "WHLSLInferTypes.h"
35 #include "WHLSLNativeFunctionWriter.h"
36 #include "WHLSLProgram.h"
37 #include "WHLSLTypeNamer.h"
38 #include "WHLSLVisitor.h"
39 #include <wtf/HashMap.h>
40 #include <wtf/HashSet.h>
41 #include <wtf/SetForScope.h>
42 #include <wtf/text/StringBuilder.h>
43
44 namespace WebCore {
45
46 namespace WHLSL {
47
48 namespace Metal {
49
50 class FunctionDeclarationWriter : public Visitor {
51 public:
52     FunctionDeclarationWriter(TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping)
53         : m_typeNamer(typeNamer)
54         , m_functionMapping(functionMapping)
55     {
56     }
57
58     virtual ~FunctionDeclarationWriter() = default;
59
60     String toString() { return m_stringBuilder.toString(); }
61
62     void visit(AST::FunctionDeclaration&) override;
63
64 private:
65     TypeNamer& m_typeNamer;
66     HashMap<AST::FunctionDeclaration*, String>& m_functionMapping;
67     StringBuilder m_stringBuilder;
68 };
69
70 void FunctionDeclarationWriter::visit(AST::FunctionDeclaration& functionDeclaration)
71 {
72     if (functionDeclaration.entryPointType())
73         return;
74
75     auto iterator = m_functionMapping.find(&functionDeclaration);
76     ASSERT(iterator != m_functionMapping.end());
77     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(functionDeclaration.type()), ' ', iterator->value, '('));
78     for (size_t i = 0; i < functionDeclaration.parameters().size(); ++i) {
79         if (i)
80             m_stringBuilder.append(", ");
81         m_stringBuilder.append(m_typeNamer.mangledNameForType(*functionDeclaration.parameters()[i]->type()));
82     }
83     m_stringBuilder.append(");\n");
84 }
85
86 class FunctionDefinitionWriter : public Visitor {
87 public:
88     FunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, Layout& layout)
89         : m_intrinsics(intrinsics)
90         , m_typeNamer(typeNamer)
91         , m_functionMapping(functionMapping)
92         , m_layout(layout)
93     {
94         m_stringBuilder.append(makeString(
95             "template <typename T>\n"
96             "inline void ", memsetZeroFunctionName, "(thread T& value)\n"
97             "{\n"
98             "    thread char* ptr = static_cast<thread char*>(static_cast<thread void*>(&value));\n"
99             "    for (size_t i = 0; i < sizeof(T); ++i)\n"
100             "        ptr[i] = 0;\n"
101             "}\n"));
102     }
103
104     static constexpr const char* memsetZeroFunctionName = "memsetZero";
105
106     virtual ~FunctionDefinitionWriter() = default;
107
108     String toString() { return m_stringBuilder.toString(); }
109
110     void visit(AST::NativeFunctionDeclaration&) override;
111     void visit(AST::FunctionDefinition&) override;
112
113 protected:
114     virtual std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) = 0;
115
116     void visit(AST::FunctionDeclaration&) override;
117     void visit(AST::Statement&) override;
118     void visit(AST::Block&) override;
119     void visit(AST::Break&) override;
120     void visit(AST::Continue&) override;
121     void visit(AST::DoWhileLoop&) override;
122     void visit(AST::EffectfulExpressionStatement&) override;
123     void visit(AST::Fallthrough&) override;
124     void visit(AST::ForLoop&) override;
125     void visit(AST::IfStatement&) override;
126     void visit(AST::Return&) override;
127     void visit(AST::SwitchStatement&) override;
128     void visit(AST::SwitchCase&) override;
129     void visit(AST::VariableDeclarationsStatement&) override;
130     void visit(AST::WhileLoop&) override;
131     void visit(AST::IntegerLiteral&) override;
132     void visit(AST::UnsignedIntegerLiteral&) override;
133     void visit(AST::FloatLiteral&) override;
134     void visit(AST::NullLiteral&) override;
135     void visit(AST::BooleanLiteral&) override;
136     void visit(AST::EnumerationMemberLiteral&) override;
137     void visit(AST::Expression&) override;
138     void visit(AST::DotExpression&) override;
139     void visit(AST::GlobalVariableReference&) override;
140     void visit(AST::IndexExpression&) override;
141     void visit(AST::PropertyAccessExpression&) override;
142     void visit(AST::VariableDeclaration&) override;
143     void visit(AST::AssignmentExpression&) override;
144     void visit(AST::CallExpression&) override;
145     void visit(AST::CommaExpression&) override;
146     void visit(AST::DereferenceExpression&) override;
147     void visit(AST::LogicalExpression&) override;
148     void visit(AST::LogicalNotExpression&) override;
149     void visit(AST::MakeArrayReferenceExpression&) override;
150     void visit(AST::MakePointerExpression&) override;
151     void visit(AST::ReadModifyWriteExpression&) override;
152     void visit(AST::TernaryExpression&) override;
153     void visit(AST::VariableReference&) override;
154
155     enum class LoopConditionLocation {
156         BeforeBody,
157         AfterBody
158     };
159     void emitLoop(LoopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body);
160
161     String constantExpressionString(AST::ConstantExpression&);
162
163     String generateNextVariableName()
164     {
165         return makeString("variable", m_variableCount++);
166     }
167
168     struct StackItem {
169         String value;
170         String leftValue;
171     };
172
173     void appendRightValue(AST::Expression&, String value)
174     {
175         m_stack.append({ WTFMove(value), String() });
176     }
177
178     void appendLeftValue(AST::Expression& expression, String value, String leftValue)
179     {
180         ASSERT_UNUSED(expression, expression.typeAnnotation().leftAddressSpace());
181         m_stack.append({ WTFMove(value), WTFMove(leftValue) });
182     }
183
184     String takeLastValue()
185     {
186         ASSERT(m_stack.last().value);
187         return m_stack.takeLast().value;
188     }
189
190     String takeLastLeftValue()
191     {
192         ASSERT(m_stack.last().leftValue);
193         return m_stack.takeLast().leftValue;
194     }
195
196     enum class BreakContext {
197         Loop,
198         Switch
199     };
200
201     Optional<BreakContext> m_currentBreakContext;
202
203     Intrinsics& m_intrinsics;
204     TypeNamer& m_typeNamer;
205     HashMap<AST::FunctionDeclaration*, String>& m_functionMapping;
206     HashMap<AST::VariableDeclaration*, String> m_variableMapping;
207     StringBuilder m_stringBuilder;
208
209     Vector<StackItem> m_stack;
210     std::unique_ptr<EntryPointScaffolding> m_entryPointScaffolding;
211     Layout& m_layout;
212     unsigned m_variableCount { 0 };
213     String m_breakOutOfCurrentLoopEarlyVariable;
214 };
215
216 void FunctionDefinitionWriter::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
217 {
218     auto iterator = m_functionMapping.find(&nativeFunctionDeclaration);
219     ASSERT(iterator != m_functionMapping.end());
220     m_stringBuilder.append(writeNativeFunction(nativeFunctionDeclaration, iterator->value, m_intrinsics, m_typeNamer, memsetZeroFunctionName));
221 }
222
223 void FunctionDefinitionWriter::visit(AST::FunctionDefinition& functionDefinition)
224 {
225     auto iterator = m_functionMapping.find(&functionDefinition);
226     ASSERT(iterator != m_functionMapping.end());
227     if (functionDefinition.entryPointType()) {
228         auto entryPointScaffolding = createEntryPointScaffolding(functionDefinition);
229         if (!entryPointScaffolding)
230             return;
231         m_entryPointScaffolding = WTFMove(entryPointScaffolding);
232         m_stringBuilder.append(m_entryPointScaffolding->helperTypes());
233         m_stringBuilder.append('\n');
234         m_stringBuilder.append(makeString(m_entryPointScaffolding->signature(iterator->value), " {\n"));
235         m_stringBuilder.append(m_entryPointScaffolding->unpack());
236         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
237             auto addResult = m_variableMapping.add(&functionDefinition.parameters()[i], m_entryPointScaffolding->parameterVariables()[i]);
238             ASSERT_UNUSED(addResult, addResult.isNewEntry);
239         }
240         checkErrorAndVisit(functionDefinition.block());
241         ASSERT(m_stack.isEmpty());
242         m_stringBuilder.append("}\n");
243         m_entryPointScaffolding = nullptr;
244     } else {
245         ASSERT(m_entryPointScaffolding == nullptr);
246         m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(functionDefinition.type()), ' ', iterator->value, '('));
247         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
248             auto& parameter = functionDefinition.parameters()[i];
249             if (i)
250                 m_stringBuilder.append(", ");
251             auto parameterName = generateNextVariableName();
252             auto addResult = m_variableMapping.add(&parameter, parameterName);
253             ASSERT_UNUSED(addResult, addResult.isNewEntry);
254             m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*parameter->type()), ' ', parameterName));
255         }
256         m_stringBuilder.append(") {\n");
257         checkErrorAndVisit(functionDefinition.block());
258         ASSERT(m_stack.isEmpty());
259         m_stringBuilder.append("}\n");
260     }
261 }
262
263 void FunctionDefinitionWriter::visit(AST::FunctionDeclaration&)
264 {
265     ASSERT_NOT_REACHED();
266 }
267
268 void FunctionDefinitionWriter::visit(AST::Statement& statement)
269 {
270     Visitor::visit(statement);
271 }
272
273 void FunctionDefinitionWriter::visit(AST::Block& block)
274 {
275     m_stringBuilder.append("{\n");
276     for (auto& statement : block.statements())
277         checkErrorAndVisit(statement);
278     m_stringBuilder.append("}\n");
279 }
280
281 void FunctionDefinitionWriter::visit(AST::Break&)
282 {
283     ASSERT(m_currentBreakContext);
284     switch (*m_currentBreakContext) {
285     case BreakContext::Switch:
286         m_stringBuilder.append("break;\n");
287         break;
288     case BreakContext::Loop:
289         ASSERT(m_breakOutOfCurrentLoopEarlyVariable.length());
290         m_stringBuilder.append(makeString(m_breakOutOfCurrentLoopEarlyVariable, " = true;\n"));
291         m_stringBuilder.append("break;\n");
292         break;
293     }
294 }
295
296 void FunctionDefinitionWriter::visit(AST::Continue&)
297 {
298     ASSERT(m_breakOutOfCurrentLoopEarlyVariable.length());
299     m_stringBuilder.append("break;\n");
300 }
301
302 void FunctionDefinitionWriter::visit(AST::EffectfulExpressionStatement& effectfulExpressionStatement)
303 {
304     checkErrorAndVisit(effectfulExpressionStatement.effectfulExpression());
305     takeLastValue(); // The statement is already effectful, so we don't need to do anything with the result.
306 }
307
308 void FunctionDefinitionWriter::visit(AST::Fallthrough&)
309 {
310     m_stringBuilder.append("[[clang::fallthrough]];\n"); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195808 Make sure this is okay. Alternatively, we could do nothing and just return here instead.
311 }
312
313 void FunctionDefinitionWriter::emitLoop(LoopConditionLocation loopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body)
314 {
315     SetForScope<String> loopVariableScope(m_breakOutOfCurrentLoopEarlyVariable, generateNextVariableName());
316
317     m_stringBuilder.append(makeString("bool ", m_breakOutOfCurrentLoopEarlyVariable, " = false;\n"));
318
319     m_stringBuilder.append("while (true) {\n");
320
321     if (loopConditionLocation == LoopConditionLocation::BeforeBody && conditionExpression) {
322         checkErrorAndVisit(*conditionExpression);
323         m_stringBuilder.append(makeString("if (!", takeLastValue(), ") break;\n"));
324     }
325
326     m_stringBuilder.append("do {\n");
327     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Loop);
328     checkErrorAndVisit(body);
329     m_stringBuilder.append("} while(false); \n");
330     m_stringBuilder.append(makeString("if (", m_breakOutOfCurrentLoopEarlyVariable, ") break;\n"));
331
332     if (increment) {
333         checkErrorAndVisit(*increment);
334         // Expression results get pushed to m_stack. We don't use the result
335         // of increment, so we dispense of that now.
336         takeLastValue();
337     }
338
339     if (loopConditionLocation == LoopConditionLocation::AfterBody && conditionExpression) {
340         checkErrorAndVisit(*conditionExpression);
341         m_stringBuilder.append(makeString("if (!", takeLastValue(), ") break;\n"));
342     }
343
344     m_stringBuilder.append("} \n");
345 }
346
347 void FunctionDefinitionWriter::visit(AST::DoWhileLoop& doWhileLoop)
348 {
349     emitLoop(LoopConditionLocation::AfterBody, &doWhileLoop.conditional(), nullptr, doWhileLoop.body());
350 }
351
352 void FunctionDefinitionWriter::visit(AST::WhileLoop& whileLoop)
353 {
354     emitLoop(LoopConditionLocation::BeforeBody, &whileLoop.conditional(), nullptr, whileLoop.body());
355 }
356
357 void FunctionDefinitionWriter::visit(AST::ForLoop& forLoop)
358 {
359     m_stringBuilder.append("{\n");
360
361     WTF::visit(WTF::makeVisitor([&](AST::Statement& statement) {
362         checkErrorAndVisit(statement);
363     }, [&](UniqueRef<AST::Expression>& expression) {
364         checkErrorAndVisit(expression);
365         takeLastValue(); // We don't need to do anything with the result.
366     }), forLoop.initialization());
367
368     emitLoop(LoopConditionLocation::BeforeBody, forLoop.condition(), forLoop.increment(), forLoop.body());
369     m_stringBuilder.append("}\n");
370 }
371
372 void FunctionDefinitionWriter::visit(AST::IfStatement& ifStatement)
373 {
374     checkErrorAndVisit(ifStatement.conditional());
375     m_stringBuilder.append(makeString("if (", takeLastValue(), ") {\n"));
376     checkErrorAndVisit(ifStatement.body());
377     if (ifStatement.elseBody()) {
378         m_stringBuilder.append("} else {\n");
379         checkErrorAndVisit(*ifStatement.elseBody());
380     }
381     m_stringBuilder.append("}\n");
382 }
383
384 void FunctionDefinitionWriter::visit(AST::Return& returnStatement)
385 {
386     if (returnStatement.value()) {
387         checkErrorAndVisit(*returnStatement.value());
388         if (m_entryPointScaffolding) {
389             auto variableName = generateNextVariableName();
390             m_stringBuilder.append(m_entryPointScaffolding->pack(takeLastValue(), variableName));
391             m_stringBuilder.append(makeString("return ", variableName, ";\n"));
392         } else
393             m_stringBuilder.append(makeString("return ", takeLastValue(), ";\n"));
394     } else
395         m_stringBuilder.append("return;\n");
396 }
397
398 void FunctionDefinitionWriter::visit(AST::SwitchStatement& switchStatement)
399 {
400     checkErrorAndVisit(switchStatement.value());
401
402     m_stringBuilder.append(makeString("switch (", takeLastValue(), ") {"));
403     for (auto& switchCase : switchStatement.switchCases())
404         checkErrorAndVisit(switchCase);
405     m_stringBuilder.append("}\n");
406 }
407
408 void FunctionDefinitionWriter::visit(AST::SwitchCase& switchCase)
409 {
410     if (switchCase.value())
411         m_stringBuilder.append(makeString("case ", constantExpressionString(*switchCase.value()), ":\n"));
412     else
413         m_stringBuilder.append("default:\n");
414     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Switch);
415     checkErrorAndVisit(switchCase.block());
416     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195812 Figure out whether we need to break or fallthrough.
417 }
418
419 void FunctionDefinitionWriter::visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement)
420 {
421     Visitor::visit(variableDeclarationsStatement);
422 }
423
424 void FunctionDefinitionWriter::visit(AST::IntegerLiteral& integerLiteral)
425 {
426     auto variableName = generateNextVariableName();
427     auto mangledTypeName = m_typeNamer.mangledNameForType(integerLiteral.resolvedType());
428     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", integerLiteral.value(), ");\n"));
429     appendRightValue(integerLiteral, variableName);
430 }
431
432 void FunctionDefinitionWriter::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
433 {
434     auto variableName = generateNextVariableName();
435     auto mangledTypeName = m_typeNamer.mangledNameForType(unsignedIntegerLiteral.resolvedType());
436     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", unsignedIntegerLiteral.value(), ");\n"));
437     appendRightValue(unsignedIntegerLiteral, variableName);
438 }
439
440 void FunctionDefinitionWriter::visit(AST::FloatLiteral& floatLiteral)
441 {
442     auto variableName = generateNextVariableName();
443     auto mangledTypeName = m_typeNamer.mangledNameForType(floatLiteral.resolvedType());
444     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", floatLiteral.value(), ");\n"));
445     appendRightValue(floatLiteral, variableName);
446 }
447
448 void FunctionDefinitionWriter::visit(AST::NullLiteral& nullLiteral)
449 {
450     auto& unifyNode = nullLiteral.resolvedType().unifyNode();
451     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
452     bool isArrayReferenceType = is<AST::ArrayReferenceType>(unnamedType);
453
454     auto variableName = generateNextVariableName();
455     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(nullLiteral.resolvedType()), ' ', variableName, " = "));
456     if (isArrayReferenceType)
457         m_stringBuilder.append("{ nullptr, 0 }");
458     else
459         m_stringBuilder.append("nullptr");
460     m_stringBuilder.append(";\n");
461     appendRightValue(nullLiteral, variableName);
462 }
463
464 void FunctionDefinitionWriter::visit(AST::BooleanLiteral& booleanLiteral)
465 {
466     auto variableName = generateNextVariableName();
467     auto mangledTypeName = m_typeNamer.mangledNameForType(booleanLiteral.resolvedType());
468     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", booleanLiteral.value() ? "true" : "false", ");\n"));
469     appendRightValue(booleanLiteral, variableName);
470 }
471
472 void FunctionDefinitionWriter::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
473 {
474     ASSERT(enumerationMemberLiteral.enumerationDefinition());
475     ASSERT(enumerationMemberLiteral.enumerationDefinition());
476     auto variableName = generateNextVariableName();
477     auto mangledTypeName = m_typeNamer.mangledNameForType(enumerationMemberLiteral.resolvedType());
478     m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = ", mangledTypeName, "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()), ";\n"));
479     appendRightValue(enumerationMemberLiteral, variableName);
480 }
481
482 void FunctionDefinitionWriter::visit(AST::Expression& expression)
483 {
484     Visitor::visit(expression);
485 }
486
487 void FunctionDefinitionWriter::visit(AST::DotExpression& dotExpression)
488 {
489     // This should be lowered already.
490     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
491     notImplemented();
492     appendRightValue(dotExpression, "dummy");
493 }
494
495 void FunctionDefinitionWriter::visit(AST::GlobalVariableReference& globalVariableReference)
496 {
497     auto valueName = generateNextVariableName();
498     auto pointerName = generateNextVariableName();
499     auto mangledTypeName = m_typeNamer.mangledNameForType(globalVariableReference.resolvedType());
500     checkErrorAndVisit(globalVariableReference.base());
501     m_stringBuilder.append(makeString("thread ", mangledTypeName, "* ", pointerName, " = &", takeLastValue(), "->", m_typeNamer.mangledNameForStructureElement(globalVariableReference.structField()), ";\n"));
502     m_stringBuilder.append(makeString(mangledTypeName, ' ', valueName, " = ", "*", pointerName, ";\n"));
503     appendLeftValue(globalVariableReference, valueName, pointerName);
504 }
505
506 void FunctionDefinitionWriter::visit(AST::IndexExpression& indexExpression)
507 {
508     // This should be lowered already.
509     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
510     notImplemented();
511     appendRightValue(indexExpression, "dummy");
512 }
513
514 void FunctionDefinitionWriter::visit(AST::PropertyAccessExpression& propertyAccessExpression)
515 {
516     // This should be lowered already.
517     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
518     notImplemented();
519     appendRightValue(propertyAccessExpression, "dummy");
520 }
521
522 void FunctionDefinitionWriter::visit(AST::VariableDeclaration& variableDeclaration)
523 {
524     ASSERT(variableDeclaration.type());
525     auto variableName = generateNextVariableName();
526     auto addResult = m_variableMapping.add(&variableDeclaration, variableName);
527     ASSERT_UNUSED(addResult, addResult.isNewEntry);
528     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198160 Implement qualifiers.
529     if (variableDeclaration.initializer()) {
530         checkErrorAndVisit(*variableDeclaration.initializer());
531         m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = ", takeLastValue(), ";\n"));
532     } else
533         m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, ";\n"));
534 }
535
536 void FunctionDefinitionWriter::visit(AST::AssignmentExpression& assignmentExpression)
537 {
538     checkErrorAndVisit(assignmentExpression.left());
539     auto pointerName = takeLastLeftValue();
540     checkErrorAndVisit(assignmentExpression.right());
541     auto rightName = takeLastValue();
542     m_stringBuilder.append(makeString("if (", pointerName, ") *", pointerName, " = ", rightName, ";\n"));
543     appendRightValue(assignmentExpression, rightName);
544 }
545
546 void FunctionDefinitionWriter::visit(AST::CallExpression& callExpression)
547 {
548     Vector<String> argumentNames;
549     for (auto& argument : callExpression.arguments()) {
550         checkErrorAndVisit(argument);
551         argumentNames.append(takeLastValue());
552     }
553     auto iterator = m_functionMapping.find(&callExpression.function());
554     ASSERT(iterator != m_functionMapping.end());
555     auto variableName = generateNextVariableName();
556     if (!matches(callExpression.resolvedType(), m_intrinsics.voidType()))
557         m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(callExpression.resolvedType()), ' ', variableName, " = "));
558     m_stringBuilder.append(makeString(iterator->value, '('));
559     for (size_t i = 0; i < argumentNames.size(); ++i) {
560         if (i)
561             m_stringBuilder.append(", ");
562         m_stringBuilder.append(argumentNames[i]);
563     }
564     m_stringBuilder.append(");\n");
565     appendRightValue(callExpression, variableName);
566 }
567
568 void FunctionDefinitionWriter::visit(AST::CommaExpression& commaExpression)
569 {
570     String result;
571     for (auto& expression : commaExpression.list()) {
572         checkErrorAndVisit(expression);
573         result = takeLastValue();
574     }
575     appendRightValue(commaExpression, result);
576 }
577
578 void FunctionDefinitionWriter::visit(AST::DereferenceExpression& dereferenceExpression)
579 {
580     checkErrorAndVisit(dereferenceExpression.pointer());
581     auto right = takeLastValue();
582     auto variableName = generateNextVariableName();
583     auto pointerName = generateNextVariableName();
584     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(dereferenceExpression.pointer().resolvedType()), ' ', pointerName, " = ", right, ";\n"));
585     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(dereferenceExpression.resolvedType()), ' ', variableName, ";\n"));
586     m_stringBuilder.append(makeString("if (", pointerName, ") ", variableName, " = *", right, ";\n"));
587     m_stringBuilder.append(makeString("else ", memsetZeroFunctionName, '(', variableName, ");\n"));
588     appendLeftValue(dereferenceExpression, variableName, pointerName);
589 }
590
591 void FunctionDefinitionWriter::visit(AST::LogicalExpression& logicalExpression)
592 {
593     checkErrorAndVisit(logicalExpression.left());
594     auto left = takeLastValue();
595     checkErrorAndVisit(logicalExpression.right());
596     auto right = takeLastValue();
597     auto variableName = generateNextVariableName();
598     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(logicalExpression.resolvedType()), ' ', variableName, " = ", left));
599     switch (logicalExpression.type()) {
600     case AST::LogicalExpression::Type::And:
601         m_stringBuilder.append(" && ");
602         break;
603     default:
604         ASSERT(logicalExpression.type() == AST::LogicalExpression::Type::Or);
605         m_stringBuilder.append(" || ");
606         break;
607     }
608     m_stringBuilder.append(makeString(right, ";\n"));
609     appendRightValue(logicalExpression, variableName);
610 }
611
612 void FunctionDefinitionWriter::visit(AST::LogicalNotExpression& logicalNotExpression)
613 {
614     checkErrorAndVisit(logicalNotExpression.operand());
615     auto operand = takeLastValue();
616     auto variableName = generateNextVariableName();
617     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(logicalNotExpression.resolvedType()), ' ', variableName, " = !", operand, ";\n"));
618     appendRightValue(logicalNotExpression, variableName);
619 }
620
621 void FunctionDefinitionWriter::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
622 {
623     checkErrorAndVisit(makeArrayReferenceExpression.leftValue());
624     // FIXME: This needs to be made to work. It probably should be using the last leftValue too.
625     // https://bugs.webkit.org/show_bug.cgi?id=198838
626     auto variableName = generateNextVariableName();
627
628     auto mangledTypeName = m_typeNamer.mangledNameForType(makeArrayReferenceExpression.resolvedType());
629     if (is<AST::PointerType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
630         auto ptrValue = takeLastValue();
631         m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, ";\n"));
632         m_stringBuilder.append(makeString("if (", ptrValue, ") ", variableName, " = { ", ptrValue, ", 1};\n"));
633         m_stringBuilder.append(makeString("else ", variableName, " = { nullptr, 0 };\n"));
634     } else if (is<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
635         auto lValue = takeLastLeftValue();
636         auto& arrayType = downcast<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType());
637         m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = { ", lValue, "->data(), ", arrayType.numElements(), " };\n"));
638     } else {
639         auto lValue = takeLastLeftValue();
640         m_stringBuilder.append(makeString(mangledTypeName, ' ', variableName, " = { ", lValue, ", 1 };\n"));
641     }
642     appendRightValue(makeArrayReferenceExpression, variableName);
643 }
644
645 void FunctionDefinitionWriter::visit(AST::MakePointerExpression& makePointerExpression)
646 {
647     checkErrorAndVisit(makePointerExpression.leftValue());
648     auto pointer = takeLastLeftValue();
649     auto variableName = generateNextVariableName();
650     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(makePointerExpression.resolvedType()), ' ', variableName, " = ", pointer, ";\n"));
651     appendRightValue(makePointerExpression, variableName);
652 }
653
654 void FunctionDefinitionWriter::visit(AST::ReadModifyWriteExpression&)
655 {
656     // This should be lowered already.
657     ASSERT_NOT_REACHED();
658 }
659
660 void FunctionDefinitionWriter::visit(AST::TernaryExpression& ternaryExpression)
661 {
662     checkErrorAndVisit(ternaryExpression.predicate());
663     auto check = takeLastValue();
664
665     auto variableName = generateNextVariableName();
666     m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(ternaryExpression.resolvedType()), ' ', variableName, ";\n"));
667
668     m_stringBuilder.append(makeString("if (", check, ") {\n"));
669     checkErrorAndVisit(ternaryExpression.bodyExpression());
670     m_stringBuilder.append(makeString(variableName, " = ", takeLastValue(), ";\n"));
671     m_stringBuilder.append("} else {\n");
672     checkErrorAndVisit(ternaryExpression.elseExpression());
673     m_stringBuilder.append(makeString(variableName, " = ", takeLastValue(), ";\n"));
674     m_stringBuilder.append("}\n");
675     appendRightValue(ternaryExpression, variableName);
676 }
677
678 void FunctionDefinitionWriter::visit(AST::VariableReference& variableReference)
679 {
680     ASSERT(variableReference.variable());
681     auto iterator = m_variableMapping.find(variableReference.variable());
682     ASSERT(iterator != m_variableMapping.end());
683     auto pointerName = generateNextVariableName();
684     m_stringBuilder.append(makeString("thread ", m_typeNamer.mangledNameForType(variableReference.resolvedType()), "* ", pointerName, " = &", iterator->value, ";\n"));
685     appendLeftValue(variableReference, iterator->value, pointerName);
686 }
687
688 String FunctionDefinitionWriter::constantExpressionString(AST::ConstantExpression& constantExpression)
689 {
690     return constantExpression.visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> String {
691         return makeString("", integerLiteral.value());
692     }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> String {
693         return makeString("", unsignedIntegerLiteral.value());
694     }, [&](AST::FloatLiteral& floatLiteral) -> String {
695         return makeString("", floatLiteral.value());
696     }, [&](AST::NullLiteral&) -> String {
697         return "nullptr"_str;
698     }, [&](AST::BooleanLiteral& booleanLiteral) -> String {
699         return booleanLiteral.value() ? "true"_str : "false"_str;
700     }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> String {
701         ASSERT(enumerationMemberLiteral.enumerationDefinition());
702         ASSERT(enumerationMemberLiteral.enumerationDefinition());
703         return makeString(m_typeNamer.mangledNameForType(*enumerationMemberLiteral.enumerationDefinition()), "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()));
704     }));
705 }
706
707 class RenderFunctionDefinitionWriter : public FunctionDefinitionWriter {
708 public:
709     RenderFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
710         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
711         , m_matchedSemantics(WTFMove(matchedSemantics))
712     {
713     }
714
715 private:
716     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
717
718     MatchedRenderSemantics m_matchedSemantics;
719 };
720
721 std::unique_ptr<EntryPointScaffolding> RenderFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
722 {
723     auto generateNextVariableName = [this]() -> String {
724         return this->generateNextVariableName();
725     };
726     if (&functionDefinition == m_matchedSemantics.vertexShader)
727         return std::make_unique<VertexEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.vertexShaderEntryPointItems, m_matchedSemantics.vertexShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedVertexAttributes);
728     if (&functionDefinition == m_matchedSemantics.fragmentShader)
729         return std::make_unique<FragmentEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.fragmentShaderEntryPointItems, m_matchedSemantics.fragmentShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedColorAttachments);
730     return nullptr;
731 }
732
733 class ComputeFunctionDefinitionWriter : public FunctionDefinitionWriter {
734 public:
735     ComputeFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
736         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
737         , m_matchedSemantics(WTFMove(matchedSemantics))
738     {
739     }
740
741 private:
742     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
743
744     MatchedComputeSemantics m_matchedSemantics;
745 };
746
747 std::unique_ptr<EntryPointScaffolding> ComputeFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
748 {
749     auto generateNextVariableName = [this]() -> String {
750         return this->generateNextVariableName();
751     };
752     if (&functionDefinition == m_matchedSemantics.shader)
753         return std::make_unique<ComputeEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.entryPointItems, m_matchedSemantics.resourceMap, m_layout, WTFMove(generateNextVariableName));
754     return nullptr;
755 }
756
757 struct SharedMetalFunctionsResult {
758     HashMap<AST::FunctionDeclaration*, String> functionMapping;
759     String metalFunctions;
760 };
761 static SharedMetalFunctionsResult sharedMetalFunctions(Program& program, TypeNamer& typeNamer, const HashSet<AST::FunctionDeclaration*>& reachableFunctions)
762 {
763     StringBuilder stringBuilder;
764
765     unsigned numFunctions = 0;
766     HashMap<AST::FunctionDeclaration*, String> functionMapping;
767     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
768         auto addResult = functionMapping.add(&nativeFunctionDeclaration, makeString("function", numFunctions++));
769         ASSERT_UNUSED(addResult, addResult.isNewEntry);
770     }
771     for (auto& functionDefinition : program.functionDefinitions()) {
772         auto addResult = functionMapping.add(&functionDefinition, makeString("function", numFunctions++));
773         ASSERT_UNUSED(addResult, addResult.isNewEntry);
774     }
775
776     {
777         FunctionDeclarationWriter functionDeclarationWriter(typeNamer, functionMapping);
778         for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
779             if (reachableFunctions.contains(&nativeFunctionDeclaration))
780                 functionDeclarationWriter.visit(nativeFunctionDeclaration);
781         }
782         for (auto& functionDefinition : program.functionDefinitions()) {
783             if (!functionDefinition->entryPointType() && reachableFunctions.contains(&functionDefinition))
784                 functionDeclarationWriter.visit(functionDefinition);
785         }
786         stringBuilder.append(functionDeclarationWriter.toString());
787     }
788
789     stringBuilder.append('\n');
790     return { WTFMove(functionMapping), stringBuilder.toString() };
791 }
792
793 class ReachableFunctionsGatherer : public Visitor {
794 public:
795     void visit(AST::FunctionDeclaration& functionDeclaration) override
796     {
797         Visitor::visit(functionDeclaration);
798         m_reachableFunctions.add(&functionDeclaration);
799     }
800
801     void visit(AST::CallExpression& callExpression) override
802     {
803         Visitor::visit(callExpression);
804         if (is<AST::FunctionDefinition>(callExpression.function()))
805             checkErrorAndVisit(downcast<AST::FunctionDefinition>(callExpression.function()));
806         else
807             checkErrorAndVisit(downcast<AST::NativeFunctionDeclaration>(callExpression.function()));
808     }
809
810     HashSet<AST::FunctionDeclaration*> takeReachableFunctions() { return WTFMove(m_reachableFunctions); }
811
812 private:
813     HashSet<AST::FunctionDeclaration*> m_reachableFunctions;
814 };
815
816 RenderMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
817 {
818     auto& vertexShaderEntryPoint = *matchedSemantics.vertexShader;
819     auto& fragmentShaderEntryPoint = *matchedSemantics.fragmentShader;
820
821     ReachableFunctionsGatherer reachableFunctionsGatherer;
822     reachableFunctionsGatherer.Visitor::visit(vertexShaderEntryPoint);
823     reachableFunctionsGatherer.Visitor::visit(fragmentShaderEntryPoint);
824     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
825
826     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer, reachableFunctions);
827
828     StringBuilder stringBuilder;
829     stringBuilder.append(sharedMetalFunctions.metalFunctions);
830
831     RenderFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
832     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
833         if (reachableFunctions.contains(&nativeFunctionDeclaration))
834             functionDefinitionWriter.visit(nativeFunctionDeclaration);
835     }
836     for (auto& functionDefinition : program.functionDefinitions()) {
837         if (reachableFunctions.contains(&functionDefinition))
838             functionDefinitionWriter.visit(functionDefinition);
839     }
840     stringBuilder.append(functionDefinitionWriter.toString());
841
842     RenderMetalFunctions result;
843     result.metalSource = stringBuilder.toString();
844     result.mangledVertexEntryPointName = sharedMetalFunctions.functionMapping.get(&vertexShaderEntryPoint);
845     result.mangledFragmentEntryPointName = sharedMetalFunctions.functionMapping.get(&fragmentShaderEntryPoint);
846     return result;
847 }
848
849 ComputeMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
850 {
851     auto& entryPoint = *matchedSemantics.shader;
852
853     ReachableFunctionsGatherer reachableFunctionsGatherer;
854     reachableFunctionsGatherer.Visitor::visit(entryPoint);
855     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
856
857     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer, reachableFunctions);
858
859     StringBuilder stringBuilder;
860     stringBuilder.append(sharedMetalFunctions.metalFunctions);
861
862     ComputeFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
863     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
864         if (reachableFunctions.contains(&nativeFunctionDeclaration))
865             functionDefinitionWriter.visit(nativeFunctionDeclaration);
866     }
867     for (auto& functionDefinition : program.functionDefinitions()) {
868         if (reachableFunctions.contains(&functionDefinition))
869             functionDefinitionWriter.visit(functionDefinition);
870     }
871     stringBuilder.append(functionDefinitionWriter.toString());
872
873     ComputeMetalFunctions result;
874     result.metalSource = stringBuilder.toString();
875     result.mangledEntryPointName = sharedMetalFunctions.functionMapping.get(&entryPoint);
876     return result;
877 }
878
879 } // namespace Metal
880
881 } // namespace WHLSL
882
883 } // namespace WebCore
884
885 #endif