Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Differentiable RGB <-> LAB #4

Open
AliShahin opened this issue Sep 10, 2019 · 8 comments
Open

Differentiable RGB <-> LAB #4

AliShahin opened this issue Sep 10, 2019 · 8 comments

Comments

@AliShahin
Copy link

Hi,

Thanks for the useful implementation of color conversion in PyTorch. I am wondering if you have any implementation for their differentiable versions?

Thanks in advance.

@aribenjamin
Copy link

aribenjamin commented Feb 6, 2020

I wanted these so I wrote them myself. Could be a pull request here, but, eh, here are the functions at least.
edited: now these all work with batches of images, too, of shape (n_batches, 3, x, y)

import torch

def BGR_to_RGB(image):
    "Just permutes the color channels"
    if len(image.size())==3:
        return image[2,1,0]
    elif len(image.size())==4:
        return image[:,[2,1,0]]

def RGB_to_BGR(image):
    "Just permutes the color channels"
    return BGR_to_RGB(image)

def srgb_to_rgb(image):
    """Linearizes sRGB to RGB. Assumes input is in range [0,1].
    Works for batched images too.


    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are gamma-corrected R, G, B
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are linearized RGB
    """
#     assert torch.max(image) <= 1

    small_u = image/12.92
    big_u = torch.pow((image+0.055)/1.055, 12./5)

    return torch.where(image<=0.04045, small_u, big_u)

def rgb_to_srgb(image):
    """Applies gamma correction to rgb to get sRGB. Assumes input is in range [0,1]
    Works for batched images too.

    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are linearized R, G, B
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are gamma-corrected RGB
    """
#     assert torch.max(image) <= 1

    small_u = image*12.92
    big_u = torch.pow(image,.416)*1.055-0.055

    return torch.where(image<=0.0031308, small_u, big_u)


def rgb_to_xyz(images):
    """
    Converts true (linearized) rbg to xyz.
        Works for batched images too.


    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are true R, G, B
                        OR a batched version with (N, 3, x, y)

    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are X, Y, Z
                    Or the batched version

    """
    M65 = torch.tensor([[0.4124564,  0.3575761,  0.1804375],
                         [0.2126729,  0.7151522,  0.0721750],
                         [0.0193339,  0.1191920 , 0.9503041]]).to(images.device)

                    # multiply by the matrix only along the color dimension
    images = torch.einsum("...jkl,jm->...mkl",images, M65)

    return images

def xyz_to_rgb(images):
    """
    Converts xyz to true (linearized) rbg.
        Works for batched images too.


    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are true X, Y, Z
                    OR a batched version with (N, 3, x, y)
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are R, G, B
                Or the batched version
    """
    M65 = torch.tensor([[0.4124564,  0.3575761,  0.1804375],
                         [0.2126729,  0.7151522,  0.0721750],
                         [0.0193339,  0.1191920 , 0.9503041]]).to(images.device)
    M65_inv = torch.inverse(M65)

    # multiply by the matrix only along the color dimension. Wor
    images = torch.einsum("...jkl,jm->...mkl",images, M65_inv)

    return images

def luv_to_xyz(image):
    """
    Converts luv to xyz. Assumes D65 standard illuminant.

    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are X, Y, Z
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are L, U, V
    """
    # single image
    if len(image.size()) == 3:
        u_prime = image[1] / (13 * image[0]) + .2009
        v_prime = image[2] / (13 * image[0]) + .4610

        small_Y = image[0] * (3. / 29) ** 3
        large_Y = ((image[0] + 16.) / 116.) ** 3

        Y = torch.where(image[0] <= 8, small_Y, large_Y)
        d = 0
        # batch of images
    elif len(image.size()) == 4:

        u_prime = image[:, 1] / (13 * image[:, 0]) + .2009
        v_prime = image[:, 2] / (13 * image[:, 0]) + .4610

        small_Y = image[:, 0] * (3. / 29) ** 3
        large_Y = ((image[:, 0] + 16.) / 116.) ** 3

        Y = torch.where(image[:, 0] <= 8, small_Y, large_Y)
        d = 1

    X = Y * 9 * u_prime / (4 * v_prime)
    Z = Y * (12 - 3 * u_prime - 20 * v_prime) / (4 * v_prime)

    xyz_image = torch.stack((X, Y, Z), dim=d)

    return xyz_image


def xyz_to_luv(image):
    """
    Converts xyz to luv. Assumes D65 standard illuminant.

    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are XYZ
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are LUV
    """
    # single image
    if len(image.size()) == 3:
        small_L = (29. / 3) ** 3 * image[1]
        large_L = 116 * torch.pow(image[1], 1 / 3.) - 16
        L = torch.where(image[1] <= (6. / 29) ** 3, small_L, large_L)

        u_prime = 4 * image[0] / (image[0] + 15 * image[1] + 3 * image[2])
        v_prime = 9 * image[1] / (image[0] + 15 * image[1] + 3 * image[2])
        d = 0
    elif len(image.size()) == 4:
        small_L = (29. / 3) ** 3 * image[:, 1]
        large_L = 116 * torch.pow(image[:, 1], 1 / 3.) - 16
        L = torch.where(image[:, 1] <= (6. / 29) ** 3, small_L, large_L)

        u_prime = 4 * image[:, 0] / (image[:, 0] + 15 * image[:, 1] + 3 * image[:, 2])
        v_prime = 9 * image[:, 1] / (image[:, 0] + 15 * image[:, 1] + 3 * image[:, 2])
        d = 1

    u = 13 * L * (u_prime - .2009)
    v = 13 * L * (v_prime - .4610)

    luv_image = torch.stack((L, u, v), dim=d)

    return luv_image

def luv_to_lch(image):
    # single image
    if len(image.size()) == 3:
        C = torch.pow(torch.pow(image[1],2) + torch.pow(image[2],2), 0.5)
        h = torch.atan2(image[2], image[1])

        lch_image = torch.stack((image[0], C, h))
    # batched image
    if len(image.size()) == 4:
        C = torch.pow(torch.pow(image[:,1],2) + torch.pow(image[:,2],2), 0.5)
        h = torch.atan2(image[:,2], image[:,1])

        lch_image = torch.stack((image[:,0], C, h), dim=1)

    return lch_image

def lch_to_luv(image):

    # single image
    if len(image.size()) == 3:
        u = image[1] * torch.cos(image[2])
        v = image[1] * torch.sin(image[2])

        luv_image = torch.stack((image[0], u, v))
    # batched image
    if len(image.size()) == 4:
        u = image[:,1] * torch.cos(image[:,2])
        v = image[:,1] * torch.sin(image[:,2])

        luv_image = torch.stack((image[:,0], u, v), dim=1)
    return luv_image


def srgb_to_xyz(image):
    return rgb_to_xyz(srgb_to_rgb(image))

def xyz_to_srgb(image):
    return rgb_to_srgb(xyz_to_rgb(image))

def srgb_to_luv(image):
    return xyz_to_luv(srgb_to_xyz(image))

def luv_to_srgb(image):
    return luv_to_xyz(lch_to_luv(image))

def srgb_to_lch(image):
    return luv_to_lch(xyz_to_luv(srgb_to_xyz(image)))

def lch_to_srgb(image):
    return xyz_to_srgb(luv_to_xyz(lch_to_luv(image)))

def sbgr_to_lch(image):
    return srgb_to_lch(BGR_to_RGB(image))

def lch_to_sbgr(image):
    return BGR_to_RGB(lch_to_srgb(image))

@wangg12
Copy link

wangg12 commented Feb 12, 2020

@aribenjamin Cool, how about rgb_to_lab ?

@AliShahin
Copy link
Author

AliShahin commented Feb 15, 2020

As part of an implementation of one of my paper, I have provided a differentiable version of RGB2Lab and Lab2RGB in PyTorch, here https://github.com/smartcameras/EdgeFool/blob/master/Train/rgb_lab_formulation_pytorch.py

@nile649
Copy link

nile649 commented Mar 8, 2020

@AliShahin you're implementation results are definitely weird on comparison with skimage result.

@fantasy-fish
Copy link

I wanted these so I wrote them myself. Could be a pull request here, but, eh, here are the functions at least.
edited: now these all work with batches of images, too, of shape (n_batches, 3, x, y)

import torch

def BGR_to_RGB(image):
    "Just permutes the color channels"
    if len(image.size())==3:
        return image[2,1,0]
    elif len(image.size())==4:
        return image[:,[2,1,0]]

def RGB_to_BGR(image):
    "Just permutes the color channels"
    return BGR_to_RGB(image)

def srgb_to_rgb(image):
    """Linearizes sRGB to RGB. Assumes input is in range [0,1].
    Works for batched images too.


    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are gamma-corrected R, G, B
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are linearized RGB
    """
#     assert torch.max(image) <= 1

    small_u = image/12.92
    big_u = torch.pow((image+0.055)/1.055, 12./5)

    return torch.where(image<=0.04045, small_u, big_u)

def rgb_to_srgb(image):
    """Applies gamma correction to rgb to get sRGB. Assumes input is in range [0,1]
    Works for batched images too.

    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are linearized R, G, B
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are gamma-corrected RGB
    """
#     assert torch.max(image) <= 1

    small_u = image*12.92
    big_u = torch.pow(image,.416)*1.055-0.055

    return torch.where(image<=0.0031308, small_u, big_u)


def rgb_to_xyz(images):
    """
    Converts true (linearized) rbg to xyz.
        Works for batched images too.


    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are true R, G, B
                        OR a batched version with (N, 3, x, y)

    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are X, Y, Z
                    Or the batched version

    """
    M65 = torch.tensor([[0.4124564,  0.3575761,  0.1804375],
                         [0.2126729,  0.7151522,  0.0721750],
                         [0.0193339,  0.1191920 , 0.9503041]]).to(images.device)

                    # multiply by the matrix only along the color dimension
    images = torch.einsum("...jkl,jm->...mkl",images, M65)

    return images

def xyz_to_rgb(images):
    """
    Converts xyz to true (linearized) rbg.
        Works for batched images too.


    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are true X, Y, Z
                    OR a batched version with (N, 3, x, y)
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are R, G, B
                Or the batched version
    """
    M65 = torch.tensor([[0.4124564,  0.3575761,  0.1804375],
                         [0.2126729,  0.7151522,  0.0721750],
                         [0.0193339,  0.1191920 , 0.9503041]]).to(images.device)
    M65_inv = torch.inverse(M65)

    # multiply by the matrix only along the color dimension. Wor
    images = torch.einsum("...jkl,jm->...mkl",images, M65_inv)

    return images

def luv_to_xyz(image):
    """
    Converts luv to xyz. Assumes D65 standard illuminant.

    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are X, Y, Z
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are L, U, V
    """
    # single image
    if len(image.size()) == 3:
        u_prime = image[1] / (13 * image[0]) + .2009
        v_prime = image[2] / (13 * image[0]) + .4610

        small_Y = image[0] * (3. / 29) ** 3
        large_Y = ((image[0] + 16.) / 116.) ** 3

        Y = torch.where(image[0] <= 8, small_Y, large_Y)
        d = 0
        # batch of images
    elif len(image.size()) == 4:

        u_prime = image[:, 1] / (13 * image[:, 0]) + .2009
        v_prime = image[:, 2] / (13 * image[:, 0]) + .4610

        small_Y = image[:, 0] * (3. / 29) ** 3
        large_Y = ((image[:, 0] + 16.) / 116.) ** 3

        Y = torch.where(image[:, 0] <= 8, small_Y, large_Y)
        d = 1

    X = Y * 9 * u_prime / (4 * v_prime)
    Z = Y * (12 - 3 * u_prime - 20 * v_prime) / (4 * v_prime)

    xyz_image = torch.stack((X, Y, Z), dim=d)

    return xyz_image


def xyz_to_luv(image):
    """
    Converts xyz to luv. Assumes D65 standard illuminant.

    :param image: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are XYZ
    :return: A pytorch tensor of shape (3, n_pixels_x, n_pixels_y) in which the channels are LUV
    """
    # single image
    if len(image.size()) == 3:
        small_L = (29. / 3) ** 3 * image[1]
        large_L = 116 * torch.pow(image[1], 1 / 3.) - 16
        L = torch.where(image[1] <= (6. / 29) ** 3, small_L, large_L)

        u_prime = 4 * image[0] / (image[0] + 15 * image[1] + 3 * image[2])
        v_prime = 9 * image[1] / (image[0] + 15 * image[1] + 3 * image[2])
        d = 0
    elif len(image.size()) == 4:
        small_L = (29. / 3) ** 3 * image[:, 1]
        large_L = 116 * torch.pow(image[:, 1], 1 / 3.) - 16
        L = torch.where(image[:, 1] <= (6. / 29) ** 3, small_L, large_L)

        u_prime = 4 * image[:, 0] / (image[:, 0] + 15 * image[:, 1] + 3 * image[:, 2])
        v_prime = 9 * image[:, 1] / (image[:, 0] + 15 * image[:, 1] + 3 * image[:, 2])
        d = 1

    u = 13 * L * (u_prime - .2009)
    v = 13 * L * (v_prime - .4610)

    luv_image = torch.stack((L, u, v), dim=d)

    return luv_image

def luv_to_lch(image):
    # single image
    if len(image.size()) == 3:
        C = torch.pow(torch.pow(image[1],2) + torch.pow(image[2],2), 0.5)
        h = torch.atan2(image[2], image[1])

        lch_image = torch.stack((image[0], C, h))
    # batched image
    if len(image.size()) == 4:
        C = torch.pow(torch.pow(image[:,1],2) + torch.pow(image[:,2],2), 0.5)
        h = torch.atan2(image[:,2], image[:,1])

        lch_image = torch.stack((image[:,0], C, h), dim=1)

    return lch_image

def lch_to_luv(image):

    # single image
    if len(image.size()) == 3:
        u = image[1] * torch.cos(image[2])
        v = image[1] * torch.sin(image[2])

        luv_image = torch.stack((image[0], u, v))
    # batched image
    if len(image.size()) == 4:
        u = image[:,1] * torch.cos(image[:,2])
        v = image[:,1] * torch.sin(image[:,2])

        luv_image = torch.stack((image[:,0], u, v), dim=1)
    return luv_image


def srgb_to_xyz(image):
    return rgb_to_xyz(srgb_to_rgb(image))

def xyz_to_srgb(image):
    return rgb_to_srgb(xyz_to_rgb(image))

def srgb_to_luv(image):
    return xyz_to_luv(srgb_to_xyz(image))

def luv_to_srgb(image):
    return luv_to_xyz(lch_to_luv(image))

def srgb_to_lch(image):
    return luv_to_lch(xyz_to_luv(srgb_to_xyz(image)))

def lch_to_srgb(image):
    return xyz_to_srgb(luv_to_xyz(lch_to_luv(image)))

def sbgr_to_lch(image):
    return srgb_to_lch(BGR_to_RGB(image))

def lch_to_sbgr(image):
    return BGR_to_RGB(lch_to_srgb(image))

Thank you for your code. But what's the range of each channel in the luv mode?

@jorge-pessoa
Copy link
Owner

Hello,
Thank you for the discussion and contributions. @aribenjamin I have searched your repositories and it seems that you have not released your differentiable implementations in any repository. Would it be okay for this to be integrated in pytorch-colors and due credit provided in the README and code?

@jorge-pessoa
Copy link
Owner

@aribenjamin I will keep this issue open for visibility of your code. Can I reiterate, would it be okay for you to include this code in the main package as a Merge Request with credit?

@legel
Copy link

legel commented Mar 4, 2023

Since I found this issue first on Github, just to direct others to this implementation in Kornia that looks good to go:
https://kornia.readthedocs.io/en/latest/_modules/kornia/color/lab.html#rgb_to_lab

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

7 participants