diff --git a/invenio_rdm_records/config.py b/invenio_rdm_records/config.py index a36b75d68f..bc777b54ed 100644 --- a/invenio_rdm_records/config.py +++ b/invenio_rdm_records/config.py @@ -4,6 +4,7 @@ # Copyright (C) 2019 Northwestern University. # Copyright (C) 2021-2023 Graz University of Technology. # Copyright (C) 2023 TU Wien. +# Copyright (C) 2023 KTH Royal Institute of Technology. # # Invenio-RDM-Records is free software; you can redistribute it and/or modify # it under the terms of the MIT License; see LICENSE file for more details. @@ -125,6 +126,12 @@ def always_valid(identifier): RDM_ALLOW_RESTRICTED_RECORDS = True """Allow users to set restricted/private records.""" +# +# Record communities +# +RDM_ENSURE_RECORD_COMMUNITY_EXISTS = False +"""Enforces at least one community per record on remove community function.""" + # # Search configuration # diff --git a/invenio_rdm_records/services/communities/service.py b/invenio_rdm_records/services/communities/service.py index 26ab36b993..448c103549 100644 --- a/invenio_rdm_records/services/communities/service.py +++ b/invenio_rdm_records/services/communities/service.py @@ -1,12 +1,13 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2023 CERN. +# Copyright (C) 2023 KTH Royal Institute of Technology. # # Invenio-RDM-Records is free software; you can redistribute it and/or modify # it under the terms of the MIT License; see LICENSE file for more details. """RDM Record Communities Service.""" - +from flask import current_app from invenio_communities.proxies import current_communities from invenio_i18n import lazy_gettext as _ from invenio_pidstore.errors import PIDDoesNotExistError @@ -34,6 +35,7 @@ InvalidAccessRestrictions, OpenRequestAlreadyExists, RecordCommunityMissing, + RecordCommunityRequired, ) @@ -166,6 +168,10 @@ def _remove(self, identity, community_id, record): self.require_permission( identity, "remove_community", record=record, community_id=community_id ) + if current_app.config.get("RDM_ENSURE_RECORD_COMMUNITY_EXISTS"): + rec_communities = record.parent.communities.ids + if len(rec_communities) <= 1: + raise RecordCommunityRequired() # Default community is deleted when the exact same community is removed from the record record.parent.communities.remove(community_id) @@ -185,12 +191,17 @@ def remove(self, identity, id_, data, uow): ) communities = valid_data["communities"] processed = [] + for community in communities: community_id = community["id"] try: self._remove(identity, community_id, record) processed.append({"community": community_id}) - except (RecordCommunityMissing, PermissionDeniedError) as ex: + except ( + RecordCommunityMissing, + PermissionDeniedError, + RecordCommunityRequired, + ) as ex: errors.append( { "community": community_id, @@ -213,7 +224,7 @@ def search( search_preference=None, expand=False, extra_filter=None, - **kwargs + **kwargs, ): """Search for record's communities.""" record = self.record_cls.pid.resolve(id_) @@ -230,7 +241,7 @@ def search( search_preference=search_preference, expand=expand, extra_filter=communities_filter, - **kwargs + **kwargs, ) @staticmethod @@ -274,7 +285,7 @@ def search_suggested_communities( expand=False, by_membership=False, extra_filter=None, - **kwargs + **kwargs, ): """Search for communities that can be added to a record.""" record = self.record_cls.pid.resolve(id_) @@ -294,7 +305,7 @@ def search_suggested_communities( params=params, search_preference=search_preference, extra_filter=communities_filter, - **kwargs + **kwargs, ) return current_communities.service.search( @@ -303,5 +314,5 @@ def search_suggested_communities( search_preference=search_preference, expand=expand, extra_filter=communities_filter, - **kwargs + **kwargs, ) diff --git a/invenio_rdm_records/services/errors.py b/invenio_rdm_records/services/errors.py index 4280ed59bf..e03c7c3801 100644 --- a/invenio_rdm_records/services/errors.py +++ b/invenio_rdm_records/services/errors.py @@ -2,6 +2,7 @@ # # Copyright (C) 2021 CERN. # Copyright (C) 2023 Graz University of Technology. +# Copyright (C) 2023 KTH Royal Institute of Technology. # # Invenio-RDM-Records is free software; you can redistribute it and/or modify # it under the terms of the MIT License; see LICENSE file for more details. @@ -142,6 +143,15 @@ def description(self): ) +class RecordCommunityRequired(Exception): + """Record associated community required.""" + + @property + def description(self): + """Exception description.""" + return _("Last community on record cannot be removed.") + + class InvalidCommunityVisibility(Exception): """Community visibility does not match the content.""" diff --git a/tests/resources/test_resources_communities.py b/tests/resources/test_resources_communities.py index 2a0f59aa97..8a8ab82789 100644 --- a/tests/resources/test_resources_communities.py +++ b/tests/resources/test_resources_communities.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- # # Copyright (C) 2023 CERN. +# Copyright (C) 2023 KTH Royal Institute of Technology. # # Invenio-RDM-Records is free software; you can redistribute it and/or modify # it under the terms of the MIT License; see LICENSE file for more details. @@ -20,6 +21,20 @@ from invenio_rdm_records.requests.community_inclusion import CommunityInclusion +# The 'autouse=True' parameter will ensure that the fixture is executed for all tests by default, +# unless a test has been explicitly marked with the '@pytest.mark. +# escape_record_community_exists_fixture' marker to escape the fixture. +@pytest.fixture(autouse=True) +def ensure_record_community_exists(request, app): + if "escape_record_community_exists_fixture" in request.keywords: + yield + else: + old_value = app.config.get("RDM_ENSURE_RECORD_COMMUNITY_EXISTS", False) + app.config["RDM_ENSURE_RECORD_COMMUNITY_EXISTS"] = True + yield + app.config["RDM_ENSURE_RECORD_COMMUNITY_EXISTS"] = old_value + + def _add_to_community(db, record, community): record.parent.communities.add(community._record, default=False) record.parent.commit() @@ -536,6 +551,52 @@ def test_invalid_record_or_draft( assert response.status_code == 404 +def test_remove_last_community( + client, + uploader, + curator, + record_community, + headers, + community, +): + """Test removal of a community from the record.""" + for user in [uploader, curator]: + record = record_community.create_record() + + data = {"communities": [{"id": community.id}]} + client = user.login(client) + response = client.get( + f"/communities/{community.id}/records", + headers=headers, + json=data, + ) + assert ( + len(response.json["hits"]["hits"][0]["parent"]["communities"]["ids"]) == 1 + ) + + response = client.delete( + f"/records/{record.pid.pid_value}/communities", + headers=headers, + json=data, + ) + assert response.status_code == 400 + assert response.json.get("errors") + record_saved = client.get(f"/records/{record.pid.pid_value}", headers=headers) + assert record_saved.json["parent"]["communities"] + + client = user.logout(client) + # check communities number + response = client.get( + f"/communities/{community.id}/records", + headers=headers, + json=data, + ) + assert ( + len(response.json["hits"]["hits"][0]["parent"]["communities"]["ids"]) == 1 + ) + + +@pytest.mark.escape_record_community_exists_fixture def test_remove_community( client, uploader, curator, record_community, headers, community ): @@ -585,6 +646,7 @@ def test_remove_missing_permission( assert record_saved.json["parent"]["communities"] == {"ids": [str(community.id)]} +@pytest.mark.escape_record_community_exists_fixture def test_remove_existing_non_existing_community( client, uploader, record_community, headers, community ): @@ -605,6 +667,34 @@ def test_remove_existing_non_existing_community( assert not record_saved.json["parent"]["communities"] +def test_remove_last_existing_non_existing_community( + client, uploader, record_community, headers, community +): + """Test removal of an existing and non-existing community from the record, + while ensuring at least one community exists.""" + data = { + "communities": [ + {"id": community.id}, + {"id": "wrong-id"}, + {"id": "wrong-id2"}, + ] + } + + client = uploader.login(client) + record = record_community.create_record() + + response = client.delete( + f"/records/{record.pid.pid_value}/communities", + headers=headers, + json=data, + ) + assert response.status_code == 400 + # Should get 3 errors: Can't remove community, 2 bad IDs + assert len(response.json["errors"]) == 3 + record_saved = client.get(f"/records/{record.pid.pid_value}", headers=headers) + assert record_saved.json["parent"]["communities"] + + @pytest.mark.parametrize( "payload", [[{"id": "wrong-id"}], [{"id": "duplicated-id"}, {"id": "duplicated-id"}], []],