From b47b36c40d1cece2484b7f18c2d1cdfcdf92a47e Mon Sep 17 00:00:00 2001 From: ganweinan <361882091@qq.com> Date: Sat, 23 Mar 2024 16:38:21 +0800 Subject: [PATCH] =?UTF-8?q?[20240323]=E5=AE=8C=E6=88=90contrastive=20&=20m?= =?UTF-8?q?lm=E4=B8=A4=E7=A7=8D=E5=AF=B9=E9=BD=90=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=E7=9A=84=E5=9F=BA=E7=A1=80=E6=A8=A1=E5=9D=97?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- alignment/arguments.py | 58 +++-------------------------- alignment/data.py | 26 +++++-------- alignment/{run_train.py => main.py} | 51 +++++++++++-------------- alignment/model.py | 9 ----- 4 files changed, 36 insertions(+), 108 deletions(-) rename alignment/{run_train.py => main.py} (68%) diff --git a/alignment/arguments.py b/alignment/arguments.py index ab24a697..4c4afcbf 100644 --- a/alignment/arguments.py +++ b/alignment/arguments.py @@ -1,6 +1,6 @@ import os from dataclasses import dataclass, field -from typing import Optional, List +from typing import Optional from transformers import TrainingArguments @@ -38,34 +38,8 @@ class ModelArguments: projection_in_dim: int = field(default=768) projection_out_dim: int = field(default=1) - # p*-tuning - model_type: str = field( - default="bert", - metadata={ - "help": "The type of model, where we currently support bert, roberta, deberta" - } - ) - prefix: bool = field( - default=False, - metadata={ - "help": "Will use P-tuning v2 during training" - } - ) - prompt: bool = field( - default=False, - metadata={ - "help": "Will use prompt tuning during training" - } - ) - prompt_from_vocab: bool = field( - default=True, - metadata={ - "help": "Will prompt embeddings initalized from plm's word embeddings" - } - ) - prompt_encoder_type: str = field(default=None) - pre_seq_len: int = field( - default=100, + prefix_len: int = field( + default=32, metadata={ "help": "The length of prompt" } @@ -145,29 +119,7 @@ class DataArguments: }, ) - def __post_init__(self): - if self.dataset_name is not None: - info = self.dataset_name.split('/') - self.dataset_split = info[-1] if len(info) == 3 else 'train' - self.dataset_name = "/".join(info[:-1]) if len(info) == 3 else '/'.join(info) - self.dataset_language = 'default' - if ':' in self.dataset_name: - self.dataset_name, self.dataset_language = self.dataset_name.split(':') - if self.train_dir is not None: - files = os.listdir(self.train_dir) - self.train_path = [ - os.path.join(self.train_dir, f) - for f in files - if f.endswith('tsv') or f.endswith('json') - ] - @dataclass -class DenseTrainingArguments(TrainingArguments): - warmup_ratio: float = field(default=0.1) - negatives_x_device: bool = field(default=False, metadata={"help": "share negatives across devices"}) - do_encode: bool = field(default=False, metadata={"help": "run the encoding loop"}) - - grad_cache: bool = field(default=False, metadata={"help": "Use gradient cache update"}) - gc_q_chunk_size: int = field(default=4) - gc_p_chunk_size: int = field(default=32) +class AlignmentTrainingArguments(TrainingArguments): + alignment_mode: str = field(default="contrastive", metadata={"help": "contrastive or mlm"}) diff --git a/alignment/data.py b/alignment/data.py index 9dbbe8ff..b3d9f457 100644 --- a/alignment/data.py +++ b/alignment/data.py @@ -1,26 +1,20 @@ -import os -import random import json from dataclasses import dataclass -from typing import Optional, Union, List, Dict, Tuple, Any -import itertools +from typing import Optional, List, Tuple import numpy as np -import datasets import torch import torch.utils.data as Data from torch.utils.data import Dataset -from transformers import PreTrainedTokenizer, BatchEncoding, DataCollatorWithPadding -from transformers.tokenization_utils_base import PaddingStrategy, PreTrainedTokenizerBase +from transformers import PreTrainedTokenizer, DataCollatorWithPadding +from transformers.tokenization_utils_base import PreTrainedTokenizerBase import pandas as pd import torch -from sklearn.metrics import mean_squared_error -from sklearn.model_selection import train_test_split from sklearn.preprocessing import LabelEncoder from .arguments import DataArguments -from ..deepctr_torch.inputs import build_input_features, get_feature_names +from ..deepctr_torch.inputs import get_feature_names from ..deepctr_torch.inputs import (DenseFeat, SparseFeat, VarLenSparseFeat) import logging @@ -101,11 +95,12 @@ def __getitem__(self, item): @dataclass class ContrastiveAlignmentCollator(DataCollatorWithPadding): + tokenizer: PreTrainedTokenizerBase max_len: int = 64 def __call__(self, features): - # batch inputs for PLM\LLM + # batch inputs for nlp model text_input = [feat_map["text_model_input"] for feat_map in features] text_input_batch = self.tokenizer.pad( text_input, @@ -124,20 +119,17 @@ def __call__(self, features): class MlmAlignmentCollator(DataCollatorWithPadding): tokenizer: PreTrainedTokenizerBase - padding: Union[bool, str, PaddingStrategy] = True - max_length: Optional[int] = 64 - pad_to_multiple_of: Optional[int] = None - return_tensors: str = "pt" + max_len: Optional[int] = 64 mlm_probability: float = 0.15 def __call__(self, features): - # batch inputs for PLM\LLM + # batch inputs for nlp model text_input = [feat_map["text_model_input"] for feat_map in features] batch = self.tokenizer.pad( text_input, padding='max_length', - max_length=self.max_length, + max_length=self.max_len, return_tensors="pt", ) # generate input & label for mlm train diff --git a/alignment/run_train.py b/alignment/main.py similarity index 68% rename from alignment/run_train.py rename to alignment/main.py index 0e0e1a67..35d3710f 100644 --- a/alignment/run_train.py +++ b/alignment/main.py @@ -1,10 +1,8 @@ import logging import os import sys -import json sys.path.insert(0, '..') -import datasets from transformers import AutoConfig, AutoTokenizer, AutoModel from transformers import ( HfArgumentParser, @@ -12,9 +10,9 @@ ) from .arguments import ModelArguments, DataArguments, \ - DenseTrainingArguments as TrainingArguments -from .data import AlignmentDataset, ContrastiveAlignmentCollator -from .model import ContrastiveAlignmentModel + AlignmentTrainingArguments as TrainingArguments +from .data import AlignmentDataset, ContrastiveAlignmentCollator, MlmAlignmentCollator +from .model import ContrastiveAlignmentModel, MlmAlignmentModel from .trainer import AlignmentTrainer as Trainer from deepctr_torch.models.deepfm import DeepFM @@ -62,21 +60,6 @@ def main(): set_seed(training_args.seed) - # config = AutoConfig.from_pretrained( - # model_args.config_name if model_args.config_name else model_args.model_name_or_path, - # cache_dir=model_args.cache_dir, - # num_labels=model_args.projection_out_dim - # ) - # # p*-tuning - # config.fine_tuning = model_args.fine_tuning - # config.prefix = model_args.prefix - # config.prompt = model_args.prompt - # config.prompt_from_vocab = model_args.prompt_from_vocab - # config.prompt_encoder_type = model_args.prompt_encoder_type - # config.pre_seq_len = model_args.pre_seq_len - # config.prefix_projection = model_args.prefix_projection - # config.prefix_hidden_size = model_args.prefix_hidden_size - # config.hidden_dropout_prob = model_args.hidden_dropout_prob tokenizer = AutoTokenizer.from_pretrained( model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path, @@ -97,20 +80,30 @@ def main(): # build text model text_model = AutoModel.from_pretrained(model_args.model_name_or_path, add_pooling_layer=False) - alignment_model = ContrastiveAlignmentModel(ctr_model = ctr_model, - text_model = text_model, - model_args = model_args, - data_args = data_args, - train_args = training_args) - + # build alignment train model + if training_args.alignment_mode == "contrastive": + alignment_model = ContrastiveAlignmentModel(ctr_model = ctr_model, + text_model = text_model, + model_args = model_args, + data_args = data_args, + train_args = training_args) + data_collator=ContrastiveAlignmentCollator(tokenizer=tokenizer) + + elif training_args.alignment_mode == "mlm": + alignment_model = MlmAlignmentModel(ctr_model = ctr_model, + text_model = text_model, + model_args = model_args, + data_args = data_args, + train_args = training_args) + data_collator=MlmAlignmentCollator(tokenizer=tokenizer) + else: + raise ValueError("Alignment mode must be in [contrastive, mlm]") trainer = Trainer( model=alignment_model, args=training_args, train_dataset=train_dataset, - data_collator=ContrastiveAlignmentCollator( - tokenizer - ), + data_collator=data_collator, ) trainer.train() diff --git a/alignment/model.py b/alignment/model.py index 01953d04..62dc425d 100644 --- a/alignment/model.py +++ b/alignment/model.py @@ -154,21 +154,12 @@ def __init__( self.train_args = train_args self.data_args = data_args - # 判断是否需要维度对齐 self.text_model_config = AutoConfig.from_pretrained( self.model_args.model_name_or_path, cache_dir=self.model_args.cache_dir, revision=self.model_args.model_revision, use_auth_token=True if self.model_args.use_auth_token else None, ) - # if self.model_args.ctr_hidden_dim == text_model_config.hidden_size: - # logger.info("CTR hidden size equal to Text model hidden size") - # self.add_pooler = False - # else: - # logger.warning("CTR hidden size not equal to Text model hidden size, add pooler layer") - # self.add_pooler = True - # self.pooler = LinearPooler(input_dim=text_model_config.hidden_size, - # output_dim=self.model_args.ctr_hidden_dim) self.prompt_layers = nn.Sequential( nn.Linear(self.model_args.ctr_hidden_dim, self.text_model_config.hidden_size),