Skip to content

Commit

Permalink
Merge pull request #64 from UBC-MDS/grayscale_function
Browse files Browse the repository at this point in the history
Grayscale function
  • Loading branch information
yfan810 authored Jan 18, 2025
2 parents 3b70f38 + f4d82a8 commit 914d622
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 14 deletions.
18 changes: 4 additions & 14 deletions src/imagenie/grayscale.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,24 @@ def grayscale(image):
Returns:
-------
ndarray
The grayscale image as a 2D NumPy array.
The grayscale image as a 2D NumPy array (dtype=uint8).
Raises:
------
TypeError
If the input is not a NumPy array.
ValueError
If the input is not a 2D or 3D NumPy array, or if the 3D array does not have 3 channels.
Examples:
---------
Convert an RGB image to grayscale:
>>> gray_image = grayscale(image)
"""
# Ensure the input is a valid NumPy array
if not isinstance(image, np.ndarray):
raise TypeError("The input image must be a NumPy array.")

# Handle already grayscale (2D) images
if image.ndim == 2:
return image

# Handle RGB images (3D)
elif image.ndim == 3:
if image.shape[-1] != 3:
raise ValueError("The input image must have 3 channels in the last dimension for RGB.")
# Use weighted average to convert to grayscale
return np.dot(image[..., :3], [0.2989, 0.5870, 0.1140])

# Convert to grayscale with rounding and cast to uint8
return np.round(np.dot(image[..., :3], [0.2989, 0.5870, 0.1140])).astype(np.uint8)
else:
raise ValueError("The input image must be a 2D or 3D NumPy array.")

4 changes: 4 additions & 0 deletions tests/test_grayscale.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
import unittest
import numpy as np
import sys
import os
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src')))
from imagenie.grayscale import grayscale

def test_grayscale():
Expand Down Expand Up @@ -54,3 +57,4 @@ def test_grayscale():
if __name__ == "__main__":
test_grayscale()
print("All tests passed.")

0 comments on commit 914d622

Please sign in to comment.