Skip to content

Commit

Permalink
Allow refresh_token reusage
Browse files Browse the repository at this point in the history
  • Loading branch information
elias-ba committed Mar 7, 2025
1 parent 47be961 commit 0e30c9a
Show file tree
Hide file tree
Showing 5 changed files with 115 additions and 51 deletions.
124 changes: 85 additions & 39 deletions lib/lightning/credentials.ex
Original file line number Diff line number Diff line change
Expand Up @@ -187,29 +187,64 @@ defmodule Lightning.Credentials do
"body" => body,
"oauth_token" => token
}) do
Multi.new()
|> Multi.run(:oauth_token, fn _repo, _changes ->
create_oauth_token(user_id, client_id, token)
end)
base_multi =
Multi.new()
|> Multi.run(:scopes, fn _repo, _changes ->
case OauthToken.extract_scopes(token) do
{:ok, scopes} -> {:ok, scopes}
:error -> {:error, "Missing required OAuth field: scope"}
end
end)

token_multi = build_token_multi(user_id, client_id, token)

base_multi
|> Multi.append(token_multi)
|> Multi.insert(:credential, fn %{oauth_token: fresh_token} ->
changeset
|> Ecto.Changeset.put_change(:oauth_token_id, fresh_token.id)
|> Ecto.Changeset.put_change(:body, body)
end)
end

defp create_oauth_token(user_id, client_id, token_data) do
case OauthToken.extract_scopes(token_data) do
{:ok, scopes} ->
find_or_create_oauth_token(
user_id,
client_id,
scopes,
token_data
)
defp build_token_multi(user_id, client_id, token) do
if token["refresh_token"] do
Multi.new()
|> Multi.insert(:oauth_token, fn %{scopes: scopes} ->
OauthToken.changeset(%{
user_id: user_id,
oauth_client_id: client_id,
scopes: scopes,
body: token
})
end)
else
Multi.new()
|> Multi.run(:token_changeset, fn _repo, %{scopes: scopes} ->
handle_missing_refresh_token(user_id, client_id, scopes, token)
end)
|> Multi.insert(:oauth_token, fn %{token_changeset: token_changeset} ->
token_changeset
end)
end
end

defp handle_missing_refresh_token(user_id, client_id, scopes, token) do
case find_oauth_token_by_scopes(user_id, client_id, scopes) do
nil ->
return_error("Missing required OAuth field: refresh_token")

oauth_token ->
refresh_token = oauth_token.body["refresh_token"]
updated_token = Map.put(token, "refresh_token", refresh_token)

:error ->
{:error, "Could not extract scopes from OAuth token"}
{:ok,
OauthToken.changeset(%{
user_id: user_id,
oauth_client_id: client_id,
scopes: scopes,
body: updated_token
})}
end
end

Expand All @@ -232,7 +267,6 @@ defmodule Lightning.Credentials do
{:ok, Credential.t()} | {:error, any()}
def update_credential(%Credential{} = credential, attrs) do
attrs = normalize_keys(attrs)

changeset = change_credential(credential, attrs)

build_update_multi(credential, changeset, attrs)
Expand Down Expand Up @@ -1189,6 +1223,17 @@ defmodule Lightning.Credentials do
return_error("Invalid OAuth token body")
end

def validate_oauth_token_data(
_token_data,
_user_id,
_oauth_client_id,
scopes,
_is_update
)
when is_nil(scopes) do
return_error("Missing required OAuth field: scope")
end

def validate_oauth_token_data(
token_data,
user_id,
Expand Down Expand Up @@ -1242,7 +1287,7 @@ defmodule Lightning.Credentials do
scopes,
is_update
) do
has_refresh_token = Map.has_key?(normalized_data, "refresh_token")
has_refresh_token? = Map.has_key?(normalized_data, "refresh_token")

existing_token_exists? =
token_exists?(user_id, oauth_client_id, scopes)
Expand All @@ -1254,7 +1299,7 @@ defmodule Lightning.Credentials do
existing_token_exists? ->
validate_expiration_fields(normalized_data)

has_refresh_token ->
has_refresh_token? ->
validate_expiration_fields(normalized_data)

true ->
Expand Down Expand Up @@ -1289,26 +1334,10 @@ defmodule Lightning.Credentials do
Enum.any?(expires_fields, &Map.has_key?(token_data, &1))
end

defp find_or_create_oauth_token(user_id, oauth_client_id, scopes, token)
when is_list(scopes) do
case find_oauth_token_by_scopes(user_id, oauth_client_id, scopes) do
nil ->
OauthToken.changeset(%{
user_id: user_id,
oauth_client_id: oauth_client_id,
scopes: scopes,
body: token
})
|> Lightning.Repo.insert()

existing ->
{:ok, existing}
end
end

defp find_oauth_token_by_scopes(user_id, oauth_client_id, scopes)
when is_list(scopes) do
sorted_scopes = Enum.sort(scopes)
incoming_scopes = MapSet.new(scopes)
incoming_size = MapSet.size(incoming_scopes)

Ecto.Query.from(t in OauthToken,
join: token_client in OauthClient,
Expand All @@ -1321,9 +1350,26 @@ defmodule Lightning.Credentials do
token_client.client_secret == reference_client.client_secret
)
|> Lightning.Repo.all()
|> Enum.find(fn token ->
sorted_token_scopes = Enum.sort(token.scopes)
Enum.all?(sorted_scopes, &Enum.member?(sorted_token_scopes, &1))
|> Enum.filter(fn token ->
existing_scopes = MapSet.new(token.scopes)
MapSet.intersection(existing_scopes, incoming_scopes) |> MapSet.size() > 0
end)
|> Enum.max_by(
fn token ->
existing_scopes = MapSet.new(token.scopes)

common_count =
MapSet.intersection(existing_scopes, incoming_scopes) |> MapSet.size()

extra_count =
MapSet.difference(existing_scopes, incoming_scopes) |> MapSet.size()

exact_match? = common_count == incoming_size && extra_count == 0
timestamp = DateTime.to_unix(token.updated_at)

{if(exact_match?, do: 1, else: 0), common_count, -extra_count, timestamp}
end,
fn -> nil end
)
end
end
2 changes: 1 addition & 1 deletion lib/lightning/credentials/credential.ex
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ defmodule Lightning.Credentials.Credential do
scopes =
case OauthToken.extract_scopes(token_data) do
{:ok, extracted_scopes} -> extracted_scopes
:error -> []
:error -> nil
end

case Credentials.validate_oauth_token_data(
Expand Down
18 changes: 17 additions & 1 deletion lib/lightning/credentials/oauth_token.ex
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ defmodule Lightning.Credentials.OauthToken do

@doc """
Creates a changeset for updating token data.
Only merges with existing token body if new_token is a map, otherwise uses new_token directly.
Preserves the refresh_token from the existing token.
"""
def update_token_changeset(oauth_token, new_token) do
scopes =
Expand All @@ -83,11 +85,25 @@ defmodule Lightning.Credentials.OauthToken do
:error -> nil
end

cast(oauth_token, %{body: new_token, scopes: scopes}, [:body, :scopes])
body = ensure_refresh_token(oauth_token, new_token)

oauth_token
|> cast(%{body: body, scopes: scopes}, [:body, :scopes])
|> validate_required([:body, :scopes])
|> validate_oauth_body()
end

defp ensure_refresh_token(oauth_token, new_token) when is_map(new_token) do
Map.merge(
%{"refresh_token" => oauth_token.body["refresh_token"]},
new_token
)
end

defp ensure_refresh_token(_oauth_token, new_token) do
new_token
end

@doc """
Extracts scopes from OAuth token data in various formats.
Expand Down
14 changes: 7 additions & 7 deletions test/lightning/credentials/credential_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,33 @@ defmodule Lightning.Credentials.CredentialTest do
end

test "oauth credentials require access_token, refresh_token, and expires_in or expires_at to be valid" do
assert_invalid_oauth_credential(%{})
assert_invalid_oauth_credential(%{}, "Missing required OAuth field: scope")

assert_invalid_oauth_credential(
%{"access_token" => "access_token_123"},
%{"access_token" => "access_token_123", "scope" => "read write"},
"Missing refresh_token for new OAuth connection"
)

assert_invalid_oauth_credential(
%{
"access_token" => "access_token_123",
"refresh_token" => "refresh_token_123"
"refresh_token" => "refresh_token_123",
"scope" => "read write"
},
"Missing expiration field: either expires_in or expires_at is required"
)

refute_invalid_oauth_credential(%{
"access_token" => "access_token_123",
"refresh_token" => "refresh_token_123",
"scope" => "read write",
"expires_at" => 3245
})

refute_invalid_oauth_credential(%{
"access_token" => "access_token_123",
"refresh_token" => "refresh_token_123",
"scope" => "read write",
"expires_in" => 3245
})
end
Expand Down Expand Up @@ -68,10 +71,7 @@ defmodule Lightning.Credentials.CredentialTest do
end
end

defp assert_invalid_oauth_credential(
body,
message \\ "Missing required OAuth field: access_token"
) do
defp assert_invalid_oauth_credential(body, message) do
errors =
Credentials.change_credential(%Credential{}, %{
name: "oauth credential",
Expand Down
8 changes: 5 additions & 3 deletions test/lightning/credentials_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -1247,11 +1247,13 @@ defmodule Lightning.CredentialsTest do
"body" => %{"key" => "value"},
"oauth_token" => %{
"access_token" => "test_access_token",
"expires_in" => 3600
"refresh_token" => "test_refresh_token",
"expires_in" => 3600,
"scopex" => "read write"
}
}

assert {:error, "Could not extract scopes from OAuth token"} =
assert {:error, "Missing required OAuth field: scope"} =
Credentials.create_credential(attrs)
end

Expand Down Expand Up @@ -1511,7 +1513,7 @@ defmodule Lightning.CredentialsTest do
["read"]
)

assert {:ok, ^token_data} =
assert {:error, "Missing required OAuth field: scope"} =
Credentials.validate_oauth_token_data(
token_data,
"user_id",
Expand Down

0 comments on commit 0e30c9a

Please sign in to comment.