-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmain.py
143 lines (117 loc) · 4.73 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
import sys
from PyQt6.QtWidgets import QApplication
from PyQt6.QtCore import QThread, pyqtSignal
from ui import MCTSUI
from mcts import MCTSWorker
from model import LanguageModel
import time
import functools
import queue
import logging
logging.basicConfig(level=logging.DEBUG, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
MODEL_NAME = "distilgpt2"
MODEL_CACHE_DIR = "path/to/your/cache/directory"
# Decorator to measure and print the execution time of functions
def timing_decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
start_time = time.time()
result = func(*args, **kwargs)
end_time = time.time()
print(f"{func.__name__} took {end_time - start_time:.4f} seconds")
return result
return wrapper
# Thread to handle UI updates asynchronously
class UIUpdateThread(QThread):
update_signal = pyqtSignal(list, int)
def __init__(self, ui):
super().__init__()
self.ui = ui
self.running = True
def run(self):
while self.running:
try:
paths, total_nodes = self.ui.update_queue.get(timeout=0.1)
self.update_signal.emit(paths, total_nodes)
except queue.Empty:
continue
except Exception as e:
print(f"Error in UIUpdateThread: {e}")
# Main application window class
class MainWindow(MCTSUI):
def __init__(self):
super().__init__()
try:
logger.info("Initializing MainWindow...")
logger.info("Initializing language model...")
self.language_model = LanguageModel.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE_DIR)
logger.info("Creating MCTS worker...")
self.worker = MCTSWorker(self.language_model, num_workers=16)
logger.info("Setting language model in UI...")
self.set_language_model(self.language_model)
logger.info("Connecting signals...")
self.connect_signals()
logger.info("Initializing UI update thread...")
self.init_ui_update_thread()
logger.info("MainWindow initialization complete.")
except Exception as e:
logger.error(f"Error initializing MainWindow: {e}", exc_info=True)
raise
def connect_signals(self):
"""Connect various signals to their respective slots"""
self.worker.finished.connect(self.reset_ui)
self.worker.update_ui_signal.connect(self.update_ui)
self.start_mcts_signal.connect(self.start_mcts_worker)
self.reset_mcts_signal.connect(self.stop_mcts)
self.pause_mcts_signal.connect(self.pause_mcts_worker)
self.worker.performance_signal.connect(self.log_performance)
def init_ui_update_thread(self):
"""Initialize and start the UI update thread"""
self.ui_update_thread = UIUpdateThread(self)
self.ui_update_thread.update_signal.connect(self.update_ui)
self.ui_update_thread.start()
@timing_decorator
def start_mcts_worker(self, prompt, temperature, min_prob, entropy_factor):
"""Start the MCTS worker with given parameters"""
prompt_token_ids = self.language_model.tokenizer.encode(prompt)
self.worker.set_params(prompt_token_ids, temperature, min_prob, entropy_factor)
self.worker.start()
def log_performance(self, performance_data):
"""Log performance data from the MCTS worker"""
print(f"Performance: {performance_data}")
def stop_mcts(self):
"""Stop the MCTS process"""
print("Stopping MCTS")
self.worker.running = False
self.worker.wait()
self.reset_ui()
def pause_mcts_worker(self):
"""Pause or resume the MCTS worker"""
if self.worker.paused:
self.worker.resume()
else:
self.worker.pause()
def start_mcts(self):
"""Start the MCTS process"""
print("Starting MCTS")
super().start_mcts()
def reset_mcts(self):
"""Reset the MCTS process"""
print("Resetting MCTS")
super().reset_mcts()
def update_ui(self, paths, total_nodes):
"""Update the UI with new MCTS data"""
print(f"MainWindow update_ui called: {len(paths)} paths, {total_nodes} total nodes")
super().update_ui(paths, total_nodes)
if __name__ == "__main__":
app = QApplication(sys.argv)
try:
logger.info("Starting application...")
main_window = MainWindow()
main_window.show()
logger.info("Entering main event loop...")
sys.exit(app.exec())
except Exception as e:
logger.error(f"Error in main: {e}", exc_info=True)
sys.exit(1)