diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 14c3074e85..cab311bf3d 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -3641,6 +3641,7 @@ def _set_new_connection(self, conn): with self._lock: old = self._connection self._connection = conn + self.refresh_schema() if old: log.debug("[control connection] Closing old connection %r, replacing with %r", old, conn) diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 18d4249780..d4a13cd249 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -168,10 +168,14 @@ def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) - old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None) + old_keyspace_meta: KeyspaceMetadata = self.keyspaces.get(keyspace_meta.name, None) self.keyspaces[keyspace_meta.name] = keyspace_meta if old_keyspace_meta: self._keyspace_updated(keyspace_meta.name) + for table_name, old_table_meta in old_keyspace_meta.tables.items(): + new_table_meta: TableMetadata = keyspace_meta.tables.get(table_name, None) + if new_table_meta is None: + self._table_removed(keyspace_meta.name, table_name) else: self._keyspace_added(keyspace_meta.name) @@ -265,6 +269,9 @@ def _drop_aggregate(self, keyspace, aggregate): except KeyError: pass + def _table_removed(self, keyspace, table): + self._tablets.drop_tablet(keyspace, table) + def _keyspace_added(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) @@ -272,10 +279,12 @@ def _keyspace_added(self, ksname): def _keyspace_updated(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) + self._tablets.drop_tablet(ksname) def _keyspace_removed(self, ksname): if self.token_map: self.token_map.remove_keyspace(ksname) + self._tablets.drop_tablet(ksname) def rebuild_token_map(self, partitioner, token_map): """ @@ -340,11 +349,13 @@ def add_or_return_host(self, host): return host, True def remove_host(self, host): + self._tablets.drop_tablet_by_host_id(host.host_id) with self._hosts_lock: self._host_id_by_endpoint.pop(host.endpoint, False) return bool(self._hosts.pop(host.host_id, False)) def remove_host_by_host_id(self, host_id, endpoint=None): + self._tablets.drop_tablet_by_host_id(host_id) with self._hosts_lock: if endpoint and self._host_id_by_endpoint[endpoint] == host_id: self._host_id_by_endpoint.pop(endpoint, False) diff --git a/cassandra/tablets.py b/cassandra/tablets.py index 61394eace5..d58c55c9a9 100644 --- a/cassandra/tablets.py +++ b/cassandra/tablets.py @@ -1,4 +1,5 @@ from threading import Lock +from uuid import UUID class Tablet(object): @@ -32,6 +33,12 @@ def from_row(first_token, last_token, replicas): return tablet return None + def replica_contains_host_id(self, uuid: UUID) -> bool: + for replica in self.replicas: + if replica[0] == uuid: + return True + return False + class Tablets(object): _lock = None @@ -51,6 +58,38 @@ def get_tablet_for_key(self, keyspace, table, t): return tablet[id] return None + def drop_tablet(self, keyspace: str, table: str = None): + with self._lock: + if table is not None: + self._tablets.pop((keyspace, table), None) + return + + to_be_deleted = [] + for key in self._tablets.keys(): + if key[0] == keyspace: + to_be_deleted.append(key) + + for key in to_be_deleted: + del self._tablets[key] + + def drop_tablet_by_host_id(self, host_id: UUID): + if host_id is None: + return + with self._lock: + for key, tablets in self._tablets.items(): + to_be_deleted = [] + for tablet_id, tablet in enumerate(tablets): + if tablet.replica_contains_host_id(host_id): + to_be_deleted.append(tablet_id) + + if len(to_be_deleted) == 0: + continue + + for tablet_id in reversed(to_be_deleted): + tablets.pop(tablet_id) + + self._tablets[key] = tablets + def add_tablet(self, keyspace, table, tablet): with self._lock: tablets_for_table = self._tablets.setdefault((keyspace, table), []) diff --git a/tests/integration/experiments/test_tablets.py b/tests/integration/experiments/test_tablets.py index 98e65c5383..725aa77c88 100644 --- a/tests/integration/experiments/test_tablets.py +++ b/tests/integration/experiments/test_tablets.py @@ -1,15 +1,18 @@ import time import unittest -import pytest -import os + from cassandra.cluster import Cluster from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy from tests.integration import PROTOCOL_VERSION, use_cluster from tests.unit.test_host_connection_pool import LOGGER +CCM_CLUSTER = None + def setup_module(): - use_cluster('tablets', [3], start=True) + global CCM_CLUSTER + + CCM_CLUSTER = use_cluster('tablets', [3], start=True) class TestTabletsIntegration(unittest.TestCase): @classmethod @@ -18,7 +21,7 @@ def setup_class(cls): load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), reconnection_policy=ConstantReconnectionPolicy(1)) cls.session = cls.cluster.connect() - cls.create_ks_and_cf(cls) + cls.create_ks_and_cf(cls.session) cls.create_data(cls.session) @classmethod @@ -47,6 +50,10 @@ def verify_same_host_in_tracing(self, results): self.assertEqual(len(host_set), 1) self.assertIn('locally', "\n".join([event.activity for event in events])) + def get_tablet_record(self, query): + metadata = self.session.cluster.metadata + return metadata._tablets.get_tablet_for_key(query.keyspace, query.table, metadata.token_map.token_class.from_key(query.routing_key)) + def verify_same_shard_in_tracing(self, results): traces = results.get_query_trace() events = traces.events @@ -69,24 +76,25 @@ def verify_same_shard_in_tracing(self, results): self.assertEqual(len(shard_set), 1) self.assertIn('locally', "\n".join([event.activity for event in events])) - def create_ks_and_cf(self): - self.session.execute( + @classmethod + def create_ks_and_cf(cls, session): + session.execute( """ DROP KEYSPACE IF EXISTS test1 """ ) - self.session.execute( + session.execute( """ CREATE KEYSPACE test1 WITH replication = { 'class': 'NetworkTopologyStrategy', - 'replication_factor': 1 + 'replication_factor': 2 } AND tablets = { 'initial': 8 } """) - self.session.execute( + session.execute( """ CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck)); """) @@ -109,6 +117,8 @@ def query_data_shard_select(self, session, verify_in_tracing=True): """) bound = prepared.bind([(2)]) + assert self.get_tablet_record(bound) is not None + results = session.execute(bound, trace=True) self.assertEqual(results, [(2, 2, 0)]) if verify_in_tracing: @@ -121,6 +131,8 @@ def query_data_host_select(self, session, verify_in_tracing=True): """) bound = prepared.bind([(2)]) + assert self.get_tablet_record(bound) is not None + results = session.execute(bound, trace=True) self.assertEqual(results, [(2, 2, 0)]) if verify_in_tracing: @@ -133,6 +145,8 @@ def query_data_shard_insert(self, session, verify_in_tracing=True): """) bound = prepared.bind([(51), (1), (2)]) + assert self.get_tablet_record(bound) is not None + results = session.execute(bound, trace=True) if verify_in_tracing: self.verify_same_shard_in_tracing(results) @@ -144,6 +158,8 @@ def query_data_host_insert(self, session, verify_in_tracing=True): """) bound = prepared.bind([(52), (1), (2)]) + assert self.get_tablet_record(bound) is not None + results = session.execute(bound, trace=True) if verify_in_tracing: self.verify_same_host_in_tracing(results) @@ -155,3 +171,61 @@ def test_tablets(self): def test_tablets_shard_awareness(self): self.query_data_shard_select(self.session) self.query_data_shard_insert(self.session) + + def test_tablets_invalidation(self): + global CCM_CLUSTER + + def recreate_and_wait(_): + # Drop and recreate ks and table to trigger tablets invalidation + self.create_ks_and_cf(self.cluster.connect()) + time.sleep(3) + + def recreate_while_reconnecting(_): + # Kill control connection + conn = self.session.cluster.control_connection._connection + self.session.cluster.control_connection._connection = None + conn.close() + + # Drop and recreate ks and table to trigger tablets invalidation + self.create_ks_and_cf(self.cluster.connect()) + + # Start control connection + self.session.cluster.control_connection._reconnect() + + def decommission_non_cc_node(rec): + # Drop and recreate ks and table to trigger tablets invalidation + for node in CCM_CLUSTER.nodes.values(): + if self.cluster.control_connection._connection.endpoint.address == node.network_interfaces["storage"][0]: + # Ignore node that control connection is connected to + continue + for replica in rec.replicas: + if str(replica[0]) == str(node.node_hostid): + node.decommission() + break + else: + continue + break + else: + assert False, "failed to find node to decommission" + time.sleep(10) + + # self.run_tablets_invalidation_test(recreate_and_wait) + # self.run_tablets_invalidation_test(recreate_while_reconnecting) + self.run_tablets_invalidation_test(decommission_non_cc_node) + + + def run_tablets_invalidation_test(self, invalidate): + # Make sure driver holds tablet info + bound = self.session.prepare( + """ + SELECT pk, ck, v FROM test1.table1 WHERE pk = ? + """).bind([(2)]) + self.session.execute(bound) + + rec = self.get_tablet_record(bound) + assert rec is not None + + invalidate(rec) + + # Check if tablets information was purged + assert self.get_tablet_record(bound) is None