[WHLSL] Enforce variable lifetimes
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLPreserveVariableLifetimes.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLPreserveVariableLifetimes.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLAST.h"
32 #include "WHLSLASTDumper.h"
33 #include "WHLSLVisitor.h"
34
35 namespace WebCore {
36
37 namespace WHLSL {
38
39 // This pass works by ensuring proper variable lifetimes. In WHLSL, each variable
40 // has global lifetime. So returning a pointer to a local variable is a totally
41 // legitimate and well-specified thing to do.
42 //
43 // We implement this by:
44 // - We note every variable whose address we take.
45 // - Each such variable gets defined as a field in a struct.
46 // - Each function which is an entry point defines this struct.
47 // - Each non entry point takes a pointer to this struct as its final parameter.
48 // - Each call to a non-native function is rewritten to pass a pointer to the
49 //   struct as the last call argument.
50 // - Each variable reference to "x", where "x" ends up in the struct, is
51 //   modified to instead be "struct->x". We store to "struct->x" after declaring
52 //   "x". If "x" is a function parameter, we store to "struct->x" as the first
53 //   thing we do in the function body.
54
55 class EscapedVariableCollector : public Visitor {
56     using Base = Visitor;
57 public:
58     void visit(AST::MakePointerExpression& makePointerExpression) override
59     {
60         if (!is<AST::VariableReference>(makePointerExpression.leftValue())) {
61             // FIXME: Are we missing any interesting productions here?
62             // https://bugs.webkit.org/show_bug.cgi?id=198311
63             Base::visit(makePointerExpression.leftValue());
64             return;
65         }
66
67         auto& variableReference = downcast<AST::VariableReference>(makePointerExpression.leftValue());
68         auto* variable = variableReference.variable();
69         ASSERT(variable);
70         // FIXME: We could skip this if we mark all internal variables with a bit, since we
71         // never make any internal variable escape the current scope it is defined in:
72         // https://bugs.webkit.org/show_bug.cgi?id=198383
73         m_escapedVariables.add(variable, makeString("_", variable->name(), "_", m_count++));
74     }
75
76     HashMap<AST::VariableDeclaration*, String> takeEscapedVariables() { return WTFMove(m_escapedVariables); }
77
78 private:
79     size_t m_count { 1 };
80     HashMap<AST::VariableDeclaration*, String> m_escapedVariables;
81 };
82
83 static ALWAYS_INLINE Lexer::Token anonymousToken(Lexer::Token::Type type)
84 {
85     return Lexer::Token { { }, 0, type };
86 }
87
88 class PreserveLifetimes : public Visitor {
89     using Base = Visitor;
90 public:
91     PreserveLifetimes(UniqueRef<AST::TypeReference>& structType, const HashMap<AST::VariableDeclaration*, AST::StructureElement*>& variableMapping)
92         : m_structType(structType)
93         , m_pointerToStructType(makeUniqueRef<AST::PointerType>(anonymousToken(Lexer::Token::Type::Identifier), AST::AddressSpace::Thread, m_structType->clone()))
94         , m_variableMapping(variableMapping)
95     { }
96
97     UniqueRef<AST::VariableReference> makeStructVariableReference()
98     {
99         auto structVariableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(*m_structVariable));
100         structVariableReference->setType(m_structVariable->type()->clone());
101         structVariableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
102         return structVariableReference;
103     }
104
105     UniqueRef<AST::AssignmentExpression> assignVariableIntoStruct(AST::VariableDeclaration& variable, AST::StructureElement* element)
106     {
107         auto lhs = makeUniqueRef<AST::GlobalVariableReference>(variable.origin(), makeStructVariableReference(), element);
108         lhs->setType(variable.type()->clone());
109         lhs->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
110
111         auto rhs = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(variable));
112         rhs->setType(variable.type()->clone());
113         rhs->setTypeAnnotation(AST::RightValue());
114
115         auto assignment = makeUniqueRef<AST::AssignmentExpression>(variable.origin(), WTFMove(lhs), WTFMove(rhs));
116         assignment->setType(variable.type()->clone());
117         assignment->setTypeAnnotation(AST::RightValue());
118
119         return assignment;
120     }
121
122     void visit(AST::FunctionDefinition& functionDefinition) override
123     {
124         bool isEntryPoint = !!functionDefinition.entryPointType();
125         if (isEntryPoint) {
126             auto structVariableDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.origin(), AST::Qualifiers(),
127                 m_structType->clone(), String(), WTF::nullopt, WTF::nullopt);
128
129             auto structVariableReference = makeUniqueRef<AST::VariableReference>(AST::VariableReference::wrap(structVariableDeclaration));
130             structVariableReference->setType(m_structType->clone());
131             structVariableReference->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
132
133             AST::VariableDeclarations structVariableDeclarations;
134             structVariableDeclarations.append(WTFMove(structVariableDeclaration));
135             auto structDeclarationStatement = makeUniqueRef<AST::VariableDeclarationsStatement>(functionDefinition.origin(), WTFMove(structVariableDeclarations));
136
137             auto makePointerExpression = makeUniqueRef<AST::MakePointerExpression>(functionDefinition.origin(), WTFMove(structVariableReference));
138             makePointerExpression->setType(m_pointerToStructType->clone());
139             makePointerExpression->setTypeAnnotation(AST::LeftValue { AST::AddressSpace::Thread });
140
141             auto pointerDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.origin(), AST::Qualifiers(),
142                 m_pointerToStructType->clone(), "wrapper"_s, WTF::nullopt, Optional<UniqueRef<AST::Expression>>(WTFMove(makePointerExpression)));
143             m_structVariable = &pointerDeclaration;
144
145             AST::VariableDeclarations pointerVariableDeclarations;
146             pointerVariableDeclarations.append(WTFMove(pointerDeclaration));
147             auto pointerDeclarationStatement = makeUniqueRef<AST::VariableDeclarationsStatement>(functionDefinition.origin(), WTFMove(pointerVariableDeclarations));
148
149             functionDefinition.block().statements().insert(0, WTFMove(structDeclarationStatement));
150             functionDefinition.block().statements().insert(1, WTFMove(pointerDeclarationStatement));
151         } else {
152             auto pointerDeclaration = makeUniqueRef<AST::VariableDeclaration>(functionDefinition.origin(), AST::Qualifiers(),
153                 m_pointerToStructType->clone(), "wrapper"_s, WTF::nullopt, WTF::nullopt);
154             m_structVariable = &pointerDeclaration;
155             functionDefinition.parameters().append(WTFMove(pointerDeclaration));
156         }
157
158         Base::visit(functionDefinition);
159
160         for (auto& parameter : functionDefinition.parameters()) {
161             auto iter = m_variableMapping.find(&parameter);
162             if (iter == m_variableMapping.end())
163                 continue;
164
165             functionDefinition.block().statements().insert(isEntryPoint ? 2 : 0,
166                 makeUniqueRef<AST::EffectfulExpressionStatement>(assignVariableIntoStruct(parameter, iter->value)));
167         }
168
169         // Inner functions are not allowed in WHLSL. So this is fine.
170         m_structVariable = nullptr;
171     }
172
173     void visit(AST::CallExpression& callExpression) override
174     {
175         RELEASE_ASSERT(m_structVariable);
176
177         Base::visit(callExpression);
178
179         // This works because it's illegal to call an entrypoint. Therefore, we can only
180         // call functions where we've already appended this struct as its final parameter.
181         if (!callExpression.function()->isNativeFunctionDeclaration())
182             callExpression.arguments().append(makeStructVariableReference());
183     }
184
185     void visit(AST::VariableReference& variableReference) override
186     {
187         RELEASE_ASSERT(m_structVariable);
188
189         auto iter = m_variableMapping.find(variableReference.variable());
190         if (iter == m_variableMapping.end())
191             return;
192
193         auto type = variableReference.variable()->type()->clone();
194         AST::TypeAnnotation typeAnnotation = variableReference.typeAnnotation();
195         auto* internalField = AST::replaceWith<AST::GlobalVariableReference>(variableReference, variableReference.origin(), makeStructVariableReference(), iter->value);
196         internalField->setType(WTFMove(type));
197         internalField->setTypeAnnotation(WTFMove(typeAnnotation));
198     }
199
200     void visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement) override
201     {
202         RELEASE_ASSERT(m_structVariable);
203
204         Base::visit(variableDeclarationsStatement);
205
206         AST::Statements statements;
207         for (UniqueRef<AST::VariableDeclaration>& variableDeclaration : variableDeclarationsStatement.variableDeclarations()) {
208             AST::VariableDeclaration& variable = variableDeclaration.get();
209
210             {
211                 AST::VariableDeclarations declarations;
212                 declarations.append(WTFMove(variableDeclaration));
213                 statements.append(makeUniqueRef<AST::VariableDeclarationsStatement>(variable.origin(), WTFMove(declarations)));
214             }
215
216             auto iter = m_variableMapping.find(&variable);
217             if (iter != m_variableMapping.end())
218                 statements.append(makeUniqueRef<AST::EffectfulExpressionStatement>(assignVariableIntoStruct(variable, iter->value)));
219         }
220
221         auto origin = Lexer::Token(variableDeclarationsStatement.origin());
222         AST::replaceWith<AST::StatementList>(variableDeclarationsStatement, WTFMove(origin), WTFMove(statements));
223     }
224
225 private:
226     AST::VariableDeclaration* m_structVariable { nullptr };
227
228     UniqueRef<AST::TypeReference>& m_structType;
229     UniqueRef<AST::PointerType> m_pointerToStructType;
230     // If this mapping contains the variable, it means that the variable's canonical location
231     // is in the struct we use to preserve variable lifetimes.
232     const HashMap<AST::VariableDeclaration*, AST::StructureElement*>& m_variableMapping;
233 };
234
235 void preserveVariableLifetimes(Program& program)
236 {
237     HashMap<AST::VariableDeclaration*, String> escapedVariables;
238     {
239         EscapedVariableCollector collector;
240         collector.Visitor::visit(program);
241         escapedVariables = collector.takeEscapedVariables();
242     }
243
244     AST::StructureElements elements;
245     for (auto& pair : escapedVariables) {
246         auto* variable = pair.key;
247         String name = pair.value;
248         elements.append(AST::StructureElement { Lexer::Token(variable->origin()), { }, variable->type()->clone(), WTFMove(name), WTF::nullopt });
249     }
250
251     // Name of this doesn't matter, since we don't use struct names when
252     // generating Metal type names. We just pick something here to make it
253     // easy to read in AST dumps.
254     auto wrapperStructDefinition = makeUniqueRef<AST::StructureDefinition>(anonymousToken(Lexer::Token::Type::Struct), "__WrapperStruct__"_s, WTFMove(elements));
255
256     HashMap<AST::VariableDeclaration*, AST::StructureElement*> variableMapping;
257     unsigned index = 0;
258     for (auto& pair : escapedVariables)
259         variableMapping.add(pair.key, &wrapperStructDefinition->structureElements()[index++]);
260
261     {
262         auto wrapperStructType = AST::TypeReference::wrap(anonymousToken(Lexer::Token::Type::Identifier), wrapperStructDefinition);
263         PreserveLifetimes preserveLifetimes(wrapperStructType, variableMapping);
264         preserveLifetimes.Visitor::visit(program);
265     }
266
267     program.structureDefinitions().append(WTFMove(wrapperStructDefinition));
268 }
269
270 } // namespace WHLSL
271
272 } // namespace WebCore
273
274 #endif // ENABLE(WEBGPU)