diff --git a/README.md b/README.md index a44c458..ac62bfc 100644 --- a/README.md +++ b/README.md @@ -33,6 +33,8 @@ pip install graphviz ```python import os + +# 将日志级别设置为6 os.environ['GLOG_v'] = '6' import paddle diff --git a/example.py b/example.py index 606813d..39157f1 100644 --- a/example.py +++ b/example.py @@ -30,12 +30,27 @@ def forward(self, img): output = self.fc(feature.reshape([img.shape[0], -1])) return output +class Transformer(nn.Layer): + def __init__(self) -> None: + super(Transformer, self).__init__() + encoder_layer = nn.TransformerEncoderLayer(128, 2, 512) + self.encoder = nn.TransformerEncoder(encoder_layer, 2) + + + def forward(self, img): + img = self.encoder(img) + return img + + if __name__ == '__main__': # 定义网络 model = Model() x = paddle.randn([1, 3, 32, 32]) + # model = Transformer() + # x = paddle.randn([2, 4, 128]) + # 正向推理 y = model(x) diff --git a/images/result.png b/images/result.png index a0d08f7..73696aa 100644 Binary files a/images/result.png and b/images/result.png differ diff --git a/paddleviz/viz.py b/paddleviz/viz.py index c0763e2..1f48b30 100644 --- a/paddleviz/viz.py +++ b/paddleviz/viz.py @@ -77,118 +77,125 @@ def add_output_tensor(var, color='darkolivegreen1'): return dot -def processOPLog(str, op_id, dot): - """ extract edge info of operator from output log - - :param str: the log of operator - :param op_id: the pointer of oparator - :param dot: the whole dot - """ - start = str.find("Input") - - while str[start] != '\n': - start += 1 +def add_edge_info(dot): + with open('./output.txt', encoding='utf-8') as f: + content = f.read() + + start = 0 - start = str.find('(', start) + 1 + while content.find('gradnode_ptr', start) != -1: + start = content.find('gradnode_ptr', start) + end = content.find('\n', start) + if content[end - 1] == '\r': + end -= 1 + + start += 15 + op_ptr = content[start: end].strip(' ') + + end = content.find("backward.cc:288", start) - end = str.find("Output") + op_log = content[start: end] - # handle node of input - p = start + # Get operator input and output objects according to the log + op_input_output = parseOpLog(op_log, op_ptr, dot) - # each line of the operator's input is processed in turn using a double pointer - while p != -1 and p < end: + +def parseOpLog(op_log, op_ptr, dot): + """ + extract edge info of operator from output log - # find carriage return - while str[p] != '\n': - p += 1 + :param op_log: the log of operator + :param op_ptr: the pointer of oparator + :param dot: the whole dot + """ + + op_input_output = {} - # find the properties of the edge - name = str[start + 1: str.find(',', start)].strip(' ') - - start = str.find("Ptr", start) - ptr = str[start + 5: str.find(',', start)] + op_log = op_log.replace('\n', ' ') - start = str.find("Dtype", start) - dtype = str[start + 7: str.find(',', start)] + start = op_log.find("Input") - start = str.find("Place", start) - place = str[start + 7: str.find(',', start)] + end = op_log.find("Output") - start = str.find("Shape", start) - shape = '[{}]'.format(str[start + 7: str.find(']', start)].strip(' ')) + # Find the input parameters of the operator + op_input_output["input"] = parseMultiParam(op_log[start: end]) if start != -1 else [] - p = str.find('(', p, end) + # Process the input parameters of the operator in turn + for param in op_input_output["input"]: - start = p - - if 'grad' not in name: + if 'grad' not in param["name"]: continue - print("input_op: {} -> {}".format(op_id, ptr)) + param_ptr = param["ptr"] - if ptr not in grad_nodes: - grad_nodes[ptr] = {} + if param_ptr not in grad_nodes: + grad_nodes[param_ptr] = {} # if have previously recorded which operator output came from, need to add information on the side - if "output_op" in grad_nodes[ptr]: - edge_info = "dtype: {} \n place: {} \n shape: {} \n".format(dtype, place, shape) - dot.edge(grad_nodes[ptr]["output_op"], op_id, _attributes={'label': edge_info}) - - + if "output_op" in grad_nodes[param_ptr]: + edge_info = "dtype: {} \n place: {} \n shape: {} \n".format(param["dtype"], param["place"], param["shape"]) + dot.edge(grad_nodes[param_ptr]["output_op"], op_ptr, _attributes={'label': edge_info}) - # handle node of output - start = end - while str[start] != '\n': - start += 1 - start = str.find('(', start) + # Find the output parameters of the operator + op_input_output["output"] = parseMultiParam(op_log[end: ]) if end != -1 else [] - end = len(str) - - p = start + # Process the output parameters of the operator in turn + for param in op_input_output["output"]: + param_ptr = param["ptr"] - # each line of the operator's output is processed in turn using a double pointer - while p != -1 and p < end: + if param_ptr not in grad_nodes: + grad_nodes[param_ptr] = {} + + grad_nodes[param_ptr]["output_op"] = op_ptr - # find carriage return - while p < end and str[p] != '\n': - p += 1 + return op_input_output - # find the properties of the edge - name = str[start + 1: str.find(',', start)].strip(' ') - start = str.find("Ptr", start) - ptr = str[start + 5: str.find(',', start)] +def parseMultiParam(multi_param_log): + multi_param = [] + start, end = 0, len(multi_param_log) - if ptr not in grad_nodes: - grad_nodes[ptr] = {} - - grad_nodes[ptr]["output_op"] = op_id - print("output_op: {} -> {}".format(op_id, ptr)) + while multi_param_log.find('(', start, end) != -1: + param_start = multi_param_log.find('(', start, end) + param_end = multi_param_log.find('}]),', start, end) + start = param_end + 1 + # Converts a parameter string to an object + param = parseParam(multi_param_log[param_start + 1: param_end]) + multi_param.append(param) - p = str.find('(', p, end) - - start = p + return multi_param -def add_edge_info(dot): - with open('./output.txt', encoding='utf-8') as f: - content = f.read() - +def parseParam(param_log): + param = {} start = 0 - - while content.find('gradnode_ptr', start) != -1: - start = content.find('gradnode_ptr', start) - end = content.find('\n', start) - if content[end - 1] == '\r': - end -= 1 - - start += 15 - op_ptr = content[start: end] + + name = param_log[start: param_log.find(',', start)].strip(' ') + param["name"] = name - end = content.find("backward.cc:288", start) + start = param_log.find("Ptr", start) + ptr = param_log[start + 5: param_log.find(',', start)].strip(' ') + param["ptr"] = ptr + + start = param_log.find("Dtype", start) + if start == -1: + param["dtype"] = None + param["place"] = None + param["shape"] = None + return param + + dtype = param_log[start + 7: param_log.find(',', start)].strip(' ') + param["dtype"] = dtype - op_log = content[start: end] + start = param_log.find("Place", start) + place = param_log[start + 7: param_log.find(',', start)].strip(' ') + param["place"] = place + + start = param_log.find("Shape", start) + shape = '[{}]'.format(param_log[start + 7: param_log.find(']', start)].strip(' ')) + param["shape"] = shape + + # print(param) - processOPLog(op_log, op_ptr, dot) \ No newline at end of file + return param \ No newline at end of file