diff --git a/src/imagenie/grayscale.py b/src/imagenie/grayscale.py index c9bc4c1..42dc468 100644 --- a/src/imagenie/grayscale.py +++ b/src/imagenie/grayscale.py @@ -12,7 +12,7 @@ def grayscale(image): Returns: ------- ndarray - The grayscale image as a 2D NumPy array. + The grayscale image as a 2D NumPy array (dtype=uint8). Raises: ------ @@ -20,26 +20,16 @@ def grayscale(image): If the input is not a NumPy array. ValueError If the input is not a 2D or 3D NumPy array, or if the 3D array does not have 3 channels. - - Examples: - --------- - Convert an RGB image to grayscale: - >>> gray_image = grayscale(image) """ - # Ensure the input is a valid NumPy array if not isinstance(image, np.ndarray): raise TypeError("The input image must be a NumPy array.") - - # Handle already grayscale (2D) images if image.ndim == 2: return image - - # Handle RGB images (3D) elif image.ndim == 3: if image.shape[-1] != 3: raise ValueError("The input image must have 3 channels in the last dimension for RGB.") - # Use weighted average to convert to grayscale - return np.dot(image[..., :3], [0.2989, 0.5870, 0.1140]) - + # Convert to grayscale with rounding and cast to uint8 + return np.round(np.dot(image[..., :3], [0.2989, 0.5870, 0.1140])).astype(np.uint8) else: raise ValueError("The input image must be a 2D or 3D NumPy array.") + diff --git a/tests/test_grayscale.py b/tests/test_grayscale.py index d38400f..6b74ef3 100644 --- a/tests/test_grayscale.py +++ b/tests/test_grayscale.py @@ -1,5 +1,8 @@ import unittest import numpy as np +import sys +import os +sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '../src'))) from imagenie.grayscale import grayscale def test_grayscale(): @@ -54,3 +57,4 @@ def test_grayscale(): if __name__ == "__main__": test_grayscale() print("All tests passed.") +