From 16617f29a3d7362091336825421dd5f771fb8bc2 Mon Sep 17 00:00:00 2001 From: Marcin Kowalczyk Date: Tue, 14 Jan 2025 10:42:28 +0100 Subject: [PATCH] Fix wrapping `Reader`s and `Writer`s when the underlying `Reader` or `Writer` fails. * Ensure that buffer pointers of the wrapper stay related to the underlying object even in the case of failure. This avoids `SyncBuffer()` modifying `cursor()` which would be no longer related to `start()` and `limit()`. * Let `Fail()` set the failure right away and then annotate it, instead of setting an annotated failure. This avoids an infinite recursion if `AnnotateStatusImpl()` calls `Fail()`. * Fix `PositionShifting{Reader,Writer}::MakeBuffer()` near the end of the `Position` range: - Ensure that buffer pointers stay related to the underlying object even if the position overflows. - Shrink the buffer to account for the remaining `Position` space if that is sufficient to avoid position overflow. - Fail if shrinking the buffer would make `available() < min_length`. - Propagate failure from `MakeBuffer()` to its callers. Cosmetics: * Let `PositionShiftingReader` expose the underlying buffer from its `start()` rather than `cursor()`. This makes seeking within the buffer more efficient. * Let `LimitingReader` ensure that its buffer pointers always stay related to the underlying `Reader`. This avoids a conditional in `SyncBuffer()`. PiperOrigin-RevId: 715284599 --- riegeli/base/object.cc | 9 +++-- riegeli/bytes/backward_writer.cc | 2 +- riegeli/bytes/limiting_reader.h | 17 +++++---- riegeli/bytes/position_shifting_reader.cc | 26 ++++++-------- riegeli/bytes/position_shifting_reader.h | 42 ++++++++++++++--------- riegeli/bytes/position_shifting_writer.cc | 22 ++++++------ riegeli/bytes/position_shifting_writer.h | 33 +++++++++++------- riegeli/bytes/writer.cc | 2 +- 8 files changed, 83 insertions(+), 70 deletions(-) diff --git a/riegeli/base/object.cc b/riegeli/base/object.cc index d9c72628..0800b679 100644 --- a/riegeli/base/object.cc +++ b/riegeli/base/object.cc @@ -65,9 +65,11 @@ void Object::Done() {} bool Object::Fail(absl::Status status) { RIEGELI_ASSERT(!status.ok()) << "Failed precondition of Object::Fail(): status not failed"; - status = AnnotateStatus(std::move(status)); + if (ABSL_PREDICT_FALSE(!not_failed())) return false; + state_.Fail(std::move(status)); + state_.SetStatus(AnnotateStatus(state_.status())); OnFail(); - return state_.Fail(std::move(status)); + return false; } void Object::SetStatus(absl::Status status) { @@ -87,8 +89,9 @@ bool Object::FailWithoutAnnotation(absl::Status status) { << "Failed precondition of Object::FailWithoutAnnotation(): " "status not failed"; if (ABSL_PREDICT_FALSE(!not_failed())) return false; + state_.Fail(std::move(status)); OnFail(); - return state_.Fail(std::move(status)); + return false; } absl::Status Object::StatusOrAnnotate(absl::Status other_status) { diff --git a/riegeli/bytes/backward_writer.cc b/riegeli/bytes/backward_writer.cc index 6eada754..243f39c8 100644 --- a/riegeli/bytes/backward_writer.cc +++ b/riegeli/bytes/backward_writer.cc @@ -40,7 +40,7 @@ namespace riegeli { -void BackwardWriter::OnFail() { set_buffer(); } +void BackwardWriter::OnFail() { set_buffer(start()); } absl::Status BackwardWriter::AnnotateStatusImpl(absl::Status status) { if (is_open()) return Annotate(status, absl::StrCat("at byte ", pos())); diff --git a/riegeli/bytes/limiting_reader.h b/riegeli/bytes/limiting_reader.h index 62d901b3..6b1104cf 100644 --- a/riegeli/bytes/limiting_reader.h +++ b/riegeli/bytes/limiting_reader.h @@ -279,9 +279,9 @@ class LimitingReaderBase : public Reader { bool fail_if_longer_ = false; // Invariants if `is_open()`: - // `start() == SrcReader()->start() || start() == nullptr` - // `limit() <= SrcReader()->limit() || limit() == nullptr` - // `start_pos() == SrcReader()->start_pos() || start() == nullptr` + // `start() >= SrcReader()->start()` + // `limit() <= SrcReader()->limit()` + // `start_pos() >= SrcReader()->start_pos()` // `limit_pos() <= max_pos_` }; @@ -460,11 +460,10 @@ inline void LimitingReaderBase::Initialize(Reader* src, Options&& options) { inline void LimitingReaderBase::set_max_pos(Position max_pos) { max_pos_ = max_pos; - if (limit_pos() > max_pos_) { + if (ABSL_PREDICT_FALSE(limit_pos() > max_pos_)) { if (ABSL_PREDICT_FALSE(pos() > max_pos_)) { - set_buffer(); + set_buffer(cursor()); set_limit_pos(max_pos_); - CheckEnough(); return; } set_buffer(start(), @@ -498,15 +497,15 @@ inline Position LimitingReaderBase::max_length() const { } inline void LimitingReaderBase::SyncBuffer(Reader& src) { - if (ABSL_PREDICT_TRUE(cursor() != nullptr)) src.set_cursor(cursor()); + src.set_cursor(cursor()); } inline void LimitingReaderBase::MakeBuffer(Reader& src) { set_buffer(src.start(), src.start_to_limit(), src.start_to_cursor()); set_limit_pos(src.limit_pos()); - if (limit_pos() > max_pos_) { + if (ABSL_PREDICT_FALSE(limit_pos() > max_pos_)) { if (ABSL_PREDICT_FALSE(pos() > max_pos_)) { - set_buffer(); + set_buffer(cursor()); } else { set_buffer(start(), start_to_limit() - IntCast(limit_pos() - max_pos_), diff --git a/riegeli/bytes/position_shifting_reader.cc b/riegeli/bytes/position_shifting_reader.cc index 55fba7e1..3cf69c15 100644 --- a/riegeli/bytes/position_shifting_reader.cc +++ b/riegeli/bytes/position_shifting_reader.cc @@ -86,8 +86,7 @@ bool PositionShiftingReaderBase::PullSlow(size_t min_length, Reader& src = *SrcReader(); SyncBuffer(src); const bool pull_ok = src.Pull(min_length, recommended_length); - MakeBuffer(src); - return pull_ok; + return MakeBuffer(src, min_length) && pull_ok; } bool PositionShiftingReaderBase::ReadSlow(size_t length, char* dest) { @@ -98,8 +97,7 @@ bool PositionShiftingReaderBase::ReadSlow(size_t length, char* dest) { Reader& src = *SrcReader(); SyncBuffer(src); const bool read_ok = src.Read(length, dest); - MakeBuffer(src); - return read_ok; + return MakeBuffer(src) && read_ok; } bool PositionShiftingReaderBase::ReadSlow(size_t length, Chain& dest) { @@ -129,8 +127,7 @@ inline bool PositionShiftingReaderBase::ReadInternal(size_t length, Reader& src = *SrcReader(); SyncBuffer(src); const bool read_ok = src.ReadAndAppend(length, dest); - MakeBuffer(src); - return read_ok; + return MakeBuffer(src) && read_ok; } bool PositionShiftingReaderBase::CopySlow(Position length, Writer& dest) { @@ -141,8 +138,7 @@ bool PositionShiftingReaderBase::CopySlow(Position length, Writer& dest) { Reader& src = *SrcReader(); SyncBuffer(src); const bool copy_ok = src.Copy(length, dest); - MakeBuffer(src); - return copy_ok; + return MakeBuffer(src) && copy_ok; } bool PositionShiftingReaderBase::CopySlow(size_t length, BackwardWriter& dest) { @@ -153,8 +149,7 @@ bool PositionShiftingReaderBase::CopySlow(size_t length, BackwardWriter& dest) { Reader& src = *SrcReader(); SyncBuffer(src); const bool copy_ok = src.Copy(length, dest); - MakeBuffer(src); - return copy_ok; + return MakeBuffer(src) && copy_ok; } bool PositionShiftingReaderBase::ReadOrPullSomeSlow( @@ -169,8 +164,7 @@ bool PositionShiftingReaderBase::ReadOrPullSomeSlow( Reader& src = *SrcReader(); SyncBuffer(src); const bool read_ok = src.ReadOrPullSome(max_length, get_dest); - MakeBuffer(src); - return read_ok; + return MakeBuffer(src) && read_ok; } void PositionShiftingReaderBase::ReadHintSlow(size_t min_length, @@ -211,8 +205,7 @@ bool PositionShiftingReaderBase::SeekSlow(Position new_pos) { Reader& src = *SrcReader(); SyncBuffer(src); const bool seek_ok = src.Seek(new_pos - base_pos_); - MakeBuffer(src); - return seek_ok; + return MakeBuffer(src) && seek_ok; } bool PositionShiftingReaderBase::SupportsSize() { @@ -225,8 +218,9 @@ absl::optional PositionShiftingReaderBase::SizeImpl() { Reader& src = *SrcReader(); SyncBuffer(src); const absl::optional size = src.Size(); - MakeBuffer(src); - if (ABSL_PREDICT_FALSE(size == absl::nullopt)) return absl::nullopt; + if (ABSL_PREDICT_FALSE(!MakeBuffer(src) || size == absl::nullopt)) { + return absl::nullopt; + } if (ABSL_PREDICT_FALSE(*size > std::numeric_limits::max() - base_pos_)) { FailOverflow(); diff --git a/riegeli/bytes/position_shifting_reader.h b/riegeli/bytes/position_shifting_reader.h index 52b13abc..6c3dadde 100644 --- a/riegeli/bytes/position_shifting_reader.h +++ b/riegeli/bytes/position_shifting_reader.h @@ -27,6 +27,7 @@ #include "absl/status/status.h" #include "absl/strings/cord.h" #include "absl/types/optional.h" +#include "riegeli/base/arithmetic.h" #include "riegeli/base/assert.h" #include "riegeli/base/chain.h" #include "riegeli/base/dependency.h" @@ -93,9 +94,9 @@ class PositionShiftingReaderBase : public Reader { // Sets cursor of `src` to cursor of `*this`. void SyncBuffer(Reader& src); - // Sets buffer pointers of `*this` to buffer pointers of `src`, adjusting - // `start()` to hide data already read. Fails `*this` if `src` failed. - void MakeBuffer(Reader& src); + // Sets buffer pointers of `*this` to buffer pointers of `src`. Fails `*this` + // if `src` failed or there is not enough `Position` space for `min_length`. + bool MakeBuffer(Reader& src, size_t min_length = 0); void Done() override; ABSL_ATTRIBUTE_COLD absl::Status AnnotateStatusImpl( @@ -125,10 +126,10 @@ class PositionShiftingReaderBase : public Reader { Position base_pos_ = 0; - // Invariants if `is_open()`: - // `start() >= SrcReader()->cursor()` + // Invariants if `ok()`: + // `start() == SrcReader()->start()` // `limit() == SrcReader()->limit()` - // `limit_pos() == SrcReader()->limit_pos() + base_pos_` + // `start_pos() == SrcReader()->start_pos() + base_pos_` }; // A `Reader` which reads from another `Reader`, reporting positions shifted so @@ -240,17 +241,27 @@ inline void PositionShiftingReaderBase::SyncBuffer(Reader& src) { src.set_cursor(cursor()); } -inline void PositionShiftingReaderBase::MakeBuffer(Reader& src) { - if (ABSL_PREDICT_FALSE(src.limit_pos() > - std::numeric_limits::max() - base_pos_)) { - FailOverflow(); - return; +inline bool PositionShiftingReaderBase::MakeBuffer(Reader& src, + size_t min_length) { + const Position max_pos = std::numeric_limits::max() - base_pos_; + if (ABSL_PREDICT_FALSE(src.limit_pos() > max_pos)) { + if (ABSL_PREDICT_FALSE(src.pos() > max_pos)) { + set_buffer(src.cursor()); + set_limit_pos(std::numeric_limits::max()); + return FailOverflow(); + } + set_buffer(src.start(), IntCast(max_pos - src.start_pos()), + src.start_to_cursor()); + set_limit_pos(std::numeric_limits::max()); + if (ABSL_PREDICT_FALSE(available() < min_length)) return FailOverflow(); + } else { + set_buffer(src.start(), src.start_to_limit(), src.start_to_cursor()); + set_limit_pos(src.limit_pos() + base_pos_); } - set_buffer(src.cursor(), src.available()); - set_limit_pos(src.limit_pos() + base_pos_); if (ABSL_PREDICT_FALSE(!src.ok())) { - FailWithoutAnnotation(AnnotateOverSrc(src.status())); + return FailWithoutAnnotation(AnnotateOverSrc(src.status())); } + return true; } template @@ -332,8 +343,7 @@ bool PositionShiftingReader::SyncImpl(SyncType sync_type) { if (sync_type != SyncType::kFromObject || src_.IsOwning()) { sync_ok = src_->Sync(sync_type); } - MakeBuffer(*src_); - return sync_ok; + return MakeBuffer(*src_) && sync_ok; } } // namespace riegeli diff --git a/riegeli/bytes/position_shifting_writer.cc b/riegeli/bytes/position_shifting_writer.cc index 00fa529d..47c8bd2a 100644 --- a/riegeli/bytes/position_shifting_writer.cc +++ b/riegeli/bytes/position_shifting_writer.cc @@ -88,8 +88,7 @@ bool PositionShiftingWriterBase::PushSlow(size_t min_length, Writer& dest = *DestWriter(); SyncBuffer(dest); const bool push_ok = dest.Push(min_length, recommended_length); - MakeBuffer(dest); - return push_ok; + return MakeBuffer(dest, min_length) && push_ok; } bool PositionShiftingWriterBase::WriteSlow(absl::string_view src) { @@ -147,8 +146,7 @@ inline bool PositionShiftingWriterBase::WriteInternal(Src&& src) { Writer& dest = *DestWriter(); SyncBuffer(dest); const bool write_ok = dest.Write(std::forward(src)); - MakeBuffer(dest); - return write_ok; + return MakeBuffer(dest) && write_ok; } bool PositionShiftingWriterBase::SupportsRandomAccess() { @@ -167,8 +165,7 @@ bool PositionShiftingWriterBase::SeekSlow(Position new_pos) { Writer& dest = *DestWriter(); SyncBuffer(dest); const bool seek_ok = dest.Seek(new_pos - base_pos_); - MakeBuffer(dest); - return seek_ok; + return MakeBuffer(dest) && seek_ok; } absl::optional PositionShiftingWriterBase::SizeImpl() { @@ -176,8 +173,9 @@ absl::optional PositionShiftingWriterBase::SizeImpl() { Writer& dest = *DestWriter(); SyncBuffer(dest); const absl::optional size = dest.Size(); - MakeBuffer(dest); - if (ABSL_PREDICT_FALSE(size == absl::nullopt)) return absl::nullopt; + if (ABSL_PREDICT_FALSE(!MakeBuffer(dest) || size == absl::nullopt)) { + return absl::nullopt; + } if (ABSL_PREDICT_FALSE(*size > std::numeric_limits::max() - base_pos_)) { FailOverflow(); @@ -199,8 +197,7 @@ bool PositionShiftingWriterBase::TruncateImpl(Position new_size) { Writer& dest = *DestWriter(); SyncBuffer(dest); const bool truncate_ok = dest.Truncate(new_size - base_pos_); - MakeBuffer(dest); - return truncate_ok; + return MakeBuffer(dest) && truncate_ok; } bool PositionShiftingWriterBase::SupportsReadMode() { @@ -214,8 +211,9 @@ Reader* PositionShiftingWriterBase::ReadModeImpl(Position initial_pos) { SyncBuffer(dest); Reader* const base_reader = dest.ReadMode(SaturatingSub(initial_pos, base_pos_)); - MakeBuffer(dest); - if (ABSL_PREDICT_FALSE(base_reader == nullptr)) return nullptr; + if (ABSL_PREDICT_FALSE(!MakeBuffer(dest) || base_reader == nullptr)) { + return nullptr; + } PositionShiftingReader<>* const reader = associated_reader_.ResetReader( base_reader, PositionShiftingReaderBase::Options().set_base_pos(base_pos_)); diff --git a/riegeli/bytes/position_shifting_writer.h b/riegeli/bytes/position_shifting_writer.h index b47e4573..ad878be9 100644 --- a/riegeli/bytes/position_shifting_writer.h +++ b/riegeli/bytes/position_shifting_writer.h @@ -95,8 +95,9 @@ class PositionShiftingWriterBase : public Writer { void SyncBuffer(Writer& dest); // Sets buffer pointers of `*this` to buffer pointers of `dest`, adjusting - // `start()` to hide data already written. Fails `*this` if `dest` failed. - void MakeBuffer(Writer& dest); + // `start()` to hide data already written. Fails `*this` if `dest` failed + // or there is not enough `Position` space for `min_length`. + bool MakeBuffer(Writer& dest, size_t min_length = 0); void Done() override; ABSL_ATTRIBUTE_COLD absl::Status AnnotateStatusImpl( @@ -246,17 +247,26 @@ inline void PositionShiftingWriterBase::SyncBuffer(Writer& dest) { dest.set_cursor(cursor()); } -inline void PositionShiftingWriterBase::MakeBuffer(Writer& dest) { - if (ABSL_PREDICT_FALSE(dest.pos() > - std::numeric_limits::max() - base_pos_)) { - FailOverflow(); - return; +inline bool PositionShiftingWriterBase::MakeBuffer(Writer& dest, + size_t min_length) { + const Position max_pos = std::numeric_limits::max() - base_pos_; + if (ABSL_PREDICT_FALSE(dest.limit_pos() > max_pos)) { + if (ABSL_PREDICT_FALSE(dest.pos() > max_pos)) { + set_buffer(dest.cursor()); + set_start_pos(std::numeric_limits::max()); + return FailOverflow(); + } + set_buffer(dest.cursor(), IntCast(max_pos - dest.pos())); + set_start_pos(dest.pos() + base_pos_); + if (ABSL_PREDICT_FALSE(available() < min_length)) return FailOverflow(); + } else { + set_buffer(dest.cursor(), dest.available()); + set_start_pos(dest.pos() + base_pos_); } - set_buffer(dest.cursor(), dest.available()); - set_start_pos(dest.pos() + base_pos_); if (ABSL_PREDICT_FALSE(!dest.ok())) { - FailWithoutAnnotation(AnnotateOverDest(dest.status())); + return FailWithoutAnnotation(AnnotateOverDest(dest.status())); } + return true; } template @@ -331,8 +341,7 @@ bool PositionShiftingWriter::FlushImpl(FlushType flush_type) { if (flush_type != FlushType::kFromObject || dest_.IsOwning()) { flush_ok = dest_->Flush(flush_type); } - MakeBuffer(*dest_); - return flush_ok; + return MakeBuffer(*dest_) && flush_ok; } } // namespace riegeli diff --git a/riegeli/bytes/writer.cc b/riegeli/bytes/writer.cc index 013e56df..afb7236e 100644 --- a/riegeli/bytes/writer.cc +++ b/riegeli/bytes/writer.cc @@ -39,7 +39,7 @@ namespace riegeli { -void Writer::OnFail() { set_buffer(); } +void Writer::OnFail() { set_buffer(start()); } absl::Status Writer::AnnotateStatusImpl(absl::Status status) { if (is_open()) return Annotate(status, absl::StrCat("at byte ", pos()));