diff --git a/cassandra/metadata.py b/cassandra/metadata.py index 18d424978..30bcf8165 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -26,6 +26,7 @@ import struct import random import itertools +from typing import Optional murmur3 = None try: @@ -168,10 +169,13 @@ def _rebuild_all(self, parser): current_keyspaces = set() for keyspace_meta in parser.get_all_keyspaces(): current_keyspaces.add(keyspace_meta.name) - old_keyspace_meta = self.keyspaces.get(keyspace_meta.name, None) + old_keyspace_meta: Optional[KeyspaceMetadata] = self.keyspaces.get(keyspace_meta.name, None) self.keyspaces[keyspace_meta.name] = keyspace_meta if old_keyspace_meta: self._keyspace_updated(keyspace_meta.name) + for table_name in old_keyspace_meta.tables.keys(): + if table_name not in keyspace_meta.tables: + self._table_removed(keyspace_meta.name, table_name) else: self._keyspace_added(keyspace_meta.name) @@ -265,6 +269,9 @@ def _drop_aggregate(self, keyspace, aggregate): except KeyError: pass + def _table_removed(self, keyspace, table): + self._tablets.drop_tablets(keyspace, table) + def _keyspace_added(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) @@ -272,10 +279,12 @@ def _keyspace_added(self, ksname): def _keyspace_updated(self, ksname): if self.token_map: self.token_map.rebuild_keyspace(ksname, build_if_absent=False) + self._tablets.drop_tablets(ksname) def _keyspace_removed(self, ksname): if self.token_map: self.token_map.remove_keyspace(ksname) + self._tablets.drop_tablets(ksname) def rebuild_token_map(self, partitioner, token_map): """ @@ -340,11 +349,13 @@ def add_or_return_host(self, host): return host, True def remove_host(self, host): + self._tablets.drop_tablets_by_host_id(host.host_id) with self._hosts_lock: self._host_id_by_endpoint.pop(host.endpoint, False) return bool(self._hosts.pop(host.host_id, False)) def remove_host_by_host_id(self, host_id, endpoint=None): + self._tablets.drop_tablets_by_host_id(host_id) with self._hosts_lock: if endpoint and self._host_id_by_endpoint[endpoint] == host_id: self._host_id_by_endpoint.pop(endpoint, False) diff --git a/cassandra/tablets.py b/cassandra/tablets.py index 61394eace..457ee93ca 100644 --- a/cassandra/tablets.py +++ b/cassandra/tablets.py @@ -1,4 +1,6 @@ from threading import Lock +from typing import Optional +from uuid import UUID class Tablet(object): @@ -32,6 +34,12 @@ def from_row(first_token, last_token, replicas): return tablet return None + def replica_contains_host_id(self, uuid: UUID) -> bool: + for replica in self.replicas: + if replica[0] == uuid: + return True + return False + class Tablets(object): _lock = None @@ -51,6 +59,33 @@ def get_tablet_for_key(self, keyspace, table, t): return tablet[id] return None + def drop_tablets(self, keyspace: str, table: Optional[str] = None): + with self._lock: + if table is not None: + self._tablets.pop((keyspace, table), None) + return + + to_be_deleted = [] + for key in self._tablets.keys(): + if key[0] == keyspace: + to_be_deleted.append(key) + + for key in to_be_deleted: + del self._tablets[key] + + def drop_tablets_by_host_id(self, host_id: Optional[UUID]): + if host_id is None: + return + with self._lock: + for key, tablets in self._tablets.items(): + to_be_deleted = [] + for tablet_id, tablet in enumerate(tablets): + if tablet.replica_contains_host_id(host_id): + to_be_deleted.append(tablet_id) + + for tablet_id in reversed(to_be_deleted): + tablets.pop(tablet_id) + def add_tablet(self, keyspace, table, tablet): with self._lock: tablets_for_table = self._tablets.setdefault((keyspace, table), []) diff --git a/tests/integration/experiments/test_tablets.py b/tests/integration/experiments/test_tablets.py index 98e65c538..79dd16660 100644 --- a/tests/integration/experiments/test_tablets.py +++ b/tests/integration/experiments/test_tablets.py @@ -1,31 +1,36 @@ import time -import unittest + import pytest -import os + from cassandra.cluster import Cluster from cassandra.policies import ConstantReconnectionPolicy, RoundRobinPolicy, TokenAwarePolicy from tests.integration import PROTOCOL_VERSION, use_cluster from tests.unit.test_host_connection_pool import LOGGER +CCM_CLUSTER = None + def setup_module(): - use_cluster('tablets', [3], start=True) + global CCM_CLUSTER + + CCM_CLUSTER = use_cluster('tablets', [3], start=True) -class TestTabletsIntegration(unittest.TestCase): + +class TestTabletsIntegration: @classmethod def setup_class(cls): cls.cluster = Cluster(contact_points=["127.0.0.1", "127.0.0.2", "127.0.0.3"], protocol_version=PROTOCOL_VERSION, load_balancing_policy=TokenAwarePolicy(RoundRobinPolicy()), reconnection_policy=ConstantReconnectionPolicy(1)) cls.session = cls.cluster.connect() - cls.create_ks_and_cf(cls) + cls.create_ks_and_cf(cls.session) cls.create_data(cls.session) @classmethod def teardown_class(cls): cls.cluster.shutdown() - def verify_same_host_in_tracing(self, results): + def verify_hosts_in_tracing(self, results, expected): traces = results.get_query_trace() events = traces.events host_set = set() @@ -33,8 +38,8 @@ def verify_same_host_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description) host_set.add(event.source) - self.assertEqual(len(host_set), 1) - self.assertIn('locally', "\n".join([event.description for event in events])) + assert len(host_set) == expected + assert 'locally' in "\n".join([event.description for event in events]) trace_id = results.response_future.get_query_trace_ids()[0] traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,)) @@ -44,8 +49,12 @@ def verify_same_host_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s", event.source, event.activity) host_set.add(event.source) - self.assertEqual(len(host_set), 1) - self.assertIn('locally', "\n".join([event.activity for event in events])) + assert len(host_set) == expected + assert 'locally' in "\n".join([event.activity for event in events]) + + def get_tablet_record(self, query): + metadata = self.session.cluster.metadata + return metadata._tablets.get_tablet_for_key(query.keyspace, query.table, metadata.token_map.token_class.from_key(query.routing_key)) def verify_same_shard_in_tracing(self, results): traces = results.get_query_trace() @@ -55,8 +64,8 @@ def verify_same_shard_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s %s", event.source, event.thread_name, event.description) shard_set.add(event.thread_name) - self.assertEqual(len(shard_set), 1) - self.assertIn('locally', "\n".join([event.description for event in events])) + assert len(shard_set) == 1 + assert 'locally' in "\n".join([event.description for event in events]) trace_id = results.response_future.get_query_trace_ids()[0] traces = self.session.execute("SELECT * FROM system_traces.events WHERE session_id = %s", (trace_id,)) @@ -66,27 +75,28 @@ def verify_same_shard_in_tracing(self, results): LOGGER.info("TRACE EVENT: %s %s", event.thread, event.activity) shard_set.add(event.thread) - self.assertEqual(len(shard_set), 1) - self.assertIn('locally', "\n".join([event.activity for event in events])) + assert len(shard_set) == 1 + assert 'locally' in "\n".join([event.activity for event in events]) - def create_ks_and_cf(self): - self.session.execute( + @classmethod + def create_ks_and_cf(cls, session): + session.execute( """ DROP KEYSPACE IF EXISTS test1 """ ) - self.session.execute( + session.execute( """ CREATE KEYSPACE test1 WITH replication = { 'class': 'NetworkTopologyStrategy', - 'replication_factor': 1 + 'replication_factor': 2 } AND tablets = { 'initial': 8 } """) - self.session.execute( + session.execute( """ CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck)); """) @@ -110,7 +120,7 @@ def query_data_shard_select(self, session, verify_in_tracing=True): bound = prepared.bind([(2)]) results = session.execute(bound, trace=True) - self.assertEqual(results, [(2, 2, 0)]) + assert results == [(2, 2, 0)] if verify_in_tracing: self.verify_same_shard_in_tracing(results) @@ -122,9 +132,9 @@ def query_data_host_select(self, session, verify_in_tracing=True): bound = prepared.bind([(2)]) results = session.execute(bound, trace=True) - self.assertEqual(results, [(2, 2, 0)]) + assert results == [(2, 2, 0)] if verify_in_tracing: - self.verify_same_host_in_tracing(results) + self.verify_hosts_in_tracing(results, 1) def query_data_shard_insert(self, session, verify_in_tracing=True): prepared = session.prepare( @@ -146,7 +156,7 @@ def query_data_host_insert(self, session, verify_in_tracing=True): bound = prepared.bind([(52), (1), (2)]) results = session.execute(bound, trace=True) if verify_in_tracing: - self.verify_same_host_in_tracing(results) + self.verify_hosts_in_tracing(results, 2) def test_tablets(self): self.query_data_host_select(self.session) @@ -155,3 +165,70 @@ def test_tablets(self): def test_tablets_shard_awareness(self): self.query_data_shard_select(self.session) self.query_data_shard_insert(self.session) + + def test_tablets_invalidation_drop_ks_while_reconnecting(self): + def recreate_while_reconnecting(_): + # Kill control connection + conn = self.session.cluster.control_connection._connection + self.session.cluster.control_connection._connection = None + conn.close() + + # Drop and recreate ks and table to trigger tablets invalidation + self.create_ks_and_cf(self.cluster.connect()) + + # Start control connection + self.session.cluster.control_connection._reconnect() + + self.run_tablets_invalidation_test(recreate_while_reconnecting) + + def test_tablets_invalidation_drop_ks(self): + def drop_ks(_): + # Drop and recreate ks and table to trigger tablets invalidation + self.create_ks_and_cf(self.cluster.connect()) + time.sleep(3) + + self.run_tablets_invalidation_test(drop_ks) + + @pytest.mark.last + def test_tablets_invalidation_decommission_non_cc_node(self): + def decommission_non_cc_node(rec): + # Drop and recreate ks and table to trigger tablets invalidation + for node in CCM_CLUSTER.nodes.values(): + if self.cluster.control_connection._connection.endpoint.address == node.network_interfaces["storage"][0]: + # Ignore node that control connection is connected to + continue + for replica in rec.replicas: + if str(replica[0]) == str(node.node_hostid): + node.decommission() + break + else: + continue + break + else: + assert False, "failed to find node to decommission" + time.sleep(10) + + self.run_tablets_invalidation_test(decommission_non_cc_node) + + + def run_tablets_invalidation_test(self, invalidate): + # Make sure driver holds tablet info + # By landing query to the host that is not in replica set + bound = self.session.prepare( + """ + SELECT pk, ck, v FROM test1.table1 WHERE pk = ? + """).bind([(2)]) + + rec = None + for host in self.cluster.metadata.all_hosts(): + self.session.execute(bound, host=host) + rec = self.get_tablet_record(bound) + if rec is not None: + break + + assert rec is not None, "failed to find tablet record" + + invalidate(rec) + + # Check if tablets information was purged + assert self.get_tablet_record(bound) is None, "tablet was not deleted, invalidation did not work"