Skip to content

Commit

Permalink
feat - in local & global search, add the support to output raw chunk …
Browse files Browse the repository at this point in the history
…(instead of processing via associated OutputParser) [default is False]
  • Loading branch information
ksachdeva committed Oct 8, 2024
1 parent 69d9aa9 commit 2c5210b
Show file tree
Hide file tree
Showing 3 changed files with 20 additions and 2 deletions.
6 changes: 6 additions & 0 deletions examples/simple-app/app/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ def global_search(
True, # noqa: FBT003
help="Repeat instructions in the prompt",
),
output_raw: bool = typer.Option(False, help="Output raw response"), # noqa: FBT001, FBT003
enable_langsmith: bool = typer.Option(False, help="Enable Langsmith"), # noqa: FBT001, FBT003
):
if enable_langsmith:
Expand All @@ -99,6 +100,7 @@ def global_search(
],
["Show References", str(show_references)],
["Repeat Instructions In Prompt", str(repeat_instructions)],
["Output Raw", str(output_raw)],
]
)

Expand Down Expand Up @@ -133,6 +135,7 @@ def global_search(
context_builder=KeyPointsContextBuilder(
token_counter=TiktokenCounter(),
),
output_raw=output_raw,
)

global_search = GlobalSearch(
Expand Down Expand Up @@ -169,6 +172,7 @@ def local_search(
True, # noqa: FBT003
help="Repeat instructions in the prompt",
),
output_raw: bool = typer.Option(False, help="Output raw response"), # noqa: FBT001, FBT003
enable_langsmith: bool = typer.Option(False, help="Enable Langsmith"), # noqa: FBT001, FBT003
):
if enable_langsmith:
Expand Down Expand Up @@ -196,6 +200,7 @@ def local_search(
],
["Show References", str(show_references)],
["Repeat Instructions In Prompt", str(repeat_instructions)],
["Output Raw", str(output_raw)],
]
)

Expand Down Expand Up @@ -255,6 +260,7 @@ def local_search(
ollama_num_context=ollama_num_context,
),
retriever=retriever,
output_raw=output_raw,
)

# get the chain
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,13 @@ def __init__(
llm: BaseLLM,
prompt_builder: PromptBuilder,
context_builder: KeyPointsContextBuilder,
*,
output_raw: bool = False,
):
self._llm = llm
self._prompt_builder = prompt_builder
self._context_builder = context_builder
self._output_raw = output_raw

def __call__(self) -> Runnable:
kp_lambda = partial(
Expand All @@ -44,7 +47,10 @@ def __call__(self) -> Runnable:
)

prompt, output_parser = self._prompt_builder.build()
base_chain = prompt | self._llm | output_parser
base_chain = prompt | self._llm

if not self._output_raw:
base_chain = base_chain | output_parser

search_chain: Runnable = {
"report_data": operator.itemgetter("report_data")
Expand Down
8 changes: 7 additions & 1 deletion src/langchain_graphrag/query/local_search/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,21 @@ def __init__(
llm: BaseLLM,
prompt_builder: PromptBuilder,
retriever: BaseRetriever,
*,
output_raw: bool = False,
):
self._llm = llm
self._prompt_builder = prompt_builder
self._retriever = retriever
self._output_raw = output_raw

def __call__(self) -> Runnable:
prompt, output_parser = self._prompt_builder.build()

base_chain = prompt | self._llm | output_parser
base_chain = prompt | self._llm

if not self._output_raw:
base_chain = base_chain | output_parser

search_chain: Runnable = {
"context_data": self._retriever | _format_docs,
Expand Down

0 comments on commit 2c5210b

Please sign in to comment.