diff --git a/.gitignore b/.gitignore index a380b4fb3c65..612f933727f3 100644 --- a/.gitignore +++ b/.gitignore @@ -148,6 +148,14 @@ java/runtime/native_dependencies/ # streaming/python streaming/python/generated/ +streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/generated/ +streaming/build/java +.clwb +streaming/**/.settings +streaming/java/**/target +streaming/java/**/.classpath +streaming/java/**/.project +streaming/java/**/*.log # python virtual env venv diff --git a/.travis.yml b/.travis.yml index 59b06a8d9d69..e8d1e5fa0536 100644 --- a/.travis.yml +++ b/.travis.yml @@ -35,11 +35,16 @@ matrix: - ./java/test.sh - os: linux - env: BAZEL_PYTHON_VERSION=PY3 PYTHON=3.5 PYTHONWARNINGS=ignore TESTSUITE=streaming + env: + - TESTSUITE=streaming + - JDK='Oracle JDK 8' + - RAY_INSTALL_JAVA=1 + - BAZEL_PYTHON_VERSION=PY3 + - PYTHON=3.5 PYTHONWARNINGS=ignore install: - python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py - eval `python $TRAVIS_BUILD_DIR/ci/travis/determine_tests_to_run.py` - - if [ $RAY_CI_STREAMING_PYTHON_AFFECTED != "1" ]; then exit; fi + - if [[ $RAY_CI_STREAMING_PYTHON_AFFECTED != "1" && $RAY_CI_STREAMING_JAVA_AFFECTED != "1" ]]; then exit; fi - ./ci/suppress_output ./ci/travis/install-bazel.sh - ./ci/suppress_output ./ci/travis/install-dependencies.sh - export PATH="$HOME/miniconda/bin:$PATH" @@ -47,7 +52,8 @@ matrix: script: # Streaming cpp test. - if [ $RAY_CI_STREAMING_CPP_AFFECTED == "1" ]; then ./ci/suppress_output bash streaming/src/test/run_streaming_queue_test.sh; fi - - if [ RAY_CI_STREAMING_PYTHON_AFFECTED == "1" ]; then python -m pytest -v --durations=5 --timeout=300 python/ray/streaming/tests/; fi + - if [ $RAY_CI_STREAMING_PYTHON_AFFECTED == "1" ]; then python -m pytest -v --durations=5 --timeout=300 streaming/python/tests/; fi + - if [ $RAY_CI_STREAMING_JAVA_AFFECTED == "1" ]; then ./streaming/java/test.sh; fi - os: linux env: LINT=1 PYTHONWARNINGS=ignore diff --git a/BUILD.bazel b/BUILD.bazel index 7eb7125fc354..4b689c8a6181 100644 --- a/BUILD.bazel +++ b/BUILD.bazel @@ -968,10 +968,23 @@ cc_binary( "@bazel_tools//src/conditions:darwin": ["external/bazel_tools/tools/jdk/include/darwin"], "//conditions:default": ["external/bazel_tools/tools/jdk/include/linux"], }), + # Export ray ABI symbols, which can then be used by libstreaming_java.so. see `//:_raylet` + linkopts = select({ + "@bazel_tools//src/conditions:darwin": [ + "-Wl,-exported_symbols_list,$(location //:src/ray/ray_exported_symbols.lds)", + ], + "@bazel_tools//src/conditions:windows": [ + ], + "//conditions:default": [ + "-Wl,--version-script,$(location //:src/ray/ray_version_script.lds)", + ], + }), linkshared = 1, linkstatic = 1, deps = [ "//:core_worker_lib", + "//:src/ray/ray_exported_symbols.lds", + "//:src/ray/ray_version_script.lds", ], ) diff --git a/bazel/ray.bzl b/bazel/ray.bzl index a3ef5dd57465..84eca9aa393e 100644 --- a/bazel/ray.bzl +++ b/bazel/ray.bzl @@ -47,7 +47,7 @@ def define_java_module( ) checkstyle_test( name = "org_ray_ray_" + name + "-checkstyle", - target = "//java:org_ray_ray_" + name, + target = ":org_ray_ray_" + name, config = "//java:checkstyle.xml", suppressions = "//java:checkstyle-suppressions.xml", size = "small", @@ -63,7 +63,7 @@ def define_java_module( ) checkstyle_test( name = "org_ray_ray_" + name + "_test-checkstyle", - target = "//java:org_ray_ray_" + name + "_test", + target = ":org_ray_ray_" + name + "_test", config = "//java:checkstyle.xml", suppressions = "//java:checkstyle-suppressions.xml", size = "small", diff --git a/bazel/ray_deps_build_all.bzl b/bazel/ray_deps_build_all.bzl index c6f187e83514..5c47301d82cd 100644 --- a/bazel/ray_deps_build_all.bzl +++ b/bazel/ray_deps_build_all.bzl @@ -1,4 +1,5 @@ load("@com_github_ray_project_ray//java:dependencies.bzl", "gen_java_deps") +load("@com_github_ray_project_ray//streaming/java:dependencies.bzl", "gen_streaming_java_deps") load("@com_github_nelhage_rules_boost//:boost/boost.bzl", "boost_deps") load("@com_github_jupp0r_prometheus_cpp//bazel:repositories.bzl", "prometheus_cpp_repositories") load("@com_github_checkstyle_java//:repo.bzl", "checkstyle_deps") @@ -9,6 +10,7 @@ load("@rules_proto_grpc//:repositories.bzl", "rules_proto_grpc_toolchains") def ray_deps_build_all(): gen_java_deps() + gen_streaming_java_deps() checkstyle_deps() boost_deps() prometheus_cpp_repositories() diff --git a/ci/travis/bazel-format.sh b/ci/travis/bazel-format.sh index 71b96357c803..75c6eb3e5db4 100755 --- a/ci/travis/bazel-format.sh +++ b/ci/travis/bazel-format.sh @@ -45,6 +45,6 @@ done pushd $ROOT_DIR/../.. BAZEL_FILES="bazel/BUILD bazel/BUILD.plasma bazel/ray.bzl BUILD.bazel - streaming/BUILD.bazel WORKSPACE" + streaming/BUILD.bazel streaming/java/BUILD.bazel WORKSPACE" buildifier -mode=$RUN_TYPE -diff_command="diff -u" $BAZEL_FILES popd diff --git a/ci/travis/determine_tests_to_run.py b/ci/travis/determine_tests_to_run.py index 40789b0a4c6d..3d85547c8910 100644 --- a/ci/travis/determine_tests_to_run.py +++ b/ci/travis/determine_tests_to_run.py @@ -40,6 +40,7 @@ def list_changed_files(commit_range): RAY_CI_MACOS_WHEELS_AFFECTED = 0 RAY_CI_STREAMING_CPP_AFFECTED = 0 RAY_CI_STREAMING_PYTHON_AFFECTED = 0 + RAY_CI_STREAMING_JAVA_AFFECTED = 0 if os.environ["TRAVIS_EVENT_TYPE"] == "pull_request": @@ -76,6 +77,7 @@ def list_changed_files(commit_range): RAY_CI_STREAMING_PYTHON_AFFECTED = 1 elif changed_file.startswith("java/"): RAY_CI_JAVA_AFFECTED = 1 + RAY_CI_STREAMING_JAVA_AFFECTED = 1 elif any( changed_file.startswith(prefix) for prefix in skip_prefix_list): @@ -91,11 +93,15 @@ def list_changed_files(commit_range): RAY_CI_MACOS_WHEELS_AFFECTED = 1 RAY_CI_STREAMING_CPP_AFFECTED = 1 RAY_CI_STREAMING_PYTHON_AFFECTED = 1 + RAY_CI_STREAMING_JAVA_AFFECTED = 1 elif changed_file.startswith("streaming/src"): RAY_CI_STREAMING_CPP_AFFECTED = 1 RAY_CI_STREAMING_PYTHON_AFFECTED = 1 + RAY_CI_STREAMING_JAVA_AFFECTED = 1 elif changed_file.startswith("streaming/python"): RAY_CI_STREAMING_PYTHON_AFFECTED = 1 + elif changed_file.startswith("streaming/java"): + RAY_CI_STREAMING_JAVA_AFFECTED = 1 else: RAY_CI_TUNE_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 @@ -105,6 +111,8 @@ def list_changed_files(commit_range): RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 RAY_CI_STREAMING_CPP_AFFECTED = 1 + RAY_CI_STREAMING_PYTHON_AFFECTED = 1 + RAY_CI_STREAMING_JAVA_AFFECTED = 1 else: RAY_CI_TUNE_AFFECTED = 1 RAY_CI_RLLIB_AFFECTED = 1 @@ -114,6 +122,8 @@ def list_changed_files(commit_range): RAY_CI_LINUX_WHEELS_AFFECTED = 1 RAY_CI_MACOS_WHEELS_AFFECTED = 1 RAY_CI_STREAMING_CPP_AFFECTED = 1 + RAY_CI_STREAMING_PYTHON_AFFECTED = 1 + RAY_CI_STREAMING_JAVA_AFFECTED = 1 # Log the modified environment variables visible in console. for output_stream in [sys.stdout, sys.stderr]: @@ -132,3 +142,5 @@ def list_changed_files(commit_range): .format(RAY_CI_STREAMING_CPP_AFFECTED)) _print("export RAY_CI_STREAMING_PYTHON_AFFECTED={}" .format(RAY_CI_STREAMING_PYTHON_AFFECTED)) + _print("export RAY_CI_STREAMING_JAVA_AFFECTED={}" + .format(RAY_CI_STREAMING_JAVA_AFFECTED)) diff --git a/java/BUILD.bazel b/java/BUILD.bazel index 35953bcc1628..77865979c8fc 100644 --- a/java/BUILD.bazel +++ b/java/BUILD.bazel @@ -5,7 +5,6 @@ exports_files([ "testng.xml", "checkstyle.xml", "checkstyle-suppressions.xml", - "streaming/testng.xml", ]) all_modules = [ @@ -13,7 +12,6 @@ all_modules = [ "runtime", "test", "tutorial", - "streaming", ] java_import( @@ -25,14 +23,11 @@ java_import( ] + [ "all_tests_deploy.jar", "all_tests_deploy-src.jar", - "streaming_tests_deploy.jar", - "streaming_tests_deploy-src.jar", ], deps = [ ":org_ray_ray_" + module for module in all_modules ] + [ ":all_tests", - ":streaming_tests", ], ) @@ -45,6 +40,7 @@ define_java_module( "@maven//:com_sun_xml_bind_jaxb_core", "@maven//:com_sun_xml_bind_jaxb_impl", ], + visibility = ["//visibility:public"] ) define_java_module( @@ -79,7 +75,9 @@ define_java_module( "@maven//:org_slf4j_slf4j_api", "@maven//:org_slf4j_slf4j_log4j12", "@maven//:redis_clients_jedis", + "@maven//:net_java_dev_jna_jna", ], + visibility = ["//visibility:public"] ) define_java_module( @@ -107,28 +105,6 @@ define_java_module( ], ) -define_java_module( - name = "streaming", - deps = [ - ":org_ray_ray_api", - ":org_ray_ray_runtime", - "@maven//:com_google_guava_guava", - "@maven//:org_slf4j_slf4j_api", - "@maven//:org_slf4j_slf4j_log4j12", - ], - define_test_lib = True, - test_deps = [ - ":org_ray_ray_api", - ":org_ray_ray_runtime", - ":org_ray_ray_streaming", - "@maven//:com_beust_jcommander", - "@maven//:com_google_guava_guava", - "@maven//:org_slf4j_slf4j_api", - "@maven//:org_slf4j_slf4j_log4j12", - "@maven//:org_testng_testng", - ], -) - java_binary( name = "all_tests", main_class = "org.testng.TestNG", @@ -140,16 +116,6 @@ java_binary( ], ) -java_binary( - name = "streaming_tests", - main_class = "org.testng.TestNG", - data = ["streaming/testng.xml"], - args = ["java/streaming/testng.xml"], - runtime_deps = [ - ":org_ray_ray_streaming_test", - ], -) - java_proto_compile( name = "common_java_proto", deps = ["@//:common_proto"], @@ -236,7 +202,6 @@ genrule( cp -f $(location //java:org_ray_ray_runtime_pom) $$WORK_DIR/java/runtime/pom.xml cp -f $(location //java:org_ray_ray_tutorial_pom) $$WORK_DIR/java/tutorial/pom.xml cp -f $(location //java:org_ray_ray_test_pom) $$WORK_DIR/java/test/pom.xml - cp -f $(location //java:org_ray_ray_streaming_pom) $$WORK_DIR/java/streaming/pom.xml echo $$(date) > $@ """, local = 1, diff --git a/java/dependencies.bzl b/java/dependencies.bzl index d72b2463304d..1f73b9b48cb9 100644 --- a/java/dependencies.bzl +++ b/java/dependencies.bzl @@ -18,8 +18,9 @@ def gen_java_deps(): "org.slf4j:slf4j-log4j12:1.7.25", "org.testng:testng:6.9.10", "redis.clients:jedis:2.8.0", + "net.java.dev.jna:jna:5.5.0" ], repositories = [ - "https://repo1.maven.org/maven2", + "https://repo1.maven.org/maven2/", ], ) diff --git a/java/pom.xml b/java/pom.xml index 912b803dedf7..67f5ffc7095f 100644 --- a/java/pom.xml +++ b/java/pom.xml @@ -12,7 +12,6 @@ api runtime test - streaming tutorial diff --git a/java/runtime/pom.xml b/java/runtime/pom.xml index eb6c268f8455..b4aa54508415 100644 --- a/java/runtime/pom.xml +++ b/java/runtime/pom.xml @@ -52,6 +52,11 @@ fst 2.57 + + net.java.dev.jna + jna + 5.5.0 + org.apache.commons commons-lang3 diff --git a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java index 62945b5576a0..5e035f7da05a 100644 --- a/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java +++ b/java/runtime/src/main/java/org/ray/runtime/RayNativeRuntime.java @@ -1,10 +1,8 @@ package org.ray.runtime; import com.google.common.base.Preconditions; -import com.google.common.base.Strings; import java.io.File; import java.io.IOException; -import java.lang.reflect.Field; import java.util.HashMap; import java.util.Map; import org.apache.commons.io.FileUtils; @@ -22,7 +20,7 @@ import org.ray.runtime.task.NativeTaskExecutor; import org.ray.runtime.task.NativeTaskSubmitter; import org.ray.runtime.task.TaskExecutor; -import org.ray.runtime.util.FileUtil; +import org.ray.runtime.util.JniUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -42,16 +40,11 @@ public final class RayNativeRuntime extends AbstractRayRuntime { static { LOGGER.debug("Loading native libraries."); - // Load native libraries. - String[] libraries = new String[]{"core_worker_library_java"}; - for (String library : libraries) { - String fileName = System.mapLibraryName(library); - try (FileUtil.TempFile libFile = FileUtil.getTempFileFromResource(fileName)) { - System.load(libFile.getFile().getAbsolutePath()); - } - LOGGER.debug("Native libraries loaded."); - } - + // Expose ray ABI symbols which may be depended by other shared + // libraries such as libstreaming_java.so. + // See BUILD.bazel:libcore_worker_library_java.so + JniUtils.loadLibrary("core_worker_library_java", true); + LOGGER.debug("Native libraries loaded."); RayConfig globalRayConfig = RayConfig.create(); resetLibraryPath(globalRayConfig); @@ -65,30 +58,9 @@ public final class RayNativeRuntime extends AbstractRayRuntime { } private static void resetLibraryPath(RayConfig rayConfig) { - if (rayConfig.libraryPath.isEmpty()) { - return; - } - - String path = System.getProperty("java.library.path"); - if (Strings.isNullOrEmpty(path)) { - path = ""; - } else { - path += ":"; - } - path += String.join(":", rayConfig.libraryPath); - - // This is a hack to reset library path at runtime, - // see https://stackoverflow.com/questions/15409223/. - System.setProperty("java.library.path", path); - // Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed. - final Field sysPathsField; - try { - sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); - sysPathsField.setAccessible(true); - sysPathsField.set(null, null); - } catch (NoSuchFieldException | IllegalAccessException e) { - LOGGER.error("Failed to set library path.", e); - } + String separator = System.getProperty("path.separator"); + String libraryPath = String.join(separator, rayConfig.libraryPath); + JniUtils.resetLibraryPath(libraryPath); } public RayNativeRuntime(RayConfig rayConfig, FunctionManager functionManager) { diff --git a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java index fad4ec2aa838..296fe19355c9 100644 --- a/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java +++ b/java/runtime/src/main/java/org/ray/runtime/runner/worker/DefaultWorker.java @@ -17,6 +17,9 @@ public class DefaultWorker { public static void main(String[] args) { try { System.setProperty("ray.worker.mode", "WORKER"); + // Set run-mode to `CLUSTER` explicitly, to prevent the DefaultWorker to receive + // a wrong run-mode parameter through jvm options. + System.setProperty("ray.run-mode", "CLUSTER"); Thread.setDefaultUncaughtExceptionHandler((Thread t, Throwable e) -> { LOGGER.error("Uncaught worker exception in thread {}: {}", t, e); }); diff --git a/java/runtime/src/main/java/org/ray/runtime/util/JniUtils.java b/java/runtime/src/main/java/org/ray/runtime/util/JniUtils.java new file mode 100644 index 000000000000..ccc68867cafa --- /dev/null +++ b/java/runtime/src/main/java/org/ray/runtime/util/JniUtils.java @@ -0,0 +1,84 @@ +package org.ray.runtime.util; + +import com.google.common.base.Strings; +import com.google.common.collect.Sets; +import com.sun.jna.NativeLibrary; +import java.lang.reflect.Field; +import java.util.Set; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class JniUtils { + private static final Logger LOGGER = LoggerFactory.getLogger(JniUtils.class); + private static Set loadedLibs = Sets.newHashSet(); + + /** + * Loads the native library specified by the libraryName argument. + * The libraryName argument must not contain any platform specific + * prefix, file extension or path. + * + * @param libraryName the name of the library. + */ + public static synchronized void loadLibrary(String libraryName) { + loadLibrary(libraryName, false); + } + + /** + * Loads the native library specified by the libraryName argument. + * The libraryName argument must not contain any platform specific + * prefix, file extension or path. + * + * @param libraryName the name of the library. + * @param exportSymbols export symbols of library so that it can be used by other libs. + */ + public static synchronized void loadLibrary(String libraryName, boolean exportSymbols) { + if (!loadedLibs.contains(libraryName)) { + LOGGER.debug("Loading native library {}.", libraryName); + // Load native library. + String fileName = System.mapLibraryName(libraryName); + String libPath = null; + try (FileUtil.TempFile libFile = FileUtil.getTempFileFromResource(fileName)) { + libPath = libFile.getFile().getAbsolutePath(); + if (exportSymbols) { + // Expose library symbols using RTLD_GLOBAL which may be depended by other shared + // libraries. + NativeLibrary.getInstance(libFile.getFile().getAbsolutePath()); + } + System.load(libPath); + } + LOGGER.debug("Native library loaded."); + resetLibraryPath(libPath); + loadedLibs.add(libraryName); + } + } + + /** + * This is a hack to reset library path at runtime. Please don't use it outside of ray + */ + public static synchronized void resetLibraryPath(String libPath) { + if (Strings.isNullOrEmpty(libPath)) { + return; + } + String path = System.getProperty("java.library.path"); + String separator = System.getProperty("path.separator"); + if (Strings.isNullOrEmpty(path)) { + path = ""; + } else { + path += separator; + } + path += String.join(separator, libPath); + + // This is a hack to reset library path at runtime, + // see https://stackoverflow.com/questions/15409223/. + System.setProperty("java.library.path", path); + // Set sys_paths to null so that java.library.path will be re-evaluated next time it is needed. + final Field sysPathsField; + try { + sysPathsField = ClassLoader.class.getDeclaredField("sys_paths"); + sysPathsField.setAccessible(true); + sysPathsField.set(null, null); + } catch (NoSuchFieldException | IllegalAccessException e) { + LOGGER.error("Failed to set library path.", e); + } + } +} diff --git a/java/streaming/src/main/java/org/ray/streaming/api/partition/Partition.java b/java/streaming/src/main/java/org/ray/streaming/api/partition/Partition.java deleted file mode 100644 index 46c8b04f3591..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/api/partition/Partition.java +++ /dev/null @@ -1,21 +0,0 @@ -package org.ray.streaming.api.partition; - -import org.ray.streaming.api.function.Function; - -/** - * Interface of the partitioning strategy. - * @param Type of the input data. - */ -@FunctionalInterface -public interface Partition extends Function { - - /** - * Given a record and downstream tasks, determine which task(s) should receive the record. - * - * @param record The record. - * @param taskIds IDs of all downstream tasks. - * @return IDs of the downstream tasks that should receive the record. - */ - int[] partition(T record, int[] taskIds); - -} diff --git a/java/streaming/src/main/java/org/ray/streaming/api/partition/impl/BroadcastPartition.java b/java/streaming/src/main/java/org/ray/streaming/api/partition/impl/BroadcastPartition.java deleted file mode 100644 index 2e415ee7e109..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/api/partition/impl/BroadcastPartition.java +++ /dev/null @@ -1,17 +0,0 @@ -package org.ray.streaming.api.partition.impl; - -import org.ray.streaming.api.partition.Partition; - -/** - * Broadcast the record to all downstream tasks. - */ -public class BroadcastPartition implements Partition { - - public BroadcastPartition() { - } - - @Override - public int[] partition(T value, int[] taskIds) { - return taskIds; - } -} diff --git a/java/streaming/src/main/java/org/ray/streaming/cluster/ResourceManager.java b/java/streaming/src/main/java/org/ray/streaming/cluster/ResourceManager.java deleted file mode 100644 index 3230abce5007..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/cluster/ResourceManager.java +++ /dev/null @@ -1,20 +0,0 @@ -package org.ray.streaming.cluster; - -import java.util.ArrayList; -import java.util.List; -import org.ray.api.Ray; -import org.ray.api.RayActor; -import org.ray.streaming.core.runtime.StreamWorker; - -public class ResourceManager { - - public List> createWorker(int workerNum) { - List> workers = new ArrayList<>(); - for (int i = 0; i < workerNum; i++) { - RayActor worker = Ray.createActor(StreamWorker::new); - workers.add(worker); - } - return workers; - } - -} diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/MasterProcessor.java b/java/streaming/src/main/java/org/ray/streaming/core/processor/MasterProcessor.java deleted file mode 100644 index 3ea0813224b8..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/core/processor/MasterProcessor.java +++ /dev/null @@ -1,101 +0,0 @@ -package org.ray.streaming.core.processor; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.stream.Collectors; -import org.ray.streaming.api.collector.Collector; -import org.ray.streaming.core.command.BatchInfo; -import org.ray.streaming.core.graph.ExecutionGraph; -import org.ray.streaming.core.graph.ExecutionNode; -import org.ray.streaming.core.graph.ExecutionNode.NodeType; -import org.ray.streaming.core.graph.ExecutionTask; -import org.ray.streaming.core.runtime.context.RuntimeContext; -import org.ray.streaming.message.Record; -import org.ray.streaming.operator.impl.MasterOperator; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - - -/** - * MasterProcessor is responsible for overall control logic. - */ -public class MasterProcessor extends StreamProcessor { - - private static final Logger LOGGER = LoggerFactory.getLogger(MasterProcessor.class); - - private Thread batchControllerThread; - private long maxBatch; - - public MasterProcessor(MasterOperator masterOperator) { - super(masterOperator); - } - - public void open(List collectors, RuntimeContext runtimeContext, - ExecutionGraph executionGraph) { - super.open(collectors, runtimeContext); - this.maxBatch = runtimeContext.getMaxBatch(); - startBatchController(executionGraph); - - } - - private void startBatchController(ExecutionGraph executionGraph) { - BatchController batchController = new BatchController(maxBatch, collectors); - List sinkTasks = new ArrayList<>(); - for (ExecutionNode executionNode : executionGraph.getExecutionNodeList()) { - if (executionNode.getNodeType() == NodeType.SINK) { - List nodeTasks = executionNode.getExecutionTaskList().stream() - .map(ExecutionTask::getTaskId).collect(Collectors.toList()); - sinkTasks.addAll(nodeTasks); - } - } - - batchControllerThread = new Thread(batchController, "controller-thread"); - batchControllerThread.start(); - } - - @Override - public void process(BatchInfo executionGraph) { - - } - - @Override - public void close() { - - } - - static class BatchController implements Runnable, Serializable { - - private AtomicInteger batchId; - private List collectors; - private Map sinkBatchMap; - private Integer frequency; - private long maxBatch; - - public BatchController(long maxBatch, List collectors) { - this.batchId = new AtomicInteger(0); - this.maxBatch = maxBatch; - this.collectors = collectors; - // TODO(zhenxuanpan): Use config to set. - this.frequency = 1000; - } - - @Override - public void run() { - while (batchId.get() < maxBatch) { - try { - Record record = new Record<>(new BatchInfo(batchId.getAndIncrement())); - for (Collector collector : collectors) { - collector.collect(record); - } - Thread.sleep(frequency); - } catch (Exception e) { - LOGGER.error(e.getMessage(), e); - } - } - } - - } -} diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/SourceProcessor.java b/java/streaming/src/main/java/org/ray/streaming/core/processor/SourceProcessor.java deleted file mode 100644 index 759a73d01087..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/core/processor/SourceProcessor.java +++ /dev/null @@ -1,25 +0,0 @@ -package org.ray.streaming.core.processor; - -import org.ray.streaming.operator.impl.SourceOperator; - -/** - * The processor for the stream sources, containing a SourceOperator. - * - * @param The type of source data. - */ -public class SourceProcessor extends StreamProcessor> { - - public SourceProcessor(SourceOperator operator) { - super(operator); - } - - @Override - public void process(Long batchId) { - this.operator.process(batchId); - } - - @Override - public void close() { - - } -} diff --git a/java/streaming/src/main/java/org/ray/streaming/core/runtime/StreamWorker.java b/java/streaming/src/main/java/org/ray/streaming/core/runtime/StreamWorker.java deleted file mode 100644 index 292d05b5b31e..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/core/runtime/StreamWorker.java +++ /dev/null @@ -1,86 +0,0 @@ -package org.ray.streaming.core.runtime; - -import java.io.Serializable; -import java.util.ArrayList; -import java.util.List; -import org.ray.api.annotation.RayRemote; -import org.ray.streaming.api.collector.Collector; -import org.ray.streaming.core.command.BatchInfo; -import org.ray.streaming.core.graph.ExecutionEdge; -import org.ray.streaming.core.graph.ExecutionGraph; -import org.ray.streaming.core.graph.ExecutionNode; -import org.ray.streaming.core.graph.ExecutionNode.NodeType; -import org.ray.streaming.core.graph.ExecutionTask; -import org.ray.streaming.core.processor.MasterProcessor; -import org.ray.streaming.core.processor.StreamProcessor; -import org.ray.streaming.core.runtime.collector.RayCallCollector; -import org.ray.streaming.core.runtime.context.RayRuntimeContext; -import org.ray.streaming.core.runtime.context.RuntimeContext; -import org.ray.streaming.core.runtime.context.WorkerContext; -import org.ray.streaming.message.Message; -import org.ray.streaming.message.Record; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * The stream worker, it is a ray actor. - */ -@RayRemote -public class StreamWorker implements Serializable { - - private static final Logger LOGGER = LoggerFactory.getLogger(StreamWorker.class); - - private int taskId; - private WorkerContext workerContext; - private StreamProcessor streamProcessor; - private NodeType nodeType; - - public StreamWorker() { - } - - public Boolean init(WorkerContext workerContext) { - this.workerContext = workerContext; - this.taskId = workerContext.getTaskId(); - ExecutionGraph executionGraph = this.workerContext.getExecutionGraph(); - ExecutionTask executionTask = executionGraph.getExecutionTaskByTaskId(taskId); - ExecutionNode executionNode = executionGraph.getExecutionNodeByTaskId(taskId); - - this.nodeType = executionNode.getNodeType(); - this.streamProcessor = executionNode.getStreamProcessor(); - LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor); - - List executionEdges = executionNode.getExecutionEdgeList(); - - List collectors = new ArrayList<>(); - for (ExecutionEdge executionEdge : executionEdges) { - collectors.add(new RayCallCollector(taskId, executionEdge, executionGraph)); - } - - RuntimeContext runtimeContext = new RayRuntimeContext(executionTask, workerContext.getConfig(), - executionNode.getParallelism()); - if (this.nodeType == NodeType.MASTER) { - ((MasterProcessor) streamProcessor).open(collectors, runtimeContext, executionGraph); - } else { - this.streamProcessor.open(collectors, runtimeContext); - } - return true; - } - - public Boolean process(Message message) { - LOGGER.debug("Processing message, taskId: {}, message: {}.", taskId, message); - if (nodeType == NodeType.SOURCE) { - Record record = message.getRecord(0); - BatchInfo batchInfo = (BatchInfo) record.getValue(); - this.streamProcessor.process(batchInfo.getBatchId()); - } else { - List records = message.getRecordList(); - for (Record record : records) { - record.setBatchId(message.getBatchId()); - record.setStream(message.getStream()); - this.streamProcessor.process(record); - } - } - return true; - } - -} diff --git a/java/streaming/src/main/java/org/ray/streaming/core/runtime/collector/RayCallCollector.java b/java/streaming/src/main/java/org/ray/streaming/core/runtime/collector/RayCallCollector.java deleted file mode 100644 index e5331a44e9ba..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/core/runtime/collector/RayCallCollector.java +++ /dev/null @@ -1,58 +0,0 @@ -package org.ray.streaming.core.runtime.collector; - -import java.util.Arrays; -import java.util.Map; -import org.ray.api.Ray; -import org.ray.api.RayActor; -import org.ray.streaming.api.collector.Collector; -import org.ray.streaming.api.partition.Partition; -import org.ray.streaming.core.graph.ExecutionEdge; -import org.ray.streaming.core.graph.ExecutionGraph; -import org.ray.streaming.core.runtime.StreamWorker; -import org.ray.streaming.message.Message; -import org.ray.streaming.message.Record; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; - -/** - * The collector that emits data via Ray remote calls. - */ -public class RayCallCollector implements Collector { - - private static final Logger LOGGER = LoggerFactory.getLogger(RayCallCollector.class); - - private int taskId; - private String stream; - private Map> taskId2Worker; - private int[] targetTaskIds; - private Partition partition; - - public RayCallCollector(int taskId, ExecutionEdge executionEdge, ExecutionGraph executionGraph) { - this.taskId = taskId; - this.stream = executionEdge.getStream(); - int targetNodeId = executionEdge.getTargetNodeId(); - taskId2Worker = executionGraph - .getTaskId2WorkerByNodeId(targetNodeId); - targetTaskIds = Arrays.stream(taskId2Worker.keySet() - .toArray(new Integer[taskId2Worker.size()])) - .mapToInt(Integer::valueOf).toArray(); - - this.partition = executionEdge.getPartition(); - LOGGER.debug("RayCallCollector constructed, taskId:{}, add stream:{}, partition:{}.", - taskId, stream, this.partition); - } - - @Override - public void collect(Record record) { - int[] taskIds = this.partition.partition(record, targetTaskIds); - LOGGER.debug("Sending data from task {} to remote tasks {}, collector stream:{}, record:{}", - taskId, taskIds, stream, record); - Message message = new Message(taskId, record.getBatchId(), stream, record); - for (int targetTaskId : taskIds) { - RayActor streamWorker = this.taskId2Worker.get(targetTaskId); - // Use ray call to send message to downstream actor. - Ray.call(StreamWorker::process, streamWorker, message); - } - } - -} diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/impl/MasterOperator.java b/java/streaming/src/main/java/org/ray/streaming/operator/impl/MasterOperator.java deleted file mode 100644 index 963a0d11b35b..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/operator/impl/MasterOperator.java +++ /dev/null @@ -1,17 +0,0 @@ -package org.ray.streaming.operator.impl; - -import org.ray.streaming.operator.OperatorType; -import org.ray.streaming.operator.StreamOperator; - - -public class MasterOperator extends StreamOperator { - - public MasterOperator() { - super(null); - } - - @Override - public OperatorType getOpType() { - return OperatorType.MASTER; - } -} diff --git a/java/streaming/src/main/java/org/ray/streaming/schedule/impl/JobScheduleImpl.java b/java/streaming/src/main/java/org/ray/streaming/schedule/impl/JobScheduleImpl.java deleted file mode 100644 index 45a13a107bb8..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/schedule/impl/JobScheduleImpl.java +++ /dev/null @@ -1,93 +0,0 @@ -package org.ray.streaming.schedule.impl; - -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import org.ray.api.Ray; -import org.ray.api.RayActor; -import org.ray.api.RayObject; -import org.ray.streaming.api.partition.impl.BroadcastPartition; -import org.ray.streaming.cluster.ResourceManager; -import org.ray.streaming.core.graph.ExecutionGraph; -import org.ray.streaming.core.graph.ExecutionNode; -import org.ray.streaming.core.graph.ExecutionNode.NodeType; -import org.ray.streaming.core.graph.ExecutionTask; -import org.ray.streaming.core.runtime.StreamWorker; -import org.ray.streaming.core.runtime.context.WorkerContext; -import org.ray.streaming.operator.impl.MasterOperator; -import org.ray.streaming.plan.Plan; -import org.ray.streaming.plan.PlanEdge; -import org.ray.streaming.plan.PlanVertex; -import org.ray.streaming.plan.VertexType; -import org.ray.streaming.schedule.IJobSchedule; -import org.ray.streaming.schedule.ITaskAssign; - - -public class JobScheduleImpl implements IJobSchedule { - - private Plan plan; - private Map jobConfig; - private ResourceManager resourceManager; - private ITaskAssign taskAssign; - - public JobScheduleImpl(Map jobConfig) { - this.resourceManager = new ResourceManager(); - this.taskAssign = new TaskAssignImpl(); - this.jobConfig = jobConfig; - } - - /** - * Schedule physical plan to execution graph, and call streaming worker to init and run. - */ - @Override - public void schedule(Plan plan) { - this.plan = plan; - addJobMaster(plan); - List> workers = this.resourceManager.createWorker(getPlanWorker()); - ExecutionGraph executionGraph = this.taskAssign.assign(this.plan, workers); - - List executionNodes = executionGraph.getExecutionNodeList(); - List> waits = new ArrayList<>(); - ExecutionTask masterTask = null; - for (ExecutionNode executionNode : executionNodes) { - List executionTasks = executionNode.getExecutionTaskList(); - for (ExecutionTask executionTask : executionTasks) { - if (executionNode.getNodeType() != NodeType.MASTER) { - Integer taskId = executionTask.getTaskId(); - RayActor streamWorker = executionTask.getWorker(); - waits.add(Ray.call(StreamWorker::init, streamWorker, - new WorkerContext(taskId, executionGraph, jobConfig))); - } else { - masterTask = executionTask; - } - } - } - Ray.wait(waits); - - Integer masterId = masterTask.getTaskId(); - RayActor masterWorker = masterTask.getWorker(); - Ray.call(StreamWorker::init, masterWorker, - new WorkerContext(masterId, executionGraph, jobConfig)).get(); - } - - private void addJobMaster(Plan plan) { - int masterVertexId = 0; - int masterParallelism = 1; - PlanVertex masterVertex = new PlanVertex(masterVertexId, masterParallelism, VertexType.MASTER, - new MasterOperator()); - plan.getPlanVertexList().add(masterVertex); - List planVertices = plan.getPlanVertexList(); - for (PlanVertex planVertex : planVertices) { - if (planVertex.getVertexType() == VertexType.SOURCE) { - PlanEdge planEdge = new PlanEdge(masterVertexId, planVertex.getVertexId(), - new BroadcastPartition()); - plan.getPlanEdgeList().add(planEdge); - } - } - } - - private int getPlanWorker() { - List planVertexList = plan.getPlanVertexList(); - return planVertexList.stream().map(vertex -> vertex.getParallelism()).reduce(0, Integer::sum); - } -} diff --git a/java/streaming/src/main/java/org/ray/streaming/util/ConfigKey.java b/java/streaming/src/main/java/org/ray/streaming/util/ConfigKey.java deleted file mode 100644 index 3fed75654eff..000000000000 --- a/java/streaming/src/main/java/org/ray/streaming/util/ConfigKey.java +++ /dev/null @@ -1,10 +0,0 @@ -package org.ray.streaming.util; - -public class ConfigKey { - - /** - * Maximum number of batches to run in a streaming job. - */ - public static final String STREAMING_MAX_BATCH_COUNT = "streaming.max.batch.count"; - -} diff --git a/java/streaming/src/test/java/org/ray/streaming/schedule/impl/TaskAssignImplTest.java b/java/streaming/src/test/java/org/ray/streaming/schedule/impl/TaskAssignImplTest.java deleted file mode 100644 index d3604e487ff8..000000000000 --- a/java/streaming/src/test/java/org/ray/streaming/schedule/impl/TaskAssignImplTest.java +++ /dev/null @@ -1,60 +0,0 @@ -package org.ray.streaming.schedule.impl; - -import org.ray.api.id.ActorId; -import org.ray.api.id.ObjectId; -import org.ray.runtime.actor.LocalModeRayActor; -import org.ray.streaming.api.partition.impl.RoundRobinPartition; -import org.ray.streaming.core.graph.ExecutionEdge; -import org.ray.streaming.core.graph.ExecutionGraph; -import org.ray.streaming.core.graph.ExecutionNode; -import org.ray.streaming.core.graph.ExecutionNode.NodeType; -import org.ray.streaming.core.runtime.StreamWorker; -import org.ray.streaming.plan.Plan; -import org.ray.streaming.plan.PlanBuilderTest; -import org.ray.streaming.schedule.ITaskAssign; -import java.util.ArrayList; -import java.util.List; -import org.ray.api.RayActor; -import org.slf4j.Logger; -import org.slf4j.LoggerFactory; -import org.testng.Assert; -import org.testng.annotations.Test; - -public class TaskAssignImplTest { - - private static final Logger LOGGER = LoggerFactory.getLogger(TaskAssignImplTest.class); - - @Test - public void testTaskAssignImpl() { - PlanBuilderTest planBuilderTest = new PlanBuilderTest(); - Plan plan = planBuilderTest.buildDataSyncPlan(); - - List> workers = new ArrayList<>(); - for(int i = 0; i < plan.getPlanVertexList().size(); i++) { - workers.add(new LocalModeRayActor(ActorId.fromRandom(), ObjectId.fromRandom())); - } - - ITaskAssign taskAssign = new TaskAssignImpl(); - ExecutionGraph executionGraph = taskAssign.assign(plan, workers); - - List executionNodeList = executionGraph.getExecutionNodeList(); - - Assert.assertEquals(executionNodeList.size(), 2); - ExecutionNode sourceNode = executionNodeList.get(0); - Assert.assertEquals(sourceNode.getNodeType(), NodeType.SOURCE); - Assert.assertEquals(sourceNode.getExecutionTaskList().size(), 1); - Assert.assertEquals(sourceNode.getExecutionEdgeList().size(), 1); - - List sourceExecutionEdges = sourceNode.getExecutionEdgeList(); - - Assert.assertEquals(sourceExecutionEdges.size(), 1); - ExecutionEdge source2Sink = sourceExecutionEdges.get(0); - - Assert.assertEquals(source2Sink.getPartition().getClass(), RoundRobinPartition.class); - - ExecutionNode sinkNode = executionNodeList.get(1); - Assert.assertEquals(sinkNode.getNodeType(), NodeType.SINK); - Assert.assertEquals(sinkNode.getExecutionTaskList().size(), 1); - Assert.assertEquals(sinkNode.getExecutionEdgeList().size(), 0); - } -} diff --git a/java/test.sh b/java/test.sh index 4612bf7e35b9..4b03cf1f5b6a 100755 --- a/java/test.sh +++ b/java/test.sh @@ -34,9 +34,6 @@ echo "Running tests under single-process mode." # bazel test //java:all_tests --jvmopt="-Dray.run-mode=SINGLE_PROCESS" --test_output="errors" || single_exit_code=$? run_testng java -Dray.run-mode="SINGLE_PROCESS" -cp $ROOT_DIR/../bazel-bin/java/all_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/testng.xml -echo "Running streaming tests." -run_testng java -cp $ROOT_DIR/../bazel-bin/java/streaming_tests_deploy.jar org.testng.TestNG -d /tmp/ray_java_test_output $ROOT_DIR/streaming/testng.xml - popd pushd $ROOT_DIR diff --git a/python/setup.py b/python/setup.py index a5ca7b2ebe08..e9925f9f56c1 100644 --- a/python/setup.py +++ b/python/setup.py @@ -29,12 +29,14 @@ "ray/core/src/ray/raylet/raylet_monitor", "ray/core/src/ray/raylet/raylet", "ray/dashboard/dashboard.py", + "ray/streaming/_streaming.so", ] # These are the directories where automatically generated Python protobuf # bindings are created. generated_python_directories = [ "ray/core/generated", + "ray/streaming/generated", ] optional_ray_files = [] diff --git a/src/ray/ray_exported_symbols.lds b/src/ray/ray_exported_symbols.lds index e6bc669f00c0..392a5b12dc47 100644 --- a/src/ray/ray_exported_symbols.lds +++ b/src/ray/ray_exported_symbols.lds @@ -25,3 +25,4 @@ *PyInit* *init_raylet* *Java* +*JNI_* diff --git a/src/ray/ray_version_script.lds b/src/ray/ray_version_script.lds index 9021f7abb1ea..0be55a0b9daa 100644 --- a/src/ray/ray_version_script.lds +++ b/src/ray/ray_version_script.lds @@ -27,5 +27,6 @@ VERSION_1.0 { *PyInit*; *init_raylet*; *Java*; + *JNI_*; local: *; }; diff --git a/streaming/BUILD.bazel b/streaming/BUILD.bazel index 3a6079f4e1ec..751d56ef3781 100644 --- a/streaming/BUILD.bazel +++ b/streaming/BUILD.bazel @@ -235,6 +235,7 @@ genrule( GENERATED_DIR=$$WORK_DIR/streaming/python/generated rm -rf $$GENERATED_DIR mkdir -p $$GENERATED_DIR + touch $$GENERATED_DIR/__init__.py for f in $(locations //streaming:streaming_py_proto); do cp $$f $$GENERATED_DIR done @@ -243,3 +244,43 @@ genrule( local = 1, visibility = ["//visibility:public"], ) + +# Streaming java +genrule( + name = "copy_jni_h", + srcs = ["@bazel_tools//tools/jdk:jni_header"], + outs = ["jni.h"], + cmd = "cp -f $< $@", +) + +genrule( + name = "copy_jni_md_h", + srcs = select({ + "@bazel_tools//src/conditions:windows": ["@bazel_tools//tools/jdk:jni_md_header-windows"], + "@bazel_tools//src/conditions:darwin": ["@bazel_tools//tools/jdk:jni_md_header-darwin"], + "//conditions:default": ["@bazel_tools//tools/jdk:jni_md_header-linux"], + }), + outs = ["jni_md.h"], + cmd = "cp -f $< $@", +) + +cc_binary( + name = "libstreaming_java.so", + srcs = glob([ + "src/lib/java/*.cc", + "src/lib/java/*.h", + ]) + [ + ":jni.h", # needed for `include "jni.h"` + ":jni_md.h", + ], + includes = [ + ".", # needed for `include ` + "src", + ], + linkshared = 1, + linkstatic = 1, + visibility = ["//visibility:public"], + deps = [ + ":streaming_lib", + ], +) diff --git a/streaming/README.md b/streaming/README.md index d7091885fc50..cb8c5d78d935 100644 --- a/streaming/README.md +++ b/streaming/README.md @@ -2,27 +2,32 @@ 1. Build streaming java * build ray - * `sh build.sh -l java` - * `cd java && mvn clean install -Dmaven.test.skip=true` + * `bazel build //java:gen_maven_deps` + * `cd java && mvn clean install -Dmaven.test.skip=true && cd ..` * build streaming - * `cd ray/streaming/java && bazel build all_modules` + * `bazel build //streaming/java:gen_maven_deps` * `mvn clean install -Dmaven.test.skip=true` -2. Build ray will build ray streaming python. +2. Build ray python will build ray streaming python. 3. Run examples -```bash -# c++ test -cd streaming/ && bazel test ... -sh src/test/run_streaming_queue_test.sh -cd .. + ```bash + # c++ test + cd streaming/ && bazel test ... + sh src/test/run_streaming_queue_test.sh + cd .. -# python test -cd python/ray/streaming/ -pushd examples -python simple.py --input-file toy.txt -popd -pushd tests -pytest . -popd -``` \ No newline at end of file + # python test + pushd python/ray/streaming/ + pushd examples + python simple.py --input-file toy.txt + popd + pushd tests + pytest . + popd + popd + + # java test + cd streaming/java/streaming-runtime + mvn test + ``` \ No newline at end of file diff --git a/streaming/java/BUILD.bazel b/streaming/java/BUILD.bazel new file mode 100644 index 000000000000..4cc708c389eb --- /dev/null +++ b/streaming/java/BUILD.bazel @@ -0,0 +1,213 @@ +load("//bazel:ray.bzl", "define_java_module") +load("@rules_proto_grpc//java:defs.bzl", "java_proto_compile") + +exports_files([ + "testng.xml", +]) + +all_modules = [ + "streaming-api", + "streaming-runtime", +] + +java_import( + name = "all_modules", + jars = [ + "liborg_ray_ray_" + module + ".jar" + for module in all_modules + ] + [ + "liborg_ray_ray_" + module + "-src.jar" + for module in all_modules + ] + [ + "all_streaming_tests_deploy.jar", + "all_streaming_tests_deploy-src.jar", + ], + deps = [ + ":org_ray_ray_" + module + for module in all_modules + ] + [ + ":all_streaming_tests", + ], +) + +define_java_module( + name = "streaming-api", + define_test_lib = True, + test_deps = [ + "//java:org_ray_ray_api", + ":org_ray_ray_streaming-api", + "@ray_streaming_maven//:com_google_guava_guava", + "@ray_streaming_maven//:org_slf4j_slf4j_api", + "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", + "@ray_streaming_maven//:org_testng_testng", + ], + visibility = ["//visibility:public"], + deps = [ + "@ray_streaming_maven//:com_google_guava_guava", + "@ray_streaming_maven//:org_slf4j_slf4j_api", + "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", + ], +) + +# `//streaming:streaming_java` will be located in jar `streaming` directory, +# but we need it located in jar root path. +# resource_strip_prefix = "streaming" will make other resources file located in wrong path. +# So we copy libs explicitly to remove `streaming` path. +filegroup( + name = "java_native_deps", + srcs = [":streaming_java"], +) + +filegroup( + name = "streaming_java", + srcs = select({ + "@bazel_tools//src/conditions:darwin": [":streaming_java_darwin"], + "//conditions:default": [":streaming_java_linux"], + }), + visibility = ["//visibility:public"], +) + +genrule( + name = "streaming_java_darwin", + srcs = ["//streaming:libstreaming_java.so"], + outs = ["libstreaming_java.dylib"], + cmd = "cp $< $@", + output_to_bindir = 1, +) + +genrule( + name = "streaming_java_linux", + srcs = ["//streaming:libstreaming_java.so"], + outs = ["libstreaming_java.so"], + cmd = "cp $< $@", + output_to_bindir = 1, +) + +define_java_module( + name = "streaming-runtime", + additional_resources = [ + ":java_native_deps", + ], + additional_srcs = [ + ":all_java_proto", + ], + define_test_lib = True, + exclude_srcs = [ + "streaming-runtime/src/main/java/org/ray/streaming/runtime/generated/*.java", + ], + test_deps = [ + "//java:org_ray_ray_api", + "//java:org_ray_ray_runtime", + ":org_ray_ray_streaming-api", + ":org_ray_ray_streaming-runtime", + "@ray_streaming_maven//:com_google_guava_guava", + "@ray_streaming_maven//:org_slf4j_slf4j_api", + "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", + "@ray_streaming_maven//:org_testng_testng", + ], + visibility = ["//visibility:public"], + deps = [ + ":org_ray_ray_streaming-api", + "//java:org_ray_ray_api", + "//java:org_ray_ray_runtime", + "@ray_streaming_maven//:com_github_davidmoten_flatbuffers_java", + "@ray_streaming_maven//:com_google_guava_guava", + "@ray_streaming_maven//:com_google_protobuf_protobuf_java", + "@ray_streaming_maven//:org_slf4j_slf4j_api", + "@ray_streaming_maven//:org_slf4j_slf4j_log4j12", + ], +) + +java_binary( + name = "all_streaming_tests", + args = ["streaming/java/testng.xml"], + data = ["testng.xml"], + main_class = "org.testng.TestNG", + runtime_deps = [ + ":org_ray_ray_streaming-api_test", + ":org_ray_ray_streaming-runtime", + ":org_ray_ray_streaming-runtime_test", + "//java:org_ray_ray_runtime", + "@ray_streaming_maven//:com_beust_jcommander", + "@ray_streaming_maven//:org_testng_testng", + ], +) + +# proto buffer +java_proto_compile( + name = "streaming_java_proto", + deps = ["//streaming:streaming_proto"], +) + +filegroup( + name = "all_java_proto", + srcs = [ + ":streaming_java_proto", + ], +) + +genrule( + name = "copy_pom_file", + srcs = [ + "//streaming/java:org_ray_ray_" + module + "_pom" + for module in all_modules + ], + outs = ["copy_pom_file.out"], + cmd = """ + set -x + WORK_DIR=$$(pwd) + cp -f $(location //streaming/java:org_ray_ray_streaming-api_pom) $$WORK_DIR/streaming/java/streaming-api/pom.xml + cp -f $(location //streaming/java:org_ray_ray_streaming-runtime_pom) $$WORK_DIR/streaming/java/streaming-runtime/pom.xml + echo $$(date) > $@ + """, + local = 1, + tags = ["no-cache"], +) + +genrule( + name = "cp_java_generated", + srcs = [ + ":all_java_proto", + ":copy_pom_file", + ], + outs = ["cp_java_generated.out"], + cmd = """ + set -x + WORK_DIR=$$(pwd) + GENERATED_DIR=$$WORK_DIR/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/generated + rm -rf $$GENERATED_DIR + mkdir -p $$GENERATED_DIR + # Copy protobuf-generated files. + for f in $(locations //streaming/java:all_java_proto); do + unzip $$f -x META-INF/MANIFEST.MF -d $$WORK_DIR/streaming/java/streaming-runtime/src/main/java + done + echo $$(date) > $@ + """, + local = 1, + tags = ["no-cache"], +) + +# Generates the dependencies needed by maven. +genrule( + name = "gen_maven_deps", + srcs = [ + ":java_native_deps", + ":cp_java_generated", + ], + outs = ["gen_maven_deps.out"], + cmd = """ + set -x + WORK_DIR=$$(pwd) + # Copy native dependencies. + NATIVE_DEPS_DIR=$$WORK_DIR/streaming/java/streaming-runtime/native_dependencies/ + rm -rf $$NATIVE_DEPS_DIR + mkdir -p $$NATIVE_DEPS_DIR + for f in $(locations //streaming/java:java_native_deps); do + chmod +w $$f + cp $$f $$NATIVE_DEPS_DIR + done + echo $$(date) > $@ + """, + local = 1, + tags = ["no-cache"], +) diff --git a/streaming/java/checkstyle-suppressions.xml b/streaming/java/checkstyle-suppressions.xml new file mode 100644 index 000000000000..7847b6a83c71 --- /dev/null +++ b/streaming/java/checkstyle-suppressions.xml @@ -0,0 +1,14 @@ + + + + + + + + + + + + diff --git a/streaming/java/dependencies.bzl b/streaming/java/dependencies.bzl new file mode 100644 index 000000000000..61fb7418bd08 --- /dev/null +++ b/streaming/java/dependencies.bzl @@ -0,0 +1,20 @@ +load("@rules_jvm_external//:defs.bzl", "maven_install") + +def gen_streaming_java_deps(): + maven_install( + name = "ray_streaming_maven", + artifacts = [ + "com.beust:jcommander:1.72", + "com.google.guava:guava:27.0.1-jre", + "com.github.davidmoten:flatbuffers-java:1.9.0.1", + "com.google.protobuf:protobuf-java:3.8.0", + "de.ruedigermoeller:fst:2.57", + "org.slf4j:slf4j-api:1.7.12", + "org.slf4j:slf4j-log4j12:1.7.25", + "org.apache.logging.log4j:log4j-core:2.8.2", + "org.testng:testng:6.9.10", + ], + repositories = [ + "https://repo1.maven.org/maven2/", + ], + ) diff --git a/streaming/java/pom.xml b/streaming/java/pom.xml new file mode 100644 index 000000000000..0790163a6d1a --- /dev/null +++ b/streaming/java/pom.xml @@ -0,0 +1,154 @@ + + + + + + 4.0.0 + pom + + org.ray + ray-streaming + 0.1-SNAPSHOT + ray streaming + ray streaming + + streaming-api + streaming-runtime + + + + 1.8 + UTF-8 + 0.1-SNAPSHOT + + + + + com.google.guava + guava + 27.0.1-jre + + + org.slf4j + slf4j-api + 1.7.25 + + + org.slf4j + slf4j-log4j12 + 1.7.25 + + + org.testng + testng + 6.9.10 + + + + + + + + org.apache.maven.plugins + maven-compiler-plugin + 3.6.1 + + ${java.version} + ${java.version} + ${project.build.sourceEncoding} + -parameters + -parameters + + + + + org.apache.maven.plugins + maven-source-plugin + 3.0.1 + + + attach-sources + deploy + + jar + + + + + + + org.apache.maven.plugins + maven-dependency-plugin + 2.10 + + + org.apache.maven.plugins + maven-clean-plugin + 3.0.0 + + + + org.apache.maven.plugins + maven-assembly-plugin + 2.2 + + + + org.apache.maven.plugins + maven-javadoc-plugin + 2.10.4 + + + attach-javadocs + deploy + + jar + + + + + + maven-deploy-plugin + 2.8.2 + + + deploy + deploy + + deploy + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + 3.0.0 + + + validate + validate + + check + + + + + ../../java/checkstyle.xml + checkstyle-suppressions.xml + UTF-8 + true + true + true + warning + ${project.build.directory}/checkstyle-errors.xml + false + + + + + + + diff --git a/java/streaming/pom.xml b/streaming/java/streaming-api/pom.xml old mode 100644 new mode 100755 similarity index 68% rename from java/streaming/pom.xml rename to streaming/java/streaming-api/pom.xml index e624bd6e53ae..419636d48fc9 --- a/java/streaming/pom.xml +++ b/streaming/java/streaming-api/pom.xml @@ -1,18 +1,18 @@ - + + ray-streaming org.ray - ray-superpom 0.1-SNAPSHOT 4.0.0 - streaming - ray streaming - ray streaming + streaming-api + ray streaming api + ray streaming api jar @@ -23,16 +23,6 @@ ${project.version} - org.ray - ray-runtime - ${project.version} - - - com.beust - jcommander - 1.72 - - com.google.guava guava 27.0.1-jre diff --git a/java/streaming/pom_template.xml b/streaming/java/streaming-api/pom_template.xml similarity index 63% rename from java/streaming/pom_template.xml rename to streaming/java/streaming-api/pom_template.xml index 3551e7443e5c..53c0e9d7e54b 100644 --- a/java/streaming/pom_template.xml +++ b/streaming/java/streaming-api/pom_template.xml @@ -1,18 +1,18 @@ -{auto_gen_header} + {auto_gen_header} + ray-streaming org.ray - ray-superpom 0.1-SNAPSHOT 4.0.0 - streaming - ray streaming - ray streaming + streaming-api + ray streaming api + ray streaming api jar @@ -22,11 +22,6 @@ ray-api ${project.version} - - org.ray - ray-runtime - ${project.version} - -{generated_bzl_deps} + {generated_bzl_deps} diff --git a/java/streaming/src/main/java/org/ray/streaming/core/runtime/collector/CollectionCollector.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/collector/CollectionCollector.java similarity index 91% rename from java/streaming/src/main/java/org/ray/streaming/core/runtime/collector/CollectionCollector.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/collector/CollectionCollector.java index 03ef391d222f..536d33e05960 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/runtime/collector/CollectionCollector.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/collector/CollectionCollector.java @@ -1,4 +1,4 @@ -package org.ray.streaming.core.runtime.collector; +package org.ray.streaming.api.collector; import java.util.List; import org.ray.streaming.api.collector.Collector; diff --git a/java/streaming/src/main/java/org/ray/streaming/api/collector/Collector.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/collector/Collector.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/collector/Collector.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/collector/Collector.java diff --git a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RuntimeContext.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/RuntimeContext.java similarity index 81% rename from java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RuntimeContext.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/RuntimeContext.java index 984987e2918f..4a6ca368eb12 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RuntimeContext.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/RuntimeContext.java @@ -1,4 +1,4 @@ -package org.ray.streaming.core.runtime.context; +package org.ray.streaming.api.context; /** * Encapsulate the runtime information of a streaming task. @@ -15,5 +15,4 @@ public interface RuntimeContext { Long getMaxBatch(); - } diff --git a/java/streaming/src/main/java/org/ray/streaming/api/context/StreamingContext.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/StreamingContext.java similarity index 71% rename from java/streaming/src/main/java/org/ray/streaming/api/context/StreamingContext.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/StreamingContext.java index aaffa502197d..24c526e698e5 100644 --- a/java/streaming/src/main/java/org/ray/streaming/api/context/StreamingContext.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/context/StreamingContext.java @@ -1,17 +1,18 @@ package org.ray.streaming.api.context; +import com.google.common.base.Preconditions; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; +import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.ServiceLoader; import java.util.concurrent.atomic.AtomicInteger; -import org.ray.api.Ray; import org.ray.streaming.api.stream.StreamSink; import org.ray.streaming.plan.Plan; import org.ray.streaming.plan.PlanBuilder; -import org.ray.streaming.schedule.IJobSchedule; -import org.ray.streaming.schedule.impl.JobScheduleImpl; +import org.ray.streaming.schedule.JobScheduler; /** * Encapsulate the context information of a streaming Job. @@ -32,11 +33,10 @@ public class StreamingContext implements Serializable { private StreamingContext() { this.idGenerator = new AtomicInteger(0); this.streamSinks = new ArrayList<>(); - this.jobConfig = new HashMap(); + this.jobConfig = new HashMap<>(); } public static StreamingContext buildContext() { - Ray.init(); return new StreamingContext(); } @@ -48,8 +48,12 @@ public void execute() { this.plan = planBuilder.buildPlan(); plan.printPlan(); - IJobSchedule jobSchedule = new JobScheduleImpl(jobConfig); - jobSchedule.schedule(plan); + ServiceLoader serviceLoader = ServiceLoader.load(JobScheduler.class); + Iterator iterator = serviceLoader.iterator(); + Preconditions.checkArgument(iterator.hasNext(), + "No JobScheduler implementation has been provided."); + JobScheduler jobSchedule = iterator.next(); + jobSchedule.schedule(plan, jobConfig); } public int generateId() { diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/Function.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/Function.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/Function.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/Function.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/AggregateFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/AggregateFunction.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/AggregateFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/AggregateFunction.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/FlatMapFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/FlatMapFunction.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/FlatMapFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/FlatMapFunction.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/JoinFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/JoinFunction.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/JoinFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/JoinFunction.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/KeyFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/KeyFunction.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/KeyFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/KeyFunction.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/MapFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/MapFunction.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/MapFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/MapFunction.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/ProcessFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/ProcessFunction.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/ProcessFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/ProcessFunction.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/ReduceFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/ReduceFunction.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/ReduceFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/ReduceFunction.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/SinkFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/SinkFunction.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/SinkFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/SinkFunction.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/SourceFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/SourceFunction.java similarity index 85% rename from java/streaming/src/main/java/org/ray/streaming/api/function/impl/SourceFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/SourceFunction.java index 64a410172beb..93d2030e33b5 100644 --- a/java/streaming/src/main/java/org/ray/streaming/api/function/impl/SourceFunction.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/impl/SourceFunction.java @@ -11,7 +11,7 @@ public interface SourceFunction extends Function { void init(int parallel, int index); - void fetch(long batchId, SourceContext ctx) throws Exception; + void run(SourceContext ctx) throws Exception; void close(); diff --git a/java/streaming/src/main/java/org/ray/streaming/api/function/internal/CollectionSourceFunction.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/internal/CollectionSourceFunction.java similarity index 81% rename from java/streaming/src/main/java/org/ray/streaming/api/function/internal/CollectionSourceFunction.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/internal/CollectionSourceFunction.java index 1ad6736f7511..48c29ff7aac7 100644 --- a/java/streaming/src/main/java/org/ray/streaming/api/function/internal/CollectionSourceFunction.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/function/internal/CollectionSourceFunction.java @@ -1,5 +1,6 @@ package org.ray.streaming.api.function.internal; +import java.util.ArrayList; import java.util.Collection; import org.ray.streaming.api.function.impl.SourceFunction; @@ -21,10 +22,12 @@ public void init(int parallel, int index) { } @Override - public void fetch(long batchId, SourceContext ctx) throws Exception { + public void run(SourceContext ctx) throws Exception { for (T value : values) { ctx.collect(value); } + // empty collection + values = new ArrayList<>(); } @Override diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/Partition.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/Partition.java new file mode 100644 index 000000000000..9ea5d28ed761 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/Partition.java @@ -0,0 +1,23 @@ +package org.ray.streaming.api.partition; + +import org.ray.streaming.api.function.Function; + +/** + * Interface of the partitioning strategy. + * + * @param Type of the input data. + */ +@FunctionalInterface +public interface Partition extends Function { + + /** + * Given a record and downstream partitions, determine which partition(s) should receive the + * record. + * + * @param record The record. + * @param numPartition num of partitions + * @return IDs of the downstream partitions that should receive the record. + */ + int[] partition(T record, int numPartition); + +} diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/BroadcastPartition.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/BroadcastPartition.java new file mode 100644 index 000000000000..a08ab0d9d21a --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/BroadcastPartition.java @@ -0,0 +1,24 @@ +package org.ray.streaming.api.partition.impl; + +import java.util.stream.IntStream; + +import org.ray.streaming.api.partition.Partition; + +/** + * Broadcast the record to all downstream partitions. + */ +public class BroadcastPartition implements Partition { + private int[] partitions = new int[0]; + + public BroadcastPartition() { + } + + @Override + public int[] partition(T value, int numPartition) { + if (partitions.length != numPartition) { + partitions = IntStream.rangeClosed(0, numPartition - 1).toArray(); + } + return partitions; + } + +} diff --git a/java/streaming/src/main/java/org/ray/streaming/api/partition/impl/KeyPartition.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/KeyPartition.java similarity index 63% rename from java/streaming/src/main/java/org/ray/streaming/api/partition/impl/KeyPartition.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/KeyPartition.java index 9c86def3478b..ac4d635d45fe 100644 --- a/java/streaming/src/main/java/org/ray/streaming/api/partition/impl/KeyPartition.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/KeyPartition.java @@ -10,11 +10,11 @@ * @param Type of the input record. */ public class KeyPartition implements Partition> { + private int[] partitions = new int[1]; @Override - public int[] partition(KeyRecord keyRecord, int[] taskIds) { - int length = taskIds.length; - int taskId = taskIds[Math.abs(keyRecord.getKey().hashCode() % length)]; - return new int[]{taskId}; + public int[] partition(KeyRecord keyRecord, int numPartition) { + partitions[0] = Math.abs(keyRecord.getKey().hashCode() % numPartition); + return partitions; } } diff --git a/java/streaming/src/main/java/org/ray/streaming/api/partition/impl/RoundRobinPartition.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/RoundRobinPartition.java similarity index 67% rename from java/streaming/src/main/java/org/ray/streaming/api/partition/impl/RoundRobinPartition.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/RoundRobinPartition.java index 0c821400b183..0c8f7a68cc86 100644 --- a/java/streaming/src/main/java/org/ray/streaming/api/partition/impl/RoundRobinPartition.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/partition/impl/RoundRobinPartition.java @@ -8,17 +8,17 @@ * @param Type of the input record. */ public class RoundRobinPartition implements Partition { - private int seq; + private int[] partitions = new int[1]; public RoundRobinPartition() { this.seq = 0; } @Override - public int[] partition(T value, int[] taskIds) { - int length = taskIds.length; - int taskId = taskIds[seq++ % length]; - return new int[]{taskId}; + public int[] partition(T value, int numPartition) { + seq = (seq + 1) % numPartition; + partitions[0] = seq; + return partitions; } } diff --git a/java/streaming/src/main/java/org/ray/streaming/api/stream/DataStream.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/DataStream.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/stream/DataStream.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/DataStream.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/stream/JoinStream.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/JoinStream.java similarity index 78% rename from java/streaming/src/main/java/org/ray/streaming/api/stream/JoinStream.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/JoinStream.java index 2795feadc499..69cb8fe79933 100644 --- a/java/streaming/src/main/java/org/ray/streaming/api/stream/JoinStream.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/JoinStream.java @@ -9,9 +9,9 @@ /** * Represents a DataStream of two joined DataStream. * - * @param Lype of the data in the left stream. - * @param Lype of the data in the right stream. - * @param Lype of the data in the joined stream. + * @param Type of the data in the left stream. + * @param Type of the data in the right stream. + * @param Type of the data in the joined stream. */ public class JoinStream extends DataStream { @@ -33,10 +33,10 @@ public Where where(KeyFunction keyFunction) { /** * Where clause of the join transformation. * - * @param Lype of the data in the left stream. - * @param Lype of the data in the right stream. - * @param Lype of the data in the joined stream. - * @param Lype of the join key. + * @param Type of the data in the left stream. + * @param Type of the data in the right stream. + * @param Type of the data in the joined stream. + * @param Type of the join key. */ class Where implements Serializable { @@ -56,10 +56,10 @@ public Equal equalLo(KeyFunction rightKeyFunction) { /** * Equal clause of the join transformation. * - * @param Lype of the data in the left stream. - * @param Lype of the data in the right stream. - * @param Lype of the data in the joined stream. - * @param Lype of the join key. + * @param Type of the data in the left stream. + * @param Type of the data in the right stream. + * @param Type of the data in the joined stream. + * @param Type of the join key. */ class Equal implements Serializable { diff --git a/java/streaming/src/main/java/org/ray/streaming/api/stream/KeyDataStream.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/KeyDataStream.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/stream/KeyDataStream.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/KeyDataStream.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/stream/Stream.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/Stream.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/stream/Stream.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/Stream.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/stream/StreamSink.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/StreamSink.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/stream/StreamSink.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/StreamSink.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/stream/StreamSource.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/StreamSource.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/stream/StreamSource.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/StreamSource.java diff --git a/java/streaming/src/main/java/org/ray/streaming/api/stream/UnionStream.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/UnionStream.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/api/stream/UnionStream.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/api/stream/UnionStream.java diff --git a/java/streaming/src/main/java/org/ray/streaming/message/KeyRecord.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/message/KeyRecord.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/message/KeyRecord.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/message/KeyRecord.java diff --git a/java/streaming/src/main/java/org/ray/streaming/message/Message.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/message/Message.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/message/Message.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/message/Message.java diff --git a/java/streaming/src/main/java/org/ray/streaming/message/Record.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/message/Record.java similarity index 67% rename from java/streaming/src/main/java/org/ray/streaming/message/Record.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/message/Record.java index 5898fc63ab52..d1b0184e0256 100644 --- a/java/streaming/src/main/java/org/ray/streaming/message/Record.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/message/Record.java @@ -4,20 +4,13 @@ public class Record implements Serializable { - protected transient String stream; - protected transient long batchId; protected T value; public Record(T value) { this.value = value; } - public Record(long batchId, T value) { - this.batchId = batchId; - this.value = value; - } - public T getValue() { return value; } @@ -34,14 +27,6 @@ public void setStream(String stream) { this.stream = stream; } - public long getBatchId() { - return batchId; - } - - public void setBatchId(long batchId) { - this.batchId = batchId; - } - @Override public String toString() { return value.toString(); diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/OneInputOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OneInputOperator.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/operator/OneInputOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OneInputOperator.java diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/Operator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/Operator.java similarity index 84% rename from java/streaming/src/main/java/org/ray/streaming/operator/Operator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/Operator.java index 46e8f2b25721..39542def8605 100644 --- a/java/streaming/src/main/java/org/ray/streaming/operator/Operator.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/Operator.java @@ -3,7 +3,7 @@ import java.io.Serializable; import java.util.List; import org.ray.streaming.api.collector.Collector; -import org.ray.streaming.core.runtime.context.RuntimeContext; +import org.ray.streaming.api.context.RuntimeContext; public interface Operator extends Serializable { diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/OperatorType.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OperatorType.java similarity index 91% rename from java/streaming/src/main/java/org/ray/streaming/operator/OperatorType.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OperatorType.java index cc8f56406cbc..840372ad1fae 100644 --- a/java/streaming/src/main/java/org/ray/streaming/operator/OperatorType.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/OperatorType.java @@ -2,7 +2,6 @@ public enum OperatorType { - MASTER, SOURCE, ONE_INPUT, TWO_INPUT, diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/StreamOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/StreamOperator.java similarity index 94% rename from java/streaming/src/main/java/org/ray/streaming/operator/StreamOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/StreamOperator.java index 5294e9115c55..77e4bb2f7523 100644 --- a/java/streaming/src/main/java/org/ray/streaming/operator/StreamOperator.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/StreamOperator.java @@ -2,8 +2,8 @@ import java.util.List; import org.ray.streaming.api.collector.Collector; +import org.ray.streaming.api.context.RuntimeContext; import org.ray.streaming.api.function.Function; -import org.ray.streaming.core.runtime.context.RuntimeContext; import org.ray.streaming.message.KeyRecord; import org.ray.streaming.message.Record; diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/TwoInputOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/TwoInputOperator.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/operator/TwoInputOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/TwoInputOperator.java diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/impl/FlatMapOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/FlatMapOperator.java similarity index 88% rename from java/streaming/src/main/java/org/ray/streaming/operator/impl/FlatMapOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/FlatMapOperator.java index 556c17077ba1..f2ae5c326be5 100644 --- a/java/streaming/src/main/java/org/ray/streaming/operator/impl/FlatMapOperator.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/FlatMapOperator.java @@ -1,10 +1,10 @@ package org.ray.streaming.operator.impl; import java.util.List; +import org.ray.streaming.api.collector.CollectionCollector; import org.ray.streaming.api.collector.Collector; +import org.ray.streaming.api.context.RuntimeContext; import org.ray.streaming.api.function.impl.FlatMapFunction; -import org.ray.streaming.core.runtime.collector.CollectionCollector; -import org.ray.streaming.core.runtime.context.RuntimeContext; import org.ray.streaming.message.Record; import org.ray.streaming.operator.OneInputOperator; import org.ray.streaming.operator.StreamOperator; diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/impl/KeyByOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/KeyByOperator.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/operator/impl/KeyByOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/KeyByOperator.java diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/impl/MapOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/MapOperator.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/operator/impl/MapOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/MapOperator.java diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/impl/ReduceOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/ReduceOperator.java similarity index 95% rename from java/streaming/src/main/java/org/ray/streaming/operator/impl/ReduceOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/ReduceOperator.java index 6aa7ff8924de..341db84aefe2 100644 --- a/java/streaming/src/main/java/org/ray/streaming/operator/impl/ReduceOperator.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/ReduceOperator.java @@ -4,8 +4,8 @@ import java.util.List; import java.util.Map; import org.ray.streaming.api.collector.Collector; +import org.ray.streaming.api.context.RuntimeContext; import org.ray.streaming.api.function.impl.ReduceFunction; -import org.ray.streaming.core.runtime.context.RuntimeContext; import org.ray.streaming.message.KeyRecord; import org.ray.streaming.message.Record; import org.ray.streaming.operator.OneInputOperator; diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/impl/SinkOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/SinkOperator.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/operator/impl/SinkOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/SinkOperator.java diff --git a/java/streaming/src/main/java/org/ray/streaming/operator/impl/SourceOperator.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/SourceOperator.java similarity index 79% rename from java/streaming/src/main/java/org/ray/streaming/operator/impl/SourceOperator.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/SourceOperator.java index 430bb78c5895..a4b42e67b6ee 100644 --- a/java/streaming/src/main/java/org/ray/streaming/operator/impl/SourceOperator.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/operator/impl/SourceOperator.java @@ -2,9 +2,9 @@ import java.util.List; import org.ray.streaming.api.collector.Collector; +import org.ray.streaming.api.context.RuntimeContext; import org.ray.streaming.api.function.impl.SourceFunction; import org.ray.streaming.api.function.impl.SourceFunction.SourceContext; -import org.ray.streaming.core.runtime.context.RuntimeContext; import org.ray.streaming.message.Record; import org.ray.streaming.operator.OperatorType; import org.ray.streaming.operator.StreamOperator; @@ -24,24 +24,20 @@ public void open(List collectorList, RuntimeContext runtimeContext) { this.function.init(runtimeContext.getParallelism(), runtimeContext.getTaskIndex()); } - public void process(Long batchId) { + public void run() { try { - this.sourceContext.setBatchId(batchId); - this.function.fetch(batchId, this.sourceContext); + this.function.run(this.sourceContext); } catch (Exception e) { throw new RuntimeException(e); } } - @Override public OperatorType getOpType() { return OperatorType.SOURCE; } class SourceContextImpl implements SourceContext { - - private long batchId; private List collectors; public SourceContextImpl(List collectors) { @@ -51,12 +47,9 @@ public SourceContextImpl(List collectors) { @Override public void collect(T t) throws Exception { for (Collector collector : collectors) { - collector.collect(new Record(batchId, t)); + collector.collect(new Record(t)); } } - private void setBatchId(long batchId) { - this.batchId = batchId; - } } } diff --git a/java/streaming/src/main/java/org/ray/streaming/plan/Plan.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/Plan.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/plan/Plan.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/Plan.java diff --git a/java/streaming/src/main/java/org/ray/streaming/plan/PlanBuilder.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/PlanBuilder.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/plan/PlanBuilder.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/PlanBuilder.java diff --git a/java/streaming/src/main/java/org/ray/streaming/plan/PlanEdge.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/PlanEdge.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/plan/PlanEdge.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/PlanEdge.java diff --git a/java/streaming/src/main/java/org/ray/streaming/plan/PlanVertex.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/PlanVertex.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/plan/PlanVertex.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/PlanVertex.java diff --git a/java/streaming/src/main/java/org/ray/streaming/plan/VertexType.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/VertexType.java similarity index 100% rename from java/streaming/src/main/java/org/ray/streaming/plan/VertexType.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/plan/VertexType.java diff --git a/java/streaming/src/main/java/org/ray/streaming/schedule/IJobSchedule.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/schedule/JobScheduler.java similarity index 69% rename from java/streaming/src/main/java/org/ray/streaming/schedule/IJobSchedule.java rename to streaming/java/streaming-api/src/main/java/org/ray/streaming/schedule/JobScheduler.java index aa57166c7826..86539b432d0f 100644 --- a/java/streaming/src/main/java/org/ray/streaming/schedule/IJobSchedule.java +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/schedule/JobScheduler.java @@ -1,17 +1,19 @@ package org.ray.streaming.schedule; +import java.util.Map; + import org.ray.streaming.plan.Plan; /** * Interface of the job scheduler. */ -public interface IJobSchedule { +public interface JobScheduler { /** * Assign logical plan to physical execution graph, and schedule job to run. * * @param plan The logical plan. */ - void schedule(Plan plan); + void schedule(Plan plan, Map conf); } diff --git a/streaming/java/streaming-api/src/main/java/org/ray/streaming/util/Config.java b/streaming/java/streaming-api/src/main/java/org/ray/streaming/util/Config.java new file mode 100644 index 000000000000..cad4dc99ea91 --- /dev/null +++ b/streaming/java/streaming-api/src/main/java/org/ray/streaming/util/Config.java @@ -0,0 +1,44 @@ +package org.ray.streaming.util; + +public class Config { + + /** + * Maximum number of batches to run in a streaming job. + */ + public static final String STREAMING_BATCH_MAX_COUNT = "streaming.batch.max.count"; + + /** + * batch frequency in milliseconds + */ + public static final String STREAMING_BATCH_FREQUENCY = "streaming.batch.frequency"; + public static final long STREAMING_BATCH_FREQUENCY_DEFAULT = 1000; + + public static final String STREAMING_JOB_NAME = "streaming.job.name"; + public static final String STREAMING_OP_NAME = "streaming.op_name"; + public static final String TASK_JOB_ID = "streaming.task_job_id"; + public static final String STREAMING_WORKER_NAME = "streaming.worker_name"; + + // channel + public static final String CHANNEL_TYPE = "channel_type"; + public static final String MEMORY_CHANNEL = "memory_channel"; + public static final String NATIVE_CHANNEL = "native_channel"; + public static final String DEFAULT_CHANNEL_TYPE = NATIVE_CHANNEL; + public static final String CHANNEL_SIZE = "channel_size"; + public static final String CHANNEL_SIZE_DEFAULT = String.valueOf((long)Math.pow(10, 8)); + public static final String IS_RECREATE = "streaming.is_recreate"; + // return from DataReader.getBundle if only empty message read in this interval. + public static final String TIMER_INTERVAL_MS = "timer_interval_ms"; + public static final String READ_TIMEOUT_MS = "read_timeout_ms"; + public static final String DEFAULT_READ_TIMEOUT_MS = "10"; + + + public static final String STREAMING_RING_BUFFER_CAPACITY = "streaming.ring_buffer_capacity"; + // write an empty message if there is no data to be written in this + // interval. + public static final String STREAMING_EMPTY_MESSAGE_INTERVAL = "streaming.empty_message_interval"; + + // operator type + public static final String OPERATOR_TYPE = "operator_type"; + + +} diff --git a/java/streaming/src/main/resources/log4j.properties b/streaming/java/streaming-api/src/main/resources/log4j.properties similarity index 100% rename from java/streaming/src/main/resources/log4j.properties rename to streaming/java/streaming-api/src/main/resources/log4j.properties diff --git a/java/streaming/src/main/resources/ray.conf b/streaming/java/streaming-api/src/main/resources/ray.conf similarity index 100% rename from java/streaming/src/main/resources/ray.conf rename to streaming/java/streaming-api/src/main/resources/ray.conf diff --git a/java/streaming/src/test/java/org/ray/streaming/plan/PlanBuilderTest.java b/streaming/java/streaming-api/src/test/java/org/ray/streaming/plan/PlanBuilderTest.java similarity index 100% rename from java/streaming/src/test/java/org/ray/streaming/plan/PlanBuilderTest.java rename to streaming/java/streaming-api/src/test/java/org/ray/streaming/plan/PlanBuilderTest.java diff --git a/java/streaming/src/test/resources/log4j.properties b/streaming/java/streaming-api/src/test/resources/log4j.properties similarity index 100% rename from java/streaming/src/test/resources/log4j.properties rename to streaming/java/streaming-api/src/test/resources/log4j.properties diff --git a/java/streaming/src/test/resources/ray.conf b/streaming/java/streaming-api/src/test/resources/ray.conf similarity index 100% rename from java/streaming/src/test/resources/ray.conf rename to streaming/java/streaming-api/src/test/resources/ray.conf diff --git a/streaming/java/streaming-runtime/pom.xml b/streaming/java/streaming-runtime/pom.xml new file mode 100755 index 000000000000..e908dc7351ac --- /dev/null +++ b/streaming/java/streaming-runtime/pom.xml @@ -0,0 +1,106 @@ + + + + + ray-streaming + org.ray + 0.1-SNAPSHOT + + 4.0.0 + + streaming-runtime + ray streaming runtime + ray streaming runtime + jar + + + + org.ray + ray-api + ${project.version} + + + org.ray + ray-runtime + ${project.version} + + + org.ray + streaming-api + ${project.version} + + + com.github.davidmoten + flatbuffers-java + 1.9.0.1 + + + com.google.guava + guava + 27.0.1-jre + + + com.google.protobuf + protobuf-java + 3.8.0 + + + org.slf4j + slf4j-api + 1.7.25 + + + org.slf4j + slf4j-log4j12 + 1.7.25 + + + org.testng + testng + 6.9.10 + + + + + + + src/main/resources + + + native_dependencies + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies-to-build + package + + copy-dependencies + + + ${basedir}/../../build/java + false + false + true + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.3.1 + + ${basedir}/../../build/java + + + + + diff --git a/streaming/java/streaming-runtime/pom_template.xml b/streaming/java/streaming-runtime/pom_template.xml new file mode 100644 index 000000000000..848e9fec2263 --- /dev/null +++ b/streaming/java/streaming-runtime/pom_template.xml @@ -0,0 +1,77 @@ + + {auto_gen_header} + + + ray-streaming + org.ray + 0.1-SNAPSHOT + + 4.0.0 + + streaming-runtime + ray streaming runtime + ray streaming runtime + jar + + + + org.ray + ray-api + ${project.version} + + + org.ray + ray-runtime + ${project.version} + + + org.ray + streaming-api + ${project.version} + + {generated_bzl_deps} + + + + + + src/main/resources + + + native_dependencies + + + + + org.apache.maven.plugins + maven-dependency-plugin + + + copy-dependencies-to-build + package + + copy-dependencies + + + ${basedir}/../../build/java + false + false + true + + + + + + + org.apache.maven.plugins + maven-jar-plugin + 2.3.1 + + ${basedir}/../../build/java + + + + + diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/cluster/ResourceManager.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/cluster/ResourceManager.java new file mode 100644 index 000000000000..0a113da76038 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/cluster/ResourceManager.java @@ -0,0 +1,24 @@ +package org.ray.streaming.runtime.cluster; + +import java.util.ArrayList; +import java.util.List; + +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.streaming.runtime.worker.JobWorker; + +/** + * Resource-Manager is used to do the management of resources + */ +public class ResourceManager { + + public List> createWorkers(int workerNum) { + List> workers = new ArrayList<>(); + for (int i = 0; i < workerNum; i++) { + RayActor worker = Ray.createActor(JobWorker::new); + workers.add(worker); + } + return workers; + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/collector/OutputCollector.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/collector/OutputCollector.java new file mode 100644 index 000000000000..64f92cc13e6f --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/collector/OutputCollector.java @@ -0,0 +1,40 @@ +package org.ray.streaming.runtime.core.collector; + +import java.nio.ByteBuffer; +import java.util.Collection; +import org.ray.runtime.util.Serializer; +import org.ray.streaming.api.collector.Collector; +import org.ray.streaming.api.partition.Partition; +import org.ray.streaming.message.Record; +import org.ray.streaming.runtime.transfer.ChannelID; +import org.ray.streaming.runtime.transfer.DataWriter; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class OutputCollector implements Collector { + private static final Logger LOGGER = LoggerFactory.getLogger(OutputCollector.class); + + private Partition partition; + private DataWriter writer; + private ChannelID[] outputQueues; + + public OutputCollector(Collection outputQueueIds, + DataWriter writer, + Partition partition) { + this.outputQueues = outputQueueIds.stream().map(ChannelID::from).toArray(ChannelID[]::new); + this.writer = writer; + this.partition = partition; + LOGGER.debug("OutputCollector constructed, outputQueueIds:{}, partition:{}.", + outputQueueIds, this.partition); + } + + @Override + public void collect(Record record) { + int[] partitions = this.partition.partition(record, outputQueues.length); + ByteBuffer msgBuffer = ByteBuffer.wrap(Serializer.encode(record)); + for (int partition : partitions) { + writer.write(outputQueues[partition], msgBuffer); + } + } + +} diff --git a/java/streaming/src/main/java/org/ray/streaming/core/command/BatchInfo.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/command/BatchInfo.java similarity index 86% rename from java/streaming/src/main/java/org/ray/streaming/core/command/BatchInfo.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/command/BatchInfo.java index 359a4c1f2289..d950d4450c71 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/command/BatchInfo.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/command/BatchInfo.java @@ -1,4 +1,4 @@ -package org.ray.streaming.core.command; +package org.ray.streaming.runtime.core.command; import java.io.Serializable; diff --git a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionEdge.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionEdge.java similarity index 95% rename from java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionEdge.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionEdge.java index 42b50ead8aca..ae9ecf841171 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionEdge.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionEdge.java @@ -1,6 +1,7 @@ -package org.ray.streaming.core.graph; +package org.ray.streaming.runtime.core.graph; import java.io.Serializable; + import org.ray.streaming.api.partition.Partition; /** diff --git a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionGraph.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionGraph.java similarity index 57% rename from java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionGraph.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionGraph.java index fc8d2b29ea30..a6f8503619be 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionGraph.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionGraph.java @@ -1,25 +1,47 @@ -package org.ray.streaming.core.graph; +package org.ray.streaming.runtime.core.graph; import java.io.Serializable; +import java.util.ArrayList; import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; + import org.ray.api.RayActor; -import org.ray.streaming.core.runtime.StreamWorker; +import org.ray.streaming.runtime.worker.JobWorker; /** * Physical execution graph. */ public class ExecutionGraph implements Serializable { - + private long buildTime; private List executionNodeList; + private List> sourceWorkers = new ArrayList<>(); + private List> sinkWorkers = new ArrayList<>(); public ExecutionGraph(List executionNodes) { this.executionNodeList = executionNodes; + for (ExecutionNode executionNode : executionNodeList) { + if (executionNode.getNodeType() == ExecutionNode.NodeType.SOURCE) { + List> actors = executionNode.getExecutionTasks().stream() + .map(ExecutionTask::getWorker).collect(Collectors.toList()); + sourceWorkers.addAll(actors); + } + if (executionNode.getNodeType() == ExecutionNode.NodeType.SINK) { + List> actors = executionNode.getExecutionTasks().stream() + .map(ExecutionTask::getWorker).collect(Collectors.toList()); + sinkWorkers.addAll(actors); + } + } + buildTime = System.currentTimeMillis(); + } + + public List> getSourceWorkers() { + return sourceWorkers; } - public void addExectionNode(ExecutionNode executionNode) { - this.executionNodeList.add(executionNode); + public List> getSinkWorkers() { + return sinkWorkers; } public List getExecutionNodeList() { @@ -28,7 +50,7 @@ public List getExecutionNodeList() { public ExecutionTask getExecutionTaskByTaskId(int taskId) { for (ExecutionNode executionNode : executionNodeList) { - for (ExecutionTask executionTask : executionNode.getExecutionTaskList()) { + for (ExecutionTask executionTask : executionNode.getExecutionTasks()) { if (executionTask.getTaskId() == taskId) { return executionTask; } @@ -48,7 +70,7 @@ public ExecutionNode getExecutionNodeByNodeId(int nodeId) { public ExecutionNode getExecutionNodeByTaskId(int taskId) { for (ExecutionNode executionNode : executionNodeList) { - for (ExecutionTask executionTask : executionNode.getExecutionTaskList()) { + for (ExecutionTask executionTask : executionNode.getExecutionTasks()) { if (executionTask.getTaskId() == taskId) { return executionNode; } @@ -57,11 +79,11 @@ public ExecutionNode getExecutionNodeByTaskId(int taskId) { throw new RuntimeException("Task " + taskId + " does not exist!"); } - public Map> getTaskId2WorkerByNodeId(int nodeId) { + public Map> getTaskId2WorkerByNodeId(int nodeId) { for (ExecutionNode executionNode : executionNodeList) { if (executionNode.getNodeId() == nodeId) { - Map> taskId2Worker = new HashMap<>(); - for (ExecutionTask executionTask : executionNode.getExecutionTaskList()) { + Map> taskId2Worker = new HashMap<>(); + for (ExecutionTask executionTask : executionNode.getExecutionTasks()) { taskId2Worker.put(executionTask.getTaskId(), executionTask.getWorker()); } return taskId2Worker; @@ -70,4 +92,7 @@ public Map> getTaskId2WorkerByNodeId(int nodeId) throw new RuntimeException("Node " + nodeId + " does not exist!"); } + public long getBuildTime() { + return buildTime; + } } diff --git a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionNode.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionNode.java similarity index 52% rename from java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionNode.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionNode.java index 32756c49c1b4..4b120ead6094 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionNode.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionNode.java @@ -1,10 +1,10 @@ -package org.ray.streaming.core.graph; +package org.ray.streaming.runtime.core.graph; import java.io.Serializable; import java.util.ArrayList; import java.util.List; -import org.ray.streaming.core.processor.StreamProcessor; import org.ray.streaming.plan.VertexType; +import org.ray.streaming.runtime.core.processor.StreamProcessor; /** * A node in the physical execution graph. @@ -15,14 +15,16 @@ public class ExecutionNode implements Serializable { private int parallelism; private NodeType nodeType; private StreamProcessor streamProcessor; - private List executionTaskList; - private List executionEdgeList; + private List executionTasks; + private List inputsEdges; + private List outputEdges; public ExecutionNode(int nodeId, int parallelism) { this.nodeId = nodeId; this.parallelism = parallelism; - this.executionTaskList = new ArrayList<>(); - this.executionEdgeList = new ArrayList<>(); + this.executionTasks = new ArrayList<>(); + this.inputsEdges = new ArrayList<>(); + this.outputEdges = new ArrayList<>(); } public int getNodeId() { @@ -41,24 +43,32 @@ public void setParallelism(int parallelism) { this.parallelism = parallelism; } - public List getExecutionTaskList() { - return executionTaskList; + public List getExecutionTasks() { + return executionTasks; } - public void setExecutionTaskList(List executionTaskList) { - this.executionTaskList = executionTaskList; + public void setExecutionTasks(List executionTasks) { + this.executionTasks = executionTasks; } - public List getExecutionEdgeList() { - return executionEdgeList; + public List getOutputEdges() { + return outputEdges; } - public void setExecutionEdgeList(List executionEdgeList) { - this.executionEdgeList = executionEdgeList; + public void setOutputEdges(List outputEdges) { + this.outputEdges = outputEdges; } public void addExecutionEdge(ExecutionEdge executionEdge) { - this.executionEdgeList.add(executionEdge); + this.outputEdges.add(executionEdge); + } + + public void addInputEdge(ExecutionEdge executionEdge) { + this.inputsEdges.add(executionEdge); + } + + public List getInputsEdges() { + return inputsEdges; } public StreamProcessor getStreamProcessor() { @@ -75,9 +85,6 @@ public NodeType getNodeType() { public void setNodeType(VertexType vertexType) { switch (vertexType) { - case MASTER: - this.nodeType = NodeType.MASTER; - break; case SOURCE: this.nodeType = NodeType.SOURCE; break; @@ -89,8 +96,18 @@ public void setNodeType(VertexType vertexType) { } } + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("ExecutionNode{"); + sb.append("nodeId=").append(nodeId); + sb.append(", parallelism=").append(parallelism); + sb.append(", nodeType=").append(nodeType); + sb.append(", streamProcessor=").append(streamProcessor); + sb.append('}'); + return sb.toString(); + } + public enum NodeType { - MASTER, SOURCE, PROCESS, SINK, diff --git a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java similarity index 67% rename from java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionTask.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java index 72d3eaa6fd12..7e205d51f285 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/graph/ExecutionTask.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/graph/ExecutionTask.java @@ -1,21 +1,22 @@ -package org.ray.streaming.core.graph; +package org.ray.streaming.runtime.core.graph; import java.io.Serializable; + import org.ray.api.RayActor; -import org.ray.streaming.core.runtime.StreamWorker; +import org.ray.streaming.runtime.worker.JobWorker; /** * ExecutionTask is minimal execution unit. - * + *

* An ExecutionNode has n ExecutionTasks if parallelism is n. */ public class ExecutionTask implements Serializable { private int taskId; private int taskIndex; - private RayActor worker; + private RayActor worker; - public ExecutionTask(int taskId, int taskIndex, RayActor worker) { + public ExecutionTask(int taskId, int taskIndex, RayActor worker) { this.taskId = taskId; this.taskIndex = taskIndex; this.worker = worker; @@ -37,11 +38,11 @@ public void setTaskIndex(int taskIndex) { this.taskIndex = taskIndex; } - public RayActor getWorker() { + public RayActor getWorker() { return worker; } - public void setWorker(RayActor worker) { + public void setWorker(RayActor worker) { this.worker = worker; } } diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/OneInputProcessor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/OneInputProcessor.java similarity index 93% rename from java/streaming/src/main/java/org/ray/streaming/core/processor/OneInputProcessor.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/OneInputProcessor.java index 3d675aa1604b..7b9fbc4c1e3d 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/processor/OneInputProcessor.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/OneInputProcessor.java @@ -1,4 +1,4 @@ -package org.ray.streaming.core.processor; +package org.ray.streaming.runtime.core.processor; import org.ray.streaming.message.Record; import org.ray.streaming.operator.OneInputOperator; diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/ProcessBuilder.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/ProcessBuilder.java similarity index 91% rename from java/streaming/src/main/java/org/ray/streaming/core/processor/ProcessBuilder.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/ProcessBuilder.java index 2a8c1e63437b..07dfda0f8a25 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/processor/ProcessBuilder.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/ProcessBuilder.java @@ -1,4 +1,4 @@ -package org.ray.streaming.core.processor; +package org.ray.streaming.runtime.core.processor; import org.ray.streaming.operator.OneInputOperator; import org.ray.streaming.operator.OperatorType; @@ -18,8 +18,6 @@ public static StreamProcessor buildProcessor(StreamOperator streamOperator) { LOGGER.info("Building StreamProcessor, operator type = {}, operator = {}.", type, streamOperator.getClass().getSimpleName().toString()); switch (type) { - case MASTER: - return new MasterProcessor(null); case SOURCE: return new SourceProcessor<>((SourceOperator) streamOperator); case ONE_INPUT: diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/Processor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/Processor.java similarity index 72% rename from java/streaming/src/main/java/org/ray/streaming/core/processor/Processor.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/Processor.java index 02b38f08ad81..0578bc6cabbb 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/processor/Processor.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/Processor.java @@ -1,9 +1,9 @@ -package org.ray.streaming.core.processor; +package org.ray.streaming.runtime.core.processor; import java.io.Serializable; import java.util.List; import org.ray.streaming.api.collector.Collector; -import org.ray.streaming.core.runtime.context.RuntimeContext; +import org.ray.streaming.api.context.RuntimeContext; public interface Processor extends Serializable { diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/SourceProcessor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/SourceProcessor.java new file mode 100644 index 000000000000..1e36ae3f3872 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/SourceProcessor.java @@ -0,0 +1,30 @@ +package org.ray.streaming.runtime.core.processor; + +import org.ray.streaming.message.Record; +import org.ray.streaming.operator.impl.SourceOperator; + +/** + * The processor for the stream sources, containing a SourceOperator. + * + * @param The type of source data. + */ +public class SourceProcessor extends StreamProcessor> { + + public SourceProcessor(SourceOperator operator) { + super(operator); + } + + @Override + public void process(Record record) { + throw new UnsupportedOperationException("SourceProcessor should not process record"); + } + + public void run() { + operator.run(); + } + + @Override + public void close() { + + } +} diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/StreamProcessor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/StreamProcessor.java similarity index 68% rename from java/streaming/src/main/java/org/ray/streaming/core/processor/StreamProcessor.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/StreamProcessor.java index 3dc307e01873..a2ecef2633f5 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/processor/StreamProcessor.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/StreamProcessor.java @@ -1,9 +1,11 @@ -package org.ray.streaming.core.processor; +package org.ray.streaming.runtime.core.processor; import java.util.List; import org.ray.streaming.api.collector.Collector; -import org.ray.streaming.core.runtime.context.RuntimeContext; +import org.ray.streaming.api.context.RuntimeContext; import org.ray.streaming.operator.Operator; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; /** * StreamingProcessor is a process unit for a operator. @@ -12,6 +14,7 @@ * @param

Type of the specific operator class. */ public abstract class StreamProcessor implements Processor { + private static final Logger LOGGER = LoggerFactory.getLogger(StreamProcessor.class); protected List collectors; protected RuntimeContext runtimeContext; @@ -28,6 +31,11 @@ public void open(List collectors, RuntimeContext runtimeContext) { if (operator != null) { this.operator.open(collectors, runtimeContext); } + LOGGER.info("opened {}", this); } + @Override + public String toString() { + return this.getClass().getSimpleName(); + } } diff --git a/java/streaming/src/main/java/org/ray/streaming/core/processor/TwoInputProcessor.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/TwoInputProcessor.java similarity index 68% rename from java/streaming/src/main/java/org/ray/streaming/core/processor/TwoInputProcessor.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/TwoInputProcessor.java index 88094b0c8cbb..fbaf84a16a84 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/processor/TwoInputProcessor.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/core/processor/TwoInputProcessor.java @@ -1,4 +1,4 @@ -package org.ray.streaming.core.processor; +package org.ray.streaming.runtime.core.processor; import org.ray.streaming.message.Record; import org.ray.streaming.operator.TwoInputOperator; @@ -6,12 +6,10 @@ import org.slf4j.LoggerFactory; public class TwoInputProcessor extends StreamProcessor> { - private static final Logger LOGGER = LoggerFactory.getLogger(TwoInputProcessor.class); - // TODO(zhenxuanpan): Set leftStream and rightStream. private String leftStream; - private String rigthStream; + private String rightStream; public TwoInputProcessor(TwoInputOperator operator) { super(operator); @@ -34,4 +32,20 @@ public void process(Record record) { public void close() { this.operator.close(); } + + public String getLeftStream() { + return leftStream; + } + + public void setLeftStream(String leftStream) { + this.leftStream = leftStream; + } + + public String getRightStream() { + return rightStream; + } + + public void setRightStream(String rightStream) { + this.rightStream = rightStream; + } } diff --git a/java/streaming/src/main/java/org/ray/streaming/schedule/ITaskAssign.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/ITaskAssign.java similarity index 56% rename from java/streaming/src/main/java/org/ray/streaming/schedule/ITaskAssign.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/ITaskAssign.java index d9c7cd507863..9fc1d0cdb6c8 100644 --- a/java/streaming/src/main/java/org/ray/streaming/schedule/ITaskAssign.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/ITaskAssign.java @@ -1,11 +1,11 @@ -package org.ray.streaming.schedule; +package org.ray.streaming.runtime.schedule; import java.io.Serializable; import java.util.List; import org.ray.api.RayActor; -import org.ray.streaming.core.graph.ExecutionGraph; -import org.ray.streaming.core.runtime.StreamWorker; import org.ray.streaming.plan.Plan; +import org.ray.streaming.runtime.core.graph.ExecutionGraph; +import org.ray.streaming.runtime.worker.JobWorker; /** * Interface of the task assigning strategy. @@ -15,6 +15,6 @@ public interface ITaskAssign extends Serializable { /** * Assign logical plan to physical execution graph. */ - ExecutionGraph assign(Plan plan, List> workers); + ExecutionGraph assign(Plan plan, List> workers); } diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java new file mode 100644 index 000000000000..8a30c75fa419 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/JobSchedulerImpl.java @@ -0,0 +1,65 @@ +package org.ray.streaming.runtime.schedule; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.RayObject; +import org.ray.streaming.plan.Plan; +import org.ray.streaming.plan.PlanVertex; +import org.ray.streaming.runtime.cluster.ResourceManager; +import org.ray.streaming.runtime.core.graph.ExecutionGraph; +import org.ray.streaming.runtime.core.graph.ExecutionNode; +import org.ray.streaming.runtime.core.graph.ExecutionTask; +import org.ray.streaming.runtime.worker.JobWorker; +import org.ray.streaming.runtime.worker.context.WorkerContext; +import org.ray.streaming.schedule.JobScheduler; + +/** + * JobSchedulerImpl schedules workers by the Plan and the resource information + * from ResourceManager. + */ +public class JobSchedulerImpl implements JobScheduler { + private Plan plan; + private Map jobConfig; + private ResourceManager resourceManager; + private ITaskAssign taskAssign; + + public JobSchedulerImpl() { + this.resourceManager = new ResourceManager(); + this.taskAssign = new TaskAssignImpl(); + } + + /** + * Schedule physical plan to execution graph, and call streaming worker to init and run. + */ + @Override + public void schedule(Plan plan, Map jobConfig) { + this.jobConfig = jobConfig; + this.plan = plan; + System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); + Ray.init(); + + List> workers = this.resourceManager.createWorkers(getPlanWorker()); + ExecutionGraph executionGraph = this.taskAssign.assign(this.plan, workers); + + List executionNodes = executionGraph.getExecutionNodeList(); + List> waits = new ArrayList<>(); + for (ExecutionNode executionNode : executionNodes) { + List executionTasks = executionNode.getExecutionTasks(); + for (ExecutionTask executionTask : executionTasks) { + int taskId = executionTask.getTaskId(); + RayActor streamWorker = executionTask.getWorker(); + waits.add(Ray.call(JobWorker::init, streamWorker, + new WorkerContext(taskId, executionGraph, jobConfig))); + } + } + Ray.wait(waits); + } + + private int getPlanWorker() { + List planVertexList = plan.getPlanVertexList(); + return planVertexList.stream().map(PlanVertex::getParallelism).reduce(0, Integer::sum); + } +} diff --git a/java/streaming/src/main/java/org/ray/streaming/schedule/impl/TaskAssignImpl.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignImpl.java similarity index 73% rename from java/streaming/src/main/java/org/ray/streaming/schedule/impl/TaskAssignImpl.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignImpl.java index be3f6ae35113..20b3e89712d6 100644 --- a/java/streaming/src/main/java/org/ray/streaming/schedule/impl/TaskAssignImpl.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/schedule/TaskAssignImpl.java @@ -1,4 +1,4 @@ -package org.ray.streaming.schedule.impl; +package org.ray.streaming.runtime.schedule; import java.util.ArrayList; import java.util.HashMap; @@ -6,29 +6,28 @@ import java.util.Map; import java.util.stream.Collectors; import org.ray.api.RayActor; -import org.ray.streaming.core.graph.ExecutionEdge; -import org.ray.streaming.core.graph.ExecutionGraph; -import org.ray.streaming.core.graph.ExecutionNode; -import org.ray.streaming.core.graph.ExecutionTask; -import org.ray.streaming.core.processor.ProcessBuilder; -import org.ray.streaming.core.processor.StreamProcessor; -import org.ray.streaming.core.runtime.StreamWorker; import org.ray.streaming.plan.Plan; import org.ray.streaming.plan.PlanEdge; import org.ray.streaming.plan.PlanVertex; -import org.ray.streaming.schedule.ITaskAssign; +import org.ray.streaming.runtime.core.graph.ExecutionEdge; +import org.ray.streaming.runtime.core.graph.ExecutionGraph; +import org.ray.streaming.runtime.core.graph.ExecutionNode; +import org.ray.streaming.runtime.core.graph.ExecutionTask; +import org.ray.streaming.runtime.core.processor.ProcessBuilder; +import org.ray.streaming.runtime.core.processor.StreamProcessor; +import org.ray.streaming.runtime.worker.JobWorker; public class TaskAssignImpl implements ITaskAssign { /** * Assign an optimized logical plan to execution graph. * - * @param plan The logical plan. + * @param plan The logical plan. * @param workers The worker actors. * @return The physical execution graph. */ @Override - public ExecutionGraph assign(Plan plan, List> workers) { + public ExecutionGraph assign(Plan plan, List> workers) { List planVertices = plan.getPlanVertexList(); List planEdges = plan.getPlanEdgeList(); @@ -45,7 +44,7 @@ public ExecutionGraph assign(Plan plan, List> workers) { } StreamProcessor streamProcessor = ProcessBuilder .buildProcessor(planVertex.getStreamOperator()); - executionNode.setExecutionTaskList(vertexTasks); + executionNode.setExecutionTasks(vertexTasks); executionNode.setStreamProcessor(streamProcessor); idToExecutionNode.put(executionNode.getNodeId(), executionNode); } @@ -57,6 +56,7 @@ public ExecutionGraph assign(Plan plan, List> workers) { ExecutionEdge executionEdge = new ExecutionEdge(srcNodeId, targetNodeId, planEdge.getPartition()); idToExecutionNode.get(srcNodeId).addExecutionEdge(executionEdge); + idToExecutionNode.get(targetNodeId).addInputEdge(executionEdge); } List executionNodes = idToExecutionNode.values().stream() diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelID.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelID.java new file mode 100644 index 000000000000..1549408627e1 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelID.java @@ -0,0 +1,182 @@ +package org.ray.streaming.runtime.transfer; + +import com.google.common.base.FinalizablePhantomReference; +import com.google.common.base.FinalizableReferenceQueue; +import com.google.common.base.Preconditions; +import com.google.common.collect.Sets; +import com.google.common.io.BaseEncoding; +import java.lang.ref.Reference; +import java.nio.ByteBuffer; +import java.util.Random; +import java.util.Set; +import sun.nio.ch.DirectBuffer; + +/** + * ChannelID is used to identify a transfer channel between a upstream worker + * and downstream worker. + */ +public class ChannelID { + public static final int ID_LENGTH = 20; + private static final FinalizableReferenceQueue REFERENCE_QUEUE = new FinalizableReferenceQueue(); + // This ensures that the FinalizablePhantomReference itself is not garbage-collected. + private static final Set> references = Sets.newConcurrentHashSet(); + + private final byte[] bytes; + private final String strId; + private final ByteBuffer buffer; + private final long address; + private final long nativeIdPtr; + + private ChannelID(String strId, byte[] idBytes) { + this.strId = strId; + this.bytes = idBytes; + ByteBuffer directBuffer = ByteBuffer.allocateDirect(ID_LENGTH); + directBuffer.put(bytes); + directBuffer.rewind(); + this.buffer = directBuffer; + this.address = ((DirectBuffer) (buffer)).address(); + long nativeIdPtr = 0; + nativeIdPtr = createNativeID(address); + this.nativeIdPtr = nativeIdPtr; + } + + public byte[] getBytes() { + return bytes; + } + + public ByteBuffer getBuffer() { + return buffer; + } + + public long getAddress() { + return address; + } + + public long getNativeIdPtr() { + if (nativeIdPtr == 0) { + throw new IllegalStateException("native ID not available"); + } + return nativeIdPtr; + } + + @Override + public String toString() { + return strId; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + ChannelID that = (ChannelID) o; + return strId.equals(that.strId); + } + + @Override + public int hashCode() { + return strId.hashCode(); + } + + private static native long createNativeID(long idAddress); + + private static native void destroyNativeID(long nativeIdPtr); + + /** + * @param id hex string representation of channel id + */ + public static ChannelID from(String id) { + return from(id, ChannelID.idStrToBytes(id)); + } + + /** + * @param idBytes bytes representation of channel id + */ + public static ChannelID from(byte[] idBytes) { + return from(idBytesToStr(idBytes), idBytes); + } + + private static ChannelID from(String strID, byte[] idBytes) { + ChannelID id = new ChannelID(strID, idBytes); + long nativeIdPtr = id.nativeIdPtr; + if (nativeIdPtr != 0) { + Reference reference = + new FinalizablePhantomReference(id, REFERENCE_QUEUE) { + @Override + public void finalizeReferent() { + destroyNativeID(nativeIdPtr); + references.remove(this); + } + }; + references.add(reference); + } + return id; + } + + /** + * @return a random channel id string + */ + public static String genRandomIdStr() { + StringBuilder sb = new StringBuilder(); + Random random = new Random(); + for (int i = 0; i < ChannelID.ID_LENGTH * 2; ++i) { + sb.append((char) (random.nextInt(6) + 'A')); + } + return sb.toString(); + } + + /** + * Generate channel name, which will be 20 character + * + * @param fromTaskId upstream task id + * @param toTaskId downstream task id + * @return channel name + */ + public static String genIdStr(int fromTaskId, int toTaskId, long ts) { + /* + | Head | Timestamp | Empty | From | To | + | 8 bytes | 4bytes | 4bytes| 2bytes| 2bytes | + */ + Preconditions.checkArgument(fromTaskId < Short.MAX_VALUE, + "fromTaskId %d is larger than %d", fromTaskId, Short.MAX_VALUE); + Preconditions.checkArgument(toTaskId < Short.MAX_VALUE, + "toTaskId %d is larger than %d", fromTaskId, Short.MAX_VALUE); + byte[] channelName = new byte[20]; + + for (int i = 11; i >= 8; i--) { + channelName[i] = (byte) (ts & 0xff); + ts >>= 8; + } + + channelName[16] = (byte) ((fromTaskId & 0xffff) >> 8); + channelName[17] = (byte) (fromTaskId & 0xff); + channelName[18] = (byte) ((toTaskId & 0xffff) >> 8); + channelName[19] = (byte) (toTaskId & 0xff); + + return ChannelID.idBytesToStr(channelName); + } + + /** + * @param id hex string representation of channel id + * @return bytes representation of channel id + */ + static byte[] idStrToBytes(String id) { + byte[] idBytes = BaseEncoding.base16().decode(id.toUpperCase()); + assert idBytes.length == ChannelID.ID_LENGTH; + return idBytes; + } + + /** + * @param id bytes representation of channel id + * @return hex string representation of channel id + */ + static String idBytesToStr(byte[] id) { + assert id.length == ChannelID.ID_LENGTH; + return BaseEncoding.base16().encode(id).toLowerCase(); + } + +} + diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInitException.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInitException.java new file mode 100644 index 000000000000..9c5206ba3308 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInitException.java @@ -0,0 +1,24 @@ +package org.ray.streaming.runtime.transfer; + +import java.util.ArrayList; +import java.util.List; + +public class ChannelInitException extends Exception { + + private final List abnormalQueues; + + public ChannelInitException(String message, List abnormalQueues) { + super(message); + this.abnormalQueues = abnormalQueues; + } + + public List getAbnormalChannels() { + return abnormalQueues; + } + + public List getAbnormalChannelsString() { + List res = new ArrayList<>(); + abnormalQueues.forEach(ele -> res.add(ChannelID.idBytesToStr(ele))); + return res; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInterruptException.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInterruptException.java new file mode 100644 index 000000000000..b922bc415516 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelInterruptException.java @@ -0,0 +1,11 @@ +package org.ray.streaming.runtime.transfer; + +public class ChannelInterruptException extends RuntimeException { + public ChannelInterruptException() { + super(); + } + + public ChannelInterruptException(String message) { + super(message); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelUtils.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelUtils.java new file mode 100644 index 000000000000..893c213978be --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/ChannelUtils.java @@ -0,0 +1,40 @@ +package org.ray.streaming.runtime.transfer; + +import java.util.Map; +import org.ray.streaming.runtime.generated.Streaming; +import org.ray.streaming.util.Config; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class ChannelUtils { + private static final Logger LOGGER = LoggerFactory.getLogger(ChannelUtils.class); + + static byte[] toNativeConf(Map conf) { + Streaming.StreamingConfig.Builder builder = Streaming.StreamingConfig.newBuilder(); + if (conf.containsKey(Config.STREAMING_JOB_NAME)) { + builder.setJobName(conf.get(Config.STREAMING_JOB_NAME)); + } + if (conf.containsKey(Config.TASK_JOB_ID)) { + builder.setTaskJobId(conf.get(Config.TASK_JOB_ID)); + } + if (conf.containsKey(Config.STREAMING_WORKER_NAME)) { + builder.setWorkerName(conf.get(Config.STREAMING_WORKER_NAME)); + } + if (conf.containsKey(Config.STREAMING_OP_NAME)) { + builder.setOpName(conf.get(Config.STREAMING_OP_NAME)); + } + if (conf.containsKey(Config.STREAMING_RING_BUFFER_CAPACITY)) { + builder.setRingBufferCapacity( + Integer.parseInt(conf.get(Config.STREAMING_RING_BUFFER_CAPACITY))); + } + if (conf.containsKey(Config.STREAMING_EMPTY_MESSAGE_INTERVAL)) { + builder.setEmptyMessageInterval( + Integer.parseInt(conf.get(Config.STREAMING_EMPTY_MESSAGE_INTERVAL))); + } + Streaming.StreamingConfig streamingConf = builder.build(); + LOGGER.info("Streaming native conf {}", streamingConf.toString()); + return streamingConf.toByteArray(); + } + +} + diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataMessage.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataMessage.java new file mode 100644 index 000000000000..2fb80b09eb64 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataMessage.java @@ -0,0 +1,54 @@ +package org.ray.streaming.runtime.transfer; + +import java.nio.ByteBuffer; + +/** + * DataMessage represents data between upstream and downstream operator + */ +public class DataMessage implements Message { + private final ByteBuffer body; + private final long msgId; + private final long timestamp; + private final String channelId; + + public DataMessage(ByteBuffer body, long timestamp, long msgId, String channelId) { + this.body = body; + this.timestamp = timestamp; + this.msgId = msgId; + this.channelId = channelId; + } + + @Override + public ByteBuffer body() { + return body; + } + + @Override + public long timestamp() { + return timestamp; + } + + /** + * @return message id + */ + public long msgId() { + return msgId; + } + + /** + * @return string id of channel where data is coming from + */ + public String channelId() { + return channelId; + } + + @Override + public String toString() { + return "DataMessage{" + + "body=" + body + + ", msgId=" + msgId + + ", timestamp=" + timestamp + + ", channelId='" + channelId + '\'' + + '}'; + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataReader.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataReader.java new file mode 100644 index 000000000000..d1ce9327d54f --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataReader.java @@ -0,0 +1,258 @@ +package org.ray.streaming.runtime.transfer; + +import com.google.common.base.Preconditions; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Queue; +import org.ray.api.id.ActorId; +import org.ray.streaming.runtime.util.Platform; +import org.ray.streaming.util.Config; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataReader is wrapper of streaming c++ DataReader, which read data + * from channels of upstream workers + */ +public class DataReader { + private static final Logger LOGGER = LoggerFactory.getLogger(DataReader.class); + + private long nativeReaderPtr; + private Queue buf = new LinkedList<>(); + + public DataReader(List inputChannels, + List fromActors, + Map conf) { + Preconditions.checkArgument(inputChannels.size() > 0); + Preconditions.checkArgument(inputChannels.size() == fromActors.size()); + byte[][] inputChannelsBytes = inputChannels.stream() + .map(ChannelID::idStrToBytes).toArray(byte[][]::new); + byte[][] fromActorsBytes = fromActors.stream() + .map(ActorId::getBytes).toArray(byte[][]::new); + long[] seqIds = new long[inputChannels.size()]; + long[] msgIds = new long[inputChannels.size()]; + for (int i = 0; i < inputChannels.size(); i++) { + seqIds[i] = 0; + msgIds[i] = 0; + } + long timerInterval = Long.parseLong( + conf.getOrDefault(Config.TIMER_INTERVAL_MS, "-1")); + String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); + boolean isMock = false; + if (Config.MEMORY_CHANNEL.equals(channelType)) { + isMock = true; + } + boolean isRecreate = Boolean.parseBoolean( + conf.getOrDefault(Config.IS_RECREATE, "false")); + this.nativeReaderPtr = createDataReaderNative( + inputChannelsBytes, + fromActorsBytes, + seqIds, + msgIds, + timerInterval, + isRecreate, + ChannelUtils.toNativeConf(conf), + isMock + ); + LOGGER.info("create DataReader succeed"); + } + + // params set by getBundleNative: bundle data address + size + private final ByteBuffer getBundleParams = ByteBuffer.allocateDirect(24); + // We use direct buffer to reduce gc overhead and memory copy. + private final ByteBuffer bundleData = Platform.wrapDirectBuffer(0, 0); + private final ByteBuffer bundleMeta = ByteBuffer.allocateDirect(BundleMeta.LENGTH); + + { + getBundleParams.order(ByteOrder.nativeOrder()); + bundleData.order(ByteOrder.nativeOrder()); + bundleMeta.order(ByteOrder.nativeOrder()); + } + + /** + * Read message from input channels, if timeout, return null. + * + * @param timeoutMillis timeout + * @return message or null + */ + public DataMessage read(long timeoutMillis) { + if (buf.isEmpty()) { + getBundle(timeoutMillis); + // if bundle not empty. empty message still has data size + seqId + msgId + if (bundleData.position() < bundleData.limit()) { + BundleMeta bundleMeta = new BundleMeta(this.bundleMeta); + // barrier + if (bundleMeta.getBundleType() == DataBundleType.BARRIER) { + throw new UnsupportedOperationException( + "Unsupported bundle type " + bundleMeta.getBundleType()); + } else if (bundleMeta.getBundleType() == DataBundleType.BUNDLE) { + String channelID = bundleMeta.getChannelID(); + long timestamp = bundleMeta.getBundleTs(); + for (int i = 0; i < bundleMeta.getMessageListSize(); i++) { + buf.offer(getDataMessage(bundleData, channelID, timestamp)); + } + } else if (bundleMeta.getBundleType() == DataBundleType.EMPTY) { + long messageId = bundleMeta.getLastMessageId(); + buf.offer(new DataMessage(null, bundleMeta.getBundleTs(), + messageId, bundleMeta.getChannelID())); + } + } + } + if (buf.isEmpty()) { + return null; + } + return buf.poll(); + } + + private DataMessage getDataMessage(ByteBuffer bundleData, String channelID, long timestamp) { + int dataSize = bundleData.getInt(); + // msgId + long msgId = bundleData.getLong(); + // msgType + bundleData.getInt(); + // make `data.capacity() == data.remaining()`, because some code used `capacity()` + // rather than `remaining()` + int position = bundleData.position(); + int limit = bundleData.limit(); + bundleData.limit(position + dataSize); + ByteBuffer data = bundleData.slice(); + bundleData.limit(limit); + bundleData.position(position + dataSize); + return new DataMessage(data, timestamp, msgId, channelID); + } + + private void getBundle(long timeoutMillis) { + getBundleNative(nativeReaderPtr, timeoutMillis, + Platform.getAddress(getBundleParams), Platform.getAddress(bundleMeta)); + bundleMeta.rewind(); + long bundleAddress = getBundleParams.getLong(0); + int bundleSize = getBundleParams.getInt(8); + // This has better performance than NewDirectBuffer or set address/capacity in jni. + Platform.wrapDirectBuffer(bundleData, bundleAddress, bundleSize); + } + + /** + * Stop reader + */ + public void stop() { + stopReaderNative(nativeReaderPtr); + } + + /** + * Close reader to release resource + */ + public void close() { + if (nativeReaderPtr == 0) { + return; + } + LOGGER.info("closing DataReader."); + closeReaderNative(nativeReaderPtr); + nativeReaderPtr = 0; + LOGGER.info("closing DataReader done."); + } + + private static native long createDataReaderNative( + byte[][] inputChannels, + byte[][] inputActorIds, + long[] seqIds, + long[] msgIds, + long timerInterval, + boolean isRecreate, + byte[] configBytes, + boolean isMock); + + private native void getBundleNative(long nativeReaderPtr, + long timeoutMillis, + long params, + long metaAddress); + + private native void stopReaderNative(long nativeReaderPtr); + + private native void closeReaderNative(long nativeReaderPtr); + + enum DataBundleType { + EMPTY(1), + BARRIER(2), + BUNDLE(3); + + int code; + + DataBundleType(int code) { + this.code = code; + } + } + + static class BundleMeta { + // kMessageBundleHeaderSize + kUniqueIDSize: + // magicNum(4b) + bundleTs(8b) + lastMessageId(8b) + messageListSize(4b) + // + bundleType(4b) + rawBundleSize(4b) + channelID(20b) + static final int LENGTH = 4 + 8 + 8 + 4 + 4 + 4 + 20; + private int magicNum; + private long bundleTs; + private long lastMessageId; + private int messageListSize; + private DataBundleType bundleType; + private String channelID; + private int rawBundleSize; + + BundleMeta(ByteBuffer buffer) { + // StreamingMessageBundleMeta Deserialization + // magicNum + magicNum = buffer.getInt(); + // messageBundleTs + bundleTs = buffer.getLong(); + // lastOffsetSeqId + lastMessageId = buffer.getLong(); + messageListSize = buffer.getInt(); + int typeInt = buffer.getInt(); + if (DataBundleType.BUNDLE.code == typeInt) { + bundleType = DataBundleType.BUNDLE; + } else if (DataBundleType.BARRIER.code == typeInt) { + bundleType = DataBundleType.BARRIER; + } else { + bundleType = DataBundleType.EMPTY; + } + // rawBundleSize + rawBundleSize = buffer.getInt(); + channelID = getQidString(buffer); + } + + private String getQidString(ByteBuffer buffer) { + byte[] bytes = new byte[ChannelID.ID_LENGTH]; + buffer.get(bytes); + return ChannelID.idBytesToStr(bytes); + } + + public int getMagicNum() { + return magicNum; + } + + public long getBundleTs() { + return bundleTs; + } + + public long getLastMessageId() { + return lastMessageId; + } + + public int getMessageListSize() { + return messageListSize; + } + + public DataBundleType getBundleType() { + return bundleType; + } + + public String getChannelID() { + return channelID; + } + + public int getRawBundleSize() { + return rawBundleSize; + } + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataWriter.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataWriter.java new file mode 100644 index 000000000000..b0c943b0eb7b --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/DataWriter.java @@ -0,0 +1,140 @@ +package org.ray.streaming.runtime.transfer; + +import com.google.common.base.Preconditions; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.List; +import java.util.Map; +import java.util.Set; +import org.ray.api.id.ActorId; +import org.ray.streaming.runtime.util.Platform; +import org.ray.streaming.util.Config; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * DataWriter is a wrapper of streaming c++ DataWriter, which sends data + * to downstream workers + */ +public class DataWriter { + private static final Logger LOGGER = LoggerFactory.getLogger(DataWriter.class); + + private long nativeWriterPtr; + private ByteBuffer buffer = ByteBuffer.allocateDirect(0); + private long bufferAddress; + + { + ensureBuffer(0); + } + + /** + * @param outputChannels output channels ids + * @param toActors downstream output actors + * @param conf configuration + */ + public DataWriter(List outputChannels, + List toActors, + Map conf) { + Preconditions.checkArgument(!outputChannels.isEmpty()); + Preconditions.checkArgument(outputChannels.size() == toActors.size()); + byte[][] outputChannelsBytes = outputChannels.stream() + .map(ChannelID::idStrToBytes).toArray(byte[][]::new); + byte[][] toActorsBytes = toActors.stream() + .map(ActorId::getBytes).toArray(byte[][]::new); + long channelSize = Long.parseLong( + conf.getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT)); + long[] msgIds = new long[outputChannels.size()]; + for (int i = 0; i < outputChannels.size(); i++) { + msgIds[i] = 0; + } + String channelType = conf.getOrDefault(Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); + boolean isMock = false; + if (Config.MEMORY_CHANNEL.equals(channelType)) { + isMock = true; + } + this.nativeWriterPtr = createWriterNative( + outputChannelsBytes, + toActorsBytes, + msgIds, + channelSize, + ChannelUtils.toNativeConf(conf), + isMock + ); + LOGGER.info("create DataWriter succeed"); + } + + /** + * Write msg into the specified channel + * + * @param id channel id + * @param item message item data section is specified by [position, limit). + */ + public void write(ChannelID id, ByteBuffer item) { + int size = item.remaining(); + ensureBuffer(size); + buffer.clear(); + buffer.put(item); + writeMessageNative(nativeWriterPtr, id.getNativeIdPtr(), bufferAddress, size); + } + + /** + * Write msg into the specified channels + * + * @param ids channel ids + * @param item message item data section is specified by [position, limit). + * item doesn't have to be a direct buffer. + */ + public void write(Set ids, ByteBuffer item) { + int size = item.remaining(); + ensureBuffer(size); + for (ChannelID id : ids) { + buffer.clear(); + buffer.put(item.duplicate()); + writeMessageNative(nativeWriterPtr, id.getNativeIdPtr(), bufferAddress, size); + } + } + + private void ensureBuffer(int size) { + if (buffer.capacity() < size) { + buffer = ByteBuffer.allocateDirect(size); + buffer.order(ByteOrder.nativeOrder()); + bufferAddress = Platform.getAddress(buffer); + } + } + + /** + * stop writer + */ + public void stop() { + stopWriterNative(nativeWriterPtr); + } + + /** + * close writer to release resources + */ + public void close() { + if (nativeWriterPtr == 0) { + return; + } + LOGGER.info("closing data writer."); + closeWriterNative(nativeWriterPtr); + nativeWriterPtr = 0; + LOGGER.info("closing data writer done."); + } + + private static native long createWriterNative( + byte[][] outputQueueIds, + byte[][] outputActorIds, + long[] msgIds, + long channelSize, + byte[] confBytes, + boolean isMock); + + private native long writeMessageNative( + long nativeQueueProducerPtr, long nativeIdPtr, long address, int size); + + private native void stopWriterNative(long nativeQueueProducerPtr); + + private native void closeWriterNative(long nativeQueueProducerPtr); + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/Message.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/Message.java new file mode 100644 index 000000000000..b43e713e7bb4 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/Message.java @@ -0,0 +1,22 @@ +package org.ray.streaming.runtime.transfer; + +import java.nio.ByteBuffer; + +public interface Message { + + /** + * Message data + * + * Message body is a direct byte buffer, which may be invalid after call next + * DataReader#getBundleNative. Please consume this buffer fully + * before next call getBundleNative. + * + * @return message body + */ + ByteBuffer body(); + + /** + * @return timestamp when item is written by upstream DataWriter + */ + long timestamp(); +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java new file mode 100644 index 000000000000..2307b64e40e9 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/transfer/TransferHandler.java @@ -0,0 +1,72 @@ +package org.ray.streaming.runtime.transfer; + +import com.google.common.base.Preconditions; +import org.ray.runtime.RayNativeRuntime; +import org.ray.runtime.functionmanager.FunctionDescriptor; +import org.ray.runtime.functionmanager.JavaFunctionDescriptor; +import org.ray.runtime.util.JniUtils; + +/** + * TransferHandler is used for handle direct call based data transfer between workers. + * TransferHandler is used by streaming queue for data transfer. + */ +public class TransferHandler { + + static { + try { + Class.forName(RayNativeRuntime.class.getName()); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + JniUtils.loadLibrary("streaming_java"); + } + + private long writerClientNative; + private long readerClientNative; + + public TransferHandler(long coreWorkerNative, + JavaFunctionDescriptor writerAsyncFunc, + JavaFunctionDescriptor writerSyncFunc, + JavaFunctionDescriptor readerAsyncFunc, + JavaFunctionDescriptor readerSyncFunc) { + Preconditions.checkArgument(coreWorkerNative != 0); + writerClientNative = createWriterClientNative( + coreWorkerNative, writerAsyncFunc, writerSyncFunc); + readerClientNative = createReaderClientNative( + coreWorkerNative, readerAsyncFunc, readerSyncFunc); + } + + public void onWriterMessage(byte[] buffer) { + handleWriterMessageNative(writerClientNative, buffer); + } + + public byte[] onWriterMessageSync(byte[] buffer) { + return handleWriterMessageSyncNative(writerClientNative, buffer); + } + + public void onReaderMessage(byte[] buffer) { + handleReaderMessageNative(readerClientNative, buffer); + } + + public byte[] onReaderMessageSync(byte[] buffer) { + return handleReaderMessageSyncNative(readerClientNative, buffer); + } + + private native long createWriterClientNative( + long coreWorkerNative, + FunctionDescriptor asyncFunc, + FunctionDescriptor syncFunc); + + private native long createReaderClientNative( + long coreWorkerNative, + FunctionDescriptor asyncFunc, + FunctionDescriptor syncFunc); + + private native void handleWriterMessageNative(long handler, byte[] buffer); + + private native byte[] handleWriterMessageSyncNative(long handler, byte[] buffer); + + private native void handleReaderMessageNative(long handler, byte[] buffer); + + private native byte[] handleReaderMessageSyncNative(long handler, byte[] buffer); +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/EnvUtil.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/EnvUtil.java new file mode 100644 index 000000000000..caf47d4894d8 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/EnvUtil.java @@ -0,0 +1,19 @@ +package org.ray.streaming.runtime.util; + +import org.ray.runtime.RayNativeRuntime; +import org.ray.runtime.util.JniUtils; + +public class EnvUtil { + + public static void loadNativeLibraries() { + // Explicitly load `RayNativeRuntime`, to make sure `core_worker_library_java` + // is loaded before `streaming_java`. + try { + Class.forName(RayNativeRuntime.class.getName()); + } catch (ClassNotFoundException e) { + throw new RuntimeException(e); + } + JniUtils.loadLibrary("streaming_java"); + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/Platform.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/Platform.java new file mode 100644 index 000000000000..21cda5d6fae6 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/util/Platform.java @@ -0,0 +1,91 @@ +package org.ray.streaming.runtime.util; + +import com.google.common.base.Preconditions; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import sun.misc.Unsafe; +import sun.nio.ch.DirectBuffer; + +/** + * Based on org.apache.spark.unsafe.Platform + */ +public final class Platform { + + public static final Unsafe UNSAFE; + + static { + Unsafe unsafe; + try { + Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + unsafeField.setAccessible(true); + unsafe = (Unsafe) unsafeField.get(null); + } catch (Throwable cause) { + throw new UnsupportedOperationException("Unsafe is not supported in this platform."); + } + UNSAFE = unsafe; + } + + // Access fields and constructors once and store them, for performance: + private static final Constructor DBB_CONSTRUCTOR; + private static final long BUFFER_ADDRESS_FIELD_OFFSET; + private static final long BUFFER_CAPACITY_FIELD_OFFSET; + + static { + try { + Class cls = Class.forName("java.nio.DirectByteBuffer"); + Constructor constructor = cls.getDeclaredConstructor(Long.TYPE, Integer.TYPE); + constructor.setAccessible(true); + DBB_CONSTRUCTOR = constructor; + Field addressField = Buffer.class.getDeclaredField("address"); + BUFFER_ADDRESS_FIELD_OFFSET = UNSAFE.objectFieldOffset(addressField); + Preconditions.checkArgument(BUFFER_ADDRESS_FIELD_OFFSET != 0); + Field capacityField = Buffer.class.getDeclaredField("capacity"); + BUFFER_CAPACITY_FIELD_OFFSET = UNSAFE.objectFieldOffset(capacityField); + Preconditions.checkArgument(BUFFER_CAPACITY_FIELD_OFFSET != 0); + } catch (ClassNotFoundException | NoSuchMethodException | NoSuchFieldException e) { + throw new IllegalStateException(e); + } + } + + private static final ThreadLocal localEmptyBuffer = + ThreadLocal.withInitial(() -> { + try { + return (ByteBuffer) DBB_CONSTRUCTOR.newInstance(0, 0); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + UNSAFE.throwException(e); + } + throw new IllegalStateException("unreachable"); + }); + + /** + * Wrap a buffer [address, address + size) as a DirectByteBuffer. + */ + public static ByteBuffer wrapDirectBuffer(long address, int size) { + ByteBuffer buffer = localEmptyBuffer.get().duplicate(); + UNSAFE.putLong(buffer, BUFFER_ADDRESS_FIELD_OFFSET, address); + UNSAFE.putInt(buffer, BUFFER_CAPACITY_FIELD_OFFSET, size); + buffer.clear(); + return buffer; + } + + /** + * Wrap a buffer [address, address + size) into provided buffer. + */ + public static void wrapDirectBuffer(ByteBuffer buffer, long address, int size) { + UNSAFE.putLong(buffer, BUFFER_ADDRESS_FIELD_OFFSET, address); + UNSAFE.putInt(buffer, BUFFER_CAPACITY_FIELD_OFFSET, size); + buffer.clear(); + } + + /** + * @param buffer a DirectBuffer backed by off-heap memory + * @return address of off-heap memory + */ + public static long getAddress(ByteBuffer buffer) { + return ((DirectBuffer) buffer).address(); + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java new file mode 100644 index 000000000000..bb7e607b03c5 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/JobWorker.java @@ -0,0 +1,159 @@ +package org.ray.streaming.runtime.worker; + +import java.io.Serializable; +import java.util.Map; +import org.ray.api.Ray; +import org.ray.api.annotation.RayRemote; +import org.ray.runtime.RayMultiWorkerNativeRuntime; +import org.ray.runtime.functionmanager.JavaFunctionDescriptor; +import org.ray.streaming.runtime.core.graph.ExecutionGraph; +import org.ray.streaming.runtime.core.graph.ExecutionNode; +import org.ray.streaming.runtime.core.graph.ExecutionNode.NodeType; +import org.ray.streaming.runtime.core.graph.ExecutionTask; +import org.ray.streaming.runtime.core.processor.OneInputProcessor; +import org.ray.streaming.runtime.core.processor.SourceProcessor; +import org.ray.streaming.runtime.core.processor.StreamProcessor; +import org.ray.streaming.runtime.transfer.TransferHandler; +import org.ray.streaming.runtime.util.EnvUtil; +import org.ray.streaming.runtime.worker.context.WorkerContext; +import org.ray.streaming.runtime.worker.tasks.OneInputStreamTask; +import org.ray.streaming.runtime.worker.tasks.SourceStreamTask; +import org.ray.streaming.runtime.worker.tasks.StreamTask; +import org.ray.streaming.util.Config; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * The stream job worker, it is a ray actor. + */ +@RayRemote +public class JobWorker implements Serializable { + private static final Logger LOGGER = LoggerFactory.getLogger(JobWorker.class); + + static { + EnvUtil.loadNativeLibraries(); + } + + private int taskId; + private Map config; + private WorkerContext workerContext; + private ExecutionNode executionNode; + private ExecutionTask executionTask; + private ExecutionGraph executionGraph; + private StreamProcessor streamProcessor; + private NodeType nodeType; + private StreamTask task; + private TransferHandler transferHandler; + + public Boolean init(WorkerContext workerContext) { + this.workerContext = workerContext; + this.taskId = workerContext.getTaskId(); + this.config = workerContext.getConfig(); + this.executionGraph = this.workerContext.getExecutionGraph(); + this.executionTask = executionGraph.getExecutionTaskByTaskId(taskId); + this.executionNode = executionGraph.getExecutionNodeByTaskId(taskId); + + this.nodeType = executionNode.getNodeType(); + this.streamProcessor = executionNode.getStreamProcessor(); + LOGGER.debug("Initializing StreamWorker, taskId: {}, operator: {}.", taskId, streamProcessor); + + String channelType = (String) this.config.getOrDefault( + Config.CHANNEL_TYPE, Config.DEFAULT_CHANNEL_TYPE); + if (channelType.equals(Config.NATIVE_CHANNEL)) { + transferHandler = new TransferHandler( + getNativeCoreWorker(), + new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessage", "([B)V"), + new JavaFunctionDescriptor(JobWorker.class.getName(), "onWriterMessageSync", "([B)[B"), + new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessage", "([B)V"), + new JavaFunctionDescriptor(JobWorker.class.getName(), "onReaderMessageSync", "([B)[B")); + } + task = createStreamTask(); + task.start(); + return true; + } + + private StreamTask createStreamTask() { + if (streamProcessor instanceof OneInputProcessor) { + return new OneInputStreamTask(taskId, streamProcessor, this); + } else if (streamProcessor instanceof SourceProcessor) { + return new SourceStreamTask(taskId, streamProcessor, this); + } else { + throw new RuntimeException("Unsupported type: " + streamProcessor); + } + } + + public int getTaskId() { + return taskId; + } + + public Map getConfig() { + return config; + } + + public WorkerContext getWorkerContext() { + return workerContext; + } + + public NodeType getNodeType() { + return nodeType; + } + + public ExecutionNode getExecutionNode() { + return executionNode; + } + + public ExecutionTask getExecutionTask() { + return executionTask; + } + + public ExecutionGraph getExecutionGraph() { + return executionGraph; + } + + public StreamProcessor getStreamProcessor() { + return streamProcessor; + } + + public StreamTask getTask() { + return task; + } + + /** + * Used by upstream streaming queue to send data to this actor + */ + public void onReaderMessage(byte[] buffer) { + transferHandler.onReaderMessage(buffer); + } + + /** + * Used by upstream streaming queue to send data to this actor + * and receive result from this actor + */ + public byte[] onReaderMessageSync(byte[] buffer) { + return transferHandler.onReaderMessageSync(buffer); + } + + /** + * Used by downstream streaming queue to send data to this actor + */ + public void onWriterMessage(byte[] buffer) { + transferHandler.onWriterMessage(buffer); + } + + /** + * Used by downstream streaming queue to send data to this actor + * and receive result from this actor + */ + public byte[] onWriterMessageSync(byte[] buffer) { + return transferHandler.onWriterMessageSync(buffer); + } + + private static long getNativeCoreWorker() { + long pointer = 0; + if (Ray.internal() instanceof RayMultiWorkerNativeRuntime) { + pointer = ((RayMultiWorkerNativeRuntime) Ray.internal()) + .getCurrentRuntime().getNativeCoreWorkerPointer(); + } + return pointer; + } +} diff --git a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RayRuntimeContext.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/RayRuntimeContext.java similarity index 66% rename from java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RayRuntimeContext.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/RayRuntimeContext.java index 6d796c30eff5..e6779733c235 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/RayRuntimeContext.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/RayRuntimeContext.java @@ -1,19 +1,21 @@ -package org.ray.streaming.core.runtime.context; +package org.ray.streaming.runtime.worker.context; -import static org.ray.streaming.util.ConfigKey.STREAMING_MAX_BATCH_COUNT; +import static org.ray.streaming.util.Config.STREAMING_BATCH_MAX_COUNT; import java.util.Map; -import org.ray.streaming.core.graph.ExecutionTask; + +import org.ray.streaming.api.context.RuntimeContext; +import org.ray.streaming.runtime.core.graph.ExecutionTask; /** * Use Ray to implement RuntimeContext. */ public class RayRuntimeContext implements RuntimeContext { - private int taskId; private int taskIndex; private int parallelism; private Long batchId; + private final Long maxBatch; private Map config; public RayRuntimeContext(ExecutionTask executionTask, Map config, @@ -22,6 +24,11 @@ public RayRuntimeContext(ExecutionTask executionTask, Map config this.config = config; this.taskIndex = executionTask.getTaskIndex(); this.parallelism = parallelism; + if (config.containsKey(STREAMING_BATCH_MAX_COUNT)) { + this.maxBatch = Long.valueOf(String.valueOf(config.get(STREAMING_BATCH_MAX_COUNT))); + } else { + this.maxBatch = Long.MAX_VALUE; + } } @Override @@ -46,10 +53,7 @@ public Long getBatchId() { @Override public Long getMaxBatch() { - if (config.containsKey(STREAMING_MAX_BATCH_COUNT)) { - return Long.valueOf(String.valueOf(config.get(STREAMING_MAX_BATCH_COUNT))); - } - return Long.MAX_VALUE; + return maxBatch; } public void setBatchId(Long batchId) { diff --git a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/WorkerContext.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/WorkerContext.java similarity index 88% rename from java/streaming/src/main/java/org/ray/streaming/core/runtime/context/WorkerContext.java rename to streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/WorkerContext.java index 39117fd7b225..567909f81179 100644 --- a/java/streaming/src/main/java/org/ray/streaming/core/runtime/context/WorkerContext.java +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/context/WorkerContext.java @@ -1,8 +1,8 @@ -package org.ray.streaming.core.runtime.context; +package org.ray.streaming.runtime.worker.context; import java.io.Serializable; import java.util.Map; -import org.ray.streaming.core.graph.ExecutionGraph; +import org.ray.streaming.runtime.core.graph.ExecutionGraph; /** * Encapsulate the context information for worker initialization. diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java new file mode 100644 index 000000000000..eed12f705cbc --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/InputStreamTask.java @@ -0,0 +1,53 @@ +package org.ray.streaming.runtime.worker.tasks; + +import org.ray.runtime.util.Serializer; +import org.ray.streaming.runtime.core.processor.Processor; +import org.ray.streaming.runtime.transfer.Message; +import org.ray.streaming.runtime.worker.JobWorker; +import org.ray.streaming.util.Config; + +public abstract class InputStreamTask extends StreamTask { + private volatile boolean running = true; + private volatile boolean stopped = false; + private long readTimeoutMillis; + + public InputStreamTask(int taskId, Processor processor, JobWorker streamWorker) { + super(taskId, processor, streamWorker); + readTimeoutMillis = Long.parseLong((String) streamWorker.getConfig() + .getOrDefault(Config.READ_TIMEOUT_MS, Config.DEFAULT_READ_TIMEOUT_MS)); + } + + @Override + protected void init() { + } + + @Override + public void run() { + while (running) { + Message item = reader.read(readTimeoutMillis); + if (item != null) { + byte[] bytes = new byte[item.body().remaining()]; + item.body().get(bytes); + Object obj = Serializer.decode(bytes); + processor.process(obj); + } + } + stopped = true; + } + + @Override + protected void cancelTask() throws Exception { + running = false; + while (!stopped) { + } + } + + @Override + public String toString() { + final StringBuilder sb = new StringBuilder("InputStreamTask{"); + sb.append("taskId=").append(taskId); + sb.append(", processor=").append(processor); + sb.append('}'); + return sb.toString(); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java new file mode 100644 index 000000000000..0b9491f76679 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/OneInputStreamTask.java @@ -0,0 +1,11 @@ +package org.ray.streaming.runtime.worker.tasks; + +import org.ray.streaming.runtime.core.processor.Processor; +import org.ray.streaming.runtime.worker.JobWorker; + +public class OneInputStreamTask extends InputStreamTask { + + public OneInputStreamTask(int taskId, Processor processor, JobWorker streamWorker) { + super(taskId, processor, streamWorker); + } +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/SourceStreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/SourceStreamTask.java new file mode 100644 index 000000000000..74a197708e0c --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/SourceStreamTask.java @@ -0,0 +1,30 @@ +package org.ray.streaming.runtime.worker.tasks; + +import org.ray.streaming.runtime.core.processor.Processor; +import org.ray.streaming.runtime.core.processor.SourceProcessor; +import org.ray.streaming.runtime.worker.JobWorker; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public class SourceStreamTask extends StreamTask { + private static final Logger LOGGER = LoggerFactory.getLogger(SourceStreamTask.class); + + public SourceStreamTask(int taskId, Processor processor, JobWorker worker) { + super(taskId, processor, worker); + } + + @Override + protected void init() { + } + + @Override + public void run() { + final SourceProcessor sourceProcessor = (SourceProcessor) this.processor; + sourceProcessor.run(); + } + + @Override + protected void cancelTask() throws Exception { + } + +} diff --git a/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java new file mode 100644 index 000000000000..7d4f397ba04f --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/java/org/ray/streaming/runtime/worker/tasks/StreamTask.java @@ -0,0 +1,134 @@ +package org.ray.streaming.runtime.worker.tasks; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.id.ActorId; +import org.ray.streaming.api.collector.Collector; +import org.ray.streaming.api.context.RuntimeContext; +import org.ray.streaming.runtime.core.collector.OutputCollector; +import org.ray.streaming.runtime.core.graph.ExecutionEdge; +import org.ray.streaming.runtime.core.graph.ExecutionGraph; +import org.ray.streaming.runtime.core.graph.ExecutionNode; +import org.ray.streaming.runtime.core.processor.Processor; +import org.ray.streaming.runtime.transfer.ChannelID; +import org.ray.streaming.runtime.transfer.DataReader; +import org.ray.streaming.runtime.transfer.DataWriter; +import org.ray.streaming.runtime.worker.JobWorker; +import org.ray.streaming.runtime.worker.context.RayRuntimeContext; +import org.ray.streaming.util.Config; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +public abstract class StreamTask implements Runnable { + private static final Logger LOG = LoggerFactory.getLogger(StreamTask.class); + + protected int taskId; + protected Processor processor; + protected JobWorker worker; + protected DataReader reader; + private Map writers; + private Thread thread; + + public StreamTask(int taskId, Processor processor, JobWorker worker) { + this.taskId = taskId; + this.processor = processor; + this.worker = worker; + prepareTask(); + + this.thread = new Thread(Ray.wrapRunnable(this), this.getClass().getName() + + "-" + System.currentTimeMillis()); + this.thread.setDaemon(true); + } + + private void prepareTask() { + Map queueConf = new HashMap<>(); + worker.getConfig().forEach((k, v) -> queueConf.put(k, String.valueOf(v))); + String queueSize = (String) worker.getConfig() + .getOrDefault(Config.CHANNEL_SIZE, Config.CHANNEL_SIZE_DEFAULT); + queueConf.put(Config.CHANNEL_SIZE, queueSize); + queueConf.put(Config.TASK_JOB_ID, Ray.getRuntimeContext().getCurrentJobId().toString()); + String channelType = (String) worker.getConfig() + .getOrDefault(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL); + queueConf.put(Config.CHANNEL_TYPE, channelType); + + ExecutionGraph executionGraph = worker.getExecutionGraph(); + ExecutionNode executionNode = worker.getExecutionNode(); + + // writers + writers = new HashMap<>(); + List outputEdges = executionNode.getOutputEdges(); + List collectors = new ArrayList<>(); + for (ExecutionEdge edge : outputEdges) { + Map outputActorIds = new HashMap<>(); + Map> taskId2Worker = executionGraph + .getTaskId2WorkerByNodeId(edge.getTargetNodeId()); + taskId2Worker.forEach((targetTaskId, targetActor) -> { + String queueName = ChannelID.genIdStr(taskId, targetTaskId, executionGraph.getBuildTime()); + outputActorIds.put(queueName, targetActor.getId()); + }); + + if (!outputActorIds.isEmpty()) { + List channelIDs = new ArrayList<>(); + List toActorIds = new ArrayList<>(); + outputActorIds.forEach((k, v) -> { + channelIDs.add(k); + toActorIds.add(v); + }); + DataWriter writer = new DataWriter(channelIDs, toActorIds, queueConf); + LOG.info("Create DataWriter succeed."); + writers.put(edge, writer); + collectors.add(new OutputCollector(channelIDs, writer, edge.getPartition())); + } + } + + // consumer + List inputEdges = executionNode.getInputsEdges(); + Map inputActorIds = new HashMap<>(); + for (ExecutionEdge edge : inputEdges) { + Map> taskId2Worker = executionGraph + .getTaskId2WorkerByNodeId(edge.getSrcNodeId()); + taskId2Worker.forEach((srcTaskId, srcActor) -> { + String queueName = ChannelID.genIdStr(srcTaskId, taskId, executionGraph.getBuildTime()); + inputActorIds.put(queueName, srcActor.getId()); + }); + } + if (!inputActorIds.isEmpty()) { + List channelIDs = new ArrayList<>(); + List fromActorIds = new ArrayList<>(); + inputActorIds.forEach((k, v) -> { + channelIDs.add(k); + fromActorIds.add(v); + }); + LOG.info("Register queue consumer, queues {}.", channelIDs); + reader = new DataReader(channelIDs, fromActorIds, queueConf); + } + + RuntimeContext runtimeContext = new RayRuntimeContext( + worker.getExecutionTask(), worker.getConfig(), executionNode.getParallelism()); + + processor.open(collectors, runtimeContext); + + Runtime.getRuntime().addShutdownHook(new Thread(() -> { + try { + // Make DataReader stop read data when MockQueue destructor gets called to avoid crash + StreamTask.this.cancelTask(); + } catch (Exception e) { + e.printStackTrace(); + } + })); + } + + protected abstract void init() throws Exception; + + protected abstract void cancelTask() throws Exception; + + public void start() { + this.thread.start(); + LOG.info("started {}-{}", this.getClass().getSimpleName(), taskId); + } + +} diff --git a/streaming/java/streaming-runtime/src/main/resources/META-INF/services/org.ray.streaming.schedule.JobScheduler b/streaming/java/streaming-runtime/src/main/resources/META-INF/services/org.ray.streaming.schedule.JobScheduler new file mode 100644 index 000000000000..53719a32af25 --- /dev/null +++ b/streaming/java/streaming-runtime/src/main/resources/META-INF/services/org.ray.streaming.schedule.JobScheduler @@ -0,0 +1 @@ +org.ray.streaming.runtime.schedule.JobSchedulerImpl \ No newline at end of file diff --git a/java/streaming/src/test/java/org/ray/streaming/demo/WordCountTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/demo/WordCountTest.java similarity index 88% rename from java/streaming/src/test/java/org/ray/streaming/demo/WordCountTest.java rename to streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/demo/WordCountTest.java index d104c8e6d1af..e6427120a52c 100644 --- a/java/streaming/src/test/java/org/ray/streaming/demo/WordCountTest.java +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/demo/WordCountTest.java @@ -1,4 +1,4 @@ -package org.ray.streaming.demo; +package org.ray.streaming.runtime.demo; import com.google.common.collect.ImmutableMap; import org.ray.streaming.api.context.StreamingContext; @@ -6,7 +6,7 @@ import org.ray.streaming.api.function.impl.ReduceFunction; import org.ray.streaming.api.function.impl.SinkFunction; import org.ray.streaming.api.stream.StreamSource; -import org.ray.streaming.util.ConfigKey; +import org.ray.streaming.util.Config; import java.io.Serializable; import java.util.ArrayList; import java.util.HashMap; @@ -31,7 +31,8 @@ public class WordCountTest implements Serializable { public void testWordCount() { StreamingContext streamingContext = StreamingContext.buildContext(); Map config = new HashMap<>(); - config.put(ConfigKey.STREAMING_MAX_BATCH_COUNT, 1); + config.put(Config.STREAMING_BATCH_MAX_COUNT, 1); + config.put(Config.CHANNEL_TYPE, Config.MEMORY_CHANNEL); streamingContext.withConfig(config); List text = new ArrayList<>(); text.add("hello world eagle eagle eagle"); @@ -46,7 +47,8 @@ public void testWordCount() { .keyBy(pair -> pair.word) .reduce((ReduceFunction) (oldValue, newValue) -> new WordAndCount(oldValue.word, oldValue.count + newValue.count)) - .sink((SinkFunction) result -> wordCount.put(result.word, result.count)); + .sink((SinkFunction) + result -> wordCount.put(result.word, result.count)); streamingContext.execute(); diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignImplTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignImplTest.java new file mode 100644 index 000000000000..f6057f28984f --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/schedule/TaskAssignImplTest.java @@ -0,0 +1,75 @@ +package org.ray.streaming.runtime.schedule; + +import java.util.ArrayList; +import java.util.List; + +import com.google.common.collect.Lists; +import org.ray.api.RayActor; +import org.ray.api.id.ActorId; +import org.ray.api.id.ObjectId; +import org.ray.runtime.actor.LocalModeRayActor; +import org.ray.streaming.api.context.StreamingContext; +import org.ray.streaming.api.partition.impl.RoundRobinPartition; +import org.ray.streaming.api.stream.DataStream; +import org.ray.streaming.api.stream.StreamSink; +import org.ray.streaming.api.stream.StreamSource; +import org.ray.streaming.runtime.core.graph.ExecutionEdge; +import org.ray.streaming.runtime.core.graph.ExecutionGraph; +import org.ray.streaming.runtime.core.graph.ExecutionNode; +import org.ray.streaming.runtime.core.graph.ExecutionNode.NodeType; +import org.ray.streaming.runtime.worker.JobWorker; +import org.ray.streaming.plan.Plan; +import org.ray.streaming.plan.PlanBuilder; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.Test; + +public class TaskAssignImplTest { + + private static final Logger LOGGER = LoggerFactory.getLogger(TaskAssignImplTest.class); + + @Test + public void testTaskAssignImpl() { + Plan plan = buildDataSyncPlan(); + + List> workers = new ArrayList<>(); + for(int i = 0; i < plan.getPlanVertexList().size(); i++) { + workers.add(new LocalModeRayActor(ActorId.fromRandom(), ObjectId.fromRandom())); + } + + ITaskAssign taskAssign = new TaskAssignImpl(); + ExecutionGraph executionGraph = taskAssign.assign(plan, workers); + + List executionNodeList = executionGraph.getExecutionNodeList(); + + Assert.assertEquals(executionNodeList.size(), 2); + ExecutionNode sourceNode = executionNodeList.get(0); + Assert.assertEquals(sourceNode.getNodeType(), NodeType.SOURCE); + Assert.assertEquals(sourceNode.getExecutionTasks().size(), 1); + Assert.assertEquals(sourceNode.getOutputEdges().size(), 1); + + List sourceExecutionEdges = sourceNode.getOutputEdges(); + + Assert.assertEquals(sourceExecutionEdges.size(), 1); + ExecutionEdge source2Sink = sourceExecutionEdges.get(0); + + Assert.assertEquals(source2Sink.getPartition().getClass(), RoundRobinPartition.class); + + ExecutionNode sinkNode = executionNodeList.get(1); + Assert.assertEquals(sinkNode.getNodeType(), NodeType.SINK); + Assert.assertEquals(sinkNode.getExecutionTasks().size(), 1); + Assert.assertEquals(sinkNode.getOutputEdges().size(), 0); + } + + public Plan buildDataSyncPlan() { + StreamingContext streamingContext = StreamingContext.buildContext(); + DataStream dataStream = StreamSource.buildSource(streamingContext, + Lists.newArrayList("a", "b", "c")); + StreamSink streamSink = dataStream.sink(x -> LOGGER.info(x)); + PlanBuilder planBuilder = new PlanBuilder(Lists.newArrayList(streamSink)); + + Plan plan = planBuilder.buildPlan(); + return plan; + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java new file mode 100644 index 000000000000..0a65b1abc387 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/StreamingQueueTest.java @@ -0,0 +1,234 @@ +package org.ray.streaming.runtime.streamingqueue; + +import com.google.common.collect.ImmutableMap; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; +import java.lang.management.ManagementFactory; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.options.ActorCreationOptions; +import org.ray.api.options.ActorCreationOptions.Builder; +import org.ray.streaming.api.context.StreamingContext; +import org.ray.streaming.api.function.impl.FlatMapFunction; +import org.ray.streaming.api.function.impl.ReduceFunction; +import org.ray.streaming.api.stream.StreamSource; +import org.ray.streaming.runtime.transfer.ChannelID; +import org.ray.streaming.runtime.util.EnvUtil; +import org.ray.streaming.util.Config; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; +import org.testng.annotations.AfterMethod; +import org.testng.annotations.BeforeClass; +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +public class StreamingQueueTest implements Serializable { + private static Logger LOGGER = LoggerFactory.getLogger(StreamingQueueTest.class); + + static { + EnvUtil.loadNativeLibraries(); + } + + @org.testng.annotations.BeforeSuite + public void suiteSetUp() throws Exception { + LOGGER.info("Do set up"); + String management = ManagementFactory.getRuntimeMXBean().getName(); + String pid = management.split("@")[0]; + + LOGGER.info("StreamingQueueTest pid: {}", pid); + LOGGER.info("java.library.path = {}", System.getProperty("java.library.path")); + } + + @org.testng.annotations.AfterSuite + public void suiteTearDown() throws Exception { + LOGGER.warn("Do tear down"); + } + + @BeforeClass + public void setUp() { + } + + @BeforeMethod + void beforeMethod() { + + LOGGER.info("beforeTest"); + Ray.shutdown(); + System.setProperty("ray.resources", "CPU:4,RES-A:4"); + System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); + System.setProperty("ray.run-mode", "CLUSTER"); + System.setProperty("ray.redirect-output", "true"); + // ray init + Ray.init(); + } + + @AfterMethod + void afterMethod() { + LOGGER.info("afterTest"); + Ray.shutdown(); + System.clearProperty("ray.run-mode"); + } + + @Test(timeOut = 3000000) + public void testReaderWriter() { + LOGGER.info("StreamingQueueTest.testReaderWriter run-mode: {}", + System.getProperty("ray.run-mode")); + Ray.shutdown(); + System.setProperty("ray.resources", "CPU:4,RES-A:4"); + System.setProperty("ray.raylet.config.num_workers_per_process_java", "1"); + + System.setProperty("ray.run-mode", "CLUSTER"); + System.setProperty("ray.redirect-output", "true"); + // ray init + Ray.init(); + + ActorCreationOptions.Builder builder = new Builder(); + + RayActor writerActor = Ray.createActor(WriterWorker::new, "writer", + builder.createActorCreationOptions()); + RayActor readerActor = Ray.createActor(ReaderWorker::new, "reader", + builder.createActorCreationOptions()); + + LOGGER.info("call getName on writerActor: {}", + Ray.call(WriterWorker::getName, writerActor).get()); + LOGGER.info("call getName on readerActor: {}", + Ray.call(ReaderWorker::getName, readerActor).get()); + + // LOGGER.info(Ray.call(WriterWorker::testCallReader, writerActor, readerActor).get()); + List outputQueueList = new ArrayList<>(); + List inputQueueList = new ArrayList<>(); + int queueNum = 2; + for (int i = 0; i < queueNum; ++i) { + String qid = ChannelID.genRandomIdStr(); + LOGGER.info("getRandomQueueId: {}", qid); + inputQueueList.add(qid); + outputQueueList.add(qid); + readerActor.getId(); + } + + final int msgCount = 100; + Ray.call(ReaderWorker::init, readerActor, inputQueueList, writerActor, msgCount); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + Ray.call(WriterWorker::init, writerActor, outputQueueList, readerActor, msgCount); + + long time = 0; + while (time < 20000 && + Ray.call(ReaderWorker::getTotalMsg, readerActor).get() < msgCount * queueNum) { + try { + Thread.sleep(1000); + time += 1000; + } catch (InterruptedException e) { + e.printStackTrace(); + } + } + + Assert.assertEquals( + Ray.call(ReaderWorker::getTotalMsg, readerActor).get().intValue(), + msgCount * queueNum); + } + + @Test(timeOut = 60000) + public void testWordCount() { + LOGGER.info("StreamingQueueTest.testWordCount run-mode: {}", + System.getProperty("ray.run-mode")); + String resultFile = "/tmp/org.ray.streaming.runtime.streamingqueue.testWordCount.txt"; + deleteResultFile(resultFile); + + Map wordCount = new ConcurrentHashMap<>(); + StreamingContext streamingContext = StreamingContext.buildContext(); + Map config = new HashMap<>(); + config.put(Config.STREAMING_BATCH_MAX_COUNT, 1); + config.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL); + config.put(Config.CHANNEL_SIZE, "100000"); + streamingContext.withConfig(config); + List text = new ArrayList<>(); + text.add("hello world eagle eagle eagle"); + StreamSource streamSource = StreamSource.buildSource(streamingContext, text); + streamSource + .flatMap((FlatMapFunction) (value, collector) -> { + String[] records = value.split(" "); + for (String record : records) { + collector.collect(new WordAndCount(record, 1)); + } + }) + .keyBy(pair -> pair.word) + .reduce((ReduceFunction) (oldValue, newValue) -> { + LOGGER.info("reduce: {} {}", oldValue, newValue); + return new WordAndCount(oldValue.word, oldValue.count + newValue.count); + }) + .sink(s -> { + LOGGER.info("sink {} {}", s.word, s.count); + wordCount.put(s.word, s.count); + serializeResultToFile(resultFile, wordCount); + }); + + streamingContext.execute(); + + Map checkWordCount = + (Map) deserializeResultFromFile(resultFile); + // Sleep until the count for every word is computed. + while (checkWordCount == null || checkWordCount.size() < 3) { + LOGGER.info("sleep"); + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + LOGGER.warn("Got an exception while sleeping.", e); + } + checkWordCount = (Map) deserializeResultFromFile(resultFile); + } + LOGGER.info("check"); + Assert.assertEquals(checkWordCount, + ImmutableMap.of("eagle", 3, "hello", 1, "world", 1)); + } + + private void serializeResultToFile(String fileName, Object obj) { + try { + ObjectOutputStream out = new ObjectOutputStream(new FileOutputStream(fileName)); + out.writeObject(obj); + } catch (Exception e) { + LOGGER.error(String.valueOf(e)); + } + } + + private Object deserializeResultFromFile(String fileName) { + Map checkWordCount = null; + try { + ObjectInputStream in = new ObjectInputStream(new FileInputStream(fileName)); + checkWordCount = (Map) in.readObject(); + Assert.assertEquals(checkWordCount, + ImmutableMap.of("eagle", 3, "hello", 1, "world", 1)); + } catch (Exception e) { + LOGGER.error(String.valueOf(e)); + } + return checkWordCount; + } + + private static class WordAndCount implements Serializable { + + public final String word; + public final Integer count; + + public WordAndCount(String key, Integer count) { + this.word = key; + this.count = count; + } + } + + private void deleteResultFile(String path) { + File file = new File(path); + file.deleteOnExit(); + } +} diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java new file mode 100644 index 000000000000..2c105576df73 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/streamingqueue/Worker.java @@ -0,0 +1,280 @@ +package org.ray.streaming.runtime.streamingqueue; + +import java.lang.management.ManagementFactory; +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Random; +import org.ray.api.Ray; +import org.ray.api.RayActor; +import org.ray.api.annotation.RayRemote; +import org.ray.api.id.ActorId; +import org.ray.runtime.RayMultiWorkerNativeRuntime; +import org.ray.runtime.actor.NativeRayActor; +import org.ray.runtime.functionmanager.JavaFunctionDescriptor; +import org.ray.streaming.runtime.transfer.ChannelID; +import org.ray.streaming.runtime.transfer.DataMessage; +import org.ray.streaming.runtime.transfer.DataReader; +import org.ray.streaming.runtime.transfer.DataWriter; +import org.ray.streaming.runtime.transfer.TransferHandler; +import org.ray.streaming.util.Config; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import org.testng.Assert; + +public class Worker { + private static final Logger LOGGER = LoggerFactory.getLogger(Worker.class); + + protected TransferHandler transferHandler = null; + + public Worker() { + transferHandler = new TransferHandler(((RayMultiWorkerNativeRuntime) Ray.internal()) + .getCurrentRuntime().getNativeCoreWorkerPointer(), + new JavaFunctionDescriptor(Worker.class.getName(), + "onWriterMessage", "([B)V"), + new JavaFunctionDescriptor(Worker.class.getName(), + "onWriterMessageSync", "([B)[B"), + new JavaFunctionDescriptor(Worker.class.getName(), + "onReaderMessage", "([B)V"), + new JavaFunctionDescriptor(Worker.class.getName(), + "onReaderMessageSync", "([B)[B")); + } + + public void onReaderMessage(byte[] buffer) { + transferHandler.onReaderMessage(buffer); + } + + public byte[] onReaderMessageSync(byte[] buffer) { + return transferHandler.onReaderMessageSync(buffer); + } + + public void onWriterMessage(byte[] buffer) { + transferHandler.onWriterMessage(buffer); + } + + public byte[] onWriterMessageSync(byte[] buffer) { + return transferHandler.onWriterMessageSync(buffer); + } +} + +@RayRemote +class ReaderWorker extends Worker { + private static final Logger LOGGER = LoggerFactory.getLogger(ReaderWorker.class); + + private String name = null; + private List inputQueueList = null; + private List inputActorIds = new ArrayList<>(); + private DataReader dataReader = null; + private long handler = 0; + private RayActor peerActor = null; + private int msgCount = 0; + private int totalMsg = 0; + + public ReaderWorker(String name) { + LOGGER.info("ReaderWorker constructor"); + this.name = name; + } + + public String getName() { + String management = ManagementFactory.getRuntimeMXBean().getName(); + String pid = management.split("@")[0]; + + LOGGER.info("pid: {} name: {}", pid, name); + return name; + } + + public String testRayCall() { + LOGGER.info("testRayCall called"); + return "testRayCall"; + } + + public boolean init(List inputQueueList, RayActor peer, int msgCount) { + + this.inputQueueList = inputQueueList; + this.peerActor = peer; + this.msgCount = msgCount; + + LOGGER.info("ReaderWorker init"); + LOGGER.info("java.library.path = {}", System.getProperty("java.library.path")); + + for (String queue : this.inputQueueList) { + inputActorIds.add(this.peerActor.getId()); + LOGGER.info("ReaderWorker actorId: {}", this.peerActor.getId()); + } + + Map conf = new HashMap<>(); + + conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL); + conf.put(Config.CHANNEL_SIZE, "100000"); + conf.put(Config.STREAMING_JOB_NAME, "integrationTest1"); + dataReader = new DataReader(inputQueueList, inputActorIds, conf); + + // Should not GetBundle in RayCall thread + Thread readThread = new Thread(Ray.wrapRunnable(new Runnable() { + @Override + public void run() { + consume(); + } + })); + readThread.start(); + + LOGGER.info("ReaderWorker init done"); + + return true; + } + + public final void consume() { + + int checkPointId = 1; + for (int i = 0; i < msgCount * inputQueueList.size(); ++i) { + DataMessage dataMessage = dataReader.read(100); + + if (dataMessage == null) { + LOGGER.error("dataMessage is null"); + i--; + continue; + } + + int bufferSize = dataMessage.body().remaining(); + int dataSize = dataMessage.body().getInt(); + + // check size + LOGGER.info("capacity {} bufferSize {} dataSize {}", + dataMessage.body().capacity(), bufferSize, dataSize); + Assert.assertEquals(bufferSize, dataSize); + if (dataMessage instanceof DataMessage) { + if (LOGGER.isInfoEnabled()) { + LOGGER.info("{} : {} message.", i, dataMessage.toString()); + } + // check content + for (int j = 0; j < dataSize - 4; ++j) { + Assert.assertEquals(dataMessage.body().get(), (byte) j); + } + } else { + LOGGER.error("unknown message type"); + Assert.fail(); + } + + totalMsg++; + } + + LOGGER.info("ReaderWorker consume data done."); + } + + void onQueueTransfer(long handler, byte[] buffer) { + } + + + public boolean done() { + return totalMsg == msgCount; + } + + public int getTotalMsg() { + return totalMsg; + } +} + +@RayRemote +class WriterWorker extends Worker { + private static final Logger LOGGER = LoggerFactory.getLogger(WriterWorker.class); + + private String name = null; + private List outputQueueList = null; + private List outputActorIds = new ArrayList<>(); + DataWriter dataWriter = null; + RayActor peerActor = null; + int msgCount = 0; + + public WriterWorker(String name) { + this.name = name; + } + + public String getName() { + String management = ManagementFactory.getRuntimeMXBean().getName(); + String pid = management.split("@")[0]; + + LOGGER.info("pid: {} name: {}", pid, name); + return name; + } + + public String testCallReader(RayActor readerActor) { + String name = (String) Ray.call(ReaderWorker::getName, readerActor).get(); + LOGGER.info("testCallReader: {}", name); + return name; + } + + public boolean init(List outputQueueList, RayActor peer, int msgCount) { + + this.outputQueueList = outputQueueList; + this.peerActor = peer; + this.msgCount = msgCount; + + LOGGER.info("WriterWorker init:"); + + for (String queue : this.outputQueueList) { + outputActorIds.add(this.peerActor.getId()); + LOGGER.info("WriterWorker actorId: {}", this.peerActor.getId()); + } + + LOGGER.info("Peer isDirectActorCall: {}", ((NativeRayActor) peer).isDirectCallActor()); + int count = 3; + while (count-- != 0) { + Ray.call(ReaderWorker::testRayCall, peer).get(); + } + + try { + Thread.sleep(2 * 1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + Map conf = new HashMap<>(); + + conf.put(Config.CHANNEL_TYPE, Config.NATIVE_CHANNEL); + conf.put(Config.CHANNEL_SIZE, "100000"); + conf.put(Config.STREAMING_JOB_NAME, "integrationTest1"); + + dataWriter = new DataWriter(this.outputQueueList, this.outputActorIds, conf); + Thread writerThread = new Thread(Ray.wrapRunnable(new Runnable() { + @Override + public void run() { + produce(); + } + })); + writerThread.start(); + + LOGGER.info("WriterWorker init done"); + return true; + } + + public final void produce() { + + int checkPointId = 1; + Random random = new Random(); + this.msgCount = 100; + for (int i = 0; i < this.msgCount; ++i) { + for (int j = 0; j < outputQueueList.size(); ++j) { + LOGGER.info("WriterWorker produce"); + int dataSize = (random.nextInt(100)) + 10; + if (LOGGER.isInfoEnabled()) { + LOGGER.info("dataSize: {}", dataSize); + } + ByteBuffer bb = ByteBuffer.allocate(dataSize); + bb.putInt(dataSize); + for (int k = 0; k < dataSize - 4; ++k) { + bb.put((byte) k); + } + + bb.clear(); + ChannelID qid = ChannelID.from(outputQueueList.get(j)); + dataWriter.write(qid, bb); + } + } + try { + Thread.sleep(20 * 1000); + } catch (InterruptedException e) { + e.printStackTrace(); + } + } +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/transfer/ChannelIDTest.java b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/transfer/ChannelIDTest.java new file mode 100644 index 000000000000..654447f95774 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/java/org/ray/streaming/runtime/transfer/ChannelIDTest.java @@ -0,0 +1,22 @@ +package org.ray.streaming.runtime.transfer; + +import static org.testng.Assert.assertEquals; + + +import org.ray.streaming.runtime.util.EnvUtil; +import org.testng.annotations.Test; + +public class ChannelIDTest { + + static { + EnvUtil.loadNativeLibraries(); + } + + @Test + public void testIdStrToBytes() { + String idStr = ChannelID.genRandomIdStr(); + assertEquals(idStr.length(), ChannelID.ID_LENGTH * 2); + assertEquals(ChannelID.idStrToBytes(idStr).length, ChannelID.ID_LENGTH); + } + +} \ No newline at end of file diff --git a/streaming/java/streaming-runtime/src/test/resources/log4j.properties b/streaming/java/streaming-runtime/src/test/resources/log4j.properties new file mode 100644 index 000000000000..30d876aec121 --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/resources/log4j.properties @@ -0,0 +1,6 @@ +log4j.rootLogger=INFO, stdout +# Direct log messages to stdout +log4j.appender.stdout=org.apache.log4j.ConsoleAppender +log4j.appender.stdout.Target=System.out +log4j.appender.stdout.layout=org.apache.log4j.PatternLayout +log4j.appender.stdout.layout.ConversionPattern=%d{yyyy-MM-dd HH:mm:ss} %-5p %c{1}:%L - %m%n diff --git a/streaming/java/streaming-runtime/src/test/resources/ray.conf b/streaming/java/streaming-runtime/src/test/resources/ray.conf new file mode 100644 index 000000000000..fdc897fa624e --- /dev/null +++ b/streaming/java/streaming-runtime/src/test/resources/ray.conf @@ -0,0 +1,3 @@ +ray { + run-mode = SINGLE_PROCESS +} diff --git a/streaming/java/test.sh b/streaming/java/test.sh new file mode 100755 index 000000000000..c58f28f88aee --- /dev/null +++ b/streaming/java/test.sh @@ -0,0 +1,39 @@ +#!/usr/bin/env bash + +# Cause the script to exit if a single command fails. +set -e +# Show explicitly which commands are currently running. +set -x + +ROOT_DIR=$(cd "$(dirname "${BASH_SOURCE:-$0}")"; pwd) + +run_testng() { + "$@" || exit_code=$? + # exit_code == 2 means there are skipped tests. + if [ $exit_code -ne 2 ] && [ $exit_code -ne 0 ] ; then + exit $exit_code + fi +} + +echo "build ray streaming" +bazel build //streaming/java:all + +echo "Linting Java code with checkstyle." +bazel test //streaming/java:all --test_tag_filters="checkstyle" --build_tests_only + +echo "Running streaming tests." +run_testng java -cp $ROOT_DIR/../../bazel-bin/streaming/java/all_streaming_tests_deploy.jar\ + org.testng.TestNG -d /tmp/ray_streaming_java_test_output $ROOT_DIR/testng.xml + +echo "Streaming TestNG results" +cat /tmp/ray_streaming_java_test_output/testng-results.xml + +echo "Testing maven install." +cd $ROOT_DIR/../../java +echo "build ray maven deps" +bazel build gen_maven_deps +echo "maven install ray" +mvn clean install -DskipTests +cd $ROOT_DIR +echo "maven install ray streaming" +mvn clean install -DskipTests diff --git a/java/streaming/testng.xml b/streaming/java/testng.xml similarity index 100% rename from java/streaming/testng.xml rename to streaming/java/testng.xml diff --git a/streaming/src/channel.cc b/streaming/src/channel.cc index de7c99f8e0da..1edd99c1b433 100644 --- a/streaming/src/channel.cc +++ b/streaming/src/channel.cc @@ -205,15 +205,23 @@ struct MockQueueItem { std::shared_ptr data; }; -struct MockQueue { +class MockQueue { + public: std::unordered_map>> message_buffer_; std::unordered_map>> consumed_buffer_; + static std::mutex mutex; + static MockQueue &GetMockQueue() { + static MockQueue mock_queue; + return mock_queue; + } }; -static MockQueue mock_queue; +std::mutex MockQueue::mutex; StreamingStatus MockProducer::CreateTransferChannel() { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); mock_queue.message_buffer_[channel_info.channel_id] = std::make_shared>(500); mock_queue.consumed_buffer_[channel_info.channel_id] = @@ -222,12 +230,16 @@ StreamingStatus MockProducer::CreateTransferChannel() { } StreamingStatus MockProducer::DestroyTransferChannel() { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); mock_queue.message_buffer_.erase(channel_info.channel_id); mock_queue.consumed_buffer_.erase(channel_info.channel_id); return StreamingStatus::OK; } StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_size) { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); auto &ring_buffer = mock_queue.message_buffer_[channel_info.channel_id]; if (ring_buffer->Full()) { return StreamingStatus::OutOfMemory; @@ -244,6 +256,8 @@ StreamingStatus MockProducer::ProduceItemToChannel(uint8_t *data, uint32_t data_ StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_t *&data, uint32_t &data_size, uint32_t timeout) { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); auto &channel_id = channel_info.channel_id; if (mock_queue.message_buffer_.find(channel_id) == mock_queue.message_buffer_.end()) { return StreamingStatus::NoSuchItem; @@ -262,6 +276,8 @@ StreamingStatus MockConsumer::ConsumeItemFromChannel(uint64_t &offset_id, uint8_ } StreamingStatus MockConsumer::NotifyChannelConsumed(uint64_t offset_id) { + std::unique_lock lock(MockQueue::mutex); + MockQueue &mock_queue = MockQueue::GetMockQueue(); auto &channel_id = channel_info.channel_id; auto &ring_buffer = mock_queue.consumed_buffer_[channel_id]; while (!ring_buffer->Empty() && ring_buffer->Front().seq_id <= offset_id) { diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.cc new file mode 100644 index 000000000000..364d0af9f861 --- /dev/null +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.cc @@ -0,0 +1,17 @@ +#include "org_ray_streaming_runtime_transfer_ChannelID.h" +#include "streaming_jni_common.h" +using namespace ray::streaming; + +JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_createNativeID( + JNIEnv *env, jclass cls, jlong qid_address) { + auto id = ray::ObjectID::FromBinary( + std::string(reinterpret_cast(qid_address), ray::ObjectID::Size())); + return reinterpret_cast(new ray::ObjectID(id)); +} + +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_destroyNativeID( + JNIEnv *env, jclass cls, jlong native_id_ptr) { + auto id = reinterpret_cast(native_id_ptr); + STREAMING_CHECK(id != nullptr); + delete id; +} diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.h b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.h new file mode 100644 index 000000000000..c6353e9b1778 --- /dev/null +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_ChannelID.h @@ -0,0 +1,31 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_streaming_runtime_transfer_ChannelID */ + +#ifndef _Included_org_ray_streaming_runtime_transfer_ChannelID +#define _Included_org_ray_streaming_runtime_transfer_ChannelID +#ifdef __cplusplus +extern "C" { +#endif +#undef org_ray_streaming_runtime_transfer_ChannelID_ID_LENGTH +#define org_ray_streaming_runtime_transfer_ChannelID_ID_LENGTH 20L +/* + * Class: org_ray_streaming_runtime_transfer_ChannelID + * Method: createNativeID + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_createNativeID + (JNIEnv *, jclass, jlong); + +/* + * Class: org_ray_streaming_runtime_transfer_ChannelID + * Method: destroyNativeID + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_ChannelID_destroyNativeID + (JNIEnv *, jclass, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.cc new file mode 100644 index 000000000000..651ef6b32e54 --- /dev/null +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.cc @@ -0,0 +1,88 @@ +#include "org_ray_streaming_runtime_transfer_DataReader.h" +#include +#include "data_reader.h" +#include "runtime_context.h" +#include "streaming_jni_common.h" + +using namespace ray; +using namespace ray::streaming; + +JNIEXPORT jlong JNICALL +Java_org_ray_streaming_runtime_transfer_DataReader_createDataReaderNative( + JNIEnv *env, jclass, jobjectArray input_channels, jobjectArray input_actor_ids, + jlongArray seq_id_array, jlongArray msg_id_array, jlong timer_interval, + jboolean isRecreate, jbyteArray config_bytes, jboolean is_mock) { + STREAMING_LOG(INFO) << "[JNI]: create DataReader."; + std::vector input_channels_ids = + jarray_to_object_id_vec(env, input_channels); + std::vector actor_ids = jarray_to_actor_id_vec(env, input_actor_ids); + std::vector seq_ids = LongVectorFromJLongArray(env, seq_id_array).data; + std::vector msg_ids = LongVectorFromJLongArray(env, msg_id_array).data; + + auto ctx = std::make_shared(); + RawDataFromJByteArray conf(env, config_bytes); + if (conf.data_size > 0) { + STREAMING_LOG(INFO) << "load config, config bytes size: " << conf.data_size; + ctx->SetConfig(conf.data, conf.data_size); + } + if (is_mock) { + ctx->MarkMockTest(); + } + auto reader = new DataReader(ctx); + reader->Init(input_channels_ids, actor_ids, seq_ids, msg_ids, timer_interval); + STREAMING_LOG(INFO) << "create native DataReader succeed"; + return reinterpret_cast(reader); +} + +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_getBundleNative( + JNIEnv *env, jobject, jlong reader_ptr, jlong timeout_millis, jlong out, + jlong meta_addr) { + std::shared_ptr bundle; + auto reader = reinterpret_cast(reader_ptr); + auto status = reader->GetBundle((uint32_t)timeout_millis, bundle); + + // over timeout, return empty array. + if (StreamingStatus::Interrupted == status) { + throwChannelInterruptException(env, "reader interrupted."); + } else if (StreamingStatus::GetBundleTimeOut == status) { + } else if (StreamingStatus::InitQueueFailed == status) { + throwRuntimeException(env, "init channel failed"); + } else if (StreamingStatus::WaitQueueTimeOut == status) { + throwRuntimeException(env, "wait channel object timeout"); + } + + if (StreamingStatus::OK != status) { + *reinterpret_cast(out) = 0; + *reinterpret_cast(out + 8) = 0; + return; + } + + // bundle data + // In streaming queue, bundle data and metadata will be different args of direct call, + // so we separate it here for future extensibility. + *reinterpret_cast(out) = + reinterpret_cast(bundle->data + kMessageBundleHeaderSize); + *reinterpret_cast(out + 8) = bundle->data_size - kMessageBundleHeaderSize; + + // bundle metadata + auto meta = reinterpret_cast(meta_addr); + // bundle header written by writer + std::memcpy(meta, bundle->data, kMessageBundleHeaderSize); + // append qid + std::memcpy(meta + kMessageBundleHeaderSize, bundle->from.Data(), kUniqueIDSize); +} + +JNIEXPORT void JNICALL +Java_org_ray_streaming_runtime_transfer_DataReader_stopReaderNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + auto reader = reinterpret_cast(ptr); + reader->Stop(); +} + +JNIEXPORT void JNICALL +Java_org_ray_streaming_runtime_transfer_DataReader_closeReaderNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + delete reinterpret_cast(ptr); +} diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.h b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.h new file mode 100644 index 000000000000..f9f266a3d018 --- /dev/null +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataReader.h @@ -0,0 +1,45 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_streaming_runtime_transfer_DataReader */ + +#ifndef _Included_org_ray_streaming_runtime_transfer_DataReader +#define _Included_org_ray_streaming_runtime_transfer_DataReader +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_streaming_runtime_transfer_DataReader + * Method: createDataReaderNative + * Signature: ([[B[[B[J[JJZ[BZ)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_createDataReaderNative + (JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlongArray, jlong, jboolean, jbyteArray, jboolean); + +/* + * Class: org_ray_streaming_runtime_transfer_DataReader + * Method: getBundleNative + * Signature: (JJJJ)V + */ +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_getBundleNative + (JNIEnv *, jobject, jlong, jlong, jlong, jlong); + +/* + * Class: org_ray_streaming_runtime_transfer_DataReader + * Method: stopReaderNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_stopReaderNative + (JNIEnv *, jobject, jlong); + +/* + * Class: org_ray_streaming_runtime_transfer_DataReader + * Method: closeReaderNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataReader_closeReaderNative + (JNIEnv *, jobject, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.cc new file mode 100644 index 000000000000..439cd89a9974 --- /dev/null +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.cc @@ -0,0 +1,82 @@ +#include "org_ray_streaming_runtime_transfer_DataWriter.h" +#include "config/streaming_config.h" +#include "data_writer.h" +#include "streaming_jni_common.h" + +using namespace ray::streaming; + +JNIEXPORT jlong JNICALL +Java_org_ray_streaming_runtime_transfer_DataWriter_createWriterNative( + JNIEnv *env, jclass, jobjectArray output_queue_ids, jobjectArray output_actor_ids, + jlongArray msg_ids, jlong channel_size, jbyteArray conf_bytes_array, + jboolean is_mock) { + STREAMING_LOG(INFO) << "[JNI]: createDataWriterNative."; + std::vector queue_id_vec = + jarray_to_object_id_vec(env, output_queue_ids); + for (auto id : queue_id_vec) { + STREAMING_LOG(INFO) << "output channel id: " << id.Hex(); + } + STREAMING_LOG(INFO) << "total channel size: " << channel_size << "*" + << queue_id_vec.size() << "=" << queue_id_vec.size() * channel_size; + LongVectorFromJLongArray long_array_obj(env, msg_ids); + std::vector msg_ids_vec = LongVectorFromJLongArray(env, msg_ids).data; + std::vector queue_size_vec(long_array_obj.data.size(), channel_size); + std::vector remain_id_vec; + std::vector actor_ids = jarray_to_actor_id_vec(env, output_actor_ids); + + STREAMING_LOG(INFO) << "actor_ids: " << actor_ids[0]; + + RawDataFromJByteArray conf(env, conf_bytes_array); + STREAMING_CHECK(conf.data != nullptr); + auto runtime_context = std::make_shared(); + if (conf.data_size > 0) { + runtime_context->SetConfig(conf.data, conf.data_size); + } + if (is_mock) { + runtime_context->MarkMockTest(); + } + auto *data_writer = new DataWriter(runtime_context); + auto status = data_writer->Init(queue_id_vec, actor_ids, msg_ids_vec, queue_size_vec); + if (status != StreamingStatus::OK) { + STREAMING_LOG(WARNING) << "DataWriter init failed."; + } else { + STREAMING_LOG(INFO) << "DataWriter init success"; + } + + data_writer->Run(); + return reinterpret_cast(data_writer); +} + +JNIEXPORT jlong JNICALL +Java_org_ray_streaming_runtime_transfer_DataWriter_writeMessageNative( + JNIEnv *env, jobject, jlong writer_ptr, jlong qid_ptr, jlong address, jint size) { + auto *data_writer = reinterpret_cast(writer_ptr); + auto qid = *reinterpret_cast(qid_ptr); + auto data = reinterpret_cast(address); + auto data_size = static_cast(size); + jlong result = data_writer->WriteMessageToBufferRing(qid, data, data_size, + StreamingMessageType::Message); + + if (result == 0) { + STREAMING_LOG(INFO) << "writer interrupted, return 0."; + throwChannelInterruptException(env, "writer interrupted."); + } + return result; +} + +JNIEXPORT void JNICALL +Java_org_ray_streaming_runtime_transfer_DataWriter_stopWriterNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + STREAMING_LOG(INFO) << "jni: stop writer."; + auto *data_writer = reinterpret_cast(ptr); + data_writer->Stop(); +} + +JNIEXPORT void JNICALL +Java_org_ray_streaming_runtime_transfer_DataWriter_closeWriterNative(JNIEnv *env, + jobject thisObj, + jlong ptr) { + auto *data_writer = reinterpret_cast(ptr); + delete data_writer; +} \ No newline at end of file diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.h b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.h new file mode 100644 index 000000000000..a6fdf533c4d5 --- /dev/null +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_DataWriter.h @@ -0,0 +1,45 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_streaming_runtime_transfer_DataWriter */ + +#ifndef _Included_org_ray_streaming_runtime_transfer_DataWriter +#define _Included_org_ray_streaming_runtime_transfer_DataWriter +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_streaming_runtime_transfer_DataWriter + * Method: createWriterNative + * Signature: ([[B[[B[JJ[BZ)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_createWriterNative + (JNIEnv *, jclass, jobjectArray, jobjectArray, jlongArray, jlong, jbyteArray, jboolean); + +/* + * Class: org_ray_streaming_runtime_transfer_DataWriter + * Method: writeMessageNative + * Signature: (JJJI)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_writeMessageNative + (JNIEnv *, jobject, jlong, jlong, jlong, jint); + +/* + * Class: org_ray_streaming_runtime_transfer_DataWriter + * Method: stopWriterNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_stopWriterNative + (JNIEnv *, jobject, jlong); + +/* + * Class: org_ray_streaming_runtime_transfer_DataWriter + * Method: closeWriterNative + * Signature: (J)V + */ +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_DataWriter_closeWriterNative + (JNIEnv *, jobject, jlong); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc new file mode 100644 index 000000000000..43f5e4a087f4 --- /dev/null +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.cc @@ -0,0 +1,75 @@ +#include "org_ray_streaming_runtime_transfer_TransferHandler.h" +#include "queue/queue_client.h" +#include "streaming_jni_common.h" + +using namespace ray::streaming; + +static std::shared_ptr JByteArrayToBuffer(JNIEnv *env, + jbyteArray bytes) { + RawDataFromJByteArray buf(env, bytes); + STREAMING_CHECK(buf.data != nullptr); + + return std::make_shared(buf.data, buf.data_size, true); +} + +JNIEXPORT jlong JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative( + JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func, + jobject sync_func) { + auto ray_async_func = FunctionDescriptorToRayFunction(env, async_func); + auto ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func); + auto *writer_client = + new WriterClient(reinterpret_cast(core_worker_ptr), + ray_async_func, ray_sync_func); + return reinterpret_cast(writer_client); +} + +JNIEXPORT jlong JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative( + JNIEnv *env, jobject this_obj, jlong core_worker_ptr, jobject async_func, + jobject sync_func) { + ray::RayFunction ray_async_func = FunctionDescriptorToRayFunction(env, async_func); + ray::RayFunction ray_sync_func = FunctionDescriptorToRayFunction(env, sync_func); + auto *reader_client = + new ReaderClient(reinterpret_cast(core_worker_ptr), + ray_async_func, ray_sync_func); + return reinterpret_cast(reader_client); +} + +JNIEXPORT void JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *writer_client = reinterpret_cast(ptr); + writer_client->OnWriterMessage(JByteArrayToBuffer(env, bytes)); +} + +JNIEXPORT jbyteArray JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *writer_client = reinterpret_cast(ptr); + std::shared_ptr result_buffer = + writer_client->OnWriterMessageSync(JByteArrayToBuffer(env, bytes)); + jbyteArray arr = env->NewByteArray(result_buffer->Size()); + env->SetByteArrayRegion(arr, 0, result_buffer->Size(), + reinterpret_cast(result_buffer->Data())); + return arr; +} + +JNIEXPORT void JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *reader_client = reinterpret_cast(ptr); + reader_client->OnReaderMessage(JByteArrayToBuffer(env, bytes)); +} + +JNIEXPORT jbyteArray JNICALL +Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative( + JNIEnv *env, jobject this_obj, jlong ptr, jbyteArray bytes) { + auto *reader_client = reinterpret_cast(ptr); + auto result_buffer = reader_client->OnReaderMessageSync(JByteArrayToBuffer(env, bytes)); + + jbyteArray arr = env->NewByteArray(result_buffer->Size()); + env->SetByteArrayRegion(arr, 0, result_buffer->Size(), + reinterpret_cast(result_buffer->Data())); + return arr; +} \ No newline at end of file diff --git a/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.h b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.h new file mode 100644 index 000000000000..1cdc3e8abb78 --- /dev/null +++ b/streaming/src/lib/java/org_ray_streaming_runtime_transfer_TransferHandler.h @@ -0,0 +1,61 @@ +/* DO NOT EDIT THIS FILE - it is machine generated */ +#include +/* Header for class org_ray_streaming_runtime_transfer_TransferHandler */ + +#ifndef _Included_org_ray_streaming_runtime_transfer_TransferHandler +#define _Included_org_ray_streaming_runtime_transfer_TransferHandler +#ifdef __cplusplus +extern "C" { +#endif +/* + * Class: org_ray_streaming_runtime_transfer_TransferHandler + * Method: createWriterClientNative + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createWriterClientNative + (JNIEnv *, jobject, jlong, jobject, jobject); + +/* + * Class: org_ray_streaming_runtime_transfer_TransferHandler + * Method: createReaderClientNative + * Signature: (J)J + */ +JNIEXPORT jlong JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_createReaderClientNative + (JNIEnv *, jobject, jlong, jobject, jobject); + +/* + * Class: org_ray_streaming_runtime_transfer_TransferHandler + * Method: handleWriterMessageNative + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageNative + (JNIEnv *, jobject, jlong, jbyteArray); + +/* + * Class: org_ray_streaming_runtime_transfer_TransferHandler + * Method: handleWriterMessageSyncNative + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleWriterMessageSyncNative + (JNIEnv *, jobject, jlong, jbyteArray); + +/* + * Class: org_ray_streaming_runtime_transfer_TransferHandler + * Method: handleReaderMessageNative + * Signature: (J[B)V + */ +JNIEXPORT void JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageNative + (JNIEnv *, jobject, jlong, jbyteArray); + +/* + * Class: org_ray_streaming_runtime_transfer_TransferHandler + * Method: handleReaderMessageSyncNative + * Signature: (J[B)[B + */ +JNIEXPORT jbyteArray JNICALL Java_org_ray_streaming_runtime_transfer_TransferHandler_handleReaderMessageSyncNative + (JNIEnv *, jobject, jlong, jbyteArray); + +#ifdef __cplusplus +} +#endif +#endif diff --git a/streaming/src/lib/java/streaming_jni_common.cc b/streaming/src/lib/java/streaming_jni_common.cc new file mode 100644 index 000000000000..89dd7b75c0fd --- /dev/null +++ b/streaming/src/lib/java/streaming_jni_common.cc @@ -0,0 +1,123 @@ +#include "streaming_jni_common.h" + +std::vector +jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr) { + int stringCount = env->GetArrayLength(jarr); + std::vector object_id_vec; + for (int i = 0; i < stringCount; i++) { + auto jstr = (jbyteArray) (env->GetObjectArrayElement(jarr, i)); + UniqueIdFromJByteArray idFromJByteArray(env, jstr); + object_id_vec.push_back(idFromJByteArray.PID); + } + return object_id_vec; +} + +std::vector +jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr) { + int count = env->GetArrayLength(jarr); + std::vector actor_id_vec; + for (int i = 0; i < count; i++) { + auto bytes = (jbyteArray)(env->GetObjectArrayElement(jarr, i)); + std::string id_str(ray::ActorID::Size(), 0); + env->GetByteArrayRegion(bytes, 0, ray::ActorID::Size(), + reinterpret_cast(&id_str.front())); + actor_id_vec.push_back(ActorID::FromBinary(id_str)); + } + + return actor_id_vec; +} + +jint throwRuntimeException(JNIEnv *env, const char *message) { + jclass exClass; + char className[] = "java/lang/RuntimeException"; + exClass = env->FindClass(className); + return env->ThrowNew(exClass, message); +} + +jint throwChannelInitException(JNIEnv *env, const char *message, + const std::vector &abnormal_queues) { + jclass array_list_class = env->FindClass("java/util/ArrayList"); + jmethodID array_list_constructor = env->GetMethodID(array_list_class, "", "()V"); + jmethodID array_list_add = env->GetMethodID(array_list_class, "add", "(Ljava/lang/Object;)Z"); + jobject array_list = env->NewObject(array_list_class, array_list_constructor); + + for (auto &q_id : abnormal_queues) { + jbyteArray jbyte_array = env->NewByteArray(kUniqueIDSize); + env->SetByteArrayRegion(jbyte_array, 0, kUniqueIDSize, const_cast(reinterpret_cast(q_id.Data()))); + env->CallBooleanMethod(array_list, array_list_add, jbyte_array); + } + + jclass ex_class = env->FindClass("org/ray/streaming/runtime/transfer/ChannelInitException"); + jmethodID ex_constructor = env->GetMethodID(ex_class, "", "(Ljava/lang/String;Ljava/util/List;)V"); + jstring message_jstr = env->NewStringUTF(message); + jobject ex_obj = env->NewObject(ex_class, ex_constructor, message_jstr, array_list); + env->DeleteLocalRef(message_jstr); + return env->Throw((jthrowable)ex_obj); +} + +jint throwChannelInterruptException(JNIEnv *env, const char *message) { + jclass ex_class = env->FindClass("org/ray/streaming/runtime/transfer/ChannelInterruptException"); + return env->ThrowNew(ex_class, message); +} + +jclass LoadClass(JNIEnv *env, const char *class_name) { + jclass tempLocalClassRef = env->FindClass(class_name); + jclass ret = (jclass)env->NewGlobalRef(tempLocalClassRef); + STREAMING_CHECK(ret) << "Can't load Java class " << class_name; + env->DeleteLocalRef(tempLocalClassRef); + return ret; +} + +template +void JavaListToNativeVector( + JNIEnv *env, jobject java_list, std::vector *native_vector, + std::function element_converter) { + jclass java_list_class = LoadClass(env, "java/util/List"); + jmethodID java_list_size = env->GetMethodID(java_list_class, "size", "()I"); + jmethodID java_list_get = env->GetMethodID(java_list_class, "get", "(I)Ljava/lang/Object;"); + int size = env->CallIntMethod(java_list, java_list_size); + native_vector->clear(); + for (int i = 0; i < size; i++) { + native_vector->emplace_back( + element_converter(env, env->CallObjectMethod(java_list, java_list_get, (jint)i))); + } +} + +/// Convert a Java String to C++ std::string. +std::string JavaStringToNativeString(JNIEnv *env, jstring jstr) { + const char *c_str = env->GetStringUTFChars(jstr, nullptr); + std::string result(c_str); + env->ReleaseStringUTFChars(static_cast(jstr), c_str); + return result; +} + +/// Convert a Java List to C++ std::vector. +void JavaStringListToNativeStringVector(JNIEnv *env, jobject java_list, + std::vector *native_vector) { + JavaListToNativeVector( + env, java_list, native_vector, [](JNIEnv *env, jobject jstr) { + return JavaStringToNativeString(env, static_cast(jstr)); + }); +} + +ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor) { + jclass java_language_class = LoadClass(env, "org/ray/runtime/generated/Common$Language"); + jclass java_function_descriptor_class = + LoadClass(env, "org/ray/runtime/functionmanager/FunctionDescriptor"); + jmethodID java_language_get_number = env->GetMethodID(java_language_class, "getNumber", "()I"); + jmethodID java_function_descriptor_get_language = + env->GetMethodID(java_function_descriptor_class, "getLanguage", + "()Lorg/ray/runtime/generated/Common$Language;"); + jobject java_language = + env->CallObjectMethod(functionDescriptor, java_function_descriptor_get_language); + int language = env->CallIntMethod(java_language, java_language_get_number); + std::vector function_descriptor; + jmethodID java_function_descriptor_to_list = + env->GetMethodID(java_function_descriptor_class, "toList", "()Ljava/util/List;"); + JavaStringListToNativeStringVector( + env, env->CallObjectMethod(functionDescriptor, java_function_descriptor_to_list), + &function_descriptor); + ray::RayFunction ray_function{static_cast<::Language>(language), function_descriptor}; + return ray_function; +} + diff --git a/streaming/src/lib/java/streaming_jni_common.h b/streaming/src/lib/java/streaming_jni_common.h new file mode 100644 index 000000000000..921def7d0a12 --- /dev/null +++ b/streaming/src/lib/java/streaming_jni_common.h @@ -0,0 +1,111 @@ +#ifndef RAY_STREAMING_JNI_COMMON_H +#define RAY_STREAMING_JNI_COMMON_H + +#include +#include +#include "ray/core_worker/common.h" +#include "util/streaming_logging.h" + +class UniqueIdFromJByteArray { + private: + JNIEnv *_env; + jbyteArray _bytes; + jbyte *b; + + public: + ray::ObjectID PID; + + UniqueIdFromJByteArray(JNIEnv *env, jbyteArray wid) { + _env = env; + _bytes = wid; + + b = reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); + PID = ray::ObjectID::FromBinary( + std::string(reinterpret_cast(b), ray::ObjectID::Size())); + } + + ~UniqueIdFromJByteArray() { + _env->ReleaseByteArrayElements(_bytes, b, 0); + } +}; + +class RawDataFromJByteArray { + private: + JNIEnv *_env; + jbyteArray _bytes; + + public: + uint8_t *data; + uint32_t data_size; + + RawDataFromJByteArray(JNIEnv *env, jbyteArray bytes) { + _env = env; + _bytes = bytes; + data_size = _env->GetArrayLength(_bytes); + jbyte *b = + reinterpret_cast(_env->GetByteArrayElements(_bytes, nullptr)); + data = reinterpret_cast(b); + } + + ~RawDataFromJByteArray() { + _env->ReleaseByteArrayElements(_bytes, reinterpret_cast(data), 0); + } + +}; + +class StringFromJString { + private: + JNIEnv *_env; + const char *j_str; + jstring jni_str; + + public: + std::string str; + + StringFromJString(JNIEnv *env, jstring jni_str_) { + jni_str = jni_str_; + _env = env; + j_str = env->GetStringUTFChars(jni_str, nullptr); + str = std::string(j_str); + } + + ~StringFromJString() { + _env->ReleaseStringUTFChars(jni_str, j_str); + } + +}; + +class LongVectorFromJLongArray { + private: + JNIEnv *_env; + jlongArray long_array; + jlong *long_array_ptr = nullptr; + + public: + std::vector data; + + LongVectorFromJLongArray(JNIEnv *env, jlongArray long_array_) { + _env = env; + long_array = long_array_; + + long_array_ptr = env->GetLongArrayElements(long_array, nullptr); + jsize seq_id_size = env->GetArrayLength(long_array); + data = std::vector(long_array_ptr, long_array_ptr + seq_id_size); + } + + ~LongVectorFromJLongArray() { + _env->ReleaseLongArrayElements(long_array, long_array_ptr, 0); + } +}; + +std::vector +jarray_to_object_id_vec(JNIEnv *env, jobjectArray jarr); +std::vector +jarray_to_actor_id_vec(JNIEnv *env, jobjectArray jarr); + +jint throwRuntimeException(JNIEnv *env, const char *message); +jint throwChannelInitException(JNIEnv *env, const char *message, + const std::vector &abnormal_queues); +jint throwChannelInterruptException(JNIEnv *env, const char *message); +ray::RayFunction FunctionDescriptorToRayFunction(JNIEnv *env, jobject functionDescriptor); +#endif //RAY_STREAMING_JNI_COMMON_H