Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hf data access #962

Open
wants to merge 12 commits into
base: dev
Choose a base branch
from
1 change: 1 addition & 0 deletions data-processing-lib/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
mmh3
psutil
polars>=1.9.0
huggingface-hub>=0.25.2
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from data_processing.data_access.data_access import DataAccess
from data_processing.data_access.data_access_local import DataAccessLocal
from data_processing.data_access.data_access_s3 import DataAccessS3
from data_processing.data_access.data_access_hf import DataAccessHF
from data_processing.data_access.data_access_factory_base import DataAccessFactoryBase
from data_processing.data_access.data_access_factory import DataAccessFactory
from data_processing.data_access.snapshot_utils import SnapshotUtils
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
DataAccessFactoryBase,
DataAccessLocal,
DataAccessS3,
DataAccessHF,
)
from data_processing.utils import ParamsUtils, str2bool

Expand Down Expand Up @@ -46,6 +47,7 @@ def __init__(self, cli_arg_prefix: str = "data_", enable_data_navigation: bool =
super().__init__(cli_arg_prefix=cli_arg_prefix)
self.s3_config = None
self.local_config = None
self.hf_config = None
self.enable_data_navigation = enable_data_navigation

def add_input_params(self, parser: argparse.ArgumentParser) -> None:
Expand Down Expand Up @@ -77,6 +79,7 @@ def add_input_params(self, parser: argparse.ArgumentParser) -> None:
self.__add_data_navigation_params(parser)

def __add_data_navigation_params(self, parser):
# s3 config
help_example_dict = {
"input_folder": [
"s3-path/your-input-bucket",
Expand All @@ -93,6 +96,7 @@ def __add_data_navigation_params(self, parser):
default=None,
help="AST string containing input/output paths.\n" + ParamsUtils.get_ast_help_text(help_example_dict),
)
# local config
help_example_dict = {
"input_folder": ["./input", "Path to input folder of files to be processed"],
"output_folder": ["/tmp/output", "Path to output folder of processed files"],
Expand All @@ -104,6 +108,19 @@ def __add_data_navigation_params(self, parser):
help="ast string containing input/output folders using local fs.\n"
+ ParamsUtils.get_ast_help_text(help_example_dict),
)
# hf config
help_example_dict = {
"hf_token": ["./input", "HF token required for write operation"],
"input_folder": ["./input", "Path to input folder of files to be processed"],
"output_folder": ["/tmp/output", "Path to output folder of processed files"],
}
parser.add_argument(
f"--{self.cli_arg_prefix}hf_config",
type=ast.literal_eval,
default=None,
help="ast string containing hf_token/input/output folders using hf fs.\n"
+ ParamsUtils.get_ast_help_text(help_example_dict),
)
parser.add_argument(
f"--{self.cli_arg_prefix}max_files", type=int, default=-1, help="Max amount of files to process"
)
Expand Down Expand Up @@ -154,6 +171,7 @@ def apply_input_params(self, args: Union[dict, argparse.Namespace]) -> bool:
s3_cred = arg_dict.get(f"{self.cli_arg_prefix}s3_cred", None)
s3_config = arg_dict.get(f"{self.cli_arg_prefix}s3_config", None)
local_config = arg_dict.get(f"{self.cli_arg_prefix}local_config", None)
hf_config = arg_dict.get(f"{self.cli_arg_prefix}hf_config", None)
checkpointing = arg_dict.get(f"{self.cli_arg_prefix}checkpointing", False)
max_files = arg_dict.get(f"{self.cli_arg_prefix}max_files", -1)
data_sets = arg_dict.get(f"{self.cli_arg_prefix}data_sets", None)
Expand All @@ -163,18 +181,20 @@ def apply_input_params(self, args: Union[dict, argparse.Namespace]) -> bool:
# check which configuration (S3 or Local) is specified
s3_config_specified = 1 if s3_config is not None else 0
local_config_specified = 1 if local_config is not None else 0
hf_config_specified = 1 if hf_config is not None else 0

# check that only one (S3 or Local) configuration is specified
if s3_config_specified + local_config_specified > 1:
if s3_config_specified + local_config_specified + hf_config_specified > 1:
self.logger.error(
f"data factory {self.cli_arg_prefix} "
f"{'S3, ' if s3_config_specified == 1 else ''}"
f"{'Local ' if local_config_specified == 1 else ''}"
f"{'hf ' if hf_config_specified == 1 else ''}"
"configurations specified, but only one configuration expected"
)
return False

# further validate the specified configuration (S3 or Local)
# further validate the specified configuration (S3, hf or Local)
if s3_config_specified == 1:
if not self._validate_s3_config(s3_config=s3_config):
return False
Expand All @@ -188,6 +208,20 @@ def apply_input_params(self, args: Union[dict, argparse.Namespace]) -> bool:
f'input path - {self.s3_config["input_folder"]}, '
f'output path - {self.s3_config["output_folder"]}'
)
elif hf_config_specified == 1:
if not self._validate_hf_config(hf_config=hf_config):
return False
self.hf_config = hf_config
self.logger.info(
f"data factory {self.cli_arg_prefix} is using HF data access: "
f"input_folder - {self.hf_config['input_folder']} "
f"output_folder - {self.hf_config['output_folder']}"
)
elif s3_cred is not None:
if not self._validate_s3_cred(s3_credentials=s3_cred):
return False
self.s3_cred = s3_cred
self.logger.info(f"data factory {self.cli_arg_prefix} is using s3 configuration without input/output path")
elif local_config_specified == 1:
if not self._validate_local_config(local_config=local_config):
return False
Expand All @@ -197,11 +231,6 @@ def apply_input_params(self, args: Union[dict, argparse.Namespace]) -> bool:
f"input_folder - {self.local_config['input_folder']} "
f"output_folder - {self.local_config['output_folder']}"
)
elif s3_cred is not None:
if not self._validate_s3_cred(s3_credentials=s3_cred):
return False
self.s3_cred = s3_cred
self.logger.info(f"data factory {self.cli_arg_prefix} is using s3 configuration without input/output path")
else:
self.logger.info(
f"data factory {self.cli_arg_prefix} " f"is using local configuration without input/output path"
Expand Down Expand Up @@ -240,26 +269,36 @@ def create_data_access(self) -> DataAccess:
Create data access based on the parameters
:return: corresponding data access class
"""
if self.s3_config is not None or self.s3_cred is not None:
# If S3 config or S3 credential are specified, its S3
return DataAccessS3(
s3_credentials=self.s3_cred,
s3_config=self.s3_config,
if self.hf_config is not None:
# hf-config is specified, its hf
return DataAccessHF(
hf_config=self.hf_config,
d_sets=self.dsets,
checkpoint=self.checkpointing,
m_files=self.max_files,
n_samples=self.n_samples,
files_to_use=self.files_to_use,
files_to_checkpoint=self.files_to_checkpoint,
)
else:
# anything else is local data
return DataAccessLocal(
local_config=self.local_config,
if self.s3_config is not None or self.s3_cred is not None:
# If S3 config or S3 credential are specified, its S3
return DataAccessS3(
s3_credentials=self.s3_cred,
s3_config=self.s3_config,
d_sets=self.dsets,
checkpoint=self.checkpointing,
m_files=self.max_files,
n_samples=self.n_samples,
files_to_use=self.files_to_use,
files_to_checkpoint=self.files_to_checkpoint,
)
# anything else is local data
return DataAccessLocal(
local_config=self.local_config,
d_sets=self.dsets,
checkpoint=self.checkpointing,
m_files=self.max_files,
n_samples=self.n_samples,
files_to_use=self.files_to_use,
files_to_checkpoint=self.files_to_checkpoint,
)
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,8 @@ def _validate_local_config(self, local_config: dict[str, str]) -> bool:
def _validate_s3_config(self, s3_config: dict[str, str]) -> bool:
"""
Validate that
:param s3_config: dictionary of local config
:return: True if s3l config is valid, False otherwise
:param s3_config: dictionary of s3 config
:return: True if s3 config is valid, False otherwise
"""
valid_config = True
if s3_config.get("input_folder", "") == "":
Expand All @@ -141,3 +141,35 @@ def _validate_s3_config(self, s3_config: dict[str, str]) -> bool:
valid_config = False
self.logger.error(f"data access factory {self.cli_arg_prefix}: Could not find output folder in s3 config")
return valid_config

def _validate_hf_config(self, hf_config: dict[str, str]) -> bool:
"""
Validate that
:param s3_config: dictionary of hf config
:return: True if hf config is valid, False otherwise
"""
valid_config = True
if hf_config.get("hf_token", "") == "":
self.logger.warning(f"data access factory {self.cli_arg_prefix}: "
f"HF token is not defined, write operation may fail")
input_folder = hf_config.get("input_folder", "")
if input_folder == "":
valid_config = False
self.logger.error(f"data access factory {self.cli_arg_prefix}: Could not find input folder in HF config")
else:
if not input_folder.startswith("datasets/"):
valid_config = False
self.logger.error(f"data access factory {self.cli_arg_prefix}: "
f"Input folder in HF config has to start from datasets/")

output_folder = hf_config.get("output_folder", "")
if output_folder == "":
valid_config = False
self.logger.error(f"data access factory {self.cli_arg_prefix}: Could not find output folder in HF config")
else:
if not output_folder.startswith("datasets/"):
valid_config = False
self.logger.error(f"data access factory {self.cli_arg_prefix}: "
f"Output folder in HF config has to start from datasets/")

return valid_config
Loading