diff --git a/psiflow/reference/reference.py b/psiflow/reference/reference.py index 317d094..8479bba 100644 --- a/psiflow/reference/reference.py +++ b/psiflow/reference/reference.py @@ -115,7 +115,7 @@ def compute(self, dataset: Dataset, *outputs: Optional[Union[str, tuple]]): for output in outputs_: if output not in self.outputs: raise ValueError("output {} not in {}".format(output, self.outputs)) - index = outputs_.index(output) + index = self.outputs.index(output) to_return.append(compute_outputs[index]) if len(outputs_) == 1: return to_return[0] diff --git a/tests/test_reference.py b/tests/test_reference.py index a018e88..29b5eb9 100644 --- a/tests/test_reference.py +++ b/tests/test_reference.py @@ -181,8 +181,10 @@ def test_reference_d3(context, dataset, tmp_path): assert state.energy is not None assert state.energy < 0.0 # dispersion is attractive - data = dataset[:3].evaluate(reference) - energy = reference.compute(dataset[:3], "energy") + subset = dataset[:3] + data = subset.evaluate(reference) + energy = reference.compute(subset, "energy") + forces = reference.compute(subset, "forces") assert np.allclose( data.get("energy").result(), @@ -193,6 +195,7 @@ def test_reference_d3(context, dataset, tmp_path): 0.0, ) + assert len(forces.result().shape) == 3 @pytest.mark.filterwarnings("ignore:Original input file not found") def test_cp2k_success(context, simple_cp2k_input):