Skip to content

Commit

Permalink
read entire dataset into memory before trying to access individual st…
Browse files Browse the repository at this point in the history
…ates; implement much faster get_length method
  • Loading branch information
svandenhaute committed Jan 30, 2024
1 parent ff6afad commit 8574a6a
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 7 deletions.
20 changes: 15 additions & 5 deletions psiflow/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,11 +234,12 @@ def read_dataset(
atoms = list(read_extxyz(f, index=index_or_indices))[0]
data = FlowAtoms.from_atoms(atoms) # single atoms instance
data.calc = None
else:
else: # read data all at once in memory, then extract whatever we need
all_data = list(read_extxyz(f, index=slice(None)))
if type(index_or_indices) is list:
data = [list(read_extxyz(f, index=i))[0] for i in index_or_indices]
data = [all_data[i] for i in index_or_indices]
elif type(index_or_indices) is slice:
data = list(read_extxyz(f, index=index_or_indices))
data = all_data[index_or_indices]
else:
raise ValueError
data = [FlowAtoms.from_atoms(a) for a in data] # list of atoms
Expand Down Expand Up @@ -285,8 +286,17 @@ def join_dataset(inputs: List[File] = [], outputs: List[File] = []) -> None:

@typeguard.typechecked
def get_length_dataset(inputs: List[File] = []) -> int:
data = read_dataset(slice(None), inputs=[inputs[0]])
return len(data)
nframes = 0
with open(inputs[0], "r") as f:
while True:
try:
natoms = int(f.readline())
except ValueError:
break
nframes += 1
for i in range(natoms + 1): # skip ahead
f.readline()
return nframes


app_length_dataset = python_app(get_length_dataset, executors=["default_threads"])
Expand Down
6 changes: 4 additions & 2 deletions psiflow/walkers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,10 +161,12 @@ def parameters(self) -> dict[str, Any]:
def multiply(cls, nwalkers: int, data_start: Dataset, **kwargs) -> list[BaseWalker]:
walkers = [cls(data_start[0], **kwargs) for i in range(nwalkers)]
length = data_start.length().result()
data_in_memory = data_start.as_list()
for i, walker in enumerate(walkers):
state = unpack_i(data_in_memory, i % length)
walker.seed = i
walker.set_initial_state(data_start[i % length])
walker.set_state(data_start[i % length])
walker.set_initial_state(state)
walker.set_state(state)
return walkers

@classmethod
Expand Down

0 comments on commit 8574a6a

Please sign in to comment.