Skip to content

Commit

Permalink
Allows worker to me embedded into KAI client (#293)
Browse files Browse the repository at this point in the history
* allows to run as daemon:

* allow for exist from start_job not finding jobs

* allow UI to exit

* stop UI when stopping worker

* fix: scribe timeout won't retry:

* lint: E721
  • Loading branch information
db0 authored Sep 16, 2023
1 parent 6223435 commit fbb0bfb
Show file tree
Hide file tree
Showing 5 changed files with 42 additions and 11 deletions.
2 changes: 1 addition & 1 deletion bridgeData_template.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ max_context_length: 1024
# When set to true, the horde alias behind the API key will be appended to the model that is advertised to the horde
# This will prevent the model from being used from the shared pool, but will ensure that no other worker
# can pretend to serve it
branded_model: true
branded_model: false

## Alchemist (Image interrogation and post-processing)

Expand Down
3 changes: 1 addition & 2 deletions worker/bridge_data/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,7 @@ def reload_models(self, model_manager):
if self.models_reloading:
return
self.models_reloading = True
thread = threading.Thread(target=self._reload_models, args=(model_manager,))
thread.daemon = True
thread = threading.Thread(target=self._reload_models, args=(model_manager,), daemon=True)
thread.start()

@logger.catch(reraise=True)
Expand Down
16 changes: 12 additions & 4 deletions worker/jobs/scribe.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ def __init__(self, mm, bd, pop):
self.current_payload["quiet"] = True
self.requested_softprompt = self.current_payload.get("softprompt")
self.censored = None
self.max_seconds = None

@logger.catch(reraise=True)
def start_job(self):
Expand All @@ -35,7 +36,9 @@ def start_job(self):
if self.status == JobStatus.FAULTED:
self.start_submit_thread()
return
self.stale_time = time.time() + (self.current_payload.get("max_length", 80) / 2) + 10
# we also re-use this for the https timeout to llm inference
self.max_seconds = (self.current_payload.get("max_length", 80) / 2) + 10
self.stale_time = time.time() + self.max_seconds
# These params will always exist in the payload from the horde
gen_payload = self.current_payload
if "width" in gen_payload or "length" in gen_payload or "steps" in gen_payload:
Expand Down Expand Up @@ -63,14 +66,19 @@ def start_job(self):
gen_req = requests.post(
self.bridge_data.kai_url + "/api/latest/generate/",
json=self.current_payload,
timeout=300,
timeout=self.max_seconds,
)
except (requests.exceptions.ConnectionError, requests.exceptions.ReadTimeout):
except requests.exceptions.ConnectionError:
logger.error(f"Worker {self.bridge_data.kai_url} unavailable. Retrying in 3 seconds...")
loop_retry += 1
time.sleep(3)
continue
if not isinstance(gen_req.json(), dict):
except requests.exceptions.ReadTimeout:
logger.error(f"Worker {self.bridge_data.kai_url} request timeout. Aborting.")
self.status = JobStatus.FAULTED
self.start_submit_thread()
return
if isinstance(gen_req.json(), dict):
logger.error(
(
f"KAI instance {self.bridge_data.kai_url} API unexpected response on generate: {gen_req}. "
Expand Down
9 changes: 9 additions & 0 deletions worker/ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ class TerminalUI:
CLIENT_AGENT = "terminalui:1:db0"

def __init__(self, bridge_data):
self.should_stop = False
self.bridge_data = bridge_data

self.dreamer_worker = False
Expand Down Expand Up @@ -802,9 +803,13 @@ def poll(self):
def main_loop(self, stdscr):
self.main = stdscr
while True:
if self.should_stop:
return
try:
self.initialise()
while True:
if self.should_stop:
return
if self.poll():
return
time.sleep(1 / self.gpu.samples_per_second)
Expand All @@ -814,8 +819,12 @@ def main_loop(self, stdscr):
logger.error(str(exc))

def run(self):
self.should_stop = False
curses.wrapper(self.main_loop)

def stop(self):
self.should_stop = True

def get_hordelib_version(self):
try:
return pkg_resources.get_distribution("hordelib").version
Expand Down
23 changes: 19 additions & 4 deletions worker/workers/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def __init__(self, this_model_manager, this_bridge_data):
self.run_count = 0
self.pilot_job_was_run = False
self.last_config_reload = 0
self.is_daemon = False
self.should_stop = False
self.should_restart = False
self.consecutive_executor_restarts = 0
Expand All @@ -25,6 +26,7 @@ def __init__(self, this_model_manager, this_bridge_data):
self.soft_restarts = 0
self.executor = None
self.ui = None
self.ui_class = None
self.last_stats_time = time.time()
logger.stats("Starting new stats session")
# These two should be filled in by the extending classes
Expand All @@ -42,14 +44,20 @@ def startup_terminal_ui(self):
else:
from worker.ui import TerminalUI

ui = TerminalUI(self.bridge_data)
self.ui = threading.Thread(target=ui.run, daemon=True)
self.ui_class = TerminalUI(self.bridge_data)
self.ui = threading.Thread(target=self.ui_class.run, daemon=True)
self.ui.start()

def on_restart(self):
"""Called when the worker loop is restarted. Make sure to invoke super().on_restart() when overriding."""
self.soft_restarts += 1

@logger.catch(reraise=True)
def stop(self):
self.should_stop = True
self.ui_class.stop()
logger.info("Stop methods called")

@logger.catch(reraise=True)
def start(self):
self.reload_data()
Expand Down Expand Up @@ -83,7 +91,10 @@ def start(self):
logger.error("Too many soft restarts, exiting the worker. Please review your config.")
logger.error("You can try asking for help in the official discord if this persists.")
logger.init("Worker", status="Shutting Down")
sys.exit(self.exit_rc)
if self.is_daemon:
return
else: # noqa: RET505
sys.exit(self.exit_rc)

def process_jobs(self):
# logger.debug("Cron: Starting process_jobs()")
Expand Down Expand Up @@ -147,6 +158,8 @@ def start_job(self):
if self.bridge_data.queue_size == 0:
if jobs := self.pop_job():
job = jobs[0]
if self.should_stop:
return False
elif len(self.waiting_jobs) > 0:
job = self.waiting_jobs.pop(0)
else:
Expand Down Expand Up @@ -229,7 +242,9 @@ def get_uptime_kudos(self):

def reload_data(self):
"""This is just a utility function to reload the configuration"""
self.bridge_data.reload_data()
# Daemons are fed the configuration externally
if not self.is_daemon:
self.bridge_data.reload_data()

def reload_bridge_data(self):
self.reload_data()
Expand Down

0 comments on commit fbb0bfb

Please sign in to comment.