-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdifyDslGenCheck_en.py
179 lines (151 loc) · 6.61 KB
/
difyDslGenCheck_en.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
import os
import operator
from typing import Annotated, Any
from pydantic import BaseModel, Field
from langchain_anthropic import ChatAnthropic
from langchain_core.runnables import ConfigurableField
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.output_parsers import StrOutputParser
from langgraph.graph import StateGraph, END
import yaml
import logging
import re
# Logging configuration
def setup_logging():
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(levelname)s - %(message)s',
datefmt='%Y-%m-%d %H:%M:%S'
)
# State class
class State(BaseModel):
query: str = Field(..., description="Workflow content that the user wants to generate")
messages: Annotated[list[str], operator.add] = Field(
default=[], description="Response history"
)
current_judge: bool = Field(default=False, description="Quality check result")
judgement_reason: str = Field(default="", description="Quality check judgment reason")
operator_approved: bool = Field(default=False, description="Operator approval status")
class Judgement(BaseModel):
reason: str = Field(default="", description="Judgment reason")
judge: bool = Field(default=False, description="Judgment result")
class WorkflowGenerator:
def __init__(self):
self.llm = ChatAnthropic(model="claude-3-5-sonnet-20241022", temperature=0.0)
self.llm = self.llm.configurable_fields(max_tokens=ConfigurableField(id='max_tokens'))
def load_prompt(self, file_path: str) -> dict:
with open(file_path, 'r', encoding='utf-8') as file:
return yaml.safe_load(file)
def generate_workflow(self, state: State) -> dict[str, Any]:
logging.info("workflow_generator_node: START")
query = state.query
role = "You are an expert in generating Dify workflows."
role_details = self.load_prompt("workflow_generator_prompt.yml")
# Create a prompt including the reason if there were issues in the previous check
if state.judgement_reason:
prompt = ChatPromptTemplate.from_template(
"""{role_details}{query}
The following issues were detected in the previous generation, please fix them:
{judgement_reason}""".strip()
)
else:
prompt = ChatPromptTemplate.from_template(
"""{role_details}{query}""".strip()
)
chain = prompt | self.llm.with_config({"max_tokens": 8192}) | StrOutputParser()
answer = self._get_complete_answer(chain, role, role_details, query, state.judgement_reason)
logging.info("workflow_generator_node: END")
return {"messages": [answer]}
def _get_complete_answer(self, chain, role, role_details, query, judgement_reason=""):
answer = ""
while True:
try:
current_answer = chain.invoke({
"role": role,
"role_details": role_details,
"query": query + ("\nExisting answer:" + answer if answer else ""),
"judgement_reason": judgement_reason
})
answer += current_answer
break
except Exception as e:
if "maximum context length" not in str(e):
raise e
return answer
def check_workflow(self, state: State) -> dict[str, Any]:
logging.info("check_node: START")
answer = state.messages[-1]
prompt_data = self.load_prompt("workflow_generator_prompt.yml")
prompt = ChatPromptTemplate.from_template(
"""
Please check if the generated workflow follows the rules specified in the prompt.
Answer 'False' if there are issues, 'True' if there are no issues.
Also, please explain the reason for your judgment.
Prompt: {prompt_data}
Answer: {answer}
"""
)
chain = prompt | self.llm.with_structured_output(Judgement)
result: Judgement = chain.invoke({
"query": state.query,
"answer": answer,
"prompt_data": prompt_data
})
logging.info(f"check_node: END {'with error' if not result.judge else ''}")
return {
"current_judge": result.judge,
"judgement_reason": result.reason
}
def ask_operator(state: State) -> dict[str, Any]:
logging.info("Checking with operator...")
print(f"\nWarning: The following issues were detected:\n{state.judgement_reason}")
print("\nGenerated workflow:")
print(state.messages[-1])
while True:
response = input("\nDo you want to regenerate this workflow? (y/n): ").lower()
if response == 'y':
return {"operator_approved": False}
elif response == 'n':
return {"operator_approved": True}
else:
print("Invalid input. Please enter y or n.")
def create_workflow_graph(generator: WorkflowGenerator) -> StateGraph:
workflow = StateGraph(State)
workflow.add_node("workflow_generator", generator.generate_workflow)
workflow.add_node("check", generator.check_workflow)
workflow.add_node("ask_operator", ask_operator)
workflow.set_entry_point("workflow_generator")
workflow.add_edge("workflow_generator", "check")
workflow.add_conditional_edges(
"check",
lambda state: state.current_judge,
{True: END, False: "ask_operator"}
)
workflow.add_conditional_edges(
"ask_operator",
lambda state: state.operator_approved,
{True: END, False: "workflow_generator"}
)
return workflow.compile()
def main():
setup_logging()
wanted_workflow = """
Purpose: Research and create an article about cooking recipes
1. Search the internet for cooking recipes and get 3 URLs
2. Retrieve information from the 3 URLs
3. Input the information obtained from the 3 URLs into LLM and organize the cooking recipe for output
"""
generator = WorkflowGenerator()
workflow = create_workflow_graph(generator)
initial_state = State(query=wanted_workflow)
result = workflow.invoke(initial_state)
logging.info(f"Judgment: {result['current_judge']}")
logging.info(f"Judgment reason: {result['judgement_reason']}")
# Extract the part enclosed by ```yaml and ``` from the message
yaml_content = re.search(r'```yaml\n(.*?)```', result['messages'][-1], re.DOTALL)
if yaml_content:
logging.info(f"Result: \n {yaml_content.group(1)}")
else:
logging.error("YAML content not found.")
if __name__ == "__main__":
main()