Skip to content

Commit

Permalink
core: revert backchannel only filtering (cherry-pick #10455) (#10457)
Browse files Browse the repository at this point in the history
core: revert backchannel only filtering (#10455)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens L <jens@goauthentik.io>
  • Loading branch information
gcp-cherry-pick-bot[bot] and BeryJu authored Jul 11, 2024
1 parent f98204e commit db1d091
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 8 deletions.
13 changes: 10 additions & 3 deletions authentik/lib/sync/outgoing/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from django.core.paginator import Paginator
from django.db.models import Model
from django.db.models.query import Q
from django.db.models.signals import m2m_changed, post_save, pre_delete

from authentik.core.models import Group, User
Expand Down Expand Up @@ -34,7 +35,9 @@ def post_save_provider(sender: type[Model], instance: OutgoingSyncProvider, crea

def model_post_save(sender: type[Model], instance: User | Group, created: bool, **_):
"""Post save handler"""
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
if not provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
task_sync_direct.delay(class_to_path(instance.__class__), instance.pk, Direction.add.value)

Expand All @@ -43,7 +46,9 @@ def model_post_save(sender: type[Model], instance: User | Group, created: bool,

def model_pre_delete(sender: type[Model], instance: User | Group, **_):
"""Pre-delete handler"""
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
if not provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
task_sync_direct.delay(
class_to_path(instance.__class__), instance.pk, Direction.remove.value
Expand All @@ -58,7 +63,9 @@ def model_m2m_changed(
"""Sync group membership"""
if action not in ["post_add", "post_remove"]:
return
if not provider_type.objects.filter(backchannel_application__isnull=False).exists():
if not provider_type.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
).exists():
return
# reverse: instance is a Group, pk_set is a list of user pks
# non-reverse: instance is a User, pk_set is a list of groups
Expand Down
16 changes: 12 additions & 4 deletions authentik/lib/sync/outgoing/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from celery.result import allow_join_result
from django.core.paginator import Paginator
from django.db.models import Model, QuerySet
from django.db.models.query import Q
from django.utils.text import slugify
from django.utils.translation import gettext_lazy as _
from structlog.stdlib import BoundLogger, get_logger
Expand Down Expand Up @@ -37,7 +38,9 @@ def __init__(self, provider_model: type[OutgoingSyncProvider]) -> None:
self._provider_model = provider_model

def sync_all(self, single_sync: Callable[[int], None]):
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
self.trigger_single_task(provider, single_sync)

def trigger_single_task(self, provider: OutgoingSyncProvider, sync_task: Callable[[int], None]):
Expand All @@ -62,7 +65,8 @@ def sync_single(
provider_pk=provider_pk,
)
provider = self._provider_model.objects.filter(
pk=provider_pk, backchannel_application__isnull=False
Q(backchannel_application__isnull=False) | Q(application__isnull=False),
pk=provider_pk,
).first()
if not provider:
return
Expand Down Expand Up @@ -204,7 +208,9 @@ def sync_signal_direct(self, model: str, pk: str | int, raw_op: str):
if not instance:
return
operation = Direction(raw_op)
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
client = provider.client_for_model(instance.__class__)
# Check if the object is allowed within the provider's restrictions
queryset = provider.get_object_qs(instance.__class__)
Expand Down Expand Up @@ -233,7 +239,9 @@ def sync_signal_m2m(self, group_pk: str, action: str, pk_set: list[int]):
group = Group.objects.filter(pk=group_pk).first()
if not group:
return
for provider in self._provider_model.objects.filter(backchannel_application__isnull=False):
for provider in self._provider_model.objects.filter(
Q(backchannel_application__isnull=False) | Q(application__isnull=False)
):
# Check if the object is allowed within the provider's restrictions
queryset: QuerySet = provider.get_object_qs(Group)
# The queryset we get from the provider must include the instance we've got given
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ const doGroupBy = (items: Provider[]) => groupBy(items, (item) => item.verboseNa
async function fetch(query?: string) {
const args: ProvidersAllListRequest = {
ordering: "name",
backchannel: false,
};
if (query !== undefined) {
args.search = query;
Expand Down

0 comments on commit db1d091

Please sign in to comment.