From 39b62ae6ef1d44711f38eecb1a8c87ff4e2761b8 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Sun, 18 May 2025 18:28:02 +1000 Subject: [PATCH 01/14] Network test needs these --- tests/networktest.cpp | 3 +++ 1 file changed, 3 insertions(+) diff --git a/tests/networktest.cpp b/tests/networktest.cpp index 581a712a..12f6238a 100644 --- a/tests/networktest.cpp +++ b/tests/networktest.cpp @@ -138,6 +138,9 @@ int main(int argc, const char* argv[]) { NUClear::Configuration config; config.default_pool_concurrency = 4; NUClear::PowerPlant plant(config, argc, argv); + plant.install(); + plant.install(); + plant.install(); plant.install(); plant.start(); From 118cd9f0310e4fb7b9563d87a413471c0090c565 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Fri, 16 May 2025 18:51:34 +1000 Subject: [PATCH 02/14] Swap to using the same timeout estimator as TCP --- src/CMakeLists.txt | 10 +- src/extension/network/NUClearNetwork.cpp | 13 +- src/extension/network/NUClearNetwork.hpp | 38 +--- src/extension/network/RTTEstimator.cpp | 52 +++++ src/extension/network/RTTEstimator.hpp | 125 +++++++++++ tests/CMakeLists.txt | 4 +- tests/tests/network/RTTEstimator.cpp | 252 +++++++++++++++++++++++ 7 files changed, 449 insertions(+), 45 deletions(-) create mode 100644 src/extension/network/RTTEstimator.cpp create mode 100644 src/extension/network/RTTEstimator.hpp create mode 100644 tests/tests/network/RTTEstimator.cpp diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7ed0fb8a..722cf01b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -28,7 +28,15 @@ configure_file(nuclear.in ${PROJECT_BINARY_DIR}/nuclear) # Build the library find_package(Threads REQUIRED) -file(GLOB_RECURSE src "*.c" "*.cpp" "*.hpp" "*.ipp") +file( + GLOB_RECURSE + src + CONFIGURE_DEPENDS + "*.c" + "*.cpp" + "*.hpp" + "*.ipp" +) add_library(nuclear STATIC ${src}) add_library(NUClear::nuclear ALIAS nuclear) diff --git a/src/extension/network/NUClearNetwork.cpp b/src/extension/network/NUClearNetwork.cpp index 11839da9..fcc28fe8 100644 --- a/src/extension/network/NUClearNetwork.cpp +++ b/src/extension/network/NUClearNetwork.cpp @@ -501,7 +501,7 @@ namespace extension { if (ptr) { auto now = std::chrono::steady_clock::now(); - auto timeout = it->last_send + ptr->round_trip_time; + auto timeout = it->last_send + ptr->rtt.timeout(); // Check if we should have expected an ack by now for some packets if (timeout < now) { @@ -510,7 +510,7 @@ namespace extension { it->last_send = now; // The next time we should check for a timeout - auto next_timeout = now + ptr->round_trip_time; + auto next_timeout = now + ptr->rtt.timeout(); if (next_timeout < next_event) { next_event = next_timeout; next_event_callback(next_event); @@ -869,7 +869,7 @@ namespace extension { // Check for and delete any timed out packets for (auto it = assemblers.begin(); it != assemblers.end();) { const auto now = std::chrono::steady_clock::now(); - const auto timeout = remote->round_trip_time * 10.0; + const auto timeout = remote->rtt.timeout() * 10.0; const auto& last_chunk_time = it->second.first; it = now > last_chunk_time + timeout ? assemblers.erase(it) : std::next(it); @@ -919,8 +919,7 @@ namespace extension { // Approximate how long the round trip is to this remote so we can work out how // long before retransmitting - // We use a baby kalman filter to help smooth out jitter - remote->measure_round_trip(round_trip); + remote->rtt.measure(round_trip); // Update our acks bool all_acked = true; @@ -987,7 +986,7 @@ namespace extension { s->last_send = std::chrono::steady_clock::now(); // The next time we should check for a timeout - auto next_timeout = s->last_send + remote->round_trip_time; + auto next_timeout = s->last_send + remote->rtt.timeout(); if (next_timeout < next_event) { next_event = next_timeout; next_event_callback(next_event); @@ -1108,7 +1107,7 @@ namespace extension { queue.targets.emplace_back(it->second, acks); // The next time we should check for a timeout - auto next_timeout = std::chrono::steady_clock::now() + it->second->round_trip_time; + auto next_timeout = std::chrono::steady_clock::now() + it->second->rtt.timeout(); if (next_timeout < next_event) { next_event = next_timeout; next_event_callback(next_event); diff --git a/src/extension/network/NUClearNetwork.hpp b/src/extension/network/NUClearNetwork.hpp index e2277af2..4f75c840 100644 --- a/src/extension/network/NUClearNetwork.hpp +++ b/src/extension/network/NUClearNetwork.hpp @@ -39,6 +39,7 @@ #include "../../util/network/sock_t.hpp" #include "../../util/platform.hpp" +#include "RTTEstimator.hpp" #include "wire_protocol.hpp" namespace NUClear { @@ -79,41 +80,8 @@ namespace extension { std::pair>>> assemblers; - /// Struct storing the kalman filter for round trip time - struct RoundTripKF { - float process_noise = 1e-6f; - float measurement_noise = 1e-1f; - float variance = 1.0f; - float mean = 1.0f; - }; - /// A little kalman filter for estimating round trip time - RoundTripKF round_trip_kf{}; - - std::chrono::steady_clock::duration round_trip_time{std::chrono::seconds(1)}; - - void measure_round_trip(std::chrono::steady_clock::duration time) { - - // Make our measurement into a float seconds type - const std::chrono::duration m = - std::chrono::duration_cast>(time); - - // Alias variables - const auto& Q = round_trip_kf.process_noise; - const auto& R = round_trip_kf.measurement_noise; - auto& P = round_trip_kf.variance; - auto& X = round_trip_kf.mean; - - // Calculate our kalman gain - const float K = (P + Q) / (P + Q + R); - - // Do filter - P = R * (P + Q) / (R + P + Q); - X = X + (m.count() - X) * K; - - // Put result into our variable - round_trip_time = std::chrono::duration_cast( - std::chrono::duration(X)); - } + /// RTT estimator for this network target + RTTEstimator rtt; }; NUClearNetwork() = default; diff --git a/src/extension/network/RTTEstimator.cpp b/src/extension/network/RTTEstimator.cpp new file mode 100644 index 00000000..00567616 --- /dev/null +++ b/src/extension/network/RTTEstimator.cpp @@ -0,0 +1,52 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "RTTEstimator.hpp" + +#include + +namespace NUClear { +namespace extension { + namespace network { + + void RTTEstimator::measure(std::chrono::steady_clock::duration time) { + // Convert measurement to float seconds + const std::chrono::duration m = std::chrono::duration_cast>(time); + const float sample_rtt = m.count(); + + // Calculate RTT variation + const float err = sample_rtt - smoothed_rtt; + rtt_var = (1 - beta) * rtt_var + beta * std::abs(err); + + // Update smoothed RTT + smoothed_rtt = (1 - alpha) * smoothed_rtt + alpha * sample_rtt; + + // Calculate RTO (smoothed RTT + 4 * RTT variation) and bound to limits + rto = std::min(std::max(smoothed_rtt + 4 * rtt_var, min_rto), max_rto); + } + + std::chrono::steady_clock::duration RTTEstimator::timeout() const { + return std::chrono::duration_cast(std::chrono::duration(rto)); + } + + } // namespace network +} // namespace extension +} // namespace NUClear diff --git a/src/extension/network/RTTEstimator.hpp b/src/extension/network/RTTEstimator.hpp new file mode 100644 index 00000000..164bd74e --- /dev/null +++ b/src/extension/network/RTTEstimator.hpp @@ -0,0 +1,125 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#ifndef NUCLEAR_EXTENSION_NETWORK_RTT_ESTIMATOR_HPP +#define NUCLEAR_EXTENSION_NETWORK_RTT_ESTIMATOR_HPP + +#include +#include +#include + +namespace NUClear { +namespace extension { + namespace network { + + /** + * Implements TCP-style Round Trip Time (RTT) estimation using Jacobson/Karels algorithm. + * + * This class provides RTT estimation functionality similar to TCP's RTT estimation mechanism. + * It uses an Exponentially Weighted Moving Average (EWMA) to smooth RTT measurements and + * calculate a retransmission timeout (RTO) value. The implementation follows the TCP + * Jacobson/Karels algorithm which provides robust RTT estimation that: + * - Smoothly tracks the mean RTT + * - Adapts to RTT variations + * - Handles network jitter + * - Provides conservative timeout values + */ + class RTTEstimator { + public: + /** + * Construct a new RTT Estimator + * + * @param alpha Weight for RTT smoothing (default: 0.125, TCP standard) + * @param beta Weight for RTT variation (default: 0.25, TCP standard) + * @param initial_rtt Initial RTT estimate in seconds (default: 1.0) + * @param initial_rtt_var Initial RTT variation in seconds (default: 0.0) + * @param min_rto Minimum RTO value in seconds (default: 0.1) + * @param max_rto Maximum RTO value in seconds (default: 60.0) + * + * The alpha and beta parameters control how quickly the estimator adapts to changes: + * - alpha: Lower values (e.g. 0.125) make the smoothed RTT more stable but slower to adapt + * - beta: Lower values (e.g. 0.25) make the RTT variation more stable but slower to adapt + * + * @throws std::invalid_argument if alpha or beta are not in range [0,1] + * @throws std::invalid_argument if min_rto >= max_rto + */ + RTTEstimator(float alpha = 0.125f, + float beta = 0.25f, + float initial_rtt = 1.0f, + float initial_rtt_var = 0.0f, + float min_rto = 0.1f, + float max_rto = 60.0f) + : alpha(alpha) + , beta(beta) + , min_rto(min_rto) + , max_rto(max_rto) + , smoothed_rtt(initial_rtt) + , rtt_var(initial_rtt_var) + , rto(std::min(std::max(initial_rtt + 4 * initial_rtt_var, min_rto), max_rto)) { + + if (alpha < 0.0f || alpha > 1.0f) { + throw std::invalid_argument("alpha must be in range [0,1]"); + } + if (beta < 0.0f || beta > 1.0f) { + throw std::invalid_argument("beta must be in range [0,1]"); + } + if (min_rto >= max_rto) { + throw std::invalid_argument("min_rto must be less than max_rto"); + } + } + + /** + * Update the RTT estimate with a new measurement + * + * Updates the smoothed RTT, RTT variation, and RTO using the Jacobson/Karels algorithm: + * 1. RTT variation = (1 - beta) * old_variation + beta * |smoothed_rtt - new_rtt| + * 2. Smoothed RTT = (1 - alpha) * old_rtt + alpha * new_rtt + * 3. RTO = smoothed_rtt + 4 * rtt_var + * + * The RTO is bounded between min_rto and max_rto to prevent extreme values. + * + * @param time The measured round trip time + */ + void measure(std::chrono::steady_clock::duration time); + + /** + * Get the current retransmission timeout + * + * @return The RTO as a duration. This value represents the recommended timeout + * for network operations based on the current RTT estimates. + */ + std::chrono::steady_clock::duration timeout() const; + + private: + float alpha; ///< Weight for RTT smoothing (typically 0.125) + float beta; ///< Weight for RTT variation (typically 0.25) + float min_rto; ///< Minimum RTO value in seconds + float max_rto; ///< Maximum RTO value in seconds + float smoothed_rtt; ///< Smoothed RTT estimate in seconds + float rtt_var; ///< RTT variation in seconds + float rto; ///< Retransmission timeout in seconds + }; + + } // namespace network +} // namespace extension +} // namespace NUClear + +#endif // NUCLEAR_EXTENSION_NETWORK_RTT_ESTIMATOR_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9520419b..700a989c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -40,7 +40,7 @@ set_target_properties(${catch2_targets} PROPERTIES CXX_CLANG_TIDY "") set_target_properties(${catch2_target} PROPERTIES CMAKE_CXX_FLAGS "") # Create a test_util library that is used by all tests -file(GLOB_RECURSE test_util_src "test_util/*.cpp") +file(GLOB_RECURSE test_util_src CONFIGURE_DEPENDS "test_util/*.cpp") add_library(test_util OBJECT ${test_util_src}) # This is linking WHOLE_ARCHIVE as otherwise the linker will remove the WSAHolder from the final binary # As a result the WSA initialisation code won't run and the network tests will fail @@ -50,7 +50,7 @@ target_include_directories(test_util PUBLIC ${PROJECT_BINARY_DIR}/include ${PROJ target_include_directories(test_util PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) # Create a test binary for each test file -file(GLOB_RECURSE test_sources "tests/*.cpp") +file(GLOB_RECURSE test_sources CONFIGURE_DEPENDS "tests/*.cpp") foreach(test_file ${test_sources}) get_filename_component(test_name ${test_file} NAME_WE) get_filename_component(test_dir ${test_file} DIRECTORY) diff --git a/tests/tests/network/RTTEstimator.cpp b/tests/tests/network/RTTEstimator.cpp new file mode 100644 index 00000000..d272c0c9 --- /dev/null +++ b/tests/tests/network/RTTEstimator.cpp @@ -0,0 +1,252 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "extension/network/RTTEstimator.hpp" + +#include +#include + +using namespace NUClear::extension::network; +using namespace std::chrono_literals; + +SCENARIO("RTTEstimator initial state", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 1.0f, 0.0f); + + THEN("the initial timeout should be 1 second") { + REQUIRE(rtt.timeout() == 1s); + } + } +} + +SCENARIO("RTTEstimator with constant RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we feed it constant RTTs of 100ms") { + for (int i = 0; i < 20; ++i) { + rtt.measure(100ms); + } + + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 200ms); + } + } + } +} + +SCENARIO("RTTEstimator with increasing RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we measure 100ms then 200ms") { + rtt.measure(100ms); + rtt.measure(200ms); + + THEN("the timeout should be at least 200ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 200ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } + } +} + +SCENARIO("RTTEstimator with decreasing RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.2f, 0.0f); + + WHEN("we measure 200ms then 100ms") { + rtt.measure(200ms); + rtt.measure(100ms); + + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } + } +} + +SCENARIO("RTTEstimator with oscillating RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.15f, 0.0f); + + WHEN("we feed it alternating RTTs of 100ms and 200ms") { + for (int i = 0; i < 20; ++i) { + rtt.measure(i % 2 == 0 ? 100ms : 200ms); + } + + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } + } +} + +SCENARIO("RTTEstimator with large RTT variation", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we measure 100ms then 1 second") { + rtt.measure(100ms); + rtt.measure(1s); + + THEN("the timeout should be at least 1s and not unreasonably high") { + REQUIRE(rtt.timeout() >= 1s); + REQUIRE(rtt.timeout() <= 2s); + } + } + } +} + +SCENARIO("RTTEstimator with small RTT variation", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we measure 100ms then 110ms") { + rtt.measure(100ms); + rtt.measure(110ms); + + THEN("the timeout should be at least 110ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 110ms); + REQUIRE(rtt.timeout() <= 200ms); + } + } + } +} + +SCENARIO("RTTEstimator with zero RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.001f, 0.0f); + + WHEN("we measure 0ms") { + rtt.measure(0ms); + + THEN("the timeout should be at least 0ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 0ms); + REQUIRE(rtt.timeout() <= 1s); + } + } + } +} + +SCENARIO("RTTEstimator with very large RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 30.0f, 0.0f); + + WHEN("we measure 30 seconds") { + rtt.measure(30s); + + THEN("the timeout should be at least 30s and not unreasonably high") { + REQUIRE(rtt.timeout() >= 30s); + REQUIRE(rtt.timeout() <= 35s); + } + } + } +} + +SCENARIO("RTTEstimator exact calculation verification", "[network]") { + GIVEN("a new RTTEstimator with known initial state") { + // Initialize with SRTT = 100ms, RTTVAR = 50ms + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.05f); + + WHEN("we measure a 120ms RTT") { + rtt.measure(120ms); + + THEN("the values should match the TCP calculation") { + // Expected values: + // RTTVAR = (1 - 0.25) * 50 + 0.25 * |100 - 120| = 0.75 * 50 + 0.25 * 20 = 37.5 + 5 = 42.5ms + // SRTT = (1 - 0.125) * 100 + 0.125 * 120 = 87.5 + 15 = 102.5ms + // RTO = 102.5 + 4 * 42.5 = 272.5ms + REQUIRE(rtt.timeout() >= 270ms); + REQUIRE(rtt.timeout() <= 275ms); + } + } + } +} + +SCENARIO("RTTEstimator spike response", "[network]") { + GIVEN("a new RTTEstimator with stable RTT") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we feed it constant 100ms RTTs then a 500ms spike") { + // First establish a stable RTT + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto before_spike = rtt.timeout(); + + // Then inject a spike + rtt.measure(500ms); + auto after_spike = rtt.timeout(); + + THEN("the timeout should increase but not dramatically") { + REQUIRE(after_spike > before_spike); // Should increase + REQUIRE(after_spike < 1s); // But not too much + } + + AND_WHEN("we return to normal RTT") { + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto after_recovery = rtt.timeout(); + + THEN("it should recover to a reasonable value") { + REQUIRE(after_recovery >= 100ms); + REQUIRE(after_recovery <= 300ms); + } + } + } + } +} + +SCENARIO("RTTEstimator noise resilience", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we feed it noisy RTTs around 100ms") { + // Generate noisy RTTs: 100ms ± 20ms + for (int i = 0; i < 50; ++i) { + auto noise = (i % 2 == 0 ? 20ms : -20ms); + rtt.measure(100ms + noise); + } + + THEN("the timeout should remain stable") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 300ms); + } + + AND_WHEN("we continue with constant RTT") { + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto final_timeout = rtt.timeout(); + + THEN("it should converge to the true RTT") { + REQUIRE(final_timeout >= 100ms); + REQUIRE(final_timeout <= 200ms); + } + } + } + } +} From bb93e71857f2c0e6fa4d616c2b5cb6d16f2e730d Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Sun, 18 May 2025 20:27:22 +1000 Subject: [PATCH 03/14] Add a packet deduplicator and tests using a bitset window --- src/extension/network/NUClearNetwork.cpp | 71 +++---- src/extension/network/NUClearNetwork.hpp | 14 +- src/extension/network/PacketDeduplicator.cpp | 75 ++++++++ src/extension/network/PacketDeduplicator.hpp | 69 +++++++ tests/tests/network/PacketDeduplicator.cpp | 183 +++++++++++++++++++ 5 files changed, 362 insertions(+), 50 deletions(-) create mode 100644 src/extension/network/PacketDeduplicator.cpp create mode 100644 src/extension/network/PacketDeduplicator.hpp create mode 100644 tests/tests/network/PacketDeduplicator.cpp diff --git a/src/extension/network/NUClearNetwork.cpp b/src/extension/network/NUClearNetwork.cpp index fcc28fe8..2f75acdd 100644 --- a/src/extension/network/NUClearNetwork.cpp +++ b/src/extension/network/NUClearNetwork.cpp @@ -675,39 +675,35 @@ namespace extension { // Check if this packet is a retransmission of data if (header.type == DATA_RETRANSMISSION) { - // See if we recently processed this packet - // NOLINTNEXTLINE(readability-qualified-auto) MSVC disagrees - auto it = std::find(remote->recent_packets.begin(), - remote->recent_packets.end(), - packet.packet_id); - // We recently processed this packet, this is just a failed ack - // Send the ack again if it was reliable - if (it != remote->recent_packets.end() && packet.reliable) { + if (remote->deduplicator.is_duplicate(packet.packet_id)) { - // Allocate room for the whole ack packet - std::vector r(sizeof(ACKPacket) + (packet.packet_count / 8), 0); - ACKPacket& response = *reinterpret_cast(r.data()); - response = ACKPacket(); - response.packet_id = packet.packet_id; - response.packet_no = packet.packet_no; - response.packet_count = packet.packet_count; + // Send the ack again if it was reliable + if (packet.reliable) { + // Allocate room for the whole ack packet + std::vector r(sizeof(ACKPacket) + (packet.packet_count / 8), 0); + ACKPacket& response = *reinterpret_cast(r.data()); + response = ACKPacket(); + response.packet_id = packet.packet_id; + response.packet_no = packet.packet_no; + response.packet_count = packet.packet_count; - // Set the bits for all packets (we got the whole thing) - for (int i = 0; i < packet.packet_count; ++i) { - (&response.packets)[i / 8] |= uint8_t(1 << (i % 8)); - } + // Set the bits for all packets (we got the whole thing) + for (int i = 0; i < packet.packet_count; ++i) { + (&response.packets)[i / 8] |= uint8_t(1 << (i % 8)); + } - // Make who we are sending it to into a useable address - const sock_t& to = remote->target; + // Make who we are sending it to into a useable address + const sock_t& to = remote->target; - // Send the packet - ::sendto(data_fd, - reinterpret_cast(r.data()), - static_cast(r.size()), - 0, - &to.sock, - to.size()); + // Send the packet + ::sendto(data_fd, + reinterpret_cast(r.data()), + static_cast(r.size()), + 0, + &to.sock, + to.size()); + } // We don't need to process this packet we already did return; @@ -739,13 +735,11 @@ namespace extension { 0, &to.sock, to.size()); - - // Set this packet to have been recently received - remote->recent_packets[remote->recent_packets_index - .fetch_add(1, std::memory_order_relaxed)] = - packet.packet_id; } + // Add the packet to our deduplicator + remote->deduplicator.add_packet(packet.packet_id); + packet_callback(*remote, packet.hash, packet.reliable, std::move(out)); } else { @@ -851,17 +845,12 @@ namespace extension { &part.data + p.second.size() - sizeof(DataPacket) + 1); } + // Add the packet to our deduplicator + remote->deduplicator.add_packet(packet.packet_id); + // Send our assembled data packet packet_callback(*remote, packet.hash, packet.reliable, std::move(out)); - // If the packet was reliable add that it was recently received - if (packet.reliable) { - // Set this packet to have been recently received - remote->recent_packets[remote->recent_packets_index - .fetch_add(1, std::memory_order_relaxed)] = - packet.packet_id; - } - // We have completed this packet, discard the data assemblers.erase(assemblers.find(packet.packet_id)); } diff --git a/src/extension/network/NUClearNetwork.hpp b/src/extension/network/NUClearNetwork.hpp index 4f75c840..d5f17e6d 100644 --- a/src/extension/network/NUClearNetwork.hpp +++ b/src/extension/network/NUClearNetwork.hpp @@ -39,6 +39,7 @@ #include "../../util/network/sock_t.hpp" #include "../../util/platform.hpp" +#include "PacketDeduplicator.hpp" #include "RTTEstimator.hpp" #include "wire_protocol.hpp" @@ -57,11 +58,7 @@ namespace extension { std::string name, const sock_t& target, const std::chrono::steady_clock::time_point& last_update = std::chrono::steady_clock::now()) - : name(std::move(name)), target(target), last_update(last_update) { - - // Set our recent packets to an invalid value - recent_packets.fill(-1); - } + : name(std::move(name)), target(target), last_update(last_update) {} /// The name of the remote target std::string name; @@ -69,10 +66,6 @@ namespace extension { sock_t target{}; /// When we last received data from the remote target std::chrono::steady_clock::time_point last_update; - /// A list of the last n packet groups to be received - std::array::max()> recent_packets{}; - /// An index for the recent_packets (circular buffer) - std::atomic recent_packets_index{0}; /// Mutex to protect the fragmented packet storage std::mutex assemblers_mutex; /// Storage for fragmented packets while we build them @@ -82,6 +75,9 @@ namespace extension { /// RTT estimator for this network target RTTEstimator rtt; + + /// Packet deduplicator for this network target + PacketDeduplicator deduplicator; }; NUClearNetwork() = default; diff --git a/src/extension/network/PacketDeduplicator.cpp b/src/extension/network/PacketDeduplicator.cpp new file mode 100644 index 00000000..ac13666b --- /dev/null +++ b/src/extension/network/PacketDeduplicator.cpp @@ -0,0 +1,75 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "PacketDeduplicator.hpp" + +namespace NUClear { +namespace extension { + namespace network { + + PacketDeduplicator::PacketDeduplicator() : newest_seen(0) {} + + bool PacketDeduplicator::is_duplicate(uint16_t packet_id) const { + // If we haven't seen any packets yet, nothing is a duplicate + if (!initialized) { + return false; + } + + // Calculate relative position in window using unsigned subtraction + uint16_t relative_id = newest_seen - packet_id; + + // If the packet is too old or too new, it's not a duplicate + if (relative_id >= 256) { + return false; + } + + return window[relative_id]; + } + + void PacketDeduplicator::add_packet(uint16_t packet_id) { + // If this is our first packet, just set it as newest_seen + if (!initialized) { + newest_seen = packet_id; + window[0] = true; + initialized = true; + return; + } + + // Calculate relative position in window using unsigned subtraction + uint16_t relative_id = newest_seen - packet_id; + + // If the distance is more than half the range, the packet is newer than our newest_seen + if (relative_id > 32768) { + // Calculate how far to shift to make this packet our newest + uint16_t shift_amount = packet_id - newest_seen; + window <<= shift_amount; + newest_seen = packet_id; + window[0] = true; + } + // Packet is recent enough to be counted + else if (relative_id < 256) { + window[relative_id] = true; + } + } + + } // namespace network +} // namespace extension +} // namespace NUClear diff --git a/src/extension/network/PacketDeduplicator.hpp b/src/extension/network/PacketDeduplicator.hpp new file mode 100644 index 00000000..36df9804 --- /dev/null +++ b/src/extension/network/PacketDeduplicator.hpp @@ -0,0 +1,69 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#ifndef NUCLEAR_EXTENSION_NETWORK_PACKET_DEDUPLICATOR_HPP +#define NUCLEAR_EXTENSION_NETWORK_PACKET_DEDUPLICATOR_HPP + +#include +#include + +namespace NUClear { +namespace extension { + namespace network { + + /** + * A class that implements a sliding window bitset for packet deduplication. + * Maintains a 256-bit window of recently seen packet IDs, sliding forward as new packets are added. + */ + class PacketDeduplicator { + public: + PacketDeduplicator(); + + /** + * Check if a packet ID has been seen recently + * + * @param packet_id The packet ID to check + * + * @return true if the packet has been seen recently, false otherwise + */ + bool is_duplicate(uint16_t packet_id) const; + + /** + * Add a packet ID to the window + * + * @param packet_id The packet ID to add + */ + void add_packet(uint16_t packet_id); + + private: + /// Whether we've seen any packets yet + bool initialized{false}; + /// The newest packet ID we've seen + uint16_t newest_seen; + /// The 256-bit window of seen packets (newest at 0, older at higher indices) + std::bitset<256> window; + }; + + } // namespace network +} // namespace extension +} // namespace NUClear + +#endif // NUCLEAR_EXTENSION_NETWORK_PACKET_DEDUPLICATOR_HPP diff --git a/tests/tests/network/PacketDeduplicator.cpp b/tests/tests/network/PacketDeduplicator.cpp new file mode 100644 index 00000000..297600bd --- /dev/null +++ b/tests/tests/network/PacketDeduplicator.cpp @@ -0,0 +1,183 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "extension/network/PacketDeduplicator.hpp" + +#include + +using namespace NUClear::extension::network; + +SCENARIO("PacketDeduplicator basic functionality", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we check a new packet") { + THEN("it should not be a duplicate") { + REQUIRE_FALSE(dedup.is_duplicate(1)); + } + + AND_WHEN("we add the packet") { + dedup.add_packet(1); + + THEN("it should be marked as a duplicate") { + REQUIRE(dedup.is_duplicate(1)); + } + + AND_WHEN("we check a different packet") { + THEN("it should not be a duplicate") { + REQUIRE_FALSE(dedup.is_duplicate(2)); + } + + AND_WHEN("we add the second packet") { + dedup.add_packet(2); + + THEN("both packets should be marked as duplicates") { + REQUIRE(dedup.is_duplicate(1)); + REQUIRE(dedup.is_duplicate(2)); + } + } + } + } + } + } +} + +SCENARIO("PacketDeduplicator window sliding", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add packets up to the window size") { + for (uint16_t i = 0; i < 256; ++i) { + dedup.add_packet(i); + } + + THEN("all packets should be marked as duplicates") { + for (uint16_t i = 0; i < 256; ++i) { + REQUIRE(dedup.is_duplicate(i)); + } + } + + AND_WHEN("we add a packet beyond the window") { + dedup.add_packet(256); + + THEN("the oldest packet should be forgotten") { + REQUIRE_FALSE(dedup.is_duplicate(0)); + } + + AND_THEN("the newest packet should be remembered") { + REQUIRE(dedup.is_duplicate(256)); + } + } + } + } +} + +SCENARIO("PacketDeduplicator out of order packets", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add packets out of order") { + dedup.add_packet(5); + dedup.add_packet(3); + dedup.add_packet(7); + dedup.add_packet(1); + + THEN("all added packets should be marked as duplicates") { + REQUIRE(dedup.is_duplicate(1)); + REQUIRE(dedup.is_duplicate(3)); + REQUIRE(dedup.is_duplicate(5)); + REQUIRE(dedup.is_duplicate(7)); + } + + AND_THEN("unseen packets should not be marked as duplicates") { + REQUIRE_FALSE(dedup.is_duplicate(2)); + REQUIRE_FALSE(dedup.is_duplicate(4)); + REQUIRE_FALSE(dedup.is_duplicate(6)); + REQUIRE_FALSE(dedup.is_duplicate(8)); + } + } + } +} + +SCENARIO("PacketDeduplicator packet wrap around", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add packets near the uint16_t wrap around point") { + uint16_t start = 65530; + + for (uint16_t i = 0; i < 10; ++i) { + uint16_t packet_id = (start + i) % 65536; + dedup.add_packet(packet_id); + } + + THEN("all added packets should be marked as duplicates") { + for (uint16_t i = 0; i < 10; ++i) { + uint16_t packet_id = (start + i) % 65536; + REQUIRE(dedup.is_duplicate(packet_id)); + } + } + + AND_THEN("packets before the window should not be marked as duplicates") { + REQUIRE_FALSE(dedup.is_duplicate(start - 1)); + } + } + } +} + +SCENARIO("PacketDeduplicator old packets", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add a packet and then slide the window past it") { + dedup.add_packet(1); + + // Slide window past packet 1 + for (uint16_t i = 2; i < 258; ++i) { + dedup.add_packet(i); + } + + THEN("the old packet should not be marked as a duplicate") { + REQUIRE_FALSE(dedup.is_duplicate(1)); + } + + AND_THEN("recent packets should still be marked as duplicates") { + REQUIRE(dedup.is_duplicate(256)); + REQUIRE(dedup.is_duplicate(257)); + } + } + } +} + +SCENARIO("PacketDeduplicator handles high initial packet IDs correctly", "[network]") { + GIVEN("A PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("The first packet ID is greater than uint16_t max/2") { + uint16_t high_id = 40000; // > 32768 + dedup.add_packet(high_id); + + THEN("The first packet should be marked as a duplicate") { + REQUIRE(dedup.is_duplicate(high_id)); + } + } + } +} From 95ec398105afc240a98a738bbcd49ca85dd346f2 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Sun, 18 May 2025 21:05:29 +1000 Subject: [PATCH 04/14] lint --- src/extension/network/PacketDeduplicator.cpp | 8 ++++---- src/extension/network/PacketDeduplicator.hpp | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/extension/network/PacketDeduplicator.cpp b/src/extension/network/PacketDeduplicator.cpp index ac13666b..5415448c 100644 --- a/src/extension/network/PacketDeduplicator.cpp +++ b/src/extension/network/PacketDeduplicator.cpp @@ -25,7 +25,7 @@ namespace NUClear { namespace extension { namespace network { - PacketDeduplicator::PacketDeduplicator() : newest_seen(0) {} + PacketDeduplicator::PacketDeduplicator() {} bool PacketDeduplicator::is_duplicate(uint16_t packet_id) const { // If we haven't seen any packets yet, nothing is a duplicate @@ -34,7 +34,7 @@ namespace extension { } // Calculate relative position in window using unsigned subtraction - uint16_t relative_id = newest_seen - packet_id; + const uint16_t relative_id = newest_seen - packet_id; // If the packet is too old or too new, it's not a duplicate if (relative_id >= 256) { @@ -54,12 +54,12 @@ namespace extension { } // Calculate relative position in window using unsigned subtraction - uint16_t relative_id = newest_seen - packet_id; + const uint16_t relative_id = newest_seen - packet_id; // If the distance is more than half the range, the packet is newer than our newest_seen if (relative_id > 32768) { // Calculate how far to shift to make this packet our newest - uint16_t shift_amount = packet_id - newest_seen; + const uint16_t shift_amount = packet_id - newest_seen; window <<= shift_amount; newest_seen = packet_id; window[0] = true; diff --git a/src/extension/network/PacketDeduplicator.hpp b/src/extension/network/PacketDeduplicator.hpp index 36df9804..53b53c50 100644 --- a/src/extension/network/PacketDeduplicator.hpp +++ b/src/extension/network/PacketDeduplicator.hpp @@ -57,7 +57,7 @@ namespace extension { /// Whether we've seen any packets yet bool initialized{false}; /// The newest packet ID we've seen - uint16_t newest_seen; + uint16_t newest_seen{0}; /// The 256-bit window of seen packets (newest at 0, older at higher indices) std::bitset<256> window; }; From ce450ae0736f7add2d4b98c29280fe3bc17e9672 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Sun, 18 May 2025 21:09:12 +1000 Subject: [PATCH 05/14] . --- src/extension/network/NUClearNetwork.cpp | 1 - src/extension/network/PacketDeduplicator.cpp | 3 ++- src/extension/network/PacketDeduplicator.hpp | 2 -- 3 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/extension/network/NUClearNetwork.cpp b/src/extension/network/NUClearNetwork.cpp index 2f75acdd..284d89a3 100644 --- a/src/extension/network/NUClearNetwork.cpp +++ b/src/extension/network/NUClearNetwork.cpp @@ -24,7 +24,6 @@ #include #include -#include #include #include #include diff --git a/src/extension/network/PacketDeduplicator.cpp b/src/extension/network/PacketDeduplicator.cpp index 5415448c..c252a1f3 100644 --- a/src/extension/network/PacketDeduplicator.cpp +++ b/src/extension/network/PacketDeduplicator.cpp @@ -21,11 +21,12 @@ */ #include "PacketDeduplicator.hpp" +#include + namespace NUClear { namespace extension { namespace network { - PacketDeduplicator::PacketDeduplicator() {} bool PacketDeduplicator::is_duplicate(uint16_t packet_id) const { // If we haven't seen any packets yet, nothing is a duplicate diff --git a/src/extension/network/PacketDeduplicator.hpp b/src/extension/network/PacketDeduplicator.hpp index 53b53c50..494f6d88 100644 --- a/src/extension/network/PacketDeduplicator.hpp +++ b/src/extension/network/PacketDeduplicator.hpp @@ -35,8 +35,6 @@ namespace extension { */ class PacketDeduplicator { public: - PacketDeduplicator(); - /** * Check if a packet ID has been seen recently * From c7ec4fa4217904feb57d77f0ba8a2c29d74668fe Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Sun, 18 May 2025 21:13:13 +1000 Subject: [PATCH 06/14] . --- src/extension/network/RTTEstimator.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/extension/network/RTTEstimator.cpp b/src/extension/network/RTTEstimator.cpp index 00567616..71ac81d8 100644 --- a/src/extension/network/RTTEstimator.cpp +++ b/src/extension/network/RTTEstimator.cpp @@ -21,6 +21,8 @@ */ #include "RTTEstimator.hpp" +#include +#include #include namespace NUClear { From b540a01affa318512f3564869c466c9b254c7175 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Sun, 18 May 2025 21:27:47 +1000 Subject: [PATCH 07/14] Improve RTTEstimator coverage --- src/extension/network/RTTEstimator.cpp | 25 ++ src/extension/network/RTTEstimator.hpp | 20 +- tests/tests/network/RTTEstimator.cpp | 363 ++++++++++++++----------- 3 files changed, 226 insertions(+), 182 deletions(-) diff --git a/src/extension/network/RTTEstimator.cpp b/src/extension/network/RTTEstimator.cpp index 71ac81d8..13876130 100644 --- a/src/extension/network/RTTEstimator.cpp +++ b/src/extension/network/RTTEstimator.cpp @@ -29,6 +29,31 @@ namespace NUClear { namespace extension { namespace network { + RTTEstimator::RTTEstimator(float alpha, + float beta, + float initial_rtt, + float initial_rtt_var, + float min_rto, + float max_rto) + : alpha(alpha) + , beta(beta) + , min_rto(min_rto) + , max_rto(max_rto) + , smoothed_rtt(initial_rtt) + , rtt_var(initial_rtt_var) + , rto(std::min(std::max(initial_rtt + 4 * initial_rtt_var, min_rto), max_rto)) { + + if (alpha < 0.0f || alpha > 1.0f) { + throw std::invalid_argument("alpha must be in range [0,1]"); + } + if (beta < 0.0f || beta > 1.0f) { + throw std::invalid_argument("beta must be in range [0,1]"); + } + if (min_rto >= max_rto) { + throw std::invalid_argument("min_rto must be less than max_rto"); + } + } + void RTTEstimator::measure(std::chrono::steady_clock::duration time) { // Convert measurement to float seconds const std::chrono::duration m = std::chrono::duration_cast>(time); diff --git a/src/extension/network/RTTEstimator.hpp b/src/extension/network/RTTEstimator.hpp index 164bd74e..481659ea 100644 --- a/src/extension/network/RTTEstimator.hpp +++ b/src/extension/network/RTTEstimator.hpp @@ -66,25 +66,7 @@ namespace extension { float initial_rtt = 1.0f, float initial_rtt_var = 0.0f, float min_rto = 0.1f, - float max_rto = 60.0f) - : alpha(alpha) - , beta(beta) - , min_rto(min_rto) - , max_rto(max_rto) - , smoothed_rtt(initial_rtt) - , rtt_var(initial_rtt_var) - , rto(std::min(std::max(initial_rtt + 4 * initial_rtt_var, min_rto), max_rto)) { - - if (alpha < 0.0f || alpha > 1.0f) { - throw std::invalid_argument("alpha must be in range [0,1]"); - } - if (beta < 0.0f || beta > 1.0f) { - throw std::invalid_argument("beta must be in range [0,1]"); - } - if (min_rto >= max_rto) { - throw std::invalid_argument("min_rto must be less than max_rto"); - } - } + float max_rto = 60.0f); /** * Update the RTT estimate with a new measurement diff --git a/tests/tests/network/RTTEstimator.cpp b/tests/tests/network/RTTEstimator.cpp index d272c0c9..734fd92c 100644 --- a/tests/tests/network/RTTEstimator.cpp +++ b/tests/tests/network/RTTEstimator.cpp @@ -24,229 +24,266 @@ #include #include -using namespace NUClear::extension::network; -using namespace std::chrono_literals; +namespace NUClear { +namespace extension { + namespace network { -SCENARIO("RTTEstimator initial state", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 1.0f, 0.0f); + using namespace std::chrono_literals; - THEN("the initial timeout should be 1 second") { - REQUIRE(rtt.timeout() == 1s); - } - } -} - -SCENARIO("RTTEstimator with constant RTT", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + SCENARIO("RTTEstimator initial state", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 1.0f, 0.0f); - WHEN("we feed it constant RTTs of 100ms") { - for (int i = 0; i < 20; ++i) { - rtt.measure(100ms); - } - - THEN("the timeout should be at least 100ms and not unreasonably high") { - REQUIRE(rtt.timeout() >= 100ms); - REQUIRE(rtt.timeout() <= 200ms); + THEN("the initial timeout should be 1 second") { + REQUIRE(rtt.timeout() == 1s); + } } } - } -} -SCENARIO("RTTEstimator with increasing RTT", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + SCENARIO("RTTEstimator with constant RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); - WHEN("we measure 100ms then 200ms") { - rtt.measure(100ms); - rtt.measure(200ms); + WHEN("we feed it constant RTTs of 100ms") { + for (int i = 0; i < 20; ++i) { + rtt.measure(100ms); + } - THEN("the timeout should be at least 200ms and not unreasonably high") { - REQUIRE(rtt.timeout() >= 200ms); - REQUIRE(rtt.timeout() <= 400ms); + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 200ms); + } + } } } - } -} -SCENARIO("RTTEstimator with decreasing RTT", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 0.2f, 0.0f); + SCENARIO("RTTEstimator with increasing RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); - WHEN("we measure 200ms then 100ms") { - rtt.measure(200ms); - rtt.measure(100ms); + WHEN("we measure 100ms then 200ms") { + rtt.measure(100ms); + rtt.measure(200ms); - THEN("the timeout should be at least 100ms and not unreasonably high") { - REQUIRE(rtt.timeout() >= 100ms); - REQUIRE(rtt.timeout() <= 400ms); + THEN("the timeout should be at least 200ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 200ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } } } - } -} -SCENARIO("RTTEstimator with oscillating RTT", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 0.15f, 0.0f); + SCENARIO("RTTEstimator with decreasing RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.2f, 0.0f); - WHEN("we feed it alternating RTTs of 100ms and 200ms") { - for (int i = 0; i < 20; ++i) { - rtt.measure(i % 2 == 0 ? 100ms : 200ms); - } + WHEN("we measure 200ms then 100ms") { + rtt.measure(200ms); + rtt.measure(100ms); - THEN("the timeout should be at least 100ms and not unreasonably high") { - REQUIRE(rtt.timeout() >= 100ms); - REQUIRE(rtt.timeout() <= 400ms); + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } } } - } -} -SCENARIO("RTTEstimator with large RTT variation", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + SCENARIO("RTTEstimator with oscillating RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.15f, 0.0f); - WHEN("we measure 100ms then 1 second") { - rtt.measure(100ms); - rtt.measure(1s); + WHEN("we feed it alternating RTTs of 100ms and 200ms") { + for (int i = 0; i < 20; ++i) { + rtt.measure(i % 2 == 0 ? 100ms : 200ms); + } - THEN("the timeout should be at least 1s and not unreasonably high") { - REQUIRE(rtt.timeout() >= 1s); - REQUIRE(rtt.timeout() <= 2s); + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } } } - } -} -SCENARIO("RTTEstimator with small RTT variation", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + SCENARIO("RTTEstimator with large RTT variation", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); - WHEN("we measure 100ms then 110ms") { - rtt.measure(100ms); - rtt.measure(110ms); + WHEN("we measure 100ms then 1 second") { + rtt.measure(100ms); + rtt.measure(1s); - THEN("the timeout should be at least 110ms and not unreasonably high") { - REQUIRE(rtt.timeout() >= 110ms); - REQUIRE(rtt.timeout() <= 200ms); + THEN("the timeout should be at least 1s and not unreasonably high") { + REQUIRE(rtt.timeout() >= 1s); + REQUIRE(rtt.timeout() <= 2s); + } + } } } - } -} -SCENARIO("RTTEstimator with zero RTT", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 0.001f, 0.0f); + SCENARIO("RTTEstimator with small RTT variation", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); - WHEN("we measure 0ms") { - rtt.measure(0ms); + WHEN("we measure 100ms then 110ms") { + rtt.measure(100ms); + rtt.measure(110ms); - THEN("the timeout should be at least 0ms and not unreasonably high") { - REQUIRE(rtt.timeout() >= 0ms); - REQUIRE(rtt.timeout() <= 1s); + THEN("the timeout should be at least 110ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 110ms); + REQUIRE(rtt.timeout() <= 200ms); + } + } } } - } -} -SCENARIO("RTTEstimator with very large RTT", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 30.0f, 0.0f); + SCENARIO("RTTEstimator with zero RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.001f, 0.0f); - WHEN("we measure 30 seconds") { - rtt.measure(30s); + WHEN("we measure 0ms") { + rtt.measure(0ms); - THEN("the timeout should be at least 30s and not unreasonably high") { - REQUIRE(rtt.timeout() >= 30s); - REQUIRE(rtt.timeout() <= 35s); + THEN("the timeout should be at least 0ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 0ms); + REQUIRE(rtt.timeout() <= 1s); + } + } } } - } -} -SCENARIO("RTTEstimator exact calculation verification", "[network]") { - GIVEN("a new RTTEstimator with known initial state") { - // Initialize with SRTT = 100ms, RTTVAR = 50ms - RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.05f); + SCENARIO("RTTEstimator with very large RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 30.0f, 0.0f); - WHEN("we measure a 120ms RTT") { - rtt.measure(120ms); + WHEN("we measure 30 seconds") { + rtt.measure(30s); - THEN("the values should match the TCP calculation") { - // Expected values: - // RTTVAR = (1 - 0.25) * 50 + 0.25 * |100 - 120| = 0.75 * 50 + 0.25 * 20 = 37.5 + 5 = 42.5ms - // SRTT = (1 - 0.125) * 100 + 0.125 * 120 = 87.5 + 15 = 102.5ms - // RTO = 102.5 + 4 * 42.5 = 272.5ms - REQUIRE(rtt.timeout() >= 270ms); - REQUIRE(rtt.timeout() <= 275ms); + THEN("the timeout should be at least 30s and not unreasonably high") { + REQUIRE(rtt.timeout() >= 30s); + REQUIRE(rtt.timeout() <= 35s); + } + } } } - } -} - -SCENARIO("RTTEstimator spike response", "[network]") { - GIVEN("a new RTTEstimator with stable RTT") { - RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); - - WHEN("we feed it constant 100ms RTTs then a 500ms spike") { - // First establish a stable RTT - for (int i = 0; i < 10; ++i) { - rtt.measure(100ms); - } - auto before_spike = rtt.timeout(); - // Then inject a spike - rtt.measure(500ms); - auto after_spike = rtt.timeout(); - - THEN("the timeout should increase but not dramatically") { - REQUIRE(after_spike > before_spike); // Should increase - REQUIRE(after_spike < 1s); // But not too much + SCENARIO("RTTEstimator exact calculation verification", "[network]") { + GIVEN("a new RTTEstimator with known initial state") { + // Initialize with SRTT = 100ms, RTTVAR = 50ms + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.05f); + + WHEN("we measure a 120ms RTT") { + rtt.measure(120ms); + + THEN("the values should match the TCP calculation") { + // Expected values: + // RTTVAR = (1 - 0.25) * 50 + 0.25 * |100 - 120| = 0.75 * 50 + 0.25 * 20 = 37.5 + 5 = 42.5ms + // SRTT = (1 - 0.125) * 100 + 0.125 * 120 = 87.5 + 15 = 102.5ms + // RTO = 102.5 + 4 * 42.5 = 272.5ms + REQUIRE(rtt.timeout() >= 270ms); + REQUIRE(rtt.timeout() <= 275ms); + } + } } + } - AND_WHEN("we return to normal RTT") { - for (int i = 0; i < 10; ++i) { - rtt.measure(100ms); + SCENARIO("RTTEstimator spike response", "[network]") { + GIVEN("a new RTTEstimator with stable RTT") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we feed it constant 100ms RTTs then a 500ms spike") { + // First establish a stable RTT + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto before_spike = rtt.timeout(); + + // Then inject a spike + rtt.measure(500ms); + auto after_spike = rtt.timeout(); + + THEN("the timeout should increase but not dramatically") { + REQUIRE(after_spike > before_spike); // Should increase + REQUIRE(after_spike < 1s); // But not too much + } + + AND_WHEN("we return to normal RTT") { + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto after_recovery = rtt.timeout(); + + THEN("it should recover to a reasonable value") { + REQUIRE(after_recovery >= 100ms); + REQUIRE(after_recovery <= 300ms); + } + } } - auto after_recovery = rtt.timeout(); + } + } - THEN("it should recover to a reasonable value") { - REQUIRE(after_recovery >= 100ms); - REQUIRE(after_recovery <= 300ms); + SCENARIO("RTTEstimator noise resilience", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we feed it noisy RTTs around 100ms") { + // Generate noisy RTTs: 100ms ± 20ms + for (int i = 0; i < 50; ++i) { + auto noise = (i % 2 == 0 ? 20ms : -20ms); + rtt.measure(100ms + noise); + } + + THEN("the timeout should remain stable") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 300ms); + } + + AND_WHEN("we continue with constant RTT") { + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto final_timeout = rtt.timeout(); + + THEN("it should converge to the true RTT") { + REQUIRE(final_timeout >= 100ms); + REQUIRE(final_timeout <= 200ms); + } + } } } } - } -} -SCENARIO("RTTEstimator noise resilience", "[network]") { - GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); - - WHEN("we feed it noisy RTTs around 100ms") { - // Generate noisy RTTs: 100ms ± 20ms - for (int i = 0; i < 50; ++i) { - auto noise = (i % 2 == 0 ? 20ms : -20ms); - rtt.measure(100ms + noise); + SCENARIO("RTTEstimator constructor validation", "[network]") { + GIVEN("invalid alpha values") { + THEN("it should throw std::invalid_argument") { + REQUIRE_THROWS_AS(RTTEstimator(-0.1f), std::invalid_argument); + REQUIRE_THROWS_AS(RTTEstimator(1.1f), std::invalid_argument); + } } - THEN("the timeout should remain stable") { - REQUIRE(rtt.timeout() >= 100ms); - REQUIRE(rtt.timeout() <= 300ms); + GIVEN("invalid beta values") { + THEN("it should throw std::invalid_argument") { + REQUIRE_THROWS_AS(RTTEstimator(0.125f, -0.1f), std::invalid_argument); + REQUIRE_THROWS_AS(RTTEstimator(0.125f, 1.1f), std::invalid_argument); + } } - AND_WHEN("we continue with constant RTT") { - for (int i = 0; i < 10; ++i) { - rtt.measure(100ms); + GIVEN("invalid min_rto/max_rto combinations") { + THEN("it should throw std::invalid_argument") { + REQUIRE_THROWS_AS(RTTEstimator(0.125f, 0.25f, 1.0f, 0.0f, 1.0f, 0.5f), std::invalid_argument); + REQUIRE_THROWS_AS(RTTEstimator(0.125f, 0.25f, 1.0f, 0.0f, 1.0f, 1.0f), std::invalid_argument); } - auto final_timeout = rtt.timeout(); + } - THEN("it should converge to the true RTT") { - REQUIRE(final_timeout >= 100ms); - REQUIRE(final_timeout <= 200ms); + GIVEN("valid arguments") { + THEN("it should not throw") { + REQUIRE_NOTHROW(RTTEstimator()); + REQUIRE_NOTHROW(RTTEstimator(0.125f, 0.25f, 1.0f, 0.0f, 0.1f, 60.0f)); } } } - } -} + + } // namespace network +} // namespace extension +} // namespace NUClear From 6e07f82503c0dfb65db120362afa70745b6a3dfc Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Sun, 18 May 2025 21:31:17 +1000 Subject: [PATCH 08/14] Fix test namespace --- tests/tests/network/PacketDeduplicator.cpp | 236 +++++++++++---------- 1 file changed, 121 insertions(+), 115 deletions(-) diff --git a/tests/tests/network/PacketDeduplicator.cpp b/tests/tests/network/PacketDeduplicator.cpp index 297600bd..5fa6cbb9 100644 --- a/tests/tests/network/PacketDeduplicator.cpp +++ b/tests/tests/network/PacketDeduplicator.cpp @@ -23,161 +23,167 @@ #include -using namespace NUClear::extension::network; +namespace NUClear { +namespace extension { + namespace network { -SCENARIO("PacketDeduplicator basic functionality", "[network]") { - GIVEN("a new PacketDeduplicator") { - PacketDeduplicator dedup; + SCENARIO("PacketDeduplicator basic functionality", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; - WHEN("we check a new packet") { - THEN("it should not be a duplicate") { - REQUIRE_FALSE(dedup.is_duplicate(1)); - } - - AND_WHEN("we add the packet") { - dedup.add_packet(1); - - THEN("it should be marked as a duplicate") { - REQUIRE(dedup.is_duplicate(1)); - } - - AND_WHEN("we check a different packet") { + WHEN("we check a new packet") { THEN("it should not be a duplicate") { - REQUIRE_FALSE(dedup.is_duplicate(2)); + REQUIRE_FALSE(dedup.is_duplicate(1)); } - AND_WHEN("we add the second packet") { - dedup.add_packet(2); + AND_WHEN("we add the packet") { + dedup.add_packet(1); - THEN("both packets should be marked as duplicates") { + THEN("it should be marked as a duplicate") { REQUIRE(dedup.is_duplicate(1)); - REQUIRE(dedup.is_duplicate(2)); + } + + AND_WHEN("we check a different packet") { + THEN("it should not be a duplicate") { + REQUIRE_FALSE(dedup.is_duplicate(2)); + } + + AND_WHEN("we add the second packet") { + dedup.add_packet(2); + + THEN("both packets should be marked as duplicates") { + REQUIRE(dedup.is_duplicate(1)); + REQUIRE(dedup.is_duplicate(2)); + } + } } } } } } - } -} -SCENARIO("PacketDeduplicator window sliding", "[network]") { - GIVEN("a new PacketDeduplicator") { - PacketDeduplicator dedup; + SCENARIO("PacketDeduplicator window sliding", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; - WHEN("we add packets up to the window size") { - for (uint16_t i = 0; i < 256; ++i) { - dedup.add_packet(i); - } + WHEN("we add packets up to the window size") { + for (uint16_t i = 0; i < 256; ++i) { + dedup.add_packet(i); + } - THEN("all packets should be marked as duplicates") { - for (uint16_t i = 0; i < 256; ++i) { - REQUIRE(dedup.is_duplicate(i)); - } - } + THEN("all packets should be marked as duplicates") { + for (uint16_t i = 0; i < 256; ++i) { + REQUIRE(dedup.is_duplicate(i)); + } + } - AND_WHEN("we add a packet beyond the window") { - dedup.add_packet(256); + AND_WHEN("we add a packet beyond the window") { + dedup.add_packet(256); - THEN("the oldest packet should be forgotten") { - REQUIRE_FALSE(dedup.is_duplicate(0)); - } + THEN("the oldest packet should be forgotten") { + REQUIRE_FALSE(dedup.is_duplicate(0)); + } - AND_THEN("the newest packet should be remembered") { - REQUIRE(dedup.is_duplicate(256)); + AND_THEN("the newest packet should be remembered") { + REQUIRE(dedup.is_duplicate(256)); + } + } } } } - } -} - -SCENARIO("PacketDeduplicator out of order packets", "[network]") { - GIVEN("a new PacketDeduplicator") { - PacketDeduplicator dedup; - - WHEN("we add packets out of order") { - dedup.add_packet(5); - dedup.add_packet(3); - dedup.add_packet(7); - dedup.add_packet(1); - - THEN("all added packets should be marked as duplicates") { - REQUIRE(dedup.is_duplicate(1)); - REQUIRE(dedup.is_duplicate(3)); - REQUIRE(dedup.is_duplicate(5)); - REQUIRE(dedup.is_duplicate(7)); - } - AND_THEN("unseen packets should not be marked as duplicates") { - REQUIRE_FALSE(dedup.is_duplicate(2)); - REQUIRE_FALSE(dedup.is_duplicate(4)); - REQUIRE_FALSE(dedup.is_duplicate(6)); - REQUIRE_FALSE(dedup.is_duplicate(8)); + SCENARIO("PacketDeduplicator out of order packets", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add packets out of order") { + dedup.add_packet(5); + dedup.add_packet(3); + dedup.add_packet(7); + dedup.add_packet(1); + + THEN("all added packets should be marked as duplicates") { + REQUIRE(dedup.is_duplicate(1)); + REQUIRE(dedup.is_duplicate(3)); + REQUIRE(dedup.is_duplicate(5)); + REQUIRE(dedup.is_duplicate(7)); + } + + AND_THEN("unseen packets should not be marked as duplicates") { + REQUIRE_FALSE(dedup.is_duplicate(2)); + REQUIRE_FALSE(dedup.is_duplicate(4)); + REQUIRE_FALSE(dedup.is_duplicate(6)); + REQUIRE_FALSE(dedup.is_duplicate(8)); + } + } } } - } -} -SCENARIO("PacketDeduplicator packet wrap around", "[network]") { - GIVEN("a new PacketDeduplicator") { - PacketDeduplicator dedup; + SCENARIO("PacketDeduplicator packet wrap around", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; - WHEN("we add packets near the uint16_t wrap around point") { - uint16_t start = 65530; + WHEN("we add packets near the uint16_t wrap around point") { + uint16_t start = 65530; - for (uint16_t i = 0; i < 10; ++i) { - uint16_t packet_id = (start + i) % 65536; - dedup.add_packet(packet_id); - } + for (uint16_t i = 0; i < 10; ++i) { + uint16_t packet_id = (start + i) % 65536; + dedup.add_packet(packet_id); + } - THEN("all added packets should be marked as duplicates") { - for (uint16_t i = 0; i < 10; ++i) { - uint16_t packet_id = (start + i) % 65536; - REQUIRE(dedup.is_duplicate(packet_id)); - } - } + THEN("all added packets should be marked as duplicates") { + for (uint16_t i = 0; i < 10; ++i) { + uint16_t packet_id = (start + i) % 65536; + REQUIRE(dedup.is_duplicate(packet_id)); + } + } - AND_THEN("packets before the window should not be marked as duplicates") { - REQUIRE_FALSE(dedup.is_duplicate(start - 1)); + AND_THEN("packets before the window should not be marked as duplicates") { + REQUIRE_FALSE(dedup.is_duplicate(start - 1)); + } + } } } - } -} -SCENARIO("PacketDeduplicator old packets", "[network]") { - GIVEN("a new PacketDeduplicator") { - PacketDeduplicator dedup; + SCENARIO("PacketDeduplicator old packets", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; - WHEN("we add a packet and then slide the window past it") { - dedup.add_packet(1); + WHEN("we add a packet and then slide the window past it") { + dedup.add_packet(1); - // Slide window past packet 1 - for (uint16_t i = 2; i < 258; ++i) { - dedup.add_packet(i); - } + // Slide window past packet 1 + for (uint16_t i = 2; i < 258; ++i) { + dedup.add_packet(i); + } - THEN("the old packet should not be marked as a duplicate") { - REQUIRE_FALSE(dedup.is_duplicate(1)); - } + THEN("the old packet should not be marked as a duplicate") { + REQUIRE_FALSE(dedup.is_duplicate(1)); + } - AND_THEN("recent packets should still be marked as duplicates") { - REQUIRE(dedup.is_duplicate(256)); - REQUIRE(dedup.is_duplicate(257)); + AND_THEN("recent packets should still be marked as duplicates") { + REQUIRE(dedup.is_duplicate(256)); + REQUIRE(dedup.is_duplicate(257)); + } + } } } - } -} -SCENARIO("PacketDeduplicator handles high initial packet IDs correctly", "[network]") { - GIVEN("A PacketDeduplicator") { - PacketDeduplicator dedup; + SCENARIO("PacketDeduplicator handles high initial packet IDs correctly", "[network]") { + GIVEN("A PacketDeduplicator") { + PacketDeduplicator dedup; - WHEN("The first packet ID is greater than uint16_t max/2") { - uint16_t high_id = 40000; // > 32768 - dedup.add_packet(high_id); + WHEN("The first packet ID is greater than uint16_t max/2") { + uint16_t high_id = 40000; // > 32768 + dedup.add_packet(high_id); - THEN("The first packet should be marked as a duplicate") { - REQUIRE(dedup.is_duplicate(high_id)); + THEN("The first packet should be marked as a duplicate") { + REQUIRE(dedup.is_duplicate(high_id)); + } + } } } - } -} + + } // namespace network +} // namespace extension +} // namespace NUClear From 9f5c6aa6e7ba1f4ee894c0842bc4051d7440963f Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Sun, 18 May 2025 21:33:06 +1000 Subject: [PATCH 09/14] Coverage --- tests/tests/network/PacketDeduplicator.cpp | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/tests/tests/network/PacketDeduplicator.cpp b/tests/tests/network/PacketDeduplicator.cpp index 5fa6cbb9..23820fb7 100644 --- a/tests/tests/network/PacketDeduplicator.cpp +++ b/tests/tests/network/PacketDeduplicator.cpp @@ -184,6 +184,21 @@ namespace extension { } } + SCENARIO("PacketDeduplicator handles adding old packets", "[network]") { + GIVEN("a PacketDeduplicator with a high packet ID") { + PacketDeduplicator dedup; + dedup.add_packet(512); // Start with a high packet ID + + WHEN("we try to add a much older packet") { + dedup.add_packet(1); // Try to add a packet that's more than 256 behind + + THEN("the old packet should not be marked as seen") { + REQUIRE_FALSE(dedup.is_duplicate(1)); + } + } + } + } + } // namespace network } // namespace extension } // namespace NUClear From f7c1970fec9ce0b07f324b6b57d372bf1cd0a7bd Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Mon, 19 May 2025 09:41:10 +1000 Subject: [PATCH 10/14] linting --- tests/tests/network/PacketDeduplicator.cpp | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/tests/network/PacketDeduplicator.cpp b/tests/tests/network/PacketDeduplicator.cpp index 23820fb7..3653248f 100644 --- a/tests/tests/network/PacketDeduplicator.cpp +++ b/tests/tests/network/PacketDeduplicator.cpp @@ -22,6 +22,7 @@ #include "extension/network/PacketDeduplicator.hpp" #include +#include namespace NUClear { namespace extension { From af33b5131d92c323921ae57fa5cc045662c99585 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Mon, 19 May 2025 09:47:58 +1000 Subject: [PATCH 11/14] merge if statements --- src/extension/network/NUClearNetwork.cpp | 57 +++++++++++------------- 1 file changed, 27 insertions(+), 30 deletions(-) diff --git a/src/extension/network/NUClearNetwork.cpp b/src/extension/network/NUClearNetwork.cpp index 284d89a3..ae0118b7 100644 --- a/src/extension/network/NUClearNetwork.cpp +++ b/src/extension/network/NUClearNetwork.cpp @@ -672,41 +672,38 @@ namespace extension { remote->last_update = std::chrono::steady_clock::now(); // Check if this packet is a retransmission of data - if (header.type == DATA_RETRANSMISSION) { + if (header.type == DATA_RETRANSMISSION + && remote->deduplicator.is_duplicate(packet.packet_id)) { - // We recently processed this packet, this is just a failed ack - if (remote->deduplicator.is_duplicate(packet.packet_id)) { - - // Send the ack again if it was reliable - if (packet.reliable) { - // Allocate room for the whole ack packet - std::vector r(sizeof(ACKPacket) + (packet.packet_count / 8), 0); - ACKPacket& response = *reinterpret_cast(r.data()); - response = ACKPacket(); - response.packet_id = packet.packet_id; - response.packet_no = packet.packet_no; - response.packet_count = packet.packet_count; - - // Set the bits for all packets (we got the whole thing) - for (int i = 0; i < packet.packet_count; ++i) { - (&response.packets)[i / 8] |= uint8_t(1 << (i % 8)); - } - - // Make who we are sending it to into a useable address - const sock_t& to = remote->target; + // Send the ack again if it was reliable + if (packet.reliable) { + // Allocate room for the whole ack packet + std::vector r(sizeof(ACKPacket) + (packet.packet_count / 8), 0); + ACKPacket& response = *reinterpret_cast(r.data()); + response = ACKPacket(); + response.packet_id = packet.packet_id; + response.packet_no = packet.packet_no; + response.packet_count = packet.packet_count; - // Send the packet - ::sendto(data_fd, - reinterpret_cast(r.data()), - static_cast(r.size()), - 0, - &to.sock, - to.size()); + // Set the bits for all packets (we got the whole thing) + for (int i = 0; i < packet.packet_count; ++i) { + (&response.packets)[i / 8] |= uint8_t(1 << (i % 8)); } - // We don't need to process this packet we already did - return; + // Make who we are sending it to into a useable address + const sock_t& to = remote->target; + + // Send the packet + ::sendto(data_fd, + reinterpret_cast(r.data()), + static_cast(r.size()), + 0, + &to.sock, + to.size()); } + + // We don't need to process this packet we already did + return; } // If this is a solo packet (in a single chunk) From 097abd0e67abde2b2f9625f4e5d8dcaab5ea98be Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Mon, 19 May 2025 10:53:25 +1000 Subject: [PATCH 12/14] linting --- tests/tests/network/PacketDeduplicator.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/tests/network/PacketDeduplicator.cpp b/tests/tests/network/PacketDeduplicator.cpp index 3653248f..3a0815f7 100644 --- a/tests/tests/network/PacketDeduplicator.cpp +++ b/tests/tests/network/PacketDeduplicator.cpp @@ -125,16 +125,16 @@ namespace extension { PacketDeduplicator dedup; WHEN("we add packets near the uint16_t wrap around point") { - uint16_t start = 65530; + const uint16_t start = 65530; for (uint16_t i = 0; i < 10; ++i) { - uint16_t packet_id = (start + i) % 65536; + const uint16_t packet_id = (start + i) % 65536; dedup.add_packet(packet_id); } THEN("all added packets should be marked as duplicates") { for (uint16_t i = 0; i < 10; ++i) { - uint16_t packet_id = (start + i) % 65536; + const uint16_t packet_id = (start + i) % 65536; REQUIRE(dedup.is_duplicate(packet_id)); } } @@ -175,7 +175,7 @@ namespace extension { PacketDeduplicator dedup; WHEN("The first packet ID is greater than uint16_t max/2") { - uint16_t high_id = 40000; // > 32768 + const uint16_t high_id = 40000; // > 32768 dedup.add_packet(high_id); THEN("The first packet should be marked as a duplicate") { From 694f4c351b8abc7f406f9045f79005f2d9346d29 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Mon, 19 May 2025 13:06:07 +1000 Subject: [PATCH 13/14] . --- tests/tests/network/RTTEstimator.cpp | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/tests/network/RTTEstimator.cpp b/tests/tests/network/RTTEstimator.cpp index 734fd92c..b139e7b3 100644 --- a/tests/tests/network/RTTEstimator.cpp +++ b/tests/tests/network/RTTEstimator.cpp @@ -23,6 +23,7 @@ #include #include +#include namespace NUClear { namespace extension { @@ -32,7 +33,7 @@ namespace extension { SCENARIO("RTTEstimator initial state", "[network]") { GIVEN("a new RTTEstimator") { - RTTEstimator rtt(0.125f, 0.25f, 1.0f, 0.0f); + const RTTEstimator rtt(0.125f, 0.25f, 1.0f, 0.0f); THEN("the initial timeout should be 1 second") { REQUIRE(rtt.timeout() == 1s); From e05cee3d8d2a48acff2984b68e749eb4636d27c5 Mon Sep 17 00:00:00 2001 From: Trent Houliston Date: Mon, 19 May 2025 13:14:00 +1000 Subject: [PATCH 14/14] . --- tests/tests/network/RTTEstimator.cpp | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/tests/network/RTTEstimator.cpp b/tests/tests/network/RTTEstimator.cpp index b139e7b3..259b0f94 100644 --- a/tests/tests/network/RTTEstimator.cpp +++ b/tests/tests/network/RTTEstimator.cpp @@ -21,7 +21,6 @@ */ #include "extension/network/RTTEstimator.hpp" -#include #include #include