Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/stanfordnlp/dspy
Browse files Browse the repository at this point in the history
  • Loading branch information
okhat committed Jul 8, 2024
2 parents 19a4fc9 + 28c70ec commit ce0333e
Show file tree
Hide file tree
Showing 5 changed files with 45 additions and 17 deletions.
46 changes: 33 additions & 13 deletions dsp/modules/aws_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,8 @@ def __init__(
self.aws_provider = aws_provider
self.provider = aws_provider.get_provider_name()

self.kwargs["stop"] = ["<|eot_id|>"]

for k, v in kwargs.items():
self.kwargs[k] = v

Expand All @@ -290,25 +292,43 @@ def _create_body(self, prompt: str, **kwargs) -> tuple[int, dict[str, str | floa
for k, v in kwargs.items():
base_args[k] = v

n, query_args = self.aws_provider.sanitize_kwargs(base_args)
n, base_args = self.aws_provider.sanitize_kwargs(base_args)

# Meta models do not support the following parameters
query_args.pop("frequency_penalty", None)
query_args.pop("num_generations", None)
query_args.pop("presence_penalty", None)
query_args.pop("model", None)
base_args.pop("frequency_penalty", None)
base_args.pop("num_generations", None)
base_args.pop("presence_penalty", None)
base_args.pop("model", None)

max_tokens = query_args.pop("max_tokens", None)
if max_tokens:
query_args["max_gen_len"] = max_tokens
max_tokens = base_args.pop("max_tokens", None)

query_args: dict[str, str | float] = {}
if isinstance(self.aws_provider, Bedrock):
if max_tokens:
base_args["max_gen_len"] = max_tokens
query_args = base_args
query_args["prompt"] = prompt
elif isinstance(self.aws_provider, Sagemaker):
if max_tokens:
base_args["max_new_tokens"] = max_tokens
query_args["parameters"] = base_args
query_args["inputs"] = prompt
else:
raise ValueError("Error - provider not recognized")

query_args["prompt"] = prompt
return (n, query_args)

def _call_model(self, body: str) -> str:
response = self.aws_provider.predictor.invoke_model(
modelId=self._model_name,
response = self.aws_provider.call_model(
model_id=self._model_name,
body=body,
)
response_body = json.loads(response["body"].read())
return response_body["generation"]
if isinstance(self.aws_provider, Bedrock):
response_body = json.loads(response["body"].read())
completion = response_body["generation"]
elif isinstance(self.aws_provider, Sagemaker):
response_body = json.loads(response["Body"].read())
completion = response_body["generated_text"]
else:
raise ValueError("Error - provider not recognized")
return completion
6 changes: 5 additions & 1 deletion dsp/modules/hf_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def send_hftgi_request_v00(arg, **kwargs):


class HFClientVLLM(HFModel):
def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', url="http://localhost", **kwargs):
def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', url="http://localhost", http_request_kwargs=None, **kwargs):
super().__init__(model=model, is_client=True)

if isinstance(url, list):
Expand All @@ -132,6 +132,7 @@ def __init__(self, model, port, model_type: Literal['chat', 'text'] = 'text', ur

self.urls_const = tuple(self.urls)
self.port = port
self.http_request_kwargs = http_request_kwargs or {}
self.model_type = model_type
self.headers = {"Content-Type": "application/json"}
self.kwargs |= kwargs
Expand Down Expand Up @@ -198,6 +199,7 @@ def _generate(self, prompt, **kwargs):
port=self.port,
json=payload,
headers=self.headers,
**self.http_request_kwargs,
)

try:
Expand Down Expand Up @@ -225,6 +227,7 @@ def _generate(self, prompt, **kwargs):
port=self.port,
json=payload,
headers=self.headers,
**self.http_request_kwargs,
)

try:
Expand Down Expand Up @@ -323,6 +326,7 @@ def __init__(self, model, api_base="https://api.together.xyz/v1", api_key=None,
stop_default = "\n\n---"

self.kwargs = {
"model": model,
"temperature": 0.0,
"max_tokens": 512,
"top_p": 1,
Expand Down
3 changes: 2 additions & 1 deletion dspy/propose/grounded_proposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,7 @@ def __init__(
prompt_model,
trainset,
program_code_string=None,
view_data_batch_size=10,
use_dataset_summary=True,
program_aware=True,
use_task_demos=True,
Expand All @@ -257,7 +258,7 @@ def __init__(
self.prompt_model = prompt_model
self.program_code_string = program_code_string
self.data_summary = create_dataset_summary(
trainset=trainset, view_data_batch_size=10, prompt_model=prompt_model,
trainset=trainset, view_data_batch_size=view_data_batch_size, prompt_model=prompt_model,
)
print(f"DATA SUMMARY: {self.data_summary}")

Expand Down
6 changes: 4 additions & 2 deletions dspy/retrieve/qdrant_rm.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class QdrantRM(dspy.Retrieve):
Below is a code snippet that shows how to use Qdrant in the forward() function of a module
```python
self.retrieve = QdrantRM("my_collection_name", qdrant_client=qdrant_client, k=num_passages)
self.retrieve = QdrantRM(question, k=num_passages, filter=filter)
```
"""

Expand All @@ -62,12 +62,13 @@ def __init__(

super().__init__(k=k)

def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None) -> dspy.Prediction:
def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = None, filter: Optional[models.Filter]=None) -> dspy.Prediction:
"""Search with Qdrant for self.k top passages for query.
Args:
query_or_queries (Union[str, List[str]]): The query or queries to search for.
k (Optional[int]): The number of top passages to retrieve. Defaults to self.k.
filter (Optional["Filter"]): "Look only for points which satisfies this conditions". Default: None.
Returns:
dspy.Prediction: An object containing the retrieved passages.
Expand All @@ -90,6 +91,7 @@ def forward(self, query_or_queries: Union[str, list[str]], k: Optional[int] = No
vector=vector,
limit=k or self.k,
with_payload=[self._document_field],
filter=filter,
)
for vector in vectors
]
Expand Down
1 change: 1 addition & 0 deletions dspy/teleprompt/mipro_optimizer_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,7 @@ def compile(
trainset=trainset,
prompt_model=self.prompt_model,
program_code_string=self.program_code_string,
view_data_batch_size=self.view_data_batch_size,
program_aware=program_aware_proposer,
)

Expand Down

0 comments on commit ce0333e

Please sign in to comment.