Skip to content

Commit

Permalink
fix:improve_media_clf (#46)
Browse files Browse the repository at this point in the history
ensure correct MediaType if a skill is explicitly requested

skip media classification if only 1 MediaType is available

restrict valid media classifications to installed skills media types
  • Loading branch information
JarbasAl authored Dec 19, 2024
1 parent 818bc6c commit 4c313cd
Showing 1 changed file with 42 additions and 29 deletions.
71 changes: 42 additions & 29 deletions ocp_pipeline/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,11 +484,15 @@ def _process_play_query(self, utterance: str, lang: str, match: dict = None,
if skill_id not in sess.blacklisted_skills and
any(s.lower() in utterance for s in samples)
]
valid_labels = []
if valid_skills:
LOG.info(f"OCP specific skill names matched: {valid_skills}")
for mtype, skills in self.media2skill.items():
if any([s in skills for s in valid_skills]):
valid_labels.append(mtype)

# classify the query media type
media_type, conf = self.classify_media(utterance, lang)
media_type, conf = self.classify_media(utterance, lang, valid_labels=valid_labels)

# extract the query string
query = self.remove_voc(utterance, "Play", lang).strip()
Expand Down Expand Up @@ -692,70 +696,77 @@ def handle_search_error_intent(self, message: Message):
self.ocp_api.stop(source_message=message)

# NLP
def voc_match_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
def voc_match_media(self, query: str, lang: str, valid_labels: Optional[List[MediaType]] = None) -> Tuple[MediaType, float]:
lang = standardize_lang_tag(lang)
valid_labels = valid_labels or [m for m, s in self.media2skill.items() if s] or list(MediaType)
# simplistic approach via voc_match, works anywhere
# and it's easy to localize, but isn't very accurate
if self.voc_match(query, "MusicKeyword", lang=lang):
if MediaType.MUSIC in valid_labels and self.voc_match(query, "MusicKeyword", lang=lang):
# NOTE - before movie to handle "{movie_name} soundtrack"
return MediaType.MUSIC, 0.6
elif self.voc_match(query, "MovieKeyword", lang=lang):
if self.voc_match(query, "ShortKeyword", lang=lang):
elif any([s in valid_labels for s in [MediaType.MOVIE, MediaType.SHORT_FILM, MediaType.SILENT_MOVIE, MediaType.BLACK_WHITE_MOVIE]]) and \
self.voc_match(query, "MovieKeyword", lang=lang):
if MediaType.SHORT_FILM in valid_labels and self.voc_match(query, "ShortKeyword", lang=lang):
return MediaType.SHORT_FILM, 0.7
elif self.voc_match(query, "SilentKeyword", lang=lang):
elif MediaType.SILENT_MOVIE in valid_labels and self.voc_match(query, "SilentKeyword", lang=lang):
return MediaType.SILENT_MOVIE, 0.7
elif self.voc_match(query, "BWKeyword", lang=lang):
elif MediaType.BLACK_WHITE_MOVIE in valid_labels and self.voc_match(query, "BWKeyword", lang=lang):
return MediaType.BLACK_WHITE_MOVIE, 0.7
return MediaType.MOVIE, 0.6
elif self.voc_match(query, "DocumentaryKeyword", lang=lang):
elif MediaType.DOCUMENTARY in valid_labels and self.voc_match(query, "DocumentaryKeyword", lang=lang):
return MediaType.DOCUMENTARY, 0.6
elif self.voc_match(query, "AudioBookKeyword", lang=lang):
elif MediaType.AUDIOBOOK in valid_labels and self.voc_match(query, "AudioBookKeyword", lang=lang):
return MediaType.AUDIOBOOK, 0.6
elif self.voc_match(query, "NewsKeyword", lang=lang):
elif MediaType.NEWS in valid_labels and self.voc_match(query, "NewsKeyword", lang=lang):
return MediaType.NEWS, 0.6
elif self.voc_match(query, "AnimeKeyword", lang=lang):
elif MediaType.ANIME in valid_labels and self.voc_match(query, "AnimeKeyword", lang=lang):
return MediaType.ANIME, 0.6
elif self.voc_match(query, "CartoonKeyword", lang=lang):
elif MediaType.CARTOON in valid_labels and self.voc_match(query, "CartoonKeyword", lang=lang):
return MediaType.CARTOON, 0.6
elif self.voc_match(query, "PodcastKeyword", lang=lang):
elif MediaType.PODCAST in valid_labels and self.voc_match(query, "PodcastKeyword", lang=lang):
return MediaType.PODCAST, 0.6
elif self.voc_match(query, "TVKeyword", lang=lang):
elif MediaType.TV in valid_labels and self.voc_match(query, "TVKeyword", lang=lang):
return MediaType.TV, 0.6
elif self.voc_match(query, "SeriesKeyword", lang=lang):
elif MediaType.VIDEO_EPISODES in valid_labels and self.voc_match(query, "SeriesKeyword", lang=lang):
return MediaType.VIDEO_EPISODES, 0.6
elif self.voc_match(query, "AudioDramaKeyword", lang=lang):
elif MediaType.RADIO_THEATRE in valid_labels and self.voc_match(query, "AudioDramaKeyword", lang=lang):
# NOTE - before "radio" to allow "radio theatre"
return MediaType.RADIO_THEATRE, 0.6
elif self.voc_match(query, "RadioKeyword", lang=lang):
elif MediaType.RADIO in valid_labels and self.voc_match(query, "RadioKeyword", lang=lang):
return MediaType.RADIO, 0.6
elif self.voc_match(query, "ComicBookKeyword", lang=lang):
elif MediaType.VISUAL_STORY in valid_labels and self.voc_match(query, "ComicBookKeyword", lang=lang):
return MediaType.VISUAL_STORY, 0.4
elif self.voc_match(query, "GameKeyword", lang=lang):
elif MediaType.GAME in valid_labels and self.voc_match(query, "GameKeyword", lang=lang):
return MediaType.GAME, 0.4
elif self.voc_match(query, "ADKeyword", lang=lang):
elif MediaType.AUDIO_DESCRIPTION in valid_labels and self.voc_match(query, "ADKeyword", lang=lang):
return MediaType.AUDIO_DESCRIPTION, 0.4
elif self.voc_match(query, "ASMRKeyword", lang=lang):
elif MediaType.ASMR in valid_labels and self.voc_match(query, "ASMRKeyword", lang=lang):
return MediaType.ASMR, 0.4
elif self.voc_match(query, "AdultKeyword", lang=lang):
if self.voc_match(query, "CartoonKeyword", lang=lang) or \
elif any([s in valid_labels for s in [MediaType.ADULT, MediaType.HENTAI, MediaType.ADULT_AUDIO]]) and self.voc_match(query, "AdultKeyword", lang=lang):
if MediaType.HENTAI in valid_labels and self.voc_match(query, "CartoonKeyword", lang=lang) or \
self.voc_match(query, "AnimeKeyword", lang=lang) or \
self.voc_match(query, "HentaiKeyword", lang=lang):
return MediaType.HENTAI, 0.4
elif self.voc_match(query, "AudioKeyword", lang=lang) or \
elif MediaType.ADULT_AUDIO in valid_labels and self.voc_match(query, "AudioKeyword", lang=lang) or \
self.voc_match(query, "ASMRKeyword", lang=lang):
return MediaType.ADULT_AUDIO, 0.4
return MediaType.ADULT, 0.4
elif self.voc_match(query, "HentaiKeyword", lang=lang):
elif MediaType.HENTAI in valid_labels and self.voc_match(query, "HentaiKeyword", lang=lang):
return MediaType.HENTAI, 0.4
elif self.voc_match(query, "VideoKeyword", lang=lang):
elif MediaType.VIDEO in valid_labels and self.voc_match(query, "VideoKeyword", lang=lang):
return MediaType.VIDEO, 0.4
elif self.voc_match(query, "AudioKeyword", lang=lang):
elif MediaType.AUDIO in valid_labels and self.voc_match(query, "AudioKeyword", lang=lang):
return MediaType.AUDIO, 0.4
return MediaType.GENERIC, 0.0

def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
def classify_media(self, query: str, lang: str, valid_labels: Optional[List[MediaType]] = None) -> Tuple[MediaType, float]:
""" determine what media type is being requested """
lang = standardize_lang_tag(lang)
valid_labels = valid_labels or [m for m, s in self.media2skill.items() if s] or list(MediaType)
LOG.debug(f"valid media types: {valid_labels}")
if len(valid_labels) == 1:
return valid_labels[0], 1.0

# using a trained classifier (Experimental)
if self.config.get("experimental_media_classifier", False):
from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier
Expand All @@ -768,6 +779,8 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
featurizer: OCPFeaturizer = self._media_clf[1]
X = featurizer.transform([query])
preds = clf.predict_labels(X)[0]
preds = {k: v for k, v in preds.items()
if OCPFeaturizer.label2media(k) in valid_labels}
label = max(preds, key=preds.get)
prob = float(round(preds[label], 3))
LOG.info(f"OVOSCommonPlay MediaType prediction: {label} confidence: {prob}")
Expand All @@ -779,7 +792,7 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]:
return OCPFeaturizer.label2media(label), prob
except:
LOG.exception(f"OCP classifier exception: {query}")
return self.voc_match_media(query, lang)
return self.voc_match_media(query, lang, valid_labels)

def is_ocp_query(self, query: str, lang: str) -> Tuple[bool, float]:
""" determine if a playback question is being asked"""
Expand Down

0 comments on commit 4c313cd

Please sign in to comment.