Skip to content

Commit b963c8e

Browse files
committed
Fix test for entity create
1 parent 841b72e commit b963c8e

File tree

6 files changed

+50
-17
lines changed

6 files changed

+50
-17
lines changed

acl/tests/test_api_v2.py

+9
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import json
22

3+
from mock import mock
4+
35
from acl.models import ACLBase
46
from airone.lib.acl import ACLType
57
from airone.lib.test import AironeViewTest
68
from airone.lib.types import AttrTypeValue
9+
from entity import tasks
710
from entity.models import Entity, EntityAttr
811
from role.models import Role
912

@@ -474,6 +477,9 @@ def _put_acl():
474477
self.assertEqual(resp.status_code, 200)
475478
self.assertTrue(role.is_permitted(acl, ACLType.Full))
476479

480+
@mock.patch(
481+
"entity.tasks.create_entity_v2.delay", mock.Mock(side_effect=tasks.create_entity_v2)
482+
)
477483
def test_list_history(self):
478484
self.initialization_for_retrieve_test()
479485
self.client.post("/entity/api/v2/", json.dumps({"name": "test"}), "application/json")
@@ -594,6 +600,9 @@ def test_list_history_with_role(self):
594600
],
595601
)
596602

603+
@mock.patch(
604+
"entity.tasks.create_entity_v2.delay", mock.Mock(side_effect=tasks.create_entity_v2)
605+
)
597606
def test_list_history_with_entity_attr(self):
598607
self.initialization_for_retrieve_test()
599608

airone/lib/drf.py

+15
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import yaml
44
from django.conf import settings
5+
from rest_framework import serializers
56
from rest_framework.exceptions import APIException, ParseError, ValidationError
67
from rest_framework.parsers import BaseParser
78
from rest_framework.renderers import BaseRenderer
@@ -150,3 +151,17 @@ def _convert_error_code(detail):
150151
response.data = _convert_error_code(response.data)
151152

152153
return response
154+
155+
156+
class AironeUserDefault(serializers.CurrentUserDefault):
157+
"""
158+
It enables to get user from the custom field in the context.
159+
The original CurrentUserDefault fetches it from request context,
160+
so it fails if the context doesn't have request.
161+
"""
162+
163+
def __call__(self, serializer_field):
164+
if "_user" in serializer_field.context:
165+
return serializer_field.context["_user"]
166+
167+
return super().__call__(serializer_field)

entity/api_v2/serializers.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from rest_framework.exceptions import PermissionDenied, ValidationError
1212

1313
import custom_view
14+
from airone.lib import drf
1415
from airone.lib.acl import ACLType
1516
from airone.lib.drf import DuplicatedObjectExistsError, ObjectNotExistsError, RequiredParameterError
1617
from airone.lib.log import Logger
@@ -69,7 +70,7 @@ def validate(self, webhook):
6970

7071

7172
class EntityAttrCreateSerializer(serializers.ModelSerializer):
72-
created_user = serializers.HiddenField(default=serializers.CurrentUserDefault())
73+
created_user = serializers.HiddenField(default=drf.AironeUserDefault())
7374

7475
class Meta:
7576
model = EntityAttr
@@ -361,6 +362,7 @@ def create(self, validated_data: EntityCreateData):
361362
if user is None:
362363
raise RequiredParameterError("user is required")
363364

365+
validated_data["created_user"] = user
364366
if custom_view.is_custom("before_create_entity_V2"):
365367
validated_data = custom_view.call_custom(
366368
"before_create_entity_v2", None, user, validated_data

entity/api_v2/views.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -134,10 +134,10 @@ def get_queryset(self):
134134
def create(self, request, *args, **kwargs):
135135
user: User = request.user
136136

137-
serializer = EntityCreateSerializer(data=request.data)
137+
serializer = EntityCreateSerializer(data=request.data, context={"_user": user})
138138
serializer.is_valid(raise_exception=True)
139139

140-
job = Job.new_create_entity_v2(user, None, params=serializer.validated_data)
140+
job = Job.new_create_entity_v2(user, None, params=request.data)
141141
job.run()
142142

143143
return Response(status=status.HTTP_202_ACCEPTED)

entity/tests/test_api_v2.py

+21-13
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import yaml
77
from django.conf import settings
88
from django.urls import reverse
9+
from rest_framework import status
910
from rest_framework.exceptions import ValidationError
1011

1112
from acl.models import ACLBase
@@ -465,6 +466,9 @@ def test_list_entity_without_permission(self):
465466
self.assertEqual(resp.status_code, 200)
466467
self.assertEqual(resp.json()["count"], 2)
467468

469+
@mock.patch(
470+
"entity.tasks.create_entity_v2.delay", mock.Mock(side_effect=tasks.create_entity_v2)
471+
)
468472
def test_create_entity(self):
469473
params = {
470474
"name": "entity1",
@@ -493,16 +497,9 @@ def test_create_entity(self):
493497
}
494498

495499
resp = self.client.post("/entity/api/v2/", json.dumps(params), "application/json")
496-
self.assertEqual(resp.status_code, 201)
500+
self.assertEqual(resp.status_code, status.HTTP_202_ACCEPTED)
497501

498-
entity: Entity = Entity.objects.get(id=resp.json()["id"])
499-
self.assertEqual(
500-
resp.json(),
501-
{
502-
"id": entity.id,
503-
"name": "entity1",
504-
},
505-
)
502+
entity: Entity = Entity.objects.get(name=params["name"])
506503
self.assertEqual(entity.name, "entity1")
507504
self.assertEqual(entity.note, "hoge")
508505
self.assertEqual(entity.status, Entity.STATUS_TOP_LEVEL)
@@ -1246,6 +1243,9 @@ def test_create_entity_with_invalid_param_webhooks(self):
12461243
},
12471244
)
12481245

1246+
@mock.patch(
1247+
"entity.tasks.create_entity_v2.delay", mock.Mock(side_effect=tasks.create_entity_v2)
1248+
)
12491249
def test_create_entity_with_attrs_referral(self):
12501250
params = {
12511251
"name": "entity1",
@@ -1261,25 +1261,33 @@ def test_create_entity_with_attrs_referral(self):
12611261
}
12621262

12631263
resp = self.client.post("/entity/api/v2/", json.dumps(params), "application/json")
1264+
self.assertEqual(resp.status_code, status.HTTP_202_ACCEPTED)
12641265

1265-
entity: Entity = Entity.objects.get(id=resp.json()["id"])
1266+
entity: Entity = Entity.objects.get(name=params["name"])
12661267
for entity_attr in entity.attrs.all():
12671268
if entity_attr.type & AttrTypeValue["object"]:
12681269
self.assertEqual([x.id for x in entity_attr.referral.all()], [self.ref_entity.id])
12691270
else:
12701271
self.assertEqual([x.id for x in entity_attr.referral.all()], [])
12711272

1273+
@mock.patch(
1274+
"entity.tasks.create_entity_v2.delay", mock.Mock(side_effect=tasks.create_entity_v2)
1275+
)
12721276
def test_create_entity_with_webhook_is_verified(self):
12731277
params = {
12741278
"name": "entity1",
12751279
"webhooks": [{"url": "http://example.net/"}, {"url": "http://hoge.hoge/"}],
12761280
}
12771281
resp = self.client.post("/entity/api/v2/", json.dumps(params), "application/json")
1278-
entity: Entity = Entity.objects.get(id=resp.json()["id"])
1282+
self.assertEqual(resp.status_code, status.HTTP_202_ACCEPTED)
1283+
entity: Entity = Entity.objects.get(name=params["name"])
12791284
self.assertEqual([x.is_verified for x in entity.webhooks.all()], [True, False])
12801285

12811286
@mock.patch("custom_view.is_custom", mock.Mock(return_value=True))
12821287
@mock.patch("custom_view.call_custom")
1288+
@mock.patch(
1289+
"entity.tasks.create_entity_v2.delay", mock.Mock(side_effect=tasks.create_entity_v2)
1290+
)
12831291
def test_create_entity_with_customview(self, mock_call_custom):
12841292
params = {"name": "hoge"}
12851293

@@ -1297,7 +1305,7 @@ def side_effect(handler_name, entity_name, user, *args):
12971305
self.assertEqual(user, self.user)
12981306

12991307
if handler_name == "before_create_entity_v2":
1300-
self.assertEqual(
1308+
self.assertDictEqual(
13011309
args[0],
13021310
{
13031311
"name": "hoge",
@@ -1315,7 +1323,7 @@ def side_effect(handler_name, entity_name, user, *args):
13151323

13161324
mock_call_custom.side_effect = side_effect
13171325
resp = self.client.post("/entity/api/v2/", json.dumps(params), "application/json")
1318-
self.assertEqual(resp.status_code, 201)
1326+
self.assertEqual(resp.status_code, status.HTTP_202_ACCEPTED)
13191327
self.assertTrue(mock_call_custom.called)
13201328

13211329
def test_create_entity_with_webhook_is_disabled(self):

job/models.py

-1
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,6 @@ def new_invoke_trigger(kls, user, target_entry, recv_attrs={}, dependent_job=Non
544544

545545
@classmethod
546546
def new_create_entity_v2(kls, user, target, text="", params={}):
547-
print(params)
548547
return kls._create_new_job(
549548
user,
550549
target,

0 commit comments

Comments
 (0)