Skip to content

Commit

Permalink
UPD: fixed rgb + decreased size
Browse files Browse the repository at this point in the history
  • Loading branch information
bmalezieux committed Jan 23, 2024
1 parent da43a39 commit ac9b268
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 30 deletions.
1 change: 1 addition & 0 deletions giskard_vision/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ def get_issues(
group=IssueGroup(result.group, "Warning"),
meta=result.get_meta_required(),
example_manager=ExamplesImages,
display_warnings=False
)
current_issue.add_examples(result.filename_examples)
issues.append(current_issue)
Expand Down
2 changes: 2 additions & 0 deletions giskard_vision/landmark_detection/dataloaders/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ class DataLoaderWrapper(DataIteratorBase):
_wrapped_dataloader (DataIteratorBase): The wrapped data loader instance.
"""

dataloader_type: str = "standard"

def __init__(self, dataloader: DataIteratorBase) -> None:
"""
Initializes the DataLoaderWrapper with a given DataIteratorBase instance.
Expand Down
27 changes: 24 additions & 3 deletions giskard_vision/landmark_detection/detectors/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
from giskard_vision.landmark_detection.tests.performance import NMEMean
from giskard_vision.utils.errors import GiskardImportError

import os
import cv2


class LandmarkDetectionBaseDetector(DetectorVisionBase):
"""
Abstract class for Landmark Detection Detectors
Expand Down Expand Up @@ -38,11 +41,29 @@ def get_results(self, model: Any, dataset: Any) -> Sequence[ScanResult]:
dataloader_ref=dataset,
)

results.append(self.get_scan_result(test_result))
# Save example images from dataloader and dataset
os.makedirs("examples_images", exist_ok=True)
filename_examples = []

filename_example_dataloader_ref = f"examples_images/{dataset.name}.png"
cv2.imwrite(
filename_example_dataloader_ref,
cv2.resize(dataset[0][0][0], (0, 0), fx=0.3, fy=0.3)
)
filename_examples.append(filename_example_dataloader_ref)

filename_example_dataloader = f"examples_images/{dl.name}.png"
cv2.imwrite(
filename_example_dataloader,
cv2.resize(dl[0][0][0], (0, 0), fx=0.3, fy=0.3)
)
filename_examples.append(filename_example_dataloader)

results.append(self.get_scan_result(test_result, filename_examples))

return results

def get_scan_result(self, test_result) -> ScanResult:
def get_scan_result(self, test_result, filename_examples) -> ScanResult:
try:
from giskard.scanner.issues import IssueLevel
except (ImportError, ModuleNotFoundError) as e:
Expand All @@ -63,5 +84,5 @@ def get_scan_result(self, test_result) -> ScanResult:
metric_reference_value=test_result.metric_value_ref,
issue_level=issue_level,
slice_size=test_result.size_data,
filename_examples=test_result.filename_examples,
filename_examples=filename_examples,
)
17 changes: 0 additions & 17 deletions giskard_vision/landmark_detection/tests/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,6 @@ class TestResult:
dataloader_ref_name: Optional[str] = None
size_data: Optional[int] = None
issue_name: Optional[str] = None
filename_examples: Optional[str] = None

def _repr_html_(self):
"""
Expand Down Expand Up @@ -314,21 +313,6 @@ def run(
[prediction_result.prediction_fail_rate, prediction_result_ref.prediction_fail_rate]
)

import os

import matplotlib.pyplot as plt

os.makedirs("examples_images", exist_ok=True)
filename_examples = []

filename_example_dataloader_ref = f"examples_images/{dataloader_ref.name}.png"
plt.imsave(filename_example_dataloader_ref, dataloader_ref[0][0][0])
filename_examples.append(filename_example_dataloader_ref)

filename_example_dataloader = f"examples_images/{dataloader.name}.png"
plt.imsave(filename_example_dataloader, dataloader[0][0][0])
filename_examples.append(filename_example_dataloader)

return TestResult(
test_name=self.__class__.__name__,
description=self.metric.description,
Expand All @@ -347,5 +331,4 @@ def run(
dataloader_ref_name=dataloader_ref.name,
size_data=len(dataloader),
issue_name=dataloader.name,
filename_examples=filename_examples,
)
10 changes: 0 additions & 10 deletions giskard_vision/scanner/scanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,6 @@ def analyze(
detectors, model, dataset, verbose=verbose, raise_exceptions=raise_exceptions
)

# issues = self._postprocess(issues)

# Scan completed
elapsed = perf_counter() - time_start

Expand Down Expand Up @@ -126,14 +124,6 @@ def _run_detectors(self, detectors, model, dataset, verbose=True, raise_exceptio

return issues, errors

def _postprocess(self, issues: Sequence[Issue]) -> Sequence[Issue]:
# If we detected a Stochasticity issue, we will have a possibly false
# positive DataLeakage issue. We remove it here.
if any(issue.group == Stochasticity for issue in issues):
issues = [issue for issue in issues if issue.group != DataLeakage]

return issues

def get_detectors(self, tags: Optional[Sequence[str]] = None) -> Sequence:
"""Returns the detector instances."""
detectors = []
Expand Down

0 comments on commit ac9b268

Please sign in to comment.