Skip to content

Commit

Permalink
Support any user token types (#1053)
Browse files Browse the repository at this point in the history
**Pull Request Checklist**
- [x] Fixes #1016
- [x] Tests added
- [x] Documentation/examples added
- [x] [Good commit messages](https://cbea.ms/git-commit/) and/or PR
title

See #1016 for a description provided by users. This PR adds support for
user specified Bearer tokens while providing a backwards compatible
experience.

---------

Signed-off-by: Flaviu Vadan <[email protected]>
Co-authored-by: Sambhav Kothari <[email protected]>
  • Loading branch information
flaviuvadan and sambhav authored May 7, 2024
1 parent c761cbe commit b841f6c
Show file tree
Hide file tree
Showing 4 changed files with 172 additions and 78 deletions.
24 changes: 22 additions & 2 deletions scripts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,7 @@ def __str__(self) -> str:
params = "None"

# headers
headers = "{'Authorization': f'Bearer {self.token}'"
headers = "{'Authorization': self.token"
if self.method.lower() == "post" or self.method.lower() == "put":
headers += f", 'Content-Type': '{self.consumes}'"
headers += "}"
Expand Down Expand Up @@ -424,7 +424,27 @@ def __init__(
\"\"\"{models_type} service constructor.\"\"\"
self.host = cast(str, host or global_config.host)
self.verify_ssl = verify_ssl if verify_ssl is not None else global_config.verify_ssl
self.token = token or global_config.token
# some users reported in https://github.com/argoproj-labs/hera/issues/1016 that it can be a bit awkward for
# Hera to assume a `Bearer` prefix on behalf of users. Some might pass it and some might not. Therefore, Hera
# only prefixes the token with `Bearer ` if it's not already specified and lets the uses specify it otherwise.
# Note that the `Bearer` token can be specified through the global configuration as well. In order to deliver
# a fix on Hera V5 without introducing breaking changes, we have to support both
global_config_token = global_config.token # call only once because it can be a user specified function!
def format_token(t):
parts = t.strip().split()
if len(parts) == 1:
return "Bearer " + t
return t
if token:
self.token: Optional[str] = format_token(token)
elif global_config_token:
self.token = format_token(global_config_token)
else:
self.token = None
self.namespace = namespace or global_config.namespace
"""

Expand Down
70 changes: 45 additions & 25 deletions src/hera/events/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,27 @@ def __init__(
"""Events service constructor."""
self.host = cast(str, host or global_config.host)
self.verify_ssl = verify_ssl if verify_ssl is not None else global_config.verify_ssl
self.token = token or global_config.token

# some users reported in https://github.com/argoproj-labs/hera/issues/1016 that it can be a bit awkward for
# Hera to assume a `Bearer` prefix on behalf of users. Some might pass it and some might not. Therefore, Hera
# only prefixes the token with `Bearer ` if it's not already specified and lets the uses specify it otherwise.
# Note that the `Bearer` token can be specified through the global configuration as well. In order to deliver
# a fix on Hera V5 without introducing breaking changes, we have to support both
global_config_token = global_config.token # call only once because it can be a user specified function!

def format_token(t):
parts = t.strip().split()
if len(parts) == 1:
return "Bearer " + t
return t

if token:
self.token: Optional[str] = format_token(token)
elif global_config_token:
self.token = format_token(global_config_token)
else:
self.token = None

self.namespace = namespace or global_config.namespace

def list_event_sources(
Expand Down Expand Up @@ -83,7 +103,7 @@ def list_event_sources(
"listOptions.limit": limit,
"listOptions.continue": continue_,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -101,7 +121,7 @@ def create_event_source(self, req: CreateEventSourceRequest, namespace: Optional
namespace=namespace if namespace is not None else self.namespace
),
params=None,
headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"},
headers={"Authorization": self.token, "Content-Type": "application/json"},
data=req.json(
exclude_none=True, by_alias=True, skip_defaults=True, exclude_unset=True, exclude_defaults=True
),
Expand All @@ -121,7 +141,7 @@ def get_event_source(self, name: str, namespace: Optional[str] = None) -> EventS
name=name, namespace=namespace if namespace is not None else self.namespace
),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -141,7 +161,7 @@ def update_event_source(
name=name, namespace=namespace if namespace is not None else self.namespace
),
params=None,
headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"},
headers={"Authorization": self.token, "Content-Type": "application/json"},
data=req.json(
exclude_none=True, by_alias=True, skip_defaults=True, exclude_unset=True, exclude_defaults=True
),
Expand Down Expand Up @@ -178,7 +198,7 @@ def delete_event_source(
"deleteOptions.propagationPolicy": propagation_policy,
"deleteOptions.dryRun": dry_run,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -196,7 +216,7 @@ def receive_event(self, discriminator: str, req: Item, namespace: Optional[str]
discriminator=discriminator, namespace=namespace if namespace is not None else self.namespace
),
params=None,
headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"},
headers={"Authorization": self.token, "Content-Type": "application/json"},
data=req.json(
exclude_none=True, by_alias=True, skip_defaults=True, exclude_unset=True, exclude_defaults=True
),
Expand All @@ -214,7 +234,7 @@ def get_info(self) -> InfoResponse:
resp = requests.get(
url=urljoin(self.host, "api/v1/info"),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand Down Expand Up @@ -254,7 +274,7 @@ def list_sensors(
"listOptions.limit": limit,
"listOptions.continue": continue_,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -272,7 +292,7 @@ def create_sensor(self, req: CreateSensorRequest, namespace: Optional[str] = Non
namespace=namespace if namespace is not None else self.namespace
),
params=None,
headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"},
headers={"Authorization": self.token, "Content-Type": "application/json"},
data=req.json(
exclude_none=True, by_alias=True, skip_defaults=True, exclude_unset=True, exclude_defaults=True
),
Expand All @@ -292,7 +312,7 @@ def get_sensor(self, name: str, namespace: Optional[str] = None, resource_versio
name=name, namespace=namespace if namespace is not None else self.namespace
),
params={"getOptions.resourceVersion": resource_version},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -310,7 +330,7 @@ def update_sensor(self, name: str, req: UpdateSensorRequest, namespace: Optional
name=name, namespace=namespace if namespace is not None else self.namespace
),
params=None,
headers={"Authorization": f"Bearer {self.token}", "Content-Type": "application/json"},
headers={"Authorization": self.token, "Content-Type": "application/json"},
data=req.json(
exclude_none=True, by_alias=True, skip_defaults=True, exclude_unset=True, exclude_defaults=True
),
Expand Down Expand Up @@ -347,7 +367,7 @@ def delete_sensor(
"deleteOptions.propagationPolicy": propagation_policy,
"deleteOptions.dryRun": dry_run,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand Down Expand Up @@ -387,7 +407,7 @@ def watch_event_sources(
"listOptions.limit": limit,
"listOptions.continue": continue_,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand Down Expand Up @@ -437,7 +457,7 @@ def event_sources_logs(
"podLogOptions.limitBytes": limit_bytes,
"podLogOptions.insecureSkipTLSVerifyBackend": insecure_skip_tls_verify_backend,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand Down Expand Up @@ -477,7 +497,7 @@ def watch_events(
"listOptions.limit": limit,
"listOptions.continue": continue_,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand Down Expand Up @@ -517,7 +537,7 @@ def watch_sensors(
"listOptions.limit": limit,
"listOptions.continue": continue_,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand Down Expand Up @@ -565,7 +585,7 @@ def sensors_logs(
"podLogOptions.limitBytes": limit_bytes,
"podLogOptions.insecureSkipTLSVerifyBackend": insecure_skip_tls_verify_backend,
},
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -581,7 +601,7 @@ def get_user_info(self) -> GetUserInfoResponse:
resp = requests.get(
url=urljoin(self.host, "api/v1/userinfo"),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -597,7 +617,7 @@ def get_version(self) -> Version:
resp = requests.get(
url=urljoin(self.host, "api/v1/version"),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand Down Expand Up @@ -631,7 +651,7 @@ def get_artifact_file(
namespace=namespace if namespace is not None else self.namespace,
),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -649,7 +669,7 @@ def get_output_artifact_by_uid(self, uid: str, node_id: str, artifact_name: str)
uid=uid, nodeId=node_id, artifactName=artifact_name
),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -670,7 +690,7 @@ def get_output_artifact(self, name: str, node_id: str, artifact_name: str, names
namespace=namespace if namespace is not None else self.namespace,
),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -688,7 +708,7 @@ def get_input_artifact_by_uid(self, uid: str, node_id: str, artifact_name: str)
uid=uid, nodeId=node_id, artifactName=artifact_name
),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand All @@ -709,7 +729,7 @@ def get_input_artifact(self, name: str, node_id: str, artifact_name: str, namesp
namespace=namespace if namespace is not None else self.namespace,
),
params=None,
headers={"Authorization": f"Bearer {self.token}"},
headers={"Authorization": self.token},
data=None,
verify=self.verify_ssl,
)
Expand Down
Loading

0 comments on commit b841f6c

Please sign in to comment.