Skip to content

Commit

Permalink
Backported serialization to python 3.6
Browse files Browse the repository at this point in the history
  • Loading branch information
w.jurasz committed Sep 4, 2019
1 parent 59fc7c2 commit 935ac0e
Show file tree
Hide file tree
Showing 6 changed files with 21 additions and 14 deletions.
7 changes: 3 additions & 4 deletions ros1_roboy/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,12 @@ WORKDIR /ros1
COPY . .

RUN apt update
RUN apt install python3.7 python3.7-dev ffmpeg iputils-ping -y
RUN apt install ffmpeg iputils-ping -y

RUN git clone https://github.com/Roboy/sonosco.git
RUN cd sonosco; git checkout demo; python3.7 -m pip install -e .
RUN python3.7 -m pip install webrtcvad
RUN cd sonosco; git checkout develop; pip3 install .
RUN chmod +x STT_server.py
RUN chmod +x STT_client.py
#RUN . ~/melodic_ws/devel/setup.bash
#RUN source ~/melodic_ws/devel/setup.bash
#RUN roscore &
#ENTRYPOINT [ "bash", "-c", "source /opt/ros/melodic/setup.bash; python3 STT_srv.py" ]
4 changes: 2 additions & 2 deletions ros1_roboy/STT_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from sonosco.models import DeepSpeech2

from sonosco.models.deepspeech2_inference import DeepSpeech2Inference
from sonosco.inference.deepspeech2_inference import DeepSpeech2Inference
from sonosco.ros1.server import SonoscoROS1
from roboy_cognition_msgs.srv import RecognizeSpeech
from roboy_control_msgs.msg import ControlLeds
Expand All @@ -14,7 +14,7 @@

model_path = "pretrained/deepspeech_final.pth"

asr = DeepSpeech2Inference(DeepSpeech2.load_model(model_path))
asr = DeepSpeech2Inference(model_path)
leave = False
got_a_sentence = False

Expand Down
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,10 @@
include_package_data=True,
dependency_links=[],
classifiers=[
"Programming Language :: Python :: 3.6",
"Programming Language :: Python :: 3.7",
"License :: OSI Approved :: BSD License",
],
python_requires='>=3.7',
install_requires = required
python_requires='>=3.6',
install_requires=required
)
6 changes: 4 additions & 2 deletions sonosco/inference/dummp_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,11 @@

class DummyASR(SonoscoASR):


def __init__(self) -> None:
super().__init__(None)
super().__init__("")

def infer(self, sound_bytes):
return "dummy transcript"

def infer_from_path(self, path: str) -> str:
return "dummy transcript"
9 changes: 6 additions & 3 deletions sonosco/model/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,13 @@
from sonosco.common.constants import CLASS_MODULE_FIELD, CLASS_NAME_FIELD, SERIALIZED_FIELD
from dataclasses import _process_class, _create_fn, _set_new_attribute, fields
import typing
from typing import List, Set, Tuple, Dict
from torch import nn
import torch

__primitives = {int, float, str, bool}
# TODO support for dict with Union value type
__iterables = [list, set, tuple, dict]
__iterables = [list, set, tuple, dict, List, Set, Tuple, Dict]


# TODO: Prevent user from serializing lambdas.
Expand Down Expand Up @@ -115,7 +116,7 @@ def __create_serialize_body(fields_to_serialize: typing.Iterable, model: bool, e
body_lines.append(f" {c.name} = {{")
__encode_serializable_serialization(body_lines, c)
body_lines.append(f"}}")
body_lines.append(f"else: raise TypeError(\"Callable must be a function for now\")")
body_lines.append(f"else: raise TypeError(f\"Callable must be a function for now: {c.name}\")")

for field in callable_iterables:
body_lines.append(f"{field.name} = []")
Expand Down Expand Up @@ -275,4 +276,6 @@ def __is_callable(obj: any) -> bool:
objs = list(obj.__args__)
else:
objs = [obj]
return all([hasattr(obj, '__origin__') and obj.__origin__ == collections.abc.Callable for obj in objs])
return all([hasattr(obj, '__origin__') and (
obj.__origin__ == collections.abc.Callable or obj.__origin__ == typing.Callable)
for obj in objs])
4 changes: 3 additions & 1 deletion tests/test_model_trainer_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
SOS = '#'
PADDING_VALUE = '%'


def test_mode_trainer_serialization():
config_path = "model_trainer_config_test.yaml"
config = parse_yaml(config_path)["train"]
Expand Down Expand Up @@ -60,4 +61,5 @@ def test_mode_trainer_serialization():
'test_data_loader': test_loader,
}, with_config=True)
assert trainer_deserialized is not None
assert deserialized_config == config
assert deserialized_config == config

0 comments on commit 935ac0e

Please sign in to comment.