diff --git a/topic/pom.xml b/topic/pom.xml index 7518677c..cc193eb6 100644 --- a/topic/pom.xml +++ b/topic/pom.xml @@ -40,14 +40,21 @@ zstd-jni 1.5.2-5 + junit junit test + + org.mockito + mockito-inline + test + tech.ydb.test ydb-junit4-support + test org.apache.logging.log4j diff --git a/topic/src/main/java/tech/ydb/topic/impl/GrpcStreamRetrier.java b/topic/src/main/java/tech/ydb/topic/impl/GrpcStreamRetrier.java index 27817bbf..3b00c96e 100644 --- a/topic/src/main/java/tech/ydb/topic/impl/GrpcStreamRetrier.java +++ b/topic/src/main/java/tech/ydb/topic/impl/GrpcStreamRetrier.java @@ -6,7 +6,6 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; -import java.util.concurrent.atomic.AtomicInteger; import org.slf4j.Logger; @@ -27,18 +26,21 @@ public abstract class GrpcStreamRetrier { protected final AtomicBoolean isReconnecting = new AtomicBoolean(false); protected final AtomicBoolean isStopped = new AtomicBoolean(false); + private final Logger logger; private final ScheduledExecutorService scheduler; private final RetryMode retryMode; private final RetryPolicy retryPolicy = new DefaultRetryPolicy(); - private final AtomicInteger retry = new AtomicInteger(-1); - protected GrpcStreamRetrier(RetryMode retryMode, ScheduledExecutorService scheduler) { + private volatile boolean connected = false; + private volatile int retryNumber = 0; + + protected GrpcStreamRetrier(Logger logger, RetryMode retryMode, ScheduledExecutorService scheduler) { + this.logger = logger; this.retryMode = retryMode; this.scheduler = scheduler; this.id = generateRandomId(ID_LENGTH); } - protected abstract Logger getLogger(); protected abstract String getStreamName(); protected abstract void onStreamReconnect(); protected abstract void onShutdown(String reason); @@ -51,34 +53,34 @@ protected static String generateRandomId(int length) { .toString(); } - private void tryScheduleReconnect(int retryNumber) { + private void tryScheduleReconnect() { if (!isReconnecting.compareAndSet(false, true)) { - getLogger().info("[{}] should reconnect {} stream, but reconnect is already in progress", id, + logger.info("[{}] should reconnect {} stream, but reconnect is already in progress", id, getStreamName()); return; } - retry.set(retryNumber); long delay = retryPolicy.nextRetryMs(retryNumber, 0); - getLogger().warn("[{}] Retry #{}. Scheduling {} reconnect in {}ms...", id, retryNumber, getStreamName(), delay); + logger.warn("[{}] Retry #{}. Scheduling {} reconnect in {}ms...", id, retryNumber, getStreamName(), delay); try { scheduler.schedule(this::reconnect, delay, TimeUnit.MILLISECONDS); } catch (RejectedExecutionException exception) { String errorMessage = "[" + id + "] Couldn't schedule reconnect: scheduler is already shut down. " + "Shutting down " + getStreamName(); - getLogger().error(errorMessage); + logger.error(errorMessage); shutdownImpl(errorMessage); } } protected void resetRetries() { - retry.set(0); + retryNumber = 0; + connected = true; } void reconnect() { - getLogger().info("[{}] {} reconnect #{} started", id, getStreamName(), retry.get()); + logger.info("[{}] {} reconnect #{} started", id, getStreamName(), retryNumber); if (!isReconnecting.compareAndSet(true, false)) { - getLogger().warn("Couldn't reset reconnect flag. Shouldn't happen"); + logger.warn("Couldn't reset reconnect flag. Shouldn't happen"); } onStreamReconnect(); } @@ -88,7 +90,7 @@ protected CompletableFuture shutdownImpl() { } protected CompletableFuture shutdownImpl(String reason) { - getLogger().info("[{}] Shutting down {}" + logger.info("[{}] Shutting down {}" + (reason == null || reason.isEmpty() ? "" : " with reason: " + reason), id, getStreamName()); isStopped.set(true); return CompletableFuture.runAsync(() -> { @@ -97,58 +99,47 @@ protected CompletableFuture shutdownImpl(String reason) { } protected void onSessionClosed(Status status, Throwable th) { - getLogger().info("[{}] onSessionClosed called", id); + logger.info("[{}] onSessionClosed called", id); if (th != null) { - getLogger().error("[{}] Exception in {} stream session: ", id, getStreamName(), th); + logger.warn("[{}] Exception in {} stream session: ", id, getStreamName(), th); } else { if (status.isSuccess()) { if (isStopped.get()) { - getLogger().info("[{}] {} stream session closed successfully", id, getStreamName()); + logger.info("[{}] {} stream session closed successfully", id, getStreamName()); return; } else { - getLogger().warn("[{}] {} stream session was closed on working {}", id, getStreamName(), + logger.warn("[{}] {} stream session was closed on working {}", id, getStreamName(), getStreamName()); } } else { - getLogger().warn("[{}] Error in {} stream session: {}", id, getStreamName(), status); + logger.warn("[{}] Error in {} stream session: {}", id, getStreamName(), status); } } if (isStopped.get()) { - getLogger().info("[{}] {} is already stopped, no need to schedule reconnect", id, getStreamName()); + logger.info("[{}] {} is already stopped, no need to schedule reconnect", id, getStreamName()); return; } - int currentRetry = nextRetryNumber(); - if (currentRetry > 0) { - tryScheduleReconnect(currentRetry); + if (retryMode == RetryMode.ALWAYS || (retryMode == RetryMode.RECOVER && connected)) { + retryNumber++; + tryScheduleReconnect(); return; } if (!isStopped.compareAndSet(false, true)) { - getLogger().warn("[{}] Stopped by retry mode {} after {} retries. But {} is already shut down.", id, - retryMode, currentRetry, getStreamName()); + logger.warn("[{}] Stopped by retry mode {} after {} retries. But {} is already shut down.", id, + retryMode, retryNumber, getStreamName()); return; } - String errorMessage = "[" + id + "] Stopped by retry mode " + retryMode + " after " + currentRetry + + String errorMessage = "[" + id + "] Stopped by retry mode " + retryMode + " after " + retryNumber + " retries. Shutting down " + getStreamName(); - getLogger().error(errorMessage); + logger.warn(errorMessage); shutdownImpl(errorMessage); } - private int nextRetryNumber() { - int next = retry.get() + 1; - switch (retryMode) { - case RECOVER: return next; - case ALWAYS: return Math.max(1, next); - case NONE: - default: - return 0; - } - } - private static class DefaultRetryPolicy extends ExponentialBackoffRetry { private static final int EXP_BACKOFF_BASE_MS = 256; diff --git a/topic/src/main/java/tech/ydb/topic/read/impl/ReaderImpl.java b/topic/src/main/java/tech/ydb/topic/read/impl/ReaderImpl.java index 6c4c6ac6..e37fc0e4 100644 --- a/topic/src/main/java/tech/ydb/topic/read/impl/ReaderImpl.java +++ b/topic/src/main/java/tech/ydb/topic/read/impl/ReaderImpl.java @@ -55,7 +55,7 @@ public abstract class ReaderImpl extends GrpcStreamRetrier { private final String consumerName; public ReaderImpl(TopicRpc topicRpc, ReaderSettings settings) { - super(settings.getRetryMode(), topicRpc.getScheduler()); + super(logger, settings.getRetryMode(), topicRpc.getScheduler()); this.topicRpc = topicRpc; this.settings = settings; this.session = new ReadSessionImpl(); @@ -88,11 +88,6 @@ public ReaderImpl(TopicRpc topicRpc, ReaderSettings settings) { logger.info(message.toString()); } - @Override - protected Logger getLogger() { - return logger; - } - @Override protected String getStreamName() { return "Reader"; diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java index 313f7811..fd49d8ba 100644 --- a/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java +++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java @@ -66,7 +66,7 @@ public abstract class WriterImpl extends GrpcStreamRetrier { private CompletableFuture lastAcceptedMessageFuture; public WriterImpl(TopicRpc topicRpc, WriterSettings settings, Executor compressionExecutor) { - super(settings.getRetryMode(), topicRpc.getScheduler()); + super(logger, settings.getRetryMode(), topicRpc.getScheduler()); this.topicRpc = topicRpc; this.settings = settings; this.session = new WriteSessionImpl(); @@ -81,11 +81,6 @@ public WriterImpl(TopicRpc topicRpc, WriterSettings settings, Executor compressi logger.info(message); } - @Override - protected Logger getLogger() { - return logger; - } - @Override protected String getStreamName() { return "Writer"; diff --git a/topic/src/test/java/tech/ydb/topic/impl/BaseMockedTest.java b/topic/src/test/java/tech/ydb/topic/impl/BaseMockedTest.java new file mode 100644 index 00000000..9deeb4f1 --- /dev/null +++ b/topic/src/test/java/tech/ydb/topic/impl/BaseMockedTest.java @@ -0,0 +1,229 @@ +package tech.ydb.topic.impl; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CompletableFuture; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.stream.Collectors; +import java.util.stream.IntStream; +import java.util.stream.LongStream; + +import org.junit.Assert; +import org.junit.Before; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.stubbing.OngoingStubbing; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +import tech.ydb.core.Status; +import tech.ydb.core.StatusCode; +import tech.ydb.core.grpc.GrpcReadStream; +import tech.ydb.core.grpc.GrpcReadWriteStream; +import tech.ydb.core.grpc.GrpcTransport; +import tech.ydb.proto.StatusCodesProtos; +import tech.ydb.proto.topic.YdbTopic; +import tech.ydb.proto.topic.v1.TopicServiceGrpc; +import tech.ydb.topic.TopicClient; + +/** + * + * @author Aleksandr Gorshenin + */ +public class BaseMockedTest { + private static final Logger logger = LoggerFactory.getLogger(BaseMockedTest.class); + + private interface WriteStream extends + GrpcReadWriteStream { + } + + private final GrpcTransport transport = Mockito.mock(GrpcTransport.class); + private final ScheduledExecutorService scheduler = Mockito.mock(ScheduledExecutorService.class); + private final ScheduledFuture emptyFuture = Mockito.mock(ScheduledFuture.class); + private final WriteStream writeStream = Mockito.mock(WriteStream.class); + private final SchedulerAssert schedulerHelper = new SchedulerAssert(); + + protected final TopicClient client = TopicClient.newClient(transport) + .setCompressionExecutor(Runnable::run) // Disable compression in separate executors + .build(); + + private volatile MockedWriteStream streamMock = null; + + @Before + public void beforeEach() { + streamMock = null; + + Mockito.when(transport.getScheduler()).thenReturn(scheduler); + Mockito.when(transport.readWriteStreamCall(Mockito.eq(TopicServiceGrpc.getStreamWriteMethod()), Mockito.any())) + .thenReturn(writeStream); + + // Every writeStream.start updates mockedWriteStream + Mockito.when(writeStream.start(Mockito.any())).thenAnswer(defaultStreamMockAnswer()); + + // Every writeStream.senbNext add message from client to mockedWriteStream.sent list + Mockito.doAnswer((Answer) (InvocationOnMock iom) -> { + streamMock.sent.add(iom.getArgument(0, YdbTopic.StreamWriteMessage.FromClient.class)); + return null; + }).when(writeStream).sendNext(Mockito.any()); + + Mockito.when(scheduler.schedule(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.any())) + .thenAnswer((InvocationOnMock iom) -> { + logger.debug("mock scheduled task"); + schedulerHelper.tasks.add(iom.getArgument(0, Runnable.class)); + return emptyFuture; + }); + } + + protected MockedWriteStream currentStream() { + return streamMock; + } + + protected SchedulerAssert getScheduler() { + return schedulerHelper; + } + + protected OngoingStubbing> mockStreams() { + return Mockito.when(writeStream.start(Mockito.any())); + } + + protected Answer> defaultStreamMockAnswer() { + return (InvocationOnMock iom) -> { + streamMock = new MockedWriteStream(iom.getArgument(0)); + return streamMock.streamFuture; + }; + } + + protected Answer> errorStreamMockAnswer(StatusCode code) { + return (iom) -> { + streamMock = null; + return CompletableFuture.completedFuture(Status.of(code)); + }; + } + + protected static class SchedulerAssert { + private final Queue tasks = new ConcurrentLinkedQueue<>(); + + public SchedulerAssert hasNoTasks() { + Assert.assertTrue(tasks.isEmpty()); + return this; + } + + public SchedulerAssert hasTasks(int count) { + Assert.assertEquals(count, tasks.size()); + return this; + } + + public SchedulerAssert executeNextTasks(int count) { + Assert.assertTrue(count <= tasks.size()); + + CompletableFuture.runAsync(() -> { + logger.debug("execute {} scheduled tasks", count); + for (int idx = 0; idx < count; idx++) { + tasks.poll().run(); + } + }).join(); + return this; + } + } + + protected static class MockedWriteStream { + private final GrpcReadWriteStream.Observer observer; + private final CompletableFuture streamFuture = new CompletableFuture<>(); + private final List sent = new ArrayList<>(); + private volatile int sentIdx = 0; + + public MockedWriteStream(GrpcReadStream.Observer observer) { + this.observer = observer; + } + + public void complete(Status status) { + streamFuture.complete(status); + } + + public void complete(Throwable th) { + streamFuture.completeExceptionally(th); + } + + public void hasNoNewMessages() { + Assert.assertTrue(sentIdx >= sent.size()); + } + + public Checker nextMsg() { + Assert.assertTrue(sentIdx < sent.size()); + return new Checker(sent.get(sentIdx++)); + } + + public void responseErrorBadRequest() { + YdbTopic.StreamWriteMessage.FromServer msg = YdbTopic.StreamWriteMessage.FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.BAD_REQUEST) + .build(); + observer.onNext(msg); + } + + public void responseInit(long lastSeqNo) { + responseInit(lastSeqNo, 123, "mocked", new int[] { 0, 1, 2}); + } + + public void responseInit(long lastSeqNo, long partitionId, String sessionId, int[] codecs) { + YdbTopic.StreamWriteMessage.FromServer msg = YdbTopic.StreamWriteMessage.FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setInitResponse(YdbTopic.StreamWriteMessage.InitResponse.newBuilder() + .setLastSeqNo(lastSeqNo) + .setPartitionId(partitionId) + .setSessionId(sessionId) + .setSupportedCodecs(YdbTopic.SupportedCodecs.newBuilder() + .addAllCodecs(IntStream.of(codecs).boxed().collect(Collectors.toList()))) + ).build(); + observer.onNext(msg); + } + + public void responseWriteWritten(long firstSeqNo, int messagesCount) { + List acks = LongStream + .range(firstSeqNo, firstSeqNo + messagesCount) + .mapToObj(seqNo -> YdbTopic.StreamWriteMessage.WriteResponse.WriteAck.newBuilder() + .setSeqNo(seqNo) + .setWritten(YdbTopic.StreamWriteMessage.WriteResponse.WriteAck.Written.newBuilder()) + .build()) + .collect(Collectors.toList()); + + YdbTopic.StreamWriteMessage.FromServer msg = YdbTopic.StreamWriteMessage.FromServer.newBuilder() + .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS) + .setWriteResponse(YdbTopic.StreamWriteMessage.WriteResponse.newBuilder().addAllAcks(acks)) + .build(); + observer.onNext(msg); + } + + protected class Checker { + private final YdbTopic.StreamWriteMessage.FromClient msg; + + public Checker(YdbTopic.StreamWriteMessage.FromClient msg) { + this.msg = msg; + } + + public Checker isInit() { + Assert.assertTrue(msg.hasInitRequest()); + return this; + } + + public Checker hasInitPath(String path) { + Assert.assertEquals(path, msg.getInitRequest().getPath()); + return this; + } + + public Checker isWrite() { + Assert.assertTrue(msg.hasWriteRequest()); + return this; + } + + public Checker hasWrite(int codec, int messagesCount) { + Assert.assertEquals(codec, msg.getWriteRequest().getCodec()); + Assert.assertEquals(messagesCount, msg.getWriteRequest().getMessagesCount()); + return this; + } + } + } +} diff --git a/topic/src/test/resources/log4j2.xml b/topic/src/test/resources/log4j2.xml index c799da30..c59b5d2a 100644 --- a/topic/src/test/resources/log4j2.xml +++ b/topic/src/test/resources/log4j2.xml @@ -2,7 +2,7 @@ - +