diff --git a/src/main/java/org/signal/registration/ratelimit/SendSmsVerificationCodeRateLimiter.java b/src/main/java/org/signal/registration/ratelimit/SendSmsVerificationCodeRateLimiter.java index 7c65d938..9890d660 100644 --- a/src/main/java/org/signal/registration/ratelimit/SendSmsVerificationCodeRateLimiter.java +++ b/src/main/java/org/signal/registration/ratelimit/SendSmsVerificationCodeRateLimiter.java @@ -14,6 +14,8 @@ import java.util.Comparator; import java.util.List; import java.util.Optional; +import java.util.stream.Stream; + import org.signal.registration.rpc.MessageTransport; import org.signal.registration.session.FailedSendReason; import org.signal.registration.session.RegistrationSession; @@ -41,16 +43,22 @@ protected int getPriorAttemptCount(final RegistrationSession session) { .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_SMS) .count() + (int) session.getFailedAttemptsList().stream() - .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_SMS - && attempt.getFailedSendReason() != FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) + .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_SMS) + .filter(attempt -> attempt.getFailedSendReason() != FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) .count(); } @Override protected Optional getLastAttemptTime(final RegistrationSession session) { - return session.getRegistrationAttemptsList().stream() - .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_SMS) - .map(attempt -> Instant.ofEpochMilli(attempt.getTimestampEpochMillis())) - .max(Comparator.naturalOrder()); + return Stream.concat( + session.getRegistrationAttemptsList().stream() + .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_SMS) + .map(attempt -> Instant.ofEpochMilli(attempt.getTimestampEpochMillis())), + session.getFailedAttemptsList().stream() + .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_SMS) + .filter(attempt -> attempt.getFailedSendReason() != FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) + .map(attempt -> Instant.ofEpochMilli(attempt.getTimestampEpochMillis())) + ) + .max(Comparator.naturalOrder()); } } diff --git a/src/main/java/org/signal/registration/ratelimit/SendVoiceVerificationCodeRateLimiter.java b/src/main/java/org/signal/registration/ratelimit/SendVoiceVerificationCodeRateLimiter.java index c9bc5b08..2b4bdd1d 100644 --- a/src/main/java/org/signal/registration/ratelimit/SendVoiceVerificationCodeRateLimiter.java +++ b/src/main/java/org/signal/registration/ratelimit/SendVoiceVerificationCodeRateLimiter.java @@ -18,6 +18,7 @@ import java.util.List; import java.util.Optional; import java.util.concurrent.CompletableFuture; +import java.util.stream.Stream; @Singleton @Named("send-voice-verification-code") @@ -76,16 +77,21 @@ protected int getPriorAttemptCount(final RegistrationSession session) { .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_VOICE) .count() + (int) session.getFailedAttemptsList().stream() - .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_VOICE - && attempt.getFailedSendReason() != FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) + .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_VOICE) + .filter(attempt -> attempt.getFailedSendReason() != FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) .count(); } @Override protected Optional getLastAttemptTime(final RegistrationSession session) { - return session.getRegistrationAttemptsList().stream() - .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_VOICE) - .map(attempt -> Instant.ofEpochMilli(attempt.getTimestampEpochMillis())) + return Stream.concat( + session.getRegistrationAttemptsList().stream() + .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_VOICE) + .map(attempt -> Instant.ofEpochMilli(attempt.getTimestampEpochMillis())), + session.getFailedAttemptsList().stream() + .filter(attempt -> attempt.getMessageTransport() == MessageTransport.MESSAGE_TRANSPORT_VOICE) + .filter(attempt -> attempt.getFailedSendReason() != FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) + .map(attempt -> Instant.ofEpochMilli(attempt.getTimestampEpochMillis()))) .max(Comparator.naturalOrder()); } } diff --git a/src/test/java/org/signal/registration/ratelimit/SendSmsVerificationCodeRateLimiterTest.java b/src/test/java/org/signal/registration/ratelimit/SendSmsVerificationCodeRateLimiterTest.java index abb28acc..7f791482 100644 --- a/src/test/java/org/signal/registration/ratelimit/SendSmsVerificationCodeRateLimiterTest.java +++ b/src/test/java/org/signal/registration/ratelimit/SendSmsVerificationCodeRateLimiterTest.java @@ -130,6 +130,16 @@ void getLastAttemptTime() { .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) .setTimestampEpochMillis(System.currentTimeMillis()) .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) + .setTimestampEpochMillis(System.currentTimeMillis()) + .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(System.currentTimeMillis()) + .build()) .build())); final long firstTimestamp = 37; @@ -143,6 +153,15 @@ void getLastAttemptTime() { .build()) .build())); + assertEquals(Optional.of(Instant.ofEpochMilli(firstTimestamp)), + rateLimiter.getLastAttemptTime(RegistrationSession.newBuilder() + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(firstTimestamp) + .build()) + .build())); + assertEquals(Optional.of(Instant.ofEpochMilli(firstTimestamp)), rateLimiter.getLastAttemptTime(RegistrationSession.newBuilder() .addRegistrationAttempts(RegistrationAttempt.newBuilder() @@ -153,6 +172,29 @@ void getLastAttemptTime() { .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) .setTimestampEpochMillis(secondTimestamp) .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) + .setTimestampEpochMillis(secondTimestamp) + .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(secondTimestamp) + .build()) + .build())); + + assertEquals(Optional.of(Instant.ofEpochMilli(firstTimestamp)), + rateLimiter.getLastAttemptTime(RegistrationSession.newBuilder() + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(firstTimestamp) + .build()) + .addRegistrationAttempts(RegistrationAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setTimestampEpochMillis(secondTimestamp) + .build()) .build())); assertEquals(Optional.of(Instant.ofEpochMilli(secondTimestamp)), @@ -166,6 +208,19 @@ void getLastAttemptTime() { .setTimestampEpochMillis(secondTimestamp) .build()) .build())); + + assertEquals(Optional.of(Instant.ofEpochMilli(secondTimestamp)), + rateLimiter.getLastAttemptTime(RegistrationSession.newBuilder() + .addRegistrationAttempts(RegistrationAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setTimestampEpochMillis(firstTimestamp) + .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(secondTimestamp) + .build()) + .build())); } @Test diff --git a/src/test/java/org/signal/registration/ratelimit/SendVoiceVerificationCodeRateLimiterTest.java b/src/test/java/org/signal/registration/ratelimit/SendVoiceVerificationCodeRateLimiterTest.java index 29febdec..881717b3 100644 --- a/src/test/java/org/signal/registration/ratelimit/SendVoiceVerificationCodeRateLimiterTest.java +++ b/src/test/java/org/signal/registration/ratelimit/SendVoiceVerificationCodeRateLimiterTest.java @@ -221,6 +221,16 @@ void getLastAttemptTime() { .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) .setTimestampEpochMillis(System.currentTimeMillis()) .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(System.currentTimeMillis()) + .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) + .setTimestampEpochMillis(System.currentTimeMillis()) + .build()) .build())); final long firstTimestamp = 37; @@ -234,6 +244,15 @@ void getLastAttemptTime() { .build()) .build())); + assertEquals(Optional.of(Instant.ofEpochMilli(firstTimestamp)), + rateLimiter.getLastAttemptTime(RegistrationSession.newBuilder() + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(firstTimestamp) + .build()) + .build())); + assertEquals(Optional.of(Instant.ofEpochMilli(firstTimestamp)), rateLimiter.getLastAttemptTime(RegistrationSession.newBuilder() .addRegistrationAttempts(RegistrationAttempt.newBuilder() @@ -244,6 +263,29 @@ void getLastAttemptTime() { .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) .setTimestampEpochMillis(secondTimestamp) .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_UNAVAILABLE) + .setTimestampEpochMillis(secondTimestamp) + .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(secondTimestamp) + .build()) + .build())); + + assertEquals(Optional.of(Instant.ofEpochMilli(firstTimestamp)), + rateLimiter.getLastAttemptTime(RegistrationSession.newBuilder() + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(firstTimestamp) + .build()) + .addRegistrationAttempts(RegistrationAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_SMS) + .setTimestampEpochMillis(secondTimestamp) + .build()) .build())); assertEquals(Optional.of(Instant.ofEpochMilli(secondTimestamp)), @@ -257,5 +299,18 @@ void getLastAttemptTime() { .setTimestampEpochMillis(secondTimestamp) .build()) .build())); + + assertEquals(Optional.of(Instant.ofEpochMilli(secondTimestamp)), + rateLimiter.getLastAttemptTime(RegistrationSession.newBuilder() + .addRegistrationAttempts(RegistrationAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setTimestampEpochMillis(firstTimestamp) + .build()) + .addFailedAttempts(FailedSendAttempt.newBuilder() + .setMessageTransport(MessageTransport.MESSAGE_TRANSPORT_VOICE) + .setFailedSendReason(FailedSendReason.FAILED_SEND_REASON_REJECTED) + .setTimestampEpochMillis(secondTimestamp) + .build()) + .build())); } }