diff --git a/README.md b/README.md
index 8ff1751..1acc252 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,12 @@ Machine Learning Serving focused on GenAI & LLMs with simplicity as the top prio
## Installation
+**Stable:**
+```shell
+pip install FastServeAI
+```
+
+**Latest:**
```shell
pip install git+https://github.com/aniketmaurya/fastserve.git@main
```
@@ -20,7 +26,8 @@ python -m fastserve
## Usage/Examples
-### Serve Mistral-7B with Llama-cpp
+
+### Serve LLMs with Llama-cpp
```python
from fastserve.models import ServeLlamaCpp
@@ -32,6 +39,29 @@ serve.run_server()
or, run `python -m fastserve.models --model llama-cpp --model_path openhermes-2-mistral-7b.Q5_K_M.gguf` from terminal.
+
+### Serve vLLM
+
+```python
+from fastserve.models import ServeVLLM
+
+app = ServeVLLM("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
+app.run_server()
+```
+
+You can use the FastServe client that will automatically apply chat template for you -
+
+```python
+from fastserve.client import vLLMClient
+from rich import print
+
+client = vLLMClient("TinyLlama/TinyLlama-1.1B-Chat-v1.0")
+response = client.chat("Write a python function to resize image to 224x224", keep_context=True)
+# print(client.context)
+print(response["outputs"][0]["text"])
+```
+
+
### Serve SDXL Turbo
```python
@@ -46,7 +76,7 @@ or, run `python -m fastserve.models --model sdxl-turbo --batch_size 2 --timeout
This application comes with an UI. You can access it at [http://localhost:8000/ui](http://localhost:8000/ui) .
-
+
### Face Detection
diff --git a/src/fastserve/client/__init__.py b/src/fastserve/client/__init__.py
new file mode 100644
index 0000000..e69de29
diff --git a/src/fastserve/client/vllm.py b/src/fastserve/client/vllm.py
new file mode 100644
index 0000000..73dcc18
--- /dev/null
+++ b/src/fastserve/client/vllm.py
@@ -0,0 +1,47 @@
+import logging
+
+import requests
+
+
+class Client:
+ def __init__(self):
+ pass
+
+
+class vLLMClient(Client):
+ def __init__(self, model: str, base_url="http://localhost:8000/endpoint"):
+ from transformers import AutoTokenizer
+
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(model)
+ self.context = []
+ self.base_url = base_url
+
+ def chat(self, prompt: str, keep_context=False):
+ new_msg = {"role": "user", "content": prompt}
+ if keep_context:
+ self.context.append(new_msg)
+ messages = self.context
+ else:
+ messages = [new_msg]
+
+ logging.info(messages)
+ chat = self.tokenizer.apply_chat_template(messages, tokenize=False)
+ headers = {
+ "accept": "application/json",
+ "Content-Type": "application/json",
+ }
+ data = {
+ "prompt": chat,
+ "temperature": 0.8,
+ "top_p": 1,
+ "max_tokens": 500,
+ "stop": [],
+ }
+
+ response = requests.post(self.base_url, headers=headers, json=data).json()
+ if keep_context:
+ self.context.append(
+ {"role": "assistant", "content": response["outputs"][0]["text"]}
+ )
+ return response
diff --git a/src/fastserve/models/__init__.py b/src/fastserve/models/__init__.py
index 3c7d565..8c495da 100644
--- a/src/fastserve/models/__init__.py
+++ b/src/fastserve/models/__init__.py
@@ -5,3 +5,4 @@
from fastserve.models.llama_cpp import ServeLlamaCpp as ServeLlamaCpp
from fastserve.models.sdxl_turbo import ServeSDXLTurbo as ServeSDXLTurbo
from fastserve.models.ssd import ServeSSD1B as ServeSSD1B
+from fastserve.models.vllm import ServeVLLM as ServeVLLM
diff --git a/src/fastserve/models/vllm.py b/src/fastserve/models/vllm.py
index 3f57fe2..db89fe0 100644
--- a/src/fastserve/models/vllm.py
+++ b/src/fastserve/models/vllm.py
@@ -1,46 +1,65 @@
-import os
-from typing import List
+import logging
+from typing import Any, List, Optional
-from fastapi import FastAPI
from pydantic import BaseModel
-from vllm import LLM, SamplingParams
-tensor_parallel_size = int(os.environ.get("DEVICES", "1"))
-print("tensor_parallel_size: ", tensor_parallel_size)
+from fastserve.core import FastServe
-llm = LLM("meta-llama/Llama-2-7b-hf", tensor_parallel_size=tensor_parallel_size)
+logger = logging.getLogger(__name__)
class PromptRequest(BaseModel):
- prompt: str
- temperature: float = 1
+ prompt: str = "Write a python function to resize image to 224x224"
+ temperature: float = 0.8
+ top_p: float = 1.0
max_tokens: int = 200
stop: List[str] = []
class ResponseModel(BaseModel):
prompt: str
- prompt_token_ids: List # The token IDs of the prompt.
- outputs: List[str] # The output sequences of the request.
+ prompt_token_ids: Optional[List] = None # The token IDs of the prompt.
+ text: str # The output sequences of the request.
finished: bool # Whether the whole request is finished.
-app = FastAPI()
+class ServeVLLM(FastServe):
+ def __init__(
+ self,
+ model,
+ batch_size=1,
+ timeout=0.0,
+ *args,
+ **kwargs,
+ ):
+ from vllm import LLM
+
+ self.llm = LLM(model)
+ self.args = args
+ self.kwargs = kwargs
+ super().__init__(
+ batch_size,
+ timeout,
+ input_schema=PromptRequest,
+ # response_schema=ResponseModel,
+ )
+
+ def __call__(self, request: PromptRequest) -> Any:
+ from vllm import SamplingParams
+ sampling_params = SamplingParams(
+ temperature=request.temperature,
+ top_p=request.top_p,
+ max_tokens=request.max_tokens,
+ )
+ result = self.llm.generate(request.prompt, sampling_params=sampling_params)
+ logger.info(result)
+ return result
-@app.post("/serve", response_model=ResponseModel)
-def serve(request: PromptRequest):
- sampling_params = SamplingParams(
- max_tokens=request.max_tokens,
- temperature=request.temperature,
- stop=request.stop,
- )
+ def handle(self, batch: List[PromptRequest]) -> List:
+ responses = []
+ for request in batch:
+ output = self(request)
+ responses.extend(output)
- result = llm.generate(request.prompt, sampling_params=sampling_params)[0]
- response = ResponseModel(
- prompt=request.prompt,
- prompt_token_ids=result.prompt_token_ids,
- outputs=result.outputs,
- finished=result.finished,
- )
- return response
+ return responses
diff --git a/src/fastserve/utils.py b/src/fastserve/utils.py
index 55deeb6..0dc2c59 100644
--- a/src/fastserve/utils.py
+++ b/src/fastserve/utils.py
@@ -23,3 +23,27 @@ def get_ui_folder():
path = os.path.join(os.path.dirname(__file__), "../ui")
path = os.path.abspath(path)
return path
+
+
+def download_file(url: str, dest: str):
+ import requests
+ from tqdm import tqdm
+
+ if dest is None:
+ dest = os.path.abspath(os.path.basename(dest))
+
+ response = requests.get(url, stream=True)
+ response.raise_for_status()
+ total_size = int(response.headers.get("content-length", 0))
+ block_size = 1024
+ with open(dest, "wb") as file, tqdm(
+ desc=dest,
+ total=total_size,
+ unit="iB",
+ unit_scale=True,
+ unit_divisor=1024,
+ ) as bar:
+ for data in response.iter_content(block_size):
+ file.write(data)
+ bar.update(len(data))
+ return dest