Skip to content

Commit

Permalink
feat: make HRIConnector and ARIConnector classes generic (#375)
Browse files Browse the repository at this point in the history
  • Loading branch information
rachwalk authored Jan 17, 2025
1 parent c7d1f2b commit 1595c52
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 13 deletions.
7 changes: 5 additions & 2 deletions src/rai/rai/communication/ari_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Optional
from typing import Generic, Optional, TypeVar

from pydantic import Field

Expand All @@ -37,7 +37,10 @@ class ROS2RRIMessage(ARIMessage):
)


class ARIConnector(BaseConnector[ARIMessage]):
T = TypeVar("T", bound=ARIMessage)


class ARIConnector(Generic[T], BaseConnector[T]):
"""
Base class for Agent-Robot Interface (ARI) connectors.
Expand Down
8 changes: 5 additions & 3 deletions src/rai/rai/communication/base_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from abc import ABC, abstractmethod
from typing import Any, Callable, Generic, Optional, TypeVar
from abc import abstractmethod
from typing import Any, Callable, Generic, Optional, Protocol, TypeVar
from uuid import uuid4


class BaseMessage(ABC):
class BaseMessage(Protocol):
payload: Any

def __init__(self, payload: Any, *args, **kwargs):
self.payload = payload

Expand Down
36 changes: 29 additions & 7 deletions src/rai/rai/communication/hri_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Annotated, List, Literal, Optional, Sequence
from typing import (
Annotated,
Generic,
List,
Literal,
Optional,
Sequence,
TypeVar,
get_args,
)

from langchain_core.messages import AIMessage
from langchain_core.messages import BaseMessage as LangchainBaseMessage
Expand All @@ -25,6 +34,11 @@
from .base_connector import BaseConnector, BaseMessage


class HRIException(Exception):
def __init__(self, msg):
super().__init__(msg)


class HRIPayload(BaseModel):
text: str
images: Optional[Annotated[List[str], "base64 encoded png images"]] = None
Expand All @@ -42,8 +56,6 @@ def __init__(
self.images = payload.images
self.audios = payload.audios

# type: Literal["ai", "human"]

def __repr__(self):
return f"HRIMessage(type={self.message_author}, text={self.text}, images={self.images}, audios={self.audios})"

Expand Down Expand Up @@ -91,7 +103,10 @@ def from_langchain(
)


class HRIConnector(BaseConnector[HRIMessage]):
T = TypeVar("T", bound=HRIMessage)


class HRIConnector(Generic[T], BaseConnector[T]):
"""
Base class for Human-Robot Interaction (HRI) connectors.
Used for sending and receiving messages between human and robot from various sources.
Expand All @@ -105,19 +120,26 @@ def __init__(
):
self.configured_targets = configured_targets
self.configured_sources = configured_sources
if not hasattr(self, "__orig_bases__"):
self.__orig_bases__ = {}
raise HRIException(
f"Error while instantiating {str(self.__class__)}: Message type T derived from HRIMessage needs to be provided e.g. Connector[MessageType]()"
)
self.T_class = get_args(self.__orig_bases__[0])[0]

def _build_message(
self,
message: LangchainBaseMessage | RAIMultimodalMessage,
) -> HRIMessage:
return HRIMessage.from_langchain(message)
) -> T:

return self.T_class.from_langchain(message)

def send_all_targets(self, message: LangchainBaseMessage | RAIMultimodalMessage):
for target in self.configured_targets:
to_send = self._build_message(message)
self.send_message(to_send, target)

def receive_all_sources(self, timeout_sec: float = 1.0) -> dict[str, HRIMessage]:
def receive_all_sources(self, timeout_sec: float = 1.0) -> dict[str, T]:
ret = {}
for source in self.configured_sources:
received = self.receive_message(source, timeout_sec)
Expand Down
2 changes: 1 addition & 1 deletion src/rai/rai/communication/sound_device_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def __init__(self, config: AudioInputDeviceConfig):
self.dtype = config["dtype"]


class StreamingAudioInputDevice(HRIConnector):
class StreamingAudioInputDevice(HRIConnector[HRIMessage]):
"""Audio input device connector implementing the Human-Robot Interface.
This class provides audio streaming capabilities while conforming to the
Expand Down

0 comments on commit 1595c52

Please sign in to comment.