@@ -1461,6 +1461,7 @@ def test_add_batch(self) -> None:
1461
1461
last_batch_received = {"step3" : None },
1462
1462
last_batch_sent = {"step3" : None },
1463
1463
last_batch_flag_sent_to = [],
1464
+ received_batch_seq_nos = {},
1464
1465
)
1465
1466
1466
1467
batch_from_step_1 = _Batch (
@@ -1505,6 +1506,7 @@ def test_step_hash_finished(self) -> None:
1505
1506
},
1506
1507
last_batch_sent = {"step1" : None , "step2" : None , "step3" : None },
1507
1508
last_batch_flag_sent_to = ["step2" ],
1509
+ received_batch_seq_nos = {},
1508
1510
)
1509
1511
1510
1512
assert batch_manager .step_has_finished ("step1" ) is True
@@ -1533,6 +1535,7 @@ def test_add_batch_with_prepend(self) -> None:
1533
1535
last_batch_received = {"step3" : None },
1534
1536
last_batch_sent = {"step3" : None },
1535
1537
last_batch_flag_sent_to = [],
1538
+ received_batch_seq_nos = {},
1536
1539
)
1537
1540
batch_0 = _Batch (
1538
1541
seq_no = 0 ,
@@ -1562,6 +1565,7 @@ def test_add_batch_to_recover_offline_batch_generation(self) -> None:
1562
1565
},
1563
1566
last_batch_sent = {"step1" : None },
1564
1567
last_batch_flag_sent_to = [],
1568
+ received_batch_seq_nos = {},
1565
1569
)
1566
1570
1567
1571
batch_manager .add_batch_to_recover_offline_batch_generation (
@@ -1675,17 +1679,6 @@ def test_cache(self, dummy_batch_manager: _BatchManager) -> None:
1675
1679
)
1676
1680
assert batch_path .exists () and batch_path .is_file ()
1677
1681
1678
- # for buffered_step_name in step.data:
1679
- # buffered_step_dir = batch_manager_step_dir / buffered_step_name
1680
- # assert buffered_step_dir.exists() and buffered_step_dir.is_dir()
1681
-
1682
- # for batch in step.data[buffered_step_name]:
1683
- # batch_path = (
1684
- # buffered_step_dir
1685
- # / f"batch_{batch.seq_no}_{batch.data_hash}.json"
1686
- # )
1687
- # assert batch_path.exists() and batch_path.is_file()
1688
-
1689
1682
def test_load_from_cache (
1690
1683
self , dummy_dag : DAG , dummy_batch_manager : _BatchManager
1691
1684
) -> None :
@@ -1712,10 +1705,12 @@ def test_can_generate(self) -> None:
1712
1705
},
1713
1706
last_batch_sent = {"step_1" : None , "step_2" : None , "step_3" : None },
1714
1707
last_batch_flag_sent_to = [],
1708
+ received_batch_seq_nos = {"step_1" : [0 ], "step_2" : [0 ], "step_3" : [0 ]},
1715
1709
)
1716
1710
1717
1711
assert batch_manager .can_generate ()
1718
1712
1713
+ def test_can_generate_last_batch (self ) -> None :
1719
1714
batch_1 = _Batch (seq_no = 0 , step_name = "step_1" , last_batch = True )
1720
1715
batch_2 = _Batch (seq_no = 0 , step_name = "step_2" , last_batch = True )
1721
1716
batch_3 = _Batch (seq_no = 0 , step_name = "step_3" , last_batch = True )
@@ -1729,10 +1724,30 @@ def test_can_generate(self) -> None:
1729
1724
},
1730
1725
last_batch_sent = {"step_1" : batch_1 , "step_2" : batch_2 , "step_3" : batch_3 },
1731
1726
last_batch_flag_sent_to = [],
1727
+ received_batch_seq_nos = {"step_1" : [0 ], "step_2" : [0 ], "step_3" : [0 ]},
1732
1728
)
1733
1729
1734
1730
assert not batch_manager .can_generate ()
1735
1731
1732
+ def test_can_generate_last_batch_missing_seq_no (self ) -> None :
1733
+ batch_1 = _Batch (seq_no = 0 , step_name = "step_1" , last_batch = True )
1734
+ batch_2 = _Batch (seq_no = 0 , step_name = "step_2" , last_batch = True )
1735
+ batch_3 = _Batch (seq_no = 1 , step_name = "step_3" , last_batch = True )
1736
+
1737
+ batch_manager = _BatchManager (
1738
+ steps = {},
1739
+ last_batch_received = {
1740
+ "step_1" : batch_1 ,
1741
+ "step_2" : batch_2 ,
1742
+ "step_3" : batch_3 ,
1743
+ },
1744
+ last_batch_sent = {"step_1" : batch_1 , "step_2" : batch_2 , "step_3" : batch_3 },
1745
+ last_batch_flag_sent_to = [],
1746
+ received_batch_seq_nos = {"step_1" : [0 ], "step_2" : [0 ], "step_3" : [1 ]},
1747
+ )
1748
+
1749
+ assert batch_manager .can_generate ()
1750
+
1736
1751
def test_invalidate_cache_for (self ) -> None :
1737
1752
with Pipeline () as pipeline :
1738
1753
generator = DummyGeneratorStep ()
@@ -1788,6 +1803,7 @@ def test_reset_batch_manager_for_step(self) -> None:
1788
1803
"step1" : _Batch (seq_no = 0 , step_name = "step1" , last_batch = True )
1789
1804
},
1790
1805
last_batch_flag_sent_to = ["step1" ],
1806
+ received_batch_seq_nos = {},
1791
1807
)
1792
1808
1793
1809
dag = DAG ()
@@ -1874,6 +1890,7 @@ def test_dump(self) -> None:
1874
1890
)
1875
1891
},
1876
1892
last_batch_flag_sent_to = ["step99" ],
1893
+ received_batch_seq_nos = {"step3" : [0 ]},
1877
1894
)
1878
1895
assert batch_manager .dump () == {
1879
1896
"steps" : {
@@ -1952,6 +1969,7 @@ def test_dump(self) -> None:
1952
1969
}
1953
1970
},
1954
1971
"last_batch_flag_sent_to" : ["step99" ],
1972
+ "received_batch_seq_nos" : {"step3" : [0 ]},
1955
1973
"type_info" : {
1956
1974
"module" : "distilabel.pipeline.batch_manager" ,
1957
1975
"name" : "_BatchManager" ,
@@ -2106,6 +2124,7 @@ def test_from_dict(self) -> None:
2106
2124
},
2107
2125
},
2108
2126
"last_batch_flag_sent_to" : ["step3" ],
2127
+ "received_batch_seq_nos" : {"step3" : [0 ]},
2109
2128
"type_info" : {
2110
2129
"module" : "distilabel.pipeline.batch_manager" ,
2111
2130
"name" : "_BatchManager" ,
@@ -2128,3 +2147,5 @@ def test_from_dict(self) -> None:
2128
2147
assert isinstance (step , _Batch )
2129
2148
2130
2149
assert batch_manager ._last_batch_flag_sent_to == ["step3" ]
2150
+
2151
+ assert batch_manager ._received_batch_seq_nos == {"step3" : [0 ]}
0 commit comments