542fa785b2808deaed05e165fa54089c8e4f8c39
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLSemanticMatcher.cpp
1 /*
2  * Copyright (C) 2019 Apple Inc. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are permitted provided that the following conditions
6  * are met:
7  * 1. Redistributions of source code must retain the above copyright
8  *    notice, this list of conditions and the following disclaimer.
9  * 2. Redistributions in binary form must reproduce the above copyright
10  *    notice, this list of conditions and the following disclaimer in the
11  *    documentation and/or other materials provided with the distribution.
12  *
13  * THIS SOFTWARE IS PROVIDED BY APPLE INC. AND ITS CONTRIBUTORS ``AS IS''
14  * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO,
15  * THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
16  * PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL APPLE INC. OR ITS CONTRIBUTORS
17  * BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
18  * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
19  * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
20  * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
21  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
22  * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
23  * THE POSSIBILITY OF SUCH DAMAGE.
24  */
25
26 #include "config.h"
27 #include "WHLSLSemanticMatcher.h"
28
29 #if ENABLE(WEBGPU)
30
31 #include "WHLSLBuiltInSemantic.h"
32 #include "WHLSLFunctionDefinition.h"
33 #include "WHLSLGatherEntryPointItems.h"
34 #include "WHLSLInferTypes.h"
35 #include "WHLSLPipelineDescriptor.h"
36 #include "WHLSLProgram.h"
37 #include "WHLSLResourceSemantic.h"
38 #include "WHLSLStageInOutSemantic.h"
39 #include <wtf/HashMap.h>
40 #include <wtf/HashSet.h>
41 #include <wtf/Optional.h>
42 #include <wtf/text/WTFString.h>
43
44 namespace WebCore {
45
46 namespace WHLSL {
47
48 static bool matchMode(Binding::BindingDetails bindingType, AST::ResourceSemantic::Mode mode)
49 {
50     return WTF::visit(WTF::makeVisitor([&](UniformBufferBinding) -> bool {
51         return mode == AST::ResourceSemantic::Mode::Buffer;
52     }, [&](SamplerBinding) -> bool {
53         return mode == AST::ResourceSemantic::Mode::Sampler;
54     }, [&](TextureBinding) -> bool {
55         return mode == AST::ResourceSemantic::Mode::Texture;
56     }, [&](StorageBufferBinding) -> bool {
57         return mode == AST::ResourceSemantic::Mode::UnorderedAccessView;
58     }), bindingType);
59 }
60
61 static Optional<HashMap<Binding*, size_t>> matchResources(Vector<EntryPointItem>& entryPointItems, Layout& layout, ShaderStage shaderStage)
62 {
63     HashMap<Binding*, size_t> result;
64     HashSet<size_t, DefaultHash<size_t>::Hash, WTF::UnsignedWithZeroKeyHashTraits<size_t>> itemIndices;
65     for (auto& bindGroup : layout) {
66         auto space = bindGroup.name;
67         for (auto& binding : bindGroup.bindings) {
68             if (!binding.visibility.contains(shaderStage))
69                 continue;
70             for (size_t i = 0; i < entryPointItems.size(); ++i) {
71                 auto& item = entryPointItems[i];
72                 auto& semantic = *item.semantic;
73                 if (!WTF::holds_alternative<AST::ResourceSemantic>(semantic))
74                     continue;
75                 auto& resourceSemantic = WTF::get<AST::ResourceSemantic>(semantic);
76                 if (!matchMode(binding.binding, resourceSemantic.mode()))
77                     continue;
78                 if (binding.externalName != resourceSemantic.index())
79                     continue;
80                 if (space != resourceSemantic.space())
81                     continue;
82                 result.add(&binding, i);
83                 itemIndices.add(i);
84             }
85         }
86     }
87
88     for (size_t i = 0; i < entryPointItems.size(); ++i) {
89         auto& item = entryPointItems[i];
90         auto& semantic = *item.semantic;
91         if (!WTF::holds_alternative<AST::ResourceSemantic>(semantic))
92             continue;
93         if (!itemIndices.contains(i))
94             return WTF::nullopt;
95     }
96
97     return result;
98 }
99
100 static bool matchInputsOutputs(Vector<EntryPointItem>& vertexOutputs, Vector<EntryPointItem>& fragmentInputs)
101 {
102     for (auto& fragmentInput : fragmentInputs) {
103         if (!WTF::holds_alternative<AST::StageInOutSemantic>(*fragmentInput.semantic))
104             continue;
105         auto& fragmentInputStageInOutSemantic = WTF::get<AST::StageInOutSemantic>(*fragmentInput.semantic);
106         bool found = false;
107         for (auto& vertexOutput : vertexOutputs) {
108             if (!WTF::holds_alternative<AST::StageInOutSemantic>(*vertexOutput.semantic))
109                 continue;
110             auto& vertexOutputStageInOutSemantic = WTF::get<AST::StageInOutSemantic>(*vertexOutput.semantic);
111             if (fragmentInputStageInOutSemantic.index() == vertexOutputStageInOutSemantic.index()) {
112                 if (matches(*fragmentInput.unnamedType, *vertexOutput.unnamedType)) {
113                     found = true;
114                     break;
115                 }
116                 return false;
117             }
118         }
119         if (!found)
120             return false;
121     }
122     return true;
123 }
124
125 static bool isAcceptableFormat(VertexFormat vertexFormat, AST::UnnamedType& unnamedType, Intrinsics& intrinsics)
126 {
127     switch (vertexFormat) {
128     case VertexFormat::FloatR32G32B32A32:
129         return matches(unnamedType, intrinsics.float4Type());
130     case VertexFormat::FloatR32G32B32:
131         return matches(unnamedType, intrinsics.float3Type());
132     case VertexFormat::FloatR32G32:
133         return matches(unnamedType, intrinsics.float2Type());
134     default:
135         ASSERT(vertexFormat == VertexFormat::FloatR32);
136         return matches(unnamedType, intrinsics.floatType());
137     }
138 }
139
140 static Optional<HashMap<VertexAttribute*, size_t>> matchVertexAttributes(Vector<EntryPointItem>& vertexInputs, VertexAttributes& vertexAttributes, Intrinsics& intrinsics)
141 {
142     HashMap<VertexAttribute*, size_t> result;
143     HashSet<size_t, DefaultHash<size_t>::Hash, WTF::UnsignedWithZeroKeyHashTraits<size_t>> itemIndices;
144     for (auto& vertexAttribute : vertexAttributes) {
145         for (size_t i = 0; i < vertexInputs.size(); ++i) {
146             auto& item = vertexInputs[i];
147             auto& semantic = *item.semantic;
148             if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
149                 continue;
150             auto& stageInOutSemantic = WTF::get<AST::StageInOutSemantic>(semantic);
151             if (stageInOutSemantic.index() != vertexAttribute.shaderLocation)
152                 continue;
153             if (!isAcceptableFormat(vertexAttribute.vertexFormat, *item.unnamedType, intrinsics))
154                 return WTF::nullopt;
155             result.add(&vertexAttribute, i);
156             itemIndices.add(i);
157         }
158     }
159
160     for (size_t i = 0; i < vertexInputs.size(); ++i) {
161         auto& item = vertexInputs[i];
162         auto& semantic = *item.semantic;
163         if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
164             continue;
165         if (!itemIndices.contains(i))
166             return WTF::nullopt;
167     }
168
169     return result;
170 }
171
172 static bool isAcceptableFormat(TextureFormat textureFormat, AST::UnnamedType& unnamedType, Intrinsics& intrinsics, bool isColor)
173 {
174     if (isColor) {
175         switch (textureFormat) {
176         case TextureFormat::R8Unorm:
177         case TextureFormat::R8UnormSrgb:
178         case TextureFormat::R8Snorm:
179         case TextureFormat::R16Unorm:
180         case TextureFormat::R16Snorm:
181         case TextureFormat::R16Float:
182         case TextureFormat::R32Float:
183             return matches(unnamedType, intrinsics.floatType());
184         case TextureFormat::RG8Unorm:
185         case TextureFormat::RG8UnormSrgb:
186         case TextureFormat::RG8Snorm:
187         case TextureFormat::RG16Unorm:
188         case TextureFormat::RG16Snorm:
189         case TextureFormat::RG16Float:
190         case TextureFormat::RG32Float:
191             return matches(unnamedType, intrinsics.float2Type());
192         case TextureFormat::B5G6R5Unorm:
193         case TextureFormat::RG11B10Float:
194             return matches(unnamedType, intrinsics.float3Type());
195         case TextureFormat::RGBA8Unorm:
196         case TextureFormat::RGBA8UnormSrgb:
197         case TextureFormat::BGRA8Unorm:
198         case TextureFormat::BGRA8UnormSrgb:
199         case TextureFormat::RGBA8Snorm:
200         case TextureFormat::RGB10A2Unorm:
201         case TextureFormat::RGBA16Unorm:
202         case TextureFormat::RGBA16Snorm:
203         case TextureFormat::RGBA16Float:
204         case TextureFormat::RGBA32Float:
205             return matches(unnamedType, intrinsics.float4Type());
206         case TextureFormat::R32Uint:
207             return matches(unnamedType, intrinsics.uintType());
208         case TextureFormat::R32Sint:
209             return matches(unnamedType, intrinsics.intType());
210         case TextureFormat::RG32Uint:
211             return matches(unnamedType, intrinsics.uint2Type());
212         case TextureFormat::RG32Sint:
213             return matches(unnamedType, intrinsics.int2Type());
214         case TextureFormat::RGBA32Uint:
215             return matches(unnamedType, intrinsics.uint4Type());
216         case TextureFormat::RGBA32Sint:
217             return matches(unnamedType, intrinsics.int4Type());
218         default:
219             ASSERT_NOT_REACHED();
220             return false;
221         }
222     }
223     return false;
224 }
225
226 static Optional<HashMap<AttachmentDescriptor*, size_t>> matchColorAttachments(Vector<EntryPointItem>& fragmentOutputs, Vector<AttachmentDescriptor>& attachmentDescriptors, Intrinsics& intrinsics)
227 {
228     HashMap<AttachmentDescriptor*, size_t> result;
229     HashSet<size_t, DefaultHash<size_t>::Hash, WTF::UnsignedWithZeroKeyHashTraits<size_t>> itemIndices;
230     for (auto& attachmentDescriptor : attachmentDescriptors) {
231         for (size_t i = 0; i < fragmentOutputs.size(); ++i) {
232             auto& item = fragmentOutputs[i];
233             auto& semantic = *item.semantic;
234             if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
235                 continue;
236             auto& stageInOutSemantic = WTF::get<AST::StageInOutSemantic>(semantic);
237             if (stageInOutSemantic.index() != attachmentDescriptor.name)
238                 continue;
239             if (!isAcceptableFormat(attachmentDescriptor.textureFormat, *item.unnamedType, intrinsics, true))
240                 return WTF::nullopt;
241             result.add(&attachmentDescriptor, i);
242             itemIndices.add(i);
243         }
244     }
245
246     for (size_t i = 0; i < fragmentOutputs.size(); ++i) {
247         auto& item = fragmentOutputs[i];
248         auto& semantic = *item.semantic;
249         if (!WTF::holds_alternative<AST::StageInOutSemantic>(semantic))
250             continue;
251         if (!itemIndices.contains(i))
252             return WTF::nullopt;
253     }
254
255     return result;
256 }
257
258 static bool matchDepthAttachment(Vector<EntryPointItem>& fragmentOutputs, Optional<AttachmentDescriptor>& depthStencilAttachmentDescriptor, Intrinsics& intrinsics)
259 {
260     auto iterator = std::find_if(fragmentOutputs.begin(), fragmentOutputs.end(), [&](EntryPointItem& item) {
261         auto& semantic = *item.semantic;
262         if (!WTF::holds_alternative<AST::BuiltInSemantic>(semantic))
263             return false;
264         auto& builtInSemantic = WTF::get<AST::BuiltInSemantic>(semantic);
265         return builtInSemantic.variable() == AST::BuiltInSemantic::Variable::SVDepth;
266     });
267     if (iterator == fragmentOutputs.end())
268         return true;
269
270     if (depthStencilAttachmentDescriptor) {
271         ASSERT(!depthStencilAttachmentDescriptor->name);
272         return isAcceptableFormat(depthStencilAttachmentDescriptor->textureFormat, *iterator->unnamedType, intrinsics, false);
273     }
274     return false;
275 }
276
277 Optional<MatchedRenderSemantics> matchSemantics(Program& program, RenderPipelineDescriptor& renderPipelineDescriptor, bool distinctFragmentShader, bool fragmentShaderExists)
278 {
279     auto vertexFunctions = program.nameContext().getFunctions(renderPipelineDescriptor.vertexEntryPointName, AST::NameSpace::NameSpace1);
280     if (vertexFunctions.size() != 1 || !vertexFunctions[0].get().entryPointType() || !is<AST::FunctionDefinition>(vertexFunctions[0].get()))
281         return WTF::nullopt;
282     auto& vertexShaderEntryPoint = downcast<AST::FunctionDefinition>(vertexFunctions[0].get());
283     auto vertexShaderEntryPointItems = gatherEntryPointItems(program.intrinsics(), vertexShaderEntryPoint);
284     if (!vertexShaderEntryPointItems)
285         return WTF::nullopt;
286     auto vertexShaderResourceMap = matchResources(vertexShaderEntryPointItems->inputs, renderPipelineDescriptor.layout, ShaderStage::Vertex);
287     if (!vertexShaderResourceMap)
288         return WTF::nullopt;
289     auto matchedVertexAttributes = matchVertexAttributes(vertexShaderEntryPointItems->inputs, renderPipelineDescriptor.vertexAttributes, program.intrinsics());
290     if (!matchedVertexAttributes)
291         return WTF::nullopt;
292     if (!fragmentShaderExists)
293         return {{ &vertexShaderEntryPoint, nullptr, *vertexShaderEntryPointItems, EntryPointItems(), *vertexShaderResourceMap, HashMap<Binding*, size_t>(), *matchedVertexAttributes, HashMap<AttachmentDescriptor*, size_t>() }};
294
295     auto fragmentNameSpace = distinctFragmentShader ? AST::NameSpace::NameSpace2 : AST::NameSpace::NameSpace1;
296     auto fragmentFunctions = program.nameContext().getFunctions(renderPipelineDescriptor.fragmentEntryPointName, fragmentNameSpace);
297     if (fragmentFunctions.size() != 1 || !fragmentFunctions[0].get().entryPointType() || !is<AST::FunctionDefinition>(fragmentFunctions[0].get()))
298         return WTF::nullopt;
299     auto& fragmentShaderEntryPoint = downcast<AST::FunctionDefinition>(fragmentFunctions[0].get());
300     auto fragmentShaderEntryPointItems = gatherEntryPointItems(program.intrinsics(), fragmentShaderEntryPoint);
301     if (!fragmentShaderEntryPointItems)
302         return WTF::nullopt;
303     auto fragmentShaderResourceMap = matchResources(fragmentShaderEntryPointItems->inputs, renderPipelineDescriptor.layout, ShaderStage::Fragment);
304     if (!fragmentShaderResourceMap)
305         return WTF::nullopt;
306     if (!matchInputsOutputs(vertexShaderEntryPointItems->outputs, fragmentShaderEntryPointItems->inputs))
307         return WTF::nullopt;
308     auto matchedColorAttachments = matchColorAttachments(fragmentShaderEntryPointItems->outputs, renderPipelineDescriptor.attachmentsStateDescriptor.attachmentDescriptors, program.intrinsics());
309     if (!matchedColorAttachments)
310         return WTF::nullopt;
311     if (!matchDepthAttachment(fragmentShaderEntryPointItems->outputs, renderPipelineDescriptor.attachmentsStateDescriptor.depthStencilAttachmentDescriptor, program.intrinsics()))
312         return WTF::nullopt;
313     return {{ &vertexShaderEntryPoint, &fragmentShaderEntryPoint, *vertexShaderEntryPointItems, *fragmentShaderEntryPointItems, *vertexShaderResourceMap, *fragmentShaderResourceMap, *matchedVertexAttributes, *matchedColorAttachments }};
314 }
315
316 Optional<MatchedComputeSemantics> matchSemantics(Program& program, ComputePipelineDescriptor& computePipelineDescriptor)
317 {
318     auto functions = program.nameContext().getFunctions(computePipelineDescriptor.entryPointName, AST::NameSpace::NameSpace1);
319     if (functions.size() != 1 || !functions[0].get().entryPointType() || !is<AST::FunctionDefinition>(functions[0].get()))
320         return WTF::nullopt;
321     auto& entryPoint = downcast<AST::FunctionDefinition>(functions[0].get());
322     auto entryPointItems = gatherEntryPointItems(program.intrinsics(), entryPoint);
323     if (!entryPointItems)
324         return WTF::nullopt;
325     ASSERT(entryPointItems->outputs.isEmpty());
326     auto resourceMap = matchResources(entryPointItems->inputs, computePipelineDescriptor.layout, ShaderStage::Compute);
327     if (!resourceMap)
328         return WTF::nullopt;
329     return {{ &entryPoint, *entryPointItems, *resourceMap }};
330 }
331
332 } // namespace WHLSL
333
334 } // namespace WebCore
335
336 #endif // ENABLE(WEBGPU)