getInputsEdges() {
+ return inputsEdges;
}
public StreamProcessor getStreamProcessor() {
@@ -75,9 +85,6 @@ public NodeType getNodeType() {
public void setNodeType(VertexType vertexType) {
switch (vertexType) {
- case MASTER:
- this.nodeType = NodeType.MASTER;
- break;
case SOURCE:
this.nodeType = NodeType.SOURCE;
break;
@@ -89,8 +96,18 @@ public void setNodeType(VertexType vertexType) {
}
}
+ @Override
+ public String toString() {
+ final StringBuilder sb = new StringBuilder("ExecutionNode{");
+ sb.append("nodeId=").append(nodeId);
+ sb.append(", parallelism=").append(parallelism);
+ sb.append(", nodeType=").append(nodeType);
+ sb.append(", streamProcessor=").append(streamProcessor);
+ sb.append('}');
+ return sb.toString();
+ }
+
public enum NodeType {
- MASTER,
SOURCE,
PROCESS,
SINK,
diff --git a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java
similarity index 67%
rename from java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionTask.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java
index 72d3eaa6fd12..7e205d51f285 100644
--- a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionTask.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java
@@ -1,21 +1,22 @@
-package org.ray.streaming.core.graph;
+package org.ray.streaming.runtime.core.graph;
import java.io.Serializable;
+
import org.ray.api.RayActor;
-import org.ray.streaming.core.runtime.StreamWorker;
+import org.ray.streaming.runtime.worker.JobWorker;
/**
* ExecutionTask is minimal execution unit.
- *
+ *
* An ExecutionNode has n ExecutionTasks if parallelism is n.
*/
public class ExecutionTask implements Serializable {
private int taskId;
private int taskIndex;
- private RayActor worker;
+ private RayActor worker;
- public ExecutionTask(int taskId, int taskIndex, RayActor worker) {
+ public ExecutionTask(int taskId, int taskIndex, RayActor worker) {
this.taskId = taskId;
this.taskIndex = taskIndex;
this.worker = worker;
@@ -37,11 +38,11 @@ public void setTaskIndex(int taskIndex) {
this.taskIndex = taskIndex;
}
- public RayActor getWorker() {
+ public RayActor getWorker() {
return worker;
}
- public void setWorker(RayActor worker) {
+ public void setWorker(RayActor worker) {
this.worker = worker;
}
}
diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/OneInputProcessor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/OneInputProcessor.java
similarity index 93%
rename from java/streaming/src/main/java/org/ray/streaming/core/processor/OneInputProcessor.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/OneInputProcessor.java
index 3d675aa1604b..7b9fbc4c1e3d 100644
--- a/java/streaming/src/main/java/org/ray/streaming/core/processor/OneInputProcessor.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/OneInputProcessor.java
@@ -1,4 +1,4 @@
-package org.ray.streaming.core.processor;
+package org.ray.streaming.runtime.core.processor;
import org.ray.streaming.message.Record;
import org.ray.streaming.operator.OneInputOperator;
diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/ProcessBuilder.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/ProcessBuilder.java
similarity index 91%
rename from java/streaming/src/main/java/org/ray/streaming/core/processor/ProcessBuilder.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/ProcessBuilder.java
index 2a8c1e63437b..07dfda0f8a25 100644
--- a/java/streaming/src/main/java/org/ray/streaming/core/processor/ProcessBuilder.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/ProcessBuilder.java
@@ -1,4 +1,4 @@
-package org.ray.streaming.core.processor;
+package org.ray.streaming.runtime.core.processor;
import org.ray.streaming.operator.OneInputOperator;
import org.ray.streaming.operator.OperatorType;
@@ -18,8 +18,6 @@ public static StreamProcessor buildProcessor(StreamOperator streamOperator) {
LOGGER.info("Building StreamProcessor, operator type = {}, operator = {}.", type,
streamOperator.getClass().getSimpleName().toString());
switch (type) {
- case MASTER:
- return new MasterProcessor(null);
case SOURCE:
return new SourceProcessor<>((SourceOperator) streamOperator);
case ONE_INPUT:
diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/Processor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/Processor.java
similarity index 72%
rename from java/streaming/src/main/java/org/ray/streaming/core/processor/Processor.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/Processor.java
index 02b38f08ad81..0578bc6cabbb 100644
--- a/java/streaming/src/main/java/org/ray/streaming/core/processor/Processor.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/Processor.java
@@ -1,9 +1,9 @@
-package org.ray.streaming.core.processor;
+package org.ray.streaming.runtime.core.processor;
import java.io.Serializable;
import java.util.List;
import org.ray.streaming.api.collector.Collector;
-import org.ray.streaming.core.runtime.context.RuntimeContext;
+import org.ray.streaming.api.context.RuntimeContext;
public interface Processor extends Serializable {
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/SourceProcessor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/SourceProcessor.java
new file mode 100644
index 000000000000..1e36ae3f3872
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/SourceProcessor.java
@@ -0,0 +1,30 @@
+package org.ray.streaming.runtime.core.processor;
+
+import org.ray.streaming.message.Record;
+import org.ray.streaming.operator.impl.SourceOperator;
+
+/**
+ * The processor for the stream sources, containing a SourceOperator.
+ *
+ * @param The type of source data.
+ */
+public class SourceProcessor extends StreamProcessor> {
+
+ public SourceProcessor(SourceOperator operator) {
+ super(operator);
+ }
+
+ @Override
+ public void process(Record record) {
+ throw new UnsupportedOperationException("SourceProcessor should not process record");
+ }
+
+ public void run() {
+ operator.run();
+ }
+
+ @Override
+ public void close() {
+
+ }
+}
diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/StreamProcessor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/StreamProcessor.java
similarity index 68%
rename from java/streaming/src/main/java/org/ray/streaming/core/processor/StreamProcessor.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/StreamProcessor.java
index 3dc307e01873..a2ecef2633f5 100644
--- a/java/streaming/src/main/java/org/ray/streaming/core/processor/StreamProcessor.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/StreamProcessor.java
@@ -1,9 +1,11 @@
-package org.ray.streaming.core.processor;
+package org.ray.streaming.runtime.core.processor;
import java.util.List;
import org.ray.streaming.api.collector.Collector;
-import org.ray.streaming.core.runtime.context.RuntimeContext;
+import org.ray.streaming.api.context.RuntimeContext;
import org.ray.streaming.operator.Operator;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
/**
* StreamingProcessor is a process unit for a operator.
@@ -12,6 +14,7 @@
* @param Type of the specific operator class.
*/
public abstract class StreamProcessor implements Processor {
+ private static final Logger LOGGER = LoggerFactory.getLogger(StreamProcessor.class);
protected List collectors;
protected RuntimeContext runtimeContext;
@@ -28,6 +31,11 @@ public void open(List collectors, RuntimeContext runtimeContext) {
if (operator != null) {
this.operator.open(collectors, runtimeContext);
}
+ LOGGER.info("opened {}", this);
}
+ @Override
+ public String toString() {
+ return this.getClass().getSimpleName();
+ }
}
diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/TwoInputProcessor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/TwoInputProcessor.java
similarity index 68%
rename from java/streaming/src/main/java/org/ray/streaming/core/processor/TwoInputProcessor.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/TwoInputProcessor.java
index 88094b0c8cbb..fbaf84a16a84 100644
--- a/java/streaming/src/main/java/org/ray/streaming/core/processor/TwoInputProcessor.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/TwoInputProcessor.java
@@ -1,4 +1,4 @@
-package org.ray.streaming.core.processor;
+package org.ray.streaming.runtime.core.processor;
import org.ray.streaming.message.Record;
import org.ray.streaming.operator.TwoInputOperator;
@@ -6,12 +6,10 @@
import org.slf4j.LoggerFactory;
public class TwoInputProcessor extends StreamProcessor> {
-
private static final Logger LOGGER = LoggerFactory.getLogger(TwoInputProcessor.class);
- // TODO(zhenxuanpan): Set leftStream and rightStream.
private String leftStream;
- private String rigthStream;
+ private String rightStream;
public TwoInputProcessor(TwoInputOperator operator) {
super(operator);
@@ -34,4 +32,20 @@ public void process(Record record) {
public void close() {
this.operator.close();
}
+
+ public String getLeftStream() {
+ return leftStream;
+ }
+
+ public void setLeftStream(String leftStream) {
+ this.leftStream = leftStream;
+ }
+
+ public String getRightStream() {
+ return rightStream;
+ }
+
+ public void setRightStream(String rightStream) {
+ this.rightStream = rightStream;
+ }
}
diff --git a/java/streaming/src/main/java/org/ray/streaming/schedule/ITaskAssign.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/ITaskAssign.java
similarity index 56%
rename from java/streaming/src/main/java/org/ray/streaming/schedule/ITaskAssign.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/ITaskAssign.java
index d9c7cd507863..9fc1d0cdb6c8 100644
--- a/java/streaming/src/main/java/org/ray/streaming/schedule/ITaskAssign.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/ITaskAssign.java
@@ -1,11 +1,11 @@
-package org.ray.streaming.schedule;
+package org.ray.streaming.runtime.schedule;
import java.io.Serializable;
import java.util.List;
import org.ray.api.RayActor;
-import org.ray.streaming.core.graph.ExecutionGraph;
-import org.ray.streaming.core.runtime.StreamWorker;
import org.ray.streaming.plan.Plan;
+import org.ray.streaming.runtime.core.graph.ExecutionGraph;
+import org.ray.streaming.runtime.worker.JobWorker;
/**
* Interface of the task assigning strategy.
@@ -15,6 +15,6 @@ public interface ITaskAssign extends Serializable {
/**
* Assign logical plan to physical execution graph.
*/
- ExecutionGraph assign(Plan plan, List> workers);
+ ExecutionGraph assign(Plan plan, List> workers);
}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java
new file mode 100644
index 000000000000..8a30c75fa419
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java
@@ -0,0 +1,65 @@
+package org.ray.streaming.runtime.schedule;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
+import org.ray.api.Ray;
+import org.ray.api.RayActor;
+import org.ray.api.RayObject;
+import org.ray.streaming.plan.Plan;
+import org.ray.streaming.plan.PlanVertex;
+import org.ray.streaming.runtime.cluster.ResourceManager;
+import org.ray.streaming.runtime.core.graph.ExecutionGraph;
+import org.ray.streaming.runtime.core.graph.ExecutionNode;
+import org.ray.streaming.runtime.core.graph.ExecutionTask;
+import org.ray.streaming.runtime.worker.JobWorker;
+import org.ray.streaming.runtime.worker.context.WorkerContext;
+import org.ray.streaming.schedule.JobScheduler;
+
+/**
+ * JobSchedulerImpl schedules workers by the Plan and the resource information
+ * from ResourceManager.
+ */
+public class JobSchedulerImpl implements JobScheduler {
+ private Plan plan;
+ private Map jobConfig;
+ private ResourceManager resourceManager;
+ private ITaskAssign taskAssign;
+
+ public JobSchedulerImpl() {
+ this.resourceManager = new ResourceManager();
+ this.taskAssign = new TaskAssignImpl();
+ }
+
+ /**
+ * Schedule physical plan to execution graph, and call streaming worker to init and run.
+ */
+ @Override
+ public void schedule(Plan plan, Map jobConfig) {
+ this.jobConfig = jobConfig;
+ this.plan = plan;
+ System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
+ Ray.init();
+
+ List> workers = this.resourceManager.createWorkers(getPlanWorker());
+ ExecutionGraph executionGraph = this.taskAssign.assign(this.plan, workers);
+
+ List executionNodes = executionGraph.getExecutionNodeList();
+ List> waits = new ArrayList<>();
+ for (ExecutionNode executionNode : executionNodes) {
+ List executionTasks = executionNode.getExecutionTasks();
+ for (ExecutionTask executionTask : executionTasks) {
+ int taskId = executionTask.getTaskId();
+ RayActor streamWorker = executionTask.getWorker();
+ waits.add(Ray.call(JobWorker::init, streamWorker,
+ new WorkerContext(taskId, executionGraph, jobConfig)));
+ }
+ }
+ Ray.wait(waits);
+ }
+
+ private int getPlanWorker() {
+ List planVertexList = plan.getPlanVertexList();
+ return planVertexList.stream().map(PlanVertex::getParallelism).reduce(0, Integer::sum);
+ }
+}
diff --git a/java/streaming/src/main/java/org/ray/streaming/schedule/impl/TaskAssignImpl.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignImpl.java
similarity index 73%
rename from java/streaming/src/main/java/org/ray/streaming/schedule/impl/TaskAssignImpl.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignImpl.java
index be3f6ae35113..20b3e89712d6 100644
--- a/java/streaming/src/main/java/org/ray/streaming/schedule/impl/TaskAssignImpl.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignImpl.java
@@ -1,4 +1,4 @@
-package org.ray.streaming.schedule.impl;
+package org.ray.streaming.runtime.schedule;
import java.util.ArrayList;
import java.util.HashMap;
@@ -6,29 +6,28 @@
import java.util.Map;
import java.util.stream.Collectors;
import org.ray.api.RayActor;
-import org.ray.streaming.core.graph.ExecutionEdge;
-import org.ray.streaming.core.graph.ExecutionGraph;
-import org.ray.streaming.core.graph.ExecutionNode;
-import org.ray.streaming.core.graph.ExecutionTask;
-import org.ray.streaming.core.processor.ProcessBuilder;
-import org.ray.streaming.core.processor.StreamProcessor;
-import org.ray.streaming.core.runtime.StreamWorker;
import org.ray.streaming.plan.Plan;
import org.ray.streaming.plan.PlanEdge;
import org.ray.streaming.plan.PlanVertex;
-import org.ray.streaming.schedule.ITaskAssign;
+import org.ray.streaming.runtime.core.graph.ExecutionEdge;
+import org.ray.streaming.runtime.core.graph.ExecutionGraph;
+import org.ray.streaming.runtime.core.graph.ExecutionNode;
+import org.ray.streaming.runtime.core.graph.ExecutionTask;
+import org.ray.streaming.runtime.core.processor.ProcessBuilder;
+import org.ray.streaming.runtime.core.processor.StreamProcessor;
+import org.ray.streaming.runtime.worker.JobWorker;
public class TaskAssignImpl implements ITaskAssign {
/**
* Assign an optimized logical plan to execution graph.
*
- * @param plan The logical plan.
+ * @param plan The logical plan.
* @param workers The worker actors.
* @return The physical execution graph.
*/
@Override
- public ExecutionGraph assign(Plan plan, List> workers) {
+ public ExecutionGraph assign(Plan plan, List> workers) {
List planVertices = plan.getPlanVertexList();
List planEdges = plan.getPlanEdgeList();
@@ -45,7 +44,7 @@ public ExecutionGraph assign(Plan plan, List> workers) {
}
StreamProcessor streamProcessor = ProcessBuilder
.buildProcessor(planVertex.getStreamOperator());
- executionNode.setExecutionTaskList(vertexTasks);
+ executionNode.setExecutionTasks(vertexTasks);
executionNode.setStreamProcessor(streamProcessor);
idToExecutionNode.put(executionNode.getNodeId(), executionNode);
}
@@ -57,6 +56,7 @@ public ExecutionGraph assign(Plan plan, List> workers) {
ExecutionEdge executionEdge = new ExecutionEdge(srcNodeId, targetNodeId,
planEdge.getPartition());
idToExecutionNode.get(srcNodeId).addExecutionEdge(executionEdge);
+ idToExecutionNode.get(targetNodeId).addInputEdge(executionEdge);
}
List executionNodes = idToExecutionNode.values().stream()
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelID.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelID.java
new file mode 100644
index 000000000000..1549408627e1
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelID.java
@@ -0,0 +1,182 @@
+package org.ray.streaming.runtime.transfer;
+
+import com.google.common.base.FinalizablePhantomReference;
+import com.google.common.base.FinalizableReferenceQueue;
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Sets;
+import com.google.common.io.BaseEncoding;
+import java.lang.ref.Reference;
+import java.nio.ByteBuffer;
+import java.util.Random;
+import java.util.Set;
+import sun.nio.ch.DirectBuffer;
+
+/**
+ * ChannelID is used to identify a transfer channel between a upstream worker
+ * and downstream worker.
+ */
+public class ChannelID {
+ public static final int ID_LENGTH = 20;
+ private static final FinalizableReferenceQueue REFERENCE_QUEUE = new FinalizableReferenceQueue();
+ // This ensures that the FinalizablePhantomReference itself is not garbage-collected.
+ private static final Set> references = Sets.newConcurrentHashSet();
+
+ private final byte[] bytes;
+ private final String strId;
+ private final ByteBuffer buffer;
+ private final long address;
+ private final long nativeIdPtr;
+
+ private ChannelID(String strId, byte[] idBytes) {
+ this.strId = strId;
+ this.bytes = idBytes;
+ ByteBuffer directBuffer = ByteBuffer.allocateDirect(ID_LENGTH);
+ directBuffer.put(bytes);
+ directBuffer.rewind();
+ this.buffer = directBuffer;
+ this.address = ((DirectBuffer) (buffer)).address();
+ long nativeIdPtr = 0;
+ nativeIdPtr = createNativeID(address);
+ this.nativeIdPtr = nativeIdPtr;
+ }
+
+ public byte[] getBytes() {
+ return bytes;
+ }
+
+ public ByteBuffer getBuffer() {
+ return buffer;
+ }
+
+ public long getAddress() {
+ return address;
+ }
+
+ public long getNativeIdPtr() {
+ if (nativeIdPtr == 0) {
+ throw new IllegalStateException("native ID not available");
+ }
+ return nativeIdPtr;
+ }
+
+ @Override
+ public String toString() {
+ return strId;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ ChannelID that = (ChannelID) o;
+ return strId.equals(that.strId);
+ }
+
+ @Override
+ public int hashCode() {
+ return strId.hashCode();
+ }
+
+ private static native long createNativeID(long idAddress);
+
+ private static native void destroyNativeID(long nativeIdPtr);
+
+ /**
+ * @param id hex string representation of channel id
+ */
+ public static ChannelID from(String id) {
+ return from(id, ChannelID.idStrToBytes(id));
+ }
+
+ /**
+ * @param idBytes bytes representation of channel id
+ */
+ public static ChannelID from(byte[] idBytes) {
+ return from(idBytesToStr(idBytes), idBytes);
+ }
+
+ private static ChannelID from(String strID, byte[] idBytes) {
+ ChannelID id = new ChannelID(strID, idBytes);
+ long nativeIdPtr = id.nativeIdPtr;
+ if (nativeIdPtr != 0) {
+ Reference reference =
+ new FinalizablePhantomReference(id, REFERENCE_QUEUE) {
+ @Override
+ public void finalizeReferent() {
+ destroyNativeID(nativeIdPtr);
+ references.remove(this);
+ }
+ };
+ references.add(reference);
+ }
+ return id;
+ }
+
+ /**
+ * @return a random channel id string
+ */
+ public static String genRandomIdStr() {
+ StringBuilder sb = new StringBuilder();
+ Random random = new Random();
+ for (int i = 0; i < ChannelID.ID_LENGTH * 2; ++i) {
+ sb.append((char) (random.nextInt(6) + 'A'));
+ }
+ return sb.toString();
+ }
+
+ /**
+ * Generate channel name, which will be 20 character
+ *
+ * @param fromTaskId upstream task id
+ * @param toTaskId downstream task id
+ * @return channel name
+ */
+ public static String genIdStr(int fromTaskId, int toTaskId, long ts) {
+ /*
+ | Head | Timestamp | Empty | From | To |
+ | 8 bytes | 4bytes | 4bytes| 2bytes| 2bytes |
+ */
+ Preconditions.checkArgument(fromTaskId < Short.MAX_VALUE,
+ "fromTaskId %d is larger than %d", fromTaskId, Short.MAX_VALUE);
+ Preconditions.checkArgument(toTaskId < Short.MAX_VALUE,
+ "toTaskId %d is larger than %d", fromTaskId, Short.MAX_VALUE);
+ byte[] channelName = new byte[20];
+
+ for (int i = 11; i >= 8; i--) {
+ channelName[i] = (byte) (ts & 0xff);
+ ts >>= 8;
+ }
+
+ channelName[16] = (byte) ((fromTaskId & 0xffff) >> 8);
+ channelName[17] = (byte) (fromTaskId & 0xff);
+ channelName[18] = (byte) ((toTaskId & 0xffff) >> 8);
+ channelName[19] = (byte) (toTaskId & 0xff);
+
+ return ChannelID.idBytesToStr(channelName);
+ }
+
+ /**
+ * @param id hex string representation of channel id
+ * @return bytes representation of channel id
+ */
+ static byte[] idStrToBytes(String id) {
+ byte[] idBytes = BaseEncoding.base16().decode(id.toUpperCase());
+ assert idBytes.length == ChannelID.ID_LENGTH;
+ return idBytes;
+ }
+
+ /**
+ * @param id bytes representation of channel id
+ * @return hex string representation of channel id
+ */
+ static String idBytesToStr(byte[] id) {
+ assert id.length == ChannelID.ID_LENGTH;
+ return BaseEncoding.base16().encode(id).toLowerCase();
+ }
+
+}
+
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInitException.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInitException.java
new file mode 100644
index 000000000000..9c5206ba3308
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInitException.java
@@ -0,0 +1,24 @@
+package org.ray.streaming.runtime.transfer;
+
+import java.util.ArrayList;
+import java.util.List;
+
+public class ChannelInitException extends Exception {
+
+ private final List abnormalQueues;
+
+ public ChannelInitException(String message, List abnormalQueues) {
+ super(message);
+ this.abnormalQueues = abnormalQueues;
+ }
+
+ public List getAbnormalChannels() {
+ return abnormalQueues;
+ }
+
+ public List getAbnormalChannelsString() {
+ List res = new ArrayList<>();
+ abnormalQueues.forEach(ele -> res.add(ChannelID.idBytesToStr(ele)));
+ return res;
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInterruptException.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInterruptException.java
new file mode 100644
index 000000000000..b922bc415516
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInterruptException.java
@@ -0,0 +1,11 @@
+package org.ray.streaming.runtime.transfer;
+
+public class ChannelInterruptException extends RuntimeException {
+ public ChannelInterruptException() {
+ super();
+ }
+
+ public ChannelInterruptException(String message) {
+ super(message);
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelUtils.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelUtils.java
new file mode 100644
index 000000000000..893c213978be
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelUtils.java
@@ -0,0 +1,40 @@
+package org.ray.streaming.runtime.transfer;
+
+import java.util.Map;
+import org.ray.streaming.runtime.generated.Streaming;
+import org.ray.streaming.util.Config;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class ChannelUtils {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ChannelUtils.class);
+
+ static byte[] toNativeConf(Map conf) {
+ Streaming.StreamingConfig.Builder builder = Streaming.StreamingConfig.newBuilder();
+ if (conf.containsKey(Config.STREAMING_JOB_NAME)) {
+ builder.setJobName(conf.get(Config.STREAMING_JOB_NAME));
+ }
+ if (conf.containsKey(Config.TASK_JOB_ID)) {
+ builder.setTaskJobId(conf.get(Config.TASK_JOB_ID));
+ }
+ if (conf.containsKey(Config.STREAMING_WORKER_NAME)) {
+ builder.setWorkerName(conf.get(Config.STREAMING_WORKER_NAME));
+ }
+ if (conf.containsKey(Config.STREAMING_OP_NAME)) {
+ builder.setOpName(conf.get(Config.STREAMING_OP_NAME));
+ }
+ if (conf.containsKey(Config.STREAMING_RING_BUFFER_CAPACITY)) {
+ builder.setRingBufferCapacity(
+ Integer.parseInt(conf.get(Config.STREAMING_RING_BUFFER_CAPACITY)));
+ }
+ if (conf.containsKey(Config.STREAMING_EMPTY_MESSAGE_INTERVAL)) {
+ builder.setEmptyMessageInterval(
+ Integer.parseInt(conf.get(Config.STREAMING_EMPTY_MESSAGE_INTERVAL)));
+ }
+ Streaming.StreamingConfig streamingConf = builder.build();
+ LOGGER.info("Streaming native conf {}", streamingConf.toString());
+ return streamingConf.toByteArray();
+ }
+
+}
+
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataMessage.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataMessage.java
new file mode 100644
index 000000000000..2fb80b09eb64
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataMessage.java
@@ -0,0 +1,54 @@
+package org.ray.streaming.runtime.transfer;
+
+import java.nio.ByteBuffer;
+
+/**
+ * DataMessage represents data between upstream and downstream operator
+ */
+public class DataMessage implements Message {
+ private final ByteBuffer body;
+ private final long msgId;
+ private final long timestamp;
+ private final String channelId;
+
+ public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) {
+ this.body = body;
+ this.timestamp = timestamp;
+ this.msgId = msgId;
+ this.channelId = channelId;
+ }
+
+ @Override
+ public ByteBuffer body() {
+ return body;
+ }
+
+ @Override
+ public long timestamp() {
+ return timestamp;
+ }
+
+ /**
+ * @return message id
+ */
+ public long msgId() {
+ return msgId;
+ }
+
+ /**
+ * @return string id of channel where data is coming from
+ */
+ public String channelId() {
+ return channelId;
+ }
+
+ @Override
+ public String toString() {
+ return "DataMessage{" +
+ "body=" + body +
+ ", msgId=" + msgId +
+ ", timestamp=" + timestamp +
+ ", channelId='" + channelId + '\'' +
+ '}';
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataReader.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataReader.java
new file mode 100644
index 000000000000..d1ce9327d54f
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataReader.java
@@ -0,0 +1,258 @@
+package org.ray.streaming.runtime.transfer;
+
+import com.google.common.base.Preconditions;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.LinkedList;
+import java.util.List;
+import java.util.Map;
+import java.util.Queue;
+import org.ray.api.id.ActorId;
+import org.ray.streaming.runtime.util.Platform;
+import org.ray.streaming.util.Config;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * DataReader is wrapper of streaming c++ DataReader, which read data
+ * from channels of upstream workers
+ */
+public class DataReader {
+ private static final Logger LOGGER = LoggerFactory.getLogger(DataReader.class);
+
+ private long nativeReaderPtr;
+ private Queue buf = new LinkedList<>();
+
+ public DataReader(List inputChannels,
+ List fromActors,
+ Map conf) {
+ Preconditions.checkArgument(inputChannels.size() > 0);
+ Preconditions.checkArgument(inputChannels.size() == fromActors.size());
+ byte[][] inputChannelsBytes = inputChannels.stream()
+ .map(ChannelID::idStrToBytes).toArray(byte[][]::new);
+ byte[][] fromActorsBytes = fromActors.stream()
+ .map(ActorId::getBytes).toArray(byte[][]::new);
+ long[] seqIds = new long[inputChannels.size()];
+ long[] msgIds = new long[inputChannels.size()];
+ for (int i = 0; i < inputChannels.size(); i++) {
+ seqIds[i] = 0;
+ msgIds[i] = 0;
+ }
+ long timerInterval = Long.parseLong(
+ conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1"));
+ String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
+ boolean isMock = false;
+ if (Config.MEMORY_CHANNEL.equals(channelType)) {
+ isMock = true;
+ }
+ boolean isRecreate = Boolean.parseBoolean(
+ conf.getOrDefault(Config.IS_RECREATE, "false"));
+ this.nativeReaderPtr = createDataReaderNative(
+ inputChannelsBytes,
+ fromActorsBytes,
+ seqIds,
+ msgIds,
+ timerInterval,
+ isRecreate,
+ ChannelUtils.toNativeConf(conf),
+ isMock
+ );
+ LOGGER.info("create DataReader succeed");
+ }
+
+ // params set by getBundleNative: bundle data address + size
+ private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24);
+ // We use direct buffer to reduce gc overhead and memory copy.
+ private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0);
+ private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH);
+
+ {
+ getBundleParams.order(ByteOrder.nativeOrder());
+ bundleData.order(ByteOrder.nativeOrder());
+ bundleMeta.order(ByteOrder.nativeOrder());
+ }
+
+ /**
+ * Read message from input channels, if timeout, return null.
+ *
+ * @param timeoutMillis timeout
+ * @return message or null
+ */
+ public DataMessage read(long timeoutMillis) {
+ if (buf.isEmpty()) {
+ getBundle(timeoutMillis);
+ // if bundle not empty. empty message still has data size + seqId + msgId
+ if (bundleData.position() < bundleData.limit()) {
+ BundleMeta bundleMeta = new BundleMeta(this.bundleMeta);
+ // barrier
+ if (bundleMeta.getBundleType() == DataBundleType.BARRIER) {
+ throw new UnsupportedOperationException(
+ "Unsupported bundle type " + bundleMeta.getBundleType());
+ } else if (bundleMeta.getBundleType() == DataBundleType.BUNDLE) {
+ String channelID = bundleMeta.getChannelID();
+ long timestamp = bundleMeta.getBundleTs();
+ for (int i = 0; i < bundleMeta.getMessageListSize(); i++) {
+ buf.offer(getDataMessage(bundleData, channelID, timestamp));
+ }
+ } else if (bundleMeta.getBundleType() == DataBundleType.EMPTY) {
+ long messageId = bundleMeta.getLastMessageId();
+ buf.offer(new DataMessage(null, bundleMeta.getBundleTs(),
+ messageId, bundleMeta.getChannelID()));
+ }
+ }
+ }
+ if (buf.isEmpty()) {
+ return null;
+ }
+ return buf.poll();
+ }
+
+ private DataMessage getDataMessage(ByteBuffer bundleData, String channelID, long timestamp) {
+ int dataSize = bundleData.getInt();
+ // msgId
+ long msgId = bundleData.getLong();
+ // msgType
+ bundleData.getInt();
+ // make `data.capacity() == data.remaining()`, because some code used `capacity()`
+ // rather than `remaining()`
+ int position = bundleData.position();
+ int limit = bundleData.limit();
+ bundleData.limit(position + dataSize);
+ ByteBuffer data = bundleData.slice();
+ bundleData.limit(limit);
+ bundleData.position(position + dataSize);
+ return new DataMessage(data, timestamp, msgId, channelID);
+ }
+
+ private void getBundle(long timeoutMillis) {
+ getBundleNative(nativeReaderPtr, timeoutMillis,
+ Platform.getAddress(getBundleParams), Platform.getAddress(bundleMeta));
+ bundleMeta.rewind();
+ long bundleAddress = getBundleParams.getLong(0);
+ int bundleSize = getBundleParams.getInt(8);
+ // This has better performance than NewDirectBuffer or set address/capacity in jni.
+ Platform.wrapDirectBuffer(bundleData, bundleAddress, bundleSize);
+ }
+
+ /**
+ * Stop reader
+ */
+ public void stop() {
+ stopReaderNative(nativeReaderPtr);
+ }
+
+ /**
+ * Close reader to release resource
+ */
+ public void close() {
+ if (nativeReaderPtr == 0) {
+ return;
+ }
+ LOGGER.info("closing DataReader.");
+ closeReaderNative(nativeReaderPtr);
+ nativeReaderPtr = 0;
+ LOGGER.info("closing DataReader done.");
+ }
+
+ private static native long createDataReaderNative(
+ byte[][] inputChannels,
+ byte[][] inputActorIds,
+ long[] seqIds,
+ long[] msgIds,
+ long timerInterval,
+ boolean isRecreate,
+ byte[] configBytes,
+ boolean isMock);
+
+ private native void getBundleNative(long nativeReaderPtr,
+ long timeoutMillis,
+ long params,
+ long metaAddress);
+
+ private native void stopReaderNative(long nativeReaderPtr);
+
+ private native void closeReaderNative(long nativeReaderPtr);
+
+ enum DataBundleType {
+ EMPTY(1),
+ BARRIER(2),
+ BUNDLE(3);
+
+ int code;
+
+ DataBundleType(int code) {
+ this.code = code;
+ }
+ }
+
+ static class BundleMeta {
+ // kMessageBundleHeaderSize + kUniqueIDSize:
+ // magicNum(4b) + bundleTs(8b) + lastMessageId(8b) + messageListSize(4b)
+ // + bundleType(4b) + rawBundleSize(4b) + channelID(20b)
+ static final int LENGTH = 4 + 8 + 8 + 4 + 4 + 4 + 20;
+ private int magicNum;
+ private long bundleTs;
+ private long lastMessageId;
+ private int messageListSize;
+ private DataBundleType bundleType;
+ private String channelID;
+ private int rawBundleSize;
+
+ BundleMeta(ByteBuffer buffer) {
+ // StreamingMessageBundleMeta Deserialization
+ // magicNum
+ magicNum = buffer.getInt();
+ // messageBundleTs
+ bundleTs = buffer.getLong();
+ // lastOffsetSeqId
+ lastMessageId = buffer.getLong();
+ messageListSize = buffer.getInt();
+ int typeInt = buffer.getInt();
+ if (DataBundleType.BUNDLE.code == typeInt) {
+ bundleType = DataBundleType.BUNDLE;
+ } else if (DataBundleType.BARRIER.code == typeInt) {
+ bundleType = DataBundleType.BARRIER;
+ } else {
+ bundleType = DataBundleType.EMPTY;
+ }
+ // rawBundleSize
+ rawBundleSize = buffer.getInt();
+ channelID = getQidString(buffer);
+ }
+
+ private String getQidString(ByteBuffer buffer) {
+ byte[] bytes = new byte[ChannelID.ID_LENGTH];
+ buffer.get(bytes);
+ return ChannelID.idBytesToStr(bytes);
+ }
+
+ public int getMagicNum() {
+ return magicNum;
+ }
+
+ public long getBundleTs() {
+ return bundleTs;
+ }
+
+ public long getLastMessageId() {
+ return lastMessageId;
+ }
+
+ public int getMessageListSize() {
+ return messageListSize;
+ }
+
+ public DataBundleType getBundleType() {
+ return bundleType;
+ }
+
+ public String getChannelID() {
+ return channelID;
+ }
+
+ public int getRawBundleSize() {
+ return rawBundleSize;
+ }
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataWriter.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataWriter.java
new file mode 100644
index 000000000000..b0c943b0eb7b
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataWriter.java
@@ -0,0 +1,140 @@
+package org.ray.streaming.runtime.transfer;
+
+import com.google.common.base.Preconditions;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+import org.ray.api.id.ActorId;
+import org.ray.streaming.runtime.util.Platform;
+import org.ray.streaming.util.Config;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * DataWriter is a wrapper of streaming c++ DataWriter, which sends data
+ * to downstream workers
+ */
+public class DataWriter {
+ private static final Logger LOGGER = LoggerFactory.getLogger(DataWriter.class);
+
+ private long nativeWriterPtr;
+ private ByteBuffer buffer = ByteBuffer.allocateDirect(0);
+ private long bufferAddress;
+
+ {
+ ensureBuffer(0);
+ }
+
+ /**
+ * @param outputChannels output channels ids
+ * @param toActors downstream output actors
+ * @param conf configuration
+ */
+ public DataWriter(List outputChannels,
+ List toActors,
+ Map conf) {
+ Preconditions.checkArgument(!outputChannels.isEmpty());
+ Preconditions.checkArgument(outputChannels.size() == toActors.size());
+ byte[][] outputChannelsBytes = outputChannels.stream()
+ .map(ChannelID::idStrToBytes).toArray(byte[][]::new);
+ byte[][] toActorsBytes = toActors.stream()
+ .map(ActorId::getBytes).toArray(byte[][]::new);
+ long channelSize = Long.parseLong(
+ conf.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT));
+ long[] msgIds = new long[outputChannels.size()];
+ for (int i = 0; i < outputChannels.size(); i++) {
+ msgIds[i] = 0;
+ }
+ String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
+ boolean isMock = false;
+ if (Config.MEMORY_CHANNEL.equals(channelType)) {
+ isMock = true;
+ }
+ this.nativeWriterPtr = createWriterNative(
+ outputChannelsBytes,
+ toActorsBytes,
+ msgIds,
+ channelSize,
+ ChannelUtils.toNativeConf(conf),
+ isMock
+ );
+ LOGGER.info("create DataWriter succeed");
+ }
+
+ /**
+ * Write msg into the specified channel
+ *
+ * @param id channel id
+ * @param item message item data section is specified by [position, limit).
+ */
+ public void write(ChannelID id, ByteBuffer item) {
+ int size = item.remaining();
+ ensureBuffer(size);
+ buffer.clear();
+ buffer.put(item);
+ writeMessageNative(nativeWriterPtr, id.getNativeIdPtr(), bufferAddress, size);
+ }
+
+ /**
+ * Write msg into the specified channels
+ *
+ * @param ids channel ids
+ * @param item message item data section is specified by [position, limit).
+ * item doesn't have to be a direct buffer.
+ */
+ public void write(Set ids, ByteBuffer item) {
+ int size = item.remaining();
+ ensureBuffer(size);
+ for (ChannelID id : ids) {
+ buffer.clear();
+ buffer.put(item.duplicate());
+ writeMessageNative(nativeWriterPtr, id.getNativeIdPtr(), bufferAddress, size);
+ }
+ }
+
+ private void ensureBuffer(int size) {
+ if (buffer.capacity() < size) {
+ buffer = ByteBuffer.allocateDirect(size);
+ buffer.order(ByteOrder.nativeOrder());
+ bufferAddress = Platform.getAddress(buffer);
+ }
+ }
+
+ /**
+ * stop writer
+ */
+ public void stop() {
+ stopWriterNative(nativeWriterPtr);
+ }
+
+ /**
+ * close writer to release resources
+ */
+ public void close() {
+ if (nativeWriterPtr == 0) {
+ return;
+ }
+ LOGGER.info("closing data writer.");
+ closeWriterNative(nativeWriterPtr);
+ nativeWriterPtr = 0;
+ LOGGER.info("closing data writer done.");
+ }
+
+ private static native long createWriterNative(
+ byte[][] outputQueueIds,
+ byte[][] outputActorIds,
+ long[] msgIds,
+ long channelSize,
+ byte[] confBytes,
+ boolean isMock);
+
+ private native long writeMessageNative(
+ long nativeQueueProducerPtr, long nativeIdPtr, long address, int size);
+
+ private native void stopWriterNative(long nativeQueueProducerPtr);
+
+ private native void closeWriterNative(long nativeQueueProducerPtr);
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/Message.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/Message.java
new file mode 100644
index 000000000000..b43e713e7bb4
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/Message.java
@@ -0,0 +1,22 @@
+package org.ray.streaming.runtime.transfer;
+
+import java.nio.ByteBuffer;
+
+public interface Message {
+
+ /**
+ * Message data
+ *
+ * Message body is a direct byte buffer, which may be invalid after call next
+ * DataReader#getBundleNative
. Please consume this buffer fully
+ * before next call getBundleNative
.
+ *
+ * @return message body
+ */
+ ByteBuffer body();
+
+ /**
+ * @return timestamp when item is written by upstream DataWriter
+ */
+ long timestamp();
+}
\ No newline at end of file
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java
new file mode 100644
index 000000000000..2307b64e40e9
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java
@@ -0,0 +1,72 @@
+package org.ray.streaming.runtime.transfer;
+
+import com.google.common.base.Preconditions;
+import org.ray.runtime.RayNativeRuntime;
+import org.ray.runtime.functionmanager.FunctionDescriptor;
+import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
+import org.ray.runtime.util.JniUtils;
+
+/**
+ * TransferHandler is used for handle direct call based data transfer between workers.
+ * TransferHandler is used by streaming queue for data transfer.
+ */
+public class TransferHandler {
+
+ static {
+ try {
+ Class.forName(RayNativeRuntime.class.getName());
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+ JniUtils.loadLibrary("streaming_java");
+ }
+
+ private long writerClientNative;
+ private long readerClientNative;
+
+ public TransferHandler(long coreWorkerNative,
+ JavaFunctionDescriptor writerAsyncFunc,
+ JavaFunctionDescriptor writerSyncFunc,
+ JavaFunctionDescriptor readerAsyncFunc,
+ JavaFunctionDescriptor readerSyncFunc) {
+ Preconditions.checkArgument(coreWorkerNative != 0);
+ writerClientNative = createWriterClientNative(
+ coreWorkerNative, writerAsyncFunc, writerSyncFunc);
+ readerClientNative = createReaderClientNative(
+ coreWorkerNative, readerAsyncFunc, readerSyncFunc);
+ }
+
+ public void onWriterMessage(byte[] buffer) {
+ handleWriterMessageNative(writerClientNative, buffer);
+ }
+
+ public byte[] onWriterMessageSync(byte[] buffer) {
+ return handleWriterMessageSyncNative(writerClientNative, buffer);
+ }
+
+ public void onReaderMessage(byte[] buffer) {
+ handleReaderMessageNative(readerClientNative, buffer);
+ }
+
+ public byte[] onReaderMessageSync(byte[] buffer) {
+ return handleReaderMessageSyncNative(readerClientNative, buffer);
+ }
+
+ private native long createWriterClientNative(
+ long coreWorkerNative,
+ FunctionDescriptor asyncFunc,
+ FunctionDescriptor syncFunc);
+
+ private native long createReaderClientNative(
+ long coreWorkerNative,
+ FunctionDescriptor asyncFunc,
+ FunctionDescriptor syncFunc);
+
+ private native void handleWriterMessageNative(long handler, byte[] buffer);
+
+ private native byte[] handleWriterMessageSyncNative(long handler, byte[] buffer);
+
+ private native void handleReaderMessageNative(long handler, byte[] buffer);
+
+ private native byte[] handleReaderMessageSyncNative(long handler, byte[] buffer);
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/EnvUtil.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/EnvUtil.java
new file mode 100644
index 000000000000..caf47d4894d8
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/EnvUtil.java
@@ -0,0 +1,19 @@
+package org.ray.streaming.runtime.util;
+
+import org.ray.runtime.RayNativeRuntime;
+import org.ray.runtime.util.JniUtils;
+
+public class EnvUtil {
+
+ public static void loadNativeLibraries() {
+ // Explicitly load `RayNativeRuntime`, to make sure `core_worker_library_java`
+ // is loaded before `streaming_java`.
+ try {
+ Class.forName(RayNativeRuntime.class.getName());
+ } catch (ClassNotFoundException e) {
+ throw new RuntimeException(e);
+ }
+ JniUtils.loadLibrary("streaming_java");
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/Platform.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/Platform.java
new file mode 100644
index 000000000000..21cda5d6fae6
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/Platform.java
@@ -0,0 +1,91 @@
+package org.ray.streaming.runtime.util;
+
+import com.google.common.base.Preconditions;
+import java.lang.reflect.Constructor;
+import java.lang.reflect.Field;
+import java.lang.reflect.InvocationTargetException;
+import java.nio.Buffer;
+import java.nio.ByteBuffer;
+import sun.misc.Unsafe;
+import sun.nio.ch.DirectBuffer;
+
+/**
+ * Based on org.apache.spark.unsafe.Platform
+ */
+public final class Platform {
+
+ public static final Unsafe UNSAFE;
+
+ static {
+ Unsafe unsafe;
+ try {
+ Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe");
+ unsafeField.setAccessible(true);
+ unsafe = (Unsafe) unsafeField.get(null);
+ } catch (Throwable cause) {
+ throw new UnsupportedOperationException("Unsafe is not supported in this platform.");
+ }
+ UNSAFE = unsafe;
+ }
+
+ // Access fields and constructors once and store them, for performance:
+ private static final Constructor> DBB_CONSTRUCTOR;
+ private static final long BUFFER_ADDRESS_FIELD_OFFSET;
+ private static final long BUFFER_CAPACITY_FIELD_OFFSET;
+
+ static {
+ try {
+ Class> cls = Class.forName("java.nio.DirectByteBuffer");
+ Constructor> constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE);
+ constructor.setAccessible(true);
+ DBB_CONSTRUCTOR = constructor;
+ Field addressField = Buffer.class.getDeclaredField("address");
+ BUFFER_ADDRESS_FIELD_OFFSET = UNSAFE.objectFieldOffset(addressField);
+ Preconditions.checkArgument(BUFFER_ADDRESS_FIELD_OFFSET != 0);
+ Field capacityField = Buffer.class.getDeclaredField("capacity");
+ BUFFER_CAPACITY_FIELD_OFFSET = UNSAFE.objectFieldOffset(capacityField);
+ Preconditions.checkArgument(BUFFER_CAPACITY_FIELD_OFFSET != 0);
+ } catch (ClassNotFoundException | NoSuchMethodException | NoSuchFieldException e) {
+ throw new IllegalStateException(e);
+ }
+ }
+
+ private static final ThreadLocal localEmptyBuffer =
+ ThreadLocal.withInitial(() -> {
+ try {
+ return (ByteBuffer) DBB_CONSTRUCTOR.newInstance(0, 0);
+ } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) {
+ UNSAFE.throwException(e);
+ }
+ throw new IllegalStateException("unreachable");
+ });
+
+ /**
+ * Wrap a buffer [address, address + size) as a DirectByteBuffer.
+ */
+ public static ByteBuffer wrapDirectBuffer(long address, int size) {
+ ByteBuffer buffer = localEmptyBuffer.get().duplicate();
+ UNSAFE.putLong(buffer, BUFFER_ADDRESS_FIELD_OFFSET, address);
+ UNSAFE.putInt(buffer, BUFFER_CAPACITY_FIELD_OFFSET, size);
+ buffer.clear();
+ return buffer;
+ }
+
+ /**
+ * Wrap a buffer [address, address + size) into provided buffer
.
+ */
+ public static void wrapDirectBuffer(ByteBuffer buffer, long address, int size) {
+ UNSAFE.putLong(buffer, BUFFER_ADDRESS_FIELD_OFFSET, address);
+ UNSAFE.putInt(buffer, BUFFER_CAPACITY_FIELD_OFFSET, size);
+ buffer.clear();
+ }
+
+ /**
+ * @param buffer a DirectBuffer backed by off-heap memory
+ * @return address of off-heap memory
+ */
+ public static long getAddress(ByteBuffer buffer) {
+ return ((DirectBuffer) buffer).address();
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java
new file mode 100644
index 000000000000..bb7e607b03c5
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java
@@ -0,0 +1,159 @@
+package org.ray.streaming.runtime.worker;
+
+import java.io.Serializable;
+import java.util.Map;
+import org.ray.api.Ray;
+import org.ray.api.annotation.RayRemote;
+import org.ray.runtime.RayMultiWorkerNativeRuntime;
+import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
+import org.ray.streaming.runtime.core.graph.ExecutionGraph;
+import org.ray.streaming.runtime.core.graph.ExecutionNode;
+import org.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
+import org.ray.streaming.runtime.core.graph.ExecutionTask;
+import org.ray.streaming.runtime.core.processor.OneInputProcessor;
+import org.ray.streaming.runtime.core.processor.SourceProcessor;
+import org.ray.streaming.runtime.core.processor.StreamProcessor;
+import org.ray.streaming.runtime.transfer.TransferHandler;
+import org.ray.streaming.runtime.util.EnvUtil;
+import org.ray.streaming.runtime.worker.context.WorkerContext;
+import org.ray.streaming.runtime.worker.tasks.OneInputStreamTask;
+import org.ray.streaming.runtime.worker.tasks.SourceStreamTask;
+import org.ray.streaming.runtime.worker.tasks.StreamTask;
+import org.ray.streaming.util.Config;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The stream job worker, it is a ray actor.
+ */
+@RayRemote
+public class JobWorker implements Serializable {
+ private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class);
+
+ static {
+ EnvUtil.loadNativeLibraries();
+ }
+
+ private int taskId;
+ private Map config;
+ private WorkerContext workerContext;
+ private ExecutionNode executionNode;
+ private ExecutionTask executionTask;
+ private ExecutionGraph executionGraph;
+ private StreamProcessor streamProcessor;
+ private NodeType nodeType;
+ private StreamTask task;
+ private TransferHandler transferHandler;
+
+ public Boolean init(WorkerContext workerContext) {
+ this.workerContext = workerContext;
+ this.taskId = workerContext.getTaskId();
+ this.config = workerContext.getConfig();
+ this.executionGraph = this.workerContext.getExecutionGraph();
+ this.executionTask = executionGraph.getExecutionTaskByTaskId(taskId);
+ this.executionNode = executionGraph.getExecutionNodeByTaskId(taskId);
+
+ this.nodeType = executionNode.getNodeType();
+ this.streamProcessor = executionNode.getStreamProcessor();
+ LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor);
+
+ String channelType = (String) this.config.getOrDefault(
+ Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE);
+ if (channelType.equals(Config.NATIVE_CHANNEL)) {
+ transferHandler = new TransferHandler(
+ getNativeCoreWorker(),
+ new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"),
+ new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"),
+ new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"),
+ new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessageSync", "([B)[B"));
+ }
+ task = createStreamTask();
+ task.start();
+ return true;
+ }
+
+ private StreamTask createStreamTask() {
+ if (streamProcessor instanceof OneInputProcessor) {
+ return new OneInputStreamTask(taskId, streamProcessor, this);
+ } else if (streamProcessor instanceof SourceProcessor) {
+ return new SourceStreamTask(taskId, streamProcessor, this);
+ } else {
+ throw new RuntimeException("Unsupported type: " + streamProcessor);
+ }
+ }
+
+ public int getTaskId() {
+ return taskId;
+ }
+
+ public Map getConfig() {
+ return config;
+ }
+
+ public WorkerContext getWorkerContext() {
+ return workerContext;
+ }
+
+ public NodeType getNodeType() {
+ return nodeType;
+ }
+
+ public ExecutionNode getExecutionNode() {
+ return executionNode;
+ }
+
+ public ExecutionTask getExecutionTask() {
+ return executionTask;
+ }
+
+ public ExecutionGraph getExecutionGraph() {
+ return executionGraph;
+ }
+
+ public StreamProcessor getStreamProcessor() {
+ return streamProcessor;
+ }
+
+ public StreamTask getTask() {
+ return task;
+ }
+
+ /**
+ * Used by upstream streaming queue to send data to this actor
+ */
+ public void onReaderMessage(byte[] buffer) {
+ transferHandler.onReaderMessage(buffer);
+ }
+
+ /**
+ * Used by upstream streaming queue to send data to this actor
+ * and receive result from this actor
+ */
+ public byte[] onReaderMessageSync(byte[] buffer) {
+ return transferHandler.onReaderMessageSync(buffer);
+ }
+
+ /**
+ * Used by downstream streaming queue to send data to this actor
+ */
+ public void onWriterMessage(byte[] buffer) {
+ transferHandler.onWriterMessage(buffer);
+ }
+
+ /**
+ * Used by downstream streaming queue to send data to this actor
+ * and receive result from this actor
+ */
+ public byte[] onWriterMessageSync(byte[] buffer) {
+ return transferHandler.onWriterMessageSync(buffer);
+ }
+
+ private static long getNativeCoreWorker() {
+ long pointer = 0;
+ if (Ray.internal() instanceof RayMultiWorkerNativeRuntime) {
+ pointer = ((RayMultiWorkerNativeRuntime) Ray.internal())
+ .getCurrentRuntime().getNativeCoreWorkerPointer();
+ }
+ return pointer;
+ }
+}
diff --git a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RayRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/RayRuntimeContext.java
similarity index 66%
rename from java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RayRuntimeContext.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/RayRuntimeContext.java
index 6d796c30eff5..e6779733c235 100644
--- a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RayRuntimeContext.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/RayRuntimeContext.java
@@ -1,19 +1,21 @@
-package org.ray.streaming.core.runtime.context;
+package org.ray.streaming.runtime.worker.context;
-import static org.ray.streaming.util.ConfigKey.STREAMING_MAX_BATCH_COUNT;
+import static org.ray.streaming.util.Config.STREAMING_BATCH_MAX_COUNT;
import java.util.Map;
-import org.ray.streaming.core.graph.ExecutionTask;
+
+import org.ray.streaming.api.context.RuntimeContext;
+import org.ray.streaming.runtime.core.graph.ExecutionTask;
/**
* Use Ray to implement RuntimeContext.
*/
public class RayRuntimeContext implements RuntimeContext {
-
private int taskId;
private int taskIndex;
private int parallelism;
private Long batchId;
+ private final Long maxBatch;
private Map config;
public RayRuntimeContext(ExecutionTask executionTask, Map config,
@@ -22,6 +24,11 @@ public RayRuntimeContext(ExecutionTask executionTask, Map config
this.config = config;
this.taskIndex = executionTask.getTaskIndex();
this.parallelism = parallelism;
+ if (config.containsKey(STREAMING_BATCH_MAX_COUNT)) {
+ this.maxBatch = Long.valueOf(String.valueOf(config.get(STREAMING_BATCH_MAX_COUNT)));
+ } else {
+ this.maxBatch = Long.MAX_VALUE;
+ }
}
@Override
@@ -46,10 +53,7 @@ public Long getBatchId() {
@Override
public Long getMaxBatch() {
- if (config.containsKey(STREAMING_MAX_BATCH_COUNT)) {
- return Long.valueOf(String.valueOf(config.get(STREAMING_MAX_BATCH_COUNT)));
- }
- return Long.MAX_VALUE;
+ return maxBatch;
}
public void setBatchId(Long batchId) {
diff --git a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/WorkerContext.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/WorkerContext.java
similarity index 88%
rename from java/streaming/src/main/java/org/ray/streaming/core/runtime/context/WorkerContext.java
rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/WorkerContext.java
index 39117fd7b225..567909f81179 100644
--- a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/WorkerContext.java
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/WorkerContext.java
@@ -1,8 +1,8 @@
-package org.ray.streaming.core.runtime.context;
+package org.ray.streaming.runtime.worker.context;
import java.io.Serializable;
import java.util.Map;
-import org.ray.streaming.core.graph.ExecutionGraph;
+import org.ray.streaming.runtime.core.graph.ExecutionGraph;
/**
* Encapsulate the context information for worker initialization.
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java
new file mode 100644
index 000000000000..eed12f705cbc
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java
@@ -0,0 +1,53 @@
+package org.ray.streaming.runtime.worker.tasks;
+
+import org.ray.runtime.util.Serializer;
+import org.ray.streaming.runtime.core.processor.Processor;
+import org.ray.streaming.runtime.transfer.Message;
+import org.ray.streaming.runtime.worker.JobWorker;
+import org.ray.streaming.util.Config;
+
+public abstract class InputStreamTask extends StreamTask {
+ private volatile boolean running = true;
+ private volatile boolean stopped = false;
+ private long readTimeoutMillis;
+
+ public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) {
+ super(taskId, processor, streamWorker);
+ readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig()
+ .getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS));
+ }
+
+ @Override
+ protected void init() {
+ }
+
+ @Override
+ public void run() {
+ while (running) {
+ Message item = reader.read(readTimeoutMillis);
+ if (item != null) {
+ byte[] bytes = new byte[item.body().remaining()];
+ item.body().get(bytes);
+ Object obj = Serializer.decode(bytes);
+ processor.process(obj);
+ }
+ }
+ stopped = true;
+ }
+
+ @Override
+ protected void cancelTask() throws Exception {
+ running = false;
+ while (!stopped) {
+ }
+ }
+
+ @Override
+ public String toString() {
+ final StringBuilder sb = new StringBuilder("InputStreamTask{");
+ sb.append("taskId=").append(taskId);
+ sb.append(", processor=").append(processor);
+ sb.append('}');
+ return sb.toString();
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java
new file mode 100644
index 000000000000..0b9491f76679
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java
@@ -0,0 +1,11 @@
+package org.ray.streaming.runtime.worker.tasks;
+
+import org.ray.streaming.runtime.core.processor.Processor;
+import org.ray.streaming.runtime.worker.JobWorker;
+
+public class OneInputStreamTask extends InputStreamTask {
+
+ public OneInputStreamTask(int taskId, Processor processor, JobWorker streamWorker) {
+ super(taskId, processor, streamWorker);
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/SourceStreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/SourceStreamTask.java
new file mode 100644
index 000000000000..74a197708e0c
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/SourceStreamTask.java
@@ -0,0 +1,30 @@
+package org.ray.streaming.runtime.worker.tasks;
+
+import org.ray.streaming.runtime.core.processor.Processor;
+import org.ray.streaming.runtime.core.processor.SourceProcessor;
+import org.ray.streaming.runtime.worker.JobWorker;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class SourceStreamTask extends StreamTask {
+ private static final Logger LOGGER = LoggerFactory.getLogger(SourceStreamTask.class);
+
+ public SourceStreamTask(int taskId, Processor processor, JobWorker worker) {
+ super(taskId, processor, worker);
+ }
+
+ @Override
+ protected void init() {
+ }
+
+ @Override
+ public void run() {
+ final SourceProcessor sourceProcessor = (SourceProcessor) this.processor;
+ sourceProcessor.run();
+ }
+
+ @Override
+ protected void cancelTask() throws Exception {
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java
new file mode 100644
index 000000000000..7d4f397ba04f
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java
@@ -0,0 +1,134 @@
+package org.ray.streaming.runtime.worker.tasks;
+
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.ray.api.Ray;
+import org.ray.api.RayActor;
+import org.ray.api.id.ActorId;
+import org.ray.streaming.api.collector.Collector;
+import org.ray.streaming.api.context.RuntimeContext;
+import org.ray.streaming.runtime.core.collector.OutputCollector;
+import org.ray.streaming.runtime.core.graph.ExecutionEdge;
+import org.ray.streaming.runtime.core.graph.ExecutionGraph;
+import org.ray.streaming.runtime.core.graph.ExecutionNode;
+import org.ray.streaming.runtime.core.processor.Processor;
+import org.ray.streaming.runtime.transfer.ChannelID;
+import org.ray.streaming.runtime.transfer.DataReader;
+import org.ray.streaming.runtime.transfer.DataWriter;
+import org.ray.streaming.runtime.worker.JobWorker;
+import org.ray.streaming.runtime.worker.context.RayRuntimeContext;
+import org.ray.streaming.util.Config;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public abstract class StreamTask implements Runnable {
+ private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class);
+
+ protected int taskId;
+ protected Processor processor;
+ protected JobWorker worker;
+ protected DataReader reader;
+ private Map writers;
+ private Thread thread;
+
+ public StreamTask(int taskId, Processor processor, JobWorker worker) {
+ this.taskId = taskId;
+ this.processor = processor;
+ this.worker = worker;
+ prepareTask();
+
+ this.thread = new Thread(Ray.wrapRunnable(this), this.getClass().getName()
+ + "-" + System.currentTimeMillis());
+ this.thread.setDaemon(true);
+ }
+
+ private void prepareTask() {
+ Map queueConf = new HashMap<>();
+ worker.getConfig().forEach((k, v) -> queueConf.put(k, String.valueOf(v)));
+ String queueSize = (String) worker.getConfig()
+ .getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT);
+ queueConf.put(Config.CHANNEL_SIZE, queueSize);
+ queueConf.put(Config.TASK_JOB_ID, Ray.getRuntimeContext().getCurrentJobId().toString());
+ String channelType = (String) worker.getConfig()
+ .getOrDefault(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL);
+ queueConf.put(Config.CHANNEL_TYPE, channelType);
+
+ ExecutionGraph executionGraph = worker.getExecutionGraph();
+ ExecutionNode executionNode = worker.getExecutionNode();
+
+ // writers
+ writers = new HashMap<>();
+ List outputEdges = executionNode.getOutputEdges();
+ List collectors = new ArrayList<>();
+ for (ExecutionEdge edge : outputEdges) {
+ Map outputActorIds = new HashMap<>();
+ Map> taskId2Worker = executionGraph
+ .getTaskId2WorkerByNodeId(edge.getTargetNodeId());
+ taskId2Worker.forEach((targetTaskId, targetActor) -> {
+ String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime());
+ outputActorIds.put(queueName, targetActor.getId());
+ });
+
+ if (!outputActorIds.isEmpty()) {
+ List channelIDs = new ArrayList<>();
+ List toActorIds = new ArrayList<>();
+ outputActorIds.forEach((k, v) -> {
+ channelIDs.add(k);
+ toActorIds.add(v);
+ });
+ DataWriter writer = new DataWriter(channelIDs, toActorIds, queueConf);
+ LOG.info("Create DataWriter succeed.");
+ writers.put(edge, writer);
+ collectors.add(new OutputCollector(channelIDs, writer, edge.getPartition()));
+ }
+ }
+
+ // consumer
+ List inputEdges = executionNode.getInputsEdges();
+ Map inputActorIds = new HashMap<>();
+ for (ExecutionEdge edge : inputEdges) {
+ Map> taskId2Worker = executionGraph
+ .getTaskId2WorkerByNodeId(edge.getSrcNodeId());
+ taskId2Worker.forEach((srcTaskId, srcActor) -> {
+ String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime());
+ inputActorIds.put(queueName, srcActor.getId());
+ });
+ }
+ if (!inputActorIds.isEmpty()) {
+ List channelIDs = new ArrayList<>();
+ List fromActorIds = new ArrayList<>();
+ inputActorIds.forEach((k, v) -> {
+ channelIDs.add(k);
+ fromActorIds.add(v);
+ });
+ LOG.info("Register queue consumer, queues {}.", channelIDs);
+ reader = new DataReader(channelIDs, fromActorIds, queueConf);
+ }
+
+ RuntimeContext runtimeContext = new RayRuntimeContext(
+ worker.getExecutionTask(), worker.getConfig(), executionNode.getParallelism());
+
+ processor.open(collectors, runtimeContext);
+
+ Runtime.getRuntime().addShutdownHook(new Thread(() -> {
+ try {
+ // Make DataReader stop read data when MockQueue destructor gets called to avoid crash
+ StreamTask.this.cancelTask();
+ } catch (Exception e) {
+ e.printStackTrace();
+ }
+ }));
+ }
+
+ protected abstract void init() throws Exception;
+
+ protected abstract void cancelTask() throws Exception;
+
+ public void start() {
+ this.thread.start();
+ LOG.info("started {}-{}", this.getClass().getSimpleName(), taskId);
+ }
+
+}
diff --git a/streaming/java/streaming-runtime/src/main/resources/META-INF/services/org.ray.streaming.schedule.JobScheduler b/streaming/java/streaming-runtime/src/main/resources/META-INF/services/org.ray.streaming.schedule.JobScheduler
new file mode 100644
index 000000000000..53719a32af25
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/main/resources/META-INF/services/org.ray.streaming.schedule.JobScheduler
@@ -0,0 +1 @@
+org.ray.streaming.runtime.schedule.JobSchedulerImpl
\ No newline at end of file
diff --git a/java/streaming/src/test/java/org/ray/streaming/demo/WordCountTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/demo/WordCountTest.java
similarity index 88%
rename from java/streaming/src/test/java/org/ray/streaming/demo/WordCountTest.java
rename to streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/demo/WordCountTest.java
index d104c8e6d1af..e6427120a52c 100644
--- a/java/streaming/src/test/java/org/ray/streaming/demo/WordCountTest.java
+++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/demo/WordCountTest.java
@@ -1,4 +1,4 @@
-package org.ray.streaming.demo;
+package org.ray.streaming.runtime.demo;
import com.google.common.collect.ImmutableMap;
import org.ray.streaming.api.context.StreamingContext;
@@ -6,7 +6,7 @@
import org.ray.streaming.api.function.impl.ReduceFunction;
import org.ray.streaming.api.function.impl.SinkFunction;
import org.ray.streaming.api.stream.StreamSource;
-import org.ray.streaming.util.ConfigKey;
+import org.ray.streaming.util.Config;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
@@ -31,7 +31,8 @@ public class WordCountTest implements Serializable {
public void testWordCount() {
StreamingContext streamingContext = StreamingContext.buildContext();
Map config = new HashMap<>();
- config.put(ConfigKey.STREAMING_MAX_BATCH_COUNT, 1);
+ config.put(Config.STREAMING_BATCH_MAX_COUNT, 1);
+ config.put(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL);
streamingContext.withConfig(config);
List text = new ArrayList<>();
text.add("hello world eagle eagle eagle");
@@ -46,7 +47,8 @@ public void testWordCount() {
.keyBy(pair -> pair.word)
.reduce((ReduceFunction) (oldValue, newValue) ->
new WordAndCount(oldValue.word, oldValue.count + newValue.count))
- .sink((SinkFunction) result -> wordCount.put(result.word, result.count));
+ .sink((SinkFunction)
+ result -> wordCount.put(result.word, result.count));
streamingContext.execute();
diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignImplTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignImplTest.java
new file mode 100644
index 000000000000..f6057f28984f
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignImplTest.java
@@ -0,0 +1,75 @@
+package org.ray.streaming.runtime.schedule;
+
+import java.util.ArrayList;
+import java.util.List;
+
+import com.google.common.collect.Lists;
+import org.ray.api.RayActor;
+import org.ray.api.id.ActorId;
+import org.ray.api.id.ObjectId;
+import org.ray.runtime.actor.LocalModeRayActor;
+import org.ray.streaming.api.context.StreamingContext;
+import org.ray.streaming.api.partition.impl.RoundRobinPartition;
+import org.ray.streaming.api.stream.DataStream;
+import org.ray.streaming.api.stream.StreamSink;
+import org.ray.streaming.api.stream.StreamSource;
+import org.ray.streaming.runtime.core.graph.ExecutionEdge;
+import org.ray.streaming.runtime.core.graph.ExecutionGraph;
+import org.ray.streaming.runtime.core.graph.ExecutionNode;
+import org.ray.streaming.runtime.core.graph.ExecutionNode.NodeType;
+import org.ray.streaming.runtime.worker.JobWorker;
+import org.ray.streaming.plan.Plan;
+import org.ray.streaming.plan.PlanBuilder;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.Assert;
+import org.testng.annotations.Test;
+
+public class TaskAssignImplTest {
+
+ private static final Logger LOGGER = LoggerFactory.getLogger(TaskAssignImplTest.class);
+
+ @Test
+ public void testTaskAssignImpl() {
+ Plan plan = buildDataSyncPlan();
+
+ List> workers = new ArrayList<>();
+ for(int i = 0; i < plan.getPlanVertexList().size(); i++) {
+ workers.add(new LocalModeRayActor(ActorId.fromRandom(), ObjectId.fromRandom()));
+ }
+
+ ITaskAssign taskAssign = new TaskAssignImpl();
+ ExecutionGraph executionGraph = taskAssign.assign(plan, workers);
+
+ List executionNodeList = executionGraph.getExecutionNodeList();
+
+ Assert.assertEquals(executionNodeList.size(), 2);
+ ExecutionNode sourceNode = executionNodeList.get(0);
+ Assert.assertEquals(sourceNode.getNodeType(), NodeType.SOURCE);
+ Assert.assertEquals(sourceNode.getExecutionTasks().size(), 1);
+ Assert.assertEquals(sourceNode.getOutputEdges().size(), 1);
+
+ List sourceExecutionEdges = sourceNode.getOutputEdges();
+
+ Assert.assertEquals(sourceExecutionEdges.size(), 1);
+ ExecutionEdge source2Sink = sourceExecutionEdges.get(0);
+
+ Assert.assertEquals(source2Sink.getPartition().getClass(), RoundRobinPartition.class);
+
+ ExecutionNode sinkNode = executionNodeList.get(1);
+ Assert.assertEquals(sinkNode.getNodeType(), NodeType.SINK);
+ Assert.assertEquals(sinkNode.getExecutionTasks().size(), 1);
+ Assert.assertEquals(sinkNode.getOutputEdges().size(), 0);
+ }
+
+ public Plan buildDataSyncPlan() {
+ StreamingContext streamingContext = StreamingContext.buildContext();
+ DataStream dataStream = StreamSource.buildSource(streamingContext,
+ Lists.newArrayList("a", "b", "c"));
+ StreamSink streamSink = dataStream.sink(x -> LOGGER.info(x));
+ PlanBuilder planBuilder = new PlanBuilder(Lists.newArrayList(streamSink));
+
+ Plan plan = planBuilder.buildPlan();
+ return plan;
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java
new file mode 100644
index 000000000000..0a65b1abc387
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java
@@ -0,0 +1,234 @@
+package org.ray.streaming.runtime.streamingqueue;
+
+import com.google.common.collect.ImmutableMap;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.FileOutputStream;
+import java.io.ObjectInputStream;
+import java.io.ObjectOutputStream;
+import java.io.Serializable;
+import java.lang.management.ManagementFactory;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.ray.api.Ray;
+import org.ray.api.RayActor;
+import org.ray.api.options.ActorCreationOptions;
+import org.ray.api.options.ActorCreationOptions.Builder;
+import org.ray.streaming.api.context.StreamingContext;
+import org.ray.streaming.api.function.impl.FlatMapFunction;
+import org.ray.streaming.api.function.impl.ReduceFunction;
+import org.ray.streaming.api.stream.StreamSource;
+import org.ray.streaming.runtime.transfer.ChannelID;
+import org.ray.streaming.runtime.util.EnvUtil;
+import org.ray.streaming.util.Config;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.Assert;
+import org.testng.annotations.AfterMethod;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.BeforeMethod;
+import org.testng.annotations.Test;
+
+public class StreamingQueueTest implements Serializable {
+ private static Logger LOGGER = LoggerFactory.getLogger(StreamingQueueTest.class);
+
+ static {
+ EnvUtil.loadNativeLibraries();
+ }
+
+ @org.testng.annotations.BeforeSuite
+ public void suiteSetUp() throws Exception {
+ LOGGER.info("Do set up");
+ String management = ManagementFactory.getRuntimeMXBean().getName();
+ String pid = management.split("@")[0];
+
+ LOGGER.info("StreamingQueueTest pid: {}", pid);
+ LOGGER.info("java.library.path = {}", System.getProperty("java.library.path"));
+ }
+
+ @org.testng.annotations.AfterSuite
+ public void suiteTearDown() throws Exception {
+ LOGGER.warn("Do tear down");
+ }
+
+ @BeforeClass
+ public void setUp() {
+ }
+
+ @BeforeMethod
+ void beforeMethod() {
+
+ LOGGER.info("beforeTest");
+ Ray.shutdown();
+ System.setProperty("ray.resources", "CPU:4,RES-A:4");
+ System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
+ System.setProperty("ray.run-mode", "CLUSTER");
+ System.setProperty("ray.redirect-output", "true");
+ // ray init
+ Ray.init();
+ }
+
+ @AfterMethod
+ void afterMethod() {
+ LOGGER.info("afterTest");
+ Ray.shutdown();
+ System.clearProperty("ray.run-mode");
+ }
+
+ @Test(timeOut = 3000000)
+ public void testReaderWriter() {
+ LOGGER.info("StreamingQueueTest.testReaderWriter run-mode: {}",
+ System.getProperty("ray.run-mode"));
+ Ray.shutdown();
+ System.setProperty("ray.resources", "CPU:4,RES-A:4");
+ System.setProperty("ray.raylet.config.num_workers_per_process_java", "1");
+
+ System.setProperty("ray.run-mode", "CLUSTER");
+ System.setProperty("ray.redirect-output", "true");
+ // ray init
+ Ray.init();
+
+ ActorCreationOptions.Builder builder = new Builder();
+
+ RayActor writerActor = Ray.createActor(WriterWorker::new, "writer",
+ builder.createActorCreationOptions());
+ RayActor readerActor = Ray.createActor(ReaderWorker::new, "reader",
+ builder.createActorCreationOptions());
+
+ LOGGER.info("call getName on writerActor: {}",
+ Ray.call(WriterWorker::getName, writerActor).get());
+ LOGGER.info("call getName on readerActor: {}",
+ Ray.call(ReaderWorker::getName, readerActor).get());
+
+ // LOGGER.info(Ray.call(WriterWorker::testCallReader, writerActor, readerActor).get());
+ List outputQueueList = new ArrayList<>();
+ List inputQueueList = new ArrayList<>();
+ int queueNum = 2;
+ for (int i = 0; i < queueNum; ++i) {
+ String qid = ChannelID.genRandomIdStr();
+ LOGGER.info("getRandomQueueId: {}", qid);
+ inputQueueList.add(qid);
+ outputQueueList.add(qid);
+ readerActor.getId();
+ }
+
+ final int msgCount = 100;
+ Ray.call(ReaderWorker::init, readerActor, inputQueueList, writerActor, msgCount);
+ try {
+ Thread.sleep(1000);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ Ray.call(WriterWorker::init, writerActor, outputQueueList, readerActor, msgCount);
+
+ long time = 0;
+ while (time < 20000 &&
+ Ray.call(ReaderWorker::getTotalMsg, readerActor).get() < msgCount * queueNum) {
+ try {
+ Thread.sleep(1000);
+ time += 1000;
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+
+ Assert.assertEquals(
+ Ray.call(ReaderWorker::getTotalMsg, readerActor).get().intValue(),
+ msgCount * queueNum);
+ }
+
+ @Test(timeOut = 60000)
+ public void testWordCount() {
+ LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}",
+ System.getProperty("ray.run-mode"));
+ String resultFile = "/tmp/org.ray.streaming.runtime.streamingqueue.testWordCount.txt";
+ deleteResultFile(resultFile);
+
+ Map wordCount = new ConcurrentHashMap<>();
+ StreamingContext streamingContext = StreamingContext.buildContext();
+ Map config = new HashMap<>();
+ config.put(Config.STREAMING_BATCH_MAX_COUNT, 1);
+ config.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
+ config.put(Config.CHANNEL_SIZE, "100000");
+ streamingContext.withConfig(config);
+ List text = new ArrayList<>();
+ text.add("hello world eagle eagle eagle");
+ StreamSource streamSource = StreamSource.buildSource(streamingContext, text);
+ streamSource
+ .flatMap((FlatMapFunction) (value, collector) -> {
+ String[] records = value.split(" ");
+ for (String record : records) {
+ collector.collect(new WordAndCount(record, 1));
+ }
+ })
+ .keyBy(pair -> pair.word)
+ .reduce((ReduceFunction) (oldValue, newValue) -> {
+ LOGGER.info("reduce: {} {}", oldValue, newValue);
+ return new WordAndCount(oldValue.word, oldValue.count + newValue.count);
+ })
+ .sink(s -> {
+ LOGGER.info("sink {} {}", s.word, s.count);
+ wordCount.put(s.word, s.count);
+ serializeResultToFile(resultFile, wordCount);
+ });
+
+ streamingContext.execute();
+
+ Map checkWordCount =
+ (Map) deserializeResultFromFile(resultFile);
+ // Sleep until the count for every word is computed.
+ while (checkWordCount == null || checkWordCount.size() < 3) {
+ LOGGER.info("sleep");
+ try {
+ Thread.sleep(1000);
+ } catch (InterruptedException e) {
+ LOGGER.warn("Got an exception while sleeping.", e);
+ }
+ checkWordCount = (Map) deserializeResultFromFile(resultFile);
+ }
+ LOGGER.info("check");
+ Assert.assertEquals(checkWordCount,
+ ImmutableMap.of("eagle", 3, "hello", 1, "world", 1));
+ }
+
+ private void serializeResultToFile(String fileName, Object obj) {
+ try {
+ ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(fileName));
+ out.writeObject(obj);
+ } catch (Exception e) {
+ LOGGER.error(String.valueOf(e));
+ }
+ }
+
+ private Object deserializeResultFromFile(String fileName) {
+ Map checkWordCount = null;
+ try {
+ ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName));
+ checkWordCount = (Map) in.readObject();
+ Assert.assertEquals(checkWordCount,
+ ImmutableMap.of("eagle", 3, "hello", 1, "world", 1));
+ } catch (Exception e) {
+ LOGGER.error(String.valueOf(e));
+ }
+ return checkWordCount;
+ }
+
+ private static class WordAndCount implements Serializable {
+
+ public final String word;
+ public final Integer count;
+
+ public WordAndCount(String key, Integer count) {
+ this.word = key;
+ this.count = count;
+ }
+ }
+
+ private void deleteResultFile(String path) {
+ File file = new File(path);
+ file.deleteOnExit();
+ }
+}
diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java
new file mode 100644
index 000000000000..2c105576df73
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java
@@ -0,0 +1,280 @@
+package org.ray.streaming.runtime.streamingqueue;
+
+import java.lang.management.ManagementFactory;
+import java.nio.ByteBuffer;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Random;
+import org.ray.api.Ray;
+import org.ray.api.RayActor;
+import org.ray.api.annotation.RayRemote;
+import org.ray.api.id.ActorId;
+import org.ray.runtime.RayMultiWorkerNativeRuntime;
+import org.ray.runtime.actor.NativeRayActor;
+import org.ray.runtime.functionmanager.JavaFunctionDescriptor;
+import org.ray.streaming.runtime.transfer.ChannelID;
+import org.ray.streaming.runtime.transfer.DataMessage;
+import org.ray.streaming.runtime.transfer.DataReader;
+import org.ray.streaming.runtime.transfer.DataWriter;
+import org.ray.streaming.runtime.transfer.TransferHandler;
+import org.ray.streaming.util.Config;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import org.testng.Assert;
+
+public class Worker {
+ private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class);
+
+ protected TransferHandler transferHandler = null;
+
+ public Worker() {
+ transferHandler = new TransferHandler(((RayMultiWorkerNativeRuntime) Ray.internal())
+ .getCurrentRuntime().getNativeCoreWorkerPointer(),
+ new JavaFunctionDescriptor(Worker.class.getName(),
+ "onWriterMessage", "([B)V"),
+ new JavaFunctionDescriptor(Worker.class.getName(),
+ "onWriterMessageSync", "([B)[B"),
+ new JavaFunctionDescriptor(Worker.class.getName(),
+ "onReaderMessage", "([B)V"),
+ new JavaFunctionDescriptor(Worker.class.getName(),
+ "onReaderMessageSync", "([B)[B"));
+ }
+
+ public void onReaderMessage(byte[] buffer) {
+ transferHandler.onReaderMessage(buffer);
+ }
+
+ public byte[] onReaderMessageSync(byte[] buffer) {
+ return transferHandler.onReaderMessageSync(buffer);
+ }
+
+ public void onWriterMessage(byte[] buffer) {
+ transferHandler.onWriterMessage(buffer);
+ }
+
+ public byte[] onWriterMessageSync(byte[] buffer) {
+ return transferHandler.onWriterMessageSync(buffer);
+ }
+}
+
+@RayRemote
+class ReaderWorker extends Worker {
+ private static final Logger LOGGER = LoggerFactory.getLogger(ReaderWorker.class);
+
+ private String name = null;
+ private List inputQueueList = null;
+ private List inputActorIds = new ArrayList<>();
+ private DataReader dataReader = null;
+ private long handler = 0;
+ private RayActor peerActor = null;
+ private int msgCount = 0;
+ private int totalMsg = 0;
+
+ public ReaderWorker(String name) {
+ LOGGER.info("ReaderWorker constructor");
+ this.name = name;
+ }
+
+ public String getName() {
+ String management = ManagementFactory.getRuntimeMXBean().getName();
+ String pid = management.split("@")[0];
+
+ LOGGER.info("pid: {} name: {}", pid, name);
+ return name;
+ }
+
+ public String testRayCall() {
+ LOGGER.info("testRayCall called");
+ return "testRayCall";
+ }
+
+ public boolean init(List inputQueueList, RayActor peer, int msgCount) {
+
+ this.inputQueueList = inputQueueList;
+ this.peerActor = peer;
+ this.msgCount = msgCount;
+
+ LOGGER.info("ReaderWorker init");
+ LOGGER.info("java.library.path = {}", System.getProperty("java.library.path"));
+
+ for (String queue : this.inputQueueList) {
+ inputActorIds.add(this.peerActor.getId());
+ LOGGER.info("ReaderWorker actorId: {}", this.peerActor.getId());
+ }
+
+ Map conf = new HashMap<>();
+
+ conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
+ conf.put(Config.CHANNEL_SIZE, "100000");
+ conf.put(Config.STREAMING_JOB_NAME, "integrationTest1");
+ dataReader = new DataReader(inputQueueList, inputActorIds, conf);
+
+ // Should not GetBundle in RayCall thread
+ Thread readThread = new Thread(Ray.wrapRunnable(new Runnable() {
+ @Override
+ public void run() {
+ consume();
+ }
+ }));
+ readThread.start();
+
+ LOGGER.info("ReaderWorker init done");
+
+ return true;
+ }
+
+ public final void consume() {
+
+ int checkPointId = 1;
+ for (int i = 0; i < msgCount * inputQueueList.size(); ++i) {
+ DataMessage dataMessage = dataReader.read(100);
+
+ if (dataMessage == null) {
+ LOGGER.error("dataMessage is null");
+ i--;
+ continue;
+ }
+
+ int bufferSize = dataMessage.body().remaining();
+ int dataSize = dataMessage.body().getInt();
+
+ // check size
+ LOGGER.info("capacity {} bufferSize {} dataSize {}",
+ dataMessage.body().capacity(), bufferSize, dataSize);
+ Assert.assertEquals(bufferSize, dataSize);
+ if (dataMessage instanceof DataMessage) {
+ if (LOGGER.isInfoEnabled()) {
+ LOGGER.info("{} : {} message.", i, dataMessage.toString());
+ }
+ // check content
+ for (int j = 0; j < dataSize - 4; ++j) {
+ Assert.assertEquals(dataMessage.body().get(), (byte) j);
+ }
+ } else {
+ LOGGER.error("unknown message type");
+ Assert.fail();
+ }
+
+ totalMsg++;
+ }
+
+ LOGGER.info("ReaderWorker consume data done.");
+ }
+
+ void onQueueTransfer(long handler, byte[] buffer) {
+ }
+
+
+ public boolean done() {
+ return totalMsg == msgCount;
+ }
+
+ public int getTotalMsg() {
+ return totalMsg;
+ }
+}
+
+@RayRemote
+class WriterWorker extends Worker {
+ private static final Logger LOGGER = LoggerFactory.getLogger(WriterWorker.class);
+
+ private String name = null;
+ private List outputQueueList = null;
+ private List outputActorIds = new ArrayList<>();
+ DataWriter dataWriter = null;
+ RayActor peerActor = null;
+ int msgCount = 0;
+
+ public WriterWorker(String name) {
+ this.name = name;
+ }
+
+ public String getName() {
+ String management = ManagementFactory.getRuntimeMXBean().getName();
+ String pid = management.split("@")[0];
+
+ LOGGER.info("pid: {} name: {}", pid, name);
+ return name;
+ }
+
+ public String testCallReader(RayActor readerActor) {
+ String name = (String) Ray.call(ReaderWorker::getName, readerActor).get();
+ LOGGER.info("testCallReader: {}", name);
+ return name;
+ }
+
+ public boolean init(List outputQueueList, RayActor peer, int msgCount) {
+
+ this.outputQueueList = outputQueueList;
+ this.peerActor = peer;
+ this.msgCount = msgCount;
+
+ LOGGER.info("WriterWorker init:");
+
+ for (String queue : this.outputQueueList) {
+ outputActorIds.add(this.peerActor.getId());
+ LOGGER.info("WriterWorker actorId: {}", this.peerActor.getId());
+ }
+
+ LOGGER.info("Peer isDirectActorCall: {}", ((NativeRayActor) peer).isDirectCallActor());
+ int count = 3;
+ while (count-- != 0) {
+ Ray.call(ReaderWorker::testRayCall, peer).get();
+ }
+
+ try {
+ Thread.sleep(2 * 1000);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ Map conf = new HashMap<>();
+
+ conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL);
+ conf.put(Config.CHANNEL_SIZE, "100000");
+ conf.put(Config.STREAMING_JOB_NAME, "integrationTest1");
+
+ dataWriter = new DataWriter(this.outputQueueList, this.outputActorIds, conf);
+ Thread writerThread = new Thread(Ray.wrapRunnable(new Runnable() {
+ @Override
+ public void run() {
+ produce();
+ }
+ }));
+ writerThread.start();
+
+ LOGGER.info("WriterWorker init done");
+ return true;
+ }
+
+ public final void produce() {
+
+ int checkPointId = 1;
+ Random random = new Random();
+ this.msgCount = 100;
+ for (int i = 0; i < this.msgCount; ++i) {
+ for (int j = 0; j < outputQueueList.size(); ++j) {
+ LOGGER.info("WriterWorker produce");
+ int dataSize = (random.nextInt(100)) + 10;
+ if (LOGGER.isInfoEnabled()) {
+ LOGGER.info("dataSize: {}", dataSize);
+ }
+ ByteBuffer bb = ByteBuffer.allocate(dataSize);
+ bb.putInt(dataSize);
+ for (int k = 0; k < dataSize - 4; ++k) {
+ bb.put((byte) k);
+ }
+
+ bb.clear();
+ ChannelID qid = ChannelID.from(outputQueueList.get(j));
+ dataWriter.write(qid, bb);
+ }
+ }
+ try {
+ Thread.sleep(20 * 1000);
+ } catch (InterruptedException e) {
+ e.printStackTrace();
+ }
+ }
+}
\ No newline at end of file
diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/transfer/ChannelIDTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/transfer/ChannelIDTest.java
new file mode 100644
index 000000000000..654447f95774
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/transfer/ChannelIDTest.java
@@ -0,0 +1,22 @@
+package org.ray.streaming.runtime.transfer;
+
+import static org.testng.Assert.assertEquals;
+
+
+import org.ray.streaming.runtime.util.EnvUtil;
+import org.testng.annotations.Test;
+
+public class ChannelIDTest {
+
+ static {
+ EnvUtil.loadNativeLibraries();
+ }
+
+ @Test
+ public void testIdStrToBytes() {
+ String idStr = ChannelID.genRandomIdStr();
+ assertEquals(idStr.length(), ChannelID.ID_LENGTH * 2);
+ assertEquals(ChannelID.idStrToBytes(idStr).length, ChannelID.ID_LENGTH);
+ }
+
+}
\ No newline at end of file
diff --git a/streaming/java/streaming-runtime/src/test/resources/log4j.properties b/streaming/java/streaming-runtime/src/test/resources/log4j.properties
new file mode 100644
index 000000000000..30d876aec121
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/test/resources/log4j.properties
@@ -0,0 +1,6 @@
+log4j.rootLogger=INFO, stdout
+# Direct log messages to stdout
+log4j.appender.stdout=org.apache.log4j.ConsoleAppender
+log4j.appender.stdout.Target=System.out
+log4j.appender.stdout.layout=org.apache.log4j.PatternLayout
+log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n
diff --git a/streaming/java/streaming-runtime/src/test/resources/ray.conf b/streaming/java/streaming-runtime/src/test/resources/ray.conf
new file mode 100644
index 000000000000..fdc897fa624e
--- /dev/null
+++ b/streaming/java/streaming-runtime/src/test/resources/ray.conf
@@ -0,0 +1,3 @@
+ray {
+ run-mode = SINGLE_PROCESS
+}
diff --git a/streaming/java/test.sh b/streaming/java/test.sh
new file mode 100755
index 000000000000..c58f28f88aee
--- /dev/null
+++ b/streaming/java/test.sh
@@ -0,0 +1,39 @@
+#!/usr/bin/env bash
+
+# Cause the script to exit if a single command fails.
+set -e
+# Show explicitly which commands are currently running.
+set -x
+
+ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd)
+
+run_testng() {
+ "$@" || exit_code=$?
+ # exit_code == 2 means there are skipped tests.
+ if [ $exit_code -ne 2 ] && [ $exit_code -ne 0 ] ; then
+ exit $exit_code
+ fi
+}
+
+echo "build ray streaming"
+bazel build //streaming/java:all
+
+echo "Linting Java code with checkstyle."
+bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_only
+
+echo "Running streaming tests."
+run_testng java -cp $ROOT_DIR/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\
+ org.testng.TestNG -d /tmp/ray_streaming_java_test_output $ROOT_DIR/testng.xml
+
+echo "Streaming TestNG results"
+cat /tmp/ray_streaming_java_test_output/testng-results.xml
+
+echo "Testing maven install."
+cd $ROOT_DIR/../../java
+echo "build ray maven deps"
+bazel build gen_maven_deps
+echo "maven install ray"
+mvn clean install -DskipTests
+cd $ROOT_DIR
+echo "maven install ray streaming"
+mvn clean install -DskipTests
diff --git a/java/streaming/testng.xml b/streaming/java/testng.xml
similarity index 100%
rename from java/streaming/testng.xml
rename to streaming/java/testng.xml
diff --git a/streaming/src/channel.cc b/streaming/src/channel.cc
index de7c99f8e0da..1edd99c1b433 100644
--- a/streaming/src/channel.cc
+++ b/streaming/src/channel.cc
@@ -205,15 +205,23 @@ struct MockQueueItem {
std::shared_ptr data;
};
-struct MockQueue {
+class MockQueue {
+ public:
std::unordered_map>>
message_buffer_;
std::unordered_map>>
consumed_buffer_;
+ static std::mutex mutex;
+ static MockQueue &GetMockQueue() {
+ static MockQueue mock_queue;
+ return mock_queue;
+ }
};
-static MockQueue mock_queue;
+std::mutex MockQueue::mutex;
StreamingStatus MockProducer::CreateTransferChannel() {
+ std::unique_lock lock(MockQueue::mutex);
+ MockQueue &mock_queue = MockQueue::GetMockQueue();
mock_queue.message_buffer_[channel_info.channel_id] =
std::make_shared>(500);
mock_queue.consumed_buffer_[channel_info.channel_id] =
@@ -222,12 +230,16 @@ StreamingStatus MockProducer::CreateTransferChannel() {
}
StreamingStatus MockProducer::DestroyTransferChannel() {
+ std::unique_lock lock(MockQueue::mutex);
+ MockQueue &mock_queue = MockQueue::GetMockQueue();
mock_queue.message_buffer_.erase(channel_info.channel_id);
mock_queue.consumed_buffer_.erase(channel_info.channel_id);
return StreamingStatus::OK;
}
StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) {
+ std::unique_lock lock(MockQueue::mutex);
+ MockQueue &mock_queue = MockQueue::GetMockQueue();
auto &ring_buffer = mock_queue.message_buffer_[channel_info.channel_id];
if (ring_buffer->Full()) {
return StreamingStatus::OutOfMemory;
@@ -244,6 +256,8 @@ StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_
StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data,
uint32_t &data_size,
uint32_t timeout) {
+ std::unique_lock lock(MockQueue::mutex);
+ MockQueue &mock_queue = MockQueue::GetMockQueue();
auto &channel_id = channel_info.channel_id;
if (mock_queue.message_buffer_.find(channel_id) == mock_queue.message_buffer_.end()) {
return StreamingStatus::NoSuchItem;
@@ -262,6 +276,8 @@ StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_
}
StreamingStatus MockConsumer::NotifyChannelConsumed(uint64_t offset_id) {
+ std::unique_lock lock(MockQueue::mutex);
+ MockQueue &mock_queue = MockQueue::GetMockQueue();
auto &channel_id = channel_info.channel_id;
auto &ring_buffer = mock_queue.consumed_buffer_[channel_id];
while (!ring_buffer->Empty() && ring_buffer->Front().seq_id <= offset_id) {
diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.cc
new file mode 100644
index 000000000000..364d0af9f861
--- /dev/null
+++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.cc
@@ -0,0 +1,17 @@
+#include "org_ray_streaming_runtime_transfer_ChannelID.h"
+#include "streaming_jni_common.h"
+using namespace ray::streaming;
+
+JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_createNativeID(
+ JNIEnv *env, jclass cls, jlong qid_address) {
+ auto id = ray::ObjectID::FromBinary(
+ std::string(reinterpret_cast(qid_address), ray::ObjectID::Size()));
+ return reinterpret_cast(new ray::ObjectID(id));
+}
+
+JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_destroyNativeID(
+ JNIEnv *env, jclass cls, jlong native_id_ptr) {
+ auto id = reinterpret_cast(native_id_ptr);
+ STREAMING_CHECK(id != nullptr);
+ delete id;
+}
diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.h b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.h
new file mode 100644
index 000000000000..c6353e9b1778
--- /dev/null
+++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.h
@@ -0,0 +1,31 @@
+/* DO NOT EDIT THIS FILE - it is machine generated */
+#include
+/* Header for class org_ray_streaming_runtime_transfer_ChannelID */
+
+#ifndef _Included_org_ray_streaming_runtime_transfer_ChannelID
+#define _Included_org_ray_streaming_runtime_transfer_ChannelID
+#ifdef __cplusplus
+extern "C" {
+#endif
+#undef org_ray_streaming_runtime_transfer_ChannelID_ID_LENGTH
+#define org_ray_streaming_runtime_transfer_ChannelID_ID_LENGTH 20L
+/*
+ * Class: org_ray_streaming_runtime_transfer_ChannelID
+ * Method: createNativeID
+ * Signature: (J)J
+ */
+JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_createNativeID
+ (JNIEnv *, jclass, jlong);
+
+/*
+ * Class: org_ray_streaming_runtime_transfer_ChannelID
+ * Method: destroyNativeID
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_destroyNativeID
+ (JNIEnv *, jclass, jlong);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.cc
new file mode 100644
index 000000000000..651ef6b32e54
--- /dev/null
+++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.cc
@@ -0,0 +1,88 @@
+#include "org_ray_streaming_runtime_transfer_DataReader.h"
+#include
+#include "data_reader.h"
+#include "runtime_context.h"
+#include "streaming_jni_common.h"
+
+using namespace ray;
+using namespace ray::streaming;
+
+JNIEXPORT jlong JNICALL
+Java_org_ray_streaming_runtime_transfer_DataReader_createDataReaderNative(
+ JNIEnv *env, jclass, jobjectArray input_channels, jobjectArray input_actor_ids,
+ jlongArray seq_id_array, jlongArray msg_id_array, jlong timer_interval,
+ jboolean isRecreate, jbyteArray config_bytes, jboolean is_mock) {
+ STREAMING_LOG(INFO) << "[JNI]: create DataReader.";
+ std::vector input_channels_ids =
+ jarray_to_object_id_vec(env, input_channels);
+ std::vector actor_ids = jarray_to_actor_id_vec(env, input_actor_ids);
+ std::vector seq_ids = LongVectorFromJLongArray(env, seq_id_array).data;
+ std::vector msg_ids = LongVectorFromJLongArray(env, msg_id_array).data;
+
+ auto ctx = std::make_shared();
+ RawDataFromJByteArray conf(env, config_bytes);
+ if (conf.data_size > 0) {
+ STREAMING_LOG(INFO) << "load config, config bytes size: " << conf.data_size;
+ ctx->SetConfig(conf.data, conf.data_size);
+ }
+ if (is_mock) {
+ ctx->MarkMockTest();
+ }
+ auto reader = new DataReader(ctx);
+ reader->Init(input_channels_ids, actor_ids, seq_ids, msg_ids, timer_interval);
+ STREAMING_LOG(INFO) << "create native DataReader succeed";
+ return reinterpret_cast(reader);
+}
+
+JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_getBundleNative(
+ JNIEnv *env, jobject, jlong reader_ptr, jlong timeout_millis, jlong out,
+ jlong meta_addr) {
+ std::shared_ptr bundle;
+ auto reader = reinterpret_cast(reader_ptr);
+ auto status = reader->GetBundle((uint32_t)timeout_millis, bundle);
+
+ // over timeout, return empty array.
+ if (StreamingStatus::Interrupted == status) {
+ throwChannelInterruptException(env, "reader interrupted.");
+ } else if (StreamingStatus::GetBundleTimeOut == status) {
+ } else if (StreamingStatus::InitQueueFailed == status) {
+ throwRuntimeException(env, "init channel failed");
+ } else if (StreamingStatus::WaitQueueTimeOut == status) {
+ throwRuntimeException(env, "wait channel object timeout");
+ }
+
+ if (StreamingStatus::OK != status) {
+ *reinterpret_cast(out) = 0;
+ *reinterpret_cast(out + 8) = 0;
+ return;
+ }
+
+ // bundle data
+ // In streaming queue, bundle data and metadata will be different args of direct call,
+ // so we separate it here for future extensibility.
+ *reinterpret_cast(out) =
+ reinterpret_cast(bundle->data + kMessageBundleHeaderSize);
+ *reinterpret_cast(out + 8) = bundle->data_size - kMessageBundleHeaderSize;
+
+ // bundle metadata
+ auto meta = reinterpret_cast(meta_addr);
+ // bundle header written by writer
+ std::memcpy(meta, bundle->data, kMessageBundleHeaderSize);
+ // append qid
+ std::memcpy(meta + kMessageBundleHeaderSize, bundle->from.Data(), kUniqueIDSize);
+}
+
+JNIEXPORT void JNICALL
+Java_org_ray_streaming_runtime_transfer_DataReader_stopReaderNative(JNIEnv *env,
+ jobject thisObj,
+ jlong ptr) {
+ auto reader = reinterpret_cast(ptr);
+ reader->Stop();
+}
+
+JNIEXPORT void JNICALL
+Java_org_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env,
+ jobject thisObj,
+ jlong ptr) {
+ delete reinterpret_cast(ptr);
+}
diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.h b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.h
new file mode 100644
index 000000000000..f9f266a3d018
--- /dev/null
+++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.h
@@ -0,0 +1,45 @@
+/* DO NOT EDIT THIS FILE - it is machine generated */
+#include
+/* Header for class org_ray_streaming_runtime_transfer_DataReader */
+
+#ifndef _Included_org_ray_streaming_runtime_transfer_DataReader
+#define _Included_org_ray_streaming_runtime_transfer_DataReader
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*
+ * Class: org_ray_streaming_runtime_transfer_DataReader
+ * Method: createDataReaderNative
+ * Signature: ([[B[[B[J[JJZ[BZ)J
+ */
+JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_createDataReaderNative
+ (JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlongArray, jlong, jboolean, jbyteArray, jboolean);
+
+/*
+ * Class: org_ray_streaming_runtime_transfer_DataReader
+ * Method: getBundleNative
+ * Signature: (JJJJ)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_getBundleNative
+ (JNIEnv *, jobject, jlong, jlong, jlong, jlong);
+
+/*
+ * Class: org_ray_streaming_runtime_transfer_DataReader
+ * Method: stopReaderNative
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_stopReaderNative
+ (JNIEnv *, jobject, jlong);
+
+/*
+ * Class: org_ray_streaming_runtime_transfer_DataReader
+ * Method: closeReaderNative
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_closeReaderNative
+ (JNIEnv *, jobject, jlong);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.cc
new file mode 100644
index 000000000000..439cd89a9974
--- /dev/null
+++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.cc
@@ -0,0 +1,82 @@
+#include "org_ray_streaming_runtime_transfer_DataWriter.h"
+#include "config/streaming_config.h"
+#include "data_writer.h"
+#include "streaming_jni_common.h"
+
+using namespace ray::streaming;
+
+JNIEXPORT jlong JNICALL
+Java_org_ray_streaming_runtime_transfer_DataWriter_createWriterNative(
+ JNIEnv *env, jclass, jobjectArray output_queue_ids, jobjectArray output_actor_ids,
+ jlongArray msg_ids, jlong channel_size, jbyteArray conf_bytes_array,
+ jboolean is_mock) {
+ STREAMING_LOG(INFO) << "[JNI]: createDataWriterNative.";
+ std::vector queue_id_vec =
+ jarray_to_object_id_vec(env, output_queue_ids);
+ for (auto id : queue_id_vec) {
+ STREAMING_LOG(INFO) << "output channel id: " << id.Hex();
+ }
+ STREAMING_LOG(INFO) << "total channel size: " << channel_size << "*"
+ << queue_id_vec.size() << "=" << queue_id_vec.size() * channel_size;
+ LongVectorFromJLongArray long_array_obj(env, msg_ids);
+ std::vector msg_ids_vec = LongVectorFromJLongArray(env, msg_ids).data;
+ std::vector queue_size_vec(long_array_obj.data.size(), channel_size);
+ std::vector remain_id_vec;
+ std::vector actor_ids = jarray_to_actor_id_vec(env, output_actor_ids);
+
+ STREAMING_LOG(INFO) << "actor_ids: " << actor_ids[0];
+
+ RawDataFromJByteArray conf(env, conf_bytes_array);
+ STREAMING_CHECK(conf.data != nullptr);
+ auto runtime_context = std::make_shared();
+ if (conf.data_size > 0) {
+ runtime_context->SetConfig(conf.data, conf.data_size);
+ }
+ if (is_mock) {
+ runtime_context->MarkMockTest();
+ }
+ auto *data_writer = new DataWriter(runtime_context);
+ auto status = data_writer->Init(queue_id_vec, actor_ids, msg_ids_vec, queue_size_vec);
+ if (status != StreamingStatus::OK) {
+ STREAMING_LOG(WARNING) << "DataWriter init failed.";
+ } else {
+ STREAMING_LOG(INFO) << "DataWriter init success";
+ }
+
+ data_writer->Run();
+ return reinterpret_cast(data_writer);
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_ray_streaming_runtime_transfer_DataWriter_writeMessageNative(
+ JNIEnv *env, jobject, jlong writer_ptr, jlong qid_ptr, jlong address, jint size) {
+ auto *data_writer = reinterpret_cast(writer_ptr);
+ auto qid = *reinterpret_cast(qid_ptr);
+ auto data = reinterpret_cast(address);
+ auto data_size = static_cast(size);
+ jlong result = data_writer->WriteMessageToBufferRing(qid, data, data_size,
+ StreamingMessageType::Message);
+
+ if (result == 0) {
+ STREAMING_LOG(INFO) << "writer interrupted, return 0.";
+ throwChannelInterruptException(env, "writer interrupted.");
+ }
+ return result;
+}
+
+JNIEXPORT void JNICALL
+Java_org_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env,
+ jobject thisObj,
+ jlong ptr) {
+ STREAMING_LOG(INFO) << "jni: stop writer.";
+ auto *data_writer = reinterpret_cast(ptr);
+ data_writer->Stop();
+}
+
+JNIEXPORT void JNICALL
+Java_org_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env,
+ jobject thisObj,
+ jlong ptr) {
+ auto *data_writer = reinterpret_cast(ptr);
+ delete data_writer;
+}
\ No newline at end of file
diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.h b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.h
new file mode 100644
index 000000000000..a6fdf533c4d5
--- /dev/null
+++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.h
@@ -0,0 +1,45 @@
+/* DO NOT EDIT THIS FILE - it is machine generated */
+#include
+/* Header for class org_ray_streaming_runtime_transfer_DataWriter */
+
+#ifndef _Included_org_ray_streaming_runtime_transfer_DataWriter
+#define _Included_org_ray_streaming_runtime_transfer_DataWriter
+#ifdef __cplusplus
+extern "C" {
+#endif
+/*
+ * Class: org_ray_streaming_runtime_transfer_DataWriter
+ * Method: createWriterNative
+ * Signature: ([[B[[B[JJ[BZ)J
+ */
+JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_createWriterNative
+ (JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlong, jbyteArray, jboolean);
+
+/*
+ * Class: org_ray_streaming_runtime_transfer_DataWriter
+ * Method: writeMessageNative
+ * Signature: (JJJI)J
+ */
+JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_writeMessageNative
+ (JNIEnv *, jobject, jlong, jlong, jlong, jint);
+
+/*
+ * Class: org_ray_streaming_runtime_transfer_DataWriter
+ * Method: stopWriterNative
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_stopWriterNative
+ (JNIEnv *, jobject, jlong);
+
+/*
+ * Class: org_ray_streaming_runtime_transfer_DataWriter
+ * Method: closeWriterNative
+ * Signature: (J)V
+ */
+JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_closeWriterNative
+ (JNIEnv *, jobject, jlong);
+
+#ifdef __cplusplus
+}
+#endif
+#endif
diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc
new file mode 100644
index 000000000000..43f5e4a087f4
--- /dev/null
+++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc
@@ -0,0 +1,75 @@
+#include "org_ray_streaming_runtime_transfer_TransferHandler.h"
+#include "queue/queue_client.h"
+#include "streaming_jni_common.h"
+
+using namespace ray::streaming;
+
+static std::shared_ptr JByteArrayToBuffer(JNIEnv *env,
+ jbyteArray bytes) {
+ RawDataFromJByteArray buf(env, bytes);
+ STREAMING_CHECK(buf.data != nullptr);
+
+ return std::make_shared(buf.data, buf.data_size, true);
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative(
+ JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
+ jobject sync_func) {
+ auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
+ auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
+ auto *writer_client =
+ new WriterClient(reinterpret_cast(core_worker_ptr),
+ ray_async_func, ray_sync_func);
+ return reinterpret_cast(writer_client);
+}
+
+JNIEXPORT jlong JNICALL
+Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative(
+ JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func,
+ jobject sync_func) {
+ ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func);
+ ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func);
+ auto *reader_client =
+ new ReaderClient(reinterpret_cast(core_worker_ptr),
+ ray_async_func, ray_sync_func);
+ return reinterpret_cast(reader_client);
+}
+
+JNIEXPORT void JNICALL
+Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative(
+ JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
+ auto *writer_client = reinterpret_cast(ptr);
+ writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes));
+}
+
+JNIEXPORT jbyteArray JNICALL
+Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative(
+ JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) {
+ auto *writer_client = reinterpret_cast