9732bfc0dd4698024637e1a9dc1e744c3f2d71ad
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / Metal / WHLSLFunctionWriter.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLFunctionWriter.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "NotImplemented.h"
32 #include "WHLSLAST.h"
33 #include "WHLSLEntryPointScaffolding.h"
34 #include "WHLSLInferTypes.h"
35 #include "WHLSLNativeFunctionWriter.h"
36 #include "WHLSLProgram.h"
37 #include "WHLSLTypeNamer.h"
38 #include "WHLSLVisitor.h"
39 #include <wtf/HashMap.h>
40 #include <wtf/HashSet.h>
41 #include <wtf/SetForScope.h>
42 #include <wtf/text/StringBuilder.h>
43
44 namespace WebCore {
45
46 namespace WHLSL {
47
48 namespace Metal {
49
50 class FunctionDeclarationWriter : public Visitor {
51 public:
52     FunctionDeclarationWriter(TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping)
53         : m_typeNamer(typeNamer)
54         , m_functionMapping(functionMapping)
55     {
56     }
57
58     virtual ~FunctionDeclarationWriter() = default;
59
60     String toString() { return m_stringBuilder.toString(); }
61
62     void visit(AST::FunctionDeclaration&) override;
63
64 private:
65     TypeNamer& m_typeNamer;
66     HashMap<AST::FunctionDeclaration*, String>& m_functionMapping;
67     StringBuilder m_stringBuilder;
68 };
69
70 void FunctionDeclarationWriter::visit(AST::FunctionDeclaration& functionDeclaration)
71 {
72     if (functionDeclaration.entryPointType())
73         return;
74
75     auto iterator = m_functionMapping.find(&functionDeclaration);
76     ASSERT(iterator != m_functionMapping.end());
77     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(functionDeclaration.type()), ' ', iterator->value, '(');
78     for (size_t i = 0; i < functionDeclaration.parameters().size(); ++i) {
79         if (i)
80             m_stringBuilder.append(", ");
81         m_stringBuilder.append(m_typeNamer.mangledNameForType(*functionDeclaration.parameters()[i]->type()));
82     }
83     m_stringBuilder.append(");\n");
84 }
85
86 class FunctionDefinitionWriter : public Visitor {
87 public:
88     FunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, Layout& layout)
89         : m_intrinsics(intrinsics)
90         , m_typeNamer(typeNamer)
91         , m_functionMapping(functionMapping)
92         , m_layout(layout)
93     {
94     }
95
96     virtual ~FunctionDefinitionWriter() = default;
97
98     String toString() { return m_stringBuilder.toString(); }
99
100     void visit(AST::NativeFunctionDeclaration&) override;
101     void visit(AST::FunctionDefinition&) override;
102
103 protected:
104     virtual std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) = 0;
105
106     void visit(AST::FunctionDeclaration&) override;
107     void visit(AST::Statement&) override;
108     void visit(AST::Block&) override;
109     void visit(AST::Break&) override;
110     void visit(AST::Continue&) override;
111     void visit(AST::DoWhileLoop&) override;
112     void visit(AST::EffectfulExpressionStatement&) override;
113     void visit(AST::Fallthrough&) override;
114     void visit(AST::ForLoop&) override;
115     void visit(AST::IfStatement&) override;
116     void visit(AST::Return&) override;
117     void visit(AST::SwitchStatement&) override;
118     void visit(AST::SwitchCase&) override;
119     void visit(AST::VariableDeclarationsStatement&) override;
120     void visit(AST::WhileLoop&) override;
121     void visit(AST::IntegerLiteral&) override;
122     void visit(AST::UnsignedIntegerLiteral&) override;
123     void visit(AST::FloatLiteral&) override;
124     void visit(AST::NullLiteral&) override;
125     void visit(AST::BooleanLiteral&) override;
126     void visit(AST::EnumerationMemberLiteral&) override;
127     void visit(AST::Expression&) override;
128     void visit(AST::DotExpression&) override;
129     void visit(AST::GlobalVariableReference&) override;
130     void visit(AST::IndexExpression&) override;
131     void visit(AST::PropertyAccessExpression&) override;
132     void visit(AST::VariableDeclaration&) override;
133     void visit(AST::AssignmentExpression&) override;
134     void visit(AST::CallExpression&) override;
135     void visit(AST::CommaExpression&) override;
136     void visit(AST::DereferenceExpression&) override;
137     void visit(AST::LogicalExpression&) override;
138     void visit(AST::LogicalNotExpression&) override;
139     void visit(AST::MakeArrayReferenceExpression&) override;
140     void visit(AST::MakePointerExpression&) override;
141     void visit(AST::ReadModifyWriteExpression&) override;
142     void visit(AST::TernaryExpression&) override;
143     void visit(AST::VariableReference&) override;
144
145     enum class LoopConditionLocation {
146         BeforeBody,
147         AfterBody
148     };
149     void emitLoop(LoopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body);
150
151     String constantExpressionString(AST::ConstantExpression&);
152
153     String generateNextVariableName()
154     {
155         return makeString("variable", m_variableCount++);
156     }
157
158     enum class Nullability : uint8_t {
159         NotNull,
160         CanBeNull
161     };
162
163     struct StackItem {
164         String value;
165         String leftValue;
166         Nullability valueNullability;
167         Nullability leftValueNullability;
168     };
169
170     struct StackValue {
171         String value;
172         Nullability nullability;
173     };
174
175     // This is the important data flow step where we can take the nullability of an lvalue
176     // and transfer it into the nullability of an rvalue. This is conveyed in MakePointerExpression
177     // and DereferenceExpression. MakePointerExpression will try to produce rvalues which are
178     // non-null, and DereferenceExpression will take a non-null rvalue and try to produce
179     // a non-null lvalue.
180     void appendRightValueWithNullability(AST::Expression&, String value, Nullability nullability)
181     {
182         m_stack.append({ WTFMove(value), String(), nullability, Nullability::CanBeNull });
183     }
184
185     void appendRightValue(AST::Expression& expression, String value)
186     {
187         appendRightValueWithNullability(expression, WTFMove(value), Nullability::CanBeNull);
188     }
189
190     void appendLeftValue(AST::Expression& expression, String value, String leftValue, Nullability nullability)
191     {
192         ASSERT_UNUSED(expression, expression.typeAnnotation().leftAddressSpace());
193         m_stack.append({ WTFMove(value), WTFMove(leftValue), Nullability::CanBeNull, nullability });
194     }
195
196     String takeLastValue()
197     {
198         ASSERT(m_stack.last().value);
199         return m_stack.takeLast().value;
200     }
201
202     StackValue takeLastValueAndNullability()
203     {
204         ASSERT(m_stack.last().value);
205         auto last = m_stack.takeLast();
206         return { last.value, last.valueNullability };
207     }
208
209     StackValue takeLastLeftValue()
210     {
211         ASSERT(m_stack.last().leftValue);
212         auto last = m_stack.takeLast();
213         return { last.leftValue, last.leftValueNullability };
214     }
215
216     enum class BreakContext {
217         Loop,
218         Switch
219     };
220
221     Optional<BreakContext> m_currentBreakContext;
222
223     Intrinsics& m_intrinsics;
224     TypeNamer& m_typeNamer;
225     HashMap<AST::FunctionDeclaration*, String>& m_functionMapping;
226     HashMap<AST::VariableDeclaration*, String> m_variableMapping;
227     StringBuilder m_stringBuilder;
228
229     Vector<StackItem> m_stack;
230     std::unique_ptr<EntryPointScaffolding> m_entryPointScaffolding;
231     Layout& m_layout;
232     unsigned m_variableCount { 0 };
233     String m_breakOutOfCurrentLoopEarlyVariable;
234 };
235
236 void FunctionDefinitionWriter::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
237 {
238     auto iterator = m_functionMapping.find(&nativeFunctionDeclaration);
239     ASSERT(iterator != m_functionMapping.end());
240     m_stringBuilder.append(writeNativeFunction(nativeFunctionDeclaration, iterator->value, m_intrinsics, m_typeNamer));
241 }
242
243 void FunctionDefinitionWriter::visit(AST::FunctionDefinition& functionDefinition)
244 {
245     auto iterator = m_functionMapping.find(&functionDefinition);
246     ASSERT(iterator != m_functionMapping.end());
247     if (functionDefinition.entryPointType()) {
248         auto entryPointScaffolding = createEntryPointScaffolding(functionDefinition);
249         if (!entryPointScaffolding)
250             return;
251         m_entryPointScaffolding = WTFMove(entryPointScaffolding);
252         m_stringBuilder.flexibleAppend(
253             m_entryPointScaffolding->helperTypes(), '\n',
254             m_entryPointScaffolding->signature(iterator->value), " {\n",
255             m_entryPointScaffolding->unpack()
256         );
257         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
258             auto addResult = m_variableMapping.add(&functionDefinition.parameters()[i], m_entryPointScaffolding->parameterVariables()[i]);
259             ASSERT_UNUSED(addResult, addResult.isNewEntry);
260         }
261         checkErrorAndVisit(functionDefinition.block());
262         ASSERT(m_stack.isEmpty());
263         m_stringBuilder.append("}\n");
264         m_entryPointScaffolding = nullptr;
265     } else {
266         ASSERT(m_entryPointScaffolding == nullptr);
267         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(functionDefinition.type()), ' ', iterator->value, '(');
268         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
269             auto& parameter = functionDefinition.parameters()[i];
270             if (i)
271                 m_stringBuilder.append(", ");
272             auto parameterName = generateNextVariableName();
273             auto addResult = m_variableMapping.add(&parameter, parameterName);
274             ASSERT_UNUSED(addResult, addResult.isNewEntry);
275             m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*parameter->type()), ' ', parameterName);
276         }
277         m_stringBuilder.append(") {\n");
278         checkErrorAndVisit(functionDefinition.block());
279         ASSERT(m_stack.isEmpty());
280         m_stringBuilder.append("}\n");
281     }
282 }
283
284 void FunctionDefinitionWriter::visit(AST::FunctionDeclaration&)
285 {
286     ASSERT_NOT_REACHED();
287 }
288
289 void FunctionDefinitionWriter::visit(AST::Statement& statement)
290 {
291     Visitor::visit(statement);
292 }
293
294 void FunctionDefinitionWriter::visit(AST::Block& block)
295 {
296     m_stringBuilder.append("{\n");
297     for (auto& statement : block.statements())
298         checkErrorAndVisit(statement);
299     m_stringBuilder.append("}\n");
300 }
301
302 void FunctionDefinitionWriter::visit(AST::Break&)
303 {
304     ASSERT(m_currentBreakContext);
305     switch (*m_currentBreakContext) {
306     case BreakContext::Switch:
307         m_stringBuilder.append("break;\n");
308         break;
309     case BreakContext::Loop:
310         ASSERT(m_breakOutOfCurrentLoopEarlyVariable.length());
311         m_stringBuilder.flexibleAppend(
312             m_breakOutOfCurrentLoopEarlyVariable, " = true;\n"
313             "break;\n"
314         );
315         break;
316     }
317 }
318
319 void FunctionDefinitionWriter::visit(AST::Continue&)
320 {
321     ASSERT(m_breakOutOfCurrentLoopEarlyVariable.length());
322     m_stringBuilder.append("break;\n");
323 }
324
325 void FunctionDefinitionWriter::visit(AST::EffectfulExpressionStatement& effectfulExpressionStatement)
326 {
327     checkErrorAndVisit(effectfulExpressionStatement.effectfulExpression());
328     takeLastValue(); // The statement is already effectful, so we don't need to do anything with the result.
329 }
330
331 void FunctionDefinitionWriter::visit(AST::Fallthrough&)
332 {
333     m_stringBuilder.append("[[clang::fallthrough]];\n"); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195808 Make sure this is okay. Alternatively, we could do nothing and just return here instead.
334 }
335
336 void FunctionDefinitionWriter::emitLoop(LoopConditionLocation loopConditionLocation, AST::Expression* conditionExpression, AST::Expression* increment, AST::Statement& body)
337 {
338     SetForScope<String> loopVariableScope(m_breakOutOfCurrentLoopEarlyVariable, generateNextVariableName());
339
340     m_stringBuilder.flexibleAppend(
341         "bool ", m_breakOutOfCurrentLoopEarlyVariable, " = false;\n",
342         "while (true) {\n"
343     );
344
345     if (loopConditionLocation == LoopConditionLocation::BeforeBody && conditionExpression) {
346         checkErrorAndVisit(*conditionExpression);
347         m_stringBuilder.flexibleAppend("if (!", takeLastValue(), ") break;\n");
348     }
349
350     m_stringBuilder.append("do {\n");
351     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Loop);
352     checkErrorAndVisit(body);
353     m_stringBuilder.flexibleAppend(
354         "} while(false); \n"
355         "if (", m_breakOutOfCurrentLoopEarlyVariable, ") break;\n"
356     );
357
358     if (increment) {
359         checkErrorAndVisit(*increment);
360         // Expression results get pushed to m_stack. We don't use the result
361         // of increment, so we dispense of that now.
362         takeLastValue();
363     }
364
365     if (loopConditionLocation == LoopConditionLocation::AfterBody && conditionExpression) {
366         checkErrorAndVisit(*conditionExpression);
367         m_stringBuilder.flexibleAppend("if (!", takeLastValue(), ") break;\n");
368     }
369
370     m_stringBuilder.append("} \n");
371 }
372
373 void FunctionDefinitionWriter::visit(AST::DoWhileLoop& doWhileLoop)
374 {
375     emitLoop(LoopConditionLocation::AfterBody, &doWhileLoop.conditional(), nullptr, doWhileLoop.body());
376 }
377
378 void FunctionDefinitionWriter::visit(AST::WhileLoop& whileLoop)
379 {
380     emitLoop(LoopConditionLocation::BeforeBody, &whileLoop.conditional(), nullptr, whileLoop.body());
381 }
382
383 void FunctionDefinitionWriter::visit(AST::ForLoop& forLoop)
384 {
385     m_stringBuilder.append("{\n");
386
387     WTF::visit(WTF::makeVisitor([&](AST::Statement& statement) {
388         checkErrorAndVisit(statement);
389     }, [&](UniqueRef<AST::Expression>& expression) {
390         checkErrorAndVisit(expression);
391         takeLastValue(); // We don't need to do anything with the result.
392     }), forLoop.initialization());
393
394     emitLoop(LoopConditionLocation::BeforeBody, forLoop.condition(), forLoop.increment(), forLoop.body());
395     m_stringBuilder.append("}\n");
396 }
397
398 void FunctionDefinitionWriter::visit(AST::IfStatement& ifStatement)
399 {
400     checkErrorAndVisit(ifStatement.conditional());
401     m_stringBuilder.flexibleAppend("if (", takeLastValue(), ") {\n");
402     checkErrorAndVisit(ifStatement.body());
403     if (ifStatement.elseBody()) {
404         m_stringBuilder.append("} else {\n");
405         checkErrorAndVisit(*ifStatement.elseBody());
406     }
407     m_stringBuilder.append("}\n");
408 }
409
410 void FunctionDefinitionWriter::visit(AST::Return& returnStatement)
411 {
412     if (returnStatement.value()) {
413         checkErrorAndVisit(*returnStatement.value());
414         if (m_entryPointScaffolding) {
415             auto variableName = generateNextVariableName();
416             m_stringBuilder.flexibleAppend(
417                 m_entryPointScaffolding->pack(takeLastValue(), variableName),
418                 "return ", variableName, ";\n"
419             );
420         } else
421             m_stringBuilder.flexibleAppend("return ", takeLastValue(), ";\n");
422     } else
423         m_stringBuilder.append("return;\n");
424 }
425
426 void FunctionDefinitionWriter::visit(AST::SwitchStatement& switchStatement)
427 {
428     checkErrorAndVisit(switchStatement.value());
429
430     m_stringBuilder.flexibleAppend("switch (", takeLastValue(), ") {");
431     for (auto& switchCase : switchStatement.switchCases())
432         checkErrorAndVisit(switchCase);
433     m_stringBuilder.append("}\n");
434 }
435
436 void FunctionDefinitionWriter::visit(AST::SwitchCase& switchCase)
437 {
438     if (switchCase.value())
439         m_stringBuilder.flexibleAppend("case ", constantExpressionString(*switchCase.value()), ":\n");
440     else
441         m_stringBuilder.append("default:\n");
442     SetForScope<Optional<BreakContext>> breakContext(m_currentBreakContext, BreakContext::Switch);
443     checkErrorAndVisit(switchCase.block());
444     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195812 Figure out whether we need to break or fallthrough.
445 }
446
447 void FunctionDefinitionWriter::visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement)
448 {
449     Visitor::visit(variableDeclarationsStatement);
450 }
451
452 void FunctionDefinitionWriter::visit(AST::IntegerLiteral& integerLiteral)
453 {
454     auto variableName = generateNextVariableName();
455     auto mangledTypeName = m_typeNamer.mangledNameForType(integerLiteral.resolvedType());
456     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", integerLiteral.value(), ");\n");
457     appendRightValue(integerLiteral, variableName);
458 }
459
460 void FunctionDefinitionWriter::visit(AST::UnsignedIntegerLiteral& unsignedIntegerLiteral)
461 {
462     auto variableName = generateNextVariableName();
463     auto mangledTypeName = m_typeNamer.mangledNameForType(unsignedIntegerLiteral.resolvedType());
464     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", unsignedIntegerLiteral.value(), ");\n");
465     appendRightValue(unsignedIntegerLiteral, variableName);
466 }
467
468 void FunctionDefinitionWriter::visit(AST::FloatLiteral& floatLiteral)
469 {
470     auto variableName = generateNextVariableName();
471     auto mangledTypeName = m_typeNamer.mangledNameForType(floatLiteral.resolvedType());
472     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", floatLiteral.value(), ");\n");
473     appendRightValue(floatLiteral, variableName);
474 }
475
476 void FunctionDefinitionWriter::visit(AST::NullLiteral& nullLiteral)
477 {
478     auto& unifyNode = nullLiteral.resolvedType().unifyNode();
479     auto& unnamedType = downcast<AST::UnnamedType>(unifyNode);
480     bool isArrayReferenceType = is<AST::ArrayReferenceType>(unnamedType);
481
482     auto variableName = generateNextVariableName();
483     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(nullLiteral.resolvedType()), ' ', variableName, " = ");
484     if (isArrayReferenceType)
485         m_stringBuilder.append("{ nullptr, 0 };\n");
486     else
487         m_stringBuilder.append("nullptr;\n");
488     appendRightValue(nullLiteral, variableName);
489 }
490
491 void FunctionDefinitionWriter::visit(AST::BooleanLiteral& booleanLiteral)
492 {
493     auto variableName = generateNextVariableName();
494     auto mangledTypeName = m_typeNamer.mangledNameForType(booleanLiteral.resolvedType());
495     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = static_cast<", mangledTypeName, ">(", booleanLiteral.value() ? "true" : "false", ");\n");
496     appendRightValue(booleanLiteral, variableName);
497 }
498
499 void FunctionDefinitionWriter::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
500 {
501     ASSERT(enumerationMemberLiteral.enumerationDefinition());
502     ASSERT(enumerationMemberLiteral.enumerationDefinition());
503     auto variableName = generateNextVariableName();
504     auto mangledTypeName = m_typeNamer.mangledNameForType(enumerationMemberLiteral.resolvedType());
505     m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = ", mangledTypeName, "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()), ";\n");
506     appendRightValue(enumerationMemberLiteral, variableName);
507 }
508
509 void FunctionDefinitionWriter::visit(AST::Expression& expression)
510 {
511     Visitor::visit(expression);
512 }
513
514 void FunctionDefinitionWriter::visit(AST::DotExpression& dotExpression)
515 {
516     // This should be lowered already.
517     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
518     notImplemented();
519     appendRightValue(dotExpression, "dummy");
520 }
521
522 void FunctionDefinitionWriter::visit(AST::GlobalVariableReference& globalVariableReference)
523 {
524     auto valueName = generateNextVariableName();
525     auto pointerName = generateNextVariableName();
526     auto mangledTypeName = m_typeNamer.mangledNameForType(globalVariableReference.resolvedType());
527     checkErrorAndVisit(globalVariableReference.base());
528     m_stringBuilder.flexibleAppend(
529         "thread ", mangledTypeName, "* ", pointerName, " = &", takeLastValue(), "->", m_typeNamer.mangledNameForStructureElement(globalVariableReference.structField()), ";\n",
530         mangledTypeName, ' ', valueName, " = ", "*", pointerName, ";\n"
531     );
532     appendLeftValue(globalVariableReference, valueName, pointerName, Nullability::NotNull);
533 }
534
535 void FunctionDefinitionWriter::visit(AST::IndexExpression& indexExpression)
536 {
537     // This should be lowered already.
538     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
539     notImplemented();
540     appendRightValue(indexExpression, "dummy");
541 }
542
543 void FunctionDefinitionWriter::visit(AST::PropertyAccessExpression& propertyAccessExpression)
544 {
545     // This should be lowered already.
546     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
547     notImplemented();
548     appendRightValue(propertyAccessExpression, "dummy");
549 }
550
551 void FunctionDefinitionWriter::visit(AST::VariableDeclaration& variableDeclaration)
552 {
553     ASSERT(variableDeclaration.type());
554     auto variableName = generateNextVariableName();
555     auto addResult = m_variableMapping.add(&variableDeclaration, variableName);
556     ASSERT_UNUSED(addResult, addResult.isNewEntry);
557     // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198160 Implement qualifiers.
558     if (variableDeclaration.initializer()) {
559         checkErrorAndVisit(*variableDeclaration.initializer());
560         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = ", takeLastValue(), ";\n");
561     } else
562         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, ";\n");
563 }
564
565 void FunctionDefinitionWriter::visit(AST::AssignmentExpression& assignmentExpression)
566 {
567     checkErrorAndVisit(assignmentExpression.left());
568     auto [pointerName, nullability] = takeLastLeftValue();
569     checkErrorAndVisit(assignmentExpression.right());
570     auto [rightName, rightNullability] = takeLastValueAndNullability();
571     if (nullability == Nullability::CanBeNull)
572         m_stringBuilder.flexibleAppend("if (", pointerName, ") *", pointerName, " = ", rightName, ";\n");
573     else
574         m_stringBuilder.flexibleAppend("*", pointerName, " = ", rightName, ";\n");
575     appendRightValueWithNullability(assignmentExpression, rightName, rightNullability);
576 }
577
578 void FunctionDefinitionWriter::visit(AST::CallExpression& callExpression)
579 {
580     Vector<String> argumentNames;
581     for (auto& argument : callExpression.arguments()) {
582         checkErrorAndVisit(argument);
583         argumentNames.append(takeLastValue());
584     }
585     auto iterator = m_functionMapping.find(&callExpression.function());
586     ASSERT(iterator != m_functionMapping.end());
587     auto variableName = generateNextVariableName();
588     if (!matches(callExpression.resolvedType(), m_intrinsics.voidType()))
589         m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(callExpression.resolvedType()), ' ', variableName, " = ");
590     m_stringBuilder.flexibleAppend(iterator->value, '(');
591     for (size_t i = 0; i < argumentNames.size(); ++i) {
592         if (i)
593             m_stringBuilder.append(", ");
594         m_stringBuilder.append(argumentNames[i]);
595     }
596     m_stringBuilder.append(");\n");
597     appendRightValue(callExpression, variableName);
598 }
599
600 void FunctionDefinitionWriter::visit(AST::CommaExpression& commaExpression)
601 {
602     String result;
603     for (auto& expression : commaExpression.list()) {
604         checkErrorAndVisit(expression);
605         result = takeLastValue();
606     }
607     appendRightValue(commaExpression, result);
608 }
609
610 void FunctionDefinitionWriter::visit(AST::DereferenceExpression& dereferenceExpression)
611 {
612     checkErrorAndVisit(dereferenceExpression.pointer());
613     auto [inputPointer, nullability] = takeLastValueAndNullability();
614     auto resultValue = generateNextVariableName();
615     auto resultPointer = generateNextVariableName();
616     m_stringBuilder.flexibleAppend(
617         m_typeNamer.mangledNameForType(dereferenceExpression.pointer().resolvedType()), ' ', resultPointer, " = ", inputPointer, ";\n",
618         m_typeNamer.mangledNameForType(dereferenceExpression.resolvedType()), ' ', resultValue, ";\n");
619     if (nullability == Nullability::CanBeNull) {
620         m_stringBuilder.flexibleAppend(
621             "if (", resultPointer, ") ", resultValue, " = *", inputPointer, ";\n",
622             "else ", resultValue, " = { };\n");
623     } else
624         m_stringBuilder.flexibleAppend(resultValue, " = *", inputPointer, ";\n");
625     appendLeftValue(dereferenceExpression, resultValue, resultPointer, nullability);
626 }
627
628 void FunctionDefinitionWriter::visit(AST::LogicalExpression& logicalExpression)
629 {
630     checkErrorAndVisit(logicalExpression.left());
631     auto left = takeLastValue();
632     checkErrorAndVisit(logicalExpression.right());
633     auto right = takeLastValue();
634     auto variableName = generateNextVariableName();
635     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(logicalExpression.resolvedType()), ' ', variableName, " = ", left);
636     switch (logicalExpression.type()) {
637     case AST::LogicalExpression::Type::And:
638         m_stringBuilder.append(" && ");
639         break;
640     default:
641         ASSERT(logicalExpression.type() == AST::LogicalExpression::Type::Or);
642         m_stringBuilder.append(" || ");
643         break;
644     }
645     m_stringBuilder.flexibleAppend(right, ";\n");
646     appendRightValue(logicalExpression, variableName);
647 }
648
649 void FunctionDefinitionWriter::visit(AST::LogicalNotExpression& logicalNotExpression)
650 {
651     checkErrorAndVisit(logicalNotExpression.operand());
652     auto operand = takeLastValue();
653     auto variableName = generateNextVariableName();
654     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(logicalNotExpression.resolvedType()), ' ', variableName, " = !", operand, ";\n");
655     appendRightValue(logicalNotExpression, variableName);
656 }
657
658 void FunctionDefinitionWriter::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression)
659 {
660     checkErrorAndVisit(makeArrayReferenceExpression.leftValue());
661     // FIXME: This needs to be made to work. It probably should be using the last leftValue too.
662     // https://bugs.webkit.org/show_bug.cgi?id=198838
663     auto variableName = generateNextVariableName();
664
665     auto mangledTypeName = m_typeNamer.mangledNameForType(makeArrayReferenceExpression.resolvedType());
666     if (is<AST::PointerType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
667         auto ptrValue = takeLastValue();
668         m_stringBuilder.flexibleAppend(
669             mangledTypeName, ' ', variableName, ";\n",
670             "if (", ptrValue, ") ", variableName, " = { ", ptrValue, ", 1};\n",
671             "else ", variableName, " = { nullptr, 0 };\n"
672         );
673     } else if (is<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType())) {
674         auto lValue = takeLastLeftValue().value;
675         auto& arrayType = downcast<AST::ArrayType>(makeArrayReferenceExpression.leftValue().resolvedType());
676         m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = { ", lValue, "->data(), ", arrayType.numElements(), " };\n");
677     } else {
678         auto lValue = takeLastLeftValue().value;
679         m_stringBuilder.flexibleAppend(mangledTypeName, ' ', variableName, " = { ", lValue, ", 1 };\n");
680     }
681     appendRightValue(makeArrayReferenceExpression, variableName);
682 }
683
684 void FunctionDefinitionWriter::visit(AST::MakePointerExpression& makePointerExpression)
685 {
686     checkErrorAndVisit(makePointerExpression.leftValue());
687     auto [pointer, nullability] = takeLastLeftValue();
688     auto variableName = generateNextVariableName();
689     m_stringBuilder.flexibleAppend(m_typeNamer.mangledNameForType(makePointerExpression.resolvedType()), ' ', variableName, " = ", pointer, ";\n");
690     appendRightValueWithNullability(makePointerExpression, variableName, nullability);
691 }
692
693 void FunctionDefinitionWriter::visit(AST::ReadModifyWriteExpression&)
694 {
695     // This should be lowered already.
696     ASSERT_NOT_REACHED();
697 }
698
699 void FunctionDefinitionWriter::visit(AST::TernaryExpression& ternaryExpression)
700 {
701     checkErrorAndVisit(ternaryExpression.predicate());
702     auto check = takeLastValue();
703
704     auto variableName = generateNextVariableName();
705     m_stringBuilder.flexibleAppend(
706         m_typeNamer.mangledNameForType(ternaryExpression.resolvedType()), ' ', variableName, ";\n"
707         "if (", check, ") {\n"
708     );
709     checkErrorAndVisit(ternaryExpression.bodyExpression());
710     m_stringBuilder.flexibleAppend(
711         variableName, " = ", takeLastValue(), ";\n"
712         "} else {\n"
713     );
714     checkErrorAndVisit(ternaryExpression.elseExpression());
715     m_stringBuilder.flexibleAppend(
716         variableName, " = ", takeLastValue(), ";\n"
717         "}\n"
718     );
719     appendRightValue(ternaryExpression, variableName);
720 }
721
722 void FunctionDefinitionWriter::visit(AST::VariableReference& variableReference)
723 {
724     ASSERT(variableReference.variable());
725     auto iterator = m_variableMapping.find(variableReference.variable());
726     ASSERT(iterator != m_variableMapping.end());
727     auto pointerName = generateNextVariableName();
728     m_stringBuilder.flexibleAppend("thread ", m_typeNamer.mangledNameForType(variableReference.resolvedType()), "* ", pointerName, " = &", iterator->value, ";\n");
729     appendLeftValue(variableReference, iterator->value, pointerName, Nullability::NotNull);
730 }
731
732 String FunctionDefinitionWriter::constantExpressionString(AST::ConstantExpression& constantExpression)
733 {
734     return constantExpression.visit(WTF::makeVisitor([&](AST::IntegerLiteral& integerLiteral) -> String {
735         return makeString("", integerLiteral.value());
736     }, [&](AST::UnsignedIntegerLiteral& unsignedIntegerLiteral) -> String {
737         return makeString("", unsignedIntegerLiteral.value());
738     }, [&](AST::FloatLiteral& floatLiteral) -> String {
739         return makeString("", floatLiteral.value());
740     }, [&](AST::NullLiteral&) -> String {
741         return "nullptr"_str;
742     }, [&](AST::BooleanLiteral& booleanLiteral) -> String {
743         return booleanLiteral.value() ? "true"_str : "false"_str;
744     }, [&](AST::EnumerationMemberLiteral& enumerationMemberLiteral) -> String {
745         ASSERT(enumerationMemberLiteral.enumerationDefinition());
746         ASSERT(enumerationMemberLiteral.enumerationDefinition());
747         return makeString(m_typeNamer.mangledNameForType(*enumerationMemberLiteral.enumerationDefinition()), "::", m_typeNamer.mangledNameForEnumerationMember(*enumerationMemberLiteral.enumerationMember()));
748     }));
749 }
750
751 class RenderFunctionDefinitionWriter : public FunctionDefinitionWriter {
752 public:
753     RenderFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
754         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
755         , m_matchedSemantics(WTFMove(matchedSemantics))
756     {
757     }
758
759 private:
760     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
761
762     MatchedRenderSemantics m_matchedSemantics;
763 };
764
765 std::unique_ptr<EntryPointScaffolding> RenderFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
766 {
767     auto generateNextVariableName = [this]() -> String {
768         return this->generateNextVariableName();
769     };
770     if (&functionDefinition == m_matchedSemantics.vertexShader)
771         return std::make_unique<VertexEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.vertexShaderEntryPointItems, m_matchedSemantics.vertexShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedVertexAttributes);
772     if (&functionDefinition == m_matchedSemantics.fragmentShader)
773         return std::make_unique<FragmentEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.fragmentShaderEntryPointItems, m_matchedSemantics.fragmentShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedColorAttachments);
774     return nullptr;
775 }
776
777 class ComputeFunctionDefinitionWriter : public FunctionDefinitionWriter {
778 public:
779     ComputeFunctionDefinitionWriter(Intrinsics& intrinsics, TypeNamer& typeNamer, HashMap<AST::FunctionDeclaration*, String>& functionMapping, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
780         : FunctionDefinitionWriter(intrinsics, typeNamer, functionMapping, layout)
781         , m_matchedSemantics(WTFMove(matchedSemantics))
782     {
783     }
784
785 private:
786     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
787
788     MatchedComputeSemantics m_matchedSemantics;
789 };
790
791 std::unique_ptr<EntryPointScaffolding> ComputeFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
792 {
793     auto generateNextVariableName = [this]() -> String {
794         return this->generateNextVariableName();
795     };
796     if (&functionDefinition == m_matchedSemantics.shader)
797         return std::make_unique<ComputeEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.entryPointItems, m_matchedSemantics.resourceMap, m_layout, WTFMove(generateNextVariableName));
798     return nullptr;
799 }
800
801 struct SharedMetalFunctionsResult {
802     HashMap<AST::FunctionDeclaration*, String> functionMapping;
803     String metalFunctions;
804 };
805 static SharedMetalFunctionsResult sharedMetalFunctions(Program& program, TypeNamer& typeNamer, const HashSet<AST::FunctionDeclaration*>& reachableFunctions)
806 {
807     StringBuilder stringBuilder;
808
809     unsigned numFunctions = 0;
810     HashMap<AST::FunctionDeclaration*, String> functionMapping;
811     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
812         auto addResult = functionMapping.add(&nativeFunctionDeclaration, makeString("function", numFunctions++));
813         ASSERT_UNUSED(addResult, addResult.isNewEntry);
814     }
815     for (auto& functionDefinition : program.functionDefinitions()) {
816         auto addResult = functionMapping.add(&functionDefinition, makeString("function", numFunctions++));
817         ASSERT_UNUSED(addResult, addResult.isNewEntry);
818     }
819
820     {
821         FunctionDeclarationWriter functionDeclarationWriter(typeNamer, functionMapping);
822         for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
823             if (reachableFunctions.contains(&nativeFunctionDeclaration))
824                 functionDeclarationWriter.visit(nativeFunctionDeclaration);
825         }
826         for (auto& functionDefinition : program.functionDefinitions()) {
827             if (!functionDefinition->entryPointType() && reachableFunctions.contains(&functionDefinition))
828                 functionDeclarationWriter.visit(functionDefinition);
829         }
830         stringBuilder.append(functionDeclarationWriter.toString());
831     }
832
833     stringBuilder.append('\n');
834     return { WTFMove(functionMapping), stringBuilder.toString() };
835 }
836
837 class ReachableFunctionsGatherer : public Visitor {
838 public:
839     void visit(AST::FunctionDeclaration& functionDeclaration) override
840     {
841         Visitor::visit(functionDeclaration);
842         m_reachableFunctions.add(&functionDeclaration);
843     }
844
845     void visit(AST::CallExpression& callExpression) override
846     {
847         Visitor::visit(callExpression);
848         if (is<AST::FunctionDefinition>(callExpression.function()))
849             checkErrorAndVisit(downcast<AST::FunctionDefinition>(callExpression.function()));
850         else
851             checkErrorAndVisit(downcast<AST::NativeFunctionDeclaration>(callExpression.function()));
852     }
853
854     HashSet<AST::FunctionDeclaration*> takeReachableFunctions() { return WTFMove(m_reachableFunctions); }
855
856 private:
857     HashSet<AST::FunctionDeclaration*> m_reachableFunctions;
858 };
859
860 RenderMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedRenderSemantics&& matchedSemantics, Layout& layout)
861 {
862     auto& vertexShaderEntryPoint = *matchedSemantics.vertexShader;
863     auto& fragmentShaderEntryPoint = *matchedSemantics.fragmentShader;
864
865     ReachableFunctionsGatherer reachableFunctionsGatherer;
866     reachableFunctionsGatherer.Visitor::visit(vertexShaderEntryPoint);
867     reachableFunctionsGatherer.Visitor::visit(fragmentShaderEntryPoint);
868     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
869
870     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer, reachableFunctions);
871
872     StringBuilder stringBuilder;
873     stringBuilder.append(sharedMetalFunctions.metalFunctions);
874
875     RenderFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
876     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
877         if (reachableFunctions.contains(&nativeFunctionDeclaration))
878             functionDefinitionWriter.visit(nativeFunctionDeclaration);
879     }
880     for (auto& functionDefinition : program.functionDefinitions()) {
881         if (reachableFunctions.contains(&functionDefinition))
882             functionDefinitionWriter.visit(functionDefinition);
883     }
884     stringBuilder.append(functionDefinitionWriter.toString());
885
886     RenderMetalFunctions result;
887     result.metalSource = stringBuilder.toString();
888     result.mangledVertexEntryPointName = sharedMetalFunctions.functionMapping.get(&vertexShaderEntryPoint);
889     result.mangledFragmentEntryPointName = sharedMetalFunctions.functionMapping.get(&fragmentShaderEntryPoint);
890     return result;
891 }
892
893 ComputeMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
894 {
895     auto& entryPoint = *matchedSemantics.shader;
896
897     ReachableFunctionsGatherer reachableFunctionsGatherer;
898     reachableFunctionsGatherer.Visitor::visit(entryPoint);
899     auto reachableFunctions = reachableFunctionsGatherer.takeReachableFunctions();
900
901     auto sharedMetalFunctions = Metal::sharedMetalFunctions(program, typeNamer, reachableFunctions);
902
903     StringBuilder stringBuilder;
904     stringBuilder.append(sharedMetalFunctions.metalFunctions);
905
906     ComputeFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
907     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
908         if (reachableFunctions.contains(&nativeFunctionDeclaration))
909             functionDefinitionWriter.visit(nativeFunctionDeclaration);
910     }
911     for (auto& functionDefinition : program.functionDefinitions()) {
912         if (reachableFunctions.contains(&functionDefinition))
913             functionDefinitionWriter.visit(functionDefinition);
914     }
915     stringBuilder.append(functionDefinitionWriter.toString());
916
917     ComputeMetalFunctions result;
918     result.metalSource = stringBuilder.toString();
919     result.mangledEntryPointName = sharedMetalFunctions.functionMapping.get(&entryPoint);
920     return result;
921 }
922
923 } // namespace Metal
924
925 } // namespace WHLSL
926
927 } // namespace WebCore
928
929 #endif