This is an automated email from the ASF dual-hosted git repository. samt pushed a commit to branch cassandra-3.0 in repository https://gitbox.apache.org/repos/asf/cassandra.git
The following commit(s) were added to refs/heads/cassandra-3.0 by this push: new 0388d89 Allow max protocol version to be capped 0388d89 is described below commit 0388d89e29393d0b1f50baa24848bc8cb0a7c9a3 Author: Sam Tunnicliffe <s...@beobal.com> AuthorDate: Tue Jul 9 12:51:16 2019 +0100 Allow max protocol version to be capped Patch by Sam Tunnicliffe; reviewed by Alex Petrov and Aleksey Yeschenko for CASSANDRA-15193 --- CHANGES.txt | 1 + NEWS.txt | 5 + bin/cqlsh.py | 47 +++--- pylib/cqlshlib/test/cassconnect.py | 8 +- pylib/cqlshlib/test/test_cqlsh_completion.py | 2 +- pylib/cqlshlib/test/test_cqlsh_output.py | 86 +++++------ src/java/org/apache/cassandra/config/Config.java | 2 +- .../cassandra/config/DatabaseDescriptor.java | 20 +++ .../org/apache/cassandra/db/SystemKeyspace.java | 43 ++++++ .../apache/cassandra/service/CassandraDaemon.java | 10 ++ .../cassandra/service/NativeTransportService.java | 19 +++ .../apache/cassandra/service/StorageService.java | 21 ++- .../cassandra/service/StorageServiceMBean.java | 3 + .../cassandra/transport/ConfiguredLimit.java | 117 +++++++++++++++ src/java/org/apache/cassandra/transport/Frame.java | 10 +- .../org/apache/cassandra/transport/Message.java | 9 +- .../cassandra/transport/ProtocolVersionLimit.java | 27 ++++ .../org/apache/cassandra/transport/Server.java | 22 ++- .../apache/cassandra/transport/SimpleClient.java | 4 +- test/unit/org/apache/cassandra/cql3/CQLTester.java | 49 +++++- .../cassandra/transport/DynamicLimitTest.java | 111 ++++++++++++++ .../cassandra/transport/ProtocolErrorTest.java | 8 +- .../transport/ProtocolNegotiationTest.java | 166 +++++++++++++++++++++ .../cassandra/transport/ProtocolTestHelper.java | 95 ++++++++++++ 24 files changed, 794 insertions(+), 91 deletions(-) diff --git a/CHANGES.txt b/CHANGES.txt index ca6ea2e..925a90a 100644 --- a/CHANGES.txt +++ b/CHANGES.txt @@ -1,4 +1,5 @@ 3.0.19 + * Add ability to cap max negotiable protocol version (CASSANDRA-15193) * Gossip tokens on startup if available (CASSANDRA-15335) * Fix resource leak in CompressedSequentialWriter (CASSANDRA-15340) * Fix merge which reverted CASSANDRA-14993 (CASSANDRA-15289) diff --git a/NEWS.txt b/NEWS.txt index 704fde1..c03284b 100644 --- a/NEWS.txt +++ b/NEWS.txt @@ -49,6 +49,11 @@ Upgrading --------- - repair_session_max_tree_depth setting has been added to cassandra.yaml to allow operators to reduce merkle tree size if repair is creating too much heap pressure. See CASSANDRA-14096 for details. + - native_transport_max_negotiable_protocol_version has been added to cassandra.yaml to allow operators to + enforce an upper limit on the version of the native protocol that servers will negotiate with clients. + This can be used during upgrades from 2.1 to 3.0 to prevent errors due to incompatible paging state formats + between the two versions. See CASSANDRA-15193 for details. + 3.0.18 ====== diff --git a/bin/cqlsh.py b/bin/cqlsh.py index 1f1fa47..08b026c 100644 --- a/bin/cqlsh.py +++ b/bin/cqlsh.py @@ -174,8 +174,6 @@ from cqlshlib.util import get_file_encoding_bomsize, trim_if_present DEFAULT_HOST = '127.0.0.1' DEFAULT_PORT = 9042 -DEFAULT_CQLVER = '3.4.0' -DEFAULT_PROTOCOL_VERSION = 4 DEFAULT_CONNECT_TIMEOUT_SECONDS = 5 DEFAULT_REQUEST_TIMEOUT_SECONDS = 10 @@ -216,9 +214,13 @@ parser.add_option('--debug', action='store_true', parser.add_option("--encoding", help="Specify a non-default encoding for output." + " (Default: %s)" % (UTF8,)) parser.add_option("--cqlshrc", help="Specify an alternative cqlshrc file location.") -parser.add_option('--cqlversion', default=DEFAULT_CQLVER, - help='Specify a particular CQL version (default: %default).' +parser.add_option('--cqlversion', default=None, + help='Specify a particular CQL version, ' + 'by default the highest version supported by the server will be used.' ' Examples: "3.0.3", "3.1.0"') +parser.add_option("--protocol-version", type="int", default=None, + help='Specify a specific protcol version otherwise the client will default and downgrade as necessary') + parser.add_option("-e", "--execute", help='Execute the statement and quit.') parser.add_option("--connect-timeout", default=DEFAULT_CONNECT_TIMEOUT_SECONDS, dest='connect_timeout', help='Specify the connection timeout in seconds (default: %default seconds).') @@ -704,7 +706,7 @@ class Shell(cmd.Cmd): def __init__(self, hostname, port, color=False, username=None, password=None, encoding=None, stdin=None, tty=True, completekey=DEFAULT_COMPLETEKEY, browser=None, use_conn=None, - cqlver=DEFAULT_CQLVER, keyspace=None, + cqlver=None, keyspace=None, tracing_enabled=False, expand_enabled=False, no_compact=False, display_nanotime_format=DEFAULT_NANOTIME_FORMAT, @@ -716,7 +718,7 @@ class Shell(cmd.Cmd): ssl=False, single_statement=None, request_timeout=DEFAULT_REQUEST_TIMEOUT_SECONDS, - protocol_version=DEFAULT_PROTOCOL_VERSION, + protocol_version=None, connect_timeout=DEFAULT_CONNECT_TIMEOUT_SECONDS): cmd.Cmd.__init__(self, completekey=completekey) self.hostname = hostname @@ -735,15 +737,19 @@ class Shell(cmd.Cmd): if use_conn: self.conn = use_conn else: - self.conn = Cluster(contact_points=(self.hostname,), port=self.port, cql_version=cqlver, - protocol_version=protocol_version, + kwargs = {} + if protocol_version is not None: + kwargs['protocol_version'] = protocol_version + if cqlver is not None: + kwargs['cql_version'] = cqlver + self.conn = Cluster(contact_points=(self.hostname,), port=self.port, auth_provider=self.auth_provider, no_compact=no_compact, ssl_options=sslhandling.ssl_settings(hostname, CONFIG_FILE) if ssl else None, load_balancing_policy=WhiteListRoundRobinPolicy([self.hostname]), control_connection_timeout=connect_timeout, - connect_timeout=connect_timeout) + connect_timeout=connect_timeout, + **kwargs) self.owns_connection = not use_conn - self.set_expanded_cql_version(cqlver) if keyspace: self.session = self.conn.connect(keyspace) @@ -767,6 +773,7 @@ class Shell(cmd.Cmd): self.session.row_factory = ordered_dict_factory self.session.default_consistency_level = cassandra.ConsistencyLevel.ONE self.get_connection_versions() + self.set_expanded_cql_version(self.connection_versions['cql']) self.current_keyspace = keyspace @@ -877,9 +884,9 @@ class Shell(cmd.Cmd): result, = self.session.execute("select * from system.local where key = 'local'") vers = { 'build': result['release_version'], - 'protocol': result['native_protocol_version'], 'cql': result['cql_version'], } + vers['protocol'] = self.conn.protocol_version self.connection_versions = vers def get_keyspace_names(self): @@ -1933,9 +1940,9 @@ class Shell(cmd.Cmd): direction = parsed.get_binding('dir').upper() if direction == 'FROM': - task = ImportTask(self, ks, table, columns, fname, opts, DEFAULT_PROTOCOL_VERSION, CONFIG_FILE) + task = ImportTask(self, ks, table, columns, fname, opts, self.conn.protocol_version, CONFIG_FILE) elif direction == 'TO': - task = ExportTask(self, ks, table, columns, fname, opts, DEFAULT_PROTOCOL_VERSION, CONFIG_FILE) + task = ExportTask(self, ks, table, columns, fname, opts, self.conn.protocol_version, CONFIG_FILE) else: raise SyntaxError("Unknown direction %s" % direction) @@ -2495,7 +2502,8 @@ def read_options(cmdlineargs, environment): optvalues.encoding = option_with_default(configs.get, 'ui', 'encoding', UTF8) optvalues.tty = option_with_default(configs.getboolean, 'ui', 'tty', sys.stdin.isatty()) - optvalues.cqlversion = option_with_default(configs.get, 'cql', 'version', DEFAULT_CQLVER) + optvalues.cqlversion = option_with_default(configs.get, 'cql', 'version', None) + optvalues.protocol_version = option_with_default(configs.getint, 'protocol', 'version', None) optvalues.connect_timeout = option_with_default(configs.getint, 'connection', 'timeout', DEFAULT_CONNECT_TIMEOUT_SECONDS) optvalues.request_timeout = option_with_default(configs.getint, 'connection', 'request_timeout', DEFAULT_REQUEST_TIMEOUT_SECONDS) optvalues.execute = None @@ -2539,11 +2547,11 @@ def read_options(cmdlineargs, environment): else: options.color = should_use_color() - options.cqlversion, cqlvertup = full_cql_version(options.cqlversion) - if cqlvertup[0] < 3: - parser.error('%r is not a supported CQL version.' % options.cqlversion) - else: - options.cqlmodule = cql3handling + if options.cqlversion is not None: + options.cqlversion, cqlvertup = full_cql_version(options.cqlversion) + if cqlvertup[0] < 3: + parser.error('%r is not a supported CQL version.' % options.cqlversion) + options.cqlmodule = cql3handling try: port = int(port) @@ -2647,6 +2655,7 @@ def main(options, hostname, port): tty=options.tty, completekey=options.completekey, browser=options.browser, + protocol_version=options.protocol_version, cqlver=options.cqlversion, keyspace=options.keyspace, no_compact=options.no_compact, diff --git a/pylib/cqlshlib/test/cassconnect.py b/pylib/cqlshlib/test/cassconnect.py index 71f7565..501850c 100644 --- a/pylib/cqlshlib/test/cassconnect.py +++ b/pylib/cqlshlib/test/cassconnect.py @@ -24,15 +24,13 @@ from .run_cqlsh import run_cqlsh, call_cqlsh test_keyspace_init = os.path.join(rundir, 'test_keyspace_init.cql') -def get_cassandra_connection(cql_version=cqlsh.DEFAULT_CQLVER): - if cql_version is None: - cql_version = cqlsh.DEFAULT_CQLVER +def get_cassandra_connection(cql_version=None): conn = cql((TEST_HOST,), TEST_PORT, cql_version=cql_version, load_balancing_policy=policy) # until the cql lib does this for us conn.cql_version = cql_version return conn -def get_cassandra_cursor(cql_version=cqlsh.DEFAULT_CQLVER): +def get_cassandra_cursor(cql_version=None): return get_cassandra_connection(cql_version=cql_version).cursor() TEST_KEYSPACES_CREATED = [] @@ -83,7 +81,7 @@ def remove_db(): c.execute('DROP KEYSPACE %s' % quote_name(TEST_KEYSPACES_CREATED.pop(-1))) @contextlib.contextmanager -def cassandra_connection(cql_version=cqlsh.DEFAULT_CQLVER): +def cassandra_connection(cql_version=None): """ Make a Cassandra CQL connection with the given CQL version and get a cursor for it, and optionally connect to a given keyspace. diff --git a/pylib/cqlshlib/test/test_cqlsh_completion.py b/pylib/cqlshlib/test/test_cqlsh_completion.py index e736ea7..75198b6 100644 --- a/pylib/cqlshlib/test/test_cqlsh_completion.py +++ b/pylib/cqlshlib/test/test_cqlsh_completion.py @@ -42,7 +42,7 @@ completion_separation_re = re.compile(r'\s+') class CqlshCompletionCase(BaseTestCase): def setUp(self): - self.cqlsh_runner = testrun_cqlsh(cqlver=cqlsh.DEFAULT_CQLVER, env={'COLUMNS': '100000'}) + self.cqlsh_runner = testrun_cqlsh(cqlver=None, env={'COLUMNS': '100000'}) self.cqlsh = self.cqlsh_runner.__enter__() def tearDown(self): diff --git a/pylib/cqlshlib/test/test_cqlsh_output.py b/pylib/cqlshlib/test/test_cqlsh_output.py index d905095..50849d4 100644 --- a/pylib/cqlshlib/test/test_cqlsh_output.py +++ b/pylib/cqlshlib/test/test_cqlsh_output.py @@ -67,13 +67,6 @@ class TestCqlshOutput(BaseTestCase): 'Actually got: %s\ncolor code: %s' % (tags, coloredtext.colored_version(), coloredtext.colortags())) - def assertCqlverQueriesGiveColoredOutput(self, queries_and_expected_outputs, - cqlver=(cqlsh.DEFAULT_CQLVER,), **kwargs): - if not isinstance(cqlver, (tuple, list)): - cqlver = (cqlver,) - for ver in cqlver: - self.assertQueriesGiveColoredOutput(queries_and_expected_outputs, cqlver=ver, **kwargs) - def assertQueriesGiveColoredOutput(self, queries_and_expected_outputs, **kwargs): """ Allow queries and expected output to be specified in structured tuples, @@ -133,7 +126,7 @@ class TestCqlshOutput(BaseTestCase): self.assertHasColors(c.read_to_next_prompt()) def test_count_output(self): - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ('select count(*) from has_all_types;', """ count MMMMM @@ -198,7 +191,7 @@ class TestCqlshOutput(BaseTestCase): (1 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) q = 'select COUNT(*) FROM twenty_rows_composite_table limit 1000000;' self.assertQueriesGiveColoredOutput(( @@ -214,10 +207,10 @@ class TestCqlshOutput(BaseTestCase): (1 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_static_cf_output(self): - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ("select a, b from twenty_rows_table where a in ('1', '13', '2');", """ a | b RR MM @@ -234,7 +227,7 @@ class TestCqlshOutput(BaseTestCase): (3 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) self.assertQueriesGiveColoredOutput(( ('select * from dynamic_columns;', """ @@ -257,11 +250,11 @@ class TestCqlshOutput(BaseTestCase): (5 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_empty_cf_output(self): # we print the header after CASSANDRA-6910 - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ('select * from empty_table;', """ lonelykey | lonelycol RRRRRRRRR MMMMMMMMM @@ -270,7 +263,7 @@ class TestCqlshOutput(BaseTestCase): (0 rows) """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) q = 'select * from has_all_types where num = 999;' @@ -284,7 +277,7 @@ class TestCqlshOutput(BaseTestCase): (0 rows) """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_columnless_key_output(self): q = "select a from twenty_rows_table where a in ('1', '2', '-9192');" @@ -304,10 +297,10 @@ class TestCqlshOutput(BaseTestCase): (2 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_numeric_output(self): - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ('''select intcol, bigintcol, varintcol \ from has_all_types \ where num in (0, 1, 2, 3, 4);''', """ @@ -353,7 +346,7 @@ class TestCqlshOutput(BaseTestCase): (5 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_timestamp_output(self): self.assertQueriesGiveColoredOutput(( @@ -390,7 +383,7 @@ class TestCqlshOutput(BaseTestCase): pass def test_boolean_output(self): - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ('select num, booleancol from has_all_types where num in (0, 1, 2, 3);', """ num | booleancol RRR MMMMMMMMMM @@ -409,11 +402,11 @@ class TestCqlshOutput(BaseTestCase): (4 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_null_output(self): # column with metainfo but no values - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ("select k, c, notthere from undefined_values_table where k in ('k1', 'k2');", """ k | c | notthere R M MMMMMMMM @@ -428,7 +421,7 @@ class TestCqlshOutput(BaseTestCase): (2 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) # all-columns, including a metainfo column has no values (cql3) self.assertQueriesGiveColoredOutput(( @@ -446,10 +439,10 @@ class TestCqlshOutput(BaseTestCase): (2 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_string_output_ascii(self): - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ("select * from ascii_with_special_chars where k in (0, 1, 2, 3);", r""" k | val R MMM @@ -468,7 +461,7 @@ class TestCqlshOutput(BaseTestCase): (4 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_string_output_utf8(self): # many of these won't line up visually here, to keep the source code @@ -477,7 +470,7 @@ class TestCqlshOutput(BaseTestCase): # terminals, but the color-checking machinery here will still treat # it as one character, so those won't seem to line up visually either. - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ("select * from utf8_with_special_chars where k in (0, 1, 2, 3, 4, 5, 6);", u""" k | val R MMM @@ -502,10 +495,10 @@ class TestCqlshOutput(BaseTestCase): (7 rows) nnnnnnnn """.encode('utf-8')), - ), cqlver=cqlsh.DEFAULT_CQLVER, env={'LANG': 'en_US.UTF-8'}) + ), env={'LANG': 'en_US.UTF-8'}) def test_blob_output(self): - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ("select num, blobcol from has_all_types where num in (0, 1, 2, 3);", r""" num | blobcol RRR MMMMMMM @@ -524,10 +517,10 @@ class TestCqlshOutput(BaseTestCase): (4 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_prompt(self): - with testrun_cqlsh(tty=True, keyspace=None, cqlver=cqlsh.DEFAULT_CQLVER) as c: + with testrun_cqlsh(tty=True, keyspace=None) as c: self.assertTrue(c.output_header.splitlines()[-1].endswith('cqlsh> ')) c.send('\n') @@ -559,8 +552,7 @@ class TestCqlshOutput(BaseTestCase): "RRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRRR") def test_describe_keyspace_output(self): - fullcqlver = cqlsh.DEFAULT_CQLVER - with testrun_cqlsh(tty=True, cqlver=fullcqlver) as c: + with testrun_cqlsh(tty=True) as c: ks = get_keyspace() qks = quote_name(ks) for cmd in ('describe keyspace', 'desc keyspace'): @@ -568,7 +560,7 @@ class TestCqlshOutput(BaseTestCase): for semicolon in ('', ';'): fullcmd = cmd + (' ' if givename else '') + givename + semicolon desc = c.cmd_and_response(fullcmd) - self.check_describe_keyspace_output(desc, givename or qks, fullcqlver) + self.check_describe_keyspace_output(desc, givename or qks) # try to actually execute that last keyspace description, with a # new keyspace name @@ -577,7 +569,7 @@ class TestCqlshOutput(BaseTestCase): statements = split_cql_commands(copy_desc) do_drop = True - with cassandra_cursor(cql_version=fullcqlver) as curs: + with cassandra_cursor() as curs: try: for stmt in statements: cqlshlog.debug('TEST EXEC: %s' % stmt) @@ -587,7 +579,7 @@ class TestCqlshOutput(BaseTestCase): if do_drop: curs.execute('drop keyspace %s' % quote_name(new_ks_name)) - def check_describe_keyspace_output(self, output, qksname, fullcqlver): + def check_describe_keyspace_output(self, output, qksname): expected_bits = [r'(?im)^CREATE KEYSPACE %s WITH\b' % re.escape(qksname), r';\s*$', r'\breplication = {\'class\':'] @@ -635,7 +627,7 @@ class TestCqlshOutput(BaseTestCase): """ % quote_name(get_keyspace())) - with testrun_cqlsh(tty=True, cqlver=cqlsh.DEFAULT_CQLVER) as c: + with testrun_cqlsh(tty=True) as c: for cmdword in ('describe table', 'desc columnfamily'): for semicolon in (';', ''): output = c.cmd_and_response('%s has_all_types%s' % (cmdword, semicolon)) @@ -653,7 +645,7 @@ class TestCqlshOutput(BaseTestCase): ks = get_keyspace() - with testrun_cqlsh(tty=True, keyspace=None, cqlver=cqlsh.DEFAULT_CQLVER) as c: + with testrun_cqlsh(tty=True, keyspace=None) as c: # when not in a keyspace for cmdword in ('DESCRIBE COLUMNFAMILIES', 'desc tables'): @@ -704,7 +696,7 @@ class TestCqlshOutput(BaseTestCase): \n ''' - with testrun_cqlsh(tty=True, keyspace=None, cqlver=cqlsh.DEFAULT_CQLVER) as c: + with testrun_cqlsh(tty=True, keyspace=None) as c: # not in a keyspace for semicolon in ('', ';'): @@ -792,7 +784,7 @@ class TestCqlshOutput(BaseTestCase): pass def test_user_types_output(self): - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ("select addresses from users;", r""" addresses MMMMMMMMM @@ -807,8 +799,8 @@ class TestCqlshOutput(BaseTestCase): (2 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) - self.assertCqlverQueriesGiveColoredOutput(( + )) + self.assertQueriesGiveColoredOutput(( ("select phone_numbers from users;", r""" phone_numbers MMMMMMMMMMMMM @@ -823,10 +815,10 @@ class TestCqlshOutput(BaseTestCase): (2 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) def test_user_types_with_collections(self): - self.assertCqlverQueriesGiveColoredOutput(( + self.assertQueriesGiveColoredOutput(( ("select info from songs;", r""" info MMMM @@ -839,8 +831,8 @@ class TestCqlshOutput(BaseTestCase): (1 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) - self.assertCqlverQueriesGiveColoredOutput(( + )) + self.assertQueriesGiveColoredOutput(( ("select tags from songs;", r""" tags MMMM @@ -853,4 +845,4 @@ class TestCqlshOutput(BaseTestCase): (1 rows) nnnnnnnn """), - ), cqlver=cqlsh.DEFAULT_CQLVER) + )) diff --git a/src/java/org/apache/cassandra/config/Config.java b/src/java/org/apache/cassandra/config/Config.java index 830d3e1..bc3e3bf 100644 --- a/src/java/org/apache/cassandra/config/Config.java +++ b/src/java/org/apache/cassandra/config/Config.java @@ -156,7 +156,7 @@ public class Config public boolean native_transport_flush_in_batches_legacy = true; public volatile long native_transport_max_concurrent_requests_in_bytes_per_ip = -1L; public volatile long native_transport_max_concurrent_requests_in_bytes = -1L; - + public Integer native_transport_max_negotiable_protocol_version = Integer.MIN_VALUE; @Deprecated public Integer thrift_max_message_length_in_mb = 16; diff --git a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java index 8417c39..a161a2a 100644 --- a/src/java/org/apache/cassandra/config/DatabaseDescriptor.java +++ b/src/java/org/apache/cassandra/config/DatabaseDescriptor.java @@ -52,6 +52,7 @@ import org.apache.cassandra.scheduler.IRequestScheduler; import org.apache.cassandra.scheduler.NoScheduler; import org.apache.cassandra.service.CacheService; import org.apache.cassandra.thrift.ThriftServer; +import org.apache.cassandra.transport.Server; import org.apache.cassandra.utils.FBUtilities; import org.apache.cassandra.utils.memory.*; @@ -764,6 +765,20 @@ public class DatabaseDescriptor throw new ConfigurationException("Encryption must be enabled in client_encryption_options for native_transport_port_ssl", false); } + // If max protocol version has been set, just validate it's within an acceptable range + if (conf.native_transport_max_negotiable_protocol_version != Integer.MIN_VALUE) + { + if (conf.native_transport_max_negotiable_protocol_version < Server.MIN_SUPPORTED_VERSION + || conf.native_transport_max_negotiable_protocol_version > Server.CURRENT_VERSION) + { + throw new ConfigurationException(String.format("Invalid setting for native_transport_max_negotiable_version (%d); " + + "Values between %s and %s are supported", + conf.native_transport_max_negotiable_protocol_version, + Server.MIN_SUPPORTED_VERSION, + Server.CURRENT_VERSION)); + } + } + if (conf.max_value_size_in_mb == null || conf.max_value_size_in_mb <= 0) throw new ConfigurationException("max_value_size_in_mb must be positive", false); else if (conf.max_value_size_in_mb >= 2048) @@ -1525,6 +1540,11 @@ public class DatabaseDescriptor return conf.native_transport_flush_in_batches_legacy; } + public static int getNativeProtocolMaxVersionOverride() + { + return conf.native_transport_max_negotiable_protocol_version; + } + public static double getCommitLogSyncBatchWindow() { return conf.commitlog_sync_batch_window_in_ms; diff --git a/src/java/org/apache/cassandra/db/SystemKeyspace.java b/src/java/org/apache/cassandra/db/SystemKeyspace.java index 541dd34..7c222dd 100644 --- a/src/java/org/apache/cassandra/db/SystemKeyspace.java +++ b/src/java/org/apache/cassandra/db/SystemKeyspace.java @@ -716,6 +716,18 @@ public final class SystemKeyspace return executorService.submit((Runnable) () -> executeInternal(String.format(req, PEERS, columnName), ep, value)); } + public static void updatePeerReleaseVersion(final InetAddress ep, final Object value, Runnable postUpdateTask, ExecutorService executorService) + { + if (ep.equals(FBUtilities.getBroadcastAddress())) + return; + + String req = "INSERT INTO system.%s (peer, %s) VALUES (?, ?)"; + executorService.execute(() -> { + executeInternal(String.format(req, PEERS, "release_version"), ep, value); + postUpdateTask.run(); + }); + } + public static synchronized void updateHintsDropped(InetAddress ep, UUID timePeriod, int value) { // with 30 day TTL @@ -812,6 +824,37 @@ public final class SystemKeyspace } /** + * Return a map of IP address to C* version. If an invalid version string, or no version + * at all is stored for a given peer IP, then NULL_VERSION will be reported for that peer + */ + public static Map<InetAddress, CassandraVersion> loadPeerVersions() + { + Map<InetAddress, CassandraVersion> releaseVersionMap = new HashMap<>(); + for (UntypedResultSet.Row row : executeInternal("SELECT peer, release_version FROM system." + PEERS)) + { + InetAddress peer = row.getInetAddress("peer"); + if (row.has("release_version")) + { + try + { + releaseVersionMap.put(peer, new CassandraVersion(row.getString("release_version"))); + } + catch (IllegalArgumentException e) + { + logger.info("Invalid version string found for {}", peer); + releaseVersionMap.put(peer, NULL_VERSION); + } + } + else + { + logger.info("No version string found for {}", peer); + releaseVersionMap.put(peer, NULL_VERSION); + } + } + return releaseVersionMap; + } + + /** * Get preferred IP for given endpoint if it is known. Otherwise this returns given endpoint itself. * * @param ep endpoint address to check diff --git a/src/java/org/apache/cassandra/service/CassandraDaemon.java b/src/java/org/apache/cassandra/service/CassandraDaemon.java index ad4a344..cc8b2ae 100644 --- a/src/java/org/apache/cassandra/service/CassandraDaemon.java +++ b/src/java/org/apache/cassandra/service/CassandraDaemon.java @@ -668,6 +668,16 @@ public class CassandraDaemon return nativeTransportService != null ? nativeTransportService.isRunning() : false; } + public int getMaxNativeProtocolVersion() + { + return nativeTransportService.getMaxProtocolVersion(); + } + + public void refreshMaxNativeProtocolVersion() + { + if (nativeTransportService != null) + nativeTransportService.refreshMaxNegotiableProtocolVersion(); + } /** * A convenience method to stop and destroy the daemon in one shot. diff --git a/src/java/org/apache/cassandra/service/NativeTransportService.java b/src/java/org/apache/cassandra/service/NativeTransportService.java index 2280818..587f781 100644 --- a/src/java/org/apache/cassandra/service/NativeTransportService.java +++ b/src/java/org/apache/cassandra/service/NativeTransportService.java @@ -33,6 +33,7 @@ import io.netty.channel.epoll.EpollEventLoopGroup; import io.netty.channel.nio.NioEventLoopGroup; import org.apache.cassandra.config.DatabaseDescriptor; import org.apache.cassandra.metrics.ClientMetrics; +import org.apache.cassandra.transport.ConfiguredLimit; import org.apache.cassandra.transport.Message; import org.apache.cassandra.transport.Server; @@ -48,6 +49,7 @@ public class NativeTransportService private boolean initialized = false; private EventLoopGroup workerGroup; + private ConfiguredLimit protocolVersionLimit; /** * Creates netty thread pools and event loops. @@ -69,12 +71,15 @@ public class NativeTransportService logger.info("Netty using Java NIO event loop"); } + protocolVersionLimit = ConfiguredLimit.newLimit(); + int nativePort = DatabaseDescriptor.getNativeTransportPort(); int nativePortSSL = DatabaseDescriptor.getNativeTransportPortSSL(); InetAddress nativeAddr = DatabaseDescriptor.getRpcAddress(); org.apache.cassandra.transport.Server.Builder builder = new org.apache.cassandra.transport.Server.Builder() .withEventLoopGroup(workerGroup) + .withProtocolVersionLimit(protocolVersionLimit) .withHost(nativeAddr); if (!DatabaseDescriptor.getClientEncryptionOptions().enabled) @@ -137,6 +142,20 @@ public class NativeTransportService Message.Dispatcher.shutdown(); } + public int getMaxProtocolVersion() + { + return protocolVersionLimit.getMaxVersion(); + } + + public void refreshMaxNegotiableProtocolVersion() + { + // lowering the max negotiable protocol version is only safe if we haven't already + // allowed clients to connect with a higher version. This still allows the max + // version to be raised, as that is safe. + if (initialized) + protocolVersionLimit.updateMaxSupportedVersion(); + } + /** * @return intend to use epoll bassed event looping */ diff --git a/src/java/org/apache/cassandra/service/StorageService.java b/src/java/org/apache/cassandra/service/StorageService.java index 2af7fb7..8c29601 100644 --- a/src/java/org/apache/cassandra/service/StorageService.java +++ b/src/java/org/apache/cassandra/service/StorageService.java @@ -442,6 +442,23 @@ public class StorageService extends NotificationBroadcasterSupport implements IE return daemon.isNativeTransportRunning(); } + public int getMaxNativeProtocolVersion() + { + if (daemon == null) + { + throw new IllegalStateException("No configured daemon"); + } + return daemon.getMaxNativeProtocolVersion(); + } + + private void refreshMaxNativeProtocolVersion() + { + if (daemon != null) + { + daemon.refreshMaxNativeProtocolVersion(); + } + } + public void stopTransports() { if (isInitialized()) @@ -1797,7 +1814,7 @@ public class StorageService extends NotificationBroadcasterSupport implements IE switch (state) { case RELEASE_VERSION: - SystemKeyspace.updatePeerInfo(endpoint, "release_version", value.value, executor); + SystemKeyspace.updatePeerReleaseVersion(endpoint, value.value, this::refreshMaxNativeProtocolVersion, executor); break; case DC: updateTopology(endpoint); @@ -1874,7 +1891,7 @@ public class StorageService extends NotificationBroadcasterSupport implements IE switch (entry.getKey()) { case RELEASE_VERSION: - SystemKeyspace.updatePeerInfo(endpoint, "release_version", entry.getValue().value, executor); + SystemKeyspace.updatePeerReleaseVersion(endpoint, entry.getValue().value, this::refreshMaxNativeProtocolVersion, executor); break; case DC: SystemKeyspace.updatePeerInfo(endpoint, "data_center", entry.getValue().value, executor); diff --git a/src/java/org/apache/cassandra/service/StorageServiceMBean.java b/src/java/org/apache/cassandra/service/StorageServiceMBean.java index ddd2da0..e22b094 100644 --- a/src/java/org/apache/cassandra/service/StorageServiceMBean.java +++ b/src/java/org/apache/cassandra/service/StorageServiceMBean.java @@ -619,4 +619,7 @@ public interface StorageServiceMBean extends NotificationEmitter * @return true if the node successfully starts resuming. (this does not mean bootstrap streaming was success.) */ public boolean resumeBootstrap(); + + /** Returns the max version that this node will negotiate for native protocol connections */ + public int getMaxNativeProtocolVersion(); } diff --git a/src/java/org/apache/cassandra/transport/ConfiguredLimit.java b/src/java/org/apache/cassandra/transport/ConfiguredLimit.java new file mode 100644 index 0000000..98518b8 --- /dev/null +++ b/src/java/org/apache/cassandra/transport/ConfiguredLimit.java @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.transport; + +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.SystemKeyspace; +import org.apache.cassandra.utils.CassandraVersion; + +public abstract class ConfiguredLimit implements ProtocolVersionLimit +{ + private static final Logger logger = LoggerFactory.getLogger(ConfiguredLimit.class); + static final String DISABLE_MAX_PROTOCOL_AUTO_OVERRIDE = "cassandra.disable_max_protocol_auto_override"; + static final CassandraVersion MIN_VERSION_FOR_V4 = new CassandraVersion("3.0.0"); + + public abstract int getMaxVersion(); + public abstract void updateMaxSupportedVersion(); + + public static ConfiguredLimit newLimit() + { + if (Boolean.getBoolean(DISABLE_MAX_PROTOCOL_AUTO_OVERRIDE)) + return new StaticLimit(Server.CURRENT_VERSION); + + int fromConfig = DatabaseDescriptor.getNativeProtocolMaxVersionOverride(); + return fromConfig != Integer.MIN_VALUE + ? new StaticLimit(fromConfig) + : new DynamicLimit(Server.CURRENT_VERSION); + } + + private static class StaticLimit extends ConfiguredLimit + { + private final int maxVersion; + private StaticLimit(int maxVersion) + { + if (maxVersion < Server.MIN_SUPPORTED_VERSION || maxVersion > Server.CURRENT_VERSION) + throw new IllegalArgumentException(String.format("Invalid max protocol version supplied (%s); " + + "Values between %s and %s are supported", + maxVersion, + Server.MIN_SUPPORTED_VERSION, + Server.CURRENT_VERSION)); + this.maxVersion = maxVersion; + logger.info("Native transport max negotiable version statically limited to {}", maxVersion); + } + + public int getMaxVersion() + { + return maxVersion; + } + + public void updateMaxSupportedVersion() + { + // statically configured, so this is a no-op + } + } + + private static class DynamicLimit extends ConfiguredLimit + { + private volatile int maxVersion; + private DynamicLimit(int initialLimit) + { + maxVersion = initialLimit; + maybeUpdateVersion(true); + } + + public int getMaxVersion() + { + return maxVersion; + } + + public void updateMaxSupportedVersion() + { + maybeUpdateVersion(false); + } + + private void maybeUpdateVersion(boolean allowLowering) + { + boolean enforceV3Cap = SystemKeyspace.loadPeerVersions() + .values() + .stream() + .anyMatch(v -> v.compareTo(MIN_VERSION_FOR_V4) < 0); + + if (!enforceV3Cap) + { + maxVersion = Server.CURRENT_VERSION; + return; + } + + if (maxVersion > Server.VERSION_3 && !allowLowering) + { + logger.info("Detected peers which do not fully support protocol V4, but V4 was previously negotiable. " + + "Not enforcing cap as this can cause issues for older client versions. After the next " + + "restart the server will apply the cap"); + return; + } + logger.info("Detected peers which do not fully support protocol V4. Capping max negotiable version to V3"); + maxVersion = Server.VERSION_3; + } + } +} diff --git a/src/java/org/apache/cassandra/transport/Frame.java b/src/java/org/apache/cassandra/transport/Frame.java index c28be9f..a07551f 100644 --- a/src/java/org/apache/cassandra/transport/Frame.java +++ b/src/java/org/apache/cassandra/transport/Frame.java @@ -145,10 +145,12 @@ public class Frame private int tooLongStreamId; private final Connection.Factory factory; + private final ProtocolVersionLimit versionCap; - public Decoder(Connection.Factory factory) + public Decoder(Connection.Factory factory, ProtocolVersionLimit versionCap) { this.factory = factory; + this.versionCap = versionCap; } @Override @@ -175,10 +177,10 @@ public class Frame int firstByte = buffer.getByte(idx++); Message.Direction direction = Message.Direction.extractFromVersion(firstByte); int version = firstByte & PROTOCOL_VERSION_MASK; - if (version < Server.MIN_SUPPORTED_VERSION || version > Server.CURRENT_VERSION) + if (version < Server.MIN_SUPPORTED_VERSION || version > versionCap.getMaxVersion()) throw new ProtocolException(String.format("Invalid or unsupported protocol version (%d); the lowest supported version is %d and the greatest is %d", - version, Server.MIN_SUPPORTED_VERSION, Server.CURRENT_VERSION), - version); + version, Server.MIN_SUPPORTED_VERSION, versionCap.getMaxVersion()), + version < Server.MIN_SUPPORTED_VERSION ? version : null); // Wait until we have the complete header if (readableBytes < Header.LENGTH) diff --git a/src/java/org/apache/cassandra/transport/Message.java b/src/java/org/apache/cassandra/transport/Message.java index 08a8600..5202578 100644 --- a/src/java/org/apache/cassandra/transport/Message.java +++ b/src/java/org/apache/cassandra/transport/Message.java @@ -322,11 +322,18 @@ public abstract class Message @ChannelHandler.Sharable public static class ProtocolEncoder extends MessageToMessageEncoder<Message> { + private final ProtocolVersionLimit versionCap; + + ProtocolEncoder(ProtocolVersionLimit versionCap) + { + this.versionCap = versionCap; + } + public void encode(ChannelHandlerContext ctx, Message message, List results) { Connection connection = ctx.channel().attr(Connection.attributeKey).get(); // The only case the connection can be null is when we send the initial STARTUP message (client side thus) - int version = connection == null ? Server.CURRENT_VERSION : connection.getVersion(); + int version = connection == null ? versionCap.getMaxVersion() : connection.getVersion(); EnumSet<Frame.Header.Flag> flags = EnumSet.noneOf(Frame.Header.Flag.class); diff --git a/src/java/org/apache/cassandra/transport/ProtocolVersionLimit.java b/src/java/org/apache/cassandra/transport/ProtocolVersionLimit.java new file mode 100644 index 0000000..c476efb --- /dev/null +++ b/src/java/org/apache/cassandra/transport/ProtocolVersionLimit.java @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.transport; + +@FunctionalInterface +public interface ProtocolVersionLimit +{ + public int getMaxVersion(); + + public static final ProtocolVersionLimit SERVER_DEFAULT = () -> Server.CURRENT_VERSION; +} diff --git a/src/java/org/apache/cassandra/transport/Server.java b/src/java/org/apache/cassandra/transport/Server.java index 83a676c..012b326 100644 --- a/src/java/org/apache/cassandra/transport/Server.java +++ b/src/java/org/apache/cassandra/transport/Server.java @@ -87,11 +87,14 @@ public class Server implements CassandraDaemon.Server private final AtomicBoolean isRunning = new AtomicBoolean(false); private EventLoopGroup workerGroup; + private final ProtocolVersionLimit protocolVersionLimit; private Server (Builder builder) { this.socket = builder.getSocket(); this.useSSL = builder.useSSL; + this.protocolVersionLimit = builder.getProtocolVersionLimit(); + if (builder.workerGroup != null) { workerGroup = builder.workerGroup; @@ -188,6 +191,7 @@ public class Server implements CassandraDaemon.Server private InetAddress hostAddr; private int port = -1; private InetSocketAddress socket; + private ProtocolVersionLimit versionLimit; public Builder withSSL(boolean useSSL) { @@ -215,6 +219,19 @@ public class Server implements CassandraDaemon.Server return this; } + public Builder withProtocolVersionLimit(ProtocolVersionLimit limit) + { + this.versionLimit = limit; + return this; + } + + ProtocolVersionLimit getProtocolVersionLimit() + { + if (versionLimit == null) + throw new IllegalArgumentException("Missing protocol version limiter"); + return versionLimit; + } + public Server build() { return new Server(this); @@ -327,7 +344,6 @@ public class Server implements CassandraDaemon.Server { // Stateless handlers private static final Message.ProtocolDecoder messageDecoder = new Message.ProtocolDecoder(); - private static final Message.ProtocolEncoder messageEncoder = new Message.ProtocolEncoder(); private static final Frame.Decompressor frameDecompressor = new Frame.Decompressor(); private static final Frame.Compressor frameCompressor = new Frame.Compressor(); private static final Frame.Encoder frameEncoder = new Frame.Encoder(); @@ -355,14 +371,14 @@ public class Server implements CassandraDaemon.Server //pipeline.addLast("debug", new LoggingHandler()); - pipeline.addLast("frameDecoder", new Frame.Decoder(server.connectionFactory)); + pipeline.addLast("frameDecoder", new Frame.Decoder(server.connectionFactory, server.protocolVersionLimit)); pipeline.addLast("frameEncoder", frameEncoder); pipeline.addLast("frameDecompressor", frameDecompressor); pipeline.addLast("frameCompressor", frameCompressor); pipeline.addLast("messageDecoder", messageDecoder); - pipeline.addLast("messageEncoder", messageEncoder); + pipeline.addLast("messageEncoder", new Message.ProtocolEncoder(server.protocolVersionLimit)); pipeline.addLast("executor", new Message.Dispatcher(DatabaseDescriptor.useNativeTransportLegacyFlusher(), EndpointPayloadTracker.get(((InetSocketAddress) channel.remoteAddress()).getAddress()))); diff --git a/src/java/org/apache/cassandra/transport/SimpleClient.java b/src/java/org/apache/cassandra/transport/SimpleClient.java index 7916deb..7d34d98 100644 --- a/src/java/org/apache/cassandra/transport/SimpleClient.java +++ b/src/java/org/apache/cassandra/transport/SimpleClient.java @@ -251,7 +251,7 @@ public class SimpleClient implements Closeable // Stateless handlers private static final Message.ProtocolDecoder messageDecoder = new Message.ProtocolDecoder(); - private static final Message.ProtocolEncoder messageEncoder = new Message.ProtocolEncoder(); + private static final Message.ProtocolEncoder messageEncoder = new Message.ProtocolEncoder(ProtocolVersionLimit.SERVER_DEFAULT); private static final Frame.Decompressor frameDecompressor = new Frame.Decompressor(); private static final Frame.Compressor frameCompressor = new Frame.Compressor(); private static final Frame.Encoder frameEncoder = new Frame.Encoder(); @@ -274,7 +274,7 @@ public class SimpleClient implements Closeable channel.attr(Connection.attributeKey).set(connection); ChannelPipeline pipeline = channel.pipeline(); - pipeline.addLast("frameDecoder", new Frame.Decoder(connectionFactory)); + pipeline.addLast("frameDecoder", new Frame.Decoder(connectionFactory, ProtocolVersionLimit.SERVER_DEFAULT)); pipeline.addLast("frameEncoder", frameEncoder); pipeline.addLast("frameDecompressor", frameDecompressor); diff --git a/test/unit/org/apache/cassandra/cql3/CQLTester.java b/test/unit/org/apache/cassandra/cql3/CQLTester.java index 999404e..95366c2 100644 --- a/test/unit/org/apache/cassandra/cql3/CQLTester.java +++ b/test/unit/org/apache/cassandra/cql3/CQLTester.java @@ -62,6 +62,7 @@ import org.apache.cassandra.serializers.TypeSerializer; import org.apache.cassandra.service.ClientState; import org.apache.cassandra.service.QueryState; import org.apache.cassandra.service.StorageService; +import org.apache.cassandra.transport.ConfiguredLimit; import org.apache.cassandra.transport.Event; import org.apache.cassandra.transport.Server; import org.apache.cassandra.transport.messages.ResultMessage; @@ -88,6 +89,7 @@ public abstract class CQLTester private static org.apache.cassandra.transport.Server server; protected static final int nativePort; protected static final InetAddress nativeAddr; + protected static ConfiguredLimit protocolVersionLimit; private static final Map<Integer, Cluster> clusters = new HashMap<>(); private static final Map<Integer, Session> sessions = new HashMap<>(); @@ -330,11 +332,43 @@ public abstract class CQLTester if (server != null) return; + prepareNetwork(); + initializeNetwork(); + } + + protected static void prepareNetwork() + { SystemKeyspace.finishStartup(); StorageService.instance.initServer(); SchemaLoader.startGossiper(); + } + + protected static void reinitializeNetwork() + { + if (server != null && server.isRunning()) + { + server.stop(); + server = null; + } + List<CloseFuture> futures = new ArrayList<>(); + for (Cluster cluster : clusters.values()) + futures.add(cluster.closeAsync()); + for (Session session : sessions.values()) + futures.add(session.closeAsync()); + FBUtilities.waitOnFutures(futures); + clusters.clear(); + sessions.clear(); + + initializeNetwork(); + } - server = new Server.Builder().withHost(nativeAddr).withPort(nativePort).build(); + private static void initializeNetwork() + { + protocolVersionLimit = ConfiguredLimit.newLimit(); + server = new Server.Builder().withHost(nativeAddr) + .withPort(nativePort) + .withProtocolVersionLimit(protocolVersionLimit) + .build(); ClientMetrics.instance.init(Collections.singleton(server)); server.start(); @@ -343,9 +377,12 @@ public abstract class CQLTester if (clusters.containsKey(version)) continue; + if (version > protocolVersionLimit.getMaxVersion()) + continue; + Cluster cluster = Cluster.builder() .addContactPoints(nativeAddr) - .withClusterName("Test Cluster") + .withClusterName("Test Cluster-v" + version) .withPort(nativePort) .withProtocolVersion(ProtocolVersion.fromInt(version)) .build(); @@ -356,6 +393,14 @@ public abstract class CQLTester } } + protected void updateMaxNegotiableProtocolVersion() + { + if (protocolVersionLimit == null) + throw new IllegalStateException("Native transport server has not been initialized"); + + protocolVersionLimit.updateMaxSupportedVersion(); + } + protected void dropPerTestKeyspace() throws Throwable { execute(String.format("DROP KEYSPACE IF EXISTS %s", KEYSPACE_PER_TEST)); diff --git a/test/unit/org/apache/cassandra/transport/DynamicLimitTest.java b/test/unit/org/apache/cassandra/transport/DynamicLimitTest.java new file mode 100644 index 0000000..83a0dd9 --- /dev/null +++ b/test/unit/org/apache/cassandra/transport/DynamicLimitTest.java @@ -0,0 +1,111 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.transport; + +import java.net.InetAddress; + +import org.junit.BeforeClass; +import org.junit.Test; + +import org.apache.cassandra.cql3.CQLTester; + +import static org.apache.cassandra.transport.ProtocolTestHelper.cleanupPeers; +import static org.apache.cassandra.transport.ProtocolTestHelper.setStaticLimitInConfig; +import static org.apache.cassandra.transport.ProtocolTestHelper.setupPeer; +import static org.apache.cassandra.transport.ProtocolTestHelper.updatePeerInfo; +import static org.junit.Assert.assertEquals; + +public class DynamicLimitTest +{ + @BeforeClass + public static void setup() + { + CQLTester.prepareServer(); + } + + @Test + public void disableDynamicLimitWithSystemProperty() throws Throwable + { + // Dynamic limiting of the max negotiable protocol version can be + // disabled with a system property + + // ensure that no static limit is configured + setStaticLimitInConfig(null); + + // set the property which disables dynamic limiting + System.setProperty(ConfiguredLimit.DISABLE_MAX_PROTOCOL_AUTO_OVERRIDE, "true"); + // insert a legacy peer into system.peers and also + InetAddress peer = null; + try + { + peer = setupPeer("127.1.0.1", "2.2.0"); + ConfiguredLimit limit = ConfiguredLimit.newLimit(); + assertEquals(Server.CURRENT_VERSION, limit.getMaxVersion()); + + // clearing the property after the limit has been returned has no effect + System.clearProperty(ConfiguredLimit.DISABLE_MAX_PROTOCOL_AUTO_OVERRIDE); + limit.updateMaxSupportedVersion(); + assertEquals(Server.CURRENT_VERSION, limit.getMaxVersion()); + + // a new limit should now be dynamic + limit = ConfiguredLimit.newLimit(); + assertEquals(Server.VERSION_3, limit.getMaxVersion()); + } + finally + { + System.clearProperty(ConfiguredLimit.DISABLE_MAX_PROTOCOL_AUTO_OVERRIDE); + cleanupPeers(peer); + } + } + + @Test + public void disallowLoweringMaxVersion() throws Throwable + { + // Lowering the max version once connections have been established is a problem + // for some clients. So for a dynamic limit, if notifications of peer versions + // trigger a change to the max version, it's only allowed to increase the max + // negotiable version + + InetAddress peer = null; + try + { + // ensure that no static limit is configured + setStaticLimitInConfig(null); + ConfiguredLimit limit = ConfiguredLimit.newLimit(); + assertEquals(Server.CURRENT_VERSION, limit.getMaxVersion()); + + peer = setupPeer("127.1.0.1", "3.0.0"); + limit.updateMaxSupportedVersion(); + assertEquals(Server.CURRENT_VERSION, limit.getMaxVersion()); + + // learn that peer doesn't actually fully support V4, behaviour should remain the same + updatePeerInfo(peer, "2.2.0"); + limit.updateMaxSupportedVersion(); + assertEquals(Server.CURRENT_VERSION, limit.getMaxVersion()); + + // finally learn that peer2 has been upgraded, just for completeness + updatePeerInfo(peer, "3.3.0"); + limit.updateMaxSupportedVersion(); + assertEquals(Server.CURRENT_VERSION, limit.getMaxVersion()); + + } finally { + cleanupPeers(peer); + } + } +} diff --git a/test/unit/org/apache/cassandra/transport/ProtocolErrorTest.java b/test/unit/org/apache/cassandra/transport/ProtocolErrorTest.java index 599087c..e212c4c 100644 --- a/test/unit/org/apache/cassandra/transport/ProtocolErrorTest.java +++ b/test/unit/org/apache/cassandra/transport/ProtocolErrorTest.java @@ -43,7 +43,7 @@ public class ProtocolErrorTest { public void testInvalidProtocolVersion(int version) throws Exception { - Frame.Decoder dec = new Frame.Decoder(null); + Frame.Decoder dec = new Frame.Decoder(null, ProtocolVersionLimit.SERVER_DEFAULT); List<Object> results = new ArrayList<>(); byte[] frame = new byte[] { @@ -71,7 +71,7 @@ public class ProtocolErrorTest { public void testInvalidProtocolVersionShortFrame() throws Exception { // test for CASSANDRA-11464 - Frame.Decoder dec = new Frame.Decoder(null); + Frame.Decoder dec = new Frame.Decoder(null, ProtocolVersionLimit.SERVER_DEFAULT); List<Object> results = new ArrayList<>(); byte[] frame = new byte[] { @@ -93,7 +93,7 @@ public class ProtocolErrorTest { @Test public void testInvalidDirection() throws Exception { - Frame.Decoder dec = new Frame.Decoder(null); + Frame.Decoder dec = new Frame.Decoder(null, ProtocolVersionLimit.SERVER_DEFAULT); List<Object> results = new ArrayList<>(); // should generate a protocol exception for using a response frame with @@ -124,7 +124,7 @@ public class ProtocolErrorTest { @Test public void testBodyLengthOverLimit() throws Exception { - Frame.Decoder dec = new Frame.Decoder(null); + Frame.Decoder dec = new Frame.Decoder(null, ProtocolVersionLimit.SERVER_DEFAULT); List<Object> results = new ArrayList<>(); byte[] frame = new byte[] { diff --git a/test/unit/org/apache/cassandra/transport/ProtocolNegotiationTest.java b/test/unit/org/apache/cassandra/transport/ProtocolNegotiationTest.java new file mode 100644 index 0000000..91c1d6a --- /dev/null +++ b/test/unit/org/apache/cassandra/transport/ProtocolNegotiationTest.java @@ -0,0 +1,166 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.transport; + +import java.net.InetAddress; +import java.util.concurrent.TimeUnit; + +import org.junit.Before; +import org.junit.BeforeClass; +import org.junit.Test; + +import com.datastax.driver.core.Cluster; +import com.datastax.driver.core.ProtocolVersion; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.cql3.CQLTester; + +import static org.apache.cassandra.transport.ProtocolTestHelper.cleanupPeers; +import static org.apache.cassandra.transport.ProtocolTestHelper.setStaticLimitInConfig; +import static org.apache.cassandra.transport.ProtocolTestHelper.setupPeer; +import static org.apache.cassandra.transport.ProtocolTestHelper.updatePeerInfo; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class ProtocolNegotiationTest extends CQLTester +{ + // to avoid JMX naming clashes between cluster metrics + private int clusterId = 0; + + @BeforeClass + public static void setup() + { + prepareNetwork(); + } + + @Before + public void clearConfig() + { + setStaticLimitInConfig(null); + } + + @Test + public void serverSupportsV3AndV4ByDefault() throws Throwable + { + reinitializeNetwork(); + // client can explicitly request either V3 or V4 + testConnection(ProtocolVersion.V3, ProtocolVersion.V3); + testConnection(ProtocolVersion.V4, ProtocolVersion.V4); + + // if not specified, V4 is the default + testConnection(null, ProtocolVersion.V4); + } + + @Test + public void testStaticLimit() throws Throwable + { + try + { + reinitializeNetwork(); + // No limit enforced to start + assertEquals(Integer.MIN_VALUE, DatabaseDescriptor.getNativeProtocolMaxVersionOverride()); + testConnection(null, ProtocolVersion.V4); + + // Update DatabaseDescriptor, then re-initialise the server to force it to read it + setStaticLimitInConfig(ProtocolVersion.V3.toInt()); + reinitializeNetwork(); + assertEquals(3, DatabaseDescriptor.getNativeProtocolMaxVersionOverride()); + testConnection(ProtocolVersion.V4, ProtocolVersion.V3); + testConnection(ProtocolVersion.V3, ProtocolVersion.V3); + testConnection(null, ProtocolVersion.V3); + } finally { + setStaticLimitInConfig(null); + } + } + + @Test + public void testDynamicLimit() throws Throwable + { + InetAddress peer1 = setupPeer("127.1.0.1", "2.2.0"); + InetAddress peer2 = setupPeer("127.1.0.2", "2.2.0"); + InetAddress peer3 = setupPeer("127.1.0.3", "2.2.0"); + reinitializeNetwork(); + try + { + // legacy peers means max negotiable version is V3 + testConnection(ProtocolVersion.V4, ProtocolVersion.V3); + testConnection(ProtocolVersion.V3, ProtocolVersion.V3); + testConnection(null, ProtocolVersion.V3); + + // receive notification that 2 peers have upgraded to a version that fully supports V4 + updatePeerInfo(peer1, "3.0.0"); + updatePeerInfo(peer2, "3.0.0"); + updateMaxNegotiableProtocolVersion(); + // version should still be capped + testConnection(ProtocolVersion.V4, ProtocolVersion.V3); + testConnection(ProtocolVersion.V3, ProtocolVersion.V3); + testConnection(null, ProtocolVersion.V3); + + // no legacy peers so V4 is negotiable + // after the last peer upgrades, cap should be lifted + updatePeerInfo(peer3, "3.0.0"); + updateMaxNegotiableProtocolVersion(); + testConnection(ProtocolVersion.V4, ProtocolVersion.V4); + testConnection(ProtocolVersion.V3, ProtocolVersion.V3); + testConnection(null, ProtocolVersion.V4); + } finally { + cleanupPeers(peer1, peer2, peer3); + } + } + + private void testConnection(com.datastax.driver.core.ProtocolVersion requestedVersion, + com.datastax.driver.core.ProtocolVersion expectedVersion) + { + long start = System.nanoTime(); + boolean expectError = requestedVersion != null && requestedVersion != expectedVersion; + Cluster.Builder builder = Cluster.builder() + .addContactPoints(nativeAddr) + .withClusterName("Test Cluster" + clusterId++) + .withPort(nativePort); + + if (requestedVersion != null) + builder = builder.withProtocolVersion(requestedVersion) ; + + Cluster cluster = builder.build(); + logger.info("Setting up cluster took {}ms", TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)); + start = System.nanoTime(); + try { + cluster.connect(); + if (expectError) + fail("Expected a protocol exception"); + } + catch (Exception e) + { + if (!expectError) + { + e.printStackTrace(); + fail("Did not expect any exception"); + } + + assertTrue(e.getMessage().contains(String.format("Host does not support protocol version %s but %s", requestedVersion, expectedVersion))); + } finally { + logger.info("Testing connection took {}ms", TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)); + start = System.nanoTime(); + cluster.closeAsync(); + logger.info("Tearing down cluster connection took {}ms", TimeUnit.MILLISECONDS.convert(System.nanoTime() - start, TimeUnit.NANOSECONDS)); + + } + } + +} diff --git a/test/unit/org/apache/cassandra/transport/ProtocolTestHelper.java b/test/unit/org/apache/cassandra/transport/ProtocolTestHelper.java new file mode 100644 index 0000000..90a2801 --- /dev/null +++ b/test/unit/org/apache/cassandra/transport/ProtocolTestHelper.java @@ -0,0 +1,95 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.cassandra.transport; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.concurrent.ExecutorService; + +import com.google.common.util.concurrent.MoreExecutors; + +import org.apache.cassandra.config.Config; +import org.apache.cassandra.config.DatabaseDescriptor; +import org.apache.cassandra.db.SystemKeyspace; +import org.apache.cassandra.gms.VersionedValue; +import org.apache.cassandra.utils.FBUtilities; + +public class ProtocolTestHelper +{ + static ExecutorService executor = MoreExecutors.newDirectExecutorService(); + static InetAddress setupPeer(String address, String version) throws Throwable + { + InetAddress peer = peer(address); + updatePeerInfo(peer, version); + return peer; + } + + static void updatePeerInfo(InetAddress peer, String version) throws Throwable + { + SystemKeyspace.updatePeerInfo(peer, "release_version", version, executor); + } + + static InetAddress peer(String address) + { + try + { + return InetAddress.getByName(address); + } + catch (UnknownHostException e) + { + throw new RuntimeException("Error creating peer", e); + } + } + + static void cleanupPeers(InetAddress...peers) throws Throwable + { + for (InetAddress peer : peers) + if (peer != null) + SystemKeyspace.removeEndpoint(peer); + } + + static void setStaticLimitInConfig(Integer version) + { + try + { + Field field = FBUtilities.getProtectedField(DatabaseDescriptor.class, "conf"); + ((Config)field.get(null)).native_transport_max_negotiable_protocol_version = version == null ? Integer.MIN_VALUE : version; + } + catch (IllegalAccessException e) + { + throw new RuntimeException("Error setting native_transport_max_protocol_version on Config", e); + } + } + + static VersionedValue releaseVersion(String versionString) + { + try + { + Constructor<VersionedValue> ctor = VersionedValue.class.getDeclaredConstructor(String.class); + ctor.setAccessible(true); + return ctor.newInstance(versionString); + } + catch (Exception e) + { + throw new RuntimeException("Error constructing VersionedValue for release version", e); + } + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@cassandra.apache.org For additional commands, e-mail: commits-h...@cassandra.apache.org