diff --git a/collect/stressRelief/stress_relief_redis.go b/collect/stressRelief/stress_relief_redis.go index 9decd2aef1..ee453b2ea1 100644 --- a/collect/stressRelief/stress_relief_redis.go +++ b/collect/stressRelief/stress_relief_redis.go @@ -46,6 +46,7 @@ type StressRelief struct { stayOnUntil time.Time minDuration time.Duration identification string + stressGossipCh chan []byte eg *errgroup.Group @@ -104,46 +105,59 @@ func (s *StressRelief) Start() error { s.stressLevels = make(map[string]stressReport) s.done = make(chan struct{}) - s.eg = &errgroup.Group{} - s.Health.Register(stressReliefHealthSource, 5*calculationInterval) s.RefineryMetrics.Register("cluster_stress_level", "gauge") s.RefineryMetrics.Register("individual_stress_level", "gauge") s.RefineryMetrics.Register("stress_relief_activated", "gauge") - if err := s.Gossip.Subscribe("stress_level", s.onStressLevelMessage); err != nil { - return err - } + s.stressGossipCh = s.Gossip.Subscribe("stress_level", 20) + s.eg = &errgroup.Group{} + s.eg.Go(s.monitor) + + return nil +} - // start our monitor goroutine that periodically calls recalc - s.eg.Go(func() error { - tick := time.NewTicker(calculationInterval) - defer tick.Stop() - for { - select { - case <-tick.C: - currentLevel := s.Recalc() - // publish the stress level to the rest of the cluster - msg := stressLevelMessage{ - level: currentLevel, - id: s.identification, - } - err := s.Gossip.Publish("stress_level", msg.ToBytes()) - if err != nil { - s.Logger.Error().Logf("error publishing stress level: %s", err) - } else { - s.Health.Ready(stressReliefHealthSource, true) - } - case <-s.done: - s.Logger.Debug().Logf("Stopping StressRelief system") - return nil +func (s *StressRelief) monitor() error { + tick := time.NewTicker(calculationInterval) + defer tick.Stop() + for { + select { + case <-tick.C: + currentLevel := s.Recalc() + // publish the stress level to the rest of the cluster + msg := stressLevelMessage{ + level: currentLevel, + id: s.identification, + } + err := s.Gossip.Publish("stress_level", msg.ToBytes()) + if err != nil { + s.Logger.Error().Logf("error publishing stress level: %s", err) + } else { + s.Health.Ready(stressReliefHealthSource, true) } - } - }) + case data := <-s.stressGossipCh: + msg, err := newMessageFromBytes(data) + if err != nil { + s.Logger.Error().Logf("error parsing stress level message: %s", err) + continue + } + + s.lock.Lock() + s.stressLevels[msg.id] = stressReport{ + key: msg.id, + level: msg.level, + timestamp: s.Clock.Now(), + } + s.lock.Unlock() + + case <-s.done: + s.Logger.Debug().Logf("Stopping StressRelief system") + return nil + } + } - return nil } func (s *StressRelief) Stop() error { @@ -311,22 +325,6 @@ func (s *StressRelief) deterministicFraction() uint { return uint(float64(s.overallStressLevel-s.activateLevel)/float64(100-s.activateLevel)*100 + 0.5) } -func (s *StressRelief) onStressLevelMessage(data []byte) { - msg, err := newMessageFromBytes(data) - if err != nil { - s.Logger.Error().Logf("error parsing stress level message: %s", err) - return - } - - s.lock.Lock() - s.stressLevels[msg.id] = stressReport{ - key: msg.id, - level: msg.level, - timestamp: s.Clock.Now(), - } - s.lock.Unlock() -} - type stressReport struct { key string level uint diff --git a/internal/gossip/gossip.go b/internal/gossip/gossip.go index 29b8a11bca..3662a00086 100644 --- a/internal/gossip/gossip.go +++ b/internal/gossip/gossip.go @@ -2,10 +2,8 @@ package gossip import ( "bytes" - "errors" "github.com/facebookgo/startstop" - "golang.org/x/sync/errgroup" ) // Gossiper is an interface for broadcasting messages to all receivers @@ -14,80 +12,15 @@ type Gossiper interface { // Publish sends a message to all peers listening on the channel Publish(channel string, value []byte) error - // Subscribe listens for messages on the channel - Subscribe(channel string, callback func(data []byte)) error + // Subscribe returns a Go channel that will receive messages from the Gossip channel + // (Redis already called the thing we listen to a channel, so we have to live with that) + // The channel has a buffer of depth; if the buffer is full, messages will be dropped. + Subscribe(channel string, depth int) chan []byte startstop.Starter startstop.Stopper } -var _ Gossiper = &InMemoryGossip{} - -// InMemoryGossip is a Gossiper that uses an in-memory channel -type InMemoryGossip struct { - channel chan []byte - subscribers map[string][]func(data []byte) - - done chan struct{} - eg *errgroup.Group -} - -func (g *InMemoryGossip) Publish(channel string, value []byte) error { - msg := message{ - key: channel, - data: value, - } - - select { - case <-g.done: - return errors.New("gossip has been stopped") - case g.channel <- msg.ToBytes(): - default: - } - return nil -} - -func (g *InMemoryGossip) Subscribe(channel string, callback func(data []byte)) error { - select { - case <-g.done: - return errors.New("gossip has been stopped") - default: - } - - g.subscribers[channel] = append(g.subscribers[channel], callback) - return nil -} - -func (g *InMemoryGossip) Start() error { - g.channel = make(chan []byte, 10) - g.eg = &errgroup.Group{} - g.subscribers = make(map[string][]func(data []byte)) - g.done = make(chan struct{}) - - g.eg.Go(func() error { - for { - select { - case <-g.done: - return nil - case value := <-g.channel: - msg := newMessageFromBytes(value) - callbacks := g.subscribers[msg.key] - - for _, cb := range callbacks { - cb(msg.data) - } - } - } - }) - - return nil -} -func (g *InMemoryGossip) Stop() error { - close(g.done) - close(g.channel) - return g.eg.Wait() -} - type message struct { key string data []byte diff --git a/internal/gossip/gossip_inmem.go b/internal/gossip/gossip_inmem.go new file mode 100644 index 0000000000..b476781d0e --- /dev/null +++ b/internal/gossip/gossip_inmem.go @@ -0,0 +1,94 @@ +package gossip + +import ( + "errors" + "sync" + + "github.com/honeycombio/refinery/logger" + "golang.org/x/sync/errgroup" +) + +// InMemoryGossip is a Gossiper that uses an in-memory channel +type InMemoryGossip struct { + Logger logger.Logger `inject:""` + gossipCh chan []byte + subscriptions map[string][]chan []byte + + done chan struct{} + mut sync.RWMutex + eg *errgroup.Group +} + +var _ Gossiper = &InMemoryGossip{} + +func (g *InMemoryGossip) Publish(channel string, value []byte) error { + msg := message{ + key: channel, + data: value, + } + + select { + case <-g.done: + return errors.New("gossip has been stopped") + case g.gossipCh <- msg.ToBytes(): + default: + g.Logger.Warn().WithFields(map[string]interface{}{ + "channel": channel, + "msg": string(value), + }).Logf("Unable to publish message") + } + return nil +} + +func (g *InMemoryGossip) Subscribe(channel string, depth int) chan []byte { + select { + case <-g.done: + return nil + default: + } + + ch := make(chan []byte, depth) + g.mut.Lock() + g.subscriptions[channel] = append(g.subscriptions[channel], ch) + g.mut.Unlock() + + return ch +} + +func (g *InMemoryGossip) Start() error { + g.gossipCh = make(chan []byte, 10) + g.eg = &errgroup.Group{} + g.subscriptions = make(map[string][]chan []byte) + g.done = make(chan struct{}) + + g.eg.Go(func() error { + for { + select { + case <-g.done: + return nil + case value := <-g.gossipCh: + msg := newMessageFromBytes(value) + g.mut.RLock() + for _, ch := range g.subscriptions[msg.key] { + select { + case ch <- msg.data: + default: + g.Logger.Warn().WithFields(map[string]interface{}{ + "channel": msg.key, + "msg": string(msg.data), + }).Logf("Unable to forward message") + } + } + g.mut.RUnlock() + } + } + }) + + return nil +} + +func (g *InMemoryGossip) Stop() error { + close(g.done) + close(g.gossipCh) + return g.eg.Wait() +} diff --git a/internal/gossip/gossip_redis.go b/internal/gossip/gossip_redis.go index b3392fd0d2..93af233793 100644 --- a/internal/gossip/gossip_redis.go +++ b/internal/gossip/gossip_redis.go @@ -21,9 +21,9 @@ type GossipRedis struct { Logger logger.Logger `inject:""` eg *errgroup.Group - lock sync.RWMutex - subscribers map[string][]func(data []byte) - done chan struct{} + lock sync.RWMutex + subscriptions map[string][]chan []byte + done chan struct{} startstop.Stopper } @@ -35,7 +35,7 @@ type GossipRedis struct { func (g *GossipRedis) Start() error { g.eg = &errgroup.Group{} g.done = make(chan struct{}) - g.subscribers = make(map[string][]func(data []byte)) + g.subscriptions = make(map[string][]chan []byte) g.eg.Go(func() error { for { @@ -46,14 +46,22 @@ func (g *GossipRedis) Start() error { err := g.Redis.ListenPubSubChannels(nil, func(channel string, b []byte) { msg := newMessageFromBytes(b) g.lock.RLock() - callbacks := g.subscribers[msg.key] + chans := g.subscriptions[msg.key] g.lock.RUnlock() - for _, cb := range callbacks { - cb(msg.data) + // we never block on sending to a subscriber; if it's full, we drop the message + for _, ch := range chans { + select { + case ch <- msg.data: + default: + g.Logger.Warn().WithFields(map[string]interface{}{ + "channel": msg.key, + "msg": string(msg.data), + }).Logf("Unable to forward message") + } } }, g.done, "refinery-gossip") if err != nil { - g.Logger.Debug().Logf("Error listening to refinery-gossip channel: %v", err) + g.Logger.Warn().Logf("Error listening to refinery-gossip channel: %v", err) } } } @@ -68,20 +76,21 @@ func (g *GossipRedis) Stop() error { return g.eg.Wait() } -// Subscribe registers a callback for a given channel. -func (g *GossipRedis) Subscribe(channel string, cb func(data []byte)) error { +// Subscribe returns a channel that will receive messages from the Gossip channel. +// The channel has a buffer of depth; if the buffer is full, messages will be dropped. +func (g *GossipRedis) Subscribe(channel string, depth int) chan []byte { select { case <-g.done: - return errors.New("gossip has been stopped") + return nil default: } + ch := make(chan []byte, depth) g.lock.Lock() defer g.lock.Unlock() - g.subscribers[channel] = append(g.subscribers[channel], cb) - - return nil + g.subscriptions[channel] = append(g.subscriptions[channel], ch) + return ch } // Publish sends a message to all subscribers of a given channel. diff --git a/internal/gossip/gossip_redis_test.go b/internal/gossip/gossip_redis_test.go index e12ebada56..c0fdc3195e 100644 --- a/internal/gossip/gossip_redis_test.go +++ b/internal/gossip/gossip_redis_test.go @@ -2,15 +2,17 @@ package gossip import ( "testing" + "time" "github.com/honeycombio/refinery/config" "github.com/honeycombio/refinery/logger" "github.com/honeycombio/refinery/metrics" "github.com/honeycombio/refinery/redis" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) -func TestRoundTrip(t *testing.T) { +func TestRoundTripChanRedis(t *testing.T) { cfg := config.MockConfig{ GetRedisHostVal: "localhost:6379", } @@ -30,18 +32,74 @@ func TestRoundTrip(t *testing.T) { require.NoError(t, g.Start()) - // Test that we can register a handler - require.NoError(t, g.Subscribe("test", func(data []byte) { - require.Equal(t, "hi", string(data)) - })) + ch := g.Subscribe("test", 10) + require.NotNil(t, ch) - require.NoError(t, g.Subscribe("test2", func(data []byte) { - require.Equal(t, "bye", string(data)) - })) + ch2 := g.Subscribe("test2", 10) + require.NotNil(t, ch2) + + // This test is flaky unless we throw away the first message + g.Publish("throwaway", []byte("nevermind")) // Test that we can publish a message require.NoError(t, g.Publish("test", []byte("hi"))) require.NoError(t, g.Publish("test2", []byte("bye"))) + require.Eventually(t, func() bool { + time.Sleep(100 * time.Millisecond) + return len(ch) == 1 && len(ch2) == 1 + }, 5*time.Second, 200*time.Millisecond) + + select { + case hi := <-ch: + require.Equal(t, "hi", string(hi)) + default: + t.Fatal("expected to receive a message on channel 'test'") + } + + select { + case bye := <-ch2: + require.Equal(t, "bye", string(bye)) + default: + t.Fatal("expected to receive a message on channel 'test2'") + } + + require.NoError(t, g.Stop()) +} + +func TestRoundTripChanInMem(t *testing.T) { + g := &InMemoryGossip{} + + require.NoError(t, g.Start()) + + ch := g.Subscribe("test", 10) + require.NotNil(t, ch) + + ch2 := g.Subscribe("test2", 10) + require.NotNil(t, ch2) + + // Test that we can publish a message + require.NoError(t, g.Publish("test", []byte("hi"))) + require.NoError(t, g.Publish("test2", []byte("bye"))) + + assert.Eventually(t, func() bool { + time.Sleep(100 * time.Millisecond) + return len(ch) == 1 && len(ch2) == 1 + }, 5*time.Second, 200*time.Millisecond) + + select { + case hi := <-ch: + require.Equal(t, "hi", string(hi)) + default: + t.Fatal("expected to receive a message on channel 'test'") + } + + select { + case bye := <-ch2: + require.Equal(t, "bye", string(bye)) + default: + t.Fatal("expected to receive a message on channel 'test2'") + } + require.NoError(t, g.Stop()) }