diff --git a/py/server/tests/test_table_data_service.py b/py/server/tests/test_table_data_service.py index 792ff1bcdcc..596cb187c6a 100644 --- a/py/server/tests/test_table_data_service.py +++ b/py/server/tests/test_table_data_service.py @@ -53,6 +53,8 @@ def __init__(self, gen_pa_table: Generator[pa.Table, None, None], pt_schema: pa. self.partitions_size_subscriptions: Dict[TableLocationKey, bool] = {} self.existing_partitions_called: int = 0 self.partition_size_called: int = 0 + self.is_size_sub_failure_cb_called: bool = False + self.size_sub_failure_cb_called_cond: threading.Condition = threading.Condition() def table_schema(self, table_key: TableKeyImpl, schema_cb: Callable[[pa.Schema, Optional[pa.Schema]], None], @@ -159,15 +161,21 @@ def _th_partition_size_changes(self, table_key: TableKeyImpl, table_location_key return while self.subscriptions_enabled_for_test and self.partitions_size_subscriptions[table_location_key]: - pa_table = self.partitions[table_location_key] - rbs = pa_table.to_batches() - rbs.append(pa_table.to_batches()[0]) - new_pa_table = pa.Table.from_batches(rbs) - self.partitions[table_location_key] = new_pa_table - size_cb(new_pa_table.num_rows) if self.sub_partition_size_fail_test: - failure_cb(Exception("table location size subscription failure")) + # give main test thread a chance to wait on the condition + time.sleep(0.1) + with self.size_sub_failure_cb_called_cond: + failure_cb(Exception("table location size subscription failure")) + self.is_size_sub_failure_cb_called = True + self.size_sub_failure_cb_called_cond.notify() return + else: + pa_table = self.partitions[table_location_key] + rbs = pa_table.to_batches() + rbs.append(pa_table.to_batches()[0]) + new_pa_table = pa.Table.from_batches(rbs) + self.partitions[table_location_key] = new_pa_table + size_cb(new_pa_table.num_rows) time.sleep(0.1) def subscribe_to_table_location_size(self, table_key: TableKeyImpl, @@ -347,7 +355,7 @@ def test_partition_sub_failure(self): table = data_service.make_table(TableKeyImpl("test"), refreshing=True) with self.assertRaises(Exception) as cm: # failure_cb will be called in the background thread after 2 PUG cycles, 3 seconds timeout should be enough - self.wait_ticking_table_update(table, 600, 3) + self.wait_ticking_table_update(table, 1024, 3) self.assertTrue(table.is_failed) def test_partition_size_sub_failure(self): @@ -357,9 +365,23 @@ def test_partition_size_sub_failure(self): data_service = TableDataService(backend) backend.sub_partition_size_fail_test = True table = data_service.make_table(TableKeyImpl("test"), refreshing=True) + + # wait for location/size subscription to be established + self.wait_ticking_table_update(table, 2, 1) + + with backend.size_sub_failure_cb_called_cond: + # the test backend will trigger a size subscription failure + if not backend.is_size_sub_failure_cb_called: + if not backend.size_sub_failure_cb_called_cond.wait(timeout=5): + self.fail("size subscription failure callback was not called in 5s") + else: + # size subscription failure callback was already called + pass + with self.assertRaises(Exception) as cm: - # failure_cb will be called in the background thread after 2 PUG cycles, 3 seconds timeout should be enough - self.wait_ticking_table_update(table, 600, 3) + # for a real PUG with 1s interval, the failure is buffered after the roots are + # processed on one cycle, it won't be delivered until the next cycle + self.wait_ticking_table_update(table, 1024, 2) self.assertTrue(table.is_failed)