From 8c63d186f52631ce88c290af216ff06abd2c516c Mon Sep 17 00:00:00 2001 From: mjamroz Date: Sun, 27 Aug 2023 11:06:15 +0200 Subject: [PATCH] convert to onnx --- main_conversion.py | 15 ++++++++------- 1 file changed, 8 insertions(+), 7 deletions(-) diff --git a/main_conversion.py b/main_conversion.py index 96033a7..7e4e457 100644 --- a/main_conversion.py +++ b/main_conversion.py @@ -13,7 +13,7 @@ from options.opts import get_conversion_arguments from utils import logger from utils.checkpoint_utils import CHECKPOINT_EXTN -from utils.pytorch_to_coreml import convert_pytorch_to_coreml +from utils.pytorch_to_coreml import convert_pytorch_to_coreml, create_rand_tensor def main_worker_conversion(args: Optional[List[str]] = None): @@ -45,14 +45,17 @@ def main_worker_conversion(args: Optional[List[str]] = None): conversion_success = False try: - converted_models_dict = convert_pytorch_to_coreml( - opts=opts, pytorch_model=model - ) + converted_models_dict = convert_pytorch_to_coreml(opts=opts, pytorch_model=model) coreml_model = converted_models_dict["coreml"] jit_model = converted_models_dict["jit"] jit_optimized = converted_models_dict["jit_optimized"] coreml_model.save(model_dst_loc) + torch.onnx.export( + model, + create_rand_tensor(opts=opts, device="cpu"), + model_dst_loc.replace(f".{coreml_extn}", ".onnx"), + ) torch.jit.save( jit_model, @@ -67,9 +70,7 @@ def main_worker_conversion(args: Optional[List[str]] = None): conversion_success = True except Exception as e: logger.error( - "PyTorch to CoreML conversion failed. See below for error details:\n {}".format( - e - ) + "PyTorch to CoreML conversion failed. See below for error details:\n {}".format(e) ) if conversion_success: