From 70d8de30c5615d44a15d859221c0ffa14ee420de Mon Sep 17 00:00:00 2001 From: miro Date: Wed, 18 Dec 2024 23:39:17 +0000 Subject: [PATCH] feat:improve_media_clf ensure correct MediaType if a skill is explicitly requested skip media classification if only 1 MediaType is available restrict valid media classifications to installed skills media types --- ocp_pipeline/opm.py | 71 +++++++++++++++++++++++++++------------------ 1 file changed, 42 insertions(+), 29 deletions(-) diff --git a/ocp_pipeline/opm.py b/ocp_pipeline/opm.py index ebe89e0..d47093e 100644 --- a/ocp_pipeline/opm.py +++ b/ocp_pipeline/opm.py @@ -484,11 +484,15 @@ def _process_play_query(self, utterance: str, lang: str, match: dict = None, if skill_id not in sess.blacklisted_skills and any(s.lower() in utterance for s in samples) ] + valid_labels = [] if valid_skills: LOG.info(f"OCP specific skill names matched: {valid_skills}") + for mtype, skills in self.media2skill.items(): + if any([s in skills for s in valid_skills]): + valid_labels.append(mtype) # classify the query media type - media_type, conf = self.classify_media(utterance, lang) + media_type, conf = self.classify_media(utterance, lang, valid_labels=valid_labels) # extract the query string query = self.remove_voc(utterance, "Play", lang).strip() @@ -692,70 +696,77 @@ def handle_search_error_intent(self, message: Message): self.ocp_api.stop(source_message=message) # NLP - def voc_match_media(self, query: str, lang: str) -> Tuple[MediaType, float]: + def voc_match_media(self, query: str, lang: str, valid_labels: Optional[List[MediaType]] = None) -> Tuple[MediaType, float]: lang = standardize_lang_tag(lang) + valid_labels = valid_labels or [m for m, s in self.media2skill.items() if s] or list(MediaType) # simplistic approach via voc_match, works anywhere # and it's easy to localize, but isn't very accurate - if self.voc_match(query, "MusicKeyword", lang=lang): + if MediaType.MUSIC in valid_labels and self.voc_match(query, "MusicKeyword", lang=lang): # NOTE - before movie to handle "{movie_name} soundtrack" return MediaType.MUSIC, 0.6 - elif self.voc_match(query, "MovieKeyword", lang=lang): - if self.voc_match(query, "ShortKeyword", lang=lang): + elif any([s in valid_labels for s in [MediaType.MOVIE, MediaType.SHORT_FILM, MediaType.SILENT_MOVIE, MediaType.BLACK_WHITE_MOVIE]]) and \ + self.voc_match(query, "MovieKeyword", lang=lang): + if MediaType.SHORT_FILM in valid_labels and self.voc_match(query, "ShortKeyword", lang=lang): return MediaType.SHORT_FILM, 0.7 - elif self.voc_match(query, "SilentKeyword", lang=lang): + elif MediaType.SILENT_MOVIE in valid_labels and self.voc_match(query, "SilentKeyword", lang=lang): return MediaType.SILENT_MOVIE, 0.7 - elif self.voc_match(query, "BWKeyword", lang=lang): + elif MediaType.BLACK_WHITE_MOVIE in valid_labels and self.voc_match(query, "BWKeyword", lang=lang): return MediaType.BLACK_WHITE_MOVIE, 0.7 return MediaType.MOVIE, 0.6 - elif self.voc_match(query, "DocumentaryKeyword", lang=lang): + elif MediaType.DOCUMENTARY in valid_labels and self.voc_match(query, "DocumentaryKeyword", lang=lang): return MediaType.DOCUMENTARY, 0.6 - elif self.voc_match(query, "AudioBookKeyword", lang=lang): + elif MediaType.AUDIOBOOK in valid_labels and self.voc_match(query, "AudioBookKeyword", lang=lang): return MediaType.AUDIOBOOK, 0.6 - elif self.voc_match(query, "NewsKeyword", lang=lang): + elif MediaType.NEWS in valid_labels and self.voc_match(query, "NewsKeyword", lang=lang): return MediaType.NEWS, 0.6 - elif self.voc_match(query, "AnimeKeyword", lang=lang): + elif MediaType.ANIME in valid_labels and self.voc_match(query, "AnimeKeyword", lang=lang): return MediaType.ANIME, 0.6 - elif self.voc_match(query, "CartoonKeyword", lang=lang): + elif MediaType.CARTOON in valid_labels and self.voc_match(query, "CartoonKeyword", lang=lang): return MediaType.CARTOON, 0.6 - elif self.voc_match(query, "PodcastKeyword", lang=lang): + elif MediaType.PODCAST in valid_labels and self.voc_match(query, "PodcastKeyword", lang=lang): return MediaType.PODCAST, 0.6 - elif self.voc_match(query, "TVKeyword", lang=lang): + elif MediaType.TV in valid_labels and self.voc_match(query, "TVKeyword", lang=lang): return MediaType.TV, 0.6 - elif self.voc_match(query, "SeriesKeyword", lang=lang): + elif MediaType.VIDEO_EPISODES in valid_labels and self.voc_match(query, "SeriesKeyword", lang=lang): return MediaType.VIDEO_EPISODES, 0.6 - elif self.voc_match(query, "AudioDramaKeyword", lang=lang): + elif MediaType.RADIO_THEATRE in valid_labels and self.voc_match(query, "AudioDramaKeyword", lang=lang): # NOTE - before "radio" to allow "radio theatre" return MediaType.RADIO_THEATRE, 0.6 - elif self.voc_match(query, "RadioKeyword", lang=lang): + elif MediaType.RADIO in valid_labels and self.voc_match(query, "RadioKeyword", lang=lang): return MediaType.RADIO, 0.6 - elif self.voc_match(query, "ComicBookKeyword", lang=lang): + elif MediaType.VISUAL_STORY in valid_labels and self.voc_match(query, "ComicBookKeyword", lang=lang): return MediaType.VISUAL_STORY, 0.4 - elif self.voc_match(query, "GameKeyword", lang=lang): + elif MediaType.GAME in valid_labels and self.voc_match(query, "GameKeyword", lang=lang): return MediaType.GAME, 0.4 - elif self.voc_match(query, "ADKeyword", lang=lang): + elif MediaType.AUDIO_DESCRIPTION in valid_labels and self.voc_match(query, "ADKeyword", lang=lang): return MediaType.AUDIO_DESCRIPTION, 0.4 - elif self.voc_match(query, "ASMRKeyword", lang=lang): + elif MediaType.ASMR in valid_labels and self.voc_match(query, "ASMRKeyword", lang=lang): return MediaType.ASMR, 0.4 - elif self.voc_match(query, "AdultKeyword", lang=lang): - if self.voc_match(query, "CartoonKeyword", lang=lang) or \ + elif any([s in valid_labels for s in [MediaType.ADULT, MediaType.HENTAI, MediaType.ADULT_AUDIO]]) and self.voc_match(query, "AdultKeyword", lang=lang): + if MediaType.HENTAI in valid_labels and self.voc_match(query, "CartoonKeyword", lang=lang) or \ self.voc_match(query, "AnimeKeyword", lang=lang) or \ self.voc_match(query, "HentaiKeyword", lang=lang): return MediaType.HENTAI, 0.4 - elif self.voc_match(query, "AudioKeyword", lang=lang) or \ + elif MediaType.ADULT_AUDIO in valid_labels and self.voc_match(query, "AudioKeyword", lang=lang) or \ self.voc_match(query, "ASMRKeyword", lang=lang): return MediaType.ADULT_AUDIO, 0.4 return MediaType.ADULT, 0.4 - elif self.voc_match(query, "HentaiKeyword", lang=lang): + elif MediaType.HENTAI in valid_labels and self.voc_match(query, "HentaiKeyword", lang=lang): return MediaType.HENTAI, 0.4 - elif self.voc_match(query, "VideoKeyword", lang=lang): + elif MediaType.VIDEO in valid_labels and self.voc_match(query, "VideoKeyword", lang=lang): return MediaType.VIDEO, 0.4 - elif self.voc_match(query, "AudioKeyword", lang=lang): + elif MediaType.AUDIO in valid_labels and self.voc_match(query, "AudioKeyword", lang=lang): return MediaType.AUDIO, 0.4 return MediaType.GENERIC, 0.0 - def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]: + def classify_media(self, query: str, lang: str, valid_labels: Optional[List[MediaType]] = None) -> Tuple[MediaType, float]: """ determine what media type is being requested """ lang = standardize_lang_tag(lang) + valid_labels = valid_labels or [m for m, s in self.media2skill.items() if s] or list(MediaType) + LOG.debug(f"valid media types: {valid_labels}") + if len(valid_labels) == 1: + return valid_labels[0], 1.0 + # using a trained classifier (Experimental) if self.config.get("experimental_media_classifier", False): from ovos_classifiers.skovos.classifier import SklearnOVOSClassifier @@ -768,6 +779,8 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]: featurizer: OCPFeaturizer = self._media_clf[1] X = featurizer.transform([query]) preds = clf.predict_labels(X)[0] + preds = {k: v for k, v in preds.items() + if OCPFeaturizer.label2media(k) in valid_labels} label = max(preds, key=preds.get) prob = float(round(preds[label], 3)) LOG.info(f"OVOSCommonPlay MediaType prediction: {label} confidence: {prob}") @@ -779,7 +792,7 @@ def classify_media(self, query: str, lang: str) -> Tuple[MediaType, float]: return OCPFeaturizer.label2media(label), prob except: LOG.exception(f"OCP classifier exception: {query}") - return self.voc_match_media(query, lang) + return self.voc_match_media(query, lang, valid_labels) def is_ocp_query(self, query: str, lang: str) -> Tuple[bool, float]: """ determine if a playback question is being asked"""