diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f83ac57 --- /dev/null +++ b/.gitignore @@ -0,0 +1,10 @@ +__pycache__ +/repositories +/venv +/outputs +/log +/webui.settings.bat +/.idea +.vscode +/test/stdout.txt +/test/stderr.txt diff --git a/civitai/api.py b/civitai/api.py index f515bec..01c0423 100644 --- a/civitai/api.py +++ b/civitai/api.py @@ -1,14 +1,27 @@ import json import os import requests +import re from modules import shared, sd_models, generation_parameters_copypaste try: - base_url = shared.cmd_options.civitai_endpoint + base_url = shared.cmd_opts.civitai_endpoint except: base_url = 'https://civitai.com/api/v1' +#region Utils +def log(message): + """Log a message to the console.""" + print(f'Civitai: {message}') + +def parse_hypernetwork(string): + match = re.search(r'(.+)\(([^)]+)', string) + if match: + return {"name": match.group(1), "hash": match.group(2)} + return {"name": "", "hash": ""} +#endregion Utils + #region API def req(endpoint, method='GET', data=None, params=None, headers=None): """Make a request to the Civitai API.""" @@ -64,36 +77,72 @@ def get_tags(query, page=1, page_size=20): #region Downloading download_locations = { - 'model': os.path.join('models', 'stable-diffusion'), - 'textual_inversion': os.path.join('embeddings'), - 'aesthetic_gradient': os.path.join('aesthetic_embeddings'), - 'hypernetwork': os.path.join('models', 'hypernetworks'), + 'Checkpoint': os.path.join('models', 'stable-diffusion'), + 'TextualInversion': os.path.join('embeddings'), + 'AestheticGradient': os.path.join('extensions/stable-diffusion-webui-aesthetic-gradients','aesthetic_embeddings'), + 'Hypernetwork': os.path.join('models', 'hypernetworks'), } -def download(url, type): +async def download(url, type): """Download a file from the Civitai API using requests and save file to type specific location with the filename from the content disposition header.""" + log(f'Downloading {type}: {url}') response = requests.get(url, stream=True) - filename = response.headers['content-disposition'].split('filename=')[1] - with open(os.path.join(download_locations[type], filename), 'wb') as f: - for chunk in response.iter_content(chunk_size=1024): - if chunk: - f.write(chunk) - # update model list - if type == 'model': sd_models.list_models() + if response.status_code != 200: + raise Exception(f'Error: {response.status_code}') + + filename = response.headers['content-disposition'].split('filename=')[1].strip('"') + dest = os.path.join(download_locations[type], filename) + + if os.path.exists(dest): + log(f'File already exists: {dest}') + return (filename, False) + + with open(dest, 'wb') as f: + for chunk in response.iter_content(chunk_size=4096): + if chunk: f.write(chunk) + + log(f'Downloaded: {dest}') + + return (filename, True) + +async def run_model(name, url): + model = sd_models.get_closet_checkpoint_match(name) + + if model is None: + (filename, downloaded) = await download(url, 'Checkpoint') + if downloaded: sd_models.list_models() + model = sd_models.get_closet_checkpoint_match(filename) + elif shared.opts.sd_model_checkpoint == model.title: + log('Model already loaded') + return model.filename + else: + filename = model.filename + log('Found model in model list') + + if model is not None: + sd_models.load_model(model) + shared.opts.sd_model_checkpoint = model.title + shared.opts.save(shared.config_filename) + else: log('Could not find model in model list') return filename -def download_model(url): - return download(url, 'model') -def download_textual_inversion(url): - return download(url, 'textual_inversion') +async def download_textual_inversion(url): + (filename) = await download(url, 'TextualInversion') + return filename -def download_aesthetic_gradient(url): - return download(url, 'aesthetic_gradient') +async def download_aesthetic_gradient(url): + (filename, downloaded) = await download(url, 'AestheticGradient') -def download_hypernetwork(url): - return download(url, 'hypernetwork') + return filename + +async def download_hypernetwork(url): + (filename, downloaded) = await download(url, 'Hypernetwork') + shared.opts.sd_hypernetwork = filename + shared.reload_hypernetworks() + + return filename #endregion Downloading diff --git a/javascript/civitai.js b/javascript/civitai.js new file mode 100644 index 0000000..d3c1445 --- /dev/null +++ b/javascript/civitai.js @@ -0,0 +1,50 @@ +// #region [utils] +const log = (...args) => console.log(`[civitai]`, ...args); +const delay = (ms) => new Promise(resolve => setTimeout(resolve, ms)); +const getElement = (selector, timeout = 10000) => new Promise((resolve, reject) => { + const interval = setInterval(() => { + const el = gradioApp().querySelector(selector); + timeout -= 100; + if (timeout < 0) { + reject('timeout'); + clearInterval(interval); + } else if (el) { + resolve(el); + clearInterval(interval); + } + }, 100); +}) +// #endregion + + +async function generate() { + const generateButton = await getElement('#txt2img_generate'); + generateButton.click(); + log('generating image'); +} + +async function handlePrompt(encodedPrompt, andGenerate = false) { + const prompt = atob(encodedPrompt); + log('injecting prompt', prompt); + const promptEl = await getElement('#txt2img_prompt textarea'); + promptEl.value = prompt; + + const pastePromptButton = await getElement('#paste'); + pastePromptButton.click(); + log('applying prompt'); + + if (andGenerate) { + await delay(3000); + await generate(); + } +} + +// Bootstrap +(async () => { + const searchParams = new URLSearchParams(location.search); + if (searchParams.has('civitai_prompt')) + handlePrompt(searchParams.get('civitai_prompt'), searchParams.has('civitai_generate')); + + // clear search params + history.replaceState({}, document.title, location.pathname); +})() \ No newline at end of file diff --git a/preload.py b/preload.py index 124779f..474e755 100644 --- a/preload.py +++ b/preload.py @@ -2,5 +2,4 @@ import argparse def preload(parser: argparse.ArgumentParser): - parser.add_argument("--civitai-endpoint", type=str, help="Endpoint for interacting with a Civitai instance", default="https://civitai.com/api/v1") - parser.add_argument("--civitai-api", type=bool, help="Accept requests to install models", default=False) \ No newline at end of file + parser.add_argument("--civitai-endpoint", type=str, help="Endpoint for interacting with a Civitai instance", default="https://civitai.com/api/v1") \ No newline at end of file diff --git a/scripts/api.py b/scripts/api.py index 900f0d8..40e2ef0 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -4,24 +4,43 @@ from fastapi import FastAPI from modules import shared, script_callbacks as script_callbacks +from modules.sd_models import checkpoints_list import extensions.sd_civitai_extension.civitai.api as civitai def civitaiAPI(demo: gr.Blocks, app: FastAPI): - @app.get("/install/{id}") - async def install(id: str): - to_install = civitai.get_model_version(id) - print("Civitai Installing: " + to_install['name']) - url = to_install['downloadUrl'] - type = to_install['type'] - task = asyncio.create_task(civitai.download(url, type)) + # To detect if the API is loaded + @app.get("/civitai/v1") + async def index(): return {"status": "success"} -try: - api_enabled = shared.cmd_options.civitai_api -except: - api_enabled = False + # To get a list of models + @app.get("/civitai/v1/models") + async def models(): + return [{"name":x.model_name, "hash":x.hash} for x in checkpoints_list.values()] -if api_enabled: - script_callbacks.on_app_started(civitaiAPI) - print("Civitai API loaded") \ No newline at end of file + # To get a list of hypernetworks + @app.get("/civitai/v1/hypernetworks") + async def hypernetworks(): + return [civitai.parse_hypernetwork(name) for name in shared.hypernetworks] + + # To download and select a model + @app.post("/civitai/v1/run/{id}") + async def run(id: str): + to_run = civitai.get_model_version(id) + to_run_name = f'{to_run["model"]["name"]} {to_run["name"]}' + civitai.log(f'Running: {to_run_name}') + primary_file = [x for x in to_run['files'] if x['primary'] == True][0] + name = primary_file['name'] + url = to_run['downloadUrl'] + type = to_run['model']['type'] + if type == 'Checkpoint': await asyncio.create_task(civitai.run_model(name, url)) + elif type == 'TextualInversion': await asyncio.create_task(civitai.download_textual_inversion(url)) + elif type == 'AestheticGradient': await asyncio.create_task(civitai.download_aesthetic_gradient(url)) + elif type == 'Hypernetwork': await asyncio.create_task(civitai.download_hypernetwork(url)) + + civitai.log(f'Loaded: {to_run_name}') + return {"status": "success"} + +script_callbacks.on_app_started(civitaiAPI) +civitai.log("API loaded") diff --git a/scripts/main.py b/scripts/main.py index df42150..4b084f8 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -46,4 +46,4 @@ def on_ui_tabs(): civitai_page_size.change(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page]) -script_callbacks.on_ui_tabs(on_ui_tabs) \ No newline at end of file +# script_callbacks.on_ui_tabs(on_ui_tabs) \ No newline at end of file