http://git-wip-us.apache.org/repos/asf/cassandra/blob/01115f72/test/unit/org/apache/cassandra/cql3/validation/entities/UFAuthTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/validation/entities/UFAuthTest.java b/test/unit/org/apache/cassandra/cql3/validation/entities/UFAuthTest.java new file mode 100644 index 0000000..498f0dd --- /dev/null +++ b/test/unit/org/apache/cassandra/cql3/validation/entities/UFAuthTest.java @@ -0,0 +1,728 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.validation.entities; + +import java.lang.reflect.Field; +import java.util.*; + +import com.google.common.base.Joiner; +import com.google.common.collect.ImmutableSet; +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.auth.*; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.cql3.Attributes; +import org.apache.cassandra.cql3.CQLStatement; +import org.apache.cassandra.cql3.QueryProcessor; +import org.apache.cassandra.cql3.functions.Function; +import org.apache.cassandra.cql3.functions.FunctionName; +import org.apache.cassandra.cql3.functions.Functions; +import org.apache.cassandra.cql3.statements.BatchStatement; +import org.apache.cassandra.cql3.statements.ModificationStatement; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.exceptions.*; +import org.apache.cassandra.service.ClientState; +import org.apache.cassandra.utils.Pair; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class UFAuthTest extends CQLTester +{ + private static final Logger logger = LoggerFactory.getLogger(UFAuthTest.class); + + String roleName = "test_role"; + AuthenticatedUser user; + RoleResource role; + ClientState clientState; + + @BeforeClass + public static void setupAuthorizer() + { + try + { + IAuthorizer authorizer = new StubAuthorizer(); + Field authorizerField = DatabaseDescriptor.class.getDeclaredField("authorizer"); + authorizerField.setAccessible(true); + authorizerField.set(null, authorizer); + DatabaseDescriptor.setPermissionsValidity(0); + } + catch (IllegalAccessException | NoSuchFieldException e) + { + throw new RuntimeException(e); + } + } + + @Before + public void setup() throws Throwable + { + ((StubAuthorizer) DatabaseDescriptor.getAuthorizer()).clear(); + setupClientState(); + setupTable("CREATE TABLE %s (k int, v1 int, v2 int, PRIMARY KEY (k, v1))"); + } + + @Test + public void functionInSelection() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("SELECT k, %s FROM %s WHERE k = 1;", + functionCall(functionName), + KEYSPACE + "." + currentTable()); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInSelectPKRestriction() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("SELECT * FROM %s WHERE k = %s", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInSelectClusteringRestriction() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("SELECT * FROM %s WHERE k = 0 AND v1 = %s", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInSelectInRestriction() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("SELECT * FROM %s WHERE k IN (%s, %s)", + KEYSPACE + "." + currentTable(), + functionCall(functionName), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInSelectMultiColumnInRestriction() throws Throwable + { + setupTable("CREATE TABLE %s (k int, v1 int, v2 int, v3 int, PRIMARY KEY (k, v1, v2))"); + String functionName = createSimpleFunction(); + String cql = String.format("SELECT * FROM %s WHERE k=0 AND (v1, v2) IN ((%s, %s))", + KEYSPACE + "." + currentTable(), + functionCall(functionName), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInSelectMultiColumnEQRestriction() throws Throwable + { + setupTable("CREATE TABLE %s (k int, v1 int, v2 int, v3 int, PRIMARY KEY (k, v1, v2))"); + String functionName = createSimpleFunction(); + String cql = String.format("SELECT * FROM %s WHERE k=0 AND (v1, v2) = (%s, %s)", + KEYSPACE + "." + currentTable(), + functionCall(functionName), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInSelectMultiColumnSliceRestriction() throws Throwable + { + setupTable("CREATE TABLE %s (k int, v1 int, v2 int, v3 int, PRIMARY KEY (k, v1, v2))"); + String functionName = createSimpleFunction(); + String cql = String.format("SELECT * FROM %s WHERE k=0 AND (v1, v2) < (%s, %s)", + KEYSPACE + "." + currentTable(), + functionCall(functionName), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInSelectTokenEQRestriction() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("SELECT * FROM %s WHERE token(k) = token(%s)", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInSelectTokenSliceRestriction() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("SELECT * FROM %s WHERE token(k) < token(%s)", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInPKForInsert() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("INSERT INTO %s (k, v1, v2) VALUES (%s, 0, 0)", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInClusteringValuesForInsert() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("INSERT INTO %s (k, v1, v2) VALUES (0, %s, 0)", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInPKForDelete() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("DELETE FROM %s WHERE k = %s", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInClusteringValuesForDelete() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("DELETE FROM %s WHERE k = 0 AND v1 = %s", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void testBatchStatement() throws Throwable + { + List<ModificationStatement> statements = new ArrayList<>(); + List<String> functions = new ArrayList<>(); + for (int i = 0; i < 3; i++) + { + String functionName = createSimpleFunction(); + ModificationStatement stmt = + (ModificationStatement) getStatement(String.format("INSERT INTO %s (k, v1, v2) " + + "VALUES (%s, %s, %s)", + KEYSPACE + "." + currentTable(), + i, i, functionCall(functionName))); + functions.add(functionName); + statements.add(stmt); + } + BatchStatement batch = new BatchStatement(-1, BatchStatement.Type.LOGGED, statements, Attributes.none()); + assertUnauthorized(batch, functions); + + grantExecuteOnFunction(functions.get(0)); + assertUnauthorized(batch, functions.subList(1, functions.size())); + + grantExecuteOnFunction(functions.get(1)); + assertUnauthorized(batch, functions.subList(2, functions.size())); + + grantExecuteOnFunction(functions.get(2)); + batch.checkAccess(clientState); + } + + @Test + public void testNestedFunctions() throws Throwable + { + String innerFunctionName = createSimpleFunction(); + String outerFunctionName = createFunction("int", + "CREATE FUNCTION %s(input int) " + + " CALLED ON NULL INPUT" + + " RETURNS int" + + " LANGUAGE java" + + " AS 'return Integer.valueOf(0);'"); + assertPermissionsOnNestedFunctions(innerFunctionName, outerFunctionName); + } + + @Test + public void functionInStaticColumnRestrictionInSelect() throws Throwable + { + setupTable("CREATE TABLE %s (k int, s int STATIC, v1 int, v2 int, PRIMARY KEY(k, v1))"); + String functionName = createSimpleFunction(); + String cql = String.format("SELECT k FROM %s WHERE k = 0 AND s = %s", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInRegularCondition() throws Throwable + { + String functionName = createSimpleFunction(); + String cql = String.format("UPDATE %s SET v2 = 0 WHERE k = 0 AND v1 = 0 IF v2 = %s", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + @Test + public void functionInStaticColumnCondition() throws Throwable + { + setupTable("CREATE TABLE %s (k int, s int STATIC, v1 int, v2 int, PRIMARY KEY(k, v1))"); + String functionName = createSimpleFunction(); + String cql = String.format("UPDATE %s SET v2 = 0 WHERE k = 0 AND v1 = 0 IF s = %s", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInCollectionLiteralCondition() throws Throwable + { + setupTable("CREATE TABLE %s (k int, v1 int, m_val map<int, int>, PRIMARY KEY(k))"); + String functionName = createSimpleFunction(); + String cql = String.format("UPDATE %s SET v1 = 0 WHERE k = 0 IF m_val = {%s : %s}", + KEYSPACE + "." + currentTable(), + functionCall(functionName), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void functionInCollectionElementCondition() throws Throwable + { + setupTable("CREATE TABLE %s (k int, v1 int, m_val map<int, int>, PRIMARY KEY(k))"); + String functionName = createSimpleFunction(); + String cql = String.format("UPDATE %s SET v1 = 0 WHERE k = 0 IF m_val[%s] = %s", + KEYSPACE + "." + currentTable(), + functionCall(functionName), + functionCall(functionName)); + assertPermissionsOnFunction(cql, functionName); + } + + @Test + public void systemFunctionsRequireNoExplicitPrivileges() throws Throwable + { + // with terminal arguments, so evaluated at prepare time + String cql = String.format("UPDATE %s SET v2 = 0 WHERE k = blobasint(intasblob(0))", + KEYSPACE + "." + currentTable()); + getStatement(cql).checkAccess(clientState); + + // with non-terminal arguments, so evaluated at execution + String functionName = createSimpleFunction(); + grantExecuteOnFunction(functionName); + cql = String.format("UPDATE %s SET v2 = 0 WHERE k = blobasint(intasblob(%s))", + KEYSPACE + "." + currentTable(), + functionCall(functionName)); + getStatement(cql).checkAccess(clientState); + } + + @Test + public void requireExecutePermissionOnComponentFunctionsWhenDefiningAggregate() throws Throwable + { + String sFunc = createSimpleStateFunction(); + String fFunc = createSimpleFinalFunction(); + // aside from the component functions, we need CREATE on the keyspace's functions + DatabaseDescriptor.getAuthorizer().grant(AuthenticatedUser.SYSTEM_USER, + ImmutableSet.of(Permission.CREATE), + FunctionResource.keyspace(KEYSPACE), + role); + String aggDef = String.format(aggregateCql(sFunc, fFunc), + KEYSPACE + ".aggregate_for_permissions_test"); + + assertUnauthorized(aggDef, sFunc, "int, int"); + grantExecuteOnFunction(sFunc); + + assertUnauthorized(aggDef, fFunc, "int"); + grantExecuteOnFunction(fFunc); + + getStatement(aggDef).checkAccess(clientState); + } + + @Test + public void revokeExecutePermissionsOnAggregateComponents() throws Throwable + { + String sFunc = createSimpleStateFunction(); + String fFunc = createSimpleFinalFunction(); + String aggDef = aggregateCql(sFunc, fFunc); + grantExecuteOnFunction(sFunc); + grantExecuteOnFunction(fFunc); + + String aggregate = createAggregate(KEYSPACE, "int", aggDef); + grantExecuteOnFunction(aggregate); + + String cql = String.format("SELECT %s(v1) FROM %s", + aggregate, + KEYSPACE + "." + currentTable()); + getStatement(cql).checkAccess(clientState); + + // check that revoking EXECUTE permission on any one of the + // component functions means we lose the ability to execute it + revokeExecuteOnFunction(aggregate); + assertUnauthorized(cql, aggregate, "int"); + grantExecuteOnFunction(aggregate); + getStatement(cql).checkAccess(clientState); + + revokeExecuteOnFunction(sFunc); + assertUnauthorized(cql, sFunc, "int, int"); + grantExecuteOnFunction(sFunc); + getStatement(cql).checkAccess(clientState); + + revokeExecuteOnFunction(fFunc); + assertUnauthorized(cql, fFunc, "int"); + grantExecuteOnFunction(fFunc); + getStatement(cql).checkAccess(clientState); + } + + @Test + public void functionWrappingAggregate() throws Throwable + { + String outerFunc = createFunction("int", + "CREATE FUNCTION %s(input int) " + + "CALLED ON NULL INPUT " + + "RETURNS int " + + "LANGUAGE java " + + "AS 'return input;'"); + + String sFunc = createSimpleStateFunction(); + String fFunc = createSimpleFinalFunction(); + String aggDef = aggregateCql(sFunc, fFunc); + grantExecuteOnFunction(sFunc); + grantExecuteOnFunction(fFunc); + + String aggregate = createAggregate(KEYSPACE, "int", aggDef); + + String cql = String.format("SELECT %s(%s(v1)) FROM %s", + outerFunc, + aggregate, + KEYSPACE + "." + currentTable()); + + assertUnauthorized(cql, outerFunc, "int"); + grantExecuteOnFunction(outerFunc); + + assertUnauthorized(cql, aggregate, "int"); + grantExecuteOnFunction(aggregate); + + getStatement(cql).checkAccess(clientState); + } + + @Test + public void aggregateWrappingFunction() throws Throwable + { + String innerFunc = createFunction("int", + "CREATE FUNCTION %s(input int) " + + "CALLED ON NULL INPUT " + + "RETURNS int " + + "LANGUAGE java " + + "AS 'return input;'"); + + String sFunc = createSimpleStateFunction(); + String fFunc = createSimpleFinalFunction(); + String aggDef = aggregateCql(sFunc, fFunc); + grantExecuteOnFunction(sFunc); + grantExecuteOnFunction(fFunc); + + String aggregate = createAggregate(KEYSPACE, "int", aggDef); + + String cql = String.format("SELECT %s(%s(v1)) FROM %s", + aggregate, + innerFunc, + KEYSPACE + "." + currentTable()); + + assertUnauthorized(cql, aggregate, "int"); + grantExecuteOnFunction(aggregate); + + assertUnauthorized(cql, innerFunc, "int"); + grantExecuteOnFunction(innerFunc); + + getStatement(cql).checkAccess(clientState); + } + + private void assertPermissionsOnNestedFunctions(String innerFunction, String outerFunction) throws Throwable + { + String cql = String.format("SELECT k, %s FROM %s WHERE k=0", + functionCall(outerFunction, functionCall(innerFunction)), + KEYSPACE + "." + currentTable()); + // fail fast with an UAE on the first function + assertUnauthorized(cql, outerFunction, "int"); + grantExecuteOnFunction(outerFunction); + + // after granting execute on the first function, still fail due to the inner function + assertUnauthorized(cql, innerFunction, ""); + grantExecuteOnFunction(innerFunction); + + // now execution of both is permitted + getStatement(cql).checkAccess(clientState); + } + + private void assertPermissionsOnFunction(String cql, String functionName) throws Throwable + { + assertPermissionsOnFunction(cql, functionName, ""); + } + + private void assertPermissionsOnFunction(String cql, String functionName, String argTypes) throws Throwable + { + assertUnauthorized(cql, functionName, argTypes); + grantExecuteOnFunction(functionName); + getStatement(cql).checkAccess(clientState); + } + + private void assertUnauthorized(BatchStatement batch, Iterable<String> functionNames) throws Throwable + { + try + { + batch.checkAccess(clientState); + fail("Expected an UnauthorizedException, but none was thrown"); + } + catch (UnauthorizedException e) + { + String functions = String.format("(%s)", Joiner.on("|").join(functionNames)); + assertTrue(e.getLocalizedMessage() + .matches(String.format("User %s has no EXECUTE permission on <function %s\\(\\)> or any of its parents", + roleName, + functions))); + } + } + + private void assertUnauthorized(String cql, String functionName, String argTypes) throws Throwable + { + try + { + getStatement(cql).checkAccess(clientState); + fail("Expected an UnauthorizedException, but none was thrown"); + } + catch (UnauthorizedException e) + { + assertEquals(String.format("User %s has no EXECUTE permission on <function %s(%s)> or any of its parents", + roleName, + functionName, + argTypes), + e.getLocalizedMessage()); + } + } + + private void grantExecuteOnFunction(String functionName) + { + DatabaseDescriptor.getAuthorizer().grant(AuthenticatedUser.SYSTEM_USER, + ImmutableSet.of(Permission.EXECUTE), + functionResource(functionName), + role); + } + + private void revokeExecuteOnFunction(String functionName) + { + DatabaseDescriptor.getAuthorizer().revoke(AuthenticatedUser.SYSTEM_USER, + ImmutableSet.of(Permission.EXECUTE), + functionResource(functionName), + role); + } + + void setupClientState() + { + + try + { + role = RoleResource.role(roleName); + // use reflection to set the logged in user so that we don't need to + // bother setting up an IRoleManager + user = new AuthenticatedUser(roleName); + clientState = ClientState.forInternalCalls(); + Field userField = ClientState.class.getDeclaredField("user"); + userField.setAccessible(true); + userField.set(clientState, user); + } + catch (IllegalAccessException | NoSuchFieldException e) + { + throw new RuntimeException(e); + } + } + + private void setupTable(String tableDef) throws Throwable + { + createTable(tableDef); + // test user needs SELECT & MODIFY on the table regardless of permissions on any function + DatabaseDescriptor.getAuthorizer().grant(AuthenticatedUser.SYSTEM_USER, + ImmutableSet.of(Permission.SELECT, Permission.MODIFY), + DataResource.table(KEYSPACE, currentTable()), + RoleResource.role(user.getName())); + } + + private String aggregateCql(String sFunc, String fFunc) + { + return "CREATE AGGREGATE %s(int) " + + "SFUNC " + shortFunctionName(sFunc) + " " + + "STYPE int " + + "FINALFUNC " + shortFunctionName(fFunc) + " " + + "INITCOND 0"; + } + + private String createSimpleStateFunction() throws Throwable + { + return createFunction("int, int", + "CREATE FUNCTION %s(a int, b int) " + + "CALLED ON NULL INPUT " + + "RETURNS int " + + "LANGUAGE java " + + "AS 'return Integer.valueOf( (a != null ? a.intValue() : 0 ) + b.intValue());'"); + } + + private String createSimpleFinalFunction() throws Throwable + { + return createFunction("int", + "CREATE FUNCTION %s(a int) " + + "CALLED ON NULL INPUT " + + "RETURNS int " + + "LANGUAGE java " + + "AS 'return a;'"); + } + + private String createSimpleFunction() throws Throwable + { + return createFunction("", + "CREATE FUNCTION %s() " + + " CALLED ON NULL INPUT " + + " RETURNS int " + + " LANGUAGE java " + + " AS 'return Integer.valueOf(0);'"); + } + + private String createFunction(String argTypes, String functionDef) throws Throwable + { + return createFunction(KEYSPACE, argTypes, functionDef); + } + + private CQLStatement getStatement(String cql) + { + return QueryProcessor.getStatement(cql, clientState).statement; + } + + private FunctionResource functionResource(String functionName) + { + // Note that this is somewhat brittle as it assumes that function names are + // truly unique. As such, it will break in the face of overloading. + // It is here to avoid having to duplicate the functionality of CqlParser + // for transforming cql types into AbstractTypes + FunctionName fn = parseFunctionName(functionName); + List<Function> functions = Functions.find(fn); + assertEquals(String.format("Expected a single function definition for %s, but found %s", + functionName, + functions.size()), + 1, functions.size()); + return FunctionResource.function(fn.keyspace, fn.name, functions.get(0).argTypes()); + } + + private String functionCall(String functionName, String...args) + { + return String.format("%s(%s)", functionName, Joiner.on(",").join(args)); + } + + static class StubAuthorizer implements IAuthorizer + { + Map<Pair<String, IResource>, Set<Permission>> userPermissions = new HashMap<>(); + + private void clear() + { + userPermissions.clear(); + } + + public Set<Permission> authorize(AuthenticatedUser user, IResource resource) + { + Pair<String, IResource> key = Pair.create(user.getName(), resource); + Set<Permission> perms = userPermissions.get(key); + return perms != null ? perms : Collections.<Permission>emptySet(); + } + + public void grant(AuthenticatedUser performer, + Set<Permission> permissions, + IResource resource, + RoleResource grantee) throws RequestValidationException, RequestExecutionException + { + Pair<String, IResource> key = Pair.create(grantee.getRoleName(), resource); + Set<Permission> perms = userPermissions.get(key); + if (null == perms) + { + perms = new HashSet<>(); + userPermissions.put(key, perms); + } + perms.addAll(permissions); + } + + public void revoke(AuthenticatedUser performer, + Set<Permission> permissions, + IResource resource, + RoleResource revokee) throws RequestValidationException, RequestExecutionException + { + Pair<String, IResource> key = Pair.create(revokee.getRoleName(), resource); + Set<Permission> perms = userPermissions.get(key); + if (null != perms) + perms.removeAll(permissions); + if (perms.isEmpty()) + userPermissions.remove(key); + } + + public Set<PermissionDetails> list(AuthenticatedUser performer, + Set<Permission> permissions, + IResource resource, + RoleResource grantee) throws RequestValidationException, RequestExecutionException + { + Pair<String, IResource> key = Pair.create(grantee.getRoleName(), resource); + Set<Permission> perms = userPermissions.get(key); + if (perms == null) + return Collections.emptySet(); + + + Set<PermissionDetails> details = new HashSet<>(); + for (Permission permission : perms) + { + if (permissions.contains(permission)) + details.add(new PermissionDetails(grantee.getRoleName(), resource, permission)); + } + return details; + } + + public void revokeAllFrom(RoleResource revokee) + { + for (Pair<String, IResource> key : userPermissions.keySet()) + if (key.left.equals(revokee.getRoleName())) + userPermissions.remove(key); + } + + public void revokeAllOn(IResource droppedResource) + { + for (Pair<String, IResource> key : userPermissions.keySet()) + if (key.right.equals(droppedResource)) + userPermissions.remove(key); + + } + + public Set<? extends IResource> protectedResources() + { + return Collections.emptySet(); + } + + public void validateConfiguration() throws ConfigurationException + { + + } + + public void setup() + { + + } + } +}
http://git-wip-us.apache.org/repos/asf/cassandra/blob/01115f72/test/unit/org/apache/cassandra/cql3/validation/entities/UFIdentificationTest.java ---------------------------------------------------------------------- diff --git a/test/unit/org/apache/cassandra/cql3/validation/entities/UFIdentificationTest.java b/test/unit/org/apache/cassandra/cql3/validation/entities/UFIdentificationTest.java new file mode 100644 index 0000000..28b8afc --- /dev/null +++ b/test/unit/org/apache/cassandra/cql3/validation/entities/UFIdentificationTest.java @@ -0,0 +1,380 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.cassandra.cql3.validation.entities; + +import java.util.*; + +import com.google.common.base.Joiner; +import com.google.common.collect.Iterables; +import org.junit.Before; +import org.junit.Ignore; +import org.junit.Test; + +import org.apache.cassandra.cql3.Attributes; +import org.apache.cassandra.cql3.CQLStatement; +import org.apache.cassandra.cql3.QueryProcessor; +import org.apache.cassandra.cql3.functions.Function; +import org.apache.cassandra.cql3.statements.BatchStatement; +import org.apache.cassandra.cql3.statements.ModificationStatement; +import org.apache.cassandra.cql3.CQLTester; +import org.apache.cassandra.service.ClientState; + +import static org.junit.Assert.assertTrue; + +/** + * Checks the collection of Function objects returned by CQLStatement.getFunction + * matches expectations. This is intended to verify the various subcomponents of + * the statement (Operations, Terms, Restrictions, RestrictionSet, Selection, + * Selector, SelectorFactories etc) properly report any constituent functions. + * Some purely terminal functions are resolved at preparation, so those are not + * included in the reported list. They still need to be surveyed, to verify the + * calling client has the necessary permissions. UFAuthTest includes tests which + * verify this more thoroughly than we can here. + */ +public class UFIdentificationTest extends CQLTester +{ + private com.google.common.base.Function<Function, String> toFunctionNames = new com.google.common.base.Function<Function, String>() + { + public String apply(Function f) + { + return f.name().keyspace + "." + f.name().name; + } + }; + + String tFunc; + String iFunc; + String lFunc; + String sFunc; + String mFunc; + String uFunc; + String udtFunc; + + String userType; + + @Before + public void setup() throws Throwable + { + userType = KEYSPACE + "." + createType("CREATE TYPE %s (t text, i int)"); + + createTable("CREATE TABLE %s (" + + " key int, " + + " t_sc text STATIC," + + " i_cc int, " + + " t_cc text, " + + " i_val int," + + " l_val list<int>," + + " s_val set<int>," + + " m_val map<int, int>," + + " u_val timeuuid," + + " udt_val frozen<" + userType + ">," + + " PRIMARY KEY (key, i_cc, t_cc)" + + ")"); + + tFunc = createEchoFunction("text"); + iFunc = createEchoFunction("int"); + lFunc = createEchoFunction("list<int>"); + sFunc = createEchoFunction("set<int>"); + mFunc = createEchoFunction("map<int, int>"); + uFunc = createEchoFunction("timeuuid"); + udtFunc = createEchoFunction(userType); + } + + @Test + public void testSimpleModificationStatement() throws Throwable + { + assertFunctions(cql("INSERT INTO %s (key, t_sc) VALUES (0, %s)", functionCall(tFunc, "'foo'")), tFunc); + assertFunctions(cql("INSERT INTO %s (key, i_cc) VALUES (0, %s)", functionCall(iFunc, "1")), iFunc); + assertFunctions(cql("INSERT INTO %s (key, t_cc) VALUES (0, %s)", functionCall(tFunc, "'foo'")), tFunc); + assertFunctions(cql("INSERT INTO %s (key, i_val) VALUES (0, %s)", functionCall(iFunc, "1")), iFunc); + assertFunctions(cql("INSERT INTO %s (key, l_val) VALUES (0, %s)", functionCall(lFunc, "[1]")), lFunc); + assertFunctions(cql("INSERT INTO %s (key, s_val) VALUES (0, %s)", functionCall(sFunc, "{1}")), sFunc); + assertFunctions(cql("INSERT INTO %s (key, m_val) VALUES (0, %s)", functionCall(mFunc, "{1:1}")), mFunc); + assertFunctions(cql("INSERT INTO %s (key, udt_val) VALUES (0,%s)", functionCall(udtFunc, "{i : 1, t : 'foo'}")), udtFunc); + assertFunctions(cql("INSERT INTO %s (key, u_val) VALUES (0, %s)", functionCall(uFunc, "now()")), uFunc, "system.now"); + } + + @Test + public void testNonTerminalCollectionLiterals() throws Throwable + { + String iFunc2 = createEchoFunction("int"); + String mapValue = String.format("{%s:%s}", functionCall(iFunc, "1"), functionCall(iFunc2, "1")); + assertFunctions(cql("INSERT INTO %s (key, m_val) VALUES (0, %s)", mapValue), iFunc, iFunc2); + + String listValue = String.format("[%s]", functionCall(iFunc, "1")); + assertFunctions(cql("INSERT INTO %s (key, l_val) VALUES (0, %s)", listValue), iFunc); + + String setValue = String.format("{%s}", functionCall(iFunc, "1")); + assertFunctions(cql("INSERT INTO %s (key, s_val) VALUES (0, %s)", setValue), iFunc); + } + + @Test + public void testNonTerminalUDTLiterals() throws Throwable + { + String udtValue = String.format("{ i: %s, t : %s } ", functionCall(iFunc, "1"), functionCall(tFunc, "'foo'")); + assertFunctions(cql("INSERT INTO %s (key, udt_val) VALUES (0, %s)", udtValue), iFunc, tFunc); + } + + @Test + public void testModificationStatementWithConditions() throws Throwable + { + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF t_sc=%s", functionCall(tFunc, "'foo'")), tFunc); + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF i_val=%s", functionCall(iFunc, "1")), iFunc); + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF l_val=%s", functionCall(lFunc, "[1]")), lFunc); + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF s_val=%s", functionCall(sFunc, "{1}")), sFunc); + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF m_val=%s", functionCall(mFunc, "{1:1}")), mFunc); + + + String iFunc2 = createEchoFunction("int"); + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF i_val IN (%s, %S)", + functionCall(iFunc, "1"), + functionCall(iFunc2, "2")), + iFunc, iFunc2); + + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF u_val=%s", + functionCall(uFunc, "now()")), + uFunc, "system.now"); + + // conditions on collection elements + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF l_val[%s] = %s", + functionCall(iFunc, "1"), + functionCall(iFunc2, "1")), + iFunc, iFunc2); + assertFunctions(cql("UPDATE %s SET i_val=0 WHERE key=0 IF m_val[%s] = %s", + functionCall(iFunc, "1"), + functionCall(iFunc2, "1")), + iFunc, iFunc2); + } + + @Test @Ignore + // Technically, attributes like timestamp and ttl are Terms so could potentially + // resolve to function calls (& so you can call getFunctions on them) + // However, this is currently disallowed by CQL syntax + public void testModificationStatementWithAttributesFromFunction() throws Throwable + { + String longFunc = createEchoFunction("bigint"); + assertFunctions(cql("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0) USING TIMESTAMP %s", + functionCall(longFunc, "9999")), + longFunc); + + assertFunctions(cql("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0) USING TTL %s", + functionCall(iFunc, "8888")), + iFunc); + + assertFunctions(cql("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0) USING TIMESTAMP %s AND TTL %s", + functionCall(longFunc, "9999"), functionCall(iFunc, "8888")), + longFunc, iFunc); + } + + @Test + public void testModificationStatementWithNestedFunctions() throws Throwable + { + String iFunc2 = createEchoFunction("int"); + String iFunc3 = createEchoFunction("int"); + String iFunc4 = createEchoFunction("int"); + String iFunc5 = createEchoFunction("int"); + String iFunc6 = createEchoFunction("int"); + String nestedFunctionCall = nestedFunctionCall(iFunc6, iFunc5, + nestedFunctionCall(iFunc4, iFunc3, + nestedFunctionCall(iFunc2, iFunc, "1"))); + + assertFunctions(cql("DELETE FROM %s WHERE key=%s", nestedFunctionCall), + iFunc, iFunc2, iFunc3, iFunc4, iFunc5, iFunc6); + } + + @Test + public void testSelectStatementSimpleRestrictions() throws Throwable + { + assertFunctions(cql("SELECT i_val FROM %s WHERE key=%s", functionCall(iFunc, "1")), iFunc); + assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND t_sc=%s", functionCall(tFunc, "'foo'")), tFunc); + assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND i_cc=%s AND t_cc='foo'", functionCall(iFunc, "1")), iFunc); + assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND i_cc=0 AND t_cc=%s", functionCall(tFunc, "'foo'")), tFunc); + + String iFunc2 = createEchoFunction("int"); + String tFunc2 = createEchoFunction("text"); + assertFunctions(cql("SELECT i_val FROM %s WHERE key=%s AND t_sc=%s AND i_cc=%s AND t_cc=%s", + functionCall(iFunc, "1"), + functionCall(tFunc, "'foo'"), + functionCall(iFunc2, "1"), + functionCall(tFunc2, "'foo'")), + iFunc, tFunc, iFunc2, tFunc2); + } + + @Test + public void testSelectStatementRestrictionsWithNestedFunctions() throws Throwable + { + String iFunc2 = createEchoFunction("int"); + String iFunc3 = createEchoFunction("int"); + String iFunc4 = createEchoFunction("int"); + String iFunc5 = createEchoFunction("int"); + String iFunc6 = createEchoFunction("int"); + String nestedFunctionCall = nestedFunctionCall(iFunc6, iFunc5, + nestedFunctionCall(iFunc3, iFunc4, + nestedFunctionCall(iFunc, iFunc2, "1"))); + + assertFunctions(cql("SELECT i_val FROM %s WHERE key=%s", nestedFunctionCall), + iFunc, iFunc2, iFunc3, iFunc4, iFunc5, iFunc6); + } + + @Test + public void testNonTerminalTupleInSelectRestrictions() throws Throwable + { + assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) IN ((%s, %s))", + functionCall(iFunc, "1"), + functionCall(tFunc, "'foo'")), + iFunc, tFunc); + + assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) = (%s, %s)", + functionCall(iFunc, "1"), + functionCall(tFunc, "'foo'")), + iFunc, tFunc); + + assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) > (%s, %s)", + functionCall(iFunc, "1"), + functionCall(tFunc, "'foo'")), + iFunc, tFunc); + + assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) < (%s, %s)", + functionCall(iFunc, "1"), + functionCall(tFunc, "'foo'")), + iFunc, tFunc); + + assertFunctions(cql("SELECT i_val FROM %s WHERE key=0 AND (i_cc, t_cc) > (%s, %s) AND (i_cc, t_cc) < (%s, %s)", + functionCall(iFunc, "1"), + functionCall(tFunc, "'foo'"), + functionCall(iFunc, "1"), + functionCall(tFunc, "'foo'")), + iFunc, tFunc); + } + + @Test + public void testNestedFunctionInTokenRestriction() throws Throwable + { + String iFunc2 = createEchoFunction("int"); + assertFunctions(cql("SELECT i_val FROM %s WHERE token(key) = token(%s)", functionCall(iFunc, "1")), + "system.token", iFunc); + assertFunctions(cql("SELECT i_val FROM %s WHERE token(key) > token(%s)", functionCall(iFunc, "1")), + "system.token", iFunc); + assertFunctions(cql("SELECT i_val FROM %s WHERE token(key) < token(%s)", functionCall(iFunc, "1")), + "system.token", iFunc); + assertFunctions(cql("SELECT i_val FROM %s WHERE token(key) > token(%s) AND token(key) < token(%s)", + functionCall(iFunc, "1"), + functionCall(iFunc2, "1")), + "system.token", iFunc, iFunc2); + } + + @Test + public void testSelectStatementSimpleSelections() throws Throwable + { + String iFunc2 = createEchoFunction("int"); + execute("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0)"); + assertFunctions(cql2("SELECT i_val, %s FROM %s WHERE key=0", functionCall(iFunc, "i_val")), iFunc); + assertFunctions(cql2("SELECT i_val, %s FROM %s WHERE key=0", nestedFunctionCall(iFunc, iFunc2, "i_val")), iFunc, iFunc2); + } + + @Test + public void testSelectStatementNestedSelections() throws Throwable + { + String iFunc2 = createEchoFunction("int"); + execute("INSERT INTO %s (key, i_cc, t_cc, i_val) VALUES (0, 0, 'foo', 0)"); + assertFunctions(cql2("SELECT i_val, %s FROM %s WHERE key=0", functionCall(iFunc, "i_val")), iFunc); + assertFunctions(cql2("SELECT i_val, %s FROM %s WHERE key=0", nestedFunctionCall(iFunc, iFunc2, "i_val")), iFunc, iFunc2); + } + + @Test + public void testBatchStatement() throws Throwable + { + String iFunc2 = createEchoFunction("int"); + List<ModificationStatement> statements = new ArrayList<>(); + statements.add(modificationStatement(cql("INSERT INTO %s (key, i_cc, t_cc) VALUES (%s, 0, 'foo')", + functionCall(iFunc, "0")))); + statements.add(modificationStatement(cql("INSERT INTO %s (key, i_cc, t_cc) VALUES (1, %s, 'foo')", + functionCall(iFunc2, "1")))); + statements.add(modificationStatement(cql("INSERT INTO %s (key, i_cc, t_cc) VALUES (2, 2, %s)", + functionCall(tFunc, "'foo'")))); + + BatchStatement batch = new BatchStatement(-1, BatchStatement.Type.LOGGED, statements, Attributes.none()); + assertFunctions(batch, iFunc, iFunc2, tFunc); + } + + @Test + public void testBatchStatementWithConditions() throws Throwable + { + List<ModificationStatement> statements = new ArrayList<>(); + statements.add(modificationStatement(cql("UPDATE %s SET i_val = %s WHERE key=0 AND i_cc=0 and t_cc='foo' IF l_val = %s", + functionCall(iFunc, "0"), functionCall(lFunc, "[1]")))); + statements.add(modificationStatement(cql("UPDATE %s SET i_val = %s WHERE key=0 AND i_cc=1 and t_cc='foo' IF s_val = %s", + functionCall(iFunc, "0"), functionCall(sFunc, "{1}")))); + + BatchStatement batch = new BatchStatement(-1, BatchStatement.Type.LOGGED, statements, Attributes.none()); + assertFunctions(batch, iFunc, lFunc, sFunc); + } + + private ModificationStatement modificationStatement(String cql) + { + return (ModificationStatement) QueryProcessor.getStatement(cql, ClientState.forInternalCalls()).statement; + } + + private void assertFunctions(String cql, String... function) + { + CQLStatement stmt = QueryProcessor.getStatement(cql, ClientState.forInternalCalls()).statement; + assertFunctions(stmt, function); + } + + private void assertFunctions(CQLStatement stmt, String... function) + { + Set<String> expected = com.google.common.collect.Sets.newHashSet(function); + Set<String> actual = com.google.common.collect.Sets.newHashSet(Iterables.transform(stmt.getFunctions(), + toFunctionNames)); + assertTrue(com.google.common.collect.Sets.symmetricDifference(expected, actual).isEmpty()); + } + + private String cql(String template, String... params) + { + String tableName = KEYSPACE + "." + currentTable(); + return String.format(template, com.google.common.collect.Lists.asList(tableName, params).toArray()); + } + + // Alternative query builder - appends the table name to the supplied params, + // for stmts of the form "SELECT x, %s FROM %s WHERE y=0" + private String cql2(String template, String... params) + { + Object[] args = Arrays.copyOf(params, params.length + 1); + args[params.length] = KEYSPACE + "." + currentTable(); + return String.format(template, args); + } + + private String functionCall(String fName, String... args) + { + return String.format("%s(%s)", fName, Joiner.on(",").join(args)); + } + + private String nestedFunctionCall(String outer, String inner, String innerArgs) + { + return functionCall(outer, functionCall(inner, innerArgs)); + } + + private String createEchoFunction(String type) throws Throwable + { + return createFunction(KEYSPACE, type, + "CREATE FUNCTION %s(input " + type + ")" + + " CALLED ON NULL INPUT" + + " RETURNS " + type + + " LANGUAGE java" + + " AS ' return input;'"); + } +}
