From e9a573ac818e8941fa22e7d5b08d070265907fee Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Fri, 20 Dec 2024 14:48:02 -0500 Subject: [PATCH 1/3] [MLX] Preserve dtype of array when converting to torch --- outlines/processors/base_logits_processor.py | 4 +++- tests/processors/test_base_processor.py | 2 +- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index 44b55af2e..9e09d2fca 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -110,8 +110,10 @@ def _to_torch(tensor_like: Array) -> torch.Tensor: import mlx.core as mx # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch + if tensor_like.dtype == mx.bfloat16: + tensor_like = tensor_like.astype(mx.float32) return torch.from_dlpack( - np.array(tensor_like.astype(mx.float32), copy=False) + np.array(tensor_like, copy=False) ) elif is_jax_array_type(type(tensor_like)): diff --git a/tests/processors/test_base_processor.py b/tests/processors/test_base_processor.py index cd9f48278..1413e6cea 100644 --- a/tests/processors/test_base_processor.py +++ b/tests/processors/test_base_processor.py @@ -17,7 +17,7 @@ try: import mlx.core as mx - arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32) + arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16) except ImportError: pass From 40f6ad10b94d85fd7d2e68a45fdd89856e73c575 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Fri, 20 Dec 2024 15:16:32 -0500 Subject: [PATCH 2/3] Fix style and run pre-commit check --- outlines/processors/base_logits_processor.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/outlines/processors/base_logits_processor.py b/outlines/processors/base_logits_processor.py index 9e09d2fca..800e69f79 100644 --- a/outlines/processors/base_logits_processor.py +++ b/outlines/processors/base_logits_processor.py @@ -112,9 +112,7 @@ def _to_torch(tensor_like: Array) -> torch.Tensor: # https://ml-explore.github.io/mlx/build/html/usage/numpy.html#pytorch if tensor_like.dtype == mx.bfloat16: tensor_like = tensor_like.astype(mx.float32) - return torch.from_dlpack( - np.array(tensor_like, copy=False) - ) + return torch.from_dlpack(np.array(tensor_like, copy=False)) elif is_jax_array_type(type(tensor_like)): import jax From a1056a535979a4ead251d47960986dbca4bf1279 Mon Sep 17 00:00:00 2001 From: Neil Mehta Date: Sun, 22 Dec 2024 15:24:37 -0500 Subject: [PATCH 3/3] Improve test coverage --- tests/processors/test_base_processor.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/processors/test_base_processor.py b/tests/processors/test_base_processor.py index 1413e6cea..d2a1e1af2 100644 --- a/tests/processors/test_base_processor.py +++ b/tests/processors/test_base_processor.py @@ -17,7 +17,8 @@ try: import mlx.core as mx - arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16) + arrays["mlx"] = mx.array([[1, 2], [3, 4]], dtype=mx.float32) + arrays["mlx_bfloat16"] = mx.array([[1, 2], [3, 4]], dtype=mx.bfloat16) except ImportError: pass @@ -59,7 +60,12 @@ def test_from_torch(array_type, processor): torch_tensor = torch.tensor([[1, 2], [3, 4]], dtype=torch.float32) data = processor._from_torch(torch_tensor, type(arrays[array_type])) assert isinstance(data, type(arrays[array_type])) - assert np.allclose(data, arrays[array_type]) + if array_type == "mlx_bfloat16": + # For bfloat16, we expect the output to be float32 due to the conversion + assert data.dtype == mx.float32 + assert np.allclose(np.array(data), np.array([[1, 2], [3, 4]], dtype=np.float32)) + else: + assert np.allclose(data, arrays[array_type]) @pytest.mark.parametrize("array_type", arrays.keys())