Skip to content

Commit

Permalink
tests fix
Browse files Browse the repository at this point in the history
  • Loading branch information
sarthak-dv committed Aug 21, 2024
1 parent e074787 commit d16081b
Showing 1 changed file with 21 additions and 85 deletions.
106 changes: 21 additions & 85 deletions tardis/visualization/tools/tests/test_liv_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import astropy.units as u
import numpy as np
import pytest
from matplotlib.collections import PolyCollection
from matplotlib.lines import Line2D
from matplotlib.testing.compare import compare_images

from tardis.base import run_tardis
from tardis.io.util import HDFWriterMixin
Expand Down Expand Up @@ -183,25 +182,6 @@ def generate_plot_mpl_hdf(self, plotter_generate_plot_mpl):
"step_x": plotter.step_x.value,
"step_y": plotter.step_y,
}
for index1, data in enumerate(fig.get_children()):
if isinstance(data.get_label(), str):
property_group["label" + str(index1)] = (
data.get_label().encode()
)
# save line plots
if isinstance(data, Line2D):
property_group["data" + str(index1)] = data.get_xydata()
property_group["linepath" + str(index1)] = (
data.get_path().vertices
)

# save artists which correspond to element contributions
if isinstance(data, PolyCollection):
for index2, path in enumerate(data.get_paths()):
property_group[
"polypath" + "ind_" + str(index1) + "ind_" + str(index2)
] = path.vertices

plot_data = PlotDataHDF(**property_group)
return plot_data

Expand All @@ -216,37 +196,26 @@ def test_generate_plot_mpl(
expected.get("plot_data_hdf/" + item).values.flatten(),
getattr(generate_plot_mpl_hdf, item),
)
labels = expected["plot_data_hdf/scalars"]
for index1, data in enumerate(fig.get_children()):
if isinstance(data.get_label(), str):
assert (
getattr(labels, "label" + str(index1)).decode()
== data.get_label()
)
# save line plots
if isinstance(data, Line2D):
np.testing.assert_allclose(
data.get_xydata(),
expected.get("plot_data_hdf/" + "data" + str(index1)),
)
np.testing.assert_allclose(
data.get_path().vertices,
expected.get("plot_data_hdf/" + "linepath" + str(index1)),
)
# save artists which correspond to element contributions
if isinstance(data, PolyCollection):
for index2, path in enumerate(data.get_paths()):
np.testing.assert_almost_equal(
path.vertices,
expected.get(
"plot_data_hdf/"
+ "polypath"
+ "ind_"
+ str(index1)
+ "ind_"
+ str(index2)
),
)

def test_mpl_image(self, plotter_generate_plot_mpl, tmp_path, request):
regression_data = RegressionData(request)
fig, _ = plotter_generate_plot_mpl
regression_data.fpath.parent.mkdir(parents=True, exist_ok=True)
fig.figure.savefig(tmp_path / f"{regression_data.fname_prefix}.png")

if regression_data.enable_generate_reference:
fig.figure.savefig(
regression_data.absolute_regression_data_dir
/ f"{regression_data.fname_prefix}.png"
)
pytest.skip("Skipping test to generate reference data")
else:
expected = str(
regression_data.absolute_regression_data_dir
/ f"{regression_data.fname_prefix}.png"
)
actual = str(tmp_path / f"{regression_data.fname_prefix}.png")
compare_images(expected, actual, tol=0.001)

@pytest.fixture(scope="function", params=combinations)
def plotter_generate_plot_ply(self, request, plotter):
Expand Down Expand Up @@ -283,14 +252,6 @@ def generate_plot_plotly_hdf(self, plotter_generate_plot_ply, request):
"step_x": plotter.step_x.value,
"step_y": plotter.step_y,
}
for index, data in enumerate(fig.data):
group = "_" + str(index)
if data.stackgroup:
property_group[group + "stackgroup"] = data.stackgroup.encode()
if data.name:
property_group[group + "name"] = data.name.encode()
property_group[group + "x"] = data.x
property_group[group + "y"] = data.y
plot_data = PlotDataHDF(**property_group)
return plot_data

Expand All @@ -306,28 +267,3 @@ def test_generate_plot_ply(
expected.get("plot_data_hdf/" + item).values.flatten(),
getattr(generate_plot_plotly_hdf, item),
)

for index, data in enumerate(fig.data):
group = "plot_data_hdf/" + "_" + str(index)
if data.stackgroup:
assert (
data.stackgroup
== getattr(
expected["/plot_data_hdf/scalars"],
"_" + str(index) + "stackgroup",
).decode()
)
if data.name:
assert (
data.name
== getattr(
expected["/plot_data_hdf/scalars"],
"_" + str(index) + "name",
).decode()
)
np.testing.assert_allclose(
data.x, expected.get(group + "x").values.flatten()
)
np.testing.assert_allclose(
data.y, expected.get(group + "y").values.flatten()
)

0 comments on commit d16081b

Please sign in to comment.