From 07202afea3e36b00fd1d8fee5d0535cb003ac107 Mon Sep 17 00:00:00 2001 From: Kent Quirk Date: Tue, 18 Jun 2024 09:05:09 -0400 Subject: [PATCH] feat: Add some useful generics (#1206) ## Which problem is this PR solving? - Some features I plan to import from Refinery 3 will likely use the Fanout generic that was added there - The gossip-based peer mechanism will need SetWithTTL ## Short description of the changes - Create and test SetWithTTL - Import code and tests for Fanout --- generics/fanout.go | 237 ++++++++++++++++++++++++++++++++++++++++ generics/fanout_test.go | 196 +++++++++++++++++++++++++++++++++ generics/setttl.go | 88 +++++++++++++++ generics/setttl_test.go | 26 +++++ go.mod | 1 + go.sum | 2 + 6 files changed, 550 insertions(+) create mode 100644 generics/fanout.go create mode 100644 generics/fanout_test.go create mode 100644 generics/setttl.go create mode 100644 generics/setttl_test.go diff --git a/generics/fanout.go b/generics/fanout.go new file mode 100644 index 0000000000..d0c552e8fc --- /dev/null +++ b/generics/fanout.go @@ -0,0 +1,237 @@ +package generics + +import "sync" + +// Fanout takes a slice of input, a parallelism factor, and a worker factory. It +// calls the generated worker on every element of the input, and returns a +// (possibly filtered) slice of the outputs in no particular order. Only the +// outputs that pass the predicate (if it is not nil) will be added to the +// slice. +// +// The factory takes an integer (the worker number) and constructs a function of +// type func(T) U that processes a single input and produces a single output. It +// also constructs a cleanup function, which may be nil. The cleanup function is +// called once for each worker, after the worker has completed processing all of +// its inputs. It is given the same index as the corresponding worker factory. +// +// If predicate is not nil, it will only add the output to the result slice if +// the predicate returns true. It will fan out the input to the worker function +// in parallel, and fan in the results to the output slice. +func Fanout[T, U any](input []T, parallelism int, workerFactory func(int) (worker func(T) U, cleanup func(int)), predicate func(U) bool) []U { + result := make([]U, 0) + + fanoutChan := make(chan T, parallelism) + faninChan := make(chan U, parallelism) + + // send all the trace IDs to the fanout channel + wgFans := sync.WaitGroup{} + wgFans.Add(1) + go func() { + defer wgFans.Done() + defer close(fanoutChan) + for i := range input { + fanoutChan <- input[i] + } + }() + + wgFans.Add(1) + go func() { + defer wgFans.Done() + for r := range faninChan { + result = append(result, r) + } + }() + + wgWorkers := sync.WaitGroup{} + for i := 0; i < parallelism; i++ { + wgWorkers.Add(1) + worker, cleanup := workerFactory(i) + go func(i int) { + defer wgWorkers.Done() + if cleanup != nil { + defer cleanup(i) + } + for u := range fanoutChan { + product := worker(u) + if predicate == nil || predicate(product) { + faninChan <- product + } + } + }(i) + } + + // wait for the workers to finish + wgWorkers.Wait() + // now we can close the fanin channel and wait for the fanin goroutine to finish + // fanout should already be done but this makes sure we don't lose track of it + close(faninChan) + wgFans.Wait() + + return result +} + +// EasyFanout is a convenience function for when you don't need all the +// features. It takes a slice of input, a parallelism factor, and a worker +// function. It calls the worker on every element of the input with the +// specified parallelism, and returns a slice of the outputs in no particular +// order. +func EasyFanout[T, U any](input []T, parallelism int, worker func(T) U) []U { + return Fanout(input, parallelism, func(int) (func(T) U, func(int)) { + return worker, nil + }, nil) +} + +// FanoutToMap takes a slice of input, a parallelism factor, and a worker +// factory. It calls the generated worker on every element of the input, and +// returns a (possibly filtered) map of the inputs to the outputs. Only the +// outputs that pass the predicate (if it is not nil) will be added to the map. +// +// The factory takes an integer (the worker number) and constructs a function of +// type func(T) U that processes a single input and produces a single output. It +// also constructs a cleanup function, which may be nil. The cleanup function is +// called once for each worker, after the worker has completed processing all of +// its inputs. It is given the same index as the corresponding worker factory. +// +// If predicate is not nil, it will only add the output to the result slice if +// the predicate returns true. It will fan out the input to the worker function +// in parallel, and fan in the results to the output slice. +func FanoutToMap[T comparable, U any](input []T, parallelism int, workerFactory func(int) (worker func(T) U, cleanup func(int)), predicate func(U) bool) map[T]U { + result := make(map[T]U) + type resultPair struct { + key T + val U + } + + fanoutChan := make(chan T, parallelism) + faninChan := make(chan resultPair, parallelism) + + // send all the trace IDs to the fanout channel + wgFans := sync.WaitGroup{} + wgFans.Add(1) + go func() { + defer wgFans.Done() + defer close(fanoutChan) + for i := range input { + fanoutChan <- input[i] + } + }() + + wgFans.Add(1) + go func() { + defer wgFans.Done() + for r := range faninChan { + result[r.key] = r.val + } + }() + + wgWorkers := sync.WaitGroup{} + for i := 0; i < parallelism; i++ { + wgWorkers.Add(1) + worker, cleanup := workerFactory(i) + go func(i int) { + defer wgWorkers.Done() + if cleanup != nil { + defer cleanup(i) + } + for t := range fanoutChan { + product := worker(t) + if predicate == nil || predicate(product) { + faninChan <- resultPair{t, product} + } + } + }(i) + } + + // wait for the workers to finish + wgWorkers.Wait() + // now we can close the fanin channel and wait for the fanin goroutine to finish + // fanout should already be done but this makes sure we don't lose track of it + close(faninChan) + wgFans.Wait() + + return result +} + +// EasyFanoutToMap is a convenience function for when you don't need all the +// features. It takes a slice of input, a parallelism factor, and a worker +// function. It calls the worker on every element of the input with the +// specified parallelism, and returns a map of the inputs to the outputs. +func EasyFanoutToMap[T comparable, U any](input []T, parallelism int, worker func(T) U) map[T]U { + return FanoutToMap(input, parallelism, func(int) (func(T) U, func(int)) { + return worker, nil + }, nil) +} + +// FanoutChunksToMap takes a slice of input, a chunk size, a maximum parallelism +// factor, and a worker factory. It calls the generated worker on every chunk of +// the input, and returns a (possibly filtered) map of the inputs to the +// outputs. Only the outputs that pass the predicate (if it is not nil) will be +// added to the map. +// +// The maximum parallelism factor is the maximum number of workers that will be +// run in parallel. The actual number of workers will be the minimum of the +// maximum parallelism factor and the number of chunks in the input. +func FanoutChunksToMap[T comparable, U any](input []T, chunkSize int, maxParallelism int, workerFactory func(int) (worker func([]T) map[T]U, cleanup func(int)), predicate func(U) bool) map[T]U { + result := make(map[T]U, 0) + + if chunkSize <= 0 { + chunkSize = 1 + } + + type resultPair struct { + key T + val U + } + parallelism := min(maxParallelism, max(len(input)/chunkSize, 1)) + fanoutChan := make(chan []T, parallelism) + faninChan := make(chan resultPair, parallelism) + + // send all the trace IDs to the fanout channel + wgFans := sync.WaitGroup{} + wgFans.Add(1) + go func() { + defer wgFans.Done() + defer close(fanoutChan) + for i := 0; i < len(input); i += chunkSize { + end := min(i+chunkSize, len(input)) + fanoutChan <- input[i:end] + } + }() + + wgFans.Add(1) + go func() { + defer wgFans.Done() + for r := range faninChan { + result[r.key] = r.val + } + }() + + wgWorkers := sync.WaitGroup{} + for i := 0; i < parallelism; i++ { + wgWorkers.Add(1) + worker, cleanup := workerFactory(i) + go func(i int) { + defer wgWorkers.Done() + if cleanup != nil { + defer cleanup(i) + } + for u := range fanoutChan { + products := worker(u) + for key, product := range products { + if predicate == nil || predicate(product) { + faninChan <- resultPair{key: key, val: product} + } + } + } + }(i) + } + + // wait for the workers to finish + wgWorkers.Wait() + // now we can close the fanin channel and wait for the fanin goroutine to finish + // fanout should already be done but this makes sure we don't lose track of it + close(faninChan) + wgFans.Wait() + + return result +} diff --git a/generics/fanout_test.go b/generics/fanout_test.go new file mode 100644 index 0000000000..988acba831 --- /dev/null +++ b/generics/fanout_test.go @@ -0,0 +1,196 @@ +package generics + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" +) + +func TestFanout(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 3 + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + return worker, nil + } + + result := Fanout(input, parallelism, workerFactory, nil) + assert.ElementsMatch(t, []int{2, 4, 6, 8, 10}, result) +} + +func TestFanoutWithPredicate(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 3 + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + return worker, nil + } + predicate := func(i int) bool { + return i%4 == 0 + } + + result := Fanout(input, parallelism, workerFactory, predicate) + assert.ElementsMatch(t, []int{4, 8}, result) +} + +func TestFanoutWithCleanup(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 4 + cleanups := []int{} + mut := sync.Mutex{} + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + cleanup := func(i int) { + mut.Lock() + cleanups = append(cleanups, i) + mut.Unlock() + } + return worker, cleanup + } + + result := Fanout(input, parallelism, workerFactory, nil) + assert.ElementsMatch(t, []int{2, 4, 6, 8, 10}, result) + assert.ElementsMatch(t, []int{0, 1, 2, 3}, cleanups) +} + +var expected = map[int]int{ + 1: 2, + 2: 4, + 3: 6, + 4: 8, + 5: 10, +} + +func TestFanoutMap(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 3 + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + return worker, nil + } + + result := FanoutToMap(input, parallelism, workerFactory, nil) + assert.EqualValues(t, expected, result) +} + +func TestFanoutMapWithPredicate(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 3 + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + return worker, nil + } + predicate := func(i int) bool { + return i%4 == 0 + } + + result := FanoutToMap(input, parallelism, workerFactory, predicate) + assert.EqualValues(t, map[int]int{2: 4, 4: 8}, result) +} + +func TestFanoutMapWithCleanup(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + parallelism := 4 + cleanups := []int{} + mut := sync.Mutex{} + workerFactory := func(i int) (func(int) int, func(int)) { + worker := func(i int) int { + return i * 2 + } + cleanup := func(i int) { + mut.Lock() + cleanups = append(cleanups, i) + mut.Unlock() + } + return worker, cleanup + } + + result := FanoutToMap(input, parallelism, workerFactory, nil) + assert.EqualValues(t, expected, result) + assert.ElementsMatch(t, []int{0, 1, 2, 3}, cleanups) +} + +func TestEasyFanout(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + worker := func(i int) int { + return i * 2 + } + + result := EasyFanout(input, 3, worker) + assert.ElementsMatch(t, []int{2, 4, 6, 8, 10}, result) +} + +func TestEasyFanoutToMap(t *testing.T) { + input := []int{1, 2, 3, 4, 5} + worker := func(i int) int { + return i * 2 + } + + result := EasyFanoutToMap(input, 3, worker) + assert.EqualValues(t, expected, result) +} + +func BenchmarkFanoutParallelism(b *testing.B) { + parallelisms := []int{1, 3, 6, 10, 25, 100} + for _, parallelism := range parallelisms { + b.Run(fmt.Sprintf("parallelism%02d", parallelism), func(b *testing.B) { + + input := make([]int, b.N) + for i := range input { + input[i] = i + } + + workerFactory := func(i int) (func(int) string, func(int)) { + worker := func(i int) string { + h := sha256.Sum256(([]byte(fmt.Sprintf("%d", i)))) + time.Sleep(1 * time.Millisecond) + return hex.EncodeToString(h[:]) + } + cleanup := func(i int) {} + return worker, cleanup + } + b.ResetTimer() + _ = Fanout(input, parallelism, workerFactory, nil) + }) + } +} + +func BenchmarkFanoutMapParallelism(b *testing.B) { + parallelisms := []int{1, 3, 6, 10, 25, 100} + for _, parallelism := range parallelisms { + b.Run(fmt.Sprintf("parallelism%02d", parallelism), func(b *testing.B) { + + input := make([]int, b.N) + for i := range input { + input[i] = i + } + + workerFactory := func(i int) (func(int) string, func(int)) { + worker := func(i int) string { + h := sha256.Sum256(([]byte(fmt.Sprintf("%d", i)))) + time.Sleep(1 * time.Millisecond) + return hex.EncodeToString(h[:]) + } + cleanup := func(i int) {} + return worker, cleanup + } + b.ResetTimer() + _ = FanoutToMap(input, parallelism, workerFactory, nil) + }) + } +} diff --git a/generics/setttl.go b/generics/setttl.go new file mode 100644 index 0000000000..f39d6a65a4 --- /dev/null +++ b/generics/setttl.go @@ -0,0 +1,88 @@ +package generics + +import ( + "cmp" + "sort" + "sync" + "time" + + "github.com/jonboulle/clockwork" + "golang.org/x/exp/maps" +) + +// SetWithTTL is a unique set of items with a TTL (time to live) for each item. +// After the TTL expires, the item is automatically removed from the set. +// It is safe for concurrent use. +type SetWithTTL[T cmp.Ordered] struct { + Items map[T]time.Time + TTL time.Duration + Clock clockwork.Clock + mut sync.RWMutex +} + +// NewSetWithTTL returns a new SetWithTTL with elements `es` and a TTL of `ttl`. +func NewSetWithTTL[T cmp.Ordered](ttl time.Duration, es ...T) *SetWithTTL[T] { + s := &SetWithTTL[T]{ + Items: make(map[T]time.Time, len(es)), + TTL: ttl, + Clock: clockwork.NewRealClock(), + } + s.Add(es...) + return s +} + +// Add adds elements `es` to the SetWithTTL. +func (s *SetWithTTL[T]) Add(es ...T) { + s.mut.Lock() + defer s.mut.Unlock() + for _, e := range es { + s.Items[e] = s.Clock.Now().Add(s.TTL) + } +} + +// Remove removes elements `es` from the SetWithTTL. +func (s *SetWithTTL[T]) Remove(es ...T) { + s.mut.Lock() + defer s.mut.Unlock() + for _, e := range es { + delete(s.Items, e) + } +} + +// Contains returns true if the SetWithTTL contains `e`. +func (s *SetWithTTL[T]) Contains(e T) bool { + s.mut.RLock() + item, ok := s.Items[e] + s.mut.RUnlock() + if !ok { + return false + } + return item.After(time.Now()) +} + +func (s *SetWithTTL[T]) cleanup() int { + s.mut.Lock() + defer s.mut.Unlock() + maps.DeleteFunc(s.Items, func(k T, exp time.Time) bool { + return exp.Before(s.Clock.Now()) + }) + return len(s.Items) +} + +// Members returns the unique elements of the SetWithTTL in sorted order. +// It also removes any items that have expired. +func (s *SetWithTTL[T]) Members() []T { + s.cleanup() + s.mut.RLock() + members := maps.Keys(s.Items) + s.mut.RUnlock() + sort.Slice(members, func(i, j int) bool { + return cmp.Less(members[i], members[j]) + }) + return members +} + +// Length returns the number of items in the SetWithTTL after removing any expired items. +func (s *SetWithTTL[T]) Length() int { + return s.cleanup() +} diff --git a/generics/setttl_test.go b/generics/setttl_test.go new file mode 100644 index 0000000000..1fca1316d6 --- /dev/null +++ b/generics/setttl_test.go @@ -0,0 +1,26 @@ +package generics + +import ( + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" +) + +func TestSetTTLBasics(t *testing.T) { + s := NewSetWithTTL(100*time.Millisecond, "a", "b", "b") + fakeclock := clockwork.NewFakeClock() + s.Clock = fakeclock + assert.Equal(t, 2, s.Length()) + fakeclock.Advance(50 * time.Millisecond) + s.Add("c") + assert.Equal(t, 3, s.Length()) + assert.Equal(t, s.Members(), []string{"a", "b", "c"}) + fakeclock.Advance(60 * time.Millisecond) + assert.Equal(t, 1, s.Length()) + assert.Equal(t, s.Members(), []string{"c"}) + fakeclock.Advance(100 * time.Millisecond) + assert.Equal(t, 0, s.Length()) + assert.Equal(t, s.Members(), []string{}) +} diff --git a/go.mod b/go.mod index 2d856aaf60..091b0ecb71 100644 --- a/go.mod +++ b/go.mod @@ -19,6 +19,7 @@ require ( github.com/honeycombio/husky v0.30.0 github.com/honeycombio/libhoney-go v1.23.1 github.com/jessevdk/go-flags v1.5.0 + github.com/jonboulle/clockwork v0.4.0 github.com/json-iterator/go v1.1.12 github.com/klauspost/compress v1.17.9 github.com/panmari/cuckoofilter v1.0.3 diff --git a/go.sum b/go.sum index a1dc29334a..0261222db5 100644 --- a/go.sum +++ b/go.sum @@ -70,6 +70,8 @@ github.com/honeycombio/opentelemetry-proto-go/otlp v0.19.0-compat h1:fMpIzVAl5C2 github.com/honeycombio/opentelemetry-proto-go/otlp v0.19.0-compat/go.mod h1:mC2aK20Z/exugKpqCgcpwEadiS0im8K6mZsD4Is/hCY= github.com/jessevdk/go-flags v1.5.0 h1:1jKYvbxEjfUl0fmqTCOfonvskHHXMjBySTLW4y9LFvc= github.com/jessevdk/go-flags v1.5.0/go.mod h1:Fw0T6WPc1dYxT4mKEZRfG5kJhaTDP9pj1c2EWnYs/m4= +github.com/jonboulle/clockwork v0.4.0 h1:p4Cf1aMWXnXAUh8lVfewRBx1zaTSYKrKMF2g3ST4RZ4= +github.com/jonboulle/clockwork v0.4.0/go.mod h1:xgRqUGwRcjKCO1vbZUEtSLrqKoPSsUpK7fnezOII0kc= github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=