diff --git a/server/session.go b/server/session.go index 2b66f61b..226b707a 100644 --- a/server/session.go +++ b/server/session.go @@ -76,6 +76,8 @@ func startSession(sessionId SessionId, sessionMetadata *proto.SessionMetadata, s } func (s *session) closeChannels() { + s.Lock() + defer s.Unlock() s.cancel() if s.heartbeatCh != nil { close(s.heartbeatCh) @@ -84,11 +86,6 @@ func (s *session) closeChannels() { s.log.Debug("Session channels closed") } -func (s *session) close() error { - s.log.Info("Session closing") - return s.delete() -} - func (s *session) delete() error { // Delete ephemeral data associated with this session sessionKey := SessionKey(s.id) @@ -171,7 +168,6 @@ func (s *session) waitForHeartbeats() { case <-timeoutCh: s.log.Warn("Session expired") - s.Lock() s.closeChannels() err := s.delete() @@ -181,7 +177,6 @@ func (s *session) waitForHeartbeats() { slog.Any("error", err), ) } - s.Unlock() s.sm.Lock() s.sm.sessions.Remove(s.id) diff --git a/server/session_manager.go b/server/session_manager.go index 8d9bd134..717bac62 100644 --- a/server/session_manager.go +++ b/server/session_manager.go @@ -202,10 +202,9 @@ func (sm *sessionManager) CloseSession(request *proto.CloseSessionRequest) (*pro } sm.sessions.Remove(s.id) sm.Unlock() - s.Lock() - defer s.Unlock() s.closeChannels() - err = s.close() + s.log.Info("Session closing") + err = s.delete() if err != nil { return nil, err } @@ -295,9 +294,7 @@ func (sm *sessionManager) Close() error { sm.cancel() for _, s := range sm.sessions.Values() { sm.sessions.Remove(s.id) - s.Lock() s.closeChannels() - s.Unlock() } sm.activeSessions.Unregister()