Skip to content

Commit

Permalink
mock-tasks: allow flows to be executed instead of tasks (#1042)
Browse files Browse the repository at this point in the history
  • Loading branch information
brig authored Nov 20, 2024
1 parent 73f9c70 commit 90c4b0d
Show file tree
Hide file tree
Showing 11 changed files with 192 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,4 +74,8 @@ public Serializable result() {
public String throwError() {
return MapUtils.getString(definition, "throwError");
}

public String executeFlow() {
return MapUtils.getString(definition, "executeFlow");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,22 @@
* =====
*/

import com.walmartlabs.concord.runtime.v2.model.FlowCall;
import com.walmartlabs.concord.runtime.v2.model.FlowCallOptions;
import com.walmartlabs.concord.runtime.v2.model.Location;
import com.walmartlabs.concord.runtime.v2.model.ProcessDefinition;
import com.walmartlabs.concord.runtime.v2.runner.vm.VMUtils;
import com.walmartlabs.concord.runtime.v2.sdk.*;
import com.walmartlabs.concord.runtime.v2.sdk.Compiler;
import com.walmartlabs.concord.sdk.MapUtils;
import com.walmartlabs.concord.svm.VM;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.lang.reflect.Array;
import java.util.Arrays;
import java.util.Map;
import java.util.function.Supplier;

public class MockTask implements Task {
Expand All @@ -50,7 +61,7 @@ public MockTask(Context ctx, String taskName,

@Override
public TaskResult execute(Variables input) throws Exception{
MockDefinition mockDefinition = mockDefinitionProvider.find(ctx, taskName, input);
var mockDefinition = mockDefinitionProvider.find(ctx, taskName, input);
if (mockDefinition == null) {
return delegate.get().execute(input);
}
Expand All @@ -61,13 +72,19 @@ public TaskResult execute(Variables input) throws Exception{
throw new UserDefinedException(mockDefinition.throwError());
}

boolean success = MapUtils.getBoolean(mockDefinition.out(), "ok", true);
var result = mockDefinition.out();
if (mockDefinition.executeFlow() != null) {
var flowResult = executeFlow(mockDefinition.executeFlow(), input.toMap());
result = assertMap(flowResult);
}

boolean success = MapUtils.getBoolean(result, "ok", true);
return TaskResult.of(success)
.values(mockDefinition.out());
.values(result);
}

public Object call(CustomTaskMethodResolver.InvocationContext ic, String method, Class<?>[] paramTypes, Object[] params) {
MockDefinition mockDefinition = mockDefinitionProvider.find(ctx, taskName, method, params);
var mockDefinition = mockDefinitionProvider.find(ctx, taskName, method, params);
if (mockDefinition == null) {
return ic.invoker().invoke(delegate.get(), method, paramTypes, params);
}
Expand All @@ -78,7 +95,12 @@ public Object call(CustomTaskMethodResolver.InvocationContext ic, String method,
throw new UserDefinedException(mockDefinition.throwError());
}

return mockDefinition.result();
var result = mockDefinition.result();
if (mockDefinition.executeFlow() != null) {
result = executeFlow(mockDefinition.executeFlow(), toMap(params));
}

return result;
}

public String taskName() {
Expand All @@ -88,4 +110,56 @@ public String taskName() {
public Class<? extends Task> originalTaskClass() {
return originalTaskClass;
}

@SuppressWarnings({"rawtypes", "unchecked"})
private Serializable executeFlow(String flowName, Map<String, Object> input) {
log.info("Executing flow '{}' to get mock results", flowName);

var runtime = ctx.execution().runtime();
var state = ctx.execution().state();
var compiler = runtime.getService(Compiler.class);
var pd = runtime.getService(ProcessDefinition.class);

var callOptions = FlowCallOptions.builder()
.input((Map)input)
.addOut("result")
.build();
var flowCallCommand = compiler.compile(pd, new FlowCall(Location.builder().build(), flowName, callOptions));

var currentThreadId = ctx.execution().currentThreadId();
var forkThreadId = state.nextThreadId();

state.fork(currentThreadId, forkThreadId, flowCallCommand);

var targetFrame = state.peekFrame(forkThreadId);
VMUtils.putLocals(targetFrame, VMUtils.getCombinedLocals(state, currentThreadId));

try {
var result = runtime.eval(state, forkThreadId);
return result.lastFrame().getLocal("result");
} catch (Exception e) {
throw new RuntimeException(e);
}
}

@SuppressWarnings({"unchecked"})
private Map<String, Object> assertMap(Serializable maybeMap) {
if (maybeMap == null) {
return Map.of();
}

if (maybeMap instanceof Map<?,?>) {
return (Map<String, Object>) maybeMap;
}

throw new IllegalArgumentException("Flow should set result as Map. Actual: " + maybeMap.getClass());
}

private static Map<String, Object> toMap(Object[] params) {
if (params == null || params.length == 0) {
return Map.of();
}

return Map.of("args", Arrays.asList(params));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,17 @@ public void testMethodMock() throws Exception {
assertLog(log, ".*result.ok: BOO.*");
}

@Test
public void testMethodMockWithFlowExecute() throws Exception {
runtime.deploy("method-mock-with-flow-execute");

byte[] log = runtime.run();
assertLog(log, ".*" + Pattern.quote("The actual 'testTask.myMethod()' is not being executed; this is a mock") + ".*");
assertLog(log, ".* Executing flow 'assertMyMethod' to get mock results.*");
assertLog(log, ".*" + Pattern.quote("flow can access method args: [1, b, false, [1, 2, 3], {k=v}]") + ".*");
assertLog(log, ".*result.ok: WOW.*");
}

@Test
public void testMethodMockWithAny() throws Exception {
runtime.deploy("method-mock-with-any");
Expand All @@ -60,4 +71,14 @@ public void testMethodMockWithAny() throws Exception {
assertLog(log, ".*" + Pattern.quote("The actual 'testTask.myMethod()' is not being executed; this is a mock") + ".*");
assertLog(log, ".*result.ok: BOO.*");
}

@Test
public void testTaskMockWithFlowExecute() throws Exception {
runtime.deploy("task-mock-with-flow-execute");

byte[] log = runtime.run();
assertLog(log, ".*testTaskLogic can access task input params: p1=value-1, p2=value-2.*");
assertLog(log, ".*The actual task is not being executed; this is a mock.*");
assertLog(log, ".*result.ok: .*fromMockAsFlow=WOW.*");
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
flows:
default:
- expr: "${testTask.myMethod(1, 'b', false, [1, 2, 3], {'k': 'v'})}"
out: result

- log: "result.ok: ${result}"

assertMyMethod:
- log: "flow can access method args: ${args}"

- set:
result: "WOW"

configuration:
arguments:
mocks:
- task: "testTask"
method: "myMethod"
args: [1, 'b', false, [1, 2, 3], {k: 'v'}]
executeFlow: "assertMyMethod"
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
flows:
default:
- task: testTask
in:
p1: "value-1"
p2: "value-2"
out: result

- log: "result.ok: ${result}"

testTaskLogic:
- log: "testTaskLogic can access task input params: p1=${p1}, p2=${p2}"

- set:
result:
fromMockAsFlow: "WOW"

# and override result

configuration:
arguments:
mocks:
- task: "testTask"
executeFlow: "testTaskLogic"
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,9 @@
*/

import com.google.inject.Injector;
import com.walmartlabs.concord.svm.Runtime;
import com.walmartlabs.concord.svm.*;
import com.walmartlabs.concord.runtime.v2.runner.vm.LoggedException;
import com.walmartlabs.concord.svm.State;
import com.walmartlabs.concord.svm.ThreadId;
import com.walmartlabs.concord.svm.VM;
import com.walmartlabs.concord.svm.Runtime;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand All @@ -49,17 +47,19 @@ public DefaultRuntime(VM vm, Injector injector) {

@Override
public void spawn(State state, ThreadId threadId) {
executor.submit(() -> {
try {
vm.eval(this, state, threadId);
} catch (LoggedException e) {
throw e.getCause();
} catch (Exception e) {
log.error("Error while evaluating commands for thread {}", threadId, e);
throw e;
}
return null;
});
executor.submit(() -> eval(state, threadId));
}

@Override
public EvalResult eval(State state, ThreadId threadId) throws Exception {
try {
return vm.eval(this, state, threadId);
} catch (LoggedException e) {
throw e.getCause();
} catch (Exception e) {
log.error("Error while evaluating commands for thread {}", threadId, e);
throw e;
}
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -81,4 +81,4 @@ public static void pushLogContext(ThreadId threadId, State state, LogContext log

private LogSegmentUtils() {
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -134,4 +134,4 @@ private static String getExceptionMessage(Exception e) {
List<Throwable> exceptions = ExceptionUtils.getExceptionList(e);
return exceptions.get(exceptions.size() - 1).getMessage();
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package com.walmartlabs.concord.svm;

import java.io.Serial;
import java.io.Serializable;

public class EvalResult implements Serializable {

@Serial
private static final long serialVersionUID = 1L;

private final Frame lastFrame;

public EvalResult(Frame lastFrame) {
this.lastFrame = lastFrame;
}

public Frame lastFrame() {
return lastFrame;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,11 @@ public interface Runtime {
*/
void spawn(State state, ThreadId threadId);

/**
* Runs the specified "vm" thread on current java thread.
*/
EvalResult eval(State state, ThreadId threadId) throws Exception;

/**
* Returns an instance of the specified service using the underlying injector.
*/
Expand Down
16 changes: 2 additions & 14 deletions runtime/v2/vm/src/main/java/com/walmartlabs/concord/svm/VM.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.Serializable;
import java.util.Collection;
import java.util.Map;
import java.util.Set;
Expand Down Expand Up @@ -60,7 +59,7 @@ public void start(State state) throws Exception {
throw e;
}

listeners.fireAfterProcessEnds(runtime, state, result.lastFrame);
listeners.fireAfterProcessEnds(runtime, state, result.lastFrame());

log.debug("start -> done");
}
Expand Down Expand Up @@ -92,7 +91,7 @@ public void resume(State state, Set<String> eventRefs) throws Exception {
throw e;
}

listeners.fireAfterProcessEnds(runtime, state, result.lastFrame);
listeners.fireAfterProcessEnds(runtime, state, result.lastFrame());

log.debug("resume ['{}'] -> done", eventRefs);
}
Expand Down Expand Up @@ -249,15 +248,4 @@ private static void wakeSuspended(State state) {
}
}
}

private static class EvalResult implements Serializable {

private static final long serialVersionUID = 1L;

private final Frame lastFrame;

private EvalResult(Frame lastFrame) {
this.lastFrame = lastFrame;
}
}
}

0 comments on commit 90c4b0d

Please sign in to comment.