Unreviewed, rolling out r249369.
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLSemanticMatcher.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLSemanticMatcher.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLBuiltInSemantic.h"
32 #include "WHLSLFunctionDefinition.h"
33 #include "WHLSLGatherEntryPointItems.h"
34 #include "WHLSLInferTypes.h"
35 #include "WHLSLPipelineDescriptor.h"
36 #include "WHLSLProgram.h"
37 #include "WHLSLResourceSemantic.h"
38 #include "WHLSLStageInOutSemantic.h"
39 #include <wtf/HashMap.h>
40 #include <wtf/HashSet.h>
41 #include <wtf/Optional.h>
42 #include <wtf/text/WTFString.h>
43
44 namespace WebCore {
45
46 namespace WHLSL {
47
48 static bool matchMode(Binding::BindingDetails bindingType, AST::ResourceSemantic::Mode mode)
49 {
50     return WTF::visit(WTF::makeVisitor([&](UniformBufferBinding) -> bool {
51         return mode == AST::ResourceSemantic::Mode::Buffer;
52     }, [&](SamplerBinding) -> bool {
53         return mode == AST::ResourceSemantic::Mode::Sampler;
54     }, [&](TextureBinding) -> bool {
55         return mode == AST::ResourceSemantic::Mode::Texture;
56     }, [&](StorageBufferBinding) -> bool {
57         return mode == AST::ResourceSemantic::Mode::UnorderedAccessView;
58     }), bindingType);
59 }
60
61 static Optional<HashMap<Binding*, size_t>> matchResources(Vector<EntryPointItem>& entryPointItems, Layout& layout, ShaderStage shaderStage)
62 {
63     HashMap<Binding*, size_t> result;
64     HashSet<size_t> itemIndices;
65     if (entryPointItems.size() == std::numeric_limits<size_t>::max())
66         return WTF::nullopt; // Work around the fact that HashSet's keys are restricted.
67     for (auto& bindGroup : layout) {
68         auto space = bindGroup.name;
69         for (auto& binding : bindGroup.bindings) {
70             if (!binding.visibility.contains(shaderStage))
71                 continue;
72             for (size_t i = 0; i < entryPointItems.size(); ++i) {
73                 auto& item = entryPointItems[i];
74                 auto& semantic = *item.semantic;
75                 if (!WTF::holds_alternative<AST::ResourceSemantic>(semantic))
76                     continue;
77                 auto& resourceSemantic = WTF::get<AST::ResourceSemantic>(semantic);
78                 if (!matchMode(binding.binding, resourceSemantic.mode()))
79                     continue;
80                 if (binding.externalName != resourceSemantic.index())
81                     continue;
82                 if (space != resourceSemantic.space())
83                     continue;
84                 result.add(&binding, i);
85                 itemIndices.add(i + 1); // Work around the fact that HashSet's keys are restricted.
86             }
87         }
88     }
89
90     for (size_t i = 0; i < entryPointItems.size(); ++i) {
91         auto& item = entryPointItems[i];
92         auto& semantic = *item.semantic;
93         if (!WTF::holds_alternative<AST::ResourceSemantic>(semantic))
94             continue;
95         if (!itemIndices.contains(i + 1))
96             return WTF::nullopt;
97     }
98
99     return result;
100 }
101
102 static bool matchInputsOutputs(Vector<EntryPointItem>& vertexOutputs, Vector<EntryPointItem>& fragmentInputs)
103 {
104     for (auto& fragmentInput : fragmentInputs) {
105         if (!WTF::holds_alternative<AST::StageInOutSemantic>(*fragmentInput.semantic))
106             continue;
107         auto& fragmentInputStageInOutSemantic = WTF::get<AST::StageInOutSemantic>(*fragmentInput.semantic);
108         bool found = false;
109         for (auto& vertexOutput : vertexOutputs) {
110             if (!WTF::holds_alternative<AST::StageInOutSemantic>(*vertexOutput.semantic))
111                 continue;
112             auto& vertexOutputStageInOutSemantic = WTF::get<AST::StageInOutSemantic>(*vertexOutput.semantic);
113             if (fragmentInputStageInOutSemantic.index() == vertexOutputStageInOutSemantic.index()) {
114                 if (matches(*fragmentInput.unnamedType, *vertexOutput.unnamedType)) {
115                     found = true;
116                     break;
117                 }
118                 return false;
119             }
120         }
121         if (!found)
122             return false;
123     }
124     return true;
125 }
126
127 static bool isAcceptableFormat(VertexFormat vertexFormat, AST::UnnamedType& unnamedType, Intrinsics& intrinsics)
128 {
129     switch (vertexFormat) {
130     case VertexFormat::FloatR32G32B32A32:
131         return matches(unnamedType, intrinsics.float4Type());
132     case VertexFormat::FloatR32G32B32:
133         return matches(unnamedType, intrinsics.float3Type());
134     case VertexFormat::FloatR32G32:
135         return matches(unnamedType, intrinsics.float2Type());
136     default:
137         ASSERT(vertexFormat == VertexFormat::FloatR32);
138         return matches(unnamedType, intrinsics.floatType());
139     }
140 }
141
142 static Optional<HashMap<VertexAttribute*, size_t>> matchVertexAttributes(Vector<EntryPointItem>& vertexInputs, VertexAttributes& vertexAttributes, Intrinsics& intrinsics)
143 {
144     HashMap<VertexAttribute*, size_t> result;
145     HashSet<size_t> itemIndices;
146     if (vertexInputs.size() == std::numeric_limits<size_t>::max())
147         return WTF::nullopt; // Work around the fact that HashSet's keys are restricted.
148     for (auto& vertexAttribute : vertexAttributes) {
149         for (size_t i = 0; i < vertexInputs.size(); ++i) {
150             auto& item = vertexInputs[i];
151             auto& semantic = *item.semantic;
152             if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
153                 continue;
154             auto& stageInOutSemantic = WTF::get<AST::StageInOutSemantic>(semantic);
155             if (stageInOutSemantic.index() != vertexAttribute.shaderLocation)
156                 continue;
157             if (!isAcceptableFormat(vertexAttribute.vertexFormat, *item.unnamedType, intrinsics))
158                 return WTF::nullopt;
159             result.add(&vertexAttribute, i);
160             itemIndices.add(i + 1); // Work around the fact that HashSet's keys are restricted.
161         }
162     }
163
164     for (size_t i = 0; i < vertexInputs.size(); ++i) {
165         auto& item = vertexInputs[i];
166         auto& semantic = *item.semantic;
167         if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
168             continue;
169         if (!itemIndices.contains(i + 1))
170             return WTF::nullopt;
171     }
172
173     return result;
174 }
175
176 static bool isAcceptableFormat(TextureFormat textureFormat, AST::UnnamedType& unnamedType, Intrinsics& intrinsics, bool isColor)
177 {
178     if (isColor) {
179         switch (textureFormat) {
180         case TextureFormat::R8Unorm:
181         case TextureFormat::R8UnormSrgb:
182         case TextureFormat::R8Snorm:
183         case TextureFormat::R16Unorm:
184         case TextureFormat::R16Snorm:
185         case TextureFormat::R16Float:
186         case TextureFormat::R32Float:
187             return matches(unnamedType, intrinsics.floatType());
188         case TextureFormat::RG8Unorm:
189         case TextureFormat::RG8UnormSrgb:
190         case TextureFormat::RG8Snorm:
191         case TextureFormat::RG16Unorm:
192         case TextureFormat::RG16Snorm:
193         case TextureFormat::RG16Float:
194         case TextureFormat::RG32Float:
195             return matches(unnamedType, intrinsics.float2Type());
196         case TextureFormat::B5G6R5Unorm:
197         case TextureFormat::RG11B10Float:
198             return matches(unnamedType, intrinsics.float3Type());
199         case TextureFormat::RGBA8Unorm:
200         case TextureFormat::RGBA8UnormSrgb:
201         case TextureFormat::BGRA8Unorm:
202         case TextureFormat::BGRA8UnormSrgb:
203         case TextureFormat::RGBA8Snorm:
204         case TextureFormat::RGB10A2Unorm:
205         case TextureFormat::RGBA16Unorm:
206         case TextureFormat::RGBA16Snorm:
207         case TextureFormat::RGBA16Float:
208         case TextureFormat::RGBA32Float:
209             return matches(unnamedType, intrinsics.float4Type());
210         case TextureFormat::R32Uint:
211             return matches(unnamedType, intrinsics.uintType());
212         case TextureFormat::R32Sint:
213             return matches(unnamedType, intrinsics.intType());
214         case TextureFormat::RG32Uint:
215             return matches(unnamedType, intrinsics.uint2Type());
216         case TextureFormat::RG32Sint:
217             return matches(unnamedType, intrinsics.int2Type());
218         case TextureFormat::RGBA32Uint:
219             return matches(unnamedType, intrinsics.uint4Type());
220         case TextureFormat::RGBA32Sint:
221             return matches(unnamedType, intrinsics.int4Type());
222         default:
223             ASSERT_NOT_REACHED();
224             return false;
225         }
226     }
227     return false;
228 }
229
230 static Optional<HashMap<AttachmentDescriptor*, size_t>> matchColorAttachments(Vector<EntryPointItem>& fragmentOutputs, Vector<AttachmentDescriptor>& attachmentDescriptors, Intrinsics& intrinsics)
231 {
232     HashMap<AttachmentDescriptor*, size_t> result;
233     HashSet<size_t> itemIndices;
234     if (attachmentDescriptors.size() == std::numeric_limits<size_t>::max())
235         return WTF::nullopt; // Work around the fact that HashSet's keys are restricted.
236     for (auto& attachmentDescriptor : attachmentDescriptors) {
237         for (size_t i = 0; i < fragmentOutputs.size(); ++i) {
238             auto& item = fragmentOutputs[i];
239             auto& semantic = *item.semantic;
240             if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
241                 continue;
242             auto& stageInOutSemantic = WTF::get<AST::StageInOutSemantic>(semantic);
243             if (stageInOutSemantic.index() != attachmentDescriptor.name)
244                 continue;
245             if (!isAcceptableFormat(attachmentDescriptor.textureFormat, *item.unnamedType, intrinsics, true))
246                 return WTF::nullopt;
247             result.add(&attachmentDescriptor, i);
248             itemIndices.add(i + 1); // Work around the fact that HashSet's keys are restricted.
249         }
250     }
251
252     for (size_t i = 0; i < fragmentOutputs.size(); ++i) {
253         auto& item = fragmentOutputs[i];
254         auto& semantic = *item.semantic;
255         if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
256             continue;
257         if (!itemIndices.contains(i + 1))
258             return WTF::nullopt;
259     }
260
261     return result;
262 }
263
264 static bool matchDepthAttachment(Vector<EntryPointItem>& fragmentOutputs, Optional<AttachmentDescriptor>& depthStencilAttachmentDescriptor, Intrinsics& intrinsics)
265 {
266     auto iterator = std::find_if(fragmentOutputs.begin(), fragmentOutputs.end(), [&](EntryPointItem& item) {
267         auto& semantic = *item.semantic;
268         if (!WTF::holds_alternative<AST::BuiltInSemantic>(semantic))
269             return false;
270         auto& builtInSemantic = WTF::get<AST::BuiltInSemantic>(semantic);
271         return builtInSemantic.variable() == AST::BuiltInSemantic::Variable::SVDepth;
272     });
273     if (iterator == fragmentOutputs.end())
274         return true;
275
276     if (depthStencilAttachmentDescriptor) {
277         ASSERT(!depthStencilAttachmentDescriptor->name);
278         return isAcceptableFormat(depthStencilAttachmentDescriptor->textureFormat, *iterator->unnamedType, intrinsics, false);
279     }
280     return false;
281 }
282
283 Optional<MatchedRenderSemantics> matchSemantics(Program& program, RenderPipelineDescriptor& renderPipelineDescriptor, bool distinctFragmentShader, bool fragmentShaderExists)
284 {
285     auto vertexFunctions = program.nameContext().getFunctions(renderPipelineDescriptor.vertexEntryPointName, AST::NameSpace::NameSpace1);
286     if (vertexFunctions.size() != 1 || !vertexFunctions[0].get().entryPointType() || !is<AST::FunctionDefinition>(vertexFunctions[0].get()))
287         return WTF::nullopt;
288     auto& vertexShaderEntryPoint = downcast<AST::FunctionDefinition>(vertexFunctions[0].get());
289     auto vertexShaderEntryPointItems = gatherEntryPointItems(program.intrinsics(), vertexShaderEntryPoint);
290     if (!vertexShaderEntryPointItems)
291         return WTF::nullopt;
292     auto vertexShaderResourceMap = matchResources(vertexShaderEntryPointItems->inputs, renderPipelineDescriptor.layout, ShaderStage::Vertex);
293     if (!vertexShaderResourceMap)
294         return WTF::nullopt;
295     auto matchedVertexAttributes = matchVertexAttributes(vertexShaderEntryPointItems->inputs, renderPipelineDescriptor.vertexAttributes, program.intrinsics());
296     if (!matchedVertexAttributes)
297         return WTF::nullopt;
298     if (!fragmentShaderExists)
299         return {{ &vertexShaderEntryPoint, nullptr, *vertexShaderEntryPointItems, EntryPointItems(), *vertexShaderResourceMap, HashMap<Binding*, size_t>(), *matchedVertexAttributes, HashMap<AttachmentDescriptor*, size_t>() }};
300
301     auto fragmentNameSpace = distinctFragmentShader ? AST::NameSpace::NameSpace2 : AST::NameSpace::NameSpace1;
302     auto fragmentFunctions = program.nameContext().getFunctions(renderPipelineDescriptor.fragmentEntryPointName, fragmentNameSpace);
303     if (fragmentFunctions.size() != 1 || !fragmentFunctions[0].get().entryPointType() || !is<AST::FunctionDefinition>(fragmentFunctions[0].get()))
304         return WTF::nullopt;
305     auto& fragmentShaderEntryPoint = downcast<AST::FunctionDefinition>(fragmentFunctions[0].get());
306     auto fragmentShaderEntryPointItems = gatherEntryPointItems(program.intrinsics(), fragmentShaderEntryPoint);
307     if (!fragmentShaderEntryPointItems)
308         return WTF::nullopt;
309     auto fragmentShaderResourceMap = matchResources(fragmentShaderEntryPointItems->inputs, renderPipelineDescriptor.layout, ShaderStage::Fragment);
310     if (!fragmentShaderResourceMap)
311         return WTF::nullopt;
312     if (!matchInputsOutputs(vertexShaderEntryPointItems->outputs, fragmentShaderEntryPointItems->inputs))
313         return WTF::nullopt;
314     auto matchedColorAttachments = matchColorAttachments(fragmentShaderEntryPointItems->outputs, renderPipelineDescriptor.attachmentsStateDescriptor.attachmentDescriptors, program.intrinsics());
315     if (!matchedColorAttachments)
316         return WTF::nullopt;
317     if (!matchDepthAttachment(fragmentShaderEntryPointItems->outputs, renderPipelineDescriptor.attachmentsStateDescriptor.depthStencilAttachmentDescriptor, program.intrinsics()))
318         return WTF::nullopt;
319     return {{ &vertexShaderEntryPoint, &fragmentShaderEntryPoint, *vertexShaderEntryPointItems, *fragmentShaderEntryPointItems, *vertexShaderResourceMap, *fragmentShaderResourceMap, *matchedVertexAttributes, *matchedColorAttachments }};
320 }
321
322 Optional<MatchedComputeSemantics> matchSemantics(Program& program, ComputePipelineDescriptor& computePipelineDescriptor)
323 {
324     auto functions = program.nameContext().getFunctions(computePipelineDescriptor.entryPointName, AST::NameSpace::NameSpace1);
325     if (functions.size() != 1 || !functions[0].get().entryPointType() || !is<AST::FunctionDefinition>(functions[0].get()))
326         return WTF::nullopt;
327     auto& entryPoint = downcast<AST::FunctionDefinition>(functions[0].get());
328     auto entryPointItems = gatherEntryPointItems(program.intrinsics(), entryPoint);
329     if (!entryPointItems)
330         return WTF::nullopt;
331     ASSERT(entryPointItems->outputs.isEmpty());
332     auto resourceMap = matchResources(entryPointItems->inputs, computePipelineDescriptor.layout, ShaderStage::Compute);
333     if (!resourceMap)
334         return WTF::nullopt;
335     return {{ &entryPoint, *entryPointItems, *resourceMap }};
336 }
337
338 } // namespace WHLSL
339
340 } // namespace WebCore
341
342 #endif // ENABLE(WEBGPU)