Skip to content

Commit

Permalink
Fix image aspect ratio anyres logic issue
Browse files Browse the repository at this point in the history
  • Loading branch information
kcz358 committed Aug 14, 2024
1 parent a514188 commit dfa6e54
Showing 1 changed file with 27 additions and 1 deletion.
28 changes: 27 additions & 1 deletion python/sglang/srt/models/llava.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
h = num_patch_height * height
w = num_patch_width * width
new_h, new_w = unpad_image_shape(h, w, image_s)
print(new_h, new_w, image_s)

if "anyres_max" in self.config.image_aspect_ratio:
matched_anyres_max_num_patches = re.match(
Expand All @@ -109,6 +110,8 @@ def pad_input_ids(self, input_ids, pad_value, pt_shape=None, image_size=None):
pad_ids = pad_value * (
(new_image_feature_len + len(pad_value)) // len(pad_value)
)
print(f"New image_feature_len : {new_image_feature_len}")
print(f"Pad ids len : {(new_image_feature_len + len(pad_value)) // len(pad_value)}")
# print("calculated new_image_feature_len: ", new_image_feature_len)
offset = input_ids.index(self.config.image_token_index)
# old_len + pad_len - 1, because we need to remove image_token_id
Expand Down Expand Up @@ -255,7 +258,29 @@ def forward(
# num_patch_height, num_patch_width, height, width, -1
# )

if (
if "unpad" in self.mm_patch_merge_type:
image_feature = image_feature.permute(
4, 0, 2, 1, 3
).contiguous()
image_feature = image_feature.flatten(1, 2).flatten(
2, 3
)
image_feature = unpad_image(
image_feature, image_sizes[image_idx][0]
)
image_feature = torch.cat(
(
image_feature,
self.language_model.model.image_newline[
:, None, None
].expand(*image_feature.shape[:-1], 1),
),
dim=-1,
)
image_feature = image_feature.flatten(1, 2).transpose(
0, 1
)
elif (
"unpad" in self.mm_patch_merge_type
and "anyres_max" in image_aspect_ratio
and matched_anyres_max_num_patches
Expand Down Expand Up @@ -349,6 +374,7 @@ def forward(
] = image_features[pt][j]
except RuntimeError as e:
print(f"RuntimeError in llava image encoding: {e}")
print(image_features[pt].shape)
print(input_embeds.shape)
print(start_idx, image_offsets[i])
pt += 1
Expand Down

0 comments on commit dfa6e54

Please sign in to comment.