diff --git a/pixelateTG.py b/pixelateTG.py index b75362f..e2032e6 100644 --- a/pixelateTG.py +++ b/pixelateTG.py @@ -1,8 +1,9 @@ import os -from dotenv import load_dotenv # Import the load_dotenv function from python-dotenv +from dotenv import load_dotenv import cv2 import random import imageio +import numpy as np from telegram import Update, InlineKeyboardButton, InlineKeyboardMarkup from telegram.ext import Updater, CallbackContext, CommandHandler, CallbackQueryHandler, MessageHandler, Filters from concurrent.futures import ThreadPoolExecutor, wait @@ -12,59 +13,66 @@ # Load environment variables from .env file load_dotenv() -TOKEN = os.getenv('TELEGRAM_BOT_TOKEN') # Get the Telegram bot token from the environment variable +TOKEN = os.getenv('TELEGRAM_BOT_TOKEN') MAX_THREADS = 15 PIXELATION_FACTOR = 0.04 -RESIZE_FACTOR = 1.5 # Common resize factor +RESIZE_FACTOR = 1.5 executor = ThreadPoolExecutor(max_workers=MAX_THREADS) +# Global MTCNN detector +mtcnn_detector = MTCNN() + +# Cache for overlay files +overlay_cache = {} + def start(update: Update, context: CallbackContext) -> None: update.message.reply_text('Send me a picture or a GIF, and I will pixelate faces in it!') def detect_heads(image): - mtcnn = MTCNN() - faces = mtcnn.detect_faces(image) - head_boxes = [(face['box'][0], face['box'][1], int(RESIZE_FACTOR * face['box'][2]), int(RESIZE_FACTOR * face['box'][3])) for face in faces] - return head_boxes + global mtcnn_detector + faces = mtcnn_detector.detect_faces(image) + return [(int(face['box'][0]), int(face['box'][1]), int(face['box'][2]), int(face['box'][3])) for face in faces] + +def get_overlay_files(overlay_type): + if overlay_type not in overlay_cache: + overlay_cache[overlay_type] = [name for name in os.listdir() if name.startswith(f'{overlay_type}_')] + return overlay_cache[overlay_type] def overlay(photo_path, user_id, overlay_type, resize_factor, bot): - image = cv2.imread(photo_path) + image_data = np.fromfile(photo_path, np.uint8) + image = cv2.imdecode(image_data, cv2.IMREAD_UNCHANGED) + heads = detect_heads(image) - + overlay_files = get_overlay_files(overlay_type) + for (x, y, w, h) in heads: - overlay_files = [name for name in os.listdir() if name.startswith(f'{overlay_type}_')] if not overlay_files: continue random_overlay = random.choice(overlay_files) - overlay_image = cv2.imread(random_overlay, cv2.IMREAD_UNCHANGED) + + overlay_data = np.fromfile(random_overlay, np.uint8) + overlay_image = cv2.imdecode(overlay_data, cv2.IMREAD_UNCHANGED) + original_aspect_ratio = overlay_image.shape[1] / overlay_image.shape[0] - - # Calculate new dimensions for the overlay new_width = int(resize_factor * w) new_height = int(new_width / original_aspect_ratio) - # Ensure the overlay is centered on the face center_x = x + w // 2 center_y = y + h // 2 - # Overlay position adjusted for better centering overlay_x = int(center_x - 0.5 * resize_factor * w) - int(0.1 * resize_factor * w) overlay_y = int(center_y - 0.5 * resize_factor * h) - int(0.1 * resize_factor * w) - # Clamp values to ensure they are within the image boundaries overlay_x = max(0, overlay_x) overlay_y = max(0, overlay_y) - # Resize the overlay image overlay_image_resized = cv2.resize(overlay_image, (new_width, new_height), interpolation=cv2.INTER_AREA) - # Calculate the regions of interest (ROI) roi_start_x = overlay_x roi_start_y = overlay_y roi_end_x = min(image.shape[1], overlay_x + new_width) roi_end_y = min(image.shape[0], overlay_y + new_height) - # Blend the overlay onto the image try: overlay_part = overlay_image_resized[:roi_end_y - roi_start_y, :roi_end_x - roi_start_x] alpha_mask = overlay_part[:, :, 3] / 255.0 @@ -81,7 +89,6 @@ def overlay(photo_path, user_id, overlay_type, resize_factor, bot): cv2.imwrite(processed_path, image, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) return processed_path - # Overlay functions def liotta_overlay(photo_path, user_id, bot): return overlay(photo_path, user_id, 'liotta', RESIZE_FACTOR, bot) @@ -102,8 +109,12 @@ def clowns_overlay(photo_path, user_id, bot): return overlay(photo_path, user_id, 'clown', RESIZE_FACTOR, bot) def process_gif(gif_path, session_id, user_id, bot): - frames = imageio.mimread(gif_path) - processed_frames = [process_image(frame, user_id, session_id, bot) for frame in frames] + reader = imageio.get_reader(gif_path) + processed_frames = [] + for frame in reader: + processed_frame = process_image(frame, user_id, session_id, bot) + processed_frames.append(processed_frame) + processed_gif_path = f"processed/{user_id}_{session_id}.gif" imageio.mimsave(processed_gif_path, processed_frames) return processed_gif_path @@ -135,7 +146,6 @@ def pixelate_faces(update: Update, context: CallbackContext) -> None: InlineKeyboardButton("🏆 Chad", callback_data=f'chad_overlay_{session_id}')] ] - # Check if it's a private chat, if yes, include the "⚔️ Pixel" button if update.message.chat.type == 'private': keyboard.append([InlineKeyboardButton("⚔️ Pixel", callback_data=f'pixelate_{session_id}')]) @@ -160,7 +170,6 @@ def pixelate_faces(update: Update, context: CallbackContext) -> None: else: update.message.reply_text('Please send either a photo or a GIF.') - def pixelate_command(update: Update, context: CallbackContext) -> None: if update.message.reply_to_message and update.message.reply_to_message.photo: session_id = str(uuid4()) @@ -201,22 +210,16 @@ def process_image(photo_path, user_id, session_id, bot): faces = detect_heads(image) for (x, y, w, h) in faces: - # Define the region of interest (ROI) roi = image[y:y+h, x:x+w] - - # Apply pixelation to the ROI - pixelation_size = max(1, int(PIXELATION_FACTOR * min(w, h))) # Ensure pixelation size is at least 1 + pixelation_size = max(1, int(PIXELATION_FACTOR * min(w, h))) pixelated_roi = cv2.resize(roi, (pixelation_size, pixelation_size), interpolation=cv2.INTER_NEAREST) pixelated_roi = cv2.resize(pixelated_roi, (w, h), interpolation=cv2.INTER_NEAREST) - - # Replace the original face region with the pixelated ROI image[y:y+h, x:x+w] = pixelated_roi processed_path = f"processed/{user_id}_{session_id}_pixelated.jpg" cv2.imwrite(processed_path, image, [int(cv2.IMWRITE_JPEG_QUALITY), 95]) return processed_path - def button_callback(update: Update, context: CallbackContext) -> None: query = update.callback_query query.answer() @@ -259,7 +262,6 @@ def button_callback(update: Update, context: CallbackContext) -> None: def main() -> None: updater = Updater(TOKEN) - dispatcher = updater.dispatcher dispatcher.add_handler(CommandHandler("start", start))