[WHLSL] Standard library is too big to directly include in WebCore
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLPreserveVariableLifetimes.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLPreserveVariableLifetimes.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLAST.h"
32 #include "WHLSLASTDumper.h"
33 #include "WHLSLReplaceWith.h"
34 #include "WHLSLVisitor.h"
35
36 namespace WebCore {
37
38 namespace WHLSL {
39
40 // This pass works by ensuring proper variable lifetimes. In WHLSL, each variable
41 // has global lifetime. So returning a pointer to a local variable is a totally
42 // legitimate and well-specified thing to do.
43 //
44 // We implement this by:
45 // - We note every variable whose address we take.
46 // - Each such variable gets defined as a field in a struct.
47 // - Each function which is an entry point defines this struct.
48 // - Each non entry point takes a pointer to this struct as its final parameter.
49 // - Each call to a non-native function is rewritten to pass a pointer to the
50 //   struct as the last call argument.
51 // - Each variable reference to "x", where "x" ends up in the struct, is
52 //   modified to instead be "struct->x". We store to "struct->x" after declaring
53 //   "x". If "x" is a function parameter, we store to "struct->x" as the first
54 //   thing we do in the function body.
55
56 class EscapedVariableCollector : public Visitor {
57     using Base = Visitor;
58 public:
59
60     void escapeVariableUse(AST::Expression& expression)
61     {
62         if (!is<AST::VariableReference>(expression)) {
63             // FIXME: Are we missing any interesting productions here?
64             // https://bugs.webkit.org/show_bug.cgi?id=198311
65             Base::visit(expression);
66             return;
67         }
68
69         auto* variable = downcast<AST::VariableReference>(expression).variable();
70         ASSERT(variable);
71         // FIXME: We could skip this if we mark all internal variables with a bit, since we
72         // never make any internal variable escape the current scope it is defined in:
73         // https://bugs.webkit.org/show_bug.cgi?id=198383
74         m_escapedVariables.add(variable, makeString("_", variable->name(), "_", m_count++));
75     }
76
77     void visit(AST::MakePointerExpression& makePointerExpression) override
78     {
79         escapeVariableUse(makePointerExpression.leftValue());
80     }
81
82     void visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression) override
83     {
84         escapeVariableUse(makeArrayReferenceExpression.leftValue());
85     }
86
87     HashMap<AST::VariableDeclaration*, String> takeEscapedVariables() { return WTFMove(m_escapedVariables); }
88
89 private:
90     size_t m_count { 1 };
91     HashMap<AST::VariableDeclaration*, String> m_escapedVariables;
92 };
93
94 static ALWAYS_INLINE Lexer::Token anonymousToken(Lexer::Token::Type type)
95 {
96     return Lexer::Token { 0, 0, type };
97 }
98
99 class PreserveLifetimes : public Visitor {
100     using Base = Visitor;
101 public:
102     PreserveLifetimes(UniqueRef<AST::TypeReference>& structType, const HashMap<AST::VariableDeclaration*, AST::StructureElement*>& variableMapping)
103         : m_structType(structType)
104         , m_pointerToStructType(makeUniqueRef<AST::PointerType>(anonymousToken(Lexer::Token::Type::Identifier), AST::AddressSpace::Thread, m_structType->clone()))
105         , m_variableMapping(variableMapping)
106     { }
107
108     UniqueRef<AST::VariableReference> makeStructVariableReference()
109     {
110         auto structVariableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(*m_structVariable));
111         structVariableReference->setType(m_structVariable->type()->clone());
112         structVariableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
113         return structVariableReference;
114     }
115
116     UniqueRef<AST::AssignmentExpression> assignVariableIntoStruct(AST::VariableDeclaration& variable, AST::StructureElement* element)
117     {
118         auto lhs = makeUniqueRef<AST::GlobalVariableReference>(variable.origin(), makeStructVariableReference(), element);
119         lhs->setType(variable.type()->clone());
120         lhs->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
121
122         auto rhs = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(variable));
123         rhs->setType(variable.type()->clone());
124         rhs->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
125
126         auto assignment = makeUniqueRef<AST::AssignmentExpression>(variable.origin(), WTFMove(lhs), WTFMove(rhs));
127         assignment->setType(variable.type()->clone());
128         assignment->setTypeAnnotation(AST::RightValue());
129
130         return assignment;
131     }
132
133     void visit(AST::FunctionDefinition& functionDefinition) override
134     {
135         bool isEntryPoint = !!functionDefinition.entryPointType();
136         if (isEntryPoint) {
137             auto structVariableDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.origin(), AST::Qualifiers(),
138                 m_structType->clone(), String(), nullptr, nullptr);
139
140             auto structVariableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(structVariableDeclaration));
141             structVariableReference->setType(m_structType->clone());
142             structVariableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
143
144             AST::VariableDeclarations structVariableDeclarations;
145             structVariableDeclarations.append(WTFMove(structVariableDeclaration));
146             auto structDeclarationStatement = makeUniqueRef<AST::VariableDeclarationsStatement>(functionDefinition.origin(), WTFMove(structVariableDeclarations));
147
148             auto makePointerExpression = std::make_unique<AST::MakePointerExpression>(functionDefinition.origin(), WTFMove(structVariableReference));
149             makePointerExpression->setType(m_pointerToStructType->clone());
150             makePointerExpression->setTypeAnnotation(AST::RightValue());
151
152             auto pointerDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.origin(), AST::Qualifiers(),
153                 m_pointerToStructType->clone(), "wrapper"_s, nullptr, WTFMove(makePointerExpression));
154             m_structVariable = &pointerDeclaration;
155
156             AST::VariableDeclarations pointerVariableDeclarations;
157             pointerVariableDeclarations.append(WTFMove(pointerDeclaration));
158             auto pointerDeclarationStatement = makeUniqueRef<AST::VariableDeclarationsStatement>(functionDefinition.origin(), WTFMove(pointerVariableDeclarations));
159
160             functionDefinition.block().statements().insert(0, WTFMove(structDeclarationStatement));
161             functionDefinition.block().statements().insert(1, WTFMove(pointerDeclarationStatement));
162         } else {
163             auto pointerDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.origin(), AST::Qualifiers(),
164                 m_pointerToStructType->clone(), "wrapper"_s, nullptr, nullptr);
165             m_structVariable = &pointerDeclaration;
166             functionDefinition.parameters().append(WTFMove(pointerDeclaration));
167         }
168
169         Base::visit(functionDefinition);
170
171         for (auto& parameter : functionDefinition.parameters()) {
172             auto iter = m_variableMapping.find(&parameter);
173             if (iter == m_variableMapping.end())
174                 continue;
175
176             functionDefinition.block().statements().insert(isEntryPoint ? 2 : 0,
177                 makeUniqueRef<AST::EffectfulExpressionStatement>(assignVariableIntoStruct(parameter, iter->value)));
178         }
179
180         // Inner functions are not allowed in WHLSL. So this is fine.
181         m_structVariable = nullptr;
182     }
183
184     void visit(AST::CallExpression& callExpression) override
185     {
186         RELEASE_ASSERT(m_structVariable);
187
188         Base::visit(callExpression);
189
190         // This works because it's illegal to call an entrypoint. Therefore, we can only
191         // call functions where we've already appended this struct as its final parameter.
192         if (!callExpression.function().isNativeFunctionDeclaration())
193             callExpression.arguments().append(makeStructVariableReference());
194     }
195
196     void visit(AST::VariableReference& variableReference) override
197     {
198         RELEASE_ASSERT(m_structVariable);
199
200         auto iter = m_variableMapping.find(variableReference.variable());
201         if (iter == m_variableMapping.end())
202             return;
203
204         auto type = variableReference.variable()->type()->clone();
205         AST::TypeAnnotation typeAnnotation = variableReference.typeAnnotation();
206         auto* internalField = AST::replaceWith<AST::GlobalVariableReference>(variableReference, variableReference.origin(), makeStructVariableReference(), iter->value);
207         internalField->setType(WTFMove(type));
208         internalField->setTypeAnnotation(WTFMove(typeAnnotation));
209     }
210
211     void visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement) override
212     {
213         RELEASE_ASSERT(m_structVariable);
214
215         Base::visit(variableDeclarationsStatement);
216
217         AST::Statements statements;
218         for (UniqueRef<AST::VariableDeclaration>& variableDeclaration : variableDeclarationsStatement.variableDeclarations()) {
219             AST::VariableDeclaration& variable = variableDeclaration.get();
220
221             {
222                 AST::VariableDeclarations declarations;
223                 declarations.append(WTFMove(variableDeclaration));
224                 statements.append(makeUniqueRef<AST::VariableDeclarationsStatement>(variable.origin(), WTFMove(declarations)));
225             }
226
227             auto iter = m_variableMapping.find(&variable);
228             if (iter != m_variableMapping.end())
229                 statements.append(makeUniqueRef<AST::EffectfulExpressionStatement>(assignVariableIntoStruct(variable, iter->value)));
230         }
231
232         auto origin = Lexer::Token(variableDeclarationsStatement.origin());
233         AST::replaceWith<AST::StatementList>(variableDeclarationsStatement, WTFMove(origin), WTFMove(statements));
234     }
235
236 private:
237     AST::VariableDeclaration* m_structVariable { nullptr };
238
239     UniqueRef<AST::TypeReference>& m_structType;
240     UniqueRef<AST::PointerType> m_pointerToStructType;
241     // If this mapping contains the variable, it means that the variable's canonical location
242     // is in the struct we use to preserve variable lifetimes.
243     const HashMap<AST::VariableDeclaration*, AST::StructureElement*>& m_variableMapping;
244 };
245
246 void preserveVariableLifetimes(Program& program)
247 {
248     HashMap<AST::VariableDeclaration*, String> escapedVariables;
249     {
250         EscapedVariableCollector collector;
251         collector.Visitor::visit(program);
252         escapedVariables = collector.takeEscapedVariables();
253     }
254
255     AST::StructureElements elements;
256     for (auto& pair : escapedVariables) {
257         auto* variable = pair.key;
258         String name = pair.value;
259         elements.append(AST::StructureElement { Lexer::Token(variable->origin()), { }, variable->type()->clone(), WTFMove(name), nullptr });
260     }
261
262     // Name of this doesn't matter, since we don't use struct names when
263     // generating Metal type names. We just pick something here to make it
264     // easy to read in AST dumps.
265     auto wrapperStructDefinition = makeUniqueRef<AST::StructureDefinition>(anonymousToken(Lexer::Token::Type::Struct), "__WrapperStruct__"_s, WTFMove(elements));
266
267     HashMap<AST::VariableDeclaration*, AST::StructureElement*> variableMapping;
268     unsigned index = 0;
269     for (auto& pair : escapedVariables)
270         variableMapping.add(pair.key, &wrapperStructDefinition->structureElements()[index++]);
271
272     {
273         auto wrapperStructType = AST::TypeReference::wrap(anonymousToken(Lexer::Token::Type::Identifier), wrapperStructDefinition);
274         PreserveLifetimes preserveLifetimes(wrapperStructType, variableMapping);
275         preserveLifetimes.Visitor::visit(program);
276     }
277
278     program.structureDefinitions().append(WTFMove(wrapperStructDefinition));
279 }
280
281 } // namespace WHLSL
282
283 } // namespace WebCore
284
285 #endif // ENABLE(WEBGPU)