From 1f747d338c64504615e8a71ca56d462f72e63a7c Mon Sep 17 00:00:00 2001 From: Rehan Date: Mon, 30 Nov 2020 13:16:10 +0500 Subject: [PATCH] Added renew subscriptions API end point and tests --- .../subscriptions/api/v2/tests/test_views.py | 142 +++++++++++++++++- ecommerce/subscriptions/api/v2/views.py | 63 +++++++- ecommerce/subscriptions/utils.py | 4 +- 3 files changed, 203 insertions(+), 6 deletions(-) diff --git a/ecommerce/subscriptions/api/v2/tests/test_views.py b/ecommerce/subscriptions/api/v2/tests/test_views.py index dd55f4d7e1b..cffd781bb0e 100644 --- a/ecommerce/subscriptions/api/v2/tests/test_views.py +++ b/ecommerce/subscriptions/api/v2/tests/test_views.py @@ -2,12 +2,16 @@ Unit tests for subscription API views. """ import json +import mock +from oscar.test.factories import OrderLineFactory import pytest from django.urls import reverse -from ecommerce.subscriptions.api.v2.tests.constants import LIMITED_ACCESS +from ecommerce.core.models import SiteConfiguration +from ecommerce.subscriptions.api.v2.tests.constants import LIMITED_ACCESS, FULL_ACCESS_COURSES from ecommerce.subscriptions.api.v2.tests.mixins import SubscriptionProductMixin +from ecommerce.subscriptions.api.v2.tests.utils import mock_user_subscription from ecommerce.tests.testcases import TestCase @@ -15,6 +19,27 @@ class SubscriptionViewSetTests(SubscriptionProductMixin, TestCase): """ Unit tests for "SubscriptionViewSet". """ + def _login_as_user(self, username=None, is_staff=False): + """ + Log in with a particular username or as staff user. + """ + user = self.create_user( + username=username, + is_staff=is_staff + ) + + self.client.logout() + self.client.login(username=user.username, password='test') + return user + + def _build_renew_subscription_url(self, subscription_id): + """ + build renew subscription url for a given subscription. + """ + return '{root_url}?subscription_id={subscription_id}'.format( + root_url=reverse('api:v2:subscriptions-renew-subscription-list'), + subscription_id=subscription_id + ) def test_list(self): """ @@ -105,10 +130,19 @@ def test_update(self): self.assertEqual(response.data.get('description'), subscription_data.get('description')) self.assertFalse(response.data.get('subscription_status')) - def test_toggle_course_payments(self): + def test_toggle_course_payments_without_authentication(self): + """ + Verify that subscriptions API does not toggle course payments flag without authentication. + """ + request_url = reverse('api:v2:subscriptions-toggle-course-payments-list') + response = self.client.post(request_url) + self.assertEqual(response.status_code, 401) + + def test_toggle_course_payments_with_authentication(self): """ - Verify that subscriptions API correctly toggles course payments flag. + Verify that subscriptions API correctly toggles course payments flag with authentication. """ + self._login_as_user(is_staff=True) request_url = reverse('api:v2:subscriptions-toggle-course-payments-list') expected_data = { 'course_payments': not self.site.siteconfiguration.enable_course_payments @@ -128,3 +162,105 @@ def test_course_payments_status(self): response = self.client.get(request_url) self.assertEqual(response.status_code, 200) self.assertEqual(response.data, expected_data) + + @mock.patch.object(SiteConfiguration, 'access_token', mock.Mock(return_value='foo')) + def test_renew_subscription_without_subscription_id_query_parameter(self): + """ + Verify that subscriptions API returns correct response on missing subscription id. + """ + self._login_as_user(is_staff=True) + request_url = reverse('api:v2:subscriptions-renew-subscription-list') + + expected_data = dict(error='Subscription ID not provided.') + response = self.client.get(request_url) + self.assertEqual(response.status_code, 406) + self.assertDictEqual(response.data, expected_data) + + @mock.patch.object(SiteConfiguration, 'access_token', mock.Mock(return_value='foo')) + def test_renew_subscription_without_subscription(self): + """ + Verify that subscriptions API returns correct response if subscription is inactive/non-existant. + """ + self._login_as_user(is_staff=True) + request_url = self._build_renew_subscription_url(1) + expected_data = dict(error='Subscription is inactive or does not exist.') + response = self.client.get(request_url) + self.assertEqual(response.status_code, 406) + self.assertDictEqual(response.data, expected_data) + + @mock.patch.object(SiteConfiguration, 'access_token', mock.Mock(return_value='foo')) + def test_renew_subscription_with_incorrect_subscription_type(self): + """ + Verify that subscriptions API correct response if requested subscription is of unsupported. + """ + self._login_as_user(is_staff=True) + limited_access_subscription = self.create_subscription( + subscription_type=LIMITED_ACCESS, stockrecords__partner=self.site.partner + ) + request_url = self._build_renew_subscription_url(str(limited_access_subscription.id)) + expected_data = dict( + error='Subscription of type {no_expiry_type} can\'t be renewed.'.format( + no_expiry_type=limited_access_subscription.attr.subscription_type.option + ) + ) + response = self.client.get(request_url) + self.assertEqual(response.status_code, 406) + self.assertDictEqual(response.data, expected_data) + + @mock.patch.object(SiteConfiguration, 'access_token', mock.Mock(return_value='foo')) + @mock.patch('ecommerce.subscriptions.utils.get_lms_resource_for_user') + def test_renew_subscription_with_existing_valid_subscription(self, lms_resource_for_user): + """ + Verify that subscriptions API correct response if user already has a valid subscription. + """ + self._login_as_user(is_staff=True) + full_access_courses_subscription = self.create_subscription( + subscription_type=FULL_ACCESS_COURSES, + subscription_status=True, + stockrecords__partner=self.site.partner + ) + request_url = self._build_renew_subscription_url(str(full_access_courses_subscription.id)) + lms_resource_for_user.return_value = [mock_user_subscription()] + expected_data = dict(error='User already has a valid subscription.') + response = self.client.get(request_url) + self.assertEqual(response.status_code, 406) + self.assertDictEqual(response.data, expected_data) + + @mock.patch.object(SiteConfiguration, 'access_token', mock.Mock(return_value='foo')) + @mock.patch('ecommerce.subscriptions.utils.get_lms_resource_for_user') + def test_renew_subscription_without_existing_purchase_of_subscription(self, lms_resource_for_user): + """ + Verify that subscriptions API correct response if user has not purchased requested subscription previously. + """ + self._login_as_user(is_staff=True) + full_access_courses_subscription = self.create_subscription( + subscription_type=FULL_ACCESS_COURSES, + subscription_status=True, + stockrecords__partner=self.site.partner + ) + request_url = self._build_renew_subscription_url(str(full_access_courses_subscription.id)) + lms_resource_for_user.return_value = [] + expected_data = dict(error='The user has not purchased the requested subscription previously.') + response = self.client.get(request_url) + self.assertEqual(response.status_code, 406) + self.assertDictEqual(response.data, expected_data) + + @mock.patch.object(SiteConfiguration, 'access_token', mock.Mock(return_value='foo')) + def test_renew_subscription_with_all_checks_passed(self): + """ + Verify that subscriptions API correctly redirects to basket page if all checks have passed. + """ + user = self._login_as_user(is_staff=True) + full_access_courses_subscription = self.create_subscription( + subscription_type=FULL_ACCESS_COURSES, + subscription_status=True, + stockrecords__partner=self.site.partner + ) + request_url = self._build_renew_subscription_url(str(full_access_courses_subscription.id)) + OrderLineFactory(product=full_access_courses_subscription, order__user=user) + response = self.client.get(request_url) + expected_url = '{base_url}?sku={sku}'.format( + base_url=reverse('basket:basket-add'), + sku=full_access_courses_subscription.stockrecords.first().partner_sku, + ) + self.assertRedirects(response, expected_url, fetch_redirect_response=False) diff --git a/ecommerce/subscriptions/api/v2/views.py b/ecommerce/subscriptions/api/v2/views.py index cbabb2931ed..a3fca3891b4 100644 --- a/ecommerce/subscriptions/api/v2/views.py +++ b/ecommerce/subscriptions/api/v2/views.py @@ -10,13 +10,19 @@ from rest_framework.response import Response from rest_framework.permissions import IsAuthenticated, AllowAny +from django.http import HttpResponseRedirect +from django.urls import reverse + from ecommerce.core.constants import SUBSCRIPTION_PRODUCT_CLASS_NAME from ecommerce.extensions.api.filters import ProductFilter from ecommerce.extensions.api.v2.views import NonDestroyableModelViewSet from ecommerce.extensions.edly_ecommerce_app.permissions import IsAdminOrCourseCreator from ecommerce.extensions.partner.shortcuts import get_partner_for_site from ecommerce.subscriptions.api.v2.serializers import SubscriptionListSerializer, SubscriptionSerializer +from ecommerce.subscriptions.api.v2.tests.constants import FULL_ACCESS_COURSES +from ecommerce.subscriptions.utils import get_valid_user_subscription +OrderLine = get_model('order', 'Line') Product = get_model('catalogue', 'Product') logger = logging.getLogger(__name__) @@ -34,7 +40,7 @@ def get_queryset(self): site_configuration = self.request.site.siteconfiguration filter_active_param = self.request.query_params.get('filter_active', 'false') filter_active = True if filter_active_param == 'true' else False - products = Product.objects.filter( + products = Product.objects.prefetch_related('stockrecords').filter( product_class__name=SUBSCRIPTION_PRODUCT_CLASS_NAME, stockrecords__partner=site_configuration.partner, ) @@ -54,7 +60,10 @@ def get_serializer_context(self): context['partner'] = get_partner_for_site(self.request) return context - @list_route(methods=['post']) + @list_route( + permission_classes=[IsAuthenticated, IsAdminOrCourseCreator], + methods=['post'] + ) def toggle_course_payments(self, request, **kwargs): """ View to toggle course payments. @@ -71,3 +80,53 @@ def course_payments_status(self, request, **kwargs): """ site_configuration = request.site.siteconfiguration return Response(status=status.HTTP_200_OK, data={'course_payments': site_configuration.enable_course_payments}) + + def _get_unacceptable_response_object(self, message): + """ + Get response object for unacceptable request with status code and message. + """ + return Response( + status=status.HTTP_406_NOT_ACCEPTABLE, + data={ + 'error': message + } + ) + + @list_route( + permission_classes=[IsAuthenticated], + methods=['get'] + ) + def renew_subscription(self, request, **kwargs): + """ + View to renew a subscription. + """ + subscription_id = request.query_params.get('subscription_id') + if not subscription_id: + return self._get_unacceptable_response_object('Subscription ID not provided.') + + subscription = self.get_queryset().filter(id=subscription_id).first() + if not subscription: + return self._get_unacceptable_response_object('Subscription is inactive or does not exist.') + + requested_subscription_type = subscription.attr.subscription_type.option + if requested_subscription_type != FULL_ACCESS_COURSES: + return self._get_unacceptable_response_object( + 'Subscription of type {no_expiry_type} can\'t be renewed.'.format( + no_expiry_type=requested_subscription_type + ) + ) + current_valid_subscription = get_valid_user_subscription(request.user, request.site) + if current_valid_subscription: + return self._get_unacceptable_response_object('User already has a valid subscription.') + + orders_lines = OrderLine.objects.filter(product=subscription, order__user=request.user) + if not orders_lines: + return self._get_unacceptable_response_object( + 'The user has not purchased the requested subscription previously.' + ) + + add_to_basket_url = '{base_url}?sku={sku}'.format( + base_url=reverse('basket:basket-add'), + sku=subscription.stockrecords.first().partner_sku, + ) + return HttpResponseRedirect(add_to_basket_url) diff --git a/ecommerce/subscriptions/utils.py b/ecommerce/subscriptions/utils.py index 11ee63a6aa0..9c618cb5dd1 100644 --- a/ecommerce/subscriptions/utils.py +++ b/ecommerce/subscriptions/utils.py @@ -38,7 +38,9 @@ def get_lms_resource_for_user(user, site, endpoint, resource_name=None, query_di try: data_list = endpoint.get(**query_dict) or [] - data_list = data_list[0] if len(data_list) > 0 else [] + if isinstance(data_list, list): + data_list = data_list[0] if len(data_list) > 0 else [] + TieredCache.set_all_tiers(cache_key, data_list, settings.LMS_API_CACHE_TIMEOUT) except (ConnectionError, SlumberBaseException, Timeout) as exc: logger.error('Failed to retrieve %s : %s', resource_name, str(exc))