Skip to content

Commit

Permalink
repair bug in latexocr cpu infer and typo (#14552)
Browse files Browse the repository at this point in the history
  • Loading branch information
liuhongen1234567 authored Jan 16, 2025
1 parent 52bc8f0 commit cf4c059
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 39 deletions.
2 changes: 1 addition & 1 deletion configs/rec/PP-FormuaNet/rec_pp_formulanet_l.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ PostProcess:
Metric:
name: LaTeXOCRMetric
main_indicator: exp_rate
cal_blue_score: False
cal_bleu_score: False

Train:
dataset:
Expand Down
2 changes: 1 addition & 1 deletion configs/rec/PP-FormuaNet/rec_pp_formulanet_s.yml
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ PostProcess:
Metric:
name: LaTeXOCRMetric
main_indicator: exp_rate
cal_blue_score: False
cal_bleu_score: False

Train:
dataset:
Expand Down
2 changes: 1 addition & 1 deletion configs/rec/rec_latex_ocr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ PostProcess:
Metric:
name: LaTeXOCRMetric
main_indicator: exp_rate
cal_blue_score: False
cal_bleu_score: False

Train:
dataset:
Expand Down
2 changes: 1 addition & 1 deletion configs/rec/rec_unimernet.yml
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ PostProcess:
Metric:
name: LaTeXOCRMetric
main_indicator: exp_rate
cal_blue_score: False
cal_bleu_score: False

Train:
dataset:
Expand Down
2 changes: 1 addition & 1 deletion ppocr/metrics/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __call__(self, line):
return self._post_tokenizer(f" {line} ")


def compute_blue_score(
def compute_bleu_score(
predictions, references, tokenizer=Tokenizer13a(), max_order=4, smooth=False
):
# if only one reference is provided make sure we still use list of lists
Expand Down
38 changes: 19 additions & 19 deletions ppocr/metrics/rec_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import numpy as np
import string
from .bleu import compute_blue_score, compute_edit_distance
from .bleu import compute_bleu_score, compute_edit_distance


class RecMetric(object):
Expand Down Expand Up @@ -181,21 +181,21 @@ def epoch_reset(self):


class LaTeXOCRMetric(object):
def __init__(self, main_indicator="exp_rate", cal_blue_score=False, **kwargs):
def __init__(self, main_indicator="exp_rate", cal_bleu_score=False, **kwargs):
self.main_indicator = main_indicator
self.cal_blue_score = cal_blue_score
self.cal_bleu_score = cal_bleu_score
self.edit_right = []
self.exp_right = []
self.blue_right = []
self.bleu_right = []
self.e1_right = []
self.e2_right = []
self.e3_right = []
self.editdistance_total_length = 0
self.exp_total_num = 0
self.edit_dist = 0
self.exp_rate = 0
if self.cal_blue_score:
self.blue_score = 0
if self.cal_bleu_score:
self.bleu_score = 0
self.e1 = 0
self.e2 = 0
self.e3 = 0
Expand Down Expand Up @@ -227,16 +227,16 @@ def __call__(self, preds, batch, **kwargs):

self.edit_dist = sum(lev_dist) # float
self.exp_rate = line_right # float
if self.cal_blue_score:
self.blue_score = compute_blue_score(word_pred, word_label)
if self.cal_bleu_score:
self.bleu_score = compute_bleu_score(word_pred, word_label)
self.e1 = e1
self.e2 = e2
self.e3 = e3
exp_length = len(word_label)
self.edit_right.append(self.edit_dist)
self.exp_right.append(self.exp_rate)
if self.cal_blue_score:
self.blue_right.append(self.blue_score * batch_size)
if self.cal_bleu_score:
self.bleu_right.append(self.bleu_score * batch_size)
self.e1_right.append(self.e1)
self.e2_right.append(self.e2)
self.e3_right.append(self.e3)
Expand All @@ -247,21 +247,21 @@ def get_metric(self):
"""
return {
'edit distance': 0,
"blue_score": 0,
"bleu_score": 0,
"exp_rate": 0,
}
"""
cur_edit_distance = sum(self.edit_right) / self.exp_total_num
cur_exp_rate = sum(self.exp_right) / self.exp_total_num
if self.cal_blue_score:
cur_blue_score = sum(self.blue_right) / self.editdistance_total_length
if self.cal_bleu_score:
cur_bleu_score = sum(self.bleu_right) / self.editdistance_total_length
cur_exp_1 = sum(self.e1_right) / self.exp_total_num
cur_exp_2 = sum(self.e2_right) / self.exp_total_num
cur_exp_3 = sum(self.e3_right) / self.exp_total_num
self.reset()
if self.cal_blue_score:
if self.cal_bleu_score:
return {
"blue_score": cur_blue_score,
"bleu_score": cur_bleu_score,
"edit distance": cur_edit_distance,
"exp_rate": cur_exp_rate,
"exp_rate<=1 ": cur_exp_1,
Expand All @@ -281,17 +281,17 @@ def get_metric(self):
def reset(self):
self.edit_dist = 0
self.exp_rate = 0
if self.cal_blue_score:
self.blue_score = 0
if self.cal_bleu_score:
self.bleu_score = 0
self.e1 = 0
self.e2 = 0
self.e3 = 0

def epoch_reset(self):
self.edit_right = []
self.exp_right = []
if self.cal_blue_score:
self.blue_right = []
if self.cal_bleu_score:
self.bleu_right = []
self.e1_right = []
self.e2_right = []
self.e3_right = []
Expand Down
22 changes: 11 additions & 11 deletions ppocr/modeling/backbones/rec_resnetv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,16 +90,9 @@ def __init__(

self.running_mean = paddle.zeros([self._out_channels], dtype="float32")
self.running_variance = paddle.ones([self._out_channels], dtype="float32")
orin_shape = self.weight.shape
new_weight = F.batch_norm(
self.weight.reshape([1, self._out_channels, -1]),
self.running_mean,
self.running_variance,
momentum=0.0,
epsilon=self.eps,
use_global_stats=False,
).reshape(orin_shape)
self.weight.set_value(new_weight.numpy())
self.batch_norm = paddle.nn.BatchNorm1D(
self._out_channels, use_global_stats=False
)

def forward(self, x):
if not self.training:
Expand All @@ -110,7 +103,14 @@ def forward(self, x):
else:
x = pad_same(x, self._kernel_size, self._stride, self._dilation)
if self.export:
weight = self.weight
weight = paddle.reshape(
self.batch_norm(
self.weight.reshape([1, self._out_channels, -1]).cast(
paddle.float32
),
),
self.weight.shape,
)
else:
weight = paddle.reshape(
F.batch_norm(
Expand Down
11 changes: 10 additions & 1 deletion ppocr/utils/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,16 @@ def dump_infer_config(config, path, logger):
infer_cfg["PreProcess"] = {"transform_ops": config["Eval"]["dataset"]["transforms"]}
postprocess = OrderedDict()
for k, v in config["PostProcess"].items():
postprocess[k] = v
if config["Architecture"].get("algorithm") in [
"LaTeXOCR",
"UniMERNet",
"PP-FormulaNet-L",
"PP-FormulaNet-S",
]:
if k != "rec_char_dict_path":
postprocess[k] = v
else:
postprocess[k] = v

if config["Architecture"].get("algorithm") in ["LaTeXOCR"]:
tokenizer_file = config["Global"].get("rec_char_dict_path")
Expand Down
6 changes: 3 additions & 3 deletions tools/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,16 +107,16 @@ def main():
model_type = "can"
elif config["Architecture"]["algorithm"] == "LaTeXOCR":
model_type = "latexocr"
config["Metric"]["cal_blue_score"] = True
config["Metric"]["cal_bleu_score"] = True
elif config["Architecture"]["algorithm"] == "UniMERNet":
model_type = "unimernet"
config["Metric"]["cal_blue_score"] = True
config["Metric"]["cal_bleu_score"] = True
elif config["Architecture"]["algorithm"] in [
"PP-FormulaNet-S",
"PP-FormulaNet-L",
]:
model_type = "pp_formulanet"
config["Metric"]["cal_blue_score"] = True
config["Metric"]["cal_bleu_score"] = True
else:
model_type = config["Architecture"]["model_type"]
else:
Expand Down

0 comments on commit cf4c059

Please sign in to comment.