Skip to content

Commit

Permalink
So close
Browse files Browse the repository at this point in the history
  • Loading branch information
JustMaier committed Dec 19, 2022
1 parent 3e1554b commit 75fca87
Show file tree
Hide file tree
Showing 6 changed files with 165 additions and 38 deletions.
10 changes: 10 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
__pycache__
/repositories
/venv
/outputs
/log
/webui.settings.bat
/.idea
.vscode
/test/stdout.txt
/test/stderr.txt
91 changes: 70 additions & 21 deletions civitai/api.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,27 @@
import json
import os
import requests
import re

from modules import shared, sd_models, generation_parameters_copypaste

try:
base_url = shared.cmd_options.civitai_endpoint
base_url = shared.cmd_opts.civitai_endpoint
except:
base_url = 'https://civitai.com/api/v1'

#region Utils
def log(message):
"""Log a message to the console."""
print(f'Civitai: {message}')

def parse_hypernetwork(string):
match = re.search(r'(.+)\(([^)]+)', string)
if match:
return {"name": match.group(1), "hash": match.group(2)}
return {"name": "", "hash": ""}
#endregion Utils

#region API
def req(endpoint, method='GET', data=None, params=None, headers=None):
"""Make a request to the Civitai API."""
Expand Down Expand Up @@ -64,36 +77,72 @@ def get_tags(query, page=1, page_size=20):

#region Downloading
download_locations = {
'model': os.path.join('models', 'stable-diffusion'),
'textual_inversion': os.path.join('embeddings'),
'aesthetic_gradient': os.path.join('aesthetic_embeddings'),
'hypernetwork': os.path.join('models', 'hypernetworks'),
'Checkpoint': os.path.join('models', 'stable-diffusion'),
'TextualInversion': os.path.join('embeddings'),
'AestheticGradient': os.path.join('extensions/stable-diffusion-webui-aesthetic-gradients','aesthetic_embeddings'),
'Hypernetwork': os.path.join('models', 'hypernetworks'),
}

def download(url, type):
async def download(url, type):
"""Download a file from the Civitai API using requests and save file to type specific location with the filename from the content disposition header."""
log(f'Downloading {type}: {url}')
response = requests.get(url, stream=True)
filename = response.headers['content-disposition'].split('filename=')[1]
with open(os.path.join(download_locations[type], filename), 'wb') as f:
for chunk in response.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)

# update model list
if type == 'model': sd_models.list_models()
if response.status_code != 200:
raise Exception(f'Error: {response.status_code}')

filename = response.headers['content-disposition'].split('filename=')[1].strip('"')
dest = os.path.join(download_locations[type], filename)

if os.path.exists(dest):
log(f'File already exists: {dest}')
return (filename, False)

with open(dest, 'wb') as f:
for chunk in response.iter_content(chunk_size=4096):
if chunk: f.write(chunk)

log(f'Downloaded: {dest}')

return (filename, True)

async def run_model(name, url):
model = sd_models.get_closet_checkpoint_match(name)

if model is None:
(filename, downloaded) = await download(url, 'Checkpoint')
if downloaded: sd_models.list_models()
model = sd_models.get_closet_checkpoint_match(filename)
elif shared.opts.sd_model_checkpoint == model.title:
log('Model already loaded')
return model.filename
else:
filename = model.filename
log('Found model in model list')

if model is not None:
sd_models.load_model(model)
shared.opts.sd_model_checkpoint = model.title
shared.opts.save(shared.config_filename)
else: log('Could not find model in model list')

return filename

def download_model(url):
return download(url, 'model')

def download_textual_inversion(url):
return download(url, 'textual_inversion')
async def download_textual_inversion(url):
(filename) = await download(url, 'TextualInversion')
return filename

def download_aesthetic_gradient(url):
return download(url, 'aesthetic_gradient')
async def download_aesthetic_gradient(url):
(filename, downloaded) = await download(url, 'AestheticGradient')

def download_hypernetwork(url):
return download(url, 'hypernetwork')
return filename

async def download_hypernetwork(url):
(filename, downloaded) = await download(url, 'Hypernetwork')
shared.opts.sd_hypernetwork = filename
shared.reload_hypernetworks()

return filename

#endregion Downloading
50 changes: 50 additions & 0 deletions javascript/civitai.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
// #region [utils]
const log = (...args) => console.log(`[civitai]`, ...args);
const delay = (ms) => new Promise(resolve => setTimeout(resolve, ms));
const getElement = (selector, timeout = 10000) => new Promise((resolve, reject) => {
const interval = setInterval(() => {
const el = gradioApp().querySelector(selector);
timeout -= 100;
if (timeout < 0) {
reject('timeout');
clearInterval(interval);
} else if (el) {
resolve(el);
clearInterval(interval);
}
}, 100);
})
// #endregion


async function generate() {
const generateButton = await getElement('#txt2img_generate');
generateButton.click();
log('generating image');
}

async function handlePrompt(encodedPrompt, andGenerate = false) {
const prompt = atob(encodedPrompt);
log('injecting prompt', prompt);
const promptEl = await getElement('#txt2img_prompt textarea');
promptEl.value = prompt;

const pastePromptButton = await getElement('#paste');
pastePromptButton.click();
log('applying prompt');

if (andGenerate) {
await delay(3000);
await generate();
}
}

// Bootstrap
(async () => {
const searchParams = new URLSearchParams(location.search);
if (searchParams.has('civitai_prompt'))
handlePrompt(searchParams.get('civitai_prompt'), searchParams.has('civitai_generate'));

// clear search params
history.replaceState({}, document.title, location.pathname);
})()
3 changes: 1 addition & 2 deletions preload.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,4 @@
import argparse

def preload(parser: argparse.ArgumentParser):
parser.add_argument("--civitai-endpoint", type=str, help="Endpoint for interacting with a Civitai instance", default="https://civitai.com/api/v1")
parser.add_argument("--civitai-api", type=bool, help="Accept requests to install models", default=False)
parser.add_argument("--civitai-endpoint", type=str, help="Endpoint for interacting with a Civitai instance", default="https://civitai.com/api/v1")
47 changes: 33 additions & 14 deletions scripts/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,24 +4,43 @@
from fastapi import FastAPI

from modules import shared, script_callbacks as script_callbacks
from modules.sd_models import checkpoints_list

import extensions.sd_civitai_extension.civitai.api as civitai

def civitaiAPI(demo: gr.Blocks, app: FastAPI):
@app.get("/install/{id}")
async def install(id: str):
to_install = civitai.get_model_version(id)
print("Civitai Installing: " + to_install['name'])
url = to_install['downloadUrl']
type = to_install['type']
task = asyncio.create_task(civitai.download(url, type))
# To detect if the API is loaded
@app.get("/civitai/v1")
async def index():
return {"status": "success"}

try:
api_enabled = shared.cmd_options.civitai_api
except:
api_enabled = False
# To get a list of models
@app.get("/civitai/v1/models")
async def models():
return [{"name":x.model_name, "hash":x.hash} for x in checkpoints_list.values()]

if api_enabled:
script_callbacks.on_app_started(civitaiAPI)
print("Civitai API loaded")
# To get a list of hypernetworks
@app.get("/civitai/v1/hypernetworks")
async def hypernetworks():
return [civitai.parse_hypernetwork(name) for name in shared.hypernetworks]

# To download and select a model
@app.post("/civitai/v1/run/{id}")
async def run(id: str):
to_run = civitai.get_model_version(id)
to_run_name = f'{to_run["model"]["name"]} {to_run["name"]}'
civitai.log(f'Running: {to_run_name}')
primary_file = [x for x in to_run['files'] if x['primary'] == True][0]
name = primary_file['name']
url = to_run['downloadUrl']
type = to_run['model']['type']
if type == 'Checkpoint': await asyncio.create_task(civitai.run_model(name, url))
elif type == 'TextualInversion': await asyncio.create_task(civitai.download_textual_inversion(url))
elif type == 'AestheticGradient': await asyncio.create_task(civitai.download_aesthetic_gradient(url))
elif type == 'Hypernetwork': await asyncio.create_task(civitai.download_hypernetwork(url))

civitai.log(f'Loaded: {to_run_name}')
return {"status": "success"}

script_callbacks.on_app_started(civitaiAPI)
civitai.log("API loaded")
2 changes: 1 addition & 1 deletion scripts/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def on_ui_tabs():
civitai_page_size.change(fn=search_models, inputs=[civitai_query, civitai_sort, civitai_sort_period, civitai_tag, civitai_creator, civitai_page, civitai_page_size], outputs=[model_output, civitai_current_page])


script_callbacks.on_ui_tabs(on_ui_tabs)
# script_callbacks.on_ui_tabs(on_ui_tabs)

0 comments on commit 75fca87

Please sign in to comment.