Skip to content

Commit

Permalink
Fix sequence continuous batching close session race condition (#3198)
Browse files Browse the repository at this point in the history
* add logging to trace jobgroup cleanup

* Monitor eventJobGroupIds

* Revert "Monitor eventJobGroupIds"

This reverts commit 70ef9b0.

* Log reset job group Ids

* Test adding job to evenJobGroupIds after completing streaming request

* Revert "Test adding job to evenJobGroupIds after completing streaming request"

This reverts commit ab78a9a.

* Force cleanup job group

* Repeat close session request to follow through with cleanup

* Improve detection of close session

* formatJava

* test not adding dummy job to closed job group

* Revert "test not adding dummy job to closed job group"

This reverts commit 51d706a.

* Remove debug logging

* comments about fix

* Avoid duplicate CompletableFutures

* Check executor task status using CompletableFuture object

* Track available local capacity for a worker

* Fix computation of capacity values

* Update test to check session cleanup

* Update pollQueueTasks key for pollJobGroup task

* formatJava
  • Loading branch information
namannandan authored Jun 22, 2024
1 parent 688f09e commit 4c96e6f
Show file tree
Hide file tree
Showing 2 changed files with 92 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import java.util.LinkedHashSet;
import java.util.LinkedList;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import org.pytorch.serve.job.Job;
import org.pytorch.serve.job.JobGroup;
import org.pytorch.serve.util.messages.BaseModelRequest;
Expand All @@ -34,16 +36,20 @@ public class SequenceBatching extends BatchAggregator {
// A list of jobGroupIds which are added into current batch. These jobGroupIds need to be added
// back to eventJobGroupIds once their jobs are processed by a batch.
protected LinkedList<String> currentJobGroupIds;
private int localCapacity;
private AtomicInteger localCapacity;
private AtomicBoolean running = new AtomicBoolean(true);
// HashMap to track poll queue tasks in the executor queue
private ConcurrentHashMap<String, CompletableFuture<Void>> pollQueueTasks =
new ConcurrentHashMap<String, CompletableFuture<Void>>();

public SequenceBatching(Model model) {
super(model);
this.localCapacity =
new AtomicInteger(Math.max(1, model.getMaxNumSequence() / model.getMinWorkers()));
this.currentJobGroupIds = new LinkedList<>();
this.pollExecutors = Executors.newFixedThreadPool(model.getBatchSize() + 1);
this.pollExecutors = Executors.newFixedThreadPool(localCapacity.get() + 1);
this.jobsQueue = new LinkedBlockingDeque<>();
this.isPollJobGroup = new AtomicBoolean(false);
this.localCapacity = model.getMaxNumSequence() / model.getMinWorkers();
this.eventJobGroupIds = new LinkedBlockingDeque<>();
this.eventJobGroupIds.add("");
this.eventDispatcher = new Thread(new EventDispatcher());
Expand All @@ -70,8 +76,9 @@ private void pollJobGroup() throws InterruptedException {

int quota =
Math.min(
this.localCapacity - jobsQueue.size(),
model.getPendingJobGroups().size() / model.getMaxWorkers());
this.localCapacity.get(),
Math.max(
1, model.getPendingJobGroups().size() / model.getMaxWorkers()));
if (quota > 0 && model.getPendingJobGroups().size() > 0) {
model.getPendingJobGroups().drainTo(tmpJobGroups, quota);
}
Expand Down Expand Up @@ -120,6 +127,8 @@ private void cleanJobGroup(String jobGroupId) {
logger.debug("Clean jobGroup: {}", jobGroupId);
if (jobGroupId != null) {
model.removeJobGroup(jobGroupId);
pollQueueTasks.remove(jobGroupId);
localCapacity.incrementAndGet();
}
}

Expand Down Expand Up @@ -176,6 +185,7 @@ public void shutdownExecutors() {

private void addJobGroup(String jobGroupId) {
if (jobGroupId != null) {
localCapacity.decrementAndGet();
eventJobGroupIds.add(jobGroupId);
}
}
Expand All @@ -192,22 +202,39 @@ public void run() {
String jobGroupId =
eventJobGroupIds.poll(model.getMaxBatchDelay(), TimeUnit.MILLISECONDS);
if (jobGroupId == null || jobGroupId.isEmpty()) {
CompletableFuture.runAsync(
() -> {
try {
pollJobGroup();
} catch (InterruptedException e) {
logger.error("Failed to poll a job group", e);
}
},
pollExecutors);
// Skip fetching new job groups when no capacity is available
if (localCapacity.get() <= 0) {
continue;
}
// Avoid duplicate poll tasks in the executor queue
if (pollQueueTasks.containsKey("pollJobGroup")
&& !pollQueueTasks.get("pollJobGroup").isDone()) {
continue;
}
CompletableFuture<Void> pollTask =
CompletableFuture.runAsync(
() -> {
try {
pollJobGroup();
} catch (InterruptedException e) {
logger.error("Failed to poll a job group", e);
}
},
pollExecutors);
pollQueueTasks.put("pollJobGroup", pollTask);
} else {

CompletableFuture.runAsync(
() -> {
pollJobFromJobGroup(jobGroupId);
},
pollExecutors);
// Avoid duplicate poll tasks in the executor queue
if (pollQueueTasks.containsKey(jobGroupId)
&& !pollQueueTasks.get(jobGroupId).isDone()) {
continue;
}
CompletableFuture<Void> pollTask =
CompletableFuture.runAsync(
() -> {
pollJobFromJobGroup(jobGroupId);
},
pollExecutors);
pollQueueTasks.put(jobGroupId, pollTask);
}
} catch (InterruptedException e) {
if (running.get()) {
Expand All @@ -224,7 +251,7 @@ private void pollJobFromJobGroup(String jobGroupId) {
if (!jobGroup.isFinished()) {
job = jobGroup.pollJob(model.getSequenceMaxIdleMSec());
}
if (job == null) {
if (job == null || jobGroup.isFinished()) {
// JobGroup expired, clean it.
cleanJobGroup(jobGroupId);
// intent to add new job groups.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
maxWorkers: 2
batchSize: 1
maxNumSequence: 2
sequenceMaxIdleMSec: 5000
sequenceMaxIdleMSec: 60000
maxSequenceJobQueueSize: 10
sequenceBatching: true
continuousBatching: true
Expand Down Expand Up @@ -219,39 +219,51 @@ def test_infer_stateful_cancel(mar_file_path, model_store):

try:
test_utils.reg_resp = test_utils.register_model_with_params(params)
with requests.post(
url=f"http://localhost:8080/predictions/{model_name}",
data=str(2).encode(),
) as response:
s_id = response.headers.get("ts_request_sequence_id")
headers = {
"ts_request_sequence_id": s_id,
}

t0 = threading.Thread(
target=__infer_stateful_cancel,
args=(
model_name,
False,
headers,
"5",
),
)
t1 = threading.Thread(
target=__infer_stateful_cancel,
args=(
model_name,
True,
headers,
"-1",
),
)

t0.start()
t1.start()
# Open and close sesions multiple times(>maxNumSequence) to test session clean up after stream response
for _ in range(4):
with requests.post(
url=f"http://localhost:8080/predictions/{model_name}",
data=str(2).encode(),
) as response:
s_id = response.headers.get("ts_request_sequence_id")
headers = {
"ts_request_sequence_id": s_id,
}

t0 = threading.Thread(
target=__infer_stateful_cancel,
args=(
model_name,
False,
headers,
"5",
),
)
t1 = threading.Thread(
target=__infer_stateful_cancel,
args=(
model_name,
True,
headers,
"-1",
),
)

t0.start()
t1.start()

t0.join()
t1.join()

# Close session after cancellation request to free up session capacity
with requests.post(
url=f"http://localhost:8080/predictions/{model_name}",
headers=headers,
data=str(0).encode(),
) as response:
assert response.status_code == 200

t0.join()
t1.join()
finally:
test_utils.unregister_model(model_name)

Expand Down

0 comments on commit 4c96e6f

Please sign in to comment.