Skip to content

Commit

Permalink
Merge pull request #134 from TGSAI/133_SelectField_fix
Browse files Browse the repository at this point in the history
Fix SelectField Dataset method returning unsliced void Variable
  • Loading branch information
markspec authored Oct 22, 2024
2 parents f5cc49d + 2739713 commit 5f270f0
Show file tree
Hide file tree
Showing 2 changed files with 100 additions and 13 deletions.
47 changes: 34 additions & 13 deletions mdio/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -1174,8 +1174,10 @@ class Dataset {
* @return An `mdio::Future` if the selection was valid and successful, or an
* error if the selection was invalid.
*/
Future<Variable<>> SelectField(const std::string variableName,
const std::string fieldName) {
template <typename T = void, DimensionIndex R = dynamic_rank,
ReadWriteMode M = ReadWriteMode::dynamic>
Future<Variable<T, R, M>> SelectField(const std::string& variableName,
const std::string& fieldName) {
// Ensure that the variable exists in the Dataset
if (!variables.contains_key(variableName)) {
return absl::Status(
Expand All @@ -1184,11 +1186,11 @@ class Dataset {
}

// Grab the Variable from the Dataset
auto varRes = variables.get(variableName);
if (!varRes.status().ok()) {
return varRes.status();
}
mdio::Variable var = varRes.value();
MDIO_ASSIGN_OR_RETURN(auto var, variables.at(variableName));
// Preserve the intervals so it can be re-sliced to the same dimensions
MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals());
intervals.pop_back(); // Remove the byte dimension (Doesn't matter if it
// doesn't exist)

// Ensure that the Variable is of dtype structarray
auto spec = var.spec();
Expand Down Expand Up @@ -1248,6 +1250,12 @@ class Dataset {
}
base["kvstore"]["driver"] = specJson["kvstore"]["driver"];
base["kvstore"]["path"] = specJson["kvstore"]["path"];
// Remove trailing slashes. This causes issue #130
while (base["kvstore"]["path"].get<std::string>().back() == '/') {
base["kvstore"]["path"] =
base["kvstore"]["path"].get<std::string>().substr(
0, base["kvstore"]["path"].get<std::string>().size() - 1);
}

// Handle cloud stores
if (specJson["kvstore"].contains("bucket")) {
Expand All @@ -1258,18 +1266,31 @@ class Dataset {
base["kvstore"]["path"] = cloudPath;
}

auto fieldedVar = mdio::Variable<>::Open(base, constants::kOpen);
auto fieldedVar = mdio::Variable<T, R, M>::Open(base, constants::kOpen);

auto pair = tensorstore::PromiseFuturePair<mdio::Variable<>>::Make();
auto pair = tensorstore::PromiseFuturePair<mdio::Variable<T, R, M>>::Make();
fieldedVar.ExecuteWhenReady(
[this, promise = pair.promise,
variableName](tensorstore::ReadyFuture<mdio::Variable<>> readyFut) {
[this, promise = pair.promise, variableName, intervals](
tensorstore::ReadyFuture<mdio::Variable<T, R, M>> readyFut) {
auto ready_result = readyFut.result();
if (!ready_result.ok()) {
promise.SetResult(ready_result.status());
} else {
this->variables.add(variableName, ready_result.value());
promise.SetResult(ready_result);
// Re-slice the Variable to the same dimensions as the original
std::vector<mdio::RangeDescriptor<Index>> slices;
slices.reserve(intervals.size() - 1);
for (const auto& interval : intervals) {
slices.emplace_back(mdio::RangeDescriptor<Index>(
{interval.label, interval.inclusive_min,
interval.exclusive_max, 1}));
}
auto slicedVarRes = ready_result.value().slice(slices);
if (!slicedVarRes.status().ok()) {
promise.SetResult(slicedVarRes.status());
} else {
this->variables.add(variableName, ready_result.value());
promise.SetResult(slicedVarRes);
}
}
});
return pair.future;
Expand Down
66 changes: 66 additions & 0 deletions mdio/dataset_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,72 @@ TEST(Dataset, selRepeatedRangeStop) {
ASSERT_FALSE(sliceRes.status().ok());
}

TEST(Dataset, selectField) {
auto json_var = GetToyExample();

auto dataset = mdio::Dataset::from_json(json_var, "zarrs/acceptance",
mdio::constants::kCreateClean);

ASSERT_TRUE(dataset.status().ok()) << dataset.status();
auto ds = dataset.value();

mdio::RangeDescriptor<mdio::Index> desc1 = {"inline", 0, 5, 1};
mdio::RangeDescriptor<mdio::Index> desc2 = {"crossline", 0, 5, 1};
mdio::RangeDescriptor<mdio::Index> desc3 = {"depth", 0, 5, 1};

auto sliceRes = ds.isel(desc1, desc2, desc3);
ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status();
auto slicedDs = sliceRes.value();

auto structImageHeadersRes =
slicedDs.variables.get<mdio::dtypes::byte_t>("image_headers");
ASSERT_TRUE(structImageHeadersRes.status().ok())
<< structImageHeadersRes.status();
auto structImageHeaders = structImageHeadersRes.value();
auto asdf = structImageHeaders.get_intervals();
ASSERT_TRUE(asdf.status().ok()) << asdf.status();
auto structIntervals = asdf.value();
ASSERT_EQ(structIntervals.size(), 3);

auto selectedVarFut =
slicedDs.SelectField<mdio::dtypes::int32_t>("image_headers", "cdp-x");
ASSERT_TRUE(selectedVarFut.status().ok()) << selectedVarFut.status();

auto typedInervalsRes = selectedVarFut.value().get_intervals();
ASSERT_TRUE(typedInervalsRes.status().ok()) << typedInervalsRes.status();
auto typedIntervals = typedInervalsRes.value();
ASSERT_EQ(typedIntervals.size(), 2);

EXPECT_EQ(typedIntervals[0].label, structIntervals[0].label)
<< "Dimension 0 labels did not match";
EXPECT_EQ(typedIntervals[1].label, structIntervals[1].label)
<< "Dimension 1 labels did not match";
EXPECT_EQ(typedIntervals[0].inclusive_min, structIntervals[0].inclusive_min)
<< "Dimension 0 min did not match";
EXPECT_EQ(typedIntervals[1].inclusive_min, structIntervals[1].inclusive_min)
<< "Dimension 1 min did not match";
EXPECT_EQ(typedIntervals[0].exclusive_max, structIntervals[0].exclusive_max)
<< "Dimension 0 max did not match";
EXPECT_EQ(typedIntervals[1].exclusive_max, structIntervals[1].exclusive_max)
<< "Dimension 1 max did not match";
}

TEST(Dataset, selectFieldName) {
auto json_var = GetToyExample();

auto dataset = mdio::Dataset::from_json(json_var, "zarrs/acceptance",
mdio::constants::kCreateClean);

ASSERT_TRUE(dataset.status().ok()) << dataset.status();
auto ds = dataset.value();

auto selectedVarFut =
ds.SelectField<mdio::dtypes::int32_t>("image_headers", "cdp-x");
ASSERT_TRUE(selectedVarFut.status().ok()) << selectedVarFut.status();
EXPECT_EQ(selectedVarFut.value().get_variable_name(), "image_headers")
<< "Expected selected variable to be named image_headers";
}

TEST(Dataset, fromConsolidatedMeta) {
auto json_vars = GetToyExample();

Expand Down

0 comments on commit 5f270f0

Please sign in to comment.