-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
267 additions
and
84 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
#ifndef REAL_DATA_EXAMPLE_PROGRESS_H_ | ||
#define REAL_DATA_EXAMPLE_PROGRESS_H_ | ||
|
||
#include <mdio/mdio.h> | ||
|
||
#include <chrono> | ||
#include <indicators/cursor_control.hpp> | ||
#include <indicators/indeterminate_progress_bar.hpp> | ||
#include <memory> | ||
#include <sstream> | ||
#include <string> | ||
#include <thread> | ||
|
||
namespace mdio { | ||
|
||
template <typename T> | ||
class ProgressTracker { | ||
public: | ||
explicit ProgressTracker(const std::string& message = "Loading data...") | ||
: message_(message), | ||
bar_{ | ||
indicators::option::BarWidth{40}, | ||
indicators::option::Start{"["}, | ||
indicators::option::Fill{"·"}, | ||
indicators::option::Lead{"<==>"}, | ||
indicators::option::End{"]"}, | ||
indicators::option::PostfixText{message}, | ||
indicators::option::ForegroundColor{indicators::Color::yellow}, | ||
indicators::option::FontStyles{std::vector<indicators::FontStyle>{ | ||
indicators::FontStyle::bold}}, | ||
} { | ||
std::cout << "\033[2K\r" << std::flush; | ||
indicators::show_console_cursor(false); | ||
start_time_ = std::chrono::steady_clock::now(); | ||
} | ||
|
||
void tick() { | ||
auto current_time = std::chrono::steady_clock::now(); | ||
auto elapsed = std::chrono::duration_cast<std::chrono::seconds>( | ||
current_time - start_time_) | ||
.count(); | ||
std::stringstream time_str; | ||
time_str << message_ << " " << elapsed << "s"; | ||
bar_.set_option(indicators::option::PostfixText{time_str.str()}); | ||
bar_.tick(); | ||
} | ||
|
||
void complete() { | ||
bar_.mark_as_completed(); | ||
bar_.set_option( | ||
indicators::option::ForegroundColor{indicators::Color::green}); | ||
bar_.set_option(indicators::option::PostfixText{message_ + " completed"}); | ||
indicators::show_console_cursor(true); | ||
} | ||
|
||
~ProgressTracker() { indicators::show_console_cursor(true); } | ||
|
||
private: | ||
std::string message_; | ||
std::chrono::steady_clock::time_point start_time_; | ||
indicators::IndeterminateProgressBar bar_; | ||
}; | ||
|
||
template <typename T = void, DimensionIndex R = dynamic_rank, | ||
ArrayOriginKind OriginKind = offset_origin> | ||
Future<VariableData<T, R, OriginKind>> ReadWithProgress( | ||
Variable<T>& variable, const std::string& message = "Loading data...") { | ||
auto tracker = std::make_shared<ProgressTracker<T>>(message); | ||
auto future = variable.Read(); | ||
|
||
std::thread([tracker, future]() { | ||
while (!future.ready()) { | ||
tracker->tick(); | ||
std::this_thread::sleep_for(std::chrono::milliseconds(200)); | ||
} | ||
tracker->complete(); | ||
}).detach(); | ||
|
||
return future; | ||
} | ||
|
||
} // namespace mdio | ||
|
||
#endif // REAL_DATA_EXAMPLE_PROGRESS_H_ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,98 @@ | ||
#pragma once | ||
|
||
#include <absl/status/status.h> | ||
#include <mdio/mdio.h> | ||
|
||
#include <fstream> | ||
#include <iostream> | ||
#include <sstream> | ||
#include <string> | ||
#include <unordered_map> | ||
|
||
#include "tensorstore/tensorstore.h" | ||
|
||
std::string GetNumpyDtypeJson(const tensorstore::DataType& dtype) { | ||
using namespace tensorstore; | ||
|
||
// Mapping of DataTypeId to NumPy dtype JSON-like string format | ||
static const std::unordered_map<DataTypeId, std::string> dtype_to_numpy_json{ | ||
{DataTypeId::int8_t, "|i1"}, {DataTypeId::uint8_t, "|u1"}, | ||
{DataTypeId::int16_t, "<i2"}, {DataTypeId::uint16_t, "<u2"}, | ||
{DataTypeId::int32_t, "<i4"}, {DataTypeId::uint32_t, "<u4"}, | ||
{DataTypeId::int64_t, "<i8"}, {DataTypeId::uint64_t, "<u8"}, | ||
{DataTypeId::float32_t, "<f4"}, {DataTypeId::float64_t, "<f8"}, | ||
}; | ||
|
||
// Find the matching NumPy dtype string or return "unknown" if not found | ||
auto it = dtype_to_numpy_json.find(dtype.id()); | ||
if (it != dtype_to_numpy_json.end()) { | ||
return it->second; | ||
} else { | ||
return "unknown dtype"; | ||
} | ||
} | ||
|
||
template <typename T, mdio::DimensionIndex R, mdio::ArrayOriginKind OriginKind> | ||
absl::Status WriteNumpy(const mdio::SharedArray<T, R, OriginKind>& accessor, | ||
const std::string& filename) { | ||
auto domain = accessor.domain(); | ||
auto type_id = GetNumpyDtypeJson(accessor.dtype()); | ||
|
||
if (domain.rank() != 3) { | ||
return absl::InvalidArgumentError("Expected a 3D array"); | ||
} | ||
|
||
auto inline_inclusive_min = domain[0].inclusive_min(); | ||
auto inline_exclusive_max = domain[0].exclusive_max(); | ||
|
||
auto xline_inclusive_min = domain[1].inclusive_min(); | ||
auto xline_exclusive_max = domain[1].exclusive_max(); | ||
|
||
auto depth_inclusive_min = domain[2].inclusive_min(); | ||
auto depth_exclusive_max = domain[2].exclusive_max(); | ||
|
||
const int width = inline_exclusive_max - inline_inclusive_min; | ||
const int height = xline_exclusive_max - xline_inclusive_min; | ||
const int depth = depth_exclusive_max - depth_inclusive_min; | ||
|
||
std::ofstream outfile(filename, std::ios::binary); | ||
if (!outfile) { | ||
return absl::InvalidArgumentError("Could not open numpy file for writing"); | ||
} | ||
|
||
// Write the numpy header | ||
// Format described at: | ||
// https://numpy.org/doc/stable/reference/generated/numpy.lib.format.html | ||
const char magic_string[] = "\x93NUMPY\x01\x00"; // Magic string and version | ||
outfile.write(magic_string, sizeof(magic_string) - 1); // Write as byte array | ||
|
||
// Construct the header | ||
std::stringstream header; | ||
header << "{'descr': '" << type_id << "', 'fortran_order': False, 'shape': (" | ||
<< width << ", " << height << ", " << depth << ")}"; | ||
|
||
// Pad header to multiple of 64 bytes | ||
int header_len = header.str().length() + 1; // +1 for newline | ||
int pad_len = 64 - (header_len % 64); | ||
if (pad_len < 1) pad_len += 64; | ||
std::string padding(pad_len - 1, ' '); | ||
header << padding << "\n"; | ||
|
||
// Write header length and header | ||
uint16_t header_size = header.str().length(); | ||
outfile.write(reinterpret_cast<char*>(&header_size), sizeof(header_size)); | ||
outfile.write(header.str().c_str(), header_size); | ||
|
||
// Write the data | ||
for (int il = inline_inclusive_min; il < inline_exclusive_max; ++il) { | ||
for (int xl = xline_inclusive_min; xl < xline_exclusive_max; ++xl) { | ||
for (int zl = depth_inclusive_min; zl < depth_exclusive_max; ++zl) { | ||
T value = accessor(il, xl, zl); | ||
outfile.write(reinterpret_cast<const char*>(&value), sizeof(T)); | ||
} | ||
} | ||
} | ||
|
||
outfile.close(); | ||
return absl::OkStatus(); | ||
} |