diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 6ff3a87..d190b8d 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -377,6 +377,8 @@ def update_post_relations(post, new_post_ids): .all() if len(new_posts) != len(new_post_ids): raise InvalidPostRelationError('One of relations does not exist.') + if post.post_id in new_post_ids: + raise InvalidPostRelationError('Post cannot relate to itself.') relations_to_del = [p for p in old_posts if p.post_id not in new_post_ids] relations_to_add = [p for p in new_posts if p.post_id not in old_post_ids] diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 6465f51..f44fb03 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -541,6 +541,14 @@ def test_update_post_relations_with_nonexisting_posts(): posts.update_post_relations(post, [100]) +def test_update_post_relations_with_itself(post_factory): + post = post_factory() + db.session.add(post) + db.session.flush() + with pytest.raises(posts.InvalidPostRelationError): + posts.update_post_relations(post, [post.post_id]) + + def test_update_post_notes(): post = db.Post() posts.update_post_notes(