Skip to content

Commit

Permalink
add slanext models (#14374)
Browse files Browse the repository at this point in the history
* add slanext models

* refine codes

* refine codes

* refine codes
  • Loading branch information
liu-jiaxuan authored Dec 13, 2024
1 parent f49dec9 commit ae67d96
Show file tree
Hide file tree
Showing 7 changed files with 930 additions and 11 deletions.
179 changes: 179 additions & 0 deletions configs/table/SLANeXt_wired.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
Global:
use_gpu: true
epoch_num: 400
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/SLANeXt_wired
save_epoch_step: 400
eval_batch_step:
- 0
- 331
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: ./output/SLANeXt_wired/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
character_dict_path: ppocr/utils/dict/table_structure_dict_ch.txt
character_type: en
max_text_length: 500
box_format: xyxyxyxy
infer_mode: false
use_sync_bn: true
save_res_path: output/infer

Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 5.0
lr:
name: Cosine
learning_rate: 0.0001
warmup_epoch: 1
regularizer:
name: L2
factor: 0.0

Architecture:
model_type: table
algorithm: SLANeXt
Backbone:
name: Vary_VIT_B
image_size: 512
encoder_embed_dim: 768
encoder_depth: 12
encoder_num_heads: 12
encoder_global_attn_indexes: [2, 5, 8, 11]
Head:
name: SLAHead
hidden_size: 512
max_text_length: 500
loc_reg_num: 8

Loss:
name: SLALoss
structure_weight: 1.0
# SLANeXt does not train the cell location task by default, set the loc_weight if needed.
loc_weight: 0.0
loc_loss: smooth_l1

PostProcess:
name: TableLabelDecode
merge_no_span_structure: true

Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: false
loc_reg_num: 8
box_format: xyxyxyxy
del_thead_tbody: true

Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/train/
label_file_list:
- train_data/table/train.txt
ratio_list:
- 1
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- TableLabelEncode:
learn_empty_box: false
merge_no_span_structure: true
replace_empty_cell_token: false
loc_reg_num: 8
max_text_length: 500
- TableBoxEncode:
in_box_format: xyxyxyxy
out_box_format: xyxyxyxy
- ResizeTableImage:
max_len: 512
resize_bboxes: true
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- PaddingTableImage:
size:
- 512
- 512
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- structure
- bboxes
- bbox_masks
- length
- shape
loader:
shuffle: true
batch_size_per_card: 48
drop_last: true
num_workers: 1

Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/val/
label_file_list:
- train_data/table/val.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- TableLabelEncode:
learn_empty_box: false
merge_no_span_structure: true
replace_empty_cell_token: false
loc_reg_num: 8
max_text_length: 500
- TableBoxEncode:
in_box_format: xyxyxyxy
out_box_format: xyxyxyxy
- ResizeTableImage:
max_len: 512
resize_bboxes: true
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- PaddingTableImage:
size:
- 512
- 512
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- structure
- bboxes
- bbox_masks
- length
- shape
loader:
shuffle: false
drop_last: false
batch_size_per_card: 48
num_workers: 1

profiler_options: null
179 changes: 179 additions & 0 deletions configs/table/SLANeXt_wireless.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,179 @@
Global:
use_gpu: true
epoch_num: 400
log_smooth_window: 20
print_batch_step: 20
save_model_dir: ./output/SLANeXt_wireless
save_epoch_step: 400
eval_batch_step:
- 0
- 331
cal_metric_during_train: true
pretrained_model: null
checkpoints: null
save_inference_dir: ./output/SLANeXt_wireless/infer
use_visualdl: false
infer_img: ppstructure/docs/table/table.jpg
character_dict_path: ppocr/utils/dict/table_structure_dict_ch.txt
character_type: en
max_text_length: 500
box_format: xyxyxyxy
infer_mode: false
use_sync_bn: true
save_res_path: output/infer

Optimizer:
name: AdamW
beta1: 0.9
beta2: 0.999
clip_norm: 5.0
lr:
name: Cosine
learning_rate: 0.0001
warmup_epoch: 1
regularizer:
name: L2
factor: 0.0

Architecture:
model_type: table
algorithm: SLANeXt
Backbone:
name: Vary_VIT_B
image_size: 512
encoder_embed_dim: 768
encoder_depth: 12
encoder_num_heads: 12
encoder_global_attn_indexes: [2, 5, 8, 11]
Head:
name: SLAHead
hidden_size: 512
max_text_length: 500
loc_reg_num: 8

Loss:
name: SLALoss
structure_weight: 1.0
# SLANeXt does not train the cell location task by default, set the loc_weight if needed.
loc_weight: 0.0
loc_loss: smooth_l1

PostProcess:
name: TableLabelDecode
merge_no_span_structure: true

Metric:
name: TableMetric
main_indicator: acc
compute_bbox_metric: false
loc_reg_num: 8
box_format: xyxyxyxy
del_thead_tbody: true

Train:
dataset:
name: PubTabDataSet
data_dir: train_data/table/train/
label_file_list:
- train_data/table/train.txt
ratio_list:
- 1
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- TableLabelEncode:
learn_empty_box: false
merge_no_span_structure: true
replace_empty_cell_token: false
loc_reg_num: 8
max_text_length: 500
- TableBoxEncode:
in_box_format: xyxyxyxy
out_box_format: xyxyxyxy
- ResizeTableImage:
max_len: 512
resize_bboxes: true
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- PaddingTableImage:
size:
- 512
- 512
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- structure
- bboxes
- bbox_masks
- length
- shape
loader:
shuffle: true
batch_size_per_card: 48
drop_last: true
num_workers: 1

Eval:
dataset:
name: PubTabDataSet
data_dir: train_data/table/val/
label_file_list:
- train_data/table/val.txt
transforms:
- DecodeImage:
img_mode: BGR
channel_first: false
- TableLabelEncode:
learn_empty_box: false
merge_no_span_structure: true
replace_empty_cell_token: false
loc_reg_num: 8
max_text_length: 500
- TableBoxEncode:
in_box_format: xyxyxyxy
out_box_format: xyxyxyxy
- ResizeTableImage:
max_len: 512
resize_bboxes: true
- NormalizeImage:
scale: 1./255.
mean:
- 0.485
- 0.456
- 0.406
std:
- 0.229
- 0.224
- 0.225
order: hwc
- PaddingTableImage:
size:
- 512
- 512
- ToCHWImage: null
- KeepKeys:
keep_keys:
- image
- structure
- bboxes
- bbox_masks
- length
- shape
loader:
shuffle: false
drop_last: false
batch_size_per_card: 48
num_workers: 1

profiler_options: null
4 changes: 2 additions & 2 deletions ppocr/losses/table_att_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


class TableAttentionLoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, **kwargs):
def __init__(self, structure_weight=1.0, loc_weight=0.0, **kwargs):
super(TableAttentionLoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction="none")
self.structure_weight = structure_weight
Expand Down Expand Up @@ -58,7 +58,7 @@ def forward(self, predicts, batch):


class SLALoss(nn.Layer):
def __init__(self, structure_weight, loc_weight, loc_loss="mse", **kwargs):
def __init__(self, structure_weight=1.0, loc_weight=0.0, loc_loss="mse", **kwargs):
super(SLALoss, self).__init__()
self.loss_func = nn.CrossEntropyLoss(weight=None, reduction="mean")
self.structure_weight = structure_weight
Expand Down
Loading

0 comments on commit ae67d96

Please sign in to comment.