-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmcts.py
242 lines (205 loc) · 9.52 KB
/
mcts.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
import asyncio
import random
from PyQt6.QtCore import QThread, pyqtSignal
from tree import Node, Tree
import math
import jax.numpy as jnp
import psutil
import os
import time
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
class MCTSWorker(QThread):
finished = pyqtSignal()
update_ui_signal = pyqtSignal(list, int)
performance_signal = pyqtSignal(dict)
def __init__(self, language_model, num_workers=4):
super().__init__()
self.language_model = language_model
self.tree = None
self.running = False
self.temperature = 1.0
self.min_prob = 1e-6
self.entropy_factor = 3.0
self.eps = 0.01
self.num_workers = num_workers
self.lock = asyncio.Lock()
self.performance_data = self.reset_performance_data()
self.paused = False
self.pause_condition = asyncio.Condition()
def set_params(self, prompt_token_ids, temperature, min_prob, entropy_factor):
self.prompt_token_ids = prompt_token_ids
self.temperature = temperature
self.min_prob = min_prob
self.entropy_factor = entropy_factor
# Removed eps
def run(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
self.language_model.set_event_loop(loop)
loop.run_until_complete(self.run_async())
async def run_async(self):
try:
logger.info("Starting MCTS run_async")
self.reset_performance_data()
self.tree = Tree(self.prompt_token_ids)
iteration = 0
process = psutil.Process(os.getpid())
self.running = True
last_update_time = 0
update_interval = 0.5 # Update UI every 0.5 seconds
start_time = time.time()
while self.running:
async with self.pause_condition:
while self.paused:
await self.pause_condition.wait()
iteration_start = time.time()
logger.debug(f"Starting iteration {iteration}")
await self.mcts_iteration()
iteration_time = time.time() - iteration_start
logger.debug(f"Iteration {iteration} completed in {iteration_time:.4f} seconds")
self.performance_data["iterations"] += 1
self.performance_data["total_time"] += iteration_time
# Limit UI updates
current_time = time.time()
if current_time - last_update_time >= update_interval:
paths = self.get_all_leaf_paths(self.tree.root)
total_nodes = self.count_nodes(self.tree.root)
logger.debug(f"Emitting update signal: {len(paths)} paths, {total_nodes} total nodes")
self.update_ui_signal.emit(paths, total_nodes)
last_update_time = current_time
# Calculate average times
for key in ["selection", "expansion", "simulation", "backpropagation", "language_model"]:
count_key = f"{key}_count"
if count_key in self.performance_data and self.performance_data[count_key] > 0:
avg_time = self.performance_data.get(f"{key}_time", 0) / self.performance_data[count_key]
self.performance_data[f"avg_{key}_time"] = avg_time
# Emit performance data
self.performance_signal.emit(self.performance_data)
# Check memory usage
memory_percent = process.memory_percent()
if memory_percent > 80:
logger.warning(f"Memory usage reached {memory_percent:.2f}%. Stopping MCTS.")
break
iteration += 1
await asyncio.sleep(0.01) # Add a small delay between iterations
end_time = time.time()
logger.info(f"MCTS completed in {end_time - start_time:.2f} seconds")
self.finished.emit()
except Exception as e:
logger.error(f"Error in MCTS process: {e}", exc_info=True)
self.finished.emit()
def reset(self):
self.running = False
self.tree = None
async def mcts_iteration(self):
logger.debug("Starting MCTS iteration")
leaf_nodes = []
for _ in range(self.num_workers):
leaf_node = await self.select(self.tree.root)
if not leaf_node.children:
leaf_nodes.append(leaf_node)
logger.debug(f"Selected {len(leaf_nodes)} leaf nodes")
if leaf_nodes:
logger.debug("Expanding leaf nodes")
await self.expand_batch(leaf_nodes)
for leaf_node in leaf_nodes:
logger.debug("Simulating and backpropagating")
value = await self.simulate(leaf_node)
await self.backpropagate(leaf_node, value)
logger.debug("MCTS iteration completed")
async def select(self, node):
async with self.lock:
start_time = time.time()
while node.children:
if not all(child.visits > 0 for child in node.children):
node = random.choice([child for child in node.children if child.visits == 0])
else:
ucb_values = [
(child.value / child.visits) + math.sqrt(2 * math.log(node.visits) / child.visits)
for child in node.children
]
node = node.children[ucb_values.index(max(ucb_values))]
self.performance_data["selection_time"] += time.time() - start_time
self.performance_data["select_count"] += 1
return node
async def expand_batch(self, nodes):
states = [self.get_state(node) for node in nodes]
results = await self.language_model.extract_distribution_batch(
states,
temperature=self.temperature,
min_prob=self.min_prob,
entropy_factor=self.entropy_factor
)
for node, result in zip(nodes, results):
if result is None:
continue
distribution, entropy, raw_sum = result
node.entropy = entropy
node.raw_sum = raw_sum # Store the raw sum in the node
for token_id, prob in distribution.items():
child = Node([int(token_id)], parent=node)
child.logprob = jnp.log(prob)
node.add_child(child)
node.child_count = len(node.children)
async def simulate(self, node):
start_time = time.time()
value = node.logprob if node.parent and not jnp.isinf(node.logprob) else 0.0
self.performance_data["simulation_time"] += time.time() - start_time
self.performance_data["simulate_count"] += 1
return value
async def backpropagate(self, node, value):
start_time = time.time()
async with self.lock:
while node:
node.visits += 1
node.value += math.exp(value) if not jnp.isinf(value) else 0
node = node.parent
self.performance_data["backpropagation_time"] += time.time() - start_time
self.performance_data["backpropagate_count"] += 1
def get_all_leaf_paths(self, root):
paths = []
self.dfs_parents(root, [], 0.0, paths)
return sorted(paths, key=lambda x: x[1], reverse=True)
def get_state(self, node):
path = []
current = node
while current:
path.append(current.token_ids)
current = current.parent
return jnp.concatenate(path[::-1]) # Reverse the path and concatenate
def dfs_parents(self, node, current_path, current_logprob, paths):
new_path = current_path + list(node.token_ids)
new_logprob = current_logprob + (node.logprob if node != self.tree.root and not jnp.isinf(node.logprob) else 0.0)
if node.children and all(not child.children for child in node.children):
child_data = [(int(child.token_ids[0]), math.exp(child.logprob), child.logprob) for child in node.children]
child_data.sort(key=lambda x: x[1], reverse=True)
paths.append((new_path, math.exp(new_logprob) if not jnp.isinf(new_logprob) else 0,
node.entropy, node.depth, node.raw_sum, len(node.children), child_data))
else:
for child in node.children:
self.dfs_parents(child, new_path, new_logprob, paths)
def count_nodes(self, node):
return 1 + sum(self.count_nodes(child) for child in node.children)
def reset_performance_data(self):
return {
"iterations": 0,
"total_time": 0,
"selection_time": 0,
"expansion_time": 0,
"simulation_time": 0,
"backpropagation_time": 0,
"language_model_time": 0,
"select_count": 0,
"expand_count": 0,
"simulate_count": 0,
"backpropagate_count": 0,
"language_model_count": 0,
"selection_count": 0,
}
def pause(self):
self.paused = True
def resume(self):
self.paused = False
asyncio.run_coroutine_threadsafe(self.pause_condition.notify_all(), self.loop)