Skip to content

Commit

Permalink
zamba2 base
Browse files Browse the repository at this point in the history
  • Loading branch information
yury-tokpanov committed Nov 28, 2024
1 parent 4b8844d commit 9a8b042
Show file tree
Hide file tree
Showing 18 changed files with 5,060 additions and 38 deletions.
1 change: 1 addition & 0 deletions docs/source/en/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -364,6 +364,7 @@ Flax), PyTorch, and/or TensorFlow.
| [YOLOS](model_doc/yolos) ||||
| [YOSO](model_doc/yoso) ||||
| [Zamba](model_doc/zamba) ||||
| [Zamba2](model_doc/zamba2) ||||
| [ZoeDepth](model_doc/zoedepth) ||||

<!-- End table-->
93 changes: 93 additions & 0 deletions docs/source/en/model_doc/zamba2.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
<!--Copyright 2024 The HuggingFace Team. All rights reserved.
Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
the License. You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
specific language governing permissions and limitations under the License.
⚠️ Note that this file is in Markdown but contain specific syntax for our doc-builder (similar to MDX) that may not be
rendered properly in your Markdown viewer.
-->
# Zamba2

Zamba2 is a large language model (LLM) trained by Zyphra, and made available under an Apache 2.0 license. Please see the [Zyphra Hugging Face](https://huggingface.co/collections/zyphra/) repository for model weights.

This model was contributed by [pglo](https://huggingface.co/pglo).


## Model details

Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B are hybrid models combining state-space models (Specifically [Mamba](https://github.com/state-spaces/mamba)) and transformer, and were trained using next-token prediction. Zamba2 uses shared transformer layers after every 6 mamba blocks. It uses the [Mistral v0.1 tokenizer](https://huggingface.co/mistralai/Mistral-7B-v0.1). We came to this architecture after a series of ablations at small scales. Zamba2-1.2B, Zamba2-2.7B and Zamba2-7B were pre-trained on 2T and 3T tokens, respectively.

<img src=https://github.com/user-attachments/assets/c2cff209-b901-483c-87aa-774b82a0769f width=30% height=40% />

## Quick start


### Presequities

Zamba2 requires you use `transformers` version 4.46.0 or higher:
```bash
pip install transformers>=4.46.0
```

## Inference

```python
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("Zyphra/Zamba2-7B")
model = AutoModelForCausalLM.from_pretrained("Zyphra/Zamba2-7B", device_map="cuda", torch_dtype=torch.bfloat16)

input_text = "What factors contributed to the fall of the Roman Empire?"
input_ids = tokenizer(input_text, return_tensors="pt").to("cuda")

outputs = model.generate(**input_ids, max_new_tokens=100)
print(tokenizer.decode(outputs[0]))
```


## Model card

The model cards can be found at:
* [Zamba2-1.2B](https://huggingface.co/Zyphra/Zamba2-1.2B)
* [Zamba2-2.7B](https://huggingface.co/Zyphra/Zamba2-2.7B)
* [Zamba2-7B](https://huggingface.co/Zyphra/Zamba2-7B)


## Issues
For issues with model output, or community discussion, please use the Hugging Face community [forum](https://huggingface.co/Zyphra/Zamba2-7B/discussions)


## License

The model weights are open-sourced via an Apache 2.0 license.


## Zamba2Config

[[autodoc]] Zamba2Config


## Zamba2Model

[[autodoc]] Zamba2Model
- forward


## Zamba2ForCausalLM

[[autodoc]] Zamba2ForCausalLM
- forward


## Zamba2ForSequenceClassification

[[autodoc]] transformers.Zamba2ForSequenceClassification
- forward
1 change: 1 addition & 0 deletions docs/source/en/perf_infer_gpu_one.md
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,7 @@ For now, Transformers supports SDPA inference and training for the following arc
* [XLM-RoBERTa](https://huggingface.co/docs/transformers/model_doc/xlm-roberta#transformers.XLMRobertaModel)
* [XLM-RoBERTa-XL](https://huggingface.co/docs/transformers/model_doc/xlm-roberta-xl#transformers.XLMRobertaXLModel)
* [YOLOS](https://huggingface.co/docs/transformers/model_doc/yolos#transformers.YolosModel)
* [Zamba2](https://huggingface.co/docs/transformers/model_doc/zamba2)

<Tip>

Expand Down
16 changes: 16 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -855,6 +855,7 @@
"models.yolos": ["YolosConfig"],
"models.yoso": ["YosoConfig"],
"models.zamba": ["ZambaConfig"],
"models.zamba2": ["Zamba2Config"],
"models.zoedepth": ["ZoeDepthConfig"],
"onnx": [],
"pipelines": [
Expand Down Expand Up @@ -3803,6 +3804,14 @@
"ZambaPreTrainedModel",
]
)
_import_structure["models.zamba2"].extend(
[
"Zamba2ForCausalLM",
"Zamba2ForSequenceClassification",
"Zamba2Model",
"Zamba2PreTrainedModel",
]
)
_import_structure["models.zoedepth"].extend(
[
"ZoeDepthForDepthEstimation",
Expand Down Expand Up @@ -5780,6 +5789,7 @@
from .models.yolos import YolosConfig
from .models.yoso import YosoConfig
from .models.zamba import ZambaConfig
from .models.zamba2 import Zamba2Config
from .models.zoedepth import ZoeDepthConfig

# Pipelines
Expand Down Expand Up @@ -8207,6 +8217,12 @@
ZambaModel,
ZambaPreTrainedModel,
)
from .models.zamba2 import (
Zamba2ForCausalLM,
Zamba2ForSequenceClassification,
Zamba2Model,
Zamba2PreTrainedModel,
)
from .models.zoedepth import (
ZoeDepthForDepthEstimation,
ZoeDepthPreTrainedModel,
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,5 +285,6 @@
yolos,
yoso,
zamba,
zamba2,
zoedepth,
)
2 changes: 2 additions & 0 deletions src/transformers/models/auto/configuration_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,7 @@
("yolos", "YolosConfig"),
("yoso", "YosoConfig"),
("zamba", "ZambaConfig"),
("zamba2", "Zamba2Config"),
("zoedepth", "ZoeDepthConfig"),
]
)
Expand Down Expand Up @@ -637,6 +638,7 @@
("yolos", "YOLOS"),
("yoso", "YOSO"),
("zamba", "Zamba"),
("zamba2", "Zamba2"),
("zoedepth", "ZoeDepth"),
]
)
Expand Down
3 changes: 3 additions & 0 deletions src/transformers/models/auto/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,7 @@
("yolos", "YolosModel"),
("yoso", "YosoModel"),
("zamba", "ZambaModel"),
("zamba2", "Zamba2Model"),
]
)

Expand Down Expand Up @@ -552,6 +553,7 @@
("xlnet", "XLNetLMHeadModel"),
("xmod", "XmodForCausalLM"),
("zamba", "ZambaForCausalLM"),
("zamba2", "Zamba2ForCausalLM"),
]
)

Expand Down Expand Up @@ -1008,6 +1010,7 @@
("xmod", "XmodForSequenceClassification"),
("yoso", "YosoForSequenceClassification"),
("zamba", "ZambaForSequenceClassification"),
("zamba2", "Zamba2ForSequenceClassification"),
]
)

Expand Down
7 changes: 7 additions & 0 deletions src/transformers/models/auto/tokenization_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -566,6 +566,13 @@
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
(
"zamba2",
(
"LlamaTokenizer" if is_sentencepiece_available() else None,
"LlamaTokenizerFast" if is_tokenizers_available() else None,
),
),
]
)

Expand Down
Loading

0 comments on commit 9a8b042

Please sign in to comment.