From 41838157526e03895564d729ebdda25e77c0d8fe Mon Sep 17 00:00:00 2001 From: Leonardo Araujo Date: Mon, 25 Mar 2024 21:15:32 -0300 Subject: [PATCH] fix: several fixes --- include/win_socket.hpp | 15 +++++-- src/unix_socket.cpp | 6 +-- src/win_socket.cpp | 91 +++++++++++++++++++++++++++++++++++++----- 3 files changed, 96 insertions(+), 16 deletions(-) diff --git a/include/win_socket.hpp b/include/win_socket.hpp index e01be2a..c036737 100644 --- a/include/win_socket.hpp +++ b/include/win_socket.hpp @@ -16,6 +16,9 @@ #pragma comment(lib, "Mswsock.lib") #pragma comment(lib, "AdvApi32.lib") #include "socket.hpp" +#include "utils.hpp" +#include "base_exceptions.hpp" +#include "console_logger.hpp" namespace tpt { @@ -28,16 +31,20 @@ namespace tpt unsigned int port; unsigned int max_connections; std::string ip_address; + std::vector ip_blacklist; + ConsoleLogger logger; public: WinSocket(); - WinSocket(unsigned int port); - WinSocket(std::string ip_address, unsigned int port); - WinSocket(std::string ip_address, unsigned int port, unsigned int max_connections); + WinSocket(ConsoleLogger logger); + WinSocket(ConsoleLogger logger, unsigned int port); + WinSocket(ConsoleLogger logger, std::string ip_address, unsigned int port); + WinSocket(ConsoleLogger logger, std::string ip_address, unsigned int port, unsigned int max_connections); ~WinSocket(); + std::string getClientIp(); virtual void bindSocket() override; virtual void listenToConnections() override; - virtual void acceptConnection(SOCKET&client_socket, void *client_address) override; + virtual void acceptConnection(SOCKET &client_socket, void *client_address) override; virtual ssize_t receiveData(SOCKET client_socket, char *buffer, unsigned int buffer_size) override; virtual void sendData(SOCKET client_socket, const void *buffer, unsigned int buffer_size, int flags) override; virtual void closeSocket() override; diff --git a/src/unix_socket.cpp b/src/unix_socket.cpp index 370346c..da465d2 100644 --- a/src/unix_socket.cpp +++ b/src/unix_socket.cpp @@ -18,7 +18,7 @@ UnixSocket::UnixSocket() std::cout << "Error code: " + errno << std::endl; exit(EXIT_FAILURE); } - std::cout << "Socket created!" << std::endl; + LOG_INFO(logger, "Socket created!"); this->server_address.sin_family = AF_INET; this->server_address.sin_port = htons(this->port); @@ -42,7 +42,7 @@ UnixSocket::UnixSocket(ConsoleLogger logger) std::cout << "Error code: " + errno << std::endl; exit(EXIT_FAILURE); } - std::cout << "Socket created!" << std::endl; + LOG_INFO(logger, "Socket created!"); this->server_address.sin_family = AF_INET; this->server_address.sin_port = htons(this->port); @@ -66,7 +66,7 @@ UnixSocket::UnixSocket(ConsoleLogger logger, unsigned int port) std::cout << "Error code: " + errno << std::endl; exit(EXIT_FAILURE); } - std::cout << "Socket created!" << std::endl; + LOG_INFO(logger, "Socket created!"); this->server_address.sin_family = AF_INET; this->server_address.sin_port = htons(this->port); diff --git a/src/win_socket.cpp b/src/win_socket.cpp index 6d325fb..6f7c48d 100644 --- a/src/win_socket.cpp +++ b/src/win_socket.cpp @@ -8,6 +8,7 @@ WinSocket::WinSocket() this->ip_address = "127.0.0.1"; this->port = 8000; this->max_connections = 10; + this->logger = ConsoleLogger(); if (WSAStartup(MAKEWORD(2, 2), &wsa) != 0) { @@ -28,13 +29,46 @@ WinSocket::WinSocket() this->server_address.sin_family = AF_INET; this->server_address.sin_port = htons(this->port); this->server_address.sin_addr.s_addr = inet_addr(this->ip_address.c_str()); + + Utils::fillIPBlacklist(this->ip_blacklist); +} + +WinSocket::WinSocket(ConsoleLogger logger) +{ + this->ip_address = "127.0.0.1"; + this->port = 8000; + this->max_connections = 10; + this->logger = logger; + + if (WSAStartup(MAKEWORD(2, 2), &wsa) != 0) + { + std::printf("Failed. Error Code : %d", WSAGetLastError()); + WSACleanup(); + exit(EXIT_FAILURE); + } + + std::cout << "Creating socket ..." << std::endl; + this->server_socket = socket(AF_INET, SOCK_STREAM, 0); + if (this->server_socket == INVALID_SOCKET) + { + std::printf("Could not create socket: %d\n", WSAGetLastError()); + WSACleanup(); + exit(EXIT_FAILURE); + } + + this->server_address.sin_family = AF_INET; + this->server_address.sin_port = htons(this->port); + this->server_address.sin_addr.s_addr = inet_addr(this->ip_address.c_str()); + + Utils::fillIPBlacklist(this->ip_blacklist); } -WinSocket::WinSocket(unsigned int port) +WinSocket::WinSocket(ConsoleLogger logger, unsigned int port) { this->ip_address = "127.0.0.1"; this->port = port; this->max_connections = 10; + this->logger = logger; if (WSAStartup(MAKEWORD(2, 2), &wsa) != 0) { @@ -55,13 +89,16 @@ WinSocket::WinSocket(unsigned int port) this->server_address.sin_family = AF_INET; this->server_address.sin_port = htons(this->port); this->server_address.sin_addr.s_addr = inet_addr(this->ip_address.c_str()); + + Utils::fillIPBlacklist(this->ip_blacklist); } -WinSocket::WinSocket(std::string ip_address, unsigned int port) +WinSocket::WinSocket(ConsoleLogger logger, std::string ip_address, unsigned int port) { this->ip_address = ip_address; this->port = port; this->max_connections = 10; + this->logger = logger; if (WSAStartup(MAKEWORD(2, 2), &wsa) != 0) { @@ -82,13 +119,16 @@ WinSocket::WinSocket(std::string ip_address, unsigned int port) this->server_address.sin_family = AF_INET; this->server_address.sin_port = htons(this->port); this->server_address.sin_addr.s_addr = inet_addr(this->ip_address.c_str()); + + Utils::fillIPBlacklist(this->ip_blacklist); } -WinSocket::WinSocket(std::string ip_address, unsigned int port, unsigned int max_connections) +WinSocket::WinSocket(ConsoleLogger logger, std::string ip_address, unsigned int port, unsigned int max_connections) { this->ip_address = ip_address; this->port = port; this->max_connections = 10; + this->logger = logger; if (WSAStartup(MAKEWORD(2, 2), &wsa) != 0) { @@ -109,6 +149,8 @@ WinSocket::WinSocket(std::string ip_address, unsigned int port, unsigned int max this->server_address.sin_family = AF_INET; this->server_address.sin_port = htons(this->port); this->server_address.sin_addr.s_addr = inet_addr(this->ip_address.c_str()); + + Utils::fillIPBlacklist(this->ip_blacklist); } void WinSocket::bindSocket() @@ -138,12 +180,43 @@ void WinSocket::listenToConnections() void WinSocket::acceptConnection(SOCKET &client_socket, void *client_address) { int client_addr_size = sizeof(sockaddr_in); - client_socket = accept(this->server_socket, static_cast(client_address), &client_addr_size); - if (client_socket == INVALID_SOCKET) { + client_socket = accept(this->server_socket, static_cast(client_address), &client_addr_size); + if (client_socket == INVALID_SOCKET) + { std::printf("Error accepting connections: %d\n", WSAGetLastError()); WSACleanup(); exit(EXIT_FAILURE); } + + // Assuming client_address is meant to store the result + if (client_address != nullptr) + { + std::memcpy(client_address, &client_addr_storage, client_addr_size); + } + + char ip_str[INET6_ADDRSTRLEN] = {0}; // Large enough for both IPv4 and IPv6 + if (client_addr_storage.ss_family == AF_INET) + { + // IPv4 + struct sockaddr_in *addr_in = (struct sockaddr_in *)&client_addr_storage; + inet_ntop(AF_INET, &addr_in->sin_addr, ip_str, INET_ADDRSTRLEN); + } + else if (client_addr_storage.ss_family == AF_INET6) + { + // IPv6 + struct sockaddr_in6 *addr_in6 = (struct sockaddr_in6 *)&client_addr_storage; + inet_ntop(AF_INET6, &addr_in6->sin6_addr, ip_str, INET6_ADDRSTRLEN); + } + + this->client_ip = std::string(ip_str); + + for (auto it : this->ip_blacklist) + { + if (this->client_ip == it) + { + throw IPBlackListedException(); + } + } } ssize_t WinSocket::receiveData(SOCKET client_socket, char *buffer, unsigned int buffer_size) @@ -151,16 +224,16 @@ ssize_t WinSocket::receiveData(SOCKET client_socket, char *buffer, unsigned int ssize_t data = recv(client_socket, buffer, buffer_size, 0); if (data < 0) { - perror("Receive error"); - std::cout << "Error code: " + errno << std::endl; - exit(1); + std::printf("Receive error\n"); + WSACleanup(); + exit(EXIT_FAILURE); } return data; } void WinSocket::sendData(SOCKET client_socket, const void *buffer, unsigned int buffer_size, int flags) { - send(client_socket, (char*)buffer, buffer_size, flags); + send(client_socket, (char *)buffer, buffer_size, flags); } void WinSocket::closeSocket()