Skip to content

Commit

Permalink
fix: Identify s3/remote uri path correctly
Browse files Browse the repository at this point in the history
Signed-off-by: ntkathole <[email protected]>
  • Loading branch information
ntkathole committed Feb 22, 2025
1 parent f3a24de commit ac389ab
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 17 deletions.
8 changes: 3 additions & 5 deletions sdk/python/feast/infra/offline_stores/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,9 @@ def persist(
# Check if the specified location already exists.
if not allow_overwrite and os.path.exists(storage.file_options.uri):
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)

if not Path(storage.file_options.uri).is_absolute():
absolute_path = Path(self.repo_path) / storage.file_options.uri
else:
absolute_path = Path(storage.file_options.uri)
absolute_path = FileSource.get_uri_for_file_path(
repo_path=self.repo_path, uri=storage.file_options.uri
)

filesystem, path = FileSource.create_filesystem_and_path(
str(absolute_path),
Expand Down
7 changes: 3 additions & 4 deletions sdk/python/feast/infra/offline_stores/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,9 @@ def _write_data_source(

file_options = data_source.file_options

if not Path(file_options.uri).is_absolute():
absolute_path = Path(repo_path) / file_options.uri
else:
absolute_path = Path(file_options.uri)
absolute_path = FileSource.get_uri_for_file_path(
repo_path=repo_path, uri=file_options.uri
)

if (
mode == "overwrite"
Expand Down
21 changes: 13 additions & 8 deletions sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path
from typing import Callable, Dict, Iterable, List, Optional, Tuple
from urllib.parse import urlparse

import pyarrow
from packaging import version
Expand Down Expand Up @@ -154,17 +155,21 @@ def validate(self, config: RepoConfig):
def source_datatype_to_feast_value_type() -> Callable[[str], ValueType]:
return type_map.pa_to_feast_value_type

@staticmethod
def get_uri_for_file_path(repo_path, uri):
parsed_uri = urlparse(uri)
if parsed_uri.scheme and parsed_uri.netloc:
return uri # Keep remote URIs as they are
if repo_path is not None and not Path(uri).is_absolute():
return str(Path(repo_path) / uri)
return str(Path(uri))

def get_table_column_names_and_types(
self, config: RepoConfig
) -> Iterable[Tuple[str, str]]:
if (
config.repo_path is not None
and not Path(self.file_options.uri).is_absolute()
):
absolute_path = config.repo_path / self.file_options.uri
else:
absolute_path = Path(self.file_options.uri)

absolute_path = self.get_uri_for_file_path(
repo_path=config.repo_path, uri=self.file_options.uri
)
filesystem, path = FileSource.create_filesystem_and_path(
str(absolute_path), self.file_options.s3_endpoint_override
)
Expand Down
24 changes: 24 additions & 0 deletions sdk/python/tests/unit/infra/offline_stores/test_offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
TrinoRetrievalJob,
)
from feast.infra.offline_stores.dask import DaskRetrievalJob
from feast.infra.offline_stores.file_source import FileSource
from feast.infra.offline_stores.offline_store import RetrievalJob, RetrievalMetadata
from feast.infra.offline_stores.redshift import (
RedshiftOfflineStoreConfig,
Expand Down Expand Up @@ -246,3 +247,26 @@ def test_to_arrow_timeout(retrieval_job, timeout: Optional[int]):
with patch.object(retrieval_job, "_to_arrow_internal") as mock_to_arrow_internal:
retrieval_job.to_arrow(timeout=timeout)
mock_to_arrow_internal.assert_called_once_with(timeout=timeout)


@pytest.mark.parametrize(
"repo_path, uri, expected",
[
# Remote URI - Should return as-is
(
"/some/repo",
"s3://bucket-name/file.parquet",
"s3://bucket-name/file.parquet",
),
# Absolute Path - Should return as-is
("/some/repo", "/abs/path/file.parquet", "/abs/path/file.parquet"),
# Relative Path with repo_path - Should combine
("/some/repo", "data/output.parquet", "/some/repo/data/output.parquet"),
# Relative Path without repo_path - Should return absolute path
(None, "C:/path/to/file.parquet", "C:/path/to/file.parquet"),
],
ids=["s3_uri", "absolute_path", "relative_path", "windows_path"],
)
def test_get_uri_for_file_path(repo_path, uri, expected):
result = FileSource.get_uri_for_file_path(repo_path=repo_path, uri=uri)
assert result == expected, f"Expected {expected}, but got {result}"

0 comments on commit ac389ab

Please sign in to comment.