Skip to content

Commit 63b6c58

Browse files
authored
fixed a bug related to retrieving offsets (#1063)
1 parent f6df6fc commit 63b6c58

File tree

2 files changed

+37
-1
lines changed

2 files changed

+37
-1
lines changed

tensorflow_io/core/kernels/kafka_kernels.cc

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -772,8 +772,15 @@ class KafkaRebalanceCb : public RdKafka::RebalanceCb {
772772
void rebalance_cb(RdKafka::KafkaConsumer* consumer, RdKafka::ErrorCode err,
773773
std::vector<RdKafka::TopicPartition*>& partitions) {
774774
LOG(ERROR) << "REBALANCE: " << RdKafka::err2str(err);
775+
LOG(ERROR) << "Retrieved committed offsets with status code: "
776+
<< consumer->committed(partitions, 5000);
777+
775778
for (int partition = 0; partition < partitions.size(); partition++) {
776-
partitions[partition]->set_offset(RdKafka::Topic::OFFSET_STORED);
779+
if (partitions[partition]->offset() == -1001) {
780+
LOG(INFO)
781+
<< "The consumer group was newly created, reading from beginning";
782+
partitions[partition]->set_offset(RdKafka::Topic::OFFSET_BEGINNING);
783+
}
777784
LOG(INFO) << "REBALANCE: " << partitions[partition]->topic() << "["
778785
<< partitions[partition]->partition() << "], "
779786
<< partitions[partition]->offset() << " "

tests/test_kafka_eager.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -184,3 +184,32 @@ def test_avro_encode_decode():
184184
entries = tfio.experimental.serialization.decode_avro(message, schema=schema)
185185
assert np.all(entries["f1"].numpy() == f1.numpy())
186186
assert np.all(entries["f2"].numpy() == f2.numpy())
187+
188+
189+
def test_kafka_group_io_dataset_new_cg():
190+
"""Test the functionality of the KafkaGroupIODataset when the consumer group
191+
is being newly created.
192+
"""
193+
dataset = tfio.experimental.streaming.KafkaGroupIODataset(
194+
topics=["key-partition-test"],
195+
group_id="cgtest",
196+
servers="localhost:9092",
197+
configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"],
198+
)
199+
assert np.all(
200+
sorted([k.numpy() for (k, _) in dataset])
201+
== sorted([("D" + str(i % 10)).encode() for i in range(10)])
202+
)
203+
204+
205+
def test_kafka_group_io_dataset_no_lag():
206+
"""Test the functionality of the KafkaGroupIODataset when the
207+
consumer group has read all the messages and committed the offsets.
208+
"""
209+
dataset = tfio.experimental.streaming.KafkaGroupIODataset(
210+
topics=["key-partition-test"],
211+
group_id="cgtest",
212+
servers="localhost:9092",
213+
configuration=["session.timeout.ms=7000", "max.poll.interval.ms=8000"],
214+
)
215+
assert np.all(sorted([k.numpy() for (k, _) in dataset]) == [])

0 commit comments

Comments
 (0)