Skip to content

Commit

Permalink
modify export with pir (#14441)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhangyubo0722 authored Dec 30, 2024
1 parent 0d41ffc commit 2f0a29e
Showing 1 changed file with 27 additions and 31 deletions.
58 changes: 27 additions & 31 deletions ppocr/utils/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,42 +318,31 @@ def dynamic_to_static(model, arch_config, logger, input_shape=None):


def export_single_model(
model, arch_config, save_path, logger, yaml_path, input_shape=None, quanter=None
model,
arch_config,
save_path,
logger,
yaml_path,
config,
input_shape=None,
quanter=None,
):

model = dynamic_to_static(model, arch_config, logger, input_shape)

if quanter is None:
paddle_version = version.parse(paddle.__version__)
if (
paddle_version >= version.parse("3.0.0b2")
or paddle_version == version.parse("0.0.0")
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]:
save_path = os.path.dirname(save_path)
for enable_pir in [True, False]:
if not enable_pir:
save_path_no_pir = os.path.join(save_path, "inference")
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = dynamic_to_static(
model, arch_config, logger, input_shape
)
paddle.jit.save(model, save_path_no_pir)
else:
save_path_pir = os.path.join(
os.path.dirname(save_path),
f"{os.path.basename(save_path)}_pir",
"inference",
)
paddle.jit.save(model, save_path_pir)
shutil.copy(
yaml_path,
os.path.join(
os.path.dirname(save_path_pir), os.path.basename(yaml_path)
),
)
else:
if config["Global"].get("export_with_pir", False):
paddle_version = version.parse(paddle.__version__)
assert (
paddle_version >= version.parse("3.0.0b2")
or paddle_version == version.parse("0.0.0")
) and os.environ.get("FLAGS_enable_pir_api", None) not in ["0", "False"]
paddle.jit.save(model, save_path)
else:
model.forward.rollback()
with paddle.pir_utils.OldIrGuard():
model = dynamic_to_static(model, arch_config, logger, input_shape)
paddle.jit.save(model, save_path)
else:
quanter.save_quantized_model(model, save_path)
logger.info("inference model is saved to {}".format(save_path))
Expand Down Expand Up @@ -472,9 +461,16 @@ def export(config, base_model=None, save_path=None):
sub_model_save_path,
logger,
yaml_path,
config,
)
else:
save_path = os.path.join(save_path, "inference")
export_single_model(
model, arch_config, save_path, logger, yaml_path, input_shape=input_shape
model,
arch_config,
save_path,
logger,
yaml_path,
config,
input_shape=input_shape,
)

0 comments on commit 2f0a29e

Please sign in to comment.