diff --git a/mdio/dataset.h b/mdio/dataset.h index c4db168..b637a12 100644 --- a/mdio/dataset.h +++ b/mdio/dataset.h @@ -1174,8 +1174,10 @@ class Dataset { * @return An `mdio::Future` if the selection was valid and successful, or an * error if the selection was invalid. */ - Future> SelectField(const std::string variableName, - const std::string fieldName) { + template + Future> SelectField(const std::string& variableName, + const std::string& fieldName) { // Ensure that the variable exists in the Dataset if (!variables.contains_key(variableName)) { return absl::Status( @@ -1184,11 +1186,11 @@ class Dataset { } // Grab the Variable from the Dataset - auto varRes = variables.get(variableName); - if (!varRes.status().ok()) { - return varRes.status(); - } - mdio::Variable var = varRes.value(); + MDIO_ASSIGN_OR_RETURN(auto var, variables.at(variableName)); + // Preserve the intervals so it can be re-sliced to the same dimensions + MDIO_ASSIGN_OR_RETURN(auto intervals, var.get_intervals()); + intervals.pop_back(); // Remove the byte dimension (Doesn't matter if it + // doesn't exist) // Ensure that the Variable is of dtype structarray auto spec = var.spec(); @@ -1248,6 +1250,12 @@ class Dataset { } base["kvstore"]["driver"] = specJson["kvstore"]["driver"]; base["kvstore"]["path"] = specJson["kvstore"]["path"]; + // Remove trailing slashes. This causes issue #130 + while (base["kvstore"]["path"].get().back() == '/') { + base["kvstore"]["path"] = + base["kvstore"]["path"].get().substr( + 0, base["kvstore"]["path"].get().size() - 1); + } // Handle cloud stores if (specJson["kvstore"].contains("bucket")) { @@ -1258,18 +1266,31 @@ class Dataset { base["kvstore"]["path"] = cloudPath; } - auto fieldedVar = mdio::Variable<>::Open(base, constants::kOpen); + auto fieldedVar = mdio::Variable::Open(base, constants::kOpen); - auto pair = tensorstore::PromiseFuturePair>::Make(); + auto pair = tensorstore::PromiseFuturePair>::Make(); fieldedVar.ExecuteWhenReady( - [this, promise = pair.promise, - variableName](tensorstore::ReadyFuture> readyFut) { + [this, promise = pair.promise, variableName, intervals]( + tensorstore::ReadyFuture> readyFut) { auto ready_result = readyFut.result(); if (!ready_result.ok()) { promise.SetResult(ready_result.status()); } else { - this->variables.add(variableName, ready_result.value()); - promise.SetResult(ready_result); + // Re-slice the Variable to the same dimensions as the original + std::vector> slices; + slices.reserve(intervals.size() - 1); + for (const auto& interval : intervals) { + slices.emplace_back(mdio::RangeDescriptor( + {interval.label, interval.inclusive_min, + interval.exclusive_max, 1})); + } + auto slicedVarRes = ready_result.value().slice(slices); + if (!slicedVarRes.status().ok()) { + promise.SetResult(slicedVarRes.status()); + } else { + this->variables.add(variableName, ready_result.value()); + promise.SetResult(slicedVarRes); + } } }); return pair.future; diff --git a/mdio/dataset_test.cc b/mdio/dataset_test.cc index a97f0fb..bf17b50 100644 --- a/mdio/dataset_test.cc +++ b/mdio/dataset_test.cc @@ -656,6 +656,72 @@ TEST(Dataset, selRepeatedRangeStop) { ASSERT_FALSE(sliceRes.status().ok()); } +TEST(Dataset, selectField) { + auto json_var = GetToyExample(); + + auto dataset = mdio::Dataset::from_json(json_var, "zarrs/acceptance", + mdio::constants::kCreateClean); + + ASSERT_TRUE(dataset.status().ok()) << dataset.status(); + auto ds = dataset.value(); + + mdio::RangeDescriptor desc1 = {"inline", 0, 5, 1}; + mdio::RangeDescriptor desc2 = {"crossline", 0, 5, 1}; + mdio::RangeDescriptor desc3 = {"depth", 0, 5, 1}; + + auto sliceRes = ds.isel(desc1, desc2, desc3); + ASSERT_TRUE(sliceRes.status().ok()) << sliceRes.status(); + auto slicedDs = sliceRes.value(); + + auto structImageHeadersRes = + slicedDs.variables.get("image_headers"); + ASSERT_TRUE(structImageHeadersRes.status().ok()) + << structImageHeadersRes.status(); + auto structImageHeaders = structImageHeadersRes.value(); + auto asdf = structImageHeaders.get_intervals(); + ASSERT_TRUE(asdf.status().ok()) << asdf.status(); + auto structIntervals = asdf.value(); + ASSERT_EQ(structIntervals.size(), 3); + + auto selectedVarFut = + slicedDs.SelectField("image_headers", "cdp-x"); + ASSERT_TRUE(selectedVarFut.status().ok()) << selectedVarFut.status(); + + auto typedInervalsRes = selectedVarFut.value().get_intervals(); + ASSERT_TRUE(typedInervalsRes.status().ok()) << typedInervalsRes.status(); + auto typedIntervals = typedInervalsRes.value(); + ASSERT_EQ(typedIntervals.size(), 2); + + EXPECT_EQ(typedIntervals[0].label, structIntervals[0].label) + << "Dimension 0 labels did not match"; + EXPECT_EQ(typedIntervals[1].label, structIntervals[1].label) + << "Dimension 1 labels did not match"; + EXPECT_EQ(typedIntervals[0].inclusive_min, structIntervals[0].inclusive_min) + << "Dimension 0 min did not match"; + EXPECT_EQ(typedIntervals[1].inclusive_min, structIntervals[1].inclusive_min) + << "Dimension 1 min did not match"; + EXPECT_EQ(typedIntervals[0].exclusive_max, structIntervals[0].exclusive_max) + << "Dimension 0 max did not match"; + EXPECT_EQ(typedIntervals[1].exclusive_max, structIntervals[1].exclusive_max) + << "Dimension 1 max did not match"; +} + +TEST(Dataset, selectFieldName) { + auto json_var = GetToyExample(); + + auto dataset = mdio::Dataset::from_json(json_var, "zarrs/acceptance", + mdio::constants::kCreateClean); + + ASSERT_TRUE(dataset.status().ok()) << dataset.status(); + auto ds = dataset.value(); + + auto selectedVarFut = + ds.SelectField("image_headers", "cdp-x"); + ASSERT_TRUE(selectedVarFut.status().ok()) << selectedVarFut.status(); + EXPECT_EQ(selectedVarFut.value().get_variable_name(), "image_headers") + << "Expected selected variable to be named image_headers"; +} + TEST(Dataset, fromConsolidatedMeta) { auto json_vars = GetToyExample();