Skip to content

Commit

Permalink
Post-merge fixes.
Browse files Browse the repository at this point in the history
- new function configureDevice(Ptr<const Options> options);
- moved all code that reads options and sets flags backends into
  configureDevice()
  • Loading branch information
ugermann committed Jul 3, 2020
1 parent be99988 commit 5b93079
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 31 deletions.
12 changes: 1 addition & 11 deletions src/rescorer/rescorer.h
Original file line number Diff line number Diff line change
Expand Up @@ -73,17 +73,7 @@ class Rescore : public ModelTask {
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setOptimized8(options_->get<bool>("optimize8"));
graph->getBackend()->setShifted(options_->get<bool>("intgemm-shifted"));
graph->getBackend()->setShiftedAll(options_->get<bool>("intgemm-shifted-all"));
graph->getBackend()->setDumpQuantMult(options_->get<bool>("dump-quantmult"));
graph->getBackend()->setPrecomputedAlpha(options_->get<bool>("use-precomputed-alphas"));
graph->getBackend()->setLegacyBatchedGemm(options_->get<bool>("use-legacy-batching"));
}

graph->getBackend()->configureDevice(options_);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);
}
Expand Down
2 changes: 2 additions & 0 deletions src/tensors/backend.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
#pragma once

#include "common/definitions.h"
#include "common/options.h"
#include "tensors/rand.h"

namespace marian {
Expand All @@ -23,6 +24,7 @@ class Backend {

// for GPU only, calls cudaSetDevice, does nothing on CPU. Maybe change name.
virtual void setDevice() = 0;
virtual void configureDevice(Ptr<Options const> options) = 0;
virtual void synchronize() = 0;

virtual void setClip(float clipValue) { clipValue_ = clipValue; }
Expand Down
12 changes: 12 additions & 0 deletions src/tensors/cpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,19 @@ class Backend : public marian::Backend {

public:
Backend(DeviceId deviceId, size_t seed) : marian::Backend(deviceId, seed) {}

void setDevice() override {}

void configureDevice(Ptr<Options const> options) override {
setClip(options->get<float>("clip-gemm"));
setOptimized(options->get<bool>("optimize"));
setOptimized8(options->get<bool>("optimize8"));
setShifted(options->get<bool>("intgemm-shifted"));
setShiftedAll(options->get<bool>("intgemm-shifted-all"));
setDumpQuantMult(options->get<bool>("dump-quantmult"));
setPrecomputedAlpha(options->get<bool>("use-precomputed-alphas"));
setLegacyBatchedGemm(options->get<bool>("use-legacy-batching"));
}
void synchronize() override {}

// for CPU & inference only, sets to use optimized code for inference. Does nothing for GPU.
Expand Down
5 changes: 5 additions & 0 deletions src/tensors/gpu/backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ class Backend : public marian::Backend {
}
}

void configureDevice(Ptr<Options const> options) override {
setClip(options->get<float>("clip-gemm"));
}


void setDevice() override { CUDA_CHECK(cudaSetDevice((int)deviceId_.no)); }

void synchronize() override { CUDA_CHECK(cudaStreamSynchronize(0)); }
Expand Down
22 changes: 2 additions & 20 deletions src/translator/translator.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,16 +85,7 @@ class Translate : public ModelTask {
auto prec = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(prec[0]));
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setOptimized8(options_->get<bool>("optimize8"));
graph->getBackend()->setShifted(options_->get<bool>("intgemm-shifted"));
graph->getBackend()->setShiftedAll(options_->get<bool>("intgemm-shifted-all"));
graph->getBackend()->setDumpQuantMult(options_->get<bool>("dump-quantmult"));
graph->getBackend()->setPrecomputedAlpha(options_->get<bool>("use-precomputed-alphas"));
graph->getBackend()->setLegacyBatchedGemm(options_->get<bool>("use-legacy-batching"));
}
graph->getBackend()->configureDevice(options_);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_[id] = graph;

Expand Down Expand Up @@ -233,16 +224,7 @@ class TranslateService : public ModelServiceTask {
auto precison = options_->get<std::vector<std::string>>("precision", {"float32"});
graph->setDefaultElementType(typeFromString(precison[0])); // only use first type, used for parameter type in graph
graph->setDevice(device);
graph->getBackend()->setClip(options_->get<float>("clip-gemm"));
if (device.type == DeviceType::cpu) {
graph->getBackend()->setOptimized(options_->get<bool>("optimize"));
graph->getBackend()->setOptimized8(options_->get<bool>("optimize8"));
graph->getBackend()->setShifted(options_->get<bool>("intgemm-shifted"));
graph->getBackend()->setShiftedAll(options_->get<bool>("intgemm-shifted-all"));
graph->getBackend()->setDumpQuantMult(options_->get<bool>("dump-quantmult"));
graph->getBackend()->setPrecomputedAlpha(options_->get<bool>("use-precomputed-alphas"));
graph->getBackend()->setLegacyBatchedGemm(options_->get<bool>("use-legacy-batching"));
}
graph->getBackend()->configureDevice(options_);
graph->reserveWorkspaceMB(options_->get<size_t>("workspace"));
graphs_.push_back(graph);

Expand Down

0 comments on commit 5b93079

Please sign in to comment.