diff --git a/lib/lightning/credentials.ex b/lib/lightning/credentials.ex index a8747438ce..2607d760ba 100644 --- a/lib/lightning/credentials.ex +++ b/lib/lightning/credentials.ex @@ -187,10 +187,19 @@ defmodule Lightning.Credentials do "body" => body, "oauth_token" => token }) do - Multi.new() - |> Multi.run(:oauth_token, fn _repo, _changes -> - create_oauth_token(user_id, client_id, token) - end) + base_multi = + Multi.new() + |> Multi.run(:scopes, fn _repo, _changes -> + case OauthToken.extract_scopes(token) do + {:ok, scopes} -> {:ok, scopes} + :error -> {:error, "Missing required OAuth field: scope"} + end + end) + + token_multi = build_token_multi(user_id, client_id, token) + + base_multi + |> Multi.append(token_multi) |> Multi.insert(:credential, fn %{oauth_token: fresh_token} -> changeset |> Ecto.Changeset.put_change(:oauth_token_id, fresh_token.id) @@ -198,18 +207,44 @@ defmodule Lightning.Credentials do end) end - defp create_oauth_token(user_id, client_id, token_data) do - case OauthToken.extract_scopes(token_data) do - {:ok, scopes} -> - find_or_create_oauth_token( - user_id, - client_id, - scopes, - token_data - ) + defp build_token_multi(user_id, client_id, token) do + if token["refresh_token"] do + Multi.new() + |> Multi.insert(:oauth_token, fn %{scopes: scopes} -> + OauthToken.changeset(%{ + user_id: user_id, + oauth_client_id: client_id, + scopes: scopes, + body: token + }) + end) + else + Multi.new() + |> Multi.run(:token_changeset, fn _repo, %{scopes: scopes} -> + handle_missing_refresh_token(user_id, client_id, scopes, token) + end) + |> Multi.insert(:oauth_token, fn %{token_changeset: token_changeset} -> + token_changeset + end) + end + end + + defp handle_missing_refresh_token(user_id, client_id, scopes, token) do + case find_oauth_token_by_scopes(user_id, client_id, scopes) do + nil -> + return_error("Missing required OAuth field: refresh_token") + + oauth_token -> + refresh_token = oauth_token.body["refresh_token"] + updated_token = Map.put(token, "refresh_token", refresh_token) - :error -> - {:error, "Could not extract scopes from OAuth token"} + {:ok, + OauthToken.changeset(%{ + user_id: user_id, + oauth_client_id: client_id, + scopes: scopes, + body: updated_token + })} end end @@ -232,7 +267,6 @@ defmodule Lightning.Credentials do {:ok, Credential.t()} | {:error, any()} def update_credential(%Credential{} = credential, attrs) do attrs = normalize_keys(attrs) - changeset = change_credential(credential, attrs) build_update_multi(credential, changeset, attrs) @@ -1189,6 +1223,17 @@ defmodule Lightning.Credentials do return_error("Invalid OAuth token body") end + def validate_oauth_token_data( + _token_data, + _user_id, + _oauth_client_id, + scopes, + _is_update + ) + when is_nil(scopes) do + return_error("Missing required OAuth field: scope") + end + def validate_oauth_token_data( token_data, user_id, @@ -1242,7 +1287,7 @@ defmodule Lightning.Credentials do scopes, is_update ) do - has_refresh_token = Map.has_key?(normalized_data, "refresh_token") + has_refresh_token? = Map.has_key?(normalized_data, "refresh_token") existing_token_exists? = token_exists?(user_id, oauth_client_id, scopes) @@ -1254,7 +1299,7 @@ defmodule Lightning.Credentials do existing_token_exists? -> validate_expiration_fields(normalized_data) - has_refresh_token -> + has_refresh_token? -> validate_expiration_fields(normalized_data) true -> @@ -1289,26 +1334,10 @@ defmodule Lightning.Credentials do Enum.any?(expires_fields, &Map.has_key?(token_data, &1)) end - defp find_or_create_oauth_token(user_id, oauth_client_id, scopes, token) - when is_list(scopes) do - case find_oauth_token_by_scopes(user_id, oauth_client_id, scopes) do - nil -> - OauthToken.changeset(%{ - user_id: user_id, - oauth_client_id: oauth_client_id, - scopes: scopes, - body: token - }) - |> Lightning.Repo.insert() - - existing -> - {:ok, existing} - end - end - defp find_oauth_token_by_scopes(user_id, oauth_client_id, scopes) when is_list(scopes) do - sorted_scopes = Enum.sort(scopes) + incoming_scopes = MapSet.new(scopes) + incoming_size = MapSet.size(incoming_scopes) Ecto.Query.from(t in OauthToken, join: token_client in OauthClient, @@ -1321,9 +1350,26 @@ defmodule Lightning.Credentials do token_client.client_secret == reference_client.client_secret ) |> Lightning.Repo.all() - |> Enum.find(fn token -> - sorted_token_scopes = Enum.sort(token.scopes) - Enum.all?(sorted_scopes, &Enum.member?(sorted_token_scopes, &1)) + |> Enum.filter(fn token -> + existing_scopes = MapSet.new(token.scopes) + MapSet.intersection(existing_scopes, incoming_scopes) |> MapSet.size() > 0 end) + |> Enum.max_by( + fn token -> + existing_scopes = MapSet.new(token.scopes) + + common_count = + MapSet.intersection(existing_scopes, incoming_scopes) |> MapSet.size() + + extra_count = + MapSet.difference(existing_scopes, incoming_scopes) |> MapSet.size() + + exact_match? = common_count == incoming_size && extra_count == 0 + timestamp = DateTime.to_unix(token.updated_at) + + {if(exact_match?, do: 1, else: 0), common_count, -extra_count, timestamp} + end, + fn -> nil end + ) end end diff --git a/lib/lightning/credentials/credential.ex b/lib/lightning/credentials/credential.ex index 805570d91e..c0e5dd885e 100644 --- a/lib/lightning/credentials/credential.ex +++ b/lib/lightning/credentials/credential.ex @@ -78,7 +78,7 @@ defmodule Lightning.Credentials.Credential do scopes = case OauthToken.extract_scopes(token_data) do {:ok, extracted_scopes} -> extracted_scopes - :error -> [] + :error -> nil end case Credentials.validate_oauth_token_data( diff --git a/lib/lightning/credentials/oauth_token.ex b/lib/lightning/credentials/oauth_token.ex index ac62bd041a..b67a8b657d 100644 --- a/lib/lightning/credentials/oauth_token.ex +++ b/lib/lightning/credentials/oauth_token.ex @@ -75,6 +75,8 @@ defmodule Lightning.Credentials.OauthToken do @doc """ Creates a changeset for updating token data. + Only merges with existing token body if new_token is a map, otherwise uses new_token directly. + Preserves the refresh_token from the existing token. """ def update_token_changeset(oauth_token, new_token) do scopes = @@ -83,11 +85,25 @@ defmodule Lightning.Credentials.OauthToken do :error -> nil end - cast(oauth_token, %{body: new_token, scopes: scopes}, [:body, :scopes]) + body = ensure_refresh_token(oauth_token, new_token) + + oauth_token + |> cast(%{body: body, scopes: scopes}, [:body, :scopes]) |> validate_required([:body, :scopes]) |> validate_oauth_body() end + defp ensure_refresh_token(oauth_token, new_token) when is_map(new_token) do + Map.merge( + %{"refresh_token" => oauth_token.body["refresh_token"]}, + new_token + ) + end + + defp ensure_refresh_token(_oauth_token, new_token) do + new_token + end + @doc """ Extracts scopes from OAuth token data in various formats. diff --git a/test/lightning/credentials/credential_test.exs b/test/lightning/credentials/credential_test.exs index 86ff5fe2b4..2e185a1607 100644 --- a/test/lightning/credentials/credential_test.exs +++ b/test/lightning/credentials/credential_test.exs @@ -13,17 +13,18 @@ defmodule Lightning.Credentials.CredentialTest do end test "oauth credentials require access_token, refresh_token, and expires_in or expires_at to be valid" do - assert_invalid_oauth_credential(%{}) + assert_invalid_oauth_credential(%{}, "Missing required OAuth field: scope") assert_invalid_oauth_credential( - %{"access_token" => "access_token_123"}, + %{"access_token" => "access_token_123", "scope" => "read write"}, "Missing refresh_token for new OAuth connection" ) assert_invalid_oauth_credential( %{ "access_token" => "access_token_123", - "refresh_token" => "refresh_token_123" + "refresh_token" => "refresh_token_123", + "scope" => "read write" }, "Missing expiration field: either expires_in or expires_at is required" ) @@ -31,12 +32,14 @@ defmodule Lightning.Credentials.CredentialTest do refute_invalid_oauth_credential(%{ "access_token" => "access_token_123", "refresh_token" => "refresh_token_123", + "scope" => "read write", "expires_at" => 3245 }) refute_invalid_oauth_credential(%{ "access_token" => "access_token_123", "refresh_token" => "refresh_token_123", + "scope" => "read write", "expires_in" => 3245 }) end @@ -68,10 +71,7 @@ defmodule Lightning.Credentials.CredentialTest do end end - defp assert_invalid_oauth_credential( - body, - message \\ "Missing required OAuth field: access_token" - ) do + defp assert_invalid_oauth_credential(body, message) do errors = Credentials.change_credential(%Credential{}, %{ name: "oauth credential", diff --git a/test/lightning/credentials_test.exs b/test/lightning/credentials_test.exs index 36178a2366..8befe2a075 100644 --- a/test/lightning/credentials_test.exs +++ b/test/lightning/credentials_test.exs @@ -1247,11 +1247,13 @@ defmodule Lightning.CredentialsTest do "body" => %{"key" => "value"}, "oauth_token" => %{ "access_token" => "test_access_token", - "expires_in" => 3600 + "refresh_token" => "test_refresh_token", + "expires_in" => 3600, + "scopex" => "read write" } } - assert {:error, "Could not extract scopes from OAuth token"} = + assert {:error, "Missing required OAuth field: scope"} = Credentials.create_credential(attrs) end @@ -1511,7 +1513,7 @@ defmodule Lightning.CredentialsTest do ["read"] ) - assert {:ok, ^token_data} = + assert {:error, "Missing required OAuth field: scope"} = Credentials.validate_oauth_token_data( token_data, "user_id",