Skip to content

Commit bff9ca6

Browse files
authored
Improve localCounter performance and memory footprint (#34)
1 parent c6b43ff commit bff9ca6

File tree

4 files changed

+136
-46
lines changed

4 files changed

+136
-46
lines changed

limit_key.go

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
package httprate
2+
3+
import (
4+
"fmt"
5+
"time"
6+
7+
"github.com/cespare/xxhash/v2"
8+
)
9+
10+
func LimitCounterKey(key string, window time.Time) uint64 {
11+
h := xxhash.New()
12+
h.WriteString(key)
13+
h.WriteString(fmt.Sprintf("%d", window.Unix()))
14+
return h.Sum64()
15+
}

limiter.go

+7-5
Original file line numberDiff line numberDiff line change
@@ -44,8 +44,10 @@ func newRateLimiter(requestLimit int, windowLength time.Duration, options ...Opt
4444

4545
if rl.limitCounter == nil {
4646
rl.limitCounter = &localCounter{
47-
counters: make(map[uint64]*count),
48-
windowLength: windowLength,
47+
latestWindow: time.Now().UTC().Truncate(windowLength),
48+
latestCounters: make(map[uint64]int),
49+
previousCounters: make(map[uint64]int),
50+
windowLength: windowLength,
4951
}
5052
}
5153
rl.limitCounter.Config(requestLimit, windowLength)
@@ -133,16 +135,16 @@ func (l *rateLimiter) Handler(next http.Handler) http.Handler {
133135
}
134136

135137
func (l *rateLimiter) calculateRate(key string, requestLimit int) (bool, float64, error) {
136-
t := time.Now().UTC()
137-
currentWindow := t.Truncate(l.windowLength)
138+
now := time.Now().UTC()
139+
currentWindow := now.Truncate(l.windowLength)
138140
previousWindow := currentWindow.Add(-l.windowLength)
139141

140142
currCount, prevCount, err := l.limitCounter.Get(key, currentWindow, previousWindow)
141143
if err != nil {
142144
return false, 0, err
143145
}
144146

145-
diff := t.Sub(currentWindow)
147+
diff := now.Sub(currentWindow)
146148
rate := float64(prevCount)*(float64(l.windowLength)-float64(diff))/float64(l.windowLength) + float64(currCount)
147149
if rate > float64(requestLimit) {
148150
return false, rate, nil

local_counter.go

+30-37
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
package httprate
22

33
import (
4-
"fmt"
54
"sync"
65
"time"
76

@@ -11,15 +10,11 @@ import (
1110
var _ LimitCounter = &localCounter{}
1211

1312
type localCounter struct {
14-
counters map[uint64]*count
15-
windowLength time.Duration
16-
lastEvict time.Time
17-
mu sync.RWMutex
18-
}
19-
20-
type count struct {
21-
value int
22-
updatedAt time.Time
13+
latestWindow time.Time
14+
previousCounters map[uint64]int
15+
latestCounters map[uint64]int
16+
windowLength time.Duration
17+
mu sync.RWMutex
2318
}
2419

2520
func (c *localCounter) Config(requestLimit int, windowLength time.Duration) {
@@ -37,17 +32,12 @@ func (c *localCounter) IncrementBy(key string, currentWindow time.Time, amount i
3732
c.mu.Lock()
3833
defer c.mu.Unlock()
3934

40-
c.evict()
35+
c.evict(currentWindow)
4136

42-
hkey := LimitCounterKey(key, currentWindow)
37+
hkey := limitCounterKey(key, currentWindow)
4338

44-
v, ok := c.counters[hkey]
45-
if !ok {
46-
v = &count{}
47-
c.counters[hkey] = v
48-
}
49-
v.value += amount
50-
v.updatedAt = time.Now()
39+
count, _ := c.latestCounters[hkey]
40+
c.latestCounters[hkey] = count + amount
5141

5242
return nil
5343
}
@@ -56,36 +46,39 @@ func (c *localCounter) Get(key string, currentWindow, previousWindow time.Time)
5646
c.mu.RLock()
5747
defer c.mu.RUnlock()
5848

59-
curr, ok := c.counters[LimitCounterKey(key, currentWindow)]
60-
if !ok {
61-
curr = &count{value: 0, updatedAt: time.Now()}
49+
if c.latestWindow == currentWindow {
50+
curr, _ := c.latestCounters[limitCounterKey(key, currentWindow)]
51+
prev, _ := c.previousCounters[limitCounterKey(key, previousWindow)]
52+
return curr, prev, nil
6253
}
63-
prev, ok := c.counters[LimitCounterKey(key, previousWindow)]
64-
if !ok {
65-
prev = &count{value: 0, updatedAt: time.Now()}
54+
55+
if c.latestWindow == previousWindow {
56+
prev, _ := c.latestCounters[limitCounterKey(key, previousWindow)]
57+
return 0, prev, nil
6658
}
6759

68-
return curr.value, prev.value, nil
60+
return 0, 0, nil
6961
}
7062

71-
func (c *localCounter) evict() {
72-
d := c.windowLength * 3
73-
74-
if time.Since(c.lastEvict) < d {
63+
func (c *localCounter) evict(currentWindow time.Time) {
64+
if c.latestWindow == currentWindow {
7565
return
7666
}
77-
c.lastEvict = time.Now()
7867

79-
for k, v := range c.counters {
80-
if time.Since(v.updatedAt) >= d {
81-
delete(c.counters, k)
82-
}
68+
previousWindow := currentWindow.Add(-c.windowLength)
69+
if c.latestWindow == previousWindow {
70+
c.latestWindow = currentWindow
71+
c.latestCounters, c.previousCounters = make(map[uint64]int), c.latestCounters
72+
return
8373
}
74+
75+
c.latestWindow = currentWindow
76+
// NOTE: Don't use clear() to keep backward-compatibility.
77+
c.previousCounters, c.latestCounters = make(map[uint64]int), make(map[uint64]int)
8478
}
8579

86-
func LimitCounterKey(key string, window time.Time) uint64 {
80+
func limitCounterKey(key string, window time.Time) uint64 {
8781
h := xxhash.New()
8882
h.WriteString(key)
89-
h.WriteString(fmt.Sprintf("%d", window.Unix()))
9083
return h.Sum64()
9184
}

local_counter_test.go

+84-4
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,94 @@ import (
88
"time"
99
)
1010

11+
func TestLocalCounter(t *testing.T) {
12+
limitCounter := &localCounter{
13+
latestWindow: time.Now().UTC().Truncate(time.Second),
14+
latestCounters: make(map[uint64]int),
15+
previousCounters: make(map[uint64]int),
16+
windowLength: time.Second,
17+
}
18+
19+
// Time = NOW()
20+
currentWindow := time.Now().UTC().Truncate(time.Second)
21+
previousWindow := currentWindow.Add(-time.Second)
22+
23+
for i := 0; i < 5; i++ {
24+
curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
25+
if curr != 0 {
26+
t.Errorf("unexpected curr = %v, expected %v", curr, 0)
27+
}
28+
if prev != 0 {
29+
t.Errorf("unexpected prev = %v, expected %v", prev, 0)
30+
}
31+
32+
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 1)
33+
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99)
34+
35+
curr, prev, _ = limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
36+
if curr != 100 {
37+
t.Errorf("unexpected curr = %v, expected %v", curr, 100)
38+
}
39+
if prev != 0 {
40+
t.Errorf("unexpected prev = %v, expected %v", prev, 0)
41+
}
42+
}
43+
44+
// Time++
45+
currentWindow = currentWindow.Add(time.Second)
46+
previousWindow = previousWindow.Add(time.Second)
47+
48+
for i := 0; i < 5; i++ {
49+
curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
50+
if curr != 0 {
51+
t.Errorf("unexpected curr = %v, expected %v", curr, 0)
52+
}
53+
if prev != 100 {
54+
t.Errorf("unexpected prev = %v, expected %v", prev, 100)
55+
}
56+
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 50)
57+
}
58+
59+
// Time++
60+
currentWindow = currentWindow.Add(time.Second)
61+
previousWindow = previousWindow.Add(time.Second)
62+
63+
for i := 0; i < 5; i++ {
64+
curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
65+
if curr != 0 {
66+
t.Errorf("unexpected curr = %v, expected %v", curr, 0)
67+
}
68+
if prev != 50 {
69+
t.Errorf("unexpected prev = %v, expected %v", prev, 50)
70+
}
71+
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99)
72+
}
73+
74+
// Time += 10
75+
currentWindow = currentWindow.Add(10 * time.Second)
76+
previousWindow = previousWindow.Add(10 * time.Second)
77+
78+
for i := 0; i < 5; i++ {
79+
curr, prev, _ := limitCounter.Get(fmt.Sprintf("key-%v", i), currentWindow, previousWindow)
80+
if curr != 0 {
81+
t.Errorf("unexpected curr = %v, expected %v", curr, 0)
82+
}
83+
if prev != 0 {
84+
t.Errorf("unexpected prev = %v, expected %v", prev, 0)
85+
}
86+
_ = limitCounter.IncrementBy(fmt.Sprintf("key-%v", i), currentWindow, 99)
87+
}
88+
}
89+
1190
func BenchmarkLocalCounter(b *testing.B) {
1291
limitCounter := &localCounter{
13-
counters: make(map[uint64]*count),
14-
windowLength: time.Second,
92+
latestWindow: time.Now().UTC().Truncate(time.Second),
93+
latestCounters: make(map[uint64]int),
94+
previousCounters: make(map[uint64]int),
95+
windowLength: time.Second,
1596
}
1697

17-
t := time.Now().UTC()
18-
currentWindow := t.Truncate(time.Second)
98+
currentWindow := time.Now().UTC().Truncate(time.Second)
1999
previousWindow := currentWindow.Add(-time.Second)
20100

21101
b.ResetTimer()

0 commit comments

Comments
 (0)