[WHLSL] Standard library is too big to directly include in WebCore
[WebKit-https.git] / Source / WebCore / Modules / webgpu / WHLSL / WHLSLNameResolver.cpp
index 48fa930..2c3a2d7 100644 (file)
 #include "WHLSLNameContext.h"
 #include "WHLSLProgram.h"
 #include "WHLSLPropertyAccessExpression.h"
+#include "WHLSLReplaceWith.h"
 #include "WHLSLResolveOverloadImpl.h"
 #include "WHLSLReturn.h"
+#include "WHLSLScopedSetAdder.h"
 #include "WHLSLTypeReference.h"
 #include "WHLSLVariableDeclaration.h"
 #include "WHLSLVariableReference.h"
@@ -55,10 +57,31 @@ NameResolver::NameResolver(NameContext& nameContext)
 {
 }
 
+NameResolver::NameResolver(NameResolver& parentResolver, NameContext& nameContext)
+    : m_nameContext(nameContext)
+    , m_parentNameResolver(&parentResolver)
+{
+    setCurrentFunctionDefinition(parentResolver.m_currentFunction);
+}
+
+NameResolver::~NameResolver()
+{
+    if (error() && m_parentNameResolver)
+        m_parentNameResolver->setError();
+}
+
 void NameResolver::visit(AST::TypeReference& typeReference)
 {
-    checkErrorAndVisit(typeReference);
-    if (typeReference.resolvedType())
+    ScopedSetAdder<AST::TypeReference*> adder(m_typeReferences, &typeReference);
+    if (!adder.isNewEntry()) {
+        setError();
+        return;
+    }
+
+    Visitor::visit(typeReference);
+    if (error())
+        return;
+    if (typeReference.maybeResolvedType()) // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198161 Shouldn't we know by now whether the type has been resolved or not?
         return;
 
     auto* candidates = m_nameContext.getTypes(typeReference.name());
@@ -66,6 +89,8 @@ void NameResolver::visit(AST::TypeReference& typeReference)
         setError();
         return;
     }
+    for (auto& candidate : *candidates)
+        Visitor::visit(candidate);
     if (auto result = resolveTypeOverloadImpl(*candidates, typeReference.typeArguments()))
         typeReference.setResolvedType(*result);
     else {
@@ -77,60 +102,79 @@ void NameResolver::visit(AST::TypeReference& typeReference)
 void NameResolver::visit(AST::FunctionDefinition& functionDefinition)
 {
     NameContext newNameContext(&m_nameContext);
-    NameResolver newNameResolver(newNameContext);
+    NameResolver newNameResolver(*this, newNameContext);
     checkErrorAndVisit(functionDefinition.type());
-    for (auto& parameter : functionDefinition.parameters()) {
+    if (error())
+        return;
+    for (auto& parameter : functionDefinition.parameters())
         newNameResolver.checkErrorAndVisit(parameter);
-        auto success = newNameContext.add(parameter);
-        if (!success) {
-            setError();
-            return;
-        }
-    }
     newNameResolver.checkErrorAndVisit(functionDefinition.block());
 }
 
 void NameResolver::visit(AST::Block& block)
 {
     NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(block);
+    NameResolver newNameResolver(*this, nameContext);
+    newNameResolver.Visitor::visit(block);
 }
 
 void NameResolver::visit(AST::IfStatement& ifStatement)
 {
     checkErrorAndVisit(ifStatement.conditional());
-    NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(ifStatement.body());
+    if (error())
+        return;
+
+    {
+        NameContext nameContext(&m_nameContext);
+        NameResolver newNameResolver(*this, nameContext);
+        newNameResolver.checkErrorAndVisit(ifStatement.body());
+    }
+    if (error())
+        return;
+
     if (ifStatement.elseBody()) {
         NameContext nameContext(&m_nameContext);
-        NameResolver(nameContext).checkErrorAndVisit(static_cast<AST::Statement&>(*ifStatement.elseBody()));
+        NameResolver newNameResolver(*this, nameContext);
+        newNameResolver.checkErrorAndVisit(*ifStatement.elseBody());
     }
 }
 
 void NameResolver::visit(AST::WhileLoop& whileLoop)
 {
     checkErrorAndVisit(whileLoop.conditional());
+    if (error())
+        return;
+
     NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(whileLoop.body());
+    NameResolver newNameResolver(*this, nameContext);
+    newNameResolver.checkErrorAndVisit(whileLoop.body());
 }
 
 void NameResolver::visit(AST::DoWhileLoop& whileLoop)
 {
-    NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(whileLoop.body());
+    {
+        NameContext nameContext(&m_nameContext);
+        NameResolver newNameResolver(*this, nameContext);
+        newNameResolver.checkErrorAndVisit(whileLoop.body());
+    }
+
     checkErrorAndVisit(whileLoop.conditional());
 }
 
 void NameResolver::visit(AST::ForLoop& forLoop)
 {
     NameContext nameContext(&m_nameContext);
-    NameResolver(nameContext).checkErrorAndVisit(forLoop);
+    NameResolver newNameResolver(*this, nameContext);
+    newNameResolver.Visitor::visit(forLoop);
 }
 
 void NameResolver::visit(AST::VariableDeclaration& variableDeclaration)
 {
-    m_nameContext.add(variableDeclaration);
-    checkErrorAndVisit(variableDeclaration);
+    if (!m_nameContext.add(variableDeclaration)) {
+        setError();
+        return;
+    }
+    Visitor::visit(variableDeclaration);
 }
 
 void NameResolver::visit(AST::VariableReference& variableReference)
@@ -150,34 +194,28 @@ void NameResolver::visit(AST::Return& returnStatement)
 {
     ASSERT(m_currentFunction);
     returnStatement.setFunction(m_currentFunction);
-    checkErrorAndVisit(returnStatement);
+    Visitor::visit(returnStatement);
 }
 
 void NameResolver::visit(AST::PropertyAccessExpression& propertyAccessExpression)
 {
-    if (auto* getFunctions = m_nameContext.getFunctions(propertyAccessExpression.getFunctionName()))
-        propertyAccessExpression.setPossibleGetOverloads(*getFunctions);
-    if (auto* setFunctions = m_nameContext.getFunctions(propertyAccessExpression.setFunctionName()))
-        propertyAccessExpression.setPossibleSetOverloads(*setFunctions);
-    if (auto* andFunctions = m_nameContext.getFunctions(propertyAccessExpression.andFunctionName()))
-        propertyAccessExpression.setPossibleAndOverloads(*andFunctions);
-    checkErrorAndVisit(propertyAccessExpression);
+    Visitor::visit(propertyAccessExpression);
 }
 
 void NameResolver::visit(AST::DotExpression& dotExpression)
 {
     if (is<AST::VariableReference>(dotExpression.base())) {
-        if (auto enumerationTypes = m_nameContext.getTypes(downcast<AST::VariableReference>(dotExpression.base()).name())) {
+        auto baseName = downcast<AST::VariableReference>(dotExpression.base()).name();
+        if (auto enumerationTypes = m_nameContext.getTypes(baseName)) {
             ASSERT(enumerationTypes->size() == 1);
             AST::NamedType& type = (*enumerationTypes)[0];
             if (is<AST::EnumerationDefinition>(type)) {
                 AST::EnumerationDefinition& enumerationDefinition = downcast<AST::EnumerationDefinition>(type);
-                if (auto* member = enumerationDefinition.memberByName(dotExpression.fieldName())) {
-                    static_assert(sizeof(AST::EnumerationMemberLiteral) <= sizeof(AST::DotExpression), "Dot expressions need to be able to become EnumerationMemberLiterals without updating backreferences");
+                auto memberName = dotExpression.fieldName();
+                if (auto* member = enumerationDefinition.memberByName(memberName)) {
                     Lexer::Token origin = dotExpression.origin();
-                    // FIXME: Perhaps do this with variants or a Rewriter instead.
-                    dotExpression.~DotExpression();
-                    new (&dotExpression) AST::EnumerationMemberLiteral(WTFMove(origin), *member);
+                    auto enumerationMemberLiteral = AST::EnumerationMemberLiteral::wrap(WTFMove(origin), WTFMove(baseName), WTFMove(memberName), enumerationDefinition, *member);
+                    AST::replaceWith<AST::EnumerationMemberLiteral>(dotExpression, WTFMove(enumerationMemberLiteral));
                     return;
                 }
                 setError();
@@ -186,41 +224,26 @@ void NameResolver::visit(AST::DotExpression& dotExpression)
         }
     }
 
-    checkErrorAndVisit(dotExpression);
+    Visitor::visit(dotExpression);
 }
 
 void NameResolver::visit(AST::CallExpression& callExpression)
 {
-    if (!callExpression.hasOverloads()) {
-        if (auto* functions = m_nameContext.getFunctions(callExpression.name()))
-            callExpression.setOverloads(*functions);
-        else {
-            if (auto* types = m_nameContext.getTypes(callExpression.name())) {
-                if (types->size() == 1) {
-                    if (auto* functions = m_nameContext.getFunctions("operator cast"_str)) {
-                        callExpression.setCastData((*types)[0].get());
-                        callExpression.setOverloads(*functions);
-                    }
-                }
-            }
-        }
-    }
-    if (!callExpression.hasOverloads()) {
-        setError();
-        return;
-    }
-    checkErrorAndVisit(callExpression);
+    Visitor::visit(callExpression);
 }
 
-void NameResolver::visit(AST::ConstantExpressionEnumerationMemberReference& constantExpressionEnumerationMemberReference)
+void NameResolver::visit(AST::EnumerationMemberLiteral& enumerationMemberLiteral)
 {
-    if (auto enumerationTypes = m_nameContext.getTypes(constantExpressionEnumerationMemberReference.left())) {
+    if (enumerationMemberLiteral.enumerationMember())
+        return;
+
+    if (auto enumerationTypes = m_nameContext.getTypes(enumerationMemberLiteral.left())) {
         ASSERT(enumerationTypes->size() == 1);
         AST::NamedType& type = (*enumerationTypes)[0];
         if (is<AST::EnumerationDefinition>(type)) {
             AST::EnumerationDefinition& enumerationDefinition = downcast<AST::EnumerationDefinition>(type);
-            if (auto* member = enumerationDefinition.memberByName(constantExpressionEnumerationMemberReference.right())) {
-                constantExpressionEnumerationMemberReference.setEnumerationMember(enumerationDefinition, *member);
+            if (auto* member = enumerationDefinition.memberByName(enumerationMemberLiteral.right())) {
+                enumerationMemberLiteral.setEnumerationMember(enumerationDefinition, *member);
                 return;
             }
         }
@@ -229,52 +252,59 @@ void NameResolver::visit(AST::ConstantExpressionEnumerationMemberReference& cons
     setError();
 }
 
-// FIXME: Make sure all the names have been resolved.
+void NameResolver::visit(AST::NativeFunctionDeclaration& nativeFunctionDeclaration)
+{
+    NameContext newNameContext(&m_nameContext);
+    NameResolver newNameResolver(newNameContext);
+    newNameResolver.Visitor::visit(nativeFunctionDeclaration);
+}
+
+// FIXME: https://bugs.webkit.org/show_bug.cgi?id=198167 Make sure all the names have been resolved.
 
 bool resolveNamesInTypes(Program& program, NameResolver& nameResolver)
 {
     for (auto& typeDefinition : program.typeDefinitions()) {
-        nameResolver.checkErrorAndVisit(static_cast<AST::TypeDefinition&>(typeDefinition));
+        nameResolver.checkErrorAndVisit(typeDefinition);
         if (nameResolver.error())
             return false;
     }
     for (auto& structureDefinition : program.structureDefinitions()) {
-        nameResolver.checkErrorAndVisit(static_cast<AST::StructureDefinition&>(structureDefinition));
+        nameResolver.checkErrorAndVisit(structureDefinition);
         if (nameResolver.error())
             return false;
     }
     for (auto& enumerationDefinition : program.enumerationDefinitions()) {
-        nameResolver.checkErrorAndVisit(static_cast<AST::EnumerationDefinition&>(enumerationDefinition));
+        nameResolver.checkErrorAndVisit(enumerationDefinition);
         if (nameResolver.error())
             return false;
     }
     for (auto& nativeTypeDeclaration : program.nativeTypeDeclarations()) {
-        nameResolver.checkErrorAndVisit(static_cast<AST::NativeTypeDeclaration&>(nativeTypeDeclaration));
+        nameResolver.checkErrorAndVisit(nativeTypeDeclaration);
         if (nameResolver.error())
             return false;
     }
     return true;
 }
 
-bool resolveNamesInFunctions(Program& program, NameResolver& nameResolver)
+bool resolveTypeNamesInFunctions(Program& program, NameResolver& nameResolver)
 {
     for (auto& functionDefinition : program.functionDefinitions()) {
-        nameResolver.setCurrentFunctionDefinition(&static_cast<AST::FunctionDefinition&>(functionDefinition));
-        nameResolver.checkErrorAndVisit(static_cast<AST::FunctionDefinition&>(functionDefinition));
+        nameResolver.setCurrentFunctionDefinition(&functionDefinition);
+        nameResolver.checkErrorAndVisit(functionDefinition);
         if (nameResolver.error())
             return false;
     }
     nameResolver.setCurrentFunctionDefinition(nullptr);
     for (auto& nativeFunctionDeclaration : program.nativeFunctionDeclarations()) {
-        nameResolver.checkErrorAndVisit(static_cast<AST::FunctionDeclaration&>(nativeFunctionDeclaration));
+        nameResolver.checkErrorAndVisit(nativeFunctionDeclaration);
         if (nameResolver.error())
             return false;
     }
     return true;
 }
 
-}
+} // namespace WHLSL
 
-}
+} // namespace WebCore
 
-#endif
+#endif // ENABLE(WEBGPU)