Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

super_resolve.py not robust to Pillow version 10 or greater #17

Open
Declan-Curran1 opened this issue May 31, 2024 · 0 comments
Open

super_resolve.py not robust to Pillow version 10 or greater #17

Declan-Curran1 opened this issue May 31, 2024 · 0 comments

Comments

@Declan-Curran1
Copy link

Had to edit super_resolve to solve following issue:

:36: DeprecationWarning: getsize is deprecated and will be removed in Pillow 10 (2023-07-01). Use getbbox or getlength instead.

Had to change getsize to getbbox. Full edited file pasted below


import torch
from utils import *
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Model checkpoints

srgan_checkpoint = "C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/checkpoint_srgan.pth.tar"
srresnet_checkpoint = "C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/checkpoint_srresnet.pth.tar"

Load models

print("Loading SRResNet model...")
srresnet = torch.load(srresnet_checkpoint)['model'].to(device)
srresnet.eval()
print("Loading SRGAN model...")
srgan_generator = torch.load(srgan_checkpoint)['generator'].to(device)
srgan_generator.eval()

def visualize_sr(img, halve=False):
try:
# Load image, downsample to obtain low-res version
print("Loading and processing HR image...")
hr_img = Image.open(img, mode="r")
hr_img = hr_img.convert('RGB')
print(f"Original HR image size: {hr_img.size}")
if halve:
hr_img = hr_img.resize((int(hr_img.width / 2), int(hr_img.height / 2)), Image.LANCZOS)
print(f"HR image resized to: {hr_img.size}")
lr_img = hr_img.resize((int(hr_img.width / 4), int(hr_img.height / 4)), Image.BICUBIC)
print(f"LR image size: {lr_img.size}")

    # Bicubic Upsampling
    print("Performing bicubic upsampling...")
    bicubic_img = lr_img.resize((hr_img.width, hr_img.height), Image.BICUBIC)
    print(f"Bicubic image size: {bicubic_img.size}")

    # Super-resolution (SR) with SRResNet
    print("Generating SR image with SRResNet...")
    sr_img_srresnet = srresnet(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
    sr_img_srresnet = sr_img_srresnet.squeeze(0).cpu().detach()
    sr_img_srresnet = convert_image(sr_img_srresnet, source='[-1, 1]', target='pil')
    print(f"SRResNet image size: {sr_img_srresnet.size}")

    # Super-resolution (SR) with SRGAN
    print("Generating SR image with SRGAN...")
    sr_img_srgan = srgan_generator(convert_image(lr_img, source='pil', target='imagenet-norm').unsqueeze(0).to(device))
    sr_img_srgan = sr_img_srgan.squeeze(0).cpu().detach()
    sr_img_srgan = convert_image(sr_img_srgan, source='[-1, 1]', target='pil')
    print(f"SRGAN image size: {sr_img_srgan.size}")

    # Save intermediate images for verification
    print("Saving intermediate images for verification...")
    bicubic_img.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/bicubic_img.png", "PNG")
    sr_img_srresnet.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/sr_img_srresnet.png", "PNG")
    sr_img_srgan.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/sr_img_srgan.png", "PNG")
    hr_img.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/hr_img.png", "PNG")

    # Create grid
    print("Creating image grid...")
    margin = 40
    grid_img = Image.new('RGB', (2 * hr_img.width + 3 * margin, 2 * hr_img.height + 3 * margin), (255, 255, 255))

    # Font
    draw = ImageDraw.Draw(grid_img)
    try:
        font = ImageFont.truetype("calibril.ttf", size=23)
        print("Loaded custom font.")
    except OSError:
        print("Defaulting to a terrible font. To use a font of your choice, include the link to its TTF file in the function.")
        font = ImageFont.load_default()
        print("Loaded default font.")

    # Place bicubic-upsampled image
    grid_img.paste(bicubic_img, (margin, margin))
    try:
        text_size = draw.textbbox((0, 0), "Bicubic", font=font)
        text_width = text_size[2] - text_size[0]
        draw.text(xy=[margin + bicubic_img.width / 2 - text_width / 2, margin - text_size[3] - 5], text="Bicubic", font=font, fill='black')
        print("Added Bicubic text.")
    except Exception as e:
        print(f"An error occurred while drawing Bicubic text: {e}")

    # Place SRResNet image
    grid_img.paste(sr_img_srresnet, (2 * margin + bicubic_img.width, margin))
    try:
        text_size = draw.textbbox((0, 0), "SRResNet", font=font)
        text_width = text_size[2] - text_size[0]
        draw.text(xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_width / 2, margin - text_size[3] - 5], text="SRResNet", font=font, fill='black')
        print("Added SRResNet text.")
    except Exception as e:
        print(f"An error occurred while drawing SRResNet text: {e}")

    # Place SRGAN image
    grid_img.paste(sr_img_srgan, (margin, 2 * margin + sr_img_srresnet.height))
    try:
        text_size = draw.textbbox((0, 0), "SRGAN", font=font)
        text_width = text_size[2] - text_size[0]
        draw.text(xy=[margin + bicubic_img.width / 2 - text_width / 2, 2 * margin + sr_img_srresnet.height - text_size[3] - 5], text="SRGAN", font=font, fill='black')
        print("Added SRGAN text.")
    except Exception as e:
        print(f"An error occurred while drawing SRGAN text: {e}")

    # Place original HR image
    grid_img.paste(hr_img, (2 * margin + bicubic_img.width, 2 * margin + sr_img_srresnet.height))
    try:
        text_size = draw.textbbox((0, 0), "Original HR", font=font)
        text_width = text_size[2] - text_size[0]
        draw.text(xy=[2 * margin + bicubic_img.width + sr_img_srresnet.width / 2 - text_width / 2, 2 * margin + sr_img_srresnet.height - text_size[3] - 1], text="Original HR", font=font, fill='black')
        print("Added Original HR text.")
    except Exception as e:
        print(f"An error occurred while drawing Original HR text: {e}")

    # Display grid using matplotlib
    print("Displaying image grid...")
    plt.imshow(grid_img)
    plt.axis('off')
    plt.show()

    # Save grid
    print("Saving image grid...")
    grid_img.save("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/img1.png", "PNG")
    print("Image saved successfully")

    return grid_img
except Exception as e:
    print(f"An error occurred: {e}")

if name == 'main':
print("Starting visualization...")
grid_img = visualize_sr("C:/Users/decla/a-PyTorch-Tutorial-to-Super-Resolution-master/data/BSDS100/62096.png", halve=True)
print("Visualization completed")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant