Skip to content

Commit

Permalink
Add testing.check_figures_equal to avoid storing baseline images (#555)
Browse files Browse the repository at this point in the history
* Turn check_figures_equal into a decorator function

Also moved test_check_figures_* to a doctest
under check_figures_equal.

* Ensure pytest fixtures can be used with check_figures_equal decorator
* Add notes on using check_figures_equal to CONTRIBUTING.md
* Extra checks to ensure image files exist or not

Co-authored-by: Wei Ji <[email protected]>
  • Loading branch information
seisman and weiji14 authored Sep 4, 2020
1 parent a3d6f84 commit 97a585b
Show file tree
Hide file tree
Showing 4 changed files with 163 additions and 4 deletions.
34 changes: 32 additions & 2 deletions CONTRIBUTING.md
Original file line number Diff line number Diff line change
Expand Up @@ -310,8 +310,38 @@ Leave a comment in the PR and we'll help you out.

### Testing plots

We use the [pytest-mpl](https://github.com/matplotlib/pytest-mpl) plug-in to test plot
generating code.
Writing an image-based test is only slightly more difficult than a simple test.
The main consideration is that you must specify the "baseline" or reference
image, and compare it with a "generated" or test image. This is handled using
the *decorator* functions `@check_figures_equal` and
`@pytest.mark.mpl_image_compare` whose usage are further described below.

#### Using check_figures_equal

This approach draws the same figure using two different methods (the reference
method and the tested method), and checks that both of them are the same.
It takes two `pygmt.Figure` objects ('fig_ref' and 'fig_test'), generates a png
image, and checks for the Root Mean Square (RMS) error between the two.
Here's an example:

```python
@check_figures_equal()
def test_my_plotting_case(fig_ref, fig_test):
"Test that my plotting function works"
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")
```

Note: This is the recommended way to test plots whenever possible, such as when
we want to compare a reference GMT plot created from NetCDF files with one
generated by PyGMT that passes through several layers of virtualfile machinery.
Using this method will help save space in the git repository by not having to
store baseline images as with the other method below.

#### Using mpl_image_compare

This method uses the [pytest-mpl](https://github.com/matplotlib/pytest-mpl)
plug-in to test plot generating code.
Every time the tests are run, `pytest-mpl` compares the generated plots with known
correct ones stored in `pygmt/tests/baseline`.
If your test created a `pygmt.Figure` object, you can test it by adding a *decorator* and
Expand Down
6 changes: 6 additions & 0 deletions pygmt/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,3 +44,9 @@ class GMTVersionError(GMTError):
"""
Raised when an incompatible version of GMT is being used.
"""


class GMTImageComparisonFailure(AssertionError):
"""
Raised when a comparison between two images fails.
"""
113 changes: 113 additions & 0 deletions pygmt/helpers/testing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,113 @@
"""
Helper functions for testing.
"""

import inspect
import os

from matplotlib.testing.compare import compare_images

from ..exceptions import GMTImageComparisonFailure
from ..figure import Figure


def check_figures_equal(*, tol=0.0, result_dir="result_images"):
"""
Decorator for test cases that generate and compare two figures.
The decorated function must take two arguments, *fig_ref* and *fig_test*,
and draw the reference and test images on them. After the function
returns, the figures are saved and compared.
This decorator is practically identical to matplotlib's check_figures_equal
function, but adapted for PyGMT figures. See also the original code at
https://matplotlib.org/3.3.1/api/testing_api.html#
matplotlib.testing.decorators.check_figures_equal
Parameters
----------
tol : float
The RMS threshold above which the test is considered failed.
result_dir : str
The directory where the figures will be stored.
Examples
--------
>>> import pytest
>>> import shutil
>>> @check_figures_equal(result_dir="tmp_result_images")
... def test_check_figures_equal(fig_ref, fig_test):
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
... fig_test.basemap(projection="X5c", region=[0, 5, 0, 5], frame="af")
>>> test_check_figures_equal()
>>> assert len(os.listdir("tmp_result_images")) == 0
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
>>> @check_figures_equal(result_dir="tmp_result_images")
... def test_check_figures_unequal(fig_ref, fig_test):
... fig_ref.basemap(projection="X5c", region=[0, 5, 0, 5], frame=True)
... fig_test.basemap(projection="X5c", region=[0, 3, 0, 3], frame=True)
>>> with pytest.raises(GMTImageComparisonFailure):
... test_check_figures_unequal()
>>> for suffix in ["", "-expected", "-failed-diff"]:
... assert os.path.exists(
... os.path.join(
... "tmp_result_images",
... f"test_check_figures_unequal{suffix}.png",
... )
... )
>>> shutil.rmtree(path="tmp_result_images") # cleanup folder if tests pass
"""

def decorator(func):

os.makedirs(result_dir, exist_ok=True)
old_sig = inspect.signature(func)

def wrapper(*args, **kwargs):
try:
fig_ref = Figure()
fig_test = Figure()
func(*args, fig_ref=fig_ref, fig_test=fig_test, **kwargs)
ref_image_path = os.path.join(
result_dir, func.__name__ + "-expected.png"
)
test_image_path = os.path.join(result_dir, func.__name__ + ".png")
fig_ref.savefig(ref_image_path)
fig_test.savefig(test_image_path)

# Code below is adapted for PyGMT, and is originally based on
# matplotlib.testing.decorators._raise_on_image_difference
err = compare_images(
expected=ref_image_path,
actual=test_image_path,
tol=tol,
in_decorator=True,
)
if err is None: # Images are the same
os.remove(ref_image_path)
os.remove(test_image_path)
else: # Images are not the same
for key in ["actual", "expected", "diff"]:
err[key] = os.path.relpath(err[key])
raise GMTImageComparisonFailure(
"images not close (RMS %(rms).3f):\n\t%(actual)s\n\t%(expected)s "
% err
)
finally:
del fig_ref
del fig_test

parameters = [
param
for param in old_sig.parameters.values()
if param.name not in {"fig_test", "fig_ref"}
]
new_sig = old_sig.replace(parameters=parameters)
wrapper.__signature__ = new_sig

return wrapper

return decorator
14 changes: 12 additions & 2 deletions pygmt/tests/test_grdimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
Test Figure.grdimage
"""
import numpy as np
import xarray as xr
import pytest
import xarray as xr

from .. import Figure
from ..exceptions import GMTInvalidInput
from ..datasets import load_earth_relief
from ..exceptions import GMTInvalidInput
from ..helpers.testing import check_figures_equal


@pytest.fixture(scope="module", name="grid")
Expand Down Expand Up @@ -93,3 +94,12 @@ def test_grdimage_over_dateline(xrgrid):
xrgrid.gmt.gtype = 1 # geographic coordinate system
fig.grdimage(grid=xrgrid, region="g", projection="A0/0/1c", V="i")
return fig


@check_figures_equal()
def test_grdimage_central_longitude(grid, fig_ref, fig_test):
"""
Test that plotting a grid centred at different longitudes/meridians work.
"""
fig_ref.grdimage("@earth_relief_01d_g", projection="W120/15c", cmap="geo")
fig_test.grdimage(grid, projection="W120/15c", cmap="geo")

0 comments on commit 97a585b

Please sign in to comment.