diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 3863e5b..ee43b92 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -355,8 +355,13 @@ def update_post_tags(post, tag_names): def update_post_relations(post, new_post_ids): assert post + try: + new_post_ids = [int(id) for id in new_post_ids] + except ValueError: + raise InvalidPostRelationError( + 'A relation must be numeric post ID.') old_posts = post.relations - old_post_ids = [p.post_id for p in old_posts] + old_post_ids = [int(p.post_id) for p in old_posts] new_posts = db.session \ .query(db.Post) \ .filter(db.Post.post_id.in_(new_post_ids)) \ diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 56daaee..9c5ecd0 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -521,6 +521,7 @@ def test_update_post_relations_bidirectionality(post_factory): db.session.flush() post = post_factory() posts.update_post_relations(post, [relation1.post_id, relation2.post_id]) + db.session.flush() posts.update_post_relations(relation1, []) assert len(post.relations) == 1 assert post.relations[0].post_id == relation2.post_id