Modified: trunk/Source/WebCore/Modules/webgpu/WHLSL/WHLSLChecker.cpp (248338 => 248339)
--- trunk/Source/WebCore/Modules/webgpu/WHLSL/WHLSLChecker.cpp 2019-08-07 03:15:19 UTC (rev 248338)
+++ trunk/Source/WebCore/Modules/webgpu/WHLSL/WHLSLChecker.cpp 2019-08-07 04:17:29 UTC (rev 248339)
@@ -118,6 +118,87 @@
}
};
+class FunctionKey {
+public:
+ FunctionKey() = default;
+ FunctionKey(WTF::HashTableDeletedValueType)
+ {
+ m_castReturnType = bitwise_cast<AST::NamedType*>(static_cast<uintptr_t>(1));
+ }
+
+ FunctionKey(String name, Vector<std::reference_wrapper<AST::UnnamedType>> types, AST::NamedType* castReturnType = nullptr)
+ : m_name(WTFMove(name))
+ , m_types(WTFMove(types))
+ , m_castReturnType(castReturnType)
+ { }
+
+ bool isEmptyValue() const { return m_name.isNull(); }
+ bool isHashTableDeletedValue() const { return m_castReturnType == bitwise_cast<AST::NamedType*>(static_cast<uintptr_t>(1)); }
+
+ unsigned hash() const
+ {
+ unsigned hash = IntHash<size_t>::hash(m_types.size());
+ hash ^= m_name.hash();
+ for (size_t i = 0; i < m_types.size(); ++i)
+ hash ^= m_types[i].get().hash() + i;
+
+ if (m_castReturnType)
+ hash ^= WTF::PtrHash<AST::Type*>::hash(&m_castReturnType->unifyNode());
+
+ return hash;
+ }
+
+ bool operator==(const FunctionKey& other) const
+ {
+ if (m_types.size() != other.m_types.size())
+ return false;
+
+ if (m_name != other.m_name)
+ return false;
+
+ for (size_t i = 0; i < m_types.size(); ++i) {
+ if (!matches(m_types[i].get(), other.m_types[i].get()))
+ return false;
+ }
+
+ if (!!m_castReturnType != !!other.m_castReturnType)
+ return false;
+
+ if (!m_castReturnType)
+ return true;
+
+ if (&m_castReturnType->unifyNode() == &other.m_castReturnType->unifyNode())
+ return true;
+
+ return false;
+ }
+
+ struct Hash {
+ static unsigned hash(const FunctionKey& key)
+ {
+ return key.hash();
+ }
+
+ static bool equal(const FunctionKey& a, const FunctionKey& b)
+ {
+ return a == b;
+ }
+
+ static const bool safeToCompareToEmptyOrDeleted = false;
+ static const bool emptyValueIsZero = false;
+ };
+
+ struct Traits : public WTF::SimpleClassHashTraits<FunctionKey> {
+ static const bool hasIsEmptyValueFunction = true;
+ static bool isEmptyValue(const FunctionKey& key) { return key.isEmptyValue(); }
+ };
+
+private:
+ String m_name;
+ Vector<std::reference_wrapper<AST::UnnamedType>> m_types;
+ AST::NamedType* m_castReturnType;
+};
+
static AST::NativeFunctionDeclaration resolveWithOperatorAnderIndexer(CodeLocation location, AST::ArrayReferenceType& firstArgument, const Intrinsics& intrinsics)
{
const bool isOperator = true;
@@ -217,21 +298,6 @@
return WTF::nullopt;
}
-static AST::FunctionDeclaration* resolveFunction(Program& program, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>* possibleOverloads, Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation location, const Intrinsics& intrinsics, AST::NamedType* castReturnType = nullptr)
-{
- if (possibleOverloads) {
- if (AST::FunctionDeclaration* function = resolveFunctionOverload(*possibleOverloads, types, castReturnType))
- return function;
- }
-
- if (auto newFunction = resolveByInstantiation(name, location, types, intrinsics)) {
- program.append(WTFMove(*newFunction));
- return &program.nativeFunctionDeclarations().last();
- }
-
- return nullptr;
-}
-
static bool checkSemantics(Vector<EntryPointItem>& inputItems, Vector<EntryPointItem>& outputItems, const Optional<AST::EntryPointType>& entryPointType, const Intrinsics& intrinsics)
{
{
@@ -454,6 +520,25 @@
: m_intrinsics(intrinsics)
, m_program(program)
{
+ auto addFunction = [&] (AST::FunctionDeclaration& function) {
+ AST::NamedType* castReturnType = nullptr;
+ if (function.isCast() && is<AST::NamedType>(function.type().unifyNode()))
+ castReturnType = &downcast<AST::NamedType>(function.type().unifyNode());
+
+ Vector<std::reference_wrapper<AST::UnnamedType>> types;
+ types.reserveInitialCapacity(function.parameters().size());
+
+ for (auto& param : function.parameters())
+ types.uncheckedAppend(normalizedTypeForFunctionKey(*param->type()));
+
+ auto addResult = m_functions.add(FunctionKey { function.name(), WTFMove(types), castReturnType }, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>());
+ addResult.iterator->value.append(function);
+ };
+
+ for (auto& function : m_program.functionDefinitions())
+ addFunction(function.get());
+ for (auto& function : m_program.nativeFunctionDeclarations())
+ addFunction(function.get());
}
virtual ~Checker() = default;
@@ -511,6 +596,36 @@
void finishVisiting(AST::PropertyAccessExpression&, ResolvingType* additionalArgumentType = nullptr);
+ AST::FunctionDeclaration* resolveFunction(Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation, AST::NamedType* castReturnType = nullptr);
+
+ AST::UnnamedType& wrappedFloatType()
+ {
+ if (!m_wrappedFloatType)
+ m_wrappedFloatType = AST::TypeReference::wrap({ }, m_intrinsics.floatType());
+ return *m_wrappedFloatType;
+ }
+
+ AST::UnnamedType& genericPointerType()
+ {
+ if (!m_genericPointerType)
+ m_genericPointerType = AST::PointerType::create({ }, AST::AddressSpace::Thread, AST::TypeReference::wrap({ }, m_intrinsics.floatType()));
+ return *m_genericPointerType;
+ }
+
+ AST::UnnamedType& normalizedTypeForFunctionKey(AST::UnnamedType& type)
+ {
+ auto* unifyNode = &type.unifyNode();
+ if (unifyNode == &m_intrinsics.uintType() || unifyNode == &m_intrinsics.intType())
+ return wrappedFloatType();
+
+ if (is<AST::ReferenceType>(type))
+ return genericPointerType();
+
+ return type;
+ }
+
+ RefPtr<AST::TypeReference> m_wrappedFloatType;
+ RefPtr<AST::UnnamedType> m_genericPointerType;
HashMap<AST::_expression_*, std::unique_ptr<ResolvingType>> m_typeMap;
HashSet<String> m_vertexEntryPoints;
HashSet<String> m_fragmentEntryPoints;
@@ -518,6 +633,7 @@
const Intrinsics& m_intrinsics;
Program& m_program;
AST::FunctionDefinition* m_currentFunction { nullptr };
+ HashMap<FunctionKey, Vector<std::reference_wrapper<AST::FunctionDeclaration>, 1>, FunctionKey::Hash, FunctionKey::Traits> m_functions;
};
void Checker::visit(Program& program)
@@ -638,6 +754,53 @@
}));
}
+AST::FunctionDeclaration* Checker::resolveFunction(Vector<std::reference_wrapper<ResolvingType>>& types, const String& name, CodeLocation location, AST::NamedType* castReturnType)
+{
+ Vector<std::reference_wrapper<AST::UnnamedType>> unnamedTypes;
+ unnamedTypes.reserveInitialCapacity(types.size());
+
+ for (auto resolvingType : types) {
+ AST::UnnamedType* type = resolvingType.get().visit(WTF::makeVisitor([&](Ref<AST::UnnamedType>& unnamedType) -> AST::UnnamedType* {
+ return unnamedType.ptr();
+ }, [&](RefPtr<ResolvableTypeReference>& resolvableTypeReference) -> AST::UnnamedType* {
+ if (resolvableTypeReference->resolvableType().maybeResolvedType())
+ return &resolvableTypeReference->resolvableType().resolvedType();
+
+ if (resolvableTypeReference->resolvableType().isFloatLiteralType()
+ || resolvableTypeReference->resolvableType().isIntegerLiteralType()
+ || resolvableTypeReference->resolvableType().isUnsignedIntegerLiteralType())
+ return &wrappedFloatType();
+
+ if (resolvableTypeReference->resolvableType().isNullLiteralType())
+ return &genericPointerType();
+
+ return commit(resolvableTypeReference->resolvableType()).get();
+ }));
+
+ if (!type) {
+ setError(Error("Could not resolve the type of a constant."));
+ return nullptr;
+ }
+
+ unnamedTypes.uncheckedAppend(normalizedTypeForFunctionKey(*type));
+ }
+
+ {
+ auto iter = m_functions.find(FunctionKey { name, WTFMove(unnamedTypes), castReturnType });
+ if (iter != m_functions.end()) {
+ if (AST::FunctionDeclaration* function = resolveFunctionOverload(iter->value, types, castReturnType))
+ return function;
+ }
+ }
+
+ if (auto newFunction = resolveByInstantiation(name, location, types, m_intrinsics)) {
+ m_program.append(WTFMove(*newFunction));
+ return &m_program.nativeFunctionDeclarations().last();
+ }
+
+ return nullptr;
+}
+
void Checker::visit(AST::EnumerationDefinition& enumerationDefinition)
{
bool isSigned;
@@ -961,8 +1124,9 @@
if (additionalArgumentType)
getterArgumentTypes.append(*additionalArgumentType);
auto getterName = propertyAccessExpression.getterFunctionName();
- auto* getterFunctions = m_program.nameContext().getFunctions(getterName);
- getterFunction = resolveFunction(m_program, getterFunctions, getterArgumentTypes, getterName, propertyAccessExpression.codeLocation(), m_intrinsics);
+ getterFunction = resolveFunction(getterArgumentTypes, getterName, propertyAccessExpression.codeLocation());
+ if (hasError())
+ return;
if (getterFunction)
getterReturnType = &getterFunction->type();
}
@@ -977,8 +1141,9 @@
if (additionalArgumentType)
anderArgumentTypes.append(*additionalArgumentType);
auto anderName = propertyAccessExpression.anderFunctionName();
- auto* anderFunctions = m_program.nameContext().getFunctions(anderName);
- anderFunction = resolveFunction(m_program, anderFunctions, anderArgumentTypes, anderName, propertyAccessExpression.codeLocation(), m_intrinsics);
+ anderFunction = resolveFunction(anderArgumentTypes, anderName, propertyAccessExpression.codeLocation());
+ if (hasError())
+ return;
if (anderFunction)
anderReturnType = &downcast<AST::PointerType>(anderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
}
@@ -992,8 +1157,9 @@
if (additionalArgumentType)
threadAnderArgumentTypes.append(*additionalArgumentType);
auto anderName = propertyAccessExpression.anderFunctionName();
- auto* anderFunctions = m_program.nameContext().getFunctions(anderName);
- threadAnderFunction = resolveFunction(m_program, anderFunctions, threadAnderArgumentTypes, anderName, propertyAccessExpression.codeLocation(), m_intrinsics);
+ threadAnderFunction = resolveFunction(threadAnderArgumentTypes, anderName, propertyAccessExpression.codeLocation());
+ if (hasError())
+ return;
if (threadAnderFunction)
threadAnderReturnType = &downcast<AST::PointerType>(threadAnderFunction->type()).elementType(); // FIXME: https://bugs.webkit.org/show_bug.cgi?id=198164 Enforce the return of anders will always be a pointer
}
@@ -1039,8 +1205,9 @@
setterArgumentTypes.append(*additionalArgumentType);
setterArgumentTypes.append(fieldResolvingType);
auto setterName = propertyAccessExpression.setterFunctionName();
- auto* setterFunctions = m_program.nameContext().getFunctions(setterName);
- setterFunction = resolveFunction(m_program, setterFunctions, setterArgumentTypes, setterName, propertyAccessExpression.codeLocation(), m_intrinsics);
+ setterFunction = resolveFunction(setterArgumentTypes, setterName, propertyAccessExpression.codeLocation());
+ if (hasError())
+ return;
if (setterFunction)
setterReturnType = &setterFunction->type();
}
@@ -1420,22 +1587,24 @@
// Don't recurse on the castReturnType, because it's guaranteed to be a NamedType, which will get visited later.
// We don't want to recurse to the same node twice.
- NameContext& nameContext = m_program.nameContext();
- auto* functions = nameContext.getFunctions(callExpression.name());
- if (!functions) {
- if (auto* types = nameContext.getTypes(callExpression.name())) {
- if (types->size() == 1) {
- if ((functions = nameContext.getFunctions("operator cast"_str)))
- callExpression.setCastData((*types)[0].get());
+ auto* function = resolveFunction(types, callExpression.name(), callExpression.codeLocation());
+ if (hasError())
+ return;
+
+ if (!function) {
+ NameContext& nameContext = m_program.nameContext();
+ if (auto* castTypes = nameContext.getTypes(callExpression.name())) {
+ if (castTypes->size() == 1) {
+ AST::NamedType& castType = (*castTypes)[0].get();
+ function = resolveFunction(types, "operator cast"_str, callExpression.codeLocation(), &castType);
+ if (hasError())
+ return;
+ if (function)
+ callExpression.setCastData(castType);
}
}
}
- if (!functions) {
- setError(Error("Could not find any functions with appropriate name.", callExpression.codeLocation()));
- return;
- }
- auto* function = resolveFunction(m_program, functions, types, callExpression.name(), callExpression.codeLocation(), m_intrinsics, callExpression.castReturnType());
if (!function) {
// FIXME: Add better error messages for why we can't resolve to one of the overrides.
// https://bugs.webkit.org/show_bug.cgi?id=200133