Skip to content

Commit

Permalink
fixes for civitai downloading and model importing
Browse files Browse the repository at this point in the history
  • Loading branch information
w4ffl35 committed Dec 20, 2023
1 parent 5e0cc76 commit c1c4bfe
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 24 deletions.
9 changes: 8 additions & 1 deletion src/airunner/aihandler/download_civitai.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import tqdm
import requests
from json.decoder import JSONDecodeError
from airunner.aihandler.logger import Logger


class DownloadCivitAI:
Expand All @@ -9,7 +11,12 @@ class DownloadCivitAI:
def get_json(model_id):
url = f"https://civitai.com/api/v1/models/{model_id}"
response = requests.get(url)
json = response.json()

try:
json = response.json()
except JSONDecodeError:
Logger.error(f"Failed to decode JSON from {url}")
print(response)
return json

def download_model(self, url, file_name, size_kb, callback):
Expand Down
113 changes: 91 additions & 22 deletions src/airunner/widgets/model_manager/import_widget.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import threading

from airunner.data.models import AIModel
from airunner.data.models import AIModel, Lora, Embedding
from airunner.utils import get_session
from airunner.widgets.base_widget import BaseWidget
from airunner.widgets.model_manager.templates.import_ui import Ui_import_model_widget
Expand Down Expand Up @@ -62,38 +62,81 @@ def download_model(self):
diffuser_model_version = model_version["baseModel"]
pipeline_class = self.settings_manager.get_pipeline_classname(pipeline_action, diffuser_model_version, category)
diffuser_model_versions = self.settings_manager.model_versions
file_path = self.download_path(file, diffuser_model_version) # path is the download path of the model
model_type = model_data["type"]
file_path = self.download_path(file, diffuser_model_version, pipeline_action, model_type) # path is the download path of the model

print("Name", name)
print("Path", file_path)
print("Branch", "main")
print("Version", diffuser_model_version)
print("Category", category)
print("Pipeline Action", pipeline_action)

trained_words = model_version.get("trained_words", [])
trained_words = ",".join(trained_words)

session = get_session()
model_exists = session.query(AIModel).filter_by(
name=name,
path=file_path,
branch="main",
version=diffuser_model_version,
category=category,
pipeline_action=pipeline_action,
).first()
if not model_exists:
new_model = AIModel(
if model_type == "Checkpoint":
model_exists = session.query(AIModel).filter_by(
name=name,
path=file_path,
branch="main",
version=diffuser_model_version,
category=category,
pipeline_action=pipeline_action,
enabled=True,
is_default=False
)
session.add(new_model)
session.commit()
).first()
if not model_exists:
new_model = AIModel(
name=name,
path=file_path,
branch="main",
version=diffuser_model_version,
category=category,
pipeline_action=pipeline_action,
enabled=True,
is_default=False
)
session.add(new_model)
session.commit()
elif model_type == "LORA":
lora_exists = session.query(Lora).filter_by(
name=name,
path=file_path,
).first()
if not lora_exists:
new_lora = Lora(
name=name,
path=file_path,
scale=1,
enabled=True,
loaded=False,
trigger_word=trained_words,
)
session.add(new_lora)
session.commit()
elif model_type == "TextualInversion":
embedding_exists = session.query(Embedding).filter_by(
name=name,
path=file_path,
).first()
if not embedding_exists:
new_embedding = Embedding(
name=name,
path=file_path,
enabled=True,
loaded=False,
trigger_word=trained_words,
)
session.add(new_embedding)
session.commit()
elif model_type == "VAE":
# todo save vae here
pass
elif model_type == "Controlnet":
# todo save controlnet here
pass
elif model_type == "Poses":
# todo save poses here
pass

print("starting download")
self.download_model_thread(download_url, file_path, size_kb)
Expand Down Expand Up @@ -163,8 +206,33 @@ def import_models(self):
def model_version_changed(self, index):
self.set_model_form_data()

def download_path(self, file, version):
path = self.settings_manager.path_settings.model_base_path
def download_path(self, file, version, pipeline_action, model_type):

if model_type == "LORA":
path = self.settings_manager.path_settings.lora_model_path
elif model_type == "Checkpoint":
if pipeline_action == "txt2img":
path = self.settings_manager.path_settings.txt2img_model_path
elif pipeline_action == "outpaint":
path = self.settings_manager.path_settings.outpaint_model_path
elif pipeline_action == "upscale":
path = self.settings_manager.path_settings.upscale_model_path
elif pipeline_action == "depth2img":
path = self.settings_manager.path_settings.depth2img_model_path
elif pipeline_action == "pix2pix":
path = self.settings_manager.path_settings.pix2pix_model_path
elif model_type == "TextualInversion":
path = self.settings_manager.path_settings.embeddings_model_path
elif model_type == "VAE":
# todo save vae here
pass
elif model_type == "Controlnet":
# todo save controlnet here
pass
elif model_type == "Poses":
# todo save poses here
pass

file_name = file["name"]
return f"{path}/{version}/{file_name}"

Expand All @@ -184,7 +252,7 @@ def set_model_form_data(self):
diffuser_model_version = model_version["baseModel"]
pipeline_class = self.settings_manager.get_pipeline_classname(pipeline_action, diffuser_model_version, category)
diffuser_model_versions = self.settings_manager.model_versions
path = self.download_path(file, diffuser_model_version) # path is the download path of the model
path = self.download_path(file, diffuser_model_version, pipeline_action, self.current_model_data["type"]) # path is the download path of the model

self.ui.model_form.set_model_form_data(
categories,
Expand All @@ -196,7 +264,8 @@ def set_model_form_data(self):
diffuser_model_version,
path,
self.current_model_data["name"],
model_data=self.current_model_data
model_data=self.current_model_data,
model_type=self.current_model_data["type"]
)

if self.is_civitai:
Expand Down
8 changes: 7 additions & 1 deletion src/airunner/widgets/model_manager/model_form_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ def set_model_form_data(
diffuser_model_version,
path,
model_name,
model_data
model_data,
model_type
):
self.ui.category.clear()
self.ui.category.addItems(categories)
Expand All @@ -30,10 +31,15 @@ def set_model_form_data(
self.ui.diffuser_model_version.clear()
self.ui.diffuser_model_version.addItems(diffuser_model_versions)
self.ui.diffuser_model_version.setCurrentText(diffuser_model_version)
self.ui.model_type.clear()
self.ui.model_type.addItems(["Checkpoint", "LORA", "Embedding", "VAE", "Controlnet", "Pose"])
self.ui.pipeline_class_line_edit.setText(pipeline_class)
self.ui.enabled.setChecked(True)
self.ui.path_line_edit.setText(path)

# set current model type
self.ui.model_type.setCurrentText(model_type)

# clear the table
self.ui.model_data_table.clearContents()
self.ui.model_data_table.setRowCount(5)
Expand Down

0 comments on commit c1c4bfe

Please sign in to comment.