From 6637ff611f5392f6622cb805cd9108902bae95c6 Mon Sep 17 00:00:00 2001 From: Matteo Merli Date: Fri, 26 Jan 2024 10:44:22 -0800 Subject: [PATCH 1/2] Close session in go client close --- oxia/async_client_impl.go | 7 ++-- oxia/async_client_impl_test.go | 50 ++++++++++++++++++++++++--- oxia/sessions.go | 54 +++++++++++++++++++++--------- server/kv/notifications_tracker.go | 7 ++++ server/session_manager.go | 1 + 5 files changed, 95 insertions(+), 24 deletions(-) diff --git a/oxia/async_client_impl.go b/oxia/async_client_impl.go index 8c5bed03..136033c7 100644 --- a/oxia/async_client_impl.go +++ b/oxia/async_client_impl.go @@ -98,13 +98,14 @@ func NewAsyncClient(serviceAddress string, opts ...ClientOption) (AsyncClient, e } func (c *clientImpl) Close() error { - c.cancel() - - return multierr.Combine( + err := multierr.Combine( + c.sessions.Close(), c.writeBatchManager.Close(), c.readBatchManager.Close(), c.clientPool.Close(), ) + c.cancel() + return err } func (c *clientImpl) Put(key string, value []byte, options ...PutOption) <-chan PutResult { diff --git a/oxia/async_client_impl_test.go b/oxia/async_client_impl_test.go index b15961f3..08eaae99 100644 --- a/oxia/async_client_impl_test.go +++ b/oxia/async_client_impl_test.go @@ -30,6 +30,7 @@ import ( ) func init() { + common.LogJSON = false common.ConfigureLogger() } @@ -295,9 +296,8 @@ func TestAsyncClientImpl_OverrideEphemeral(t *testing.T) { } func TestAsyncClientImpl_ClientIdentity(t *testing.T) { - identity1 := newKey() client1, err := NewSyncClient(serviceAddress, - WithIdentity(identity1), + WithIdentity("client-1"), ) assert.NoError(t, err) @@ -306,10 +306,11 @@ func TestAsyncClientImpl_ClientIdentity(t *testing.T) { assert.NoError(t, err) assert.True(t, version.Ephemeral) - assert.Equal(t, identity1, version.ClientIdentity) + assert.Equal(t, "client-1", version.ClientIdentity) client2, err := NewSyncClient(serviceAddress, WithSessionTimeout(2*time.Second), + WithIdentity("client-2"), ) assert.NoError(t, err) @@ -319,14 +320,53 @@ func TestAsyncClientImpl_ClientIdentity(t *testing.T) { assert.EqualValues(t, 0, version.ModificationsCount) assert.Equal(t, "v1", string(res)) assert.True(t, version.Ephemeral) - assert.Equal(t, identity1, version.ClientIdentity) + assert.Equal(t, "client-1", version.ClientIdentity) version, err = client2.Put(context.Background(), k, []byte("v2"), Ephemeral()) assert.NoError(t, err) assert.True(t, version.Ephemeral) - assert.NotSame(t, "", version.ClientIdentity) + assert.Equal(t, "client-2", version.ClientIdentity) assert.NoError(t, client1.Close()) assert.NoError(t, client2.Close()) } + +func TestSyncClientImpl_SessionNotifications(t *testing.T) { + standaloneServer, err := server.NewStandalone(server.NewTestConfig(t.TempDir())) + assert.NoError(t, err) + + serviceAddress := fmt.Sprintf("localhost:%d", standaloneServer.RpcPort()) + client1, err := NewSyncClient(serviceAddress, WithIdentity("client-1")) + assert.NoError(t, err) + + client2, err := NewSyncClient(serviceAddress, WithIdentity("client-1")) + assert.NoError(t, err) + + notifications, err := client2.GetNotifications() + assert.NoError(t, err) + + ctx := context.Background() + + s1, _ := client1.Put(ctx, "/a", []byte("0"), Ephemeral()) + + n := <-notifications.Ch() + assert.Equal(t, KeyCreated, n.Type) + assert.Equal(t, "/a", n.Key) + assert.Equal(t, s1.VersionId, n.VersionId) + + err = client1.Close() + assert.NoError(t, err) + + select { + case n = <-notifications.Ch(): + assert.Equal(t, KeyDeleted, n.Type) + assert.Equal(t, "/a", n.Key) + case <-time.After(3 * time.Second): + assert.Fail(t, "read from channel timed out") + + } + + assert.NoError(t, client2.Close()) + assert.NoError(t, standaloneServer.Close()) +} diff --git a/oxia/sessions.go b/oxia/sessions.go index 7b70e096..8eff1168 100644 --- a/oxia/sessions.go +++ b/oxia/sessions.go @@ -18,6 +18,7 @@ import ( "context" "errors" "fmt" + "go.uber.org/multierr" "log/slog" "sync" "time" @@ -40,6 +41,7 @@ func newSessions(ctx context.Context, shardManager internal.ShardManager, pool c clientOpts: options, log: slog.With( slog.String("component", "oxia-session-manager"), + slog.String("client-identity", options.identity), ), } return s @@ -71,13 +73,15 @@ func (s *sessions) startSession(shardId int64) *clientSession { cs := &clientSession{ shardId: shardId, sessions: s, - ctx: s.ctx, started: make(chan error), log: slog.With( slog.String("component", "session"), slog.Int64("shard", shardId), ), } + + cs.ctx, cs.cancel = context.WithCancel(s.ctx) + cs.log.Debug("Creating session") go common.DoWithLabels( cs.ctx, @@ -90,6 +94,15 @@ func (s *sessions) startSession(shardId int64) *clientSession { return cs } +func (s *sessions) Close() error { + var err error + for _, cs := range s.sessionsByShard { + err = multierr.Append(err, cs.Close()) + } + + return err +} + type clientSession struct { sync.Mutex started chan error @@ -98,6 +111,7 @@ type clientSession struct { log *slog.Logger sessions *sessions ctx context.Context + cancel context.CancelFunc } func (cs *clientSession) executeWithId(callback func(int64, error)) { @@ -163,6 +177,7 @@ func (cs *clientSession) createSession() error { cs.sessionId = sessionId cs.log = cs.log.With( slog.Int64("session-id", sessionId), + slog.String("client-identity", cs.sessions.clientIdentity), ) close(cs.started) cs.log.Debug("Successfully created session") @@ -200,7 +215,7 @@ func (cs *clientSession) createSession() error { ) }) - if !errors.Is(err, context.Canceled) { + if err != nil && !errors.Is(err, context.Canceled) { cs.log.Error( "Failed to keep alive session", slog.Any("error", err), @@ -217,11 +232,30 @@ func (cs *clientSession) getRpc() (proto.OxiaClientClient, error) { return cs.sessions.pool.GetClientRpc(leader) } +func (cs *clientSession) Close() error { + cs.cancel() + + rpc, err := cs.getRpc() + if err != nil { + return err + } + ctx, cancel := context.WithTimeout(cs.sessions.ctx, cs.sessions.clientOpts.requestTimeout) + defer cancel() + + if _, err = rpc.CloseSession(ctx, &proto.CloseSessionRequest{ + ShardId: cs.shardId, + SessionId: cs.sessionId, + }); err != nil { + return err + } + return nil +} + func (cs *clientSession) keepAlive() error { cs.sessions.Lock() cs.Lock() timeout := cs.sessions.clientOpts.sessionTimeout - ctx := cs.sessions.ctx + ctx := cs.ctx shardId := cs.shardId sessionId := cs.sessionId cs.Unlock() @@ -248,19 +282,7 @@ func (cs *clientSession) keepAlive() error { return err } case <-ctx.Done(): - ctx, cancel := context.WithTimeout(context.Background(), cs.sessions.clientOpts.requestTimeout) - rpc, err = cs.getRpc() - if err != nil { - cancel() - return err - } - _, err = rpc.CloseSession(ctx, &proto.CloseSessionRequest{ - ShardId: shardId, - SessionId: sessionId, - }) - - cancel() - return err + return nil } } } diff --git a/server/kv/notifications_tracker.go b/server/kv/notifications_tracker.go index 1ebf3b5c..ff580e0e 100644 --- a/server/kv/notifications_tracker.go +++ b/server/kv/notifications_tracker.go @@ -19,6 +19,7 @@ import ( "fmt" "log/slog" "math" + "strings" "sync" "sync/atomic" "time" @@ -58,6 +59,9 @@ func newNotifications(shardId int64, offset int64, timestamp uint64) *notificati } func (n *notifications) Modified(key string, versionId, modificationsCount int64) { + if strings.HasPrefix(key, common.InternalKeyPrefix) { + return + } nType := proto.NotificationType_KEY_CREATED if modificationsCount > 0 { nType = proto.NotificationType_KEY_MODIFIED @@ -69,6 +73,9 @@ func (n *notifications) Modified(key string, versionId, modificationsCount int64 } func (n *notifications) Deleted(key string) { + if strings.HasPrefix(key, common.InternalKeyPrefix) { + return + } n.batch.Notifications[key] = &proto.Notification{ Type: proto.NotificationType_KEY_DELETED, } diff --git a/server/session_manager.go b/server/session_manager.go index b0820304..499e1708 100644 --- a/server/session_manager.go +++ b/server/session_manager.go @@ -137,6 +137,7 @@ func (sm *sessionManager) createSession(request *proto.CreateSessionRequest, min metadata := proto.SessionMetadataFromVTPool() metadata.TimeoutMs = uint32(timeout.Milliseconds()) + metadata.Identity = request.ClientIdentity defer metadata.ReturnToVTPool() marshalledMetadata, err := metadata.MarshalVT() From ecf38bb57faaf0a49f600e6cb9c85122af404f85 Mon Sep 17 00:00:00 2001 From: Matteo Merli Date: Fri, 26 Jan 2024 14:38:12 -0800 Subject: [PATCH 2/2] Fixed lint --- oxia/async_client_impl_test.go | 1 - oxia/sessions.go | 3 ++- server/session_manager_test.go | 3 ++- 3 files changed, 4 insertions(+), 3 deletions(-) diff --git a/oxia/async_client_impl_test.go b/oxia/async_client_impl_test.go index 08eaae99..aabaf32d 100644 --- a/oxia/async_client_impl_test.go +++ b/oxia/async_client_impl_test.go @@ -364,7 +364,6 @@ func TestSyncClientImpl_SessionNotifications(t *testing.T) { assert.Equal(t, "/a", n.Key) case <-time.After(3 * time.Second): assert.Fail(t, "read from channel timed out") - } assert.NoError(t, client2.Close()) diff --git a/oxia/sessions.go b/oxia/sessions.go index 8eff1168..8d6ca805 100644 --- a/oxia/sessions.go +++ b/oxia/sessions.go @@ -18,11 +18,12 @@ import ( "context" "errors" "fmt" - "go.uber.org/multierr" "log/slog" "sync" "time" + "go.uber.org/multierr" + "github.com/cenkalti/backoff/v4" "google.golang.org/grpc/status" diff --git a/server/session_manager_test.go b/server/session_manager_test.go index e562887c..3da30b1d 100644 --- a/server/session_manager_test.go +++ b/server/session_manager_test.go @@ -17,11 +17,12 @@ package server import ( "context" "errors" - "github.com/streamnative/oxia/server/wal" "io" "testing" "time" + "github.com/streamnative/oxia/server/wal" + "github.com/stretchr/testify/assert" pb "google.golang.org/protobuf/proto"