Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

About inference time #10

Open
Blank-z0 opened this issue Jan 9, 2025 · 0 comments
Open

About inference time #10

Blank-z0 opened this issue Jan 9, 2025 · 0 comments

Comments

@Blank-z0
Copy link

Blank-z0 commented Jan 9, 2025

Hi, great work! I think Zamba2 is a great hybrid SSM model for the whole community.
I'm trying to reproduce the time-to-first-token result presented in your tech report. But I found that I couldn't reproduce the TTFT of Zamba2 reported in the paper. On 1.2B and 2.7B models, I tried setting the input prompt length to 2048, the output token number to 1, and batch size is 1.
However, the TTFT of Zamba2 is higher than attention-based models, such as Phi2-2.7B, Qwen2-1.5B, Qwen2.5-3B.
For example, Zamba2-2.7B vs. Phi2-2.7B 150ms vs. 90ms, Zamba2-1.2B vs. Qwen2-1.5B 94ms vs. 81ms, I use 1 A100 (40G)
For Zamba2, I used mamba-ssm and casual-conv1d to speed up the inference, and for attention-based LLM I used flash attention 2.
The following are my machine envs

PRETTY_NAME="Ubuntu 22.04.2 LTS"
NAME="Ubuntu"
VERSION_ID="22.04"
VERSION="22.04.2 LTS (Jammy Jellyfish)"
VERSION_CODENAME=jammy
ID=ubuntu
ID_LIKE=debian
HOME_URL="https://www.ubuntu.com/"
SUPPORT_URL="https://help.ubuntu.com/"
BUG_REPORT_URL="https://bugs.launchpad.net/ubuntu/"
PRIVACY_POLICY_URL="https://www.ubuntu.com/legal/terms-and-policies/privacy-policy"
UBUNTU_CODENAME=jammy

libnccl2.18.3-1+cuda12.2

python pkgs are

Package                   Version              Editable project location
------------------------- -------------------- ----------------------------------------------------------
absl-py                   1.4.0
accelerate                0.33.0
accelerator               2024.3.8.dev1
aiohttp                   3.8.4
aiosignal                 1.3.1
annotated-types           0.7.0
apex                      0.1
argon2-cffi               21.3.0
argon2-cffi-bindings      21.2.0
asttokens                 2.2.1
astunparse                1.6.3
async-timeout             4.0.2
attrs                     23.1.0
audioread                 3.0.0
backcall                  0.2.0
beautifulsoup4            4.12.2
bitsandbytes              0.43.0
bleach                    6.0.0
blis                      0.7.10
bottle                    0.12.25
cachetools                5.3.1
catalogue                 2.0.9
causal-conv1d             1.3.0.post1
certifi                   2023.7.22
cffi                      1.15.1
charset-normalizer        3.2.0
click                     8.1.5
cloudpickle               2.2.1
cmake                     3.27.1
comm                      0.1.4
confection                0.1.1
contourpy                 1.1.0
cubinlinker               0.3.0+2.g7c3675e
cuda-python               12.1.0rc5+1.g994d8d0
cudf                      23.6.0
cugraph                   23.6.0
cugraph-dgl               23.6.0
cugraph-service-client    23.6.0
cugraph-service-server    23.6.0
cuml                      23.6.0
cupy-cuda12x              12.1.0
cycler                    0.11.0
cymem                     2.0.7
Cython                    3.0.0
dask                      2023.3.2
dask-cuda                 23.6.0
dask-cudf                 23.6.0
datasets                  2.18.0
debugpy                   1.8.7
decorator                 5.1.1
deepspeed                 0.14.4
defusedxml                0.7.1
dill                      0.3.8
distributed               2023.3.2.1
dm-tree                   0.1.8
docker-pycreds            0.4.0
einops                    0.6.1
exceptiongroup            1.1.2
execnet                   2.0.2
executing                 1.2.0
expecttest                0.1.3
fastjsonschema            2.18.0
fastrlock                 0.8.1
fbgemm-gpu                0.6.0
filelock                  3.12.2
flash-attn                2.6.3
fonttools                 4.42.0
frozenlist                1.4.0
fsspec                    2023.6.0
gast                      0.5.4
gin-config                0.5.0
gitdb                     4.0.11
GitPython                 3.1.43
google-auth               2.22.0
google-auth-oauthlib      0.4.6
graphsurgeon              0.4.6
grpcio                    1.56.2
hjson                     3.1.0
huggingface-hub           0.26.5
hypothesis                5.35.1
idna                      3.4
importlib-metadata        6.8.0
iniconfig                 2.0.0
intel-openmp              2021.4.0
iopath                    0.1.10
ipykernel                 6.25.0
ipython                   8.14.0
ipython-genutils          0.2.0
jedi                      0.19.0
Jinja2                    3.1.2
joblib                    1.3.1
json5                     0.9.14
jsonlines                 4.0.0
jsonschema                4.18.6
jsonschema-specifications 2023.7.1
jupyter_client            8.3.0
jupyter_core              5.3.1
jupyter-tensorboard       0.2.0
jupyterlab                2.3.2
jupyterlab-pygments       0.2.2
jupyterlab-server         1.2.0
jupytext                  1.15.0
kiwisolver                1.4.4
langcodes                 3.3.0
librosa                   0.9.2
llvmlite                  0.40.1
locket                    1.0.0
mamba-ssm                 2.1.0
Markdown                  3.4.4
markdown-it-py            3.0.0
MarkupSafe                2.1.3
matplotlib                3.7.2
matplotlib-inline         0.1.6
mdit-py-plugins           0.4.0
mdurl                     0.1.2
mistune                   3.0.1
mkl                       2021.1.1
mkl-devel                 2021.1.1
mkl-include               2021.1.1
mock                      5.1.0
modelscope                1.22.0
mpmath                    1.3.0
msgpack                   1.0.5
multidict                 6.0.4
multiprocess              0.70.16
murmurhash                1.0.9
nbclient                  0.8.0
nbconvert                 7.7.3
nbformat                  5.9.2
nest-asyncio              1.5.7
networkx                  2.6.3
ninja                     1.11.1
notebook                  6.4.10
numba                     0.57.1+1.gc785c8f1f
numpy                     1.22.2
nvidia-cublas-cu12        12.1.3.1
nvidia-cuda-cupti-cu12    12.1.105
nvidia-cuda-nvrtc-cu12    12.1.105
nvidia-cuda-runtime-cu12  12.1.105
nvidia-cudnn-cu12         9.1.0.70
nvidia-cufft-cu12         11.0.2.54
nvidia-curand-cu12        10.3.2.106
nvidia-cusolver-cu12      11.4.5.107
nvidia-cusparse-cu12      12.1.0.106
nvidia-dali-cuda120       1.28.0
nvidia-ml-py              12.560.30
nvidia-nccl-cu12          2.20.5
nvidia-nvjitlink-cu12     12.6.68
nvidia-nvtx-cu12          12.1.105
nvidia-pyindex            1.0.9
nvtx                      0.2.5
oauthlib                  3.2.2
onnx                      1.14.0
opencv                    4.7.0
packaging                 23.1
pandas                    1.5.2
pandocfilters             1.5.0
parso                     0.8.3
partd                     1.4.0
pathy                     0.10.2
peft                      0.11.0
pexpect                   4.8.0
pickleshare               0.7.5
Pillow                    9.2.0
pip                       23.2.1
platformdirs              3.10.0
pluggy                    1.2.0
ply                       3.11
polygraphy                0.47.1
pooch                     1.7.0
portalocker               2.10.1
preshed                   3.0.8
prettytable               3.8.0
prometheus-client         0.17.1
prompt-toolkit            3.0.39
protobuf                  4.21.12
psutil                    5.9.4
ptxcompiler               0.8.1+1.g4a94326
ptyprocess                0.7.0
pure-eval                 0.2.2
py-cpuinfo                9.0.0
pyarrow                   17.0.0
pyarrow-hotfix            0.6
pyasn1                    0.5.0
pyasn1-modules            0.3.0
pybind11                  2.11.1
pycocotools               2.0+nv0.7.3
pycparser                 2.21
pydantic                  2.9.2
pydantic_core             2.23.4
Pygments                  2.16.1
pylibcugraph              23.6.0
pylibcugraphops           23.6.0
pylibraft                 23.6.0
Pympler                   1.1
pynvml                    11.4.1
pyparsing                 3.0.9
pytest                    7.4.0
pytest-flakefinder        1.1.0
pytest-rerunfailures      12.0
pytest-shard              0.1.2
pytest-xdist              3.3.1
python-dateutil           2.8.2
python-hostlist           1.23.0
pytorch-quantization      2.1.2
pytz                      2023.3
PyYAML                    6.0.1
pyzmq                     25.1.0
raft-dask                 23.6.0
referencing               0.30.2
regex                     2023.6.3
requests                  2.31.0
requests-oauthlib         1.3.1
resampy                   0.4.2
rmm                       23.6.0
rpds-py                   0.9.2
rsa                       4.9
safetensors               0.4.5
scikit-learn              1.2.0
scipy                     1.11.1
Send2Trash                1.8.2
sentencepiece             0.1.99
sentry-sdk                2.17.0
setproctitle              1.3.3
setuptools                68.0.0
six                       1.16.0
smart-open                6.3.0
smmap                     5.0.1
sortedcontainers          2.4.0
soundfile                 0.12.1
soupsieve                 2.4.1
spacy                     3.6.0
spacy-legacy              3.0.12
spacy-loggers             1.0.4
sphinx-glpi-theme         0.3
srsly                     2.4.7
stack-data                0.6.2
sympy                     1.12
tabulate                  0.9.0
tbb                       2021.10.0
tblib                     2.0.0
tensorboard               2.9.0
tensorboard-data-server   0.6.1
tensorboard-plugin-wit    1.8.1
tensorrt                  8.6.1
terminado                 0.17.1
thinc                     8.1.10
threadpoolctl             3.2.0
thriftpy2                 0.4.16
tinycss2                  1.2.1
tokenizers                0.19.1
toml                      0.10.2
tomli                     2.0.1
toolz                     0.12.0
torch                     2.4.0
torchdata                 0.7.1
torchvision               0.19.0
tornado                   6.3.2
tqdm                      4.65.0
traitlets                 5.9.0
transformer_engine_cu12   1.12.0
transformers              4.43.0.dev0          /codes/transformers_zamba2
treelite                  3.2.0
treelite-runtime          3.2.0
triton                    3.0.0
typer                     0.9.0
types-dataclasses         0.6.6
typing_extensions         4.12.2
ucx-py                    0.32.0
uff                       0.6.9
urllib3                   1.26.16
waitress                  3.0.2
wandb                     0.18.5
wasabi                    1.1.2
wcwidth                   0.2.6
webencodings              0.5.1
Werkzeug                  2.3.6
wheel                     0.41.1
xdoctest                  1.0.2
xgboost                   1.7.5
xxhash                    3.5.0
yarl                      1.9.2
zict                      3.0.0
zipp                      3.16.2

My mamba-ssm and causal-conv1d use the same version as in setup.py. Below is my code to reproduce TTFT:

# Modified from https://github.com/state-spaces/mamba/blob/main/benchmarks/benchmark_generation_mamba_simple.py
import argparse
import time
import json
import torch
import torch.nn.functional as F
import sys
from einops import rearrange
from transformers import AutoTokenizer, AutoModelForCausalLM

parser = argparse.ArgumentParser(description="Generation benchmarking")
parser.add_argument("--model-name", type=str, default="Qwen2-1.5B")
parser.add_argument("--prompt", type=str, default=None)
parser.add_argument("--promptlen", type=int, default=2048)
parser.add_argument("--genlen", type=int, default=1)
parser.add_argument("--temperature", type=float, default=1.0)
parser.add_argument("--topk", type=int, default=0)
parser.add_argument("--topp", type=float, default=1.0)
parser.add_argument("--minp", type=float, default=0.0)
parser.add_argument("--repetition-penalty", type=float, default=1.0)
parser.add_argument("--batch", type=int, default=1)
# parser.add_argument("--output_to_files", type=bool, default=True)
args = parser.parse_args()

repeats = 3
device = "cuda"
dtype = torch.bfloat16
root = '../../model/'
cache_dir = './'

print(f"Loading model {args.model_name}")

tokenizer = AutoTokenizer.from_pretrained(root+args.model_name)
model = AutoModelForCausalLM.from_pretrained(root+args.model_name, device_map=device, torch_dtype=dtype, attn_implementation="flash_attention_2")
model.eval()
print(f"Number of parameters: {sum(p.numel() for p in model.parameters() if p.requires_grad)}")

torch.random.manual_seed(0)
if args.prompt is None:
    input_ids = torch.randint(1, 1000, (args.batch, args.promptlen), dtype=torch.long, device="cuda")
    attn_mask = torch.ones_like(input_ids, dtype=torch.long, device="cuda")
else:
    tokens = tokenizer(args.prompt, return_tensors="pt")
    input_ids = tokens.input_ids.to(device=device)
    attn_mask = tokens.attention_mask.to(device=device)
max_length = input_ids.shape[1] + args.genlen
print(input_ids.shape)
fn = lambda: model.generate(
    input_ids=input_ids,
    attention_mask=attn_mask,
    max_new_tokens=args.genlen,
    return_dict_in_generate=True,
    pad_token_id=tokenizer.eos_token_id,
    do_sample=False,
    temperature=args.temperature,
    top_k=args.topk,
    top_p=args.topp,
    repetition_penalty=args.repetition_penalty,
)
try:
    out = fn()
except RuntimeError as e:
    if "out of memory" in str(e).lower():
        with open('ouput.txt', mode='a', encoding='utf-8') as f:
            print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}\n", file=f)
            print(f"{args.model_name} OOM!\n\n", file=f)
        print(f"{args.model_name} OOM!")
        sys.exit(1)

if args.prompt is not None:
    print(tokenizer.batch_decode(out.sequences.tolist()))

torch.cuda.synchronize()
start = time.time()
for _ in range(repeats):
    try:
        fn()
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            with open(f'{args.model_name}-speed.txt', mode='a', encoding='utf-8') as f:
                print(f"Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}", file=f)
                print(f"{args.model_name} OOM!\n\n", file=f)
            print(f"{args.model_name} OOM!")
            sys.exit(1)
               
torch.cuda.synchronize()
end=time.time()
print(f"Batch size: {args.batch}, Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}")
print(f"{args.model_name} prompt processing + decoding time: {(end - start) / repeats * 1000:.0f}ms\n\n")
with open(f'{args.model_name}-speed.txt', mode='a', encoding='utf-8') as f:
    print(f"Batch size: {args.batch}, Prompt length: {len(input_ids[0])}, generation length: {len(out.sequences[0]) - len(input_ids[0])}", file=f)
    print(f"{args.model_name} prompt processing + decoding time: {(end - start) / repeats * 1000:.0f}ms\n\n", file=f)

I would like to know if my reproduction process is correct and how you calculated the TTFT results shown in the tech report.
Looking forward to your reply, thanks!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant