Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUGFIX] argilla: review datasest import with new export flow #5756

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions argilla/src/argilla/datasets/_io/_disk.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ def from_disk(
name=dataset_model.name, workspace_id=workspace.id
)
dataset = cls.from_model(model=dataset_model, client=client)
dataset.get()
else:
# Create a new dataset and load the settings and records
if not os.path.exists(settings_path):
Expand Down
30 changes: 22 additions & 8 deletions argilla/src/argilla/datasets/_io/_hub.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,11 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
for col in responses_columns:
question_name = col.split(".")[0]
if col.endswith("users"):
response_questions[question_name]["users"] = hf_dataset[col]
user_ids.update({UUID(user_id): UUID(user_id) for user_id in set(sum(hf_dataset[col], []))})
response_questions[question_name]["users"] = hf_dataset[col] or []
for users in hf_dataset[col]:
if users is None:
continue
user_ids.update({UUID(user_id): user_id for user_id in users})
elif col.endswith("responses"):
response_questions[question_name]["responses"] = hf_dataset[col]
elif col.endswith("status"):
Expand All @@ -240,7 +243,15 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
user_ids[unknown_user_id] = my_user.id

# Create a mapper to map the Hugging Face dataset to a Record object
mapping = {col: col for col in hf_dataset.column_names if ".suggestion" in col}
mapping = {}
for col in hf_dataset.column_names:
if ".suggestion" in col:
mapping[col] = col
elif col.startswith("metadata.") and col.replace("metadata.", "") in dataset.schema:
mapping[col] = col.replace("metadata.", "")
elif col.startswith("vector.") and col.replace("vector.", "") in dataset.schema:
mapping[col] = col.replace("vector.", "")

mapper = IngestedRecordMapper(dataset=dataset, mapping=mapping, user_id=my_user.id)

# Extract responses and create Record objects
Expand All @@ -249,14 +260,17 @@ def _log_dataset_records(hf_dataset: "HFDataset", dataset: "Dataset"):
for idx, row in enumerate(hf_dataset):
record = mapper(row)
for question_name, values in response_questions.items():
response_values = values["responses"][idx]
response_users = values["users"][idx]
response_status = values["status"][idx]
response_values = values["responses"][idx] or []
response_users = values["users"][idx] or []
response_status = values["status"][idx] or []

used_users = set()
for value, user_id, status in zip(response_values, response_users, response_status):
user_id = user_ids[UUID(user_id)]
if user_id in response_users:
if user_id in used_users:
continue
response_users[user_id] = True

used_users.add(user_id)
response = Response(
user_id=user_id,
question_name=question_name,
Expand Down
20 changes: 14 additions & 6 deletions argilla/src/argilla/responses.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,16 @@ def __init__(
status (Union[ResponseStatus, str]): The status of the response as "draft", "submitted", "discarded".
"""

if isinstance(status, str):
status = ResponseStatus(status)

if question_name is None:
raise ValueError("question_name is required")
if value is None:
if value is None and status == ResponseStatus.submitted:
raise ValueError("value is required")
if user_id is None:
raise ValueError("user_id is required")

if isinstance(status, str):
status = ResponseStatus(status)

self._record = _record
self.question_name = question_name
self.value = value
Expand Down Expand Up @@ -253,7 +253,7 @@ def _compute_user_id_from_responses(responses: List[Response]) -> Optional[UUID]
@staticmethod
def __responses_as_model_values(responses: List[Response]) -> Dict[str, Dict[str, Any]]:
"""Creates a dictionary of response values from a list of Responses"""
return {answer.question_name: {"value": answer.value} for answer in responses}
return {answer.question_name: {"value": answer.value} for answer in responses if answer.value is not None}

@classmethod
def __model_as_responses_list(cls, model: UserResponseModel, record: "Record") -> List[Response]:
Expand All @@ -276,4 +276,12 @@ def __ranking_from_model_value(cls, value: List[Dict[str, Any]]) -> List[str]:

@classmethod
def __ranking_to_model_value(cls, value: List[str]) -> List[Dict[str, str]]:
return [{"value": v} for v in value]
values = []
for v in value or []:
if isinstance(v, dict):
values.append(v)
elif isinstance(v, str):
values.append({"value": v})
else:
raise RecordResponsesError(f"Invalid value for ranking question: {v}")
return values
22 changes: 3 additions & 19 deletions argilla/src/argilla/settings/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@
FieldSettings,
)
from argilla.settings._common import SettingsPropertyBase
from argilla.settings._metadata import MetadataField, MetadataType
from argilla.settings._vector import VectorField


try:
from typing import Self
Expand Down Expand Up @@ -296,21 +295,6 @@ def _field_from_model(model: FieldModel) -> Field:
raise ArgillaError(f"Unsupported field type: {model.settings.type}")


def _field_from_dict(data: dict) -> Union[Field, VectorField, MetadataType]:
def _field_from_dict(data: dict) -> Field:
"""Create a field instance from a field dictionary"""
field_type = data["type"]

if field_type == "text":
return TextField.from_dict(data)
elif field_type == "image":
return ImageField.from_dict(data)
elif field_type == "chat":
return ChatField.from_dict(data)
elif field_type == "custom":
return CustomField.from_dict(data)
elif field_type == "vector":
return VectorField.from_dict(data)
elif field_type == "metadata":
return MetadataField.from_dict(data)
else:
raise ArgillaError(f"Unsupported field type: {field_type}")
return _field_from_model(FieldModel(**data))
Loading