diff --git a/src/Core/Settings.cs b/src/Core/Settings.cs index 05d7376f..34c0791b 100644 --- a/src/Core/Settings.cs +++ b/src/Core/Settings.cs @@ -273,6 +273,9 @@ public class FileFormatData : AutoConfiguration [ConfigComment("If true, folders will be discarded from starred image paths.")] public bool StarNoFolders = false; + [ConfigComment("Whether to automatically use the base model type as part of the model path when downloading using the 'Model Download' tool")] + public bool GroupDownloadedModelsByBaseType = false; + public class ThemesImpl : SettingsOptionsAttribute.AbstractImpl { public override string[] GetOptions => [.. Program.Web.RegisteredThemes.Keys]; diff --git a/src/WebAPI/ModelsAPI.cs b/src/WebAPI/ModelsAPI.cs index ceeb00e6..04cb0fc4 100644 --- a/src/WebAPI/ModelsAPI.cs +++ b/src/WebAPI/ModelsAPI.cs @@ -505,8 +505,23 @@ public static async Task DoModelDownloadWS(Session session, WebSocket w } try { - string outPath = $"{handler.FolderPaths[0]}/{name}.safetensors"; - if (File.Exists(outPath)) + string baseModel = ""; + if (!string.IsNullOrWhiteSpace(metadata)) + { + JObject metadataObj = JObject.Parse(metadata); + baseModel = metadataObj["modelspec.baseModel"]?.ToString() ?? ""; + } + string modelOutPath; + if (!string.IsNullOrWhiteSpace(baseModel) && session.User.Settings.GroupDownloadedModelsByBaseType) + { + modelOutPath = $"{handler.FolderPaths[0]}/{baseModel}/{name}.safetensors"; + } + else + { + modelOutPath = $"{handler.FolderPaths[0]}/{name}.safetensors"; + } + modelOutPath = Utilities.StrictFilenameClean(modelOutPath); + if (File.Exists(modelOutPath)) { await ws.SendJson(new JObject() { ["error"] = "Model at that save path already exists." }, API.WebsocketTimeout); return null; @@ -516,7 +531,7 @@ public static async Task DoModelDownloadWS(Session session, WebSocket w { File.Delete(tempPath); } - Directory.CreateDirectory(Path.GetDirectoryName(outPath)); + Directory.CreateDirectory(Path.GetDirectoryName(modelOutPath)); using CancellationTokenSource canceller = new(); Task downloading = Utilities.DownloadFile(url, tempPath, (progress, total, perSec) => { @@ -556,10 +571,20 @@ public static async Task DoModelDownloadWS(Session session, WebSocket w } }); await downloading; - File.Move(tempPath, outPath); + File.Move(tempPath, modelOutPath); if (!string.IsNullOrWhiteSpace(metadata)) { - File.WriteAllText($"{handler.FolderPaths[0]}/{name}.json", metadata); + string metadataOutPath; + if (!string.IsNullOrWhiteSpace(baseModel) && session.User.Settings.GroupDownloadedModelsByBaseType) + { + metadataOutPath = $"{handler.FolderPaths[0]}/{baseModel}/{name}.json"; + } + else + { + metadataOutPath = $"{handler.FolderPaths[0]}/{name}.json"; + } + metadataOutPath = Utilities.StrictFilenameClean(metadataOutPath); + File.WriteAllText(metadataOutPath, metadata); } await ws.SendJson(new JObject() { ["success"] = true }, API.WebsocketTimeout); } diff --git a/src/wwwroot/js/genpage/utiltab.js b/src/wwwroot/js/genpage/utiltab.js index 7ba54bf9..95496104 100644 --- a/src/wwwroot/js/genpage/utiltab.js +++ b/src/wwwroot/js/genpage/utiltab.js @@ -259,6 +259,7 @@ class ModelDownloaderUtil { 'modelspec.author': rawData.creator.username, 'modelspec.description': `From ${url}\n${rawVersion.description || ''}\n${rawData.description}\n`, 'modelspec.date': rawVersion.createdAt, + 'modelspec.baseModel': rawVersion.baseModel, }; if (rawVersion.trainedWords) { metadata['modelspec.trigger_phrase'] = rawVersion.trainedWords.join(", ");