Title: [248339] trunk/Source/WebCore
Revision
248339
Author
sbar...@apple.com
Date
2019-08-06 21:17:29 -0700 (Tue, 06 Aug 2019)

Log Message

[WHLSL] Make resolveFunction in Checker faster
https://bugs.webkit.org/show_bug.cgi?id=200287

Reviewed by Robin Morisset.

This patch makes compute_boids faster by making function overload
resolution faster inside the Checker. It's a ~6ms speedup in the
checker. The main idea is to limit the number of overloads we need
to look for by using a hash table that describes a function's type
instead of just using a hash table keyed by a function's name.

The interesting implementation detail here is we must construct entries
in the hash table such that they still allow constants to be resolved to
various types. This means that the key in the hash table must normalize
the vector of types it uses to express a function's identity. The normalization
rules are:
- int => float
- uint => float
- T* => float*
- T[] => float*

The first two rules are because int constants can be matched against
the float and uint types. The latter two rules are because the null
literal can be matched against any pointer or any array reference
(we pick float* arbitrarily). Even though it seems like these
normalization rules would drastically broaden the efficacy of the hash
table, we still see a 100x reduction in the number of overloads we must
resolve inside compute_boids. We go from having to resolve 400,000
overloads to just resolving 4,000.

* Modules/webgpu/WHLSL/WHLSLChecker.cpp:
(WebCore::WHLSL::FunctionKey::FunctionKey):
(WebCore::WHLSL::FunctionKey::isEmptyValue const):
(WebCore::WHLSL::FunctionKey::isHashTableDeletedValue const):
(WebCore::WHLSL::FunctionKey::hash const):
(WebCore::WHLSL::FunctionKey::operator== const):
(WebCore::WHLSL::FunctionKey::Hash::hash):
(WebCore::WHLSL::FunctionKey::Hash::equal):
(WebCore::WHLSL::FunctionKey::Traits::isEmptyValue):
(WebCore::WHLSL::Checker::Checker):
(WebCore::WHLSL::Checker::wrappedFloatType):
(WebCore::WHLSL::Checker::genericPointerType):
(WebCore::WHLSL::Checker::normalizedTypeForFunctionKey):
(WebCore::WHLSL::Checker::resolveFunction):
(WebCore::WHLSL::Checker::finishVisiting):
(WebCore::WHLSL::Checker::visit):
(WebCore::WHLSL::resolveFunction): Deleted.

Modified Paths

Diff

Modified: trunk/Source/WebCore/ChangeLog (248338 => 248339)


--- trunk/Source/WebCore/ChangeLog	2019-08-07 03:15:19 UTC (rev 248338)
+++ trunk/Source/WebCore/ChangeLog	2019-08-07 04:17:29 UTC (rev 248339)
@@ -1,3 +1,53 @@
+2019-08-06  Saam Barati  <sbar...@apple.com>
+
+        [WHLSL] Make resolveFunction in Checker faster
+        https://bugs.webkit.org/show_bug.cgi?id=200287
+
+        Reviewed by Robin Morisset.
+
+        This patch makes compute_boids faster by making function overload
+        resolution faster inside the Checker. It's a ~6ms speedup in the
+        checker. The main idea is to limit the number of overloads we need
+        to look for by using a hash table that describes a function's type
+        instead of just using a hash table keyed by a function's name.
+        
+        The interesting implementation detail here is we must construct entries
+        in the hash table such that they still allow constants to be resolved to
+        various types. This means that the key in the hash table must normalize
+        the vector of types it uses to express a function's identity. The normalization
+        rules are:
+        - int => float
+        - uint => float
+        - T* => float*
+        - T[] => float*
+        
+        The first two rules are because int constants can be matched against
+        the float and uint types. The latter two rules are because the null
+        literal can be matched against any pointer or any array reference
+        (we pick float* arbitrarily). Even though it seems like these
+        normalization rules would drastically broaden the efficacy of the hash
+        table, we still see a 100x reduction in the number of overloads we must
+        resolve inside compute_boids. We go from having to resolve 400,000
+        overloads to just resolving 4,000.
+
+        * Modules/webgpu/WHLSL/WHLSLChecker.cpp:
+        (WebCore::WHLSL::FunctionKey::FunctionKey):
+        (WebCore::WHLSL::FunctionKey::isEmptyValue const):
+        (WebCore::WHLSL::FunctionKey::isHashTableDeletedValue const):
+        (WebCore::WHLSL::FunctionKey::hash const):
+        (WebCore::WHLSL::FunctionKey::operator== const):
+        (WebCore::WHLSL::FunctionKey::Hash::hash):
+        (WebCore::WHLSL::FunctionKey::Hash::equal):
+        (WebCore::WHLSL::FunctionKey::Traits::isEmptyValue):
+        (WebCore::WHLSL::Checker::Checker):
+        (WebCore::WHLSL::Checker::wrappedFloatType):
+        (WebCore::WHLSL::Checker::genericPointerType):
+        (WebCore::WHLSL::Checker::normalizedTypeForFunctionKey):
+        (WebCore::WHLSL::Checker::resolveFunction):
+        (WebCore::WHLSL::Checker::finishVisiting):
+        (WebCore::WHLSL::Checker::visit):
+        (WebCore::WHLSL::resolveFunction): Deleted.
+
 2019-08-06  Loïc Yhuel  <loic.yh...@softathome.com>
 
         Fix 32-bit Linux build after r248282

Modified: trunk/Source/WebCore/Modules/webgpu/WHLSL/WHLSLChecker.cpp (248338 => 248339)


--- trunk/Source/WebCore/Modules/webgpu/WHLSL/WHLSLChecker.cpp	2019-08-07 03:15:19 UTC (rev 248338)
+++ trunk/Source/WebCore/Modules/webgpu/WHLSL/WHLSLChecker.cpp	2019-08-07 04:17:29 UTC (rev 248339)
@@ -118,6 +118,87 @@
     }
 };
 
+class FunctionKey {
+public:
+    FunctionKey() = default;
+    FunctionKey(WTF::HashTableDeletedValueType)
+    {
+        m_castReturnType = bitwise_cast<AST::NamedType*>(static_cast<uintptr_t>(1));
+    }
+
+    FunctionKey(String name, Vector<std::reference_wrapper<AST::UnnamedType>> types, AST::NamedType* castReturnType = nullptr)
+        : m_name(WTFMove(name))
+        , m_types(WTFMove(types))
+        , m_castReturnType(castReturnType)
+    { }
+
+    bool isEmptyValue() const { return m_name.isNull(); }
+    bool isHashTableDeletedValue() const { return m_castReturnType == bitwise_cast<AST::NamedType*>(static_cast<uintptr_t>(1)); }
+
+    unsigned hash() const
+    {
+        unsigned hash = IntHash<size_t>::hash(m_types.size());
+        hash ^= m_name.hash();
+        for (size_t i = 0; i < m_types.size(); ++i)
+            hash ^= m_types[i].get().hash() + i;
+
+        if (m_castReturnType)
+            hash ^= WTF::PtrHash<AST::Type*>::hash(&m_castReturnType->unifyNode());
+
+        return hash;
+    }
+
+    bool operator==(const FunctionKey& other) const
+    {
+        if (m_types.size() != other.m_types.size())
+            return false;
+
+        if (m_name != other.m_name)
+            return false;
+
+        for (size_t i = 0; i < m_types.size(); ++i) {
+            if (!matches(m_types[i].get(), other.m_types[i].get()))
+                return false;
+        }
+
+        if (!!m_castReturnType != !!other.m_castReturnType)
+            return false;
+
+        if (!m_castReturnType)
+            return true;
+
+        if (&m_castReturnType->unifyNode() == &other.m_castReturnType->unifyNode())
+            return true;
+
+        return false;
+    }
+
+    struct Hash {
+        static unsigned hash(const FunctionKey& key)
+        {
+            return key.hash();
+        }
+
+        static bool equal(const FunctionKey& a, const FunctionKey& b)
+        {
+            return a == b;
+        }
+
+        static const bool safeToCompareToEmptyOrDeleted = false;
+        static const bool emptyValueIsZero = false;
+    };
+
+    struct Traits : public WTF::SimpleClassHashTraits<FunctionKey> {
+        static const bool hasIsEmptyValueFunction = true;
+        static bool isEmptyValue(const FunctionKey& key) { return key.isEmptyValue(); }
+    };
+
+private:
+    String m_name;
+    Vector<std::reference_wrapper<AST::UnnamedType>> m_types;
+    AST::NamedType* m_castReturnType;
+};
+
 static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(CodeLocation location, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics)
 {
     const bool isOperator = true;
@@ -217,21 +298,6 @@
     return WTF::nullopt;
 }
 
-static AST::FunctionDeclaration* resolveFunction(Program& program, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>* possibleOverloads, Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation location, const Intrinsics& intrinsics, AST::NamedType* castReturnType = nullptr)
-{
-    if (possibleOverloads) {
-        if (AST::FunctionDeclaration* function = resolveFunctionOverload(*possibleOverloads, types, castReturnType))
-            return function;
-    }
-
-    if (auto newFunction = resolveByInstantiation(name, location, types, intrinsics)) {
-        program.append(WTFMove(*newFunction));
-        return &program.nativeFunctionDeclarations().last();
-    }
-
-    return nullptr;
-}
-
 static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
 {
     {
@@ -454,6 +520,25 @@
         : m_intrinsics(intrinsics)
         , m_program(program)
     {
+        auto addFunction = [&] (AST::FunctionDeclaration& function) {
+            AST::NamedType* castReturnType = nullptr;
+            if (function.isCast() && is<AST::NamedType>(function.type().unifyNode()))
+                castReturnType = &downcast<AST::NamedType>(function.type().unifyNode());
+
+            Vector<std::reference_wrapper<AST::UnnamedType>> types;
+            types.reserveInitialCapacity(function.parameters().size());
+
+            for (auto& param : function.parameters())
+                types.uncheckedAppend(normalizedTypeForFunctionKey(*param->type()));
+
+            auto addResult = m_functions.add(FunctionKey { function.name(), WTFMove(types), castReturnType }, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>());
+            addResult.iterator->value.append(function);
+        };
+
+        for (auto& function : m_program.functionDefinitions())
+            addFunction(function.get());
+        for (auto& function : m_program.nativeFunctionDeclarations())
+            addFunction(function.get());
     }
 
     virtual ~Checker() = default;
@@ -511,6 +596,36 @@
 
     void finishVisiting(AST::PropertyAccessExpression&, ResolvingType* additionalArgumentType = nullptr);
 
+    AST::FunctionDeclaration* resolveFunction(Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation, AST::NamedType* castReturnType = nullptr);
+
+    AST::UnnamedType& wrappedFloatType()
+    {
+        if (!m_wrappedFloatType)
+            m_wrappedFloatType = AST::TypeReference::wrap({ }, m_intrinsics.floatType());
+        return *m_wrappedFloatType;
+    }
+
+    AST::UnnamedType& genericPointerType()
+    {
+        if (!m_genericPointerType)
+            m_genericPointerType = AST::PointerType::create({ }, AST::AddressSpace::Thread, AST::TypeReference::wrap({ }, m_intrinsics.floatType()));
+        return *m_genericPointerType;
+    }
+
+    AST::UnnamedType& normalizedTypeForFunctionKey(AST::UnnamedType& type)
+    {
+        auto* unifyNode = &type.unifyNode();
+        if (unifyNode == &m_intrinsics.uintType() || unifyNode == &m_intrinsics.intType())
+            return wrappedFloatType();
+
+        if (is<AST::ReferenceType>(type))
+            return genericPointerType();
+
+        return type;
+    }
+
+    RefPtr<AST::TypeReference> m_wrappedFloatType;
+    RefPtr<AST::UnnamedType> m_genericPointerType;
     HashMap<AST::_expression_*, std::unique_ptr<ResolvingType>> m_typeMap;
     HashSet<String> m_vertexEntryPoints;
     HashSet<String> m_fragmentEntryPoints;
@@ -518,6 +633,7 @@
     const Intrinsics& m_intrinsics;
     Program& m_program;
     AST::FunctionDefinition* m_currentFunction { nullptr };
+    HashMap<FunctionKey, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>, FunctionKey::Hash, FunctionKey::Traits> m_functions;
 };
 
 void Checker::visit(Program& program)
@@ -638,6 +754,53 @@
     }));
 }
 
+AST::FunctionDeclaration* Checker::resolveFunction(Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation location, AST::NamedType* castReturnType)
+{
+    Vector<std::reference_wrapper<AST::UnnamedType>> unnamedTypes;
+    unnamedTypes.reserveInitialCapacity(types.size());
+
+    for (auto resolvingType : types) {
+        AST::UnnamedType* type = resolvingType.get().visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
+            return unnamedType.ptr();
+        }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> AST::UnnamedType* {
+            if (resolvableTypeReference->resolvableType().maybeResolvedType())
+                return &resolvableTypeReference->resolvableType().resolvedType();
+
+            if (resolvableTypeReference->resolvableType().isFloatLiteralType()
+                || resolvableTypeReference->resolvableType().isIntegerLiteralType()
+                || resolvableTypeReference->resolvableType().isUnsignedIntegerLiteralType())
+                return &wrappedFloatType();
+
+            if (resolvableTypeReference->resolvableType().isNullLiteralType())
+                return &genericPointerType();
+
+            return commit(resolvableTypeReference->resolvableType()).get();
+        }));
+
+        if (!type) {
+            setError(Error("Could not resolve the type of a constant."));
+            return nullptr;
+        }
+
+        unnamedTypes.uncheckedAppend(normalizedTypeForFunctionKey(*type));
+    }
+
+    {
+        auto iter = m_functions.find(FunctionKey { name, WTFMove(unnamedTypes), castReturnType });
+        if (iter != m_functions.end()) {
+            if (AST::FunctionDeclaration* function = resolveFunctionOverload(iter->value, types, castReturnType))
+                return function;
+        }
+    }
+
+    if (auto newFunction = resolveByInstantiation(name, location, types, m_intrinsics)) {
+        m_program.append(WTFMove(*newFunction));
+        return &m_program.nativeFunctionDeclarations().last();
+    }
+
+    return nullptr;
+}
+
 void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
 {
     bool isSigned;
@@ -961,8 +1124,9 @@
         if (additionalArgumentType)
             getterArgumentTypes.append(*additionalArgumentType);
         auto getterName = propertyAccessExpression.getterFunctionName();
-        auto* getterFunctions = m_program.nameContext().getFunctions(getterName);
-        getterFunction = resolveFunction(m_program, getterFunctions, getterArgumentTypes, getterName, propertyAccessExpression.codeLocation(), m_intrinsics);
+        getterFunction = resolveFunction(getterArgumentTypes, getterName, propertyAccessExpression.codeLocation());
+        if (hasError())
+            return;
         if (getterFunction)
             getterReturnType = &getterFunction->type();
     }
@@ -977,8 +1141,9 @@
             if (additionalArgumentType)
                 anderArgumentTypes.append(*additionalArgumentType);
             auto anderName = propertyAccessExpression.anderFunctionName();
-            auto* anderFunctions = m_program.nameContext().getFunctions(anderName);
-            anderFunction = resolveFunction(m_program, anderFunctions, anderArgumentTypes, anderName, propertyAccessExpression.codeLocation(), m_intrinsics);
+            anderFunction = resolveFunction(anderArgumentTypes, anderName, propertyAccessExpression.codeLocation());
+            if (hasError())
+                return;
             if (anderFunction)
                 anderReturnType = &downcast<AST::PointerType>(anderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
         }
@@ -992,8 +1157,9 @@
         if (additionalArgumentType)
             threadAnderArgumentTypes.append(*additionalArgumentType);
         auto anderName = propertyAccessExpression.anderFunctionName();
-        auto* anderFunctions = m_program.nameContext().getFunctions(anderName);
-        threadAnderFunction = resolveFunction(m_program, anderFunctions, threadAnderArgumentTypes, anderName, propertyAccessExpression.codeLocation(), m_intrinsics);
+        threadAnderFunction = resolveFunction(threadAnderArgumentTypes, anderName, propertyAccessExpression.codeLocation());
+        if (hasError())
+            return;
         if (threadAnderFunction)
             threadAnderReturnType = &downcast<AST::PointerType>(threadAnderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
     }
@@ -1039,8 +1205,9 @@
             setterArgumentTypes.append(*additionalArgumentType);
         setterArgumentTypes.append(fieldResolvingType);
         auto setterName = propertyAccessExpression.setterFunctionName();
-        auto* setterFunctions = m_program.nameContext().getFunctions(setterName);
-        setterFunction = resolveFunction(m_program, setterFunctions, setterArgumentTypes, setterName, propertyAccessExpression.codeLocation(), m_intrinsics);
+        setterFunction = resolveFunction(setterArgumentTypes, setterName, propertyAccessExpression.codeLocation());
+        if (hasError())
+            return;
         if (setterFunction)
             setterReturnType = &setterFunction->type();
     }
@@ -1420,22 +1587,24 @@
     // Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
     // We don't want to recurse to the same node twice.
 
-    NameContext& nameContext = m_program.nameContext();
-    auto* functions = nameContext.getFunctions(callExpression.name());
-    if (!functions) {
-        if (auto* types = nameContext.getTypes(callExpression.name())) {
-            if (types->size() == 1) {
-                if ((functions = nameContext.getFunctions("operator cast"_str)))
-                    callExpression.setCastData((*types)[0].get());
+    auto* function = resolveFunction(types, callExpression.name(), callExpression.codeLocation());
+    if (hasError())
+        return;
+
+    if (!function) {
+        NameContext& nameContext = m_program.nameContext();
+        if (auto* castTypes = nameContext.getTypes(callExpression.name())) {
+            if (castTypes->size() == 1) {
+                AST::NamedType& castType = (*castTypes)[0].get();
+                function = resolveFunction(types, "operator cast"_str, callExpression.codeLocation(), &castType);
+                if (hasError())
+                    return;
+                if (function)
+                    callExpression.setCastData(castType);
             }
         }
     }
-    if (!functions) {
-        setError(Error("Could not find any functions with appropriate name.", callExpression.codeLocation()));
-        return;
-    }
 
-    auto* function = resolveFunction(m_program, functions, types, callExpression.name(), callExpression.codeLocation(), m_intrinsics, callExpression.castReturnType());
     if (!function) {
         // FIXME: Add better error messages for why we can't resolve to one of the overrides.
         // https://bugs.webkit.org/show_bug.cgi?id=200133
_______________________________________________
webkit-changes mailing list
webkit-changes@lists.webkit.org
https://lists.webkit.org/mailman/listinfo/webkit-changes

Reply via email to