Skip to content

Commit

Permalink
Python: Retain explicitly specified 'unique' with frac_as_map
Browse files Browse the repository at this point in the history
  • Loading branch information
dbaston committed Aug 9, 2024
1 parent 024a6c4 commit f387f56
Show file tree
Hide file tree
Showing 4 changed files with 29 additions and 13 deletions.
6 changes: 5 additions & 1 deletion python/src/exactextract/exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,7 +218,11 @@ def prep_ops(stats, values, weights=None, *, add_unique=False):
if op.stat == "unique":
del need_unique[op.key()]
for op in need_unique.values():
ops += prepare_operations([change_stat(op, "unique")], values, weights)
for unique_op in prepare_operations(
[change_stat(op, "unique")], values, weights
):
unique_op.name = "@delete@" + unique_op.name
ops.append(unique_op)

return ops

Expand Down
27 changes: 19 additions & 8 deletions python/src/exactextract/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def __init__(
Args:
array_type: type that should be used to represent array outputs.
Either "numpy" (default) or "list".
Either "numpy" (default), "list", or "set"
map_fields: An optional dictionary of fields to be created by
interpreting one field as keys and another as values, in the format
``{ dst_field : (src_keys, src_vals) }``. For example, the fields
Expand All @@ -38,24 +38,28 @@ def __init__(
"""
super().__init__()

if array_type not in ("numpy", "list"):
if array_type not in ("numpy", "list", "set"):
raise ValueError("Unsupported array_type: " + array_type)

self.array_type = array_type
self.feature_list = []
self.map_fields = map_fields or {}
self.op_fields = {}
self.ops = []
self.remove_temporary_fields = False

def add_operation(self, op):
self.ops.append(op)
if op.name.startswith("@delete@"):
self.remove_temporary_fields = True

def write(self, feature):
f = JSONFeature()
feature.copy_to(f)

self._convert_arrays(f)
self._create_map_fields(f)
self._convert_arrays(f)
self._remove_temporary_fields(f)

self.feature_list.append(f.feature)

Expand All @@ -71,6 +75,12 @@ def _convert_arrays(self, f):
for k in props:
if type(props[k]) is np.ndarray:
props[k] = list(props[k])
elif self.array_type == "set":
import numpy as np

for k in props:
if type(props[k]) is np.ndarray:
props[k] = set(props[k])

def _fields_for_stat(self, stat):
return [o.name for o in self.ops if o.stat == stat]
Expand All @@ -79,7 +89,6 @@ def _create_map_fields(self, f):
props = f.feature["properties"]

new_fields = {}
to_delete = set()
for field in self.map_fields:
key_stat, val_stat = self.map_fields[field]

Expand All @@ -99,12 +108,14 @@ def _create_map_fields(self, f):
k: v for k, v in zip(props[key_field], props[val_field])
}

to_delete.add(key_field)
to_delete.add(val_field)
for field in to_delete:
del props[field]
props.update(new_fields)

def _remove_temporary_fields(self, f):
if self.remove_temporary_fields:
for field in list(f.feature["properties"]):
if field.startswith("@delete@"):
del f.feature["properties"][field]


class PandasWriter(Writer):
"""Creates a (Geo)Pandas DataFrame"""
Expand Down
4 changes: 2 additions & 2 deletions python/src/pybindings/operation_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ bind_operation(py::module& m)
.def("key", &Operation::key)
.def("weighted", &Operation::weighted)
.def_readonly("stat", &Operation::stat)
.def_readonly("name", &Operation::name)
.def_readwrite("name", &Operation::name)
.def_readonly("values", &Operation::values)
.def_readonly("weights", &Operation::weights);

Expand All @@ -135,7 +135,7 @@ bind_operation(py::module& m)

m.def("prepare_operations", py::overload_cast<const std::vector<std::string>&, const std::vector<RasterSource*>&, const std::vector<RasterSource*>&>(&prepare_operations));

m.def("change_stat", [](const Operation& op, std::string stat) {
m.def("change_stat", [](const Operation& op, std::string_view stat) {
auto sd = op.descriptor();
sd.name = "";
sd.stat = stat;
Expand Down
5 changes: 3 additions & 2 deletions python/tests/test_exact_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -1019,14 +1019,15 @@ def test_output_map_fields():
result = exact_extract(
rast,
square,
["frac", "weighted_frac"],
["frac", "weighted_frac", "unique"],
weights=weights,
output_options={"frac_as_map": True},
output_options={"frac_as_map": True, "array_type": "set"},
)

assert result[0]["properties"] == {
"frac": {1: 0.25, 2: 0.5, 3: 0.1875, 4: 0.0625},
"weighted_frac": {1: 0.0, 2: 0.0, 3: 0.75, 4: 0.25},
"unique": {1, 2, 3, 4},
}


Expand Down

0 comments on commit f387f56

Please sign in to comment.