From 6ea73e1841cbf7a19c3b6f43e7386ef8a5945d5e Mon Sep 17 00:00:00 2001 From: liyou_b <2953090824@qq.com> Date: Wed, 18 Dec 2024 08:02:36 +0000 Subject: [PATCH] =?UTF-8?q?!16993=20=E3=80=90PROF=E3=80=91Master:=20dynolo?= =?UTF-8?q?g=20for=20dynamic=20profiling=20Merge=20pull=20request=20!16993?= =?UTF-8?q?=20from=20liyou=5Fb/dynolog=5Fmaster?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- torch_npu/csrc/profiler/CMakeLists.txt | 2 +- .../csrc/profiler/dyno/DynoLogNpuMonitor.cpp | 29 +++ .../csrc/profiler/dyno/DynoLogNpuMonitor.h | 20 ++ torch_npu/csrc/profiler/dyno/MonitorBase.h | 12 ++ torch_npu/csrc/profiler/dyno/NpuIpcClient.cpp | 134 +++++++++++++ torch_npu/csrc/profiler/dyno/NpuIpcClient.h | 96 +++++++++ torch_npu/csrc/profiler/dyno/NpuIpcEndPoint.h | 185 ++++++++++++++++++ .../profiler/dyno/PyDynamicMonitorProxy.h | 30 +++ torch_npu/csrc/profiler/dyno/utils.h | 99 ++++++++++ torch_npu/csrc/profiler/init.cpp | 6 +- .../_dynamic_profiler_config_context.py | 142 +++++++++++--- .../_dynamic_profiler_monitor.py | 119 +++++++---- .../_dynamic_profiler_monitor_shm.py | 59 +++--- .../_dynamic_profiler_utils.py | 151 +++++++++----- torch_npu/profiler/_non_intrusive_profile.py | 27 ++- torch_npu/profiler/dynamic_profile.py | 52 +++-- 16 files changed, 999 insertions(+), 164 deletions(-) create mode 100644 torch_npu/csrc/profiler/dyno/DynoLogNpuMonitor.cpp create mode 100644 torch_npu/csrc/profiler/dyno/DynoLogNpuMonitor.h create mode 100644 torch_npu/csrc/profiler/dyno/MonitorBase.h create mode 100644 torch_npu/csrc/profiler/dyno/NpuIpcClient.cpp create mode 100644 torch_npu/csrc/profiler/dyno/NpuIpcClient.h create mode 100644 torch_npu/csrc/profiler/dyno/NpuIpcEndPoint.h create mode 100644 torch_npu/csrc/profiler/dyno/PyDynamicMonitorProxy.h create mode 100644 torch_npu/csrc/profiler/dyno/utils.h diff --git a/torch_npu/csrc/profiler/CMakeLists.txt b/torch_npu/csrc/profiler/CMakeLists.txt index 9a39c2d4d8..52d68a84c6 100644 --- a/torch_npu/csrc/profiler/CMakeLists.txt +++ b/torch_npu/csrc/profiler/CMakeLists.txt @@ -1,4 +1,4 @@ -FILE(GLOB _PROF_SRCS *.cpp) +FILE(GLOB _PROF_SRCS *.cpp dyno/*.h dyno/*.cpp) LIST(APPEND PROF_SRCS ${_PROF_SRCS}) diff --git a/torch_npu/csrc/profiler/dyno/DynoLogNpuMonitor.cpp b/torch_npu/csrc/profiler/dyno/DynoLogNpuMonitor.cpp new file mode 100644 index 0000000000..80238935c1 --- /dev/null +++ b/torch_npu/csrc/profiler/dyno/DynoLogNpuMonitor.cpp @@ -0,0 +1,29 @@ +#include "DynoLogNpuMonitor.h" +#include "utils.h" +namespace torch_npu { +namespace profiler { +bool DynoLogNpuMonitor::Init() +{ + if (isInitialized_) { + ASCEND_LOGW("DynoLog npu monitor is initialized !"); + return true; + } + bool res = ipcClient_.RegisterInstance(npuId_); + if (res) { + isInitialized_ = true; + ASCEND_LOGI("DynoLog npu monitor initialized success !"); + } + return res; +} +std::string DynoLogNpuMonitor::Poll() +{ + std::string res = ipcClient_.IpcClientNpuConfig(); + if (res.empty()) { + ASCEND_LOGI("Request for dynolog server is empty !"); + return ""; + } + ASCEND_LOGI("Received NPU configuration successfully"); + return res; +} +} // namespace profiler +} // namespace torch_npu \ No newline at end of file diff --git a/torch_npu/csrc/profiler/dyno/DynoLogNpuMonitor.h b/torch_npu/csrc/profiler/dyno/DynoLogNpuMonitor.h new file mode 100644 index 0000000000..338385418b --- /dev/null +++ b/torch_npu/csrc/profiler/dyno/DynoLogNpuMonitor.h @@ -0,0 +1,20 @@ +#pragma once +#include +#include "MonitorBase.h" +#include "NpuIpcClient.h" +namespace torch_npu { +namespace profiler { +class DynoLogNpuMonitor : public MonitorBase, public torch_npu::toolkit::profiler::Singleton { + friend class torch_npu::toolkit::profiler::Singleton; +public: + DynoLogNpuMonitor() = default; + bool Init() override; + std::string Poll() override; + void SetNpuId(int id) override { npuId_ = id;} +private: + bool isInitialized_ = false; + int32_t npuId_ = 0; + IpcClient ipcClient_; +}; +} // namespace profiler +} // namespace torch_npu diff --git a/torch_npu/csrc/profiler/dyno/MonitorBase.h b/torch_npu/csrc/profiler/dyno/MonitorBase.h new file mode 100644 index 0000000000..0250260c98 --- /dev/null +++ b/torch_npu/csrc/profiler/dyno/MonitorBase.h @@ -0,0 +1,12 @@ +#pragma once +#include +namespace torch_npu { +namespace profiler { +class MonitorBase { +public: + virtual bool Init() = 0; + virtual std::string Poll() = 0; + virtual void SetNpuId(int id) = 0; +}; +} // namespace profiler +} // namespace torch_npu diff --git a/torch_npu/csrc/profiler/dyno/NpuIpcClient.cpp b/torch_npu/csrc/profiler/dyno/NpuIpcClient.cpp new file mode 100644 index 0000000000..968b2cc4f4 --- /dev/null +++ b/torch_npu/csrc/profiler/dyno/NpuIpcClient.cpp @@ -0,0 +1,134 @@ +#include "NpuIpcClient.h" + +namespace torch_npu { +namespace profiler { +bool torch_npu::profiler::IpcClient::RegisterInstance(int32_t id) +{ + NpuContext context{ + .npu = id, + .pid = getpid(), + .jobId = JOB_ID, + }; + std::unique_ptr message = Message::ConstructMessage(context, "ctxt"); + try { + if (!SyncSendMessage(*message, std::string(DYNO_IPC_NAME))) { + ASCEND_LOGW("Failed to send register ctxt for pid %d with dyno", context.pid); + return false; + } + } catch (const std::exception &e) { + ASCEND_LOGW("Error when SyncSendMessage %s !", e.what()); + return false; + } + ASCEND_LOGI("Resigter pid %d for dynolog success !", context.pid); + return true; +} +std::string IpcClient::IpcClientNpuConfig() +{ + int size = pids_.size(); + auto *req = (NpuRequest *)malloc(sizeof(NpuRequest) + sizeof(int32_t) * size); + req->type = DYNO_IPC_TYPE; + req->pidSize = size; + req->jobId = JOB_ID; + for (int i = 0; i < size; i++) { + req->pids[i] = pids_[i]; + } + std::unique_ptr message = Message::ConstructMessage(*req, "req", size); + if (!SyncSendMessage(*message, std::string(DYNO_IPC_NAME))) { + ASCEND_LOGW("Failed to send config to dyno server fail !"); + free(req); + req = nullptr; + return ""; + } + free(req); + message = PollRecvMessage(MAX_IPC_RETRIES, MAX_SLEEP_US); + if (!message) { + ASCEND_LOGW("Failed to receive on-demand config !"); + return ""; + } + std::string res = std::string((char *)message->buf.get(), message->metadata.size); + return res; +} +std::unique_ptr IpcClient::ReceiveMessage() +{ + std::lock_guard wguard(dequeLock_); + if (msgDynoDeque_.empty()) { + return nullptr; + } + std::unique_ptr message = std::move(msgDynoDeque_.front()); + msgDynoDeque_.pop_front(); + return message; +} +bool IpcClient::SyncSendMessage(const Message &message, const std::string &destName, int numRetry, int seepTimeUs) +{ + if (destName.empty()) { + ASCEND_LOGW("Can not send to empty socket name !"); + return false; + } + int i = 0; + std::vector npuPayLoad{ NpuPayLoad(sizeof(struct Metadata), (void *)&message.metadata), + NpuPayLoad(message.metadata.size, message.buf.get()) }; + try { + auto ctxt = ep_.BuildSendNpuCtxt(destName, npuPayLoad, std::vector()); + while (!ep_.TrySendMessage(*ctxt) && i < numRetry) { + i++; + usleep(seepTimeUs); + seepTimeUs *= 2; + } + } catch (const std::exception &e) { + ASCEND_LOGW("Error when SyncSendMessage %s !", e.what()); + return false; + } + return i < numRetry; +} +bool IpcClient::Recv() +{ + try { + Metadata recvMetadata; + std::vector PeekNpuPayLoad{ NpuPayLoad(sizeof(struct Metadata), &recvMetadata) }; + auto peekCtxt = ep_.BuildNpuRcvCtxt(PeekNpuPayLoad); + bool successFlag = false; + try { + successFlag = ep_.TryPeekMessage(*peekCtxt); + } catch (std::exception &e) { + ASCEND_LOGW("ERROR when TryPeekMessage: %s !", e.what()); + return false; + } + if (successFlag) { + std::unique_ptr npuMessage = std::make_unique(Message()); + npuMessage->metadata = recvMetadata; + npuMessage->buf = std::unique_ptr(new unsigned char[recvMetadata.size]); + npuMessage->src = std::string(ep_.GetName(*peekCtxt)); + std::vector npuPayLoad{ + NpuPayLoad(sizeof(struct Metadata), (void *)&npuMessage->metadata), + NpuPayLoad(recvMetadata.size, npuMessage->buf.get()) }; + auto recvCtxt = ep_.BuildNpuRcvCtxt(npuPayLoad); + try { + successFlag = ep_.TryRcvMessage(*recvCtxt); + } catch (std::exception &e) { + ASCEND_LOGW("Error when TryRecvMsg: %s !", e.what()); + return false; + } + if (successFlag) { + std::lock_guard wguard(dequeLock_); + msgDynoDeque_.push_back(std::move(npuMessage)); + return true; + } + } + } catch (std::exception &e) { + ASCEND_LOGW("Error in Recv(): %s !", e.what()); + return false; + } + return false; +} +std::unique_ptr IpcClient::PollRecvMessage(int maxRetry, int sleeTimeUs) +{ + for (int i = 0; i < maxRetry; i++) { + if (Recv()) { + return ReceiveMessage(); + } + usleep(sleeTimeUs); + } + return nullptr; +} +} // namespace profiler +} // namespace torch_npu \ No newline at end of file diff --git a/torch_npu/csrc/profiler/dyno/NpuIpcClient.h b/torch_npu/csrc/profiler/dyno/NpuIpcClient.h new file mode 100644 index 0000000000..42fff6765f --- /dev/null +++ b/torch_npu/csrc/profiler/dyno/NpuIpcClient.h @@ -0,0 +1,96 @@ +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include "NpuIpcEndPoint.h" +#include "utils.h" +namespace torch_npu { +namespace profiler { +constexpr int TYPE_SIZE = 32; +constexpr int JOB_ID = 0; +constexpr const char *DYNO_IPC_NAME = "dynolog"; +constexpr const int DYNO_IPC_TYPE = 3; +constexpr const int MAX_IPC_RETRIES = 5; +constexpr const int MAX_SLEEP_US = 10000; +struct NpuRequest { + int type; + int pidSize; + int64_t jobId; + int32_t pids[0]; +}; +struct NpuContext { + int32_t npu; + pid_t pid; + int64_t jobId; +}; +struct Metadata { + size_t size = 0; + char type[TYPE_SIZE] = ""; +}; +struct Message { + Metadata metadata; + std::unique_ptr buf; + std::string src; + template static std::unique_ptr ConstructMessage(const T &data, const std::string &type) + { + std::unique_ptr ipcNpuMessage = std::make_unique(Message()); + if (type.size() + 1 > sizeof(ipcNpuMessage->metadata.type)) { + throw std::runtime_error("Type string is too long to fit in metadata.type"); + } + memcpy(ipcNpuMessage->metadata.type, type.c_str(), type.size() + 1); +#if __cplusplus >= 201703L + if constexpr (std::is_same::value == true) { + ipcNpuMessage->metadata.size = data.size(); + ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); + memcpy(ipcNpuMessage->buf.get(), data.c_str(), sizeof(data)); + return ipcNpuMessage; + } +#endif + static_assert(std::is_trivially_copyable::value); + ipcNpuMessage->metadata.size = sizeof(data); + ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); + memcpy(ipcNpuMessage->buf.get(), &data, sizeof(data)); + return ipcNpuMessage; + } + + template + static std::unique_ptr ConstructMessage(const T &data, const std::string &type, int n) + { + std::unique_ptr ipcNpuMessage = std::make_unique(Message()); + if (type.size() + 1 > sizeof(ipcNpuMessage->metadata.type)) { + throw std::runtime_error("Type string is too long to fit in metadata.type"); + } + memcpy(ipcNpuMessage->metadata.type, type.c_str(), type.size() + 1); + static_assert(std::is_trivially_copyable::value); + static_assert(std::is_trivially_copyable::value); + ipcNpuMessage->metadata.size = sizeof(data) + sizeof(U) * n; + ipcNpuMessage->buf = std::make_unique(ipcNpuMessage->metadata.size); + memcpy(ipcNpuMessage->buf.get(), &data, ipcNpuMessage->metadata.size); + return ipcNpuMessage; + } +}; +class IpcClient { +public: + IpcClient(const IpcClient &) = delete; + IpcClient &operator = (const IpcClient &) = delete; + IpcClient() = default; + bool RegisterInstance(int32_t npu); + std::string IpcClientNpuConfig(); +private: + std::vector pids_ = GetPids(); + NpuIpcEndPoint<0> ep_{ "dynoconfigclient" + GenerateUuidV4() }; + std::mutex dequeLock_; + std::deque> msgDynoDeque_; + std::unique_ptr ReceiveMessage(); + bool SyncSendMessage(const Message &message, const std::string &destName, int numRetry = 10, + int seepTimeUs = 10000); + bool Recv(); + std::unique_ptr PollRecvMessage(int maxRetry, int sleeTimeUs); +}; +} // namespace profiler +} // namespace torch_npu diff --git a/torch_npu/csrc/profiler/dyno/NpuIpcEndPoint.h b/torch_npu/csrc/profiler/dyno/NpuIpcEndPoint.h new file mode 100644 index 0000000000..59e0135552 --- /dev/null +++ b/torch_npu/csrc/profiler/dyno/NpuIpcEndPoint.h @@ -0,0 +1,185 @@ +// +// Created by liyou on 2024/11/30. +// +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "utils.h" +namespace torch_npu { +namespace profiler { +using fileDesT = int; +constexpr const char STR_END_CHAR = '\0'; +constexpr int SOCKET_FD_CHMOD = 0666; +struct NpuPayLoad { + size_t size; + void *data; + NpuPayLoad(size_t size, void *data) : size(size), data(data) {} +}; +template struct NpuIpcEndPointCtxt { + struct sockaddr_un messageName; + size_t messageLen; + fileDesT *fileDesPtr; + struct msghdr msghdr; + std::vector iov; + char ancillaryBuf[CMSG_SPACE(MaxNumFileDes * sizeof(fileDesT))]; + explicit NpuIpcEndPointCtxt(size_t num) : iov(std::vector(num)){}; +}; +template class NpuIpcEndPoint final { + using Ctxt = NpuIpcEndPointCtxt; +public: + constexpr static size_t addressMaxLen = 108 - 2; // Max unix socket path length + explicit NpuIpcEndPoint(const std::string &addressName) + { + socketFd = socket(AF_UNIX, SOCK_DGRAM, 0); + if (socketFd == -1) { + throw std::runtime_error(std::strerror(errno)); + } + struct sockaddr_un address; + size_t addressLen = SetSocketAdress(addressName, address); + if (address.sun_path[0] != STR_END_CHAR) { + unlink(address.sun_path); + } + int res = bind(socketFd, (const struct sockaddr *)&address, addressLen); + if (res == -1) { + throw std::runtime_error("Bind socket failed !"); + } + if (address.sun_path[0] != STR_END_CHAR) { + chmod(address.sun_path, SOCKET_FD_CHMOD); + } + } + ~NpuIpcEndPoint() + { + close(socketFd); + } + [[nodiscard]] auto BuildSendNpuCtxt(const std::string &desAddrName, const std::vector &npuPayLoad, + const std::vector &fileDes) + { + if (fileDes.size() > MaxNumFileDes) { + throw std::runtime_error("Request to fill more than max connections"); + } + if (desAddrName.empty()) { + throw std::runtime_error("Can not send to dest point, because dest socket name is empty"); + } + auto ctxt = BuildNpuCtxt_(npuPayLoad, fileDes.size()); + ctxt->msghdr.msg_namelen = SetSocketAdress(desAddrName, ctxt->messageName); + if (!fileDes.empty()) { + if (sizeof(ctxt->fileDesPtr) < fileDes.size() * sizeof(fileDesT)) { + throw std::runtime_error("Memcpy failed when fileDes size large than ctxt fileDesPtr"); + } + memcpy(ctxt->fileDesPtr, fileDes.data(), fileDes.size() * sizeof(fileDesT)); + } + return ctxt; + } + [[nodiscard]] bool TrySendMessage(Ctxt const & ctxt, bool retryOnConnRefused = true) + { + ssize_t retCode = sendmsg(socketFd, &ctxt.msghdr, MSG_DONTWAIT); + if (retCode > 0) { + return true; + } + if ((errno == EAGAIN || errno == EWOULDBLOCK) && retCode == -1) { + return false; + } + if (retryOnConnRefused && errno == ECONNREFUSED && retCode == -1) { + return false; + } + throw std::runtime_error("TrySendMessage occur " + std::string(std::strerror(errno))); + } + [[nodiscard]] auto BuildNpuRcvCtxt(const std::vector &npuPayLoad) + { + return BuildNpuCtxt_(npuPayLoad, MaxNumFileDes); + } + [[nodiscard]] bool TryRcvMessage(Ctxt &ctxt) noexcept + { + size_t retCode = recvmsg(socketFd, &ctxt.msghdr, MSG_DONTWAIT); + if (retCode > 0) { + return true; + } + if (retCode == 0) { + return false; + } + if (errno == EWOULDBLOCK || errno == EAGAIN) { + return false; + } + throw std::runtime_error("TryRcvMessage occur " + std::string(std::strerror(errno))); + } + [[nodiscard]] bool TryPeekMessage(Ctxt &ctxt) + { + ssize_t ret = recvmsg(socketFd, &ctxt.msghdr, MSG_DONTWAIT | MSG_PEEK); + if (ret > 0) { + return true; + } + if (ret == 0) { + return false; + } + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return false; + } + throw std::runtime_error("TryPeekMessage occur " + std::string(std::strerror(errno))); + } + const char *GetName(Ctxt const & ctxt) const noexcept + { + if (ctxt.messageName.sun_path[0] != STR_END_CHAR) { + throw std::runtime_error("GetName() want to got abstract socket, but got " + + std::string(ctxt.messageName.sun_path)); + } + return ctxt.messageName.sun_path + 1; + } + std::vector GetFileDes(const Ctxt &ctxt) const + { + struct cmsghdr *cmg = CMSG_FIRSTHDR(&ctxt.msghdl); + unsigned numFileDes = (cmg->cmsg_len - sizeof(struct cmsghdr)) / sizeof(fileDesT); + return { ctxt.fileDesPtr, ctxt.fileDesPtr + numFileDes }; + } +protected: + fileDesT socketFd; + size_t SetSocketAdress(const std::string &srcSocket, struct sockaddr_un &destSocket) + { + if (srcSocket.size() > addressMaxLen) { + throw std::runtime_error("Abstract UNIX Socket path cannot be larger than addressMaxLen"); + } + destSocket.sun_family = AF_UNIX; + destSocket.sun_path[0] = STR_END_CHAR; + if (srcSocket.empty()) { + return sizeof(sa_family_t); + } + srcSocket.copy(destSocket.sun_path + 1, srcSocket.size()); + destSocket.sun_path[srcSocket.size() + 1] = STR_END_CHAR; + return sizeof(sa_family_t) + srcSocket.size() + 2; + } + auto BuildNpuCtxt_(const std::vector &npuPayLoad, unsigned numFileDes) + { + auto ctxt = std::make_unique(npuPayLoad.size()); + std::memset(&ctxt->msghdr, 0, sizeof(ctxt->msghdr)); + for (int i = 0; i < npuPayLoad.size(); i++) { + ctxt->iov[i] = {npuPayLoad[i].data, npuPayLoad[i].size}; + } + ctxt->msghdr.msg_name = &ctxt->messageName; + ctxt->msghdr.msg_namelen = sizeof(decltype(ctxt->messageName)); + ctxt->msghdr.msg_iov = ctxt->iov.data(); + ctxt->msghdr.msg_iovlen = npuPayLoad.size(); + ctxt->fileDesPtr = nullptr; + if (numFileDes == 0) { + return ctxt; + } + const size_t fileDesSize = sizeof(fileDesT) * numFileDes; + ctxt->msghdr.msg_control = ctxt->ancillaryBuf; + ctxt->msghdr.msg_controllen = CMSG_SPACE(fileDesSize); + + struct cmsghdr *cmsg = CMSG_FIRSTHDR(&ctxt->msghdr); + cmsg->cmsg_level = SOL_SOCKET; + cmsg->cmsg_type = SCM_RIGHTS; + cmsg->cmsg_len = CMSG_LEN(fileDesSize); + ctxt->fileDesPtr = (fileDesT *)CMSG_DATA(cmsg); + return ctxt; + } +}; +} // namespace profiler +} // namespace torch_npu diff --git a/torch_npu/csrc/profiler/dyno/PyDynamicMonitorProxy.h b/torch_npu/csrc/profiler/dyno/PyDynamicMonitorProxy.h new file mode 100644 index 0000000000..e08b72640a --- /dev/null +++ b/torch_npu/csrc/profiler/dyno/PyDynamicMonitorProxy.h @@ -0,0 +1,30 @@ +#pragma once +#include "MonitorBase.h" +#include "DynoLogNpuMonitor.h" +namespace torch_npu { +namespace profiler { +class PyDynamicMonitorProxy { +public: + PyDynamicMonitorProxy() = default; + bool InitDyno(int npuId) + { + try { + monitor_ = DynoLogNpuMonitor::GetInstance(); + monitor_->SetNpuId(npuId); + bool res = monitor_->Init(); + return res; + } catch (const std::exception &e) { + ASCEND_LOGE("Error when init dyno %s !", e.what()); + return false; + } + } + std::string PollDyno() + { + return monitor_->Poll(); + }; + +private: + MonitorBase *monitor_ = nullptr; +}; +} // namespace profiler +} // namespace torch_npu diff --git a/torch_npu/csrc/profiler/dyno/utils.h b/torch_npu/csrc/profiler/dyno/utils.h new file mode 100644 index 0000000000..f95624f184 --- /dev/null +++ b/torch_npu/csrc/profiler/dyno/utils.h @@ -0,0 +1,99 @@ +// +// Created by liyou on 2024/12/3. +// +#pragma once +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "torch_npu/csrc/core/npu/npu_log.h" +namespace torch_npu { +namespace profiler { +inline int32_t GetProcessId() +{ + int32_t pid = 0; + pid = static_cast(getpid()); + return pid; +} + +inline std::pair GetParentPidAndCommand(int32_t pid) +{ + std::string fileName = "/proc/" + std::to_string(pid) + "/stat"; + std::ifstream statFile(fileName); + if (!statFile) { + return std::make_pair(0, ""); + } + int32_t parentPid = 0; + std::string command; + std::string line; + if (std::getline(statFile, line)) { + sscanf(line.c_str(), "%*d (%[^)]) %*c %d", command.data(), &parentPid); + ASCEND_LOGI("Success to get parent pid %d", parentPid); + return std::make_pair(parentPid, command); + } + ASCEND_LOGW("Failed to parse /proc/%d/stat", pid); + return std::make_pair(0, ""); +} + +constexpr int MaxParentPids = 5; +inline std::vector> GetPidCommandPairsofAncestors() +{ + std::vector> process_pids_and_cmds; + process_pids_and_cmds.reserve(MaxParentPids + 1); + int32_t current_pid = GetProcessId(); + for (int i = 0; i <= MaxParentPids && (i == 0 || current_pid > 1); i++) { + std::pair parent_pid_and_cmd = GetParentPidAndCommand(current_pid); + process_pids_and_cmds.push_back(std::make_pair(current_pid, parent_pid_and_cmd.second)); + current_pid = parent_pid_and_cmd.first; + } + return process_pids_and_cmds; +} + +inline std::vector GetPids() +{ + const auto &pids = GetPidCommandPairsofAncestors(); + std::vector res; + res.reserve(pids.size()); + for (const auto &pidPair : pids) { + res.push_back(pidPair.first); + } + return res; +} +inline std::string GenerateUuidV4() +{ + static std::random_device randomDevice; + static std::mt19937 gen(randomDevice()); + static std::uniform_int_distribution<> dis(0, 15); + static std::uniform_int_distribution<> dis2(8, 11); + + std::stringstream stringStream; + stringStream << std::hex; + for (int i = 0; i < 8; i++) { + stringStream << dis(gen); + } + stringStream << "-"; + for (int j = 0; j < 4; j++) { + stringStream << dis(gen); + } + stringStream << "-4"; + for (int k = 0; k < 3; k++) { + stringStream << dis(gen); + } + stringStream << "-"; + stringStream << dis2(gen); + for (int m = 0; m < 3; m++) { + stringStream << dis(gen); + } + stringStream << "-"; + for (int n = 0; n < 12; n++) { + stringStream << dis(gen); + } + return stringStream.str(); +} +} // namespace profiler +} // namespace torch_npu diff --git a/torch_npu/csrc/profiler/init.cpp b/torch_npu/csrc/profiler/init.cpp index 94ad31d5a6..4ae750c6b2 100644 --- a/torch_npu/csrc/profiler/init.cpp +++ b/torch_npu/csrc/profiler/init.cpp @@ -18,6 +18,7 @@ #include "torch_npu/csrc/toolkit/profiler/common/utils.h" #include "torch_npu/csrc/framework/interface/LibAscendHal.h" #include "torch_npu/csrc/core/npu/NPUException.h" +#include "torch_npu/csrc/profiler/dyno/PyDynamicMonitorProxy.h" namespace torch_npu { namespace profiler { @@ -64,7 +65,10 @@ PyObject* profiler_initExtension(PyObject* _unused, PyObject *unused) { py::class_(m, "NpuProfilerConfig") .def(py::init()); - + py::class_(m, "PyDynamicMonitorProxy") + .def(py::init<>()) + .def("init_dyno", &PyDynamicMonitorProxy::InitDyno, py::arg("npuId")) + .def("poll_dyno", &PyDynamicMonitorProxy::PollDyno); m.def("_supported_npu_activities", []() { std::set activities { NpuActivityType::CPU, diff --git a/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_config_context.py b/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_config_context.py index 6b335c2ecd..0cb049d5db 100644 --- a/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_config_context.py +++ b/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_config_context.py @@ -1,12 +1,14 @@ import json from torch_npu._C._profiler import ProfilerActivity from ..experimental_config import _ExperimentalConfig, ProfilerLevel, AiCMetrics -from ._dynamic_profiler_utils import logger, _get_rank_id +from ._dynamic_profiler_utils import DynamicProfilerUtils class ConfigContext: DEFAULT_ACTIVE_NUM = 1 DEFAULT_START_STEP = 0 + DEADLINE_PROF_DIR = "./" + BOOL_MAP = {'true': True, 'false': False} def __init__(self, json_data: dict): self.activity_set = set() @@ -18,14 +20,15 @@ def __init__(self, json_data: dict): self.with_flops = False self.with_modules = False self.is_rank = False - self.rank_set = set() + self._rank_set = set() self.experimental_config = None self._active = 1 self._start_step = 0 self.is_valid = False self._meta_data = {} - self._rank_id = _get_rank_id() self._async_mode = False + self._is_dyno = DynamicProfilerUtils.is_dyno_model() + self._rank_id = DynamicProfilerUtils.get_rank_id() self.parse(json_data) def parse(self, json_data: dict): @@ -36,22 +39,17 @@ def parse(self, json_data: dict): activity = getattr(ProfilerActivity, entry.upper(), None) if activity: self.activity_set.add(activity) - self.prof_path = json_data.get('prof_dir') + self._parse_prof_dir(json_data) self._meta_data = json_data.get('metadata', {}) self._analyse = json_data.get('analyse', False) self._async_mode = json_data.get('async_mode', False) - self.record_shapes = json_data.get('record_shapes', False) - self.profile_memory = json_data.get('profile_memory', False) - self.with_stack = json_data.get('with_stack', False) - self.with_flops = json_data.get('with_flops', False) - self.with_modules = json_data.get('with_modules', False) - self._active = json_data.get('active', self.DEFAULT_ACTIVE_NUM) - self._start_step = json_data.get("start_step", self.DEFAULT_START_STEP) - if not isinstance(self._start_step, int) or self._start_step < 0: - logger.info(f"Start step is not valid, will be reset to {self.DEFAULT_START_STEP}.") - self._start_step = self.DEFAULT_START_STEP - else: - logger.info(f"Start step will be set to {self._start_step}.") + self._parse_report_shape(json_data) + self._parse_profiler_memory(json_data) + self._parse_with_flops(json_data) + self._parse_with_stack(json_data) + self._parse_with_modules(json_data) + self._parse_active(json_data) + self._parse_start_step(json_data) exp_config = json_data.get('experimental_config') if not exp_config: self.experimental_config = None @@ -78,36 +76,122 @@ def parse(self, json_data: dict): export_type=export_type, msprof_tx=msprof_tx ) - self.parse_ranks(json_data) + self._parse_ranks(json_data) - def parse_ranks(self, json_data: dict): + def _parse_start_step(self, json_data: dict): + if not self._is_dyno: + self._start_step = json_data.get("start_step", self.DEFAULT_START_STEP) + else: + start_step = json_data.get("PROFILE_START_ITERATION_ROUNDUP", self.DEFAULT_START_STEP) + try: + self._start_step = int(start_step) + except ValueError: + self._start_step = self.DEFAULT_START_STEP + + if not isinstance(self._start_step, int) or self._start_step < 0: + DynamicProfilerUtils.out_log("Start step is not valid, will be reset to {}.".format( + self.DEFAULT_START_STEP), DynamicProfilerUtils.LoggerLevelEnum.INFO) + self._start_step = self.DEFAULT_START_STEP + DynamicProfilerUtils.out_log("Start step will be set to {}.".format( + self._start_step), DynamicProfilerUtils.LoggerLevelEnum.INFO) + + def _parse_prof_dir(self, json_data: dict): + if not self._is_dyno: + self.prof_path = json_data.get('prof_dir', self.DEADLINE_PROF_DIR) + else: + self.prof_path = json_data.get("ACTIVITIES_LOG_FILE", self.DEADLINE_PROF_DIR) + + def _parse_active(self, json_data: dict): + if not self._is_dyno: + self._active = json_data.get("active", self.DEFAULT_ACTIVE_NUM) + else: + active = json_data.get("ACTIVITIES_ITERATIONS", self.DEFAULT_ACTIVE_NUM) + try: + self._active = int(active) + except ValueError: + self._active = self.DEFAULT_ACTIVE_NUM + + def _parse_with_stack(self, json_data: dict): + if not self._is_dyno: + self.with_stack = json_data.get('with_stack', False) + else: + with_stack = json_data.get("PROFILE_WITH_STACK") + if isinstance(with_stack, str): + self.with_stack = self.BOOL_MAP.get(with_stack.lower(), False) + else: + self.with_stack = False + + def _parse_report_shape(self, json_data: dict): + if not self._is_dyno: + self.record_shapes = json_data.get('record_shapes', False) + else: + record_shapes = json_data.get("PROFILE_REPORT_INPUT_SHAPES") + if isinstance(record_shapes, str): + self.record_shapes = self.BOOL_MAP.get(record_shapes.lower(), False) + else: + self.record_shapes = False + + def _parse_profiler_memory(self, json_data: dict): + if not self._is_dyno: + self.profile_memory = json_data.get('profile_memory', None) + else: + profile_memory = json_data.get("PROFILE_PROFILE_MEMORY") + if isinstance(profile_memory, str): + self.profile_memory = self.BOOL_MAP.get(profile_memory.lower(), False) + else: + self.profile_memory = False + + def _parse_with_flops(self, json_data: dict): + if not self._is_dyno: + self.with_flops = json_data.get('with_flops', False) + else: + with_flops = json_data.get("PROFILE_WITH_FLOPS") + if isinstance(with_flops, str): + self.with_flops = self.BOOL_MAP.get(with_flops.lower(), False) + else: + self.with_flops = False + + def _parse_with_modules(self, json_data: dict): + if not self._is_dyno: + self.with_modules = json_data.get('with_modules', False) + else: + with_modules = json_data.get("PROFILE_WITH_MODULES") + if isinstance(with_modules, str): + self.with_modules = self.BOOL_MAP.get(with_modules.lower(), False) + else: + self.with_modules = False + + def _parse_ranks(self, json_data: dict): self.is_rank = json_data.get("is_rank", False) if not isinstance(self.is_rank, bool): self.is_rank = False - logger.warning("Set is_rank failed, is_rank must be bool!") + DynamicProfilerUtils.out_log("Set is_rank failed, is_rank must be bool!", + DynamicProfilerUtils.LoggerLevelEnum.WARNING) return if not self.is_rank: return - logger.info("Set is_rank success!") + DynamicProfilerUtils.out_log("Set is_rank success!", DynamicProfilerUtils.LoggerLevelEnum.INFO) ranks = json_data.get("rank_list", False) if not isinstance(ranks, list): - logger.warning("Set rank_list failed, rank_list must be list!") + DynamicProfilerUtils.out_log("Set rank_list failed, rank_list must be list!", + DynamicProfilerUtils.LoggerLevelEnum.WARNING) return for rank in ranks: if isinstance(rank, int) and rank >= 0: - self.rank_set.add(rank) + self._rank_set.add(rank) def valid(self) -> bool: if not self.is_valid: return False if not self.is_rank: return True - if self._rank_id in self.rank_set: + if self._rank_id in self._rank_set: self._analyse = False - logger.info("Rank {} is in valid rank_list {}, profiler data analyse will be closed !".format( - self._rank_id, self.rank_set)) + DynamicProfilerUtils.out_log("Rank {} is in rank_list {}, profiler data analyse will be closed !".format( + self._rank_id, self._rank_set), DynamicProfilerUtils.LoggerLevelEnum.INFO) return True - logger.warning("Rank {} not in valid rank_list {}!".format(self._rank_id, self.rank_set)) + DynamicProfilerUtils.out_log("Rank {} not in valid rank_list {}!".format(self._rank_id, self._rank_set), + DynamicProfilerUtils.LoggerLevelEnum.WARNING) return False def meta_data(self): @@ -148,7 +232,8 @@ def with_modules(self) -> bool: def active(self) -> int: if not isinstance(self._active, int) or self._active <= 0: - logger.warning("Invalid parameter active, reset it to 1.") + DynamicProfilerUtils.out_log("Invalid parameter active, reset it to 1.", + DynamicProfilerUtils.LoggerLevelEnum.WARNING) return self.DEFAULT_ACTIVE_NUM return self._active @@ -169,6 +254,3 @@ def bytes_to_profiler_cfg_json(bytes_shm: bytes) -> dict: cfg_json_str = bytes_shm.decode("utf-8") cfg_json = json.loads(cfg_json_str) return cfg_json - - - diff --git a/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_monitor.py b/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_monitor.py index c0703d517f..1cc58c1a23 100644 --- a/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_monitor.py +++ b/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_monitor.py @@ -5,44 +5,45 @@ import json import struct import multiprocessing - -from ._dynamic_profiler_utils import logger, logger_monitor, init_logger, _get_rank_id +from torch_npu._C._profiler import PyDynamicMonitorProxy from ._dynamic_profiler_config_context import ConfigContext +from ._dynamic_profiler_utils import DynamicProfilerUtils from ._dynamic_profiler_monitor_shm import DynamicProfilerShareMemory class DynamicProfilerMonitor: def __init__( - self, - path: str, - buffer_size: int = 1024, - poll_interval: int = 2 + self ): - self._path = path - self._rank_id = _get_rank_id() - self._buffer_size = buffer_size + self._path = DynamicProfilerUtils.CFG_CONFIG_PATH + self._rank_id = DynamicProfilerUtils.get_rank_id() + self._buffer_size = DynamicProfilerUtils.CFG_BUFFER_SIZE self._monitor_process = None self.prof_cfg_context = None self._shared_loop_flag = multiprocessing.Value('b', True) - self._step_time = multiprocessing.Value('i', poll_interval) - self._config_path = os.path.join(self._path, 'profiler_config.json') + self._step_time = multiprocessing.Value('i', DynamicProfilerUtils.POLL_INTERVAL) + self._config_path = None + self._is_dyno = DynamicProfilerUtils.is_dyno_model() + if not self._is_dyno: + self._config_path = os.path.join(self._path, 'profiler_config.json') self._shm_obj = DynamicProfilerShareMemory( self._path, self._config_path, - self._rank_id, - self._buffer_size) + self._rank_id) self._cur_time = int(time.time()) self._create_process() def shm_to_prof_conf_context(self): if self._shm_obj is None: - logger.warning('Rank %d shared memory is None !', self._rank_id) + DynamicProfilerUtils.out_log('Rank {} shared memory is None !'.format( + self._rank_id), DynamicProfilerUtils.LoggerLevelEnum.ERROR) return None try: time_bytes_data = self._shm_obj.read_bytes(read_time=True) shm_cfg_change_time = struct.unpack(" max_size: - logger_monitor.warning("Dynamic profiler process load json failed, " - "because cfg bytes size over %d bytes", max_size) + dynamic_profiler_utils.out_log("Load json failed, because cfg bytes size over {} bytes".format( + max_size), DynamicProfilerUtils.LoggerLevelEnum.WARNING, is_monitor_process=True) continue try: if is_mmap and mmap is not None: @@ -151,8 +157,51 @@ def worker_func(params_dict): elif shm is not None: shm.write_bytes_over_py38(prof_cfg_bytes) except Exception as ex: - logger_monitor.warning("Dynamic profiler cfg bytes write failed, %s has occur!", str(ex)) + dynamic_profiler_utils.out_log("Dynamic profiler cfg bytes write failed, {} has occur!".format( + str(ex)), dynamic_profiler_utils.LoggerLevelEnum.ERROR, is_monitor_process=True) else: - logger_monitor.error("Dynamic profiler cfg json not exists") + dynamic_profiler_utils.out_log("Dynamic profiler cfg json not exists", + dynamic_profiler_utils.LoggerLevelEnum.ERROR, is_monitor_process=True) time.sleep(poll_interval.value) - logger_monitor.info("Dynamic profiler process done") + dynamic_profiler_utils.out_log("Dynamic profiler process done", dynamic_profiler_utils.LoggerLevelEnum.INFO, + is_monitor_process=True) + + +def worker_dyno_func(params_dict): + """ Json monitor process worker function python version >= 3.8""" + loop_flag = params_dict.get("loop_flag") + poll_interval = params_dict.get("poll_interval") + shm = params_dict.get("shm") + rank_id = params_dict.get("rank_id") + max_size = params_dict.get("max_size") + dynamic_profiler_utils = params_dict.get("dynamic_profiler_utils") + + py_dyno_monitor = PyDynamicMonitorProxy() + ret = py_dyno_monitor.init_dyno(rank_id) + if not ret: + dynamic_profiler_utils.out_log("Init dynolog failed !", dynamic_profiler_utils.LoggerLevelEnum.WARNING) + return + dynamic_profiler_utils.out_log("Init dynolog success !", dynamic_profiler_utils.LoggerLevelEnum.INFO) + while loop_flag.value: + time.sleep(poll_interval.value) + res = py_dyno_monitor.poll_dyno() + data = DynamicProfilerUtils.dyno_str_to_json(res) + if data: + data['is_valid'] = True + dynamic_profiler_utils.out_log("Dynolog profiler process load json success", + dynamic_profiler_utils.LoggerLevelEnum.INFO) + else: + continue + time_bytes = struct.pack(" max_size: + dynamic_profiler_utils.out_log("Load json failed, because cfg bytes size over {} bytes".format( + max_size), dynamic_profiler_utils.LoggerLevelEnum.INFO) + continue + try: + if shm is not None: + shm.write_bytes_over_py38(prof_cfg_bytes) + except Exception as ex: + dynamic_profiler_utils.out_log("Dynamic profiler cfg bytes write failed, {} has occur!".format(str(ex)), + dynamic_profiler_utils.LoggerLevelEnum.ERROR) + dynamic_profiler_utils.out_log("Dynolog profiler process done", dynamic_profiler_utils.LoggerLevelEnum.INFO) diff --git a/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_monitor_shm.py b/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_monitor_shm.py index 6c9b817537..93e0b0ab74 100644 --- a/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_monitor_shm.py +++ b/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_monitor_shm.py @@ -11,7 +11,7 @@ from ...utils._path_manager import PathManager from ...utils._error_code import ErrCode, prof_error from ..analysis.prof_common_func._file_manager import FileManager -from ._dynamic_profiler_utils import logger +from ._dynamic_profiler_utils import DynamicProfilerUtils class DynamicProfilerShareMemory: @@ -46,13 +46,13 @@ def __init__( path: str, config_path: str, rank_id: int, - buffer_size: int = 1024 ): self._path = path self.config_path = config_path self._rank_id = rank_id self.shm_path = f"DynamicProfileNpuShm{datetime.utcnow().strftime('%Y%m%d%H')}" - self._shm_buf_bytes_size = buffer_size + self._shm_buf_bytes_size = DynamicProfilerUtils.CFG_BUFFER_SIZE + self._is_dyno = self._is_dyno = DynamicProfilerUtils.is_dyno_model() self.is_create_process = False self.shm = None self.cur_mtime = 0 @@ -61,11 +61,10 @@ def __init__( self._time_bytes_size = len(self._time_data_bytes) self._clean_shm_for_killed() self._create_shm() - if self.is_create_process: + if self.is_create_process and not self._is_dyno: self._create_prof_cfg() - @staticmethod - def _get_pid_st_ctime(pid): + def _get_pid_st_ctime(self, pid): try: fd = os.open("/proc/" + str(pid), os.O_RDONLY, stat.S_IRUSR | stat.S_IRGRP) stat_ino = os.fstat(fd) @@ -73,7 +72,8 @@ def _get_pid_st_ctime(pid): create_time = stat_ino.st_ctime return create_time except Exception as ex: - logger.warning("An error is occurred: %s", str(ex)) + DynamicProfilerUtils.out_log("An error is occurred: {}".format( + str(ex)), DynamicProfilerUtils.LoggerLevelEnum.ERROR) return None def _clean_shm_for_killed(self): @@ -93,7 +93,8 @@ def _clean_shm_for_killed(self): def _create_prof_cfg(self): if not os.path.exists(self.config_path): - logger.info("Create profiler_config.json default.") + DynamicProfilerUtils.out_log("Create profiler_config.json default.", + DynamicProfilerUtils.LoggerLevelEnum.INFO) FileManager.create_json_file_by_path( self.config_path, self.JSON_DATA, @@ -130,7 +131,8 @@ def _create_shm_over_py38(self): lambda *args, **kwargs: None): self.shm = shared_memory.SharedMemory(name=self.shm_path) self.is_create_process = False - logger.info("Rank %d shared memory is connected.", self._rank_id) + DynamicProfilerUtils.out_log("Rank {} shared memory is connected.".format( + self._rank_id), DynamicProfilerUtils.LoggerLevelEnum.INFO) break except FileNotFoundError: try: @@ -140,16 +142,18 @@ def _create_shm_over_py38(self): self.is_create_process = True bytes_data = self._get_default_cfg_bytes() self.shm.buf[:self._shm_buf_bytes_size] = bytes_data - logger.info("Rank %d shared memory is created.", self._rank_id) + DynamicProfilerUtils.out_log("Rank {} shared memory is created.".format( + self._rank_id), DynamicProfilerUtils.LoggerLevelEnum.INFO) break except Exception as ex: # other process will go to step 1 and open shm file try_times -= 1 - logger.warning("Rank %d shared memory create failed, retry times = %d, %s has occur.", - self._rank_id, try_times, str(ex)) + DynamicProfilerUtils.out_log("Rank {} shared memory create failed, " + "retry times = {}, {} has occur.".format( + self._rank_id, try_times, str(ex)), DynamicProfilerUtils.LoggerLevelEnum.ERROR) time.sleep(random.uniform(0, 0.02)) # sleep 0 ~ 20 ms if try_times <= 0: - raise RuntimeError("Failed to create shared memory.") + raise RuntimeError("Failed to create shared memory." + prof_error(ErrCode.VALUE)) def clean_resource(self): if sys.version_info >= (3, 8): @@ -165,9 +169,11 @@ def _clean_shm_over_py38(self): try: self.shm.close() self.shm.unlink() - logger.info("Rank %s unlink shm", self._rank_id) + DynamicProfilerUtils.out_log("Rank {} unlink shm".format( + self._rank_id), DynamicProfilerUtils.LoggerLevelEnum.INFO) except Exception as ex: - logger.warning("Rank %s unlink shm failed, may be removed, %s hs occur", self._rank_id, str(ex)) + DynamicProfilerUtils.out_log("Rank {} unlink shm failed, may be removed, {} hs occur".format( + self._rank_id, str(ex)), DynamicProfilerUtils.LoggerLevelEnum.ERROR) self.shm = None def _clean_shm_py37(self): @@ -179,16 +185,18 @@ def _clean_shm_py37(self): self._memory_mapped_file.close() elif self.fd: os.close(self.fd) - logger.info("Rank %s unlink shm", self._rank_id) + DynamicProfilerUtils.out_log("Rank {} unlink shm".format( + self._rank_id), DynamicProfilerUtils.LoggerLevelEnum.INFO) except Exception as ex: - logger.warning("Rank %s unlink shm failed, may be removed, %s has occur ", self._rank_id, str(ex)) + DynamicProfilerUtils.out_log("Rank {} unlink shm failed, may be removed, {} has occur ".format( + self._rank_id, str(ex)), DynamicProfilerUtils.LoggerLevelEnum.ERROR) PathManager.remove_path_safety(os.path.dirname(self.shm_path)) self.shm = None def _create_shm_py37(self): """Create a json monitor process based on whether the SharedMemory is successfully created py37""" - logger.warning("Dynamic profiler is not work well on python 3.7x, " - "please update to python 3.8+ for better performance.") + DynamicProfilerUtils.out_log("Dynamic profiler is not work well on python 3.7x, " + "please update to python 3.8+ for better performance.", DynamicProfilerUtils.LoggerLevelEnum.INFO) try_times = 10 while try_times: try: @@ -197,7 +205,8 @@ def _create_shm_py37(self): self._memory_mapped_file = os.fdopen(self.fd, 'rb') self.shm = mmap.mmap(self._memory_mapped_file.fileno(), length=self._shm_buf_bytes_size) self.is_create_process = False - logger.info("Rank %d shared memory is connected.", self._rank_id) + DynamicProfilerUtils.out_log("Rank {} shared memory is connected.".format( + self._rank_id), DynamicProfilerUtils.LoggerLevelEnum.INFO) break except ValueError: time.sleep(0.02) @@ -219,17 +228,19 @@ def _create_shm_py37(self): self._memory_mapped_file = os.fdopen(self.fd, 'rb') self.shm = mmap.mmap(self._memory_mapped_file.fileno(), length=self._shm_buf_bytes_size) self.is_create_process = True - logger.info("Rank %d shared memory is created.", self._rank_id) + DynamicProfilerUtils.out_log("Rank {} shared memory is created.".format( + self._rank_id), DynamicProfilerUtils.LoggerLevelEnum.INFO) break except Exception as ex: # other process will go to step 1 and open shm file try_times -= 1 - logger.warning("Rank %d shared memory create failed, retry times = %d, %s has occur .", - self._rank_id, try_times, str(ex)) + DynamicProfilerUtils.out_log("Rank {} shared memory create failed, " + "retry times = {}, {} has occur .".format( + self._rank_id, try_times, str(ex)), DynamicProfilerUtils.LoggerLevelEnum.ERROR) time.sleep(random.uniform(0, 0.02)) # sleep 0 ~ 20 ms if try_times <= 0: - raise RuntimeError("Failed to create shared memory.") + raise RuntimeError("Failed to create shared memory." + prof_error(ErrCode.VALUE)) def read_bytes(self, read_time=False): """Read bytes from shared memory""" diff --git a/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_utils.py b/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_utils.py index a3a348124f..01ff72940c 100644 --- a/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_utils.py +++ b/torch_npu/profiler/_dynamic_profiler/_dynamic_profiler_utils.py @@ -1,53 +1,112 @@ import os import socket import logging +from enum import Enum from logging.handlers import RotatingFileHandler import torch + from ...utils._path_manager import PathManager +from ..analysis.prof_common_func._constant import print_info_msg +from ..analysis.prof_common_func._constant import print_warn_msg +from ..analysis.prof_common_func._constant import print_error_msg + + +class DynamicProfilerUtils: + class DynamicProfilerConfigModel(Enum): + CFG_CONFIG = 0 # 配置文件使能方式 + DYNO_CONFIG = 1 # Dynolog使能方式 + + class LoggerLevelEnum(Enum): + INFO = 0 + WARNING = 1 + ERROR = 2 + + DYNAMIC_PROFILER_MODEL = DynamicProfilerConfigModel.CFG_CONFIG + LOGGER = logging.getLogger("DynamicProfiler") + LOGGER_MONITOR = logging.getLogger("DynamicProfilerMonitor") + DYNAMIC_LOG_TO_FILE_MAP = { + LoggerLevelEnum.INFO: LOGGER.info, + LoggerLevelEnum.WARNING: LOGGER.warning, + LoggerLevelEnum.ERROR: LOGGER.error + } + DYNAMIC_MONITOR_LOG_TO_FILE_MAP = { + LoggerLevelEnum.INFO: LOGGER_MONITOR.info, + LoggerLevelEnum.WARNING: LOGGER_MONITOR.warning, + LoggerLevelEnum.ERROR: LOGGER_MONITOR.error + } + + DYNAMIC_LOG_TO_STDOUT_MAP = { + LoggerLevelEnum.INFO: print_info_msg, + LoggerLevelEnum.WARNING: print_warn_msg, + LoggerLevelEnum.ERROR: print_error_msg + } + + CFG_CONFIG_PATH = None + CFG_BUFFER_SIZE = 1024 * 1024 + POLL_INTERVAL = 2 + + @classmethod + def init_logger(cls, is_monitor_process: bool = False): + logger_ = cls.LOGGER_MONITOR if is_monitor_process else cls.LOGGER + path = cls.CFG_CONFIG_PATH + path = os.path.join(path, 'log') + if not os.path.exists(path): + PathManager.make_dir_safety(path) + worker_name = "{}".format(socket.gethostname()) + log_name = "dp_{}_{}_rank_{}.log".format(worker_name, os.getpid(), cls.get_rank_id()) + if is_monitor_process: + log_name = "monitor_" + log_name + log_file = os.path.join(path, log_name) + if not os.path.exists(log_file): + PathManager.create_file_safety(log_file) + handler = RotatingFileHandler(filename=log_file, maxBytes=1024 * 200, backupCount=1) + formatter = logging.Formatter("%(asctime)s [%(levelname)s] [%(process)d] %(filename)s: %(message)s") + handler.setFormatter(formatter) + logger_.setLevel(logging.DEBUG) + logger_.addHandler(handler) + + @classmethod + def is_dyno_model(cls): + if DynamicProfilerUtils.DYNAMIC_PROFILER_MODEL == DynamicProfilerUtils.DynamicProfilerConfigModel.CFG_CONFIG: + return False + return True + + @classmethod + def out_log(cls, message: str, level: LoggerLevelEnum = LoggerLevelEnum.INFO, is_monitor_process: bool = False): + if not cls.is_dyno_model(): + if not is_monitor_process: + cls.DYNAMIC_LOG_TO_FILE_MAP[level](message) + else: + cls.DYNAMIC_MONITOR_LOG_TO_FILE_MAP[level](message) + else: + cls.stdout_log(message, level) + + @classmethod + def stdout_log(cls, message: str, level: LoggerLevelEnum = LoggerLevelEnum.INFO): + cls.DYNAMIC_LOG_TO_STDOUT_MAP[level](message) + + @staticmethod + def get_rank_id() -> int: + try: + rank_id = os.environ.get('RANK') + if rank_id is None and torch.distributed.is_available() and torch.distributed.is_initialized(): + rank_id = torch.distributed.get_rank() + if not isinstance(rank_id, int): + rank_id = int(rank_id) + except Exception as ex: + print_warn_msg(f"Get rank id {str(ex)}, rank_id will be set to -1 !") + rank_id = -1 + + return rank_id + + @staticmethod + def dyno_str_to_json(res: str): + res_dict = {} + pairs = str(res).split("\n") + char_equal = '=' + for pair in pairs: + str_split = pair.split(char_equal) + if len(str_split) == 2: + res_dict[str_split[0]] = str_split[1] -logger = logging.getLogger("DynamicProfiler") -logger_monitor = logging.getLogger("DynamicProfilerMonitor") - - -def init_logger(logger_: logging.Logger, path: str, is_monitor_process: bool = False): - path = os.path.join(path, 'log') - if not os.path.exists(path): - PathManager.make_dir_safety(path) - worker_name = "{}".format(socket.gethostname()) - log_name = "dp_{}_{}_rank_{}.log".format(worker_name, os.getpid(), _get_rank_id()) - if is_monitor_process: - log_name = "monitor_" + log_name - log_file = os.path.join(path, log_name) - if not os.path.exists(log_file): - PathManager.create_file_safety(log_file) - handler = RotatingFileHandler(filename=log_file, maxBytes=1024 * 200, backupCount=1) - formatter = logging.Formatter("%(asctime)s [%(levelname)s] [%(process)d] %(filename)s: %(message)s") - handler.setFormatter(formatter) - logger_.setLevel(logging.DEBUG) - logger_.addHandler(handler) - - -def _get_rank_id() -> int: - try: - rank_id = os.environ.get('RANK') - if rank_id is None and torch.distributed.is_available() and torch.distributed.is_initialized(): - rank_id = torch.distributed.get_rank() - if not isinstance(rank_id, int): - rank_id = int(rank_id) - except Exception as ex: - logger.warning("Get rank id %s, rank_id will be set to -1 !", str(ex)) - rank_id = -1 - - return rank_id - - -def _get_device_id() -> int: - try: - device_id = os.environ.get('LOCAL_RANK') - if not isinstance(device_id, int): - device_id = int(device_id) - except Exception as ex: - logger.warning("Get device id %s, device_id will be set to -1 !", str(ex)) - device_id = -1 - - return device_id + return res_dict diff --git a/torch_npu/profiler/_non_intrusive_profile.py b/torch_npu/profiler/_non_intrusive_profile.py index ef2fe9cd38..a60303adec 100644 --- a/torch_npu/profiler/_non_intrusive_profile.py +++ b/torch_npu/profiler/_non_intrusive_profile.py @@ -1,10 +1,11 @@ import os +import sys import functools import torch from ..utils._path_manager import PathManager -from .dynamic_profile import _DynamicProfile +from ._dynamic_profiler._dynamic_profiler_utils import DynamicProfilerUtils from .dynamic_profile import init as dp_init from .dynamic_profile import step as dp_step from .analysis.prof_common_func._constant import print_error_msg @@ -58,13 +59,27 @@ def step(*args, **kwargs): @staticmethod def init(): prof_config_path = os.getenv("PROF_CONFIG_PATH", "") - if not prof_config_path: - return + dyno_enable_flag = os.getenv("KINETO_USE_DAEMON", 0) try: - PathManager.check_input_directory_path(prof_config_path) - except RuntimeError: - print_error_msg(f"The path '{prof_config_path}' is invalid, and profiler will not be enabled.") + dyno_enable_flag = int(dyno_enable_flag) + except ValueError: + print_error_msg("Environment variable KINETO_USE_DAEMON value not valid, will be set to 0 !") + dyno_enable_flag = 0 + if not prof_config_path and dyno_enable_flag != 1: + return + is_dyno = True + if prof_config_path: + try: + PathManager.check_input_directory_path(prof_config_path) + except RuntimeError: + print_error_msg(f"The path '{prof_config_path}' is invalid, and profiler will not be enabled.") + return + is_dyno = False + if is_dyno and sys.version_info < (3, 8): + print_error_msg(f"Dynolog only supported above Python 3.8 !.") return + elif is_dyno: + DynamicProfilerUtils.DYNAMIC_PROFILER_MODEL = DynamicProfilerUtils.DynamicProfilerConfigModel.DYNO_CONFIG dp_init(prof_config_path) if torch.__version__ >= "2.0.0": torch.optim.Optimizer._patch_step_function = _NonIntrusiveProfile.patch_step_function diff --git a/torch_npu/profiler/dynamic_profile.py b/torch_npu/profiler/dynamic_profile.py index 096a74370a..2e248b24cd 100644 --- a/torch_npu/profiler/dynamic_profile.py +++ b/torch_npu/profiler/dynamic_profile.py @@ -8,12 +8,9 @@ from .analysis.prof_common_func._singleton import Singleton from ..utils._path_manager import PathManager -from .analysis.prof_common_func._constant import print_info_msg -from .analysis.prof_common_func._constant import print_warn_msg -from .analysis.prof_common_func._constant import print_error_msg from .analysis.prof_common_func._utils import no_exception_func from .analysis.prof_common_func._file_manager import FileManager -from ._dynamic_profiler._dynamic_profiler_utils import logger, init_logger +from ._dynamic_profiler._dynamic_profiler_utils import DynamicProfilerUtils from ._dynamic_profiler._dynamic_profiler_monitor import DynamicProfilerMonitor from ._dynamic_profiler._dynamic_profiler_config_context import ConfigContext @@ -27,8 +24,6 @@ @Singleton class _DynamicProfile: RECORD_TIME_STEP = 10 - CFG_BUFFER_SIZE = 1024 * 1024 - POLL_INTERVAL = 2 def __init__(self) -> None: self.prof = None @@ -42,11 +37,12 @@ def __init__(self) -> None: self._step_time = 0 self._min_poll_interval = 1 - def init(self, path: str): + def init(self): if self.repeat_init: - print_warn_msg("Init dynamic profiling repeatedly") + DynamicProfilerUtils.stdout_log("Init dynamic profiling repeatedly", + DynamicProfilerUtils.LoggerLevelEnum.WARNING) return - self._dynamic_monitor = DynamicProfilerMonitor(path, self.CFG_BUFFER_SIZE, self.POLL_INTERVAL) + self._dynamic_monitor = DynamicProfilerMonitor() self.repeat_init = True atexit.register(self._clean_resource) @@ -54,7 +50,9 @@ def _clean_resource(self): if self.prof is not None: self.prof.stop() self.prof = None - print_warn_msg("Profiler stop when process exit, check cfg json active whether over all step!") + DynamicProfilerUtils.stdout_log( + "Profiler stop when process exit, check cfg json active whether over all step!", + DynamicProfilerUtils.LoggerLevelEnum.WARNING) self._dynamic_monitor.clean_resource() def _dynamic_profiler_valid(self): @@ -77,7 +75,8 @@ def step(self): if 0 == self.step_num: self.prof.stop() self.prof = None - logger.info(f"Stop Dynamic Profiler at {self.cur_step} step.") + DynamicProfilerUtils.out_log("Stop Dynamic Profiler at {} step.".format( + self.cur_step), DynamicProfilerUtils.LoggerLevelEnum.INFO) elif self.prof is None and self.cfg_ctx is not None and self.cur_step == self.cfg_ctx.start_step(): self.step_num = self.cfg_ctx.active() self.enable_prof() @@ -85,7 +84,9 @@ def step(self): def start(self, config_path: str): if self.prof: - print_error_msg(f"Profiler already started. Cannot call start interface while the profiler is active. ") + DynamicProfilerUtils.stdout_log("Profiler already started. " + "Cannot call start interface while the profiler is active. ", + DynamicProfilerUtils.LoggerLevelEnum.ERROR) return enable_config_path = "" if config_path: @@ -94,18 +95,22 @@ def start(self, config_path: str): PathManager.check_directory_path_readable(config_path) enable_config_path = config_path except Exception as err: - logger.error(f"The provided config_path is invalid: {config_path}. Details: {err}") + DynamicProfilerUtils.stdout_log("The provided config_path is invalid: {}. Details: {}".format( + config_path, str(err)), DynamicProfilerUtils.LoggerLevelEnum.ERROR) enable_config_path = "" if not enable_config_path: enable_config_path = self._dynamic_monitor._config_path - print_info_msg(f"The start interface profiler enable config path is set to {enable_config_path}") + DynamicProfilerUtils.stdout_log("The start interface profiler enable config path is set to {}".format( + enable_config_path), DynamicProfilerUtils.LoggerLevelEnum.INFO) try: json_data = FileManager.read_json_file(enable_config_path) if not json_data: - print_error_msg(f"The config data is empty from: {enable_config_path}. Please check the config file. ") + DynamicProfilerUtils.stdout_log("The config data is empty from: {}. Please check the config file. ".format( + enable_config_path), DynamicProfilerUtils.LoggerLevelEnum.ERROR) return except RuntimeError: - print_error_msg(f"Failed to read config from : {enable_config_path}. Please check the config file. ") + DynamicProfilerUtils.stdout_log("Failed to read config from : {}. Please check the config file. ".format( + enable_config_path), DynamicProfilerUtils.LoggerLevelEnum.ERROR) return self.cfg_ctx = ConfigContext(json_data) self.step_num = self.cfg_ctx.active() @@ -128,19 +133,24 @@ def enable_prof(self): self.prof.start() for key, value in self.cfg_ctx.meta_data().items(): self.prof.add_metadata_json(str(key), json.dumps(value)) - logger.info(f"Start Dynamic Profiler at {self.cur_step} step.") + DynamicProfilerUtils.out_log("Start Dynamic Profiler at {} step.".format( + self.cur_step), DynamicProfilerUtils.LoggerLevelEnum.INFO) @no_exception_func() def init(path: str): + if DynamicProfilerUtils.is_dyno_model(): + _DynamicProfile().init() + return try: PathManager.check_input_directory_path(path) except RuntimeError: - print_error_msg(f"The path '{path}' is invalid, and profiler will not be enabled.") + DynamicProfilerUtils.stdout_log("The path '{}' is invalid, and profiler will not be enabled.".format( + path), DynamicProfilerUtils.LoggerLevelEnum.ERROR) return - dp_path = os.path.abspath(path) - init_logger(logger, dp_path) - _DynamicProfile().init(dp_path) + DynamicProfilerUtils.CFG_CONFIG_PATH = os.path.abspath(path) + DynamicProfilerUtils.init_logger() + _DynamicProfile().init() @no_exception_func()