Skip to content

Commit

Permalink
[misc] remove Scheduler::Run()
Browse files Browse the repository at this point in the history
  • Loading branch information
ouonline committed Dec 7, 2023
1 parent 96bdc59 commit b7aa90c
Show file tree
Hide file tree
Showing 7 changed files with 26 additions and 45 deletions.
16 changes: 2 additions & 14 deletions src/ppl/nn/engines/cuda/graph_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,8 +127,8 @@ RetCode CudaGraphScheduler::ExecForEach(KernelExecContext& ctx,
return RC_SUCCESS;
}

RetCode CudaGraphScheduler::DoForEach(const function<RetCode(KernelImpl*, KernelExecContext*)>& exec,
Profiler* profiler) {
RetCode CudaGraphScheduler::ForEach(const function<RetCode(KernelImpl*, KernelExecContext*)>& exec,
Profiler* profiler) {
KernelExecContext ctx;
ctx.SetAcquireFunc(acquire_object_func_);
ctx.SetProfilingFlag((profiler != nullptr));
Expand Down Expand Up @@ -161,18 +161,6 @@ RetCode CudaGraphScheduler::DoForEach(const function<RetCode(KernelImpl*, Kernel
return RC_SUCCESS;
}

RetCode CudaGraphScheduler::Run(Profiler* profiler) {
return DoForEach(
[](KernelImpl* kernel, KernelExecContext* ctx) -> RetCode {
return kernel->Execute(ctx);
},
profiler);
}

RetCode CudaGraphScheduler::ForEach(const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>& f) {
return DoForEach(f, nullptr);
}

void CudaGraphScheduler::GraphRunnerAddDevice(const CudaDevice* dev) {
graph_runner_.AddDevice(dev);
return;
Expand Down
6 changes: 3 additions & 3 deletions src/ppl/nn/engines/cuda/graph_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ class CudaGraphScheduler final : public Scheduler {
public:
CudaGraphScheduler();
ppl::common::RetCode Init(const Options&) override;
ppl::common::RetCode ForEach(const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>&) override;
ppl::common::RetCode Run(Profiler*) override;
ppl::common::RetCode ForEach(const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>&,
Profiler*) override;
void GraphRunnerAddDevice(const CudaDevice* dev);

private:
Expand Down Expand Up @@ -59,4 +59,4 @@ class CudaGraphScheduler final : public Scheduler {
};
}}} // namespace ppl::nn::cuda

#endif
#endif
6 changes: 5 additions & 1 deletion src/ppl/nn/runtime/partition_runner_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,11 @@ RetCode PartitionRunnerImpl::Sync() {
}

RetCode PartitionRunnerImpl::Run() {
auto rc = sched_->Run(nullptr);
auto rc = sched_->ForEach(
[](KernelImpl* kernel, KernelExecContext* ctx) -> RetCode {
return kernel->Execute(ctx);
},
nullptr);
if (rc != RC_SUCCESS) {
LOG(ERROR) << "PartitionRunner Run() failed: " << GetRetCodeStr(rc);
return rc;
Expand Down
14 changes: 10 additions & 4 deletions src/ppl/nn/runtime/runtime_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,11 @@ RetCode RuntimeImpl::RunAsync() {
constexpr Profiler* profiler = nullptr;
#endif

auto status = sched_->Run(profiler);
auto status = sched_->ForEach(
[](KernelImpl* kernel, KernelExecContext* ctx) -> RetCode {
return kernel->Execute(ctx);
},
profiler);
if (status != RC_SUCCESS) {
LOG(ERROR) << "Run() failed: " << GetRetCodeStr(status);
}
Expand Down Expand Up @@ -491,9 +495,11 @@ RetCode RuntimeImpl::ConfSetProfilingFlag(RuntimeImpl* rt, va_list args) {
}

RetCode RuntimeImpl::ConfInferShapes(RuntimeImpl* rt, va_list) {
return rt->sched_->ForEach([](KernelImpl* kernel, KernelExecContext* ctx) -> RetCode {
return kernel->Reshape(ctx);
});
return rt->sched_->ForEach(
[](KernelImpl* kernel, KernelExecContext* ctx) -> RetCode {
return kernel->Reshape(ctx);
},
nullptr);
}

RetCode RuntimeImpl::ConfSetScheduler(RuntimeImpl* rt, va_list args) {
Expand Down
5 changes: 2 additions & 3 deletions src/ppl/nn/runtime/scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,8 @@ class Scheduler {
public:
virtual ~Scheduler() {}
virtual ppl::common::RetCode Init(const Options&) = 0;
virtual ppl::common::RetCode ForEach(
const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>&) = 0;
virtual ppl::common::RetCode Run(Profiler*) = 0;
virtual ppl::common::RetCode ForEach(const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>&,
Profiler*) = 0;
};

}} // namespace ppl::nn
Expand Down
16 changes: 2 additions & 14 deletions src/ppl/nn/runtime/sequential_scheduler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,8 @@ RetCode SequentialScheduler::Init(const Options& options) {
return RC_SUCCESS;
}

RetCode SequentialScheduler::DoForEach(const function<RetCode(KernelImpl*, KernelExecContext*)>& exec,
Profiler* profiler) {
RetCode SequentialScheduler::ForEach(const function<RetCode(KernelImpl*, KernelExecContext*)>& exec,
Profiler* profiler) {
#ifndef NDEBUG
set<edgeid_t> edges_before;
for (uint32_t i = 0; i < edgeid2object_->size(); ++i) {
Expand Down Expand Up @@ -178,16 +178,4 @@ RetCode SequentialScheduler::DoForEach(const function<RetCode(KernelImpl*, Kerne
return RC_SUCCESS;
}

RetCode SequentialScheduler::Run(Profiler* profiler) {
return DoForEach(
[](KernelImpl* kernel, KernelExecContext* ctx) -> RetCode {
return kernel->Execute(ctx);
},
profiler);
}

RetCode SequentialScheduler::ForEach(const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>& f) {
return DoForEach(f, nullptr);
}

}} // namespace ppl::nn
8 changes: 2 additions & 6 deletions src/ppl/nn/runtime/sequential_scheduler.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,8 @@ class SequentialScheduler final : public Scheduler {
public:
SequentialScheduler();
ppl::common::RetCode Init(const Options&) override;
ppl::common::RetCode ForEach(const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>&) override;
ppl::common::RetCode Run(Profiler*) override;

private:
ppl::common::RetCode DoForEach(const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>&,
Profiler*);
ppl::common::RetCode ForEach(const std::function<ppl::common::RetCode(KernelImpl*, KernelExecContext*)>&,
Profiler*) override;

private:
const ir::GraphTopo* topo_;
Expand Down

0 comments on commit b7aa90c

Please sign in to comment.