diff --git a/app/app_test.go b/app/app_test.go index 3b3904cc4c..8309a5ea1e 100644 --- a/app/app_test.go +++ b/app/app_test.go @@ -33,6 +33,7 @@ import ( "github.com/honeycombio/refinery/internal/peer" "github.com/honeycombio/refinery/logger" "github.com/honeycombio/refinery/metrics" + "github.com/honeycombio/refinery/pubsub" "github.com/honeycombio/refinery/sample" "github.com/honeycombio/refinery/sharder" "github.com/honeycombio/refinery/transmit" @@ -198,6 +199,8 @@ func newStartedApp( &inject.Object{Value: samplerFactory}, &inject.Object{Value: &health.Health{}}, &inject.Object{Value: clockwork.NewRealClock()}, + &inject.Object{Value: &pubsub.LocalPubSub{}}, + &inject.Object{Value: &collect.EMAThroughputCalculator{}, Name: "throughputCalculator"}, &inject.Object{Value: &collect.MockStressReliever{}, Name: "stressRelief"}, &inject.Object{Value: &a}, ) diff --git a/cmd/refinery/main.go b/cmd/refinery/main.go index 0f33a28d5b..be2b6d695c 100644 --- a/cmd/refinery/main.go +++ b/cmd/refinery/main.go @@ -260,6 +260,7 @@ func main() { {Value: version, Name: "version"}, {Value: samplerFactory}, {Value: stressRelief, Name: "stressRelief"}, + {Value: &collect.EMAThroughputCalculator{}, Name: "throughputCalculator"}, {Value: &health.Health{}}, {Value: &configwatcher.ConfigWatcher{}}, {Value: &a}, diff --git a/collect/collect.go b/collect/collect.go index 007d117164..cba3844483 100644 --- a/collect/collect.go +++ b/collect/collect.go @@ -67,11 +67,12 @@ type InMemCollector struct { Health health.Recorder `inject:""` Sharder sharder.Sharder `inject:""` - Transmission transmit.Transmission `inject:"upstreamTransmission"` - Metrics metrics.Metrics `inject:"genericMetrics"` - SamplerFactory *sample.SamplerFactory `inject:""` - StressRelief StressReliever `inject:"stressRelief"` - Peers peer.Peers `inject:""` + Transmission transmit.Transmission `inject:"upstreamTransmission"` + Metrics metrics.Metrics `inject:"genericMetrics"` + SamplerFactory *sample.SamplerFactory `inject:""` + StressRelief StressReliever `inject:"stressRelief"` + ThroughputCalculator *EMAThroughputCalculator `inject:"throughputCalculator"` + Peers peer.Peers `inject:""` // For test use only BlockOnAddSpan bool @@ -128,6 +129,10 @@ func (i *InMemCollector) Start() error { i.Metrics.Register("trace_send_on_shutdown", "counter") i.Metrics.Register("trace_forwarded_on_shutdown", "counter") + i.Metrics.Register("original_sample_rate_before_multi", "histogram") + i.Metrics.Register("sample_rate_multi", "histogram") + i.Metrics.Register("trace_aggregate_sample_rate", "histogram") + i.Metrics.Register(TraceSendGotRoot, "counter") i.Metrics.Register(TraceSendExpired, "counter") i.Metrics.Register(TraceSendSpanLimit, "counter") @@ -660,6 +665,7 @@ func (i *InMemCollector) dealWithSentTrace(ctx context.Context, tr cache.TraceSe } } if keep { + i.ThroughputCalculator.IncrementEventCount(1) i.Logger.Debug().WithField("trace_id", sp.TraceID).Logf("Sending span because of previous decision to send trace") mergeTraceAndSpanSampleRates(sp, tr.Rate(), isDryRun) // if this span is a late root span, possibly update it with our current span count @@ -781,7 +787,25 @@ func (i *InMemCollector) send(trace *types.Trace, sendReason string) { } // make sampling decision and update the trace - rate, shouldSend, reason, key := sampler.GetSampleRate(trace) + originalRate, reason, key := sampler.GetSampleRate(trace) + sampleRateMultiplier := i.ThroughputCalculator.GetSamplingRateMultiplier() + i.Metrics.Histogram("original_sample_rate_before_multi", originalRate) + i.Metrics.Histogram("sample_rate_multi", sampleRateMultiplier) + + // counting the expected number of spans based on the original sample rate + // this will tell us the throughput we would have sent without the adjustment from the multiplier + i.ThroughputCalculator.IncrementEventCount(float64(trace.DescendantCount()) / float64(originalRate)) + + // TODO: if the sample rate returned by the sampler is set to 1, we should not + // modify the sample rate with the multiplier + var rate uint + if originalRate == 1 { + rate = originalRate + } else { + rate = uint(float64(originalRate) * sampleRateMultiplier) + } + shouldSend := sampler.MakeSamplingDecision(rate, trace) + trace.SetSampleRate(rate) trace.KeepSample = shouldSend logFields["reason"] = reason @@ -799,6 +823,7 @@ func (i *InMemCollector) send(trace *types.Trace, sendReason string) { i.Logger.Info().WithFields(logFields).Logf("Dropping trace because of sampling") return } + i.Metrics.Increment("trace_send_kept") // This will observe sample rate decisions only if the trace is kept i.Metrics.Histogram("trace_kept_sample_rate", float64(rate)) diff --git a/collect/collect_test.go b/collect/collect_test.go index 89d406be6b..93879a335c 100644 --- a/collect/collect_test.go +++ b/collect/collect_test.go @@ -22,6 +22,7 @@ import ( "github.com/honeycombio/refinery/internal/peer" "github.com/honeycombio/refinery/logger" "github.com/honeycombio/refinery/metrics" + "github.com/honeycombio/refinery/pubsub" "github.com/honeycombio/refinery/sample" "github.com/honeycombio/refinery/sharder" "github.com/honeycombio/refinery/transmit" @@ -50,14 +51,15 @@ func newTestCollector(conf config.Config, transmission transmit.Transmission) *I healthReporter.Start() return &InMemCollector{ - Config: conf, - Clock: clock, - Logger: &logger.NullLogger{}, - Tracer: noop.NewTracerProvider().Tracer("test"), - Health: healthReporter, - Transmission: transmission, - Metrics: &metrics.NullMetrics{}, - StressRelief: &MockStressReliever{}, + Config: conf, + Clock: clock, + Logger: &logger.NullLogger{}, + Tracer: noop.NewTracerProvider().Tracer("test"), + Health: healthReporter, + Transmission: transmission, + Metrics: &metrics.NullMetrics{}, + StressRelief: &MockStressReliever{}, + ThroughputCalculator: &EMAThroughputCalculator{}, SamplerFactory: &sample.SamplerFactory{ Config: conf, Metrics: s, @@ -423,14 +425,20 @@ func TestDryRunMode(t *testing.T) { var traceID2 = "def456" var traceID3 = "ghi789" // sampling decisions based on trace ID - sampleRate1, keepTraceID1, _, _ := sampler.GetSampleRate(&types.Trace{TraceID: traceID1}) + trace1 := &types.Trace{TraceID: traceID1} + sampleRate1, _, _ := sampler.GetSampleRate(trace1) + keepTraceID1 := sampler.MakeSamplingDecision(sampleRate1, trace1) // would be dropped if dry run mode was not enabled assert.False(t, keepTraceID1) assert.Equal(t, uint(10), sampleRate1) - sampleRate2, keepTraceID2, _, _ := sampler.GetSampleRate(&types.Trace{TraceID: traceID2}) + trace2 := &types.Trace{TraceID: traceID2} + sampleRate2, _, _ := sampler.GetSampleRate(trace2) + keepTraceID2 := sampler.MakeSamplingDecision(sampleRate2, trace2) assert.True(t, keepTraceID2) assert.Equal(t, uint(10), sampleRate2) - sampleRate3, keepTraceID3, _, _ := sampler.GetSampleRate(&types.Trace{TraceID: traceID3}) + trace3 := &types.Trace{TraceID: traceID3} + sampleRate3, _, _ := sampler.GetSampleRate(trace3) + keepTraceID3 := sampler.MakeSamplingDecision(sampleRate3, trace3) // would be dropped if dry run mode was not enabled assert.False(t, keepTraceID3) assert.Equal(t, uint(10), sampleRate3) @@ -827,8 +835,11 @@ func TestDependencyInjection(t *testing.T) { &inject.Object{Value: &sharder.SingleServerSharder{}}, &inject.Object{Value: &transmit.MockTransmission{}, Name: "upstreamTransmission"}, &inject.Object{Value: &metrics.NullMetrics{}, Name: "genericMetrics"}, + &inject.Object{Value: &metrics.NullMetrics{}, Name: "metrics"}, &inject.Object{Value: &sample.SamplerFactory{}}, &inject.Object{Value: &MockStressReliever{}, Name: "stressRelief"}, + &inject.Object{Value: &pubsub.LocalPubSub{}}, + &inject.Object{Value: &EMAThroughputCalculator{}, Name: "throughputCalculator"}, &inject.Object{Value: &peer.MockPeers{}}, ) if err != nil { diff --git a/collect/throughput_calculator.go b/collect/throughput_calculator.go new file mode 100644 index 0000000000..8325ecb76a --- /dev/null +++ b/collect/throughput_calculator.go @@ -0,0 +1,205 @@ +package collect + +import ( + "context" + "fmt" + "math" + "strconv" + "strings" + "sync" + "time" + + "github.com/honeycombio/refinery/config" + "github.com/honeycombio/refinery/internal/peer" + "github.com/honeycombio/refinery/metrics" + "github.com/honeycombio/refinery/pubsub" + "github.com/jonboulle/clockwork" +) + +const emaThroughputTopic = "ema_throughput" + +// EMAThroughputCalculator encapsulates the logic to calculate a throughput value using an Exponential Moving Average (EMA). +type EMAThroughputCalculator struct { + Config config.Config `inject:""` + Metrics metrics.Metrics `inject:"metrics"` + Clock clockwork.Clock `inject:""` + Pubsub pubsub.PubSub `inject:""` + Peer peer.Peers `inject:""` + + throughputLimit uint + weight float64 // Smoothing factor for EMA + intervalLength time.Duration // Length of the interval + hostID string + + mut sync.RWMutex + throughputs map[string]throughputReport + clusterEMA uint + weightedEventTotal float64 // Internal count of events in the current interval + done chan struct{} +} + +// NewEMAThroughputCalculator creates a new instance of EMAThroughputCalculator. +func (c *EMAThroughputCalculator) Start() error { + cfg := c.Config.GetThroughputCalculatorConfig() + c.throughputLimit = uint(cfg.Limit) + c.done = make(chan struct{}) + + // if throughput limit is not set, disable the calculator + if c.throughputLimit == 0 { + return nil + } + + c.intervalLength = time.Duration(cfg.AdjustmentInterval) + if c.intervalLength == 0 { + c.intervalLength = 15 * time.Second + } + + c.weight = cfg.Weight + if c.weight == 0 { + c.weight = 0.5 + } + + peerID, err := c.Peer.GetInstanceID() + if err != nil { + return err + } + c.hostID = peerID + c.throughputs = make(map[string]throughputReport) + + c.Metrics.Register("cluster_throughput", "gauge") + c.Metrics.Register("cluster_ema_throughput", "gauge") + c.Metrics.Register("individual_throughput", "gauge") + c.Metrics.Register("ema_throughput_publish_error", "counter") + // Subscribe to the throughput topic so we can react to throughput + // changes in the cluster. + c.Pubsub.Subscribe(context.Background(), emaThroughputTopic, c.onThroughputUpdate) + + // have a centralized peer metric service that's responsible for publishing and + // receiving peer metrics + // it could have a channel that's receiving metrics from different source + // it then only send a message if the value has changed and it has passed the configured interval for the metric + // there could be a third case that basically says you have to send it now because we have passed the configured interval and we haven't send a message about this metric since the last interval + go func() { + ticker := c.Clock.NewTicker(c.intervalLength) + defer ticker.Stop() + + for { + select { + case <-c.done: + return + case <-ticker.Chan(): + currentThroughput := c.updateEMA() + err := c.Pubsub.Publish(context.Background(), emaThroughputTopic, newThroughputMessage(currentThroughput, peerID).String()) + if err != nil { + c.Metrics.Count("ema_throughput_publish_error", 1) + } + } + } + + }() + + return nil +} + +func (c *EMAThroughputCalculator) onThroughputUpdate(ctx context.Context, msg string) { + throughputMsg, err := unmarshalThroughputMessage(msg) + if err != nil { + return + } + c.mut.Lock() + c.throughputs[throughputMsg.peerID] = throughputReport{ + key: throughputMsg.peerID, + throughput: throughputMsg.throughput, + timestamp: c.Clock.Now(), + } + c.mut.Unlock() +} + +func (c *EMAThroughputCalculator) Stop() { + close(c.done) +} + +// IncrementEventCount increments the internal event count by a specified amount. +func (c *EMAThroughputCalculator) IncrementEventCount(count float64) { + c.mut.Lock() + c.weightedEventTotal += count + c.mut.Unlock() +} + +// updateEMA calculates the current throughput and updates the EMA. +func (c *EMAThroughputCalculator) updateEMA() uint { + c.mut.Lock() + defer c.mut.Unlock() + + var totalThroughput float64 + + for _, report := range c.throughputs { + if c.Clock.Since(report.timestamp) > c.intervalLength*2 { + delete(c.throughputs, report.key) + continue + } + + totalThroughput += float64(report.throughput) + } + c.Metrics.Gauge("cluster_throughput", totalThroughput) + c.clusterEMA = uint(math.Ceil(c.weight*totalThroughput + (1-c.weight)*float64(c.clusterEMA))) + c.Metrics.Gauge("cluster_ema_throughput", c.clusterEMA) + + // calculating throughput for the next interval + currentThroughput := float64(c.weightedEventTotal) / c.intervalLength.Seconds() + c.Metrics.Gauge("individual_throughput", currentThroughput) + c.weightedEventTotal = 0 // Reset the event count for the new interval + + return uint(currentThroughput) +} + +// GetSamplingRateMultiplier calculates and returns a sampling rate multiplier +// based on the difference between the configured throughput limit and the current throughput. +func (c *EMAThroughputCalculator) GetSamplingRateMultiplier() float64 { + if c.throughputLimit == 0 { + return 1.0 // No limit set, so no adjustment needed + } + + c.mut.RLock() + currentEMA := c.clusterEMA + c.mut.RUnlock() + + if currentEMA <= c.throughputLimit { + return 1.0 // Throughput is within the limit, no adjustment needed + } + + return float64(currentEMA) / float64(c.throughputLimit) +} + +type throughputReport struct { + key string + throughput uint + timestamp time.Time +} + +type throughputMessage struct { + peerID string + throughput uint +} + +func newThroughputMessage(throughput uint, peerID string) *throughputMessage { + return &throughputMessage{throughput: throughput, peerID: peerID} +} + +func (msg *throughputMessage) String() string { + return msg.peerID + "|" + fmt.Sprint(msg.throughput) +} + +func unmarshalThroughputMessage(msg string) (*throughputMessage, error) { + if len(msg) < 2 { + return nil, fmt.Errorf("empty message") + } + + parts := strings.SplitN(msg, "|", 2) + throughput, err := strconv.Atoi(parts[1]) + if err != nil { + return nil, err + } + + return newThroughputMessage(uint(throughput), parts[0]), nil +} diff --git a/collect/throughput_calculator_test.go b/collect/throughput_calculator_test.go new file mode 100644 index 0000000000..ae6149667c --- /dev/null +++ b/collect/throughput_calculator_test.go @@ -0,0 +1,157 @@ +package collect + +import ( + "context" + "math" + "sync" + "testing" + "time" + + "github.com/honeycombio/refinery/config" + "github.com/honeycombio/refinery/internal/peer" + "github.com/honeycombio/refinery/metrics" + "github.com/honeycombio/refinery/pubsub" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEMAThroughputCalculator(t *testing.T) { + fakeClock := clockwork.NewFakeClock() + + weight := 0.5 + intervalLength := time.Second + throughputLimit := 100 + calculator := &EMAThroughputCalculator{ + Clock: fakeClock, + Metrics: &metrics.NullMetrics{}, + Pubsub: &pubsub.LocalPubSub{}, + Peer: &peer.MockPeers{}, + done: make(chan struct{}), + hostID: "test-host", + throughputs: make(map[string]throughputReport), + intervalLength: intervalLength, + weight: weight, + throughputLimit: uint(throughputLimit), + } + calculator.Pubsub.Start() + defer calculator.Pubsub.Stop() + + calculator.IncrementEventCount(150) + + calculator.updateEMA() + // check that the EMA was updated correctly + expectedThroughput := float64(150) / intervalLength.Seconds() + // starting lastEMA is 0 + expectedEMA := weight*expectedThroughput + (1-weight)*0 + calculator.mut.RLock() + require.Equal(t, uint(expectedEMA), calculator.clusterEMA, "EMA calculation is incorrect", calculator.clusterEMA) + require.Equal(t, 0, calculator.weightedEventTotal, "event count is not reset after EMA calculation") + calculator.mut.RUnlock() + + multiplier := calculator.GetSamplingRateMultiplier() + assert.Equal(t, 1.0, multiplier, "Sampling rate multiplier is incorrect") + + calculator.IncrementEventCount(300) + + calculator.updateEMA() + newThroughput := float64(300) / intervalLength.Seconds() + expectedEMA = math.Ceil(weight*newThroughput + (1-weight)*expectedEMA) + calculator.mut.RLock() + assert.Equal(t, uint(expectedEMA), calculator.clusterEMA, "EMA calculation after second interval is incorrect") + require.Equal(t, 0, calculator.weightedEventTotal, "event count is not reset after EMA calculation") + calculator.mut.RUnlock() + + multiplier = calculator.GetSamplingRateMultiplier() + assert.Equal(t, 1.88, multiplier, "Sampling rate multiplier should be 1 when throughput is within the limit") +} + +func TestEMAThroughputCalculator_Concurrent(t *testing.T) { + fakeClock := clockwork.NewFakeClock() + + weight := 0.5 + intervalLength := time.Second + throughputLimit := 100 + + calculator := &EMAThroughputCalculator{ + Clock: fakeClock, + Config: &config.MockConfig{ + GetThroughputCalculatorVal: config.ThroughputCalculatorConfig{ + Limit: throughputLimit, + Weight: weight, + AdjustmentInterval: config.Duration(intervalLength), + }, + }, + Pubsub: &pubsub.LocalPubSub{}, + Peer: &peer.MockPeers{}, + Metrics: &metrics.NullMetrics{}, + } + calculator.Pubsub.Start() + defer calculator.Pubsub.Stop() + calculator.Start() + defer calculator.Stop() + + numGoroutines := 10 + incrementsPerGoroutine := 100 + + var wg sync.WaitGroup + wg.Add(numGoroutines * 2) + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < incrementsPerGoroutine; j++ { + calculator.IncrementEventCount(1) + } + fakeClock.Advance(intervalLength) + }() + } + + for i := 0; i < numGoroutines; i++ { + go func() { + defer wg.Done() + for j := 0; j < incrementsPerGoroutine; j++ { + rate := calculator.GetSamplingRateMultiplier() + assert.GreaterOrEqual(t, rate, 1.0) + } + }() + } + wg.Wait() +} + +func TestEMAThroughputCalculator_MultiplePeers(t *testing.T) { + mockPubSub := &pubsub.LocalPubSub{} + mockPeers := &peer.MockPeers{ + Peers: []string{"instance-1", "instance-2", "instance-3"}, + ID: "instance-1", + } + + fakeClock := clockwork.NewFakeClock() + + calculator := &EMAThroughputCalculator{ + Config: &config.MockConfig{ + GetThroughputCalculatorVal: config.ThroughputCalculatorConfig{ + Limit: 1000, + Weight: 0.5, + AdjustmentInterval: config.Duration(time.Second), + }, + }, + Clock: fakeClock, + Metrics: &metrics.NullMetrics{}, + Pubsub: mockPubSub, + Peer: mockPeers, + intervalLength: time.Second, + weight: 0.5, + throughputs: make(map[string]throughputReport), + } + + // Simulate multiple peers reporting their throughputs + calculator.weightedEventTotal = 100 + calculator.onThroughputUpdate(context.Background(), "instance-2|200") + calculator.onThroughputUpdate(context.Background(), "instance-3|300") + + // Update EMA and check the combined cluster EMA + calculator.updateEMA() + + assert.Equal(t, uint(625), calculator.clusterEMA, "The cluster EMA should be the sum of all peer throughputs.", int(calculator.clusterEMA)) +} diff --git a/config/config.go b/config/config.go index 1fa599829a..eeac08dd35 100644 --- a/config/config.go +++ b/config/config.go @@ -150,6 +150,7 @@ type Config interface { GetSampleCacheConfig() SampleCacheConfig GetStressReliefConfig() StressReliefConfig + GetThroughputCalculatorConfig() ThroughputCalculatorConfig GetAdditionalAttributes() map[string]string diff --git a/config/file_config.go b/config/file_config.go index aea5fd6c65..ddcaf1fe62 100644 --- a/config/file_config.go +++ b/config/file_config.go @@ -46,28 +46,35 @@ type fileConfig struct { var _ Config = (*fileConfig)(nil) type configContents struct { - General GeneralConfig `yaml:"General"` - Network NetworkConfig `yaml:"Network"` - AccessKeys AccessKeyConfig `yaml:"AccessKeys"` - Telemetry RefineryTelemetryConfig `yaml:"RefineryTelemetry"` - Traces TracesConfig `yaml:"Traces"` - Debugging DebuggingConfig `yaml:"Debugging"` - Logger LoggerConfig `yaml:"Logger"` - HoneycombLogger HoneycombLoggerConfig `yaml:"HoneycombLogger"` - StdoutLogger StdoutLoggerConfig `yaml:"StdoutLogger"` - PrometheusMetrics PrometheusMetricsConfig `yaml:"PrometheusMetrics"` - LegacyMetrics LegacyMetricsConfig `yaml:"LegacyMetrics"` - OTelMetrics OTelMetricsConfig `yaml:"OTelMetrics"` - OTelTracing OTelTracingConfig `yaml:"OTelTracing"` - PeerManagement PeerManagementConfig `yaml:"PeerManagement"` - RedisPeerManagement RedisPeerManagementConfig `yaml:"RedisPeerManagement"` - Collection CollectionConfig `yaml:"Collection"` - BufferSizes BufferSizeConfig `yaml:"BufferSizes"` - Specialized SpecializedConfig `yaml:"Specialized"` - IDFieldNames IDFieldsConfig `yaml:"IDFields"` - GRPCServerParameters GRPCServerParameters `yaml:"GRPCServerParameters"` - SampleCache SampleCacheConfig `yaml:"SampleCache"` - StressRelief StressReliefConfig `yaml:"StressRelief"` + General GeneralConfig `yaml:"General"` + Network NetworkConfig `yaml:"Network"` + AccessKeys AccessKeyConfig `yaml:"AccessKeys"` + Telemetry RefineryTelemetryConfig `yaml:"RefineryTelemetry"` + Traces TracesConfig `yaml:"Traces"` + Debugging DebuggingConfig `yaml:"Debugging"` + Logger LoggerConfig `yaml:"Logger"` + HoneycombLogger HoneycombLoggerConfig `yaml:"HoneycombLogger"` + StdoutLogger StdoutLoggerConfig `yaml:"StdoutLogger"` + PrometheusMetrics PrometheusMetricsConfig `yaml:"PrometheusMetrics"` + LegacyMetrics LegacyMetricsConfig `yaml:"LegacyMetrics"` + OTelMetrics OTelMetricsConfig `yaml:"OTelMetrics"` + OTelTracing OTelTracingConfig `yaml:"OTelTracing"` + PeerManagement PeerManagementConfig `yaml:"PeerManagement"` + RedisPeerManagement RedisPeerManagementConfig `yaml:"RedisPeerManagement"` + Collection CollectionConfig `yaml:"Collection"` + BufferSizes BufferSizeConfig `yaml:"BufferSizes"` + Specialized SpecializedConfig `yaml:"Specialized"` + IDFieldNames IDFieldsConfig `yaml:"IDFields"` + GRPCServerParameters GRPCServerParameters `yaml:"GRPCServerParameters"` + SampleCache SampleCacheConfig `yaml:"SampleCache"` + StressRelief StressReliefConfig `yaml:"StressRelief"` + ThroughputCalculator ThroughputCalculatorConfig `yaml:"ThroughputCalculator"` +} + +type ThroughputCalculatorConfig struct { + Limit int `json:"limit" yaml:"Limit,omitempty" validate:"required,gte=1"` + Weight float64 `json:"weight" yaml:"Weight,omitempty" validate:"gt=0,lt=1"` + AdjustmentInterval Duration `json:"adjustmentinterval" yaml:"AdjustmentInterval,omitempty"` } type GeneralConfig struct { @@ -953,6 +960,13 @@ func (f *fileConfig) GetStressReliefConfig() StressReliefConfig { return f.mainConfig.StressRelief } +func (f *fileConfig) GetThroughputCalculatorConfig() ThroughputCalculatorConfig { + f.mux.RLock() + defer f.mux.RUnlock() + + return f.mainConfig.ThroughputCalculator +} + func (f *fileConfig) GetTraceIdFieldNames() []string { f.mux.RLock() defer f.mux.RUnlock() diff --git a/config/metadata/configMeta.yaml b/config/metadata/configMeta.yaml index ee2bdadace..fe70a35ce7 100644 --- a/config/metadata/configMeta.yaml +++ b/config/metadata/configMeta.yaml @@ -1787,3 +1787,48 @@ groups: If this duration is `0`, then Refinery will not start in stressed mode, which will provide faster startup at the possible cost of startup instability. + + - name: ThroughputCalculator + sortorder: 1 + title: Throughput Calculator + description: > + The Throughput Calculator is designed to dynamically adjust the sampling + rates of all configured samplers to maintain an overall throughput that + does not exceed a specified limit. This configuration uses an Exponential + Moving Average (EMA) to track the current throughput and applies a sample + rate multiplier that scales the sample rates accordingly. + fields: + - name: Limit + type: int + validations: + - type: requiredInGroup + summary: is the maximum allowable throughput per second for the entire Refinery cluster. + description: > + The maximum number of events per second you want to send to Honeycomb, aggregated across + all instances in the cluster. + + Refinery will adjust sample rate calculated from all configured samplers dynamically to + try to ensure that the overall traffic from all instances combined does not exceed this limit. + - name: Weight + type: float + validations: + - type: minimum + arg: 0 + - type: maximum + arg: 1 + summary: is the weight to use when calculating the EMA. + description: > + The weight to use when calculating the EMA. It should be a number + between `0` and `1`. Larger values weight the average more toward + recent observations. In other words, a larger weight will cause + sample rates more quickly adapt to traffic patterns, while a smaller + weight will result in sample rates that are less sensitive to bursts + or drops in traffic and thus more consistent over time. + - name: AdjustmentInterval + type: duration + summary: is how often the sampler will recalculate the sample rate. + description: > + The duration after which the EMA Dynamic Sampler should recalculate + its internal counters. It should be specified as a duration string. + For example, "30s" or "1m". Defaults to "15s". + diff --git a/config/mock.go b/config/mock.go index 34497a3a10..c53e7bdec2 100644 --- a/config/mock.go +++ b/config/mock.go @@ -27,6 +27,7 @@ type MockConfig struct { GetLoggerLevelVal Level GetPeersVal []string GetRedisPeerManagementVal RedisPeerManagementConfig + GetThroughputCalculatorVal ThroughputCalculatorConfig GetSamplerTypeName string GetSamplerTypeVal interface{} GetMetricsTypeVal string @@ -274,13 +275,18 @@ func (m *MockConfig) GetAllSamplerRules() *V2SamplerConfig { return nil } - v := &V2SamplerConfig{ - Samplers: map[string]*V2SamplerChoice{"dataset1": choice}, - } + v := &V2SamplerConfig{Samplers: map[string]*V2SamplerChoice{"dataset1": choice}} return v } +func (m *MockConfig) GetThroughputCalculatorConfig() ThroughputCalculatorConfig { + m.Mux.RLock() + defer m.Mux.RUnlock() + + return m.GetThroughputCalculatorVal +} + func (m *MockConfig) GetUpstreamBufferSize() int { m.Mux.RLock() defer m.Mux.RUnlock() diff --git a/metrics/otel_metrics.go b/metrics/otel_metrics.go index f879f49543..1fc12899ae 100644 --- a/metrics/otel_metrics.go +++ b/metrics/otel_metrics.go @@ -34,7 +34,7 @@ type OTelMetrics struct { counters map[string]metric.Int64Counter gauges map[string]metric.Float64ObservableGauge - histograms map[string]metric.Int64Histogram + histograms map[string]metric.Float64Histogram updowns map[string]metric.Int64UpDownCounter // values keeps a map of all the non-histogram metrics and their current value @@ -52,7 +52,7 @@ func (o *OTelMetrics) Start() error { o.counters = make(map[string]metric.Int64Counter) o.gauges = make(map[string]metric.Float64ObservableGauge) - o.histograms = make(map[string]metric.Int64Histogram) + o.histograms = make(map[string]metric.Float64Histogram) o.updowns = make(map[string]metric.Int64UpDownCounter) o.values = make(map[string]float64) @@ -215,7 +215,7 @@ func (o *OTelMetrics) Register(name string, metricType string) { } o.gauges[name] = g case "histogram": - h, err := o.meter.Int64Histogram(name) + h, err := o.meter.Float64Histogram(name) if err != nil { o.Logger.Error().WithString("msg", "failed to create histogram").WithString("name", name) return @@ -268,7 +268,7 @@ func (o *OTelMetrics) Histogram(name string, val interface{}) { if h, ok := o.histograms[name]; ok { f := ConvertNumeric(val) - h.Record(context.Background(), int64(f)) + h.Record(context.Background(), f) o.values[name] += f } } diff --git a/route/route_test.go b/route/route_test.go index 1649df16cf..b1d226532b 100644 --- a/route/route_test.go +++ b/route/route_test.go @@ -21,6 +21,7 @@ import ( "github.com/honeycombio/refinery/internal/peer" "github.com/honeycombio/refinery/logger" "github.com/honeycombio/refinery/metrics" + "github.com/honeycombio/refinery/pubsub" "github.com/honeycombio/refinery/sharder" "github.com/honeycombio/refinery/transmit" "github.com/jonboulle/clockwork" @@ -387,11 +388,11 @@ func TestDebugAllRules(t *testing.T) { }{ { format: "json", - expect: `{"rulesversion":0,"samplers":{"dataset1":{"deterministicsampler":{"samplerate":0},"rulesbasedsampler":null,"dynamicsampler":null,"emadynamicsampler":null,"emathroughputsampler":null,"windowedthroughputsampler":null,"totalthroughputsampler":null}}}`, + expect: `{"rulesversion":0,"throughputlimit":{"limit":0,"weight":0,"adjustmentinterval":"0s"},"samplers":{"dataset1":{"deterministicsampler":{"samplerate":0},"rulesbasedsampler":null,"dynamicsampler":null,"emadynamicsampler":null,"emathroughputsampler":null,"windowedthroughputsampler":null,"totalthroughputsampler":null}}}`, }, { format: "toml", - expect: "RulesVersion = 0\n\n[Samplers]\n[Samplers.dataset1]\n[Samplers.dataset1.DeterministicSampler]\nSampleRate = 0\n", + expect: "RulesVersion = 0\n\n[ThroughPutLimit]\nLimit = 0\nWeight = 0.0\nAdjustmentInterval = '0s'\n\n[Samplers]\n[Samplers.dataset1]\n[Samplers.dataset1.DeterministicSampler]\nSampleRate = 0\n", }, { format: "yaml", @@ -491,6 +492,8 @@ func TestDependencyInjection(t *testing.T) { &inject.Object{Value: &collect.MockStressReliever{}, Name: "stressRelief"}, &inject.Object{Value: &peer.MockPeers{}}, &inject.Object{Value: &health.Health{}}, + &inject.Object{Value: &pubsub.LocalPubSub{}}, + &inject.Object{Value: &collect.EMAThroughputCalculator{}, Name: "throughputCalculator"}, &inject.Object{Value: clockwork.NewFakeClock()}, ) if err != nil { diff --git a/sample/deterministic.go b/sample/deterministic.go index 1d2e5a44c4..22d29d3230 100644 --- a/sample/deterministic.go +++ b/sample/deterministic.go @@ -15,6 +15,8 @@ import ( // other sharding that uses the trace ID (eg deterministic sharding) const shardingSalt = "5VQ8l2jE5aJLPVqk" +var _ Sampler = (*DeterministicSampler)(nil) + type DeterministicSampler struct { Config *config.DeterministicSamplerConfig Logger logger.Logger @@ -42,10 +44,19 @@ func (d *DeterministicSampler) Start() error { return nil } -func (d *DeterministicSampler) GetSampleRate(trace *types.Trace) (rate uint, keep bool, reason string, key string) { +func (d *DeterministicSampler) GetSampleRate(trace *types.Trace) (rate uint, reason string, key string) { if d.sampleRate <= 1 { - return 1, true, "deterministic/always", "" + return 1, "deterministic/always", "" + } + + return uint(d.sampleRate), "deterministic/chance", "" +} + +func (d *DeterministicSampler) MakeSamplingDecision(rate uint, trace *types.Trace) bool { + if rate == 1 { + return true } + sum := sha1.Sum([]byte(trace.TraceID + shardingSalt)) v := binary.BigEndian.Uint32(sum[:4]) shouldKeep := v <= d.upperBound @@ -55,5 +66,5 @@ func (d *DeterministicSampler) GetSampleRate(trace *types.Trace) (rate uint, kee d.Metrics.Increment(d.prefix + "num_dropped") } - return uint(d.sampleRate), shouldKeep, "deterministic/chance", "" + return shouldKeep } diff --git a/sample/deterministic_test.go b/sample/deterministic_test.go index 3ec6afcf01..65704cf76e 100644 --- a/sample/deterministic_test.go +++ b/sample/deterministic_test.go @@ -50,7 +50,9 @@ func TestGetSampleRate(t *testing.T) { ds.Start() for i, tst := range tsts { - rate, keep, reason, key := ds.GetSampleRate(tst.trace) + rate, reason, key := ds.GetSampleRate(tst.trace) + keep := ds.MakeSamplingDecision(rate, tst.trace) + assert.Equal(t, uint(10), rate, "sample rate should be fixed") assert.Equal(t, tst.sampled, keep, "%d: trace ID %s should be %v", i, tst.trace.TraceID, tst.sampled) assert.Equal(t, "deterministic/chance", reason) diff --git a/sample/dynamic.go b/sample/dynamic.go index 7367e5865b..54dfdc5253 100644 --- a/sample/dynamic.go +++ b/sample/dynamic.go @@ -1,7 +1,6 @@ package sample import ( - "math/rand" "time" dynsampler "github.com/honeycombio/dynsampler-go" @@ -12,6 +11,8 @@ import ( "github.com/honeycombio/refinery/types" ) +var _ Sampler = (*DynamicSampler)(nil) + type DynamicSampler struct { Config *config.DynamicSamplerConfig Logger logger.Logger @@ -63,26 +64,21 @@ func (d *DynamicSampler) Start() error { return nil } -func (d *DynamicSampler) GetSampleRate(trace *types.Trace) (rate uint, keep bool, reason string, key string) { +func (d *DynamicSampler) GetSampleRate(trace *types.Trace) (rate uint, reason string, key string) { key = d.key.build(trace) count := int(trace.DescendantCount()) + rate = uint(d.dynsampler.GetSampleRateMulti(key, count)) if rate < 1 { // protect against dynsampler being broken even though it shouldn't be rate = 1 } - shouldKeep := rand.Intn(int(rate)) == 0 + d.Logger.Debug().WithFields(map[string]interface{}{ "sample_key": key, "sample_rate": rate, - "sample_keep": shouldKeep, "trace_id": trace.TraceID, "span_count": count, - }).Logf("got sample rate and decision") - if shouldKeep { - d.Metrics.Increment(d.prefix + "num_kept") - } else { - d.Metrics.Increment(d.prefix + "num_dropped") - } + }).Logf("got sample rate") d.Metrics.Histogram(d.prefix+"sample_rate", float64(rate)) for name, val := range d.dynsampler.GetMetrics(d.prefix) { switch getMetricType(name) { @@ -94,5 +90,22 @@ func (d *DynamicSampler) GetSampleRate(trace *types.Trace) (rate uint, keep bool d.Metrics.Gauge(name, val) } } - return rate, shouldKeep, "dynamic", key + return rate, "dynamic", key +} + +func (d *DynamicSampler) MakeSamplingDecision(rate uint, trace *types.Trace) bool { + shouldKeep := makeSamplingDecision(rate) + if shouldKeep { + d.Metrics.Increment(d.prefix + "num_kept") + } else { + d.Metrics.Increment(d.prefix + "num_dropped") + } + + d.Logger.Debug().WithFields(map[string]interface{}{ + "sample_rate": rate, + "sample_keep": shouldKeep, + "trace_id": trace.TraceID, + }).Logf("got sample decision") + + return shouldKeep } diff --git a/sample/dynamic_ema.go b/sample/dynamic_ema.go index 7fd04d2780..fbfe1510ae 100644 --- a/sample/dynamic_ema.go +++ b/sample/dynamic_ema.go @@ -1,7 +1,6 @@ package sample import ( - "math/rand" "time" dynsampler "github.com/honeycombio/dynsampler-go" @@ -12,6 +11,8 @@ import ( "github.com/honeycombio/refinery/types" ) +var _ Sampler = (*EMADynamicSampler)(nil) + type EMADynamicSampler struct { Config *config.EMADynamicSamplerConfig Logger logger.Logger @@ -71,26 +72,21 @@ func (d *EMADynamicSampler) Start() error { return nil } -func (d *EMADynamicSampler) GetSampleRate(trace *types.Trace) (rate uint, keep bool, reason string, key string) { +func (d *EMADynamicSampler) GetSampleRate(trace *types.Trace) (rate uint, reason string, key string) { key = d.key.build(trace) count := int(trace.DescendantCount()) + rate = uint(d.dynsampler.GetSampleRateMulti(key, count)) if rate < 1 { // protect against dynsampler being broken even though it shouldn't be rate = 1 } - shouldKeep := rand.Intn(int(rate)) == 0 + d.Logger.Debug().WithFields(map[string]interface{}{ "sample_key": key, "sample_rate": rate, - "sample_keep": shouldKeep, "trace_id": trace.TraceID, "span_count": count, - }).Logf("got sample rate and decision") - if shouldKeep { - d.Metrics.Increment(d.prefix + "num_kept") - } else { - d.Metrics.Increment(d.prefix + "num_dropped") - } + }).Logf("got sample rate") d.Metrics.Histogram(d.prefix+"sample_rate", float64(rate)) for name, val := range d.dynsampler.GetMetrics(d.prefix) { switch getMetricType(name) { @@ -102,5 +98,21 @@ func (d *EMADynamicSampler) GetSampleRate(trace *types.Trace) (rate uint, keep b d.Metrics.Gauge(name, val) } } - return rate, shouldKeep, "emadynamic", key + return rate, "emadynamic", key +} + +func (d *EMADynamicSampler) MakeSamplingDecision(rate uint, trace *types.Trace) bool { + shouldKeep := makeSamplingDecision(rate) + if shouldKeep { + d.Metrics.Increment(d.prefix + "num_kept") + } else { + d.Metrics.Increment(d.prefix + "num_dropped") + } + + d.Logger.Debug().WithFields(map[string]interface{}{ + "sample_rate": rate, + "trace_id": trace.TraceID, + }).Logf("got sample decision") + + return shouldKeep } diff --git a/sample/dynamic_ema_test.go b/sample/dynamic_ema_test.go index d344071f08..122ce5bf88 100644 --- a/sample/dynamic_ema_test.go +++ b/sample/dynamic_ema_test.go @@ -39,7 +39,7 @@ func TestDynamicEMAAddSampleRateKeyToTrace(t *testing.T) { }) } sampler.Start() - rate, _, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) spans := trace.GetSpans() diff --git a/sample/dynamic_test.go b/sample/dynamic_test.go index 358d8d6f0e..6c3c97c846 100644 --- a/sample/dynamic_test.go +++ b/sample/dynamic_test.go @@ -48,7 +48,8 @@ func TestDynamicAddSampleRateKeyToTrace(t *testing.T) { }) } sampler.Start() - rate, keep, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) + keep := sampler.MakeSamplingDecision(rate, trace) spans := trace.GetSpans() assert.Len(t, spans, spanCount, "should have the same number of spans as input") diff --git a/sample/ema_throughput.go b/sample/ema_throughput.go index 7d27983c67..a1900bc027 100644 --- a/sample/ema_throughput.go +++ b/sample/ema_throughput.go @@ -1,7 +1,6 @@ package sample import ( - "math/rand" "time" dynsampler "github.com/honeycombio/dynsampler-go" @@ -12,6 +11,8 @@ import ( "github.com/honeycombio/refinery/types" ) +var _ Sampler = (*EMAThroughputSampler)(nil) + type EMAThroughputSampler struct { Config *config.EMAThroughputSamplerConfig Logger logger.Logger @@ -88,27 +89,23 @@ func (d *EMAThroughputSampler) SetClusterSize(size int) { } } -func (d *EMAThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint, keep bool, reason string, key string) { +func (d *EMAThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint, reason string, key string) { key = d.key.build(trace) count := int(trace.DescendantCount()) + rate = uint(d.dynsampler.GetSampleRateMulti(key, count)) if rate < 1 { // protect against dynsampler being broken even though it shouldn't be rate = 1 } - shouldKeep := rand.Intn(int(rate)) == 0 d.Logger.Debug().WithFields(map[string]interface{}{ "sample_key": key, "sample_rate": rate, - "sample_keep": shouldKeep, "trace_id": trace.TraceID, "span_count": count, - }).Logf("got sample rate and decision") - if shouldKeep { - d.Metrics.Increment(d.prefix + "num_kept") - } else { - d.Metrics.Increment(d.prefix + "num_dropped") - } + }).Logf("got sample rate") + d.Metrics.Histogram(d.prefix+"sample_rate", float64(rate)) + for name, val := range d.dynsampler.GetMetrics(d.prefix) { switch getMetricType(name) { case "counter": @@ -119,5 +116,19 @@ func (d *EMAThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint, kee d.Metrics.Gauge(name, val) } } - return rate, shouldKeep, "emathroughput", key + return rate, "emathroughput", key +} + +func (d *EMAThroughputSampler) MakeSamplingDecision(rate uint, trace *types.Trace) bool { + shouldKeep := makeSamplingDecision(rate) + if shouldKeep { + d.Metrics.Increment(d.prefix + "num_kept") + } else { + d.Metrics.Increment(d.prefix + "num_dropped") + } + d.Logger.Debug().WithFields(map[string]interface{}{ + "sample_rate": rate, + "trace_id": trace.TraceID, + }).Logf("got sample decision") + return shouldKeep } diff --git a/sample/ema_throughput_test.go b/sample/ema_throughput_test.go index 152e439aa3..f412315bc1 100644 --- a/sample/ema_throughput_test.go +++ b/sample/ema_throughput_test.go @@ -39,7 +39,7 @@ func TestEMAThroughputAddSampleRateKeyToTrace(t *testing.T) { }) } sampler.Start() - rate, _, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) spans := trace.GetSpans() @@ -47,5 +47,4 @@ func TestEMAThroughputAddSampleRateKeyToTrace(t *testing.T) { assert.Equal(t, uint(10), rate, "sample rate should be 10") assert.Equal(t, "emathroughput", reason) assert.Equal(t, "4•,200•,true•,/{slug}/fun•,", key) - } diff --git a/sample/rules.go b/sample/rules.go index c1795fe036..6dfeb95d39 100644 --- a/sample/rules.go +++ b/sample/rules.go @@ -2,7 +2,6 @@ package sample import ( "encoding/json" - "math/rand" "strings" "github.com/honeycombio/refinery/config" @@ -12,6 +11,8 @@ import ( "github.com/tidwall/gjson" ) +var _ Sampler = (*RulesBasedSampler)(nil) + type RulesBasedSampler struct { Config *config.RulesBasedSamplerConfig Logger logger.Logger @@ -79,7 +80,7 @@ func (s *RulesBasedSampler) Start() error { return nil } -func (s *RulesBasedSampler) GetSampleRate(trace *types.Trace) (rate uint, keep bool, reason string, key string) { +func (s *RulesBasedSampler) GetSampleRate(trace *types.Trace) (rate uint, reason string, key string) { logger := s.Logger.Debug().WithFields(map[string]interface{}{ "trace_id": trace.TraceID, }) @@ -106,7 +107,6 @@ func (s *RulesBasedSampler) GetSampleRate(trace *types.Trace) (rate uint, keep b if matched { var rate uint - var keep bool var samplerReason string var key string @@ -117,36 +117,47 @@ func (s *RulesBasedSampler) GetSampleRate(trace *types.Trace) (rate uint, keep b logger.WithFields(map[string]interface{}{ "rule_name": rule.Name, }).Logf("could not find downstream sampler for rule: %s", rule.Name) - return 1, true, reason + "bad_rule:" + rule.Name, "" + return 1, reason + "bad_rule:" + rule.Name, "" } - rate, keep, samplerReason, key = sampler.GetSampleRate(trace) + rate, samplerReason, key = sampler.GetSampleRate(trace) reason += rule.Name + ":" + samplerReason } else { rate = uint(rule.SampleRate) - keep = !rule.Drop && rule.SampleRate > 0 && rand.Intn(rule.SampleRate) == 0 reason += rule.Name s.Metrics.Histogram(s.prefix+"sample_rate", float64(rate)) - } - - if keep { - s.Metrics.Increment(s.prefix + "num_kept") - } else { - s.Metrics.Increment(s.prefix + "num_dropped") if rule.Drop { - // If we dropped because of an explicit drop rule, then increment that too. - s.Metrics.Increment(s.prefix + "num_dropped_by_drop_rule") + rate = 0 } } + logger.WithFields(map[string]interface{}{ "rate": rate, - "keep": keep, "drop_rule": rule.Drop, }).Logf("got sample rate and decision") - return rate, keep, reason, key + return rate, reason, key } } - return 1, true, "no rule matched", "" + return 1, "no rule matched", "" +} + +func (s *RulesBasedSampler) MakeSamplingDecision(rate uint, trace *types.Trace) bool { + if rate == 0 { + // If we dropped because of an explicit drop rule, then increment that too. + s.Metrics.Increment(s.prefix + "num_dropped_by_drop_rule") + return false + } + if rate == 1 { + return true + } + keep := makeSamplingDecision(rate) + if keep { + s.Metrics.Increment(s.prefix + "num_kept") + } else { + s.Metrics.Increment(s.prefix + "num_dropped") + } + + return keep } func ruleMatchesTrace(t *types.Trace, rule *config.RulesBasedSamplerRule, checkNestedFields bool) bool { diff --git a/sample/rules_test.go b/sample/rules_test.go index d4ec68bd32..32eea1776f 100644 --- a/sample/rules_test.go +++ b/sample/rules_test.go @@ -834,7 +834,7 @@ func TestRules(t *testing.T) { }, ExpectedName: "Check that the number of descendants is greater than 3", ExpectedKeep: false, - ExpectedRate: 1, + ExpectedRate: 0, }, { Rules: &config.RulesBasedSamplerConfig{ @@ -919,7 +919,8 @@ func TestRules(t *testing.T) { } } - rate, keep, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) + keep := sampler.MakeSamplingDecision(rate, trace) assert.Equal(t, d.ExpectedRate, rate, d.Rules) name := d.ExpectedName @@ -1112,7 +1113,8 @@ func TestRulesWithNestedFields(t *testing.T) { trace.AddSpan(span) } - rate, keep, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) + keep := sampler.MakeSamplingDecision(rate, trace) assert.Equal(t, d.ExpectedRate, rate, d.Rules) name := d.ExpectedName @@ -1196,7 +1198,8 @@ func TestRulesWithDynamicSampler(t *testing.T) { } sampler.Start() - rate, keep, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) + keep := sampler.MakeSamplingDecision(rate, trace) assert.Equal(t, d.ExpectedRate, rate, d.Rules) name := d.ExpectedName @@ -1282,7 +1285,8 @@ func TestRulesWithEMADynamicSampler(t *testing.T) { } sampler.Start() - rate, keep, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) + keep := sampler.MakeSamplingDecision(rate, trace) assert.Equal(t, d.ExpectedRate, rate, d.Rules) name := d.ExpectedName @@ -1406,9 +1410,15 @@ func TestRuleMatchesSpanMatchingSpan(t *testing.T) { } sampler.Start() - rate, keep, _, _ := sampler.GetSampleRate(trace) + rate, _, _ := sampler.GetSampleRate(trace) + keep := sampler.MakeSamplingDecision(rate, trace) + + if !keep { + assert.Equal(t, uint(0), rate, rate) + } else { + assert.Equal(t, uint(1), rate, rate) + } - assert.Equal(t, uint(1), rate, rate) if scope == "span" { assert.Equal(t, tc.keepSpanScope, keep, keep) } else { @@ -1988,7 +1998,9 @@ func TestRulesDatatypes(t *testing.T) { trace.AddSpan(span) } - rate, keep, _, _ := sampler.GetSampleRate(trace) + rate, _, _ := sampler.GetSampleRate(trace) + keep := sampler.MakeSamplingDecision(rate, trace) + assert.Equal(t, d.ExpectedRate, rate, d.Rules) // because keep depends on sampling rate, we can only test expectedKeep when it should be false if !d.ExpectedKeep { @@ -2065,7 +2077,7 @@ func TestRegexpRules(t *testing.T) { trace.AddSpan(span) } - rate, _, _, _ := sampler.GetSampleRate(trace) + rate, _, _ := sampler.GetSampleRate(trace) assert.Equal(t, d.rate, rate, d) }) } @@ -2136,7 +2148,9 @@ func TestRulesWithDeterministicSampler(t *testing.T) { } sampler.Start() - rate, keep, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) + keep := sampler.MakeSamplingDecision(rate, trace) + assert.Equal(t, "", key) assert.Equal(t, d.ExpectedRate, rate, d.Rules) @@ -2850,7 +2864,7 @@ func TestRulesRootSpanContext(t *testing.T) { spans := trace.GetSpans() assert.Len(t, spans, len(d.Spans), "should have the same number of spans as input") - rate, _, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) assert.Equal(t, "", key) assert.Equal(t, d.ExpectedRate, rate, d.Rules) diff --git a/sample/sample.go b/sample/sample.go index 4ddbad1d01..4ab85fd85d 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -2,6 +2,7 @@ package sample import ( "fmt" + "math/rand" "os" "strings" @@ -13,7 +14,8 @@ import ( ) type Sampler interface { - GetSampleRate(trace *types.Trace) (rate uint, keep bool, reason string, key string) + GetSampleRate(trace *types.Trace) (rate uint, reason string, key string) + MakeSamplingDecision(rate uint, trace *types.Trace) bool Start() error } @@ -109,3 +111,7 @@ func getMetricType(name string) string { } return "gauge" } + +func makeSamplingDecision(rate uint) bool { + return rand.Intn(int(rate)) == 0 +} diff --git a/sample/totalthroughput.go b/sample/totalthroughput.go index 3475d29b5b..716be6e0d2 100644 --- a/sample/totalthroughput.go +++ b/sample/totalthroughput.go @@ -1,7 +1,6 @@ package sample import ( - "math/rand" "time" dynsampler "github.com/honeycombio/dynsampler-go" @@ -12,6 +11,8 @@ import ( "github.com/honeycombio/refinery/types" ) +var _ Sampler = (*TotalThroughputSampler)(nil) + type TotalThroughputSampler struct { Config *config.TotalThroughputSamplerConfig Logger logger.Logger @@ -80,26 +81,22 @@ func (d *TotalThroughputSampler) SetClusterSize(size int) { } } -func (d *TotalThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint, keep bool, reason string, key string) { +func (d *TotalThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint, reason string, key string) { key = d.key.build(trace) count := int(trace.DescendantCount()) + rate = uint(d.dynsampler.GetSampleRateMulti(key, count)) if rate < 1 { // protect against dynsampler being broken even though it shouldn't be rate = 1 } - shouldKeep := rand.Intn(int(rate)) == 0 + d.Logger.Debug().WithFields(map[string]interface{}{ "sample_key": key, "sample_rate": rate, - "sample_keep": shouldKeep, "trace_id": trace.TraceID, "span_count": count, - }).Logf("got sample rate and decision") - if shouldKeep { - d.Metrics.Increment(d.prefix + "num_kept") - } else { - d.Metrics.Increment(d.prefix + "num_dropped") - } + }).Logf("got sample rate") + d.Metrics.Histogram(d.prefix+"sample_rate", float64(rate)) for name, val := range d.dynsampler.GetMetrics(d.prefix) { switch getMetricType(name) { @@ -111,5 +108,20 @@ func (d *TotalThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint, k d.Metrics.Gauge(name, val) } } - return rate, shouldKeep, "totalthroughput", key + return rate, "totalthroughput", key +} + +func (d *TotalThroughputSampler) MakeSamplingDecision(rate uint, trace *types.Trace) bool { + shouldKeep := makeSamplingDecision(rate) + if shouldKeep { + d.Metrics.Increment(d.prefix + "num_kept") + } else { + d.Metrics.Increment(d.prefix + "num_dropped") + } + d.Logger.Debug().WithFields(map[string]interface{}{ + "sample_rate": rate, + "sample_keep": shouldKeep, + "trace_id": trace.TraceID, + }).Logf("got sample decision") + return shouldKeep } diff --git a/sample/windowed_throughput.go b/sample/windowed_throughput.go index 1c1e1cd542..d02c5e02bb 100644 --- a/sample/windowed_throughput.go +++ b/sample/windowed_throughput.go @@ -1,7 +1,6 @@ package sample import ( - "math/rand" "time" dynsampler "github.com/honeycombio/dynsampler-go" @@ -12,6 +11,8 @@ import ( "github.com/honeycombio/refinery/types" ) +var _ Sampler = (*WindowedThroughputSampler)(nil) + type WindowedThroughputSampler struct { Config *config.WindowedThroughputSamplerConfig Logger logger.Logger @@ -76,27 +77,23 @@ func (d *WindowedThroughputSampler) SetClusterSize(size int) { } } -func (d *WindowedThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint, keep bool, reason string, key string) { +func (d *WindowedThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint, reason string, key string) { key = d.key.build(trace) count := int(trace.DescendantCount()) + rate = uint(d.dynsampler.GetSampleRateMulti(key, count)) if rate < 1 { // protect against dynsampler being broken even though it shouldn't be rate = 1 } - shouldKeep := rand.Intn(int(rate)) == 0 d.Logger.Debug().WithFields(map[string]interface{}{ "sample_key": key, "sample_rate": rate, - "sample_keep": shouldKeep, "trace_id": trace.TraceID, "span_count": count, - }).Logf("got sample rate and decision") - if shouldKeep { - d.Metrics.Increment(d.prefix + "num_kept") - } else { - d.Metrics.Increment(d.prefix + "num_dropped") - } + }).Logf("got sample rate") + d.Metrics.Histogram(d.prefix+"sample_rate", float64(rate)) + for name, val := range d.dynsampler.GetMetrics(d.prefix) { switch getMetricType(name) { case "counter": @@ -107,5 +104,20 @@ func (d *WindowedThroughputSampler) GetSampleRate(trace *types.Trace) (rate uint d.Metrics.Gauge(name, val) } } - return rate, shouldKeep, "Windowedthroughput", key + return rate, "Windowedthroughput", key +} + +func (d *WindowedThroughputSampler) MakeSamplingDecision(rate uint, trace *types.Trace) bool { + keep := makeSamplingDecision(rate) + if keep { + d.Metrics.Increment(d.prefix + "num_kept") + } else { + d.Metrics.Increment(d.prefix + "num_dropped") + } + + d.Logger.Debug().WithFields(map[string]interface{}{ + "sample_rate": rate, + "trace_id": trace.TraceID, + }).Logf("got sample decision") + return keep } diff --git a/sample/windowed_throughput_test.go b/sample/windowed_throughput_test.go index 9fb7d9894b..64c0b3e228 100644 --- a/sample/windowed_throughput_test.go +++ b/sample/windowed_throughput_test.go @@ -39,7 +39,7 @@ func TestWindowedThroughputAddSampleRateKeyToTrace(t *testing.T) { }) } sampler.Start() - rate, _, reason, key := sampler.GetSampleRate(trace) + rate, reason, key := sampler.GetSampleRate(trace) spans := trace.GetSpans()