-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathaws_data_set_up.py
211 lines (175 loc) · 8.48 KB
/
aws_data_set_up.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
from PIL import Image, ImageDraw
import numpy as np
import io, boto3, os
from torch.utils.data import Dataset
from collections import defaultdict
class S3Dataset(Dataset):
def __init__(self, images_directory, annotations_directory, transform, bucket_name, target_size=300, max_annotations=100):
self.transform = transform
self.bucket_name = bucket_name
self.s3_client = boto3.client('s3')
self.target_size = target_size
self.max_annotations = max_annotations
# Ensure directories end with '/'
self.images_directory = images_directory.rstrip('/') + '/'
self.annotations_directory = annotations_directory.rstrip('/') + '/'
# List files using S3 client
self.image_files = self.list_s3_files(self.images_directory)
self.annotation_files = self.list_s3_files(self.annotations_directory)
# Filters images to those that have corresponding annotations
self.valid_image_files = self.filter_images_with_annotations(self.image_files)
def __len__(self):
return len(self.valid_image_files)
def list_s3_files(self, prefix):
"""List files in an S3 bucket under a specific prefix."""
paginator = self.s3_client.get_paginator('list_objects_v2')
response_iterator = paginator.paginate(Bucket=self.bucket_name, Prefix=prefix)
files = []
for page in response_iterator:
if 'Contents' in page:
files.extend([
content['Key'] for content in page['Contents']
if content['Key'].endswith(('.jpg', '.png', '.jpeg', '.txt'))
])
return files
def __getitem__(self, idx):
# Image handling
image_path = self.valid_image_files[idx]
image = self.load_s3_image(image_path)
# Annotations Handling
annotation_filename = os.path.basename(image_path).replace('.jpg', '.txt').replace('.png', '.txt')
annotation_path = self.annotations_directory + annotation_filename
boxes, labels, masks = self.load_s3_annotation(annotation_path)
boxes = torch.as_tensor(boxes, dtype=torch.float32) / self.target_size
labels = torch.as_tensor(labels, dtype=torch.int64)
if masks:
masks = torch.stack(masks)
else:
masks = torch.zeros((0, self.target_size, self.target_size), dtype=torch.uint8)
masks = self.pad_masks(masks, self.max_annotations)
mask_labels = self.pad_mask_classes(labels, self.max_annotations)
target = {'boxes': boxes, 'labels': labels, 'masks': masks, "mask_labels": mask_labels}
if self.transform:
image = self.transform(image)
return image, target
def load_s3_image(self, path):
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=path)
return Image.open(io.BytesIO(obj['Body'].read())).convert('RGB')
def load_s3_annotation(self, key):
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=key)
lines = obj['Body'].read().decode('utf-8').strip().split('\n')
boxes, labels, masks_list = [], [], []
for line in lines:
bbox_class_part = line.split("[")[0].split(",")
x_min, y_min, x_max, y_max = map(int, bbox_class_part[0:4])
class_code = int(bbox_class_part[4])
box = [x_min, y_min, x_max, y_max]
box = self.resize_box(box)
masks_part = eval("[" + line.split("[")[1])
mask = self.create_and_resize_mask((600, 600), (self.target_size, self.target_size), masks_part)
boxes.append(box)
labels.append(class_code)
masks_list.append(mask)
return boxes, labels, masks_list
def filter_images_with_annotations(self, image_files):
valid_image_files = []
for image_name in image_files:
annotation_name = os.path.basename(image_name).replace('.jpg', '.txt').replace('.png', '.txt')
annotation_path = self.annotations_directory + annotation_name
# Try to get the annotation file from S3
try:
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=annotation_path)
annotation_lines = obj['Body'].read().decode('utf-8').strip().split('\n')
if len(annotation_lines) <= self.max_annotations:
valid_image_files.append(image_name)
except self.s3_client.exceptions.NoSuchKey:
# Annotation file does not exist
continue
return valid_image_files
def resize_box(self, box, original_size=(600,600)):
"""
Parameters:
- box (list<int>): List of integers, representing the coordinates of the (x_min, y_min) and (x_max, y_max)
- original_size (tuple): Tuple of the original image size (600,600)
"""
x_min, y_min, x_max, y_max = box
orig_w, orig_h = original_size
new_w, new_h = self.target_size, self.target_size
x_min = int(x_min * new_w / orig_w)
x_max = int(x_max * new_w / orig_w)
y_min = int(y_min * new_h / orig_h)
y_max = int(y_max * new_h / orig_h)
return [x_min, y_min, x_max, y_max]
def create_and_resize_mask(self, original_size, target_size, polygons):
"""
Parameters:
original_size (tuple): Original dimensions of the image (height, width).
target_size (tuple): Target dimensions to which the mask will be resized.
polygons (list of lists of tuples): Each list of tuples represents polygon vertices.
"""
mask = Image.new('L', original_size, 0)
draw = ImageDraw.Draw(mask)
if len(polygons) >= 3:
draw.polygon(polygons, outline=1, fill=1)
else:
print("Polygon with insufficient points:", polygons)
mask = mask.resize(target_size, Image.NEAREST) # type: ignore
mask_array = np.array(mask)
return torch.tensor(mask_array, dtype=torch.uint8)
def pad_masks(self, masks, max_pad=100):
"""
Pad or truncate the mask tensor to have a fixed number of masks
"""
padded_masks = torch.zeros((max_pad, masks.shape[1], masks.shape[2]), dtype=masks.dtype)
actual_masks = min(max_pad, masks.shape[0])
padded_masks[:actual_masks] = masks[:actual_masks]
return padded_masks
def pad_mask_classes(self, labels, max_pad=100):
"""
Pad or truncate the labels to have a fixed number of classes
"""
padded_labels = torch.full((max_pad,), -1, dtype=torch.int64)
actual_labels = min(max_pad, len(labels))
padded_labels[:actual_labels] = labels[:actual_labels]
return padded_labels
def analyze_bounding_boxes(self):
"""
Analyze the bounding box sizes and return statistics.
"""
box_widths = []
box_heights = []
aspect_ratios = defaultdict(int)
for image_name in self.image_files:
# Construct the annotation file name
annotation_name = os.path.basename(image_name).replace('.jpg', '.txt').replace('.png', '.txt')
annotation_path = self.annotations_directory + annotation_name
obj = self.s3_client.get_object(Bucket=self.bucket_name, Key=annotation_path)
annotation_lines = obj['Body'].read().decode('utf-8').strip().split('\n')
for line in annotation_lines:
bbox_class_part = line.split("[")[0].split(",")
x_min, y_min, x_max, y_max = map(int, bbox_class_part[0:4])
width = x_max - x_min
height = y_max - y_min
if height == 0: height = 1 # Avoid division by zero
box_widths.append(width)
box_heights.append(height)
aspect_ratio = width / height
aspect_ratios[aspect_ratio] += 1
# Compute statistics
if box_widths and box_heights:
mean_width = np.mean(box_widths)
mean_height = np.mean(box_heights)
std_width = np.std(box_widths)
std_height = np.std(box_heights)
else:
mean_width = mean_height = std_width = std_height = 0
return {
'box_widths': box_widths,
'box_heights': box_heights,
'aspect_ratios': aspect_ratios,
'mean_width': mean_width,
'mean_height': mean_height,
'std_width': std_width,
'std_height': std_height,
}