diff --git a/civitai/hashes.py b/civitai/hashes.py index 43118d8..d2cf5df 100644 --- a/civitai/hashes.py +++ b/civitai/hashes.py @@ -1,4 +1,4 @@ -from blake3 import blake3 as hasher +# from blake3 import blake3 as hasher import os.path from modules import hashes as sd_hashes @@ -9,50 +9,50 @@ # It was not noticably faster in our tests, so we are sticking with SHA256 for now. # Especially since SHA256 is the standard inside this UI. -cache_key = "civitai_hashes" +# cache_key = "civitai_hashes" -def blake3_from_cache(filename, title): - hashes = sd_hashes.cache(cache_key) - ondisk_mtime = os.path.getmtime(filename) +# def blake3_from_cache(filename, title): +# hashes = sd_hashes.cache(cache_key) +# ondisk_mtime = os.path.getmtime(filename) - if title not in hashes: - return None +# if title not in hashes: +# return None - cached_blake3 = hashes[title].get("blake3", None) - cached_mtime = hashes[title].get("mtime", 0) +# cached_blake3 = hashes[title].get("blake3", None) +# cached_mtime = hashes[title].get("mtime", 0) - if ondisk_mtime > cached_mtime or cached_blake3 is None: - return None +# if ondisk_mtime > cached_mtime or cached_blake3 is None: +# return None - return cached_blake3 +# return cached_blake3 -def calculate_blake3(filename): - hash_blake3 = hasher() - blksize = 1024 * 1024 +# def calculate_blake3(filename): +# hash_blake3 = hasher() +# blksize = 1024 * 1024 - with open(filename, "rb") as f: - for chunk in iter(lambda: f.read(blksize), b""): - hash_blake3.update(chunk) +# with open(filename, "rb") as f: +# for chunk in iter(lambda: f.read(blksize), b""): +# hash_blake3.update(chunk) - return hash_blake3.hexdigest() +# return hash_blake3.hexdigest() -def blake3(filename, title): - hashes = sd_hashes.cache(cache_key) +# def blake3(filename, title): +# hashes = sd_hashes.cache(cache_key) - blake3_value = blake3_from_cache(filename, title) - if blake3_value is not None: - return blake3_value +# blake3_value = blake3_from_cache(filename, title) +# if blake3_value is not None: +# return blake3_value - print(f"Calculating blake3 for {filename}: ", end='') - blake3_value = calculate_blake3(filename) - print(f"{blake3_value}") +# print(f"Calculating blake3 for {filename}: ", end='') +# blake3_value = calculate_blake3(filename) +# print(f"{blake3_value}") - if title not in hashes: - hashes[title] = {} +# if title not in hashes: +# hashes[title] = {} - hashes[title]["mtime"] = os.path.getmtime(filename) - hashes[title]["blake3"] = blake3_value +# hashes[title]["mtime"] = os.path.getmtime(filename) +# hashes[title]["blake3"] = blake3_value - sd_hashes.dump_cache() +# sd_hashes.dump_cache() - return blake3_value +# return blake3_value diff --git a/civitai/models.py b/civitai/models.py index ef5a48c..f67359d 100644 --- a/civitai/models.py +++ b/civitai/models.py @@ -45,6 +45,10 @@ class ResourceRequest(BaseModel): url: str = Field(default=None, title="URL", description="The URL of the resource to download.", required=False) previewImage: str = Field(default=None, title="Preview Image", description="The URL of the preview image.", required=False) +class RoomPresence(BaseModel): + client: int = Field(default=None, title="Clients", description="The number of clients in the room") + sd: int = Field(default=None, title="Stable Diffusion Clients", description="The number of Stable Diffusion Clients in the room") + class Command(BaseModel): id: str = Field(default=None, title="ID", description="The ID of the command.") type: CommandTypes = Field(default=None, title="Type", description="The type of command to execute.") diff --git a/install.py b/install.py index e1b65c5..9ef04f2 100644 --- a/install.py +++ b/install.py @@ -1,21 +1,27 @@ -import filecmp -import importlib.util +import subprocess import os -import shutil import sys -import sysconfig import git from launch import run -if sys.version_info < (3, 8): - import importlib_metadata -else: - import importlib.metadata as importlib_metadata - req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") +def is_package_installed(package_name, version): + # strip [] from package name + package_name = package_name.split("[")[0] + try: + result = subprocess.run(['pip', 'show', package_name], stdout=subprocess.PIPE, stderr=subprocess.PIPE) + except FileNotFoundError: + return False + if result.returncode == 0: + for line in result.stdout.decode('utf-8').splitlines(): + if line.startswith('Version: '): + installed_version = line.split(' ')[-1] + if installed_version == version: + return True + return False def check_versions(): global req_file @@ -27,29 +33,11 @@ def check_versions(): if len(splits) == 2: key = splits[0] reqs_dict[key] = splits[1].replace("\n", "").strip() - # print(f"Reqs dict: {reqs_dict}") - checks = ["socketio[client]","blake3"] - for check in checks: - check_ver = "N/A" - status = "[ ]" - try: - check_available = importlib.util.find_spec(check) is not None - if check_available: - check_ver = importlib_metadata.version(check) - if check in reqs_dict: - req_version = reqs_dict[check] - if str(check_ver) == str(req_version): - status = "[+]" - else: - status = "[!]" - except importlib_metadata.PackageNotFoundError: - check_available = False - if not check_available: - status = "[!]" - print(f"{status} {check} NOT installed.") - else: - print(f"{status} {check} version {check_ver} installed.") - + # Loop through reqs and check if installed + for req in reqs_dict: + available = is_package_installed(req, reqs_dict[req]) + if available: print(f"[+] {req} version {reqs_dict[req]} installed.") + else : print(f"[!] {req} version {reqs_dict[req]} NOT installed.") base_dir = os.path.dirname(os.path.realpath(__file__)) revision = "" diff --git a/requirements.txt b/requirements.txt index c097620..8635dbf 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1 @@ -python-socketio[client]==5.7.2 -blake3==0.3.3 \ No newline at end of file +python-socketio[client]==5.7.2 \ No newline at end of file diff --git a/scripts/main.py b/scripts/main.py index 0f8e0bb..8f138c2 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -7,7 +7,7 @@ import os import extensions.sd_civitai_extension.civitai.lib as civitai -from extensions.sd_civitai_extension.civitai.models import Command, CommandActivitiesList, CommandResourcesAdd, CommandActivitiesCancel, CommandResourcesList, CommandResourcesRemove, ErrorPayload, JoinedPayload, UpgradeKeyPayload +from extensions.sd_civitai_extension.civitai.models import Command, CommandActivitiesList, CommandResourcesAdd, CommandActivitiesCancel, CommandResourcesList, CommandResourcesRemove, ErrorPayload, JoinedPayload, RoomPresence, UpgradeKeyPayload from modules import shared, sd_models, script_callbacks, hashes @@ -105,11 +105,26 @@ def on_resources_remove(payload: CommandResourcesRemove): socketio_url = 'https://link.civitai.com' sio = socketio.Client() +should_reconnect = False @sio.event def connect(): - civitai.log('Connected to Civitai Link') + global should_reconnect + + civitai.log('Connected to Civitai Link Server') sio.emit('iam', {"type": "sd"}) + if should_reconnect: + key = shared.opts.data.get("civitai_link_key", None) + if key is None: return + join_room(key) + should_reconnect = False + +@sio.event +def disconnect(): + global should_reconnect + + civitai.log('Disconnected from Civitai Link Server') + should_reconnect = True @sio.on('command') def on_command(payload: Command): @@ -125,10 +140,10 @@ def on_command(payload: Command): elif command == 'resources:remove': return on_resources_remove(payload) -@sio.on('linkStatus') -def on_link_status(payload: bool): - civitai.connected = payload - civitai.log("Civitai Link ready") +@sio.on('roomPresence') +def on_link_status(payload: RoomPresence): + civitai.log(f"Presence update: SD: {payload['sd']}, Clients: {payload['client']}") + civitai.connected = payload['sd'] > 0 and payload['client'] > 0 @sio.on('upgradeKey') def on_upgrade_key(payload: UpgradeKeyPayload): @@ -139,11 +154,6 @@ def on_upgrade_key(payload: UpgradeKeyPayload): def on_error(payload: ErrorPayload): civitai.log(f"Error: {payload['msg']}") -@sio.on('joined') -def on_joined(payload: JoinedPayload): - if payload['type'] != 'client': return - civitai.log("Client joined") - def command_response(payload, history=False): payload['updatedAt'] = datetime.now(timezone.utc).isoformat() if history: civitai.add_activity(payload)