Skip to content

Commit

Permalink
Admin actions to run user helm charts (#1471)
Browse files Browse the repository at this point in the history
Adds admin actions that will let us upgrade user
helm charts. Initially this is the bootstrap user
chart and the reset home directory chart, but
could be expanded in future.
  • Loading branch information
michaeljcollinsuk authored Feb 24, 2025
1 parent 2ab77a0 commit c5ad8fc
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 12 deletions.
38 changes: 37 additions & 1 deletion controlpanel/api/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,16 @@
import csv

# Third-party
from django.contrib import admin
from django.conf import settings
from django.contrib import admin, messages
from django.http import HttpResponse
from django.utils import timezone
from django.utils.translation import ngettext
from simple_history.admin import SimpleHistoryAdmin

# First-party/Local
from controlpanel.api.models import App, Feedback, IPAllowlist, S3Bucket, ToolDeployment, User
from controlpanel.api.tasks.user import upgrade_user_helm_chart


def make_migration_pending(modeladmin, request, queryset):
Expand Down Expand Up @@ -56,6 +59,39 @@ class UserAdmin(admin.ModelAdmin):
"email",
"auth0_id",
)
actions = [
"upgrade_bootstrap_user_helm_chart",
"upgrade_provision_user_helm_chart",
"reset_home_directory",
]

def _upgrade_helm_chart(self, request, queryset, chart_name):
total = 0
for user in queryset:
if not user.is_iam_user:
continue

upgrade_user_helm_chart.delay(user.username, chart_name)
total += 1

self.message_user(
request,
ngettext(
f"{chart_name} helm chart updated for %d user.",
f"{chart_name} helm chart updated for %d users.",
total,
)
% total,
messages.SUCCESS,
)

@admin.action(description="Upgrade bootstrap-user helm chart")
def upgrade_bootstrap_user_helm_chart(self, request, queryset):
self._upgrade_helm_chart(request, queryset, f"{settings.HELM_REPO}/bootstrap-user")

@admin.action(description="Reset users home directory")
def reset_home_directory(self, request, queryset):
self._upgrade_helm_chart(request, queryset, f"{settings.HELM_REPO}/reset-user-efs-home")


class IPAllowlistAdmin(SimpleHistoryAdmin):
Expand Down
12 changes: 12 additions & 0 deletions controlpanel/api/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,18 @@ def __init__(self, user):
def _init_aws_services(self):
self.aws_role_service = self.create_aws_service(AWSRole)

def get_helm_chart(self, chart_name):
"""
Lookup helm chart dictionary by name. This is fine for now as there are not many charts,
but if the number of charts grows, we should consider refactoring to store them in a
way that allows a more efficient lookup.
"""
for chart_type, charts in self.user_helm_charts.items():
for chart in charts:
if chart["chart"] == chart_name:
return chart
return None

@property
def user_helm_charts(self):
# The full list of the charts required for a user under different situations
Expand Down
13 changes: 2 additions & 11 deletions controlpanel/api/tasks/tools.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,9 @@
# Third-party
from celery import shared_task
from django.apps import apps

# First-party/Local
from controlpanel.api import cluster, helm


def _get_model(model_name):
"""
This is used to avoid a circular import when calling tasks from within models. I feel like this
is the best worst option. For futher reading on this issue and the lack of an ideal solution:
https://stackoverflow.com/questions/26379026/resolving-circular-imports-in-celery-and-django
"""
return apps.get_model("api", model_name)
from controlpanel.api import helm
from controlpanel.utils import _get_model


# TODO do we need to use acks_late? try without first
Expand Down
18 changes: 18 additions & 0 deletions controlpanel/api/tasks/user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
# Third-party
from celery import shared_task

# First-party/Local
from controlpanel.api import cluster
from controlpanel.utils import _get_model


@shared_task(acks_on_failure_or_timeout=False)
def upgrade_user_helm_chart(username, chart_name):
User = _get_model("User")
try:
user = User.objects.get(username=username)
except User.DoesNotExist:
return
cluster_user = cluster.User(user)
chart = cluster_user.get_helm_chart(chart_name)
cluster_user._run_helm_install_command(chart)
10 changes: 10 additions & 0 deletions controlpanel/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from channels.exceptions import StopConsumer
from channels.generic.http import AsyncHttpConsumer
from channels.layers import get_channel_layer
from django.apps import apps
from django.conf import settings
from django.template.defaultfilters import slugify
from nacl import encoding, public
Expand Down Expand Up @@ -246,3 +247,12 @@ def start_background_task(task, message):
**message,
},
)


def _get_model(model_name):
"""
This is used to avoid a circular import when calling tasks from within models. I feel like this
is the best worst option. For futher reading on this issue and the lack of an ideal solution:
https://stackoverflow.com/questions/26379026/resolving-circular-imports-in-celery-and-django
"""
return apps.get_model("api", model_name)
46 changes: 46 additions & 0 deletions tests/api/tasks/test_user.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Standard library
from unittest.mock import MagicMock, patch

# Third-party
import pytest

# First-party/Local
from controlpanel.api.models import User
from controlpanel.api.tasks.user import upgrade_user_helm_chart


@pytest.fixture()
def mock_get_user_model():
with patch("controlpanel.api.tasks.user._get_model") as mock_get_model:
mock_get_model.return_value = User
yield mock_get_model


@patch("controlpanel.api.tasks.user.cluster.User")
def test_upgrade_user_helm_chart_user_does_not_exist(mock_cluster_user, mock_get_user_model):
with patch.object(User.objects, "get") as mock_get:
mock_get.side_effect = User.DoesNotExist
upgrade_user_helm_chart("nonexistent_user", "chart_name")

mock_get_user_model.assert_called_once_with("User")
mock_cluster_user.assert_not_called()


@patch("controlpanel.api.tasks.user.cluster.User")
def test_upgrade_user_helm_chart_success(mock_cluster_user, mock_get_user_model):
cluster_user_instance = MagicMock()
mock_cluster_user.return_value = cluster_user_instance

chart = MagicMock()
cluster_user_instance.get_helm_chart.return_value = chart

user_instance = MagicMock()

with patch.object(User.objects, "get") as mock_get:
mock_get.return_value = user_instance
upgrade_user_helm_chart("existing_user", "chart_name")

mock_get_user_model.assert_called_once_with("User")
mock_cluster_user.assert_called_once_with(user_instance)
cluster_user_instance.get_helm_chart.assert_called_once_with("chart_name")
cluster_user_instance._run_helm_install_command.assert_called_once_with(chart)
1 change: 1 addition & 0 deletions tests/api/views/test_user.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

# First-party/Local
from controlpanel.api.models import User
from controlpanel.api.tasks.user import upgrade_user_helm_chart


@pytest.fixture(autouse=True)
Expand Down

0 comments on commit c5ad8fc

Please sign in to comment.