diff --git a/conftest.py b/conftest.py index 22002bb..9b757a2 100644 --- a/conftest.py +++ b/conftest.py @@ -89,8 +89,27 @@ def start_consul_container(version, acl_master_token=None): merged_config = {**base_config, **acl_config} docker_config["environment"]["CONSUL_LOCAL_CONFIG"] = json.dumps(merged_config) - container = client.containers.run( - f"hashicorp/consul:{version}", command="agent -dev -client=0.0.0.0 -log-level trace", **docker_config + def start_consul_container_with_retry(client, command, version, docker_config, max_retries=3, retry_delay=2): # pylint: disable=inconsistent-return-statements + """ + Start a Consul container with retries as a few initial attempts sometimes fail. + """ + for attempt in range(max_retries): + try: + container = client.containers.run(f"hashicorp/consul:{version}", command=command, **docker_config) + return container + except docker.errors.APIError: + # Cleanup that stray container as it might cause a naming conflict + try: + container = client.containers.get(docker_config["name"]) + container.remove(force=True) + except docker.errors.NotFound: + pass + if attempt == max_retries - 1: + raise + time.sleep(retry_delay) + + container = start_consul_container_with_retry( + client, command="agent -dev -client=0.0.0.0 -log-level trace", version=version, docker_config=docker_config ) # Wait for Consul to be ready diff --git a/consul/aio.py b/consul/aio.py index d3e1ae3..7bcf335 100644 --- a/consul/aio.py +++ b/consul/aio.py @@ -1,4 +1,5 @@ import asyncio +from typing import Dict, Optional import aiohttp @@ -23,33 +24,39 @@ def __init__(self, *args, loop=None, connections_limit=None, connections_timeout session_kwargs["timeout"] = timeout self._session = aiohttp.ClientSession(connector=connector, **session_kwargs) - async def _request(self, callback, method, uri, data=None, connections_timeout=None): + async def _request( + self, callback, method, uri, headers: Optional[Dict[str, str]], data=None, connections_timeout=None + ): session_kwargs = {} if connections_timeout: timeout = aiohttp.ClientTimeout(total=connections_timeout) session_kwargs["timeout"] = timeout - resp = await self._session.request(method, uri, data=data, **session_kwargs) + resp = await self._session.request(method, uri, headers=headers, data=data, **session_kwargs) body = await resp.text(encoding="utf-8") if resp.status == 599: raise Timeout r = base.Response(resp.status, resp.headers, body) return callback(r) - def get(self, callback, path, params=None, connections_timeout=None): + def get(self, callback, path, params=None, headers: Optional[Dict[str, str]] = None, connections_timeout=None): uri = self.uri(path, params) - return self._request(callback, "GET", uri, connections_timeout=connections_timeout) + return self._request(callback, "GET", uri, headers=headers, connections_timeout=connections_timeout) - def put(self, callback, path, params=None, data="", connections_timeout=None): + def put( + self, callback, path, params=None, data="", headers: Optional[Dict[str, str]] = None, connections_timeout=None + ): uri = self.uri(path, params) - return self._request(callback, "PUT", uri, data=data, connections_timeout=connections_timeout) + return self._request(callback, "PUT", uri, headers=headers, data=data, connections_timeout=connections_timeout) - def delete(self, callback, path, params=None, connections_timeout=None): + def delete(self, callback, path, params=None, headers: Optional[Dict[str, str]] = None, connections_timeout=None): uri = self.uri(path, params) - return self._request(callback, "DELETE", uri, connections_timeout=connections_timeout) + return self._request(callback, "DELETE", uri, headers=headers, connections_timeout=connections_timeout) - def post(self, callback, path, params=None, data="", connections_timeout=None): + def post( + self, callback, path, params=None, data="", headers: Optional[Dict[str, str]] = None, connections_timeout=None + ): uri = self.uri(path, params) - return self._request(callback, "POST", uri, data=data, connections_timeout=connections_timeout) + return self._request(callback, "POST", uri, headers=headers, data=data, connections_timeout=connections_timeout) def close(self): return self._session.close() diff --git a/consul/api/acl/policy.py b/consul/api/acl/policy.py index 479707f..75f9941 100644 --- a/consul/api/acl/policy.py +++ b/consul/api/acl/policy.py @@ -15,10 +15,9 @@ def list(self, token=None): Requires a token with acl:read capability. ACLPermissionDenied raised otherwise """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.get(CB.json(), "/v1/acl/policies", params=params) + + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), "/v1/acl/policies", params=params, headers=headers) def read(self, uuid, token=None): """ @@ -28,10 +27,8 @@ def read(self, uuid, token=None): :return: selected Polic information """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.get(CB.json(), f"/v1/acl/policy/{uuid}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), f"/v1/acl/policy/{uuid}", params=params, headers=headers) def create(self, name, token=None, description=None, rules=None): """ @@ -44,17 +41,16 @@ def create(self, name, token=None, description=None, rules=None): :return: The cloned token information """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) json_data = {"name": name} if rules: json_data["rules"] = json.dumps(rules) if description: json_data["Description"] = description + headers = self.agent.prepare_headers(token) return self.agent.http.put( CB.json(), "/v1/acl/policy", params=params, + headers=headers, data=json.dumps(json_data), ) diff --git a/consul/api/acl/token.py b/consul/api/acl/token.py index 1a686bb..0f698d2 100644 --- a/consul/api/acl/token.py +++ b/consul/api/acl/token.py @@ -15,10 +15,8 @@ def list(self, token=None): Requires a token with acl:read capability. ACLPermissionDenied raised otherwise """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.get(CB.json(), "/v1/acl/tokens", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), "/v1/acl/tokens", params=params, headers=headers) def read(self, accessor_id, token=None): """ @@ -28,10 +26,8 @@ def read(self, accessor_id, token=None): :return: selected token information """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.get(CB.json(), f"/v1/acl/token/{accessor_id}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), f"/v1/acl/token/{accessor_id}", params=params, headers=headers) def delete(self, accessor_id, token=None): """ @@ -41,10 +37,8 @@ def delete(self, accessor_id, token=None): :return: True if the token was deleted """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.delete(CB.bool(), f"/v1/acl/token/{accessor_id}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.delete(CB.bool(), f"/v1/acl/token/{accessor_id}", params=params, headers=headers) def clone(self, accessor_id, token=None, description=""): """ @@ -55,15 +49,14 @@ def clone(self, accessor_id, token=None, description=""): :return: The cloned token information """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) json_data = {"Description": description} + headers = self.agent.prepare_headers(token) return self.agent.http.put( CB.json(), f"/v1/acl/token/{accessor_id}/clone", params=params, + headers=headers, data=json.dumps(json_data), ) @@ -79,9 +72,6 @@ def create(self, token=None, accessor_id=None, secret_id=None, policies_id=None, :return: The cloned token information """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) json_data = {} if accessor_id: @@ -93,10 +83,12 @@ def create(self, token=None, accessor_id=None, secret_id=None, policies_id=None, if policies_id: json_data["Policies"] = [{"ID": policy} for policy in policies_id] + headers = self.agent.prepare_headers(token) return self.agent.http.put( CB.json(), "/v1/acl/token", params=params, + headers=headers, data=json.dumps(json_data), ) @@ -111,18 +103,17 @@ def update(self, accessor_id, token=None, secret_id=None, description=""): :return: The updated token information """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) json_data = {"AccessorID": accessor_id} if secret_id: json_data["SecretID"] = secret_id if description: json_data["Description"] = description + headers = self.agent.prepare_headers(token) return self.agent.http.put( CB.json(), f"/v1/acl/token/{accessor_id}", params=params, + headers=headers, data=json.dumps(json_data), ) diff --git a/consul/api/agent.py b/consul/api/agent.py index ba02389..df999b2 100644 --- a/consul/api/agent.py +++ b/consul/api/agent.py @@ -89,11 +89,9 @@ def maintenance(self, enable, reason=None, token=None): params.append(("enable", enable)) if reason: params.append(("reason", reason)) - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.put(CB.bool(), "/v1/agent/maintenance", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.put(CB.bool(), "/v1/agent/maintenance", params=params, headers=headers) def join(self, address, wan=False, token=None): """ @@ -111,11 +109,8 @@ def join(self, address, wan=False, token=None): if wan: params.append(("wan", 1)) - token = token or self.agent.token - if token: - params.append(("token", token)) - - return self.agent.http.put(CB.bool(), f"/v1/agent/join/{address}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.put(CB.bool(), f"/v1/agent/join/{address}", params=params, headers=headers) def force_leave(self, node, token=None): """ @@ -131,11 +126,8 @@ def force_leave(self, node, token=None): params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - - return self.agent.http.put(CB.bool(), f"/v1/agent/force-leave/{node}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.put(CB.bool(), f"/v1/agent/force-leave/{node}", params=params, headers=headers) class Service: def __init__(self, agent): @@ -231,11 +223,10 @@ def register( ) params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - - return self.agent.http.put(CB.bool(), "/v1/agent/service/register", params=params, data=json.dumps(payload)) + headers = self.agent.prepare_headers(token) + return self.agent.http.put( + CB.bool(), "/v1/agent/service/register", params=params, headers=headers, data=json.dumps(payload) + ) def deregister(self, service_id, token=None): """ @@ -244,11 +235,11 @@ def deregister(self, service_id, token=None): there is an associated check, that is also deregistered. """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) + headers = self.agent.prepare_headers(token) - return self.agent.http.put(CB.bool(), f"/v1/agent/service/deregister/{service_id}", params=params) + return self.agent.http.put( + CB.bool(), f"/v1/agent/service/deregister/{service_id}", params=params, headers=headers + ) def maintenance(self, service_id, enable, reason=None, token=None): """ @@ -271,11 +262,11 @@ def maintenance(self, service_id, enable, reason=None, token=None): if reason: params.append(("reason", reason)) - token = token or self.agent.token - if token: - params.append(("token", token)) + headers = self.agent.prepare_headers(token) - return self.agent.http.put(CB.bool(), f"/v1/agent/service/maintenance/{service_id}", params=params) + return self.agent.http.put( + CB.bool(), f"/v1/agent/service/maintenance/{service_id}", params=params, headers=headers + ) class Check: def __init__(self, agent): @@ -344,22 +335,21 @@ def register( payload["serviceid"] = service_id params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - - return self.agent.http.put(CB.bool(), "/v1/agent/check/register", params=params, data=json.dumps(payload)) + headers = self.agent.prepare_headers(token) + return self.agent.http.put( + CB.bool(), "/v1/agent/check/register", params=params, headers=headers, data=json.dumps(payload) + ) def deregister(self, check_id, token=None): """ Remove a check from the local agent. """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) + headers = self.agent.prepare_headers(token) - return self.agent.http.put(CB.bool(), f"/v1/agent/check/deregister/{check_id}", params=params) + return self.agent.http.put( + CB.bool(), f"/v1/agent/check/deregister/{check_id}", params=params, headers=headers + ) def ttl_pass(self, check_id, notes=None, token=None): """ @@ -369,11 +359,9 @@ def ttl_pass(self, check_id, notes=None, token=None): params = [] if notes: params.append(("note", notes)) - token = token or self.agent.token - if token: - params.append(("token", token)) + headers = self.agent.prepare_headers(token) - return self.agent.http.put(CB.bool(), f"/v1/agent/check/pass/{check_id}", params=params) + return self.agent.http.put(CB.bool(), f"/v1/agent/check/pass/{check_id}", params=params, headers=headers) def ttl_fail(self, check_id, notes=None, token=None): """ @@ -384,11 +372,9 @@ def ttl_fail(self, check_id, notes=None, token=None): params = [] if notes: params.append(("note", notes)) - token = token or self.agent.token - if token: - params.append(("token", token)) + headers = self.agent.prepare_headers(token) - return self.agent.http.put(CB.bool(), f"/v1/agent/check/fail/{check_id}", params=params) + return self.agent.http.put(CB.bool(), f"/v1/agent/check/fail/{check_id}", params=params, headers=headers) def ttl_warn(self, check_id, notes=None, token=None): """ @@ -399,11 +385,9 @@ def ttl_warn(self, check_id, notes=None, token=None): params = [] if notes: params.append(("note", notes)) - token = token or self.agent.token - if token: - params.append(("token", token)) + headers = self.agent.prepare_headers(token) - return self.agent.http.put(CB.bool(), f"/v1/agent/check/warn/{check_id}", params=params) + return self.agent.http.put(CB.bool(), f"/v1/agent/check/warn/{check_id}", params=params, headers=headers) class Connect: def __init__(self, agent): @@ -429,12 +413,10 @@ def authorize(self, target, client_cert_uri, client_cert_serial, token=None): payload = {"Target": target, "ClientCertURI": client_cert_uri, "ClientCertSerial": client_cert_serial} params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) + headers = self.agent.prepare_headers(token) return self.agent.http.put( - CB.json(), "/v1/agent/connect/authorize", params=params, data=json.dumps(payload) + CB.json(), "/v1/agent/connect/authorize", params=params, headers=headers, data=json.dumps(payload) ) class CA: @@ -446,8 +428,8 @@ def roots(self): def leaf(self, service, token=None): params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) + headers = self.agent.prepare_headers(token) - return self.agent.http.get(CB.json(), f"/v1/agent/connect/ca/leaf/{service}", params=params) + return self.agent.http.get( + CB.json(), f"/v1/agent/connect/ca/leaf/{service}", params=params, headers=headers + ) diff --git a/consul/api/catalog.py b/consul/api/catalog.py index e868811..8bba8f3 100644 --- a/consul/api/catalog.py +++ b/consul/api/catalog.py @@ -82,7 +82,11 @@ def register(self, node, address, service=None, check=None, dc=None, token=None, if node_meta: for nodemeta_name, nodemeta_value in node_meta.items(): params.append(("node-meta", f"{nodemeta_name}:{nodemeta_value}")) - return self.agent.http.put(CB.bool(), "/v1/catalog/register", data=json.dumps(data), params=params) + + headers = self.agent.prepare_headers(token) + return self.agent.http.put( + CB.bool(), "/v1/catalog/register", data=json.dumps(data), params=params, headers=headers + ) def deregister(self, node, service_id=None, check_id=None, dc=None, token=None): """ @@ -112,7 +116,8 @@ def deregister(self, node, service_id=None, check_id=None, dc=None, token=None): token = token or self.agent.token if token: data["WriteRequest"] = {"Token": token} - return self.agent.http.put(CB.bool(), "/v1/catalog/deregister", data=json.dumps(data)) + headers = self.agent.prepare_headers(token) + return self.agent.http.put(CB.bool(), "/v1/catalog/deregister", headers=headers, data=json.dumps(data)) def datacenters(self): """ @@ -168,16 +173,15 @@ def nodes(self, index=None, wait=None, consistency=None, dc=None, near=None, tok params.append(("wait", wait)) if near: params.append(("near", near)) - token = token or self.agent.token - if token: - params.append(("token", token)) + consistency = consistency or self.agent.consistency if consistency in ("consistent", "stale"): params.append((consistency, "1")) if node_meta: for nodemeta_name, nodemeta_value in node_meta.items(): params.append(("node-meta", f"{nodemeta_name}:{nodemeta_value}")) - return self.agent.http.get(CB.json(index=True), "/v1/catalog/nodes", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(index=True), "/v1/catalog/nodes", params=params, headers=headers) def services(self, index=None, wait=None, consistency=None, dc=None, token=None, node_meta=None): """ @@ -223,16 +227,14 @@ def services(self, index=None, wait=None, consistency=None, dc=None, token=None, params.append(("index", index)) if wait: params.append(("wait", wait)) - token = token or self.agent.token - if token: - params.append(("token", token)) consistency = consistency or self.agent.consistency if consistency in ("consistent", "stale"): params.append((consistency, "1")) if node_meta: for nodemeta_name, nodemeta_value in node_meta.items(): params.append(("node-meta", f"{nodemeta_name}:{nodemeta_value}")) - return self.agent.http.get(CB.json(index=True), "/v1/catalog/services", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(index=True), "/v1/catalog/services", params=params, headers=headers) def node(self, node, index=None, wait=None, consistency=None, dc=None, token=None): """ @@ -288,13 +290,11 @@ def node(self, node, index=None, wait=None, consistency=None, dc=None, token=Non params.append(("index", index)) if wait: params.append(("wait", wait)) - token = token or self.agent.token - if token: - params.append(("token", token)) consistency = consistency or self.agent.consistency if consistency in ("consistent", "stale"): params.append((consistency, "1")) - return self.agent.http.get(CB.json(index=True), f"/v1/catalog/node/{node}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(index=True), f"/v1/catalog/node/{node}", params=params, headers=headers) def _service( self, @@ -320,16 +320,14 @@ def _service( params.append(("wait", wait)) if near: params.append(("near", near)) - token = token or self.agent.token - if token: - params.append(("token", token)) consistency = consistency or self.agent.consistency if consistency in ("consistent", "stale"): params.append((consistency, "1")) if node_meta: for nodemeta_name, nodemeta_value in node_meta.items(): params.append(("node-meta", f"{nodemeta_name}:{nodemeta_value}")) - return self.agent.http.get(CB.json(index=True), internal_uri, params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(index=True), internal_uri, params=params, headers=headers) def service(self, service, **kwargs): """ diff --git a/consul/api/connect.py b/consul/api/connect.py index c29bd29..768d6ee 100644 --- a/consul/api/connect.py +++ b/consul/api/connect.py @@ -13,16 +13,12 @@ def __init__(self, agent): def roots(self, pem=False, token=None): params = [] params.append(("pem", int(pem))) - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.get(CB.json(), "/v1/connect/ca/roots", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), "/v1/connect/ca/roots", params=params, headers=headers) def configuration(self, token=None): params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.get(CB.json(), "/v1/connect/ca/configuration", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), "/v1/connect/ca/configuration", params=params, headers=headers) diff --git a/consul/api/event.py b/consul/api/event.py index f548a26..9bf985f 100644 --- a/consul/api/event.py +++ b/consul/api/event.py @@ -50,11 +50,9 @@ def fire(self, name, body="", node=None, service=None, tag=None, token=None): params.append(("service", service)) if tag is not None: params.append(("tag", tag)) - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.put(CB.json(), f"/v1/event/fire/{name}", params=params, data=body) + headers = self.agent.prepare_headers(token) + return self.agent.http.put(CB.json(), f"/v1/event/fire/{name}", params=params, headers=headers, data=body) def list(self, name=None, index=None, wait=None): """ diff --git a/consul/api/health.py b/consul/api/health.py index 1d40686..97ba229 100644 --- a/consul/api/health.py +++ b/consul/api/health.py @@ -35,13 +35,11 @@ def _service( params.append(("dc", dc)) if near: params.append(("near", near)) - token = token or self.agent.token - if token: - params.append(("token", token)) if node_meta: for nodemeta_name, nodemeta_value in node_meta.items(): params.append(("node-meta", f"{nodemeta_name}:{nodemeta_value}")) - return self.agent.http.get(CB.json(index=True), internal_uri, params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(index=True), internal_uri, params=params, headers=headers) def service(self, service, **kwargs): """ @@ -122,13 +120,11 @@ def checks(self, service, index=None, wait=None, dc=None, near=None, token=None, params.append(("dc", dc)) if near: params.append(("near", near)) - token = token or self.agent.token - if token: - params.append(("token", token)) if node_meta: for nodemeta_name, nodemeta_value in node_meta.items(): params.append(("node-meta", f"{nodemeta_name}:{nodemeta_value}")) - return self.agent.http.get(CB.json(index=True), f"/v1/health/checks/{service}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(index=True), f"/v1/health/checks/{service}", params=params, headers=headers) def state(self, name, index=None, wait=None, dc=None, near=None, token=None, node_meta=None): """ @@ -171,13 +167,11 @@ def state(self, name, index=None, wait=None, dc=None, near=None, token=None, nod params.append(("dc", dc)) if near: params.append(("near", near)) - token = token or self.agent.token - if token: - params.append(("token", token)) if node_meta: for nodemeta_name, nodemeta_value in node_meta.items(): params.append(("node-meta", f"{nodemeta_name}:{nodemeta_value}")) - return self.agent.http.get(CB.json(index=True), f"/v1/health/state/{name}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(index=True), f"/v1/health/state/{name}", params=params, headers=headers) def node(self, node, index=None, wait=None, dc=None, token=None): """ @@ -205,8 +199,6 @@ def node(self, node, index=None, wait=None, dc=None, token=None): dc = dc or self.agent.dc if dc: params.append(("dc", dc)) - token = token or self.agent.token - if token: - params.append(("token", token)) - return self.agent.http.get(CB.json(index=True), f"/v1/health/node/{node}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(index=True), f"/v1/health/node/{node}", params=params, headers=headers) diff --git a/consul/api/kv.py b/consul/api/kv.py index 567366a..b162d3f 100644 --- a/consul/api/kv.py +++ b/consul/api/kv.py @@ -71,9 +71,6 @@ def get( params.append(("wait", wait)) if recurse: params.append(("recurse", "1")) - token = token or self.agent.token - if token: - params.append(("token", token)) dc = dc or self.agent.dc if dc: params.append(("dc", dc)) @@ -95,8 +92,10 @@ def get( http_kwargs = {} if connections_timeout: http_kwargs["connections_timeout"] = connections_timeout + + headers = self.agent.prepare_headers(token) return self.agent.http.get( - CB.json(index=True, decode=decode, one=one), f"/v1/kv/{key}", params=params, **http_kwargs + CB.json(index=True, decode=decode, one=one), f"/v1/kv/{key}", params=params, headers=headers, **http_kwargs ) def put( @@ -156,16 +155,16 @@ def put( params.append(("acquire", acquire)) if release: params.append(("release", release)) - token = token or self.agent.token - if token: - params.append(("token", token)) dc = dc or self.agent.dc if dc: params.append(("dc", dc)) http_kwargs = {} if connections_timeout: http_kwargs["connections_timeout"] = connections_timeout - return self.agent.http.put(CB.json(), f"/v1/kv/{key}", params=params, data=value, **http_kwargs) + headers = self.agent.prepare_headers(token) + return self.agent.http.put( + CB.json(), f"/v1/kv/{key}", params=params, headers=headers, data=value, **http_kwargs + ) def delete(self, key, recurse=None, cas=None, token=None, dc=None, connections_timeout=None): """ @@ -193,13 +192,11 @@ def delete(self, key, recurse=None, cas=None, token=None, dc=None, connections_t params.append(("recurse", "1")) if cas is not None: params.append(("cas", cas)) - token = token or self.agent.token - if token: - params.append(("token", token)) dc = dc or self.agent.dc if dc: params.append(("dc", dc)) http_kwargs = {} if connections_timeout: http_kwargs["connections_timeout"] = connections_timeout - return self.agent.http.delete(CB.json(), f"/v1/kv/{key}", params=params, **http_kwargs) + headers = self.agent.prepare_headers(token) + return self.agent.http.delete(CB.json(), f"/v1/kv/{key}", params=params, headers=headers, **http_kwargs) diff --git a/consul/api/query.py b/consul/api/query.py index b34a8f3..b9d01c0 100644 --- a/consul/api/query.py +++ b/consul/api/query.py @@ -19,13 +19,11 @@ def list(self, dc=None, token=None): *token* is an optional `ACL token`_ to apply to this request. """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) if dc: params.append(("dc", dc)) - return self.agent.http.get(CB.json(), "/v1/query", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), "/v1/query", params=params, headers=headers) def _query_data( self, @@ -165,12 +163,10 @@ def get(self, query_id, token=None, dc=None): default the datacenter of the host is used. """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) if dc: params.append(("dc", dc)) - return self.agent.http.get(CB.json(), f"/v1/query/{query_id}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), f"/v1/query/{query_id}", params=params, headers=headers) def delete(self, query_id, token=None, dc=None): """ @@ -184,12 +180,10 @@ def delete(self, query_id, token=None, dc=None): default the datacenter of the host is used. """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) if dc: params.append(("dc", dc)) - return self.agent.http.delete(CB.bool(), f"/v1/query/{query_id}", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.delete(CB.bool(), f"/v1/query/{query_id}", params=params, headers=headers) def execute(self, query, token=None, dc=None, near=None, limit=None): """ @@ -209,16 +203,14 @@ def execute(self, query, token=None, dc=None, near=None, limit=None): of nodes. This is applied after any sorting or shuffling. """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) if dc: params.append(("dc", dc)) if near: params.append(("near", near)) if limit: params.append(("limit", limit)) - return self.agent.http.get(CB.json(), f"/v1/query/{query}/execute", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), f"/v1/query/{query}/execute", params=params, headers=headers) def explain(self, query, token=None, dc=None): """ @@ -232,9 +224,7 @@ def explain(self, query, token=None, dc=None): default the datacenter of the host is used. """ params = [] - token = token or self.agent.token - if token: - params.append(("token", token)) if dc: params.append(("dc", dc)) - return self.agent.http.get(CB.json(), f"/v1/query/{query}/explain", params=params) + headers = self.agent.prepare_headers(token) + return self.agent.http.get(CB.json(), f"/v1/query/{query}/explain", params=params, headers=headers) diff --git a/consul/base.py b/consul/base.py index dd62550..1c9df53 100644 --- a/consul/base.py +++ b/consul/base.py @@ -3,6 +3,7 @@ import logging import os import urllib +from typing import Dict, Optional from consul.api.acl import ACL from consul.api.agent import Agent @@ -45,19 +46,19 @@ def uri(self, path, params=None): return uri @abc.abstractmethod - def get(self, callback, path, params=None): + def get(self, callback, path, params=None, headers: Optional[Dict[str, str]] = None): raise NotImplementedError @abc.abstractmethod - def put(self, callback, path, params=None, data=""): + def put(self, callback, path, params=None, data="", headers: Optional[Dict[str, str]] = None): raise NotImplementedError @abc.abstractmethod - def delete(self, callback, path, params=None): + def delete(self, callback, path, params=None, headers: Optional[Dict[str, str]] = None): raise NotImplementedError @abc.abstractmethod - def post(self, callback, path, params=None, data=""): + def post(self, callback, path, params=None, data="", headers: Optional[Dict[str, str]] = None): raise NotImplementedError @abc.abstractmethod @@ -151,3 +152,9 @@ async def __aexit__(self, exc_type, exc, tb): @abc.abstractmethod def http_connect(self, host, port, scheme, verify=True, cert=None): pass + + def prepare_headers(self, token: Optional[str] = None) -> Dict[str, str]: + headers = {} + if token or self.token: + headers["X-Consul-Token"] = token or self.token + return headers diff --git a/consul/std.py b/consul/std.py index c6f03fb..5c2e159 100644 --- a/consul/std.py +++ b/consul/std.py @@ -1,3 +1,5 @@ +from typing import Dict, Optional + import requests from consul import base @@ -14,21 +16,25 @@ def response(self, response): response.encoding = "utf-8" return base.Response(response.status_code, response.headers, response.text) - def get(self, callback, path, params=None): + def get(self, callback, path, params=None, headers: Optional[Dict[str, str]] = None): uri = self.uri(path, params) - return callback(self.response(self.session.get(uri, verify=self.verify, cert=self.cert))) + return callback(self.response(self.session.get(uri, headers=headers, verify=self.verify, cert=self.cert))) - def put(self, callback, path, params=None, data=""): + def put(self, callback, path, params=None, data="", headers: Optional[Dict[str, str]] = None): uri = self.uri(path, params) - return callback(self.response(self.session.put(uri, data=data, verify=self.verify, cert=self.cert))) + return callback( + self.response(self.session.put(uri, headers=headers, data=data, verify=self.verify, cert=self.cert)) + ) - def delete(self, callback, path, params=None): + def delete(self, callback, path, params=None, headers: Optional[Dict[str, str]] = None): uri = self.uri(path, params) - return callback(self.response(self.session.delete(uri, verify=self.verify, cert=self.cert))) + return callback(self.response(self.session.delete(uri, headers=headers, verify=self.verify, cert=self.cert))) - def post(self, callback, path, params=None, data=""): + def post(self, callback, path, params=None, data="", headers: Optional[Dict[str, str]] = None): uri = self.uri(path, params) - return callback(self.response(self.session.post(uri, data=data, verify=self.verify, cert=self.cert))) + return callback( + self.response(self.session.post(uri, headers=headers, data=data, verify=self.verify, cert=self.cert)) + ) def close(self): pass diff --git a/tests/test_base.py b/tests/test_base.py index 079c63d..f448701 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -6,21 +6,21 @@ import consul import consul.check -Request = collections.namedtuple("Request", ["method", "path", "params", "data"]) +Request = collections.namedtuple("Request", ["method", "path", "params", "headers", "data"]) class HTTPClient: def __init__(self, host=None, port=None, scheme=None, verify=True, cert=None): pass - def get(self, callback, path, params=None): # pylint: disable=unused-argument - return Request("get", path, params, None) + def get(self, callback, path, params=None, headers=None): # pylint: disable=unused-argument + return Request("get", path, params, headers, None) - def put(self, callback, path, params=None, data=""): # pylint: disable=unused-argument - return Request("put", path, params, data) + def put(self, callback, path, params=None, headers=None, data=""): # pylint: disable=unused-argument + return Request("put", path, params, headers, data) - def delete(self, callback, path, params=None): # pylint: disable=unused-argument - return Request("delete", path, params, None) + def delete(self, callback, path, params=None, headers=None): # pylint: disable=unused-argument + return Request("delete", path, params, headers, None) class Consul(consul.base.Consul):