diff --git a/topic/pom.xml b/topic/pom.xml
index 7518677c..cc193eb6 100644
--- a/topic/pom.xml
+++ b/topic/pom.xml
@@ -40,14 +40,21 @@
zstd-jni
1.5.2-5
+
junit
junit
test
+
+ org.mockito
+ mockito-inline
+ test
+
tech.ydb.test
ydb-junit4-support
+ test
org.apache.logging.log4j
diff --git a/topic/src/main/java/tech/ydb/topic/impl/GrpcStreamRetrier.java b/topic/src/main/java/tech/ydb/topic/impl/GrpcStreamRetrier.java
index 27817bbf..3b00c96e 100644
--- a/topic/src/main/java/tech/ydb/topic/impl/GrpcStreamRetrier.java
+++ b/topic/src/main/java/tech/ydb/topic/impl/GrpcStreamRetrier.java
@@ -6,7 +6,6 @@
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
-import java.util.concurrent.atomic.AtomicInteger;
import org.slf4j.Logger;
@@ -27,18 +26,21 @@ public abstract class GrpcStreamRetrier {
protected final AtomicBoolean isReconnecting = new AtomicBoolean(false);
protected final AtomicBoolean isStopped = new AtomicBoolean(false);
+ private final Logger logger;
private final ScheduledExecutorService scheduler;
private final RetryMode retryMode;
private final RetryPolicy retryPolicy = new DefaultRetryPolicy();
- private final AtomicInteger retry = new AtomicInteger(-1);
- protected GrpcStreamRetrier(RetryMode retryMode, ScheduledExecutorService scheduler) {
+ private volatile boolean connected = false;
+ private volatile int retryNumber = 0;
+
+ protected GrpcStreamRetrier(Logger logger, RetryMode retryMode, ScheduledExecutorService scheduler) {
+ this.logger = logger;
this.retryMode = retryMode;
this.scheduler = scheduler;
this.id = generateRandomId(ID_LENGTH);
}
- protected abstract Logger getLogger();
protected abstract String getStreamName();
protected abstract void onStreamReconnect();
protected abstract void onShutdown(String reason);
@@ -51,34 +53,34 @@ protected static String generateRandomId(int length) {
.toString();
}
- private void tryScheduleReconnect(int retryNumber) {
+ private void tryScheduleReconnect() {
if (!isReconnecting.compareAndSet(false, true)) {
- getLogger().info("[{}] should reconnect {} stream, but reconnect is already in progress", id,
+ logger.info("[{}] should reconnect {} stream, but reconnect is already in progress", id,
getStreamName());
return;
}
- retry.set(retryNumber);
long delay = retryPolicy.nextRetryMs(retryNumber, 0);
- getLogger().warn("[{}] Retry #{}. Scheduling {} reconnect in {}ms...", id, retryNumber, getStreamName(), delay);
+ logger.warn("[{}] Retry #{}. Scheduling {} reconnect in {}ms...", id, retryNumber, getStreamName(), delay);
try {
scheduler.schedule(this::reconnect, delay, TimeUnit.MILLISECONDS);
} catch (RejectedExecutionException exception) {
String errorMessage = "[" + id + "] Couldn't schedule reconnect: scheduler is already shut down. " +
"Shutting down " + getStreamName();
- getLogger().error(errorMessage);
+ logger.error(errorMessage);
shutdownImpl(errorMessage);
}
}
protected void resetRetries() {
- retry.set(0);
+ retryNumber = 0;
+ connected = true;
}
void reconnect() {
- getLogger().info("[{}] {} reconnect #{} started", id, getStreamName(), retry.get());
+ logger.info("[{}] {} reconnect #{} started", id, getStreamName(), retryNumber);
if (!isReconnecting.compareAndSet(true, false)) {
- getLogger().warn("Couldn't reset reconnect flag. Shouldn't happen");
+ logger.warn("Couldn't reset reconnect flag. Shouldn't happen");
}
onStreamReconnect();
}
@@ -88,7 +90,7 @@ protected CompletableFuture shutdownImpl() {
}
protected CompletableFuture shutdownImpl(String reason) {
- getLogger().info("[{}] Shutting down {}"
+ logger.info("[{}] Shutting down {}"
+ (reason == null || reason.isEmpty() ? "" : " with reason: " + reason), id, getStreamName());
isStopped.set(true);
return CompletableFuture.runAsync(() -> {
@@ -97,58 +99,47 @@ protected CompletableFuture shutdownImpl(String reason) {
}
protected void onSessionClosed(Status status, Throwable th) {
- getLogger().info("[{}] onSessionClosed called", id);
+ logger.info("[{}] onSessionClosed called", id);
if (th != null) {
- getLogger().error("[{}] Exception in {} stream session: ", id, getStreamName(), th);
+ logger.warn("[{}] Exception in {} stream session: ", id, getStreamName(), th);
} else {
if (status.isSuccess()) {
if (isStopped.get()) {
- getLogger().info("[{}] {} stream session closed successfully", id, getStreamName());
+ logger.info("[{}] {} stream session closed successfully", id, getStreamName());
return;
} else {
- getLogger().warn("[{}] {} stream session was closed on working {}", id, getStreamName(),
+ logger.warn("[{}] {} stream session was closed on working {}", id, getStreamName(),
getStreamName());
}
} else {
- getLogger().warn("[{}] Error in {} stream session: {}", id, getStreamName(), status);
+ logger.warn("[{}] Error in {} stream session: {}", id, getStreamName(), status);
}
}
if (isStopped.get()) {
- getLogger().info("[{}] {} is already stopped, no need to schedule reconnect", id, getStreamName());
+ logger.info("[{}] {} is already stopped, no need to schedule reconnect", id, getStreamName());
return;
}
- int currentRetry = nextRetryNumber();
- if (currentRetry > 0) {
- tryScheduleReconnect(currentRetry);
+ if (retryMode == RetryMode.ALWAYS || (retryMode == RetryMode.RECOVER && connected)) {
+ retryNumber++;
+ tryScheduleReconnect();
return;
}
if (!isStopped.compareAndSet(false, true)) {
- getLogger().warn("[{}] Stopped by retry mode {} after {} retries. But {} is already shut down.", id,
- retryMode, currentRetry, getStreamName());
+ logger.warn("[{}] Stopped by retry mode {} after {} retries. But {} is already shut down.", id,
+ retryMode, retryNumber, getStreamName());
return;
}
- String errorMessage = "[" + id + "] Stopped by retry mode " + retryMode + " after " + currentRetry +
+ String errorMessage = "[" + id + "] Stopped by retry mode " + retryMode + " after " + retryNumber +
" retries. Shutting down " + getStreamName();
- getLogger().error(errorMessage);
+ logger.warn(errorMessage);
shutdownImpl(errorMessage);
}
- private int nextRetryNumber() {
- int next = retry.get() + 1;
- switch (retryMode) {
- case RECOVER: return next;
- case ALWAYS: return Math.max(1, next);
- case NONE:
- default:
- return 0;
- }
- }
-
private static class DefaultRetryPolicy extends ExponentialBackoffRetry {
private static final int EXP_BACKOFF_BASE_MS = 256;
diff --git a/topic/src/main/java/tech/ydb/topic/read/impl/ReaderImpl.java b/topic/src/main/java/tech/ydb/topic/read/impl/ReaderImpl.java
index 6c4c6ac6..e37fc0e4 100644
--- a/topic/src/main/java/tech/ydb/topic/read/impl/ReaderImpl.java
+++ b/topic/src/main/java/tech/ydb/topic/read/impl/ReaderImpl.java
@@ -55,7 +55,7 @@ public abstract class ReaderImpl extends GrpcStreamRetrier {
private final String consumerName;
public ReaderImpl(TopicRpc topicRpc, ReaderSettings settings) {
- super(settings.getRetryMode(), topicRpc.getScheduler());
+ super(logger, settings.getRetryMode(), topicRpc.getScheduler());
this.topicRpc = topicRpc;
this.settings = settings;
this.session = new ReadSessionImpl();
@@ -88,11 +88,6 @@ public ReaderImpl(TopicRpc topicRpc, ReaderSettings settings) {
logger.info(message.toString());
}
- @Override
- protected Logger getLogger() {
- return logger;
- }
-
@Override
protected String getStreamName() {
return "Reader";
diff --git a/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java b/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java
index 313f7811..fd49d8ba 100644
--- a/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java
+++ b/topic/src/main/java/tech/ydb/topic/write/impl/WriterImpl.java
@@ -66,7 +66,7 @@ public abstract class WriterImpl extends GrpcStreamRetrier {
private CompletableFuture lastAcceptedMessageFuture;
public WriterImpl(TopicRpc topicRpc, WriterSettings settings, Executor compressionExecutor) {
- super(settings.getRetryMode(), topicRpc.getScheduler());
+ super(logger, settings.getRetryMode(), topicRpc.getScheduler());
this.topicRpc = topicRpc;
this.settings = settings;
this.session = new WriteSessionImpl();
@@ -81,11 +81,6 @@ public WriterImpl(TopicRpc topicRpc, WriterSettings settings, Executor compressi
logger.info(message);
}
- @Override
- protected Logger getLogger() {
- return logger;
- }
-
@Override
protected String getStreamName() {
return "Writer";
diff --git a/topic/src/test/java/tech/ydb/topic/impl/BaseMockedTest.java b/topic/src/test/java/tech/ydb/topic/impl/BaseMockedTest.java
new file mode 100644
index 00000000..9deeb4f1
--- /dev/null
+++ b/topic/src/test/java/tech/ydb/topic/impl/BaseMockedTest.java
@@ -0,0 +1,229 @@
+package tech.ydb.topic.impl;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Queue;
+import java.util.concurrent.CompletableFuture;
+import java.util.concurrent.ConcurrentLinkedQueue;
+import java.util.concurrent.ScheduledExecutorService;
+import java.util.concurrent.ScheduledFuture;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import java.util.stream.LongStream;
+
+import org.junit.Assert;
+import org.junit.Before;
+import org.mockito.Mockito;
+import org.mockito.invocation.InvocationOnMock;
+import org.mockito.stubbing.Answer;
+import org.mockito.stubbing.OngoingStubbing;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+import tech.ydb.core.Status;
+import tech.ydb.core.StatusCode;
+import tech.ydb.core.grpc.GrpcReadStream;
+import tech.ydb.core.grpc.GrpcReadWriteStream;
+import tech.ydb.core.grpc.GrpcTransport;
+import tech.ydb.proto.StatusCodesProtos;
+import tech.ydb.proto.topic.YdbTopic;
+import tech.ydb.proto.topic.v1.TopicServiceGrpc;
+import tech.ydb.topic.TopicClient;
+
+/**
+ *
+ * @author Aleksandr Gorshenin
+ */
+public class BaseMockedTest {
+ private static final Logger logger = LoggerFactory.getLogger(BaseMockedTest.class);
+
+ private interface WriteStream extends
+ GrpcReadWriteStream {
+ }
+
+ private final GrpcTransport transport = Mockito.mock(GrpcTransport.class);
+ private final ScheduledExecutorService scheduler = Mockito.mock(ScheduledExecutorService.class);
+ private final ScheduledFuture> emptyFuture = Mockito.mock(ScheduledFuture.class);
+ private final WriteStream writeStream = Mockito.mock(WriteStream.class);
+ private final SchedulerAssert schedulerHelper = new SchedulerAssert();
+
+ protected final TopicClient client = TopicClient.newClient(transport)
+ .setCompressionExecutor(Runnable::run) // Disable compression in separate executors
+ .build();
+
+ private volatile MockedWriteStream streamMock = null;
+
+ @Before
+ public void beforeEach() {
+ streamMock = null;
+
+ Mockito.when(transport.getScheduler()).thenReturn(scheduler);
+ Mockito.when(transport.readWriteStreamCall(Mockito.eq(TopicServiceGrpc.getStreamWriteMethod()), Mockito.any()))
+ .thenReturn(writeStream);
+
+ // Every writeStream.start updates mockedWriteStream
+ Mockito.when(writeStream.start(Mockito.any())).thenAnswer(defaultStreamMockAnswer());
+
+ // Every writeStream.senbNext add message from client to mockedWriteStream.sent list
+ Mockito.doAnswer((Answer) (InvocationOnMock iom) -> {
+ streamMock.sent.add(iom.getArgument(0, YdbTopic.StreamWriteMessage.FromClient.class));
+ return null;
+ }).when(writeStream).sendNext(Mockito.any());
+
+ Mockito.when(scheduler.schedule(Mockito.any(Runnable.class), Mockito.anyLong(), Mockito.any()))
+ .thenAnswer((InvocationOnMock iom) -> {
+ logger.debug("mock scheduled task");
+ schedulerHelper.tasks.add(iom.getArgument(0, Runnable.class));
+ return emptyFuture;
+ });
+ }
+
+ protected MockedWriteStream currentStream() {
+ return streamMock;
+ }
+
+ protected SchedulerAssert getScheduler() {
+ return schedulerHelper;
+ }
+
+ protected OngoingStubbing> mockStreams() {
+ return Mockito.when(writeStream.start(Mockito.any()));
+ }
+
+ protected Answer> defaultStreamMockAnswer() {
+ return (InvocationOnMock iom) -> {
+ streamMock = new MockedWriteStream(iom.getArgument(0));
+ return streamMock.streamFuture;
+ };
+ }
+
+ protected Answer> errorStreamMockAnswer(StatusCode code) {
+ return (iom) -> {
+ streamMock = null;
+ return CompletableFuture.completedFuture(Status.of(code));
+ };
+ }
+
+ protected static class SchedulerAssert {
+ private final Queue tasks = new ConcurrentLinkedQueue<>();
+
+ public SchedulerAssert hasNoTasks() {
+ Assert.assertTrue(tasks.isEmpty());
+ return this;
+ }
+
+ public SchedulerAssert hasTasks(int count) {
+ Assert.assertEquals(count, tasks.size());
+ return this;
+ }
+
+ public SchedulerAssert executeNextTasks(int count) {
+ Assert.assertTrue(count <= tasks.size());
+
+ CompletableFuture.runAsync(() -> {
+ logger.debug("execute {} scheduled tasks", count);
+ for (int idx = 0; idx < count; idx++) {
+ tasks.poll().run();
+ }
+ }).join();
+ return this;
+ }
+ }
+
+ protected static class MockedWriteStream {
+ private final GrpcReadWriteStream.Observer observer;
+ private final CompletableFuture streamFuture = new CompletableFuture<>();
+ private final List sent = new ArrayList<>();
+ private volatile int sentIdx = 0;
+
+ public MockedWriteStream(GrpcReadStream.Observer observer) {
+ this.observer = observer;
+ }
+
+ public void complete(Status status) {
+ streamFuture.complete(status);
+ }
+
+ public void complete(Throwable th) {
+ streamFuture.completeExceptionally(th);
+ }
+
+ public void hasNoNewMessages() {
+ Assert.assertTrue(sentIdx >= sent.size());
+ }
+
+ public Checker nextMsg() {
+ Assert.assertTrue(sentIdx < sent.size());
+ return new Checker(sent.get(sentIdx++));
+ }
+
+ public void responseErrorBadRequest() {
+ YdbTopic.StreamWriteMessage.FromServer msg = YdbTopic.StreamWriteMessage.FromServer.newBuilder()
+ .setStatus(StatusCodesProtos.StatusIds.StatusCode.BAD_REQUEST)
+ .build();
+ observer.onNext(msg);
+ }
+
+ public void responseInit(long lastSeqNo) {
+ responseInit(lastSeqNo, 123, "mocked", new int[] { 0, 1, 2});
+ }
+
+ public void responseInit(long lastSeqNo, long partitionId, String sessionId, int[] codecs) {
+ YdbTopic.StreamWriteMessage.FromServer msg = YdbTopic.StreamWriteMessage.FromServer.newBuilder()
+ .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS)
+ .setInitResponse(YdbTopic.StreamWriteMessage.InitResponse.newBuilder()
+ .setLastSeqNo(lastSeqNo)
+ .setPartitionId(partitionId)
+ .setSessionId(sessionId)
+ .setSupportedCodecs(YdbTopic.SupportedCodecs.newBuilder()
+ .addAllCodecs(IntStream.of(codecs).boxed().collect(Collectors.toList())))
+ ).build();
+ observer.onNext(msg);
+ }
+
+ public void responseWriteWritten(long firstSeqNo, int messagesCount) {
+ List acks = LongStream
+ .range(firstSeqNo, firstSeqNo + messagesCount)
+ .mapToObj(seqNo -> YdbTopic.StreamWriteMessage.WriteResponse.WriteAck.newBuilder()
+ .setSeqNo(seqNo)
+ .setWritten(YdbTopic.StreamWriteMessage.WriteResponse.WriteAck.Written.newBuilder())
+ .build())
+ .collect(Collectors.toList());
+
+ YdbTopic.StreamWriteMessage.FromServer msg = YdbTopic.StreamWriteMessage.FromServer.newBuilder()
+ .setStatus(StatusCodesProtos.StatusIds.StatusCode.SUCCESS)
+ .setWriteResponse(YdbTopic.StreamWriteMessage.WriteResponse.newBuilder().addAllAcks(acks))
+ .build();
+ observer.onNext(msg);
+ }
+
+ protected class Checker {
+ private final YdbTopic.StreamWriteMessage.FromClient msg;
+
+ public Checker(YdbTopic.StreamWriteMessage.FromClient msg) {
+ this.msg = msg;
+ }
+
+ public Checker isInit() {
+ Assert.assertTrue(msg.hasInitRequest());
+ return this;
+ }
+
+ public Checker hasInitPath(String path) {
+ Assert.assertEquals(path, msg.getInitRequest().getPath());
+ return this;
+ }
+
+ public Checker isWrite() {
+ Assert.assertTrue(msg.hasWriteRequest());
+ return this;
+ }
+
+ public Checker hasWrite(int codec, int messagesCount) {
+ Assert.assertEquals(codec, msg.getWriteRequest().getCodec());
+ Assert.assertEquals(messagesCount, msg.getWriteRequest().getMessagesCount());
+ return this;
+ }
+ }
+ }
+}
diff --git a/topic/src/test/resources/log4j2.xml b/topic/src/test/resources/log4j2.xml
index c799da30..c59b5d2a 100644
--- a/topic/src/test/resources/log4j2.xml
+++ b/topic/src/test/resources/log4j2.xml
@@ -2,7 +2,7 @@
-
+