From 3265173d07a227c27819b015b57c68c7ef3b096a Mon Sep 17 00:00:00 2001 From: blasscoc Date: Thu, 14 Nov 2024 08:48:46 -0600 Subject: [PATCH] enh:add niceness --- examples/real_data_example/src/progress.h | 84 +++++++++ .../src/real_data_example.cc | 169 +++++++++--------- .../real_data_example/src/seismic_numpy.h | 98 ++++++++++ 3 files changed, 267 insertions(+), 84 deletions(-) create mode 100644 examples/real_data_example/src/progress.h create mode 100644 examples/real_data_example/src/seismic_numpy.h diff --git a/examples/real_data_example/src/progress.h b/examples/real_data_example/src/progress.h new file mode 100644 index 0000000..d2c2aa8 --- /dev/null +++ b/examples/real_data_example/src/progress.h @@ -0,0 +1,84 @@ +#ifndef REAL_DATA_EXAMPLE_PROGRESS_H_ +#define REAL_DATA_EXAMPLE_PROGRESS_H_ + +#include + +#include +#include +#include +#include +#include +#include +#include + +namespace mdio { + +template +class ProgressTracker { + public: + explicit ProgressTracker(const std::string& message = "Loading data...") + : message_(message), + bar_{ + indicators::option::BarWidth{40}, + indicators::option::Start{"["}, + indicators::option::Fill{"·"}, + indicators::option::Lead{"<==>"}, + indicators::option::End{"]"}, + indicators::option::PostfixText{message}, + indicators::option::ForegroundColor{indicators::Color::yellow}, + indicators::option::FontStyles{std::vector{ + indicators::FontStyle::bold}}, + } { + std::cout << "\033[2K\r" << std::flush; + indicators::show_console_cursor(false); + start_time_ = std::chrono::steady_clock::now(); + } + + void tick() { + auto current_time = std::chrono::steady_clock::now(); + auto elapsed = std::chrono::duration_cast( + current_time - start_time_) + .count(); + std::stringstream time_str; + time_str << message_ << " " << elapsed << "s"; + bar_.set_option(indicators::option::PostfixText{time_str.str()}); + bar_.tick(); + } + + void complete() { + bar_.mark_as_completed(); + bar_.set_option( + indicators::option::ForegroundColor{indicators::Color::green}); + bar_.set_option(indicators::option::PostfixText{message_ + " completed"}); + indicators::show_console_cursor(true); + } + + ~ProgressTracker() { indicators::show_console_cursor(true); } + + private: + std::string message_; + std::chrono::steady_clock::time_point start_time_; + indicators::IndeterminateProgressBar bar_; +}; + +template +Future> ReadWithProgress( + Variable& variable, const std::string& message = "Loading data...") { + auto tracker = std::make_shared>(message); + auto future = variable.Read(); + + std::thread([tracker, future]() { + while (!future.ready()) { + tracker->tick(); + std::this_thread::sleep_for(std::chrono::milliseconds(200)); + } + tracker->complete(); + }).detach(); + + return future; +} + +} // namespace mdio + +#endif // REAL_DATA_EXAMPLE_PROGRESS_H_ \ No newline at end of file diff --git a/examples/real_data_example/src/real_data_example.cc b/examples/real_data_example/src/real_data_example.cc index 6202a3a..88944c0 100644 --- a/examples/real_data_example/src/real_data_example.cc +++ b/examples/real_data_example/src/real_data_example.cc @@ -4,7 +4,11 @@ #include #include +#include "absl/flags/flag.h" +#include "absl/flags/parse.h" +#include "absl/strings/str_split.h" #include "interpolation.h" +#include "progress.h" #include "seismic_numpy.h" #include "seismic_png.h" #include "tensorstore/tensorstore.h" @@ -13,105 +17,102 @@ using Index = mdio::Index; -absl::Status Run() { +ABSL_FLAG(std::string, inline_range, "{inline,700,701,1}", + "Inline range in format {inline,start,end,step}"); +ABSL_FLAG(std::string, xline_range, "{crossline,500,700,1}", + "Crossline range in format {crossline,start,end,step}"); +ABSL_FLAG(std::string, depth_range, "", + "Optional depth range in format {depth,start,end,step}"); +ABSL_FLAG(std::string, variable_name, "seismic", + "Name of the seismic variable"); +ABSL_FLAG(bool, print_dataset, false, "Print the dataset URL and return"); +ABSL_FLAG(std::string, dataset_path, + "s3://tgs-opendata-poseidon/full_stack_agc.mdio", + "The path to the dataset"); + +using Index = mdio::Index; + +// Make Run a template function for both the type and descriptors +template +absl::Status Run(const Descriptors... descriptors) { + // New feature to print the dataset if the flag is set + MDIO_ASSIGN_OR_RETURN( auto dataset, - mdio::Dataset::Open( - std::string("s3://tgs-opendata-poseidon/full_stack_agc.mdio"), - mdio::constants::kOpen) + mdio::Dataset::Open(std::string(absl::GetFlag(FLAGS_dataset_path)), + mdio::constants::kOpen) .result()) - std::cout << dataset << std::endl; - - auto inline_index = 700; - - // Select a inline slice - mdio::SliceDescriptor desc1 = {"inline", inline_index, inline_index + 1, 1}; - mdio::SliceDescriptor desc2 = {"crossline", 500, 700, 1}; - - MDIO_ASSIGN_OR_RETURN(auto inline_slice, dataset.isel({desc1, desc2})) - - // Add seismic data reading example - MDIO_ASSIGN_OR_RETURN(auto seismic_var, - inline_slice.variables.get("seismic")) - - // Create and configure the progress bar - indicators::IndeterminateProgressBar bar{ - indicators::option::BarWidth{40}, - indicators::option::Start{"["}, - indicators::option::Fill{"·"}, - indicators::option::Lead{"<==>"}, - indicators::option::End{"]"}, - indicators::option::PostfixText{"Loading seismic data..."}, - indicators::option::ForegroundColor{indicators::Color::yellow}, - indicators::option::FontStyles{ - std::vector{indicators::FontStyle::bold}}, - }; - // Clear any existing output and hide cursor - std::cout << "\033[2K\r" << std::flush; - indicators::show_console_cursor(false); - - // Start the async read operation - auto future = seismic_var.Read(); - auto start_time = std::chrono::steady_clock::now(); - - // Show progress bar while waiting for data - while (!future.ready()) { - auto current_time = std::chrono::steady_clock::now(); - auto elapsed = std::chrono::duration_cast( - current_time - start_time) - .count(); - std::stringstream time_str; - time_str << "Loading seismic data... " << elapsed << "s"; - bar.set_option(indicators::option::PostfixText{time_str.str()}); - bar.tick(); - // Reduce update frequency slightly - std::this_thread::sleep_for(std::chrono::milliseconds(200)); + if (absl::GetFlag(FLAGS_print_dataset)) { + std::cout << dataset << std::endl; + return absl::OkStatus(); // Return early if just printing the dataset } - // Get the result and stop the bar - MDIO_ASSIGN_OR_RETURN(auto seismic_data, future.result()) - bar.mark_as_completed(); - bar.set_option(indicators::option::ForegroundColor{indicators::Color::green}); - bar.set_option(indicators::option::PostfixText{"Seismic data loaded"}); + // slice the dataset + MDIO_ASSIGN_OR_RETURN(auto inline_slice, dataset.isel(descriptors...)); - // Show cursor again - indicators::show_console_cursor(true); + // Get variable with template type T using the variable name from the CLI + MDIO_ASSIGN_OR_RETURN(auto variable, inline_slice.variables.template get( + absl::GetFlag(FLAGS_variable_name))); - auto seismic_accessor = seismic_data.get_data_accessor(); - - std::cout << seismic_data.dimensions() << std::endl; + if (variable.rank() != 3) { + return absl::InvalidArgumentError("Seismic data must be 3D"); + } - // Get the domain of the variable - auto xline_inclusive_min = - seismic_var.get_store().domain()[1].interval().inclusive_min(); - auto xline_exclusive_max = - seismic_var.get_store().domain()[1].interval().exclusive_max(); + MDIO_ASSIGN_OR_RETURN(auto seismic_data, ReadWithProgress(variable).result()) - auto depth_inclusive_min = - seismic_var.get_store().domain()[2].interval().inclusive_min(); - auto depth_exclusive_max = - seismic_var.get_store().domain()[2].interval().exclusive_max(); + auto seismic_accessor = seismic_data.get_data_accessor(); // Write numpy file - MDIO_RETURN_IF_ERROR(WriteNumpy(seismic_accessor, inline_index, - xline_inclusive_min, xline_exclusive_max, - depth_inclusive_min, depth_exclusive_max)); - - // Write PNG file - MDIO_RETURN_IF_ERROR(WritePNG(seismic_accessor, inline_index, - xline_inclusive_min, xline_exclusive_max, - depth_inclusive_min, depth_exclusive_max)); + MDIO_RETURN_IF_ERROR(WriteNumpy(seismic_accessor, "seismic_slice.npy")); return absl::OkStatus(); } -int main() { - auto status = Run(); - if (!status.ok()) { - std::cout << "Task failed.\n" << status << std::endl; - } else { - std::cout << "MDIO Example Complete.\n"; +mdio::SliceDescriptor ParseRange(std::string_view range) { + // Remove leading/trailing whitespace and braces + if (range.empty()) { + // FIXME - we need a better way to handle this + return {"ignore_me", 0, 1, 1}; } - return status.ok() ? 0 : 1; + + range.remove_prefix(range.find_first_not_of(" {")); + range.remove_suffix(range.length() - range.find_last_not_of("} ") - 1); + + // Split by comma into parts + std::vector parts = absl::StrSplit(range, ','); + + if (parts.size() != 4) { + throw std::runtime_error( + "Invalid range format. Expected {label,start,end,step}"); + } + + // Clean up the label (first part) by removing quotes and spaces + auto label = parts[0]; + label.remove_prefix(std::min(label.find_first_not_of(" \""), label.size())); + label.remove_suffix( + label.length() - + std::min(label.find_last_not_of(" \"") + 1, label.size())); + + return { + label, // dimension name + std::stoi(std::string(parts[1])), // start + std::stoi(std::string(parts[2])), // end + std::stoi(std::string(parts[3])) // step + }; +} + +int main(int argc, char* argv[]) { + absl::ParseCommandLine(argc, argv); + + // keep me in memory + auto inline_range = absl::GetFlag(FLAGS_inline_range); + auto crossline_range = absl::GetFlag(FLAGS_xline_range); + auto depth_range = absl::GetFlag(FLAGS_depth_range); + + auto desc1 = ParseRange(inline_range); + auto desc2 = ParseRange(crossline_range); + auto desc3 = ParseRange(depth_range); + + return Run(desc1, desc2, desc3).ok() ? 0 : 1; } \ No newline at end of file diff --git a/examples/real_data_example/src/seismic_numpy.h b/examples/real_data_example/src/seismic_numpy.h new file mode 100644 index 0000000..d81560b --- /dev/null +++ b/examples/real_data_example/src/seismic_numpy.h @@ -0,0 +1,98 @@ +#pragma once + +#include +#include + +#include +#include +#include +#include +#include + +#include "tensorstore/tensorstore.h" + +std::string GetNumpyDtypeJson(const tensorstore::DataType& dtype) { + using namespace tensorstore; + + // Mapping of DataTypeId to NumPy dtype JSON-like string format + static const std::unordered_map dtype_to_numpy_json{ + {DataTypeId::int8_t, "|i1"}, {DataTypeId::uint8_t, "|u1"}, + {DataTypeId::int16_t, "second; + } else { + return "unknown dtype"; + } +} + +template +absl::Status WriteNumpy(const mdio::SharedArray& accessor, + const std::string& filename) { + auto domain = accessor.domain(); + auto type_id = GetNumpyDtypeJson(accessor.dtype()); + + if (domain.rank() != 3) { + return absl::InvalidArgumentError("Expected a 3D array"); + } + + auto inline_inclusive_min = domain[0].inclusive_min(); + auto inline_exclusive_max = domain[0].exclusive_max(); + + auto xline_inclusive_min = domain[1].inclusive_min(); + auto xline_exclusive_max = domain[1].exclusive_max(); + + auto depth_inclusive_min = domain[2].inclusive_min(); + auto depth_exclusive_max = domain[2].exclusive_max(); + + const int width = inline_exclusive_max - inline_inclusive_min; + const int height = xline_exclusive_max - xline_inclusive_min; + const int depth = depth_exclusive_max - depth_inclusive_min; + + std::ofstream outfile(filename, std::ios::binary); + if (!outfile) { + return absl::InvalidArgumentError("Could not open numpy file for writing"); + } + + // Write the numpy header + // Format described at: + // https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html + const char magic_string[] = "\x93NUMPY\x01\x00"; // Magic string and version + outfile.write(magic_string, sizeof(magic_string) - 1); // Write as byte array + + // Construct the header + std::stringstream header; + header << "{'descr': '" << type_id << "', 'fortran_order': False, 'shape': (" + << width << ", " << height << ", " << depth << ")}"; + + // Pad header to multiple of 64 bytes + int header_len = header.str().length() + 1; // +1 for newline + int pad_len = 64 - (header_len % 64); + if (pad_len < 1) pad_len += 64; + std::string padding(pad_len - 1, ' '); + header << padding << "\n"; + + // Write header length and header + uint16_t header_size = header.str().length(); + outfile.write(reinterpret_cast(&header_size), sizeof(header_size)); + outfile.write(header.str().c_str(), header_size); + + // Write the data + for (int il = inline_inclusive_min; il < inline_exclusive_max; ++il) { + for (int xl = xline_inclusive_min; xl < xline_exclusive_max; ++xl) { + for (int zl = depth_inclusive_min; zl < depth_exclusive_max; ++zl) { + T value = accessor(il, xl, zl); + outfile.write(reinterpret_cast(&value), sizeof(T)); + } + } + } + + outfile.close(); + return absl::OkStatus(); +} \ No newline at end of file