diff --git a/llmfoundry/command_utils/__init__.py b/llmfoundry/command_utils/__init__.py index 0226c4f408..756e611a88 100644 --- a/llmfoundry/command_utils/__init__.py +++ b/llmfoundry/command_utils/__init__.py @@ -8,6 +8,9 @@ convert_dataset_json, convert_dataset_json_from_args, ) +from llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds import ( + convert_delta_to_contrastive_mds, +) from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( convert_delta_to_json_from_args, fetch_DT, @@ -44,6 +47,7 @@ 'convert_dataset_hf_from_args', 'convert_dataset_json', 'convert_dataset_json_from_args', + 'convert_delta_to_contrastive_mds', 'convert_finetuning_dataset_from_args', 'convert_finetuning_dataset', 'convert_text_to_mds', diff --git a/llmfoundry/command_utils/data_prep/convert_delta_to_contrastive_mds.py b/llmfoundry/command_utils/data_prep/convert_delta_to_contrastive_mds.py new file mode 100644 index 0000000000..30bce3ddb7 --- /dev/null +++ b/llmfoundry/command_utils/data_prep/convert_delta_to_contrastive_mds.py @@ -0,0 +1,162 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import json +import logging +import os +import tempfile +from typing import TYPE_CHECKING, Optional + +from streaming import MDSWriter + +from llmfoundry.command_utils.data_prep.convert_delta_to_json import ( + _check_imports, + fetch_DT, + run_query, + validate_and_get_cluster_info, +) + +if TYPE_CHECKING: + from databricks.sql.client import Cursor + from pyspark.sql import SparkSession + +logger = logging.getLogger(__name__) + + +def validate_columns_in_table( + required_columns: list, + optional_columns: list, + table_name: str, + method: str, + cursor: Optional['Cursor'] = None, + spark: Optional['SparkSession'] = None, +) -> bool: + """Validate that required and optional columns exist in the Delta table.""" + try: + result = run_query( + f'SHOW COLUMNS IN {table_name}', + method, + cursor, + spark, + ) + + # Get the actual column names + assert result + actual_columns = [row.asDict()['col_name'] for row in result] + + missing_required = set(required_columns) - set(actual_columns) + allowed_columns = set(required_columns + optional_columns) + extra_columns = set(actual_columns) - allowed_columns + + if missing_required: + logger.error(f'Missing required columns: {missing_required}') + return False + if extra_columns: + logger.warning(f'Extra columns found: {extra_columns}') + return False + + logger.info( + f'Table {table_name} contains the required and optional columns.', + ) + return True + except Exception as e: + logger.error(f'Error validating columns in table {table_name}: {e}') + return False + + +def convert_delta_to_contrastive_mds( + delta_table_name: str, + http_path: Optional[str], + cluster_id: Optional[str], + use_serverless: bool, + output_path: str, + batch_size: int, + processes: int, +): + _check_imports() + from databricks.sdk import WorkspaceClient + w = WorkspaceClient() + DATABRICKS_HOST = w.config.host + DATABRICKS_TOKEN = w.config.token + + logger.info( + f'Validating columns in table {delta_table_name} and cluster info...', + ) + dtypes = { + 'query_text': 'str', + 'positive_passage': 'str', + 'negative_passages': 'str', + } + required_columns = ['query_text', 'positive_passage'] + optional_columns = ['negative_passages'] + method, dbsql, sparkSession = validate_and_get_cluster_info( + cluster_id=cluster_id, + databricks_host=DATABRICKS_HOST, + databricks_token=DATABRICKS_TOKEN, + http_path=http_path, + use_serverless=use_serverless, + ) + logger.info(f'Validated cluster info') + if not validate_columns_in_table( + required_columns=required_columns, + optional_columns=optional_columns, + table_name=delta_table_name, + method=method, + cursor=dbsql.cursor() if dbsql else None, + spark=sparkSession, + ): + logger.error('Column validation failed. Exiting.') + raise ValueError('Column validation failed.') + logger.info(f'Validated columns in table {delta_table_name}') + + compression = 'zstd:7' + hashes = ['sha1'] + limit = '10mb' + + def convert_x(x: dict) -> dict: + + return { + 'query_text': + x['query_text'], + 'positive_passage': + x['positive_passage'], + 'negative_passages': + json.dumps(x['negative_passages']) + if 'negative_passages' in x else '[]', + } + + with tempfile.TemporaryDirectory() as temp_dir: + logger.info(f'Created temporary directory at {temp_dir}') + json_output_path = os.path.join(temp_dir, 'output.jsonl') + try: + fetch_DT( + delta_table_name=delta_table_name, + json_output_folder=temp_dir, + http_path=http_path, + cluster_id=cluster_id, + use_serverless=use_serverless, + json_output_filename='output.jsonl', + batch_size=batch_size, + processes=processes, + DATABRICKS_HOST=DATABRICKS_HOST, + DATABRICKS_TOKEN=DATABRICKS_TOKEN, + ) + except Exception as e: + logger.error(f'Error fetching data: {e}') + raise e + with MDSWriter( + out=output_path, + columns=dtypes, + compression=compression, + hashes=hashes, + size_limit=limit, + ) as out: + try: + with open(json_output_path, 'r') as f: + for line in f: + out.write(convert_x(json.loads(line))) + except FileNotFoundError as e: + logger.error(f'JSON output file not found: {e}') + raise e + + logger.info(f'Wrote to MDS at {output_path}') diff --git a/llmfoundry/data/__init__.py b/llmfoundry/data/__init__.py index 5710be0c55..3511bb39b1 100644 --- a/llmfoundry/data/__init__.py +++ b/llmfoundry/data/__init__.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +from llmfoundry.data.contrastive_pairs.dataloader import build_pairs_dataloader from llmfoundry.data.data import ( SUPPORTED_MDS_ENCODING_TYPES, ConcatTokensDataset, @@ -38,6 +39,7 @@ dataloaders.register('text', func=build_text_dataloader) dataloaders.register('finetuning', func=build_finetuning_dataloader) +dataloaders.register('contrastive_pairs', func=build_pairs_dataloader) dataset_replication_validators.register( 'dataset_replication_validator', diff --git a/llmfoundry/data/contrastive_pairs/__init__.py b/llmfoundry/data/contrastive_pairs/__init__.py new file mode 100644 index 0000000000..24ce1cd652 --- /dev/null +++ b/llmfoundry/data/contrastive_pairs/__init__.py @@ -0,0 +1,12 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.data.contrastive_pairs.dataloader import ( + StreamingPairsDataset, + build_pairs_dataloader, +) + +__all__ = [ + 'StreamingPairsDataset', + 'build_pairs_dataloader', +] diff --git a/llmfoundry/data/contrastive_pairs/dataloader.py b/llmfoundry/data/contrastive_pairs/dataloader.py new file mode 100644 index 0000000000..d9760aa926 --- /dev/null +++ b/llmfoundry/data/contrastive_pairs/dataloader.py @@ -0,0 +1,352 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""Dataset and dataloader for contrastive training. + +Build a StreamingPairsDataset dataset and dataloader for contrastive training. +""" + +import json +from itertools import islice +from typing import Any, Literal, Mapping, Optional, Union + +import numpy as np +import torch +from composer.core import DataSpec +from streaming import Stream, StreamingDataset +from torch.utils.data import DataLoader +from transformers import PreTrainedTokenizerBase + +from llmfoundry import registry +from llmfoundry.data.text_data import ( + ConcatenatedSequenceCollatorWrapper, + StreamingTextDataset, +) +from llmfoundry.utils.registry_utils import construct_from_registry + +ContrastiveSampleType = Literal['one_query_multiple_responses', + 'one_query_one_response'] + + +def _get_contrastive_sample_type( + sample: Mapping[str, Any], +) -> ContrastiveSampleType: + """Get the type of contrastive sample from the sample. + + Args: + sample (Mapping): A sample from the dataset. + + Returns: + ContrastiveSampleType: The type of contrastive sample. + """ + sample_contains_text_a = any( + key.startswith('text_a') for key in sample.keys() + ) + sample_contains_text_b = any( + key.startswith('text_b') for key in sample.keys() + ) + + if sample_contains_text_a and sample_contains_text_b: + return 'one_query_one_response' + elif 'query_text' in sample and 'positive_passage' in sample and 'negative_passages' in sample: + return 'one_query_multiple_responses' + else: + raise ValueError( + 'Sample does not contain the required keys for contrastive training. \ + For datasets with one query and one response, the keys must contain \ + "text_a" and "text_b". For datasets with one query and multiple responses, \ + the keys must contain "query_text", "positive_passage", and "negative_passages".', + ) + + +class StreamingPairsDataset(StreamingTextDataset): + """Contrastive pairs dataset using MosaicML's StreamingTextDataset. + + Args: + max_hard_negatives (int, optional): The maximum number of hard negatives to include in the + contrastive training samples. Defaults to ``None``. If ``None``, all hard negatives are + included. + prepend_query (str, optional): Text to prepend to the query text. Defaults to ``''``. + prepend_passage (str, optional): Text to prepend to the passage text. Defaults to ``''``. + append_eos_token (bool, optional): Whether to append the EOS token to the query and passage + text. Defaults to ``False``. Mutually exclusive with ``append_token``. + append_token (str, optional): Token to append to the query and passage text. Defaults to + ``''``. Mutually exclusive with ``append_eos_token``. + shuffle_hard_negatives (bool, optional): Whether to shuffle the hard negatives. Defaults to + ``False``. + **kwargs: Additional keyword arguments to pass to the superclass. See ``StreamingTextDataset`` + for more information. + """ + + def __init__( + self, + max_hard_negatives: Optional[int] = None, + prepend_query: str = '', + prepend_passage: str = '', + append_eos_token: bool = False, + append_token: str = '', + shuffle_hard_negatives: bool = False, + **kwargs: Any, + ): + + super().__init__(**kwargs) + + self.max_hard_negatives = max_hard_negatives + self.prepend_query = prepend_query + self.prepend_passage = prepend_passage + self.shuffle_hard_negatives = shuffle_hard_negatives + self._generator = np.random.default_rng(seed=self.shuffle_seed) + if append_eos_token: + if append_token != '': + raise ValueError( + 'The arguments append_eos_token and append_token are mutually exclusive.', + ) + self.append_token = self.tokenizer.eos_token + else: + self.append_token = append_token + + def _get_contrastive_samples( + self, + query_text: str, + positive_response: str, + negative_responses: list[str], + ) -> dict[str, Union[str, list[str]]]: + """Flatten contrastive samples into a list of strings. + + Args: + query_text (str): The query text. + positive_response (str): The positive response. + negative_responses (List[str]): The negative responses. + + Returns: + Dict[str, Union[str, List[str]]]: The contrastive samples, with keys 'query', 'positive', and 'negative'. + """ + query_text = f'{self.prepend_query}{query_text}{self.append_token}' + positive_response = f'{self.prepend_passage}{positive_response}{self.append_token}' + if self.shuffle_hard_negatives: + self._generator.shuffle(negative_responses) + negative_responses = negative_responses[:self.max_hard_negatives] + negative_responses = [ + f'{self.prepend_passage}{response}{self.append_token}' + for response in negative_responses + ] + return { + 'query': query_text, + 'positive': positive_response, + 'negative': negative_responses, + } + + def __getitem__(self, idx: int) -> dict[str, list[int]]: + sample = StreamingDataset.__getitem__(self, idx) + text_samples = [] + + sample_type = _get_contrastive_sample_type(sample) + if sample_type == 'one_query_one_response': + text_samples = self._get_contrastive_samples( + sample['text_a'], + sample['text_b'], + [], + ) + elif sample_type == 'one_query_multiple_responses': + negative_passages_str = sample['negative_passages'] + text_samples = self._get_contrastive_samples( + sample['query_text'], + sample['positive_passage'], + json.loads(negative_passages_str), + ) + else: + raise ValueError(f'Unknown sample type: {sample_type}') + + token_samples = self._tokenize(text_samples) + return token_samples + + def _tokenize( + self, + text_samples: dict[str, Union[str, list[str]]], + ) -> dict[str, list[int]]: + if self.tokenizer.pad_token is None: + raise RuntimeError( + 'If tokenizing on-the-fly, tokenizer must have a pad_token_id', + ) + + text_samples_list = [text_samples['query'], text_samples['positive']] + text_samples_negatives = text_samples['negative'] + assert isinstance(text_samples_negatives, list) # pyright type check + text_samples_list.extend(text_samples_negatives) + return self.tokenizer( + text_samples_list, + truncation=True, + padding='max_length', + max_length=self.max_seq_len, + ) + + +def build_pairs_dataloader( + dataset: dict[str, Any], + tokenizer: PreTrainedTokenizerBase, + device_batch_size: int, + drop_last: bool, + num_workers: int, + pin_memory: bool = True, + prefetch_factor: int = 2, + persistent_workers: bool = True, + timeout: int = 0, + max_hard_negatives: Optional[int] = None, +) -> DataSpec: + dataset_cfg = dataset + streams_dict = dataset.pop('streams', None) + eos_token_id = dataset.pop('eos_token_id', None) + bos_token_id = dataset.pop('bos_token_id', None) + + streams = None + if streams_dict is not None: + streams = [] + for stream in streams_dict.values(): + # stream is the streams kwargs + # fwd all kwargs with **stream allows streaming to check args + streams.append(Stream(**stream)) + + pairs_dataset = StreamingPairsDataset( + tokenizer=tokenizer, + streams=streams, + batch_size=device_batch_size, + max_hard_negatives=max_hard_negatives, + **dataset, + ) + + dataloader_cfg = { + 'name': 'contrastive_pairs', + 'dataset': dataset_cfg, + 'drop_last': drop_last, + 'num_workers': num_workers, + 'pin_memory': pin_memory, + 'prefetch_factor': prefetch_factor, + 'persistent_workers': persistent_workers, + 'timeout': timeout, + } + + collate_fn, _ = construct_from_registry( + name='text_collator', + registry=registry.collators, + partial_function=False, + kwargs={ + 'dataloader_cfg': dataloader_cfg, + 'tokenizer': tokenizer, + 'dataset_batch_size': device_batch_size, + }, + ) + + if (eos_token_id is not None) or (bos_token_id is not None): + # Note: Will raise an error if both are non-None + collate_fn = ConcatenatedSequenceCollatorWrapper( + base_collator=collate_fn, + eos_token_id=eos_token_id, + bos_token_id=bos_token_id, + ) + + def collate_fn_without_labels(batch: list[Any]) -> dict[str, torch.Tensor]: + # Contrastive learning does not require labels, with some embedding models even erroring out if they are present + processed_batch: dict[str, torch.Tensor] = collate_fn(batch) + if 'labels' in processed_batch: + del processed_batch['labels'] + return processed_batch + + dl = DataLoader( + pairs_dataset, + collate_fn=collate_fn_without_labels, + batch_size=device_batch_size, + drop_last=drop_last, + num_workers=num_workers, + pin_memory=pin_memory, + prefetch_factor=prefetch_factor, + persistent_workers=persistent_workers, + timeout=timeout, + ) + + return DataSpec(dataloader=dl) + + +# Helpful to test if your dataloader is working locally +# Run `python dataloader.py --local_path [local] [--remote_path remote, optional]` and verify that batches are printed out +if __name__ == '__main__': + import argparse + + from llmfoundry.utils.builders import build_tokenizer + + parser = argparse.ArgumentParser() + parser.add_argument( + '--tokenizer', + type=str, + default='EleutherAI/gpt-neox-20b', + help='the name of the tokenizer to use', + ) + parser.add_argument( + '--local_path', + type=str, + required=True, + help='the path to the local copy of the dataset', + ) + parser.add_argument( + '--remote_path', + type=str, + default=None, + help='the path to the remote copy to stream from (optional)', + ) + parser.add_argument( + '--split', + type=str, + default='train', + help='which split of the dataset to use', + ) + parser.add_argument( + '--max_seq_len', + type=int, + default=32, + help='max sequence length to test', + ) + + args = parser.parse_args() + + if args.remote_path is not None: + print( + f'Reading {args.split} split from {args.local_path} <- streamed from <- {args.remote_path}', + ) + else: + print(f'Reading {args.split} split from {args.local_path}') + + cfg = { + 'name': 'contrastive_pairs', + 'dataset': { + 'local': args.local_path, + 'remote': args.remote_path, + 'split': args.split, + 'shuffle': False, + 'max_seq_len': args.max_seq_len, + 'keep_zip': True, # in case we need compressed files after testing + }, + 'drop_last': False, + 'num_workers': 4, + } + device_batch_size = 2 + + tokenizer_name = args.tokenizer + tokenizer_kwargs = {'model_max_length': args.max_seq_len} + tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) + + loader = build_pairs_dataloader( + **cfg, + tokenizer=tokenizer, + device_batch_size=device_batch_size, + ).dataloader + assert isinstance(loader, DataLoader) + assert isinstance(loader.dataset, StreamingPairsDataset) + tokenizer = loader.dataset.tokenizer + + for batch_ix, batch in enumerate(islice(loader, 5)): + print('\n') + print('#' * 20, f'Batch {batch_ix}', '#' * 20) + for k, v in batch.items(): + print(k, v.shape, v.dtype) + for sample_ix, token_sample in enumerate(batch['input_ids']): + print('-' * 20, f' Sample {sample_ix} ', '-' * 20) + print(tokenizer.decode(token_sample)) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 37d4c32b23..5689adbb20 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -23,7 +23,7 @@ from transformers import PreTrainedTokenizerBase from llmfoundry import registry -from llmfoundry.data import ( +from llmfoundry.data.data import ( SUPPORTED_MDS_ENCODING_TYPES, stream_remote_local_validate, ) diff --git a/llmfoundry/models/__init__.py b/llmfoundry/models/__init__.py index 827fe2ce56..d6af82f616 100644 --- a/llmfoundry/models/__init__.py +++ b/llmfoundry/models/__init__.py @@ -8,6 +8,7 @@ OpenAICausalLMEvalWrapper, OpenAIChatAPIEvalWrapper, ) +from llmfoundry.models.llm_embed import ContrastiveModel, FinetuneEmbeddingModel from llmfoundry.models.mpt import ( ComposerMPTCausalLM, MPTConfig, @@ -24,6 +25,8 @@ models.register('fmapi_causal_lm', func=FMAPICasualLMEvalWrapper) models.register('openai_chat', func=OpenAIChatAPIEvalWrapper) models.register('fmapi_chat', func=FMAPIChatAPIEvalWrapper) +models.register('finetune_embedding_model', func=FinetuneEmbeddingModel) +models.register('contrastive_lm', func=ContrastiveModel) __all__ = [ 'ComposerHFCausalLM', diff --git a/llmfoundry/models/llm_embed/README.md b/llmfoundry/models/llm_embed/README.md new file mode 100644 index 0000000000..20e5d358c1 --- /dev/null +++ b/llmfoundry/models/llm_embed/README.md @@ -0,0 +1,73 @@ +# Embedding models + +_Q: What is a contrastive loss?_ + +The contrastive loss can be thought of as a loss that creates a high similarity score between two similar samples, and a low score between two very different samples. Formally, this can be achieved by using the cross-entropy loss for an N-way softmax classifier. Some motivation for the contrastive loss can be found in [this blogpost](https://ankeshanand.com/blog/2020/01/26/contrative-self-supervised-learning.html). This has become the dominant method for training embedding/retrieval methods. + +_Q: How does the data need to be formatted?_ + +The data simply needs to come in "pairs." This can either be in the form of: +1. positive pairs such as "query: is snowboarding allowed at the alta ski resort?" "passage: the alta ski resort does not allow snowboarding," or +2. tuples of positive pairs with curated hard negative pairs. + +In the first case, the InfoNCE Loss treats the rest of the samples in the batch as "soft negatives" (this scenario also is most successful with _very_ large global batch sizes on the order of 16-32k). In the second scenario, the InfoNCE uses the hard negatives in the denominator of the loss and much smaller global batch sizes (e.g. 32). + +_Q: How do you get a vector embedding out of a decoder? I thought you could only do that with encoders?_ + +Before the final "logit" layer of the decoder, the tensor still has dimensions _batch size_ x _sequence length_ x _hidden dimension_. In order to get a single vector representation for a single sample, we can average along the sequence length dimension (i.e. average the vectors representing each token). Alternatively, you can append an `<|endoftext|>` token to the end of each sample and extract the vector for this token alone (this seems to work well for RepLlama). + +The main additions are as follows: + +* The class `ContrastiveModel(HuggingFaceModel)`, which implements the InfoNCE Loss in the `.loss()` function. +* A dataloader for contrastive pairs `build_pairs_dataloader()`. This can handle positive pairs formatted as `text_a` and `text_b`, or positive pairs with hard negatives formatted as `query`,`passage` and a list of `hard_negatives`. + +## Example YAML + +```yaml +variables: + data_local: + data_remote: # If blank, files must be present in data_local + max_seq_len: 2048 + global_seed: 17 + + # Run Name + run_name: # If left blank, will be read from env var $RUN_NAME + +max_seq_len: ${variables.max_seq_len} +run_name: ${variables.run_name} + +# Model +model: + name: contrastive_lm + init_device: meta + d_model: 768 + n_heads: 12 + n_layers: 12 + expansion_ratio: 4 + max_seq_len: ${variables.max_seq_len} + vocab_size: 50368 + attn_config: + attn_impl: flash + +# Tokenizer +tokenizer: + name: EleutherAI/gpt-neox-20b + kwargs: + model_max_length: ${variables.max_seq_len} + +# Dataloaders +train_loader: + name: contrastive_pairs + dataset: + local: ${variables.data_local} + split: null + remote: ${variables.data_remote} + shuffle: true + max_seq_len: ${variables.max_seq_len} + shuffle_seed: ${variables.global_seed} + prepend_query: 'query: ' + prepend_passage: 'passage: ' + append_eos_token: true + drop_last: true + num_workers: 8 +``` diff --git a/llmfoundry/models/llm_embed/__init__.py b/llmfoundry/models/llm_embed/__init__.py new file mode 100644 index 0000000000..57bce28ac2 --- /dev/null +++ b/llmfoundry/models/llm_embed/__init__.py @@ -0,0 +1,15 @@ +# Copyright 2022 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from llmfoundry.models.llm_embed.finetune_embedding_model import \ + FinetuneEmbeddingModel +from llmfoundry.models.llm_embed.modeling_llm_embed import ( + ContrastiveEvalLoss, + ContrastiveModel, +) + +__all__ = [ + 'ContrastiveModel', + 'ContrastiveEvalLoss', + 'FinetuneEmbeddingModel', +] diff --git a/llmfoundry/models/llm_embed/finetune_embedding_model.py b/llmfoundry/models/llm_embed/finetune_embedding_model.py new file mode 100644 index 0000000000..6b389edf6a --- /dev/null +++ b/llmfoundry/models/llm_embed/finetune_embedding_model.py @@ -0,0 +1,99 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Mapping + +import torch +from composer.utils import dist +from transformers import AutoModel +from transformers.modeling_outputs import CausalLMOutputWithPast + +from llmfoundry.models.llm_embed.modeling_llm_embed import ContrastiveModel + + +class FinetuneEmbeddingModel(ContrastiveModel): + + def construct_model(self) -> CausalLMOutputWithPast: + # Define the model construction specific to FinetuneEmbeddingModel + model = None + + def load_model(): + return AutoModel.from_pretrained( + self.pretrained_model_name_or_path, + trust_remote_code=self.trust_remote_code, + use_auth_token=self.use_auth_token, + **self.kwargs, + ) + + if dist.get_global_rank() == 0: + model = load_model() + dist.barrier() + if model is None: + model = load_model() + + assert model, 'Model is not loaded properly' + return model + + def get_hidden_state(self, outputs: CausalLMOutputWithPast) -> torch.Tensor: + """Override to return the last hidden state.""" + return outputs.last_hidden_state + + def handle_language_head( + self, + outputs: CausalLMOutputWithPast, + ) -> torch.Tensor: + """Override to skip language head handling.""" + return torch.tensor( + 0, + dtype=torch.float32, + device=outputs.last_hidden_state.device, + ) + + def flops_per_batch(self, batch: Mapping) -> int: + # Get the batch size and maximum sequence length + bs, msl = batch['input_ids'].shape[0:2] + + model_dimension = self._get_attribute( + self.model.config, + [ + 'hidden_size', + 'd_model', + 'n_embd', + 'dim', + 'embed_dim', + 'embedding_size', + 'hidden_dim', + ], + ) + + num_layers = self._get_attribute( + self.model.config, + [ + 'num_hidden_layers', + 'n_layer', + 'num_layers', + 'encoder_layers', + 'decoder_layers', + 'n_layers', + 'num_blocks', + 'layer_count', + ], + ) + + num_parameters = sum(p.numel() for p in self.model.parameters()) + + # Estimate FLOPs + params_flops = 2 * num_parameters + seq_flops = params_flops * msl + attn_flops = bs * num_layers * 2 * msl * model_dimension + total_flops = seq_flops * bs + attn_flops + + return total_flops + + def _get_attribute(self, config: Any, possible_names: list): + """Retrieve an attribute from config using a list of possible names.""" + for name in possible_names: + value = getattr(config, name, None) + if value is not None: + return value + return None diff --git a/llmfoundry/models/llm_embed/modeling_llm_embed.py b/llmfoundry/models/llm_embed/modeling_llm_embed.py new file mode 100644 index 0000000000..694d46d184 --- /dev/null +++ b/llmfoundry/models/llm_embed/modeling_llm_embed.py @@ -0,0 +1,439 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +"""LLM Contrastive Embedding Model. + +Implements InfoNCE Loss using the MPT architecture. The resulting model can +be used as a vector embedding model. + +This is inspired by Microsoft Research's unilm repository +https://github.com/microsoft/unilm +""" + +import logging +from dataclasses import dataclass +from typing import Any, Mapping, MutableMapping, Optional, Union, cast + +import torch +import torch.distributed +import torch.nn as nn +import torch.nn.functional as F +from composer.models import HuggingFaceModel +from composer.utils import dist +from einops import rearrange +from omegaconf import OmegaConf as om +from torch.distributed.nn.functional import all_gather +from torchmetrics import Metric +from transformers import PreTrainedTokenizerBase +from transformers.modeling_outputs import CausalLMOutputWithPast + +from llmfoundry import registry +from llmfoundry.models.hf.hf_causal_lm import ComposerHFCausalLM +from llmfoundry.models.mpt.configuration_mpt import MPTConfig +from llmfoundry.models.mpt.modeling_mpt import MPTForCausalLM +from llmfoundry.models.utils.config_moe_args import create_set_process_group + +log = logging.getLogger(__name__) + + +class ContrastiveEvalLoss(Metric): + + def __init__(self): + super().__init__() + self.add_state('loss', default=torch.tensor(0.0), dist_reduce_fx='sum') + self.add_state('total', default=torch.tensor(0), dist_reduce_fx='sum') + + def update(self, outputs: Any, labels: Any): + loss = outputs['loss'] + if loss.device != self.loss.device: + loss = loss.to(self.loss.device) + self.loss += loss + self.total += 1 + + def compute(self): + return self.loss / self.total + + +@dataclass +class ContrastiveConfig: + """Configuration for the contrastive loss. + + Args: + temperature (Union[int, float], optional): Temperature for InfoNCE Loss. Defaults to 1. + vector_representation (str, optional): The vector representation to use. Defaults to 'avg'. + normalize_output (bool, optional): Whether to normalize the output. Defaults to True. + pos_step_size (int, optional): The step size for positive samples. Defaults to 2. + gather_in_batch_negatives (bool, optional): Whether to call all_gather on all samples in global batch + use_legacy_gradient_passthrough (bool, optional): Whether to use the legacy gradient passthrough. Defaults to False. + infonce_process_group_size (int, optional): The size of the process group for InfoNCE loss. Defaults to None. + """ + temperature: Union[int, float] = 1 + vector_representation: str = 'avg' + normalize_output: bool = True + pos_step_size: int = 2 + gather_in_batch_negatives: bool = False + use_legacy_gradient_passthrough: bool = False + infonce_process_group_size: Optional[int] = None + + +class ContrastiveModel(HuggingFaceModel): + """A contrastive loss function wrapping MPT or Huggingface architecture. + + This model applies a contrastive loss function to either a MPT (Mosaic Pretrained Transformer) + or a Huggingface architecture. It allows for bidirectional encoding by modifying the attention mask. + + Args: + tokenizer (PreTrainedTokenizerBase): The tokenizer used for tokenization. + contrastive_config (Dict[str, Any], optional): Configuration for the contrastive loss. Defaults to None. See `ContrastiveConfig`. + pretrained_model_name_or_path (Optional[str], optional): Pretrained model name or path. Defaults to None. + pretrained_lora_id_or_path (Optional[str], optional): Pretrained LoRA (Low Rank Adaptation) ID or path. Defaults to None. + trust_remote_code (bool, optional): Whether to trust remote code. Defaults to False. + init_device (str, optional): The initial device. Defaults to 'cpu'. + use_flash_attention_2 (bool, optional): Whether to use Flash Attention 2. Defaults to True. + use_auth_token (bool, optional): Whether to use an authentication token. Defaults to False. + config_overrides (Optional[Dict[str, Any]], optional): Overrides for the model configuration. Defaults to None. + load_in_8bit (bool, optional): Whether to load the model in 8-bit mode. Defaults to False. + loss_fn (str, optional): The loss function to use (either 'torch_crossentropy' or 'fused_crossentropy'). Defaults to 'fused_crossentropy'. + **kwargs (Dict[str, Any]): Additional keyword arguments. + """ + + def __init__( + self, + tokenizer: PreTrainedTokenizerBase, + contrastive_config: Optional[dict[str, Any]] = None, + pretrained_model_name_or_path: Optional[str] = None, + pretrained_lora_id_or_path: Optional[str] = None, + trust_remote_code: bool = False, + init_device: str = 'cpu', + use_flash_attention_2: bool = True, + use_auth_token: bool = False, + config_overrides: Optional[dict[str, Any]] = None, + load_in_8bit: bool = False, + loss_fn: str = 'fused_crossentropy', + **kwargs: dict[str, Any], + ): + self.pretrained_model_name_or_path = pretrained_model_name_or_path + self.pretrained_lora_id_or_path = pretrained_lora_id_or_path + self.trust_remote_code = trust_remote_code + self.init_device = init_device + self.use_flash_attention_2 = use_flash_attention_2 + self.use_auth_token = use_auth_token + self.config_overrides = config_overrides + self.load_in_8bit = load_in_8bit + self.kwargs = kwargs + self.is_mpt = False + + contrastive_config = contrastive_config or {} + contrastive_config_obj: ContrastiveConfig = om.structured( + ContrastiveConfig(**contrastive_config), + ) + if tokenizer.pad_token is None: # type: ignore + tokenizer.pad_token = tokenizer.eos_token + + model = self.construct_model() + + train_metrics: list[Metric] = [ + ] # TODO: no train metrics for embedding models yet! + + self.eval_metrics = [ + ContrastiveEvalLoss(), + ] + + super().__init__( + model=model, + tokenizer=tokenizer, + use_logits=False, + metrics=train_metrics, + eval_metrics=self.eval_metrics, # type: ignore + shift_labels=False, + allow_embedding_resizing=True, + ) + + # Temperature for InfoNCELoss + self.temperature = contrastive_config_obj.temperature + + # Set the vector representation to either be the average of all the individual token vectors, + # or the EOS token at the end of the sequence + self.vector_representation = contrastive_config_obj.vector_representation + self.normalize_output = contrastive_config_obj.normalize_output + + self.step_size = contrastive_config_obj.pos_step_size + self.gather_in_batch_negatives = contrastive_config_obj.gather_in_batch_negatives + self.use_legacy_gradient_passthrough = contrastive_config_obj.use_legacy_gradient_passthrough + self.n_active_params = sum(p.numel() for p in self.parameters()) + if loss_fn == 'fused_crossentropy': + try: + from flash_attn.losses.cross_entropy import \ + CrossEntropyLoss as FusedCrossEntropyLoss + + self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100) + except: + raise ValueError( + 'Fused Cross Entropy is not installed. Either (1) have a CUDA-compatible GPU ' + + + 'and `pip install .[gpu]` if installing from source or `pip install xentropy-cuda-lib@git+https://github.com/HazyResearch/flash-attention.git@v1.0.3#subdirectory=csrc/xentropy` ' + + + 'if installing from pypi, or (2) set your config model.loss_fn=torch_crossentropy.', + ) + elif loss_fn == 'torch_crossentropy': + self.loss_fn = nn.CrossEntropyLoss(ignore_index=-100) + else: + raise ValueError( + f'Specified loss_fn={loss_fn} not recognized. `loss_fn` must be one of [`fused_crossentropy`, `torch_crossentropy`].', + ) + + self.infonce_process_group = None + if contrastive_config_obj.infonce_process_group_size is not None: + pg_size = contrastive_config_obj.infonce_process_group_size + self.infonce_process_group = create_set_process_group(pg_size) + + def construct_model(self): + if self.pretrained_model_name_or_path: + model_class = registry.models.get('hf_causal_lm') + model_class = cast(type[ComposerHFCausalLM], model_class) + model = model_class.build_inner_model( + pretrained=True, + pretrained_model_name_or_path=self. + pretrained_model_name_or_path, + pretrained_lora_id_or_path=self.pretrained_lora_id_or_path, + trust_remote_code=self.trust_remote_code, + init_device=self.init_device, + use_flash_attention_2=self.use_flash_attention_2, + use_auth_token=self.use_auth_token, + config_overrides=self.config_overrides or {}, + load_in_8bit=self.load_in_8bit, + **self.kwargs, + ) + else: + model = MPTForCausalLM(MPTConfig(**self.kwargs)) + self.is_mpt = True + return model + + def format_queries_batch( + self, + batch: MutableMapping, + last_hidden_state: torch.Tensor, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """Format `queries` by selecting every ``n``th entry from the batch. + + Here ``n`` is the step size, which represents the number of hard + negatives per passage. + """ + queries = {} + for key in batch: + # Select every `step_size`-th entry from the batch for the given key + queries[key] = batch[key][0::self.step_size, :] + + # Select every `step_size`-th entry from `last_hidden_state` along the batch dimension + return queries, last_hidden_state[0::self.step_size, :, :] + + def format_passages_batch( + self, + batch: MutableMapping, + last_hidden_state: torch.Tensor, + ) -> tuple[dict[str, torch.Tensor], torch.Tensor]: + """Format `passages` by selecting every ``n``th entry from the batch. + + Here ``n`` is the step size, which represents the number of hard + negatives per passage. + """ + passages = {} + + # Index on a variable step size + index = 0 + for key in batch: + num_blocks = batch[key].size(0) // self.step_size + index = torch.arange( + 1, + num_blocks * self.step_size + 1, + device=last_hidden_state.device, + ).view(num_blocks, self.step_size) + index = index[:, :self.step_size - 1].reshape(-1) + passages[key] = batch[key][index] + + return passages, last_hidden_state[index, :, :] + + def forward(self, batch: MutableMapping) -> CausalLMOutputWithPast: + # Collapse pairs into the batch dimension + collapse_dims = lambda x: rearrange(x, 'b p d -> (b p) d') if \ + len(x.shape) > 2 else x + + for key in batch: + batch[key] = collapse_dims(batch[key]) + + return self.model( + output_hidden_states=True, + **batch, + ) + + def _cat_gather(self, t: torch.Tensor, group: Any = None) -> torch.Tensor: + """Applies an all gather operation necessary for InfoNCELoss. + + See https://github.com/pytorch/pytorch/blob/63d5e9221bedd1546b7d364b5ce4171547db12a9/torch/distributed/nn/functional.py#L314 + as well as https://github.com/pytorch/pytorch/issues/121587#issuecomment-1989070351 + """ + if self.use_legacy_gradient_passthrough: + all_tensors = list(dist.all_gather(t, group)) + all_tensors[dist.get_global_rank()] = t + all_tensors = torch.cat(all_tensors) + else: + extra_kwargs = {'group': group} if group is not None else {} + all_tensors = all_gather(t, **extra_kwargs) + all_tensors = torch.cat(all_tensors) + + return all_tensors + + def get_hidden_state(self, outputs: CausalLMOutputWithPast) -> torch.Tensor: + """Returns the hidden state to use for pooling.""" + return outputs.hidden_states[-1] + + def handle_language_head( + self, + outputs: CausalLMOutputWithPast, + ) -> torch.Tensor: + """Handles `zero` tensor to avoid DDP unused parameters error.""" + return torch.sum( + outputs.logits, + ) * 0 # This attaches the language head to the computation graph + + def _compute_scores( + self, + batch: MutableMapping, + outputs: CausalLMOutputWithPast, + ) -> tuple[torch.Tensor, torch.Tensor]: + """Run pairs through the encoder separately in two passes. + + This function splits queries and passages based on the step size, which represents + the number of hard negatives per passage. It then runs the queries and passages + through the encoder separately to obtain the encoded representations. The encoded + representations are used for further computations in the model. + + Args: + batch (MutableMapping): The input batch containing queries and passages. + outputs (CausalLMOutputWithPast): The model outputs. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: The encoded representations of queries and passages. + """ + hidden_state = self.get_hidden_state(outputs) + zero = self.handle_language_head(outputs) + ( + queries_batch, + queries_last_hidden_state, + ) = self.format_queries_batch(batch, hidden_state) + ( + passages_batch, + passages_last_hidden_state, + ) = self.format_passages_batch(batch, hidden_state) + + query_attn_mask = queries_batch.get('attention_mask') + passage_attn_mask = passages_batch.get('attention_mask') + assert isinstance(query_attn_mask, torch.Tensor) + assert isinstance(passage_attn_mask, torch.Tensor) + if self.vector_representation == 'eos': + + def pool_fn(x: torch.Tensor, mask: torch.Tensor): + row_indices = torch.arange(mask.shape[0]) + flipped_mask = ~mask.bool() + last_true_indices = flipped_mask.int().argmax(dim=1) - 1 + pooled_outputs = x[row_indices, last_true_indices, :] + return pooled_outputs + elif self.vector_representation == 'avg': + + def pool_fn(x: torch.Tensor, mask: torch.Tensor): + x = x.masked_fill(~mask[..., None].bool(), 0.0) + pooled_outputs = x.sum(dim=1) / (mask.sum(dim=1)[..., None]) + return pooled_outputs + else: + raise ValueError( + f'Specified vector_representation={self.vector_representation} not recognized. `vector_representation` must be one of [`avg`, `eos`].', + ) + + q_pooled_outputs = pool_fn(queries_last_hidden_state, query_attn_mask) + p_pooled_outputs = pool_fn( + passages_last_hidden_state, + passage_attn_mask, + ) + + if self.normalize_output: + q_pooled_outputs = F.normalize(q_pooled_outputs, dim=-1) + p_pooled_outputs = F.normalize(p_pooled_outputs, dim=-1) + + # Use all_gather to include negatives across mini batch + if self.gather_in_batch_negatives: + all_q_pooled_outputs = self._cat_gather( + q_pooled_outputs, + group=self.infonce_process_group, + ) + all_p_pooled_outputs = self._cat_gather( + p_pooled_outputs, + group=self.infonce_process_group, + ) + else: + all_q_pooled_outputs = q_pooled_outputs + all_p_pooled_outputs = p_pooled_outputs + + assert all_q_pooled_outputs is not None + assert all_p_pooled_outputs is not None + + all_scores = self._full_contrastive_scores( + queries=all_q_pooled_outputs, + passages=all_p_pooled_outputs, + ) + all_scores = all_scores * (1 / self.temperature) + zero + + all_labels = torch.arange( + all_scores.size(0), + device=q_pooled_outputs.device, + dtype=torch.long, + ) + all_labels = all_labels * ( + p_pooled_outputs.size(0) // q_pooled_outputs.size(0) + ) + + return all_scores, all_labels + + def _full_contrastive_scores( + self, + queries: torch.Tensor, + passages: torch.Tensor, + ) -> torch.Tensor: + + # this calculates the inner product between query and passage pairs + qp = torch.mm(queries, passages.t()) + + return qp + + def loss( + self, + outputs: CausalLMOutputWithPast, + batch: MutableMapping, + ) -> torch.Tensor: + scores, labels = self._compute_scores(batch, outputs) + loss = self.loss_fn(scores, labels) + return loss + + def eval_forward( + self, + batch: MutableMapping, + outputs: Optional[Any] = None, + ): + if outputs is None: + outputs = self.forward(batch) + val_loss = self.loss(outputs, batch) + return {'loss': val_loss, 'outputs': outputs} + + def flops_per_batch(self, batch: Mapping) -> int: + # Note: this computation does not take into account padding, and assumes + # that the dataset has been constructed without padding. Additionally, we + # assume the backward pass is approximately 2x the forward pass + + bs, msl = batch['input_ids'].shape[0:2] + params_flops_per_token = 2 * self.n_active_params + params_flops_per_seq = params_flops_per_token * msl + attn_flops_per_seq = ( + self.model.config.n_layers * 2 * 2 * + (self.model.config.d_model * (msl**2)) + ) + + return (params_flops_per_seq + attn_flops_per_seq) * 3 * bs diff --git a/scripts/data_prep/delta_to_contrastive_mds.py b/scripts/data_prep/delta_to_contrastive_mds.py new file mode 100644 index 0000000000..7bb75f755b --- /dev/null +++ b/scripts/data_prep/delta_to_contrastive_mds.py @@ -0,0 +1,79 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +import logging +import os +from argparse import ArgumentParser, Namespace + +from llmfoundry.command_utils import convert_delta_to_contrastive_mds + +logger = logging.getLogger(__name__) + + +def parse_args() -> Namespace: + parser = ArgumentParser( + description= + 'Download Delta table from UC and convert to JSON to save locally.', + ) + parser.add_argument( + '--delta_table_name', + required=True, + type=str, + help='UC table ..', + ) + parser.add_argument( + '--output_path', + required=True, + type=str, + help='Local path to save the converted JSON', + ) + parser.add_argument( + '--http_path', + required=False, + type=str, + help='http_path is set then dbsql method is used', + ) + parser.add_argument( + '--cluster_id', + required=False, + type=str, + help= + 'Cluster ID with runtime newer than 14.1.0 and access mode of either assigned or shared can use databricks-connect.', + ) + parser.add_argument( + '--use_serverless', + required=False, + type=bool, + default=False, + help= + 'Use serverless or not. Make sure the workspace is entitled with serverless', + ) + parser.add_argument( + '--batch_size', + required=False, + type=int, + default=1 << 30, + help='Batch size for processing the data', + ) + parser.add_argument( + '--processes', + required=False, + type=int, + default=os.cpu_count(), + help='Number of processes to use for parallel processing', + ) + parsed = parser.parse_args() + return parsed + + +if __name__ == '__main__': + args = parse_args() + convert_delta_to_contrastive_mds( + delta_table_name=args.delta_table_name, + http_path=args.http_path, + cluster_id=args.cluster_id, + use_serverless=args.use_serverless, + output_path=args.output_path, + batch_size=args.batch_size, + processes=args.processes, + ) diff --git a/tests/a_scripts/data_prep/test_delta_to_contrastive.py b/tests/a_scripts/data_prep/test_delta_to_contrastive.py new file mode 100644 index 0000000000..4cd22e356a --- /dev/null +++ b/tests/a_scripts/data_prep/test_delta_to_contrastive.py @@ -0,0 +1,468 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +# tests/a_scripts/contrastive/test_delta_to_contrastive.py + +import json +import unittest +from typing import Any +from unittest.mock import MagicMock, mock_open, patch + + +class TestValidateColumnsInTable(unittest.TestCase): + """Unit tests for the validate_columns_in_table function.""" + + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.run_query', + autospec=True, + ) + def test_validate_columns_success(self, mock_run_query: MagicMock) -> None: + # Import inside the test after patching + from llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds import \ + validate_columns_in_table + + # Mock the run_query to return all required and optional columns + mock_run_query.return_value = [ + MagicMock( + asDict=MagicMock(return_value={'col_name': 'query_text'}), + ), + MagicMock( + asDict=MagicMock(return_value={'col_name': 'positive_passage'}), + ), + MagicMock( + asDict=MagicMock( + return_value={'col_name': 'negative_passages'}, + ), + ), + ] + + required_columns = ['query_text', 'positive_passage'] + optional_columns = ['negative_passages'] + table_name = 'test_table' + method = 'dbconnect' + + result: bool = validate_columns_in_table( + required_columns=required_columns, + optional_columns=optional_columns, + table_name=table_name, + method=method, + cursor=None, + spark=None, + ) + + self.assertTrue(result) + mock_run_query.assert_called_once_with( + f'SHOW COLUMNS IN {table_name}', + method, + None, + None, + ) + + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.run_query', + autospec=True, + ) + def test_validate_columns_missing_required( + self, + mock_run_query: MagicMock, + ) -> None: + from llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds import \ + validate_columns_in_table + + # Mock the run_query to return missing required columns + mock_run_query.return_value = [ + MagicMock( + asDict=MagicMock(return_value={'col_name': 'query_text'}), + ), + ] + + required_columns = ['query_text', 'positive_passage'] + optional_columns = ['negative_passages'] + table_name = 'test_table' + method = 'dbconnect' + + with self.assertLogs( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds', + level='ERROR', + ) as log: + result: bool = validate_columns_in_table( + required_columns=required_columns, + optional_columns=optional_columns, + table_name=table_name, + method=method, + cursor=None, + spark=None, + ) + + self.assertFalse(result) + self.assertIn('Missing required columns', log.output[0]) + + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.run_query', + autospec=True, + ) + def test_validate_columns_extra_columns( + self, + mock_run_query: MagicMock, + ) -> None: + from llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds import \ + validate_columns_in_table + + # Mock the run_query to return extra columns + mock_run_query.return_value = [ + MagicMock( + asDict=MagicMock(return_value={'col_name': 'query_text'}), + ), + MagicMock( + asDict=MagicMock(return_value={'col_name': 'positive_passage'}), + ), + MagicMock( + asDict=MagicMock(return_value={'col_name': 'extra_column'}), + ), + ] + + required_columns = ['query_text', 'positive_passage'] + optional_columns = ['negative_passages'] + table_name = 'test_table' + method = 'dbconnect' + + with self.assertLogs( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds', + level='WARNING', + ) as log: + result: bool = validate_columns_in_table( + required_columns=required_columns, + optional_columns=optional_columns, + table_name=table_name, + method=method, + cursor=None, + spark=None, + ) + + self.assertFalse(result) + self.assertIn('Extra columns found', log.output[0]) + + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.run_query', + autospec=True, + ) + def test_validate_columns_exception( + self, + mock_run_query: MagicMock, + ) -> None: + from llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds import \ + validate_columns_in_table + + # Mock run_query to raise an exception + mock_run_query.side_effect = Exception('Test Exception') + + required_columns = ['query_text', 'positive_passage'] + optional_columns = ['negative_passages'] + table_name = 'test_table' + method = 'dbconnect' + + with self.assertLogs( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds', + level='ERROR', + ) as log: + result: bool = validate_columns_in_table( + required_columns=required_columns, + optional_columns=optional_columns, + table_name=table_name, + method=method, + cursor=None, + spark=None, + ) + + self.assertFalse(result) + self.assertIn('Error validating columns in table', log.output[0]) + + +class TestMainFunction(unittest.TestCase): + """Unit tests for the main function.""" + + @patch('databricks.sdk.WorkspaceClient', autospec=True) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.validate_columns_in_table', + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.validate_and_get_cluster_info', + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.fetch_DT', + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.MDSWriter', + autospec=True, + ) + def test_main_success( + self, + mock_mds_writer: MagicMock, + mock_fetch_DT: MagicMock, + mock_validate_cluster_info: MagicMock, + mock_validate_columns: MagicMock, + mock_workspace_client_class: MagicMock, + ) -> None: + with patch( + 'databricks.sdk.WorkspaceClient', + mock_workspace_client_class, + ): + from llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds import \ + convert_delta_to_contrastive_mds + + # Setup mocks + mock_workspace_client_instance = MagicMock() + mock_workspace_client_class.return_value = mock_workspace_client_instance + + mock_validate_cluster_info.return_value = ( + 'dbconnect', + MagicMock(), + MagicMock(), + ) + mock_validate_columns.return_value = True + + mock_mds_instance: MagicMock = MagicMock() + mock_mds_writer.return_value.__enter__.return_value = mock_mds_instance + + args: dict[str, Any] = { + 'delta_table_name': 'test_table', + 'http_path': 'http://test_path', + 'cluster_id': 'cluster123', + 'use_serverless': False, + 'output_path': '/output/path', + 'batch_size': 1000, + 'processes': 4, + } + + with patch('tempfile.TemporaryDirectory') as mock_temp_dir: + mock_temp_dir.return_value.__enter__.return_value = '/tmp/mock_dir' + + # **Update the mock_open data to include negative_passages as a list** + with patch( + 'builtins.open', + mock_open( + read_data= + '{"query_text": "sample", "positive_passage": "passage", "negative_passages": []}\n', + ), + ): + convert_delta_to_contrastive_mds(**args) + + # Assertions + mock_workspace_client_class.assert_called_once() + mock_validate_cluster_info.assert_called_once_with( + cluster_id='cluster123', + databricks_host=mock_workspace_client_instance.config. + host, + databricks_token=mock_workspace_client_instance.config. + token, + http_path='http://test_path', + use_serverless=False, + ) + mock_validate_columns.assert_called_once_with( + required_columns=['query_text', 'positive_passage'], + optional_columns=['negative_passages'], + table_name='test_table', + method='dbconnect', + cursor=mock_validate_cluster_info.return_value[1]. + cursor(), + spark=mock_validate_cluster_info.return_value[2], + ) + mock_fetch_DT.assert_called_once_with( + delta_table_name='test_table', + json_output_folder='/tmp/mock_dir', + http_path='http://test_path', + cluster_id='cluster123', + use_serverless=False, + json_output_filename='output.jsonl', + batch_size=1000, + processes=4, + DATABRICKS_HOST=mock_workspace_client_instance.config. + host, + DATABRICKS_TOKEN=mock_workspace_client_instance.config. + token, + ) + mock_mds_writer.assert_called_once_with( + out='/output/path', + columns={ + 'query_text': 'str', + 'positive_passage': 'str', + 'negative_passages': 'str', + }, + compression='zstd:7', + hashes=['sha1'], + size_limit='10mb', + ) + mock_mds_instance.write.assert_called_once_with({ + 'query_text': 'sample', + 'positive_passage': 'passage', + 'negative_passages': json.dumps([]), + }) + + @patch('databricks.sdk.WorkspaceClient', autospec=True) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.validate_columns_in_table', + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.validate_and_get_cluster_info', + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.fetch_DT', + side_effect=Exception('Fetch DT Error'), + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.MDSWriter', + autospec=True, + ) + def test_main_fetch_DT_exception( + self, + mock_mds_writer: MagicMock, + mock_fetch_DT: MagicMock, + mock_validate_cluster_info: MagicMock, + mock_validate_columns: MagicMock, + mock_workspace_client_class: MagicMock, + ) -> None: + """Test that main raises an exception when fetch_DT fails.""" + with patch( + 'databricks.sdk.WorkspaceClient', + mock_workspace_client_class, + ): + from llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds import \ + convert_delta_to_contrastive_mds + + # Setup mocks + mock_workspace_client_instance = MagicMock() + mock_workspace_client_class.return_value = mock_workspace_client_instance + + mock_validate_cluster_info.return_value = ( + 'dbconnect', + MagicMock(), + MagicMock(), + ) + mock_validate_columns.return_value = True + + args: dict[str, Any] = { + 'delta_table_name': 'test_table', + 'http_path': 'http://test_path', + 'cluster_id': 'cluster123', + 'use_serverless': False, + 'output_path': '/output/path', + 'batch_size': 1000, + 'processes': 4, + } + + with self.assertLogs( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds', + level='ERROR', + ) as log: + with self.assertRaises(Exception) as cm: + convert_delta_to_contrastive_mds(**args) + + self.assertIn('Error fetching data: Fetch DT Error', log.output[0]) + self.assertEqual(str(cm.exception), 'Fetch DT Error') + + @patch('databricks.sdk.WorkspaceClient', autospec=True) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.validate_columns_in_table', + return_value=True, + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.validate_and_get_cluster_info', + return_value=('dbconnect', MagicMock(), MagicMock()), + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.fetch_DT', + autospec=True, + ) + @patch( + 'llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds.MDSWriter', + autospec=True, + ) + def test_main_temporary_directory_handling( + self, + mock_mds_writer: MagicMock, + mock_fetch_DT: MagicMock, + mock_validate_cluster_info: MagicMock, + mock_validate_columns: MagicMock, + mock_workspace_client_class: MagicMock, + ) -> None: + with patch( + 'databricks.sdk.WorkspaceClient', + mock_workspace_client_class, + ): + from llmfoundry.command_utils.data_prep.convert_delta_to_contrastive_mds import \ + convert_delta_to_contrastive_mds + + # Setup mocks + mock_workspace_client_instance = MagicMock() + mock_workspace_client_class.return_value = mock_workspace_client_instance + + mock_validate_cluster_info.return_value = ( + 'dbconnect', + MagicMock(), + MagicMock(), + ) + mock_validate_columns.return_value = True + + mock_mds_instance: MagicMock = MagicMock() + mock_mds_writer.return_value.__enter__.return_value = mock_mds_instance + + args: dict[str, Any] = { + 'delta_table_name': 'test_table', + 'http_path': 'http://test_path', + 'cluster_id': 'cluster123', + 'use_serverless': False, + 'output_path': '/output/path', + 'batch_size': 1000, + 'processes': 4, + } + + with patch('tempfile.TemporaryDirectory') as mock_temp_dir: + mock_temp_dir.return_value.__enter__.return_value = '/tmp/mock_dir' + # **Update the mock_open data to include negative_passages as a list** + with patch( + 'builtins.open', + mock_open( + read_data= + '{"query_text": "sample", "positive_passage": "passage", "negative_passages": []}\n', + ), + ): + convert_delta_to_contrastive_mds(**args) + mock_temp_dir.assert_called_once() + mock_fetch_DT.assert_called_once_with( + delta_table_name='test_table', + json_output_folder='/tmp/mock_dir', + http_path='http://test_path', + cluster_id='cluster123', + use_serverless=False, + json_output_filename='output.jsonl', + batch_size=1000, + processes=4, + DATABRICKS_HOST=mock_workspace_client_instance.config. + host, + DATABRICKS_TOKEN=mock_workspace_client_instance.config. + token, + ) + mock_mds_writer.assert_called_once_with( + out='/output/path', + columns={ + 'query_text': 'str', + 'positive_passage': 'str', + 'negative_passages': 'str', + }, + compression='zstd:7', + hashes=['sha1'], + size_limit='10mb', + ) + mock_mds_instance.write.assert_called_once_with({ + 'query_text': 'sample', + 'positive_passage': 'passage', + 'negative_passages': json.dumps([]), + }) diff --git a/tests/data/test_contrastive_dataloader.py b/tests/data/test_contrastive_dataloader.py new file mode 100644 index 0000000000..4e664bb458 --- /dev/null +++ b/tests/data/test_contrastive_dataloader.py @@ -0,0 +1,74 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import cast + +import pytest +import torch + +from llmfoundry.utils.builders import build_dataloader +from tests.data_utils import temporary_contrastive_streaming_dataset +from tests.test_utils import MockTokenizer + + +@pytest.fixture +def mock_tokenizer() -> MockTokenizer: + return MockTokenizer() + + +@pytest.mark.parametrize( + 'ds_format', + ['one_query_one_response', 'one_query_multiple_responses'], +) +def test_pairs_dataloader( + ds_format: str, + mock_tokenizer: MockTokenizer, +) -> None: + with temporary_contrastive_streaming_dataset(ds_format) as data_dir: + cfg = { + 'name': 'contrastive_pairs', + 'dataset': { + 'remote': data_dir, + 'split': 'train', + 'max_seq_len': 1024, + 'shuffle_hard_negatives': False, + }, + 'drop_last': False, + 'num_workers': 1, + 'max_hard_negatives': 2, + } + + dl = build_dataloader(cfg, mock_tokenizer, 1) + + for i, batch in enumerate(dl.dataloader): + batch_dict = cast(dict[str, torch.Tensor], batch) + batch_input_ids = batch_dict['input_ids'] + # query + positive + max 2 hard negatives + assert batch_input_ids.shape[1] <= 4 + if ds_format == 'one_query_one_response': + # 0th item is the query, 1st item is the positive, 2nd item is (optionally) the negative + tokenizer_output = mock_tokenizer( + [f'hello {i}', f'world {i}'], + padding='max_length', + max_length=1024, + return_tensors='pt', + ) + tokenizer_dict = cast(dict[str, torch.Tensor], tokenizer_output) + expected_ids = tokenizer_dict['input_ids'] + else: + # 0th item is the query, 1st item is the positive, 2nd and 3rd items are the negatives + tokenizer_output = mock_tokenizer( + [ + f'query {i}', + f'positive passage {i}', + f'negative passage {i}', + f'negative passage {i + 1}', + ], + padding='max_length', + max_length=1024, + return_tensors='pt', + ) + tokenizer_dict = cast(dict[str, torch.Tensor], tokenizer_output) + expected_ids = tokenizer_dict['input_ids'] + + assert torch.allclose(batch_input_ids[0], expected_ids) diff --git a/tests/data_utils.py b/tests/data_utils.py index 67c1be9f6e..57da51956e 100644 --- a/tests/data_utils.py +++ b/tests/data_utils.py @@ -4,11 +4,14 @@ import json import os import shutil +from contextlib import contextmanager from pathlib import Path +from tempfile import TemporaryDirectory from typing import Optional from omegaconf import DictConfig from omegaconf import OmegaConf as om +from streaming import MDSWriter from llmfoundry.command_utils import ( convert_dataset_hf, @@ -308,3 +311,56 @@ def gpt_tiny_cfg(dataset_name: str, device: str): test_cfg.precision = 'fp32' return test_cfg + + +@contextmanager +def temporary_contrastive_streaming_dataset(ds_format: str): + dir_name, cleanup_fn = build_temporary_contrastive_streaming_dataset( + ds_format, + ) + + try: + yield dir_name + finally: + cleanup_fn() + + +def build_temporary_contrastive_streaming_dataset(ds_format: str): + tempdir = TemporaryDirectory() + columns = { + 'text_a': 'str', + 'text_b': 'str', + 'id': 'int', + } if ds_format == 'one_query_one_response' else { + 'query_text': 'str', + 'positive_passage': 'str', + 'negative_passages': 'str', + 'id': 'int', + } + with MDSWriter( + columns=columns, + out=os.path.join(tempdir.name, 'train'), + compression=None, + ) as output_writer: + for i in range(100): + if ds_format == 'one_query_one_response': + output_writer.write({ + 'text_a': f'hello {i}', + 'text_b': f'world {i}', + 'id': i, + }) + elif ds_format == 'one_query_multiple_responses': + output_writer.write({ + 'query_text': + f'query {i}', + 'positive_passage': + f'positive passage {i}', + 'negative_passages': + f'["negative passage {i}", "negative passage {i + 1}", "negative passage {i + 2}"]', + 'id': + i, + }) + else: + raise ValueError(f'Unknown format: {format}') + + return tempdir.name, tempdir.cleanup diff --git a/tests/models/llm_embed/test_embedding_finetune.py b/tests/models/llm_embed/test_embedding_finetune.py new file mode 100644 index 0000000000..ad33f39ffd --- /dev/null +++ b/tests/models/llm_embed/test_embedding_finetune.py @@ -0,0 +1,176 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any +from unittest.mock import patch + +import pytest +import torch +from transformers import AutoConfig +from transformers.modeling_outputs import \ + BaseModelOutputWithPastAndCrossAttentions + +from llmfoundry.models.llm_embed import FinetuneEmbeddingModel +from tests.test_utils import MockTokenizer + + +@pytest.fixture +def mock_tokenizer() -> MockTokenizer: + return MockTokenizer() + + +class MockAutoModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.config: AutoConfig = AutoConfig.from_pretrained( + 'bert-base-uncased', + ) + self.config.hidden_size = 768 + self.config.num_hidden_layers = 12 + self.config.n_layers = 12 + self.config.vocab_size = 30000 + self.linear: torch.nn.Linear = torch.nn.Linear( + self.config.hidden_size, + self.config.hidden_size, + ) + + @classmethod + def from_pretrained(cls, *args: Any, **kwargs: Any) -> 'MockAutoModel': + return cls() + + def forward( + self, + **kwargs: Any, + ) -> BaseModelOutputWithPastAndCrossAttentions: + # Simulate forward pass + input_ids: torch.Tensor = kwargs.get( + 'input_ids', + torch.zeros(1, 10, dtype=torch.long), + ) + batch_size: int = input_ids.size(0) + seq_length: int = input_ids.size(1) + last_hidden_state: torch.Tensor = torch.randn( + batch_size, + seq_length, + self.config.hidden_size, + ) + last_hidden_state = self.linear(last_hidden_state) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_state, + hidden_states=(last_hidden_state,) * + (self.config.num_hidden_layers + 1), + ) + + +@pytest.fixture +def mock_auto_model() -> MockAutoModel: + return MockAutoModel() + + +@pytest.fixture +def model( + mock_tokenizer: MockTokenizer, + mock_auto_model: MockAutoModel, +) -> FinetuneEmbeddingModel: + with patch('transformers.AutoModel.from_pretrained', return_value=mock_auto_model), \ + patch('composer.utils.dist.get_global_rank', return_value=0), \ + patch('composer.utils.dist.barrier'), \ + patch('llmfoundry.models.llm_embed.FinetuneEmbeddingModel.construct_model', return_value=mock_auto_model): + model_instance: FinetuneEmbeddingModel = FinetuneEmbeddingModel( + tokenizer=mock_tokenizer, + pretrained_model_name_or_path='bert-base-uncased', + loss_fn='torch_crossentropy', + ) + return model_instance + + +def test_construct_model(model: FinetuneEmbeddingModel) -> None: + with patch( + 'transformers.AutoModel.from_pretrained', + return_value=model.model, + ): + constructed_model = model.construct_model() + assert constructed_model is not None + assert isinstance(constructed_model, MockAutoModel) + + +def test_get_hidden_state(model: FinetuneEmbeddingModel) -> None: + mock_outputs: BaseModelOutputWithPastAndCrossAttentions = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=torch.randn(1, 10, model.model.config.hidden_size), + ) + hidden_state: torch.Tensor = model.get_hidden_state(mock_outputs) + assert torch.equal(hidden_state, mock_outputs.last_hidden_state) + + +def test_handle_language_head(model: FinetuneEmbeddingModel) -> None: + mock_outputs: BaseModelOutputWithPastAndCrossAttentions = BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=torch.randn(1, 10, model.model.config.hidden_size), + ) + result: torch.Tensor = model.handle_language_head(mock_outputs) + assert isinstance(result, torch.Tensor) + assert result.item() == 0 + assert result.dtype == torch.float32 + assert result.device == mock_outputs.last_hidden_state.device + + +def test_flops_per_batch(model: FinetuneEmbeddingModel) -> None: + batch: dict[str, torch.Tensor] = { + 'input_ids': torch.randint(0, 1000, (2, 128)), + } + flops: int = model.flops_per_batch(batch) + assert isinstance(flops, int) + assert flops > 0 + + +def test_get_attribute(model: FinetuneEmbeddingModel) -> None: + config: AutoConfig = AutoConfig.from_pretrained('bert-base-uncased') + config.hidden_size = 768 + config.d_model = 1024 + config.n_embd = 512 + + attribute_value: Any = model._get_attribute( + config, + ['hidden_size', 'd_model', 'n_embd'], + ) + assert attribute_value == 768 + attribute_value = model._get_attribute(config, ['d_model', 'n_embd']) + assert attribute_value == 1024 + attribute_value = model._get_attribute( + config, + ['non_existent', 'also_non_existent'], + ) + assert attribute_value is None + + +@pytest.mark.parametrize( + 'dist_initialized', + [ + pytest.param( + True, + marks=[ + pytest.mark.gpu, + pytest.mark.world_size(2), + ], + ), + pytest.param(False), + ], +) +def test_construct_model_distributed( + mock_tokenizer: MockTokenizer, + mock_auto_model: MockAutoModel, + dist_initialized: bool, +) -> None: + with patch('torch.distributed.is_initialized', return_value=dist_initialized), \ + patch('torch.distributed.get_rank', return_value=0), \ + patch('torch.distributed.barrier'), \ + patch('transformers.AutoModel.from_pretrained', return_value=mock_auto_model), \ + patch('llmfoundry.models.llm_embed.FinetuneEmbeddingModel.construct_model', return_value=mock_auto_model): + model_instance: FinetuneEmbeddingModel = FinetuneEmbeddingModel( + tokenizer=mock_tokenizer, + pretrained_model_name_or_path='bert-base-uncased', + loss_fn='torch_crossentropy', + ) + constructed_model: torch.nn.Module = model_instance.construct_model() + assert constructed_model is not None + assert isinstance(constructed_model, torch.nn.Module) diff --git a/tests/models/llm_embed/test_llm_embedding.py b/tests/models/llm_embed/test_llm_embedding.py new file mode 100644 index 0000000000..97dc39788e --- /dev/null +++ b/tests/models/llm_embed/test_llm_embedding.py @@ -0,0 +1,408 @@ +# Copyright 2024 MosaicML LLM Foundry authors +# SPDX-License-Identifier: Apache-2.0 + +from contextlib import nullcontext +from typing import Any, Optional +from unittest.mock import patch + +import pytest +import torch +from composer import Trainer +from composer.core import get_precision_context +from torch.nn.parallel import DistributedDataParallel as DDP +from transformers import AutoConfig +from transformers.modeling_outputs import \ + BaseModelOutputWithPastAndCrossAttentions + +from llmfoundry.models.llm_embed import ContrastiveEvalLoss, ContrastiveModel +from llmfoundry.utils.builders import build_dataloader +from tests.data_utils import temporary_contrastive_streaming_dataset +from tests.test_utils import MockTokenizer + + +@pytest.fixture +def mock_tokenizer() -> MockTokenizer: + return MockTokenizer() + + +class MockAutoModel(torch.nn.Module): + + def __init__(self) -> None: + super().__init__() + self.config: AutoConfig = AutoConfig.from_pretrained( + 'bert-base-uncased', + ) + self.config.hidden_size = 768 + self.config.num_hidden_layers = 12 + self.config.n_layers = 12 + self.config.vocab_size = 30000 + self.linear: torch.nn.Linear = torch.nn.Linear( + self.config.hidden_size, + self.config.hidden_size, + ) + + @classmethod + def from_pretrained(cls, *args: Any, **kwargs: Any) -> 'MockAutoModel': + return cls() + + def forward( + self, + **kwargs: Any, + ) -> BaseModelOutputWithPastAndCrossAttentions: + # Simulate forward pass + input_ids: torch.Tensor = kwargs.get( + 'input_ids', + torch.zeros(1, 10, dtype=torch.long), + ) + batch_size: int = input_ids.size(0) + seq_length: int = input_ids.size(1) + last_hidden_state: torch.Tensor = torch.randn( + batch_size, + seq_length, + self.config.hidden_size, + ) + last_hidden_state = self.linear(last_hidden_state) + return BaseModelOutputWithPastAndCrossAttentions( + last_hidden_state=last_hidden_state, + hidden_states=(last_hidden_state,) * + (self.config.num_hidden_layers + 1), + ) + + +@pytest.fixture +def mock_auto_model() -> MockAutoModel: + return MockAutoModel() + + +@pytest.fixture +def model( + mock_tokenizer: MockTokenizer, + mock_auto_model: MockAutoModel, +) -> ContrastiveModel: + with patch('transformers.AutoModel.from_pretrained', return_value=mock_auto_model), \ + patch('torch.distributed.is_initialized', return_value=False), \ + patch('torch.distributed.get_rank', return_value=0), \ + patch('torch.distributed.barrier'), \ + patch('llmfoundry.models.llm_embed.FinetuneEmbeddingModel.construct_model', return_value=mock_auto_model): + model_instance: ContrastiveModel = ContrastiveModel( + tokenizer=mock_tokenizer, + pretrained_model_name_or_path='bert-base-uncased', + loss_fn='torch_crossentropy', + use_flash_attention_2=False, + ) + return model_instance + + +def build_lm_config(is_hf: bool, attn_impl: Optional[str]) -> dict[str, Any]: + if is_hf: + assert attn_impl is None + return {'pretrained_model_name_or_path': 'facebook/opt-350m'} + else: + assert attn_impl is not None + return { + 'num_layers': 2, + 'word_embed_proj_dim': 768, + 'd_model': 768, + 'n_heads': 12, + 'vocab_size': 100352, + 'attn_config': { + 'attn_impl': attn_impl, + }, + } + + +def build_tokenizer_config(is_hf: bool) -> dict[str, Any]: + return {'vocab_size': 50257 if is_hf else 100352} + + +@pytest.mark.gpu +@pytest.mark.parametrize('is_hf', [True, False]) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +def test_mpt_embedding_lm( + is_hf: bool, + attn_impl: str, + mock_tokenizer: MockTokenizer, +): + maybe_attn_impl = None if is_hf else attn_impl + lm_config = build_lm_config(is_hf, maybe_attn_impl) + + model = ContrastiveModel(**lm_config, tokenizer=mock_tokenizer).to( + torch.bfloat16, + ).to('cuda') + model_inputs_batch = mock_tokenizer([['pair 1 a', 'pair 1 b'], + ['pair 2 a', 'pair 2 b']], + padding='max_length', + truncation=True, + max_length=128, + return_tensors='pt') + if isinstance(model_inputs_batch, dict): + model_inputs_batch = { + k: v.to('cuda') for k, v in model_inputs_batch.items() + } + + ctx = get_precision_context( + 'amp_bf16', + ) if maybe_attn_impl == 'flash' else nullcontext() + with ctx: + outputs = model(model_inputs_batch) + + assert isinstance(outputs, dict) + assert 'hidden_states' in outputs + + hidden_states = outputs['hidden_states'] + assert isinstance(hidden_states, tuple) + + last_hidden_state = hidden_states[-1] + proj_dim = model.model.config.word_embed_proj_dim + assert last_hidden_state.shape == ( + 4, + 128, + proj_dim, + ) # 2 pairs * 2 texts per pair, 128 sequence length, word_embed_proj_dim dim + assert last_hidden_state.dtype == torch.bfloat16 + assert last_hidden_state.device.type == 'cuda' + + +dataloader_config = lambda remote, local_ext: { + 'name': 'contrastive_pairs', + 'dataset': { + 'remote': remote, + 'local': remote + '_' + local_ext, + 'split': 'train', + 'max_seq_len': 1024, + }, + 'drop_last': False, + 'num_workers': 1, + 'max_hard_negatives': 1, +} + + +@pytest.mark.gpu +@pytest.mark.parametrize('is_hf', [True, False]) +@pytest.mark.parametrize('attn_impl', ['flash', 'torch']) +@pytest.mark.parametrize( + 'ds_format', + ['one_query_one_response', 'one_query_multiple_responses'], +) +def test_contrastive_loss( + ds_format: str, + is_hf: bool, + attn_impl: str, + mock_tokenizer: MockTokenizer, +): + maybe_attn_impl = None if is_hf else attn_impl + + with temporary_contrastive_streaming_dataset(ds_format) as data_dir: + lm_config = build_lm_config(is_hf, maybe_attn_impl) + model = ContrastiveModel(**lm_config, tokenizer=mock_tokenizer).to( + torch.bfloat16, + ).to('cuda') + + train_dataloader = build_dataloader( + dataloader_config(data_dir, 'local'), + mock_tokenizer, + 2, + ) + + precision = 'amp_bf16' if maybe_attn_impl == 'flash' else 'fp32' + ctx = get_precision_context( + 'amp_bf16', + ) if attn_impl == 'flash' else nullcontext() + with ctx: + trainer = Trainer( + model=model, + train_dataloader=train_dataloader, + precision=precision, + max_duration='3ba', + ) + trainer.fit() + + +@pytest.mark.gpu +@pytest.mark.world_size(2) +@pytest.mark.parametrize( + 'use_legacy_gradient_passthrough', + [ + pytest.param( + True, + marks=pytest.mark.xfail(reason='Does not backprop gradients.'), + ), + False, + ], +) +def test_distributed_loss( + use_legacy_gradient_passthrough: bool, + mock_tokenizer: MockTokenizer, +): + is_hf = False + + lm_config = build_lm_config(is_hf, 'flash') + lm_config['contrastive_config'] = { + 'gather_in_batch_negatives': True, + 'use_legacy_gradient_passthrough': use_legacy_gradient_passthrough, + } + + lm_config_single_device = lm_config.copy() + lm_config_single_device['contrastive_config'] = lm_config[ + 'contrastive_config'].copy() + lm_config_single_device['contrastive_config']['gather_in_batch_negatives' + ] = False + + model_for_ddp = ContrastiveModel(mock_tokenizer, **lm_config) + model_ddp = model_for_ddp.to('cuda').to(torch.bfloat16) + model_ddp = DDP(model_ddp) + + model = ContrastiveModel(mock_tokenizer, + **lm_config_single_device).to('cuda').to( + torch.bfloat16, + ) + model.load_state_dict(model_for_ddp.state_dict()) + + input_batch = mock_tokenizer([['pair 1 a', 'pair 1 b'], + ['pair 2 a', 'pair 2 b'], + ['pair 3 a', 'pair 3 b'], + ['pair 4 a', 'pair 4 b']], + padding='max_length', + truncation=True, + max_length=128, + return_tensors='pt') + if isinstance(input_batch, dict): + input_batch = {k: v.to('cuda') for k, v in input_batch.items()} + + +def test_contrastive_eval_loss_update_and_compute() -> None: + metric = ContrastiveEvalLoss() + + # Mock outputs and labels + outputs1 = {'loss': torch.tensor(1.0)} + outputs2 = {'loss': torch.tensor(2.0)} + outputs3 = {'loss': torch.tensor(3.0)} + + # Update metric + metric.update(outputs1, None) + metric.update(outputs2, None) + metric.update(outputs3, None) + + # Compute average loss + average_loss = metric.compute() + assert average_loss == pytest.approx(2.0) + + +def test_contrastive_eval_loss_device_handling() -> None: + metric = ContrastiveEvalLoss() + + # Mock outputs on a different device + if torch.cuda.is_available(): + device = 'cuda' + else: + device = 'cpu' + loss_tensor = torch.tensor(1.5, device=device) + outputs = {'loss': loss_tensor} + + metric.update(outputs, None) + + # Ensure loss is moved to metric's device + assert metric.loss.device.type == device + assert metric.loss == loss_tensor + + +def test_eval_forward_without_outputs(model: ContrastiveModel) -> None: + # Create a mock batch + batch = { + 'input_ids': torch.randint(0, 1000, (2, 128)), + 'attention_mask': torch.ones(2, 128, dtype=torch.long), + 'labels': torch.randint(0, 2, (2, 128)), + } + + # Mock the forward method to return a mock output with 'loss' + with patch.object(model, 'forward') as mock_forward, \ + patch.object(model, 'loss') as mock_loss: + mock_forward.return_value = { + 'loss': torch.tensor(1.0), + 'hidden_states': None, + } + mock_loss.return_value = torch.tensor(1.0) + + result = model.eval_forward(batch) + + assert isinstance(result, dict) + assert 'loss' in result + assert 'outputs' in result + assert result['loss'] == torch.tensor(1.0) + assert result['outputs'] == mock_forward.return_value + + +def test_eval_forward_with_outputs(model: ContrastiveModel) -> None: + # Create a mock batch and outputs + batch = { + 'input_ids': torch.randint(0, 1000, (2, 128)), + 'attention_mask': torch.ones(2, 128, dtype=torch.long), + 'labels': torch.randint(0, 2, (2, 128)), + } + mock_outputs = {'loss': torch.tensor(2.0), 'hidden_states': None} + + # Mock the loss method + with patch.object(model, 'loss') as mock_loss: + mock_loss.return_value = torch.tensor(2.0) + + result = model.eval_forward(batch, outputs=mock_outputs) + + assert isinstance(result, dict) + assert 'loss' in result + assert 'outputs' in result + assert result['loss'] == torch.tensor(2.0) + assert result['outputs'] == mock_outputs + + +def test_eval_forward_returns_correct_structure( + model: ContrastiveModel, +) -> None: + # Create a mock batch + batch = { + 'input_ids': torch.randint(0, 1000, (1, 50)), + 'attention_mask': torch.ones(1, 50, dtype=torch.long), + 'labels': torch.randint(0, 2, (1, 50)), + } + + # Mock the forward and loss methods + with patch.object(model, 'forward') as mock_forward, \ + patch.object(model, 'loss') as mock_loss: + mock_forward.return_value = { + 'loss': torch.tensor(0.5), + 'hidden_states': None, + } + mock_loss.return_value = torch.tensor(0.5) + + result = model.eval_forward(batch) + + assert isinstance(result, dict) + assert set(result.keys()) == {'loss', 'outputs'} + assert isinstance(result['loss'], torch.Tensor) + assert isinstance(result['outputs'], dict) + assert 'loss' in result['outputs'] + assert result['outputs']['loss'] == torch.tensor(0.5) + + +def test_eval_forward_handles_missing_outputs(model: ContrastiveModel) -> None: + # Create a mock batch without 'loss' in outputs + batch = { + 'input_ids': torch.randint(0, 1000, (2, 128)), + 'attention_mask': torch.ones(2, 128, dtype=torch.long), + 'labels': torch.randint(0, 2, (2, 128)), + } + + # Mock the forward method to return outputs without 'loss' + with patch.object(model, 'forward') as mock_forward, \ + patch.object(model, 'loss') as mock_loss: + mock_forward.return_value = {'hidden_states': None} + mock_loss.return_value = torch.tensor( + 1.0, + ) # Assume loss is computed elsewhere + + result = model.eval_forward(batch) + + assert isinstance(result, dict) + assert 'loss' in result + assert 'outputs' in result + assert result['loss'] == torch.tensor(1.0) + assert result['outputs'] == mock_forward.return_value diff --git a/tests/test_utils.py b/tests/test_utils.py index 05c0881b9f..30f4b56c58 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,11 +2,13 @@ # SPDX-License-Identifier: Apache-2.0 import copy -from typing import Any +from typing import Any, Union import catalogue import pytest +import torch from omegaconf import DictConfig +from transformers import PreTrainedTokenizerBase from llmfoundry.registry import config_transforms from llmfoundry.utils.config_utils import ( @@ -102,3 +104,49 @@ def test_logged_cfg(): 'device_eval_batch_size': 1, }) assert expected_config == logged_config + + +class MockTokenizer(PreTrainedTokenizerBase): + + def __init__(self) -> None: + super().__init__() + self.pad_token: str = '' + self.eos_token: str = '' + self.bos_token: str = '' + self.unk_token: str = '' + self._vocab_size: int = 30000 + + def __len__(self) -> int: + return self._vocab_size + + def convert_tokens_to_ids( + self, + tokens: Union[str, list[str]], + ) -> Union[int, list[int]]: + return 0 + + @property + def pad_token_id(self) -> int: + return 0 + + def _batch_encode_plus(self, *args: Any, + **kwargs: Any) -> dict[str, torch.Tensor]: + batch_texts = args[0] if args else kwargs.get( + 'batch_text_or_text_pairs', + [], + ) + max_length = kwargs.get('max_length', 1024) + + if isinstance(batch_texts[0], list): + texts = [t for pair in batch_texts for t in pair] + else: + texts = batch_texts + + token_ids = torch.tensor([ + [hash(text) % 1000 + j for j in range(max_length)] for text in texts + ]) + + return { + 'input_ids': token_ids, + 'attention_mask': torch.ones_like(token_ids), + }