[WHLSL] Hook up the compiler to our WebGPU implementation
authormmaxfield@apple.com <mmaxfield@apple.com@268f45cc-cd09-0410-ab3c-d52691b4dbfc>
Mon, 18 Mar 2019 19:23:42 +0000 (19:23 +0000)
committermmaxfield@apple.com <mmaxfield@apple.com@268f45cc-cd09-0410-ab3c-d52691b4dbfc>
Mon, 18 Mar 2019 19:23:42 +0000 (19:23 +0000)
https://bugs.webkit.org/show_bug.cgi?id=195509

Reviewed by Dean Jackson.

Source/WebCore:

This represents a collection of changes necessary to compile and run the first WHLSL program in WebKit.

Because WHLSL isn't fully implemented yet, this patch doesn't remove the existing method for supplying
Metal shaders to WebGPU. Instead, it adds a new boolean to WebGPUShaderModuleDescriptor, "isWHLSL" which
causes us to run the WHLSL compiler.

More details below.

Test: webgpu/whlsl.html

* Modules/webgpu/WHLSL/AST/WHLSLCallExpression.h: Use raw pointer instead of Optional<std::reference_wrapper>.
(WebCore::WHLSL::AST::CallExpression::setCastData):
(WebCore::WHLSL::AST::CallExpression::isCast):
(WebCore::WHLSL::AST::CallExpression::castReturnType):
* Modules/webgpu/WHLSL/AST/WHLSLNativeTypeDeclaration.h:
(WebCore::WHLSL::AST::NativeTypeDeclaration::isAtomic const):
(WebCore::WHLSL::AST::NativeTypeDeclaration::setIsAtomic):
(WebCore::WHLSL::AST::NativeTypeDeclaration::name const): Deleted. The parent class already has a name string.
(WebCore::WHLSL::AST::NativeTypeDeclaration::name): Deleted.
* Modules/webgpu/WHLSL/AST/WHLSLReturn.h:
* Modules/webgpu/WHLSL/AST/WHLSLTypeReference.h:
(WebCore::WHLSL::AST::TypeReference::cloneTypeReference const): When cloning a type reference, make sure to
clone the pointer to its resolved type, too.
* Modules/webgpu/WHLSL/AST/WHLSLVariableReference.h:
* Modules/webgpu/WHLSL/Metal/WHLSLEntryPointScaffolding.cpp: Incorporate resolution from
https://github.com/gpuweb/gpuweb/pull/188.
(WebCore::WHLSL::Metal::EntryPointScaffolding::EntryPointScaffolding):
(WebCore::WHLSL::Metal::EntryPointScaffolding::resourceHelperTypes):
(WebCore::WHLSL::Metal::EntryPointScaffolding::resourceSignature):
(WebCore::WHLSL::Metal::VertexEntryPointScaffolding::helperTypes):
(WebCore::WHLSL::Metal::VertexEntryPointScaffolding::unpack):
(WebCore::WHLSL::Metal::VertexEntryPointScaffolding::pack): Support semantics being placed directly on the
entry point, instead of being placed on a structure member.
(WebCore::WHLSL::Metal::FragmentEntryPointScaffolding::helperTypes):
(WebCore::WHLSL::Metal::FragmentEntryPointScaffolding::pack): Ditto.
(WebCore::WHLSL::Metal::EntryPointScaffolding::mappedBindGroups const): Deleted.
* Modules/webgpu/WHLSL/Metal/WHLSLEntryPointScaffolding.h:
* Modules/webgpu/WHLSL/Metal/WHLSLFunctionWriter.cpp:
(WebCore::WHLSL::Metal::FunctionDefinitionWriter::visit):
(WebCore::WHLSL::Metal::RenderFunctionDefinitionWriter::createEntryPointScaffolding):
(WebCore::WHLSL::Metal::ComputeFunctionDefinitionWriter::createEntryPointScaffolding):
(WebCore::WHLSL::Metal::metalFunctions):
(WebCore::WHLSL::Metal::RenderFunctionDefinitionWriter::takeVertexMappedBindGroups): Deleted. After
https://github.com/gpuweb/gpuweb/pull/188, we don't need the mappings.
(WebCore::WHLSL::Metal::RenderFunctionDefinitionWriter::takeFragmentMappedBindGroups): Deleted. Ditto.
(WebCore::WHLSL::Metal::ComputeFunctionDefinitionWriter::takeMappedBindGroups): Deleted. Ditto.
* Modules/webgpu/WHLSL/Metal/WHLSLFunctionWriter.h: Ditto.
* Modules/webgpu/WHLSL/Metal/WHLSLMetalCodeGenerator.cpp: Ditto.
(WebCore::WHLSL::Metal::generateMetalCodeShared):
(WebCore::WHLSL::Metal::generateMetalCode):
* Modules/webgpu/WHLSL/Metal/WHLSLMetalCodeGenerator.h: Ditto.
* Modules/webgpu/WHLSL/Metal/WHLSLNativeFunctionWriter.cpp: Support compiler-generated functions. Change
CRASH() to notImplemented().
(WebCore::WHLSL::Metal::writeNativeFunction):
(WebCore::WHLSL::Metal::getNativeName): Deleted.
* Modules/webgpu/WHLSL/Metal/WHLSLNativeFunctionWriter.h:
* Modules/webgpu/WHLSL/Metal/WHLSLNativeTypeWriter.cpp:
(WebCore::WHLSL::Metal::writeNativeType):
* Modules/webgpu/WHLSL/Metal/WHLSLTypeNamer.cpp: The dependency graph needs to track all unnamed types. Also,
we need to track types that are the results of expressions (not just types literally spelled out in the
program). Enumerations need to be emitted after their base types are emitted.
(WebCore::WHLSL::Metal::TypeNamer::visit):
(WebCore::WHLSL::Metal::MetalTypeDeclarationWriter::MetalTypeDeclarationWriter):
(WebCore::WHLSL::Metal::TypeNamer::metalTypeDeclarations):
(WebCore::WHLSL::Metal::TypeNamer::emitUnnamedTypeDefinition):
(WebCore::WHLSL::Metal::TypeNamer::emitNamedTypeDefinition):
(WebCore::WHLSL::Metal::TypeNamer::emitAllUnnamedTypeDefinitions):
(WebCore::WHLSL::Metal::TypeNamer::metalTypeDefinitions):
* Modules/webgpu/WHLSL/Metal/WHLSLTypeNamer.h:
* Modules/webgpu/WHLSL/WHLSLCheckDuplicateFunctions.cpp:
(WebCore::WHLSL::checkDuplicateFunctions):
* Modules/webgpu/WHLSL/WHLSLChecker.cpp: Wrap ResolvingType in a class to make sure it plays nicely with
HashMap. Also, use raw pointers instead of Optional<std::reference_wrapper>s.
(WebCore::WHLSL::resolveWithReferenceComparator):
(WebCore::WHLSL::resolveByInstantiation):
(WebCore::WHLSL::checkOperatorOverload):
(WebCore::WHLSL::Checker::assignTypes):
(WebCore::WHLSL::Checker::checkShaderType):
(WebCore::WHLSL::Checker::visit):
(WebCore::WHLSL::matchAndCommit):
(WebCore::WHLSL::Checker::recurseAndGetInfo):
(WebCore::WHLSL::Checker::assignType):
(WebCore::WHLSL::Checker::forwardType):
(WebCore::WHLSL::getUnnamedType):
(WebCore::WHLSL::Checker::finishVisitingPropertyAccess):
(WebCore::WHLSL::Checker::isBoolType):
* Modules/webgpu/WHLSL/WHLSLGatherEntryPointItems.cpp:
(WebCore::WHLSL::Gatherer::visit):
* Modules/webgpu/WHLSL/WHLSLInferTypes.cpp:
(WebCore::WHLSL::inferTypesForCall):
* Modules/webgpu/WHLSL/WHLSLInferTypes.h:
* Modules/webgpu/WHLSL/WHLSLIntrinsics.cpp:
(WebCore::WHLSL::Intrinsics::addPrimitive):
(WebCore::WHLSL::Intrinsics::addFullTexture):
* Modules/webgpu/WHLSL/WHLSLIntrinsics.h:
(WebCore::WHLSL::Intrinsics::ucharType const):
(WebCore::WHLSL::Intrinsics::ushortType const):
(WebCore::WHLSL::Intrinsics::charType const):
(WebCore::WHLSL::Intrinsics::shortType const):
(WebCore::WHLSL::Intrinsics::intType const):
(WebCore::WHLSL::Intrinsics::uchar2Type const):
(WebCore::WHLSL::Intrinsics::uchar4Type const):
(WebCore::WHLSL::Intrinsics::ushort2Type const):
(WebCore::WHLSL::Intrinsics::ushort4Type const):
(WebCore::WHLSL::Intrinsics::uint2Type const):
(WebCore::WHLSL::Intrinsics::uint4Type const):
(WebCore::WHLSL::Intrinsics::char2Type const):
(WebCore::WHLSL::Intrinsics::char4Type const):
(WebCore::WHLSL::Intrinsics::short2Type const):
(WebCore::WHLSL::Intrinsics::short4Type const):
(WebCore::WHLSL::Intrinsics::int2Type const):
(WebCore::WHLSL::Intrinsics::int4Type const):
* Modules/webgpu/WHLSL/WHLSLLexer.cpp:
(WebCore::WHLSL::Lexer::recognizeKeyword):
* Modules/webgpu/WHLSL/WHLSLNameContext.cpp:
(WebCore::WHLSL::NameContext::add):
* Modules/webgpu/WHLSL/WHLSLNameResolver.cpp:
(WebCore::WHLSL::NameResolver::visit): Don't visit recursive types.
Also, make sure we preserve the CurrentFunction in our recursive scopes.
* Modules/webgpu/WHLSL/WHLSLNameResolver.h:
* Modules/webgpu/WHLSL/WHLSLParser.cpp:
(WebCore::WHLSL::Parser::fail):
(WebCore::WHLSL::Parser::peek):
(WebCore::WHLSL::Parser::parseType):
(WebCore::WHLSL::Parser::parseBuiltInSemantic):
* Modules/webgpu/WHLSL/WHLSLParser.h:
* Modules/webgpu/WHLSL/WHLSLPipelineDescriptor.h:
* Modules/webgpu/WHLSL/WHLSLPrepare.cpp:
(WebCore::WHLSL::prepareShared):
(WebCore::WHLSL::prepare):
* Modules/webgpu/WHLSL/WHLSLPrepare.h:
* Modules/webgpu/WHLSL/WHLSLRecursiveTypeChecker.cpp: Move big inline functions out-of-line.
(WebCore::WHLSL::RecursiveTypeChecker::visit):
(WebCore::WHLSL::checkRecursiveTypes):
(): Deleted.
* Modules/webgpu/WHLSL/WHLSLResolveOverloadImpl.cpp:
(WebCore::WHLSL::conversionCost):
(WebCore::WHLSL::resolveFunctionOverloadImpl):
* Modules/webgpu/WHLSL/WHLSLResolveOverloadImpl.h:
* Modules/webgpu/WHLSL/WHLSLResolvingType.h:
(WebCore::WHLSL::ResolvingType::ResolvingType):
(WebCore::WHLSL::ResolvingType::operator=):
(WebCore::WHLSL::ResolvingType::getUnnamedType):
(WebCore::WHLSL::ResolvingType::visit):
* Modules/webgpu/WHLSL/WHLSLScopedSetAdder.h: Renamed from Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLMappedBindings.h.
(WebCore::WHLSL::ScopedSetAdder::ScopedSetAdder):
(WebCore::WHLSL::ScopedSetAdder::~ScopedSetAdder):
(WebCore::WHLSL::ScopedSetAdder::isNewEntry const):
* Modules/webgpu/WHLSL/WHLSLSemanticMatcher.cpp:
(WebCore::WHLSL::isAcceptableFormat):
* Modules/webgpu/WHLSL/WHLSLStandardLibrary.txt: Turns out a bunch of texture types don't exist in MSL.
* Modules/webgpu/WHLSL/WHLSLSynthesizeArrayOperatorLength.cpp:
(WebCore::WHLSL::synthesizeArrayOperatorLength):
* Modules/webgpu/WHLSL/WHLSLSynthesizeArrayOperatorLength.h:
* Modules/webgpu/WHLSL/WHLSLSynthesizeConstructors.cpp: Adding to the program can fail.
(WebCore::WHLSL::synthesizeConstructors): Some constructors shouldn't be generated for "void" and for atomic types.
* Modules/webgpu/WHLSL/WHLSLSynthesizeConstructors.h: Adding to the program can fail.
* Modules/webgpu/WHLSL/WHLSLSynthesizeEnumerationFunctions.cpp: Ditto.
(WebCore::WHLSL::synthesizeEnumerationFunctions):
* Modules/webgpu/WHLSL/WHLSLSynthesizeEnumerationFunctions.h: Ditto.
* Modules/webgpu/WHLSL/WHLSLSynthesizeStructureAccessors.cpp: Ditto.
(WebCore::WHLSL::synthesizeStructureAccessors):
* Modules/webgpu/WHLSL/WHLSLSynthesizeStructureAccessors.h: Ditto.
* Modules/webgpu/WHLSL/WHLSLVisitor.cpp:
(WebCore::WHLSL::Visitor::visit):
* Modules/webgpu/WebGPUDevice.cpp: Add flag that triggers the WHLSL compiler.
(WebCore::WebGPUDevice::createShaderModule const):
* Modules/webgpu/WebGPUShaderModuleDescriptor.h: Ditto.
* Modules/webgpu/WebGPUShaderModuleDescriptor.idl: Ditto.
* WebCore.xcodeproj/project.pbxproj:
* platform/graphics/gpu/GPUPipelineLayout.h:
(WebCore::GPUPipelineLayout::bindGroupLayouts const):
* platform/graphics/gpu/GPUShaderModule.h: Add a string that represents the WHLSL shader source. The compiler currently
needs the rest of the pipeline state descriptor, so we defer compilation until create*Pipeline().
(WebCore::GPUShaderModule::platformShaderModule const):
(WebCore::GPUShaderModule::whlslSource const):
* platform/graphics/gpu/GPUShaderModuleDescriptor.h:
* platform/graphics/gpu/cocoa/GPURenderPipelineMetal.mm: Convert GPU types into WHLSL types, and invoke the compiler.
(WebCore::convertVertexFormat):
(WebCore::convertShaderStageFlags):
(WebCore::convertBindingType):
(WebCore::convertTextureFormat):
(WebCore::convertLayout):
(WebCore::convertRenderPipelineDescriptor):
(WebCore::trySetMetalFunctionsForPipelineDescriptor):
(WebCore::trySetWHLSLFunctionsForPipelineDescriptor):
(WebCore::trySetFunctionsForPipelineDescriptor):
(WebCore::tryCreateMtlRenderPipelineState):
* platform/graphics/gpu/cocoa/GPUShaderModuleMetal.mm:
(WebCore::GPUShaderModule::create):
(WebCore::GPUShaderModule::GPUShaderModule):

LayoutTests:

* webgpu/whlsl-expected.html: Added.
* webgpu/whlsl.html: Added.

git-svn-id: https://svn.webkit.org/repository/webkit/trunk@243091 268f45cc-cd09-0410-ab3c-d52691b4dbfc

61 files changed:
LayoutTests/ChangeLog
LayoutTests/webgpu/whlsl-expected.html [new file with mode: 0644]
LayoutTests/webgpu/whlsl.html [new file with mode: 0644]
Source/WebCore/ChangeLog
Source/WebCore/Modules/webgpu/WHLSL/AST/WHLSLCallExpression.h
Source/WebCore/Modules/webgpu/WHLSL/AST/WHLSLNativeTypeDeclaration.h
Source/WebCore/Modules/webgpu/WHLSL/AST/WHLSLReturn.h
Source/WebCore/Modules/webgpu/WHLSL/AST/WHLSLTypeReference.h
Source/WebCore/Modules/webgpu/WHLSL/AST/WHLSLVariableReference.h
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLEntryPointScaffolding.cpp
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLEntryPointScaffolding.h
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLFunctionWriter.cpp
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLFunctionWriter.h
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLMetalCodeGenerator.cpp
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLMetalCodeGenerator.h
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLNativeFunctionWriter.cpp
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLNativeFunctionWriter.h
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLNativeTypeWriter.cpp
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLTypeNamer.cpp
Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLTypeNamer.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLCheckDuplicateFunctions.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLChecker.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLGatherEntryPointItems.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLInferTypes.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLInferTypes.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLIntrinsics.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLIntrinsics.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLLexer.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLNameContext.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLNameResolver.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLNameResolver.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLParser.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLParser.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLPipelineDescriptor.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLPrepare.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLPrepare.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLRecursiveTypeChecker.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLResolveOverloadImpl.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLResolveOverloadImpl.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLResolvingType.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLScopedSetAdder.h [moved from Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLMappedBindings.h with 72% similarity]
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSemanticMatcher.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLStandardLibrary.txt
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSynthesizeArrayOperatorLength.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSynthesizeArrayOperatorLength.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSynthesizeConstructors.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSynthesizeConstructors.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSynthesizeEnumerationFunctions.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSynthesizeEnumerationFunctions.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSynthesizeStructureAccessors.cpp
Source/WebCore/Modules/webgpu/WHLSL/WHLSLSynthesizeStructureAccessors.h
Source/WebCore/Modules/webgpu/WHLSL/WHLSLVisitor.cpp
Source/WebCore/Modules/webgpu/WebGPUDevice.cpp
Source/WebCore/Modules/webgpu/WebGPUShaderModuleDescriptor.h
Source/WebCore/Modules/webgpu/WebGPUShaderModuleDescriptor.idl
Source/WebCore/WebCore.xcodeproj/project.pbxproj
Source/WebCore/platform/graphics/gpu/GPUPipelineLayout.h
Source/WebCore/platform/graphics/gpu/GPUShaderModule.h
Source/WebCore/platform/graphics/gpu/GPUShaderModuleDescriptor.h
Source/WebCore/platform/graphics/gpu/cocoa/GPURenderPipelineMetal.mm
Source/WebCore/platform/graphics/gpu/cocoa/GPUShaderModuleMetal.mm

index 35cc046..2d1fe43 100644 (file)
@@ -1,3 +1,13 @@
+2019-03-18  Myles C. Maxfield  <mmaxfield@apple.com>
+
+        [WHLSL] Hook up the compiler to our WebGPU implementation
+        https://bugs.webkit.org/show_bug.cgi?id=195509
+
+        Reviewed by Dean Jackson.
+
+        * webgpu/whlsl-expected.html: Added.
+        * webgpu/whlsl.html: Added.
+
 2019-03-18  Justin Fan  <justin_fan@apple.com>
 
         [Web GPU] GPUAdapter.createDevice -> GPUAdapter.requestDevice
diff --git a/LayoutTests/webgpu/whlsl-expected.html b/LayoutTests/webgpu/whlsl-expected.html
new file mode 100644 (file)
index 0000000..f417050
--- /dev/null
@@ -0,0 +1,19 @@
+<!DOCTYPE html>
+<html>
+<head>
+</head>
+<body>
+<canvas id="canvas" width="400" height="400"></canvas>
+<script>
+async function start() {
+    const canvas = document.getElementById("canvas");
+    const context = canvas.getContext("2d");
+    context.fillStyle = "blue";
+    context.fillRect(0, 0, 400, 400);
+    context.fillStyle = "white";
+    context.fillRect(100, 100, 200, 200);
+}
+window.addEventListener("load", start);
+</script>
+</body>
+</html>
diff --git a/LayoutTests/webgpu/whlsl.html b/LayoutTests/webgpu/whlsl.html
new file mode 100644 (file)
index 0000000..6ce3cdc
--- /dev/null
@@ -0,0 +1,120 @@
+<!DOCTYPE html>
+<html>
+<head>
+</head>
+<body>
+<canvas id="canvas" width="400" height="400"></canvas>
+<script>
+const shaderSource = `
+vertex float4 vertexShader(float4 position : attribute(0), float i : attribute(1)) : SV_Position {
+    return position;
+}
+
+fragment float4 fragmentShader(float4 position : SV_Position) : SV_Target 0 {
+    return position;
+}
+`;
+async function start() {
+    const adapter = await window.gpu.requestAdapter();
+    const device = adapter.createDevice();
+
+    const shaderModule = device.createShaderModule({code: shaderSource, isWHLSL: true});
+    const vertexStage = {module: shaderModule, entryPoint: "vertexShader"};
+    const fragmentStage = {module: shaderModule, entryPoint: "fragmentShader"};
+    const primitiveTopology = "triangle-strip";
+    const rasterizationState = {frontFace: "cw", cullMode: "none"};
+    const alphaBlend = {srcFactor: "zero", dstFactor: "one", operation: "add"};
+    const colorBlend = {srcFactor: "zero", dstFactor: "one", operation: "add"};
+    const colorStates = [{format: "rgba8unorm", alphaBlend, colorBlend, writeMask: 15}]; // GPUColorWriteBits.ALL
+    const depthStencilState = null;
+    
+    const attribute0 = {shaderLocation: 0, inputSlot: 0, offset: 0, format: "float4"};
+    const attribute1 = {shaderLocation: 1, inputSlot: 1, offset: 0, format: "float"};
+    const attributes = [attribute0, attribute1];
+    const input0 = {inputSlot: 0, stride: 16, stepMode: "vertex"};
+    const input1 = {inputSlot: 1, stride: 4, stepMode: "vertex"};
+    const inputs = [input0, input1];
+    const inputState = {indexFormat: "uint32", attributes, inputs};
+
+    const bindGroupLayoutDescriptor = {bindings: [{binding: 0, visibility: 7, type: "uniform-buffer"}]};
+    const bindGroupLayout = device.createBindGroupLayout(bindGroupLayoutDescriptor);
+    const pipelineLayoutDescriptor = {bindGroupLayouts: [bindGroupLayout]};
+    const pipelineLayout = device.createPipelineLayout(pipelineLayoutDescriptor);
+
+    const renderPipelineDescriptor = {vertexStage, fragmentStage, primitiveTopology, rasterizationState, colorStates, depthStencilState, inputState, sampleCount: 1, layout: pipelineLayout};
+    const renderPipeline = device.createRenderPipeline(renderPipelineDescriptor);
+
+    const vertexBuffer0Descriptor = {size: Float32Array.BYTES_PER_ELEMENT * 4 * 4, usage: GPUBufferUsage.VERTEX | GPUBufferUsage.MAP_WRITE};
+    const vertexBuffer0 = device.createBuffer(vertexBuffer0Descriptor);
+    const vertexBuffer0ArrayBuffer = await vertexBuffer0.mapWriteAsync();
+    const vertexBuffer0Float32Array = new Float32Array(vertexBuffer0ArrayBuffer);
+    vertexBuffer0Float32Array[0] = -0.5;
+    vertexBuffer0Float32Array[1] = -0.5;
+    vertexBuffer0Float32Array[2] = 1.0;
+    vertexBuffer0Float32Array[3] = 1;
+    vertexBuffer0Float32Array[4] = -0.5;
+    vertexBuffer0Float32Array[5] = 0.5;
+    vertexBuffer0Float32Array[6] = 1.0;
+    vertexBuffer0Float32Array[7] = 1;
+    vertexBuffer0Float32Array[8] = 0.5;
+    vertexBuffer0Float32Array[9] = -0.5;
+    vertexBuffer0Float32Array[10] = 1.0;
+    vertexBuffer0Float32Array[11] = 1;
+    vertexBuffer0Float32Array[12] = 0.5;
+    vertexBuffer0Float32Array[13] = 0.5;
+    vertexBuffer0Float32Array[14] = 1.0;
+    vertexBuffer0Float32Array[15] = 1;
+    vertexBuffer0.unmap();
+
+    const vertexBuffer1Descriptor = {size: Float32Array.BYTES_PER_ELEMENT * 4, usage: GPUBufferUsage.VERTEX | GPUBufferUsage.MAP_WRITE};
+    const vertexBuffer1 = device.createBuffer(vertexBuffer1Descriptor);
+    const vertexBuffer1ArrayBuffer = await vertexBuffer1.mapWriteAsync();
+    const vertexBuffer1Float32Array = new Float32Array(vertexBuffer1ArrayBuffer);
+    vertexBuffer1Descriptor[0] = 1;
+    vertexBuffer1Descriptor[1] = 1;
+    vertexBuffer1Descriptor[2] = 1;
+    vertexBuffer1Descriptor[3] = 1;
+    vertexBuffer1.unmap();
+
+    const resourceBufferDescriptor = {size: Float32Array.BYTES_PER_ELEMENT, usage: GPUBufferUsage.UNIFORM | GPUBufferUsage.MAP_WRITE};
+    const resourceBuffer = device.createBuffer(resourceBufferDescriptor);
+    const resourceBufferArrayBuffer = await resourceBuffer.mapWriteAsync();
+    const resourceBufferFloat32Array = new Float32Array(resourceBufferArrayBuffer);
+    resourceBufferFloat32Array[0] = 1;
+    resourceBuffer.unmap();
+
+    const bufferBinding = {buffer: resourceBuffer, offset: 0, size: 4};
+    const bindGroupBinding = {binding: 0, resource: bufferBinding};
+    const bindGroupDescriptor = {layout: bindGroupLayout, bindings: [bindGroupBinding]};
+    const bindGroup = device.createBindGroup(bindGroupDescriptor);
+
+    const canvas = document.getElementById("canvas");
+    const context = canvas.getContext("gpu");
+    const swapChainDescriptor = {context, format: "bgra8unorm"};
+    const swapChain = device.createSwapChain(swapChainDescriptor);
+    const outputTexture = swapChain.getCurrentTexture();
+    const outputTextureView = outputTexture.createDefaultTextureView(); // createDefaultView()
+
+    const commandEncoder = device.createCommandEncoder(); // {}
+    const red = {r: 0, g: 0, b: 1, a: 1};
+    const colorAttachments = [{attachment: outputTextureView, resolveTarget: null, loadOp: "clear", storeOp: "store", clearColor: red}];
+    const depthStencilAttachment = null;
+    const renderPassDescriptor = {colorAttachments, depthStencilAttachment};
+    const renderPassEncoder = commandEncoder.beginRenderPass(renderPassDescriptor);
+    renderPassEncoder.setPipeline(renderPipeline);
+    renderPassEncoder.setBindGroup(0, bindGroup);
+    renderPassEncoder.setVertexBuffers(0, [vertexBuffer0, vertexBuffer1], [0, 0]);
+    renderPassEncoder.draw(4, 1, 0, 0);
+    renderPassEncoder.endPass();
+    const commandBuffer = commandEncoder.finish();
+    device.getQueue().submit([commandBuffer]);
+
+    if (window.testRunner)
+        testRunner.notifyDone();
+}
+if (window.testRunner)
+    testRunner.waitUntilDone();
+window.addEventListener("load", start);
+</script>
+</body>
+</html>
index e151d61..30c2dfb 100644 (file)
@@ -1,3 +1,202 @@
+2019-03-18  Myles C. Maxfield  <mmaxfield@apple.com>
+
+        [WHLSL] Hook up the compiler to our WebGPU implementation
+        https://bugs.webkit.org/show_bug.cgi?id=195509
+
+        Reviewed by Dean Jackson.
+
+        This represents a collection of changes necessary to compile and run the first WHLSL program in WebKit.
+
+        Because WHLSL isn't fully implemented yet, this patch doesn't remove the existing method for supplying
+        Metal shaders to WebGPU. Instead, it adds a new boolean to WebGPUShaderModuleDescriptor, "isWHLSL" which
+        causes us to run the WHLSL compiler.
+
+        More details below.
+
+        Test: webgpu/whlsl.html
+
+        * Modules/webgpu/WHLSL/AST/WHLSLCallExpression.h: Use raw pointer instead of Optional<std::reference_wrapper>.
+        (WebCore::WHLSL::AST::CallExpression::setCastData):
+        (WebCore::WHLSL::AST::CallExpression::isCast):
+        (WebCore::WHLSL::AST::CallExpression::castReturnType):
+        * Modules/webgpu/WHLSL/AST/WHLSLNativeTypeDeclaration.h:
+        (WebCore::WHLSL::AST::NativeTypeDeclaration::isAtomic const):
+        (WebCore::WHLSL::AST::NativeTypeDeclaration::setIsAtomic):
+        (WebCore::WHLSL::AST::NativeTypeDeclaration::name const): Deleted. The parent class already has a name string.
+        (WebCore::WHLSL::AST::NativeTypeDeclaration::name): Deleted.
+        * Modules/webgpu/WHLSL/AST/WHLSLReturn.h:
+        * Modules/webgpu/WHLSL/AST/WHLSLTypeReference.h:
+        (WebCore::WHLSL::AST::TypeReference::cloneTypeReference const): When cloning a type reference, make sure to
+        clone the pointer to its resolved type, too.
+        * Modules/webgpu/WHLSL/AST/WHLSLVariableReference.h:
+        * Modules/webgpu/WHLSL/Metal/WHLSLEntryPointScaffolding.cpp: Incorporate resolution from
+        https://github.com/gpuweb/gpuweb/pull/188.
+        (WebCore::WHLSL::Metal::EntryPointScaffolding::EntryPointScaffolding):
+        (WebCore::WHLSL::Metal::EntryPointScaffolding::resourceHelperTypes):
+        (WebCore::WHLSL::Metal::EntryPointScaffolding::resourceSignature):
+        (WebCore::WHLSL::Metal::VertexEntryPointScaffolding::helperTypes):
+        (WebCore::WHLSL::Metal::VertexEntryPointScaffolding::unpack):
+        (WebCore::WHLSL::Metal::VertexEntryPointScaffolding::pack): Support semantics being placed directly on the
+        entry point, instead of being placed on a structure member.
+        (WebCore::WHLSL::Metal::FragmentEntryPointScaffolding::helperTypes):
+        (WebCore::WHLSL::Metal::FragmentEntryPointScaffolding::pack): Ditto.
+        (WebCore::WHLSL::Metal::EntryPointScaffolding::mappedBindGroups const): Deleted.
+        * Modules/webgpu/WHLSL/Metal/WHLSLEntryPointScaffolding.h:
+        * Modules/webgpu/WHLSL/Metal/WHLSLFunctionWriter.cpp:
+        (WebCore::WHLSL::Metal::FunctionDefinitionWriter::visit):
+        (WebCore::WHLSL::Metal::RenderFunctionDefinitionWriter::createEntryPointScaffolding):
+        (WebCore::WHLSL::Metal::ComputeFunctionDefinitionWriter::createEntryPointScaffolding):
+        (WebCore::WHLSL::Metal::metalFunctions):
+        (WebCore::WHLSL::Metal::RenderFunctionDefinitionWriter::takeVertexMappedBindGroups): Deleted. After
+        https://github.com/gpuweb/gpuweb/pull/188, we don't need the mappings.
+        (WebCore::WHLSL::Metal::RenderFunctionDefinitionWriter::takeFragmentMappedBindGroups): Deleted. Ditto.
+        (WebCore::WHLSL::Metal::ComputeFunctionDefinitionWriter::takeMappedBindGroups): Deleted. Ditto.
+        * Modules/webgpu/WHLSL/Metal/WHLSLFunctionWriter.h: Ditto.
+        * Modules/webgpu/WHLSL/Metal/WHLSLMetalCodeGenerator.cpp: Ditto.
+        (WebCore::WHLSL::Metal::generateMetalCodeShared):
+        (WebCore::WHLSL::Metal::generateMetalCode):
+        * Modules/webgpu/WHLSL/Metal/WHLSLMetalCodeGenerator.h: Ditto.
+        * Modules/webgpu/WHLSL/Metal/WHLSLNativeFunctionWriter.cpp: Support compiler-generated functions. Change
+        CRASH() to notImplemented().
+        (WebCore::WHLSL::Metal::writeNativeFunction):
+        (WebCore::WHLSL::Metal::getNativeName): Deleted.
+        * Modules/webgpu/WHLSL/Metal/WHLSLNativeFunctionWriter.h:
+        * Modules/webgpu/WHLSL/Metal/WHLSLNativeTypeWriter.cpp:
+        (WebCore::WHLSL::Metal::writeNativeType): 
+        * Modules/webgpu/WHLSL/Metal/WHLSLTypeNamer.cpp: The dependency graph needs to track all unnamed types. Also,
+        we need to track types that are the results of expressions (not just types literally spelled out in the
+        program). Enumerations need to be emitted after their base types are emitted.
+        (WebCore::WHLSL::Metal::TypeNamer::visit):
+        (WebCore::WHLSL::Metal::MetalTypeDeclarationWriter::MetalTypeDeclarationWriter):
+        (WebCore::WHLSL::Metal::TypeNamer::metalTypeDeclarations):
+        (WebCore::WHLSL::Metal::TypeNamer::emitUnnamedTypeDefinition):
+        (WebCore::WHLSL::Metal::TypeNamer::emitNamedTypeDefinition):
+        (WebCore::WHLSL::Metal::TypeNamer::emitAllUnnamedTypeDefinitions):
+        (WebCore::WHLSL::Metal::TypeNamer::metalTypeDefinitions):
+        * Modules/webgpu/WHLSL/Metal/WHLSLTypeNamer.h:
+        * Modules/webgpu/WHLSL/WHLSLCheckDuplicateFunctions.cpp:
+        (WebCore::WHLSL::checkDuplicateFunctions):
+        * Modules/webgpu/WHLSL/WHLSLChecker.cpp: Wrap ResolvingType in a class to make sure it plays nicely with
+        HashMap. Also, use raw pointers instead of Optional<std::reference_wrapper>s.
+        (WebCore::WHLSL::resolveWithReferenceComparator):
+        (WebCore::WHLSL::resolveByInstantiation):
+        (WebCore::WHLSL::checkOperatorOverload):
+        (WebCore::WHLSL::Checker::assignTypes):
+        (WebCore::WHLSL::Checker::checkShaderType):
+        (WebCore::WHLSL::Checker::visit):
+        (WebCore::WHLSL::matchAndCommit):
+        (WebCore::WHLSL::Checker::recurseAndGetInfo):
+        (WebCore::WHLSL::Checker::assignType):
+        (WebCore::WHLSL::Checker::forwardType):
+        (WebCore::WHLSL::getUnnamedType):
+        (WebCore::WHLSL::Checker::finishVisitingPropertyAccess):
+        (WebCore::WHLSL::Checker::isBoolType):
+        * Modules/webgpu/WHLSL/WHLSLGatherEntryPointItems.cpp:
+        (WebCore::WHLSL::Gatherer::visit):
+        * Modules/webgpu/WHLSL/WHLSLInferTypes.cpp:
+        (WebCore::WHLSL::inferTypesForCall):
+        * Modules/webgpu/WHLSL/WHLSLInferTypes.h:
+        * Modules/webgpu/WHLSL/WHLSLIntrinsics.cpp:
+        (WebCore::WHLSL::Intrinsics::addPrimitive):
+        (WebCore::WHLSL::Intrinsics::addFullTexture):
+        * Modules/webgpu/WHLSL/WHLSLIntrinsics.h:
+        (WebCore::WHLSL::Intrinsics::ucharType const):
+        (WebCore::WHLSL::Intrinsics::ushortType const):
+        (WebCore::WHLSL::Intrinsics::charType const):
+        (WebCore::WHLSL::Intrinsics::shortType const):
+        (WebCore::WHLSL::Intrinsics::intType const):
+        (WebCore::WHLSL::Intrinsics::uchar2Type const):
+        (WebCore::WHLSL::Intrinsics::uchar4Type const):
+        (WebCore::WHLSL::Intrinsics::ushort2Type const):
+        (WebCore::WHLSL::Intrinsics::ushort4Type const):
+        (WebCore::WHLSL::Intrinsics::uint2Type const):
+        (WebCore::WHLSL::Intrinsics::uint4Type const):
+        (WebCore::WHLSL::Intrinsics::char2Type const):
+        (WebCore::WHLSL::Intrinsics::char4Type const):
+        (WebCore::WHLSL::Intrinsics::short2Type const):
+        (WebCore::WHLSL::Intrinsics::short4Type const):
+        (WebCore::WHLSL::Intrinsics::int2Type const):
+        (WebCore::WHLSL::Intrinsics::int4Type const):
+        * Modules/webgpu/WHLSL/WHLSLLexer.cpp:
+        (WebCore::WHLSL::Lexer::recognizeKeyword):
+        * Modules/webgpu/WHLSL/WHLSLNameContext.cpp:
+        (WebCore::WHLSL::NameContext::add):
+        * Modules/webgpu/WHLSL/WHLSLNameResolver.cpp:
+        (WebCore::WHLSL::NameResolver::visit): Don't visit recursive types.
+        Also, make sure we preserve the CurrentFunction in our recursive scopes.
+        * Modules/webgpu/WHLSL/WHLSLNameResolver.h:
+        * Modules/webgpu/WHLSL/WHLSLParser.cpp:
+        (WebCore::WHLSL::Parser::fail):
+        (WebCore::WHLSL::Parser::peek):
+        (WebCore::WHLSL::Parser::parseType):
+        (WebCore::WHLSL::Parser::parseBuiltInSemantic):
+        * Modules/webgpu/WHLSL/WHLSLParser.h:
+        * Modules/webgpu/WHLSL/WHLSLPipelineDescriptor.h:
+        * Modules/webgpu/WHLSL/WHLSLPrepare.cpp:
+        (WebCore::WHLSL::prepareShared):
+        (WebCore::WHLSL::prepare):
+        * Modules/webgpu/WHLSL/WHLSLPrepare.h:
+        * Modules/webgpu/WHLSL/WHLSLRecursiveTypeChecker.cpp: Move big inline functions out-of-line.
+        (WebCore::WHLSL::RecursiveTypeChecker::visit):
+        (WebCore::WHLSL::checkRecursiveTypes):
+        (): Deleted.
+        * Modules/webgpu/WHLSL/WHLSLResolveOverloadImpl.cpp:
+        (WebCore::WHLSL::conversionCost):
+        (WebCore::WHLSL::resolveFunctionOverloadImpl):
+        * Modules/webgpu/WHLSL/WHLSLResolveOverloadImpl.h:
+        * Modules/webgpu/WHLSL/WHLSLResolvingType.h:
+        (WebCore::WHLSL::ResolvingType::ResolvingType):
+        (WebCore::WHLSL::ResolvingType::operator=):
+        (WebCore::WHLSL::ResolvingType::getUnnamedType):
+        (WebCore::WHLSL::ResolvingType::visit):
+        * Modules/webgpu/WHLSL/WHLSLScopedSetAdder.h: Renamed from Source/WebCore/Modules/webgpu/WHLSL/Metal/WHLSLMappedBindings.h.
+        (WebCore::WHLSL::ScopedSetAdder::ScopedSetAdder):
+        (WebCore::WHLSL::ScopedSetAdder::~ScopedSetAdder):
+        (WebCore::WHLSL::ScopedSetAdder::isNewEntry const):
+        * Modules/webgpu/WHLSL/WHLSLSemanticMatcher.cpp:
+        (WebCore::WHLSL::isAcceptableFormat):
+        * Modules/webgpu/WHLSL/WHLSLStandardLibrary.txt: Turns out a bunch of texture types don't exist in MSL.
+        * Modules/webgpu/WHLSL/WHLSLSynthesizeArrayOperatorLength.cpp:
+        (WebCore::WHLSL::synthesizeArrayOperatorLength):
+        * Modules/webgpu/WHLSL/WHLSLSynthesizeArrayOperatorLength.h:
+        * Modules/webgpu/WHLSL/WHLSLSynthesizeConstructors.cpp: Adding to the program can fail.
+        (WebCore::WHLSL::synthesizeConstructors): Some constructors shouldn't be generated for "void" and for atomic types.
+        * Modules/webgpu/WHLSL/WHLSLSynthesizeConstructors.h: Adding to the program can fail.
+        * Modules/webgpu/WHLSL/WHLSLSynthesizeEnumerationFunctions.cpp: Ditto.
+        (WebCore::WHLSL::synthesizeEnumerationFunctions):
+        * Modules/webgpu/WHLSL/WHLSLSynthesizeEnumerationFunctions.h: Ditto.
+        * Modules/webgpu/WHLSL/WHLSLSynthesizeStructureAccessors.cpp: Ditto.
+        (WebCore::WHLSL::synthesizeStructureAccessors):
+        * Modules/webgpu/WHLSL/WHLSLSynthesizeStructureAccessors.h: Ditto.
+        * Modules/webgpu/WHLSL/WHLSLVisitor.cpp:
+        (WebCore::WHLSL::Visitor::visit):
+        * Modules/webgpu/WebGPUDevice.cpp: Add flag that triggers the WHLSL compiler.
+        (WebCore::WebGPUDevice::createShaderModule const):
+        * Modules/webgpu/WebGPUShaderModuleDescriptor.h: Ditto.
+        * Modules/webgpu/WebGPUShaderModuleDescriptor.idl: Ditto.
+        * WebCore.xcodeproj/project.pbxproj:
+        * platform/graphics/gpu/GPUPipelineLayout.h:
+        (WebCore::GPUPipelineLayout::bindGroupLayouts const):
+        * platform/graphics/gpu/GPUShaderModule.h: Add a string that represents the WHLSL shader source. The compiler currently
+        needs the rest of the pipeline state descriptor, so we defer compilation until create*Pipeline().
+        (WebCore::GPUShaderModule::platformShaderModule const):
+        (WebCore::GPUShaderModule::whlslSource const):
+        * platform/graphics/gpu/GPUShaderModuleDescriptor.h:
+        * platform/graphics/gpu/cocoa/GPURenderPipelineMetal.mm: Convert GPU types into WHLSL types, and invoke the compiler.
+        (WebCore::convertVertexFormat):
+        (WebCore::convertShaderStageFlags):
+        (WebCore::convertBindingType):
+        (WebCore::convertTextureFormat):
+        (WebCore::convertLayout):
+        (WebCore::convertRenderPipelineDescriptor):
+        (WebCore::trySetMetalFunctionsForPipelineDescriptor):
+        (WebCore::trySetWHLSLFunctionsForPipelineDescriptor):
+        (WebCore::trySetFunctionsForPipelineDescriptor):
+        (WebCore::tryCreateMtlRenderPipelineState):
+        * platform/graphics/gpu/cocoa/GPUShaderModuleMetal.mm:
+        (WebCore::GPUShaderModule::create):
+        (WebCore::GPUShaderModule::GPUShaderModule):
+
 2019-03-18  Justin Fan  <justin_fan@apple.com>
 
         [Web GPU] GPUAdapter.createDevice -> GPUAdapter.requestDevice
index bab3181..1a0d027 100644 (file)
@@ -62,11 +62,11 @@ public:
 
     void setCastData(NamedType& namedType)
     {
-        m_castReturnType = { namedType };
+        m_castReturnType = &namedType;
     }
 
-    bool isCast() { return static_cast<bool>(m_castReturnType); }
-    Optional<std::reference_wrapper<NamedType>>& castReturnType() { return m_castReturnType; }
+    bool isCast() { return m_castReturnType; }
+    NamedType* castReturnType() { return m_castReturnType; }
     bool hasOverloads() const { return static_cast<bool>(m_overloads); }
     Optional<Vector<std::reference_wrapper<FunctionDeclaration>, 1>>& overloads() { return m_overloads; }
     void setOverloads(const Vector<std::reference_wrapper<FunctionDeclaration>, 1>& overloads)
@@ -88,7 +88,7 @@ private:
     Vector<UniqueRef<Expression>> m_arguments;
     Optional<Vector<std::reference_wrapper<FunctionDeclaration>, 1>> m_overloads;
     FunctionDeclaration* m_function { nullptr };
-    Optional<std::reference_wrapper<NamedType>> m_castReturnType { WTF::nullopt };
+    NamedType* m_castReturnType { nullptr };
 };
 
 } // namespace AST
index f651880..3ae6ed6 100644 (file)
@@ -54,13 +54,12 @@ public:
 
     bool isNativeTypeDeclaration() const override { return true; }
 
-    const String& name() const { return m_name; }
-    String& name() { return m_name; }
     TypeArguments& typeArguments() { return m_typeArguments; }
 
     bool isInt() const { return m_isInt; }
     bool isNumber() const { return m_isNumber; }
     bool isFloating() const { return m_isFloating; }
+    bool isAtomic() const { return m_isAtomic; }
     bool isVector() const { return m_isVector; }
     bool isMatrix() const { return m_isMatrix; }
     bool isTexture() const { return m_isTexture; }
@@ -76,6 +75,7 @@ public:
     void setIsInt() { m_isInt = true; }
     void setIsNumber() { m_isNumber = true; }
     void setIsFloating() { m_isFloating = true; }
+    void setIsAtomic() { m_isAtomic = true; }
     void setIsVector() { m_isVector = true; }
     void setIsMatrix() { m_isMatrix = true; }
     void setIsTexture() { m_isTexture = true; }
@@ -89,7 +89,6 @@ public:
     void setIterateAllValues(std::function<void(const std::function<bool(int64_t)>&)>&& iterateAllValues) { m_iterateAllValues = WTFMove(iterateAllValues); }
 
 private:
-    String m_name;
     TypeArguments m_typeArguments;
     std::function<bool(int)> m_canRepresentInteger;
     std::function<bool(unsigned)> m_canRepresentUnsignedInteger;
@@ -101,6 +100,7 @@ private:
     bool m_isInt { false };
     bool m_isNumber { false };
     bool m_isFloating { false };
+    bool m_isAtomic { false };
     bool m_isVector { false };
     bool m_isMatrix { false };
     bool m_isTexture { false };
index f9d209d..75ace01 100644 (file)
@@ -61,7 +61,7 @@ public:
 
 private:
     Optional<UniqueRef<Expression>> m_value;
-    FunctionDefinition* m_function;
+    FunctionDefinition* m_function { nullptr };
 };
 
 } // namespace AST
index bb69f03..bf4c599 100644 (file)
@@ -83,7 +83,10 @@ public:
 
     UniqueRef<TypeReference> cloneTypeReference() const
     {
-        return makeUniqueRef<TypeReference>(Lexer::Token(origin()), String(m_name), AST::clone(m_typeArguments));
+        auto result = makeUniqueRef<TypeReference>(Lexer::Token(origin()), String(m_name), AST::clone(m_typeArguments));
+        if (m_resolvedType)
+            result->setResolvedType(*m_resolvedType);
+        return result;
     }
 
     UniqueRef<UnnamedType> clone() const override
index 5052fdf..69e79ad 100644 (file)
@@ -76,7 +76,7 @@ private:
     }
 
     String m_name;
-    VariableDeclaration* m_variable;
+    VariableDeclaration* m_variable { nullptr };
 };
 
 } // namespace AST
index 2854f8a..624c562 100644 (file)
@@ -98,19 +98,17 @@ EntryPointScaffolding::EntryPointScaffolding(AST::FunctionDefinition& functionDe
     , m_layout(layout)
     , m_generateNextVariableName(generateNextVariableName)
 {
-    unsigned argumentBufferIndex = 0;
     m_namedBindGroups.reserveInitialCapacity(m_layout.size());
     for (size_t i = 0; i < m_layout.size(); ++i) {
         NamedBindGroup namedBindGroup;
         namedBindGroup.structName = m_typeNamer.generateNextTypeName();
         namedBindGroup.variableName = m_generateNextVariableName();
-        namedBindGroup.argumentBufferIndex = argumentBufferIndex++;
+        namedBindGroup.argumentBufferIndex = m_layout[i].name; // convertLayout() in GPURenderPipelineMetal.mm makes sure these don't collide.
         namedBindGroup.namedBindings.reserveInitialCapacity(m_layout[i].bindings.size());
-        unsigned index = 0;
         for (size_t j = 0; j < m_layout[i].bindings.size(); ++j) {
             NamedBinding namedBinding;
             namedBinding.elementName = m_typeNamer.generateNextStructureElementName();
-            namedBinding.index = index++;
+            namedBinding.index = m_layout[i].bindings[j].name; // GPUBindGroupLayout::tryCreate() makes sure these don't collide.
             namedBindGroup.namedBindings.uncheckedAppend(WTFMove(namedBinding));
         }
         m_namedBindGroups.uncheckedAppend(WTFMove(namedBindGroup));
@@ -130,21 +128,6 @@ EntryPointScaffolding::EntryPointScaffolding(AST::FunctionDefinition& functionDe
         m_parameterVariables.uncheckedAppend(m_generateNextVariableName());
 }
 
-MappedBindGroups EntryPointScaffolding::mappedBindGroups() const
-{
-    MappedBindGroups result;
-    result.reserveInitialCapacity(m_layout.size());
-    for (auto& namedBindGroup : m_namedBindGroups) {
-        MappedBindGroup mappedBindGroup;
-        mappedBindGroup.argumentBufferIndex = namedBindGroup.argumentBufferIndex;
-        mappedBindGroup.bindingIndices.reserveInitialCapacity(namedBindGroup.namedBindings.size());
-        for (auto& namedBinding : namedBindGroup.namedBindings)
-            mappedBindGroup.bindingIndices.uncheckedAppend(namedBinding.index);
-        result.uncheckedAppend(WTFMove(mappedBindGroup));
-    }
-    return result;
-}
-
 String EntryPointScaffolding::resourceHelperTypes()
 {
     StringBuilder stringBuilder;
@@ -159,7 +142,7 @@ String EntryPointScaffolding::resourceHelperTypes()
             auto index = m_namedBindGroups[i].namedBindings[j].index;
             stringBuilder.append(makeString("    ", mangledTypeName, ' ', elementName, " [[id(", index, ")]];\n"));
         }
-        stringBuilder.append("}\n\n");
+        stringBuilder.append("};\n\n");
     }
     return stringBuilder.toString();
 }
@@ -174,7 +157,7 @@ Optional<String> EntryPointScaffolding::resourceSignature()
         if (i)
             stringBuilder.append(", ");
         auto& namedBindGroup = m_namedBindGroups[i];
-        stringBuilder.append(makeString(namedBindGroup.structName, "& ", namedBindGroup.variableName, " [[buffer(", namedBindGroup.argumentBufferIndex, ")]]"));
+        stringBuilder.append(makeString("device ", namedBindGroup.structName, "& ", namedBindGroup.variableName, " [[buffer(", namedBindGroup.argumentBufferIndex, ")]]"));
     }
     return stringBuilder.toString();
 }
@@ -320,10 +303,10 @@ String VertexEntryPointScaffolding::helperTypes()
     for (auto& namedStageIn : m_namedStageIns) {
         auto mangledTypeName = m_typeNamer.mangledNameForType(*m_entryPointItems.inputs[namedStageIn.indexInEntryPointItems].unnamedType);
         auto elementName = namedStageIn.elementName;
-        auto attributeIndex = namedStageIn.elementName;
+        auto attributeIndex = namedStageIn.attributeIndex;
         stringBuilder.append(makeString("    ", mangledTypeName, ' ', elementName, " [[attribute(", attributeIndex, ")]];\n"));
     }
-    stringBuilder.append("}\n\n");
+    stringBuilder.append("};\n\n");
 
     stringBuilder.append(makeString("struct ", m_returnStructName, " {\n"));
     for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
@@ -333,7 +316,7 @@ String VertexEntryPointScaffolding::helperTypes()
         auto attribute = attributeForSemantic(*outputItem.semantic);
         stringBuilder.append(makeString("    ", mangledTypeName, ' ', elementName, ' ', attribute, ";\n"));
     }
-    stringBuilder.append("}\n\n");
+    stringBuilder.append("};\n\n");
 
     stringBuilder.append(resourceHelperTypes());
 
@@ -363,7 +346,7 @@ String VertexEntryPointScaffolding::unpack()
     for (auto& namedStageIn : m_namedStageIns) {
         auto& path = m_entryPointItems.inputs[namedStageIn.indexInEntryPointItems].path;
         auto& elementName = namedStageIn.elementName;
-        stringBuilder.append(makeString(mangledInputPath(path), " = ", m_stageInStructName, '.', elementName, ";\n"));
+        stringBuilder.append(makeString(mangledInputPath(path), " = ", m_stageInParameterName, '.', elementName, ";\n"));
     }
 
     return stringBuilder.toString();
@@ -372,7 +355,12 @@ String VertexEntryPointScaffolding::unpack()
 String VertexEntryPointScaffolding::pack(const String& inputVariableName, const String& outputVariableName)
 {
     StringBuilder stringBuilder;
-    stringBuilder.append(makeString(m_returnStructName, ' ', outputVariableName));
+    stringBuilder.append(makeString(m_returnStructName, ' ', outputVariableName, ";\n"));
+    if (m_entryPointItems.outputs.size() == 1 && !m_entryPointItems.outputs[0].path.size()) {
+        auto& elementName = m_namedOutputs[0].elementName;
+        stringBuilder.append(makeString(outputVariableName, '.', elementName, " = ", inputVariableName, ";\n"));
+        return stringBuilder.toString();
+    }
     for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
         auto& elementName = m_namedOutputs[i].elementName;
         auto& path = m_entryPointItems.outputs[i].path;
@@ -418,7 +406,7 @@ String FragmentEntryPointScaffolding::helperTypes()
         auto attributeIndex = namedStageIn.elementName;
         stringBuilder.append(makeString("    ", mangledTypeName, ' ', elementName, " [[user(", attributeIndex, ")]];\n"));
     }
-    stringBuilder.append("}\n\n");
+    stringBuilder.append("};\n\n");
 
     stringBuilder.append(makeString("struct ", m_returnStructName, " {\n"));
     for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
@@ -428,7 +416,7 @@ String FragmentEntryPointScaffolding::helperTypes()
         auto attribute = attributeForSemantic(*outputItem.semantic);
         stringBuilder.append(makeString("    ", mangledTypeName, ' ', elementName, ' ', attribute, ";\n"));
     }
-    stringBuilder.append("}\n\n");
+    stringBuilder.append("};\n\n");
 
     stringBuilder.append(resourceHelperTypes());
 
@@ -467,7 +455,12 @@ String FragmentEntryPointScaffolding::unpack()
 String FragmentEntryPointScaffolding::pack(const String& inputVariableName, const String& outputVariableName)
 {
     StringBuilder stringBuilder;
-    stringBuilder.append(makeString(m_returnStructName, ' ', outputVariableName));
+    stringBuilder.append(makeString(m_returnStructName, ' ', outputVariableName, ";\n"));
+    if (m_entryPointItems.outputs.size() == 1 && !m_entryPointItems.outputs[0].path.size()) {
+        auto& elementName = m_namedOutputs[0].elementName;
+        stringBuilder.append(makeString(outputVariableName, '.', elementName, " = ", inputVariableName, ";\n"));
+        return stringBuilder.toString();
+    }
     for (size_t i = 0; i < m_entryPointItems.outputs.size(); ++i) {
         auto& elementName = m_namedOutputs[i].elementName;
         auto& path = m_entryPointItems.outputs[i].path;
index 99adca5..2467e86 100644 (file)
@@ -27,7 +27,6 @@
 
 #if ENABLE(WEBGPU)
 
-#include "WHLSLMappedBindings.h"
 #include "WHLSLPipelineDescriptor.h"
 #include <wtf/HashMap.h>
 #include <wtf/text/WTFString.h>
@@ -59,7 +58,6 @@ public:
     virtual String unpack() = 0;
     virtual String pack(const String& existingVariableName, const String& variableName) = 0;
 
-    MappedBindGroups mappedBindGroups() const;
     Vector<String>& parameterVariables() { return m_parameterVariables; }
 
 protected:
index 1b7a42a..2fc1465 100644 (file)
@@ -28,6 +28,7 @@
 
 #if ENABLE(WEBGPU)
 
+#include "NotImplemented.h"
 #include "WHLSLArrayReferenceType.h"
 #include "WHLSLArrayType.h"
 #include "WHLSLAssignmentExpression.h"
@@ -50,7 +51,6 @@
 #include "WHLSLLogicalNotExpression.h"
 #include "WHLSLMakeArrayReferenceExpression.h"
 #include "WHLSLMakePointerExpression.h"
-#include "WHLSLMappedBindings.h"
 #include "WHLSLNativeFunctionDeclaration.h"
 #include "WHLSLNativeFunctionWriter.h"
 #include "WHLSLNativeTypeDeclaration.h"
@@ -193,7 +193,7 @@ void FunctionDefinitionWriter::visit(AST::NativeFunctionDeclaration& nativeFunct
 {
     auto iterator = m_functionMapping.find(&nativeFunctionDeclaration);
     ASSERT(iterator != m_functionMapping.end());
-    m_stringBuilder.append(writeNativeFunction(nativeFunctionDeclaration, iterator->value, m_typeNamer));
+    m_stringBuilder.append(writeNativeFunction(nativeFunctionDeclaration, iterator->value, m_intrinsics, m_typeNamer));
 }
 
 void FunctionDefinitionWriter::visit(AST::FunctionDefinition& functionDefinition)
@@ -207,7 +207,7 @@ void FunctionDefinitionWriter::visit(AST::FunctionDefinition& functionDefinition
         m_entryPointScaffolding = WTFMove(entryPointScaffolding);
         m_stringBuilder.append(m_entryPointScaffolding->helperTypes());
         m_stringBuilder.append('\n');
-        m_stringBuilder.append(makeString(m_entryPointScaffolding->signature(iterator->value), " {"));
+        m_stringBuilder.append(makeString(m_entryPointScaffolding->signature(iterator->value), " {\n"));
         m_stringBuilder.append(m_entryPointScaffolding->unpack());
         for (size_t i = 0; i < functionDefinition.parameters().size(); ++i) {
             auto addResult = m_variableMapping.add(&functionDefinition.parameters()[i], m_entryPointScaffolding->parameterVariables()[i]);
@@ -261,8 +261,8 @@ void FunctionDefinitionWriter::visit(AST::Break&)
 
 void FunctionDefinitionWriter::visit(AST::Continue&)
 {
-    // FIXME: Figure out which loop we're in, and run the increment code
-    CRASH();
+    // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195808 Figure out which loop we're in, and run the increment code
+    notImplemented();
 }
 
 void FunctionDefinitionWriter::visit(AST::DoWhileLoop& doWhileLoop)
@@ -350,14 +350,14 @@ void FunctionDefinitionWriter::visit(AST::SwitchCase& switchCase)
     else
         m_stringBuilder.append("default:\n");
     checkErrorAndVisit(switchCase.block());
-    // FIXME: Figure out whether we need to break or fallthrough.
-    CRASH();
+    // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195812 Figure out whether we need to break or fallthrough.
+    notImplemented();
 }
 
 void FunctionDefinitionWriter::visit(AST::Trap&)
 {
-    // FIXME: Implement this
-    CRASH();
+    // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195811 Implement this
+    notImplemented();
 }
 
 void FunctionDefinitionWriter::visit(AST::VariableDeclarationsStatement& variableDeclarationsStatement)
@@ -447,35 +447,41 @@ void FunctionDefinitionWriter::visit(AST::Expression& expression)
 void FunctionDefinitionWriter::visit(AST::DotExpression&)
 {
     // This should be lowered already.
-    ASSERT_NOT_REACHED();
+    // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
+    notImplemented();
+    m_stack.append("dummy");
 }
 
 void FunctionDefinitionWriter::visit(AST::IndexExpression&)
 {
     // This should be lowered already.
-    ASSERT_NOT_REACHED();
+    // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
+    notImplemented();
+    m_stack.append("dummy");
 }
 
 void FunctionDefinitionWriter::visit(AST::PropertyAccessExpression&)
 {
-    ASSERT_NOT_REACHED();
+    // This should be lowered already.
+    // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Replace this with ASSERT_NOT_REACHED().
+    notImplemented();
+    m_stack.append("dummy");
 }
 
 void FunctionDefinitionWriter::visit(AST::VariableDeclaration& variableDeclaration)
 {
     ASSERT(variableDeclaration.type());
-    if (variableDeclaration.initializer())
-        checkErrorAndVisit(*variableDeclaration.initializer());
-    else {
-        // FIXME: Zero-fill the variable.
-        CRASH();
-    }
-    // FIXME: Implement qualifiers.
     auto variableName = generateNextVariableName();
     auto addResult = m_variableMapping.add(&variableDeclaration, variableName);
     ASSERT_UNUSED(addResult, addResult.isNewEntry);
-    m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = ", m_stack.takeLast(), ";\n"));
-    m_stack.append(variableName);
+    // FIXME: Implement qualifiers.
+    if (variableDeclaration.initializer()) {
+        checkErrorAndVisit(*variableDeclaration.initializer());
+        m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, " = ", m_stack.takeLast(), ";\n"));
+    } else {
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195771 Zero-fill the variable.
+        m_stringBuilder.append(makeString(m_typeNamer.mangledNameForType(*variableDeclaration.type()), ' ', variableName, ";\n"));
+    }
 }
 
 void FunctionDefinitionWriter::visit(AST::AssignmentExpression& assignmentExpression)
@@ -485,6 +491,7 @@ void FunctionDefinitionWriter::visit(AST::AssignmentExpression& assignmentExpres
     checkErrorAndVisit(assignmentExpression.right());
     auto rightName = m_stack.takeLast();
     m_stringBuilder.append(makeString(leftName, " = ", rightName, ";\n"));
+    m_stack.append(leftName);
 }
 
 void FunctionDefinitionWriter::visit(AST::CallExpression& callExpression)
@@ -650,24 +657,10 @@ public:
     {
     }
 
-    MappedBindGroups&& takeVertexMappedBindGroups()
-    {
-        ASSERT(m_vertexMappedBindGroups);
-        return WTFMove(*m_vertexMappedBindGroups);
-    }
-
-    MappedBindGroups&& takeFragmentMappedBindGroups()
-    {
-        ASSERT(m_fragmentMappedBindGroups);
-        return WTFMove(*m_fragmentMappedBindGroups);
-    }
-
 private:
     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
 
     MatchedRenderSemantics m_matchedSemantics;
-    Optional<MappedBindGroups> m_vertexMappedBindGroups;
-    Optional<MappedBindGroups> m_fragmentMappedBindGroups;
 };
 
 std::unique_ptr<EntryPointScaffolding> RenderFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
@@ -675,18 +668,10 @@ std::unique_ptr<EntryPointScaffolding> RenderFunctionDefinitionWriter::createEnt
     auto generateNextVariableName = [this]() -> String {
         return this->generateNextVariableName();
     };
-    if (&functionDefinition == m_matchedSemantics.vertexShader) {
-        auto result = std::make_unique<VertexEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.vertexShaderEntryPointItems, m_matchedSemantics.vertexShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedVertexAttributes);
-        ASSERT(!m_vertexMappedBindGroups);
-        m_vertexMappedBindGroups = result->mappedBindGroups();
-        return result;
-    }
-    if (&functionDefinition == m_matchedSemantics.fragmentShader) {
-        auto result = std::make_unique<FragmentEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.fragmentShaderEntryPointItems, m_matchedSemantics.fragmentShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedColorAttachments);
-        ASSERT(!m_fragmentMappedBindGroups);
-        m_fragmentMappedBindGroups = result->mappedBindGroups();
-        return result;
-    }
+    if (&functionDefinition == m_matchedSemantics.vertexShader)
+        return std::make_unique<VertexEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.vertexShaderEntryPointItems, m_matchedSemantics.vertexShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedVertexAttributes);
+    if (&functionDefinition == m_matchedSemantics.fragmentShader)
+        return std::make_unique<FragmentEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.fragmentShaderEntryPointItems, m_matchedSemantics.fragmentShaderResourceMap, m_layout, WTFMove(generateNextVariableName), m_matchedSemantics.matchedColorAttachments);
     return nullptr;
 }
 
@@ -698,17 +683,10 @@ public:
     {
     }
 
-    MappedBindGroups&& takeMappedBindGroups()
-    {
-        ASSERT(m_mappedBindGroups);
-        return WTFMove(*m_mappedBindGroups);
-    }
-
 private:
     std::unique_ptr<EntryPointScaffolding> createEntryPointScaffolding(AST::FunctionDefinition&) override;
 
     MatchedComputeSemantics m_matchedSemantics;
-    Optional<MappedBindGroups> m_mappedBindGroups;
 };
 
 std::unique_ptr<EntryPointScaffolding> ComputeFunctionDefinitionWriter::createEntryPointScaffolding(AST::FunctionDefinition& functionDefinition)
@@ -716,12 +694,8 @@ std::unique_ptr<EntryPointScaffolding> ComputeFunctionDefinitionWriter::createEn
     auto generateNextVariableName = [this]() -> String {
         return this->generateNextVariableName();
     };
-    if (&functionDefinition == m_matchedSemantics.shader) {
-        auto result = std::make_unique<ComputeEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.entryPointItems, m_matchedSemantics.resourceMap, m_layout, WTFMove(generateNextVariableName));
-        ASSERT(!m_mappedBindGroups);
-        m_mappedBindGroups = result->mappedBindGroups();
-        return result;
-    }
+    if (&functionDefinition == m_matchedSemantics.shader)
+        return std::make_unique<ComputeEntryPointScaffolding>(functionDefinition, m_intrinsics, m_typeNamer, m_matchedSemantics.entryPointItems, m_matchedSemantics.resourceMap, m_layout, WTFMove(generateNextVariableName));
     return nullptr;
 }
 
@@ -766,6 +740,9 @@ RenderMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, Matc
     StringBuilder stringBuilder;
     stringBuilder.append(sharedMetalFunctions.metalFunctions);
 
+    auto* vertexShaderEntryPoint = matchedSemantics.vertexShader;
+    auto* fragmentShaderEntryPoint = matchedSemantics.fragmentShader;
+
     RenderFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations())
         functionDefinitionWriter.visit(nativeFunctionDeclaration);
@@ -775,8 +752,8 @@ RenderMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, Matc
 
     RenderMetalFunctions result;
     result.metalSource = stringBuilder.toString();
-    result.vertexMappedBindGroups = functionDefinitionWriter.takeVertexMappedBindGroups();
-    result.fragmentMappedBindGroups = functionDefinitionWriter.takeFragmentMappedBindGroups();
+    result.mangledVertexEntryPointName = sharedMetalFunctions.functionMapping.get(vertexShaderEntryPoint);
+    result.mangledFragmentEntryPointName = sharedMetalFunctions.functionMapping.get(fragmentShaderEntryPoint);
     return result;
 }
 
@@ -787,6 +764,8 @@ ComputeMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, Mat
     StringBuilder stringBuilder;
     stringBuilder.append(sharedMetalFunctions.metalFunctions);
 
+    auto* entryPoint = matchedSemantics.shader;
+
     ComputeFunctionDefinitionWriter functionDefinitionWriter(program.intrinsics(), typeNamer, sharedMetalFunctions.functionMapping, WTFMove(matchedSemantics), layout);
     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations())
         functionDefinitionWriter.visit(nativeFunctionDeclaration);
@@ -796,7 +775,7 @@ ComputeMetalFunctions metalFunctions(Program& program, TypeNamer& typeNamer, Mat
 
     ComputeMetalFunctions result;
     result.metalSource = stringBuilder.toString();
-    result.mappedBindGroups = functionDefinitionWriter.takeMappedBindGroups();
+    result.mangledEntryPointName = sharedMetalFunctions.functionMapping.get(entryPoint);
     return result;
 }
 
index 062821f..dfa46ac 100644 (file)
@@ -27,7 +27,6 @@
 
 #if ENABLE(WEBGPU)
 
-#include "WHLSLMappedBindings.h"
 #include "WHLSLSemanticMatcher.h"
 
 namespace WebCore {
@@ -42,14 +41,14 @@ class TypeNamer;
 
 struct RenderMetalFunctions {
     String metalSource;
-    MappedBindGroups vertexMappedBindGroups;
-    MappedBindGroups fragmentMappedBindGroups;
+    String mangledVertexEntryPointName;
+    String mangledFragmentEntryPointName;
 };
 RenderMetalFunctions metalFunctions(Program&, TypeNamer&, MatchedRenderSemantics&&, Layout&);
 
 struct ComputeMetalFunctions {
     String metalSource;
-    MappedBindGroups mappedBindGroups;
+    String mangledEntryPointName;
 };
 ComputeMetalFunctions metalFunctions(Program&, TypeNamer&, MatchedComputeSemantics&&, Layout&);
 
index 305f338..e9eb9de 100644 (file)
@@ -48,7 +48,7 @@ static String generateMetalCodeShared(String&& metalTypes, String&& metalFunctio
     stringBuilder.append("#include <metal_compute>\n");
     stringBuilder.append("#include <metal_texture>\n");
     stringBuilder.append("\n");
-    stringBuilder.append("using namespace metal;\n"); // FIXME: Probably should qualify all calls to built-in functions, instead of using this line.
+    stringBuilder.append("using namespace metal;\n");
     stringBuilder.append("\n");
 
     stringBuilder.append(WTFMove(metalTypes));
@@ -62,7 +62,7 @@ RenderMetalCode generateMetalCode(Program& program, MatchedRenderSemantics&& mat
     auto metalTypes = typeNamer.metalTypes();
     auto metalFunctions = Metal::metalFunctions(program, typeNamer, WTFMove(matchedSemantics), layout);
     auto metalCode = generateMetalCodeShared(WTFMove(metalTypes), WTFMove(metalFunctions.metalSource));
-    return { WTFMove(metalCode), WTFMove(metalFunctions.vertexMappedBindGroups), WTFMove(metalFunctions.fragmentMappedBindGroups) };
+    return { WTFMove(metalCode), WTFMove(metalFunctions.mangledVertexEntryPointName), WTFMove(metalFunctions.mangledFragmentEntryPointName) };
 }
 
 ComputeMetalCode generateMetalCode(Program& program, MatchedComputeSemantics&& matchedSemantics, Layout& layout)
@@ -71,7 +71,7 @@ ComputeMetalCode generateMetalCode(Program& program, MatchedComputeSemantics&& m
     auto metalTypes = typeNamer.metalTypes();
     auto metalFunctions = Metal::metalFunctions(program, typeNamer, WTFMove(matchedSemantics), layout);
     auto metalCode = generateMetalCodeShared(WTFMove(metalTypes), WTFMove(metalFunctions.metalSource));
-    return { WTFMove(metalCode), WTFMove(metalFunctions.mappedBindGroups) };
+    return { WTFMove(metalCode), WTFMove(metalFunctions.mangledEntryPointName) };
 }
 
 }
index a6b3d31..ea2b3ef 100644 (file)
@@ -27,7 +27,6 @@
 
 #if ENABLE(WEBGPU)
 
-#include "WHLSLMappedBindings.h"
 #include "WHLSLPipelineDescriptor.h"
 #include "WHLSLSemanticMatcher.h"
 #include <wtf/Variant.h>
@@ -43,15 +42,15 @@ namespace Metal {
 
 struct RenderMetalCode {
     String metalSource;
-    MappedBindGroups vertexMappedBindGroups;
-    MappedBindGroups fragmentMappedBindGroups;
+    String mangledVertexEntryPointName;
+    String mangledFragmentEntryPointName;
 };
 // Can't fail. Any failure checks need to be done earlier, in the backend-agnostic part of the compiler.
 RenderMetalCode generateMetalCode(Program&, MatchedRenderSemantics&& matchedSemantics, Layout&);
 
 struct ComputeMetalCode {
     String metalSource;
-    MappedBindGroups bindGroups;
+    String mangledEntryPointName;
 };
 // Can't fail. Any failure checks need to be done earlier, in the backend-agnostic part of the compiler.
 ComputeMetalCode generateMetalCode(Program&, MatchedComputeSemantics&& matchedSemantics, Layout&);
index b4d3706..6972e98 100644 (file)
 
 #if ENABLE(WEBGPU)
 
+#include "NotImplemented.h"
 #include "WHLSLAddressSpace.h"
+#include "WHLSLArrayType.h"
+#include "WHLSLInferTypes.h"
+#include "WHLSLIntrinsics.h"
 #include "WHLSLNamedType.h"
 #include "WHLSLNativeFunctionDeclaration.h"
 #include "WHLSLNativeTypeDeclaration.h"
@@ -44,15 +48,6 @@ namespace WHLSL {
 
 namespace Metal {
 
-static String getNativeName(AST::UnnamedType& unnamedType, TypeNamer& typeNamer)
-{
-    ASSERT(is<AST::NamedType>(unnamedType.unifyNode()));
-    auto& namedType = downcast<AST::NamedType>(unnamedType.unifyNode());
-    ASSERT(is<AST::NativeTypeDeclaration>(namedType));
-    auto& nativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(namedType);
-    return typeNamer.mangledNameForType(nativeTypeDeclaration);
-}
-
 static String mapFunctionName(String& functionName)
 {
     if (functionName == "ddx")
@@ -101,22 +96,70 @@ static String atomicName(String input)
         return "fetch_xor"_str;
 }
 
-String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclaration, String& outputFunctionName, TypeNamer& typeNamer)
+String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclaration, String& outputFunctionName, Intrinsics& intrinsics, TypeNamer& typeNamer)
 {
     StringBuilder stringBuilder;
     if (nativeFunctionDeclaration.isCast()) {
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        if (!nativeFunctionDeclaration.parameters().size()) {
+            stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, "() {\n"));
+            stringBuilder.append(makeString("    ", metalReturnName, " x;\n"));
+            // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195771 Zero-fill
+            stringBuilder.append("    return x;\n");
+            stringBuilder.append("}\n");
+            return stringBuilder.toString();
+        }
+
+        ASSERT(nativeFunctionDeclaration.parameters().size() == 1);
+        auto metalParameterName = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto& parameterType = nativeFunctionDeclaration.parameters()[0].type()->unifyNode();
+        if (is<AST::NamedType>(parameterType)) {
+            auto& parameterNamedType = downcast<AST::NamedType>(parameterType);
+            if (is<AST::NativeTypeDeclaration>(parameterNamedType)) {
+                auto& parameterNativeTypeDeclaration = downcast<AST::NativeTypeDeclaration>(parameterNamedType);
+                if (parameterNativeTypeDeclaration.isAtomic()) {
+                    stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, " x) {\n"));
+                    stringBuilder.append("    return atomic_load_explicit(&x, memory_order_relaxed);\n");
+                    stringBuilder.append("}\n");
+                    return stringBuilder.toString();
+                }
+            }
+        }
+
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, " x) {\n"));
+        stringBuilder.append(makeString("    return static_cast<", metalReturnName, ">(x);\n"));
+        stringBuilder.append("}\n");
+        return stringBuilder.toString();
+    }
+
+    if (nativeFunctionDeclaration.name() == "operator.value") {
         ASSERT(nativeFunctionDeclaration.parameters().size() == 1);
-        auto metalParameterName = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-        auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
-        if (metalParameterName != "atomic_int"_str && metalParameterName != "atomic_uint"_str) {
-            stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, "x) {\n"));
-            stringBuilder.append(makeString("    return static_cast<", metalReturnName, ">(x);\n"));
+        auto metalParameterName = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, " x) {\n"));
+        stringBuilder.append(makeString("    return static_cast<", metalReturnName, ">(x);\n"));
+        stringBuilder.append("}\n");
+        return stringBuilder.toString();
+    }
+
+    if (nativeFunctionDeclaration.name() == "operator.length") {
+        ASSERT_UNUSED(intrinsics, matches(nativeFunctionDeclaration.type(), intrinsics.uintType()));
+        ASSERT(nativeFunctionDeclaration.parameters().size() == 1);
+        auto metalParameterName = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto& parameterType = nativeFunctionDeclaration.parameters()[0].type()->unifyNode();
+        ASSERT(is<AST::UnnamedType>(parameterType));
+        auto& unnamedParameterType = downcast<AST::UnnamedType>(parameterType);
+        if (is<AST::ArrayType>(unnamedParameterType)) {
+            auto& arrayParameterType = downcast<AST::ArrayType>(unnamedParameterType);
+            stringBuilder.append(makeString("uint ", outputFunctionName, '(', metalParameterName, " v) {\n"));
+            stringBuilder.append(makeString("    return ", arrayParameterType.numElements(), "u;\n"));
             stringBuilder.append("}\n");
             return stringBuilder.toString();
         }
 
-        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, "x) {\n"));
-        stringBuilder.append("    return atomic_load_explicit(&x, memory_order_relaxed);\n");
+        ASSERT(is<AST::ArrayReferenceType>(unnamedParameterType));
+        stringBuilder.append(makeString("uint ", outputFunctionName, '(', metalParameterName, " v) {\n"));
+        stringBuilder.append(makeString("    return v.length;\n"));
         stringBuilder.append("}\n");
         return stringBuilder.toString();
     }
@@ -124,12 +167,12 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
     if (nativeFunctionDeclaration.name().startsWith("operator."_str)) {
         if (nativeFunctionDeclaration.name().endsWith("=")) {
             ASSERT(nativeFunctionDeclaration.parameters().size() == 2);
-            auto metalParameter1Name = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-            auto metalParameter2Name = getNativeName(*nativeFunctionDeclaration.parameters()[1].type(), typeNamer);
-            auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
+            auto metalParameter1Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+            auto metalParameter2Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[1].type());
+            auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
             auto fieldName = nativeFunctionDeclaration.name().substring("operator."_str.length());
             fieldName = fieldName.substring(0, fieldName.length() - 1);
-            stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, "v, ", metalParameter2Name, " n) {\n"));
+            stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, " v, ", metalParameter2Name, " n) {\n"));
             stringBuilder.append(makeString("    v.", fieldName, " = n;\n"));
             stringBuilder.append(makeString("    return v;\n"));
             stringBuilder.append("}\n");
@@ -137,33 +180,56 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
         }
 
         ASSERT(nativeFunctionDeclaration.parameters().size() == 1);
-        auto metalParameterName = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-        auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
-        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, "v) {\n"));
-        stringBuilder.append(makeString("    return v.", nativeFunctionDeclaration.name().substring("operator."_str.length()), ";\n"));
+        auto metalParameterName = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        auto fieldName = nativeFunctionDeclaration.name().substring("operator."_str.length());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, " v) {\n"));
+        stringBuilder.append(makeString("    return v.", fieldName, ";\n"));
         stringBuilder.append("}\n");
         return stringBuilder.toString();
+    }
 
+    if (nativeFunctionDeclaration.name().startsWith("operator&."_str)) {
+        ASSERT(nativeFunctionDeclaration.parameters().size() == 1);
+        auto metalParameterName = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        auto fieldName = nativeFunctionDeclaration.name().substring("operator&."_str.length());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, " v) {\n"));
+        stringBuilder.append(makeString("    return &(v->", fieldName, ");\n"));
+        stringBuilder.append("}\n");
+        return stringBuilder.toString();
     }
 
     if (nativeFunctionDeclaration.name() == "operator[]") {
         ASSERT(nativeFunctionDeclaration.parameters().size() == 2);
-        auto metalParameter1Name = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-        auto metalParameter2Name = getNativeName(*nativeFunctionDeclaration.parameters()[1].type(), typeNamer);
-        auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
-        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, "m, ", metalParameter2Name, " i) {\n"));
+        auto metalParameter1Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalParameter2Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[1].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, " m, ", metalParameter2Name, " i) {\n"));
         stringBuilder.append(makeString("    return m[i];\n"));
         stringBuilder.append("}\n");
         return stringBuilder.toString();
     }
 
+    if (nativeFunctionDeclaration.name() == "operator&[]") {
+        ASSERT(nativeFunctionDeclaration.parameters().size() == 2);
+        auto metalParameter1Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalParameter2Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        auto fieldName = nativeFunctionDeclaration.name().substring("operator&[]."_str.length());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, " v, ", metalParameter2Name, " n) {\n"));
+        stringBuilder.append(makeString("    return &(v.pointer[n]);\n"));
+        stringBuilder.append("}\n");
+        return stringBuilder.toString();
+    }
+
     if (nativeFunctionDeclaration.name() == "operator[]=") {
         ASSERT(nativeFunctionDeclaration.parameters().size() == 3);
-        auto metalParameter1Name = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-        auto metalParameter2Name = getNativeName(*nativeFunctionDeclaration.parameters()[1].type(), typeNamer);
-        auto metalParameter3Name = getNativeName(*nativeFunctionDeclaration.parameters()[2].type(), typeNamer);
-        auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
-        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, "m, ", metalParameter2Name, " i, ", metalParameter3Name, " v) {\n"));
+        auto metalParameter1Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalParameter2Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[1].type());
+        auto metalParameter3Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[2].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, " m, ", metalParameter2Name, " i, ", metalParameter3Name, " v) {\n"));
         stringBuilder.append(makeString("    m[i] = v;\n"));
         stringBuilder.append(makeString("    return m;\n"));
         stringBuilder.append("}\n");
@@ -173,9 +239,9 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
     if (nativeFunctionDeclaration.isOperator()) {
         if (nativeFunctionDeclaration.parameters().size() == 1) {
             auto operatorName = nativeFunctionDeclaration.name().substring("operator"_str.length());
-            auto metalParameterName = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-            auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
-            stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, "x) {\n"));
+            auto metalParameterName = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+            auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+            stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, " x) {\n"));
             stringBuilder.append(makeString("    return ", operatorName, "x;\n"));
             stringBuilder.append("}\n");
             return stringBuilder.toString();
@@ -183,10 +249,10 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
 
         ASSERT(nativeFunctionDeclaration.parameters().size() == 2);
         auto operatorName = nativeFunctionDeclaration.name().substring("operator"_str.length());
-        auto metalParameter1Name = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-        auto metalParameter2Name = getNativeName(*nativeFunctionDeclaration.parameters()[1].type(), typeNamer);
-        auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
-        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, "x, ", metalParameter2Name, " y) {\n"));
+        auto metalParameter1Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalParameter2Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[1].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, " x, ", metalParameter2Name, " y) {\n"));
         stringBuilder.append(makeString("    return x ", operatorName, " y;\n"));
         stringBuilder.append("}\n");
         return stringBuilder.toString();
@@ -217,9 +283,9 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
         || nativeFunctionDeclaration.name() == "asuint"
         || nativeFunctionDeclaration.name() == "asfloat") {
         ASSERT(nativeFunctionDeclaration.parameters().size() == 1);
-        auto metalParameterName = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-        auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
-        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, "x) {\n"));
+        auto metalParameterName = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameterName, " x) {\n"));
         stringBuilder.append(makeString("    return ", mapFunctionName(nativeFunctionDeclaration.name()), "(x);\n"));
         stringBuilder.append("}\n");
         return stringBuilder.toString();
@@ -227,18 +293,18 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
 
     if (nativeFunctionDeclaration.name() == "pow" || nativeFunctionDeclaration.name() == "atan2") {
         ASSERT(nativeFunctionDeclaration.parameters().size() == 2);
-        auto metalParameter1Name = getNativeName(*nativeFunctionDeclaration.parameters()[0].type(), typeNamer);
-        auto metalParameter2Name = getNativeName(*nativeFunctionDeclaration.parameters()[1].type(), typeNamer);
-        auto metalReturnName = getNativeName(nativeFunctionDeclaration.type(), typeNamer);
-        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, "x, ", metalParameter2Name, " y) {\n"));
+        auto metalParameter1Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[0].type());
+        auto metalParameter2Name = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[1].type());
+        auto metalReturnName = typeNamer.mangledNameForType(nativeFunctionDeclaration.type());
+        stringBuilder.append(makeString(metalReturnName, ' ', outputFunctionName, '(', metalParameter1Name, " x, ", metalParameter2Name, " y) {\n"));
         stringBuilder.append(makeString("    return ", nativeFunctionDeclaration.name(), "(x, y);\n"));
         stringBuilder.append("}\n");
         return stringBuilder.toString();
     }
 
     if (nativeFunctionDeclaration.name() == "f16tof32" || nativeFunctionDeclaration.name() == "f32tof16") {
-        // FIXME: Implement this
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "AllMemoryBarrierWithGroupSync") {
@@ -273,13 +339,13 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
             ASSERT(is<AST::PointerType>(*nativeFunctionDeclaration.parameters()[0].type()));
             auto& firstArgumentPointer = downcast<AST::PointerType>(*nativeFunctionDeclaration.parameters()[0].type());
             auto firstArgumentAddressSpace = firstArgumentPointer.addressSpace();
-            auto firstArgumentPointee = getNativeName(firstArgumentPointer.elementType(), typeNamer);
-            auto secondArgument = getNativeName(*nativeFunctionDeclaration.parameters()[1].type(), typeNamer);
-            auto thirdArgument = getNativeName(*nativeFunctionDeclaration.parameters()[2].type(), typeNamer);
+            auto firstArgumentPointee = typeNamer.mangledNameForType(firstArgumentPointer.elementType());
+            auto secondArgument = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[1].type());
+            auto thirdArgument = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[2].type());
             ASSERT(is<AST::PointerType>(*nativeFunctionDeclaration.parameters()[3].type()));
             auto& fourthArgumentPointer = downcast<AST::PointerType>(*nativeFunctionDeclaration.parameters()[3].type());
             auto fourthArgumentAddressSpace = fourthArgumentPointer.addressSpace();
-            auto fourthArgumentPointee = getNativeName(fourthArgumentPointer.elementType(), typeNamer);
+            auto fourthArgumentPointee = typeNamer.mangledNameForType(fourthArgumentPointer.elementType());
             stringBuilder.append(makeString("void ", outputFunctionName, '(', convertAddressSpace(firstArgumentAddressSpace), ' ', firstArgumentPointee, "* object, ", secondArgument, " compare, ", thirdArgument, " desired, ", convertAddressSpace(fourthArgumentAddressSpace), ' ', fourthArgumentPointee, "* out) {\n"));
             stringBuilder.append("    atomic_compare_exchange_weak_explicit(object, &compare, desired, memory_order_relaxed);\n");
             stringBuilder.append("    *out = compare;\n");
@@ -291,12 +357,12 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
         ASSERT(is<AST::PointerType>(*nativeFunctionDeclaration.parameters()[0].type()));
         auto& firstArgumentPointer = downcast<AST::PointerType>(*nativeFunctionDeclaration.parameters()[0].type());
         auto firstArgumentAddressSpace = firstArgumentPointer.addressSpace();
-        auto firstArgumentPointee = getNativeName(firstArgumentPointer.elementType(), typeNamer);
-        auto secondArgument = getNativeName(*nativeFunctionDeclaration.parameters()[1].type(), typeNamer);
+        auto firstArgumentPointee = typeNamer.mangledNameForType(firstArgumentPointer.elementType());
+        auto secondArgument = typeNamer.mangledNameForType(*nativeFunctionDeclaration.parameters()[1].type());
         ASSERT(is<AST::PointerType>(*nativeFunctionDeclaration.parameters()[2].type()));
         auto& thirdArgumentPointer = downcast<AST::PointerType>(*nativeFunctionDeclaration.parameters()[2].type());
         auto thirdArgumentAddressSpace = thirdArgumentPointer.addressSpace();
-        auto thirdArgumentPointee = getNativeName(thirdArgumentPointer.elementType(), typeNamer);
+        auto thirdArgumentPointee = typeNamer.mangledNameForType(thirdArgumentPointer.elementType());
         auto name = atomicName(nativeFunctionDeclaration.name().substring("Interlocked"_str.length()));
         stringBuilder.append(makeString("void ", outputFunctionName, '(', convertAddressSpace(firstArgumentAddressSpace), ' ', firstArgumentPointee, "* object, ", secondArgument, " operand, ", convertAddressSpace(thirdArgumentAddressSpace), ' ', thirdArgumentPointee, "* out) {\n"));
         stringBuilder.append(makeString("    *out = atomic_fetch_", name, "_explicit(object, operand, memory_order_relaxed);\n"));
@@ -305,83 +371,83 @@ String writeNativeFunction(AST::NativeFunctionDeclaration& nativeFunctionDeclara
     }
 
     if (nativeFunctionDeclaration.name() == "Sample") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "Load") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "GetDimensions") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "SampleBias") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "SampleGrad") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "SampleLevel") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "Gather") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "GatherRed") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "SampleCmp") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "SampleCmpLevelZero") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "Store") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "GatherAlpha") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "GatherBlue") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "GatherCmp") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "GatherCmpRed") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     if (nativeFunctionDeclaration.name() == "GatherGreen") {
-        // FIXME: Implement this.
-        CRASH();
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195813 Implement this
+        notImplemented();
     }
 
     // FIXME: Add all the functions that the compiler generated.
index 2e9bd4e..156245a 100644 (file)
@@ -43,7 +43,7 @@ namespace Metal {
 
 class TypeNamer;
 
-String writeNativeFunction(AST::NativeFunctionDeclaration&, String& outputFunctionName, TypeNamer&);
+String writeNativeFunction(AST::NativeFunctionDeclaration&, String& outputFunctionName, Intrinsics&, TypeNamer&);
 
 }
 
index a7c9bf5..45e60fd 100644 (file)
@@ -96,7 +96,7 @@ String writeNativeType(AST::NativeTypeDeclaration& nativeTypeDeclaration)
             return "float";
         })();
         ASSERT(WTF::holds_alternative<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[1]));
-        auto& constantExpression = WTF::get<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[0]);
+        auto& constantExpression = WTF::get<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[1]);
         auto& integerLiteral = constantExpression.integerLiteral();
         auto suffix = ([&]() -> String {
             switch (integerLiteral.value()) {
@@ -127,7 +127,7 @@ String writeNativeType(AST::NativeTypeDeclaration& nativeTypeDeclaration)
             return "float";
         })();
         ASSERT(WTF::holds_alternative<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[1]));
-        auto& constantExpression1 = WTF::get<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[0]);
+        auto& constantExpression1 = WTF::get<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[1]);
         auto& integerLiteral1 = constantExpression1.integerLiteral();
         auto middle = ([&]() -> String {
             switch (integerLiteral1.value()) {
@@ -141,7 +141,7 @@ String writeNativeType(AST::NativeTypeDeclaration& nativeTypeDeclaration)
             }
         })();
         ASSERT(WTF::holds_alternative<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[2]));
-        auto& constantExpression2 = WTF::get<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[0]);
+        auto& constantExpression2 = WTF::get<AST::ConstantExpression>(nativeTypeDeclaration.typeArguments()[2]);
         auto& integerLiteral2 = constantExpression2.integerLiteral();
         auto suffix = ([&]() -> String {
             switch (integerLiteral2.value()) {
index 6ff36f9..1b253d0 100644 (file)
@@ -31,6 +31,7 @@
 #include "WHLSLAddressSpace.h"
 #include "WHLSLArrayReferenceType.h"
 #include "WHLSLArrayType.h"
+#include "WHLSLCallExpression.h"
 #include "WHLSLEnumerationDefinition.h"
 #include "WHLSLEnumerationMember.h"
 #include "WHLSLNativeTypeDeclaration.h"
@@ -174,27 +175,6 @@ TypeNamer::TypeNamer(Program& program)
 
 TypeNamer::~TypeNamer() = default;
 
-void TypeNamer::visit(AST::UnnamedType& unnamedType)
-{
-    insert(unnamedType, m_trie);
-}
-
-void TypeNamer::visit(AST::EnumerationDefinition& enumerationDefinition)
-{
-    auto addResult = m_namedTypeMapping.add(&enumerationDefinition, generateNextTypeName());
-    ASSERT_UNUSED(addResult, addResult.isNewEntry);
-    for (auto& enumerationMember : enumerationDefinition.enumerationMembers()) {
-        auto addResult = m_enumerationMemberMapping.add(&static_cast<AST::EnumerationMember&>(enumerationMember), generateNextEnumerationMemberName());
-        ASSERT_UNUSED(addResult, addResult.isNewEntry);
-    }
-    Visitor::visit(enumerationDefinition);
-}
-
-void TypeNamer::visit(AST::NativeTypeDeclaration&)
-{
-    // Native type declarations already have names, and are already declared in Metal.
-}
-
 static Vector<UniqueRef<BaseTypeNameNode>>::iterator findInVector(AST::UnnamedType& unnamedType, Vector<UniqueRef<BaseTypeNameNode>>& types)
 {
     return std::find_if(types.begin(), types.end(), [&](BaseTypeNameNode& baseTypeNameNode) -> bool {
@@ -230,6 +210,39 @@ static BaseTypeNameNode& find(AST::UnnamedType& unnamedType, Vector<UniqueRef<Ba
     return *iterator;
 }
 
+void TypeNamer::visit(AST::UnnamedType& unnamedType)
+{
+    insert(unnamedType, m_trie);
+}
+
+void TypeNamer::visit(AST::EnumerationDefinition& enumerationDefinition)
+{
+    {
+        auto addResult = m_namedTypeMapping.add(&enumerationDefinition, generateNextTypeName());
+        ASSERT_UNUSED(addResult, addResult.isNewEntry);
+    }
+
+    for (auto& enumerationMember : enumerationDefinition.enumerationMembers()) {
+        auto addResult = m_enumerationMemberMapping.add(&static_cast<AST::EnumerationMember&>(enumerationMember), generateNextEnumerationMemberName());
+        ASSERT_UNUSED(addResult, addResult.isNewEntry);
+    }
+
+    Visitor::visit(enumerationDefinition);
+
+    {
+        Vector<std::reference_wrapper<BaseTypeNameNode>> neighbors = { find(enumerationDefinition.type(), m_trie) };
+        auto addResult = m_dependencyGraph.add(&enumerationDefinition, WTFMove(neighbors));
+        ASSERT_UNUSED(addResult, addResult.isNewEntry);
+    }
+}
+
+void TypeNamer::visit(AST::NativeTypeDeclaration& nativeTypeDeclaration)
+{
+    // Native type declarations already have names, and are already declared in Metal.
+    auto addResult = m_dependencyGraph.add(&nativeTypeDeclaration, Vector<std::reference_wrapper<BaseTypeNameNode>>());
+    ASSERT_UNUSED(addResult, addResult.isNewEntry);
+}
+
 void TypeNamer::visit(AST::StructureDefinition& structureDefinition)
 {
     {
@@ -263,6 +276,19 @@ void TypeNamer::visit(AST::TypeDefinition& typeDefinition)
     }
 }
 
+void TypeNamer::visit(AST::Expression& expression)
+{
+    ASSERT(expression.resolvedType());
+    insert(*expression.resolvedType(), m_trie);
+    Visitor::visit(expression);
+}
+
+void TypeNamer::visit(AST::CallExpression& callExpression)
+{
+    for (auto& argument : callExpression.arguments())
+        checkErrorAndVisit(argument);
+}
+
 String TypeNamer::mangledNameForType(AST::NativeTypeDeclaration& nativeTypeDeclaration)
 {
     return writeNativeType(nativeTypeDeclaration);
@@ -328,34 +354,20 @@ size_t TypeNamer::insert(AST::UnnamedType& unnamedType, Vector<UniqueRef<BaseTyp
 
 class MetalTypeDeclarationWriter : public Visitor {
 public:
-    MetalTypeDeclarationWriter(std::function<String(AST::NamedType&)>&& mangledNameForNamedType, std::function<String(AST::UnnamedType&)>&& mangledNameForUnnamedType, std::function<String(AST::EnumerationMember&)>&& mangledNameForEnumerationMember)
+    MetalTypeDeclarationWriter(std::function<String(AST::NamedType&)>&& mangledNameForNamedType)
         : m_mangledNameForNamedType(WTFMove(mangledNameForNamedType))
-        , m_mangledNameForUnnamedType(WTFMove(mangledNameForUnnamedType))
-        , m_mangledNameForEnumerationMember(WTFMove(mangledNameForEnumerationMember))
     {
     }
 
     String toString() { return m_stringBuilder.toString(); }
 
 private:
-    void visit(AST::EnumerationDefinition& enumerationDefinition)
-    {
-        auto& baseType = enumerationDefinition.type().unifyNode();
-        ASSERT(is<AST::NamedType>(baseType));
-        m_stringBuilder.append(makeString("enum class ", m_mangledNameForNamedType(enumerationDefinition), " : ", m_mangledNameForNamedType(downcast<AST::NamedType>(baseType)), " {\n"));
-        for (auto& enumerationMember : enumerationDefinition.enumerationMembers())
-            m_stringBuilder.append(makeString("    ", m_mangledNameForEnumerationMember(enumerationMember), ",\n"));
-        m_stringBuilder.append("};\n");
-    }
-
     void visit(AST::StructureDefinition& structureDefinition)
     {
         m_stringBuilder.append(makeString("struct ", m_mangledNameForNamedType(structureDefinition), ";\n"));
     }
 
     std::function<String(AST::NamedType&)> m_mangledNameForNamedType;
-    std::function<String(AST::UnnamedType&)> m_mangledNameForUnnamedType;
-    std::function<String(AST::EnumerationMember&)>&& m_mangledNameForEnumerationMember;
     StringBuilder m_stringBuilder;
 };
 
@@ -363,10 +375,6 @@ String TypeNamer::metalTypeDeclarations()
 {
     MetalTypeDeclarationWriter metalTypeDeclarationWriter([&](AST::NamedType& namedType) -> String {
         return mangledNameForType(namedType);
-    }, [&](AST::UnnamedType& unnamedType) -> String {
-        return mangledNameForType(unnamedType);
-    }, [&](AST::EnumerationMember& enumerationMember) -> String {
-        return mangledNameForEnumerationMember(enumerationMember);
     });
     metalTypeDeclarationWriter.Visitor::visit(m_program);
     return metalTypeDeclarationWriter.toString();
@@ -406,7 +414,7 @@ void TypeNamer::emitUnnamedTypeDefinition(BaseTypeNameNode& baseTypeNameNode, Ha
         ASSERT(baseTypeNameNode.parent());
         stringBuilder.append(makeString("struct ", arrayReferenceType.mangledName(), "{ \n"));
         stringBuilder.append(makeString("    ", toString(arrayReferenceType.addressSpace()), " ", arrayReferenceType.parent()->mangledName(), "* pointer;\n"));
-        stringBuilder.append("    unsigned length;\n");
+        stringBuilder.append("    uint length;\n");
         stringBuilder.append("};\n");
     } else {
         ASSERT(is<ArrayTypeNameNode>(baseTypeNameNode));
@@ -426,7 +434,13 @@ void TypeNamer::emitNamedTypeDefinition(AST::NamedType& namedType, HashSet<AST::
     for (auto& baseTypeNameNode : iterator->value)
         emitUnnamedTypeDefinition(baseTypeNameNode, emittedNamedTypes, emittedUnnamedTypes, stringBuilder);
     if (is<AST::EnumerationDefinition>(namedType)) {
-        // We already emitted this in the type declaration section. There's nothing to do.
+        auto& enumerationDefinition = downcast<AST::EnumerationDefinition>(namedType);
+        auto& baseType = enumerationDefinition.type().unifyNode();
+        ASSERT(is<AST::NamedType>(baseType));
+        stringBuilder.append(makeString("enum class ", mangledNameForType(enumerationDefinition), " : ", mangledNameForType(downcast<AST::NamedType>(baseType)), " {\n"));
+        for (auto& enumerationMember : enumerationDefinition.enumerationMembers())
+            stringBuilder.append(makeString("    ", mangledNameForEnumerationMember(enumerationMember), ",\n"));
+        stringBuilder.append("};\n");
     } else if (is<AST::NativeTypeDeclaration>(namedType)) {
         // Native types already have definitions. There's nothing to do.
     } else if (is<AST::StructureDefinition>(namedType)) {
@@ -443,6 +457,14 @@ void TypeNamer::emitNamedTypeDefinition(AST::NamedType& namedType, HashSet<AST::
     emittedNamedTypes.add(&namedType);
 }
 
+void TypeNamer::emitAllUnnamedTypeDefinitions(Vector<UniqueRef<BaseTypeNameNode>>& nodes, HashSet<AST::NamedType*>& emittedNamedTypes, HashSet<BaseTypeNameNode*>& emittedUnnamedTypes, StringBuilder& stringBuilder)
+{
+    for (auto& node : nodes) {
+        emitUnnamedTypeDefinition(node, emittedNamedTypes, emittedUnnamedTypes, stringBuilder);
+        emitAllUnnamedTypeDefinitions(node->children(), emittedNamedTypes, emittedUnnamedTypes, stringBuilder);
+    }
+}
+
 String TypeNamer::metalTypeDefinitions()
 {
     HashSet<AST::NamedType*> emittedNamedTypes;
@@ -450,8 +472,7 @@ String TypeNamer::metalTypeDefinitions()
     StringBuilder stringBuilder;
     for (auto& keyValuePair : m_dependencyGraph)
         emitNamedTypeDefinition(*keyValuePair.key, emittedNamedTypes, emittedUnnamedTypes, stringBuilder);
-    for (auto& baseTypeNameNode : m_trie)
-        emitUnnamedTypeDefinition(baseTypeNameNode, emittedNamedTypes, emittedUnnamedTypes, stringBuilder);
+    emitAllUnnamedTypeDefinitions(m_trie, emittedNamedTypes, emittedUnnamedTypes, stringBuilder);
     return stringBuilder.toString();
 }
 
index 63cd5fe..290866b 100644 (file)
@@ -85,6 +85,8 @@ private:
     void visit(AST::NativeTypeDeclaration&) override;
     void visit(AST::StructureDefinition&) override;
     void visit(AST::TypeDefinition&) override;
+    void visit(AST::Expression&) override;
+    void visit(AST::CallExpression&) override;
 
     String generateNextEnumerationMemberName()
     {
@@ -93,6 +95,7 @@ private:
 
     void emitNamedTypeDefinition(AST::NamedType&, HashSet<AST::NamedType*>& emittedNamedTypes, HashSet<BaseTypeNameNode*>& emittedUnnamedTypes, StringBuilder&);
     void emitUnnamedTypeDefinition(BaseTypeNameNode&, HashSet<AST::NamedType*>& emittedNamedTypes, HashSet<BaseTypeNameNode*>& emittedUnnamedTypes, StringBuilder&);
+    void emitAllUnnamedTypeDefinitions(Vector<UniqueRef<BaseTypeNameNode>>&, HashSet<AST::NamedType*>& emittedNamedTypes, HashSet<BaseTypeNameNode*>& emittedUnnamedTypes, StringBuilder&);
     String metalTypeDeclarations();
     String metalTypeDefinitions();
 
index 868bfcc..61d348a 100644 (file)
@@ -59,7 +59,7 @@ bool checkDuplicateFunctions(const Program& program)
         return false;
     });
     for (size_t i = 0; i < functions.size(); ++i) {
-        for (size_t j = i + 1; j < functions.size(); ++i) {
+        for (size_t j = i + 1; j < functions.size(); ++j) {
             if (functions[i].get().name() != functions[j].get().name())
                 break;
             if (is<AST::NativeFunctionDeclaration>(functions[i].get()) && is<AST::NativeFunctionDeclaration>(functions[j].get()))
index 294dc98..1365f26 100644 (file)
@@ -140,60 +140,76 @@ static AST::NativeFunctionDeclaration resolveWithReferenceComparator(AST::CallEx
 {
     const bool isOperator = true;
     auto returnType = AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.boolType());
-    auto argumentType = WTF::visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
+    auto argumentType = firstArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
         return unnamedType->clone();
-    }, [&](Ref<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
-        return WTF::visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
+    }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
+        return secondArgument.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> UniqueRef<AST::UnnamedType> {
             return unnamedType->clone();
-        }, [&](Ref<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
+        }, [&](RefPtr<ResolvableTypeReference>&) -> UniqueRef<AST::UnnamedType> {
             // We encountered "null == null".
-            // The type isn't observable, so we can pick whatever we want.
             // FIXME: This can probably be generalized, using the "preferred type" infrastructure used by generic literals
+            ASSERT_NOT_REACHED();
             return AST::TypeReference::wrap(Lexer::Token(callExpression.origin()), intrinsics.intType());
-        }), secondArgument);
-    }), firstArgument);
+        }));
+    }));
     AST::VariableDeclarations parameters;
     parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { argumentType->clone() }, String(), WTF::nullopt, WTF::nullopt));
     parameters.append(AST::VariableDeclaration(Lexer::Token(callExpression.origin()), AST::Qualifiers(), { WTFMove(argumentType) }, String(), WTF::nullopt, WTF::nullopt));
     return AST::NativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(callExpression.origin()), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), String("operator==", String::ConstructFromLiteral), WTFMove(parameters), WTF::nullopt, isOperator));
 }
 
+enum class Acceptability {
+    Yes,
+    Maybe,
+    No
+};
+
 static Optional<AST::NativeFunctionDeclaration> resolveByInstantiation(AST::CallExpression& callExpression, const Vector<std::reference_wrapper<ResolvingType>>& types, const Intrinsics& intrinsics)
 {
     if (callExpression.name() == "operator&[]" && types.size() == 2) {
-        auto* firstArgumentArrayRef = WTF::visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* {
+        auto* firstArgumentArrayRef = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::ArrayReferenceType* {
             if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
                 return &downcast<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType));
             return nullptr;
-        }, [](Ref<ResolvableTypeReference>&) -> AST::ArrayReferenceType* {
+        }, [](RefPtr<ResolvableTypeReference>&) -> AST::ArrayReferenceType* {
             return nullptr;
-        }), types[0].get());
-        bool secondArgumentIsUint = WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
+        }));
+        bool secondArgumentIsUint = types[1].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
             return matches(unnamedType, intrinsics.uintType());
-        }, [&](Ref<ResolvableTypeReference>& resolvableTypeReference) -> bool {
+        }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
             return resolvableTypeReference->resolvableType().canResolve(intrinsics.uintType());
-        }), types[1].get());
+        }));
         if (firstArgumentArrayRef && secondArgumentIsUint)
             return resolveWithOperatorAnderIndexer(callExpression, *firstArgumentArrayRef, intrinsics);
     } else if (callExpression.name() == "operator.length" && types.size() == 1) {
-        auto* firstArgumentReference = WTF::visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
-            if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) || is<AST::ArrayType>(static_cast<AST::UnnamedType&>(unnamedType)))
+        auto* firstArgumentReference = types[0].get().visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
+            if (is<AST::ArrayReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)))
                 return &unnamedType;
             return nullptr;
-        }, [](Ref<ResolvableTypeReference>&) -> AST::UnnamedType* {
+        }, [](RefPtr<ResolvableTypeReference>&) -> AST::UnnamedType* {
             return nullptr;
-        }), types[0].get());
+        }));
         if (firstArgumentReference)
             return resolveWithOperatorLength(callExpression, *firstArgumentReference, intrinsics);
     } else if (callExpression.name() == "operator==" && types.size() == 2) {
-        auto isAcceptable = [](ResolvingType& resolvingType) -> bool {
-            return WTF::visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
-                return is<AST::ReferenceType>(static_cast<AST::UnnamedType&>(unnamedType));
-            }, [](Ref<ResolvableTypeReference>& resolvableTypeReference) -> bool {
-                return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType());
-            }), resolvingType);
+        auto acceptability = [](ResolvingType& resolvingType) -> Acceptability {
+            return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& unnamedType) -> Acceptability {
+                return is<AST::ReferenceType>(static_cast<AST::UnnamedType&>(unnamedType)) ? Acceptability::Yes : Acceptability::No;
+            }, [](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> Acceptability {
+                return is<AST::NullLiteralType>(resolvableTypeReference->resolvableType()) ? Acceptability::Maybe : Acceptability::No;
+            }));
         };
-        if (isAcceptable(types[0].get()) && isAcceptable(types[1].get()))
+        auto leftAcceptability = acceptability(types[0].get());
+        auto rightAcceptability = acceptability(types[1].get());
+        bool success = false;
+        if (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Yes) {
+            auto& unnamedType1 = types[0].get().getUnnamedType();
+            auto& unnamedType2 = types[1].get().getUnnamedType();
+            success = matches(unnamedType1, unnamedType2);
+        } else if ((leftAcceptability == Acceptability::Maybe && rightAcceptability == Acceptability::Yes)
+            || (leftAcceptability == Acceptability::Yes && rightAcceptability == Acceptability::Maybe))
+            success = true;
+        if (success)
             return resolveWithReferenceComparator(callExpression, types[0].get(), types[1].get(), intrinsics);
     }
     return WTF::nullopt;
@@ -342,8 +358,7 @@ static bool checkOperatorOverload(const AST::FunctionDefinition& functionDefinit
             argumentTypes.append((*functionDefinition.parameters()[0].type())->clone());
         for (auto& argumentType : argumentTypes)
             argumentTypeReferences.append(argumentType);
-        Optional<std::reference_wrapper<AST::NamedType>> castReturnType;
-        auto* overload = resolveFunctionOverloadImpl(*getterFuncs, argumentTypeReferences, castReturnType);
+        auto* overload = resolveFunctionOverloadImpl(*getterFuncs, argumentTypeReferences, nullptr);
         if (!overload)
             return false;
         auto& resultType = overload->type();
@@ -443,7 +458,7 @@ private:
     Optional<UniqueRef<AST::UnnamedType>> recurseAndWrapBaseType(AST::PropertyAccessExpression&);
     bool recurseAndRequireBoolType(AST::Expression&);
     void assignType(AST::Expression&, UniqueRef<AST::UnnamedType>&&, Optional<AST::AddressSpace> = WTF::nullopt);
-    void assignType(AST::Expression&, Ref<ResolvableTypeReference>&&, Optional<AST::AddressSpace> = WTF::nullopt);
+    void assignType(AST::Expression&, RefPtr<ResolvableTypeReference>&&, Optional<AST::AddressSpace> = WTF::nullopt);
     void forwardType(AST::Expression&, ResolvingType&, Optional<AST::AddressSpace> = WTF::nullopt);
 
     void visit(AST::FunctionDefinition&) override;
@@ -508,18 +523,17 @@ void Checker::visit(Program& program)
 bool Checker::assignTypes()
 {
     for (auto& keyValuePair : m_typeMap) {
-        auto success = WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
+        auto success = keyValuePair.value.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
             keyValuePair.key->setType(unnamedType->clone());
             return true;
-        }, [&](Ref<ResolvableTypeReference>& resolvableTypeReference) -> bool {
+        }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
             if (!resolvableTypeReference->resolvableType().resolvedType()) {
-                // FIXME: Instead of trying to commit, it might be better to just return an error instead.
                 if (!static_cast<bool>(commit(resolvableTypeReference->resolvableType())))
                     return false;
             }
             keyValuePair.key->setType(resolvableTypeReference->resolvableType().resolvedType()->clone());
             return true;
-        }), keyValuePair.value);
+        }));
         if (!success)
             return false;
     }
@@ -533,11 +547,11 @@ bool Checker::checkShaderType(const AST::FunctionDefinition& functionDefinition)
 {
     switch (*functionDefinition.entryPointType()) {
     case AST::EntryPointType::Vertex:
-        return !m_vertexEntryPoints.add(functionDefinition.name()).isNewEntry;
+        return static_cast<bool>(m_vertexEntryPoints.add(functionDefinition.name()));
     case AST::EntryPointType::Fragment:
-        return !m_fragmentEntryPoints.add(functionDefinition.name()).isNewEntry;
+        return static_cast<bool>(m_fragmentEntryPoints.add(functionDefinition.name()));
     case AST::EntryPointType::Compute:
-        return !m_computeEntryPoints.add(functionDefinition.name()).isNewEntry;
+        return static_cast<bool>(m_computeEntryPoints.add(functionDefinition.name()));
     }
 }
 
@@ -563,48 +577,48 @@ void Checker::visit(AST::FunctionDefinition& functionDefinition)
         return;
     }
 
-    checkErrorAndVisit(functionDefinition);
+    Visitor::visit(functionDefinition);
 }
 
 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& left, ResolvingType& right)
 {
-    return WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
-        return WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
+    return left.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
+        return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
             if (matches(left, right))
                 return left->clone();
             return WTF::nullopt;
-        }, [&](Ref<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
+        }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
             return matchAndCommit(left, right->resolvableType());
-        }), right);
-    }, [&](Ref<ResolvableTypeReference>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
-        return WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
+        }));
+    }, [&](RefPtr<ResolvableTypeReference>& left) -> Optional<UniqueRef<AST::UnnamedType>> {
+        return right.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
             return matchAndCommit(right, left->resolvableType());
-        }, [&](Ref<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
+        }, [&](RefPtr<ResolvableTypeReference>& right) -> Optional<UniqueRef<AST::UnnamedType>> {
             return matchAndCommit(left->resolvableType(), right->resolvableType());
-        }), right);
-    }), left);
+        }));
+    }));
 }
 
 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::UnnamedType& unnamedType)
 {
-    return WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
+    return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
         if (matches(unnamedType, resolvingType))
             return unnamedType.clone();
         return WTF::nullopt;
-    }, [&](Ref<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
+    }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
         return matchAndCommit(unnamedType, resolvingType->resolvableType());
-    }), resolvingType);
+    }));
 }
 
 static Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(ResolvingType& resolvingType, AST::NamedType& namedType)
 {
-    return WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
+    return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
         if (matches(resolvingType, namedType))
             return resolvingType->clone();
         return WTF::nullopt;
-    }, [&](Ref<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
+    }, [&](RefPtr<ResolvableTypeReference>& resolvingType) -> Optional<UniqueRef<AST::UnnamedType>> {
         return matchAndCommit(namedType, resolvingType->resolvableType());
-    }), resolvingType);
+    }));
 }
 
 void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
@@ -710,13 +724,14 @@ void Checker::visit(AST::TypeReference& typeReference)
 {
     ASSERT(typeReference.resolvedType());
 
-    checkErrorAndVisit(typeReference);
+    for (auto& typeArgument : typeReference.typeArguments())
+        checkErrorAndVisit(typeArgument);
 }
 
 auto Checker::recurseAndGetInfo(AST::Expression& expression, bool requiresLValue) -> Optional<RecurseInfo>
 {
-    checkErrorAndVisit(expression);
-    if (!error())
+    Visitor::visit(expression);
+    if (error())
         return WTF::nullopt;
     return getInfo(expression, requiresLValue);
 }
@@ -760,7 +775,7 @@ void Checker::assignType(AST::Expression& expression, UniqueRef<AST::UnnamedType
     ASSERT_UNUSED(addressSpaceAddResult, addressSpaceAddResult.isNewEntry);
 }
 
-void Checker::assignType(AST::Expression& expression, Ref<ResolvableTypeReference>&& resolvableTypeReference, Optional<AST::AddressSpace> addressSpace)
+void Checker::assignType(AST::Expression& expression, RefPtr<ResolvableTypeReference>&& resolvableTypeReference, Optional<AST::AddressSpace> addressSpace)
 {
     auto addResult = m_typeMap.add(&expression, WTFMove(resolvableTypeReference));
     ASSERT_UNUSED(addResult, addResult.isNewEntry);
@@ -789,13 +804,13 @@ void Checker::visit(AST::AssignmentExpression& assignmentExpression)
 
 void Checker::forwardType(AST::Expression& expression, ResolvingType& resolvingType, Optional<AST::AddressSpace> addressSpace)
 {
-    WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& result) {
+    resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& result) {
         auto addResult = m_typeMap.add(&expression, result->clone());
         ASSERT_UNUSED(addResult, addResult.isNewEntry);
-    }, [&](Ref<ResolvableTypeReference>& result) {
+    }, [&](RefPtr<ResolvableTypeReference>& result) {
         auto addResult = m_typeMap.add(&expression, result.copyRef());
         ASSERT_UNUSED(addResult, addResult.isNewEntry);
-    }), resolvingType);
+    }));
     auto addressSpaceAddResult = m_addressSpaceMap.add(&expression, addressSpace);
     ASSERT_UNUSED(addressSpaceAddResult, addressSpaceAddResult.isNewEntry);
 }
@@ -826,12 +841,12 @@ void Checker::visit(AST::ReadModifyWriteExpression& readModifyWriteExpression)
 
 static AST::UnnamedType* getUnnamedType(ResolvingType& resolvingType)
 {
-    return WTF::visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& type) -> AST::UnnamedType* {
+    return resolvingType.visit(WTF::makeVisitor([](UniqueRef<AST::UnnamedType>& type) -> AST::UnnamedType* {
         return &type;
-    }, [](Ref<ResolvableTypeReference>& type) -> AST::UnnamedType* {
+    }, [](RefPtr<ResolvableTypeReference>& type) -> AST::UnnamedType* {
         // FIXME: If the type isn't committed, should we just commit() it now?
         return type->resolvableType().resolvedType();
-    }), resolvingType);
+    }));
 }
 
 void Checker::visit(AST::DereferenceExpression& dereferenceExpression)
@@ -921,7 +936,6 @@ void Checker::visit(AST::MakeArrayReferenceExpression& makeArrayReferenceExpress
 
 void Checker::finishVisitingPropertyAccess(AST::PropertyAccessExpression& propertyAccessExpression, AST::UnnamedType& wrappedBaseType, AST::UnnamedType* extraArgumentType)
 {
-    Optional<std::reference_wrapper<AST::NamedType>> castReturnType;
     using OverloadResolution = std::tuple<AST::FunctionDeclaration*, AST::UnnamedType*>;
 
     AST::FunctionDeclaration* getFunction;
@@ -937,7 +951,7 @@ void Checker::finishVisitingPropertyAccess(AST::PropertyAccessExpression& proper
         if (getArgumentType2)
             getArgumentTypes.append(*getArgumentType2);
 
-        auto* getFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleGetOverloads(), getArgumentTypes, castReturnType);
+        auto* getFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleGetOverloads(), getArgumentTypes, nullptr);
         if (!getFunction)
             return std::make_pair(nullptr, nullptr);
         return std::make_pair(getFunction, &getFunction->type());
@@ -950,10 +964,10 @@ void Checker::finishVisitingPropertyAccess(AST::PropertyAccessExpression& proper
             if (is<AST::ArrayReferenceType>(unnamedType))
                 return { unnamedType.clone() };
             if (is<AST::ArrayType>(unnamedType))
-                return { makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, downcast<AST::ArrayType>(unnamedType).type().clone()) };
+                return { ResolvingType(makeUniqueRef<AST::ArrayReferenceType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, downcast<AST::ArrayType>(unnamedType).type().clone())) };
             if (is<AST::PointerType>(unnamedType))
                 return WTF::nullopt;
-            return { makeUniqueRef<AST::PointerType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, downcast<AST::ArrayType>(unnamedType).type().clone()) };
+            return { ResolvingType(makeUniqueRef<AST::PointerType>(Lexer::Token(propertyAccessExpression.origin()), AST::AddressSpace::Thread, downcast<AST::TypeReference>(unnamedType).clone())) };
         };
         auto computeAndReturnType = [&](AST::UnnamedType& unnamedType) -> AST::UnnamedType* {
             if (is<AST::PointerType>(unnamedType))
@@ -973,7 +987,7 @@ void Checker::finishVisitingPropertyAccess(AST::PropertyAccessExpression& proper
         if (andArgumentType2)
             andArgumentTypes.append(*andArgumentType2);
 
-        auto* andFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleAndOverloads(), andArgumentTypes, castReturnType);
+        auto* andFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleAndOverloads(), andArgumentTypes, nullptr);
         if (!andFunction)
             return std::make_pair(nullptr, nullptr);
         return std::make_pair(andFunction, computeAndReturnType(andFunction->type()));
@@ -1004,7 +1018,7 @@ void Checker::finishVisitingPropertyAccess(AST::PropertyAccessExpression& proper
             setArgumentTypes.append(*setArgumentType2);
         setArgumentTypes.append(setArgument3Type);
 
-        auto* setFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleSetOverloads(), setArgumentTypes, castReturnType);
+        auto* setFunction = resolveFunctionOverloadImpl(propertyAccessExpression.possibleSetOverloads(), setArgumentTypes, nullptr);
         if (!setFunction)
             return std::make_pair(nullptr, nullptr);
         return std::make_pair(setFunction, &setFunction->type());
@@ -1161,11 +1175,11 @@ void Checker::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
 
 bool Checker::isBoolType(ResolvingType& resolvingType)
 {
-    return WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> bool {
+    return resolvingType.visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& left) -> bool {
         return matches(left, m_intrinsics.boolType());
-    }, [&](Ref<ResolvableTypeReference>& left) -> bool {
+    }, [&](RefPtr<ResolvableTypeReference>& left) -> bool {
         return static_cast<bool>(matchAndCommit(m_intrinsics.boolType(), left->resolvableType()));
-    }), resolvingType);
+    }));
 }
 
 bool Checker::recurseAndRequireBoolType(AST::Expression& expression)
@@ -1424,11 +1438,8 @@ void Checker::visit(AST::CallExpression& callExpression)
             return;
         types.uncheckedAppend(argumentInfo->resolvingType);
     }
-    if (callExpression.castReturnType()) {
-        checkErrorAndVisit(callExpression.castReturnType()->get());
-        if (error())
-            return;
-    }
+    // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
+    // We don't want to recurse to the same node twice.
 
     ASSERT(callExpression.hasOverloads());
     auto* function = resolveFunctionOverloadImpl(*callExpression.overloads(), types, callExpression.castReturnType());
index 45a62a6..a32af5e 100644 (file)
@@ -116,6 +116,7 @@ public:
         auto depth = m_typeReferences.size();
         checkErrorAndVisit(*typeReference.resolvedType());
         ASSERT_UNUSED(depth, m_typeReferences.size() == depth);
+        m_typeReferences.removeLast();
     }
 
     void visit(AST::PointerType& pointerType)
index 0a2c9eb..2fcc6f1 100644 (file)
@@ -221,20 +221,20 @@ bool inferTypesForTypeArguments(AST::NamedType& possibleType, AST::TypeArguments
     return true;
 }
 
-bool inferTypesForCall(AST::FunctionDeclaration& possibleFunction, Vector<std::reference_wrapper<ResolvingType>>& argumentTypes, Optional<std::reference_wrapper<AST::NamedType>>& castReturnType)
+bool inferTypesForCall(AST::FunctionDeclaration& possibleFunction, Vector<std::reference_wrapper<ResolvingType>>& argumentTypes, const AST::NamedType* castReturnType)
 {
     if (possibleFunction.parameters().size() != argumentTypes.size())
         return false;
     for (size_t i = 0; i < possibleFunction.parameters().size(); ++i) {
-        auto success = WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
+        auto success = argumentTypes[i].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>& unnamedType) -> bool {
             return matches(*possibleFunction.parameters()[i].type(), unnamedType);
-        }, [&](Ref<ResolvableTypeReference>& resolvableTypeReference) -> bool {
-            return resolvableTypeReference->resolvableType().canResolve(*possibleFunction.parameters()[i].type());
-        }), argumentTypes[i].get());
+        }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> bool {
+            return resolvableTypeReference->resolvableType().canResolve(possibleFunction.parameters()[i].type()->unifyNode());
+        }));
         if (!success)
             return false;
     }
-    if (castReturnType && !matches(castReturnType->get(), possibleFunction.type()))
+    if (castReturnType && !matches(possibleFunction.type(), *castReturnType))
         return false;
     return true;
 }
index dfd4b54..173d301 100644 (file)
@@ -53,7 +53,7 @@ Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(AST::NamedType&, AST::Resol
 Optional<UniqueRef<AST::UnnamedType>> matchAndCommit(AST::ResolvableType&, AST::ResolvableType&);
 Optional<UniqueRef<AST::UnnamedType>> commit(AST::ResolvableType&);
 bool inferTypesForTypeArguments(AST::NamedType& possibleType, AST::TypeArguments&);
-bool inferTypesForCall(AST::FunctionDeclaration& possibleFunction, Vector<std::reference_wrapper<ResolvingType>>& argumentTypes, Optional<std::reference_wrapper<AST::NamedType>>& castReturnType);
+bool inferTypesForCall(AST::FunctionDeclaration& possibleFunction, Vector<std::reference_wrapper<ResolvingType>>& argumentTypes, const AST::NamedType* castReturnType);
 
 }
 
index 8bba6b6..cb9d8d1 100644 (file)
@@ -266,11 +266,13 @@ bool Intrinsics::addPrimitive(AST::NativeTypeDeclaration& nativeTypeDeclaration)
             return true;
         });
         m_floatType = &nativeTypeDeclaration;
-    } else if (nativeTypeDeclaration.name() == "atomic_int")
+    } else if (nativeTypeDeclaration.name() == "atomic_int") {
+        nativeTypeDeclaration.setIsAtomic();
         m_atomicIntType = &nativeTypeDeclaration;
-    else if (nativeTypeDeclaration.name() == "atomic_uint")
+    } else if (nativeTypeDeclaration.name() == "atomic_uint") {
+        nativeTypeDeclaration.setIsAtomic();
         m_atomicUintType = &nativeTypeDeclaration;
-    else if (nativeTypeDeclaration.name() == "sampler")
+    else if (nativeTypeDeclaration.name() == "sampler")
         m_samplerType = &nativeTypeDeclaration;
     else
         ASSERT_NOT_REACHED();
@@ -357,7 +359,7 @@ bool Intrinsics::addFullTexture(AST::NativeTypeDeclaration& nativeTypeDeclaratio
     unsigned vectorLength;
     for (unsigned i = 0; i < WTF_ARRAY_LENGTH(m_textureInnerTypeNames); ++i) {
         if (innerType.name().startsWith(m_textureInnerTypeNames[i])) {
-            textureTypeIndex = i;
+            innerTypeIndex = i;
             if (innerType.name() == m_textureInnerTypeNames[i])
                 vectorLength = 1;
             else {
index fc47d32..d23bf97 100644 (file)
@@ -58,10 +58,16 @@ public:
         return *m_boolType;
     }
 
-    AST::NativeTypeDeclaration& intType() const
+    AST::NativeTypeDeclaration& ucharType() const
     {
-        ASSERT(m_intType);
-        return *m_intType;
+        ASSERT(m_ucharType);
+        return *m_ucharType;
+    }
+
+    AST::NativeTypeDeclaration& ushortType() const
+    {
+        ASSERT(m_ushortType);
+        return *m_ushortType;
     }
 
     AST::NativeTypeDeclaration& uintType() const
@@ -70,6 +76,96 @@ public:
         return *m_uintType;
     }
 
+    AST::NativeTypeDeclaration& charType() const
+    {
+        ASSERT(m_charType);
+        return *m_charType;
+    }
+
+    AST::NativeTypeDeclaration& shortType() const
+    {
+        ASSERT(m_shortType);
+        return *m_shortType;
+    }
+
+    AST::NativeTypeDeclaration& intType() const
+    {
+        ASSERT(m_intType);
+        return *m_intType;
+    }
+
+    AST::NativeTypeDeclaration& uchar2Type() const
+    {
+        ASSERT(m_vectorUchar[0]);
+        return *m_vectorUchar[0];
+    }
+
+    AST::NativeTypeDeclaration& uchar4Type() const
+    {
+        ASSERT(m_vectorUchar[2]);
+        return *m_vectorUchar[2];
+    }
+
+    AST::NativeTypeDeclaration& ushort2Type() const
+    {
+        ASSERT(m_vectorUshort[0]);
+        return *m_vectorUshort[0];
+    }
+
+    AST::NativeTypeDeclaration& ushort4Type() const
+    {
+        ASSERT(m_vectorUshort[2]);
+        return *m_vectorUshort[2];
+    }
+
+    AST::NativeTypeDeclaration& uint2Type() const
+    {
+        ASSERT(m_vectorUint[0]);
+        return *m_vectorUint[0];
+    }
+
+    AST::NativeTypeDeclaration& uint4Type() const
+    {
+        ASSERT(m_vectorUint[2]);
+        return *m_vectorUint[2];
+    }
+
+    AST::NativeTypeDeclaration& char2Type() const
+    {
+        ASSERT(m_vectorChar[0]);
+        return *m_vectorChar[0];
+    }
+
+    AST::NativeTypeDeclaration& char4Type() const
+    {
+        ASSERT(m_vectorChar[2]);
+        return *m_vectorChar[2];
+    }
+
+    AST::NativeTypeDeclaration& short2Type() const
+    {
+        ASSERT(m_vectorShort[0]);
+        return *m_vectorShort[0];
+    }
+
+    AST::NativeTypeDeclaration& short4Type() const
+    {
+        ASSERT(m_vectorShort[2]);
+        return *m_vectorShort[2];
+    }
+
+    AST::NativeTypeDeclaration& int2Type() const
+    {
+        ASSERT(m_vectorInt[0]);
+        return *m_vectorInt[0];
+    }
+
+    AST::NativeTypeDeclaration& int4Type() const
+    {
+        ASSERT(m_vectorInt[2]);
+        return *m_vectorInt[2];
+    }
+
     AST::NativeTypeDeclaration& samplerType() const
     {
         ASSERT(m_samplerType);
index 964bc41..b836c5d 100644 (file)
@@ -325,7 +325,7 @@ auto Lexer::recognizeKeyword(unsigned end) -> Optional<Token::Type>
         return Token::Type::SVSampleIndex;
     if (substring == "SV_InnerCoverage")
         return Token::Type::SVInnerCoverage;
-    if (substring == "SV_Target")
+    if (substring == "SV_Target") // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195807 Make this work with strings like "SV_Target0".
         return Token::Type::SVTarget;
     if (substring == "SV_Depth")
         return Token::Type::SVDepth;
index 38214a4..1990213 100644 (file)
@@ -110,6 +110,8 @@ bool NameContext::add(AST::NativeTypeDeclaration& nativeTypeDeclaration)
 
 bool NameContext::add(AST::VariableDeclaration& variableDeclaration)
 {
+    if (variableDeclaration.name().isNull())
+        return true;
     if (exists(variableDeclaration.name()))
         return false;
     auto result = m_variables.add(String(variableDeclaration.name()), &variableDeclaration);
index 9c41931..05350e0 100644 (file)
@@ -41,6 +41,7 @@
 #include "WHLSLPropertyAccessExpression.h"
 #include "WHLSLResolveOverloadImpl.h"
 #include "WHLSLReturn.h"
+#include "WHLSLScopedSetAdder.h"
 #include "WHLSLTypeReference.h"
 #include "WHLSLVariableDeclaration.h"
 #include "WHLSLVariableReference.h"
@@ -57,7 +58,13 @@ NameResolver::NameResolver(NameContext& nameContext)
 
 void NameResolver::visit(AST::TypeReference& typeReference)
 {
-    checkErrorAndVisit(typeReference);
+    ScopedSetAdder<AST::TypeReference*> adder(m_typeReferences, &typeReference);
+    if (!adder.isNewEntry()) {
+        setError();
+        return;
+    }
+
+    Visitor::visit(typeReference);
     if (typeReference.resolvedType())
         return;
 
@@ -66,6 +73,8 @@ void NameResolver::visit(AST::TypeReference& typeReference)
         setError();
         return;
     }
+    for (auto& candidate : *candidates)
+        Visitor::visit(candidate);
     if (auto result = resolveTypeOverloadImpl(*candidates, typeReference.typeArguments()))
         typeReference.setResolvedType(*result);
     else {
@@ -78,32 +87,33 @@ void NameResolver::visit(AST::FunctionDefinition& functionDefinition)
 {
     NameContext newNameContext(&m_nameContext);
     NameResolver newNameResolver(newNameContext);
+    newNameResolver.setCurrentFunctionDefinition(m_currentFunction);
     checkErrorAndVisit(functionDefinition.type());
-    for (auto& parameter : functionDefinition.parameters()) {
+    for (auto& parameter : functionDefinition.parameters())
         newNameResolver.checkErrorAndVisit(parameter);
-        auto success = newNameContext.add(parameter);
-        if (!success) {
-            setError();
-            return;
-        }
-    }
     newNameResolver.checkErrorAndVisit(functionDefinition.block());
 }
 
 void NameResolver::visit(AST::Block& block)
 {
     NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(block);
+    NameResolver newNameResolver(nameContext);
+    newNameResolver.setCurrentFunctionDefinition(m_currentFunction);
+    newNameResolver.Visitor::visit(block);
 }
 
 void NameResolver::visit(AST::IfStatement& ifStatement)
 {
     checkErrorAndVisit(ifStatement.conditional());
     NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(ifStatement.body());
+    NameResolver newNameResolver(nameContext);
+    newNameResolver.setCurrentFunctionDefinition(m_currentFunction);
+    newNameResolver.checkErrorAndVisit(ifStatement.body());
     if (ifStatement.elseBody()) {
         NameContext nameContext(&m_nameContext);
-        NameResolver(nameContext).checkErrorAndVisit(*ifStatement.elseBody());
+        NameResolver newNameResolver(nameContext);
+        newNameResolver.setCurrentFunctionDefinition(m_currentFunction);
+        newNameResolver.checkErrorAndVisit(*ifStatement.elseBody());
     }
 }
 
@@ -111,26 +121,35 @@ void NameResolver::visit(AST::WhileLoop& whileLoop)
 {
     checkErrorAndVisit(whileLoop.conditional());
     NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(whileLoop.body());
+    NameResolver newNameResolver(nameContext);
+    newNameResolver.setCurrentFunctionDefinition(m_currentFunction);
+    newNameResolver.checkErrorAndVisit(whileLoop.body());
 }
 
 void NameResolver::visit(AST::DoWhileLoop& whileLoop)
 {
     NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(whileLoop.body());
+    NameResolver newNameResolver(nameContext);
+    newNameResolver.setCurrentFunctionDefinition(m_currentFunction);
+    newNameResolver.checkErrorAndVisit(whileLoop.body());
     checkErrorAndVisit(whileLoop.conditional());
 }
 
 void NameResolver::visit(AST::ForLoop& forLoop)
 {
     NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(forLoop);
+    NameResolver newNameResolver(nameContext);
+    newNameResolver.setCurrentFunctionDefinition(m_currentFunction);
+    newNameResolver.Visitor::visit(forLoop);
 }
 
 void NameResolver::visit(AST::VariableDeclaration& variableDeclaration)
 {
-    m_nameContext.add(variableDeclaration);
-    checkErrorAndVisit(variableDeclaration);
+    if (!m_nameContext.add(variableDeclaration)) {
+        setError();
+        return;
+    }
+    Visitor::visit(variableDeclaration);
 }
 
 void NameResolver::visit(AST::VariableReference& variableReference)
@@ -150,7 +169,7 @@ void NameResolver::visit(AST::Return& returnStatement)
 {
     ASSERT(m_currentFunction);
     returnStatement.setFunction(m_currentFunction);
-    checkErrorAndVisit(returnStatement);
+    Visitor::visit(returnStatement);
 }
 
 void NameResolver::visit(AST::PropertyAccessExpression& propertyAccessExpression)
@@ -161,7 +180,7 @@ void NameResolver::visit(AST::PropertyAccessExpression& propertyAccessExpression
         propertyAccessExpression.setPossibleSetOverloads(*setFunctions);
     if (auto* andFunctions = m_nameContext.getFunctions(propertyAccessExpression.andFunctionName()))
         propertyAccessExpression.setPossibleAndOverloads(*andFunctions);
-    checkErrorAndVisit(propertyAccessExpression);
+    Visitor::visit(propertyAccessExpression);
 }
 
 void NameResolver::visit(AST::DotExpression& dotExpression)
@@ -189,7 +208,7 @@ void NameResolver::visit(AST::DotExpression& dotExpression)
         }
     }
 
-    checkErrorAndVisit(dotExpression);
+    Visitor::visit(dotExpression);
 }
 
 void NameResolver::visit(AST::CallExpression& callExpression)
@@ -212,7 +231,7 @@ void NameResolver::visit(AST::CallExpression& callExpression)
         setError();
         return;
     }
-    checkErrorAndVisit(callExpression);
+    Visitor::visit(callExpression);
 }
 
 void NameResolver::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
index 530b39f..08f0f65 100644 (file)
@@ -64,7 +64,8 @@ private:
     void visit(AST::CallExpression&) override;
     void visit(AST::EnumerationMemberLiteral&) override;
 
-    NameContext m_nameContext;
+    NameContext& m_nameContext;
+    HashSet<AST::TypeReference*> m_typeReferences;
     AST::FunctionDefinition* m_currentFunction { nullptr };
 };
 
index baa6caf..5332567 100644 (file)
@@ -129,10 +129,12 @@ auto Parser::parse(Program& program, StringView stringView, Mode mode) -> Option
     return WTF::nullopt;
 }
 
-auto Parser::fail(const String& message) -> Unexpected<Error>
+auto Parser::fail(const String& message, TryToPeek tryToPeek) -> Unexpected<Error>
 {
-    if (auto nextToken = peek())
-        return Unexpected<Error>(Error(m_lexer.errorString(*nextToken, message)));
+    if (tryToPeek == TryToPeek::Yes) {
+        if (auto nextToken = peek())
+            return Unexpected<Error>(Error(m_lexer.errorString(*nextToken, message)));
+    }
     return Unexpected<Error>(Error(makeString("Cannot lex: ", message)));
 }
 
@@ -142,7 +144,7 @@ auto Parser::peek() -> Expected<Lexer::Token, Error>
         m_lexer.unconsumeToken(Lexer::Token(*token));
         return *token;
     }
-    return fail("Cannot consume token"_str);
+    return fail("Cannot consume token"_str, TryToPeek::No);
 }
 
 Optional<Lexer::Token> Parser::tryType(Lexer::Token::Type type)
@@ -582,13 +584,15 @@ auto Parser::parseNonAddressSpaceType() -> Expected<UniqueRef<AST::UnnamedType>,
 
 auto Parser::parseType() -> Expected<UniqueRef<AST::UnnamedType>, Error>
 {
-    auto type = backtrackingScope<Expected<UniqueRef<AST::UnnamedType>, Error>>([&]() {
-        return parseAddressSpaceType();
-    });
-    if (type)
-        return type;
+    {
+        auto type = backtrackingScope<Expected<UniqueRef<AST::UnnamedType>, Error>>([&]() {
+            return parseAddressSpaceType();
+        });
+        if (type)
+            return type;
+    }
 
-    type = backtrackingScope<Expected<UniqueRef<AST::UnnamedType>, Error>>([&]() {
+    auto type = backtrackingScope<Expected<UniqueRef<AST::UnnamedType>, Error>>([&]() {
         return parseNonAddressSpaceType();
     });
     if (type)
@@ -653,7 +657,7 @@ auto Parser::parseBuiltInSemantic() -> Expected<AST::BuiltInSemantic, Error>
     case Lexer::Token::Type::SVInnerCoverage:
         return AST::BuiltInSemantic(WTFMove(*origin), AST::BuiltInSemantic::Variable::SVInnerCoverage);
     case Lexer::Token::Type::SVTarget: {
-        auto target = consumeNonNegativeIntegralLiteral();
+        auto target = consumeNonNegativeIntegralLiteral(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195807 Make this work with strings like "SV_Target0".
         if (!target)
             return Unexpected<Error>(target.error());
         return AST::BuiltInSemantic(WTFMove(*origin), AST::BuiltInSemantic::Variable::SVTarget, *target);
index d37c50d..11ef5cd 100644 (file)
@@ -130,7 +130,11 @@ private:
         return result;
     }
 
-    Unexpected<Error> fail(const String& message);
+    enum class TryToPeek {
+        Yes,
+        No
+    };
+    Unexpected<Error> fail(const String& message, TryToPeek = TryToPeek::Yes);
     Expected<Lexer::Token, Error> peek();
     Optional<Lexer::Token> tryType(Lexer::Token::Type);
     Optional<Lexer::Token> tryTypes(Vector<Lexer::Token::Type>);
index 0f28692..5c4c0b5 100644 (file)
@@ -50,10 +50,50 @@ struct VertexAttribute {
 using VertexAttributes = Vector<VertexAttribute>;
 
 enum class TextureFormat {
-    R8G8B8A8Unorm,
-    R8G8B8A8Uint,
-    B8G8R8A8Unorm,
-    D32FloatS8Uint
+    R8Unorm,
+    R8UnormSrgb,
+    R8Snorm,
+    R8Uint,
+    R8Sint,
+    R16Unorm,
+    R16Snorm,
+    R16Uint,
+    R16Sint,
+    R16Float,
+    RG8Unorm,
+    RG8UnormSrgb,
+    RG8Snorm,
+    RG8Uint,
+    RG8Sint,
+    B5G6R5Unorm,
+    R32Uint,
+    R32Sint,
+    R32Float,
+    RG16Unorm,
+    RG16Snorm,
+    RG16Uint,
+    RG16Sint,
+    RG16Float,
+    RGBA8Unorm,
+    RGBA8UnormSrgb,
+    RGBA8Snorm,
+    RGBA8Uint,
+    RGBA8Sint,
+    BGRA8Unorm,
+    BGRA8UnormSrgb,
+    RGB10A2Unorm,
+    RG11B10Float,
+    RG32Uint,
+    RG32Sint,
+    RG32Float,
+    RGBA16Unorm,
+    RGBA16Snorm,
+    RGBA16Uint,
+    RGBA16Sint,
+    RGBA16Float,
+    RGBA32Uint,
+    RGBA32Sint,
+    RGBA32Float
 };
 
 struct AttachmentDescriptor {
index dcb94be..9a1035a 100644 (file)
@@ -66,18 +66,23 @@ static Optional<Program> prepareShared(String& whlslSource)
         return WTF::nullopt;
     if (!checkRecursiveTypes(program))
         return WTF::nullopt;
-    synthesizeStructureAccessors(program);
-    synthesizeEnumerationFunctions(program);
-    synthesizeArrayOperatorLength(program);
-    synthesizeConstructors(program);
-    resolveNamesInFunctions(program, nameResolver);
+    if (!synthesizeStructureAccessors(program))
+        return WTF::nullopt;
+    if (!synthesizeEnumerationFunctions(program))
+        return WTF::nullopt;
+    if (!synthesizeArrayOperatorLength(program))
+        return WTF::nullopt;
+    if (!synthesizeConstructors(program))
+        return WTF::nullopt;
+    if (!resolveNamesInFunctions(program, nameResolver))
+        return WTF::nullopt;
     if (!checkDuplicateFunctions(program))
         return WTF::nullopt;
 
     if (!check(program))
         return WTF::nullopt;
     checkLiteralTypes(program);
-    // resolveProperties(program);
+    // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195788 Resolve properties here
     findHighZombies(program);
     if (!checkStatementBehavior(program))
         return WTF::nullopt;
@@ -101,8 +106,8 @@ Optional<RenderPrepareResult> prepare(String& whlslSource, RenderPipelineDescrip
 
     RenderPrepareResult result;
     result.metalSource = WTFMove(generatedCode.metalSource);
-    result.vertexMappedBindGroups = WTFMove(generatedCode.vertexMappedBindGroups);
-    result.fragmentMappedBindGroups = WTFMove(generatedCode.fragmentMappedBindGroups);
+    result.mangledVertexEntryPointName = WTFMove(generatedCode.mangledVertexEntryPointName);
+    result.mangledFragmentEntryPointName = WTFMove(generatedCode.mangledFragmentEntryPointName);
     return result;
 }
 
@@ -119,7 +124,7 @@ Optional<ComputePrepareResult> prepare(String& whlslSource, ComputePipelineDescr
 
     ComputePrepareResult result;
     result.metalSource = WTFMove(generatedCode.metalSource);
-    result.mappedBindGroups = WTFMove(generatedCode.bindGroups);
+    result.mangledEntryPointName = WTFMove(generatedCode.mangledEntryPointName);
     return result;
 }
 
index 063b974..97490a6 100644 (file)
@@ -27,7 +27,6 @@
 
 #if ENABLE(WEBGPU)
 
-#include "WHLSLMappedBindings.h"
 #include "WHLSLPipelineDescriptor.h"
 #include <wtf/text/WTFString.h>
 
@@ -38,14 +37,14 @@ namespace WHLSL {
 // FIXME: Generate descriptive error messages and return them here.
 struct RenderPrepareResult {
     String metalSource;
-    Metal::MappedBindGroups vertexMappedBindGroups;
-    Metal::MappedBindGroups fragmentMappedBindGroups;
+    String mangledVertexEntryPointName;
+    String mangledFragmentEntryPointName;
 };
 Optional<RenderPrepareResult> prepare(String& whlslSource, RenderPipelineDescriptor&);
 
 struct ComputePrepareResult {
     String metalSource;
-    Metal::MappedBindGroups mappedBindGroups;
+    String mangledEntryPointName;
 };
 Optional<ComputePrepareResult> prepare(String& whlslSource, ComputePipelineDescriptor&);
 
index 2a6bba6..d5a7b0d 100644 (file)
@@ -28,6 +28,7 @@
 
 #if ENABLE(WEBGPU)
 
+#include "WHLSLScopedSetAdder.h"
 #include "WHLSLStructureDefinition.h"
 #include "WHLSLTypeDefinition.h"
 #include "WHLSLTypeReference.h"
@@ -42,63 +43,61 @@ class RecursiveTypeChecker : public Visitor {
 public:
     ~RecursiveTypeChecker() = default;
 
-    void visit(AST::TypeDefinition& typeDefinition) override
-    {
-        auto addResult = m_types.add(&typeDefinition);
-        if (!addResult.isNewEntry) {
-            setError();
-            return;
-        }
+    void visit(AST::TypeDefinition&) override;
+    void visit(AST::StructureDefinition&) override;
+    void visit(AST::TypeReference&) override;
+    void visit(AST::ReferenceType&) override;
 
-        Visitor::visit(typeDefinition);
+private:
+    using Adder = ScopedSetAdder<AST::Type*>;
+    HashSet<AST::Type*> m_types;
+};
 
-        auto success = m_types.remove(&typeDefinition);
-        ASSERT_UNUSED(success, success);
+void RecursiveTypeChecker::visit(AST::TypeDefinition& typeDefinition)
+{
+    Adder adder(m_types, &typeDefinition);
+    if (!adder.isNewEntry()) {
+        setError();
+        return;
     }
 
-    void visit(AST::StructureDefinition& structureDefinition) override
-    {
-        auto addResult = m_types.add(&structureDefinition);
-        if (!addResult.isNewEntry) {
-            setError();
-            return;
-        }
-
-        Visitor::visit(structureDefinition);
+    Visitor::visit(typeDefinition);
+}
 
-        auto success = m_types.remove(&structureDefinition);
-        ASSERT_UNUSED(success, success);
+void RecursiveTypeChecker::visit(AST::StructureDefinition& structureDefinition)
+{
+    Adder adder(m_types, &structureDefinition);
+    if (!adder.isNewEntry()) {
+        setError();
+        return;
     }
 
-    void visit(AST::TypeReference& typeReference) override
-    {
-        auto addResult = m_types.add(&typeReference);
-        if (!addResult.isNewEntry) {
-            setError();
-            return;
-        }
-
-        for (auto& typeArgument : typeReference.typeArguments())
-            checkErrorAndVisit(typeArgument);
-        checkErrorAndVisit(*typeReference.resolvedType());
+    Visitor::visit(structureDefinition);
+}
 
-        auto success = m_types.remove(&typeReference);
-        ASSERT_UNUSED(success, success);
+void RecursiveTypeChecker::visit(AST::TypeReference& typeReference)
+{
+    Adder adder(m_types, &typeReference);
+    if (!adder.isNewEntry()) {
+        setError();
+        return;
     }
 
-    void visit(AST::ReferenceType&) override
-    {
-    }
+    for (auto& typeArgument : typeReference.typeArguments())
+        checkErrorAndVisit(typeArgument);
+    if (typeReference.resolvedType())
+        checkErrorAndVisit(*typeReference.resolvedType());
+}
 
-private:
-    HashSet<AST::Type*> m_types;
-};
+void RecursiveTypeChecker::visit(AST::ReferenceType&)
+{
+}
 
 bool checkRecursiveTypes(Program& program)
 {
     RecursiveTypeChecker recursiveTypeChecker;
     recursiveTypeChecker.checkErrorAndVisit(program);
-    return recursiveTypeChecker.error();
+    return !recursiveTypeChecker.error();
 }
 
 } // namespace WHLSL
index 61a4251..e7838ce 100644 (file)
@@ -41,17 +41,17 @@ static unsigned conversionCost(AST::FunctionDeclaration& candidate, const Vector
 {
     unsigned conversionCost = 0;
     for (size_t i = 0; i < candidate.parameters().size(); ++i) {
-        conversionCost += WTF::visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>&) -> unsigned {
+        conversionCost += argumentTypes[i].get().visit(WTF::makeVisitor([&](UniqueRef<AST::UnnamedType>&) -> unsigned {
             return 0;
-        }, [&](Ref<ResolvableTypeReference>& resolvableTypeReference) -> unsigned {
+        }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> unsigned {
             return resolvableTypeReference->resolvableType().conversionCost(*candidate.parameters()[i].type());
-        }), argumentTypes[i].get());
+        }));
     }
     // The return type can never be a literal type, so its conversion cost is always 0.
     return conversionCost;
 }
 
-AST::FunctionDeclaration* resolveFunctionOverloadImpl(Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>& possibleFunctions, Vector<std::reference_wrapper<ResolvingType>>& argumentTypes, Optional<std::reference_wrapper<AST::NamedType>>& castReturnType)
+AST::FunctionDeclaration* resolveFunctionOverloadImpl(Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>& possibleFunctions, Vector<std::reference_wrapper<ResolvingType>>& argumentTypes, AST::NamedType* castReturnType)
 {
     Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1> candidates;
     for (auto& possibleFunction : possibleFunctions) {
index 8d9346f..0653303 100644 (file)
@@ -43,7 +43,7 @@ class NamedType;
 
 }
 
-AST::FunctionDeclaration* resolveFunctionOverloadImpl(Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>& possibleFunctions, Vector<std::reference_wrapper<ResolvingType>>& argumentTypes, Optional<std::reference_wrapper<AST::NamedType>>& castReturnType);
+AST::FunctionDeclaration* resolveFunctionOverloadImpl(Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>& possibleFunctions, Vector<std::reference_wrapper<ResolvingType>>& argumentTypes, AST::NamedType* castReturnType);
 AST::NamedType* resolveTypeOverloadImpl(Vector<std::reference_wrapper<AST::NamedType>, 1>&, AST::TypeArguments&);
 
 }
index 97c0eba..ca5cbb5 100644 (file)
@@ -27,9 +27,9 @@
 
 #if ENABLE(WEBGPU)
 
-#include <memory>
-#include <wtf/Ref.h>
+#include "WHLSLUnnamedType.h"
 #include <wtf/RefCounted.h>
+#include <wtf/RefPtr.h>
 #include <wtf/UniqueRef.h>
 #include <wtf/Variant.h>
 
@@ -40,10 +40,18 @@ namespace WHLSL {
 namespace AST {
 
 class ResolvableType;
-class UnnamedType;
 
 }
 
+// There are cases where the type of one AST node should match the type of another AST node.
+// One example of this is the comma expression - the type of the comma expression should match the type of its last item.
+// If this type happens to be a resolvable type, it will get resolved later. When that happens,
+// both of the AST nodes have to be resolved to the result type. This class represents a
+// reference counted wrapper around a resolvable type so both entries in the type map can point
+// to the same resolvable type, so resolving it once resolves both the entries in the map.
+// This class could probably be represented as
+// "class ResolvableTypeReference : public std::reference_wrapper<AST::ResolvableType>, public RefCounted<ResolvableTypeReference {}"
+// but I didn't want to be too clever.
 class ResolvableTypeReference : public RefCounted<ResolvableTypeReference> {
 public:
     ResolvableTypeReference(AST::ResolvableType& resolvableType)
@@ -53,6 +61,8 @@ public:
 
     ResolvableTypeReference(const ResolvableTypeReference&) = delete;
     ResolvableTypeReference(ResolvableTypeReference&&) = delete;
+    ResolvableTypeReference& operator=(const ResolvableTypeReference&) = delete;
+    ResolvableTypeReference& operator=(ResolvableTypeReference&&) = delete;
 
     AST::ResolvableType& resolvableType() { return *m_resolvableType; }
 
@@ -60,7 +70,52 @@ private:
     AST::ResolvableType* m_resolvableType;
 };
 
-using ResolvingType = Variant<UniqueRef<AST::UnnamedType>, Ref<ResolvableTypeReference>>;
+// This is a thin wrapper around a Variant.
+// It exists so we can make sure that the default constructor does the right thing.
+class ResolvingType {
+public:
+    ResolvingType()
+        : m_inner(RefPtr<ResolvableTypeReference>())
+    {
+    }
+
+    ResolvingType(UniqueRef<AST::UnnamedType>&& v)
+        : m_inner(WTFMove(v))
+    {
+    }
+
+    ResolvingType(RefPtr<ResolvableTypeReference>&& v)
+        : m_inner(WTFMove(v))
+    {
+    }
+
+    ResolvingType(const ResolvingType&) = delete;
+    ResolvingType(ResolvingType&& other)
+        : m_inner(WTFMove(other.m_inner))
+    {
+    }
+
+    ResolvingType& operator=(const ResolvingType&) = delete;
+    ResolvingType& operator=(ResolvingType&& other)
+    {
+        m_inner = WTFMove(other.m_inner);
+        return *this;
+    }
+
+    AST::UnnamedType& getUnnamedType()
+    {
+        ASSERT(WTF::holds_alternative<UniqueRef<AST::UnnamedType>>(m_inner));
+        return WTF::get<UniqueRef<AST::UnnamedType>>(m_inner);
+    }
+
+    template <typename Visitor> auto visit(const Visitor& visitor) -> decltype(WTF::visit(visitor, std::declval<Variant<UniqueRef<AST::UnnamedType>, RefPtr<ResolvableTypeReference>>&>()))
+    {
+        return WTF::visit(visitor, m_inner);
+    }
+
+private:
+    Variant<UniqueRef<AST::UnnamedType>, RefPtr<ResolvableTypeReference>> m_inner;
+};
 
 }
 
 
 #if ENABLE(WEBGPU)
 
-#include <wtf/Vector.h>
+#include <wtf/HashSet.h>
 
 namespace WebCore {
 
 namespace WHLSL {
 
-namespace Metal {
+template<typename T> class ScopedSetAdder {
+public:
+    ScopedSetAdder(HashSet<T>& set, T&& item)
+        : m_set(set)
+        , m_item(WTFMove(item))
+    {
+        m_isNewEntry = static_cast<bool>(m_set.add(m_item));
+    }
 
-struct MappedBindGroup {
-    unsigned argumentBufferIndex;
-    Vector<unsigned> bindingIndices;
-};
+    ~ScopedSetAdder()
+    {
+        if (!m_isNewEntry)
+            return;
+        auto success = m_set.remove(m_item);
+        ASSERT_UNUSED(success, success);
+    }
 
-using MappedBindGroups = Vector<MappedBindGroup>; // Parallel to the input resource Layout.
+    bool isNewEntry() const { return m_isNewEntry; }
 
-} // namespace Metal
+private:
+    HashSet<T>& m_set;
+    T m_item;
+    bool m_isNewEntry;
+};
 
-} // namespace WHLSL
+}
 
-} // namespace WebCore
+}
 
 #endif
index 0bdcc0c..f614567 100644 (file)
@@ -189,16 +189,78 @@ static bool isAcceptableFormat(TextureFormat textureFormat, AST::UnnamedType& un
 {
     if (isColor) {
         switch (textureFormat) {
-        case TextureFormat::R8G8B8A8Unorm:
-        case TextureFormat::R8G8B8A8Uint:
-        case TextureFormat::B8G8R8A8Unorm:
+        case TextureFormat::R8Unorm:
+        case TextureFormat::R8UnormSrgb:
+        case TextureFormat::R8Snorm:
+        case TextureFormat::R16Unorm:
+        case TextureFormat::R16Snorm:
+        case TextureFormat::R16Float:
+        case TextureFormat::R32Float:
+            return matches(unnamedType, intrinsics.floatType());
+        case TextureFormat::RG8Unorm:
+        case TextureFormat::RG8UnormSrgb:
+        case TextureFormat::RG8Snorm:
+        case TextureFormat::RG16Unorm:
+        case TextureFormat::RG16Snorm:
+        case TextureFormat::RG16Float:
+        case TextureFormat::RG32Float:
+            return matches(unnamedType, intrinsics.float2Type());
+        case TextureFormat::B5G6R5Unorm:
+        case TextureFormat::RG11B10Float:
+            return matches(unnamedType, intrinsics.float3Type());
+        case TextureFormat::RGBA8Unorm:
+        case TextureFormat::RGBA8UnormSrgb:
+        case TextureFormat::BGRA8Unorm:
+        case TextureFormat::BGRA8UnormSrgb:
+        case TextureFormat::RGBA8Snorm:
+        case TextureFormat::RGB10A2Unorm:
+        case TextureFormat::RGBA16Unorm:
+        case TextureFormat::RGBA16Snorm:
+        case TextureFormat::RGBA16Float:
+        case TextureFormat::RGBA32Float:
             return matches(unnamedType, intrinsics.float4Type());
+        case TextureFormat::R8Uint:
+            return matches(unnamedType, intrinsics.ucharType());
+        case TextureFormat::R8Sint:
+            return matches(unnamedType, intrinsics.charType());
+        case TextureFormat::R16Uint:
+            return matches(unnamedType, intrinsics.ushortType());
+        case TextureFormat::R16Sint:
+            return matches(unnamedType, intrinsics.shortType());
+        case TextureFormat::R32Uint:
+            return matches(unnamedType, intrinsics.uintType());
+        case TextureFormat::R32Sint:
+            return matches(unnamedType, intrinsics.intType());
+        case TextureFormat::RG8Uint:
+            return matches(unnamedType, intrinsics.uchar2Type());
+        case TextureFormat::RG8Sint:
+            return matches(unnamedType, intrinsics.char2Type());
+        case TextureFormat::RG16Uint:
+            return matches(unnamedType, intrinsics.ushort2Type());
+        case TextureFormat::RG16Sint:
+            return matches(unnamedType, intrinsics.short2Type());
+        case TextureFormat::RG32Uint:
+            return matches(unnamedType, intrinsics.uint2Type());
+        case TextureFormat::RG32Sint:
+            return matches(unnamedType, intrinsics.int2Type());
+        case TextureFormat::RGBA8Uint:
+            return matches(unnamedType, intrinsics.uchar4Type());
+        case TextureFormat::RGBA8Sint:
+            return matches(unnamedType, intrinsics.char4Type());
+        case TextureFormat::RGBA16Uint:
+            return matches(unnamedType, intrinsics.ushort4Type());
+        case TextureFormat::RGBA16Sint:
+            return matches(unnamedType, intrinsics.short4Type());
+        case TextureFormat::RGBA32Uint:
+            return matches(unnamedType, intrinsics.uint4Type());
+        case TextureFormat::RGBA32Sint:
+            return matches(unnamedType, intrinsics.int4Type());
         default:
-            ASSERT(textureFormat == TextureFormat::D32FloatS8Uint);
+            ASSERT_NOT_REACHED();
             return false;
         }
     }
-    return textureFormat == TextureFormat::D32FloatS8Uint && matches(unnamedType, intrinsics.floatType());
+    return false;
 }
 
 static Optional<HashMap<AttachmentDescriptor*, size_t>> matchColorAttachments(Vector<EntryPointItem>& fragmentOutputs, Vector<AttachmentDescriptor>& attachmentDescriptors, Intrinsics& intrinsics)
index 570747b..79a2c60 100644 (file)
@@ -105,10 +105,6 @@ typedef float4x3 = matrix<float, 4, 3>;
 native typedef matrix<float, 4, 4>;
 typedef float4x4 = matrix<float, 4, 4>;
 native typedef sampler;
-native typedef Texture1D<uchar>;
-native typedef Texture1D<uchar2>;
-native typedef Texture1D<uchar3>;
-native typedef Texture1D<uchar4>;
 native typedef Texture1D<ushort>;
 native typedef Texture1D<ushort2>;
 native typedef Texture1D<ushort3>;
@@ -117,10 +113,6 @@ native typedef Texture1D<uint>;
 native typedef Texture1D<uint2>;
 native typedef Texture1D<uint3>;
 native typedef Texture1D<uint4>;
-native typedef Texture1D<char>;
-native typedef Texture1D<char2>;
-native typedef Texture1D<char3>;
-native typedef Texture1D<char4>;
 native typedef Texture1D<short>;
 native typedef Texture1D<short2>;
 native typedef Texture1D<short3>;
@@ -137,10 +129,6 @@ native typedef Texture1D<float>;
 native typedef Texture1D<float2>;
 native typedef Texture1D<float3>;
 native typedef Texture1D<float4>;
-native typedef RWTexture1D<uchar>;
-native typedef RWTexture1D<uchar2>;
-native typedef RWTexture1D<uchar3>;
-native typedef RWTexture1D<uchar4>;
 native typedef RWTexture1D<ushort>;
 native typedef RWTexture1D<ushort2>;
 native typedef RWTexture1D<ushort3>;
@@ -149,10 +137,6 @@ native typedef RWTexture1D<uint>;
 native typedef RWTexture1D<uint2>;
 native typedef RWTexture1D<uint3>;
 native typedef RWTexture1D<uint4>;
-native typedef RWTexture1D<char>;
-native typedef RWTexture1D<char2>;
-native typedef RWTexture1D<char3>;
-native typedef RWTexture1D<char4>;
 native typedef RWTexture1D<short>;
 native typedef RWTexture1D<short2>;
 native typedef RWTexture1D<short3>;
@@ -169,10 +153,6 @@ native typedef RWTexture1D<float>;
 native typedef RWTexture1D<float2>;
 native typedef RWTexture1D<float3>;
 native typedef RWTexture1D<float4>;
-native typedef Texture1DArray<uchar>;
-native typedef Texture1DArray<uchar2>;
-native typedef Texture1DArray<uchar3>;
-native typedef Texture1DArray<uchar4>;
 native typedef Texture1DArray<ushort>;
 native typedef Texture1DArray<ushort2>;
 native typedef Texture1DArray<ushort3>;
@@ -181,10 +161,6 @@ native typedef Texture1DArray<uint>;
 native typedef Texture1DArray<uint2>;
 native typedef Texture1DArray<uint3>;
 native typedef Texture1DArray<uint4>;
-native typedef Texture1DArray<char>;
-native typedef Texture1DArray<char2>;
-native typedef Texture1DArray<char3>;
-native typedef Texture1DArray<char4>;
 native typedef Texture1DArray<short>;
 native typedef Texture1DArray<short2>;
 native typedef Texture1DArray<short3>;
@@ -201,10 +177,6 @@ native typedef Texture1DArray<float>;
 native typedef Texture1DArray<float2>;
 native typedef Texture1DArray<float3>;
 native typedef Texture1DArray<float4>;
-native typedef RWTexture1DArray<uchar>;
-native typedef RWTexture1DArray<uchar2>;
-native typedef RWTexture1DArray<uchar3>;
-native typedef RWTexture1DArray<uchar4>;
 native typedef RWTexture1DArray<ushort>;
 native typedef RWTexture1DArray<ushort2>;
 native typedef RWTexture1DArray<ushort3>;
@@ -213,10 +185,6 @@ native typedef RWTexture1DArray<uint>;
 native typedef RWTexture1DArray<uint2>;
 native typedef RWTexture1DArray<uint3>;
 native typedef RWTexture1DArray<uint4>;
-native typedef RWTexture1DArray<char>;
-native typedef RWTexture1DArray<char2>;
-native typedef RWTexture1DArray<char3>;
-native typedef RWTexture1DArray<char4>;
 native typedef RWTexture1DArray<short>;
 native typedef RWTexture1DArray<short2>;
 native typedef RWTexture1DArray<short3>;
@@ -233,10 +201,6 @@ native typedef RWTexture1DArray<float>;
 native typedef RWTexture1DArray<float2>;
 native typedef RWTexture1DArray<float3>;
 native typedef RWTexture1DArray<float4>;
-native typedef Texture2D<uchar>;
-native typedef Texture2D<uchar2>;
-native typedef Texture2D<uchar3>;
-native typedef Texture2D<uchar4>;
 native typedef Texture2D<ushort>;
 native typedef Texture2D<ushort2>;
 native typedef Texture2D<ushort3>;
@@ -245,10 +209,6 @@ native typedef Texture2D<uint>;
 native typedef Texture2D<uint2>;
 native typedef Texture2D<uint3>;
 native typedef Texture2D<uint4>;
-native typedef Texture2D<char>;
-native typedef Texture2D<char2>;
-native typedef Texture2D<char3>;
-native typedef Texture2D<char4>;
 native typedef Texture2D<short>;
 native typedef Texture2D<short2>;
 native typedef Texture2D<short3>;
@@ -265,10 +225,6 @@ native typedef Texture2D<float>;
 native typedef Texture2D<float2>;
 native typedef Texture2D<float3>;
 native typedef Texture2D<float4>;
-native typedef RWTexture2D<uchar>;
-native typedef RWTexture2D<uchar2>;
-native typedef RWTexture2D<uchar3>;
-native typedef RWTexture2D<uchar4>;
 native typedef RWTexture2D<ushort>;
 native typedef RWTexture2D<ushort2>;
 native typedef RWTexture2D<ushort3>;
@@ -277,10 +233,6 @@ native typedef RWTexture2D<uint>;
 native typedef RWTexture2D<uint2>;
 native typedef RWTexture2D<uint3>;
 native typedef RWTexture2D<uint4>;
-native typedef RWTexture2D<char>;
-native typedef RWTexture2D<char2>;
-native typedef RWTexture2D<char3>;
-native typedef RWTexture2D<char4>;
 native typedef RWTexture2D<short>;
 native typedef RWTexture2D<short2>;
 native typedef RWTexture2D<short3>;
@@ -297,10 +249,6 @@ native typedef RWTexture2D<float>;
 native typedef RWTexture2D<float2>;
 native typedef RWTexture2D<float3>;
 native typedef RWTexture2D<float4>;
-native typedef Texture2DArray<uchar>;
-native typedef Texture2DArray<uchar2>;
-native typedef Texture2DArray<uchar3>;
-native typedef Texture2DArray<uchar4>;
 native typedef Texture2DArray<ushort>;
 native typedef Texture2DArray<ushort2>;
 native typedef Texture2DArray<ushort3>;
@@ -309,10 +257,6 @@ native typedef Texture2DArray<uint>;
 native typedef Texture2DArray<uint2>;
 native typedef Texture2DArray<uint3>;
 native typedef Texture2DArray<uint4>;
-native typedef Texture2DArray<char>;
-native typedef Texture2DArray<char2>;
-native typedef Texture2DArray<char3>;
-native typedef Texture2DArray<char4>;
 native typedef Texture2DArray<short>;
 native typedef Texture2DArray<short2>;
 native typedef Texture2DArray<short3>;
@@ -329,10 +273,6 @@ native typedef Texture2DArray<float>;
 native typedef Texture2DArray<float2>;
 native typedef Texture2DArray<float3>;
 native typedef Texture2DArray<float4>;
-native typedef RWTexture2DArray<uchar>;
-native typedef RWTexture2DArray<uchar2>;
-native typedef RWTexture2DArray<uchar3>;
-native typedef RWTexture2DArray<uchar4>;
 native typedef RWTexture2DArray<ushort>;
 native typedef RWTexture2DArray<ushort2>;
 native typedef RWTexture2DArray<ushort3>;
@@ -341,10 +281,6 @@ native typedef RWTexture2DArray<uint>;
 native typedef RWTexture2DArray<uint2>;
 native typedef RWTexture2DArray<uint3>;
 native typedef RWTexture2DArray<uint4>;
-native typedef RWTexture2DArray<char>;
-native typedef RWTexture2DArray<char2>;
-native typedef RWTexture2DArray<char3>;
-native typedef RWTexture2DArray<char4>;
 native typedef RWTexture2DArray<short>;
 native typedef RWTexture2DArray<short2>;
 native typedef RWTexture2DArray<short3>;
@@ -361,10 +297,6 @@ native typedef RWTexture2DArray<float>;
 native typedef RWTexture2DArray<float2>;
 native typedef RWTexture2DArray<float3>;
 native typedef RWTexture2DArray<float4>;
-native typedef Texture3D<uchar>;
-native typedef Texture3D<uchar2>;
-native typedef Texture3D<uchar3>;
-native typedef Texture3D<uchar4>;
 native typedef Texture3D<ushort>;
 native typedef Texture3D<ushort2>;
 native typedef Texture3D<ushort3>;
@@ -373,10 +305,6 @@ native typedef Texture3D<uint>;
 native typedef Texture3D<uint2>;
 native typedef Texture3D<uint3>;
 native typedef Texture3D<uint4>;
-native typedef Texture3D<char>;
-native typedef Texture3D<char2>;
-native typedef Texture3D<char3>;
-native typedef Texture3D<char4>;
 native typedef Texture3D<short>;
 native typedef Texture3D<short2>;
 native typedef Texture3D<short3>;
@@ -393,10 +321,6 @@ native typedef Texture3D<float>;
 native typedef Texture3D<float2>;
 native typedef Texture3D<float3>;
 native typedef Texture3D<float4>;
-native typedef RWTexture3D<uchar>;
-native typedef RWTexture3D<uchar2>;
-native typedef RWTexture3D<uchar3>;
-native typedef RWTexture3D<uchar4>;
 native typedef RWTexture3D<ushort>;
 native typedef RWTexture3D<ushort2>;
 native typedef RWTexture3D<ushort3>;
@@ -405,10 +329,6 @@ native typedef RWTexture3D<uint>;
 native typedef RWTexture3D<uint2>;
 native typedef RWTexture3D<uint3>;
 native typedef RWTexture3D<uint4>;
-native typedef RWTexture3D<char>;
-native typedef RWTexture3D<char2>;
-native typedef RWTexture3D<char3>;
-native typedef RWTexture3D<char4>;
 native typedef RWTexture3D<short>;
 native typedef RWTexture3D<short2>;
 native typedef RWTexture3D<short3>;
@@ -425,10 +345,6 @@ native typedef RWTexture3D<float>;
 native typedef RWTexture3D<float2>;
 native typedef RWTexture3D<float3>;
 native typedef RWTexture3D<float4>;
-native typedef TextureCube<uchar>;
-native typedef TextureCube<uchar2>;
-native typedef TextureCube<uchar3>;
-native typedef TextureCube<uchar4>;
 native typedef TextureCube<ushort>;
 native typedef TextureCube<ushort2>;
 native typedef TextureCube<ushort3>;
@@ -437,10 +353,6 @@ native typedef TextureCube<uint>;
 native typedef TextureCube<uint2>;
 native typedef TextureCube<uint3>;
 native typedef TextureCube<uint4>;
-native typedef TextureCube<char>;
-native typedef TextureCube<char2>;
-native typedef TextureCube<char3>;
-native typedef TextureCube<char4>;
 native typedef TextureCube<short>;
 native typedef TextureCube<short2>;
 native typedef TextureCube<short3>;
@@ -458,14 +370,24 @@ native typedef TextureCube<float2>;
 native typedef TextureCube<float3>;
 native typedef TextureCube<float4>;
 native typedef TextureDepth2D<float>;
-native typedef TextureDepth2D<half>;
 native typedef RWTextureDepth2D<float>;
-native typedef RWTextureDepth2D<half>;
 native typedef TextureDepth2DArray<float>;
-native typedef TextureDepth2DArray<half>;
 native typedef RWTextureDepth2DArray<float>;
-native typedef RWTextureDepth2DArray<half>;
 native typedef TextureDepthCube<float>;
-native typedef TextureDepthCube<half>;
 
-// FIXME: Insert the rest of the standard library once the parser is fast enough
+native float operator.x(float4);
+native float operator.y(float4);
+native float operator.z(float4);
+native float operator.w(float4);
+native float4 operator.x=(float4, float);
+native float4 operator.y=(float4, float);
+native float4 operator.z=(float4, float);
+native float4 operator.w=(float4, float);
+
+native float ddx(float);
+native float ddy(float);
+native void AllMemoryBarrierWithGroupSync();
+native void DeviceMemoryBarrierWithGroupSync();
+native void GroupMemoryBarrierWithGroupSync();
+
+// FIXME: https://bugs.webkit.org/show_bug.cgi?id=192890 Insert the rest of the standard library once the parser is fast enough
index 3bf6e87..ae3ca64 100644 (file)
@@ -44,7 +44,7 @@ public:
     void visit(AST::ArrayType& arrayType) override
     {
         m_arrayTypes.append(arrayType);
-        checkErrorAndVisit(arrayType);
+        Visitor::visit(arrayType);
     }
 
     Vector<std::reference_wrapper<AST::ArrayType>>&& takeArrayTypes()
@@ -56,7 +56,7 @@ private:
     Vector<std::reference_wrapper<AST::ArrayType>> m_arrayTypes;
 };
 
-void synthesizeArrayOperatorLength(Program& program)
+bool synthesizeArrayOperatorLength(Program& program)
 {
     FindArrayTypes findArrayTypes;
     findArrayTypes.checkErrorAndVisit(program);
@@ -69,8 +69,10 @@ void synthesizeArrayOperatorLength(Program& program)
         AST::VariableDeclarations parameters;
         parameters.append(WTFMove(variableDeclaration));
         AST::NativeFunctionDeclaration nativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(arrayType.get().origin()), AST::AttributeBlock(), WTF::nullopt, AST::TypeReference::wrap(Lexer::Token(arrayType.get().origin()), program.intrinsics().uintType()), "operator.length"_str, WTFMove(parameters), WTF::nullopt, isOperator));
-        program.append(WTFMove(nativeFunctionDeclaration));
+        if (!program.append(WTFMove(nativeFunctionDeclaration)))
+            return false;
     }
+    return true;
 }
 
 } // namespace WHLSL
index ed32b64..95af98c 100644 (file)
@@ -51,37 +51,37 @@ public:
     void visit(AST::PointerType& pointerType) override
     {
         m_unnamedTypes.append(pointerType);
-        checkErrorAndVisit(pointerType);
+        Visitor::visit(pointerType);
     }
 
     void visit(AST::ArrayReferenceType& arrayReferenceType) override
     {
         m_unnamedTypes.append(arrayReferenceType);
-        checkErrorAndVisit(arrayReferenceType);
+        Visitor::visit(arrayReferenceType);
     }
 
     void visit(AST::ArrayType& arrayType) override
     {
         m_unnamedTypes.append(arrayType);
-        checkErrorAndVisit(arrayType);
+        Visitor::visit(arrayType);
     }
 
     void visit(AST::EnumerationDefinition& enumerationDefinition) override
     {
         m_namedTypes.append(enumerationDefinition);
-        checkErrorAndVisit(enumerationDefinition);
+        Visitor::visit(enumerationDefinition);
     }
 
     void visit(AST::StructureDefinition& structureDefinition) override
     {
         m_namedTypes.append(structureDefinition);
-        checkErrorAndVisit(structureDefinition);
+        Visitor::visit(structureDefinition);
     }
 
     void visit(AST::NativeTypeDeclaration& nativeTypeDeclaration) override
     {
         m_namedTypes.append(nativeTypeDeclaration);
-        checkErrorAndVisit(nativeTypeDeclaration);
+        Visitor::visit(nativeTypeDeclaration);
     }
 
     Vector<std::reference_wrapper<AST::UnnamedType>>&& takeUnnamedTypes()
@@ -99,16 +99,16 @@ private:
     Vector<std::reference_wrapper<AST::NamedType>> m_namedTypes;
 };
 
-void synthesizeConstructors(Program& program)
+bool synthesizeConstructors(Program& program)
 {
     FindAllTypes findAllTypes;
     findAllTypes.checkErrorAndVisit(program);
-    auto m_unnamedTypes = findAllTypes.takeUnnamedTypes();
-    auto m_namedTypes = findAllTypes.takeNamedTypes();
+    auto unnamedTypes = findAllTypes.takeUnnamedTypes();
+    auto namedTypes = findAllTypes.takeNamedTypes();
 
     bool isOperator = true;
 
-    for (auto& unnamedType : m_unnamedTypes) {
+    for (auto& unnamedType : unnamedTypes) {
         AST::VariableDeclaration variableDeclaration(Lexer::Token(unnamedType.get().origin()), AST::Qualifiers(), { unnamedType.get().clone() }, String(), WTF::nullopt, WTF::nullopt);
         AST::VariableDeclarations parameters;
         parameters.append(WTFMove(variableDeclaration));
@@ -116,10 +116,16 @@ void synthesizeConstructors(Program& program)
         program.append(WTFMove(copyConstructor));
 
         AST::NativeFunctionDeclaration defaultConstructor(AST::FunctionDeclaration(Lexer::Token(unnamedType.get().origin()), AST::AttributeBlock(), WTF::nullopt, unnamedType.get().clone(), "operator cast"_str, AST::VariableDeclarations(), WTF::nullopt, isOperator));
-        program.append(WTFMove(defaultConstructor));
+        if (!program.append(WTFMove(defaultConstructor)))
+            return false;
     }
 
-    for (auto& namedType : m_namedTypes) {
+    for (auto& namedType : namedTypes) {
+        if (matches(namedType, program.intrinsics().voidType()))
+            continue;
+        if (is<AST::NativeTypeDeclaration>(static_cast<AST::NamedType&>(namedType)) && downcast<AST::NativeTypeDeclaration>(static_cast<AST::NamedType&>(namedType)).isAtomic())
+            continue;
+
         AST::VariableDeclaration variableDeclaration(Lexer::Token(namedType.get().origin()), AST::Qualifiers(), { AST::TypeReference::wrap(Lexer::Token(namedType.get().origin()), namedType.get()) }, String(), WTF::nullopt, WTF::nullopt);
         AST::VariableDeclarations parameters;
         parameters.append(WTFMove(variableDeclaration));
@@ -127,8 +133,10 @@ void synthesizeConstructors(Program& program)
         program.append(WTFMove(copyConstructor));
 
         AST::NativeFunctionDeclaration defaultConstructor(AST::FunctionDeclaration(Lexer::Token(namedType.get().origin()), AST::AttributeBlock(), WTF::nullopt, AST::TypeReference::wrap(Lexer::Token(namedType.get().origin()), namedType.get()), "operator cast"_str, AST::VariableDeclarations(), WTF::nullopt, isOperator));
-        program.append(WTFMove(defaultConstructor));
+        if (!program.append(WTFMove(defaultConstructor)))
+            return false;
     }
+    return true;
 }
 
 } // namespace WHLSL
index 19a05be..67a0863 100644 (file)
@@ -33,7 +33,7 @@ namespace WHLSL {
 
 class Program;
 
-void synthesizeConstructors(Program&);
+bool synthesizeConstructors(Program&);
 
 }
 
index 5529d03..1797fe5 100644 (file)
@@ -36,7 +36,7 @@ namespace WebCore {
 
 namespace WHLSL {
 
-void synthesizeEnumerationFunctions(Program& program)
+bool synthesizeEnumerationFunctions(Program& program)
 {
     bool isOperator = true;
     for (auto& enumerationDefinition : program.enumerationDefinitions()) {
@@ -47,7 +47,8 @@ void synthesizeEnumerationFunctions(Program& program)
             parameters.append(WTFMove(variableDeclaration1));
             parameters.append(WTFMove(variableDeclaration2));
             AST::NativeFunctionDeclaration nativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(enumerationDefinition->origin()), AST::AttributeBlock(), WTF::nullopt, AST::TypeReference::wrap(Lexer::Token(enumerationDefinition->origin()), program.intrinsics().boolType()), "operator=="_str, WTFMove(parameters), WTF::nullopt, isOperator));
-            program.append(WTFMove(nativeFunctionDeclaration));
+            if (!program.append(WTFMove(nativeFunctionDeclaration)))
+                return false;
         }
 
         {
@@ -55,7 +56,8 @@ void synthesizeEnumerationFunctions(Program& program)
             AST::VariableDeclarations parameters;
             parameters.append(WTFMove(variableDeclaration));
             AST::NativeFunctionDeclaration nativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(enumerationDefinition->origin()), AST::AttributeBlock(), WTF::nullopt, enumerationDefinition->type().clone(), "operator.value"_str, WTFMove(parameters), WTF::nullopt, isOperator));
-            program.append(WTFMove(nativeFunctionDeclaration));
+            if (!program.append(WTFMove(nativeFunctionDeclaration)))
+                return false;
         }
 
         {
@@ -63,7 +65,8 @@ void synthesizeEnumerationFunctions(Program& program)
             AST::VariableDeclarations parameters;
             parameters.append(WTFMove(variableDeclaration));
             AST::NativeFunctionDeclaration nativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(enumerationDefinition->origin()), AST::AttributeBlock(), WTF::nullopt, enumerationDefinition->type().clone(), "operator cast"_str, WTFMove(parameters), WTF::nullopt, isOperator));
-            program.append(WTFMove(nativeFunctionDeclaration));
+            if (!program.append(WTFMove(nativeFunctionDeclaration)))
+                return false;
         }
 
         {
@@ -71,9 +74,11 @@ void synthesizeEnumerationFunctions(Program& program)
             AST::VariableDeclarations parameters;
             parameters.append(WTFMove(variableDeclaration));
             AST::NativeFunctionDeclaration nativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(enumerationDefinition->origin()), AST::AttributeBlock(), WTF::nullopt, { AST::TypeReference::wrap(Lexer::Token(enumerationDefinition->origin()), enumerationDefinition) }, "operator cast"_str, WTFMove(parameters), WTF::nullopt, isOperator));
-            program.append(WTFMove(nativeFunctionDeclaration));
+            if (!program.append(WTFMove(nativeFunctionDeclaration)))
+                return false;
         }
     }
+    return true;
 }
 
 } // namespace WHLSL
index 53024a6..8267221 100644 (file)
@@ -39,7 +39,7 @@ namespace WebCore {
 
 namespace WHLSL {
 
-void synthesizeStructureAccessors(Program& program)
+bool synthesizeStructureAccessors(Program& program)
 {
     bool isOperator = true;
     for (auto& structureDefinition : program.structureDefinitions()) {
@@ -50,7 +50,8 @@ void synthesizeStructureAccessors(Program& program)
                 AST::VariableDeclarations parameters;
                 parameters.append(WTFMove(variableDeclaration));
                 AST::NativeFunctionDeclaration nativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(structureElement.origin()), AST::AttributeBlock(), WTF::nullopt, structureElement.type().clone(), makeString("operator.", structureElement.name()), WTFMove(parameters), WTF::nullopt, isOperator));
-                program.append(WTFMove(nativeFunctionDeclaration));
+                if (!program.append(WTFMove(nativeFunctionDeclaration)))
+                    return false;
             }
 
             {
@@ -61,7 +62,8 @@ void synthesizeStructureAccessors(Program& program)
                 parameters.append(WTFMove(variableDeclaration1));
                 parameters.append(WTFMove(variableDeclaration2));
                 AST::NativeFunctionDeclaration nativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(structureElement.origin()), AST::AttributeBlock(), WTF::nullopt, AST::TypeReference::wrap(Lexer::Token(structureElement.origin()), structureDefinition), makeString("operator.", structureElement.name(), '='), WTFMove(parameters), WTF::nullopt, isOperator));
-                program.append(WTFMove(nativeFunctionDeclaration));
+                if (!program.append(WTFMove(nativeFunctionDeclaration)))
+                    return false;
             }
 
             // The ander: operator&.field
@@ -74,12 +76,14 @@ void synthesizeStructureAccessors(Program& program)
                 AST::NativeFunctionDeclaration nativeFunctionDeclaration(AST::FunctionDeclaration(Lexer::Token(structureElement.origin()), AST::AttributeBlock(), WTF::nullopt, WTFMove(returnType), makeString("operator&.", structureElement.name()), WTFMove(parameters), WTF::nullopt, isOperator));
                 return nativeFunctionDeclaration;
             };
-            program.append(createAnder(AST::AddressSpace::Constant));
-            program.append(createAnder(AST::AddressSpace::Device));
-            program.append(createAnder(AST::AddressSpace::Threadgroup));
-            program.append(createAnder(AST::AddressSpace::Thread));
+            if (!program.append(createAnder(AST::AddressSpace::Constant))
+                || !program.append(createAnder(AST::AddressSpace::Device))
+                || !program.append(createAnder(AST::AddressSpace::Threadgroup))
+                || !program.append(createAnder(AST::AddressSpace::Thread)))
+                return false;
         }
     }
+    return true;
 }
 
 } // namespace WHLSL
index bb9f2b9..1673162 100644 (file)
@@ -33,7 +33,7 @@ namespace WHLSL {
 
 class Program;
 
-void synthesizeStructureAccessors(Program&);
+bool synthesizeStructureAccessors(Program&);
 
 }
 
index 80b4e1c..ed2e8c0 100644 (file)
@@ -562,7 +562,7 @@ void Visitor::visit(AST::CallExpression& callExpression)
     for (auto& argument : callExpression.arguments())
         checkErrorAndVisit(argument);
     if (callExpression.castReturnType())
-        checkErrorAndVisit(callExpression.castReturnType()->get());
+        checkErrorAndVisit(*callExpression.castReturnType());
 }
 
 void Visitor::visit(AST::CommaExpression& commaExpression)
index d2a4a1c..b9b5665 100644 (file)
@@ -122,7 +122,7 @@ Ref<WebGPUBindGroup> WebGPUDevice::createBindGroup(WebGPUBindGroupDescriptor&& d
 RefPtr<WebGPUShaderModule> WebGPUDevice::createShaderModule(WebGPUShaderModuleDescriptor&& descriptor) const
 {
     // FIXME: What can be validated here?
-    if (auto module = m_device->createShaderModule(GPUShaderModuleDescriptor { descriptor.code }))
+    if (auto module = m_device->createShaderModule(GPUShaderModuleDescriptor { descriptor.code, descriptor.isWHLSL }))
         return WebGPUShaderModule::create(module.releaseNonNull());
     return nullptr;
 }
index af32a98..b8d9817 100644 (file)
@@ -33,6 +33,7 @@ namespace WebCore {
 
 struct WebGPUShaderModuleDescriptor {
     String code;
+    bool isWHLSL;
 };
 
 } // namespace WebCore
index 32d707e..971352a 100644 (file)
@@ -29,4 +29,5 @@
     EnabledAtRuntime=WebGPU
 ] dictionary WebGPUShaderModuleDescriptor {
     /*ArrayBuffer*/ DOMString code; // FIXME: DOMString for MTL prototyping only.
+    boolean isWHLSL = false;
 };
index ecc6e78..67089df 100644 (file)
                1C33277121CF0BE1000DC9F2 /* WHLSLNamedType.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = WHLSLNamedType.h; sourceTree = "<group>"; };
                1C33277221CF0D2E000DC9F2 /* WHLSLUnnamedType.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = WHLSLUnnamedType.h; sourceTree = "<group>"; };
                1C3969CF1B74211E002BCFA7 /* FontCacheCoreText.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = FontCacheCoreText.cpp; sourceTree = "<group>"; };
+               1C59B0182238687900853805 /* WHLSLScopedSetAdder.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = WHLSLScopedSetAdder.h; sourceTree = "<group>"; };
                1C66260E1C6E7CA600AB527C /* FontFace.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = FontFace.cpp; sourceTree = "<group>"; };
                1C66260F1C6E7CA600AB527C /* FontFace.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = FontFace.h; sourceTree = "<group>"; };
                1C81B9560E97330800266E07 /* InspectorController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = InspectorController.h; sourceTree = "<group>"; };
                C24A57B321FB8DDA004C6DD1 /* WHLSLSemanticMatcher.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = WHLSLSemanticMatcher.h; sourceTree = "<group>"; };
                C24A57BA21FEAFEA004C6DD1 /* WHLSLPrepare.cpp */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.cpp; path = WHLSLPrepare.cpp; sourceTree = "<group>"; };
                C24A57BB21FEAFEA004C6DD1 /* WHLSLPrepare.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = WHLSLPrepare.h; sourceTree = "<group>"; };
-               C24A57BE21FEC65C004C6DD1 /* WHLSLMappedBindings.h */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.c.h; path = WHLSLMappedBindings.h; sourceTree = "<group>"; };
                C26017A11C72DC9900F74A16 /* CSSFontFaceSet.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CSSFontFaceSet.cpp; sourceTree = "<group>"; };
                C26017A21C72DC9900F74A16 /* CSSFontFaceSet.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CSSFontFaceSet.h; sourceTree = "<group>"; };
                C280833C1C6DB194001451B6 /* FontFace.idl */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = FontFace.idl; sourceTree = "<group>"; };
                                1CECB3BB21F511AA00F44542 /* WHLSLEntryPointScaffolding.h */,
                                1CECB3AF21F2B98400F44542 /* WHLSLFunctionWriter.cpp */,
                                1CECB3B221F2B98600F44542 /* WHLSLFunctionWriter.h */,
-                               C24A57BE21FEC65C004C6DD1 /* WHLSLMappedBindings.h */,
                                1CECB3B521F50AC700F44542 /* WHLSLMetalCodeGenerator.cpp */,
                                1CECB3B621F50AC700F44542 /* WHLSLMetalCodeGenerator.h */,
                                1CECB3B821F50D1000F44542 /* WHLSLNativeFunctionWriter.cpp */,
                                C234A99921E90F29003C984D /* WHLSLResolveOverloadImpl.cpp */,
                                C234A99721E90F28003C984D /* WHLSLResolveOverloadImpl.h */,
                                C234A99D21E910BD003C984D /* WHLSLResolvingType.h */,
+                               1C59B0182238687900853805 /* WHLSLScopedSetAdder.h */,
                                C24A57B221FB8DDA004C6DD1 /* WHLSLSemanticMatcher.cpp */,
                                C24A57B321FB8DDA004C6DD1 /* WHLSLSemanticMatcher.h */,
                                C21BF74521CD969800227979 /* WHLSLStandardLibrary.txt */,
index 262618f..e8c6ab9 100644 (file)
@@ -38,6 +38,8 @@ class GPUPipelineLayout : public RefCounted<GPUPipelineLayout> {
 public:
     static Ref<GPUPipelineLayout> create(GPUPipelineLayoutDescriptor&&);
 
+    const Vector<RefPtr<const GPUBindGroupLayout>>& bindGroupLayouts() const { return m_bindGroupLayouts; }
+
 private:
     explicit GPUPipelineLayout(GPUPipelineLayoutDescriptor&&);
 
index cbfb3f9..e9a87bb 100644 (file)
@@ -46,12 +46,15 @@ class GPUShaderModule : public RefCounted<GPUShaderModule> {
 public:
     static RefPtr<GPUShaderModule> create(const GPUDevice&, GPUShaderModuleDescriptor&&);
 
-    PlatformShaderModule* platformShaderModule() const { return m_platformShaderModule.get(); }
+    PlatformShaderModule* platformShaderModule() const { return m_whlslSource.isNull() ? m_platformShaderModule.get() : nullptr; }
+    const String& whlslSource() const { return m_whlslSource; }
 
 private:
     GPUShaderModule(PlatformShaderModuleSmartPtr&&);
+    GPUShaderModule(String&& whlslSource);
 
     PlatformShaderModuleSmartPtr m_platformShaderModule;
+    String m_whlslSource;
 };
 
 } // namespace WebCore
index 6cfe0ad..5c3ecf6 100644 (file)
@@ -33,6 +33,7 @@ namespace WebCore {
 
 struct GPUShaderModuleDescriptor {
     String code;
+    bool isWHLSL;
 };
 
 } // namespace WebCore
index e6f4aed..a9d500a 100644 (file)
@@ -31,6 +31,7 @@
 #import "GPULimits.h"
 #import "GPUUtils.h"
 #import "Logging.h"
+#import "WHLSLPrepare.h"
 #import "WHLSLVertexBufferIndexCalculator.h"
 #import <Metal/Metal.h>
 #import <wtf/BlockObjCExceptions.h>
@@ -78,39 +79,214 @@ static RetainPtr<MTLDepthStencilState> tryCreateMtlDepthStencilState(const char*
     return state;
 }
 
-static bool trySetFunctionsForPipelineDescriptor(const char* const functionName, MTLRenderPipelineDescriptor *mtlDescriptor, const GPURenderPipelineDescriptor& descriptor)
+static WHLSL::VertexFormat convertVertexFormat(GPUVertexFormat vertexFormat)
+{
+    switch (vertexFormat) {
+    case GPUVertexFormat::Float4:
+        return WHLSL::VertexFormat::FloatR32G32B32A32;
+    case GPUVertexFormat::Float3:
+        return WHLSL::VertexFormat::FloatR32G32B32;
+    case GPUVertexFormat::Float2:
+        return WHLSL::VertexFormat::FloatR32G32;
+    default:
+        ASSERT(vertexFormat == GPUVertexFormat::Float);
+        return WHLSL::VertexFormat::FloatR32;
+    }
+}
+
+static OptionSet<WHLSL::ShaderStage> convertShaderStageFlags(GPUShaderStageFlags flags)
+{
+    OptionSet<WHLSL::ShaderStage> result;
+    if (flags & GPUShaderStageBit::Flags::Vertex)
+        result.add(WHLSL::ShaderStage::Vertex);
+    if (flags & GPUShaderStageBit::Flags::Fragment)
+        result.add(WHLSL::ShaderStage::Fragment);
+    if (flags & GPUShaderStageBit::Flags::Compute)
+        result.add(WHLSL::ShaderStage::Compute);
+    return result;
+}
+
+static Optional<WHLSL::BindingType> convertBindingType(GPUBindingType type)
+{
+    switch (type) {
+    case GPUBindingType::UniformBuffer:
+        return WHLSL::BindingType::UniformBuffer;
+    case GPUBindingType::Sampler:
+        return WHLSL::BindingType::Sampler;
+    case GPUBindingType::SampledTexture:
+        return WHLSL::BindingType::Texture;
+    case GPUBindingType::StorageBuffer:
+        return WHLSL::BindingType::StorageBuffer;
+    default:
+        return WTF::nullopt;
+    }
+}
+
+static Optional<WHLSL::TextureFormat> convertTextureFormat(GPUTextureFormat format)
+{
+    switch (format) {
+    case GPUTextureFormat::Rgba8unorm:
+        return WHLSL::TextureFormat::RGBA8Unorm;
+    case GPUTextureFormat::Rgba8uint:
+        return WHLSL::TextureFormat::RGBA8Uint;
+    case GPUTextureFormat::Bgra8unorm:
+        return WHLSL::TextureFormat::BGRA8Unorm;
+    case GPUTextureFormat::Depth32floatStencil8:
+        return WTF::nullopt; // FIXME: Figure out what to do with this.
+    case GPUTextureFormat::Bgra8unormSRGB:
+        return WHLSL::TextureFormat::BGRA8UnormSrgb;
+    case GPUTextureFormat::Rgba16float:
+        return WHLSL::TextureFormat::RGBA16Float;
+    default:
+        return WTF::nullopt;
+    }
+}
+
+static Optional<WHLSL::Layout> convertLayout(const GPUPipelineLayout& layout)
+{
+    WHLSL::Layout result;
+    if (layout.bindGroupLayouts().size() > std::numeric_limits<unsigned>::max())
+        return WTF::nullopt;
+    for (size_t i = 0; i < layout.bindGroupLayouts().size(); ++i) {
+        const auto& bindGroupLayout = layout.bindGroupLayouts()[i];
+        WHLSL::BindGroup bindGroup;
+        bindGroup.name = static_cast<unsigned>(i);
+        for (const auto& keyValuePair : bindGroupLayout->bindingsMap()) {
+            const auto& gpuBindGroupLayoutBinding = keyValuePair.value;
+            WHLSL::Binding binding;
+            binding.visibility = convertShaderStageFlags(gpuBindGroupLayoutBinding.visibility);
+            if (auto bindingType = convertBindingType(gpuBindGroupLayoutBinding.type))
+                binding.bindingType = *bindingType;
+            else
+                return WTF::nullopt;
+            if (gpuBindGroupLayoutBinding.binding > std::numeric_limits<unsigned>::max())
+                return WTF::nullopt;
+            binding.name = static_cast<unsigned>(gpuBindGroupLayoutBinding.binding);
+            bindGroup.bindings.append(WTFMove(binding));
+        }
+        result.append(WTFMove(bindGroup));
+    }
+    return result;
+}
+
+static Optional<WHLSL::RenderPipelineDescriptor> convertRenderPipelineDescriptor(const GPURenderPipelineDescriptor& descriptor)
+{
+    WHLSL::RenderPipelineDescriptor whlslDescriptor;
+    if (descriptor.inputState.attributes.size() > std::numeric_limits<unsigned>::max())
+        return WTF::nullopt;
+    if (descriptor.colorStates.size() > std::numeric_limits<unsigned>::max())
+        return WTF::nullopt;
+
+    for (size_t i = 0; i < descriptor.inputState.attributes.size(); ++i)
+        whlslDescriptor.vertexAttributes.append({ convertVertexFormat(descriptor.inputState.attributes[i].format), static_cast<unsigned>(i) });
+
+    for (size_t i = 0; i < descriptor.colorStates.size(); ++i) {
+        if (auto format = convertTextureFormat(descriptor.colorStates[i].format))
+            whlslDescriptor.attachmentsStateDescriptor.attachmentDescriptors.append({*format, static_cast<unsigned>(i)});
+        else
+            return WTF::nullopt;
+    }
+
+    // FIXME: depthStencilAttachmentDescriptor isn't implemented yet.
+
+    if (descriptor.layout) {
+        if (auto layout = convertLayout(*descriptor.layout))
+            whlslDescriptor.layout = WTFMove(*layout);
+        else
+            return WTF::nullopt;
+    }
+    whlslDescriptor.vertexEntryPointName = descriptor.vertexStage.entryPoint;
+    whlslDescriptor.fragmentEntryPointName = descriptor.fragmentStage.entryPoint;
+    return whlslDescriptor;
+}
+
+static bool trySetMetalFunctionsForPipelineDescriptor(const char* const functionName, MTLLibrary *vertexMetalLibrary, MTLLibrary *fragmentMetalLibrary, MTLRenderPipelineDescriptor *mtlDescriptor, const String& vertexEntryPointName, const String& fragmentEntryPointName)
 {
 #if LOG_DISABLED
     UNUSED_PARAM(functionName);
 #endif
-    const auto& vertexStage = descriptor.vertexStage;
-    auto mtlLibrary = vertexStage.module->platformShaderModule();
-    if (!mtlLibrary) {
-        LOG(WebGPU, "%s: MTLLibrary for vertex stage does not exist!", functionName);
-        return false;
+
+    {
+        BEGIN_BLOCK_OBJC_EXCEPTIONS;
+
+        // Metal requires a vertex shader in all render pipelines.
+        if (!vertexMetalLibrary) {
+            LOG(WebGPU, "%s: MTLLibrary for vertex stage does not exist!", functionName);
+            return false;
+        }
+
+        auto function = adoptNS([vertexMetalLibrary newFunctionWithName:vertexEntryPointName]);
+        if (!function) {
+            LOG(WebGPU, "%s: Cannot create vertex MTLFunction \"%s\"!", functionName, vertexEntryPointName.utf8().data());
+            return false;
+        }
+
+        [mtlDescriptor setVertexFunction:function.get()];
+
+        END_BLOCK_OBJC_EXCEPTIONS;
     }
 
-    auto function = adoptNS([mtlLibrary newFunctionWithName:vertexStage.entryPoint]);
-    if (!function) {
-        LOG(WebGPU, "%s: Vertex MTLFunction \"%s\" not found!", functionName, vertexStage.entryPoint.utf8().data());
-        return false;
+    {
+        BEGIN_BLOCK_OBJC_EXCEPTIONS;
+
+        // However, fragment shaders are optional.
+        if (!fragmentMetalLibrary)
+            return true;
+
+        auto function = adoptNS([fragmentMetalLibrary newFunctionWithName:fragmentEntryPointName]);
+
+        if (!function) {
+            LOG(WebGPU, "%s: Cannot create fragment MTLFunction \"%s\"!", functionName, fragmentEntryPointName.utf8().data());
+            return false;
+        }
+
+        [mtlDescriptor setFragmentFunction:function.get()];
+        return true;
+
+        END_BLOCK_OBJC_EXCEPTIONS;
     }
 
-    [mtlDescriptor setVertexFunction:function.get()];
+    return false;
+}
 
-    const auto& fragmentStage = descriptor.fragmentStage;
-    if (!(mtlLibrary = fragmentStage.module->platformShaderModule())) {
-        LOG(WebGPU, "%s: MTLLibrary for fragment stage does not exist!", functionName);
+static bool trySetWHLSLFunctionsForPipelineDescriptor(const char* const functionName, MTLRenderPipelineDescriptor *mtlDescriptor, const GPURenderPipelineDescriptor& descriptor, String whlslSource, const GPUDevice& device)
+{
+    auto whlslDescriptor = convertRenderPipelineDescriptor(descriptor);
+    if (!whlslDescriptor)
         return false;
-    }
 
-    if (!(function = adoptNS([mtlLibrary newFunctionWithName:fragmentStage.entryPoint]))) {
-        LOG(WebGPU, "%s: Fragment MTLFunction \"%s\" not found!", functionName, fragmentStage.entryPoint.utf8().data());
+    auto result = WHLSL::prepare(whlslSource, *whlslDescriptor);
+    if (!result)
         return false;
+
+    WTFLogAlways("Metal code: %s", result->metalSource.utf8().data());
+
+    NSError *error = nil;
+    auto library = adoptNS([device.platformDevice() newLibraryWithSource:result->metalSource options:nil error:&error]);
+    ASSERT(library);
+    // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195771 Once we zero-fill variables, there should be no warnings, so we should be able to ASSERT(!error) here.
+
+    return trySetMetalFunctionsForPipelineDescriptor(functionName, library.get(), library.get(), mtlDescriptor, result->mangledVertexEntryPointName, result->mangledFragmentEntryPointName);
+}
+
+static bool trySetFunctionsForPipelineDescriptor(const char* const functionName, MTLRenderPipelineDescriptor *mtlDescriptor, const GPURenderPipelineDescriptor& descriptor, const GPUDevice& device)
+{
+    const auto& vertexStage = descriptor.vertexStage;
+    const auto& fragmentStage = descriptor.fragmentStage;
+
+    if (vertexStage.module.ptr() == fragmentStage.module.ptr()) {
+        // FIXME: https://bugs.webkit.org/show_bug.cgi?id=195446 Allow WHLSL shaders to come from different programs.
+        const auto& whlslSource = vertexStage.module->whlslSource();
+        if (!whlslSource.isNull())
+            return trySetWHLSLFunctionsForPipelineDescriptor(functionName, mtlDescriptor, descriptor, whlslSource, device);
     }
 
-    [mtlDescriptor setFragmentFunction:function.get()];
-    return true;
+    auto vertexLibrary = vertexStage.module->platformShaderModule();
+    MTLLibrary *fragmentLibrary = nil;
+    if (!fragmentStage.entryPoint.isNull())
+        fragmentLibrary = fragmentStage.module->platformShaderModule();
+
+    return trySetMetalFunctionsForPipelineDescriptor(functionName, vertexLibrary, fragmentLibrary, mtlDescriptor, vertexStage.entryPoint, fragmentStage.entryPoint);
 }
 
 static MTLVertexFormat mtlVertexFormatForGPUVertexFormat(GPUVertexFormat format)
@@ -220,7 +396,7 @@ static RetainPtr<MTLRenderPipelineState> tryCreateMtlRenderPipelineState(const c
 
     BEGIN_BLOCK_OBJC_EXCEPTIONS;
 
-    didSetFunctions = trySetFunctionsForPipelineDescriptor(functionName, mtlDescriptor.get(), descriptor);
+    didSetFunctions = trySetFunctionsForPipelineDescriptor(functionName, mtlDescriptor.get(), descriptor, device);
     didSetInputState = trySetInputStateForPipelineDescriptor(functionName, mtlDescriptor.get(), descriptor.inputState);
     didSetColorStates = trySetColorStatesForColorAttachmentArray(mtlDescriptor.get().colorAttachments, descriptor.colorStates);
 
index 4553a15..1aaf6bf 100644 (file)
@@ -43,6 +43,9 @@ RefPtr<GPUShaderModule> GPUShaderModule::create(const GPUDevice& device, GPUShad
         LOG(WebGPU, "GPUShaderModule::create(): Invalid GPUDevice!");
         return nullptr;
     }
+    
+    if (descriptor.isWHLSL)
+        return adoptRef(new GPUShaderModule(String(descriptor.code)));
 
     PlatformShaderModuleSmartPtr module;
 
@@ -63,6 +66,11 @@ GPUShaderModule::GPUShaderModule(PlatformShaderModuleSmartPtr&& module)
 {
 }
 
+GPUShaderModule::GPUShaderModule(String&& whlslSource)
+    : m_whlslSource(WTFMove(whlslSource))
+{
+}
+
 }
 
 #endif // ENABLE(WEBGPU)