Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add image alignment research code #69

Merged
merged 1 commit into from
Jun 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions OCR/alignment/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .four_point_transform import FourPointTransform as FourPointTransform
from .image_homography import ImageHomography as ImageHomography
from .random_perspective_transform import RandomPerspectiveTransform as RandomPerspectiveTransform
47 changes: 47 additions & 0 deletions OCR/alignment/four_point_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
"""
Uses quadrilaterial edge detection and executes a four-point perspective transform on a source image.
"""

from pathlib import Path
import functools

import numpy as np
import cv2 as cv


class FourPointTransform:
def __init__(self, image: Path):
self.image = cv.imread(str(image), cv.IMREAD_GRAYSCALE)

@staticmethod
def _order_points(quadrilateral: np.ndarray) -> np.ndarray:
"Reorder points from a 4x2 input array representing the vertices of a quadrilateral, such that the coordinates of each vertex are arranged in order from top left, top right, bottom right, and bottom left."
quadrilateral = quadrilateral.reshape(4, 2)
output_quad = np.zeros([4, 2]).astype(np.float32)
s = quadrilateral.sum(axis=1)
output_quad[0] = quadrilateral[np.argmin(s)]
output_quad[2] = quadrilateral[np.argmax(s)]
diff = np.diff(quadrilateral, axis=1)
output_quad[1] = quadrilateral[np.argmin(diff)]
output_quad[3] = quadrilateral[np.argmax(diff)]
return output_quad

def find_largest_contour(self):
"""Compute contours for an image and find the biggest one by area."""
_, contours, _ = cv.findContours(self.image, cv.RETR_EXTERNAL, cv.CHAIN_APPROX_SIMPLE)
return functools.reduce(lambda a, b: b if cv.contourArea(a) < cv.contourArea(b) else a, contours)

def simplify_polygon(self, contour):
"""Simplify to a polygon with (hopefully four) vertices."""
perimeter = cv.arcLength(contour, True)
return cv.approxPolyDP(contour, 0.01 * perimeter, True)

def dewarp(self) -> np.ndarray:
biggest_contour = self.find_largest_contour()
simplified = self.simplify_polygon(biggest_contour)

height, width = self.image.shape
destination = np.array([[0, 0], [width, 0], [width, height], [0, height]], dtype=np.float32)

M = cv.getPerspectiveTransform(self.order_points(simplified), destination)
return cv.warpPerspective(self.image, M, (width, height))
50 changes: 50 additions & 0 deletions OCR/alignment/image_homography.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from pathlib import Path

import numpy as np
import cv2 as cv


class ImageHomography:
def __init__(self, template: Path, match_ratio=0.3):
"""Initialize the image homography pipeline with a `template` image."""
if match_ratio >= 1 or match_ratio <= 0:
raise ValueError("`match_ratio` must be between 0 and 1")

self.template = cv.imread(template)
self.match_ratio = match_ratio
self._sift = cv.SIFT_create()

def estimate_self_similarity(self):
"""Calibrate `match_ratio` using a self-similarity metric."""
raise NotImplementedError

def compute_descriptors(self, img):
"""Compute SIFT descriptors for a target `img`."""
return self._sift.detectAndCompute(img, None)

def knn_match(self, descriptor_template, descriptor_query):
"""Return k-nearest neighbors match (k=2) between descriptors generated from a template and query image."""
matcher = cv.DescriptorMatcher_create(cv.DescriptorMatcher_FLANNBASED)
return matcher.knnMatch(descriptor_template, descriptor_query, 2)

def transform_homography(self, other):
"""Run the image homography pipeline against a query image."""
# find the keypoints and descriptors with SIFT
kp1, descriptors1 = self.compute_descriptors(self.template)
kp2, descriptors2 = self.compute_descriptors(other)

knn_matches = self.knn_match(descriptors1, descriptors2)

# Filter matches using the Lowe's ratio test
# use an aggressive threshold here- the larger the image the more aggresively this should be filtered
good_matches = []
for m, n in knn_matches:
if m.distance < self.match_ratio * n.distance:
good_matches.append(m)

src_pts = np.float32([kp1[m.queryIdx].pt for m in good_matches]).reshape(-1, 1, 2)
dst_pts = np.float32([kp2[m.trainIdx].pt for m in good_matches]).reshape(-1, 1, 2)

M, _ = cv.findHomography(dst_pts, src_pts, cv.RANSAC, 5.0)

return cv.warpPerspective(other, M, (self.template.shape[1], self.template.shape[0]))
37 changes: 37 additions & 0 deletions OCR/alignment/random_perspective_transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""
Perspective transforms a base image between 10% and 90% distortion.
"""

from pathlib import Path

import torchvision.transforms as transforms
from PIL import Image


class RandomPerspectiveTransform:
"""Generate a random perspective transform based on a template `image`."""

def __init__(self, image: Path):
self.image = Image.open(image)

@staticmethod
def _make_transform(distortion_scale: float) -> object:
"""
Internal function to create a composed transformer for perspective warps.

This needs to be instantiated new each time in order for the RandomPerspective transformer to be truly random between repeated calls to the `transform` function.
"""
return transforms.Compose(
[
transforms.RandomPerspective(distortion_scale=distortion_scale, p=1),
transforms.ToTensor(),
transforms.ToPILImage(),
]
)

def transform(self, distortion_scale: float) -> object:
"""Warp the template image with specified `distortion_scale`."""
if distortion_scale < 0 or distortion_scale >= 1:
raise ValueError("`distortion_scale` must be between 0 and 1")

return self._make_transform(distortion_scale)(self.image)
39 changes: 38 additions & 1 deletion OCR/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions OCR/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ opencv-python = "^4.9.0.80"
python-dotenv = "^1.0.1"
transformers = {extras = ["torch"], version = "^4.39.3"}
pillow = "^10.3.0"
torchvision = "^0.18.0"

[tool.poetry.group.dev.dependencies]
ruff = "^0.3.7"
Expand Down
27 changes: 27 additions & 0 deletions OCR/tests/alignment_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import os

import cv2 as cv
import numpy as np

from alignment import ImageHomography, RandomPerspectiveTransform


path = os.path.dirname(__file__)

template_image_path = os.path.join(path, "./assets/template_hep.jpg")
filled_image_path = os.path.join(path, "./assets/form_filled_hep.jpg")
filled_image = cv.imread(filled_image_path)


class TestAlignment:
def test_random_warp(self):
transformed = RandomPerspectiveTransform(filled_image_path).transform(distortion_scale=0.1)
assert np.median(cv.absdiff(np.array(transformed), filled_image)) > 0

def test_alignment_filled(self):
aligner = ImageHomography(template_image_path)
warped_image = np.array(RandomPerspectiveTransform(filled_image_path).transform(distortion_scale=0.1))
aligned = aligner.transform_homography(warped_image)
res = cv.absdiff(aligner.template, aligned)
assert aligner.template.shape == warped_image.shape
assert np.median(res) == 0
Binary file added OCR/tests/assets/template_hep.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading