diff --git a/xla/backends/profiler/gpu/device_tracer_rocm.cc b/xla/backends/profiler/gpu/device_tracer_rocm.cc index ae69b22f9c328..4501cd45bc383 100644 --- a/xla/backends/profiler/gpu/device_tracer_rocm.cc +++ b/xla/backends/profiler/gpu/device_tracer_rocm.cc @@ -70,8 +70,7 @@ namespace se = ::stream_executor; // GpuTracer for ROCm GPU. class GpuTracer : public profiler::ProfilerInterface { public: - GpuTracer() { - // se::rocprofiler_force_configure + GpuTracer(RocmTracer* rocmtracer) : rocm_tracer_(rocmtracer) { LOG(ERROR) << "GpuTrace with rocprofv3...\n"; Start(); LOG(INFO) << "GpuTracer created..."; @@ -120,14 +119,12 @@ RocmTraceCollectorOptions GpuTracer::GetRocmTraceCollectorOptions( } absl::Status GpuTracer::DoStart() { - /* if (!rocm_tracer_->IsAvailable()) { return tsl::errors::Unavailable("Another profile session running."); } - */ + // AnnotationStack::Enable(true); -/* RocmTraceCollectorOptions trace_collector_options = GetRocmTraceCollectorOptions(rocm_tracer_->NumGpus()); uint64_t start_gputime_ns = rocm_tracer_->GetTimestamp(); @@ -135,11 +132,11 @@ absl::Status GpuTracer::DoStart() { rocm_trace_collector_ = CreateRocmCollector( trace_collector_options, start_walltime_ns, start_gputime_ns); - RocmTracerOptions tracer_options = GetRocmTracerOptions(); - rocm_tracer_->Enable(tracer_options, rocm_trace_collector_.get()); - */ - LOG(ERROR) << "cj rocm_tracer_collector = " << rocm_trace_collector_.get(); - LOG(ERROR) << "cj rocm_tracer_ collector = " << rocm_tracer_->get_collector(); + // RocmTracerOptions tracer_options = GetRocmTracerOptions(); + // rocm_tracer_->Enable(tracer_options, rocm_trace_collector_.get()); + + // LOG(ERROR) << "cj rocm_tracer_collector = " << rocm_trace_collector_.get(); + // LOG(ERROR) << "cj rocm_tracer_ collector = " << rocm_tracer_->get_collector(); // LOG(ERROR) << "cj check XSpace = " << space; LOG(ERROR) << "DO START ..."; @@ -151,8 +148,14 @@ absl::Status GpuTracer::DoStart() { return absel::; } */ - rocm_tracer_->setup(); - rocm_tracer_->start(); + for (auto& event: rocm_tracer_->GetEvents()) { + rocm_trace_collector_->AddEvent(std::move(event)); + } + LOG(ERROR) << "DO START after moving events..."; + rocm_trace_collector_->Flush(); + LOG(ERROR) << "DO START after flush..."; + + LOG(ERROR) << "Export XSpace after flush..."; return absl::OkStatus(); } @@ -168,8 +171,8 @@ absl::Status GpuTracer::Start() { } absl::Status GpuTracer::DoStop() { - rocm_tracer_->stop(); - rocm_tracer_->shutdown(); + // rocm_tracer_->stop(); + // rocm_tracer_->shutdown(); return absl::OkStatus(); } @@ -182,6 +185,9 @@ absl::Status GpuTracer::Stop() { } absl::Status GpuTracer::CollectData(XSpace* space) { + if (rocm_trace_collector_) rocm_trace_collector_->Export(space); + LOG(ERROR) << "CollectData XSpace = " << space; + switch (profiling_state_) { case State::kNotStarted: VLOG(3) << "No trace data collected, session wasn't started"; @@ -196,7 +202,7 @@ absl::Status GpuTracer::CollectData(XSpace* space) { VLOG(3) << "No trace data collected"; return absl::OkStatus(); case State::kStoppedOk: { - if (rocm_trace_collector_) rocm_trace_collector_->Export(space); + // if (rocm_trace_collector_) rocm_trace_collector_->Export(space); return absl::OkStatus(); } } @@ -211,15 +217,15 @@ std::unique_ptr CreateGpuTracer( return nullptr; } - /* profiler::RocmTracer* rocm_tracer = profiler::RocmTracer::GetRocmTracerSingleton(); LOG(ERROR) << "cj rocm_tracer is available = " << rocm_tracer->IsAvailable(); + LOG(ERROR) << "Traced events = " << rocm_tracer->GetEvents().size(); if (!rocm_tracer->IsAvailable()) { return nullptr; } - */ - return std::make_unique(); + + return std::make_unique(rocm_tracer); } auto register_rocm_gpu_tracer_factory = [] { diff --git a/xla/backends/profiler/gpu/rocm_collector.cc b/xla/backends/profiler/gpu/rocm_collector.cc index 6d6a5d8bb2b2e..d502438b73a89 100644 --- a/xla/backends/profiler/gpu/rocm_collector.cc +++ b/xla/backends/profiler/gpu/rocm_collector.cc @@ -455,9 +455,9 @@ class RocmTraceCollectorImpl : public profiler::RocmTraceCollector { }; void RocmTraceCollectorImpl::AddEvent(RocmTracerEvent&& event) { - LOG(ERROR) << "Starting RocmTraceCollectorImpl::AddEvent"; - // mutex_lock lock(event_maps_mutex_); - // events_.push_back(std::move(event)); + // LOG(ERROR) << "Starting RocmTraceCollectorImpl::AddEvent"; + mutex_lock lock(event_maps_mutex_); + events_.push_back(std::move(event)); } void RocmTraceCollectorImpl::Flush() { diff --git a/xla/backends/profiler/gpu/rocm_tracer.cc b/xla/backends/profiler/gpu/rocm_tracer.cc index 8159151fbba9b..b806a63f38d48 100644 --- a/xla/backends/profiler/gpu/rocm_tracer.cc +++ b/xla/backends/profiler/gpu/rocm_tracer.cc @@ -59,7 +59,7 @@ extern "C" rocprofiler_tool_configure_result_t* rocprofiler_configure( rocprofiler_client_id_t* id ); -auto rocmtracer_singleton = xla::profiler::RocmTracer::GetRocmTracerSingleton(); +// auto rocmtracer_singleton = xla::profiler::RocmTracer::GetRocmTracerSingleton(); template using buffer_name_info_t = rocprofiler::sdk::utility::name_info; @@ -67,7 +67,8 @@ using buffer_name_info_t = rocprofiler::sdk::utility::name_infocategory == ROCPROFILER_BUFFER_CATEGORY_TRACING && header->kind == ROCPROFILER_BUFFER_TRACING_MEMORY_COPY) @@ -519,8 +545,8 @@ void RocmTracer::stop(){ ROCPROFILER_CALL(se::wrap::rocprofiler_stop_context(client_ctx), "context stop"); } - -/* static */ RocmTracer* RocmTracer::GetRocmTracerSingleton() { +/* +RocmTracer* RocmTracer::GetRocmTracerSingleton() { LOG(INFO) << "Entering GetRocmTracerSingleton..."; static std::once_flag flag; @@ -538,10 +564,15 @@ void RocmTracer::stop(){ abort(); // Ensure the program stops if initialization fails. } - LOG(INFO) << "Returning RocmTracer singleton instance."; + LOG(INFO) << "Returning RocmTracer singleton instance." << instance; return instance; } +*/ +/* static */ RocmTracer* RocmTracer::GetRocmTracerSingleton() { + static auto* singleton = new RocmTracer(); + return singleton; +} bool RocmTracer::IsAvailable() const { return GetRocmTracerSingleton() != nullptr; @@ -602,24 +633,6 @@ rocprofiler_configure(uint32_t version, // store client info xla::profiler::client_id = id; LOG(ERROR) << "Configure rocprofiler-sdk...\n"; - - // auto rocmtracer_singleton = xla::profiler::RocmTracer::GetRocmTracerSingleton(); - // LOG(ERROR) << "cj -1 rocprofiler_configure() with rocm collector"; - /* - auto trace_collector_options = GetRocmTraceCollectorOptions(rocmtracer_singleton->NumGpus()); - LOG(ERROR) << "cj 0 rocprofiler_configure() with rocm collector"; - - uint64_t start_gputime_ns = rocmtracer_singleton->GetTimestamp(); - LOG(ERROR) << "cj 1 rocprofiler_configure() with rocm collector"; - uint64_t start_walltime_ns = tsl::EnvTime::NowNanos(); - auto rocm_trace_collector_ = xla::profiler::CreateRocmCollector( - trace_collector_options, start_walltime_ns, start_gputime_ns); - LOG(ERROR) << "cj 2 rocprofiler_configure() with rocm collector"; - auto tracer_options = GetRocmTracerOptions(); - LOG(ERROR) << "cj 3 rocprofiler_configure() with rocm collector"; - rocmtracer_singleton->Enable(tracer_options, rocm_trace_collector_.get()); - LOG(ERROR) << "cj 4 rocprofiler_configure() with rocm collector"; - */ // compute major/minor/patch version info uint32_t major = version / 10000; diff --git a/xla/backends/profiler/gpu/rocm_tracer.h b/xla/backends/profiler/gpu/rocm_tracer.h index 75b741c5654cc..41660134cc440 100644 --- a/xla/backends/profiler/gpu/rocm_tracer.h +++ b/xla/backends/profiler/gpu/rocm_tracer.h @@ -41,6 +41,8 @@ limitations under the License. namespace xla { namespace profiler { +// std::vector all_rocm_events_1; + struct RocmTracerOptions { std::set api_tracking_set; // actual api set we want to profile @@ -65,6 +67,8 @@ class RocmTracer { static int NumGpus(); void Enable(RocmTraceCollector* collector); RocmTraceCollector* get_collector() { return collector_; } + void AppendEvent(RocmTracerEvent event) { rocm_events_.push_back(event); } + RocmTracerEvent_t GetEvents() {return rocm_events_;} void setup() CLIENT_API; void start() CLIENT_API; @@ -80,6 +84,7 @@ class RocmTracer { int num_gpus_; // std::optional options_; RocmTraceCollector* collector_ = nullptr; + RocmTracerEvent_t rocm_events_; static tsl::mutex mtx; public: