[WHLSL] Pointers should have automatically-generated equality checks
[WebKit-https.git] / Tools / WebGPUShadingLanguageRI / CallExpression.js
index b4c0da7..04be2da 100644 (file)
@@ -102,6 +102,11 @@ class CallExpression extends Expression {
             func = this._resolveWithOperatorAnderIndexer(program);
         else if (this.name == "operator.length")
             func = this._resolveWithOperatorLength(program);
+        else if (this.name == "operator==" && this.argumentTypes.length == 2
+            && (this.argumentTypes[0] instanceof NullType || this.argumentTypes[0] instanceof ReferenceType)
+            && (this.argumentTypes[1] instanceof NullType || this.argumentTypes[1] instanceof ReferenceType)
+            && this.argumentTypes[0].equals(this.argumentTypes[1]))
+                func = this._resolveWithReferenceComparator(program);
         else
             return null;
 
@@ -119,17 +124,17 @@ class CallExpression extends Expression {
         const addressSpace = arrayRefType.addressSpace;
 
         // The later checkLiteralTypes stage will verify that the literal can be represented as a uint.
-        const uintType = TypeRef.wrap(program.types.get("uint"));
+        const uintType = TypeRef.wrap(program.intrinsics.uint);
         indexType.type = uintType;
 
         const elementType = this.argumentTypes[0].elementType;
-        this.resultType = this._returnType = TypeRef.wrap(new PtrType(this.origin, addressSpace, TypeRef.wrap(elementType)))
+        this.resultType = TypeRef.wrap(new PtrType(this.origin, addressSpace, TypeRef.wrap(elementType)))
 
-        let arrayRefAccessor = new OperatorAnderIndexer(this.returnType.toString(), addressSpace);
+        let arrayRefAccessor = new OperatorAnderIndexer(this.resultType.toString(), addressSpace);
         const func = new NativeFunc(this.origin, "operator&[]", this.resultType, [
             new FuncParameter(this.origin, null, arrayRefType),
             new FuncParameter(this.origin, null, uintType)
-        ], false);
+        ]);
 
         arrayRefAccessor.instantiateImplementation(func);
 
@@ -138,14 +143,14 @@ class CallExpression extends Expression {
 
     _resolveWithOperatorLength(program)
     {
-        this.resultType = this._returnType = TypeRef.wrap(program.types.get("uint"));
+        this.resultType = TypeRef.wrap(program.intrinsics.uint);
 
         if (this.argumentTypes[0].isArray) {
             const arrayType = this.argumentTypes[0];
             const func = new NativeFunc(this.origin, "operator.length", this.resultType, [
                 new FuncParameter(this.origin, null, arrayType)
-            ], false);
-            func.implementation = (args, node) => EPtr.box(arrayType.numElementsValue);
+            ]);
+            func.implementation = (args) => EPtr.box(arrayType.numElementsValue);
             return func;
         } else if (this.argumentTypes[0].isArrayRef) {
             const arrayRefType = this.argumentTypes[0];
@@ -153,12 +158,38 @@ class CallExpression extends Expression {
             const operatorLength = new OperatorArrayRefLength(arrayRefType.toString(), addressSpace);
             const func = new NativeFunc(this.origin, "operator.length", this.resultType, [
                 new FuncParameter(this.origin, null, arrayRefType)
-            ], false);
+            ]);
             operatorLength.instantiateImplementation(func);
             return func;
         } else
             throw new WTypeError(this.origin.originString, `Expected ${this.argumentTypes[0]} to be array/array ref type for operator.length`);
     }
+
+    _resolveWithReferenceComparator(program)
+    {
+        let argumentType = this.argumentTypes[0];
+        if (argumentType instanceof NullType)
+            argumentType = this.argumentTypes[1];
+        if (argumentType instanceof NullType) {
+            // We encountered "null == null".
+            // The type isn't observable, so we can pick whatever we want.
+            // FIXME: This can probably be generalized, using the "preferred type" infrastructure used by generic literals
+            argumentType = new PtrType(this.origin, "thread", program.intrinsics.int);
+        }
+        this.resultType = TypeRef.wrap(program.intrinsics.bool);
+        const func = new NativeFunc(this.origin, "operator==", this.resultType, [
+            new FuncParameter(this.origin, null, argumentType),
+            new FuncParameter(this.origin, null, argumentType)
+        ]);
+        func.implementation = ([lhs, rhs]) => {
+            let left = lhs.loadValue();
+            let right = rhs.loadValue();
+            if (left && right)
+                return EPtr.box(left.equals(right));
+            return EPtr.box(left == right);
+        };
+        return func;
+    }
     
     resolveToOverload(overload)
     {