From c971a8bddcbdf47b218d3bc8734a3dfb7867fe1e Mon Sep 17 00:00:00 2001 From: James Brandreth Date: Tue, 28 Jan 2025 15:34:36 +0000 Subject: [PATCH] bug: set metamodel nclasses *before* building metamodel --- src/scripts/miade.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/src/scripts/miade.py b/src/scripts/miade.py index afcdd44..288b87f 100644 --- a/src/scripts/miade.py +++ b/src/scripts/miade.py @@ -18,7 +18,6 @@ from shutil import rmtree from typing import Optional, List from pydantic import BaseModel, validator -from datetime import datetime from tokenizers import ByteLevelBPETokenizer from gensim.models import Word2Vec @@ -746,13 +745,13 @@ def make(config_filepath: Path, temp_dir: Path = Path("./.temp"), output: Path = meta_model_categories.append(meta_spec.config.general.category_name) - meta_model = MiADE_MetaCAT(tokenizer=tokenizer, embeddings=embeddings, config=meta_spec.config) - annotations_path = meta_spec.annotations.get_or_download(temp_dir) if meta_spec.synthetic_data: synthetic_csv_path = meta_spec.synthetic_data.get_or_download(temp_dir) - synthetic_data = pl.read_csv(synthetic_csv_path)[["text", meta_model_name]] + synthetic_data = pl.read_csv(synthetic_csv_path, null_values=["nan"])[ + ["text", "start", "end", meta_model_name] + ].filter(pl.col(meta_model_name).is_not_null()) prepped_synthetic_data_path = temp_dir / Path(f"prepped_synthetic_data_{meta_model_name}") synthetic_data.write_csv(prepped_synthetic_data_path) synthetic_label_counts = _get_label_counts_for_synthetic_data(synthetic_data, meta_model_name) @@ -765,16 +764,19 @@ def make(config_filepath: Path, temp_dir: Path = Path("./.temp"), output: Path = synthetic_label_counts, ) + log.info(f"Number of examples: {label_counts}") + labels = list(label_counts.keys()) total_label_count = sum(label_counts.values()) - meta_model.config.general["category_value2id"] = dict(zip(labels, range(0, len(labels)))) + meta_spec.config.general["category_value2id"] = dict(zip(labels, range(0, len(labels)))) if meta_spec.balance_weights: - meta_model.config.train["class_weights"] = [ + meta_spec.config.train["class_weights"] = [ 1 - (label_counts[label] / total_label_count) for label in labels ] - meta_model.config.model.nclasses = len(labels) - meta_model.config.model["last_trained_on"] = datetime.now().strftime("%y%m%d%H%M%S") + meta_spec.config.model.nclasses = len(labels) + + meta_model = MiADE_MetaCAT(tokenizer=tokenizer, embeddings=embeddings, config=meta_spec.config) log.info( f"Starting MetaCAT training for {meta_spec.config.general['category_name']} for {meta_spec.config.train.nepochs} epoch(s) "