@@ -184,3 +184,32 @@ def test_avro_encode_decode():
184
184
entries = tfio .experimental .serialization .decode_avro (message , schema = schema )
185
185
assert np .all (entries ["f1" ].numpy () == f1 .numpy ())
186
186
assert np .all (entries ["f2" ].numpy () == f2 .numpy ())
187
+
188
+
189
+ def test_kafka_group_io_dataset_new_cg ():
190
+ """Test the functionality of the KafkaGroupIODataset when the consumer group
191
+ is being newly created.
192
+ """
193
+ dataset = tfio .experimental .streaming .KafkaGroupIODataset (
194
+ topics = ["key-partition-test" ],
195
+ group_id = "cgtest" ,
196
+ servers = "localhost:9092" ,
197
+ configuration = ["session.timeout.ms=7000" , "max.poll.interval.ms=8000" ],
198
+ )
199
+ assert np .all (
200
+ sorted ([k .numpy () for (k , _ ) in dataset ])
201
+ == sorted ([("D" + str (i % 10 )).encode () for i in range (10 )])
202
+ )
203
+
204
+
205
+ def test_kafka_group_io_dataset_no_lag ():
206
+ """Test the functionality of the KafkaGroupIODataset when the
207
+ consumer group has read all the messages and committed the offsets.
208
+ """
209
+ dataset = tfio .experimental .streaming .KafkaGroupIODataset (
210
+ topics = ["key-partition-test" ],
211
+ group_id = "cgtest" ,
212
+ servers = "localhost:9092" ,
213
+ configuration = ["session.timeout.ms=7000" , "max.poll.interval.ms=8000" ],
214
+ )
215
+ assert np .all (sorted ([k .numpy () for (k , _ ) in dataset ]) == [])
0 commit comments