diff --git a/Packs/Base/Scripts/DBotPredictPhishingWords/DBotPredictPhishingWords.py b/Packs/Base/Scripts/DBotPredictPhishingWords/DBotPredictPhishingWords.py index 538f3c05315b..fc133c958e62 100644 --- a/Packs/Base/Scripts/DBotPredictPhishingWords/DBotPredictPhishingWords.py +++ b/Packs/Base/Scripts/DBotPredictPhishingWords/DBotPredictPhishingWords.py @@ -9,14 +9,28 @@ logging.getLogger('transformers').setLevel(logging.ERROR) FASTTEXT_MODEL_TYPE = 'FASTTEXT_MODEL_TYPE' -TORCH_TYPE = 'torch' UNKNOWN_MODEL_TYPE = 'UNKNOWN_MODEL_TYPE' +TORCH_TYPE = demisto_ml.ModelType.Torch.value +FASTTEXT_TYPE = demisto_ml.ModelType.FastText.value def OrderedSet(iterable): return list(dict.fromkeys(iterable)) +def update_model(model_data, model_type, model_name): + res = demisto.executeCommand( + 'createMLModel', + { + 'modelData': model_data, + 'modelName': model_name, + 'modelType': model_type, + } + ) + if is_error(res): + raise DemistoException(f'Unable to update model: {res}') + + def get_model_data(model_name: str, store_type: str, is_return_error: bool) -> tuple[dict, str]: def load_from_models(model_name: str) -> None | tuple[dict, str]: @@ -106,10 +120,18 @@ def preprocess_text(text, model_type, is_return_error): def predict_phishing_words(model_name, model_store_type, email_subject, email_body, min_text_length, label_threshold, word_threshold, top_word_limit, is_return_error, set_incidents_fields=False): model_data, model_type = get_model_data(model_name, model_store_type, is_return_error) - if model_type.strip() == '' or model_type.strip() == 'Phishing': - model_type = FASTTEXT_MODEL_TYPE - if model_type not in [FASTTEXT_MODEL_TYPE, TORCH_TYPE, UNKNOWN_MODEL_TYPE]: - model_type = UNKNOWN_MODEL_TYPE + + if model_type in ('Phishing', 'torch'): + model_data, model_type = demisto_ml.renew_model(model_data, model_type) + update_model(model_data, model_type, model_name) + + model_type = { + '': FASTTEXT_MODEL_TYPE, + FASTTEXT_TYPE: FASTTEXT_MODEL_TYPE, + FASTTEXT_MODEL_TYPE: FASTTEXT_MODEL_TYPE, + TORCH_TYPE: TORCH_TYPE, + UNKNOWN_MODEL_TYPE: UNKNOWN_MODEL_TYPE, + }.get(model_type.strip(), UNKNOWN_MODEL_TYPE) phishing_model = demisto_ml.phishing_model_loads_handler(model_data, model_type) @@ -171,8 +193,7 @@ def predict_single_incident_full_output(email_subject, email_body, is_return_err explain_result['Probability'] = float(explain_result["Probability"]) predicted_prob = explain_result["Probability"] if predicted_prob < label_threshold: - handle_error("Label probability is {:.2f} and it's below the input confidence threshold".format( - predicted_prob), is_return_error) + handle_error(f"Label probability is {predicted_prob:.2f} and it's below the input confidence threshold", is_return_error) positive_tokens = OrderedSet(explain_result['PositiveWords']) negative_tokens = OrderedSet(explain_result['NegativeWords']) @@ -196,8 +217,8 @@ def predict_single_incident_full_output(email_subject, email_body, is_return_err explain_result_hr = {} explain_result_hr['TextTokensHighlighted'] = highlighted_text_markdown explain_result_hr['Label'] = predicted_label - explain_result_hr['Probability'] = "%.2f" % predicted_prob - explain_result_hr['Confidence'] = "%.2f" % predicted_prob + explain_result_hr['Probability'] = f"{predicted_prob:.2f}" + explain_result_hr['Confidence'] = f"{predicted_prob:.2f}" explain_result_hr['PositiveWords'] = ", ".join([w.lower() for w in positive_words]) explain_result_hr['NegativeWords'] = ", ".join([w.lower() for w in negative_words]) incident_context = demisto.incidents()[0] @@ -257,7 +278,6 @@ def main(): demisto.args()['returnError'] == 'true', demisto.args().get('setIncidentFields', 'false') == 'true' ) - return result diff --git a/Packs/Base/Scripts/DBotTrainTextClassifierV2/DBotTrainTextClassifierV2.py b/Packs/Base/Scripts/DBotTrainTextClassifierV2/DBotTrainTextClassifierV2.py index 8ddac2f11140..88b6f0d35706 100644 --- a/Packs/Base/Scripts/DBotTrainTextClassifierV2/DBotTrainTextClassifierV2.py +++ b/Packs/Base/Scripts/DBotTrainTextClassifierV2/DBotTrainTextClassifierV2.py @@ -26,7 +26,10 @@ AUTO_TRAINING_ALGO = 'auto' # the following mapping need to correspond to predict_phishing_words func at DBotPredictPhishingWords -ALGO_TO_MODEL_TYPE = {FASTTEXT_TRAINING_ALGO: 'Phishing', FINETUNE_TRAINING_ALGO: 'torch'} +ALGO_TO_MODEL_TYPE = { + FASTTEXT_TRAINING_ALGO: demisto_ml.ModelType.FastText.value, + FINETUNE_TRAINING_ALGO: demisto_ml.ModelType.Torch.value +} FINETUNE_LABELS = ['Malicious', 'Non-Malicious'] @@ -348,7 +351,7 @@ def validate_labels_and_decide_algorithm(y, algorithm): labels_counter = Counter(y) # type: Dict[str, int] illegal_labels_for_fine_tune = [label for label in labels_counter if label not in FINETUNE_LABELS] if algorithm == FINETUNE_TRAINING_ALGO and len(illegal_labels_for_fine_tune) > 0: - error = ['When trainingAlgorithm is set to {}, all labels mus be mapped to {}.\n'.format(algorithm, + error = ['When trainingAlgorithm is set to {}, all labels must be mapped to {}.\n'.format(algorithm, ', '.join( FINETUNE_LABELS))] error += ['The following labels/verdicts need to be mapped to one of those values: '] diff --git a/Packs/ML/Scripts/DBotPredictOutOfTheBoxV2/DBotPredictOutOfTheBoxV2.py b/Packs/ML/Scripts/DBotPredictOutOfTheBoxV2/DBotPredictOutOfTheBoxV2.py index 0c526a008b14..e649ca40feaf 100644 --- a/Packs/ML/Scripts/DBotPredictOutOfTheBoxV2/DBotPredictOutOfTheBoxV2.py +++ b/Packs/ML/Scripts/DBotPredictOutOfTheBoxV2/DBotPredictOutOfTheBoxV2.py @@ -29,7 +29,7 @@ def load_oob_model(): 'modelName': OUT_OF_THE_BOX_MODEL_NAME, 'modelLabels': ['Malicious', 'Non-Malicious'], 'modelOverride': 'true', - 'modelType': 'torch', + 'modelType': demisto_ml.ModelType.Torch.value, 'modelExtraInfo': {'threshold': THRESHOLD, OOB_VERSION_INFO_KEY: SCRIPT_MODEL_VERSION }