Skip to content

Commit

Permalink
update for XEvent
Browse files Browse the repository at this point in the history
  • Loading branch information
cj401-amd committed Nov 16, 2024
1 parent f39ddbb commit 030d8e8
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 64 deletions.
1 change: 0 additions & 1 deletion xla/backends/profiler/gpu/device_tracer_rocm.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,6 @@ absl::Status GpuTracer::DoStart() {

rocm_tracer_->setup();
rocm_tracer_->start();
rocm_tracer_->identify(1);

uint64_t start_gputime_ns = rocm_tracer_->GetTimestamp();
uint64_t start_walltime_ns = tsl::EnvTime::NowNanos();
Expand Down
13 changes: 5 additions & 8 deletions xla/backends/profiler/gpu/rocm_collector.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ enum class RocmTracerEventType {
Unsupported = 0,
HIP_RUNTIME_API,
KERNEL_DISPATCH,
MEMORY_COPY,
};

const char* GetRocmTracerEventTypeName(const RocmTracerEventType& type);
Expand Down Expand Up @@ -116,16 +117,12 @@ struct RocmTracerEvent {
// RocmTracerEventDomain domain;
RocmTracerEventType type;
std::string name;
// This points to strings in AnnotationMap, which should outlive the point
// where serialization happens.
// absl::string_view annotation;
// absl::string_view roctx_range;
uint64_t start_time_ns = 0;
uint64_t end_time_ns = 0;
uint32_t device_id = kInvalidDeviceId;
uint32_t correlation_id = kInvalidCorrelationId;
uint32_t thread_id = kInvalidThreadId;
int64_t stream_id = kInvalidStreamId;
uint32_t device_id = 0;
uint32_t correlation_id = 0;
uint32_t thread_id = 0;
int64_t stream_id = 0;
};

struct RocmTraceCollectorOptions {
Expand Down
87 changes: 34 additions & 53 deletions xla/backends/profiler/gpu/rocm_tracer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,18 +183,12 @@ tool_tracing_callback(rocprofiler_context_id_t context,
"array of headers. this should never happen"};
*/

auto* collector_impl = static_cast<RocmTraceCollector*>(user_data);
LOG(ERROR) << "Number of gpus = " << &collector_impl;


LOG(INFO) << "Number of heads = " << num_headers;
LOG(INFO) << "Tracing category = " << ROCPROFILER_BUFFER_CATEGORY_TRACING;
for(size_t i = 0; i < num_headers; ++i)
{
auto* header = headers[i];

RocmTracerEvent event;

auto kind_name = std::string{};
LOG(INFO) << "head category = " << header->category;
LOG(INFO) << "head kind = " << header->kind;
Expand Down Expand Up @@ -223,13 +217,15 @@ tool_tracing_callback(rocprofiler_context_id_t context,
auto* record =
static_cast<rocprofiler_buffer_tracing_hip_api_record_t*>(header->payload);

/*
event.type = RocmTracerEventType::HIP_RUNTIME_API;
event.start_time_ns = record->start_timestamp;
event.end_time_ns = record->end_timestamp;
// event.device_id = record->dispatch_info.agent_id.handle;;
// event.stream_id = record->stream_id;
event.correlation_id = record->correlation_id.internal;
event.name = client_name_info[record->kind][record->operation];
*/

auto info = std::stringstream{};
info << "tid=" << record->thread_id << ", context=" << context.handle
Expand All @@ -250,23 +246,31 @@ tool_tracing_callback(rocprofiler_context_id_t context,
// throw std::runtime_error{msg.str()};
}

static_cast<call_stack_t*>(user_data)->emplace_back(
source_location{__FUNCTION__, __FILE__, __LINE__, kind_name + info.str()});
static_cast<RocmTracerEvent*>(user_data)->emplace_back(
RocmTracerEvent{RocmTracerEventType::HIP_RUNTIME_API,
client_name_info[record->kind][record->operation],
record->start_timestamp,
record->end_timestamp,
0, // how to access device id,
record->correlation_id.internal,
record->thread_id,
record->stream_id});
}
else if(header->category == ROCPROFILER_BUFFER_CATEGORY_TRACING &&
header->kind == ROCPROFILER_BUFFER_TRACING_KERNEL_DISPATCH)
{
auto* record =
static_cast<rocprofiler_buffer_tracing_kernel_dispatch_record_t*>(header->payload);

/*
event.type = RocmTracerEventType::KERNEL_DISPATCH;
event.start_time_ns = record->start_timestamp;
event.end_time_ns = record->end_timestamp;
event.device_id = record->dispatch_info.agent_id.handle;
event.stream_id = record->dispatch_info.queue_id.handle;
event.correlation_id = record->correlation_id.internal;
event.name = client_kernels.at(record->dispatch_info.kernel_id).kernel_name;

*/
auto info = std::stringstream{};

info << "tid=" << record->thread_id << ", context=" << context.handle
Expand Down Expand Up @@ -296,8 +300,15 @@ tool_tracing_callback(rocprofiler_context_id_t context,
printf("kernel dispatch: start > end");
// throw std::runtime_error("kernel dispatch: start > end");

static_cast<call_stack_t*>(user_data)->emplace_back(
source_location{__FUNCTION__, __FILE__, __LINE__, kind_name + info.str()});
static_cast<RocmTracerEvent*>(user_data)->emplace_back(
RocmTracerEvent{RocmTracerEventType::KERNEL_DISPATCH,
client_name_info[record->kind][record->operation],
record->start_timestamp,
record->end_timestamp,
0, // how to access device id,
record->correlation_id.internal,
record->thread_id,
record->stream_id});
}
else if(header->category == ROCPROFILER_BUFFER_CATEGORY_TRACING &&
header->kind == ROCPROFILER_BUFFER_TRACING_MEMORY_COPY)
Expand All @@ -322,8 +333,15 @@ tool_tracing_callback(rocprofiler_context_id_t context,
printf("memory copy: start > end \n");
// throw std::runtime_error("memory copy: start > end");

static_cast<call_stack_t*>(user_data)->emplace_back(
source_location{__FUNCTION__, __FILE__, __LINE__, kind_name + info.str()});
static_cast<RocmTracerEvent*>(user_data)->emplace_back(
RocmTracerEvent{RocmTracerEventType::MEMORY_COPY,
client_name_info[record->kind][record->operation],
record->start_timestamp,
record->end_timestamp,
0, // how to access device id,
record->correlation_id.internal,
record->thread_id,
record->stream_id});
}
else
{
Expand All @@ -333,45 +351,19 @@ tool_tracing_callback(rocprofiler_context_id_t context,
std::cout << _msg.str() << std::endl;
// throw std::runtime_error{_msg.str()};
}
// Pass the created event to RocmTraceCollectorImpl
LOG(ERROR) << "Prepare transferring event to RocmTraceCollectorImpl";
collector_impl->AddEvent(event);
}
}

int tool_init(rocprofiler_client_finalize_t fini_func, void* tool_data)
{
assert(tool_data != nullptr);

auto* call_stack_v = static_cast<call_stack_t*>(tool_data);

call_stack_v->emplace_back(source_location{__FUNCTION__, __FILE__, __LINE__, ""});
// auto* call_stack_v = static_cast<call_stack_t*>(tool_data);
// call_stack_v->emplace_back(source_location{__FUNCTION__, __FILE__, __LINE__, ""});

client_name_info = rocm_get_buffer_tracing_names<std::string_view>();
// client_name_info = get_default_buffer_tracing_names();

for(const auto& itr : client_name_info)
{
auto name_idx = std::stringstream{};
name_idx << " [" << std::setw(3) << itr.value << "]";
call_stack_v->emplace_back(
source_location{"rocprofiler_buffer_tracing_kind_names " + name_idx.str(),
__FILE__,
__LINE__,
std::string{itr.name}});

for(auto [didx, ditr] : itr.items())
{
auto operation_idx = std::stringstream{};
operation_idx << " [" << std::setw(3) << didx << "]";
call_stack_v->emplace_back(source_location{
"rocprofiler_buffer_tracing_kind_operation_names" + operation_idx.str(),
__FILE__,
__LINE__,
std::string{"- "} + std::string{*ditr}});
}
}

client_fini_func = fini_func;

ROCPROFILER_CALL(se::wrap::rocprofiler_create_context(&client_ctx), "context creation");
Expand Down Expand Up @@ -435,9 +427,6 @@ int tool_init(rocprofiler_client_finalize_t fini_func, void* tool_data)
void tool_fini(void* tool_data){
assert(tool_data != nullptr);

auto* _call_stack = static_cast<call_stack_t*>(tool_data);
_call_stack->emplace_back(source_location{__FUNCTION__, __FILE__, __LINE__, ""});

print_call_stack(*_call_stack);

delete _call_stack;
Expand Down Expand Up @@ -496,14 +485,6 @@ void RocmTracer::start(){
ROCPROFILER_CALL(se::wrap::rocprofiler_start_context(client_ctx), "context start");
}

void RocmTracer::identify(uint64_t val){
auto _tid = rocprofiler_thread_id_t{};
se::wrap::rocprofiler_get_thread_id(&_tid);
rocprofiler_user_data_t user_data = {};
user_data.value = val;
se::wrap::rocprofiler_push_external_correlation_id(client_ctx, _tid, user_data);
}

void RocmTracer::stop(){
ROCPROFILER_CALL(se::wrap::rocprofiler_stop_context(client_ctx), "context stop");
}
Expand Down Expand Up @@ -545,7 +526,7 @@ rocprofiler_configure(uint32_t version,

std::clog << info.str() << std::endl;

auto* client_tool_data = new std::vector<>{};
auto* client_tool_data = new std::vector<RocmTracerEvent>{};

// create configure data
static auto cfg =
Expand Down
2 changes: 0 additions & 2 deletions xla/backends/profiler/gpu/rocm_tracer.h
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,10 @@ class RocmTracer {
static uint64_t GetTimestamp();
static int NumGpus();


void setup() CLIENT_API;
void start() CLIENT_API;
void stop() CLIENT_API;
void shutdown() CLIENT_API;
void identify(uint64_t corr_id) CLIENT_API;

private:
// Private constructor for singleton
Expand Down

0 comments on commit 030d8e8

Please sign in to comment.