53e3788304517a1bc8c4ee871cf4aa538e824bd7
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLPreserveVariableLifetimes.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLPreserveVariableLifetimes.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLAST.h"
32 #include "WHLSLASTDumper.h"
33 #include "WHLSLProgram.h"
34 #include "WHLSLReplaceWith.h"
35 #include "WHLSLVisitor.h"
36
37 namespace WebCore {
38
39 namespace WHLSL {
40
41 // This pass works by ensuring proper variable lifetimes. In WHLSL, each variable
42 // has global lifetime. So returning a pointer to a local variable is a totally
43 // legitimate and well-specified thing to do.
44 //
45 // We implement this by:
46 // - We note every variable whose address we take.
47 // - Each such variable gets defined as a field in a struct.
48 // - Each function which is an entry point defines this struct.
49 // - Each non entry point takes a pointer to this struct as its final parameter.
50 // - Each call to a non-native function is rewritten to pass a pointer to the
51 //   struct as the last call argument.
52 // - Each variable reference to "x", where "x" ends up in the struct, is
53 //   modified to instead be "struct->x". We store to "struct->x" after declaring
54 //   "x". If "x" is a function parameter, we store to "struct->x" as the first
55 //   thing we do in the function body.
56
57 class EscapedVariableCollector final : public Visitor {
58     using Base = Visitor;
59 public:
60
61     void escapeVariableUse(AST::Expression& expression)
62     {
63         if (!is<AST::VariableReference>(expression)) {
64             // FIXME: Are we missing any interesting productions here?
65             // https://bugs.webkit.org/show_bug.cgi?id=198311
66             Base::visit(expression);
67             return;
68         }
69
70         auto* variable = downcast<AST::VariableReference>(expression).variable();
71         ASSERT(variable);
72         // FIXME: We could skip this if we mark all internal variables with a bit, since we
73         // never make any internal variable escape the current scope it is defined in:
74         // https://bugs.webkit.org/show_bug.cgi?id=198383
75         m_escapedVariables.add(variable, makeString("_", variable->name(), "_", m_count++));
76     }
77
78     void visit(AST::MakePointerExpression& makePointerExpression) override
79     {
80         if (makePointerExpression.mightEscape())
81             escapeVariableUse(makePointerExpression.leftValue());
82     }
83
84     void visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpression) override
85     {
86         if (makeArrayReferenceExpression.mightEscape())
87             escapeVariableUse(makeArrayReferenceExpression.leftValue());
88     }
89
90     void visit(AST::FunctionDefinition& functionDefinition) override
91     {
92         if (functionDefinition.parsingMode() != ParsingMode::StandardLibrary)
93             Base::visit(functionDefinition);
94     }
95
96     HashMap<AST::VariableDeclaration*, String> takeEscapedVariables() { return WTFMove(m_escapedVariables); }
97
98 private:
99     size_t m_count { 1 };
100     HashMap<AST::VariableDeclaration*, String> m_escapedVariables;
101 };
102
103 static ALWAYS_INLINE Token anonymousToken(Token::Type type)
104 {
105     return Token { { }, type };
106 }
107
108 class PreserveLifetimes : public Visitor {
109     using Base = Visitor;
110 public:
111     PreserveLifetimes(Ref<AST::TypeReference> structType, const HashMap<AST::VariableDeclaration*, AST::StructureElement*>& variableMapping)
112         : m_structType(WTFMove(structType))
113         , m_pointerToStructType(AST::PointerType::create(anonymousToken(Token::Type::Identifier), AST::AddressSpace::Thread, m_structType.copyRef()))
114         , m_variableMapping(variableMapping)
115     { }
116
117     UniqueRef<AST::VariableReference> makeStructVariableReference()
118     {
119         auto structVariableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(*m_structVariable));
120         structVariableReference->setType(*m_structVariable->type());
121         structVariableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
122         return structVariableReference;
123     }
124
125     UniqueRef<AST::AssignmentExpression> assignVariableIntoStruct(AST::VariableDeclaration& variable, AST::StructureElement* element)
126     {
127         auto lhs = makeUniqueRef<AST::GlobalVariableReference>(variable.codeLocation(), makeStructVariableReference(), element);
128         lhs->setType(*variable.type());
129         lhs->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
130
131         auto rhs = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(variable));
132         rhs->setType(*variable.type());
133         rhs->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
134
135         auto assignment = makeUniqueRef<AST::AssignmentExpression>(variable.codeLocation(), WTFMove(lhs), WTFMove(rhs));
136         assignment->setType(*variable.type());
137         assignment->setTypeAnnotation(AST::RightValue());
138
139         return assignment;
140     }
141
142     void visit(AST::FunctionDefinition& functionDefinition) override
143     {
144         if (functionDefinition.parsingMode() == ParsingMode::StandardLibrary)
145             return;
146
147         bool isEntryPoint = !!functionDefinition.entryPointType();
148         if (isEntryPoint) {
149             auto structVariableDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.codeLocation(), AST::Qualifiers(),
150                 m_structType.copyRef(), String(), nullptr, nullptr);
151
152             auto structVariableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(structVariableDeclaration));
153             structVariableReference->setType(m_structType.copyRef());
154             structVariableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
155
156             AST::VariableDeclarations structVariableDeclarations;
157             structVariableDeclarations.append(WTFMove(structVariableDeclaration));
158             auto structDeclarationStatement = makeUniqueRef<AST::VariableDeclarationsStatement>(functionDefinition.codeLocation(), WTFMove(structVariableDeclarations));
159
160             std::unique_ptr<AST::Expression> makePointerExpression(new AST::MakePointerExpression(functionDefinition.codeLocation(), WTFMove(structVariableReference), AST::AddressEscapeMode::DoesNotEscape));
161             makePointerExpression->setType(m_pointerToStructType.copyRef());
162             makePointerExpression->setTypeAnnotation(AST::RightValue());
163
164             auto pointerDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.codeLocation(), AST::Qualifiers(),
165                 m_pointerToStructType.copyRef(), "wrapper"_s, nullptr, WTFMove(makePointerExpression));
166             m_structVariable = &pointerDeclaration;
167
168             AST::VariableDeclarations pointerVariableDeclarations;
169             pointerVariableDeclarations.append(WTFMove(pointerDeclaration));
170             auto pointerDeclarationStatement = makeUniqueRef<AST::VariableDeclarationsStatement>(functionDefinition.codeLocation(), WTFMove(pointerVariableDeclarations));
171
172             functionDefinition.block().statements().insert(0, WTFMove(structDeclarationStatement));
173             functionDefinition.block().statements().insert(1, WTFMove(pointerDeclarationStatement));
174         } else {
175             auto pointerDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.codeLocation(), AST::Qualifiers(),
176                 m_pointerToStructType.copyRef(), "wrapper"_s, nullptr, nullptr);
177             m_structVariable = &pointerDeclaration;
178             functionDefinition.parameters().append(WTFMove(pointerDeclaration));
179         }
180
181         Base::visit(functionDefinition);
182
183         for (auto& parameter : functionDefinition.parameters()) {
184             auto iter = m_variableMapping.find(&parameter);
185             if (iter == m_variableMapping.end())
186                 continue;
187
188             functionDefinition.block().statements().insert(isEntryPoint ? 2 : 0,
189                 makeUniqueRef<AST::EffectfulExpressionStatement>(assignVariableIntoStruct(parameter, iter->value)));
190         }
191
192         // Inner functions are not allowed in WHLSL. So this is fine.
193         m_structVariable = nullptr;
194     }
195
196     void visit(AST::CallExpression& callExpression) override
197     {
198         RELEASE_ASSERT(m_structVariable);
199
200         Base::visit(callExpression);
201
202         // This works because it's illegal to call an entrypoint. Therefore, we can only
203         // call functions where we've already appended this struct as its final parameter.
204         if (!callExpression.function().isNativeFunctionDeclaration() && callExpression.function().parsingMode() != ParsingMode::StandardLibrary)
205             callExpression.arguments().append(makeStructVariableReference());
206     }
207
208     void visit(AST::VariableReference& variableReference) override
209     {
210         RELEASE_ASSERT(m_structVariable);
211
212         auto iter = m_variableMapping.find(variableReference.variable());
213         if (iter == m_variableMapping.end())
214             return;
215
216         Ref<AST::UnnamedType> type = *variableReference.variable()->type();
217         AST::TypeAnnotation typeAnnotation = variableReference.typeAnnotation();
218         auto* internalField = AST::replaceWith<AST::GlobalVariableReference>(variableReference, variableReference.codeLocation(), makeStructVariableReference(), iter->value);
219         internalField->setType(WTFMove(type));
220         internalField->setTypeAnnotation(WTFMove(typeAnnotation));
221     }
222
223     void visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement) override
224     {
225         RELEASE_ASSERT(m_structVariable);
226
227         Base::visit(variableDeclarationsStatement);
228
229         AST::Statements statements;
230         for (UniqueRef<AST::VariableDeclaration>& variableDeclaration : variableDeclarationsStatement.variableDeclarations()) {
231             AST::VariableDeclaration& variable = variableDeclaration.get();
232
233             {
234                 AST::VariableDeclarations declarations;
235                 declarations.append(WTFMove(variableDeclaration));
236                 statements.append(makeUniqueRef<AST::VariableDeclarationsStatement>(variable.codeLocation(), WTFMove(declarations)));
237             }
238
239             auto iter = m_variableMapping.find(&variable);
240             if (iter != m_variableMapping.end())
241                 statements.append(makeUniqueRef<AST::EffectfulExpressionStatement>(assignVariableIntoStruct(variable, iter->value)));
242         }
243
244         AST::replaceWith<AST::StatementList>(variableDeclarationsStatement, variableDeclarationsStatement.codeLocation(), WTFMove(statements));
245     }
246
247 private:
248     AST::VariableDeclaration* m_structVariable { nullptr };
249
250     Ref<AST::TypeReference> m_structType;
251     Ref<AST::PointerType> m_pointerToStructType;
252     // If this mapping contains the variable, it means that the variable's canonical location
253     // is in the struct we use to preserve variable lifetimes.
254     const HashMap<AST::VariableDeclaration*, AST::StructureElement*>& m_variableMapping;
255 };
256
257 void preserveVariableLifetimes(Program& program)
258 {
259     HashMap<AST::VariableDeclaration*, String> escapedVariables;
260     {
261         EscapedVariableCollector collector;
262         for (size_t i = 0; i < program.functionDefinitions().size(); ++i)
263             collector.visit(program.functionDefinitions()[i]);
264         escapedVariables = collector.takeEscapedVariables();
265     }
266
267     AST::StructureElements elements;
268     for (auto& pair : escapedVariables) {
269         auto* variable = pair.key;
270         String name = pair.value;
271         elements.append(AST::StructureElement { variable->codeLocation(), { }, *variable->type(), WTFMove(name), nullptr });
272     }
273
274     // Name of this doesn't matter, since we don't use struct names when
275     // generating Metal type names. We just pick something here to make it
276     // easy to read in AST dumps.
277     auto wrapperStructDefinition = makeUniqueRef<AST::StructureDefinition>(anonymousToken(Token::Type::Struct), "__WrapperStruct__"_s, WTFMove(elements));
278
279     HashMap<AST::VariableDeclaration*, AST::StructureElement*> variableMapping;
280     unsigned index = 0;
281     for (auto& pair : escapedVariables)
282         variableMapping.add(pair.key, &wrapperStructDefinition->structureElements()[index++]);
283
284     {
285         auto wrapperStructType = AST::TypeReference::wrap(anonymousToken(Token::Type::Identifier), wrapperStructDefinition);
286         PreserveLifetimes preserveLifetimes(WTFMove(wrapperStructType), variableMapping);
287         preserveLifetimes.Visitor::visit(program);
288     }
289
290     program.structureDefinitions().append(WTFMove(wrapperStructDefinition));
291 }
292
293 } // namespace WHLSL
294
295 } // namespace WebCore
296
297 #endif // ENABLE(WEBGPU)