diff --git a/sunode/problem.py b/sunode/problem.py index 2700fb4..54d396f 100644 --- a/sunode/problem.py +++ b/sunode/problem.py @@ -107,34 +107,38 @@ def solution_to_xarray( # type: ignore solution = solution.view(self.state_dtype)[..., 0] params = self.extract_params(user_data) - def as_dict(array, prepend=None): # type: ignore + def as_dict(array, dims, prepend=None): # type: ignore if prepend is None: prepend = [] dtype = array.dtype out = {} for name in dtype.names: if array[name].dtype == np.float64: - out['_'.join(prepend + [name])] = array[name] + out['_'.join(prepend + [name])] = (tuple(dims[name][1]), array[name]) else: - out.update(as_dict(array[name], prepend + [name])) + out.update(as_dict(array[name], dims[name], prepend + [name])) return out - data = xr.Dataset() + data = xr.Dataset(coords=self.coords) data['time'] = ('time', tvals) # TODO t0? if unstack_state: - state = as_dict(solution, ['solution']) + state = as_dict(solution, self.state_subset.dims, ['solution']) for name in state: - assert name not in data - data[name] = ('time', state[name]) + if name in data: + raise ValueError(f"Variable {name} is not unique.") + dims, vals = state[name] + data[name] = (('time',) + dims, vals) else: data['solution'] = ('time', solution) if unstack_params: - params = as_dict(params, ['parameters']) + params = as_dict(params, self.params_subset.dims, ['parameters']) for name in params: - assert name not in data - data[name] = params[name] + if name in data: + raise ValueError(f"Variable {name} is not unique.") + dims, vals = params[name] + data[name] = (dims, vals) else: data['parameters'] = params