185ba0b088a5ced4c5b6636fa9442c00340dd02a
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / Metal / WHLSLTypeNamer.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLTypeNamer.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLAddressSpace.h"
32 #include "WHLSLArrayReferenceType.h"
33 #include "WHLSLArrayType.h"
34 #include "WHLSLCallExpression.h"
35 #include "WHLSLEnumerationDefinition.h"
36 #include "WHLSLEnumerationMember.h"
37 #include "WHLSLNativeTypeDeclaration.h"
38 #include "WHLSLNativeTypeWriter.h"
39 #include "WHLSLPointerType.h"
40 #include "WHLSLStructureDefinition.h"
41 #include "WHLSLTypeDefinition.h"
42 #include "WHLSLTypeReference.h"
43 #include "WHLSLVisitor.h"
44 #include <algorithm>
45 #include <functional>
46 #include <wtf/FastMalloc.h>
47 #include <wtf/HashMap.h>
48 #include <wtf/HashSet.h>
49 #include <wtf/Optional.h>
50 #include <wtf/UniqueRef.h>
51 #include <wtf/Vector.h>
52 #include <wtf/text/StringBuilder.h>
53 #include <wtf/text/StringConcatenateNumbers.h>
54
55 namespace WebCore {
56
57 namespace WHLSL {
58
59 namespace Metal {
60
61 // FIXME: Look into replacing BaseTypeNameNode with a simple struct { RefPtr<UnnamedType> parent; MangledTypeName; } that UnnamedTypeKeys map to.
62 class BaseTypeNameNode {
63     WTF_MAKE_FAST_ALLOCATED;
64 public:
65     BaseTypeNameNode(BaseTypeNameNode* parent, MangledTypeName&& mangledName, AST::UnnamedType::Kind kind)
66         : m_parent(parent)
67         , m_mangledName(mangledName)
68         , m_kind(kind)
69     {
70     }
71     virtual ~BaseTypeNameNode() = default;
72     
73     AST::UnnamedType::Kind kind() { return m_kind; }
74     bool isReferenceTypeNameNode() const { return m_kind == AST::UnnamedType::Kind::TypeReference; }
75     bool isPointerTypeNameNode() const { return m_kind == AST::UnnamedType::Kind::Pointer; }
76     bool isArrayReferenceTypeNameNode() const { return m_kind == AST::UnnamedType::Kind::ArrayReference; }
77     bool isArrayTypeNameNode() const { return m_kind == AST::UnnamedType::Kind::Array; }
78
79     BaseTypeNameNode* parent() { return m_parent; }
80     MangledTypeName mangledName() const { return m_mangledName; }
81
82 private:
83     BaseTypeNameNode* m_parent;
84     MangledTypeName m_mangledName;
85     AST::UnnamedType::Kind m_kind;
86 };
87
88 class ArrayTypeNameNode final : public BaseTypeNameNode {
89     WTF_MAKE_FAST_ALLOCATED;
90 public:
91     ArrayTypeNameNode(BaseTypeNameNode* parent, MangledTypeName&& mangledName, unsigned numElements)
92         : BaseTypeNameNode(parent, WTFMove(mangledName), AST::UnnamedType::Kind::Array)
93         , m_numElements(numElements)
94     {
95     }
96     virtual ~ArrayTypeNameNode() = default;
97     unsigned numElements() const { return m_numElements; }
98
99 private:
100     unsigned m_numElements;
101 };
102
103 class ArrayReferenceTypeNameNode final : public BaseTypeNameNode {
104     WTF_MAKE_FAST_ALLOCATED;
105 public:
106     ArrayReferenceTypeNameNode(BaseTypeNameNode* parent, MangledTypeName&& mangledName, AST::AddressSpace addressSpace)
107         : BaseTypeNameNode(parent, WTFMove(mangledName), AST::UnnamedType::Kind::ArrayReference)
108         , m_addressSpace(addressSpace)
109     {
110     }
111     virtual ~ArrayReferenceTypeNameNode() = default;
112     AST::AddressSpace addressSpace() const { return m_addressSpace; }
113
114 private:
115     AST::AddressSpace m_addressSpace;
116 };
117
118 class PointerTypeNameNode final : public BaseTypeNameNode {
119     WTF_MAKE_FAST_ALLOCATED;
120 public:
121     PointerTypeNameNode(BaseTypeNameNode* parent, MangledTypeName&& mangledName, AST::AddressSpace addressSpace)
122         : BaseTypeNameNode(parent, WTFMove(mangledName), AST::UnnamedType::Kind::Pointer)
123         , m_addressSpace(addressSpace)
124     {
125     }
126     virtual ~PointerTypeNameNode() = default;
127     AST::AddressSpace addressSpace() const { return m_addressSpace; }
128
129 private:
130     AST::AddressSpace m_addressSpace;
131 };
132
133 class ReferenceTypeNameNode final : public BaseTypeNameNode {
134     WTF_MAKE_FAST_ALLOCATED;
135 public:
136     ReferenceTypeNameNode(BaseTypeNameNode* parent, MangledTypeName&& mangledName, AST::NamedType& namedType)
137         : BaseTypeNameNode(parent, WTFMove(mangledName), AST::UnnamedType::Kind::TypeReference)
138         , m_namedType(namedType)
139     {
140     }
141     virtual ~ReferenceTypeNameNode() = default;
142     AST::NamedType& namedType() { return m_namedType; }
143
144 private:
145     AST::NamedType& m_namedType;
146 };
147
148 }
149
150 }
151
152 }
153
154 #define SPECIALIZE_TYPE_TRAITS_WHLSL_BASE_TYPE_NAMED_NODE(ToValueTypeName, predicate) \
155 SPECIALIZE_TYPE_TRAITS_BEGIN(WebCore::WHLSL::Metal::ToValueTypeName) \
156     static bool isType(const WebCore::WHLSL::Metal::BaseTypeNameNode& type) { return type.predicate; } \
157 SPECIALIZE_TYPE_TRAITS_END()
158
159 SPECIALIZE_TYPE_TRAITS_WHLSL_BASE_TYPE_NAMED_NODE(ArrayTypeNameNode, isArrayTypeNameNode())
160
161 SPECIALIZE_TYPE_TRAITS_WHLSL_BASE_TYPE_NAMED_NODE(ArrayReferenceTypeNameNode, isArrayReferenceTypeNameNode())
162
163 SPECIALIZE_TYPE_TRAITS_WHLSL_BASE_TYPE_NAMED_NODE(PointerTypeNameNode, isPointerTypeNameNode())
164
165 SPECIALIZE_TYPE_TRAITS_WHLSL_BASE_TYPE_NAMED_NODE(ReferenceTypeNameNode, isReferenceTypeNameNode())
166
167 namespace WebCore {
168
169 namespace WHLSL {
170
171 namespace Metal {
172
173 TypeNamer::TypeNamer(Program& program)
174     : m_program(program)
175 {
176 }
177
178 TypeNamer::~TypeNamer() = default;
179
180 void TypeNamer::visit(AST::UnnamedType& unnamedType)
181 {
182     insert(unnamedType);
183 }
184
185 void TypeNamer::visit(AST::EnumerationDefinition& enumerationDefinition)
186 {
187     {
188         auto addResult = m_namedTypeMapping.add(&enumerationDefinition, generateNextTypeName());
189         ASSERT_UNUSED(addResult, addResult.isNewEntry);
190     }
191
192     for (auto& enumerationMember : enumerationDefinition.enumerationMembers()) {
193         auto addResult = m_enumerationMemberMapping.add(&static_cast<AST::EnumerationMember&>(enumerationMember), generateNextEnumerationMemberName());
194         ASSERT_UNUSED(addResult, addResult.isNewEntry);
195     }
196
197     Visitor::visit(enumerationDefinition);
198
199     {
200         Vector<std::reference_wrapper<BaseTypeNameNode>> neighbors = { find(enumerationDefinition.type()) };
201         auto addResult = m_dependencyGraph.add(&enumerationDefinition, WTFMove(neighbors));
202         ASSERT_UNUSED(addResult, addResult.isNewEntry);
203     }
204 }
205
206 void TypeNamer::visit(AST::NativeTypeDeclaration& nativeTypeDeclaration)
207 {
208     // Native type declarations already have names, and are already declared in Metal.
209     auto addResult = m_dependencyGraph.add(&nativeTypeDeclaration, Vector<std::reference_wrapper<BaseTypeNameNode>>());
210     ASSERT_UNUSED(addResult, addResult.isNewEntry);
211 }
212
213 void TypeNamer::visit(AST::StructureDefinition& structureDefinition)
214 {
215     {
216         auto addResult = m_namedTypeMapping.add(&structureDefinition, generateNextTypeName());
217         ASSERT_UNUSED(addResult, addResult.isNewEntry);
218     }
219     Visitor::visit(structureDefinition);
220     {
221         Vector<std::reference_wrapper<BaseTypeNameNode>> neighbors;
222         for (auto& structureElement : structureDefinition.structureElements()) {
223             auto addResult = m_structureElementMapping.add(&structureElement, generateNextStructureElementName());
224             ASSERT_UNUSED(addResult, addResult.isNewEntry);
225             neighbors.append(find(structureElement.type()));
226         }
227         auto addResult = m_dependencyGraph.add(&structureDefinition, WTFMove(neighbors));
228         ASSERT_UNUSED(addResult, addResult.isNewEntry);
229     }
230 }
231
232 void TypeNamer::visit(AST::TypeDefinition& typeDefinition)
233 {
234     {
235         auto addResult = m_namedTypeMapping.add(&typeDefinition, generateNextTypeName());
236         ASSERT_UNUSED(addResult, addResult.isNewEntry);
237     }
238     Visitor::visit(typeDefinition);
239     {
240         Vector<std::reference_wrapper<BaseTypeNameNode>> neighbors = { find(typeDefinition.type()) };
241         auto addResult = m_dependencyGraph.add(&typeDefinition, WTFMove(neighbors));
242         ASSERT_UNUSED(addResult, addResult.isNewEntry);
243     }
244 }
245
246 void TypeNamer::visit(AST::Expression& expression)
247 {
248     insert(expression.resolvedType());
249     Visitor::visit(expression);
250 }
251
252 void TypeNamer::visit(AST::CallExpression& callExpression)
253 {
254     for (auto& argument : callExpression.arguments())
255         checkErrorAndVisit(argument);
256 }
257
258 String TypeNamer::mangledNameForType(AST::NativeTypeDeclaration& nativeTypeDeclaration)
259 {
260     return writeNativeType(nativeTypeDeclaration);
261 }
262
263 BaseTypeNameNode& TypeNamer::find(AST::UnnamedType& unnamedType)
264 {
265     auto iterator = m_unnamedTypesUniquingMap.find(unnamedType);
266     ASSERT(iterator != m_unnamedTypesUniquingMap.end());
267     return *iterator->value;
268 }
269
270 std::unique_ptr<BaseTypeNameNode> TypeNamer::createNameNode(AST::UnnamedType& unnamedType, BaseTypeNameNode* parent)
271 {
272     switch (unnamedType.kind()) {
273     case AST::UnnamedType::Kind::TypeReference: {
274         auto& typeReference = downcast<AST::TypeReference>(unnamedType);
275         return std::make_unique<ReferenceTypeNameNode>(parent, generateNextTypeName(), typeReference.resolvedType());
276     }
277     case AST::UnnamedType::Kind::Pointer: {
278         auto& pointerType = downcast<AST::PointerType>(unnamedType);
279         return std::make_unique<PointerTypeNameNode>(parent, generateNextTypeName(), pointerType.addressSpace());
280     }
281     case AST::UnnamedType::Kind::ArrayReference: {
282         auto& arrayReferenceType = downcast<AST::ArrayReferenceType>(unnamedType);
283         return std::make_unique<ArrayReferenceTypeNameNode>(parent, generateNextTypeName(), arrayReferenceType.addressSpace());
284     }
285     case AST::UnnamedType::Kind::Array: {
286         auto& arrayType = downcast<AST::ArrayType>(unnamedType);
287         return std::make_unique<ArrayTypeNameNode>(parent, generateNextTypeName(), arrayType.numElements());
288     }
289     default:
290         RELEASE_ASSERT_NOT_REACHED();
291     }
292 }
293
294 static AST::UnnamedType* parent(AST::UnnamedType& unnamedType)
295 {
296     switch (unnamedType.kind()) {
297     case AST::UnnamedType::Kind::TypeReference:
298         return nullptr;
299     case AST::UnnamedType::Kind::Pointer:
300         return &downcast<AST::PointerType>(unnamedType).elementType();
301     case AST::UnnamedType::Kind::ArrayReference:
302         return &downcast<AST::ArrayReferenceType>(unnamedType).elementType();
303     case AST::UnnamedType::Kind::Array:
304         return &downcast<AST::ArrayType>(unnamedType).type();
305     default:
306         RELEASE_ASSERT_NOT_REACHED();
307     }
308 }
309
310 BaseTypeNameNode* TypeNamer::insert(AST::UnnamedType& unnamedType)
311 {
312     if (auto* result = m_unnamedTypeMapping.get(&unnamedType))
313         return result;
314
315     auto* parentUnnamedType = parent(unnamedType);
316     BaseTypeNameNode* parentNode = parentUnnamedType ? insert(*parentUnnamedType) : nullptr;
317
318     auto addResult = m_unnamedTypesUniquingMap.ensure(UnnamedTypeKey { unnamedType }, [&] {
319         return createNameNode(unnamedType, parentNode);
320     });
321
322     m_unnamedTypeMapping.add(&unnamedType, addResult.iterator->value.get());
323     return addResult.iterator->value.get();
324 }
325
326 class MetalTypeDeclarationWriter final : public Visitor {
327     WTF_MAKE_FAST_ALLOCATED;
328 public:
329     MetalTypeDeclarationWriter(StringBuilder& stringBuilder, std::function<MangledOrNativeTypeName(AST::NamedType&)>&& mangledNameForNamedType)
330         : m_mangledNameForNamedType(WTFMove(mangledNameForNamedType))
331         , m_stringBuilder(stringBuilder)
332     {
333     }
334
335 private:
336     void visit(AST::StructureDefinition& structureDefinition) override
337     {
338         m_stringBuilder.flexibleAppend("struct ", m_mangledNameForNamedType(structureDefinition), ";\n");
339     }
340
341     std::function<MangledOrNativeTypeName(AST::NamedType&)> m_mangledNameForNamedType;
342     StringBuilder& m_stringBuilder;
343 };
344
345 void TypeNamer::emitMetalTypeDeclarations(StringBuilder& stringBuilder)
346 {
347     MetalTypeDeclarationWriter metalTypeDeclarationWriter(stringBuilder, [&](AST::NamedType& namedType) -> MangledOrNativeTypeName {
348         return mangledNameForType(namedType);
349     });
350     metalTypeDeclarationWriter.Visitor::visit(m_program);
351 }
352
353 void TypeNamer::emitUnnamedTypeDefinition(StringBuilder& stringBuilder, BaseTypeNameNode& baseTypeNameNode, HashSet<AST::NamedType*>& emittedNamedTypes, HashSet<BaseTypeNameNode*>& emittedUnnamedTypes)
354 {
355     if (emittedUnnamedTypes.contains(&baseTypeNameNode))
356         return;
357
358     if (baseTypeNameNode.parent())
359         emitUnnamedTypeDefinition(stringBuilder, *baseTypeNameNode.parent(), emittedNamedTypes, emittedUnnamedTypes);
360     
361     switch (baseTypeNameNode.kind()) {
362     case AST::UnnamedType::Kind::TypeReference: {
363         auto& namedType = downcast<ReferenceTypeNameNode>(baseTypeNameNode).namedType();
364         emitNamedTypeDefinition(stringBuilder, namedType, emittedNamedTypes, emittedUnnamedTypes);
365         stringBuilder.flexibleAppend("typedef ", mangledNameForType(namedType), ' ', baseTypeNameNode.mangledName(), ";\n");
366         break;
367     }
368     case AST::UnnamedType::Kind::Pointer: {
369         auto& pointerType = downcast<PointerTypeNameNode>(baseTypeNameNode);
370         ASSERT(baseTypeNameNode.parent());
371         stringBuilder.flexibleAppend("typedef ", toString(pointerType.addressSpace()), ' ', pointerType.parent()->mangledName(), "* ", pointerType.mangledName(), ";\n");
372         break;
373     }
374     case AST::UnnamedType::Kind::ArrayReference: {
375         auto& arrayReferenceType = downcast<ArrayReferenceTypeNameNode>(baseTypeNameNode);
376         ASSERT(baseTypeNameNode.parent());
377         stringBuilder.flexibleAppend(
378             "struct ", arrayReferenceType.mangledName(), "{ \n"
379             "    ", toString(arrayReferenceType.addressSpace()), ' ', arrayReferenceType.parent()->mangledName(), "* pointer;\n"
380             "    uint32_t length;\n"
381             "};\n"
382         );
383         break;
384     }
385     case AST::UnnamedType::Kind::Array: {
386         auto& arrayType = downcast<ArrayTypeNameNode>(baseTypeNameNode);
387         ASSERT(baseTypeNameNode.parent());
388         stringBuilder.flexibleAppend("typedef array<", arrayType.parent()->mangledName(), ", ", arrayType.numElements(), "> ", arrayType.mangledName(), ";\n");
389         break;
390     }
391     default:
392         RELEASE_ASSERT_NOT_REACHED();
393     }
394
395     emittedUnnamedTypes.add(&baseTypeNameNode);
396 }
397
398 void TypeNamer::emitNamedTypeDefinition(StringBuilder& stringBuilder, AST::NamedType& namedType, HashSet<AST::NamedType*>& emittedNamedTypes, HashSet<BaseTypeNameNode*>& emittedUnnamedTypes)
399 {
400     if (emittedNamedTypes.contains(&namedType))
401         return;
402     auto iterator = m_dependencyGraph.find(&namedType);
403     ASSERT(iterator != m_dependencyGraph.end());
404     for (auto& baseTypeNameNode : iterator->value)
405         emitUnnamedTypeDefinition(stringBuilder, baseTypeNameNode, emittedNamedTypes, emittedUnnamedTypes);
406     if (is<AST::EnumerationDefinition>(namedType)) {
407         auto& enumerationDefinition = downcast<AST::EnumerationDefinition>(namedType);
408         auto& baseType = enumerationDefinition.type().unifyNode();
409         stringBuilder.flexibleAppend("enum class ", mangledNameForType(enumerationDefinition), " : ", mangledNameForType(downcast<AST::NamedType>(baseType)), " {\n");
410         for (auto& enumerationMember : enumerationDefinition.enumerationMembers())
411             stringBuilder.flexibleAppend("    ", mangledNameForEnumerationMember(enumerationMember), " = ", enumerationMember.get().value(), ",\n");
412         stringBuilder.append("};\n");
413     } else if (is<AST::NativeTypeDeclaration>(namedType)) {
414         // Native types already have definitions. There's nothing to do.
415     } else if (is<AST::StructureDefinition>(namedType)) {
416         auto& structureDefinition = downcast<AST::StructureDefinition>(namedType);
417         stringBuilder.flexibleAppend("struct ", mangledNameForType(structureDefinition), " {\n");
418         for (auto& structureElement : structureDefinition.structureElements())
419             stringBuilder.flexibleAppend("    ", mangledNameForType(structureElement.type()), ' ', mangledNameForStructureElement(structureElement), ";\n");
420         stringBuilder.append("};\n");
421     } else {
422         auto& typeDefinition = downcast<AST::TypeDefinition>(namedType);
423         stringBuilder.flexibleAppend("typedef ", mangledNameForType(typeDefinition.type()), ' ', mangledNameForType(typeDefinition), ";\n");
424     }
425     emittedNamedTypes.add(&namedType);
426 }
427
428 void TypeNamer::emitMetalTypeDefinitions(StringBuilder& stringBuilder)
429 {
430     HashSet<AST::NamedType*> emittedNamedTypes;
431     HashSet<BaseTypeNameNode*> emittedUnnamedTypes;
432     for (auto& namedType : m_dependencyGraph.keys())
433         emitNamedTypeDefinition(stringBuilder, *namedType, emittedNamedTypes, emittedUnnamedTypes);
434     for (auto& node : m_unnamedTypesUniquingMap.values())
435         emitUnnamedTypeDefinition(stringBuilder, *node, emittedNamedTypes, emittedUnnamedTypes);
436 }
437
438 MangledTypeName TypeNamer::mangledNameForType(AST::UnnamedType& unnamedType)
439 {
440     return find(unnamedType).mangledName();
441 }
442
443 MangledOrNativeTypeName TypeNamer::mangledNameForType(AST::NamedType& namedType)
444 {
445     if (is<AST::NativeTypeDeclaration>(namedType))
446         return mangledNameForType(downcast<AST::NativeTypeDeclaration>(namedType));
447     auto iterator = m_namedTypeMapping.find(&namedType);
448     ASSERT(iterator != m_namedTypeMapping.end());
449     return iterator->value;
450 }
451
452
453 MangledEnumerationMemberName TypeNamer::mangledNameForEnumerationMember(AST::EnumerationMember& enumerationMember)
454 {
455     auto iterator = m_enumerationMemberMapping.find(&enumerationMember);
456     ASSERT(iterator != m_enumerationMemberMapping.end());
457     return iterator->value;
458 }
459
460 MangledStructureElementName TypeNamer::mangledNameForStructureElement(AST::StructureElement& structureElement)
461 {
462     auto iterator = m_structureElementMapping.find(&structureElement);
463     ASSERT(iterator != m_structureElementMapping.end());
464     return iterator->value;
465 }
466
467 void TypeNamer::emitMetalTypes(StringBuilder& stringBuilder)
468 {
469     Visitor::visit(m_program);
470
471     emitMetalTypeDeclarations(stringBuilder);
472     stringBuilder.append('\n');
473     emitMetalTypeDefinitions(stringBuilder);
474 }
475
476 } // namespace Metal
477
478 } // namespace WHLSL
479
480 } // namespace WebCore
481
482 #endif