diff --git a/raft/raft.go b/raft/raft.go index e3fb1fa7..741174e7 100644 --- a/raft/raft.go +++ b/raft/raft.go @@ -92,58 +92,143 @@ type PeerPayload struct { // Node represents a node in the Raft cluster. type Node struct { - raft *raft.Raft - config *raft.Config - Fsm *FSM + raft *raft.Raft + config *raft.Config + Fsm *FSM + logStore raft.LogStore + stableStore raft.StableStore + snapshotStore raft.SnapshotStore + transport raft.Transport + Logger zerolog.Logger + Peers []config.RaftPeer + rpcServer *grpc.Server + rpcClient *rpcClient + grpcAddr string + grpcIsSecure bool + peerSyncCancel context.CancelFunc +} + +type nodeConfig struct { + config *raft.Config + nodeID string + raftAddr string + raftDir string +} + +type stores struct { logStore raft.LogStore stableStore raft.StableStore snapshotStore raft.SnapshotStore - transport raft.Transport - Logger zerolog.Logger - Peers []config.RaftPeer - rpcServer *grpc.Server - rpcClient *rpcClient - grpcAddr string - grpcIsSecure bool } // NewRaftNode creates and initializes a new Raft node. func NewRaftNode(logger zerolog.Logger, raftConfig config.Raft) (*Node, error) { - RaftNodeConfig := raft.DefaultConfig() + // Initialize basic configuration + nodeConfig, err := initializeNodeConfig(logger, raftConfig) + if err != nil { + return nil, fmt.Errorf("failed to initialize node config: %w", err) + } - // Create HcLogAdapter to wrap zerolog logger - hcLogger := logging.NewHcLogAdapter(&logger, "raft") - RaftNodeConfig.Logger = hcLogger + // Create and initialize FSM + fsm := initializeFSM(raftConfig.Peers) + + // Initialize storage components + stores, err := initializeStores(nodeConfig.raftDir) + if err != nil { + return nil, fmt.Errorf("failed to initialize stores: %w", err) + } + + // Setup transport + transport, err := setupTransport(nodeConfig.raftAddr) + if err != nil { + return nil, fmt.Errorf("failed to setup transport: %w", err) + } + + // Create Raft instance + raftNode, err := raft.NewRaft( + nodeConfig.config, + fsm, + stores.logStore, + stores.stableStore, + stores.snapshotStore, + transport, + ) + if err != nil { + return nil, fmt.Errorf("failed to create raft instance: %w", err) + } + + // Create and initialize node + node := &Node{ + raft: raftNode, + config: nodeConfig.config, + Fsm: fsm, + logStore: stores.logStore, + stableStore: stores.stableStore, + snapshotStore: stores.snapshotStore, + transport: transport, + Logger: logger, + Peers: raftConfig.Peers, + grpcAddr: raftConfig.GRPCAddress, + grpcIsSecure: raftConfig.IsSecure, + } + + // Initialize networking + if err := initializeNetworking(node, raftConfig); err != nil { + return nil, fmt.Errorf("failed to initialize networking: %w", err) + } + + // Handle cluster configuration + if err := configureCluster(node, raftConfig, nodeConfig, transport); err != nil { + return nil, fmt.Errorf("failed to configure cluster: %w", err) + } + + return node, nil +} + +// initializeNodeConfig initializes the node configuration. +func initializeNodeConfig(logger zerolog.Logger, raftConfig config.Raft) (*nodeConfig, error) { + config := raft.DefaultConfig() + config.Logger = logging.NewHcLogAdapter(&logger, "raft") - var err error nodeID := raftConfig.NodeID - if raftConfig.NodeID == "" { + var err error + if nodeID == "" { nodeID, err = os.Hostname() if err != nil { return nil, fmt.Errorf("error getting hostname: %w", err) } } - raftAddr := raftConfig.Address - RaftNodeConfig.LocalID = raft.ServerID(nodeID) + + config.LocalID = raft.ServerID(nodeID) raftDir := filepath.Join(raftConfig.Directory, nodeID) - err = os.MkdirAll(raftDir, os.ModePerm) - if err != nil { + + if err := os.MkdirAll(raftDir, os.ModePerm); err != nil { return nil, fmt.Errorf("error creating raft directory: %w", err) } - // Create the FSM - fsm := NewFSM() + return &nodeConfig{ + config: config, + nodeID: nodeID, + raftAddr: raftConfig.Address, + raftDir: raftDir, + }, nil +} - // Add all peers to FSM if not already present +// initializeFSM initializes the FSM. +func initializeFSM(peers []config.RaftPeer) *FSM { + fsm := NewFSM() fsm.mu.Lock() - for _, peer := range raftConfig.Peers { + for _, peer := range peers { if _, exists := fsm.raftPeers[peer.ID]; !exists { fsm.raftPeers[peer.ID] = peer } } fsm.mu.Unlock() + return fsm +} - // Create the log store and stable store +// initializeStores initializes the stores. +func initializeStores(raftDir string) (*stores, error) { logStore, err := raftboltdb.NewBoltStore(filepath.Join(raftDir, "raft-log.db")) if err != nil { return nil, fmt.Errorf("error creating log store: %w", err) @@ -154,98 +239,107 @@ func NewRaftNode(logger zerolog.Logger, raftConfig config.Raft) (*Node, error) { return nil, fmt.Errorf("error creating stable store: %w", err) } - // Create the snapshot store snapshotStore, err := raft.NewFileSnapshotStore(raftDir, maxSnapshots, os.Stderr) if err != nil { return nil, fmt.Errorf("error creating snapshot store: %w", err) } - // Setup Raft communication + return &stores{ + logStore: logStore, + stableStore: stableStore, + snapshotStore: snapshotStore, + }, nil +} + +// setupTransport sets up the transport. +func setupTransport(raftAddr string) (raft.Transport, error) { addr, err := net.ResolveTCPAddr("tcp", raftAddr) if err != nil { return nil, fmt.Errorf("error resolving TCP address: %w", err) } + transport, err := raft.NewTCPTransport(raftAddr, addr, maxPool, transportTimeout, os.Stderr) if err != nil { return nil, fmt.Errorf("error creating TCP transport: %w", err) } - // Create the Raft node - raftNode, err := raft.NewRaft(RaftNodeConfig, fsm, logStore, stableStore, snapshotStore, transport) - if err != nil { - return nil, fmt.Errorf("error creating Raft: %w", err) - } - - node := &Node{ - raft: raftNode, - config: RaftNodeConfig, - Fsm: fsm, - logStore: logStore, - stableStore: stableStore, - snapshotStore: snapshotStore, - transport: transport, - Logger: logger, - Peers: raftConfig.Peers, - grpcAddr: raftConfig.GRPCAddress, - grpcIsSecure: raftConfig.IsSecure, - } + return transport, nil +} - // Initialize RPC client +// initializeNetworking initializes the networking. +func initializeNetworking(node *Node, raftConfig config.Raft) error { node.rpcClient = newRPCClient(node) - // Start RPC server if err := node.startGRPCServer(raftConfig.CertFile, raftConfig.KeyFile); err != nil { - return nil, fmt.Errorf("failed to start RPC server: %w", err) + return fmt.Errorf("failed to start RPC server: %w", err) } - ctx := context.Background() - node.StartPeerSynchronizer(ctx) + // Start peer synchronizer with a cancellable context + node.startPeerSynchronization() + + return nil +} + +// startPeerSynchronization initializes and starts the peer synchronization process. +func (n *Node) startPeerSynchronization() { + ctx, cancel := context.WithCancel(context.Background()) + n.peerSyncCancel = cancel // Store cancel function for cleanup + n.StartPeerSynchronizer(ctx) +} - // Handle bootstrapping +// configureCluster configures the cluster. +func configureCluster(node *Node, raftConfig config.Raft, nodeConfig *nodeConfig, transport raft.Transport) error { if raftConfig.IsBootstrap { - configuration := raft.Configuration{ - Servers: make([]raft.Server, len(node.Peers)), - } - for i, peer := range node.Peers { - configuration.Servers[i] = raft.Server{ - ID: raft.ServerID(peer.ID), - Address: raft.ServerAddress(peer.Address), - } - } + return bootstrapCluster(node, nodeConfig, transport) + } - selfPeer := config.RaftPeer{ - ID: string(node.config.LocalID), - Address: raftAddr, - GRPCAddress: node.grpcAddr, - } - configuration.Servers = append(configuration.Servers, raft.Server{ - ID: RaftNodeConfig.LocalID, - Address: transport.LocalAddr(), - }) + if len(node.Peers) == 0 { + node.Logger.Info().Msg("no peers found, skipping cluster connection") + return nil + } - // Add self to both configuration and FSM - fsm.mu.Lock() - if _, exists := fsm.raftPeers[string(node.config.LocalID)]; !exists { - fsm.raftPeers[string(node.config.LocalID)] = selfPeer + go func() { + if err := node.tryConnectToCluster(nodeConfig.raftAddr); err != nil { + node.Logger.Error().Err(err).Msg("failed to connect to cluster") } - fsm.mu.Unlock() - - node.raft.BootstrapCluster(configuration) - } else { - // if peers not exists skip tryConnectToCluster - if len(node.Peers) == 0 { - node.Logger.Info().Msg("no peers found, skipping cluster connection") - return node, nil + }() + + return nil +} + +func bootstrapCluster(node *Node, nodeConfig *nodeConfig, transport raft.Transport) error { + configuration := raft.Configuration{ + Servers: make([]raft.Server, len(node.Peers)), + } + + for i, peer := range node.Peers { + configuration.Servers[i] = raft.Server{ + ID: raft.ServerID(peer.ID), + Address: raft.ServerAddress(peer.Address), } + } - go func() { - if err := node.tryConnectToCluster(raftAddr); err != nil { - node.Logger.Error().Err(err).Msg("failed to connect to cluster") - } - }() + selfPeer := config.RaftPeer{ + ID: string(node.config.LocalID), + Address: nodeConfig.raftAddr, + GRPCAddress: node.grpcAddr, } - return node, nil + configuration.Servers = append(configuration.Servers, raft.Server{ + ID: nodeConfig.config.LocalID, + Address: transport.LocalAddr(), + }) + + node.Fsm.mu.Lock() + if _, exists := node.Fsm.raftPeers[string(node.config.LocalID)]; !exists { + node.Fsm.raftPeers[string(node.config.LocalID)] = selfPeer + } + node.Fsm.mu.Unlock() + + if err := node.raft.BootstrapCluster(configuration).Error(); err != nil { + return fmt.Errorf("failed to bootstrap cluster: %w", err) + } + return nil } // tryConnectToCluster attempts to connect to the cluster by sending AddPeer requests to all peers. @@ -604,6 +698,10 @@ func (n *Node) forwardToLeader(ctx context.Context, data []byte, timeout time.Du // shutdown properly, ignoring the ErrRaftShutdown error which indicates the node was already // shutdown. func (n *Node) Shutdown() error { + if n.peerSyncCancel != nil { + n.peerSyncCancel() + } + if n.rpcServer != nil { n.rpcServer.GracefulStop() }