Skip to content

Commit

Permalink
Merge pull request #2 from PFCCLab/task-2
Browse files Browse the repository at this point in the history
【Optimize code】Parse log to dict
  • Loading branch information
huajiao-hjyp authored Sep 20, 2023
2 parents efbecff + a6f6703 commit 2df3178
Show file tree
Hide file tree
Showing 4 changed files with 108 additions and 84 deletions.
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ pip install graphviz

```python
import os

# 将日志级别设置为6
os.environ['GLOG_v'] = '6'

import paddle
Expand Down
15 changes: 15 additions & 0 deletions example.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,27 @@ def forward(self, img):
output = self.fc(feature.reshape([img.shape[0], -1]))
return output

class Transformer(nn.Layer):
def __init__(self) -> None:
super(Transformer, self).__init__()
encoder_layer = nn.TransformerEncoderLayer(128, 2, 512)
self.encoder = nn.TransformerEncoder(encoder_layer, 2)


def forward(self, img):
img = self.encoder(img)
return img


if __name__ == '__main__':

# 定义网络
model = Model()
x = paddle.randn([1, 3, 32, 32])

# model = Transformer()
# x = paddle.randn([2, 4, 128])

# 正向推理
y = model(x)

Expand Down
Binary file modified images/result.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
175 changes: 91 additions & 84 deletions paddleviz/viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,118 +77,125 @@ def add_output_tensor(var, color='darkolivegreen1'):
return dot


def processOPLog(str, op_id, dot):
""" extract edge info of operator from output log
:param str: the log of operator
:param op_id: the pointer of oparator
:param dot: the whole dot
"""
start = str.find("Input")

while str[start] != '\n':
start += 1
def add_edge_info(dot):
with open('./output.txt', encoding='utf-8') as f:
content = f.read()

start = 0

start = str.find('(', start) + 1
while content.find('gradnode_ptr', start) != -1:
start = content.find('gradnode_ptr', start)
end = content.find('\n', start)
if content[end - 1] == '\r':
end -= 1

start += 15
op_ptr = content[start: end].strip(' ')

end = content.find("backward.cc:288", start)

end = str.find("Output")
op_log = content[start: end]

# handle node of input
p = start
# Get operator input and output objects according to the log
op_input_output = parseOpLog(op_log, op_ptr, dot)

# each line of the operator's input is processed in turn using a double pointer
while p != -1 and p < end:

def parseOpLog(op_log, op_ptr, dot):
"""
extract edge info of operator from output log
# find carriage return
while str[p] != '\n':
p += 1
:param op_log: the log of operator
:param op_ptr: the pointer of oparator
:param dot: the whole dot
"""

op_input_output = {}

# find the properties of the edge
name = str[start + 1: str.find(',', start)].strip(' ')

start = str.find("Ptr", start)
ptr = str[start + 5: str.find(',', start)]
op_log = op_log.replace('\n', ' ')

start = str.find("Dtype", start)
dtype = str[start + 7: str.find(',', start)]
start = op_log.find("Input")

start = str.find("Place", start)
place = str[start + 7: str.find(',', start)]
end = op_log.find("Output")

start = str.find("Shape", start)
shape = '[{}]'.format(str[start + 7: str.find(']', start)].strip(' '))
# Find the input parameters of the operator
op_input_output["input"] = parseMultiParam(op_log[start: end]) if start != -1 else []

p = str.find('(', p, end)
# Process the input parameters of the operator in turn
for param in op_input_output["input"]:

start = p

if 'grad' not in name:
if 'grad' not in param["name"]:
continue

print("input_op: {} -> {}".format(op_id, ptr))
param_ptr = param["ptr"]

if ptr not in grad_nodes:
grad_nodes[ptr] = {}
if param_ptr not in grad_nodes:
grad_nodes[param_ptr] = {}

# if have previously recorded which operator output came from, need to add information on the side
if "output_op" in grad_nodes[ptr]:
edge_info = "dtype: {} \n place: {} \n shape: {} \n".format(dtype, place, shape)
dot.edge(grad_nodes[ptr]["output_op"], op_id, _attributes={'label': edge_info})


if "output_op" in grad_nodes[param_ptr]:
edge_info = "dtype: {} \n place: {} \n shape: {} \n".format(param["dtype"], param["place"], param["shape"])
dot.edge(grad_nodes[param_ptr]["output_op"], op_ptr, _attributes={'label': edge_info})

# handle node of output
start = end
while str[start] != '\n':
start += 1

start = str.find('(', start)
# Find the output parameters of the operator
op_input_output["output"] = parseMultiParam(op_log[end: ]) if end != -1 else []

end = len(str)

p = start
# Process the output parameters of the operator in turn
for param in op_input_output["output"]:
param_ptr = param["ptr"]

# each line of the operator's output is processed in turn using a double pointer
while p != -1 and p < end:
if param_ptr not in grad_nodes:
grad_nodes[param_ptr] = {}

grad_nodes[param_ptr]["output_op"] = op_ptr

# find carriage return
while p < end and str[p] != '\n':
p += 1
return op_input_output

# find the properties of the edge
name = str[start + 1: str.find(',', start)].strip(' ')

start = str.find("Ptr", start)
ptr = str[start + 5: str.find(',', start)]
def parseMultiParam(multi_param_log):
multi_param = []
start, end = 0, len(multi_param_log)

if ptr not in grad_nodes:
grad_nodes[ptr] = {}

grad_nodes[ptr]["output_op"] = op_id
print("output_op: {} -> {}".format(op_id, ptr))
while multi_param_log.find('(', start, end) != -1:
param_start = multi_param_log.find('(', start, end)
param_end = multi_param_log.find('}]),', start, end)
start = param_end + 1
# Converts a parameter string to an object
param = parseParam(multi_param_log[param_start + 1: param_end])
multi_param.append(param)

p = str.find('(', p, end)

start = p
return multi_param


def add_edge_info(dot):
with open('./output.txt', encoding='utf-8') as f:
content = f.read()

def parseParam(param_log):
param = {}
start = 0

while content.find('gradnode_ptr', start) != -1:
start = content.find('gradnode_ptr', start)
end = content.find('\n', start)
if content[end - 1] == '\r':
end -= 1

start += 15
op_ptr = content[start: end]

name = param_log[start: param_log.find(',', start)].strip(' ')
param["name"] = name

end = content.find("backward.cc:288", start)
start = param_log.find("Ptr", start)
ptr = param_log[start + 5: param_log.find(',', start)].strip(' ')
param["ptr"] = ptr

start = param_log.find("Dtype", start)
if start == -1:
param["dtype"] = None
param["place"] = None
param["shape"] = None
return param

dtype = param_log[start + 7: param_log.find(',', start)].strip(' ')
param["dtype"] = dtype

op_log = content[start: end]
start = param_log.find("Place", start)
place = param_log[start + 7: param_log.find(',', start)].strip(' ')
param["place"] = place

start = param_log.find("Shape", start)
shape = '[{}]'.format(param_log[start + 7: param_log.find(']', start)].strip(' '))
param["shape"] = shape

# print(param)

processOPLog(op_log, op_ptr, dot)
return param

0 comments on commit 2df3178

Please sign in to comment.