diff --git a/civitai/hashes.py b/civitai/hashes.py new file mode 100644 index 0000000..43118d8 --- /dev/null +++ b/civitai/hashes.py @@ -0,0 +1,58 @@ +from blake3 import blake3 as hasher +import os.path + +from modules import hashes as sd_hashes + +# NOTE: About this file +# --------------------------------- +# This is not being used. It was added to see if Blake3 was faster than SHA256. +# It was not noticably faster in our tests, so we are sticking with SHA256 for now. +# Especially since SHA256 is the standard inside this UI. + +cache_key = "civitai_hashes" + +def blake3_from_cache(filename, title): + hashes = sd_hashes.cache(cache_key) + ondisk_mtime = os.path.getmtime(filename) + + if title not in hashes: + return None + + cached_blake3 = hashes[title].get("blake3", None) + cached_mtime = hashes[title].get("mtime", 0) + + if ondisk_mtime > cached_mtime or cached_blake3 is None: + return None + + return cached_blake3 + +def calculate_blake3(filename): + hash_blake3 = hasher() + blksize = 1024 * 1024 + + with open(filename, "rb") as f: + for chunk in iter(lambda: f.read(blksize), b""): + hash_blake3.update(chunk) + + return hash_blake3.hexdigest() + +def blake3(filename, title): + hashes = sd_hashes.cache(cache_key) + + blake3_value = blake3_from_cache(filename, title) + if blake3_value is not None: + return blake3_value + + print(f"Calculating blake3 for {filename}: ", end='') + blake3_value = calculate_blake3(filename) + print(f"{blake3_value}") + + if title not in hashes: + hashes[title] = {} + + hashes[title]["mtime"] = os.path.getmtime(filename) + hashes[title]["blake3"] = blake3_value + + sd_hashes.dump_cache() + + return blake3_value diff --git a/civitai/lib.py b/civitai/lib.py index 95a31b9..93be7f3 100644 --- a/civitai/lib.py +++ b/civitai/lib.py @@ -1,29 +1,61 @@ import json import os +import shutil +import tempfile +import time import requests -import re +import glob -from importlib import import_module -from basicsr.utils.download_util import load_file_from_url -from modules import shared, sd_models, sd_vae +from tqdm import tqdm +from modules import shared, sd_models, sd_vae, hashes from modules.paths import models_path from extensions.sd_civitai_extension.civitai.models import ResourceRequest +#region shared variables try: base_url = shared.cmd_opts.civitai_endpoint except: base_url = 'https://civitai.com/api/v1' +connected = False +user_agent = 'CivitaiLink:Automatic1111' +download_chunk_size = 8192 +#endregion + #region Utils def log(message): """Log a message to the console.""" print(f'Civitai: {message}') -def parse_hypernetwork(string): - match = re.search(r'(.+)\(([^)]+)', string) - if match: - return {"name": match.group(1), "hash": match.group(2), "type": "Hypernetwork"} - return {"name": "", "hash": "", "type": "Hypernetwork"} +def download_file(url, dest, on_progress=None): + if os.path.exists(dest): + log(f'File already exists: {dest}') + + log(f'Downloading: "{url}" to {dest}\n') + + response = requests.get(url, stream=True, headers={"User-Agent": user_agent}) + total = int(response.headers.get('content-length', 0)) + start_time = time.time() + + dest = os.path.expanduser(dest) + dst_dir = os.path.dirname(dest) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + + try: + current = 0 + with tqdm(total=total, unit='B', unit_scale=True, unit_divisor=1024) as bar: + for data in response.iter_content(chunk_size=download_chunk_size): + current += len(data) + pos = f.write(data) + bar.update(pos) + if on_progress is not None: + on_progress(current, total, start_time) + f.close() + shutil.move(f.name, dest) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) #endregion Utils #region API @@ -31,7 +63,7 @@ def req(endpoint, method='GET', data=None, params=None, headers=None): """Make a request to the Civitai API.""" if headers is None: headers = {} - headers['User-Agent'] = 'Automatic1111' + headers['User-Agent'] = user_agent if data is not None: data = json.dumps(data) if not endpoint.startswith('/'): @@ -83,34 +115,140 @@ def get_tags(query, page=1, page_size=20): return response #endregion API -#region Auto Utils +#region Get Utils +def get_automatic_type(type: str): + if type == 'Hypernetwork': return 'hypernet' + return type.lower() + +def get_automatic_name(type: str, filename: str, folder: str): + abspath = os.path.abspath(filename) + if abspath.startswith(folder): + fullname = abspath.replace(folder, '') + else: + fullname = os.path.basename(filename) + + if fullname.startswith("\\") or fullname.startswith("/"): + fullname = fullname[1:] + + if type == 'Checkpoint': return fullname + return os.path.splitext(fullname)[0] + +def has_preview(filename: str): + return os.path.isfile(os.path.splitext(filename)[0] + '.png') + +def get_resources_in_folder(type, folder, exts=[], exts_exclude=[]): + resources = [] + os.makedirs(folder, exist_ok=True) + + candidates = [] + for ext in exts: + candidates += glob.glob(os.path.join(folder, '**/*.' + ext), recursive=True) + for ext in exts_exclude: + candidates = [x for x in candidates if not x.endswith(ext)] + + folder = os.path.abspath(folder) + automatic_type = get_automatic_type(type) + for filename in sorted(candidates): + if os.path.isdir(filename): + continue + + name = os.path.splitext(filename)[0] + automatic_name = get_automatic_name(type, filename, folder) + hash = hashes.sha256(filename, f"{automatic_type}/{automatic_name}") + + resources.append({'type': type, 'name': name, 'hash': hash, 'path': filename, 'hasPreview': has_preview(filename) }) + + return resources + +resources = [] +def load_resource_list(types=['LORA', 'Hypernetwork', 'TextualInversion', 'Checkpoint']): + global resources + + # If resources is empty and types is empty, load all types + # This is a helper to be able to get the resource list without + # having to worry about initialization. On subsequent calls, no work will be done + if len(resources) == 0 and len(types) == 0: + types = ['LORA', 'Hypernetwork', 'TextualInversion', 'Checkpoint'] + + if 'LORA' in types: + resources = [r for r in resources if r['type'] != 'LORA'] + resources += get_resources_in_folder('LORA', shared.cmd_opts.lora_dir, ['pt', 'safetensors', 'ckpt']) + if 'Hypernetwork' in types: + resources = [r for r in resources if r['type'] != 'Hypernetwork'] + resources += get_resources_in_folder('Hypernetwork', shared.cmd_opts.hypernetwork_dir, ['pt', 'safetensors', 'ckpt']) + if 'TextualInversion' in types: + resources = [r for r in resources if r['type'] != 'TextualInversion'] + resources += get_resources_in_folder('TextualInversion', shared.cmd_opts.embeddings_dir, ['pt']) + if 'Checkpoint' in types: + resources = [r for r in resources if r['type'] != 'Checkpoint'] + resources += get_resources_in_folder('Checkpoint', sd_models.model_path, ['safetensors', 'ckpt'], ['vae.safetensors', 'vae.ckpt']) + + return resources + +def get_resource_by_hash(hash: str): + resources = load_resource_list([]) + + found = [resource for resource in resources if hash.lower() == resource['hash']] + if found: + return found[0] + + return None + def get_model_by_hash(hash: str): found = [info for info in sd_models.checkpoints_list.values() if hash == info.sha256 or hash == info.shorthash or hash == info.hash] if found: return found[0] return None -#endregion Auto Utils + +#endregion Get Utils + +#region Removing +def remove_resource(resource: ResourceRequest): + removed = None + target = get_resource_by_hash(resource['hash']) + if target is None or target['type'] != resource['type']: removed = False + elif os.path.exists(target['path']): + os.remove(target['path']) + removed = True + + if removed == True: + log(f'Removed resource') + load_resource_list([resource['type']]) + if resource['type'] == 'Checkpoint': + sd_models.list_models() + elif resource['type'] == 'Hypernetwork': + shared.reload_hypernetworks() + # elif resource['type'] == 'LORA': + # TODO: reload LORA + elif removed == None: + log(f'Resource not found') +#endregion Removing #region Downloading -def load_if_missing(path, url): +def load_if_missing(path, url, on_progress=None): if os.path.exists(path): return True if url is None: return False - dir, file = os.path.split(path) - load_file_from_url(url, dir, True, file) + download_file(url, path, on_progress) return None -async def load_resource(resource: ResourceRequest): - if resource.type == 'Checkpoint': await load_model(resource) - if resource.type == 'CheckpointConfig': await load_model_config(resource) - elif resource.type == 'Hypernetwork': await load_hypernetwork(resource) - elif resource.type == 'TextualInversion': await load_textual_inversion(resource) - elif resource.type == 'AestheticGradient': await load_aesthetic_gradient(resource) - elif resource.type == 'VAE': await load_vae(resource) - elif resource.type == 'LORA': await load_lora(resource) +def load_resource(resource: ResourceRequest, on_progress=None): + resource['hash'] = resource['hash'].lower() + existing_resource = get_resource_by_hash(resource['hash']) + if existing_resource: + log(f'Already have resource: {resource["name"]}') + return + + if resource['type'] == 'Checkpoint': load_model(resource, on_progress) + elif resource['type'] == 'CheckpointConfig': load_model_config(resource, on_progress) + elif resource['type'] == 'Hypernetwork': load_hypernetwork(resource, on_progress) + elif resource['type'] == 'TextualInversion': load_textual_inversion(resource, on_progress) + elif resource['type'] == 'LORA': load_lora(resource, on_progress) -async def fetch_model_by_hash(hash: str): + load_resource_list([resource['type']]) + +def fetch_model_by_hash(hash: str): model_version = get_model_version_by_hash(hash) if model_version is None: log(f'Could not find model version with hash {hash}') @@ -123,65 +261,71 @@ async def fetch_model_by_hash(hash: str): hash=file['hashes']['SHA256'], url=file['downloadUrl'] ) - await load_resource(resource) - -async def load_model_config(resource: ResourceRequest): - load_if_missing(os.path.join(models_path, 'stable-diffusion', resource.name), resource.url) + load_resource(resource) -async def load_model(resource: ResourceRequest): - if shared.opts.data["sd_checkpoint_hash"] == resource.hash: return +def load_model_config(resource: ResourceRequest, on_progress=None): + load_if_missing(os.path.join(sd_models.model_path, resource['name']), resource['url'], on_progress) - model = get_model_by_hash(resource.hash) +def load_model(resource: ResourceRequest, on_progress=None): + model = get_model_by_hash(resource['hash']) if model is not None: log('Found model in model list') - if model is None and resource.url is not None: + if model is None and resource['url'] is not None: log('Downloading model') - load_file_from_url(resource.url, os.path.join(models_path, 'stable-diffusion'), True, resource.name) + download_file(resource['url'], os.path.join(sd_models.model_path, resource['name']), on_progress) sd_models.list_models() - model = get_model_by_hash(resource.hash) + model = get_model_by_hash(resource['hash']) - if model is not None: - sd_models.load_model(model) - shared.opts.save(shared.config_filename) - else: log('Could not find model and no URL was provided') + return model -async def load_hypernetwork(resource: ResourceRequest): - # TODO: rig some way to work with hashes instead of names to avoid collisions +def load_textual_inversion(resource: ResourceRequest, on_progress=None): + load_if_missing(os.path.join(shared.cmd_opts.embeddings_dir, resource['name']), resource['url'], on_progress) - if shared.opts.sd_hypernetwork == resource.name: - log('Hypernetwork already loaded') - return +def load_lora(resource: ResourceRequest, on_progress=None): + isAvailable = load_if_missing(os.path.join(shared.cmd_opts.lora_dir, resource['name']), resource['url'], on_progress) + # TODO: reload lora list - not sure best way to import this + # if isAvailable is None: + # lora.list_available_loras() + +def load_vae(resource: ResourceRequest, on_progress=None): + # TODO: find by hash instead of name + if not resource['name'].endswith('.pt'): resource['name'] += '.pt' + full_path = os.path.join(models_path, 'VAE', resource['name']) - full_path = os.path.join(models_path, 'hypernetworks', resource.name); + isAvailable = load_if_missing(full_path, resource['url'], on_progress) + if isAvailable is None: + sd_vae.refresh_vae_list() + +def load_hypernetwork(resource: ResourceRequest, on_progress=None): + full_path = os.path.join(shared.cmd_opts.hypernetwork_dir, resource['name']); if not full_path.endswith('.pt'): full_path += '.pt' - isAvailable = load_if_missing(full_path, resource.url) - if not isAvailable: - log('Could not find hypernetwork') - return + isAvailable = load_if_missing(full_path, resource['url'], on_progress) + if isAvailable is None: + shared.reload_hypernetworks() - shared.opts.sd_hypernetwork = resource.name - shared.opts.save(shared.config_filename) - shared.reload_hypernetworks() +#endregion Downloading -async def load_textual_inversion(resource: ResourceRequest): - # TODO: rig some way to work with hashes instead of names to avoid collisions - load_if_missing(os.path.join('embeddings', resource.name), resource.url) +#region Selecting Resources +def select_model(resource: ResourceRequest): + if shared.opts.data["sd_checkpoint_hash"] == resource['hash']: return -async def load_aesthetic_gradient(resource: ResourceRequest): - # TODO: rig some way to work with hashes instead of names to avoid collisions - load_if_missing(os.path.join('extensions/stable-diffusion-webui-aesthetic-gradients','aesthetic_embeddings', resource.name), resource.url) + model = load_model(resource) -async def load_vae(resource: ResourceRequest): - # TODO: rig some way to work with hashes instead of names to avoid collisions + if model is not None: + sd_models.load_model(model) + shared.opts.save(shared.config_filename) + else: log('Could not find model and no URL was provided') - if not resource.name.endswith('.pt'): resource.name += '.pt' - full_path = os.path.join(models_path, 'VAE', resource.name) +def select_vae(resource: ResourceRequest): + # TODO: find by hash instead of name + if not resource['name'].endswith('.pt'): resource['name'] += '.pt' + full_path = os.path.join(models_path, 'VAE', resource['name']) - if sd_vae.loaded_vae_file is not None and sd_vae.get_filename(sd_vae.loaded_vae_file) == resource.name: + if sd_vae.loaded_vae_file is not None and sd_vae.get_filename(sd_vae.loaded_vae_file) == resource['name']: log('VAE already loaded') return - isAvailable = load_if_missing(full_path, resource.url) + isAvailable = load_if_missing(full_path, resource['url']) if not isAvailable: log('Could not find VAE') return @@ -189,61 +333,24 @@ async def load_vae(resource: ResourceRequest): sd_vae.refresh_vae_list() sd_vae.load_vae(shared.sd_model, full_path) -async def load_lora(resource: ResourceRequest): - isAvailable = load_if_missing(os.path.join('extensions/sd-webui-additional-networks/models/lora', resource.name), resource.url) - - # TODO: Auto refresh LORA - # lora = import_module('extensions.sd-webui-additional-networks.scripts.additional_networks') - # if lora is None: - # log('LORA extension not installed') - # return - # if isAvailable is None: # isAvailable is None if the file was downloaded - # lora.update_lora_models() - - # TODO: Select LORA - # ¯\_(ツ)_/¯ - -async def old_load_model(name, url=None): - if shared.opts.sd_model_checkpoint == name: return - - model = sd_models.get_closet_checkpoint_match(name) - if model is None and url is not None: - log('Downloading model') - load_if_missing(os.path.join(models_path, 'stable-diffusion', name), url) - sd_models.list_models() - model = sd_models.get_closet_checkpoint_match(name) - elif shared.opts.sd_model_checkpoint == model.title: - log('Model already loaded') - return - else: - log('Found model in model list') - - if model is not None: - sd_models.load_model(model) - shared.opts.sd_model_checkpoint = model.title - shared.opts.save(shared.config_filename) - else: log('Could not find model in model list') - - -async def download_textual_inversion(name, url): - load_if_missing(os.path.join('embeddings', name), url) - -async def download_aesthetic_gradient(name, url): - load_if_missing(os.path.join('extensions/stable-diffusion-webui-aesthetic-gradients','aesthetic_embeddings', name), url) +def clear_vae(): + log('Clearing VAE') + sd_vae.clear_loaded_vae() -async def old_load_hypernetwork(name, url=None): - if shared.opts.sd_hypernetwork == name: +def select_hypernetwork(resource: ResourceRequest): + # TODO: find by hash instead of name + if shared.opts.sd_hypernetwork == resource['name']: log('Hypernetwork already loaded') return - full_path = os.path.join(models_path, 'hypernetworks', name); + full_path = os.path.join(shared.cmd_opts.hypernetwork_dir, resource['name']); if not full_path.endswith('.pt'): full_path += '.pt' - isAvailable = load_if_missing(full_path, url) + isAvailable = load_if_missing(full_path, resource['url']) if not isAvailable: log('Could not find hypernetwork') return - shared.opts.sd_hypernetwork = name + shared.opts.sd_hypernetwork = resource['name'] shared.opts.save(shared.config_filename) shared.reload_hypernetworks() @@ -254,25 +361,4 @@ def clear_hypernetwork(): shared.opts.sd_hypernetwork = 'None' shared.opts.save(shared.config_filename) shared.reload_hypernetworks() - -async def load_vae(name, url=None): - if not name.endswith('.pt'): name += '.pt' - full_path = os.path.join(models_path, 'VAE', name) - - if sd_vae.loaded_vae_file is not None and sd_vae.get_filename(sd_vae.loaded_vae_file) == name: - log('VAE already loaded') - return - - isAvailable = load_if_missing(full_path, url) - if not isAvailable: - log('Could not find VAE') - return - - sd_vae.refresh_vae_list() - sd_vae.load_vae(shared.sd_model, full_path) - -def clear_vae(): - log('Clearing VAE') - sd_vae.clear_loaded_vae() - -#endregion Downloading +#endregion diff --git a/civitai/models.py b/civitai/models.py index 256ba7a..db1e997 100644 --- a/civitai/models.py +++ b/civitai/models.py @@ -10,6 +10,11 @@ class ResourceTypes(str, Enum): LORA = "LORA" VAE = "VAE" +class CommandTypes(str, Enum): + ResourcesList = "resources:list" + ResourcesAdd = "resources:add" + ResourcesRemove = "resources:remove" + class ImageParams(BaseModel): prompt: str = Field(default="", title="Prompt", description="The prompt to use when generating the image.") negativePrompt: str = Field(default="", title="Negative Prompt", description="The negative prompt to use when generating the image.") @@ -34,4 +39,35 @@ class ResourceRequest(BaseModel): name: str = Field(default=None, title="Name", description="The name of the resource to download.") type: ResourceTypes = Field(default=None, title="Type", description="The type of the resource to download.") hash: str = Field(default=None, title="Hash", description="The SHA256 hash of the resource to download.") - url: str = Field(default=None, title="URL", description="The URL of the resource to download.", required=False) \ No newline at end of file + url: str = Field(default=None, title="URL", description="The URL of the resource to download.", required=False) + previewImage: str = Field(default=None, title="Preview Image", description="The URL of the preview image.", required=False) + addons: list[str] = Field(default=[], title="Addons", description="The addons to download with the resource.", required=False) + +class Command(BaseModel): + id: str = Field(default=None, title="ID", description="The ID of the command.") + type: CommandTypes = Field(default=None, title="Type", description="The type of command to execute.") + +class CommandResourcesList(Command): + type: CommandTypes = Field(default=CommandTypes.ResourcesList, title="Type", description="The type of command to execute.") + types: list[ResourceTypes] = Field(default=[], title="Types", description="The types of resources to list.") + +class CommandResourcesAdd(Command): + type: CommandTypes = Field(default=CommandTypes.ResourcesAdd, title="Type", description="The type of command to execute.") + resources: list[ResourceRequest] = Field(default=[], title="Resources", description="The resources to add.") + +class ResourceRemoveRequest(BaseModel): + type: ResourceTypes = Field(default=None, title="Type", description="The type of the resource to remove.") + hash: str = Field(default=None, title="Hash", description="The SHA256 hash of the resource to remove.") + +class CommandResourcesRemove(Command): + type: CommandTypes = Field(default=CommandTypes.ResourcesRemove, title="Type", description="The type of command to execute.") + resources: list[ResourceRemoveRequest] = Field(default=[], title="Resources", description="The resources to remove.") + +class UpgradeKeyPayload(BaseModel): + key: str = Field(default=None, title="Key", description="The upgraded key.") + +class ErrorPayload(BaseModel): + msg: str = Field(default=None, title="Message", description="The error message.") + +class JoinedPayload(BaseModel): + type: str = Field(default=None, title="Type", description="The type of the client that joined.") \ No newline at end of file diff --git a/install.py b/install.py index 10eb05d..e1b65c5 100644 --- a/install.py +++ b/install.py @@ -1 +1,91 @@ -# install script \ No newline at end of file +import filecmp +import importlib.util +import os +import shutil +import sys +import sysconfig + +import git + +from launch import run + +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + +req_file = os.path.join(os.path.dirname(os.path.realpath(__file__)), "requirements.txt") + + +def check_versions(): + global req_file + reqs = open(req_file, 'r') + lines = reqs.readlines() + reqs_dict = {} + for line in lines: + splits = line.split("==") + if len(splits) == 2: + key = splits[0] + reqs_dict[key] = splits[1].replace("\n", "").strip() + # print(f"Reqs dict: {reqs_dict}") + checks = ["socketio[client]","blake3"] + for check in checks: + check_ver = "N/A" + status = "[ ]" + try: + check_available = importlib.util.find_spec(check) is not None + if check_available: + check_ver = importlib_metadata.version(check) + if check in reqs_dict: + req_version = reqs_dict[check] + if str(check_ver) == str(req_version): + status = "[+]" + else: + status = "[!]" + except importlib_metadata.PackageNotFoundError: + check_available = False + if not check_available: + status = "[!]" + print(f"{status} {check} NOT installed.") + else: + print(f"{status} {check} version {check_ver} installed.") + + +base_dir = os.path.dirname(os.path.realpath(__file__)) +revision = "" +app_revision = "" + +try: + repo = git.Repo(base_dir) + revision = repo.rev_parse("HEAD") + app_repo = git.Repo(os.path.join(base_dir, "..", "..")) + app_revision = app_repo.rev_parse("HEAD") +except: + pass + +print("") +print("#######################################################################################################") +print("Initializing Civitai Link") +print("If submitting an issue on github, please provide the below text for debugging purposes:") +print("") +print(f"Python revision: {sys.version}") +print(f"Civitai Link revision: {revision}") +print(f"SD-WebUI revision: {app_revision}") +print("") +civitai_skip_install = os.environ.get('CIVITAI_SKIP_INSTALL', False) + +try: + requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt") + if requirements_file == req_file: + civitai_skip_install = True +except: + pass + +if not civitai_skip_install: + name = "Civitai Link" + run(f'"{sys.executable}" -m pip install -r "{req_file}"', f"Checking {name} requirements...", + f"Couldn't install {name} requirements.") + +check_versions() +print("") +print("#######################################################################################################") diff --git a/javascript/civitai.js b/javascript/civitai.js index 9488623..cc22249 100644 --- a/javascript/civitai.js +++ b/javascript/civitai.js @@ -1,74 +1,93 @@ -// #region [utils] -const log = (...args) => console.log(`[civitai]`, ...args); -const delay = (ms) => new Promise(resolve => setTimeout(resolve, ms)); -const getElement = (selector, timeout = 10000) => new Promise((resolve, reject) => { - const interval = setInterval(() => { - const el = gradioApp().querySelector(selector); - timeout -= 100; - if (timeout < 0) { - reject('timeout'); - clearInterval(interval); - } else if (el) { - resolve(el); - clearInterval(interval); - } - }, 100); -}) -// #endregion +(async function () { + // #region [utils] + const log = (...args) => console.log(`[civitai]`, ...args); + const delay = (ms) => new Promise(resolve => setTimeout(resolve, ms)); + const getElement = (selector, timeout = 10000) => new Promise((resolve, reject) => { + const interval = setInterval(() => { + const el = gradioApp().querySelector(selector); + timeout -= 100; + if (timeout < 0) { + reject('timeout'); + clearInterval(interval); + } else if (el) { + resolve(el); + clearInterval(interval); + } + }, 100); + }) + // #endregion + + + async function generate() { + const generateButton = await getElement('#txt2img_generate'); + generateButton.click(); + log('generating image'); + } + async function handlePrompt(prompt, andGenerate = false, delayMs = 3000) { + log('injecting prompt', prompt); + const promptEl = await getElement('#txt2img_prompt textarea'); + promptEl.value = prompt; + promptEl.dispatchEvent(new Event("input", { bubbles: true })); // internal Svelte trigger -async function generate() { - const generateButton = await getElement('#txt2img_generate'); - generateButton.click(); - log('generating image'); -} + const pastePromptButton = await getElement('#paste'); + pastePromptButton.click(); + log('applying prompt'); -async function handlePrompt(prompt, andGenerate = false, delayMs = 3000) { - log('injecting prompt', prompt); - const promptEl = await getElement('#txt2img_prompt textarea'); - promptEl.value = prompt; - promptEl.dispatchEvent(new Event("input", { bubbles: true })); // internal Svelte trigger + if (andGenerate) { + await delay(delayMs); + await generate(); + notifyParent({generate: true}); + } + } - const pastePromptButton = await getElement('#paste'); - pastePromptButton.click(); - log('applying prompt'); + function notifyParent(msg) { + if (child && child.sendMessageToParent) + child.sendMessageToParent(msg); + } - if (andGenerate) { - await delay(delayMs); - await generate(); - notifyParent({generate: true}); + async function refreshModels() { + const refreshModelsButton = await getElement('#refresh_sd_model_checkpoint'); + refreshModelsButton.click(); } -} -function notifyParent(msg) { - if (child && child.sendMessageToParent) - child.sendMessageToParent(msg); -} + let child; + async function hookChild() { + child = new AcrossTabs.default.Child({ + // origin: 'https://civitai.com', + origin: 'http://localhost:3000', + onParentCommunication: commandHandler + }); + } -async function refreshModels() { - const refreshModelsButton = await getElement('#refresh_sd_model_checkpoint'); - refreshModelsButton.click(); -} + function commandHandler({ command, ...data }) { + log('tab communication', { command, data }) + switch (command) { + case 'generate': return handlePrompt(data.generationParams, true, 500); + case 'refresh-models': return refreshModels(); + } + } -let child; -async function hookChild() { - child = new AcrossTabs.default.Child({ - // origin: 'https://civitai.com', - origin: 'http://localhost:3000', - onParentCommunication: commandHandler - }); -} + let statusElement = document.createElement('div'); + let currentStatus = false; + async function checkStatus() { + const { connected } = await fetch('/civitai/v1/link-status').then(x=>x.json()); + if (currentStatus != connected) { + currentStatus = connected; + statusElement.classList.toggle('connected', connected); + } + } + async function startStatusChecks() { + statusElement.id = 'civitai-status'; + statusElement.classList.add('civitai-status'); + await getElement('.gradio-container'); // wait for gradio to load + gradioApp().appendChild(statusElement); -function commandHandler({ command, ...data }) { - log('tab communication', { command, data }) - switch (command) { - case 'generate': return handlePrompt(data.generationParams, true, 500); - case 'refresh-models': return refreshModels(); + setInterval(checkStatus, 1000 * 10); + checkStatus(); } -} -// Bootstrap -(async () => { + // Bootstrap const searchParams = new URLSearchParams(location.search); if (searchParams.has('civitai_prompt')) handlePrompt(atob(searchParams.get('civitai_prompt')), searchParams.has('civitai_generate')); @@ -79,4 +98,5 @@ function commandHandler({ command, ...data }) { // clear search params history.replaceState({}, document.title, location.pathname); -})() \ No newline at end of file + await startStatusChecks(); +})(); diff --git a/preload.py b/preload.py index 474e755..effc7ce 100644 --- a/preload.py +++ b/preload.py @@ -2,4 +2,5 @@ import argparse def preload(parser: argparse.ArgumentParser): - parser.add_argument("--civitai-endpoint", type=str, help="Endpoint for interacting with a Civitai instance", default="https://civitai.com/api/v1") \ No newline at end of file + parser.add_argument("--civitai-endpoint", type=str, help="Endpoint for interacting with a Civitai instance", default="https://civitai.com/api/v1") + parser.add_argument("--civitai-link-endpoint", type=str, help="Endpoint for interacting with a Civitai Link server", default="https://link.civitai.com/api/socketio") \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..c097620 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,2 @@ +python-socketio[client]==5.7.2 +blake3==0.3.3 \ No newline at end of file diff --git a/scripts/api.py b/scripts/api.py index 7d8d095..e81aa56 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -1,114 +1,16 @@ # api endpoints -import asyncio import gradio as gr from fastapi import FastAPI -from modules import shared, script_callbacks as script_callbacks -from modules.hypernetworks import hypernetwork -from modules.api.api import encode_pil_to_base64, validate_sampler_name -from modules.api.models import StableDiffusionTxt2ImgProcessingAPI, TextToImageResponse -from modules.processing import StableDiffusionProcessingTxt2Img, process_images -from modules.sd_models import checkpoints_list -from modules.call_queue import queue_lock -from typing import List +from modules import script_callbacks as script_callbacks import extensions.sd_civitai_extension.civitai.lib as civitai from extensions.sd_civitai_extension.civitai.models import GenerateImageRequest, ResourceRequest def civitaiAPI(demo: gr.Blocks, app: FastAPI): - # To detect if the API is loaded - @app.get("/civitai/v1") - async def index(): - return {"status": "success"} - - # To get a list of resources available - @app.get("/civitai/v1/resources") - async def get_resources(): - models = [{"name":x.model_name, "hash":x.sha256, "type":"Checkpoint"} for x in checkpoints_list.values()] - hypernetworks = [civitai.parse_hypernetwork(name) for name in shared.hypernetworks] - return models + hypernetworks - - # To activate a list of resources - @app.post("/civitai/v1/resources") - async def set_resources(resources: List[ResourceRequest]): - for resource in resources: - await asyncio.create_task(civitai.load_resource(resource)) - - return {"status": "success"} - - - # To download and select a model - # @app.post("/civitai/v1/run/{id}") - # async def run(id: str): - # to_run = civitai.get_model_version(id) - # to_run_name = f'{to_run["model"]["name"]} {to_run["name"]}' - # civitai.log(f'Running: {to_run_name}') - # primary_file = [x for x in to_run['files'] if x['primary'] == True][0] - # name = primary_file['name'] - # hash = primary_file['hashes']['AutoV1'] - # url = to_run['downloadUrl'] - # type = to_run['model']['type'] - # if type == 'Checkpoint': - # try: - # config_file = [x for x in to_run['files'] if x['type'] == "Config"][0] - # await asyncio.create_task(civitai.load_config(config_file['name'], config_file['downloadUrl'])) - # except IndexError: config_file = None - # await asyncio.create_task(civitai.load_model(name, url)) - # elif type == 'TextualInversion': await asyncio.create_task(civitai.download_textual_inversion(name, url)) - # elif type == 'AestheticGradient': await asyncio.create_task(civitai.download_aesthetic_gradient(name, url)) - # elif type == 'Hypernetwork': await asyncio.create_task(civitai.load_hypernetwork(name, url)) - - # civitai.log(f'Loaded: {to_run_name}') - # return {"status": "success"} - - def txt2img(txt2imgreq: StableDiffusionTxt2ImgProcessingAPI): - populate = txt2imgreq.copy(update={ # Override __init__ params - "sampler_name": validate_sampler_name(txt2imgreq.sampler_name or txt2imgreq.sampler_index), - "do_not_save_samples": True, - "do_not_save_grid": True - } - ) - if populate.sampler_name: - populate.sampler_index = None # prevent a warning later on - - args = vars(populate) - args.pop('script_name', None) - - with queue_lock: - p = StableDiffusionProcessingTxt2Img(sd_model=shared.sd_model, **args) - - shared.state.begin() - processed = process_images(p) - shared.state.end() - - b64images = list(map(encode_pil_to_base64, processed.images)) - - return TextToImageResponse(images=b64images, parameters=vars(txt2imgreq), info=processed.js()) - - @app.post("/civitai/v1/generate/image", response_model=TextToImageResponse) - async def generate_image(req: GenerateImageRequest): - if (req.vae is None): civitai.clear_vae() - if (req.hypernetwork is None): civitai.clear_hypernetwork() - - if (req.model is not None): await asyncio.create_task(civitai.load_model(req.model)) - if (req.hypernetwork is not None): - await asyncio.create_task(civitai.load_hypernetwork(req.hypernetwork)) - hypernetwork.apply_strength(req.params.hypernetworkStrength) - if (req.vae is not None): await asyncio.create_task(civitai.load_vae(req.vae)) - - return txt2img( - StableDiffusionTxt2ImgProcessingAPI( - prompt=req.params.prompt, - negative_prompt=req.params.negativePrompt, - seed=req.params.seed, - steps=req.params.steps, - width=req.params.width, - height=req.params.height, - cfg_scale=req.params.cfgScale, - n_iter=req.quantity, - batch_size=req.batchSize, - ) - ) + @app.get('/civitai/v1/link-status') + def link_status(): + return { "connected": civitai.connected } script_callbacks.on_app_started(civitaiAPI) civitai.log("API loaded") diff --git a/scripts/main.py b/scripts/main.py index d05f252..3256d88 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -1,49 +1,151 @@ # main ui +import time import gradio as gr +import socketio +import os import extensions.sd_civitai_extension.civitai.lib as civitai +from extensions.sd_civitai_extension.civitai.models import Command, CommandResourcesAdd, CommandResourcesList, CommandResourcesRemove, ErrorPayload, JoinedPayload, UpgradeKeyPayload -from modules import shared, sd_models, script_callbacks - -def on_ui_tabs(): - with gr.Blocks() as civitai_interface: - # Nav row with Civitai logo, search bar, sort, sort period, tag select, and creator select - with gr.Row(): - gr.HTML("

Civitai

") - with gr.Group(): - civitai_query = gr.Textbox(label="Search", default="") - civitai_button_search = gr.Button(label="🔎"); - civitai_sort = gr.Dropdown(label="Sort", value="Most Downloaded", options=["Most Downloaded", "Most Recent", "Most Liked", "Most Viewed"]) - civitai_sort_period = gr.Dropdown(label="Sort Period", value="All Time", options=["All Time", "Last 30 Days", "Last 7 Days", "Last 24 Hours"]) - civitai_tag = gr.Dropdown(label="Tag", choices=["All", "Anime", "Cartoon", "Comic", "Game", "Movie", "Music", "Other", "Realistic", "TV"]) - civitai_creator = gr.Dropdown(label="Creator", choices=["All", "Anime", "Cartoon", "Comic", "Game", "Movie", "Music", "Other", "Realistic", "TV"]) - civitai_page = gr.Number(visible=False, value=1) - civitai_page_size = gr.Number(visible=False, default=20) - # Model list - with gr.Row(): - model_output = gr.HTML() - # Pagination - with gr.Row(): - civitai_button_prev = gr.Button(label="Previous") - civitai_current_page = gr.HTML("

Page 1 of 1

") - civitai_button_next = gr.Button(label="Next") - - # Dummy Elements - download_model_version_id = gr.Number(visible=False, value=0, elem_id="download_model_version_id") - download_model_button = gr.Button(label="Download", visible=False, elem_id="download_model_button") - - # Handle button clicks - civitai_button_search.click(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page]) - civitai_button_prev.click(fn=prev_page, inputs=[civitai_page], outputs=[civitai_page]) - civitai_button_next.click(fn=next_page, inputs=[civitai_page], outputs=[civitai_page]) - - # Handle dropdown changes - civitai_tag.change(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page]) - civitai_sort.change(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page]) - civitai_sort_period.change(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page]) - civitai_creator.change(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page]) - civitai_page.change(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page]) - civitai_page_size.change(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page]) +from modules import shared, sd_models, script_callbacks, hashes + +#region Civitai Link Command Handlers +def on_resources_list(payload: CommandResourcesList): + types = payload['types'] if 'types' in payload else [] + resources = civitai.load_resource_list(types) + sio.emit('commandStatus', { 'type': payload['type'], 'resources': resources, 'status': 'success' }) + +report_interval = 10 +def on_resources_add(payload: CommandResourcesAdd): + resources = payload['resources'] + payload['status'] = 'processing' + + last_report = time.time() + def report_status(force=False): + nonlocal last_report + current_time = time.time() + if force or current_time - last_report > report_interval: + sio.emit('commandStatus', { 'id': payload['id'], 'type': payload['type'], 'resources': resources, 'status': payload['status'] }) + last_report = current_time + + def progress_for_resource(resource): + def on_progress(current: int, total: int, start_time: float): + current_time = time.time() + elapsed_time = current_time - start_time + speed = current / elapsed_time + remaining_time = (total - current) / speed + progress = current / total * 100 + resource['status'] = 'processing' + resource['progress'] = progress + resource['remainingTime'] = remaining_time + resource['speed'] = speed + report_status() + + return on_progress + + had_error = False + for resource in resources: + try: + civitai.load_resource(resource, progress_for_resource(resource)) + resource['status'] = 'success' + except Exception as e: + civitai.log(e) + resource['status'] = 'error' + resource['error'] = 'Failed to download resource' + had_error = True + report_status(True) + + + payload['status'] = 'success' if not had_error else 'error' + if had_error: + payload['error'] = 'Failed to download some resources' + + report_status(True) + +def on_resources_remove(payload: CommandResourcesRemove): + resources = payload['resources'] + + had_error = False + for resource in resources: + try: + civitai.remove_resource(resource) + resource['status'] = 'success' + except Exception as e: + civitai.log(e) + resource['status'] = 'error' + resource['error'] = 'Failed to remove resource' + had_error = True + + sio.emit('commandStatus', { 'id': payload['id'], 'type': payload['type'], 'resources': resources, 'status': 'success' if not had_error else 'error' }) +#endregion + +#region SocketIO Events +try: + socketio_url = shared.cmd_opts.civitai_link_endpoint +except: + socketio_url = 'https://link.civitai.com' + +sio = socketio.Client() + +@sio.event +def connect(): + civitai.log('Connected to Civitai Link') + sio.emit('iam', {"type": "sd"}) + +@sio.on('command') +def on_command(payload: Command): + command = payload['type'] + civitai.log(f"command: {payload['type']}") + if command == 'resources:list': return on_resources_list(payload) + elif command == 'resources:add': return on_resources_add(payload) + elif command == 'resources:remove': return on_resources_remove(payload) + + +@sio.on('linkStatus') +def on_link_status(payload: bool): + civitai.connected = payload + civitai.log("Civitai Link ready") + +@sio.on('upgradeKey') +def on_upgrade_key(payload: UpgradeKeyPayload): + civitai.log("Link Key upgraded") + shared.opts.data['civitai_link_key'] = payload['key'] + +@sio.on('error') +def on_error(payload: ErrorPayload): + civitai.log(f"Error: {payload['msg']}") + +@sio.on('joined') +def on_joined(payload: JoinedPayload): + if payload['type'] != 'client': return + civitai.log("Client joined") +#endregion + +#region SocketIO Connection Management +def socketio_connect(): + sio.connect(socketio_url, socketio_path='/api/socketio') + +def join_room(key): + def on_join(payload): + civitai.log(f"Joined room {key}") + sio.emit('join', key, callback=on_join) + +def connect_to_civitai(demo: gr.Blocks, app): + key = shared.opts.data.get("civitai_link_key", None) + if key is None: return + + socketio_connect() + join_room(key) + +def on_civitai_link_key_changed(): + if not sio.connected: socketio_connect() + key = shared.opts.data.get("civitai_link_key", None) + join_room(key) +#endregion + +def on_ui_settings(): + section = ('civitai_link', "Civitai Link") + shared.opts.add_option("civitai_link_key", shared.OptionInfo("", "Your Civitai Link Key", section=section, onchange=on_civitai_link_key_changed)) # Automatically pull model with corresponding hash from Civitai @@ -56,4 +158,5 @@ def on_infotext_pasted(infotext, params): civitai.fetch_model_by_hash(model_hash) script_callbacks.on_infotext_pasted(on_infotext_pasted) -# script_callbacks.on_ui_tabs(on_ui_tabs) \ No newline at end of file +script_callbacks.on_ui_settings(on_ui_settings) +script_callbacks.on_app_started(connect_to_civitai) \ No newline at end of file diff --git a/style.css b/style.css index e69de29..77da882 100644 --- a/style.css +++ b/style.css @@ -0,0 +1,32 @@ +div.civitai-status{ + position: absolute; + top: 7px; + right: 5px; + width:24px; + height:24px; + background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 178 178' style='enable-background:new 0 0 178 178' xml:space='preserve'%3E%3ClinearGradient id='a' gradientUnits='userSpaceOnUse' x1='89.3' y1='1.5' x2='89.3' y2='177.014'%3E%3Cstop offset='0' style='stop-color:%23081692'/%3E%3Cstop offset='1' style='stop-color:%231e043c'/%3E%3C/linearGradient%3E%3ClinearGradient id='b' gradientUnits='userSpaceOnUse' x1='89.3' y1='1.5' x2='89.3' y2='177.014'%3E%3Cstop offset='0' style='stop-color:%231284f7'/%3E%3Cstop offset='1' style='stop-color:%230a20c9'/%3E%3C/linearGradient%3E%3Cpath style='fill:url(%23a)' d='M13.3 45.4v87.7l76 43.9 76-43.9V45.4l-76-43.9z'/%3E%3Cpath style='fill:url(%23b)' d='m89.3 29.2 52 30v60l-52 30-52-30v-60l52-30m0-27.7-76 43.9v87.8l76 43.9 76-43.9V45.4l-76-43.9z'/%3E%3Cpath style='fill:%23fff' d='m104.1 97.2-14.9 8.5-14.9-8.5v-17l14.9-8.5 14.9 8.5h18.2V69.7l-33-19-33 19v38.1l33 19 33-19V97.2z'/%3E%3C/svg%3E"); +} + +div.civitai-status:before { + width:6px; + height:6px; + background: red; + border-radius: 50%; + content: ''; + position:absolute; + top:-4px; + right:1px; + border: 1px solid rgba(255,255,255,0.3); +} + +/* blinking animation */ +@keyframes blink { + 0% { opacity: 0.2; } + 50% { opacity: 1; } + 100% { opacity: 0.2; } +} + +div.civitai-status.connected:before { + background:green; + animation: blink 1s ease-in-out infinite; +} \ No newline at end of file