forked from NSLS-II-CHX/workflows
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsparsify.py
206 lines (167 loc) · 6.27 KB
/
sparsify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import sys
sys.path.insert(0, "/nsls2/data/chx/shared/workflows")
import numpy as np
import sparse
import tiled
import time
from functools import reduce
from masks import MaskClient, combine_masks
from pathlib import Path
from prefect import flow, task, get_run_logger
from tiled.client import from_profile
from tiled.structures.sparse import COOStructure
EXPORT_PATH = Path("/nsls2/data/dssi/scratch/prefect-outputs/chx/")
# distributed_client = distributed.Client(n_workers=1, threads_per_worker=1, processes=False)
tiled_client = from_profile("nsls2", "dask", username=None)["chx"]
tiled_client_chx = tiled_client["raw"]
tiled_client_sandbox = tiled_client["sandbox"]
def get_metadata(run):
"""
Collect the BlueskyRun metadata.
Parameters
----------
run: BlueskyRun
Returns
-------
metadata: dict
The BlueskyRun's metadata
"""
# TODO: Exception or Warning for collisions with the start metadata.
metadata = {}
metadata.update(run.start)
metadata["suid"] = run.start["uid"].split("-")[0] # short uid
metadata.update(run.start.get("plan_args", {}))
# Get the detector metadata.
detector = run.start["detectors"][0]
metadata["detector"] = f"{detector}_image"
metadata["detectors"] = [detector]
# Check if the method below applies to runs in general and not just for run2
metadata.update(run["primary"].descriptors[0]["configuration"][detector]["data"])
# Get filename prefix.
# We think we can still sparsify a run if there are no resources,
# so we don't raise an exception if no resource is found.
for name, document in run.documents():
if name == "resource":
metadata["filename"] = str(
Path(document.get("root", "/"), document["resource_path"])
)
break
if "primary" in run:
descriptor = run["primary"].descriptors[0]
# data_keys is a required key in the descriptor document.
# detector_name must be in the descriptor.data_keys
metadata["img_shape"] = descriptor["data_keys"][f"{detector}_image"].get(
"shape", []
)[:2][::-1]
# Fix up some datatypes.
metadata["number of images"] = int(metadata["number of images"])
metadata["start_time"] = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(run.start["time"])
)
metadata["stop_time"] = time.strftime(
"%Y-%m-%d %H:%M:%S", time.localtime(run.stop["time"])
)
# Get detector metadata
dataset = run[f"{detector}_image_metadata_patched_in_at_runtime"]["data"].read()
file_metadata = {key: dataset[key].values[0] for key in list(dataset)}
# Convert numpy arrays to lists.
# for key in {"pixel_mask", "binary_mask"}:
# file_metadata[key] = file_metadata[key].tolist()
metadata.update(file_metadata)
del metadata["pixel_mask"]
del metadata["binary_mask"]
return metadata
def write_sparse_chunk(data, dataset_id=None, block_info=None, dataset=None):
result = sparse.COO(data)
if block_info:
if dataset is None:
tiled_client = from_profile("nsls2", "dask", username=None)["chx"]
tiled_client_sandbox = tiled_client["sandbox"]
dataset = tiled_client_sandbox[dataset_id]
dataset.write_block(
coords=result.coords,
data=result.data,
block=block_info[None]["chunk-location"],
)
# Returning `data` instead of `result` gives a nice performance
# improvment. This causes dask not to update the resulting
# dataset, which is not needed because we wrote the result
# to tiled.
return data
# TODO: Change "chip_mask" back to "pixel_mask"
@task
def sparsify(
ref,
mask_names=["chip_mask"]
):
"""
Performs sparsification.
Parameters
----------
ref: string
This is the reference to the BlueskyRun to be exported. It can be
a partial uid, a full uid, a scan_id, or an index (e.g. -1).
mask_names: list
A list of mask names to be applied.
Returns
-------
dataset_uid: string
The uid of the resulting dataset.
"""
logger = get_run_logger()
# Get the BlueskyRun from Tiled.
run = tiled_client_chx[ref]
# Compose the run metadata.
metadata = get_metadata(run)
detector_name = metadata["detector"]
# Load the images.
images = run["primary"]["data"][detector_name].read()
# TODO: Save the detector image in the correct orientation,
# so we don't have to rotate it.
# Rotate the images if he detector is eiger500k_single_image.
if detector_name == "eiger500K_single_image":
images = np.rot90(images, axes=(3, 2))
# Get the mask.
mask_client = MaskClient(tiled_client_sandbox)
uid_masks = [mask_client.get_mask(detector_name, mask_name)
for mask_name in mask_names]
uids = [uid for uid, mask in uid_masks]
masks = [mask for uid, mask in uid_masks]
metadata['masks_names'] = mask_names
metadata['mask_uids'] = uids
mask = combine_masks(masks)
# Flip the images.
images = np.flip(images, axis=2)
# Apply the mask.
image_count = images.shape[1]
mask = np.broadcast_to(mask, (image_count,) + mask.shape)
images = images * mask
# Let dask pick the chunk size.
# Set the block_size_limit equal to the tiled size limit.
images = images.rechunk(block_size_limit=75_000_000)
# Create a new dataset in tiled.
dataset = tiled_client_sandbox.new(
"sparse",
COOStructure(
shape=images.shape,
chunks=images.chunks,
),
metadata=metadata,
)
dataset_id = dataset.item["id"]
# Run sparsification and write the data to tiled in parallel.
_ = images.map_blocks(
write_sparse_chunk, dataset_id=dataset_id, dataset=dataset
).compute()
logger.info(f"dataset_id: {dataset_id}")
return dataset_id
# Make the Prefect Flow.
# A separate command is needed to register it with the Prefect server.
@flow
def sparsify_flow(ref, mask_names=["chip_mask"]):
logger = get_run_logger()
logger.info(f"tiled: {tiled.__version__}")
logger.info(f"sparse: {sparse.__version__}")
logger.info(f"profiles: {tiled.profiles.list_profiles()['nsls2']}")
# TODO: Change "chip_mask" back to "pixel_mask"
sparsify(ref,mask_names=mask_names)