Skip to content

Commit

Permalink
Add ability to disable scram downgrade protection
Browse files Browse the repository at this point in the history
  • Loading branch information
prefiks committed Jan 16, 2024
1 parent 7cc3741 commit db6d730
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/xmpp_sasl.erl
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ server_new(ServerHost, GetPassword, CheckPassword, CheckPasswordDigest) ->
check_password_digest = CheckPasswordDigest}.

-spec server_start(sasl_state(), mechanism(), binary(), channel_bindings(),
list(binary())) -> sasl_return().
list(binary()) | undefined) -> sasl_return().
server_start(State, Mech, ClientIn, ChannelBindings, Mechs) ->
case get_mod(Mech) of
undefined ->
Expand Down
26 changes: 17 additions & 9 deletions src/xmpp_sasl_scram.erl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
{step = 2 :: 2 | 4,
algo = sha :: sha | sha256 | sha512,
channel_bindings = none :: none | #{atom() => binary()},
ssdp :: binary(),
ssdp :: undefined | binary(),
stored_key = <<"">> :: binary(),
server_key = <<"">> :: binary(),
username = <<"">> :: binary(),
Expand Down Expand Up @@ -84,13 +84,17 @@ mech_new(Mech, ChannelBindings, Mechs, _Host, GetPassword, _CheckPassword, _Chec
<<"SCRAM-SHA-512">> -> {sha512, none};
<<"SCRAM-SHA-512-PLUS">> -> {sha512, ChannelBindings}
end,
Ssdp = base64:encode(crypto:hash(Algo, [
lists:join(<<",">>, lists:sort(Mechs)),
case ChannelBindings of
none -> [];
_ when map_size(ChannelBindings) == 0 -> [];
_ -> [<<"|">>, lists:join(<<",">>, lists:sort(maps:keys(ChannelBindings)))]
end])),
Ssdp = case Mechs of
undefined -> undefined;
_ ->
base64:encode(crypto:hash(Algo, [
lists:join(<<",">>, lists:sort(Mechs)),
case ChannelBindings of
none -> [];
_ when map_size(ChannelBindings) == 0 -> [];
_ -> [<<"|">>, lists:join(<<",">>, lists:sort(maps:keys(ChannelBindings)))]
end]))
end,
#state{step = 2, get_password = GetPassword, algo = Algo,
channel_bindings = CB, ssdp = Ssdp}.

Expand Down Expand Up @@ -148,6 +152,10 @@ mech_step(#state{step = 2, algo = Algo, ssdp = Ssdp} = State, ClientIn) ->
str(ClientIn, <<"n=">>)),
ServerNonce =
base64:encode(p1_rand:bytes(?NONCE_LENGTH)),
SsdpPart = case Ssdp of
undefined -> [];
_ -> [",d=", Ssdp]
end,
ServerFirstMessage =
iolist_to_binary(
["r=",
Expand All @@ -157,7 +165,7 @@ mech_step(#state{step = 2, algo = Algo, ssdp = Ssdp} = State, ClientIn) ->
base64:encode(Salt),
",", "i=",
integer_to_list(IterationCount),
",d=", Ssdp]),
SsdpPart]),
{continue, ServerFirstMessage,
State#state{step = 4, stored_key = StoredKey,
server_key = ServerKey,
Expand Down
22 changes: 20 additions & 2 deletions src/xmpp_stream_in.erl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@
-callback tls_required(state()) -> boolean().
-callback tls_enabled(state()) -> boolean().
-callback sasl_mechanisms([xmpp_sasl:mechanism()], state()) -> [xmpp_sasl:mechanism()].
-callback sasl_options(state()) -> [tuple()].
-callback unauthenticated_stream_features(state()) -> [xmpp_element()].
-callback authenticated_stream_features(state()) -> [xmpp_element()].

Expand Down Expand Up @@ -142,6 +143,7 @@
tls_required/1,
tls_enabled/1,
sasl_mechanisms/2,
sasl_options/1,
unauthenticated_stream_features/1,
authenticated_stream_features/1]).

Expand Down Expand Up @@ -955,9 +957,17 @@ process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn},
GetPW = get_password_fun(Mech, State1),
CheckPW = check_password_fun(Mech, State1),
CheckPWDigest = check_password_digest_fun(Mech, State1),
Mechs2 = try callback(sasl_options, State) of
Opts ->
case lists:keyfind(scram_downgrade_protection, 1, Opts) of
{_, false} -> undefined;
_ -> Mechs
end
catch _:{?MODULE, undef} -> Mechs
end,
SASLState = xmpp_sasl:server_new(LServer, GetPW, CheckPW, CheckPWDigest),
CB = maps:get(sasl_channel_bindings, State1, none),
Res = xmpp_sasl:server_start(SASLState, Mech, ClientIn, CB, Mechs),
Res = xmpp_sasl:server_start(SASLState, Mech, ClientIn, CB, Mechs2),
process_sasl_result(Res, disable_sasl2(State1#{sasl_state => SASLState}));
false ->
process_sasl_result({error, unsupported_mechanism, <<"">>}, disable_sasl2(State1))
Expand Down Expand Up @@ -1068,9 +1078,17 @@ process_sasl2_request(#sasl2_authenticate{mechanism = Mech, initial_response = C
GetPW = get_password_fun(Mech, State1),
CheckPW = check_password_fun(Mech, State1),
CheckPWDigest = check_password_digest_fun(Mech, State1),
Mechs2 = try callback(sasl_options, State) of
Opts ->
case lists:keyfind(scram_downgrade_protection, 1, Opts) of
{_, false} -> undefined;
_ -> Mechs
end
catch _:{?MODULE, undef} -> Mechs
end,
SASLState = xmpp_sasl:server_new(LServer, GetPW, CheckPW, CheckPWDigest),
CB = maps:get(sasl_channel_bindings, State1, none),
Res = xmpp_sasl:server_start(SASLState, Mech, ClientIn, CB, Mechs),
Res = xmpp_sasl:server_start(SASLState, Mech, ClientIn, CB, Mechs2),
process_sasl2_result(Res, State1#{sasl_state => SASLState,
sasl2_inline_els => SaslInline,
sasl2_ua_id => UAId});
Expand Down

0 comments on commit db6d730

Please sign in to comment.