From 00f30732c5198826794616bce127f88264e23d41 Mon Sep 17 00:00:00 2001 From: blee-bot <93bslee@gmail.com> Date: Mon, 4 Nov 2024 10:06:19 +0900 Subject: [PATCH] Split codes into core snippets for small PR. Split codes into core snippets for small PR. ONE-DCO-1.0-Signed-off-by: Banseok Lee --- .../include/record-hessian/HessianObserver.h | 45 ++++++++++ compiler/record-hessian/src/RecordHessian.cpp | 85 ------------------- .../record-hessian/src/RecordHessian.test.cpp | 62 -------------- 3 files changed, 45 insertions(+), 147 deletions(-) create mode 100644 compiler/record-hessian/include/record-hessian/HessianObserver.h delete mode 100644 compiler/record-hessian/src/RecordHessian.test.cpp diff --git a/compiler/record-hessian/include/record-hessian/HessianObserver.h b/compiler/record-hessian/include/record-hessian/HessianObserver.h new file mode 100644 index 00000000000..46a39b28b46 --- /dev/null +++ b/compiler/record-hessian/include/record-hessian/HessianObserver.h @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef __RECORD_HESSIAN_HESSIANOBSERVER_H__ +#define __RECORD_HESSIAN_HESSIANOBSERVER_H__ + +#include "record-hessian/HessianComputer.h" + +#include +#include +#include + +namespace record_hessian +{ + +class HessianObserver : public luci_interpreter::ExecutionObserver +{ +public: + HessianObserver() = default; + + void postTensorWrite(const luci::CircleNode *node, + const luci_interpreter::Tensor *tensor) override; + + std::unique_ptr hessianData() { return _hessian_computer.getMap(); } + +private: + HessianComputer _hessian_computer; +}; + +} // namespace record_hessian + +#endif // __RECORD_HESSIAN_HESSIANOBSERVER_H__ diff --git a/compiler/record-hessian/src/RecordHessian.cpp b/compiler/record-hessian/src/RecordHessian.cpp index 4f54170bf34..95e8545403b 100644 --- a/compiler/record-hessian/src/RecordHessian.cpp +++ b/compiler/record-hessian/src/RecordHessian.cpp @@ -104,88 +104,3 @@ void verifyTypeShape(const luci::CircleInput *input_node, const DataType &dtype, } } // namespace - -namespace record_hessian -{ - -void RecordHessian::initialize(luci::Module *module) -{ - // Create and initialize interpreters and observers - - _module = module; - - auto interpreter = std::make_unique(module); - auto observer = std::make_unique(); - - interpreter->attachObserver(observer.get()); - - _observer = std::move(observer); - _interpreter = std::move(interpreter); -} - -std::unique_ptr RecordHessian::profileData(const std::string &input_data_path) -{ - try - { - dio::hdf5::HDF5Importer importer(input_data_path); - importer.importGroup("value"); - - bool is_raw_data = importer.isRawData(); - - const auto num_records = importer.numData(); - if (num_records == 0) - throw std::runtime_error("RecordHessian: The input data file does not contain any record."); - - const auto input_nodes = loco::input_nodes(_module->graph()); - const auto num_inputs = input_nodes.size(); - - for (int32_t record_idx = 0; record_idx < num_records; record_idx++) - { - if (num_inputs != static_cast(importer.numInputs(record_idx))) - throw std::runtime_error("RecordHessian: Wrong number of inputs."); - - std::cout << "Recording " << record_idx << "'th data for hessian." << std::endl; - - for (uint32_t input_idx = 0; input_idx < num_inputs; input_idx++) - { - const auto *input_node = loco::must_cast(input_nodes[input_idx]); - assert(input_node->index() == input_idx); - checkInputDimension(input_node); - std::vector input_data(getTensorSize(input_node)); - - if (!is_raw_data) - { - DataType dtype; - Shape shape; - importer.readTensor(record_idx, input_idx, &dtype, &shape, input_data.data(), - input_data.size()); - - // Check the type and the shape of the input data is valid - verifyTypeShape(input_node, dtype, shape); - } - else - { - // Skip type/shape check for raw data - importer.readTensor(record_idx, input_idx, input_data.data(), input_data.size()); - } - - // TODO: Input data is copied twice (file -> buffer (input_data) -> interpreter inputs) - // We can redcue the copy by directly writing data from file to interpreter inputs - getInterpreter()->writeInputTensor(input_node, input_data.data(), input_data.size()); - } - - getInterpreter()->interpret(); - } - - std::cout << "Recording finished. Number of recorded data: " << num_records << std::endl; - } - catch (const H5::Exception &e) - { - H5::Exception::printErrorStack(); - throw std::runtime_error("RecordHessian: HDF5 error occurred."); - } - - return getObserver()->hessianData(); -} - -} // namespace record_hessian diff --git a/compiler/record-hessian/src/RecordHessian.test.cpp b/compiler/record-hessian/src/RecordHessian.test.cpp deleted file mode 100644 index 66e5fe0a299..00000000000 --- a/compiler/record-hessian/src/RecordHessian.test.cpp +++ /dev/null @@ -1,62 +0,0 @@ -/* - * Copyright (c) 2024 Samsung Electronics Co., Ltd. All Rights Reserved - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "record-hessian/RecordHessian.h" - -#include -#include -#include -#include -#include -#include -#include - -#include - -using namespace record_hessian; - -TEST(RecordHessianTest, profileDataInvalidInputPath_NEG) -{ - // Create a module and a graph - auto m = luci::make_module(); - - // Initialize RecordHessian - RecordHessian rh; - rh.initialize(m.get()); - - // Provide an invalid input_data_path - std::string invalid_input_data_path = "invalid_h5_file"; - - // Call profileData and expect an exception - EXPECT_ANY_THROW( - { std::unique_ptr hessian_map = rh.profileData(invalid_input_data_path); }); -} - -TEST(RecordHessianTest, profileDataNonexistingFile_NEG) -{ - // Create a module and a graph - auto m = luci::make_module(); - - // Initialize RecordHessian - RecordHessian rh; - rh.initialize(m.get()); - - // // Provide an invalid input_data_path - std::string non_existing_h5 = "non_existing.h5"; - - // // Call profileData and expect an exception - EXPECT_ANY_THROW({ std::unique_ptr hessian_map = rh.profileData(non_existing_h5); }); -}