Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add keycloak SSO #5711

Merged
Show file tree
Hide file tree
Changes from 17 commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
ab86419
Add keycloak SSO
paulbauriegel Nov 22, 2024
8fd3b3b
Merge branch 'refactor/argilla-server/better-oauth2-integration' into…
paulbauriegel Nov 26, 2024
41bf238
Merge branch 'refactor/argilla-server/better-oauth2-integration' into…
frascuchon Dec 2, 2024
03b10e7
chore: Create keycloack backend
frascuchon Dec 2, 2024
7f319d2
Update docs for new backend
paulbauriegel Dec 2, 2024
0da655e
Merge branch 'refactor/argilla-server/better-oauth2-integration' into…
frascuchon Dec 3, 2024
bf2a0f6
feat: Configure role and workspaces from realm access roles
frascuchon Dec 3, 2024
174b9b9
Update User role after login
paulbauriegel Jan 16, 2025
8cddf2f
Update Keycloak Docs w. Mapper
paulbauriegel Jan 16, 2025
a474061
Use always the max rights role from the realm roles
paulbauriegel Jan 16, 2025
a25938e
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2025
c3c20eb
Add last_name to userinfo.py
paulbauriegel Jan 16, 2025
55c64b3
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 16, 2025
58bb93a
Update oauth2 handler to update workspaces if needed
paulbauriegel Jan 16, 2025
bc1084a
Merge branch 'refactor/argilla-server/better-oauth2-integration' into…
paulbauriegel Jan 22, 2025
7898ebb
Using flag to sync user from oauth info
frascuchon Jan 28, 2025
38ee8b7
Merge branch 'refactor/argilla-server/better-oauth2-integration' into…
frascuchon Jan 28, 2025
42fc36e
Update argilla-server/src/argilla_server/security/authentication/user…
paulbauriegel Jan 28, 2025
8395cb9
Update argilla-server/src/argilla_server/security/authentication/user…
paulbauriegel Jan 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
164 changes: 164 additions & 0 deletions argilla-frontend/components/features/login/components/KeycloakLogo.vue
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
<template>
<!--https://github.com/keycloak/keycloak-misc/blob/main/logo/icon.svg-->
<svg
width="256"
height="256"
viewBox="0 0 44.216 39.861"
fill="none"
xmlns="http://www.w3.org/2000/svg"
>
<path
d="m88.61 138.456 5.716-9.865 23.018-.004 5.686 9.965.007 19.932-5.691 9.957-23.012.008-5.782-9.965z"
style="
display: inline;
fill: #4d4d4d;
fill-opacity: 1;
stroke-width: 0.264583;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M88.552 158.481h10.375l-5.699-10.041 4.634-9.982-9.252-.002-5.795 10.065"
style="
fill: #ededed;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M102.073 158.481h7.582l6.706-9.773-6.589-10.156h-8.921l-5.373 9.814z"
style="
fill: #e0e0e0;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m82.815 148.52 5.738 9.964h10.374l-5.636-9.93z"
style="
fill: #acacac;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m95.589 148.522 6.484 9.963h7.582l6.601-9.959z"
style="
fill: #9e9e9e;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m98.157 148.529-1.958.569-1.877-.572 7.667-13.288 1.918 3.316"
style="
fill: #00b8e3;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m103.9 158.482-1.909 3.332-5.093-5.487-2.58-7.797v-.004h3.838"
style="
fill: #33c6e9;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M94.322 148.526h-.003v.003l-1.918 3.322-1.925-3.307 1.952-3.386 5.728-9.92h3.834"
style="
fill: #008aaa;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M115.42 158.481h11.611l-.007-19.93h-11.605z"
style="
fill: #d4d4d4;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M115.42 148.554v9.93h11.59v-9.93z"
style="
fill: #919191;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M101.992 161.817h-3.836l-5.755-9.966 1.918-3.321z"
style="
fill: #00b8e3;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m117.333 148.526-7.669 13.289c-.705-1.036-1.913-3.331-1.913-3.331l5.753-9.959z"
style="
fill: #008aaa;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="m113.495 161.815-3.831-.001 7.67-13.288 1.917-3.317 1.921 3.34m-3.839-.023h-3.828l-5.755-9.973 1.905-3.314 4.658 5.922z"
style="
fill: #00b8e3;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
<path
d="M119.25 145.205v.003l-1.917 3.318-7.677-13.286 3.841.002z"
style="
fill: #33c6e9;
fill-opacity: 1;
fill-rule: nonzero;
stroke: none;
stroke-width: 0.330729;
"
transform="translate(-82.815 -128.588)"
/>
</svg>
</template>
Original file line number Diff line number Diff line change
@@ -1,11 +1,18 @@
<template>
<BaseButton class="sign-in-button" @click="$emit('click')">
<KeycloakLogo v-if="provider === 'keycloak'" />
{{ signinText }}
</BaseButton>
</template>

<script>
import KeycloakLogo from "./KeycloakLogo.vue";

export default {
name: "OAuthLoginButton",
components: {
KeycloakLogo,
},
props: {
provider: {
type: String,
Expand Down
2 changes: 1 addition & 1 deletion argilla-frontend/translation/de.js
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ export default {
button: {
ignore_and_continue: "Ignorieren und fortfahren",
login: "Anmelden",
signin_with_provider: "Anmeldung bei {provider} starten",
signin_with_provider: "Mit {provider} anmelden",
"hf-login": "Mit Hugging Face anmelden",
sign_in_with_username: "Mit Benutzername anmelden",
cancel: "Abbrechen",
Expand Down
49 changes: 43 additions & 6 deletions argilla-server/src/argilla_server/api/handlers/v1/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from argilla_server.contexts import accounts
from argilla_server.database import get_async_db
from argilla_server.errors.future import NotFoundError
from argilla_server.models import User
from argilla_server.models import Workspace, WorkspaceUser
from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider
from argilla_server.security.authentication.userinfo import UserInfo
from argilla_server.security.settings import settings
Expand Down Expand Up @@ -61,14 +61,51 @@ async def get_access_token(
if not userinfo.username:
raise RuntimeError("OAuth error: Missing username")

user = await User.get_by(db, username=userinfo.username)
if user is None:
user = await accounts.create_user_with_random_password(
default_available_workspaces = [workspace.name for workspace in settings.oauth.allowed_workspaces]
available_workspaces = userinfo.available_workspaces or default_available_workspaces

oauth_user = await accounts.get_user_by_username(db, username=userinfo.username)

if oauth_user is None:
for workspace_name in available_workspaces:
if await Workspace.get_by(db, name=workspace_name) is None:
await Workspace.create(db, name=workspace_name, autocommit=False)

oauth_user = await accounts.create_user_with_random_password(
db,
username=userinfo.username,
first_name=userinfo.first_name,
role=userinfo.role,
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
workspaces=available_workspaces,
)
elif provider.sync_user:
oauth_role = oauth_user.role
oauth_workspaces = oauth_user.workspaces or []

# Sync user role
if oauth_role != userinfo.role:
await accounts.update_user(db, user=oauth_user, user_attrs={"role": userinfo.role})
# Sync removed workspaces
for workspace in oauth_workspaces:
if workspace.name not in available_workspaces:
ws_user = await WorkspaceUser.get_by(db, workspace_id=workspace.id, user_id=oauth_user.id)
await ws_user.delete(db, autocommit=False)
# Sync added workspaces
for workspace_name in available_workspaces:
if workspace_name in [ws.name for ws in oauth_workspaces]:
continue

workspace = await Workspace.get_by(db, name=workspace_name)
if not workspace:
workspace = await Workspace.create(db, name=workspace_name, autocommit=False)

if not await WorkspaceUser.get_by(db, workspace_id=workspace.id, user_id=oauth_user.id):
await WorkspaceUser.create(
db,
workspace_id=workspace.id,
user_id=oauth_user.id,
autocommit=False,
)
await db.commit()

return Token(access_token=accounts.generate_user_token(user))
return Token(access_token=accounts.generate_user_token(oauth_user))
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,15 @@
# limitations under the License.

import os
from typing import Type, Dict, Any
from typing import Type, Dict, Any, Optional, List

from social_core.backends.oauth import BaseOAuth2
from social_core.backends.open_id_connect import OpenIdConnectAuth
from social_core.backends.utils import load_backends
from social_core.strategy import BaseStrategy

from argilla_server.errors.future import NotFoundError
from argilla_server.models import UserRole


class Strategy(BaseStrategy):
Expand Down Expand Up @@ -48,6 +49,61 @@ class HuggingfaceOpenId(OpenIdConnectAuth):
DEFAULT_SCOPE = ["openid", "profile"]


class KeycloakOpenId(OpenIdConnectAuth):
"""Huggingface OpenID Connect authentication backend."""

name = "keycloak"

def oidc_endpoint(self) -> str:
frascuchon marked this conversation as resolved.
Show resolved Hide resolved
value = super().oidc_endpoint()

if value is None:
from social_core.utils import setting_name

name = setting_name("OIDC_ENDPOINT")
raise ValueError(
"oidc_endpoint needs to be set in the Keycloak configuration. "
f"Please set the {name} environment variable."
)

return value

def get_user_details(self, response: Dict[str, Any]) -> Dict[str, Any]:
user = super().get_user_details(response)

if role := self._extract_role(response):
user["role"] = role

if available_workspaces := self._extract_available_workspaces(response):
user["available_workspaces"] = available_workspaces

return user

def _extract_role(self, response: Dict[str, Any]) -> Optional[str]:
roles = self._read_realm_roles(response)
role_to_value = {UserRole.owner: 3, UserRole.admin: 2, UserRole.annotator: 1}
role_list = [role.split(":")[1] for role in roles if role.startswith("argilla_role:")]
if role_list:
max_role = max(role_list, key=lambda s: role_to_value.get(s, 0))
return max_role

def _extract_available_workspaces(self, response: Dict[str, Any]) -> List[str]:
roles = self._read_realm_roles(response)

workspaces = []
for role in roles:
if role.startswith("argilla_workspace:"):
workspace = role.split(":")[1]
workspaces.append(workspace)

return workspaces

@classmethod
def _read_realm_roles(cls, response) -> List[str]:
realm_access = response.get("realm_access") or {}
return realm_access.get("roles") or []


_SUPPORTED_BACKENDS = {}


Expand All @@ -56,6 +112,7 @@ def load_supported_backends(extra_backends: list = None) -> Dict[str, Type[BaseO

backends = [
"argilla_server.security.authentication.oauth2._backends.HuggingfaceOpenId",
"argilla_server.security.authentication.oauth2._backends.KeycloakOpenId",
"social_core.backends.github.GithubOAuth2",
"social_core.backends.google.GoogleOAuth2",
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
client_secret: str = None,
scope: Optional[List[str]] = None,
redirect_uri: str = None,
sync_user: bool = False,
) -> None:
self.name = backend_class.name
self._backend = backend_class(strategy=self.backend_strategy)
Expand All @@ -74,6 +75,7 @@ def __init__(
self.scope = self.scope.split(" ")

self.redirect_uri = redirect_uri or f"/oauth/{self.name}/callback"
self.sync_user = sync_user

@classmethod
def from_dict(cls, provider: dict, backend_class: Type[BaseOAuth2]) -> "OAuth2ClientProvider":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,19 @@ def username(self) -> str:
def first_name(self) -> str:
return self.get("first_name") or self.username

@property
def last_name(self) -> str:
paulbauriegel marked this conversation as resolved.
Show resolved Hide resolved
return self.get("last_name") or ""
paulbauriegel marked this conversation as resolved.
Show resolved Hide resolved

@property
def role(self) -> UserRole:
role = self.get("role") or self._parse_role_from_environment()
return UserRole(role)

@property
def available_workspaces(self) -> Optional[list]:
return self.get("available_workspaces")

def _parse_role_from_environment(self) -> Optional[UserRole]:
"""This is a temporal solution, and it will be replaced by a proper Sign up process"""
if self["username"] == os.getenv("USERNAME"):
Expand Down
Loading