Skip to content

Commit

Permalink
performance: support padatious (#42)
Browse files Browse the repository at this point in the history
* performance: support padatious

* performance: support padatious

* tests
  • Loading branch information
JarbasAl authored Dec 11, 2024
1 parent 065d3a2 commit 0c8b517
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 3 deletions.
43 changes: 40 additions & 3 deletions ocp_pipeline/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,8 @@
from ovos_utils.ocp import MediaType, PlaybackType, PlaybackMode, PlayerState, OCP_ID, \
MediaEntry, Playlist, MediaState, TrackState, dict2entry, PluginStream
from ovos_workshop.app import OVOSAbstractApplication
from padacioso import IntentContainer
from ovos_utils.xdg_utils import xdg_data_home
from ovos_config.meta import get_xdg_base

from ocp_pipeline.feats import OCPFeaturizer
from ocp_pipeline.legacy import LegacyCommonPlay
Expand All @@ -48,6 +49,7 @@ class OCPPipelineMatcher(ConfidenceMatcherPipeline, OVOSAbstractApplication):
"next.intent", "prev.intent", "pause.intent", "play_favorites.intent",
"resume.intent", "like_song.intent"]
intent_matchers = {}
intent_cache = f"{xdg_data_home()}/{get_xdg_base()}/intent_cache"

def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
config: Optional[Dict] = None):
Expand Down Expand Up @@ -150,9 +152,22 @@ def register_ocp_api_events(self):
@classmethod
def load_intent_files(cls):
intent_files = cls.load_resource_files()

try:
from ovos_padatious import IntentContainer
is_padatious = True
except ImportError:
from padacioso import IntentContainer
is_padatious = False
LOG.warning("Padatious not available, using padacioso. intent matching will be orders of magnitude slower!")

for lang, intent_data in intent_files.items():
lang = standardize_lang_tag(lang)
cls.intent_matchers[lang] = IntentContainer()
if is_padatious:
cache = f"{cls.intent_cache}/{lang}"
cls.intent_matchers[lang] = IntentContainer(cache)
else:
cls.intent_matchers[lang] = IntentContainer()
for intent_name in cls.intents:
samples = intent_data.get(intent_name)
if samples:
Expand Down Expand Up @@ -301,6 +316,10 @@ def handle_player_state_update(self, message: Message):
def match_high(self, utterances: List[str], lang: str, message: Message = None) -> Optional[IntentHandlerMatch]:
""" exact matches only, handles playback control
recommended after high confidence intents pipeline stage """

if not len(self.skill_aliases): # skill_id registered when skills load
return None # dont waste compute cycles, no media skills -> no match

lang = self._get_closest_lang(lang)
if lang is None: # no intents registered for this lang
return None
Expand All @@ -310,9 +329,21 @@ def match_high(self, utterances: List[str], lang: str, message: Message = None)
utterance = utterances[0].lower()
match = self.intent_matchers[lang].calc_intent(utterance)

if hasattr(match, "name"): # padatious
match = {
"name": match.name,
"conf": match.conf,
"entities": match.matches
}

if match["name"] is None:
return None
LOG.info(f"OCP exact match: {match}")

if match.get("conf", 1.0) < 0.7:
LOG.debug(f"Ignoring low confidence OCP match: {match}")
return None

LOG.info(f"OCP match: {match}")

player = self.get_player(message)

Expand Down Expand Up @@ -1128,6 +1159,12 @@ def match(self, utterances: List[str], lang: str, message: Message = None) -> Op
return None

match = OCPPipelineMatcher.intent_matchers[lang].calc_intent(utterance)
if hasattr(match, "name"): # padatious
match = {
"name": match.name,
"conf": match.conf,
"entities": match.matches
}

if match["name"] is None:
return None
Expand Down
2 changes: 2 additions & 0 deletions tests/test_ocp.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ def setUp(self):
os.path.dirname(ocp_pipeline.opm.__file__) + "/models/ocp_entities_v0.csv"
]}
self.ocp = OCPPipelineMatcher(config=config)
self.ocp.skill_aliases["test"] = ["Test Skill"] # pretend a skill is loaded or matching is skipped

def test_match_high(self):
result = self.ocp.match_high(["play metallica"], "en-US")
Expand Down Expand Up @@ -114,6 +115,7 @@ def setUp(self):
os.path.dirname(ocp_pipeline.opm.__file__) + "/models/ocp_entities_v0.csv"
]}
self.ocp = OCPPipelineMatcher(config=config)
self.ocp.skill_aliases["test"] = ["Test Skill"] # pretend a skill is loaded or matching is skipped

def test_match_high(self):
result = self.ocp.match_high(["play metallica"], "en-US")
Expand Down

0 comments on commit 0c8b517

Please sign in to comment.