Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding chat capability for local transformers models #67

Merged
merged 1 commit into from
Apr 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion alfred/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def chat(self, log_save_path: Optional[str] = None, **kwargs: Any):
:param log_save_path: The file to save the chat logs.
:type log_save_path: Optional[str]
"""
if self.model_type in ["openai", "anthropic", "google"]:
if self.model_type in ["openai", "anthropic", "google", "huggingface"]:
self.model.chat(log_save_path=log_save_path, **kwargs)
else:
logger.error(
Expand Down
1 change: 1 addition & 0 deletions alfred/client/ssh/utils.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Modified with ideas originated from https://github.com/paramiko/paramiko/blob/main/demos/forward.py
"""

import select
import socket
import threading
Expand Down
3 changes: 2 additions & 1 deletion alfred/data/wrench.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
"""

Wrench Dataset Class is a dataset wrapper for Wrench, a weak supervision benchmark testbed.
Wrench Dataset Class is a dataset wrapper for Wrench, a weak supervision benchmark testbed.

"""

import json
import logging
import os
Expand Down
15 changes: 10 additions & 5 deletions alfred/fm/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"claude-2",
"claude-2.0",
"claude-3-opus-20240229",
"claude-3-sonnet-20240229"
"claude-3-sonnet-20240229",
)

try:
Expand Down Expand Up @@ -103,7 +103,9 @@ def _anthropic_query(
return response.content[0].text

def __init__(
self, model_string: str = "claude-3-opus-20240229", api_key: Optional[str] = None
self,
model_string: str = "claude-3-opus-20240229",
api_key: Optional[str] = None,
):
"""
Initialize the Anthropic API wrapper.
Expand Down Expand Up @@ -272,10 +274,13 @@ def _feedback(feedback: str, no_newline=False, override=False):
if isinstance(resp, MessageStopEvent):
break
if isinstance(resp, ContentBlockStartEvent):
resp=resp.content_block
resp = resp.content_block
if isinstance(resp, ContentBlockDeltaEvent):
resp=resp.delta
if resp.type == "content_block_stop" or resp.type == "message_delta":
resp = resp.delta
if (
resp.type == "content_block_stop"
or resp.type == "message_delta"
):
break
if resp.type != "text" and resp.type != "text_delta":
logger.warning(f"Unsupported response type {resp.type}")
Expand Down
101 changes: 101 additions & 0 deletions alfred/fm/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@

from .model import LocalAccessFoundationModel
from .response import CompletionResponse
from .utils import colorize_str, type_print

import json

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -457,3 +460,101 @@ def _encode_batch(self, batch_instance, **kwargs) -> List[torch.Tensor]:
_hidden_state = self._get_hidden_states(inputs, reduction=reduction)

return list(_hidden_state)

def chat(self, **kwargs: Any):
"""
Launch an interactive chat session
"""

def _feedback(feedback: str, no_newline=False, override=False):
if override:
print("\r", end="")
print(
colorize_str("Chat AI: ", "GREEN"),
end="",
)
type_print(feedback)
print(
"",
end="\n" if not no_newline else "",
)

model = kwargs.get("model", self.model_string)
c_title = colorize_str("Alfred's Anthropic Chat", "BLUE")
c_model = colorize_str(model, "WARNING")
c_exit = colorize_str("exit", "FAIL")
c_ctrlc = colorize_str("Ctrl+C", "FAIL")

temperature = kwargs.get("temperature", 0.7)
max_tokens = kwargs.get("max_tokens", 1024)
log_save_path = kwargs.get("log_save_path", None)
manual_chat_sequence = kwargs.get("manual_chat_sequence", None)

print(f"Welcome to the {c_title} session!\nYou are using the {c_model} model.")
print(f"Type '{c_exit}' or hit {c_ctrlc} to exit the chat session.")

message_log = [
# {
# "role": "system",
# "content": "You are a friendly chatbot.",
# },
]

print()
print("======== Chat Begin ========")
print()

try:
while True:
if manual_chat_sequence is not None:
query = manual_chat_sequence.pop(0)
_feedback(query, no_newline=True)
print()
if len(manual_chat_sequence) == 0:
break
else:
query = input(colorize_str("You: "))
if query == "exit":
_feedback("Goodbye!")
break

message_log.append({"role": "user", "content": query})
print(
colorize_str("Chat AI: ", "GREEN"),
end="",
)
try:
tokenized_chat = self.tokenizer.apply_chat_template(
message_log,
tokenize=True,
add_generation_prompt=True,
return_tensors="pt",
)
tokenized_chat = tokenized_chat.to(self.model.device)
except Exception as e:
_feedback(f"Error: {e}")
break
outputs = self.model.generate(
tokenized_chat,
max_new_tokens=max_tokens,
temperature=temperature,
do_sample=False if temperature == 0 else True,
)
outputs = outputs[0][len(tokenized_chat[0]) :]
txt = self.tokenizer.decode(outputs, skip_special_tokens=True)
type_print(txt)
print()
response = txt.strip().replace("\n", "")
message_log.append({"role": "assistant", "content": response})
except KeyboardInterrupt:
_feedback("Goodbye!")

print()
print("======== Chat End ========")
print()
print(colorize_str("Thank you for using Alfred!"))

if log_save_path:
with open(log_save_path, "w") as f:
json.dump(message_log, f)
print(f"Your chat log is saved to {log_save_path}")
1 change: 1 addition & 0 deletions alfred/fm/query/ranked_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Ranked Query Class encompasses query tem

"""

from typing import List, Union, Tuple, Callable

import numpy as np
Expand Down
1 change: 1 addition & 0 deletions alfred/fm/remote/protos/query_pb2.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions alfred/fm/remote/protos/query_pb2_grpc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
"""Client and server classes corresponding to protobuf-defined services."""

import grpc

from . import query_pb2 as query__pb2
Expand Down
2 changes: 1 addition & 1 deletion alfred/fm/response/completion_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CompletionResponse(Response):

def __init__(
self,
prediction: str,
prediction: str = None,
score: Optional[float] = None,
embedding: Optional[Union[torch.Tensor, np.ndarray]] = None,
):
Expand Down
4 changes: 2 additions & 2 deletions alfred/fm/response/ranked_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ class RankedResponse(Response):

def __init__(
self,
prediction: str,
scores: Dict,
prediction: str = None,
scores: Dict = None,
logits: Optional[Union[torch.Tensor, np.ndarray]] = None,
embeddings: Optional[Union[torch.Tensor, np.ndarray]] = None,
):
Expand Down
5 changes: 3 additions & 2 deletions alfred/template/string_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,9 @@ def from_promptsource(self, promptsource_template):
self._metadata = promptsource_template["metadata"]
self._answer_choices = promptsource_template["answer_choices"]

def apply(self, example: Union[Dict, List[Dict]], **kawrgs) -> Union[Query, List[Query]]:
def apply(
self, example: Union[Dict, List[Dict]], **kawrgs
) -> Union[Query, List[Query]]:
"""
Apply template to an example or a list of examples and returns a query object or a list of queries

Expand All @@ -138,7 +140,6 @@ def apply(self, example: Union[Dict, List[Dict]], **kawrgs) -> Union[Query, List
else:
raise ValueError(f"Unsupported example type: {type(example)}")


if "key_translator" in kawrgs:
key_translator = kawrgs["key_translator"]
else:
Expand Down
12 changes: 6 additions & 6 deletions alfred/voter/voter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,12 +154,12 @@ def clear_calibration(self):
self._calibration = None

def __call__(
self,
responses: Union[Iterable[str], str, Iterable[Response], Response],
matching_function: Optional[Callable] = None,
label_map: Optional[Dict] = None,
**kwargs: Any,
) -> np.ndarray:
self,
responses: Union[Iterable[str], str, Iterable[Response], Response],
matching_function: Optional[Callable] = None,
label_map: Optional[Dict] = None,
**kwargs: Any,
) -> np.ndarray:
"""
Vote for the responses based on the matching function and the label maps
"""
Expand Down
Loading