Skip to content

Commit 28632fc

Browse files
author
Andrey Zelenchuk
committed
Fix losing initial payload because of the race.
1 parent e37937e commit 28632fc

File tree

1 file changed

+19
-26
lines changed

1 file changed

+19
-26
lines changed

channels_graphql_ws/graphql_ws_consumer.py

Lines changed: 19 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -713,9 +713,6 @@ async def _register_subscription(
713713
# `_sids_by_group` without any locks.
714714
self._assert_thread()
715715

716-
# The subject we will trigger on the `broadcast` message.
717-
trigger = rx.subjects.Subject()
718-
719716
# The subscription notification queue.
720717
queue_size = notification_queue_limit
721718
if not queue_size or queue_size < 0:
@@ -728,56 +725,41 @@ async def _register_subscription(
728725

729726
# Start an endless task which listens the `notification_queue`
730727
# and invokes subscription "resolver" on new notifications.
731-
async def notifier():
728+
async def notifier(observer: rx.Observer):
732729
"""Watch the notification queue and notify clients."""
733730

734731
# Assert we run in a proper thread.
735732
self._assert_thread()
736-
737-
# Dirty hack to partially workaround the race between:
738-
# 1) call to `result.subscribe` in `_on_gql_start`; and
739-
# 2) call to `trigger.on_next` below in this function.
740-
# The first call must be earlier. Otherwise, first one or more notifications
741-
# may be lost.
742-
await asyncio.sleep(1)
743-
744733
while True:
745734
serialized_payload = await notification_queue.get()
746735

747736
# Run a subscription's `publish` method (invoked by the
748-
# `trigger.on_next` function) within the threadpool used
737+
# `observer.on_next` function) within the threadpool used
749738
# for processing other GraphQL resolver functions.
750739
# NOTE: it is important to run the deserialization
751740
# in the worker thread as well.
752741
def workload():
753742
try:
754743
payload = Serializer.deserialize(serialized_payload)
755744
except Exception as ex: # pylint: disable=broad-except
756-
trigger.on_error(f"Cannot deserialize payload. {ex}")
745+
observer.on_error(f"Cannot deserialize payload. {ex}")
757746
else:
758-
trigger.on_next(payload)
747+
observer.on_next(payload)
759748

760749
await self._run_in_worker(workload)
761750

762751
# Message processed. This allows `Queue.join` to work.
763752
notification_queue.task_done()
764753

765-
# Enqueue the `publish` method execution. But do not notify
766-
# clients when `publish` returns `SKIP`.
767-
stream = trigger.map(publish_callback).filter( # pylint: disable=no-member
768-
lambda publish_returned: publish_returned is not self.SKIP
769-
)
770-
754+
def push_payloads(observer: rx.Observer):
771755
# Start listening for broadcasts (subscribe to the Channels
772756
# groups), spawn the notification processing task and put
773757
# subscription information into the registry.
774758
# NOTE: Update of `_sids_by_group` & `_subscriptions` must be
775759
# atomic i.e. without `awaits` in between.
776-
waitlist = []
777760
for group in groups:
778761
self._sids_by_group.setdefault(group, []).append(operation_id)
779-
waitlist.append(self._channel_layer.group_add(group, self.channel_name))
780-
notifier_task = self._spawn_background_task(notifier())
762+
notifier_task = self._spawn_background_task(notifier(observer))
781763
self._subscriptions[operation_id] = self._SubInf(
782764
groups=groups,
783765
sid=operation_id,
@@ -786,9 +768,20 @@ def workload():
786768
notifier_task=notifier_task,
787769
)
788770

789-
await asyncio.wait(waitlist)
771+
await asyncio.wait(
772+
[
773+
self._channel_layer.group_add(group, self.channel_name)
774+
for group in groups
775+
]
776+
)
790777

791-
return stream
778+
# Enqueue the `publish` method execution. But do not notify
779+
# clients when `publish` returns `SKIP`.
780+
return (
781+
rx.Observable.create(push_payloads) # pylint: disable=no-member
782+
.map(publish_callback)
783+
.filter(lambda publish_returned: publish_returned is not self.SKIP)
784+
)
792785

793786
async def _on_gql_stop(self, operation_id):
794787
"""Process the STOP message.

0 commit comments

Comments
 (0)