Skip to content

Commit

Permalink
When no more records will fit in the current chunk after adding the c…
Browse files Browse the repository at this point in the history
…urrent

record (most likely a single record exceeds the desired chunk size), write the
chunk after adding the record instead of keeping a large chunk in memory.

This required redesigning `LastPos()`, because after a successful
`WriteRecord()` the chunk encoder no longer necessarily holds the last record.

As a bonus, `LastPos()` remains valid after `Flush()` and `Close()`.

A side effect, besides writing the chunk earlier, is that `Pos()` in such case
points after that chunk, having `record_index() == 0`, rather than at the end of
the chunk. In other words, in more cases `Pos()` will happen to be the canonical
position of the next record, OTOH still not always.

Cosmetics: let `RecordWriterBase::WriteRecordImpl()` cover also
`WriteRecord(MessageLite)` by making the parameters variadic. This reduces
source code duplication.

Cosmetics: skip overflow protection when a chunk in memory would exceed
`uint64_t` range.
PiperOrigin-RevId: 604212432
  • Loading branch information
QrczakMK committed Feb 5, 2024
1 parent 0800c8f commit e57d97d
Show file tree
Hide file tree
Showing 4 changed files with 82 additions and 60 deletions.
3 changes: 1 addition & 2 deletions python/riegeli/records/record_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -716,8 +716,7 @@ Seeking to any equivalent position leads to reading the same record.
last_pos.numeric returns the position as an int.
Precondition:
a record was successfully written and there was no intervening call to
close() or flush().
a record was successfully written
)doc"),
nullptr},
{const_cast<char*>("pos"), reinterpret_cast<getter>(RecordWriterPos),
Expand Down
2 changes: 2 additions & 0 deletions python/riegeli/records/tests/records_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,6 +337,7 @@ def test_write_read_record(self, file_spec, random_access, parallelism):
positions.append(canonical_pos)
writer.close()
end_pos = writer.pos
self.assertEqual(writer.last_pos, positions[-1])
with riegeli.RecordReader(
files.reading_open(),
owns_src=files.reading_should_close,
Expand Down Expand Up @@ -375,6 +376,7 @@ def test_write_read_message(self, file_spec, random_access, parallelism):
positions.append(canonical_pos)
writer.close()
end_pos = writer.pos
self.assertEqual(writer.last_pos, positions[-1])
with riegeli.RecordReader(
files.reading_open(),
owns_src=files.reading_should_close,
Expand Down
113 changes: 62 additions & 51 deletions riegeli/records/record_writer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -811,31 +811,31 @@ void RecordWriterBase::Reset(Closed) {
Object::Reset(kClosed);
desired_chunk_size_ = 0;
chunk_size_so_far_ = 0;
last_record_is_valid_ = false;
last_record_ = LastRecordIsInvalid();
worker_.reset();
}

void RecordWriterBase::Reset() {
Object::Reset();
desired_chunk_size_ = 0;
chunk_size_so_far_ = 0;
last_record_is_valid_ = false;
last_record_ = LastRecordIsInvalid();
worker_.reset();
}

RecordWriterBase::RecordWriterBase(RecordWriterBase&& that) noexcept
: Object(static_cast<Object&&>(that)),
desired_chunk_size_(that.desired_chunk_size_),
chunk_size_so_far_(that.chunk_size_so_far_),
last_record_is_valid_(std::exchange(that.last_record_is_valid_, false)),
last_record_(std::exchange(that.last_record_, LastRecordIsInvalid())),
worker_(std::move(that.worker_)) {}

RecordWriterBase& RecordWriterBase::operator=(
RecordWriterBase&& that) noexcept {
Object::operator=(static_cast<Object&&>(that));
desired_chunk_size_ = that.desired_chunk_size_;
chunk_size_so_far_ = that.chunk_size_so_far_;
last_record_is_valid_ = std::exchange(that.last_record_is_valid_, false);
last_record_ = std::exchange(that.last_record_, LastRecordIsInvalid());
worker_ = std::move(that.worker_);
return *this;
}
Expand Down Expand Up @@ -868,8 +868,10 @@ void RecordWriterBase::Done() {
"null worker_ but RecordWriterBase is_open()";
return;
}
last_record_is_valid_ = false;
if (chunk_size_so_far_ != 0) {
if (chunk_size_so_far_ > 0) {
if (absl::holds_alternative<LastRecordIsValid>(last_record_)) {
last_record_ = LastRecordIsValidAt{worker_->LastPos()};
}
if (ABSL_PREDICT_FALSE(!worker_->CloseChunk())) {
FailWithoutAnnotation(worker_->status());
}
Expand Down Expand Up @@ -902,73 +904,58 @@ absl::Status RecordWriterBase::AnnotateOverDest(absl::Status status) {
bool RecordWriterBase::WriteRecord(const google::protobuf::MessageLite& record,
SerializeOptions serialize_options) {
if (ABSL_PREDICT_FALSE(!ok())) return false;
last_record_is_valid_ = false;
// Decoding a chunk writes records to one array, and their positions to
// another array. We limit the size of both arrays together, to include
// attempts to accumulate an unbounded number of empty records.
const size_t size = serialize_options.GetByteSize(record);
const uint64_t added_size =
SaturatingAdd(IntCast<uint64_t>(size), uint64_t{sizeof(uint64_t)});
if (ABSL_PREDICT_FALSE(chunk_size_so_far_ > desired_chunk_size_ ||
added_size >
desired_chunk_size_ - chunk_size_so_far_) &&
chunk_size_so_far_ > 0) {
if (ABSL_PREDICT_FALSE(!worker_->CloseChunk())) {
return FailWithoutAnnotation(worker_->status());
}
worker_->OpenChunk();
chunk_size_so_far_ = 0;
}
chunk_size_so_far_ += added_size;
if (ABSL_PREDICT_FALSE(!worker_->AddRecord(record, serialize_options))) {
return FailWithoutAnnotation(worker_->status());
}
last_record_is_valid_ = true;
return true;
return WriteRecordImpl(size, record, std::move(serialize_options));
}

bool RecordWriterBase::WriteRecord(absl::string_view record) {
return WriteRecordImpl(record);
if (ABSL_PREDICT_FALSE(!ok())) return false;
return WriteRecordImpl(record.size(), record);
}

template <typename Src,
std::enable_if_t<std::is_same<Src, std::string>::value, int>>
bool RecordWriterBase::WriteRecord(Src&& record) {
if (ABSL_PREDICT_FALSE(!ok())) return false;
const size_t size = record.size();
// `std::move(record)` is correct and `std::forward<Src>(record)` is not
// necessary: `Src` is always `std::string`, never an lvalue reference.
return WriteRecordImpl(std::move(record));
return WriteRecordImpl(size, std::move(record));
}

template bool RecordWriterBase::WriteRecord(std::string&& record);

bool RecordWriterBase::WriteRecord(const Chain& record) {
return WriteRecordImpl(record);
if (ABSL_PREDICT_FALSE(!ok())) return false;
return WriteRecordImpl(record.size(), record);
}

bool RecordWriterBase::WriteRecord(Chain&& record) {
return WriteRecordImpl(std::move(record));
if (ABSL_PREDICT_FALSE(!ok())) return false;
const size_t size = record.size();
return WriteRecordImpl(size, std::move(record));
}

bool RecordWriterBase::WriteRecord(const absl::Cord& record) {
return WriteRecordImpl(record);
if (ABSL_PREDICT_FALSE(!ok())) return false;
return WriteRecordImpl(record.size(), record);
}

bool RecordWriterBase::WriteRecord(absl::Cord&& record) {
return WriteRecordImpl(std::move(record));
if (ABSL_PREDICT_FALSE(!ok())) return false;
const size_t size = record.size();
return WriteRecordImpl(size, std::move(record));
}

template <typename Record>
inline bool RecordWriterBase::WriteRecordImpl(Record&& record) {
if (ABSL_PREDICT_FALSE(!ok())) return false;
last_record_is_valid_ = false;
template <typename... Args>
inline bool RecordWriterBase::WriteRecordImpl(size_t size, Args&&... args) {
last_record_ = LastRecordIsInvalid();
// Decoding a chunk writes records to one array, and their positions to
// another array. We limit the size of both arrays together, to include
// attempts to accumulate an unbounded number of empty records.
const uint64_t added_size = SaturatingAdd(IntCast<uint64_t>(record.size()),
uint64_t{sizeof(uint64_t)});
if (ABSL_PREDICT_FALSE(chunk_size_so_far_ > desired_chunk_size_ ||
added_size >
desired_chunk_size_ - chunk_size_so_far_) &&
const uint64_t added_size = uint64_t{size} + uint64_t{sizeof(uint64_t)};
if (ABSL_PREDICT_FALSE(chunk_size_so_far_ + added_size >
desired_chunk_size_) &&
chunk_size_so_far_ > 0) {
if (ABSL_PREDICT_FALSE(!worker_->CloseChunk())) {
return FailWithoutAnnotation(worker_->status());
Expand All @@ -977,17 +964,32 @@ inline bool RecordWriterBase::WriteRecordImpl(Record&& record) {
chunk_size_so_far_ = 0;
}
chunk_size_so_far_ += added_size;
if (ABSL_PREDICT_FALSE(!worker_->AddRecord(std::forward<Record>(record)))) {
if (ABSL_PREDICT_FALSE(!worker_->AddRecord(std::forward<Args>(args)...))) {
return FailWithoutAnnotation(worker_->status());
}
last_record_is_valid_ = true;
if (ABSL_PREDICT_FALSE(chunk_size_so_far_ + uint64_t{sizeof(uint64_t)} >
desired_chunk_size_)) {
// No more records will fit in this chunk, most likely a single record
// exceeds the desired chunk size. Write the chunk now to avoid keeping a
// large chunk in memory.
last_record_ = LastRecordIsValidAt{worker_->LastPos()};
if (ABSL_PREDICT_FALSE(!worker_->CloseChunk())) {
return FailWithoutAnnotation(worker_->status());
}
worker_->OpenChunk();
chunk_size_so_far_ = 0;
return true;
}
last_record_ = LastRecordIsValid();
return true;
}

bool RecordWriterBase::Flush(FlushType flush_type) {
if (ABSL_PREDICT_FALSE(!ok())) return false;
last_record_is_valid_ = false;
if (chunk_size_so_far_ != 0) {
if (chunk_size_so_far_ > 0) {
if (absl::holds_alternative<LastRecordIsValid>(last_record_)) {
last_record_ = LastRecordIsValidAt{worker_->LastPos()};
}
if (ABSL_PREDICT_FALSE(!worker_->CloseChunk())) {
return FailWithoutAnnotation(worker_->status());
}
Expand All @@ -1000,7 +1002,7 @@ bool RecordWriterBase::Flush(FlushType flush_type) {
return FailWithoutAnnotation(worker_->status());
}
}
if (chunk_size_so_far_ != 0) {
if (chunk_size_so_far_ > 0) {
worker_->OpenChunk();
chunk_size_so_far_ = 0;
}
Expand All @@ -1014,8 +1016,10 @@ RecordWriterBase::FutureStatus RecordWriterBase::FutureFlush(
promise.set_value(status());
return promise.get_future();
}
last_record_is_valid_ = false;
if (chunk_size_so_far_ != 0) {
if (chunk_size_so_far_ > 0) {
if (absl::holds_alternative<LastRecordIsValid>(last_record_)) {
last_record_ = LastRecordIsValidAt{worker_->LastPos()};
}
if (ABSL_PREDICT_FALSE(!worker_->CloseChunk())) {
FailWithoutAnnotation(worker_->status());
std::promise<absl::Status> promise;
Expand All @@ -1037,7 +1041,7 @@ RecordWriterBase::FutureStatus RecordWriterBase::FutureFlush(
} else {
result = worker_->FutureFlush(flush_type);
}
if (chunk_size_so_far_ != 0) {
if (chunk_size_so_far_ > 0) {
worker_->OpenChunk();
chunk_size_so_far_ = 0;
}
Expand All @@ -1048,6 +1052,13 @@ FutureRecordPosition RecordWriterBase::LastPos() const {
RIEGELI_ASSERT(last_record_is_valid())
<< "Failed precondition of RecordWriterBase::LastPos(): "
"no record was recently written";
{
const LastRecordIsValidAt* const last_record_at_pos =
absl::get_if<LastRecordIsValidAt>(&last_record_);
if (last_record_at_pos != nullptr) {
return last_record_at_pos->pos;
}
}
RIEGELI_ASSERT(worker_ != nullptr)
<< "Failed invariant of RecordWriterBase: "
"last position should be valid but worker is null";
Expand Down
24 changes: 17 additions & 7 deletions riegeli/records/record_writer.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#ifndef RIEGELI_RECORDS_RECORD_WRITER_H_
#define RIEGELI_RECORDS_RECORD_WRITER_H_

#include <stddef.h>
#include <stdint.h>

#include <future>
Expand All @@ -30,6 +31,7 @@
#include "absl/strings/cord.h"
#include "absl/strings/string_view.h"
#include "absl/types/optional.h"
#include "absl/types/variant.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/message_lite.h"
#include "riegeli/base/assert.h"
Expand Down Expand Up @@ -529,13 +531,14 @@ class RecordWriterBase : public Object {
// `LastPos().get().numeric()` returns the position as an integer of type
// `Position`.
//
// Precondition: a record was successfully written and there was no
// intervening call to `Close()`, `Flush()` or `FutureFlush()` (this can be
// checked with `last_record_is_valid()`).
// Precondition: a record was successfully written (this can be checked with
// `last_record_is_valid()`).
FutureRecordPosition LastPos() const;

// Returns `true` if calling `LastPos()` is valid.
bool last_record_is_valid() const { return last_record_is_valid_; }
bool last_record_is_valid() const {
return !absl::holds_alternative<LastRecordIsInvalid>(last_record_);
}

// Returns a position of the next record (or the end of file if there is no
// next record).
Expand Down Expand Up @@ -591,12 +594,19 @@ class RecordWriterBase : public Object {
class SerialWorker;
class Worker;

template <typename Record>
bool WriteRecordImpl(Record&& record);
struct LastRecordIsInvalid {};
struct LastRecordIsValid {}; // At one record before `Pos()`.
struct LastRecordIsValidAt {
FutureRecordPosition pos;
};

template <typename... Args>
bool WriteRecordImpl(size_t size, Args&&... args);

uint64_t desired_chunk_size_ = 0;
uint64_t chunk_size_so_far_ = 0;
bool last_record_is_valid_ = false;
absl::variant<LastRecordIsInvalid, LastRecordIsValid, LastRecordIsValidAt>
last_record_ = LastRecordIsInvalid();
// Invariant: if `is_open()` then `worker_ != nullptr`.
std::unique_ptr<Worker> worker_;
};
Expand Down

0 comments on commit e57d97d

Please sign in to comment.