diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 7ed0fb8a..722cf01b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -28,7 +28,15 @@ configure_file(nuclear.in ${PROJECT_BINARY_DIR}/nuclear) # Build the library find_package(Threads REQUIRED) -file(GLOB_RECURSE src "*.c" "*.cpp" "*.hpp" "*.ipp") +file( + GLOB_RECURSE + src + CONFIGURE_DEPENDS + "*.c" + "*.cpp" + "*.hpp" + "*.ipp" +) add_library(nuclear STATIC ${src}) add_library(NUClear::nuclear ALIAS nuclear) diff --git a/src/extension/network/NUClearNetwork.cpp b/src/extension/network/NUClearNetwork.cpp index 11839da9..ae0118b7 100644 --- a/src/extension/network/NUClearNetwork.cpp +++ b/src/extension/network/NUClearNetwork.cpp @@ -24,7 +24,6 @@ #include #include -#include #include #include #include @@ -501,7 +500,7 @@ namespace extension { if (ptr) { auto now = std::chrono::steady_clock::now(); - auto timeout = it->last_send + ptr->round_trip_time; + auto timeout = it->last_send + ptr->rtt.timeout(); // Check if we should have expected an ack by now for some packets if (timeout < now) { @@ -510,7 +509,7 @@ namespace extension { it->last_send = now; // The next time we should check for a timeout - auto next_timeout = now + ptr->round_trip_time; + auto next_timeout = now + ptr->rtt.timeout(); if (next_timeout < next_event) { next_event = next_timeout; next_event_callback(next_event); @@ -673,18 +672,11 @@ namespace extension { remote->last_update = std::chrono::steady_clock::now(); // Check if this packet is a retransmission of data - if (header.type == DATA_RETRANSMISSION) { + if (header.type == DATA_RETRANSMISSION + && remote->deduplicator.is_duplicate(packet.packet_id)) { - // See if we recently processed this packet - // NOLINTNEXTLINE(readability-qualified-auto) MSVC disagrees - auto it = std::find(remote->recent_packets.begin(), - remote->recent_packets.end(), - packet.packet_id); - - // We recently processed this packet, this is just a failed ack // Send the ack again if it was reliable - if (it != remote->recent_packets.end() && packet.reliable) { - + if (packet.reliable) { // Allocate room for the whole ack packet std::vector r(sizeof(ACKPacket) + (packet.packet_count / 8), 0); ACKPacket& response = *reinterpret_cast(r.data()); @@ -708,10 +700,10 @@ namespace extension { 0, &to.sock, to.size()); - - // We don't need to process this packet we already did - return; } + + // We don't need to process this packet we already did + return; } // If this is a solo packet (in a single chunk) @@ -739,13 +731,11 @@ namespace extension { 0, &to.sock, to.size()); - - // Set this packet to have been recently received - remote->recent_packets[remote->recent_packets_index - .fetch_add(1, std::memory_order_relaxed)] = - packet.packet_id; } + // Add the packet to our deduplicator + remote->deduplicator.add_packet(packet.packet_id); + packet_callback(*remote, packet.hash, packet.reliable, std::move(out)); } else { @@ -851,17 +841,12 @@ namespace extension { &part.data + p.second.size() - sizeof(DataPacket) + 1); } + // Add the packet to our deduplicator + remote->deduplicator.add_packet(packet.packet_id); + // Send our assembled data packet packet_callback(*remote, packet.hash, packet.reliable, std::move(out)); - // If the packet was reliable add that it was recently received - if (packet.reliable) { - // Set this packet to have been recently received - remote->recent_packets[remote->recent_packets_index - .fetch_add(1, std::memory_order_relaxed)] = - packet.packet_id; - } - // We have completed this packet, discard the data assemblers.erase(assemblers.find(packet.packet_id)); } @@ -869,7 +854,7 @@ namespace extension { // Check for and delete any timed out packets for (auto it = assemblers.begin(); it != assemblers.end();) { const auto now = std::chrono::steady_clock::now(); - const auto timeout = remote->round_trip_time * 10.0; + const auto timeout = remote->rtt.timeout() * 10.0; const auto& last_chunk_time = it->second.first; it = now > last_chunk_time + timeout ? assemblers.erase(it) : std::next(it); @@ -919,8 +904,7 @@ namespace extension { // Approximate how long the round trip is to this remote so we can work out how // long before retransmitting - // We use a baby kalman filter to help smooth out jitter - remote->measure_round_trip(round_trip); + remote->rtt.measure(round_trip); // Update our acks bool all_acked = true; @@ -987,7 +971,7 @@ namespace extension { s->last_send = std::chrono::steady_clock::now(); // The next time we should check for a timeout - auto next_timeout = s->last_send + remote->round_trip_time; + auto next_timeout = s->last_send + remote->rtt.timeout(); if (next_timeout < next_event) { next_event = next_timeout; next_event_callback(next_event); @@ -1108,7 +1092,7 @@ namespace extension { queue.targets.emplace_back(it->second, acks); // The next time we should check for a timeout - auto next_timeout = std::chrono::steady_clock::now() + it->second->round_trip_time; + auto next_timeout = std::chrono::steady_clock::now() + it->second->rtt.timeout(); if (next_timeout < next_event) { next_event = next_timeout; next_event_callback(next_event); diff --git a/src/extension/network/NUClearNetwork.hpp b/src/extension/network/NUClearNetwork.hpp index e2277af2..d5f17e6d 100644 --- a/src/extension/network/NUClearNetwork.hpp +++ b/src/extension/network/NUClearNetwork.hpp @@ -39,6 +39,8 @@ #include "../../util/network/sock_t.hpp" #include "../../util/platform.hpp" +#include "PacketDeduplicator.hpp" +#include "RTTEstimator.hpp" #include "wire_protocol.hpp" namespace NUClear { @@ -56,11 +58,7 @@ namespace extension { std::string name, const sock_t& target, const std::chrono::steady_clock::time_point& last_update = std::chrono::steady_clock::now()) - : name(std::move(name)), target(target), last_update(last_update) { - - // Set our recent packets to an invalid value - recent_packets.fill(-1); - } + : name(std::move(name)), target(target), last_update(last_update) {} /// The name of the remote target std::string name; @@ -68,10 +66,6 @@ namespace extension { sock_t target{}; /// When we last received data from the remote target std::chrono::steady_clock::time_point last_update; - /// A list of the last n packet groups to be received - std::array::max()> recent_packets{}; - /// An index for the recent_packets (circular buffer) - std::atomic recent_packets_index{0}; /// Mutex to protect the fragmented packet storage std::mutex assemblers_mutex; /// Storage for fragmented packets while we build them @@ -79,41 +73,11 @@ namespace extension { std::pair>>> assemblers; - /// Struct storing the kalman filter for round trip time - struct RoundTripKF { - float process_noise = 1e-6f; - float measurement_noise = 1e-1f; - float variance = 1.0f; - float mean = 1.0f; - }; - /// A little kalman filter for estimating round trip time - RoundTripKF round_trip_kf{}; - - std::chrono::steady_clock::duration round_trip_time{std::chrono::seconds(1)}; - - void measure_round_trip(std::chrono::steady_clock::duration time) { - - // Make our measurement into a float seconds type - const std::chrono::duration m = - std::chrono::duration_cast>(time); - - // Alias variables - const auto& Q = round_trip_kf.process_noise; - const auto& R = round_trip_kf.measurement_noise; - auto& P = round_trip_kf.variance; - auto& X = round_trip_kf.mean; - - // Calculate our kalman gain - const float K = (P + Q) / (P + Q + R); - - // Do filter - P = R * (P + Q) / (R + P + Q); - X = X + (m.count() - X) * K; + /// RTT estimator for this network target + RTTEstimator rtt; - // Put result into our variable - round_trip_time = std::chrono::duration_cast( - std::chrono::duration(X)); - } + /// Packet deduplicator for this network target + PacketDeduplicator deduplicator; }; NUClearNetwork() = default; diff --git a/src/extension/network/PacketDeduplicator.cpp b/src/extension/network/PacketDeduplicator.cpp new file mode 100644 index 00000000..c252a1f3 --- /dev/null +++ b/src/extension/network/PacketDeduplicator.cpp @@ -0,0 +1,76 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "PacketDeduplicator.hpp" + +#include + +namespace NUClear { +namespace extension { + namespace network { + + + bool PacketDeduplicator::is_duplicate(uint16_t packet_id) const { + // If we haven't seen any packets yet, nothing is a duplicate + if (!initialized) { + return false; + } + + // Calculate relative position in window using unsigned subtraction + const uint16_t relative_id = newest_seen - packet_id; + + // If the packet is too old or too new, it's not a duplicate + if (relative_id >= 256) { + return false; + } + + return window[relative_id]; + } + + void PacketDeduplicator::add_packet(uint16_t packet_id) { + // If this is our first packet, just set it as newest_seen + if (!initialized) { + newest_seen = packet_id; + window[0] = true; + initialized = true; + return; + } + + // Calculate relative position in window using unsigned subtraction + const uint16_t relative_id = newest_seen - packet_id; + + // If the distance is more than half the range, the packet is newer than our newest_seen + if (relative_id > 32768) { + // Calculate how far to shift to make this packet our newest + const uint16_t shift_amount = packet_id - newest_seen; + window <<= shift_amount; + newest_seen = packet_id; + window[0] = true; + } + // Packet is recent enough to be counted + else if (relative_id < 256) { + window[relative_id] = true; + } + } + + } // namespace network +} // namespace extension +} // namespace NUClear diff --git a/src/extension/network/PacketDeduplicator.hpp b/src/extension/network/PacketDeduplicator.hpp new file mode 100644 index 00000000..494f6d88 --- /dev/null +++ b/src/extension/network/PacketDeduplicator.hpp @@ -0,0 +1,67 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#ifndef NUCLEAR_EXTENSION_NETWORK_PACKET_DEDUPLICATOR_HPP +#define NUCLEAR_EXTENSION_NETWORK_PACKET_DEDUPLICATOR_HPP + +#include +#include + +namespace NUClear { +namespace extension { + namespace network { + + /** + * A class that implements a sliding window bitset for packet deduplication. + * Maintains a 256-bit window of recently seen packet IDs, sliding forward as new packets are added. + */ + class PacketDeduplicator { + public: + /** + * Check if a packet ID has been seen recently + * + * @param packet_id The packet ID to check + * + * @return true if the packet has been seen recently, false otherwise + */ + bool is_duplicate(uint16_t packet_id) const; + + /** + * Add a packet ID to the window + * + * @param packet_id The packet ID to add + */ + void add_packet(uint16_t packet_id); + + private: + /// Whether we've seen any packets yet + bool initialized{false}; + /// The newest packet ID we've seen + uint16_t newest_seen{0}; + /// The 256-bit window of seen packets (newest at 0, older at higher indices) + std::bitset<256> window; + }; + + } // namespace network +} // namespace extension +} // namespace NUClear + +#endif // NUCLEAR_EXTENSION_NETWORK_PACKET_DEDUPLICATOR_HPP diff --git a/src/extension/network/RTTEstimator.cpp b/src/extension/network/RTTEstimator.cpp new file mode 100644 index 00000000..13876130 --- /dev/null +++ b/src/extension/network/RTTEstimator.cpp @@ -0,0 +1,79 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "RTTEstimator.hpp" + +#include +#include +#include + +namespace NUClear { +namespace extension { + namespace network { + + RTTEstimator::RTTEstimator(float alpha, + float beta, + float initial_rtt, + float initial_rtt_var, + float min_rto, + float max_rto) + : alpha(alpha) + , beta(beta) + , min_rto(min_rto) + , max_rto(max_rto) + , smoothed_rtt(initial_rtt) + , rtt_var(initial_rtt_var) + , rto(std::min(std::max(initial_rtt + 4 * initial_rtt_var, min_rto), max_rto)) { + + if (alpha < 0.0f || alpha > 1.0f) { + throw std::invalid_argument("alpha must be in range [0,1]"); + } + if (beta < 0.0f || beta > 1.0f) { + throw std::invalid_argument("beta must be in range [0,1]"); + } + if (min_rto >= max_rto) { + throw std::invalid_argument("min_rto must be less than max_rto"); + } + } + + void RTTEstimator::measure(std::chrono::steady_clock::duration time) { + // Convert measurement to float seconds + const std::chrono::duration m = std::chrono::duration_cast>(time); + const float sample_rtt = m.count(); + + // Calculate RTT variation + const float err = sample_rtt - smoothed_rtt; + rtt_var = (1 - beta) * rtt_var + beta * std::abs(err); + + // Update smoothed RTT + smoothed_rtt = (1 - alpha) * smoothed_rtt + alpha * sample_rtt; + + // Calculate RTO (smoothed RTT + 4 * RTT variation) and bound to limits + rto = std::min(std::max(smoothed_rtt + 4 * rtt_var, min_rto), max_rto); + } + + std::chrono::steady_clock::duration RTTEstimator::timeout() const { + return std::chrono::duration_cast(std::chrono::duration(rto)); + } + + } // namespace network +} // namespace extension +} // namespace NUClear diff --git a/src/extension/network/RTTEstimator.hpp b/src/extension/network/RTTEstimator.hpp new file mode 100644 index 00000000..481659ea --- /dev/null +++ b/src/extension/network/RTTEstimator.hpp @@ -0,0 +1,107 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#ifndef NUCLEAR_EXTENSION_NETWORK_RTT_ESTIMATOR_HPP +#define NUCLEAR_EXTENSION_NETWORK_RTT_ESTIMATOR_HPP + +#include +#include +#include + +namespace NUClear { +namespace extension { + namespace network { + + /** + * Implements TCP-style Round Trip Time (RTT) estimation using Jacobson/Karels algorithm. + * + * This class provides RTT estimation functionality similar to TCP's RTT estimation mechanism. + * It uses an Exponentially Weighted Moving Average (EWMA) to smooth RTT measurements and + * calculate a retransmission timeout (RTO) value. The implementation follows the TCP + * Jacobson/Karels algorithm which provides robust RTT estimation that: + * - Smoothly tracks the mean RTT + * - Adapts to RTT variations + * - Handles network jitter + * - Provides conservative timeout values + */ + class RTTEstimator { + public: + /** + * Construct a new RTT Estimator + * + * @param alpha Weight for RTT smoothing (default: 0.125, TCP standard) + * @param beta Weight for RTT variation (default: 0.25, TCP standard) + * @param initial_rtt Initial RTT estimate in seconds (default: 1.0) + * @param initial_rtt_var Initial RTT variation in seconds (default: 0.0) + * @param min_rto Minimum RTO value in seconds (default: 0.1) + * @param max_rto Maximum RTO value in seconds (default: 60.0) + * + * The alpha and beta parameters control how quickly the estimator adapts to changes: + * - alpha: Lower values (e.g. 0.125) make the smoothed RTT more stable but slower to adapt + * - beta: Lower values (e.g. 0.25) make the RTT variation more stable but slower to adapt + * + * @throws std::invalid_argument if alpha or beta are not in range [0,1] + * @throws std::invalid_argument if min_rto >= max_rto + */ + RTTEstimator(float alpha = 0.125f, + float beta = 0.25f, + float initial_rtt = 1.0f, + float initial_rtt_var = 0.0f, + float min_rto = 0.1f, + float max_rto = 60.0f); + + /** + * Update the RTT estimate with a new measurement + * + * Updates the smoothed RTT, RTT variation, and RTO using the Jacobson/Karels algorithm: + * 1. RTT variation = (1 - beta) * old_variation + beta * |smoothed_rtt - new_rtt| + * 2. Smoothed RTT = (1 - alpha) * old_rtt + alpha * new_rtt + * 3. RTO = smoothed_rtt + 4 * rtt_var + * + * The RTO is bounded between min_rto and max_rto to prevent extreme values. + * + * @param time The measured round trip time + */ + void measure(std::chrono::steady_clock::duration time); + + /** + * Get the current retransmission timeout + * + * @return The RTO as a duration. This value represents the recommended timeout + * for network operations based on the current RTT estimates. + */ + std::chrono::steady_clock::duration timeout() const; + + private: + float alpha; ///< Weight for RTT smoothing (typically 0.125) + float beta; ///< Weight for RTT variation (typically 0.25) + float min_rto; ///< Minimum RTO value in seconds + float max_rto; ///< Maximum RTO value in seconds + float smoothed_rtt; ///< Smoothed RTT estimate in seconds + float rtt_var; ///< RTT variation in seconds + float rto; ///< Retransmission timeout in seconds + }; + + } // namespace network +} // namespace extension +} // namespace NUClear + +#endif // NUCLEAR_EXTENSION_NETWORK_RTT_ESTIMATOR_HPP diff --git a/tests/CMakeLists.txt b/tests/CMakeLists.txt index 9520419b..700a989c 100644 --- a/tests/CMakeLists.txt +++ b/tests/CMakeLists.txt @@ -40,7 +40,7 @@ set_target_properties(${catch2_targets} PROPERTIES CXX_CLANG_TIDY "") set_target_properties(${catch2_target} PROPERTIES CMAKE_CXX_FLAGS "") # Create a test_util library that is used by all tests -file(GLOB_RECURSE test_util_src "test_util/*.cpp") +file(GLOB_RECURSE test_util_src CONFIGURE_DEPENDS "test_util/*.cpp") add_library(test_util OBJECT ${test_util_src}) # This is linking WHOLE_ARCHIVE as otherwise the linker will remove the WSAHolder from the final binary # As a result the WSA initialisation code won't run and the network tests will fail @@ -50,7 +50,7 @@ target_include_directories(test_util PUBLIC ${PROJECT_BINARY_DIR}/include ${PROJ target_include_directories(test_util PUBLIC ${CMAKE_CURRENT_SOURCE_DIR}) # Create a test binary for each test file -file(GLOB_RECURSE test_sources "tests/*.cpp") +file(GLOB_RECURSE test_sources CONFIGURE_DEPENDS "tests/*.cpp") foreach(test_file ${test_sources}) get_filename_component(test_name ${test_file} NAME_WE) get_filename_component(test_dir ${test_file} DIRECTORY) diff --git a/tests/networktest.cpp b/tests/networktest.cpp index 581a712a..12f6238a 100644 --- a/tests/networktest.cpp +++ b/tests/networktest.cpp @@ -138,6 +138,9 @@ int main(int argc, const char* argv[]) { NUClear::Configuration config; config.default_pool_concurrency = 4; NUClear::PowerPlant plant(config, argc, argv); + plant.install(); + plant.install(); + plant.install(); plant.install(); plant.start(); diff --git a/tests/tests/network/PacketDeduplicator.cpp b/tests/tests/network/PacketDeduplicator.cpp new file mode 100644 index 00000000..3a0815f7 --- /dev/null +++ b/tests/tests/network/PacketDeduplicator.cpp @@ -0,0 +1,205 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "extension/network/PacketDeduplicator.hpp" + +#include +#include + +namespace NUClear { +namespace extension { + namespace network { + + SCENARIO("PacketDeduplicator basic functionality", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we check a new packet") { + THEN("it should not be a duplicate") { + REQUIRE_FALSE(dedup.is_duplicate(1)); + } + + AND_WHEN("we add the packet") { + dedup.add_packet(1); + + THEN("it should be marked as a duplicate") { + REQUIRE(dedup.is_duplicate(1)); + } + + AND_WHEN("we check a different packet") { + THEN("it should not be a duplicate") { + REQUIRE_FALSE(dedup.is_duplicate(2)); + } + + AND_WHEN("we add the second packet") { + dedup.add_packet(2); + + THEN("both packets should be marked as duplicates") { + REQUIRE(dedup.is_duplicate(1)); + REQUIRE(dedup.is_duplicate(2)); + } + } + } + } + } + } + } + + SCENARIO("PacketDeduplicator window sliding", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add packets up to the window size") { + for (uint16_t i = 0; i < 256; ++i) { + dedup.add_packet(i); + } + + THEN("all packets should be marked as duplicates") { + for (uint16_t i = 0; i < 256; ++i) { + REQUIRE(dedup.is_duplicate(i)); + } + } + + AND_WHEN("we add a packet beyond the window") { + dedup.add_packet(256); + + THEN("the oldest packet should be forgotten") { + REQUIRE_FALSE(dedup.is_duplicate(0)); + } + + AND_THEN("the newest packet should be remembered") { + REQUIRE(dedup.is_duplicate(256)); + } + } + } + } + } + + SCENARIO("PacketDeduplicator out of order packets", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add packets out of order") { + dedup.add_packet(5); + dedup.add_packet(3); + dedup.add_packet(7); + dedup.add_packet(1); + + THEN("all added packets should be marked as duplicates") { + REQUIRE(dedup.is_duplicate(1)); + REQUIRE(dedup.is_duplicate(3)); + REQUIRE(dedup.is_duplicate(5)); + REQUIRE(dedup.is_duplicate(7)); + } + + AND_THEN("unseen packets should not be marked as duplicates") { + REQUIRE_FALSE(dedup.is_duplicate(2)); + REQUIRE_FALSE(dedup.is_duplicate(4)); + REQUIRE_FALSE(dedup.is_duplicate(6)); + REQUIRE_FALSE(dedup.is_duplicate(8)); + } + } + } + } + + SCENARIO("PacketDeduplicator packet wrap around", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add packets near the uint16_t wrap around point") { + const uint16_t start = 65530; + + for (uint16_t i = 0; i < 10; ++i) { + const uint16_t packet_id = (start + i) % 65536; + dedup.add_packet(packet_id); + } + + THEN("all added packets should be marked as duplicates") { + for (uint16_t i = 0; i < 10; ++i) { + const uint16_t packet_id = (start + i) % 65536; + REQUIRE(dedup.is_duplicate(packet_id)); + } + } + + AND_THEN("packets before the window should not be marked as duplicates") { + REQUIRE_FALSE(dedup.is_duplicate(start - 1)); + } + } + } + } + + SCENARIO("PacketDeduplicator old packets", "[network]") { + GIVEN("a new PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("we add a packet and then slide the window past it") { + dedup.add_packet(1); + + // Slide window past packet 1 + for (uint16_t i = 2; i < 258; ++i) { + dedup.add_packet(i); + } + + THEN("the old packet should not be marked as a duplicate") { + REQUIRE_FALSE(dedup.is_duplicate(1)); + } + + AND_THEN("recent packets should still be marked as duplicates") { + REQUIRE(dedup.is_duplicate(256)); + REQUIRE(dedup.is_duplicate(257)); + } + } + } + } + + SCENARIO("PacketDeduplicator handles high initial packet IDs correctly", "[network]") { + GIVEN("A PacketDeduplicator") { + PacketDeduplicator dedup; + + WHEN("The first packet ID is greater than uint16_t max/2") { + const uint16_t high_id = 40000; // > 32768 + dedup.add_packet(high_id); + + THEN("The first packet should be marked as a duplicate") { + REQUIRE(dedup.is_duplicate(high_id)); + } + } + } + } + + SCENARIO("PacketDeduplicator handles adding old packets", "[network]") { + GIVEN("a PacketDeduplicator with a high packet ID") { + PacketDeduplicator dedup; + dedup.add_packet(512); // Start with a high packet ID + + WHEN("we try to add a much older packet") { + dedup.add_packet(1); // Try to add a packet that's more than 256 behind + + THEN("the old packet should not be marked as seen") { + REQUIRE_FALSE(dedup.is_duplicate(1)); + } + } + } + } + + } // namespace network +} // namespace extension +} // namespace NUClear diff --git a/tests/tests/network/RTTEstimator.cpp b/tests/tests/network/RTTEstimator.cpp new file mode 100644 index 00000000..259b0f94 --- /dev/null +++ b/tests/tests/network/RTTEstimator.cpp @@ -0,0 +1,289 @@ +/* + * MIT License + * + * Copyright (c) 2025 NUClear Contributors + * + * This file is part of the NUClear codebase. + * See https://github.com/Fastcode/NUClear for further info. + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated + * documentation files (the "Software"), to deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE + * WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR + * COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR + * OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +#include "extension/network/RTTEstimator.hpp" + +#include +#include + +namespace NUClear { +namespace extension { + namespace network { + + using namespace std::chrono_literals; + + SCENARIO("RTTEstimator initial state", "[network]") { + GIVEN("a new RTTEstimator") { + const RTTEstimator rtt(0.125f, 0.25f, 1.0f, 0.0f); + + THEN("the initial timeout should be 1 second") { + REQUIRE(rtt.timeout() == 1s); + } + } + } + + SCENARIO("RTTEstimator with constant RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we feed it constant RTTs of 100ms") { + for (int i = 0; i < 20; ++i) { + rtt.measure(100ms); + } + + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 200ms); + } + } + } + } + + SCENARIO("RTTEstimator with increasing RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we measure 100ms then 200ms") { + rtt.measure(100ms); + rtt.measure(200ms); + + THEN("the timeout should be at least 200ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 200ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } + } + } + + SCENARIO("RTTEstimator with decreasing RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.2f, 0.0f); + + WHEN("we measure 200ms then 100ms") { + rtt.measure(200ms); + rtt.measure(100ms); + + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } + } + } + + SCENARIO("RTTEstimator with oscillating RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.15f, 0.0f); + + WHEN("we feed it alternating RTTs of 100ms and 200ms") { + for (int i = 0; i < 20; ++i) { + rtt.measure(i % 2 == 0 ? 100ms : 200ms); + } + + THEN("the timeout should be at least 100ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 400ms); + } + } + } + } + + SCENARIO("RTTEstimator with large RTT variation", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we measure 100ms then 1 second") { + rtt.measure(100ms); + rtt.measure(1s); + + THEN("the timeout should be at least 1s and not unreasonably high") { + REQUIRE(rtt.timeout() >= 1s); + REQUIRE(rtt.timeout() <= 2s); + } + } + } + } + + SCENARIO("RTTEstimator with small RTT variation", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we measure 100ms then 110ms") { + rtt.measure(100ms); + rtt.measure(110ms); + + THEN("the timeout should be at least 110ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 110ms); + REQUIRE(rtt.timeout() <= 200ms); + } + } + } + } + + SCENARIO("RTTEstimator with zero RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.001f, 0.0f); + + WHEN("we measure 0ms") { + rtt.measure(0ms); + + THEN("the timeout should be at least 0ms and not unreasonably high") { + REQUIRE(rtt.timeout() >= 0ms); + REQUIRE(rtt.timeout() <= 1s); + } + } + } + } + + SCENARIO("RTTEstimator with very large RTT", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 30.0f, 0.0f); + + WHEN("we measure 30 seconds") { + rtt.measure(30s); + + THEN("the timeout should be at least 30s and not unreasonably high") { + REQUIRE(rtt.timeout() >= 30s); + REQUIRE(rtt.timeout() <= 35s); + } + } + } + } + + SCENARIO("RTTEstimator exact calculation verification", "[network]") { + GIVEN("a new RTTEstimator with known initial state") { + // Initialize with SRTT = 100ms, RTTVAR = 50ms + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.05f); + + WHEN("we measure a 120ms RTT") { + rtt.measure(120ms); + + THEN("the values should match the TCP calculation") { + // Expected values: + // RTTVAR = (1 - 0.25) * 50 + 0.25 * |100 - 120| = 0.75 * 50 + 0.25 * 20 = 37.5 + 5 = 42.5ms + // SRTT = (1 - 0.125) * 100 + 0.125 * 120 = 87.5 + 15 = 102.5ms + // RTO = 102.5 + 4 * 42.5 = 272.5ms + REQUIRE(rtt.timeout() >= 270ms); + REQUIRE(rtt.timeout() <= 275ms); + } + } + } + } + + SCENARIO("RTTEstimator spike response", "[network]") { + GIVEN("a new RTTEstimator with stable RTT") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we feed it constant 100ms RTTs then a 500ms spike") { + // First establish a stable RTT + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto before_spike = rtt.timeout(); + + // Then inject a spike + rtt.measure(500ms); + auto after_spike = rtt.timeout(); + + THEN("the timeout should increase but not dramatically") { + REQUIRE(after_spike > before_spike); // Should increase + REQUIRE(after_spike < 1s); // But not too much + } + + AND_WHEN("we return to normal RTT") { + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto after_recovery = rtt.timeout(); + + THEN("it should recover to a reasonable value") { + REQUIRE(after_recovery >= 100ms); + REQUIRE(after_recovery <= 300ms); + } + } + } + } + } + + SCENARIO("RTTEstimator noise resilience", "[network]") { + GIVEN("a new RTTEstimator") { + RTTEstimator rtt(0.125f, 0.25f, 0.1f, 0.0f); + + WHEN("we feed it noisy RTTs around 100ms") { + // Generate noisy RTTs: 100ms ± 20ms + for (int i = 0; i < 50; ++i) { + auto noise = (i % 2 == 0 ? 20ms : -20ms); + rtt.measure(100ms + noise); + } + + THEN("the timeout should remain stable") { + REQUIRE(rtt.timeout() >= 100ms); + REQUIRE(rtt.timeout() <= 300ms); + } + + AND_WHEN("we continue with constant RTT") { + for (int i = 0; i < 10; ++i) { + rtt.measure(100ms); + } + auto final_timeout = rtt.timeout(); + + THEN("it should converge to the true RTT") { + REQUIRE(final_timeout >= 100ms); + REQUIRE(final_timeout <= 200ms); + } + } + } + } + } + + SCENARIO("RTTEstimator constructor validation", "[network]") { + GIVEN("invalid alpha values") { + THEN("it should throw std::invalid_argument") { + REQUIRE_THROWS_AS(RTTEstimator(-0.1f), std::invalid_argument); + REQUIRE_THROWS_AS(RTTEstimator(1.1f), std::invalid_argument); + } + } + + GIVEN("invalid beta values") { + THEN("it should throw std::invalid_argument") { + REQUIRE_THROWS_AS(RTTEstimator(0.125f, -0.1f), std::invalid_argument); + REQUIRE_THROWS_AS(RTTEstimator(0.125f, 1.1f), std::invalid_argument); + } + } + + GIVEN("invalid min_rto/max_rto combinations") { + THEN("it should throw std::invalid_argument") { + REQUIRE_THROWS_AS(RTTEstimator(0.125f, 0.25f, 1.0f, 0.0f, 1.0f, 0.5f), std::invalid_argument); + REQUIRE_THROWS_AS(RTTEstimator(0.125f, 0.25f, 1.0f, 0.0f, 1.0f, 1.0f), std::invalid_argument); + } + } + + GIVEN("valid arguments") { + THEN("it should not throw") { + REQUIRE_NOTHROW(RTTEstimator()); + REQUIRE_NOTHROW(RTTEstimator(0.125f, 0.25f, 1.0f, 0.0f, 0.1f, 60.0f)); + } + } + } + + } // namespace network +} // namespace extension +} // namespace NUClear