From e33409b7e7d814fb47859259a21e3232bdd2f336 Mon Sep 17 00:00:00 2001 From: Vlad Shulman Date: Thu, 2 May 2024 16:56:52 -0700 Subject: [PATCH] updating vllm --- mistral/mistral-7b-instruct-vllm/config.yaml | 9 +++++---- mistral/mistral-7b-instruct-vllm/model/model.py | 6 ++++++ 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/mistral/mistral-7b-instruct-vllm/config.yaml b/mistral/mistral-7b-instruct-vllm/config.yaml index 2795bc73..cc3e8e7d 100644 --- a/mistral/mistral-7b-instruct-vllm/config.yaml +++ b/mistral/mistral-7b-instruct-vllm/config.yaml @@ -5,7 +5,7 @@ model_metadata: avatar_url: https://cdn.baseten.co/production/static/explore/mistral_logo.png cover_image_url: https://cdn.baseten.co/production/static/explore/mistral.png engine_args: - model: mistralai/Mistral-7B-Instruct-v0.1 + model: mistralai/Mistral-7B-Instruct-v0.2 example_model_input: prompt: What is the Mistral wind? pretty_name: Mistral 7B Instruct @@ -15,12 +15,13 @@ model_metadata: model_name: Mistral 7B Instruct vLLM python_version: py311 requirements: -- vllm==0.2.1.post1 +- vllm==0.4.1 resources: - accelerator: A10G + accelerator: A100 memory: 25Gi use_gpu: true runtime: predict_concurrency: 256 -secrets: {} system_packages: [] +secrets: + hf_access_token: null diff --git a/mistral/mistral-7b-instruct-vllm/model/model.py b/mistral/mistral-7b-instruct-vllm/model/model.py index 926140f4..11af58dc 100644 --- a/mistral/mistral-7b-instruct-vllm/model/model.py +++ b/mistral/mistral-7b-instruct-vllm/model/model.py @@ -1,3 +1,4 @@ +import os import uuid from typing import Any @@ -10,6 +11,11 @@ class Model: def __init__(self, **kwargs) -> None: self.engine_args = kwargs["config"]["model_metadata"]["engine_args"] self.prompt_format = kwargs["config"]["model_metadata"]["prompt_format"] + self._secrets = kwargs["secrets"] + + if "hf_access_token" in self._secrets._base_secrets.keys(): + # Set the environment variable + os.environ["HF_TOKEN"] = self._secrets["hf_access_token"] def load(self) -> None: self.llm_engine = AsyncLLMEngine.from_engine_args(