Skip to content

Commit

Permalink
Fix issue in solution to xarray conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Sep 17, 2020
1 parent f3c4082 commit d3fe385
Showing 1 changed file with 14 additions and 10 deletions.
24 changes: 14 additions & 10 deletions sunode/problem.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,34 +107,38 @@ def solution_to_xarray( # type: ignore
solution = solution.view(self.state_dtype)[..., 0]
params = self.extract_params(user_data)

def as_dict(array, prepend=None): # type: ignore
def as_dict(array, dims, prepend=None): # type: ignore
if prepend is None:
prepend = []
dtype = array.dtype
out = {}
for name in dtype.names:
if array[name].dtype == np.float64:
out['_'.join(prepend + [name])] = array[name]
out['_'.join(prepend + [name])] = (tuple(dims[name][1]), array[name])
else:
out.update(as_dict(array[name], prepend + [name]))
out.update(as_dict(array[name], dims[name], prepend + [name]))
return out

data = xr.Dataset()
data = xr.Dataset(coords=self.coords)
data['time'] = ('time', tvals)
# TODO t0?
if unstack_state:
state = as_dict(solution, ['solution'])
state = as_dict(solution, self.state_subset.dims, ['solution'])
for name in state:
assert name not in data
data[name] = ('time', state[name])
if name in data:
raise ValueError(f"Variable {name} is not unique.")
dims, vals = state[name]
data[name] = (('time',) + dims, vals)
else:
data['solution'] = ('time', solution)

if unstack_params:
params = as_dict(params, ['parameters'])
params = as_dict(params, self.params_subset.dims, ['parameters'])
for name in params:
assert name not in data
data[name] = params[name]
if name in data:
raise ValueError(f"Variable {name} is not unique.")
dims, vals = params[name]
data[name] = (dims, vals)
else:
data['parameters'] = params

Expand Down

0 comments on commit d3fe385

Please sign in to comment.