[WHSL] Need grammar to specify kernel group size
authormmaxfield@apple.com <mmaxfield@apple.com@268f45cc-cd09-0410-ab3c-d52691b4dbfc>
Sat, 22 Sep 2018 23:49:28 +0000 (23:49 +0000)
committermmaxfield@apple.com <mmaxfield@apple.com@268f45cc-cd09-0410-ab3c-d52691b4dbfc>
Sat, 22 Sep 2018 23:49:28 +0000 (23:49 +0000)
https://bugs.webkit.org/show_bug.cgi?id=189108

Reviewed by Dean Jackson.

In HLSL, compute functions are annotated with their workgroup size.
For example,

[numthreads(3, 4, 5)] compute void foo(...) { ... }

* WebGPUShadingLanguageRI/All.js:
* WebGPUShadingLanguageRI/Func.js:
(Func):
(Func.prototype.get attributeBlock):
* WebGPUShadingLanguageRI/FuncAttribute.js: Copied from Tools/WebGPUShadingLanguageRI/FuncDef.js.
(FuncAttribute):
* WebGPUShadingLanguageRI/FuncDef.js:
(FuncDef):
* WebGPUShadingLanguageRI/FuncNumThreadsAttribute.js: Copied from Tools/WebGPUShadingLanguageRI/FuncDef.js.
(FuncNumThreadsAttribute):
(FuncNumThreadsAttribute.prototype.get x):
(FuncNumThreadsAttribute.prototype.get y):
(FuncNumThreadsAttribute.prototype.get z):
* WebGPUShadingLanguageRI/LateChecker.js:
(LateChecker.prototype._checkShaderType):
* WebGPUShadingLanguageRI/Parse.js:
(parseAttributeBlock):
(parseFuncDecl):
(parseFuncDef):
(parseNativeFunc):
* WebGPUShadingLanguageRI/SPIRV.html:
* WebGPUShadingLanguageRI/StatementCloner.js:
(StatementCloner.prototype.visitFuncDef):
(StatementCloner.prototype.visitFuncNumThreadsAttribute):
(StatementCloner):
* WebGPUShadingLanguageRI/Test.html:
* WebGPUShadingLanguageRI/Test.js:
(tests.numThreads):
* WebGPUShadingLanguageRI/Visitor.js:
(Visitor.prototype.visitFunc):
(Visitor.prototype.visitFuncNumThreadsAttribute):
(Visitor):
* WebGPUShadingLanguageRI/index.html:

git-svn-id: https://svn.webkit.org/repository/webkit/trunk@236390 268f45cc-cd09-0410-ab3c-d52691b4dbfc

14 files changed:
Tools/ChangeLog
Tools/WebGPUShadingLanguageRI/All.js
Tools/WebGPUShadingLanguageRI/Func.js
Tools/WebGPUShadingLanguageRI/FuncAttribute.js [new file with mode: 0644]
Tools/WebGPUShadingLanguageRI/FuncDef.js
Tools/WebGPUShadingLanguageRI/FuncNumThreadsAttribute.js [new file with mode: 0644]
Tools/WebGPUShadingLanguageRI/LateChecker.js
Tools/WebGPUShadingLanguageRI/Parse.js
Tools/WebGPUShadingLanguageRI/SPIRV.html
Tools/WebGPUShadingLanguageRI/StatementCloner.js
Tools/WebGPUShadingLanguageRI/Test.html
Tools/WebGPUShadingLanguageRI/Test.js
Tools/WebGPUShadingLanguageRI/Visitor.js
Tools/WebGPUShadingLanguageRI/index.html

index a6025d8..166c0c0 100644 (file)
@@ -1,5 +1,51 @@
 2018-09-22  Myles C. Maxfield  <mmaxfield@apple.com>
 
+        [WHSL] Need grammar to specify kernel group size
+        https://bugs.webkit.org/show_bug.cgi?id=189108
+
+        Reviewed by Dean Jackson.
+
+        In HLSL, compute functions are annotated with their workgroup size.
+        For example,
+
+        [numthreads(3, 4, 5)] compute void foo(...) { ... }
+
+        * WebGPUShadingLanguageRI/All.js:
+        * WebGPUShadingLanguageRI/Func.js:
+        (Func):
+        (Func.prototype.get attributeBlock):
+        * WebGPUShadingLanguageRI/FuncAttribute.js: Copied from Tools/WebGPUShadingLanguageRI/FuncDef.js.
+        (FuncAttribute):
+        * WebGPUShadingLanguageRI/FuncDef.js:
+        (FuncDef):
+        * WebGPUShadingLanguageRI/FuncNumThreadsAttribute.js: Copied from Tools/WebGPUShadingLanguageRI/FuncDef.js.
+        (FuncNumThreadsAttribute):
+        (FuncNumThreadsAttribute.prototype.get x):
+        (FuncNumThreadsAttribute.prototype.get y):
+        (FuncNumThreadsAttribute.prototype.get z):
+        * WebGPUShadingLanguageRI/LateChecker.js:
+        (LateChecker.prototype._checkShaderType):
+        * WebGPUShadingLanguageRI/Parse.js:
+        (parseAttributeBlock):
+        (parseFuncDecl):
+        (parseFuncDef):
+        (parseNativeFunc):
+        * WebGPUShadingLanguageRI/SPIRV.html:
+        * WebGPUShadingLanguageRI/StatementCloner.js:
+        (StatementCloner.prototype.visitFuncDef):
+        (StatementCloner.prototype.visitFuncNumThreadsAttribute):
+        (StatementCloner):
+        * WebGPUShadingLanguageRI/Test.html:
+        * WebGPUShadingLanguageRI/Test.js:
+        (tests.numThreads):
+        * WebGPUShadingLanguageRI/Visitor.js:
+        (Visitor.prototype.visitFunc):
+        (Visitor.prototype.visitFuncNumThreadsAttribute):
+        (Visitor):
+        * WebGPUShadingLanguageRI/index.html:
+
+2018-09-22  Myles C. Maxfield  <mmaxfield@apple.com>
+
         Native functions which accept pointers need to do null checks
         https://bugs.webkit.org/show_bug.cgi?id=189883
 
index e2910d7..4dbeeb8 100644 (file)
@@ -90,7 +90,9 @@ load("FloatLiteralType.js");
 load("FoldConstexprs.js");
 load("ForLoop.js");
 load("Func.js");
+load("FuncAttribute.js");
 load("FuncDef.js");
+load("FuncNumThreadsAttribute.js");
 load("FuncParameter.js");
 load("FunctionLikeBlock.js");
 load("HighZombieFinder.js");
index 4f58906..56b1d2c 100644 (file)
@@ -25,7 +25,7 @@
 "use strict";
 
 class Func extends Node {
-    constructor(origin, name, returnType, parameters, isCast, shaderType)
+    constructor(origin, name, returnType, parameters, isCast, shaderType, attributeBlock = null)
     {
         if (!(origin instanceof LexerToken))
             throw new Error("Bad origin: " + origin);
@@ -42,6 +42,7 @@ class Func extends Node {
         this._parameters = parameters;
         this._isCast = isCast;
         this._shaderType = shaderType;
+        this._attributeBlock = attributeBlock;
     }
     
     get origin() { return this._origin; }
@@ -51,6 +52,7 @@ class Func extends Node {
     get parameterTypes() { return this.parameters.map(parameter => parameter.type); }
     get isCast() { return this._isCast; }
     get shaderType() { return this._shaderType; }
+    get attributeBlock() { return this._attributeBlock; }
     get isEntryPoint() { return this.shaderType != null; }
     get returnTypeForOverloadResolution() { return this.isCast ? this.returnType : null; }
     
diff --git a/Tools/WebGPUShadingLanguageRI/FuncAttribute.js b/Tools/WebGPUShadingLanguageRI/FuncAttribute.js
new file mode 100644 (file)
index 0000000..e2e267e
--- /dev/null
@@ -0,0 +1,29 @@
+/*
+ * Copyright (C) 2018 Apple Inc. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ * 1. Redistributions of source code must retain the above copyright
+ *    notice, this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright
+ *    notice, this list of conditions and the following disclaimer in the
+ *    documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+ * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+"use strict";
+
+class FuncAttribute extends Node {
+}
+
index 3016b13..c2d11f3 100644 (file)
@@ -25,9 +25,9 @@
 "use strict";
 
 class FuncDef extends Func {
-    constructor(origin, name, returnType, parameters, body, isCast, shaderType)
+    constructor(origin, name, returnType, parameters, body, isCast, shaderType, attributeBlock = null)
     {
-        super(origin, name, returnType, parameters, isCast, shaderType);
+        super(origin, name, returnType, parameters, isCast, shaderType, attributeBlock);
         this._body = body;
         this.isRestricted = false;
     }
diff --git a/Tools/WebGPUShadingLanguageRI/FuncNumThreadsAttribute.js b/Tools/WebGPUShadingLanguageRI/FuncNumThreadsAttribute.js
new file mode 100644 (file)
index 0000000..6abe344
--- /dev/null
@@ -0,0 +1,40 @@
+/*
+ * Copyright (C) 2018 Apple Inc. All rights reserved.
+ *
+ * Redistribution and use in source and binary forms, with or without
+ * modification, are permitted provided that the following conditions
+ * are met:
+ * 1. Redistributions of source code must retain the above copyright
+ *    notice, this list of conditions and the following disclaimer.
+ * 2. Redistributions in binary form must reproduce the above copyright
+ *    notice, this list of conditions and the following disclaimer in the
+ *    documentation and/or other materials provided with the distribution.
+ *
+ * THIS SOFTWARE IS PROVIDED BY APPLE INC. ``AS IS'' AND ANY
+ * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
+ * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
+ * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL APPLE INC. OR
+ * CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
+ * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
+ * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
+ * PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
+ * OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+ */
+"use strict";
+
+class FuncNumThreadsAttribute extends FuncAttribute {
+    constructor(x, y, z)
+    {
+        super();
+        this._x = x;
+        this._y = y;
+        this._z = z;
+    }
+
+    get x() { return this._x; }
+    get y() { return this._y; }
+    get z() { return this._z; }
+}
+
index eb5ee69..eb1ce91 100644 (file)
@@ -70,6 +70,8 @@ class LateChecker extends Visitor {
                 }
             }
             break;
+        case "compute":
+            break;
         case "test":
             break;
         default:
index 7d1381f..3df6d21 100644 (file)
@@ -1155,6 +1155,50 @@ function parse(program, origin, originKind, lineNumberOffset, text)
         return maybeError.text;
     }
 
+    function parseAttributeBlock()
+    {
+        let maybeError = consume("[");
+        if (maybeError instanceof WSyntaxError)
+            return maybeError;
+        maybeError = consume("numthreads");
+        if (maybeError instanceof WSyntaxError)
+            return maybeError;
+        maybeError = consume("(");
+        if (maybeError instanceof WSyntaxError)
+            return maybeError;
+        let xDimension = consumeKind("intLiteral");
+        if (xDimension instanceof WSyntaxError)
+            return xDimension;
+        let xDimensionUintVersion = xDimension.text >>> 0;
+        if (xDimensionUintVersion.toString() !== xDimension.text)
+            return fail("Numthreads X attribute is not 32-bit unsigned integer: " + xDimension.text);
+        maybeError = consume(",");
+        if (maybeError instanceof WSyntaxError)
+            return maybeError;
+        let yDimension = consumeKind("intLiteral");
+        if (yDimension instanceof WSyntaxError)
+            return yDimension;
+        let yDimensionUintVersion = yDimension.text >>> 0;
+        if (yDimensionUintVersion.toString() !== yDimension.text)
+            return fail("Numthreads Y attribute is not 32-bit unsigned integer: " + yDimension.text);
+        maybeError = consume(",");
+        if (maybeError instanceof WSyntaxError)
+            return maybeError;
+        let zDimension = consumeKind("intLiteral");
+        if (zDimension instanceof WSyntaxError)
+            return zDimension;
+        let zDimensionUintVersion = zDimension.text >>> 0;
+        if (zDimensionUintVersion.toString() !== zDimension.text)
+            return fail("Numthreads Z attribute is not 32-bit unsigned integer: " + zDimension.text);
+        maybeError = consume(")");
+        if (maybeError instanceof WSyntaxError)
+            return maybeError;
+        maybeError = consume("]");
+        if (maybeError instanceof WSyntaxError)
+            return maybeError;
+        return [new FuncNumThreadsAttribute(xDimensionUintVersion, yDimensionUintVersion, zDimensionUintVersion)];
+    }
+
     function parseFuncDecl()
     {
         let origin;
@@ -1162,6 +1206,7 @@ function parse(program, origin, originKind, lineNumberOffset, text)
         let name;
         let isCast;
         let shaderType;
+        let attributeBlock = null;
         let operatorToken = tryConsume("operator");
         if (operatorToken) {
             origin = operatorToken;
@@ -1171,7 +1216,15 @@ function parse(program, origin, originKind, lineNumberOffset, text)
             name = "operator cast";
             isCast = true;
         } else {
-            shaderType = tryConsume("vertex", "fragment", "test");
+            if (test("[")) {
+                attributeBlock = parseAttributeBlock();
+                if (attributeBlock instanceof WSyntaxError)
+                    return attributeBlock;
+                shaderType = consume("compute");
+                if (shaderType instanceof WSyntaxError)
+                    return shaderType;
+            } else
+                shaderType = tryConsume("vertex", "fragment", "test");
             returnType = parseType();
             if (returnType instanceof WSyntaxError)
                 return returnType;
@@ -1188,7 +1241,7 @@ function parse(program, origin, originKind, lineNumberOffset, text)
         let parameters = parseParameters();
         if (parameters instanceof WSyntaxError)
             return parameters;
-        return new Func(origin, name, returnType, parameters, isCast, shaderType);
+        return new Func(origin, name, returnType, parameters, isCast, shaderType, attributeBlock);
     }
 
     function parseFuncDef()
@@ -1199,7 +1252,7 @@ function parse(program, origin, originKind, lineNumberOffset, text)
         let body = parseBlock();
         if (body instanceof WSyntaxError)
             return body;
-        return new FuncDef(func.origin, func.name, func.returnType, func.parameters, body, func.isCast, func.shaderType);
+        return new FuncDef(func.origin, func.name, func.returnType, func.parameters, body, func.isCast, func.shaderType, func.attributeBlock);
     }
 
     function parseField()
@@ -1242,6 +1295,8 @@ function parse(program, origin, originKind, lineNumberOffset, text)
         let func = parseFuncDecl();
         if (func instanceof WSyntaxError)
             return func;
+        if (func.attributeBlock)
+            return fail("Native function must not have attribute block");
         let maybeError = consume(";");
         if (maybeError instanceof WSyntaxError)
             return maybeError;
index 8918671..0a09ba0 100644 (file)
@@ -73,7 +73,9 @@ td {
     <script src="FoldConstexprs.js"></script>
     <script src="ForLoop.js"></script>
     <script src="Func.js"></script>
+    <script src="FuncAttribute.js"></script>
     <script src="FuncDef.js"></script>
+    <script src="FuncNumThreadsAttribute.js"></script>
     <script src="FuncParameter.js"></script>
     <script src="FunctionLikeBlock.js"></script>
     <script src="HighZombieFinder.js"></script>
index 0d0715a..472e7ac 100644 (file)
 class StatementCloner extends Rewriter {
     visitFuncDef(node)
     {
+        let attributeBlock = null;
+        if (node.attributeBlock)
+            attributeBlock = node.attributeBlock.map(attribute => attribute.visit(this));
         let result = new FuncDef(
             node.origin, node.name,
             node.returnType.visit(this),
             node.parameters.map(parameter => parameter.visit(this)),
             node.body.visit(this),
-            node.isCast, node.shaderType);
+            node.isCast, node.shaderType, attributeBlock);
         result.isRestricted = node.isRestricted;
         return result;
     }
@@ -78,5 +81,10 @@ class StatementCloner extends Rewriter {
             result.add(member);
         return result;
     }
+
+    visitFuncNumThreadsAttribute(node)
+    {
+        return new FuncNumThreadsAttribute(node.x, node.y, node.z);
+    }
 }
 
index 98e3851..a300dac 100644 (file)
@@ -67,7 +67,9 @@
 <script src="FoldConstexprs.js"></script>
 <script src="ForLoop.js"></script>
 <script src="Func.js"></script>
+<script src="FuncAttribute.js"></script>
 <script src="FuncDef.js"></script>
+<script src="FuncNumThreadsAttribute.js"></script>
 <script src="FuncParameter.js"></script>
 <script src="FunctionLikeBlock.js"></script>
 <script src="HighZombieFinder.js"></script>
index d562008..e92354b 100644 (file)
@@ -6766,6 +6766,120 @@ tests.arrayIndex = function() {
     checkInt(program, callFunction(program, "arrayIndexing", [ makeUint(program, 1), makeUint(program, 2) ]), 6);
 }
 
+tests.numThreads = function() {
+    let program = doPrep(`
+        [numthreads(3, 4, 5)]
+        compute void foo() {
+        }
+
+        [numthreads(6, 7, 8)]
+        compute void bar() {
+        }
+
+        [numthreads(9, 10, 11)]
+        compute void bar(device float[] buffer) {
+        }
+
+        struct R {
+            float4 position;
+        }
+        vertex R baz() {
+            R r;
+            r.position = float4(1, 2, 3, 4);
+            return r;
+        }
+    `);
+
+    if (program.functions.get("foo").length != 1)
+        throw new Error("Cannot find function named 'foo'");
+    let foo = program.functions.get("foo")[0];
+    if (foo.attributeBlock.length != 1)
+        throw new Error("'foo' doesn't have numthreads attribute");
+    if (foo.attributeBlock[0].x != 3)
+        throw new Error("'foo' numthreads x is not 3");
+    if (foo.attributeBlock[0].y != 4)
+        throw new Error("'foo' numthreads y is not 4");
+    if (foo.attributeBlock[0].z != 5)
+        throw new Error("'foo' numthreads z is not 5");
+
+    if (program.functions.get("bar").length != 2)
+        throw new Error("Cannot find function named 'bar'");
+    let bar1 = null;
+    let bar2 = null;
+    for (let bar of program.functions.get("bar")) {
+        if (bar.parameters.length == 0)
+            bar1 = bar;
+        else if (bar.parameters.length == 1)
+            bar2 = bar;
+        else
+            throw new Error("Unexpected 'bar' function.");
+    }
+    if (!bar1)
+        throw new Error("Could not find appropriate 'bar' function");
+    if (!bar2)
+        throw new Error("Could not find appropriate 'bar' function");
+
+    if (bar1.attributeBlock.length != 1)
+        throw new Error("'bar1' doesn't have numthreads attribute");
+    if (bar1.attributeBlock[0].x != 6)
+        throw new Error("'bar1' numthreads x is not 6");
+    if (bar1.attributeBlock[0].y != 7)
+        throw new Error("'bar1' numthreads y is not 7");
+    if (bar1.attributeBlock[0].z != 8)
+        throw new Error("'bar1' numthreads z is not 8");
+
+    if (bar2.attributeBlock.length != 1)
+        throw new Error("'bar2' doesn't have numthreads attribute");
+    if (bar2.attributeBlock[0].x != 9)
+        throw new Error("'bar2' numthreads x is not 9");
+    if (bar2.attributeBlock[0].y != 10)
+        throw new Error("'bar2' numthreads y is not 10");
+    if (bar2.attributeBlock[0].z != 11)
+        throw new Error("'bar2' numthreads z is not 11");
+
+    if (program.functions.get("baz").length != 1)
+        throw new Error("Cannot find function named 'baz'");
+    let baz = program.functions.get("baz")[0];
+    if (baz.attributeBlock != null)
+        throw new Error("'baz' has attribute block");
+
+    checkFail(() => doPrep(`
+        [numthreads(3, 4)]
+        compute void foo() {
+        }
+    `), e => e instanceof WSyntaxError);
+
+    checkFail(() => doPrep(`
+        []
+        compute void foo() {
+        }
+    `), e => e instanceof WSyntaxError);
+
+    checkFail(() => doPrep(`
+        [numthreads(3, 4, 5), numthreads(3, 4, 5)]
+        compute void foo() {
+        }
+    `), e => e instanceof WSyntaxError);
+
+    checkFail(() => doPrep(`
+        compute void foo() {
+        }
+    `), e => e instanceof WSyntaxError);
+
+
+    checkFail(() => doPrep(`
+        struct R {
+            float4 position;
+        }
+        [numthreads(3, 4, 5)]
+        vertex R baz() {
+            R r;
+            r.position = float4(1, 2, 3, 4);
+            return r;
+        }
+    `), e => e instanceof WSyntaxError);
+}
+
 function createTexturesForTesting(program)
 {
     let texture1D = make1DTexture(program, [[1, 7, 14, 79], [13, 16], [15]], "float");
index b9ea7e4..b8f4dbc 100644 (file)
@@ -36,6 +36,10 @@ class Visitor {
         node.returnType.visit(this);
         for (let parameter of node.parameters)
             parameter.visit(this);
+        if (node.attributeBlock) {
+            for (let attribute of node.attributeBlock)
+                attribute.visit(this);
+        }
     }
     
     visitFuncParameter(node)
@@ -347,5 +351,9 @@ class Visitor {
         node.numRows.visit(this);
         node.numColumns.visit(this);
     }
+
+    visitFuncNumThreadsAttribute(node)
+    {
+    }
 }
 
index 08f4ed5..41693d7 100644 (file)
@@ -67,7 +67,9 @@
 <script src="FoldConstexprs.js"></script>
 <script src="ForLoop.js"></script>
 <script src="Func.js"></script>
+<script src="FuncAttribute.js"></script>
 <script src="FuncDef.js"></script>
+<script src="FuncNumThreadsAttribute.js"></script>
 <script src="FuncParameter.js"></script>
 <script src="FunctionLikeBlock.js"></script>
 <script src="HighZombieFinder.js"></script>