Skip to content

Commit dfd9aba

Browse files
committed
aghuser: imp code
1 parent b6e7522 commit dfd9aba

File tree

2 files changed

+68
-41
lines changed

2 files changed

+68
-41
lines changed

internal/aghuser/sessionstorage.go

+59-37
Original file line numberDiff line numberDiff line change
@@ -108,8 +108,8 @@ func NewDefaultSessionStorage(
108108
if err != nil {
109109
ds.logger.ErrorContext(ctx, "opening db %q: %w", dbFilename, err)
110110
if errors.Is(err, berrors.ErrInvalid) {
111-
const s = `AdGuard Home cannot be initialized due to an incompatible file system.
112-
Please read the explanation here: https://github.com/AdguardTeam/AdGuardHome/wiki/Getting-Started#limitations`
111+
const s = "AdGuard Home cannot be initialized due to an incompatible file system.\n" +
112+
"Please read the explanation here: https://adguard-dns.io/kb/adguard-home/getting-started/#limitations"
113113
slogutil.PrintLines(ctx, ds.logger, slog.LevelError, "", s)
114114
}
115115

@@ -133,11 +133,9 @@ func (ds *DefaultSessionStorage) loadSessions(ctx context.Context) (err error) {
133133

134134
needRollback := true
135135
defer func() {
136-
if !needRollback {
137-
return
136+
if needRollback {
137+
err = errors.WithDeferred(err, tx.Rollback())
138138
}
139-
140-
err = errors.Join(err, tx.Rollback())
141139
}()
142140

143141
bkt := tx.Bucket([]byte(bboltBucketSessions))
@@ -178,21 +176,61 @@ func (ds *DefaultSessionStorage) processSessions(
178176
ctx context.Context,
179177
bkt *bbolt.Bucket,
180178
) (removed int, err error) {
181-
now := ds.clock.Now()
182179
invalidSessions := [][]byte{}
183180

184-
err = bkt.ForEach(func(k, v []byte) (txErr error) {
185-
s, txErr := deserialize(v)
186-
if txErr != nil || now.After(s.Expire) {
187-
invalidSessions = append(invalidSessions, k)
181+
err = bkt.ForEach(ds.bboltSessionHandler(ctx, &invalidSessions))
182+
if err != nil {
183+
return 0, fmt.Errorf("iterating over sessions: %w", err)
184+
}
188185

189-
return txErr
186+
var errs []error
187+
for _, s := range invalidSessions {
188+
if err = bkt.Delete(s); err != nil {
189+
errs = append(errs, err)
190190
}
191+
}
192+
193+
if err = errors.Join(errs...); err != nil {
194+
return 0, fmt.Errorf("deleting sessions: %w", err)
195+
}
196+
197+
return len(invalidSessions), nil
198+
}
199+
200+
// bboltSessionHandler returns a function for [bbolt.Bucket.ForEach] that
201+
// iterates over stored sessions, deserializes them, and logs any errors
202+
// encountered. The returned error is always nil, as these errors are
203+
// considered non-critical to stop the iteration process.
204+
func (ds *DefaultSessionStorage) bboltSessionHandler(
205+
ctx context.Context,
206+
invalidSessions *[][]byte,
207+
) (fn func(k, v []byte) (err error)) {
208+
now := ds.clock.Now()
191209

192-
var u *User
193-
u, txErr = ds.userDB.ByLogin(ctx, s.UserLogin)
194-
if txErr != nil {
195-
invalidSessions = append(invalidSessions, k)
210+
return func(k, v []byte) (err error) {
211+
s, err := bboltDecode(v)
212+
if err != nil {
213+
*invalidSessions = append(*invalidSessions, k)
214+
ds.logger.DebugContext(ctx, "deserializing session", slogutil.KeyError, err)
215+
216+
return nil
217+
}
218+
219+
if now.After(s.Expire) {
220+
*invalidSessions = append(*invalidSessions, k)
221+
222+
return nil
223+
}
224+
225+
u, err := ds.userDB.ByLogin(ctx, s.UserLogin)
226+
if err != nil {
227+
// Should not happen, as it currently always returns nil for error.
228+
panic(err)
229+
}
230+
231+
if u == nil {
232+
ds.logger.DebugContext(ctx, "no saved user by name", "name", s.UserLogin)
233+
*invalidSessions = append(*invalidSessions, k)
196234

197235
return nil
198236
}
@@ -203,19 +241,7 @@ func (ds *DefaultSessionStorage) processSessions(
203241
ds.sessions[t] = s
204242

205243
return nil
206-
})
207-
if err != nil {
208-
// Don't wrap the error because it's informative enough as is.
209-
return 0, err
210244
}
211-
212-
for _, s := range invalidSessions {
213-
if err = bkt.Delete(s); err != nil {
214-
return 0, fmt.Errorf("deleting session: %w", err)
215-
}
216-
}
217-
218-
return len(invalidSessions), nil
219245
}
220246

221247
// bboltBucketSessions is the name of the bucket storing web user sessions in
@@ -232,12 +258,8 @@ const (
232258
bboltSessionNameLen = 2
233259
)
234260

235-
// deserialize decodes a binary data into a session.
236-
//
237-
// TODO(s.chzhen): !! Improve naming.
238-
func deserialize(data []byte) (s *Session, err error) {
239-
defer func() { err = errors.Annotate(err, "deserializing session: %w") }()
240-
261+
// bboltDecode deserializes decodes a binary data into a session.
262+
func bboltDecode(data []byte) (s *Session, err error) {
241263
if len(data) < bboltSessionExpireLen+bboltSessionNameLen {
242264
return nil, fmt.Errorf("length of the data is less than expected: got %d", len(data))
243265
}
@@ -259,8 +281,8 @@ func deserialize(data []byte) (s *Session, err error) {
259281
}, nil
260282
}
261283

262-
// serialize encodes a session properties into a binary data.
263-
func serialize(s *Session) (data []byte) {
284+
// bboltEncode serializes a session properties into a binary data.
285+
func bboltEncode(s *Session) (data []byte) {
264286
data = make([]byte, bboltSessionExpireLen+bboltSessionNameLen+len(s.UserLogin))
265287

266288
expireData := data[:bboltSessionExpireLen]
@@ -319,7 +341,7 @@ func (ds *DefaultSessionStorage) store(s *Session) (err error) {
319341
return fmt.Errorf("creating bucket: %w", err)
320342
}
321343

322-
err = bkt.Put(s.Token[:], serialize(s))
344+
err = bkt.Put(s.Token[:], bboltEncode(s))
323345
if err != nil {
324346
return fmt.Errorf("putting data: %w", err)
325347
}

internal/aghuser/sessionstorage_test.go

+9-4
Original file line numberDiff line numberDiff line change
@@ -20,16 +20,21 @@ func TestDefaultSessionStorage(t *testing.T) {
2020
)
2121

2222
var (
23-
ctx = testutil.ContextWithTimeout(t, testTimeout)
24-
logger = slogutil.NewDiscardLogger()
23+
ctx = testutil.ContextWithTimeout(t, testTimeout)
24+
logger = slogutil.NewDiscardLogger()
25+
)
26+
27+
const (
2528
sessionTTL = time.Minute
29+
timeStep = time.Second
2630
)
2731

28-
// Set up a mock clock to test expired sessions.
32+
// Set up a mock clock to test expired sessions. Each call to [clock.Now]
33+
// will return the [date] incremented by [timeStep].
2934
date := time.Now()
3035
clock := &faketime.Clock{
3136
OnNow: func() (now time.Time) {
32-
date = date.Add(time.Second)
37+
date = date.Add(timeStep)
3338

3439
return date
3540
},

0 commit comments

Comments
 (0)