Skip to content

Commit 9a40c60

Browse files
authored
fix: data races in some situations (#476)
1 parent 17169dc commit 9a40c60

File tree

7 files changed

+186
-49
lines changed

7 files changed

+186
-49
lines changed

browser_context.go

+13-9
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ import (
99
"strings"
1010
"sync"
1111

12+
"github.com/playwright-community/playwright-go/internal/safe"
1213
"golang.org/x/exp/slices"
1314
)
1415

@@ -23,7 +24,7 @@ type browserContextImpl struct {
2324
browser *browserImpl
2425
serviceWorkers []Worker
2526
backgroundPages []Page
26-
bindings map[string]BindingCallFunction
27+
bindings *safe.SyncMap[string, BindingCallFunction]
2728
tracing *tracingImpl
2829
request *apiRequestContextImpl
2930
harRecorders map[string]harRecordingMetadata
@@ -240,18 +241,21 @@ func (b *browserContextImpl) ExposeBinding(name string, binding BindingCallFunct
240241
needsHandle = handle[0]
241242
}
242243
for _, page := range b.Pages() {
243-
if _, ok := page.(*pageImpl).bindings[name]; ok {
244+
if _, ok := page.(*pageImpl).bindings.Load(name); ok {
244245
return fmt.Errorf("Function '%s' has been already registered in one of the pages", name)
245246
}
246247
}
247-
if _, ok := b.bindings[name]; ok {
248+
if _, ok := b.bindings.Load(name); ok {
248249
return fmt.Errorf("Function '%s' has been already registered", name)
249250
}
250-
b.bindings[name] = binding
251251
_, err := b.channel.Send("exposeBinding", map[string]interface{}{
252252
"name": name,
253253
"needsHandle": needsHandle,
254254
})
255+
if err != nil {
256+
return err
257+
}
258+
b.bindings.Store(name, binding)
255259
return err
256260
}
257261

@@ -533,11 +537,11 @@ func (b *browserContextImpl) StorageState(paths ...string) (*StorageState, error
533537
}
534538

535539
func (b *browserContextImpl) onBinding(binding *bindingCallImpl) {
536-
function := b.bindings[binding.initializer["name"].(string)]
537-
if function == nil {
540+
function, ok := b.bindings.Load(binding.initializer["name"].(string))
541+
if !ok || function == nil {
538542
return
539543
}
540-
go binding.Call(function)
544+
binding.Call(function)
541545
}
542546

543547
func (b *browserContextImpl) onClose() {
@@ -740,7 +744,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini
740744
pages: make([]Page, 0),
741745
backgroundPages: make([]Page, 0),
742746
routes: make([]*routeHandlerEntry, 0),
743-
bindings: make(map[string]BindingCallFunction),
747+
bindings: safe.NewSyncMap[string, BindingCallFunction](),
744748
harRecorders: make(map[string]harRecordingMetadata),
745749
closed: make(chan struct{}, 1),
746750
harRouters: make([]*harRouter, 0),
@@ -754,7 +758,7 @@ func newBrowserContext(parent *channelOwner, objectType string, guid string, ini
754758
bt.request = fromChannel(initializer["requestContext"]).(*apiRequestContextImpl)
755759
bt.clock = newClock(bt)
756760
bt.channel.On("bindingCall", func(params map[string]interface{}) {
757-
bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
761+
go bt.onBinding(fromChannel(params["binding"]).(*bindingCallImpl))
758762
})
759763

760764
bt.channel.On("close", bt.onClose)

channel_owner.go

+2-2
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ func (c *channelOwner) dispose(reason ...string) {
2323
if c.parent != nil {
2424
delete(c.parent.objects, c.guid)
2525
}
26-
delete(c.connection.objects, c.guid)
26+
c.connection.objects.Delete(c.guid)
2727
if len(reason) > 0 {
2828
c.wasCollected = reason[0] == "gc"
2929
}
@@ -89,7 +89,7 @@ func (c *channelOwner) createChannelOwner(self interface{}, parent *channelOwner
8989
c.parent.objects[guid] = c
9090
}
9191
if c.connection != nil {
92-
c.connection.objects[guid] = c
92+
c.connection.objects.Store(guid, c)
9393
}
9494
c.channel = newChannel(c, self)
9595
c.eventToSubscriptionMapping = map[string]string{}

connection.go

+12-10
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ import (
1212
"time"
1313

1414
"github.com/go-stack/stack"
15+
"github.com/playwright-community/playwright-go/internal/safe"
1516
)
1617

1718
var (
@@ -27,10 +28,10 @@ type result struct {
2728
type connection struct {
2829
transport transport
2930
apiZone sync.Map
30-
objects map[string]*channelOwner
31+
objects *safe.SyncMap[string, *channelOwner]
3132
lastID atomic.Uint32
3233
rootObject *rootChannelOwner
33-
callbacks sync.Map
34+
callbacks *safe.SyncMap[uint32, *protocolCallback]
3435
afterClose func()
3536
onClose func() error
3637
isRemote bool
@@ -97,21 +98,21 @@ func (c *connection) Dispatch(msg *message) {
9798
method := msg.Method
9899
if msg.ID != 0 {
99100
cb, _ := c.callbacks.LoadAndDelete(uint32(msg.ID))
100-
if cb.(*protocolCallback).noReply {
101+
if cb.noReply {
101102
return
102103
}
103104
if msg.Error != nil {
104-
cb.(*protocolCallback).SetResult(result{
105+
cb.SetResult(result{
105106
Error: parseError(msg.Error.Error),
106107
})
107108
} else {
108-
cb.(*protocolCallback).SetResult(result{
109+
cb.SetResult(result{
109110
Data: c.replaceGuidsWithChannels(msg.Result),
110111
})
111112
}
112113
return
113114
}
114-
object := c.objects[msg.GUID]
115+
object, _ := c.objects.Load(msg.GUID)
115116
if method == "__create__" {
116117
c.createRemoteObject(
117118
object, msg.Params["type"].(string), msg.Params["guid"].(string), msg.Params["initializer"],
@@ -122,7 +123,7 @@ func (c *connection) Dispatch(msg *message) {
122123
return
123124
}
124125
if method == "__adopt__" {
125-
child, ok := c.objects[msg.Params["guid"].(string)]
126+
child, ok := c.objects.Load(msg.Params["guid"].(string))
126127
if !ok {
127128
return
128129
}
@@ -205,7 +206,7 @@ func (c *connection) replaceGuidsWithChannels(payload interface{}) interface{} {
205206
if v.Kind() == reflect.Map {
206207
mapV := payload.(map[string]interface{})
207208
if guid, hasGUID := mapV["guid"]; hasGUID {
208-
if channelOwner, ok := c.objects[guid.(string)]; ok {
209+
if channelOwner, ok := c.objects.Load(guid.(string)); ok {
209210
return channelOwner.channel
210211
}
211212
}
@@ -254,7 +255,7 @@ func (c *connection) sendMessageToServer(object *channelOwner, method string, pa
254255
return nil, fmt.Errorf("could not send message: %w", err)
255256
}
256257

257-
return cb.(*protocolCallback), nil
258+
return cb, nil
258259
}
259260

260261
func (c *connection) setInTracing(isTracing bool) {
@@ -327,7 +328,8 @@ func serializeCallLocation(caller stack.Call) map[string]interface{} {
327328
func newConnection(transport transport, localUtils ...*localUtilsImpl) *connection {
328329
connection := &connection{
329330
abort: make(chan struct{}, 1),
330-
objects: make(map[string]*channelOwner),
331+
callbacks: safe.NewSyncMap[uint32, *protocolCallback](),
332+
objects: safe.NewSyncMap[string, *channelOwner](),
331333
transport: transport,
332334
isRemote: false,
333335
closedError: &safeValue[error]{},

event_emitter.go

+25-18
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ type (
2323
hasInit bool
2424
}
2525
eventRegister struct {
26+
sync.Mutex
2627
listeners []listener
2728
}
2829
listener struct {
@@ -33,18 +34,15 @@ type (
3334

3435
func (e *eventEmitter) Emit(name string, payload ...interface{}) (hasListener bool) {
3536
e.eventsMutex.Lock()
36-
defer e.eventsMutex.Unlock()
3737
e.init()
3838

3939
evt, ok := e.events[name]
4040
if !ok {
41+
e.eventsMutex.Unlock()
4142
return
4243
}
43-
44-
hasListener = evt.count() > 0
45-
46-
evt.callHandlers(payload...)
47-
return
44+
e.eventsMutex.Unlock()
45+
return evt.callHandlers(payload...) > 0
4846
}
4947

5048
func (e *eventEmitter) Once(name string, handler interface{}) {
@@ -60,10 +58,11 @@ func (e *eventEmitter) RemoveListener(name string, handler interface{}) {
6058
defer e.eventsMutex.Unlock()
6159
e.init()
6260

63-
if _, ok := e.events[name]; !ok {
64-
return
61+
if evt, ok := e.events[name]; ok {
62+
evt.Lock()
63+
defer evt.Unlock()
64+
evt.removeHandler(handler)
6565
}
66-
e.events[name].removeHandler(handler)
6766
}
6867

6968
// ListenerCount count the listeners by name, count all if name is empty
@@ -90,6 +89,7 @@ func (e *eventEmitter) ListenerCount(name string) int {
9089

9190
func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) {
9291
e.eventsMutex.Lock()
92+
defer e.eventsMutex.Unlock()
9393
e.init()
9494

9595
if _, ok := e.events[name]; !ok {
@@ -98,7 +98,6 @@ func (e *eventEmitter) addEvent(name string, handler interface{}, once bool) {
9898
}
9999
}
100100
e.events[name].addHandler(handler, once)
101-
e.eventsMutex.Unlock()
102101
}
103102

104103
func (e *eventEmitter) init() {
@@ -108,23 +107,27 @@ func (e *eventEmitter) init() {
108107
}
109108
}
110109

111-
func (e *eventRegister) addHandler(handler interface{}, once bool) {
112-
e.listeners = append(e.listeners, listener{handler: handler, once: once})
110+
func (er *eventRegister) addHandler(handler interface{}, once bool) {
111+
er.Lock()
112+
defer er.Unlock()
113+
er.listeners = append(er.listeners, listener{handler: handler, once: once})
113114
}
114115

115-
func (e *eventRegister) count() int {
116-
return len(e.listeners)
116+
func (er *eventRegister) count() int {
117+
er.Lock()
118+
defer er.Unlock()
119+
return len(er.listeners)
117120
}
118121

119122
func (e *eventRegister) removeHandler(handler interface{}) {
120123
handlerPtr := reflect.ValueOf(handler).Pointer()
121124

122-
e.listeners = slices.DeleteFunc[[]listener](e.listeners, func(l listener) bool {
125+
e.listeners = slices.DeleteFunc(e.listeners, func(l listener) bool {
123126
return reflect.ValueOf(l.handler).Pointer() == handlerPtr
124127
})
125128
}
126129

127-
func (e *eventRegister) callHandlers(payloads ...interface{}) {
130+
func (er *eventRegister) callHandlers(payloads ...interface{}) int {
128131
payloadV := make([]reflect.Value, 0)
129132

130133
for _, p := range payloads {
@@ -136,10 +139,14 @@ func (e *eventRegister) callHandlers(payloads ...interface{}) {
136139
handlerV.Call(payloadV[:int(math.Min(float64(handlerV.Type().NumIn()), float64(len(payloadV))))])
137140
}
138141

139-
for _, l := range e.listeners {
142+
er.Lock()
143+
defer er.Unlock()
144+
count := len(er.listeners)
145+
for _, l := range er.listeners {
140146
if l.once {
141-
defer e.removeHandler(l.handler)
147+
defer er.removeHandler(l.handler)
142148
}
143149
handle(l)
144150
}
151+
return count
145152
}

internal/safe/map.go

+91
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
package safe
2+
3+
import (
4+
"sync"
5+
6+
"golang.org/x/exp/maps"
7+
)
8+
9+
// SyncMap is a thread-safe map
10+
type SyncMap[K comparable, V any] struct {
11+
sync.RWMutex
12+
m map[K]V
13+
}
14+
15+
// NewSyncMap creates a new thread-safe map
16+
func NewSyncMap[K comparable, V any]() *SyncMap[K, V] {
17+
return &SyncMap[K, V]{
18+
m: make(map[K]V),
19+
}
20+
}
21+
22+
func (m *SyncMap[K, V]) Store(k K, v V) {
23+
m.Lock()
24+
defer m.Unlock()
25+
m.m[k] = v
26+
}
27+
28+
func (m *SyncMap[K, V]) Load(k K) (v V, ok bool) {
29+
m.RLock()
30+
defer m.RUnlock()
31+
v, ok = m.m[k]
32+
return
33+
}
34+
35+
// LoadOrStore returns the existing value for the key if present. Otherwise, it stores and returns the given value.
36+
func (m *SyncMap[K, V]) LoadOrStore(k K, v V) (actual V, loaded bool) {
37+
m.Lock()
38+
defer m.Unlock()
39+
actual, loaded = m.m[k]
40+
if loaded {
41+
return
42+
}
43+
m.m[k] = v
44+
return v, false
45+
}
46+
47+
// LoadAndDelete deletes the value for a key, and returns the previous value if any.
48+
func (m *SyncMap[K, V]) LoadAndDelete(k K) (v V, loaded bool) {
49+
m.Lock()
50+
defer m.Unlock()
51+
v, loaded = m.m[k]
52+
if loaded {
53+
delete(m.m, k)
54+
}
55+
return
56+
}
57+
58+
func (m *SyncMap[K, V]) Delete(k K) {
59+
m.Lock()
60+
defer m.Unlock()
61+
delete(m.m, k)
62+
}
63+
64+
func (m *SyncMap[K, V]) Clear() {
65+
m.Lock()
66+
defer m.Unlock()
67+
maps.Clear(m.m)
68+
}
69+
70+
func (m *SyncMap[K, V]) Len() int {
71+
m.RLock()
72+
defer m.RUnlock()
73+
return len(m.m)
74+
}
75+
76+
func (m *SyncMap[K, V]) Clone() map[K]V {
77+
m.RLock()
78+
defer m.RUnlock()
79+
return maps.Clone(m.m)
80+
}
81+
82+
func (m *SyncMap[K, V]) Range(f func(k K, v V) bool) {
83+
m.RLock()
84+
defer m.RUnlock()
85+
86+
for k, v := range m.m {
87+
if !f(k, v) {
88+
break
89+
}
90+
}
91+
}

0 commit comments

Comments
 (0)