Skip to content

Commit

Permalink
enh:add niceness
Browse files Browse the repository at this point in the history
  • Loading branch information
blasscoc committed Nov 14, 2024
1 parent 3218051 commit 3265173
Show file tree
Hide file tree
Showing 3 changed files with 267 additions and 84 deletions.
84 changes: 84 additions & 0 deletions examples/real_data_example/src/progress.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
#ifndef REAL_DATA_EXAMPLE_PROGRESS_H_
#define REAL_DATA_EXAMPLE_PROGRESS_H_

#include <mdio/mdio.h>

#include <chrono>
#include <indicators/cursor_control.hpp>
#include <indicators/indeterminate_progress_bar.hpp>
#include <memory>
#include <sstream>
#include <string>
#include <thread>

namespace mdio {

template <typename T>
class ProgressTracker {
public:
explicit ProgressTracker(const std::string& message = "Loading data...")
: message_(message),
bar_{
indicators::option::BarWidth{40},
indicators::option::Start{"["},
indicators::option::Fill{"·"},
indicators::option::Lead{"<==>"},
indicators::option::End{"]"},
indicators::option::PostfixText{message},
indicators::option::ForegroundColor{indicators::Color::yellow},
indicators::option::FontStyles{std::vector<indicators::FontStyle>{
indicators::FontStyle::bold}},
} {
std::cout << "\033[2K\r" << std::flush;
indicators::show_console_cursor(false);
start_time_ = std::chrono::steady_clock::now();
}

void tick() {
auto current_time = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
current_time - start_time_)
.count();
std::stringstream time_str;
time_str << message_ << " " << elapsed << "s";
bar_.set_option(indicators::option::PostfixText{time_str.str()});
bar_.tick();
}

void complete() {
bar_.mark_as_completed();
bar_.set_option(
indicators::option::ForegroundColor{indicators::Color::green});
bar_.set_option(indicators::option::PostfixText{message_ + " completed"});
indicators::show_console_cursor(true);
}

~ProgressTracker() { indicators::show_console_cursor(true); }

private:
std::string message_;
std::chrono::steady_clock::time_point start_time_;
indicators::IndeterminateProgressBar bar_;
};

template <typename T = void, DimensionIndex R = dynamic_rank,
ArrayOriginKind OriginKind = offset_origin>
Future<VariableData<T, R, OriginKind>> ReadWithProgress(
Variable<T>& variable, const std::string& message = "Loading data...") {
auto tracker = std::make_shared<ProgressTracker<T>>(message);
auto future = variable.Read();

std::thread([tracker, future]() {
while (!future.ready()) {
tracker->tick();
std::this_thread::sleep_for(std::chrono::milliseconds(200));
}
tracker->complete();
}).detach();

return future;
}

} // namespace mdio

#endif // REAL_DATA_EXAMPLE_PROGRESS_H_
169 changes: 85 additions & 84 deletions examples/real_data_example/src/real_data_example.cc
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,11 @@
#include <indicators/indeterminate_progress_bar.hpp>
#include <indicators/progress_bar.hpp>

#include "absl/flags/flag.h"
#include "absl/flags/parse.h"
#include "absl/strings/str_split.h"
#include "interpolation.h"
#include "progress.h"
#include "seismic_numpy.h"
#include "seismic_png.h"
#include "tensorstore/tensorstore.h"
Expand All @@ -13,105 +17,102 @@

using Index = mdio::Index;

absl::Status Run() {
ABSL_FLAG(std::string, inline_range, "{inline,700,701,1}",
"Inline range in format {inline,start,end,step}");
ABSL_FLAG(std::string, xline_range, "{crossline,500,700,1}",
"Crossline range in format {crossline,start,end,step}");
ABSL_FLAG(std::string, depth_range, "",
"Optional depth range in format {depth,start,end,step}");
ABSL_FLAG(std::string, variable_name, "seismic",
"Name of the seismic variable");
ABSL_FLAG(bool, print_dataset, false, "Print the dataset URL and return");
ABSL_FLAG(std::string, dataset_path,
"s3://tgs-opendata-poseidon/full_stack_agc.mdio",
"The path to the dataset");

using Index = mdio::Index;

// Make Run a template function for both the type and descriptors
template <typename T, typename... Descriptors>
absl::Status Run(const Descriptors... descriptors) {
// New feature to print the dataset if the flag is set

MDIO_ASSIGN_OR_RETURN(
auto dataset,
mdio::Dataset::Open(
std::string("s3://tgs-opendata-poseidon/full_stack_agc.mdio"),
mdio::constants::kOpen)
mdio::Dataset::Open(std::string(absl::GetFlag(FLAGS_dataset_path)),
mdio::constants::kOpen)
.result())
std::cout << dataset << std::endl;

auto inline_index = 700;

// Select a inline slice
mdio::SliceDescriptor desc1 = {"inline", inline_index, inline_index + 1, 1};
mdio::SliceDescriptor desc2 = {"crossline", 500, 700, 1};

MDIO_ASSIGN_OR_RETURN(auto inline_slice, dataset.isel({desc1, desc2}))

// Add seismic data reading example
MDIO_ASSIGN_OR_RETURN(auto seismic_var,
inline_slice.variables.get<float>("seismic"))

// Create and configure the progress bar
indicators::IndeterminateProgressBar bar{
indicators::option::BarWidth{40},
indicators::option::Start{"["},
indicators::option::Fill{"·"},
indicators::option::Lead{"<==>"},
indicators::option::End{"]"},
indicators::option::PostfixText{"Loading seismic data..."},
indicators::option::ForegroundColor{indicators::Color::yellow},
indicators::option::FontStyles{
std::vector<indicators::FontStyle>{indicators::FontStyle::bold}},
};

// Clear any existing output and hide cursor
std::cout << "\033[2K\r" << std::flush;
indicators::show_console_cursor(false);

// Start the async read operation
auto future = seismic_var.Read();
auto start_time = std::chrono::steady_clock::now();

// Show progress bar while waiting for data
while (!future.ready()) {
auto current_time = std::chrono::steady_clock::now();
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
current_time - start_time)
.count();
std::stringstream time_str;
time_str << "Loading seismic data... " << elapsed << "s";
bar.set_option(indicators::option::PostfixText{time_str.str()});
bar.tick();
// Reduce update frequency slightly
std::this_thread::sleep_for(std::chrono::milliseconds(200));
if (absl::GetFlag(FLAGS_print_dataset)) {
std::cout << dataset << std::endl;
return absl::OkStatus(); // Return early if just printing the dataset
}

// Get the result and stop the bar
MDIO_ASSIGN_OR_RETURN(auto seismic_data, future.result())
bar.mark_as_completed();
bar.set_option(indicators::option::ForegroundColor{indicators::Color::green});
bar.set_option(indicators::option::PostfixText{"Seismic data loaded"});
// slice the dataset
MDIO_ASSIGN_OR_RETURN(auto inline_slice, dataset.isel(descriptors...));

// Show cursor again
indicators::show_console_cursor(true);
// Get variable with template type T using the variable name from the CLI
MDIO_ASSIGN_OR_RETURN(auto variable, inline_slice.variables.template get<T>(
absl::GetFlag(FLAGS_variable_name)));

auto seismic_accessor = seismic_data.get_data_accessor();

std::cout << seismic_data.dimensions() << std::endl;
if (variable.rank() != 3) {
return absl::InvalidArgumentError("Seismic data must be 3D");
}

// Get the domain of the variable
auto xline_inclusive_min =
seismic_var.get_store().domain()[1].interval().inclusive_min();
auto xline_exclusive_max =
seismic_var.get_store().domain()[1].interval().exclusive_max();
MDIO_ASSIGN_OR_RETURN(auto seismic_data, ReadWithProgress(variable).result())

auto depth_inclusive_min =
seismic_var.get_store().domain()[2].interval().inclusive_min();
auto depth_exclusive_max =
seismic_var.get_store().domain()[2].interval().exclusive_max();
auto seismic_accessor = seismic_data.get_data_accessor();

// Write numpy file
MDIO_RETURN_IF_ERROR(WriteNumpy(seismic_accessor, inline_index,
xline_inclusive_min, xline_exclusive_max,
depth_inclusive_min, depth_exclusive_max));

// Write PNG file
MDIO_RETURN_IF_ERROR(WritePNG(seismic_accessor, inline_index,
xline_inclusive_min, xline_exclusive_max,
depth_inclusive_min, depth_exclusive_max));
MDIO_RETURN_IF_ERROR(WriteNumpy(seismic_accessor, "seismic_slice.npy"));

return absl::OkStatus();
}

int main() {
auto status = Run();
if (!status.ok()) {
std::cout << "Task failed.\n" << status << std::endl;
} else {
std::cout << "MDIO Example Complete.\n";
mdio::SliceDescriptor ParseRange(std::string_view range) {
// Remove leading/trailing whitespace and braces
if (range.empty()) {
// FIXME - we need a better way to handle this
return {"ignore_me", 0, 1, 1};
}
return status.ok() ? 0 : 1;

range.remove_prefix(range.find_first_not_of(" {"));
range.remove_suffix(range.length() - range.find_last_not_of("} ") - 1);

// Split by comma into parts
std::vector<std::string_view> parts = absl::StrSplit(range, ',');

if (parts.size() != 4) {
throw std::runtime_error(
"Invalid range format. Expected {label,start,end,step}");
}

// Clean up the label (first part) by removing quotes and spaces
auto label = parts[0];
label.remove_prefix(std::min(label.find_first_not_of(" \""), label.size()));
label.remove_suffix(
label.length() -
std::min(label.find_last_not_of(" \"") + 1, label.size()));

return {
label, // dimension name
std::stoi(std::string(parts[1])), // start
std::stoi(std::string(parts[2])), // end
std::stoi(std::string(parts[3])) // step
};
}

int main(int argc, char* argv[]) {
absl::ParseCommandLine(argc, argv);

// keep me in memory
auto inline_range = absl::GetFlag(FLAGS_inline_range);
auto crossline_range = absl::GetFlag(FLAGS_xline_range);
auto depth_range = absl::GetFlag(FLAGS_depth_range);

auto desc1 = ParseRange(inline_range);
auto desc2 = ParseRange(crossline_range);
auto desc3 = ParseRange(depth_range);

return Run<float>(desc1, desc2, desc3).ok() ? 0 : 1;
}
98 changes: 98 additions & 0 deletions examples/real_data_example/src/seismic_numpy.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,98 @@
#pragma once

#include <absl/status/status.h>
#include <mdio/mdio.h>

#include <fstream>
#include <iostream>
#include <sstream>
#include <string>
#include <unordered_map>

#include "tensorstore/tensorstore.h"

std::string GetNumpyDtypeJson(const tensorstore::DataType& dtype) {
using namespace tensorstore;

// Mapping of DataTypeId to NumPy dtype JSON-like string format
static const std::unordered_map<DataTypeId, std::string> dtype_to_numpy_json{
{DataTypeId::int8_t, "|i1"}, {DataTypeId::uint8_t, "|u1"},
{DataTypeId::int16_t, "<i2"}, {DataTypeId::uint16_t, "<u2"},
{DataTypeId::int32_t, "<i4"}, {DataTypeId::uint32_t, "<u4"},
{DataTypeId::int64_t, "<i8"}, {DataTypeId::uint64_t, "<u8"},
{DataTypeId::float32_t, "<f4"}, {DataTypeId::float64_t, "<f8"},
};

// Find the matching NumPy dtype string or return "unknown" if not found
auto it = dtype_to_numpy_json.find(dtype.id());
if (it != dtype_to_numpy_json.end()) {
return it->second;
} else {
return "unknown dtype";
}
}

template <typename T, mdio::DimensionIndex R, mdio::ArrayOriginKind OriginKind>
absl::Status WriteNumpy(const mdio::SharedArray<T, R, OriginKind>& accessor,
const std::string& filename) {
auto domain = accessor.domain();
auto type_id = GetNumpyDtypeJson(accessor.dtype());

if (domain.rank() != 3) {
return absl::InvalidArgumentError("Expected a 3D array");
}

auto inline_inclusive_min = domain[0].inclusive_min();
auto inline_exclusive_max = domain[0].exclusive_max();

auto xline_inclusive_min = domain[1].inclusive_min();
auto xline_exclusive_max = domain[1].exclusive_max();

auto depth_inclusive_min = domain[2].inclusive_min();
auto depth_exclusive_max = domain[2].exclusive_max();

const int width = inline_exclusive_max - inline_inclusive_min;
const int height = xline_exclusive_max - xline_inclusive_min;
const int depth = depth_exclusive_max - depth_inclusive_min;

std::ofstream outfile(filename, std::ios::binary);
if (!outfile) {
return absl::InvalidArgumentError("Could not open numpy file for writing");
}

// Write the numpy header
// Format described at:
// https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html
const char magic_string[] = "\x93NUMPY\x01\x00"; // Magic string and version
outfile.write(magic_string, sizeof(magic_string) - 1); // Write as byte array

// Construct the header
std::stringstream header;
header << "{'descr': '" << type_id << "', 'fortran_order': False, 'shape': ("
<< width << ", " << height << ", " << depth << ")}";

// Pad header to multiple of 64 bytes
int header_len = header.str().length() + 1; // +1 for newline
int pad_len = 64 - (header_len % 64);
if (pad_len < 1) pad_len += 64;
std::string padding(pad_len - 1, ' ');
header << padding << "\n";

// Write header length and header
uint16_t header_size = header.str().length();
outfile.write(reinterpret_cast<char*>(&header_size), sizeof(header_size));
outfile.write(header.str().c_str(), header_size);

// Write the data
for (int il = inline_inclusive_min; il < inline_exclusive_max; ++il) {
for (int xl = xline_inclusive_min; xl < xline_exclusive_max; ++xl) {
for (int zl = depth_inclusive_min; zl < depth_exclusive_max; ++zl) {
T value = accessor(il, xl, zl);
outfile.write(reinterpret_cast<const char*>(&value), sizeof(T));
}
}
}

outfile.close();
return absl::OkStatus();
}

0 comments on commit 3265173

Please sign in to comment.