Skip to content

Commit

Permalink
bug: set metamodel nclasses *before* building metamodel
Browse files Browse the repository at this point in the history
  • Loading branch information
jamesbrandreth committed Jan 28, 2025
1 parent 56270cf commit c971a8b
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions src/scripts/miade.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from shutil import rmtree
from typing import Optional, List
from pydantic import BaseModel, validator
from datetime import datetime

from tokenizers import ByteLevelBPETokenizer
from gensim.models import Word2Vec
Expand Down Expand Up @@ -746,13 +745,13 @@ def make(config_filepath: Path, temp_dir: Path = Path("./.temp"), output: Path =

meta_model_categories.append(meta_spec.config.general.category_name)

meta_model = MiADE_MetaCAT(tokenizer=tokenizer, embeddings=embeddings, config=meta_spec.config)

annotations_path = meta_spec.annotations.get_or_download(temp_dir)

if meta_spec.synthetic_data:
synthetic_csv_path = meta_spec.synthetic_data.get_or_download(temp_dir)
synthetic_data = pl.read_csv(synthetic_csv_path)[["text", meta_model_name]]
synthetic_data = pl.read_csv(synthetic_csv_path, null_values=["nan"])[
["text", "start", "end", meta_model_name]
].filter(pl.col(meta_model_name).is_not_null())
prepped_synthetic_data_path = temp_dir / Path(f"prepped_synthetic_data_{meta_model_name}")
synthetic_data.write_csv(prepped_synthetic_data_path)
synthetic_label_counts = _get_label_counts_for_synthetic_data(synthetic_data, meta_model_name)
Expand All @@ -765,16 +764,19 @@ def make(config_filepath: Path, temp_dir: Path = Path("./.temp"), output: Path =
synthetic_label_counts,
)

log.info(f"Number of examples: {label_counts}")

labels = list(label_counts.keys())
total_label_count = sum(label_counts.values())

meta_model.config.general["category_value2id"] = dict(zip(labels, range(0, len(labels))))
meta_spec.config.general["category_value2id"] = dict(zip(labels, range(0, len(labels))))
if meta_spec.balance_weights:
meta_model.config.train["class_weights"] = [
meta_spec.config.train["class_weights"] = [
1 - (label_counts[label] / total_label_count) for label in labels
]
meta_model.config.model.nclasses = len(labels)
meta_model.config.model["last_trained_on"] = datetime.now().strftime("%y%m%d%H%M%S")
meta_spec.config.model.nclasses = len(labels)

meta_model = MiADE_MetaCAT(tokenizer=tokenizer, embeddings=embeddings, config=meta_spec.config)

log.info(
f"Starting MetaCAT training for {meta_spec.config.general['category_name']} for {meta_spec.config.train.nepochs} epoch(s) "
Expand Down

0 comments on commit c971a8b

Please sign in to comment.