From c62665f4f34d5452134f4429eaa88e9aa0bee548 Mon Sep 17 00:00:00 2001 From: Sylwia Szunejko Date: Thu, 13 Jun 2024 09:28:16 +0200 Subject: [PATCH] Add RackAwareRoundRobinPolicy for host selection --- cassandra/cluster.py | 9 +- cassandra/metadata.py | 2 +- cassandra/policies.py | 152 +++++++++++++- docs/api/cassandra/policies.rst | 3 + .../standard/test_rack_aware_policy.py | 89 ++++++++ tests/unit/test_policies.py | 198 +++++++++++------- 6 files changed, 369 insertions(+), 84 deletions(-) create mode 100644 tests/integration/standard/test_rack_aware_policy.py diff --git a/cassandra/cluster.py b/cassandra/cluster.py index 71be215ab1..06e6293ef8 100644 --- a/cassandra/cluster.py +++ b/cassandra/cluster.py @@ -492,7 +492,8 @@ def _profiles_without_explicit_lbps(self): def distance(self, host): distances = set(p.load_balancing_policy.distance(host) for p in self.profiles.values()) - return HostDistance.LOCAL if HostDistance.LOCAL in distances else \ + return HostDistance.LOCAL_RACK if HostDistance.LOCAL_RACK in distances else \ + HostDistance.LOCAL if HostDistance.LOCAL in distances else \ HostDistance.REMOTE if HostDistance.REMOTE in distances else \ HostDistance.IGNORED @@ -609,7 +610,7 @@ class Cluster(object): Defaults to loopback interface. - Note: When using :class:`.DCAwareLoadBalancingPolicy` with no explicit + Note: When using :class:`.DCAwareRoundRobinPolicy` with no explicit local_dc set (as is the default), the DC is chosen from an arbitrary host in contact_points. In this case, contact_points should contain only nodes from a single, local DC. @@ -1369,21 +1370,25 @@ def __init__(self, self._user_types = defaultdict(dict) self._min_requests_per_connection = { + HostDistance.LOCAL_RACK: DEFAULT_MIN_REQUESTS, HostDistance.LOCAL: DEFAULT_MIN_REQUESTS, HostDistance.REMOTE: DEFAULT_MIN_REQUESTS } self._max_requests_per_connection = { + HostDistance.LOCAL_RACK: DEFAULT_MAX_REQUESTS, HostDistance.LOCAL: DEFAULT_MAX_REQUESTS, HostDistance.REMOTE: DEFAULT_MAX_REQUESTS } self._core_connections_per_host = { + HostDistance.LOCAL_RACK: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, HostDistance.LOCAL: DEFAULT_MIN_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MIN_CONNECTIONS_PER_REMOTE_HOST } self._max_connections_per_host = { + HostDistance.LOCAL_RACK: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, HostDistance.LOCAL: DEFAULT_MAX_CONNECTIONS_PER_LOCAL_HOST, HostDistance.REMOTE: DEFAULT_MAX_CONNECTIONS_PER_REMOTE_HOST } diff --git a/cassandra/metadata.py b/cassandra/metadata.py index d30e6a1925..edee822e40 100644 --- a/cassandra/metadata.py +++ b/cassandra/metadata.py @@ -3436,7 +3436,7 @@ def group_keys_by_replica(session, keyspace, table, keys): all_replicas = cluster.metadata.get_replicas(keyspace, routing_key) # First check if there are local replicas valid_replicas = [host for host in all_replicas if - host.is_up and distance(host) == HostDistance.LOCAL] + host.is_up and distance(host) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]] if not valid_replicas: valid_replicas = [host for host in all_replicas if host.is_up] diff --git a/cassandra/policies.py b/cassandra/policies.py index a1495f3660..d9d3da7980 100644 --- a/cassandra/policies.py +++ b/cassandra/policies.py @@ -46,7 +46,18 @@ class HostDistance(object): connections opened to it. """ - LOCAL = 0 + LOCAL_RACK = 0 + """ + Nodes with ``LOCAL_RACK`` distance will be preferred for operations + under some load balancing policies (such as :class:`.RackAwareRoundRobinPolicy`) + and will have a greater number of connections opened against + them by default. + + This distance is typically used for nodes within the same + datacenter and the same rack as the client. + """ + + LOCAL = 1 """ Nodes with ``LOCAL`` distance will be preferred for operations under some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) @@ -57,12 +68,12 @@ class HostDistance(object): datacenter as the client. """ - REMOTE = 1 + REMOTE = 2 """ Nodes with ``REMOTE`` distance will be treated as a last resort - by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy`) - and will have a smaller number of connections opened against - them by default. + by some load balancing policies (such as :class:`.DCAwareRoundRobinPolicy` + and :class:`.RackAwareRoundRobinPolicy`)and will have a smaller number of + connections opened against them by default. This distance is typically used for nodes outside of the datacenter that the client is running in. @@ -102,6 +113,11 @@ class LoadBalancingPolicy(HostStateListener): You may also use subclasses of :class:`.LoadBalancingPolicy` for custom behavior. + + You should always use immutable collections (e.g., tuples or + frozensets) to store information about hosts to prevent accidental + modification. When there are changes to the hosts (e.g., a host is + down or up), the old collection should be replaced with a new one. """ _hosts_lock = None @@ -316,6 +332,130 @@ def on_add(self, host): def on_remove(self, host): self.on_down(host) +class RackAwareRoundRobinPolicy(LoadBalancingPolicy): + """ + Similar to :class:`.DCAwareRoundRobinPolicy`, but prefers hosts + in the local rack, before hosts in the local datacenter but a + different rack, before hosts in all other datercentres + """ + + local_dc = None + local_rack = None + used_hosts_per_remote_dc = 0 + + def __init__(self, local_dc, local_rack, used_hosts_per_remote_dc=0): + """ + The `local_dc` and `local_rack` parameters should be the name of the + datacenter and rack (such as is reported by ``nodetool ring``) that + should be considered local. + + `used_hosts_per_remote_dc` controls how many nodes in + each remote datacenter will have connections opened + against them. In other words, `used_hosts_per_remote_dc` hosts + will be considered :attr:`~.HostDistance.REMOTE` and the + rest will be considered :attr:`~.HostDistance.IGNORED`. + By default, all remote hosts are ignored. + """ + self.local_rack = local_rack + self.local_dc = local_dc + self.used_hosts_per_remote_dc = used_hosts_per_remote_dc + self._live_hosts = {} + self._dc_live_hosts = {} + self._endpoints = [] + self._position = 0 + LoadBalancingPolicy.__init__(self) + + def _rack(self, host): + return host.rack or self.local_rack + + def _dc(self, host): + return host.datacenter or self.local_dc + + def populate(self, cluster, hosts): + for (dc, rack), rack_hosts in groupby(hosts, lambda host: (self._dc(host), self._rack(host))): + self._live_hosts[(dc, rack)] = tuple(set(rack_hosts)) + for dc, dc_hosts in groupby(hosts, lambda host: self._dc(host)): + self._dc_live_hosts[dc] = tuple(set(dc_hosts)) + + self._position = randint(0, len(hosts) - 1) if hosts else 0 + + def distance(self, host): + rack = self._rack(host) + dc = self._dc(host) + if rack == self.local_rack and dc == self.local_dc: + return HostDistance.LOCAL_RACK + + if dc == self.local_dc: + return HostDistance.LOCAL + + if not self.used_hosts_per_remote_dc: + return HostDistance.IGNORED + + dc_hosts = self._dc_live_hosts.get(dc, ()) + if not dc_hosts: + return HostDistance.IGNORED + if host in dc_hosts and dc_hosts.index(host) < self.used_hosts_per_remote_dc: + return HostDistance.REMOTE + else: + return HostDistance.IGNORED + + def make_query_plan(self, working_keyspace=None, query=None): + pos = self._position + self._position += 1 + + local_rack_live = self._live_hosts.get((self.local_dc, self.local_rack), ()) + pos = (pos % len(local_rack_live)) if local_rack_live else 0 + # Slice the cyclic iterator to start from pos and include the next len(local_live) elements + # This ensures we get exactly one full cycle starting from pos + for host in islice(cycle(local_rack_live), pos, pos + len(local_rack_live)): + yield host + + local_live = [host for host in self._dc_live_hosts.get(self.local_dc, ()) if host.rack != self.local_rack] + pos = (pos % len(local_live)) if local_live else 0 + for host in islice(cycle(local_live), pos, pos + len(local_live)): + yield host + + # the dict can change, so get candidate DCs iterating over keys of a copy + for dc, remote_live in self._dc_live_hosts.copy().items(): + if dc != self.local_dc: + for host in remote_live[:self.used_hosts_per_remote_dc]: + yield host + + def on_up(self, host): + dc = self._dc(host) + rack = self._rack(host) + with self._hosts_lock: + current_rack_hosts = self._live_hosts.get((dc, rack), ()) + if host not in current_rack_hosts: + self._live_hosts[(dc, rack)] = current_rack_hosts + (host, ) + current_dc_hosts = self._dc_live_hosts.get(dc, ()) + if host not in current_dc_hosts: + self._dc_live_hosts[dc] = current_dc_hosts + (host, ) + + def on_down(self, host): + dc = self._dc(host) + rack = self._rack(host) + with self._hosts_lock: + current_rack_hosts = self._live_hosts.get((dc, rack), ()) + if host in current_rack_hosts: + hosts = tuple(h for h in current_rack_hosts if h != host) + if hosts: + self._live_hosts[(dc, rack)] = hosts + else: + del self._live_hosts[(dc, rack)] + current_dc_hosts = self._dc_live_hosts.get(dc, ()) + if host in current_dc_hosts: + hosts = tuple(h for h in current_dc_hosts if h != host) + if hosts: + self._dc_live_hosts[dc] = hosts + else: + del self._dc_live_hosts[dc] + + def on_add(self, host): + self.on_up(host) + + def on_remove(self, host): + self.on_down(host) class TokenAwarePolicy(LoadBalancingPolicy): """ @@ -390,7 +530,7 @@ def make_query_plan(self, working_keyspace=None, query=None): shuffle(replicas) for replica in replicas: - if replica.is_up and child.distance(replica) == HostDistance.LOCAL: + if replica.is_up and child.distance(replica) in [HostDistance.LOCAL, HostDistance.LOCAL_RACK]: yield replica for host in child.make_query_plan(keyspace, query): diff --git a/docs/api/cassandra/policies.rst b/docs/api/cassandra/policies.rst index 387b19ed95..ea3b19d796 100644 --- a/docs/api/cassandra/policies.rst +++ b/docs/api/cassandra/policies.rst @@ -18,6 +18,9 @@ Load Balancing .. autoclass:: DCAwareRoundRobinPolicy :members: +.. autoclass:: RackAwareRoundRobinPolicy + :members: + .. autoclass:: WhiteListRoundRobinPolicy :members: diff --git a/tests/integration/standard/test_rack_aware_policy.py b/tests/integration/standard/test_rack_aware_policy.py new file mode 100644 index 0000000000..5d7a69642f --- /dev/null +++ b/tests/integration/standard/test_rack_aware_policy.py @@ -0,0 +1,89 @@ +import logging +import unittest + +from cassandra.cluster import Cluster +from cassandra.policies import ConstantReconnectionPolicy, RackAwareRoundRobinPolicy + +from tests.integration import PROTOCOL_VERSION, get_cluster, use_multidc + +LOGGER = logging.getLogger(__name__) + +def setup_module(): + use_multidc({'DC1': {'RC1': 2, 'RC2': 2}, 'DC2': {'RC1': 3}}) + +class RackAwareRoundRobinPolicyTests(unittest.TestCase): + @classmethod + def setup_class(cls): + cls.cluster = Cluster(contact_points=[node.address() for node in get_cluster().nodelist()], protocol_version=PROTOCOL_VERSION, + load_balancing_policy=RackAwareRoundRobinPolicy("DC1", "RC1", used_hosts_per_remote_dc=0), + reconnection_policy=ConstantReconnectionPolicy(1)) + cls.session = cls.cluster.connect() + cls.create_ks_and_cf(cls) + cls.create_data(cls.session) + cls.node1, cls.node2, cls.node3, cls.node4, cls.node5, cls.node6, cls.node7 = get_cluster().nodes.values() + + @classmethod + def teardown_class(cls): + cls.cluster.shutdown() + + def create_ks_and_cf(self): + self.session.execute( + """ + DROP KEYSPACE IF EXISTS test1 + """ + ) + self.session.execute( + """ + CREATE KEYSPACE test1 + WITH replication = { + 'class': 'NetworkTopologyStrategy', + 'replication_factor': 3 + } + """) + + self.session.execute( + """ + CREATE TABLE test1.table1 (pk int, ck int, v int, PRIMARY KEY (pk, ck)); + """) + + @staticmethod + def create_data(session): + prepared = session.prepare( + """ + INSERT INTO test1.table1 (pk, ck, v) VALUES (?, ?, ?) + """) + + for i in range(50): + bound = prepared.bind((i, i%5, i%2)) + session.execute(bound) + + def test_rack_aware(self): + prepared = self.session.prepare( + """ + SELECT pk, ck, v FROM test1.table1 WHERE pk = ? + """) + + for i in range (10): + bound = prepared.bind([i]) + results = self.session.execute(bound) + self.assertEqual(results, [(i, i%5, i%2)]) + coordinator = str(results.response_future.coordinator_host.endpoint) + self.assertTrue(coordinator in set(["127.0.0.1:9042", "127.0.0.2:9042"])) + + self.node2.stop(wait_other_notice=True, gently=True) + + for i in range (10): + bound = prepared.bind([i]) + results = self.session.execute(bound) + self.assertEqual(results, [(i, i%5, i%2)]) + coordinator =str(results.response_future.coordinator_host.endpoint) + self.assertEqual(coordinator, "127.0.0.1:9042") + + self.node1.stop(wait_other_notice=True, gently=True) + + for i in range (10): + bound = prepared.bind([i]) + results = self.session.execute(bound) + self.assertEqual(results, [(i, i%5, i%2)]) + coordinator = str(results.response_future.coordinator_host.endpoint) + self.assertTrue(coordinator in set(["127.0.0.3:9042", "127.0.0.4:9042"])) diff --git a/tests/unit/test_policies.py b/tests/unit/test_policies.py index 877731dc08..15bd1ea95b 100644 --- a/tests/unit/test_policies.py +++ b/tests/unit/test_policies.py @@ -17,6 +17,7 @@ from itertools import islice, cycle from mock import Mock, patch, call from random import randint +import pytest from _thread import LockType import sys import struct @@ -25,7 +26,7 @@ from cassandra import ConsistencyLevel from cassandra.cluster import Cluster, ControlConnection from cassandra.metadata import Metadata -from cassandra.policies import (RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, +from cassandra.policies import (RackAwareRoundRobinPolicy, RoundRobinPolicy, WhiteListRoundRobinPolicy, DCAwareRoundRobinPolicy, TokenAwarePolicy, SimpleConvictionPolicy, HostDistance, ExponentialReconnectionPolicy, RetryPolicy, WriteType, @@ -177,75 +178,107 @@ def test_no_live_nodes(self): qplan = list(policy.make_query_plan()) self.assertEqual(qplan, []) +@pytest.mark.parametrize("policy_specialization, constructor_args", [(DCAwareRoundRobinPolicy, ("dc1", )), (RackAwareRoundRobinPolicy, ("dc1", "rack1"))]) +class TestRackOrDCAwareRoundRobinPolicy: -class DCAwareRoundRobinPolicyTest(unittest.TestCase): - - def test_no_remote(self): + def test_no_remote(self, policy_specialization, constructor_args): hosts = [] - for i in range(4): + for i in range(2): h = Host(DefaultEndPoint(i), SimpleConvictionPolicy) + h.set_location_info("dc1", "rack2") + hosts.append(h) + for i in range(2): + h = Host(DefaultEndPoint(i + 2), SimpleConvictionPolicy) h.set_location_info("dc1", "rack1") hosts.append(h) - policy = DCAwareRoundRobinPolicy("dc1") + policy = policy_specialization(*constructor_args) policy.populate(None, hosts) qplan = list(policy.make_query_plan()) - self.assertEqual(sorted(qplan), sorted(hosts)) + assert sorted(qplan) == sorted(hosts) - def test_with_remotes(self): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + def test_with_remotes(self, policy_specialization, constructor_args): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(6)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") - for h in hosts[2:]: + for h in hosts[2:4]: + h.set_location_info("dc1", "rack2") + for h in hosts[4:]: h.set_location_info("dc2", "rack1") - local_hosts = set(h for h in hosts if h.datacenter == "dc1") + local_rack_hosts = set(h for h in hosts if h.datacenter == "dc1" and h.rack == "rack1") + local_hosts = set(h for h in hosts if h.datacenter == "dc1" and h.rack != "rack1") remote_hosts = set(h for h in hosts if h.datacenter != "dc1") # allow all of the remote hosts to be used - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=2) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=2) policy.populate(Mock(), hosts) qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan[:2]), local_hosts) - self.assertEqual(set(qplan[2:]), remote_hosts) + if isinstance(policy_specialization, DCAwareRoundRobinPolicy): + assert set(qplan[:4]) == local_rack_hosts + local_hosts + elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): + assert set(qplan[:2]) == local_rack_hosts + assert set(qplan[2:4]) == local_hosts + assert set(qplan[4:]) == remote_hosts # allow only one of the remote hosts to be used - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(Mock(), hosts) qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan[:2]), local_hosts) + if isinstance(policy_specialization, DCAwareRoundRobinPolicy): + assert set(qplan[:4]) == local_rack_hosts + local_hosts + elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): + assert set(qplan[:2]) == local_rack_hosts + assert set(qplan[2:4]) == local_hosts - used_remotes = set(qplan[2:]) - self.assertEqual(1, len(used_remotes)) - self.assertIn(qplan[2], remote_hosts) + used_remotes = set(qplan[4:]) + assert 1 == len(used_remotes) + assert qplan[4] in remote_hosts # allow no remote hosts to be used - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) policy.populate(Mock(), hosts) qplan = list(policy.make_query_plan()) - self.assertEqual(2, len(qplan)) - self.assertEqual(local_hosts, set(qplan)) - def test_get_distance(self): - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=0) + assert 4 == len(qplan) + if isinstance(policy_specialization, DCAwareRoundRobinPolicy): + assert set(qplan) == local_rack_hosts + local_hosts + elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): + assert set(qplan[:2]) == local_rack_hosts + assert set(qplan[2:4]) == local_hosts + + def test_get_distance(self, policy_specialization, constructor_args): + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) + + # same dc, same rack host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) host.set_location_info("dc1", "rack1") policy.populate(Mock(), [host]) - self.assertEqual(policy.distance(host), HostDistance.LOCAL) + if isinstance(policy_specialization, DCAwareRoundRobinPolicy): + assert policy.distance(host) == HostDistance.LOCAL + elif isinstance(policy_specialization, RackAwareRoundRobinPolicy): + assert policy.distance(host) == HostDistance.LOCAL_RACK + + # same dc different rack + host = Host(DefaultEndPoint("ip1"), SimpleConvictionPolicy) + host.set_location_info("dc1", "rack2") + policy.populate(Mock(), [host]) + + assert policy.distance(host) == HostDistance.LOCAL # used_hosts_per_remote_dc is set to 0, so ignore it remote_host = Host(DefaultEndPoint("ip2"), SimpleConvictionPolicy) remote_host.set_location_info("dc2", "rack1") - self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + assert policy.distance(remote_host) == HostDistance.IGNORED # dc2 isn't registered in the policy's live_hosts dict policy.used_hosts_per_remote_dc = 1 - self.assertEqual(policy.distance(remote_host), HostDistance.IGNORED) + assert policy.distance(remote_host) == HostDistance.IGNORED # make sure the policy has both dcs registered policy.populate(Mock(), [host, remote_host]) - self.assertEqual(policy.distance(remote_host), HostDistance.REMOTE) + assert policy.distance(remote_host) == HostDistance.REMOTE # since used_hosts_per_remote_dc is set to 1, only the first # remote host in dc2 will be REMOTE, the rest are IGNORED @@ -253,54 +286,58 @@ def test_get_distance(self): second_remote_host.set_location_info("dc2", "rack1") policy.populate(Mock(), [host, remote_host, second_remote_host]) distances = set([policy.distance(remote_host), policy.distance(second_remote_host)]) - self.assertEqual(distances, set([HostDistance.REMOTE, HostDistance.IGNORED])) + assert distances == set([HostDistance.REMOTE, HostDistance.IGNORED]) - def test_status_updates(self): - hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] + def test_status_updates(self, policy_specialization, constructor_args): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(5)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") - for h in hosts[2:]: + for h in hosts[2:4]: + h.set_location_info("dc1", "rack2") + for h in hosts[4:]: h.set_location_info("dc2", "rack1") - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(Mock(), hosts) policy.on_down(hosts[0]) policy.on_remove(hosts[2]) - new_local_host = Host(DefaultEndPoint(4), SimpleConvictionPolicy) + new_local_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) new_local_host.set_location_info("dc1", "rack1") policy.on_up(new_local_host) - new_remote_host = Host(DefaultEndPoint(5), SimpleConvictionPolicy) + new_remote_host = Host(DefaultEndPoint(6), SimpleConvictionPolicy) new_remote_host.set_location_info("dc9000", "rack1") policy.on_add(new_remote_host) - # we now have two local hosts and two remote hosts in separate dcs + # we now have three local hosts and two remote hosts in separate dcs qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan[:2]), set([hosts[1], new_local_host])) - self.assertEqual(set(qplan[2:]), set([hosts[3], new_remote_host])) + + assert set(qplan[:3]) == set([hosts[1], new_local_host, hosts[3]]) + assert set(qplan[3:]) == set([hosts[4], new_remote_host]) # since we have hosts in dc9000, the distance shouldn't be IGNORED - self.assertEqual(policy.distance(new_remote_host), HostDistance.REMOTE) + assert policy.distance(new_remote_host), HostDistance.REMOTE policy.on_down(new_local_host) policy.on_down(hosts[1]) qplan = list(policy.make_query_plan()) - self.assertEqual(set(qplan), set([hosts[3], new_remote_host])) + assert set(qplan) == set([hosts[3], hosts[4], new_remote_host]) policy.on_down(new_remote_host) policy.on_down(hosts[3]) + policy.on_down(hosts[4]) qplan = list(policy.make_query_plan()) - self.assertEqual(qplan, []) + assert qplan == [] - def test_modification_during_generation(self): + def test_modification_during_generation(self, policy_specialization, constructor_args): hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(4)] for h in hosts[:2]: h.set_location_info("dc1", "rack1") for h in hosts[2:]: h.set_location_info("dc2", "rack1") - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=3) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=3) policy.populate(Mock(), hosts) # The general concept here is to change thee internal state of the @@ -315,20 +352,20 @@ def test_modification_during_generation(self): plan = policy.make_query_plan() policy.on_up(new_host) # local list is not bound yet, so we get to see that one - self.assertEqual(len(list(plan)), 3 + 2) + assert len(list(plan)) == 3 + 2 # remove local before iteration plan = policy.make_query_plan() policy.on_down(new_host) # local list is not bound yet, so we don't see it - self.assertEqual(len(list(plan)), 2 + 2) + assert len(list(plan)) == 2 + 2 # new local after starting iteration plan = policy.make_query_plan() next(plan) policy.on_up(new_host) # local list was is bound, and one consumed, so we only see the other original - self.assertEqual(len(list(plan)), 1 + 2) + assert len(list(plan)) == 1 + 2 # remove local after traversing available plan = policy.make_query_plan() @@ -336,7 +373,7 @@ def test_modification_during_generation(self): next(plan) policy.on_down(new_host) # we should be past the local list - self.assertEqual(len(list(plan)), 0 + 2) + assert len(list(plan)) == 0 + 2 # REMOTES CHANGE new_host.set_location_info("dc2", "rack1") @@ -347,7 +384,7 @@ def test_modification_during_generation(self): next(plan) policy.on_up(new_host) # list is updated before we get to it - self.assertEqual(len(list(plan)), 0 + 3) + assert len(list(plan)) == 0 + 3 # remove remote after traversing local, but not starting remote plan = policy.make_query_plan() @@ -355,7 +392,7 @@ def test_modification_during_generation(self): next(plan) policy.on_down(new_host) # list is updated before we get to it - self.assertEqual(len(list(plan)), 0 + 2) + assert len(list(plan)) == 0 + 2 # new remote after traversing local, and starting remote plan = policy.make_query_plan() @@ -363,7 +400,7 @@ def test_modification_during_generation(self): next(plan) policy.on_up(new_host) # slice is already made, and we've consumed one - self.assertEqual(len(list(plan)), 0 + 1) + assert len(list(plan)) == 0 + 1 # remove remote after traversing local, and starting remote plan = policy.make_query_plan() @@ -371,7 +408,7 @@ def test_modification_during_generation(self): next(plan) policy.on_down(new_host) # slice is created with all present, and we've consumed one - self.assertEqual(len(list(plan)), 0 + 2) + assert len(list(plan)) == 0 + 2 # local DC disappears after finishing it, but not starting remote plan = policy.make_query_plan() @@ -380,7 +417,7 @@ def test_modification_during_generation(self): policy.on_down(hosts[0]) policy.on_down(hosts[1]) # dict traversal starts as normal - self.assertEqual(len(list(plan)), 0 + 2) + assert len(list(plan)) == 0 + 2 policy.on_up(hosts[0]) policy.on_up(hosts[1]) @@ -393,7 +430,7 @@ def test_modification_during_generation(self): policy.on_down(hosts[0]) policy.on_down(hosts[1]) # dict traversal has begun and consumed one - self.assertEqual(len(list(plan)), 0 + 1) + assert len(list(plan)) == 0 + 1 policy.on_up(hosts[0]) policy.on_up(hosts[1]) @@ -404,7 +441,7 @@ def test_modification_during_generation(self): policy.on_down(hosts[2]) policy.on_down(hosts[3]) # nothing left - self.assertEqual(len(list(plan)), 0 + 0) + assert len(list(plan)) == 0 + 0 policy.on_up(hosts[2]) policy.on_up(hosts[3]) @@ -415,7 +452,7 @@ def test_modification_during_generation(self): policy.on_down(hosts[2]) policy.on_down(hosts[3]) # we continue with remainder of original list - self.assertEqual(len(list(plan)), 0 + 1) + assert len(list(plan)) == 0 + 1 policy.on_up(hosts[2]) policy.on_up(hosts[3]) @@ -430,7 +467,7 @@ def test_modification_during_generation(self): policy.on_up(new_host) policy.on_up(another_host) # we continue with remainder of original list - self.assertEqual(len(list(plan)), 0 + 1) + assert len(list(plan)), 0 + 1 # remote DC disappears after finishing it plan = policy.make_query_plan() @@ -444,9 +481,9 @@ def test_modification_during_generation(self): for h in down_hosts: policy.on_down(h) # the last DC has two - self.assertEqual(len(list(plan)), 0 + 2) + assert len(list(plan)), 0 + 2 - def test_no_live_nodes(self): + def test_no_live_nodes(self, policy_specialization, constructor_args): """ Ensure query plan for a downed cluster will execute without errors """ @@ -457,25 +494,37 @@ def test_no_live_nodes(self): h.set_location_info("dc1", "rack1") hosts.append(h) - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(Mock(), hosts) for host in hosts: policy.on_down(host) qplan = list(policy.make_query_plan()) - self.assertEqual(qplan, []) + assert qplan == [] - def test_no_nodes(self): + def test_no_nodes(self, policy_specialization, constructor_args): """ Ensure query plan for an empty cluster will execute without errors """ - policy = DCAwareRoundRobinPolicy("dc1", used_hosts_per_remote_dc=1) + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=1) policy.populate(None, []) qplan = list(policy.make_query_plan()) - self.assertEqual(qplan, []) + assert qplan == [] + + def test_wrong_dc(self, policy_specialization, constructor_args): + hosts = [Host(DefaultEndPoint(i), SimpleConvictionPolicy) for i in range(3)] + for h in hosts[:3]: + h.set_location_info("dc2", "rack2") + + policy = policy_specialization(*constructor_args, used_hosts_per_remote_dc=0) + policy.populate(Mock(), hosts) + qplan = list(policy.make_query_plan()) + assert len(qplan) == 0 + +class DCAwareRoundRobinPolicyTest(unittest.TestCase): def test_default_dc(self): host_local = Host(DefaultEndPoint(1), SimpleConvictionPolicy, 'local') @@ -488,35 +537,34 @@ def test_default_dc(self): # contact DC first policy = DCAwareRoundRobinPolicy() policy.populate(cluster, [host_none]) - self.assertFalse(policy.local_dc) + assert not policy.local_dc policy.on_add(host_local) policy.on_add(host_remote) - self.assertNotEqual(policy.local_dc, host_remote.datacenter) - self.assertEqual(policy.local_dc, host_local.datacenter) + assert policy.local_dc != host_remote.datacenter + assert policy.local_dc == host_local.datacenter # contact DC second policy = DCAwareRoundRobinPolicy() policy.populate(cluster, [host_none]) - self.assertFalse(policy.local_dc) + assert not policy.local_dc policy.on_add(host_remote) policy.on_add(host_local) - self.assertNotEqual(policy.local_dc, host_remote.datacenter) - self.assertEqual(policy.local_dc, host_local.datacenter) + assert policy.local_dc != host_remote.datacenter + assert policy.local_dc == host_local.datacenter # no DC policy = DCAwareRoundRobinPolicy() policy.populate(cluster, [host_none]) - self.assertFalse(policy.local_dc) + assert not policy.local_dc policy.on_add(host_none) - self.assertFalse(policy.local_dc) + assert not policy.local_dc # only other DC policy = DCAwareRoundRobinPolicy() policy.populate(cluster, [host_none]) - self.assertFalse(policy.local_dc) + assert not policy.local_dc policy.on_add(host_remote) - self.assertFalse(policy.local_dc) - + assert not policy.local_dc class TokenAwarePolicyTest(unittest.TestCase): @@ -1274,7 +1322,7 @@ def test_hosts_with_hostname(self): self.assertEqual(sorted(qplan), [host]) self.assertEqual(policy.distance(host), HostDistance.LOCAL) - + def test_hosts_with_socket_hostname(self): hosts = [UnixSocketEndPoint('/tmp/scylla-workdir/cql.m')] policy = WhiteListRoundRobinPolicy(hosts)