[WHLSL] Make generated Metal code should be indented properly to ease reading while...
[WebKit.git] / Source / WebCore / Modules / webgpu / WHLSL / Metal / WHLSLEntryPointScaffolding.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLEntryPointScaffolding.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLBuiltInSemantic.h"
32 #include "WHLSLFunctionDefinition.h"
33 #include "WHLSLGatherEntryPointItems.h"
34 #include "WHLSLPipelineDescriptor.h"
35 #include "WHLSLReferenceType.h"
36 #include "WHLSLResourceSemantic.h"
37 #include "WHLSLStageInOutSemantic.h"
38 #include "WHLSLStructureDefinition.h"
39 #include "WHLSLTypeNamer.h"
40 #include <algorithm>
41 #include <wtf/Optional.h>
42 #include <wtf/text/StringBuilder.h>
43 #include <wtf/text/StringConcatenateNumbers.h>
44
45 namespace WebCore {
46
47 namespace WHLSL {
48
49 namespace Metal {
50
51 static String attributeForSemantic(AST::BuiltInSemantic& builtInSemantic)
52 {
53     switch (builtInSemantic.variable()) {
54     case AST::BuiltInSemantic::Variable::SVInstanceID:
55         return "[[instance_id]]"_str;
56     case AST::BuiltInSemantic::Variable::SVVertexID:
57         return "[[vertex_id]]"_str;
58     case AST::BuiltInSemantic::Variable::PSize:
59         return "[[point_size]]"_str;
60     case AST::BuiltInSemantic::Variable::SVPosition:
61         return "[[position]]"_str;
62     case AST::BuiltInSemantic::Variable::SVIsFrontFace:
63         return "[[front_facing]]"_str;
64     case AST::BuiltInSemantic::Variable::SVSampleIndex:
65         return "[[sample_id]]"_str;
66     case AST::BuiltInSemantic::Variable::SVInnerCoverage:
67         return "[[sample_mask]]"_str;
68     case AST::BuiltInSemantic::Variable::SVTarget:
69         return makeString("[[color(", *builtInSemantic.targetIndex(), ")]]");
70     case AST::BuiltInSemantic::Variable::SVDepth:
71         return "[[depth(any)]]"_str;
72     case AST::BuiltInSemantic::Variable::SVCoverage:
73         return "[[sample_mask]]"_str;
74     case AST::BuiltInSemantic::Variable::SVDispatchThreadID:
75         return "[[thread_position_in_grid]]"_str;
76     case AST::BuiltInSemantic::Variable::SVGroupID:
77         return "[[threadgroup_position_in_grid]]"_str;
78     case AST::BuiltInSemantic::Variable::SVGroupIndex:
79         return "[[thread_index_in_threadgroup]]"_str;
80     default:
81         ASSERT(builtInSemantic.variable() == AST::BuiltInSemantic::Variable::SVGroupThreadID);
82         return "[[thread_position_in_threadgroup]]"_str;
83     }
84 }
85
86 static String attributeForSemantic(AST::Semantic& semantic)
87 {
88     if (WTF::holds_alternative<AST::BuiltInSemantic>(semantic))
89         return attributeForSemantic(WTF::get<AST::BuiltInSemantic>(semantic));
90     auto& stageInOutSemantic = WTF::get<AST::StageInOutSemantic>(semantic);
91     return makeString("[[user(user", stageInOutSemantic.index(), ")]]");
92 }
93
94 EntryPointScaffolding::EntryPointScaffolding(AST::FunctionDefinition& functionDefinition, Intrinsics& intrinsics, TypeNamer& typeNamer, EntryPointItems& entryPointItems, HashMap<Binding*, size_t>& resourceMap, Layout& layout, std::function<MangledVariableName()>&& generateNextVariableName)
95     : m_functionDefinition(functionDefinition)
96     , m_intrinsics(intrinsics)
97     , m_typeNamer(typeNamer)
98     , m_entryPointItems(entryPointItems)
99     , m_resourceMap(resourceMap)
100     , m_layout(layout)
101     , m_generateNextVariableName(generateNextVariableName)
102 {
103     m_namedBindGroups.reserveInitialCapacity(m_layout.size());
104     for (size_t i = 0; i < m_layout.size(); ++i) {
105         NamedBindGroup namedBindGroup;
106         namedBindGroup.structName = m_typeNamer.generateNextTypeName();
107         namedBindGroup.variableName = m_generateNextVariableName();
108         namedBindGroup.argumentBufferIndex = m_layout[i].name; // convertLayout() in GPURenderPipelineMetal.mm makes sure these don't collide.
109         namedBindGroup.namedBindings.reserveInitialCapacity(m_layout[i].bindings.size());
110         for (size_t j = 0; j < m_layout[i].bindings.size(); ++j) {
111             NamedBinding namedBinding;
112             namedBinding.elementName = m_typeNamer.generateNextStructureElementName();
113             namedBinding.index = m_layout[i].bindings[j].internalName;
114             WTF::visit(WTF::makeVisitor([&](UniformBufferBinding& uniformBufferBinding) {
115                 LengthInformation lengthInformation { m_typeNamer.generateNextStructureElementName(), m_generateNextVariableName(), uniformBufferBinding.lengthName };
116                 namedBinding.lengthInformation = lengthInformation;
117             }, [&](SamplerBinding&) {
118             }, [&](TextureBinding&) {
119             }, [&](StorageBufferBinding& storageBufferBinding) {
120                 LengthInformation lengthInformation { m_typeNamer.generateNextStructureElementName(), m_generateNextVariableName(), storageBufferBinding.lengthName };
121                 namedBinding.lengthInformation = lengthInformation;
122             }), m_layout[i].bindings[j].binding);
123             namedBindGroup.namedBindings.uncheckedAppend(WTFMove(namedBinding));
124         }
125         m_namedBindGroups.uncheckedAppend(WTFMove(namedBindGroup));
126     }
127
128     for (size_t i = 0; i < m_entryPointItems.inputs.size(); ++i) {
129         if (!WTF::holds_alternative<AST::BuiltInSemantic>(*m_entryPointItems.inputs[i].semantic))
130             continue;
131         NamedBuiltIn namedBuiltIn;
132         namedBuiltIn.indexInEntryPointItems = i;
133         namedBuiltIn.variableName = m_generateNextVariableName();
134         m_namedBuiltIns.append(WTFMove(namedBuiltIn));
135     }
136
137     m_parameterVariables.reserveInitialCapacity(m_functionDefinition.parameters().size());
138     for (size_t i = 0; i < m_functionDefinition.parameters().size(); ++i)
139         m_parameterVariables.uncheckedAppend(m_generateNextVariableName());
140 }
141
142 void EntryPointScaffolding::emitResourceHelperTypes(StringBuilder& stringBuilder, Indentation<4> indent)
143 {
144     for (size_t i = 0; i < m_layout.size(); ++i) {
145         stringBuilder.append(indent, "struct ", m_namedBindGroups[i].structName, " {\n");
146         {
147             IndentationScope scope(indent);
148             Vector<std::pair<unsigned, String>> structItems;
149             for (size_t j = 0; j < m_layout[i].bindings.size(); ++j) {
150                 auto iterator = m_resourceMap.find(&m_layout[i].bindings[j]);
151                 if (iterator == m_resourceMap.end())
152                     continue;
153                 auto& type = m_entryPointItems.inputs[iterator->value].unnamedType->unifyNode();
154                 if (is<AST::UnnamedType>(type) && is<AST::ReferenceType>(downcast<AST::UnnamedType>(type))) {
155                     auto& referenceType = downcast<AST::ReferenceType>(downcast<AST::UnnamedType>(type));
156                     auto mangledTypeName = m_typeNamer.mangledNameForType(referenceType.elementType());
157                     auto addressSpace = toString(referenceType.addressSpace());
158                     auto elementName = m_namedBindGroups[i].namedBindings[j].elementName;
159                     auto index = m_namedBindGroups[i].namedBindings[j].index;
160                     structItems.append(std::make_pair(index, makeString(addressSpace, " ", mangledTypeName, "* ", elementName, " [[id(", index, ")]];")));
161                     if (auto lengthInformation = m_namedBindGroups[i].namedBindings[j].lengthInformation)
162                         structItems.append(std::make_pair(lengthInformation->index, makeString("uint2 ", lengthInformation->elementName, " [[id(", lengthInformation->index, ")]];")));
163                 } else if (is<AST::NamedType>(type) && is<AST::NativeTypeDeclaration>(downcast<AST::NamedType>(type))) {
164                     auto& namedType = downcast<AST::NativeTypeDeclaration>(downcast<AST::NamedType>(type));
165                     auto mangledTypeName = m_typeNamer.mangledNameForType(namedType);
166                     auto elementName = m_namedBindGroups[i].namedBindings[j].elementName;
167                     auto index = m_namedBindGroups[i].namedBindings[j].index;
168                     structItems.append(std::make_pair(index, makeString(mangledTypeName, ' ', elementName, " [[id(", index, ")]];")));
169                 }
170             }
171             std::sort(structItems.begin(), structItems.end(), [](const std::pair<unsigned, String>& left, const std::pair<unsigned, String>& right) {
172                 return left.first < right.first;
173             });
174             for (const auto& structItem : structItems)
175                 stringBuilder.append(indent, structItem.second, '\n');
176         }
177         stringBuilder.append(indent, "};\n\n");
178     }
179 }
180
181 bool EntryPointScaffolding::emitResourceSignature(StringBuilder& stringBuilder, IncludePrecedingComma includePrecedingComma)
182 {
183     if (!m_layout.size())
184         return false;
185
186     if (includePrecedingComma == IncludePrecedingComma::Yes)
187         stringBuilder.append(", ");
188
189     for (size_t i = 0; i < m_layout.size(); ++i) {
190         if (i)
191             stringBuilder.append(", ");
192         auto& namedBindGroup = m_namedBindGroups[i];
193         stringBuilder.append("device ", namedBindGroup.structName, "& ", namedBindGroup.variableName, " [[buffer(", namedBindGroup.argumentBufferIndex, ")]]");
194     }
195     return true;
196 }
197
198 static StringView internalTypeForSemantic(const AST::BuiltInSemantic& builtInSemantic)
199 {
200     switch (builtInSemantic.variable()) {
201     case AST::BuiltInSemantic::Variable::SVInstanceID:
202         return "uint";
203     case AST::BuiltInSemantic::Variable::SVVertexID:
204         return "uint";
205     case AST::BuiltInSemantic::Variable::PSize:
206         return "float";
207     case AST::BuiltInSemantic::Variable::SVPosition:
208         return "float4";
209     case AST::BuiltInSemantic::Variable::SVIsFrontFace:
210         return "bool";
211     case AST::BuiltInSemantic::Variable::SVSampleIndex:
212         return "uint";
213     case AST::BuiltInSemantic::Variable::SVInnerCoverage:
214         return "uint";
215     case AST::BuiltInSemantic::Variable::SVTarget:
216         return { };
217     case AST::BuiltInSemantic::Variable::SVDepth:
218         return "float";
219     case AST::BuiltInSemantic::Variable::SVCoverage:
220         return "uint";
221     case AST::BuiltInSemantic::Variable::SVDispatchThreadID:
222         return "uint3";
223     case AST::BuiltInSemantic::Variable::SVGroupID:
224         return "uint3";
225     case AST::BuiltInSemantic::Variable::SVGroupIndex:
226         return "uint";
227     default:
228         ASSERT(builtInSemantic.variable() == AST::BuiltInSemantic::Variable::SVGroupThreadID);
229         return "uint3";
230     }
231 }
232
233 bool EntryPointScaffolding::emitBuiltInsSignature(StringBuilder& stringBuilder, IncludePrecedingComma includePrecedingComma)
234 {
235     if (!m_namedBuiltIns.size())
236         return false;
237
238     if (includePrecedingComma == IncludePrecedingComma::Yes)
239         stringBuilder.append(", ");
240
241     for (size_t i = 0; i < m_namedBuiltIns.size(); ++i) {
242         if (i)
243             stringBuilder.append(", ");
244         auto& namedBuiltIn = m_namedBuiltIns[i];
245         auto& item = m_entryPointItems.inputs[namedBuiltIn.indexInEntryPointItems];
246         auto& builtInSemantic = WTF::get<AST::BuiltInSemantic>(*item.semantic);
247         auto internalType = internalTypeForSemantic(builtInSemantic);
248         if (!internalType.isNull())
249             stringBuilder.append(internalType);
250         else
251             stringBuilder.append(m_typeNamer.mangledNameForType(*item.unnamedType));
252         stringBuilder.append(' ', namedBuiltIn.variableName, ' ', attributeForSemantic(builtInSemantic));
253     }
254     return true;
255 }
256
257 void EntryPointScaffolding::emitMangledInputPath(StringBuilder& stringBuilder, Vector<String>& path)
258 {
259     ASSERT(!path.isEmpty());
260     bool found = false;
261     AST::StructureDefinition* structureDefinition = nullptr;
262     for (size_t i = 0; i < m_functionDefinition.parameters().size(); ++i) {
263         if (m_functionDefinition.parameters()[i]->name() == path[0]) {
264             stringBuilder.append(m_parameterVariables[i]);
265             auto& unifyNode = m_functionDefinition.parameters()[i]->type()->unifyNode();
266             if (is<AST::NamedType>(unifyNode)) {
267                 auto& namedType = downcast<AST::NamedType>(unifyNode);
268                 if (is<AST::StructureDefinition>(namedType))
269                     structureDefinition = &downcast<AST::StructureDefinition>(namedType);
270             }
271             found = true;
272             break;
273         }
274     }
275     ASSERT(found);
276     for (size_t i = 1; i < path.size(); ++i) {
277         ASSERT(structureDefinition);
278         auto* next = structureDefinition->find(path[i]);
279         ASSERT(next);
280         stringBuilder.append('.', m_typeNamer.mangledNameForStructureElement(*next));
281         structureDefinition = nullptr;
282         auto& unifyNode = next->type().unifyNode();
283         if (is<AST::NamedType>(unifyNode)) {
284             auto& namedType = downcast<AST::NamedType>(unifyNode);
285             if (is<AST::StructureDefinition>(namedType))
286                 structureDefinition = &downcast<AST::StructureDefinition>(namedType);
287         }
288     }
289 }
290
291 void EntryPointScaffolding::emitMangledOutputPath(StringBuilder& stringBuilder, Vector<String>& path)
292 {
293     AST::StructureDefinition* structureDefinition = nullptr;
294     auto& unifyNode = m_functionDefinition.type().unifyNode();
295     structureDefinition = &downcast<AST::StructureDefinition>(downcast<AST::NamedType>(unifyNode));
296     for (auto& component : path) {
297         ASSERT(structureDefinition);
298         auto* next = structureDefinition->find(component);
299         ASSERT(next);
300         stringBuilder.append('.', m_typeNamer.mangledNameForStructureElement(*next));
301         structureDefinition = nullptr;
302         auto& unifyNode = next->type().unifyNode();
303         if (is<AST::NamedType>(unifyNode)) {
304             auto& namedType = downcast<AST::NamedType>(unifyNode);
305             if (is<AST::StructureDefinition>(namedType))
306                 structureDefinition = &downcast<AST::StructureDefinition>(namedType);
307         }
308     }
309 }
310
311 void EntryPointScaffolding::emitUnpackResourcesAndNamedBuiltIns(StringBuilder& stringBuilder, Indentation<4> indent)
312 {
313     for (size_t i = 0; i < m_functionDefinition.parameters().size(); ++i)
314         stringBuilder.append(indent, m_typeNamer.mangledNameForType(*m_functionDefinition.parameters()[i]->type()), ' ', m_parameterVariables[i], ";\n");
315
316     for (size_t i = 0; i < m_layout.size(); ++i) {
317         auto variableName = m_namedBindGroups[i].variableName;
318         for (size_t j = 0; j < m_layout[i].bindings.size(); ++j) {
319             auto iterator = m_resourceMap.find(&m_layout[i].bindings[j]);
320             if (iterator == m_resourceMap.end())
321                 continue;
322             if (m_namedBindGroups[i].namedBindings[j].lengthInformation) {
323                 auto& path = m_entryPointItems.inputs[iterator->value].path;
324                 auto elementName = m_namedBindGroups[i].namedBindings[j].elementName;
325                 auto lengthElementName = m_namedBindGroups[i].namedBindings[j].lengthInformation->elementName;
326                 auto lengthTemporaryName = m_namedBindGroups[i].namedBindings[j].lengthInformation->temporaryName;
327
328                 auto& unnamedType = *m_entryPointItems.inputs[iterator->value].unnamedType;
329                 auto mangledTypeName = m_typeNamer.mangledNameForType(downcast<AST::ReferenceType>(unnamedType).elementType());
330
331                 stringBuilder.append(
332                     indent, "size_t ", lengthTemporaryName, " = ", variableName, '.', lengthElementName, ".y;\n",
333                     indent, lengthTemporaryName, " = ", lengthTemporaryName, " << 32;\n",
334                     indent, lengthTemporaryName, " = ", lengthTemporaryName, " | ", variableName, '.', lengthElementName, ".x;\n",
335                     indent, lengthTemporaryName, " = ", lengthTemporaryName, " / sizeof(", mangledTypeName, ");\n",
336                     indent, "if (", lengthTemporaryName, " > 0xFFFFFFFF)\n",
337                     indent, "    ", lengthTemporaryName, " = 0xFFFFFFFF;\n"
338                 );
339
340                 stringBuilder.append(indent);
341                 emitMangledInputPath(stringBuilder, path);
342                 stringBuilder.append(
343                     " = { ", variableName, '.', elementName, ", static_cast<uint32_t>(", lengthTemporaryName, ") };\n"
344                 );
345             } else {
346                 auto& path = m_entryPointItems.inputs[iterator->value].path;
347                 auto elementName = m_namedBindGroups[i].namedBindings[j].elementName;
348                 
349                 stringBuilder.append(indent);
350                 emitMangledInputPath(stringBuilder, path);
351                 stringBuilder.append(" = ", variableName, '.', elementName, ";\n");
352             }
353         }
354     }
355
356     for (auto& namedBuiltIn : m_namedBuiltIns) {
357         auto& item = m_entryPointItems.inputs[namedBuiltIn.indexInEntryPointItems];
358         auto& path = item.path;
359         auto& variableName = namedBuiltIn.variableName;
360         auto mangledTypeName = m_typeNamer.mangledNameForType(*item.unnamedType);
361
362         stringBuilder.append(indent);
363         emitMangledInputPath(stringBuilder, path);
364         stringBuilder.append(" = ", mangledTypeName, '(', variableName, ");\n");
365     }
366 }
367
368 VertexEntryPointScaffolding::VertexEntryPointScaffolding(AST::FunctionDefinition& functionDefinition, Intrinsics& intrinsics, TypeNamer& typeNamer, EntryPointItems& entryPointItems, HashMap<Binding*, size_t>& resourceMap, Layout& layout, std::function<MangledVariableName()>&& generateNextVariableName, HashMap<VertexAttribute*, size_t>& matchedVertexAttributes)
369     : EntryPointScaffolding(functionDefinition, intrinsics, typeNamer, entryPointItems, resourceMap, layout, WTFMove(generateNextVariableName))
370     , m_matchedVertexAttributes(matchedVertexAttributes)
371     , m_stageInStructName(typeNamer.generateNextTypeName())
372     , m_returnStructName(typeNamer.generateNextTypeName())
373     , m_stageInParameterName(m_generateNextVariableName())
374 {
375     m_namedStageIns.reserveInitialCapacity(m_matchedVertexAttributes.size());
376     for (auto& keyValuePair : m_matchedVertexAttributes) {
377         NamedStageIn namedStageIn;
378         namedStageIn.indexInEntryPointItems = keyValuePair.value;
379         namedStageIn.elementName = m_typeNamer.generateNextStructureElementName();
380         namedStageIn.attributeIndex = keyValuePair.key->metalLocation;
381         m_namedStageIns.uncheckedAppend(WTFMove(namedStageIn));
382     }
383
384     m_namedOutputs.reserveInitialCapacity(m_entryPointItems.outputs.size());
385     for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
386         auto& outputItem = m_entryPointItems.outputs[i];
387         NamedOutput namedOutput;
388         namedOutput.elementName = m_typeNamer.generateNextStructureElementName();
389         StringView internalType;
390         if (WTF::holds_alternative<AST::BuiltInSemantic>(*outputItem.semantic))
391             internalType = internalTypeForSemantic(WTF::get<AST::BuiltInSemantic>(*outputItem.semantic));
392         if (!internalType.isNull())
393             namedOutput.internalTypeName = internalType.toString();
394         else
395             namedOutput.internalTypeName = m_typeNamer.mangledNameForType(*outputItem.unnamedType);
396         m_namedOutputs.uncheckedAppend(WTFMove(namedOutput));
397     }
398 }
399
400 void VertexEntryPointScaffolding::emitHelperTypes(StringBuilder& stringBuilder, Indentation<4> indent)
401 {
402     stringBuilder.append(indent, "struct ", m_stageInStructName, " {\n");
403     {
404         IndentationScope scope(indent);
405         for (auto& namedStageIn : m_namedStageIns) {
406             auto mangledTypeName = m_typeNamer.mangledNameForType(*m_entryPointItems.inputs[namedStageIn.indexInEntryPointItems].unnamedType);
407             auto elementName = namedStageIn.elementName;
408             auto attributeIndex = namedStageIn.attributeIndex;
409             stringBuilder.append(indent, mangledTypeName, ' ', elementName, " [[attribute(", attributeIndex, ")]];\n");
410         }
411     }
412     stringBuilder.append(
413         indent, "};\n\n",
414         indent, "struct ", m_returnStructName, " {\n"
415     );
416     {
417         IndentationScope scope(indent);
418         for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
419             auto& outputItem = m_entryPointItems.outputs[i];
420             auto& internalTypeName = m_namedOutputs[i].internalTypeName;
421             auto elementName = m_namedOutputs[i].elementName;
422             auto attribute = attributeForSemantic(*outputItem.semantic);
423             stringBuilder.append(indent, internalTypeName, ' ', elementName, ' ', attribute, ";\n");
424         }
425     }
426     stringBuilder.append(indent, "};\n\n");
427     
428     emitResourceHelperTypes(stringBuilder, indent);
429 }
430
431 void VertexEntryPointScaffolding::emitSignature(StringBuilder& stringBuilder, MangledFunctionName functionName, Indentation<4> indent)
432 {
433     stringBuilder.append(indent, "vertex ", m_returnStructName, ' ', functionName, '(', m_stageInStructName, ' ', m_stageInParameterName, " [[stage_in]]");
434     emitResourceSignature(stringBuilder, IncludePrecedingComma::Yes);
435     emitBuiltInsSignature(stringBuilder, IncludePrecedingComma::Yes);
436     stringBuilder.append(")\n");
437 }
438
439 void VertexEntryPointScaffolding::emitUnpack(StringBuilder& stringBuilder, Indentation<4> indent)
440 {
441     emitUnpackResourcesAndNamedBuiltIns(stringBuilder, indent);
442
443     for (auto& namedStageIn : m_namedStageIns) {
444         auto& path = m_entryPointItems.inputs[namedStageIn.indexInEntryPointItems].path;
445         auto& elementName = namedStageIn.elementName;
446         
447         stringBuilder.append(indent);
448         emitMangledInputPath(stringBuilder, path);
449         stringBuilder.append(" = ", m_stageInParameterName, '.', elementName, ";\n");
450     }
451 }
452
453 void VertexEntryPointScaffolding::emitPack(StringBuilder& stringBuilder, MangledVariableName inputVariableName, MangledVariableName outputVariableName, Indentation<4> indent)
454 {
455     stringBuilder.append(indent, m_returnStructName, ' ', outputVariableName, ";\n");
456     if (m_entryPointItems.outputs.size() == 1 && !m_entryPointItems.outputs[0].path.size()) {
457         auto& elementName = m_namedOutputs[0].elementName;
458         stringBuilder.append(indent, outputVariableName, '.', elementName, " = ", inputVariableName, ";\n");
459         return;
460     }
461     for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
462         auto& elementName = m_namedOutputs[i].elementName;
463         auto& internalTypeName = m_namedOutputs[i].internalTypeName;
464         auto& path = m_entryPointItems.outputs[i].path;
465         stringBuilder.append(indent, outputVariableName, '.', elementName, " = ", internalTypeName, '(', inputVariableName);
466         emitMangledOutputPath(stringBuilder, path);
467         stringBuilder.append(");\n");
468     }
469 }
470
471 FragmentEntryPointScaffolding::FragmentEntryPointScaffolding(AST::FunctionDefinition& functionDefinition, Intrinsics& intrinsics, TypeNamer& typeNamer, EntryPointItems& entryPointItems, HashMap<Binding*, size_t>& resourceMap, Layout& layout, std::function<MangledVariableName()>&& generateNextVariableName, HashMap<AttachmentDescriptor*, size_t>&)
472     : EntryPointScaffolding(functionDefinition, intrinsics, typeNamer, entryPointItems, resourceMap, layout, WTFMove(generateNextVariableName))
473     , m_stageInStructName(typeNamer.generateNextTypeName())
474     , m_returnStructName(typeNamer.generateNextTypeName())
475     , m_stageInParameterName(m_generateNextVariableName())
476 {
477     for (size_t i = 0; i < m_entryPointItems.inputs.size(); ++i) {
478         auto& inputItem = m_entryPointItems.inputs[i];
479         if (!WTF::holds_alternative<AST::StageInOutSemantic>(*inputItem.semantic))
480             continue;
481         auto& stageInOutSemantic = WTF::get<AST::StageInOutSemantic>(*inputItem.semantic);
482         NamedStageIn namedStageIn;
483         namedStageIn.indexInEntryPointItems = i;
484         namedStageIn.elementName = m_typeNamer.generateNextStructureElementName();
485         namedStageIn.attributeIndex = stageInOutSemantic.index();
486         m_namedStageIns.append(WTFMove(namedStageIn));
487     }
488
489     m_namedOutputs.reserveInitialCapacity(m_entryPointItems.outputs.size());
490     for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
491         auto& outputItem = m_entryPointItems.outputs[i];
492         NamedOutput namedOutput;
493         namedOutput.elementName = m_typeNamer.generateNextStructureElementName();
494         StringView internalType;
495         if (WTF::holds_alternative<AST::BuiltInSemantic>(*outputItem.semantic))
496             internalType = internalTypeForSemantic(WTF::get<AST::BuiltInSemantic>(*outputItem.semantic));
497         if (!internalType.isNull())
498             namedOutput.internalTypeName = internalType.toString();
499         else
500             namedOutput.internalTypeName = m_typeNamer.mangledNameForType(*outputItem.unnamedType);
501         m_namedOutputs.uncheckedAppend(WTFMove(namedOutput));
502     }
503 }
504
505 void FragmentEntryPointScaffolding::emitHelperTypes(StringBuilder& stringBuilder, Indentation<4> indent)
506 {
507     stringBuilder.append(indent, "struct ", m_stageInStructName, " {\n");
508     {
509         IndentationScope scope(indent);
510         for (auto& namedStageIn : m_namedStageIns) {
511             auto mangledTypeName = m_typeNamer.mangledNameForType(*m_entryPointItems.inputs[namedStageIn.indexInEntryPointItems].unnamedType);
512             auto elementName = namedStageIn.elementName;
513             auto attributeIndex = namedStageIn.attributeIndex;
514             stringBuilder.append(indent, mangledTypeName, ' ', elementName, " [[user(user", attributeIndex, ")]];\n");
515         }
516     }
517     stringBuilder.append(
518         indent, "};\n\n",
519         indent, "struct ", m_returnStructName, " {\n"
520     );
521     {
522         IndentationScope scope(indent);
523         for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
524             auto& outputItem = m_entryPointItems.outputs[i];
525             auto& internalTypeName = m_namedOutputs[i].internalTypeName;
526             auto elementName = m_namedOutputs[i].elementName;
527             auto attribute = attributeForSemantic(*outputItem.semantic);
528             stringBuilder.append(indent, internalTypeName, ' ', elementName, ' ', attribute, ";\n");
529         }
530     }
531     stringBuilder.append(indent, "};\n\n");
532
533     emitResourceHelperTypes(stringBuilder, indent);
534 }
535
536 void FragmentEntryPointScaffolding::emitSignature(StringBuilder& stringBuilder, MangledFunctionName functionName, Indentation<4> indent)
537 {
538     stringBuilder.append(indent, "fragment ", m_returnStructName, ' ', functionName, '(', m_stageInStructName, ' ', m_stageInParameterName, " [[stage_in]]");
539     emitResourceSignature(stringBuilder, IncludePrecedingComma::Yes);
540     emitBuiltInsSignature(stringBuilder, IncludePrecedingComma::Yes);
541     stringBuilder.append(")\n");
542 }
543
544 void FragmentEntryPointScaffolding::emitUnpack(StringBuilder& stringBuilder, Indentation<4> indent)
545 {
546     emitUnpackResourcesAndNamedBuiltIns(stringBuilder, indent);
547
548     for (auto& namedStageIn : m_namedStageIns) {
549         auto& path = m_entryPointItems.inputs[namedStageIn.indexInEntryPointItems].path;
550         auto& elementName = namedStageIn.elementName;
551
552         stringBuilder.append(indent);
553         emitMangledInputPath(stringBuilder, path);
554         stringBuilder.append(" = ", m_stageInParameterName, '.', elementName, ";\n");
555     }
556 }
557
558 void FragmentEntryPointScaffolding::emitPack(StringBuilder& stringBuilder, MangledVariableName inputVariableName, MangledVariableName outputVariableName, Indentation<4> indent)
559 {
560     stringBuilder.append(indent, m_returnStructName, ' ', outputVariableName, ";\n");
561     if (m_entryPointItems.outputs.size() == 1 && !m_entryPointItems.outputs[0].path.size()) {
562         auto& elementName = m_namedOutputs[0].elementName;
563         stringBuilder.append(indent, outputVariableName, '.', elementName, " = ", inputVariableName, ";\n");
564         return;
565     }
566     for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
567         auto& elementName = m_namedOutputs[i].elementName;
568         auto& internalTypeName = m_namedOutputs[i].internalTypeName;
569         auto& path = m_entryPointItems.outputs[i].path;
570         stringBuilder.append(indent, outputVariableName, '.', elementName, " = ", internalTypeName, '(', inputVariableName);
571         emitMangledOutputPath(stringBuilder, path);
572         stringBuilder.append(");\n");
573     }
574 }
575
576 ComputeEntryPointScaffolding::ComputeEntryPointScaffolding(AST::FunctionDefinition& functionDefinition, Intrinsics& intrinsics, TypeNamer& typeNamer, EntryPointItems& entryPointItems, HashMap<Binding*, size_t>& resourceMap, Layout& layout, std::function<MangledVariableName()>&& generateNextVariableName)
577     : EntryPointScaffolding(functionDefinition, intrinsics, typeNamer, entryPointItems, resourceMap, layout, WTFMove(generateNextVariableName))
578 {
579 }
580
581 void ComputeEntryPointScaffolding::emitHelperTypes(StringBuilder& stringBuilder, Indentation<4> indent)
582 {
583     emitResourceHelperTypes(stringBuilder, indent);
584 }
585
586 void ComputeEntryPointScaffolding::emitSignature(StringBuilder& stringBuilder, MangledFunctionName functionName, Indentation<4> indent)
587 {
588     stringBuilder.append(indent, "kernel void ", functionName, '(');
589     bool addedToSignature = emitResourceSignature(stringBuilder, IncludePrecedingComma::No);
590     emitBuiltInsSignature(stringBuilder, addedToSignature ? IncludePrecedingComma::Yes : IncludePrecedingComma::No);
591     stringBuilder.append(")\n");
592 }
593
594 void ComputeEntryPointScaffolding::emitUnpack(StringBuilder& stringBuilder, Indentation<4> indent)
595 {
596     emitUnpackResourcesAndNamedBuiltIns(stringBuilder, indent);
597 }
598
599 void ComputeEntryPointScaffolding::emitPack(StringBuilder&, MangledVariableName, MangledVariableName, Indentation<4>)
600 {
601     ASSERT_NOT_REACHED();
602 }
603
604 }
605
606 }
607
608 }
609
610 #endif