forked from msamogh/schema_attention_model
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathchat.py
246 lines (200 loc) · 8.42 KB
/
chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
import asyncio
import json
import numpy as np
import os
import yaml
from dataclasses import dataclass, field, asdict
from typing import List
import re
import torch
from pprint import pprint
from typing import Any, Dict, Tuple, Text
from collections import Counter, defaultdict
from torch.utils.data import DataLoader
from tqdm import tqdm, trange
from tokenizers import BertWordPieceTokenizer
from data_readers import filter_dataset, NextActionDataset, NextActionSchema
from models import ActionBertModel, SchemaActionBertModel
from STAR.apis import api
from data_model_utils import CURR_DIR, get_system_action, load_saved_model
from slot_extraction import get_entity, to_db_result_string
with open("messages.yml", "r") as f:
MESSAGES = yaml.load(f)["templates"]
def user_utterance_to_model_input(user_input, requested_entity_name):
if len(user_input) < 8 and (
"hello" in user_input or "hi" in user_input or "hey" in user_input
):
user_input = "hello hello"
user_input = explicitize_user_input(user_input, requested_entity_name)
user_input = get_turn_str("User", user_input)
return user_input
def get_requested_entity_if_exists(system_response):
search_result = re.search(r"%%(.*)%%", system_response)
if search_result is not None:
requested_entity = search_result.group(1)
return requested_entity
search_result = re.search(r"%(.*)%", system_response)
if search_result is not None:
requested_entity = search_result.group(1)
return requested_entity
return None
def remove_entity_annotations(system_response):
system_response = re.sub(r"%%(.*)%%", "", system_response)
system_response = system_response.replace("%", "")
return system_response
def explicitize_user_input(user_input, requested_entity_name):
if requested_entity_name is not None:
user_input = f"The {requested_entity_name} is {user_input}."
return user_input
def get_turn_str(speaker, utterance):
return "[{}] {} [SEP] ".format(speaker, utterance.strip())
def get_db_result_string(task, request_type, slots):
assert request_type in ["[QUERY]", "[QUERY_BOOK]", "[QUERY_CHECK]"]
# Sets RequestType to "", "Book", or "Check".
slots["request type"] = request_type[1:-1][len("QUERY_") :].title()
def to_title_case(key):
return key.title().replace(" ", "").strip()
api_response = api.call_api(
task,
constraints=[{to_title_case(k): v for k, v in slots.items()}],
)[0]
db_result_string = to_db_result_string(api_response)
return db_result_string, api_response
def title_to_snake_case(s):
return re.sub(r"(?<!^)(?=[A-Z])", "_", s).lower()
@dataclass
class DialogueContext(object):
history: List[Text] = field(default_factory=list)
slots: Dict[Text, Any] = field(default_factory=dict)
prev_sys_response: Text = field(default="")
db_results_so_far: Dict[Text, Any] = field(default_factory=dict)
def load_schema_json(task):
return json.load(
open(os.path.join(CURR_DIR, "STAR", "tasks", task, f"{task}.json"), "r")
)
async def handle_web_message(
context: Dict[Text, Any],
new_message: Text,
model: SchemaActionBertModel,
schema: Dict,
task: Text,
domain: Text,
) -> Text:
ctx = DialogueContext(**context)
waiting_for_user_input = False
db_result_string = None
db_result_dict = None
while True:
if db_result_string is not None:
ctx.history.append(db_result_string)
else:
if waiting_for_user_input:
break
requested_entity_name = get_requested_entity_if_exists(
ctx.prev_sys_response
)
if requested_entity_name is not None:
ctx.slots[requested_entity_name] = get_entity(
requested_entity_name, ctx.prev_system_response, new_message
)
ctx.history.append(
user_utterance_to_model_input(new_message, requested_entity_name)
)
# Get system action and ask user to rephrase if necessary
system_action, is_ambiguous = await get_system_action(
model, ctx.history, domain, task
)
# Check if db_result_string is None, meaning that the system action is not a query
if db_result_string is None and is_ambiguous:
del ctx.history[-1]
return json(
{
"response": f"{MESSAGES['rephrase']}{remove_entity_annotations(ctx.prev_sys_response)}\n",
"updated_context": asdict(ctx),
}
)
# Get system response
system_response = schema["replies"][system_action]
if db_result_string is not None:
system_response = system_response.format(
**{title_to_snake_case(k): v for k, v in ctx.db_results_so_far.items()}
)
prev_sys_response = system_response
ctx.history.append(
get_turn_str("Agent", remove_entity_annotations(prev_sys_response))
)
# Reset database result variables
db_result_string = None
db_result_dict = None
# Handle DB calls separately
if system_response in ["[QUERY]", "[QUERY_BOOK]", "[QUERY_CHECK]"]:
db_result_string, db_result_dict = get_db_result_string(
task=task, request_type=system_response, slots=ctx.slots
)
ctx.db_results_so_far.update(**db_result_dict)
else:
waiting_for_user_input = True # as opposed to waiting for an API response
return json(
{
"response": f"{remove_entity_annotations(system_response)}\n",
"updated_context": asdict(ctx),
}
)
def chat(domain, task):
schema = load_schema_json(task)
model = load_saved_model(task=task)
ctx = DialogueContext()
print(MESSAGES["welcome"])
while True:
# Fetch user input
if db_result_string is not None:
ctx.history.append(db_result_string)
else:
user_input_raw = input("USER: >> ").strip().lower()
requested_entity_name = get_requested_entity_if_exists(prev_sys_response)
user_input_model = user_utterance_to_model_input(
user_input_raw, requested_entity_name
)
if requested_entity_name is not None:
entity_name = requested_entity_name
entity_value = get_entity(
requested_entity_name, ctx.prev_system_response, user_input_raw
)
ctx.slots[entity_name] = entity_value
ctx.history.append(user_input_model)
# Get system action and ask user to rephrase if necessary
system_action, is_ambiguous = asyncio.run(
get_system_action(model, ctx.history, domain, task)
)
if is_ambiguous:
print(
f"SYS: >> Sorry, I didn't catch that. "
f"Could you rephrase that more explicitly?\n"
f"{remove_entity_annotations(prev_sys_response)}"
)
# Undo appending of the latest user utterance.
del ctx.history[-1]
continue
system_response = schema["replies"][system_action]
if db_result_string is not None:
system_response = system_response.format(
**{title_to_snake_case(k): v for k, v in ctx.db_results_so_far.items()}
)
# Reset database result variables
db_result_string = None
db_result_dict = None
# Handle DB calls separately
if system_response in ["[QUERY]", "[QUERY_BOOK]", "[QUERY_CHECK]"]:
db_result_string, db_result_dict = get_db_result_string(
task=task, request_type=system_response, slots=ctx.slots
)
ctx.db_results_so_far.update(**db_result_dict)
else:
print(f"SYS: >> {remove_entity_annotations(system_response)}")
prev_sys_response = system_response
system_response_model = get_turn_str(
"Agent", remove_entity_annotations(prev_sys_response)
)
ctx.history.append(system_response_model)
if __name__ == "__main__":
pass