From 935ac0e80fc133d11d552cb1d44454f90b769506 Mon Sep 17 00:00:00 2001 From: "w.jurasz" Date: Wed, 4 Sep 2019 21:07:57 +0200 Subject: [PATCH] Backported serialization to python 3.6 --- ros1_roboy/Dockerfile | 7 +++---- ros1_roboy/STT_server.py | 4 ++-- setup.py | 5 +++-- sonosco/inference/dummp_asr.py | 6 ++++-- sonosco/model/serialization.py | 9 ++++++--- tests/test_model_trainer_serialization.py | 4 +++- 6 files changed, 21 insertions(+), 14 deletions(-) diff --git a/ros1_roboy/Dockerfile b/ros1_roboy/Dockerfile index f7b0e83..b0ab104 100644 --- a/ros1_roboy/Dockerfile +++ b/ros1_roboy/Dockerfile @@ -6,13 +6,12 @@ WORKDIR /ros1 COPY . . RUN apt update -RUN apt install python3.7 python3.7-dev ffmpeg iputils-ping -y +RUN apt install ffmpeg iputils-ping -y RUN git clone https://github.com/Roboy/sonosco.git -RUN cd sonosco; git checkout demo; python3.7 -m pip install -e . -RUN python3.7 -m pip install webrtcvad +RUN cd sonosco; git checkout develop; pip3 install . RUN chmod +x STT_server.py RUN chmod +x STT_client.py -#RUN . ~/melodic_ws/devel/setup.bash +#RUN source ~/melodic_ws/devel/setup.bash #RUN roscore & #ENTRYPOINT [ "bash", "-c", "source /opt/ros/melodic/setup.bash; python3 STT_srv.py" ] diff --git a/ros1_roboy/STT_server.py b/ros1_roboy/STT_server.py index e2978f6..b2ea5af 100644 --- a/ros1_roboy/STT_server.py +++ b/ros1_roboy/STT_server.py @@ -4,7 +4,7 @@ from sonosco.models import DeepSpeech2 -from sonosco.models.deepspeech2_inference import DeepSpeech2Inference +from sonosco.inference.deepspeech2_inference import DeepSpeech2Inference from sonosco.ros1.server import SonoscoROS1 from roboy_cognition_msgs.srv import RecognizeSpeech from roboy_control_msgs.msg import ControlLeds @@ -14,7 +14,7 @@ model_path = "pretrained/deepspeech_final.pth" -asr = DeepSpeech2Inference(DeepSpeech2.load_model(model_path)) +asr = DeepSpeech2Inference(model_path) leave = False got_a_sentence = False diff --git a/setup.py b/setup.py index da95e74..0bbffcd 100644 --- a/setup.py +++ b/setup.py @@ -20,9 +20,10 @@ include_package_data=True, dependency_links=[], classifiers=[ + "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", "License :: OSI Approved :: BSD License", ], - python_requires='>=3.7', - install_requires = required + python_requires='>=3.6', + install_requires=required ) diff --git a/sonosco/inference/dummp_asr.py b/sonosco/inference/dummp_asr.py index 3ec5c47..478c56e 100644 --- a/sonosco/inference/dummp_asr.py +++ b/sonosco/inference/dummp_asr.py @@ -3,9 +3,11 @@ class DummyASR(SonoscoASR): - def __init__(self) -> None: - super().__init__(None) + super().__init__("") def infer(self, sound_bytes): return "dummy transcript" + + def infer_from_path(self, path: str) -> str: + return "dummy transcript" diff --git a/sonosco/model/serialization.py b/sonosco/model/serialization.py index d84fef1..319c47a 100644 --- a/sonosco/model/serialization.py +++ b/sonosco/model/serialization.py @@ -3,12 +3,13 @@ from sonosco.common.constants import CLASS_MODULE_FIELD, CLASS_NAME_FIELD, SERIALIZED_FIELD from dataclasses import _process_class, _create_fn, _set_new_attribute, fields import typing +from typing import List, Set, Tuple, Dict from torch import nn import torch __primitives = {int, float, str, bool} # TODO support for dict with Union value type -__iterables = [list, set, tuple, dict] +__iterables = [list, set, tuple, dict, List, Set, Tuple, Dict] # TODO: Prevent user from serializing lambdas. @@ -115,7 +116,7 @@ def __create_serialize_body(fields_to_serialize: typing.Iterable, model: bool, e body_lines.append(f" {c.name} = {{") __encode_serializable_serialization(body_lines, c) body_lines.append(f"}}") - body_lines.append(f"else: raise TypeError(\"Callable must be a function for now\")") + body_lines.append(f"else: raise TypeError(f\"Callable must be a function for now: {c.name}\")") for field in callable_iterables: body_lines.append(f"{field.name} = []") @@ -275,4 +276,6 @@ def __is_callable(obj: any) -> bool: objs = list(obj.__args__) else: objs = [obj] - return all([hasattr(obj, '__origin__') and obj.__origin__ == collections.abc.Callable for obj in objs]) + return all([hasattr(obj, '__origin__') and ( + obj.__origin__ == collections.abc.Callable or obj.__origin__ == typing.Callable) + for obj in objs]) diff --git a/tests/test_model_trainer_serialization.py b/tests/test_model_trainer_serialization.py index 50b78f4..c6fc174 100644 --- a/tests/test_model_trainer_serialization.py +++ b/tests/test_model_trainer_serialization.py @@ -20,6 +20,7 @@ SOS = '#' PADDING_VALUE = '%' + def test_mode_trainer_serialization(): config_path = "model_trainer_config_test.yaml" config = parse_yaml(config_path)["train"] @@ -60,4 +61,5 @@ def test_mode_trainer_serialization(): 'test_data_loader': test_loader, }, with_config=True) assert trainer_deserialized is not None - assert deserialized_config == config \ No newline at end of file + assert deserialized_config == config +