diff --git a/README.md b/README.md index 2de447d5..692c13db 100644 --- a/README.md +++ b/README.md @@ -31,6 +31,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - Faster and memory efficient latent decoding with [TAESD](https://github.com/madebyollin/taesd) - Upscale images generated with [ESRGAN](https://github.com/xinntao/Real-ESRGAN) - VAE tiling processing for reduce memory usage +- Textual Inversion support (embeddings) - Sampling method - `Euler A` - `Euler` @@ -53,9 +54,7 @@ Inference of [Stable Diffusion](https://github.com/CompVis/stable-diffusion) in - [ ] More sampling methods - [ ] Make inference faster - The current implementation of ggml_conv_2d is slow and has high memory usage - - Implement Winograd Convolution 2D for 3x3 kernel filtering - [ ] Continuing to reduce memory usage (quantizing the weights of ggml_conv_2d) -- [ ] Implement Textual Inversion (embeddings) - [ ] Implement Inpainting support - [ ] k-quants support diff --git a/examples/cli/main.cpp b/examples/cli/main.cpp index 6264d6e2..0788aeec 100644 --- a/examples/cli/main.cpp +++ b/examples/cli/main.cpp @@ -60,6 +60,7 @@ struct SDParams { std::string vae_path; std::string taesd_path; std::string esrgan_path; + std::string embeddings_path; ggml_type wtype = GGML_TYPE_COUNT; std::string lora_model_dir; std::string output_path = "output.png"; @@ -121,6 +122,7 @@ void print_usage(int argc, const char* argv[]) { printf(" -m, --model [MODEL] path to model\n"); printf(" --vae [VAE] path to vae\n"); printf(" --taesd [TAESD_PATH] path to taesd. Using Tiny AutoEncoder for fast decoding (low quality)\n"); + printf(" --embd-dir [EMBEDDING_PATH] path to embeddings.\n"); printf(" --upscale-model [ESRGAN_PATH] path to esrgan model. Upscale images after generate, just RealESRGAN_x4plus_anime_6B supported by now.\n"); printf(" --type [TYPE] weight type (f32, f16, q4_0, q4_1, q5_0, q5_1, q8_0)\n"); printf(" If not specified, the default is the type of the weight file.\n"); @@ -201,6 +203,12 @@ void parse_args(int argc, const char** argv, SDParams& params) { break; } params.esrgan_path = argv[i]; + } else if (arg == "--embd-dir") { + if (++i >= argc) { + invalid_arg = true; + break; + } + params.embeddings_path = argv[i]; } else if (arg == "--type") { if (++i >= argc) { invalid_arg = true; @@ -484,7 +492,7 @@ int main(int argc, const char* argv[]) { StableDiffusion sd(params.n_threads, vae_decode_only, params.taesd_path, params.esrgan_path, true, params.vae_tiling, params.lora_model_dir, params.rng_type); - if (!sd.load_from_file(params.model_path, params.vae_path, params.wtype, params.schedule, params.clip_skip)) { + if (!sd.load_from_file(params.model_path, params.vae_path, params.embeddings_path, params.wtype, params.schedule, params.clip_skip)) { return 1; } diff --git a/stable-diffusion.cpp b/stable-diffusion.cpp index 4386e9e5..c4a812ec 100644 --- a/stable-diffusion.cpp +++ b/stable-diffusion.cpp @@ -79,11 +79,22 @@ std::string sd_get_system_info() { return ss.str(); } -static void ggml_log_callback_default(ggml_log_level level, const char* text, void* user_data) { - (void)level; - (void)user_data; - fputs(text, stderr); - fflush(stderr); +std::string ltrim(const std::string& s) { + auto it = std::find_if(s.begin(), s.end(), [](int ch) { + return !std::isspace(ch); + }); + return std::string(it, s.end()); +} + +std::string rtrim(const std::string& s) { + auto it = std::find_if(s.rbegin(), s.rend(), [](int ch) { + return !std::isspace(ch); + }); + return std::string(s.begin(), it.base()); +} + +std::string trim(const std::string& s) { + return rtrim(ltrim(s)); } void ggml_tensor_set_f32_randn(struct ggml_tensor* tensor, std::shared_ptr rng) { @@ -608,6 +619,15 @@ std::pair, std::string> extract_and_remov return std::make_pair(filename2multiplier, text); } +void ggml_backend_tensor_set_and_sync(ggml_backend_t backend, struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { + #ifdef SD_USE_CUBLAS + ggml_backend_tensor_set_async(backend, tensor, data, offset, size); + ggml_backend_synchronize(backend); + #else + ggml_backend_tensor_set(tensor, data, offset, size); + #endif +} + void ggml_backend_tensor_get_and_sync(ggml_backend_t backend, const struct ggml_tensor * tensor, void * data, size_t offset, size_t size) { #ifdef SD_USE_CUBLAS ggml_backend_tensor_get_async(backend, tensor, data, offset, size); @@ -655,242 +675,6 @@ std::vector> bytes_to_unicode() { return byte_unicode_pairs; } -// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py -class CLIPTokenizer { -private: - SDVersion version = VERSION_1_x; - std::map byte_encoder; - std::map encoder; - std::map, int> bpe_ranks; - std::regex pat; - - static std::string strip(const std::string& str) { - std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); - std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); - - if (start == std::string::npos) { - // String contains only whitespace characters - return ""; - } - - return str.substr(start, end - start + 1); - } - - static std::string whitespace_clean(std::string text) { - text = std::regex_replace(text, std::regex(R"(\s+)"), " "); - text = strip(text); - return text; - } - - static std::set> get_pairs(const std::vector& subwords) { - std::set> pairs; - if (subwords.size() == 0) { - return pairs; - } - std::u32string prev_subword = subwords[0]; - for (int i = 1; i < subwords.size(); i++) { - std::u32string subword = subwords[i]; - std::pair pair(prev_subword, subword); - pairs.insert(pair); - prev_subword = subword; - } - return pairs; - } - -public: - CLIPTokenizer(SDVersion version = VERSION_1_x) - : version(version) {} - - void load_from_merges(const std::string& merges_utf8_str) { - auto byte_unicode_pairs = bytes_to_unicode(); - byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); - // for (auto & pair: byte_unicode_pairs) { - // std::cout << pair.first << ": " << pair.second << std::endl; - // } - std::vector merges; - size_t start = 0; - size_t pos; - std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); - while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { - merges.push_back(merges_utf32_str.substr(start, pos - start)); - start = pos + 1; - } - // LOG_DEBUG("merges size %llu", merges.size()); - GGML_ASSERT(merges.size() == 48895); - merges = std::vector(merges.begin() + 1, merges.end()); - std::vector> merge_pairs; - for (const auto& merge : merges) { - size_t space_pos = merge.find(' '); - merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); - // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); - } - std::vector vocab; - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second); - } - for (const auto& pair : byte_unicode_pairs) { - vocab.push_back(pair.second + utf8_to_utf32("")); - } - for (const auto& merge : merge_pairs) { - vocab.push_back(merge.first + merge.second); - } - vocab.push_back(utf8_to_utf32("<|startoftext|>")); - vocab.push_back(utf8_to_utf32("<|endoftext|>")); - LOG_DEBUG("vocab size: %llu", vocab.size()); - int i = 0; - for (const auto& token : vocab) { - encoder[token] = i++; - } - - int rank = 0; - for (const auto& merge : merge_pairs) { - bpe_ranks[merge] = rank++; - } - }; - - std::u32string bpe(const std::u32string& token) { - std::vector word; - - for (int i = 0; i < token.size() - 1; i++) { - word.emplace_back(1, token[i]); - } - word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); - - std::set> pairs = get_pairs(word); - - if (pairs.empty()) { - return token + utf8_to_utf32(""); - } - - while (true) { - auto min_pair_iter = std::min_element(pairs.begin(), - pairs.end(), - [&](const std::pair& a, - const std::pair& b) { - if (bpe_ranks.find(a) == bpe_ranks.end()) { - return false; - } else if (bpe_ranks.find(b) == bpe_ranks.end()) { - return true; - } - return bpe_ranks.at(a) < bpe_ranks.at(b); - }); - - const std::pair& bigram = *min_pair_iter; - - if (bpe_ranks.find(bigram) == bpe_ranks.end()) { - break; - } - - std::u32string first = bigram.first; - std::u32string second = bigram.second; - std::vector new_word; - int32_t i = 0; - - while (i < word.size()) { - auto it = std::find(word.begin() + i, word.end(), first); - if (it == word.end()) { - new_word.insert(new_word.end(), word.begin() + i, word.end()); - break; - } - new_word.insert(new_word.end(), word.begin() + i, it); - i = static_cast(std::distance(word.begin(), it)); - - if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { - new_word.push_back(first + second); - i += 2; - } else { - new_word.push_back(word[i]); - i += 1; - } - } - - word = new_word; - - if (word.size() == 1) { - break; - } - pairs = get_pairs(word); - } - - std::u32string result; - for (int i = 0; i < word.size(); i++) { - result += word[i]; - if (i != word.size() - 1) { - result += utf8_to_utf32(" "); - } - } - - return result; - } - - std::vector tokenize(std::string text, size_t max_length = 0, bool padding = false) { - std::vector tokens = encode(text); - tokens.insert(tokens.begin(), BOS_TOKEN_ID); - if (max_length > 0) { - if (tokens.size() > max_length - 1) { - tokens.resize(max_length - 1); - tokens.push_back(EOS_TOKEN_ID); - } else { - tokens.push_back(EOS_TOKEN_ID); - if (padding) { - int pad_token_id = PAD_TOKEN_ID; - if (version == VERSION_2_x) { - pad_token_id = 0; - } - tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); - } - } - } - return tokens; - } - - std::vector encode(std::string text) { - std::string original_text = text; - std::vector bpe_tokens; - text = whitespace_clean(text); - std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); }); - - std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", - std::regex::icase); - - std::smatch matches; - std::string str = text; - std::vector token_strs; - while (std::regex_search(str, matches, pat)) { - for (auto& token : matches) { - std::string token_str = token.str(); - std::u32string utf32_token; - for (int i = 0; i < token_str.length(); i++) { - char b = token_str[i]; - utf32_token += byte_encoder[b]; - } - auto bpe_strs = bpe(utf32_token); - size_t start = 0; - size_t pos; - while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { - auto bpe_str = bpe_strs.substr(start, pos - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - - start = pos + 1; - } - auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); - bpe_tokens.push_back(encoder[bpe_str]); - token_strs.push_back(utf32_to_utf8(bpe_str)); - } - str = matches.suffix(); - } - std::stringstream ss; - ss << "["; - for (auto token : token_strs) { - ss << "\"" << token << "\", "; - } - ss << "]"; - LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); - return bpe_tokens; - } -}; - // Ref: https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/cad87bf4e3e0b0a759afa94e933527c3123d59bc/modules/prompt_parser.py#L345 // // Parses a string with attention tokens and returns a list of pairs: text and its associated weight. @@ -1194,12 +978,17 @@ struct CLIPTextModel { struct ggml_tensor* token_embed_weight; struct ggml_tensor* position_embed_weight; + struct ggml_tensor* token_embed_custom; + // transformer std::vector resblocks; struct ggml_tensor* final_ln_w; struct ggml_tensor* final_ln_b; struct ggml_tensor* text_projection; + std::string embd_dir; + int32_t num_custom_embeddings = 0; + std::vector readed_embeddings; CLIPTextModel(CLIPVersion version = OPENAI_CLIP_VIT_L_14, int clip_skip = 1, @@ -1236,6 +1025,9 @@ struct CLIPTextModel { mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(GGML_TYPE_I32); // position_ids mem_size += hidden_size * vocab_size * ggml_type_sizef(wtype); // token_embed_weight mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // position_embed_weight + if(version == OPENAI_CLIP_VIT_L_14) { + mem_size += hidden_size * max_position_embeddings * ggml_type_sizef(wtype); // token_embed_custom + } for (int i = 0; i < num_hidden_layers; i++) { mem_size += resblocks[i].calculate_mem_size(wtype); } @@ -1246,6 +1038,39 @@ struct CLIPTextModel { return static_cast(mem_size); } + bool load_embedding(std::string embd_name, std::string embd_path, std::vector &bpe_tokens) { + // the order matters + ModelLoader model_loader; + if(!model_loader.init_from_file(embd_path)) { + LOG_ERROR("embedding '%s' failed", embd_name.c_str()); + return false; + } + struct ggml_init_params params; + params.mem_size = 32 * 1024; // max for custom embeddings 32 KB + params.mem_buffer = NULL; + params.no_alloc = false; + struct ggml_context* embd_ctx = ggml_init(params); + struct ggml_tensor* embd = NULL; + auto on_load = [&](const TensorStorage& tensor_storage, ggml_tensor** dst_tensor) { + if(tensor_storage.ne[0] != hidden_size) { + LOG_DEBUG("embedding wrong hidden size, got %i, expected %i", tensor_storage.ne[0], hidden_size); + return false; + } + embd = ggml_new_tensor_2d(embd_ctx, token_embed_weight->type, hidden_size, tensor_storage.n_dims > 1 ? tensor_storage.ne[1] : 1); + *dst_tensor = embd; + return true; + }; + model_loader.load_tensors(on_load, NULL); + ggml_backend_tensor_set(token_embed_custom, embd->data, num_custom_embeddings * hidden_size * ggml_type_size(token_embed_custom->type), ggml_nbytes(embd)); + readed_embeddings.push_back(embd_name); + for(int i = 0; i < embd->ne[1]; i++) { + bpe_tokens.push_back(vocab_size + num_custom_embeddings); + // LOG_DEBUG("new custom token: %i", vocab_size + num_custom_embeddings); + num_custom_embeddings++; + } + return true; + } + void map_by_name(std::map& tensors, const std::string prefix) { tensors[prefix + "embeddings.token_embedding.weight"] = token_embed_weight; tensors[prefix + "embeddings.position_embedding.weight"] = position_embed_weight; @@ -1260,14 +1085,14 @@ struct CLIPTextModel { } } - struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, uint32_t max_token_idx = 0, bool return_pooled = false) { + struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* tkn_embeddings, uint32_t max_token_idx = 0, bool return_pooled = false) { // input_ids: [N, n_token] GGML_ASSERT(input_ids->ne[0] <= position_ids->ne[0]); // token_embedding + position_embedding struct ggml_tensor* x; x = ggml_add(ctx0, - ggml_get_rows(ctx0, token_embed_weight, input_ids), + ggml_get_rows(ctx0, tkn_embeddings == NULL ? token_embed_weight : tkn_embeddings, input_ids), ggml_get_rows(ctx0, position_embed_weight, ggml_view_1d(ctx0, position_ids, input_ids->ne[0], 0))); // [N, n_token, hidden_size] @@ -1317,6 +1142,10 @@ struct CLIPTextModel { text_projection = ggml_new_tensor_2d(ctx, GGML_TYPE_F32, projection_dim, hidden_size); } + if(version == OPENAI_CLIP_VIT_L_14) { + token_embed_custom = ggml_new_tensor_2d(ctx, wtype, hidden_size, max_position_embeddings); + } + // alloc all tensors linked to this context for (struct ggml_tensor* t = ggml_get_first_tensor(ctx); t != NULL; t = ggml_get_next_tensor(ctx, t)) { if (t->data == NULL) { @@ -1333,8 +1162,264 @@ struct CLIPTextModel { for (int i = 0; i < max_position_embeddings; i++) { pos_temp.push_back(i); } - ggml_backend_tensor_set(position_ids, pos_temp.data(), 0, ggml_nbytes(position_ids)); + ggml_backend_tensor_set_and_sync(backend, position_ids, pos_temp.data(), 0, ggml_nbytes(position_ids)); + } + } +}; + + + +// Ref: https://github.com/openai/CLIP/blob/main/clip/simple_tokenizer.py +class CLIPTokenizer { +private: + SDVersion version = VERSION_1_x; + std::map byte_encoder; + std::map encoder; + std::map, int> bpe_ranks; + std::regex pat; + + static std::string strip(const std::string& str) { + std::string::size_type start = str.find_first_not_of(" \t\n\r\v\f"); + std::string::size_type end = str.find_last_not_of(" \t\n\r\v\f"); + + if (start == std::string::npos) { + // String contains only whitespace characters + return ""; + } + + return str.substr(start, end - start + 1); + } + + static std::string whitespace_clean(std::string text) { + text = std::regex_replace(text, std::regex(R"(\s+)"), " "); + text = strip(text); + return text; + } + + static std::set> get_pairs(const std::vector& subwords) { + std::set> pairs; + if (subwords.size() == 0) { + return pairs; + } + std::u32string prev_subword = subwords[0]; + for (int i = 1; i < subwords.size(); i++) { + std::u32string subword = subwords[i]; + std::pair pair(prev_subword, subword); + pairs.insert(pair); + prev_subword = subword; + } + return pairs; + } + +public: + CLIPTokenizer(SDVersion version = VERSION_1_x) + : version(version) {} + + void load_from_merges(const std::string& merges_utf8_str) { + auto byte_unicode_pairs = bytes_to_unicode(); + byte_encoder = std::map(byte_unicode_pairs.begin(), byte_unicode_pairs.end()); + // for (auto & pair: byte_unicode_pairs) { + // std::cout << pair.first << ": " << pair.second << std::endl; + // } + std::vector merges; + size_t start = 0; + size_t pos; + std::u32string merges_utf32_str = utf8_to_utf32(merges_utf8_str); + while ((pos = merges_utf32_str.find('\n', start)) != std::string::npos) { + merges.push_back(merges_utf32_str.substr(start, pos - start)); + start = pos + 1; + } + // LOG_DEBUG("merges size %llu", merges.size()); + GGML_ASSERT(merges.size() == 48895); + merges = std::vector(merges.begin() + 1, merges.end()); + std::vector> merge_pairs; + for (const auto& merge : merges) { + size_t space_pos = merge.find(' '); + merge_pairs.emplace_back(merge.substr(0, space_pos), merge.substr(space_pos + 1)); + // LOG_DEBUG("%s", utf32_to_utf8(merge.substr(space_pos + 1)).c_str()); + } + std::vector vocab; + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second); + } + for (const auto& pair : byte_unicode_pairs) { + vocab.push_back(pair.second + utf8_to_utf32("")); + } + for (const auto& merge : merge_pairs) { + vocab.push_back(merge.first + merge.second); + } + vocab.push_back(utf8_to_utf32("<|startoftext|>")); + vocab.push_back(utf8_to_utf32("<|endoftext|>")); + LOG_DEBUG("vocab size: %llu", vocab.size()); + int i = 0; + for (const auto& token : vocab) { + encoder[token] = i++; + } + + int rank = 0; + for (const auto& merge : merge_pairs) { + bpe_ranks[merge] = rank++; + } + }; + + std::u32string bpe(const std::u32string& token) { + std::vector word; + + for (int i = 0; i < token.size() - 1; i++) { + word.emplace_back(1, token[i]); + } + word.push_back(token.substr(token.size() - 1) + utf8_to_utf32("")); + + std::set> pairs = get_pairs(word); + + if (pairs.empty()) { + return token + utf8_to_utf32(""); + } + + while (true) { + auto min_pair_iter = std::min_element(pairs.begin(), + pairs.end(), + [&](const std::pair& a, + const std::pair& b) { + if (bpe_ranks.find(a) == bpe_ranks.end()) { + return false; + } else if (bpe_ranks.find(b) == bpe_ranks.end()) { + return true; + } + return bpe_ranks.at(a) < bpe_ranks.at(b); + }); + + const std::pair& bigram = *min_pair_iter; + + if (bpe_ranks.find(bigram) == bpe_ranks.end()) { + break; + } + + std::u32string first = bigram.first; + std::u32string second = bigram.second; + std::vector new_word; + int32_t i = 0; + + while (i < word.size()) { + auto it = std::find(word.begin() + i, word.end(), first); + if (it == word.end()) { + new_word.insert(new_word.end(), word.begin() + i, word.end()); + break; + } + new_word.insert(new_word.end(), word.begin() + i, it); + i = static_cast(std::distance(word.begin(), it)); + + if (word[i] == first && i < static_cast(word.size()) - 1 && word[i + 1] == second) { + new_word.push_back(first + second); + i += 2; + } else { + new_word.push_back(word[i]); + i += 1; + } + } + + word = new_word; + + if (word.size() == 1) { + break; + } + pairs = get_pairs(word); + } + + std::u32string result; + for (int i = 0; i < word.size(); i++) { + result += word[i]; + if (i != word.size() - 1) { + result += utf8_to_utf32(" "); + } } + + return result; + } + + std::vector tokenize(std::string text, CLIPTextModel text_model, size_t max_length = 0, bool padding = false) { + std::vector tokens = encode(text, text_model); + tokens.insert(tokens.begin(), BOS_TOKEN_ID); + if (max_length > 0) { + if (tokens.size() > max_length - 1) { + tokens.resize(max_length - 1); + tokens.push_back(EOS_TOKEN_ID); + } else { + tokens.push_back(EOS_TOKEN_ID); + if (padding) { + int pad_token_id = PAD_TOKEN_ID; + if (version == VERSION_2_x) { + pad_token_id = 0; + } + tokens.insert(tokens.end(), max_length - tokens.size(), pad_token_id); + } + } + } + return tokens; + } + + std::vector encode(std::string text, CLIPTextModel text_model) { + std::string original_text = text; + std::vector bpe_tokens; + text = whitespace_clean(text); + std::transform(text.begin(), text.end(), text.begin(), [](unsigned char c) { return std::tolower(c); }); + + std::regex pat(R"(<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[[:alpha:]]+|[[:digit:]]|[^[:space:][:alpha:][:digit:]]+)", + std::regex::icase); + + std::smatch matches; + std::string str = text; + std::vector token_strs; + while (std::regex_search(str, matches, pat)) { + size_t word_end = str.find(","); + std::string embd_name = word_end == std::string::npos ? str : str.substr(0, word_end); + embd_name = trim(embd_name); + std::string embd_path = path_join(text_model.embd_dir, embd_name + ".pt"); + if(!file_exists(embd_path)) { + embd_path = path_join(text_model.embd_dir, embd_name + ".ckpt"); + } + if(!file_exists(embd_path)) { + embd_path = path_join(text_model.embd_dir, embd_name + ".safetensors"); + } + if(file_exists(embd_path)) { + if(text_model.load_embedding(embd_name, embd_path, bpe_tokens)) { + if(word_end != std::string::npos) { + str = str.substr(word_end); + } + continue; + } + } + for (auto& token : matches) { + std::string token_str = token.str(); + std::u32string utf32_token; + for (int i = 0; i < token_str.length(); i++) { + char b = token_str[i]; + utf32_token += byte_encoder[b]; + } + auto bpe_strs = bpe(utf32_token); + size_t start = 0; + size_t pos; + while ((pos = bpe_strs.find(' ', start)) != std::u32string::npos) { + auto bpe_str = bpe_strs.substr(start, pos - start); + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + + start = pos + 1; + } + auto bpe_str = bpe_strs.substr(start, bpe_strs.size() - start); + bpe_tokens.push_back(encoder[bpe_str]); + token_strs.push_back(utf32_to_utf8(bpe_str)); + } + str = matches.suffix(); + } + std::stringstream ss; + ss << "["; + for (auto token : token_strs) { + ss << "\"" << token << "\", "; + } + ss << "]"; + LOG_DEBUG("split prompt \"%s\" to tokens %s", original_text.c_str(), ss.str().c_str()); + return bpe_tokens; } }; @@ -1344,10 +1429,10 @@ struct FrozenCLIPEmbedder { CLIPTextModel text_model; struct ggml_tensor* forward(struct ggml_context* ctx, struct ggml_allocr* allocr, const std::string& prompt) { - std::vector tokens = tokenizer.tokenize(prompt, text_model.max_position_embeddings, true); + std::vector tokens = tokenizer.tokenize(prompt, text_model, text_model.max_position_embeddings, true); struct ggml_tensor* input_ids = ggml_new_tensor_1d(ctx, GGML_TYPE_I32, tokens.size()); memcpy(input_ids->data, tokens.data(), tokens.size() * ggml_element_size(input_ids)); - struct ggml_tensor* hidden_states = text_model.forward(ctx, input_ids); + struct ggml_tensor* hidden_states = text_model.forward(ctx, input_ids, NULL); return hidden_states; } }; @@ -1405,11 +1490,11 @@ struct FrozenCLIPEmbedderWithCustomWords { } } - struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, uint32_t max_token_idx = 0, bool return_pooled = false) { + struct ggml_tensor* forward(struct ggml_context* ctx0, struct ggml_tensor* input_ids, struct ggml_tensor* input_ids2, struct ggml_tensor* embeddings, uint32_t max_token_idx = 0, bool return_pooled = false) { if (return_pooled) { - return text_model2.forward(ctx0, input_ids2, max_token_idx, return_pooled); + return text_model2.forward(ctx0, input_ids2, NULL, max_token_idx, return_pooled); } - auto hidden_states = text_model.forward(ctx0, input_ids); // [N, n_token, hidden_size] + auto hidden_states = text_model.forward(ctx0, input_ids, embeddings); // [N, n_token, hidden_size] // LOG_DEBUG("hidden_states: %d %d %d %d %d", hidden_states->n_dims, hidden_states->ne[0], hidden_states->ne[1], hidden_states->ne[2], hidden_states->ne[3]); if (version == VERSION_XL) { hidden_states = ggml_reshape_4d(ctx0, @@ -1420,7 +1505,7 @@ struct FrozenCLIPEmbedderWithCustomWords { hidden_states->ne[3]); hidden_states = ggml_cont(ctx0, ggml_permute(ctx0, hidden_states, 2, 0, 1, 3)); - auto hidden_states2 = text_model2.forward(ctx0, input_ids2); // [N, n_token, hidden_size2] + auto hidden_states2 = text_model2.forward(ctx0, input_ids2, NULL); // [N, n_token, hidden_size2] hidden_states2 = ggml_reshape_4d(ctx0, hidden_states2, hidden_states2->ne[0], @@ -1462,7 +1547,7 @@ struct FrozenCLIPEmbedderWithCustomWords { for (const auto& item : parsed_attention) { const std::string& curr_text = item.first; float curr_weight = item.second; - std::vector curr_tokens = tokenizer.encode(curr_text); + std::vector curr_tokens = tokenizer.encode(curr_text, text_model); tokens.insert(tokens.end(), curr_tokens.begin(), curr_tokens.end()); weights.insert(weights.end(), curr_tokens.size(), curr_weight); } @@ -1563,7 +1648,7 @@ struct FrozenCLIPEmbedderWithCustomWords { ggml_allocr_alloc(allocr, input_ids); if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set(input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids)); + ggml_backend_tensor_set_and_sync(backend, input_ids, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids)); } struct ggml_tensor* input_ids2 = NULL; @@ -1585,11 +1670,30 @@ struct FrozenCLIPEmbedderWithCustomWords { // printf("\n"); if (!ggml_allocr_is_measure(allocr)) { - ggml_backend_tensor_set(input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2)); + ggml_backend_tensor_set_and_sync(backend, input_ids2, tokens.data(), 0, tokens.size() * ggml_element_size(input_ids2)); } } + struct ggml_tensor* embeddings = NULL; - struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, max_token_idx, return_pooled); + if(version != VERSION_XL) { + embeddings = ggml_new_tensor_2d(ctx0, wtype, text_model.hidden_size, text_model.vocab_size + text_model.num_custom_embeddings /* custom placeholder */); + ggml_allocr_alloc(allocr, embeddings); + + if (!ggml_allocr_is_measure(allocr)) { + // really bad, there is memory inflexibility (this is for host<->device memory conflicts) + void* freeze_data = malloc(ggml_nbytes(text_model.token_embed_weight)); + ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_weight, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); + ggml_backend_tensor_set_and_sync(backend, embeddings, freeze_data, 0, ggml_nbytes(text_model.token_embed_weight)); + free(freeze_data); + // concatenate custom embeddings + void* custom_data = malloc(ggml_nbytes(text_model.token_embed_custom)); + ggml_backend_tensor_get_and_sync(backend, text_model.token_embed_custom, custom_data, 0, ggml_nbytes(text_model.token_embed_custom)); + ggml_backend_tensor_set_and_sync(backend, embeddings, custom_data, ggml_nbytes(text_model.token_embed_weight), text_model.num_custom_embeddings * text_model.hidden_size * ggml_type_size(wtype)); + free(custom_data); + } + } + + struct ggml_tensor* hidden_states = forward(ctx0, input_ids, input_ids2, embeddings, max_token_idx, return_pooled); ggml_build_forward_expand(gf, hidden_states); ggml_free(ctx0); @@ -1679,6 +1783,7 @@ struct FrozenCLIPEmbedderWithCustomWords { } }; + /*==================================================== UnetModel =====================================================*/ struct ResBlock { @@ -2835,16 +2940,16 @@ struct UNetModel { } // pass data to device backend if (!ggml_allocr_is_measure(compute_alloc)) { - ggml_backend_tensor_set(x_t, x->data, 0, ggml_nbytes(x)); - ggml_backend_tensor_set(context_t, context->data, 0, ggml_nbytes(context)); + ggml_backend_tensor_set_and_sync(backend, x_t, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set_and_sync(backend, context_t, context->data, 0, ggml_nbytes(context)); if (timesteps_t != NULL) { - ggml_backend_tensor_set(timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); + ggml_backend_tensor_set_and_sync(backend, timesteps_t, timesteps->data, 0, ggml_nbytes(timesteps)); } if (t_emb_t != NULL) { - ggml_backend_tensor_set(t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); + ggml_backend_tensor_set_and_sync(backend, t_emb_t, t_emb->data, 0, ggml_nbytes(t_emb)); } if (y != NULL) { - ggml_backend_tensor_set(y_t, y->data, 0, ggml_nbytes(y)); + ggml_backend_tensor_set_and_sync(backend, y_t, y->data, 0, ggml_nbytes(y)); } } } else { @@ -3675,7 +3780,7 @@ struct AutoEncoderKL { // pass data to device backend if (!ggml_allocr_is_measure(compute_alloc)) { - ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z)); + ggml_backend_tensor_set_and_sync(backend, z_, z->data, 0, ggml_nbytes(z)); } } else { z_ = z; @@ -4360,7 +4465,7 @@ struct TinyAutoEncoder { // pass data to device backend if (!ggml_allocr_is_measure(compute_alloc)) { - ggml_backend_tensor_set(z_, z->data, 0, ggml_nbytes(z)); + ggml_backend_tensor_set_and_sync(backend, z_, z->data, 0, ggml_nbytes(z)); } } else { z_ = z; @@ -4872,7 +4977,7 @@ struct ESRGAN { ggml_allocr_alloc(compute_alloc, os); if (!ggml_allocr_is_measure(compute_alloc)) { float scale = 0.2f; - ggml_backend_tensor_set(os, &scale, 0, sizeof(scale)); + ggml_backend_tensor_set_and_sync(backend, os, &scale, 0, sizeof(scale)); } // it's performing a compute, check if backend isn't cpu @@ -4883,7 +4988,7 @@ struct ESRGAN { // pass data to device backend if (!ggml_allocr_is_measure(compute_alloc)) { - ggml_backend_tensor_set(x_, x->data, 0, ggml_nbytes(x)); + ggml_backend_tensor_set_and_sync(backend, x_, x->data, 0, ggml_nbytes(x)); } } else { x_ = x; @@ -5086,7 +5191,7 @@ struct LoraModel { ggml_allocr_alloc(compute_alloc, lora_scale); if (!ggml_allocr_is_measure(compute_alloc)) { - ggml_backend_tensor_set(lora_scale, &scale_value, 0, ggml_nbytes(lora_scale)); + ggml_backend_tensor_set_and_sync(backend, lora_scale, &scale_value, 0, ggml_nbytes(lora_scale)); } // flat lora tensors to multiply it @@ -5359,6 +5464,7 @@ class StableDiffusionGGML { bool load_from_file(const std::string& model_path, const std::string& vae_path, + std::string embeddings_path, ggml_type wtype, Schedule schedule, int clip_skip) { @@ -5368,7 +5474,6 @@ class StableDiffusionGGML { #endif #ifdef SD_USE_METAL LOG_DEBUG("Using Metal backend"); - ggml_metal_log_set_callback(ggml_log_callback_default, nullptr); backend = ggml_backend_metal_init(); #endif @@ -5435,6 +5540,8 @@ class StableDiffusionGGML { return false; } + cond_stage_model.text_model.embd_dir = embeddings_path; + ggml_type vae_type = model_data_type; if (version == VERSION_XL) { vae_type = GGML_TYPE_F32; // avoid nan, not work... @@ -6449,10 +6556,11 @@ StableDiffusion::StableDiffusion(int n_threads, bool StableDiffusion::load_from_file(const std::string& model_path, const std::string& vae_path, + std::string embeddings_path, ggml_type wtype, Schedule s, int clip_skip) { - return sd->load_from_file(model_path, vae_path, wtype, s, clip_skip); + return sd->load_from_file(model_path, vae_path, embeddings_path, wtype, s, clip_skip); } std::vector StableDiffusion::txt2img(std::string prompt, diff --git a/stable-diffusion.h b/stable-diffusion.h index 3ae012f9..059f9e68 100644 --- a/stable-diffusion.h +++ b/stable-diffusion.h @@ -50,6 +50,7 @@ class StableDiffusion { bool load_from_file(const std::string& model_path, const std::string& vae_path, + std::string embeddings_path, ggml_type wtype, Schedule d = DEFAULT, int clip_skip = -1);