Skip to content

Commit

Permalink
fix:improve cache (#40)
Browse files Browse the repository at this point in the history
* fix:improve cache

cache domain intents in their own folder to avoid retraining when config changes

improve error handling

* Update ovos_padatious/opm.py

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>

* fix:improve cache

cache domain intents in their own folder to avoid retraining when config changes

improve error handling

---------

Co-authored-by: coderabbitai[bot] <136622811+coderabbitai[bot]@users.noreply.github.com>
  • Loading branch information
JarbasAl and coderabbitai[bot] authored Jan 25, 2025
1 parent 97149f8 commit 4e2c2ce
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 7 deletions.
17 changes: 15 additions & 2 deletions ovos_padatious/opm.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,14 +274,27 @@ def __init__(self, bus: Optional[Union[MessageBusClient, FakeBus]] = None,
if engine_class is None and self.config.get("domain_engine"):
engine_class = DomainIntentContainer

use_stemmer = self.config.get("stem", False)
self.engine_class = engine_class or IntentContainer
intent_cache = expanduser(self.config.get('intent_cache') or
f"{xdg_data_home()}/{get_xdg_base()}/intent_cache")
if self.engine_class == DomainIntentContainer:
# allow user to switch back and forth without retraining
# cache is cheap, training isn't
intent_cache += "_domain"
if use_stemmer:
intent_cache += "_stemmer"
self.containers = {lang: self.engine_class(cache_dir=f"{intent_cache}/{lang}",
disable_padaos=self.config.get("disable_padaos", False))
for lang in langs}

if self.config.get("stem", False):
# pre-load any cached intents
for container in self.containers.values():
try:
container.instantiate_from_disk()
except Exception as e:
LOG.error(f"Failed to pre-load cached intents: {str(e)}")

if use_stemmer:
self.stemmers = {lang: Stemmer(lang)
for lang in langs if Stemmer.supports_lang(lang)}
else:
Expand Down
15 changes: 12 additions & 3 deletions ovos_padatious/training_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,10 @@ def add(self, name: str, lines: List[str], reload_cache: bool = False, must_trai
old_hsh = g.read()
min_ver = splitext(ovos_padatious.__version__)[0]
new_hsh = lines_hash([min_ver] + lines)
if not old_hsh:
LOG.debug("First time training")
elif old_hsh and old_hsh != new_hsh:
LOG.debug(f"{name} Hash changed! retraining - {old_hsh} {new_hsh}")
retrain = reload_cache or old_hsh != new_hsh
if not retrain:
try:
Expand Down Expand Up @@ -138,17 +142,22 @@ def train(self, debug: bool = True, single_thread: Union[None, bool] = None,

train = partial(_train_and_save, cache=self.cache, data=self.train_data, print_updates=debug)

objs = list(self.objects_to_train) # make a copy so its thread safe
fails = []
# Train objects sequentially
for obj in self.objects_to_train:
for obj in objs:
try:
train(obj)
except Exception as e:
LOG.error(f"Error training {obj.name}: {e}")
fails.append(obj)

# Load saved objects from disk
for obj in self.objects_to_train:
for obj in objs:
try:
self.objects.append(self.cls.from_file(name=obj.name, folder=self.cache))
except Exception as e:
LOG.error(f"Failed to load trained object {obj.name}: {e}")
self.objects_to_train = []
fails.append(obj)
self.objects_to_train = [o for o in self.objects_to_train
if o not in objs or o in fails]
2 changes: 1 addition & 1 deletion ovos_padatious/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def lines_hash(lines):
bytearray: Binary hash
"""
x = xxh32()
for i in lines:
for i in sorted(lines):
x.update(i.encode())
return x.digest()

Expand Down
2 changes: 1 addition & 1 deletion tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@


def test_lines_hash():
assert lines_hash(['word1', 'word2']) != lines_hash(['word2', 'word1'])
assert lines_hash(['word1', 'word2']) == lines_hash(['word2', 'word1'])
assert lines_hash(['word1', 'word2']) != lines_hash(['word1', 'word1'])


Expand Down

0 comments on commit 4e2c2ce

Please sign in to comment.