From dfc95843265f5cc93d04569beb4ccf281c71240e Mon Sep 17 00:00:00 2001 From: blee-bot <93bslee@gmail.com> Date: Fri, 1 Nov 2024 13:38:00 +0900 Subject: [PATCH] Remove long exceeded line, make alias. Remove long exceeded line, make alias. ONE-DCO-1.0-Signed-off-by: Banseok Lee --- .../record-hessian/include/record-hessian/HessianComputer.h | 6 ++++-- compiler/record-hessian/src/HessianComputer.cpp | 3 ++- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/compiler/record-hessian/include/record-hessian/HessianComputer.h b/compiler/record-hessian/include/record-hessian/HessianComputer.h index 55dd2167af4..490303de73f 100644 --- a/compiler/record-hessian/include/record-hessian/HessianComputer.h +++ b/compiler/record-hessian/include/record-hessian/HessianComputer.h @@ -25,9 +25,11 @@ namespace record_hessian { /** - * @brief Record approximated hessian matrix from GPTQ paper(https://arxiv.org/abs/2210.17323). + * @brief Record approximated hessian matrix from + * GPTQ paper(https://arxiv.org/abs/2210.17323). */ using HessianMap = std::unordered_map>; +using HessianVectorMap = std::unordered_map; class HessianComputer { @@ -38,7 +40,7 @@ class HessianComputer std::unique_ptr getMap(); private: - std::unordered_map _hessian_map; + HessianVectorMap _hessian_map; const luci_interpreter::Tensor *_input_tensor = nullptr; void recordHessianForConv2D(const luci::CircleNode *node); diff --git a/compiler/record-hessian/src/HessianComputer.cpp b/compiler/record-hessian/src/HessianComputer.cpp index e4bdbe9dca4..9c22852ac48 100644 --- a/compiler/record-hessian/src/HessianComputer.cpp +++ b/compiler/record-hessian/src/HessianComputer.cpp @@ -23,7 +23,8 @@ namespace record_hessian /** * @brief unfold the vector with NHWC shape, inherently acting in an in-place manner. - * @note (N, H, W, C) -> (N, L, H*W*C). See details(https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html). + * @note (N, H, W, C) -> (N, L, H*W*C). + * See details(https://pytorch.org/docs/stable/generated/torch.nn.Unfold.html). */ void unfold(std::vector &buf, uint32_t input_n, uint32_t input_h, uint32_t input_w, uint32_t input_c, uint32_t stride_h, uint32_t stride_w, uint32_t dilation_h,