From 22842e111c88cc54c31f6805d50accc06056ec62 Mon Sep 17 00:00:00 2001 From: Nikolay Bogoychev Date: Fri, 10 Jul 2020 08:16:06 +0000 Subject: [PATCH] easier intgemm cmdparsing --- src/common/config_parser.cpp | 14 +++----------- src/tensors/backend.h | 36 ++++++++++++++++++++++++++++++++++++ src/tensors/cpu/backend.h | 7 +------ 3 files changed, 40 insertions(+), 17 deletions(-) diff --git a/src/common/config_parser.cpp b/src/common/config_parser.cpp index 33558abec..4eac2c775 100755 --- a/src/common/config_parser.cpp +++ b/src/common/config_parser.cpp @@ -629,18 +629,10 @@ void ConfigParser::addOptionsTranslation(cli::CLIWrapper& cli) { addSuboptionsDevices(cli); addSuboptionsBatching(cli); - cli.add("--optimize", - "Optimize speed aggressively sacrificing memory or precision by using 16bit integer CPU multiplication. Only available on CPU"); - cli.add("--optimize8", - "Optimize speed even more aggressively sacrificing memory or precision by using 8bit integer CPU multiplication. Only available on CPU"); - cli.add("--intgemm-shifted", - "Use a shifted GEMM implementation. Only available with intgemm8."); - cli.add("--intgemm-shifted-all", - "Use a shifted GEMM implementation even for operations without biases. Only available with intgemm8."); + cli.add("--gemm-precision", + "Use lower precision for the GEMM operations only. Supported values: float32, int16, int8, int8shift, int8shiftAlpha, int8shiftAll, int8shiftAlphaAll", "float32"); cli.add("--dump-quantmult", - "Dump the quantization multipliers during an avarage run."); - cli.add("--use-precomputed-alphas", - "Use precomputed alphas for bias calculation."); + "Dump the quantization multipliers during an avarage run. To be used to compute alphas for ---gemm-precision int8shiftAlpha or int8shiftAlphaAll."); cli.add("--use-legacy-batching", "Use legacy codepath with a for loop of cblas_sgemm, instead of cblas_sgemm_batched."); cli.add("--skip-cost", diff --git a/src/tensors/backend.h b/src/tensors/backend.h index d1e567918..f450e823b 100644 --- a/src/tensors/backend.h +++ b/src/tensors/backend.h @@ -27,6 +27,42 @@ class Backend { virtual void configureDevice(Ptr options) = 0; virtual void synchronize() = 0; + virtual void configureIntgemm(Ptr options) { + std::string gemmPrecision = options->get("gemm-precision"); + bool dumpQuantMults = options->get("dump-quantmult"); + if (dumpQuantMults) { + setOptimized8(true); + setShifted(true); + setShiftedAll(true); + setDumpQuantMult(true); + //float32, int16, int8, int8shift, int8shiftAlpha, int8shiftAll, int8shiftAlphaAll + } else if (gemmPrecision == "float32") { + // Default case, all variables are false. Do nothing + } else if (gemmPrecision == "int16") { + setOptimized(true); + } else if (gemmPrecision == "int8") { + setOptimized8(true); + } else if (gemmPrecision == "int8shift") { + setOptimized8(true); + setShifted(true); + } else if (gemmPrecision == "int8shiftAlpha") { + setOptimized8(true); + setShifted(true); + setPrecomputedAlpha(true); + } else if (gemmPrecision == "int8shiftAll") { + setOptimized8(true); + setShifted(true); + setShiftedAll(true); + } else if (gemmPrecision == "int8shiftAlphaAll") { + setOptimized8(true); + setShifted(true); + setShiftedAll(true); + setPrecomputedAlpha(true); + } else { + ABORT("Unknown option {} for command line parameter gemm-precision.", gemmPrecision); + } + } + virtual void setClip(float clipValue) { clipValue_ = clipValue; } float getClip() { return clipValue_; } diff --git a/src/tensors/cpu/backend.h b/src/tensors/cpu/backend.h index a757b59e3..49c65235d 100644 --- a/src/tensors/cpu/backend.h +++ b/src/tensors/cpu/backend.h @@ -25,13 +25,8 @@ class Backend : public marian::Backend { void setDevice() override {} void configureDevice(Ptr options) override { + configureIntgemm(options); setClip(options->get("clip-gemm")); - setOptimized(options->get("optimize")); - setOptimized8(options->get("optimize8")); - setShifted(options->get("intgemm-shifted")); - setShiftedAll(options->get("intgemm-shifted-all")); - setDumpQuantMult(options->get("dump-quantmult")); - setPrecomputedAlpha(options->get("use-precomputed-alphas")); setLegacyBatchedGemm(options->get("use-legacy-batching")); } void synchronize() override {}