diff --git a/src/xmpp_sasl.erl b/src/xmpp_sasl.erl index f392c0b..aaf36db 100644 --- a/src/xmpp_sasl.erl +++ b/src/xmpp_sasl.erl @@ -105,7 +105,7 @@ server_new(ServerHost, GetPassword, CheckPassword, CheckPasswordDigest) -> check_password_digest = CheckPasswordDigest}. -spec server_start(sasl_state(), mechanism(), binary(), channel_bindings(), - list(binary())) -> sasl_return(). + list(binary()) | undefined) -> sasl_return(). server_start(State, Mech, ClientIn, ChannelBindings, Mechs) -> case get_mod(Mech) of undefined -> diff --git a/src/xmpp_sasl_scram.erl b/src/xmpp_sasl_scram.erl index cc69bca..c5aee21 100644 --- a/src/xmpp_sasl_scram.erl +++ b/src/xmpp_sasl_scram.erl @@ -33,7 +33,7 @@ {step = 2 :: 2 | 4, algo = sha :: sha | sha256 | sha512, channel_bindings = none :: none | #{atom() => binary()}, - ssdp :: binary(), + ssdp :: undefined | binary(), stored_key = <<"">> :: binary(), server_key = <<"">> :: binary(), username = <<"">> :: binary(), @@ -84,13 +84,17 @@ mech_new(Mech, ChannelBindings, Mechs, _Host, GetPassword, _CheckPassword, _Chec <<"SCRAM-SHA-512">> -> {sha512, none}; <<"SCRAM-SHA-512-PLUS">> -> {sha512, ChannelBindings} end, - Ssdp = base64:encode(crypto:hash(Algo, [ - lists:join(<<",">>, lists:sort(Mechs)), - case ChannelBindings of - none -> []; - _ when map_size(ChannelBindings) == 0 -> []; - _ -> [<<"|">>, lists:join(<<",">>, lists:sort(maps:keys(ChannelBindings)))] - end])), + Ssdp = case Mechs of + undefined -> undefined; + _ -> + base64:encode(crypto:hash(Algo, [ + lists:join(<<",">>, lists:sort(Mechs)), + case ChannelBindings of + none -> []; + _ when map_size(ChannelBindings) == 0 -> []; + _ -> [<<"|">>, lists:join(<<",">>, lists:sort(maps:keys(ChannelBindings)))] + end])) + end, #state{step = 2, get_password = GetPassword, algo = Algo, channel_bindings = CB, ssdp = Ssdp}. @@ -148,6 +152,10 @@ mech_step(#state{step = 2, algo = Algo, ssdp = Ssdp} = State, ClientIn) -> str(ClientIn, <<"n=">>)), ServerNonce = base64:encode(p1_rand:bytes(?NONCE_LENGTH)), + SsdpPart = case Ssdp of + undefined -> []; + _ -> [",d=", Ssdp] + end, ServerFirstMessage = iolist_to_binary( ["r=", @@ -157,7 +165,7 @@ mech_step(#state{step = 2, algo = Algo, ssdp = Ssdp} = State, ClientIn) -> base64:encode(Salt), ",", "i=", integer_to_list(IterationCount), - ",d=", Ssdp]), + SsdpPart]), {continue, ServerFirstMessage, State#state{step = 4, stored_key = StoredKey, server_key = ServerKey, diff --git a/src/xmpp_stream_in.erl b/src/xmpp_stream_in.erl index c6808ec..8139034 100644 --- a/src/xmpp_stream_in.erl +++ b/src/xmpp_stream_in.erl @@ -110,6 +110,7 @@ -callback tls_required(state()) -> boolean(). -callback tls_enabled(state()) -> boolean(). -callback sasl_mechanisms([xmpp_sasl:mechanism()], state()) -> [xmpp_sasl:mechanism()]. +-callback sasl_options(state()) -> [tuple()]. -callback unauthenticated_stream_features(state()) -> [xmpp_element()]. -callback authenticated_stream_features(state()) -> [xmpp_element()]. @@ -142,6 +143,7 @@ tls_required/1, tls_enabled/1, sasl_mechanisms/2, + sasl_options/1, unauthenticated_stream_features/1, authenticated_stream_features/1]). @@ -955,9 +957,17 @@ process_sasl_request(#sasl_auth{mechanism = Mech, text = ClientIn}, GetPW = get_password_fun(Mech, State1), CheckPW = check_password_fun(Mech, State1), CheckPWDigest = check_password_digest_fun(Mech, State1), + Mechs2 = try callback(sasl_options, State) of + Opts -> + case lists:keyfind(scram_downgrade_protection, 1, Opts) of + {_, false} -> undefined; + _ -> Mechs + end + catch _:{?MODULE, undef} -> Mechs + end, SASLState = xmpp_sasl:server_new(LServer, GetPW, CheckPW, CheckPWDigest), CB = maps:get(sasl_channel_bindings, State1, none), - Res = xmpp_sasl:server_start(SASLState, Mech, ClientIn, CB, Mechs), + Res = xmpp_sasl:server_start(SASLState, Mech, ClientIn, CB, Mechs2), process_sasl_result(Res, disable_sasl2(State1#{sasl_state => SASLState})); false -> process_sasl_result({error, unsupported_mechanism, <<"">>}, disable_sasl2(State1)) @@ -1068,9 +1078,17 @@ process_sasl2_request(#sasl2_authenticate{mechanism = Mech, initial_response = C GetPW = get_password_fun(Mech, State1), CheckPW = check_password_fun(Mech, State1), CheckPWDigest = check_password_digest_fun(Mech, State1), + Mechs2 = try callback(sasl_options, State) of + Opts -> + case lists:keyfind(scram_downgrade_protection, 1, Opts) of + {_, false} -> undefined; + _ -> Mechs + end + catch _:{?MODULE, undef} -> Mechs + end, SASLState = xmpp_sasl:server_new(LServer, GetPW, CheckPW, CheckPWDigest), CB = maps:get(sasl_channel_bindings, State1, none), - Res = xmpp_sasl:server_start(SASLState, Mech, ClientIn, CB, Mechs), + Res = xmpp_sasl:server_start(SASLState, Mech, ClientIn, CB, Mechs2), process_sasl2_result(Res, State1#{sasl_state => SASLState, sasl2_inline_els => SaslInline, sasl2_ua_id => UAId});