Skip to content

Commit

Permalink
runtime-v2: fix itemIndex in parallel loop (#904)
Browse files Browse the repository at this point in the history
  • Loading branch information
brig authored May 11, 2024
1 parent eeaa674 commit a9de18e
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,14 @@ protected void eval(Runtime runtime, State state, ThreadId threadId, Context ctx
int batchSize = toBatchSize(runtime, ctx, parallelism);

List<ArrayList<Serializable>> batches = batches(items, batchSize);
int itemIndexStart = 0;
for (ArrayList<Serializable> batch : batches) {
evalBatch(state, threadId, batch, outVarsAccumulator);
evalBatch(itemIndexStart, state, threadId, batch, outVarsAccumulator);
itemIndexStart += batch.size();
}
}

private void evalBatch(State state, ThreadId threadId, ArrayList<Serializable> items, Map<String, List<Serializable>> outVarsAccumulator) {
private void evalBatch(int itemIndexStart, State state, ThreadId threadId, ArrayList<Serializable> items, Map<String, List<Serializable>> outVarsAccumulator) {
Frame frame = state.peekFrame(threadId);

List<Map.Entry<ThreadId, Serializable>> forks = items.stream()
Expand All @@ -177,13 +179,12 @@ private void evalBatch(State state, ThreadId threadId, ArrayList<Serializable> i

for (int i = 0; i < forks.size(); i++) {
Map.Entry<ThreadId, Serializable> f = forks.get(i);

Frame cmdFrame = Frame.builder()
.nonRoot()
.build();

cmdFrame.setLocal(CURRENT_ITEMS, items);
cmdFrame.setLocal(CURRENT_INDEX, i);
cmdFrame.setLocal(CURRENT_INDEX, itemIndexStart + i);
cmdFrame.setLocal(CURRENT_ITEM, f.getValue());

// fork will create rootFrame for forked commands
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1814,6 +1814,27 @@ public void testStringIfExpression() throws Exception {
assertLog(log, ".*it's true.*");
}

@Test
public void testParallelLoopItemIndex() throws Exception {
deploy("parallelLoopItemIndex");

save(ProcessConfiguration.builder()
.build());

byte[] log = run();
assertLog(log, ".*serial: five==5.*");
assertLog(log, ".*serial: four==4.*");
assertLog(log, ".*serial: three==3.*");
assertLog(log, ".*serial: two==2.*");
assertLog(log, ".*serial: one==1.*");

assertLog(log, ".*parallel: five==5.*");
assertLog(log, ".*parallel: four==4.*");
assertLog(log, ".*parallel: three==3.*");
assertLog(log, ".*parallel: two==2.*");
assertLog(log, ".*parallel: one==1.*");
}

private void deploy(String resource) throws URISyntaxException, IOException {
Path src = Paths.get(MainTest.class.getResource(resource).toURI());
IOUtils.copy(src, workDir);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
flows:
default:
# serial
- call: main
in:
item: "${item}"
index: "${itemIndex}"
prefix: "serial"
out: x
loop:
mode: serial
items: ['one', 'two', 'three', 'four', 'five']

# parallel
- call: main
in:
item: "${item}"
index: "${itemIndex}"
prefix: "parallel"
out: x
loop:
mode: parallel
items: ['one', 'two', 'three', 'four', 'five']
parallelism: 2

main:
- log: "${prefix += ': ' += item += '==' += (itemIndex+1)}"

0 comments on commit a9de18e

Please sign in to comment.