diff --git a/server/session.go b/server/session.go index 226b707a..e3a4dfc4 100644 --- a/server/session.go +++ b/server/session.go @@ -17,6 +17,7 @@ package server import ( "context" "fmt" + "io" "log/slog" "net/url" "sync" @@ -30,6 +31,7 @@ import ( // --- Session type session struct { + io.Closer sync.Mutex id SessionId clientIdentity string @@ -75,7 +77,7 @@ func startSession(sessionId SessionId, sessionMetadata *proto.SessionMetadata, s return s } -func (s *session) closeChannels() { +func (s *session) Close() { s.Lock() defer s.Unlock() s.cancel() @@ -168,7 +170,7 @@ func (s *session) waitForHeartbeats() { case <-timeoutCh: s.log.Warn("Session expired") - s.closeChannels() + s.Close() err := s.delete() if err != nil { diff --git a/server/session_manager.go b/server/session_manager.go index 717bac62..648c9210 100644 --- a/server/session_manager.go +++ b/server/session_manager.go @@ -202,8 +202,9 @@ func (sm *sessionManager) CloseSession(request *proto.CloseSessionRequest) (*pro } sm.sessions.Remove(s.id) sm.Unlock() - s.closeChannels() + s.log.Info("Session closing") + s.Close() err = s.delete() if err != nil { return nil, err @@ -294,7 +295,7 @@ func (sm *sessionManager) Close() error { sm.cancel() for _, s := range sm.sessions.Values() { sm.sessions.Remove(s.id) - s.closeChannels() + s.Close() } sm.activeSessions.Unregister()