From 9d6a0e0173831a6284be89cf71c1e627b4d8c0c9 Mon Sep 17 00:00:00 2001 From: rr- Date: Fri, 21 Oct 2016 21:48:08 +0200 Subject: [PATCH] server/posts: add post merging --- API.md | 35 +++ config.yaml.dist | 1 + server/szurubooru/api/post_api.py | 16 ++ server/szurubooru/func/posts.py | 75 ++++++ .../szurubooru/tests/api/test_post_merging.py | 89 +++++++ server/szurubooru/tests/conftest.py | 24 ++ server/szurubooru/tests/func/test_posts.py | 219 ++++++++++++++++++ 7 files changed, 459 insertions(+) create mode 100644 server/szurubooru/tests/api/test_post_merging.py diff --git a/API.md b/API.md index 27d4de7..53fec51 100644 --- a/API.md +++ b/API.md @@ -36,6 +36,7 @@ - [Updating post](#updating-post) - [Getting post](#getting-post) - [Deleting post](#deleting-post) + - [Merging posts](#merging-posts) - [Rating post](#rating-post) - [Adding post to favorites](#adding-post-to-favorites) - [Removing post from favorites](#removing-post-from-favorites) @@ -910,6 +911,40 @@ data. Deletes existing post. Related posts and tags are kept. +## Merging posts +- **Request** + + `POST /post-merge/` + +- **Input** + + ```json5 + { + "removeVersion": , + "remove": , + "mergeToVersion": , + "mergeTo": + } + ``` + +- **Output** + + A [post resource](#post) containing the merged post. + +- **Errors** + + - the version of either post is outdated + - the source or target post does not exist + - the source post is the same as the target post + - privileges are too low + +- **Description** + + Removes source post and merges all of its tags, relations, scores, + favorites and comments to the target post. Source post properties such as + its content, safety, source, whether to loop the video and other scalar + values do not get transferred and are discarded. + ## Rating post - **Request** diff --git a/config.yaml.dist b/config.yaml.dist index 88548f7..0e18de4 100644 --- a/config.yaml.dist +++ b/config.yaml.dist @@ -80,6 +80,7 @@ privileges: 'posts:feature': moderator 'posts:delete': moderator 'posts:score': regular + 'posts:merge': moderator 'posts:favorite': regular 'tags:create': regular diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index e4d4975..9cb1237 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -124,6 +124,22 @@ def delete_post(ctx, params): return {} +@routes.post('/post-merge/?') +def merge_posts(ctx, _params=None): + source_post_id = ctx.get_param_as_string('remove', required=True) or '' + target_post_id = ctx.get_param_as_string('mergeTo', required=True) or '' + source_post = posts.get_post_by_id(source_post_id) + target_post = posts.get_post_by_id(target_post_id) + versions.verify_version(source_post, ctx, 'removeVersion') + versions.verify_version(target_post, ctx, 'mergeToVersion') + versions.bump_version(target_post) + auth.verify_privilege(ctx.user, 'posts:merge') + posts.merge_posts(source_post, target_post) + snapshots.merge(source_post, target_post, ctx.user) + ctx.session.commit() + return _serialize_post(ctx, target_post) + + @routes.get('/featured-post/?') def get_featured_post(ctx, _params=None): post = posts.try_get_featured_post() diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index ee43b92..f3725e5 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -440,3 +440,78 @@ def feature_post(post, user): def delete(post): assert post db.session.delete(post) + + +def merge_posts(source_post, target_post): + assert source_post + assert target_post + if source_post.post_id == target_post.post_id: + raise InvalidPostRelationError('Cannot merge post with itself.') + + def merge_tables(table, anti_dup_func, source_post_id, target_post_id): + table1 = table + table2 = sqlalchemy.orm.util.aliased(table) + update_stmt = (sqlalchemy.sql.expression.update(table1) + .where(table1.post_id == source_post_id)) + + if anti_dup_func is not None: + update_stmt = (update_stmt + .where(~sqlalchemy.exists() + .where(anti_dup_func(table1, table2)) + .where(table2.post_id == target_post_id))) + + update_stmt = (update_stmt.values(post_id=target_post_id)) + db.session.execute(update_stmt) + + def merge_tags(source_post_id, target_post_id): + merge_tables( + db.PostTag, + lambda alias1, alias2: alias1.tag_id == alias2.tag_id, + source_post_id, + target_post_id) + + def merge_scores(source_post_id, target_post_id): + merge_tables( + db.PostScore, + lambda alias1, alias2: alias1.user_id == alias2.user_id, + source_post_id, + target_post_id) + + def merge_favorites(source_post_id, target_post_id): + merge_tables( + db.PostFavorite, + lambda alias1, alias2: alias1.user_id == alias2.user_id, + source_post_id, + target_post_id) + + def merge_comments(source_post_id, target_post_id): + merge_tables(db.Comment, None, source_post_id, target_post_id) + + def merge_relations(source_post_id, target_post_id): + table1 = db.PostRelation + table2 = sqlalchemy.orm.util.aliased(db.PostRelation) + update_stmt = (sqlalchemy.sql.expression.update(table1) + .where(table1.parent_id == source_post_id) + .where(table1.child_id != target_post_id) + .where(~sqlalchemy.exists() + .where(table2.child_id == table1.child_id) + .where(table2.parent_id == target_post_id)) + .values(parent_id=target_post_id)) + db.session.execute(update_stmt) + + update_stmt = (sqlalchemy.sql.expression.update(table1) + .where(table1.child_id == source_post_id) + .where(table1.parent_id != target_post_id) + .where(~sqlalchemy.exists() + .where(table2.parent_id == table1.parent_id) + .where(table2.child_id == target_post_id)) + .values(child_id=target_post_id)) + db.session.execute(update_stmt) + + merge_tags(source_post.post_id, target_post.post_id) + merge_comments(source_post.post_id, target_post.post_id) + merge_scores(source_post.post_id, target_post.post_id) + merge_favorites(source_post.post_id, target_post.post_id) + merge_relations(source_post.post_id, target_post.post_id) + + delete(source_post) diff --git a/server/szurubooru/tests/api/test_post_merging.py b/server/szurubooru/tests/api/test_post_merging.py new file mode 100644 index 0000000..e654090 --- /dev/null +++ b/server/szurubooru/tests/api/test_post_merging.py @@ -0,0 +1,89 @@ +from unittest.mock import patch +import pytest +from szurubooru import api, db, errors +from szurubooru.func import posts, snapshots + + +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({'privileges': {'posts:merge': db.User.RANK_REGULAR}}) + + +def test_merging(user_factory, context_factory, post_factory): + auth_user = user_factory(rank=db.User.RANK_REGULAR) + source_post = post_factory() + target_post = post_factory() + db.session.add_all([source_post, target_post]) + db.session.flush() + with patch('szurubooru.func.posts.serialize_post'), \ + patch('szurubooru.func.posts.merge_posts'), \ + patch('szurubooru.func.snapshots.merge'): + api.post_api.merge_posts( + context_factory( + params={ + 'removeVersion': 1, + 'mergeToVersion': 1, + 'remove': source_post.post_id, + 'mergeTo': target_post.post_id, + }, + user=auth_user)) + posts.merge_posts.called_once_with(source_post, target_post) + snapshots.merge.assert_called_once_with( + source_post, target_post, auth_user) + + +@pytest.mark.parametrize( + 'field', ['remove', 'mergeTo', 'removeVersion', 'mergeToVersion']) +def test_trying_to_omit_mandatory_field( + user_factory, post_factory, context_factory, field): + source_post = post_factory() + target_post = post_factory() + db.session.add_all([source_post, target_post]) + db.session.commit() + params = { + 'removeVersion': 1, + 'mergeToVersion': 1, + 'remove': source_post.post_id, + 'mergeTo': target_post.post_id, + } + del params[field] + with pytest.raises(errors.ValidationError): + api.post_api.merge_posts( + context_factory( + params=params, + user=user_factory(rank=db.User.RANK_REGULAR))) + + +def test_trying_to_merge_non_existing( + user_factory, post_factory, context_factory): + post = post_factory() + db.session.add(post) + db.session.commit() + with pytest.raises(posts.PostNotFoundError): + api.post_api.merge_posts( + context_factory( + params={'remove': post.post_id, 'mergeTo': 999}, + user=user_factory(rank=db.User.RANK_REGULAR))) + with pytest.raises(posts.PostNotFoundError): + api.post_api.merge_posts( + context_factory( + params={'remove': 999, 'mergeTo': post.post_id}, + user=user_factory(rank=db.User.RANK_REGULAR))) + + +def test_trying_to_merge_without_privileges( + user_factory, post_factory, context_factory): + source_post = post_factory() + target_post = post_factory() + db.session.add_all([source_post, target_post]) + db.session.commit() + with pytest.raises(errors.AuthError): + api.post_api.merge_posts( + context_factory( + params={ + 'removeVersion': 1, + 'mergeToVersion': 1, + 'remove': source_post.post_id, + 'mergeTo': target_post.post_id, + }, + user=user_factory(rank=db.User.RANK_ANONYMOUS))) diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 1f40887..49f9a17 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -192,6 +192,30 @@ def comment_factory(user_factory, post_factory): return factory +@pytest.fixture +def post_score_factory(user_factory, post_factory): + def factory(post=None, user=None, score=1): + if user is None: + user = user_factory() + if post is None: + post = post_factory() + return db.PostScore( + post=post, user=user, score=score, time=datetime(1999, 1, 1)) + return factory + + +@pytest.fixture +def post_favorite_factory(user_factory, post_factory): + def factory(post=None, user=None): + if user is None: + user = user_factory() + if post is None: + post = post_factory() + return db.PostFavorite( + post=post, user=user, time=datetime(1999, 1, 1)) + return factory + + @pytest.fixture def read_asset(): def get(path): diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 9c5ecd0..033e015 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -605,3 +605,222 @@ def test_delete(post_factory): posts.delete(post) db.session.flush() assert posts.get_post_count() == 0 + + +def test_merge_posts_deletes_source_post(post_factory): + source_post = post_factory() + target_post = post_factory() + db.session.add_all([source_post, target_post]) + db.session.flush() + posts.merge_posts(source_post, target_post) + db.session.flush() + assert posts.try_get_post_by_id(source_post.post_id) is None + post = posts.get_post_by_id(target_post.post_id) + assert post is not None + + +def test_merge_posts_with_itself(post_factory): + source_post = post_factory() + db.session.add(source_post) + db.session.flush() + with pytest.raises(posts.InvalidPostRelationError): + posts.merge_posts(source_post, source_post) + + +def test_merge_posts_moves_tags(post_factory, tag_factory): + source_post = post_factory() + target_post = post_factory() + tag = tag_factory() + tag.posts = [source_post] + db.session.add_all([source_post, target_post, tag]) + db.session.commit() + assert source_post.tag_count == 1 + assert target_post.tag_count == 0 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).tag_count == 1 + + +def test_merge_posts_doesnt_duplicate_tags(post_factory, tag_factory): + source_post = post_factory() + target_post = post_factory() + tag = tag_factory() + tag.posts = [source_post, target_post] + db.session.add_all([source_post, target_post, tag]) + db.session.commit() + assert source_post.tag_count == 1 + assert target_post.tag_count == 1 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).tag_count == 1 + + +def test_merge_posts_moves_comments(post_factory, comment_factory): + source_post = post_factory() + target_post = post_factory() + comment = comment_factory(post=source_post) + db.session.add_all([source_post, target_post, comment]) + db.session.commit() + assert source_post.comment_count == 1 + assert target_post.comment_count == 0 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).comment_count == 1 + + +def test_merge_posts_moves_scores(post_factory, post_score_factory): + source_post = post_factory() + target_post = post_factory() + score = post_score_factory(post=source_post, score=1) + db.session.add_all([source_post, target_post, score]) + db.session.commit() + assert source_post.score == 1 + assert target_post.score == 0 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).score == 1 + + +def test_merge_posts_doesnt_duplicate_scores( + post_factory, user_factory, post_score_factory): + source_post = post_factory() + target_post = post_factory() + user = user_factory() + score1 = post_score_factory(post=source_post, score=1, user=user) + score2 = post_score_factory(post=target_post, score=1, user=user) + db.session.add_all([source_post, target_post, score1, score2]) + db.session.commit() + assert source_post.score == 1 + assert target_post.score == 1 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).score == 1 + + +def test_merge_posts_moves_favorites(post_factory, post_favorite_factory): + source_post = post_factory() + target_post = post_factory() + favorite = post_favorite_factory(post=source_post) + db.session.add_all([source_post, target_post, favorite]) + db.session.commit() + assert source_post.favorite_count == 1 + assert target_post.favorite_count == 0 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).favorite_count == 1 + + +def test_merge_posts_doesnt_duplicate_favorites( + post_factory, user_factory, post_favorite_factory): + source_post = post_factory() + target_post = post_factory() + user = user_factory() + favorite1 = post_favorite_factory(post=source_post, user=user) + favorite2 = post_favorite_factory(post=target_post, user=user) + db.session.add_all([source_post, target_post, favorite1, favorite2]) + db.session.commit() + assert source_post.favorite_count == 1 + assert target_post.favorite_count == 1 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).favorite_count == 1 + + +def test_merge_posts_moves_child_relations(post_factory): + source_post = post_factory() + target_post = post_factory() + related_post = post_factory() + source_post.relations = [related_post] + db.session.add_all([source_post, target_post, related_post]) + db.session.commit() + assert source_post.relation_count == 1 + assert target_post.relation_count == 0 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).relation_count == 1 + + +def test_merge_posts_doesnt_duplicate_child_relations(post_factory): + source_post = post_factory() + target_post = post_factory() + related_post = post_factory() + source_post.relations = [related_post] + target_post.relations = [related_post] + db.session.add_all([source_post, target_post, related_post]) + db.session.commit() + assert source_post.relation_count == 1 + assert target_post.relation_count == 1 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).relation_count == 1 + + +def test_merge_posts_moves_parent_relations(post_factory): + source_post = post_factory() + target_post = post_factory() + related_post = post_factory() + related_post.relations = [source_post] + db.session.add_all([source_post, target_post, related_post]) + db.session.commit() + assert source_post.relation_count == 1 + assert target_post.relation_count == 0 + assert related_post.relation_count == 1 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).relation_count == 1 + assert posts.get_post_by_id(related_post.post_id).relation_count == 1 + + +def test_merge_posts_doesnt_duplicate_parent_relations(post_factory): + source_post = post_factory() + target_post = post_factory() + related_post = post_factory() + related_post.relations = [source_post, target_post] + db.session.add_all([source_post, target_post, related_post]) + db.session.commit() + assert source_post.relation_count == 1 + assert target_post.relation_count == 1 + assert related_post.relation_count == 2 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).relation_count == 1 + assert posts.get_post_by_id(related_post.post_id).relation_count == 1 + + +def test_merge_posts_doesnt_create_relation_loop_for_children(post_factory): + source_post = post_factory() + target_post = post_factory() + source_post.relations = [target_post] + db.session.add_all([source_post, target_post]) + db.session.commit() + assert source_post.relation_count == 1 + assert target_post.relation_count == 1 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).relation_count == 0 + + +def test_merge_posts_doesnt_create_relation_loop_for_parents(post_factory): + source_post = post_factory() + target_post = post_factory() + target_post.relations = [source_post] + db.session.add_all([source_post, target_post]) + db.session.commit() + assert source_post.relation_count == 1 + assert target_post.relation_count == 1 + posts.merge_posts(source_post, target_post) + db.session.commit() + assert posts.try_get_post_by_id(source_post.post_id) is None + assert posts.get_post_by_id(target_post.post_id).relation_count == 0