diff --git a/app/bot/nlu/entity_extractors/crf_entity_extractor.py b/app/bot/nlu/entity_extractors/crf_entity_extractor.py index 7d3022c2..d5c45d56 100755 --- a/app/bot/nlu/entity_extractors/crf_entity_extractor.py +++ b/app/bot/nlu/entity_extractors/crf_entity_extractor.py @@ -1,6 +1,8 @@ import pycrfsuite from app.config import app_config +MODEL_NAME = "crf__entity_extractor.model" + class CRFEntityExtractor: """ Performs NER training, prediction, model import/export @@ -10,6 +12,7 @@ def __init__(self, synonyms={}): import spacy self.tokenizer = spacy.load("en_core_web_md") self.synonyms = synonyms + self.tagger = None def replace_synonyms(self, entities): """ @@ -88,7 +91,7 @@ def sent_to_labels(self, sent): """ return [label for token, postag, label in sent] - def train(self, train_sentences, model_name): + def train(self, train_sentences, model_path: str): """ Train NER model for given model :param train_sentences: @@ -110,9 +113,23 @@ def train(self, train_sentences, model_name): # include transitions that are possible, but not observed 'feature.possible_transitions': True }) - trainer.train('model_files/%s.model' % model_name) + trainer.train(f"{model_path}/{MODEL_NAME}") return True + def load(self, model_path: str) -> bool: + """ + Load the CRF model from the given path + :param model_path: Path to the model directory + :return: True if successful, False otherwise + """ + try: + self.tagger = pycrfsuite.Tagger() + self.tagger.open(f"{model_path}/entity_model.model") + return True + except Exception as e: + print(f"Error loading CRF model: {e}") + return False + def crf2json(self, tagged_sentence): """ Extract label-value pair from NER prediction output @@ -143,19 +160,16 @@ def extract_ner_labels(self, predicted_labels): labels.append(tp[2:]) return labels - def predict(self, model_name, message): + def predict(self, message): """ - Predict NER labels for given model and query - :param model_name: + Predict NER labels for given message :param message: :return: """ spacy_doc = message.get("spacy_doc") tagged_token = self.pos_tagger(spacy_doc) words = [token.text for token in spacy_doc] - tagger = pycrfsuite.Tagger() - tagger.open("{}/{}.model".format(app_config.MODELS_DIR, model_name)) - predicted_labels = tagger.tag(self.sent_to_features(tagged_token)) + predicted_labels = self.tagger.tag(self.sent_to_features(tagged_token)) extracted_entities = self.crf2json( zip(words, predicted_labels)) return self.replace_synonyms(extracted_entities) diff --git a/app/bot/nlu/pipeline.py b/app/bot/nlu/pipeline.py index f927605b..ab19c004 100644 --- a/app/bot/nlu/pipeline.py +++ b/app/bot/nlu/pipeline.py @@ -111,28 +111,19 @@ def __init__(self, synonyms: Optional[Dict[str, str]] = None): self.extractor = CRFEntityExtractor(synonyms or {}) def train(self, training_data: List[Dict[str, Any]], model_path: str) -> None: - # Group training data by intent - intent_data = {} - for example in training_data: - intent = example.get("intent") - if intent not in intent_data: - intent_data[intent] = [] - intent_data[intent].append(example) - - # Train model for each intent - for intent_id, examples in intent_data.items(): - ner_training_data = self.extractor.json2crf(examples) - self.extractor.train(ner_training_data, intent_id) + # Convert all training data to CRF format at once + ner_training_data = self.extractor.json2crf(training_data) + # Train a single model for all entities + self.extractor.train(ner_training_data, "entity_model") def load(self, model_path: str) -> bool: - # Entity extractor loads models on demand per intent - return True + # Load the single entity model + return self.extractor.load(model_path) def process(self, message: Dict[str, Any]) -> Dict[str, Any]: - if not message.get("text") or not message.get("intent", {}).get("intent") or not message.get("spacy_doc"): + if not message.get("text") or not message.get("spacy_doc"): return message - intent_id = message["intent"]["intent"] - entities = self.extractor.predict(intent_id,message) + entities = self.extractor.predict(message) message["entities"] = entities return message \ No newline at end of file