From b7f2982c9e85a9c5b9e48de3488567b3f9606475 Mon Sep 17 00:00:00 2001 From: rr- Date: Sun, 17 Jul 2016 20:58:42 +0200 Subject: [PATCH] server/posts: fix relations bidirectionality --- server/migrate-v1 | 3 +++ server/szurubooru/db/post.py | 7 +----- server/szurubooru/func/posts.py | 23 +++++++++++++------ server/szurubooru/func/snapshots.py | 2 +- .../search/configs/post_search_config.py | 3 +-- server/szurubooru/tests/db/test_post.py | 20 ++++++++-------- server/szurubooru/tests/func/test_posts.py | 19 +++++++++++---- .../szurubooru/tests/func/test_snapshots.py | 4 ++-- 8 files changed, 48 insertions(+), 33 deletions(-) diff --git a/server/migrate-v1 b/server/migrate-v1 index 6240c0c..6fc4f82 100755 --- a/server/migrate-v1 +++ b/server/migrate-v1 @@ -288,6 +288,9 @@ def import_post_relations(unused_post_ids, v1_session, v2_session): v2_session.add( db.PostRelation( parent_id=row['post1id'], child_id=row['post2id'])) + v2_session.add( + db.PostRelation( + parent_id=row['post2id'], child_id=row['post1id'])) v2_session.commit() def import_post_favorites(unused_post_ids, v1_session, v2_session): diff --git a/server/szurubooru/db/post.py b/server/szurubooru/db/post.py index 73314b2..38b0cba 100644 --- a/server/szurubooru/db/post.py +++ b/server/szurubooru/db/post.py @@ -112,16 +112,11 @@ class Post(Base): # foreign tables user = relationship('User') tags = relationship('Tag', backref='posts', secondary='post_tag') - relating_to = relationship( + relations = relationship( 'Post', secondary='post_relation', primaryjoin=post_id == PostRelation.parent_id, secondaryjoin=post_id == PostRelation.child_id, lazy='joined') - related_by = relationship( - 'Post', - secondary='post_relation', - primaryjoin=post_id == PostRelation.child_id, - secondaryjoin=post_id == PostRelation.parent_id, lazy='joined') features = relationship( 'PostFeature', cascade='all, delete-orphan', lazy='joined') scores = relationship( diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 87b1cbe..ef21383 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -94,8 +94,7 @@ def serialize_post(post, authenticated_user, options=None): { post['id']: post for post in [ - serialize_micro_post(rel) \ - for rel in post.related_by + post.relating_to + serialize_micro_post(rel) for rel in post.relations ] }.values(), key=lambda post: post['id']), @@ -258,14 +257,24 @@ def update_post_tags(post, tag_names): existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) post.tags = existing_tags + new_tags -def update_post_relations(post, post_ids): - relations = db.session \ +def update_post_relations(post, new_post_ids): + old_posts = post.relations + old_post_ids = [p.post_id for p in old_posts] + new_posts = db.session \ .query(db.Post) \ - .filter(db.Post.post_id.in_(post_ids)) \ + .filter(db.Post.post_id.in_(new_post_ids)) \ .all() - if len(relations) != len(post_ids): + if len(new_posts) != len(new_post_ids): raise InvalidPostRelationError('One of relations does not exist.') - post.relating_to = relations + + relations_to_del = [p for p in old_posts if p.post_id not in new_post_ids] + relations_to_add = [p for p in new_posts if p.post_id not in old_post_ids] + for relation in relations_to_del: + post.relations.remove(relation) + relation.relations.remove(post) + for relation in relations_to_add: + post.relations.append(relation) + relation.relations.append(post) def update_post_notes(post, notes): post.notes = [] diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index 49dea66..c65e47b 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -16,7 +16,7 @@ def get_post_snapshot(post): 'checksum': post.checksum, 'tags': sorted([tag.first_name for tag in post.tags]), 'relations': sorted([ - rel.post_id for rel in post.relating_to + post.related_by]), + rel.post_id for rel in post.relations]), 'notes': sorted([{ 'polygon': note.polygon, 'text': note.text, diff --git a/server/szurubooru/search/configs/post_search_config.py b/server/szurubooru/search/configs/post_search_config.py index b83002c..4fca489 100644 --- a/server/szurubooru/search/configs/post_search_config.py +++ b/server/szurubooru/search/configs/post_search_config.py @@ -97,8 +97,7 @@ class PostSearchConfig(BaseSearchConfig): defer(db.Post.tag_count), subqueryload(db.Post.tags).subqueryload(db.Tag.names), lazyload(db.Post.user), - lazyload(db.Post.relating_to), - lazyload(db.Post.related_by), + lazyload(db.Post.relations), lazyload(db.Post.notes), lazyload(db.Post.favorited_by), ) diff --git a/server/szurubooru/tests/db/test_post.py b/server/szurubooru/tests/db/test_post.py index 3a57e90..7a04e32 100644 --- a/server/szurubooru/tests/db/test_post.py +++ b/server/szurubooru/tests/db/test_post.py @@ -19,8 +19,8 @@ def test_saving_post(post_factory, user_factory, tag_factory): post.user = user post.tags.append(tag1) post.tags.append(tag2) - post.relating_to.append(related_post1) - post.relating_to.append(related_post2) + post.relations.append(related_post1) + post.relations.append(related_post2) db.session.commit() db.session.refresh(post) @@ -33,12 +33,10 @@ def test_saving_post(post_factory, user_factory, tag_factory): assert post.checksum == 'deadbeef' assert post.creation_time == datetime(1997, 1, 1) assert post.last_edit_time == datetime(1998, 1, 1) - assert len(post.relating_to) == 2 - assert len(related_post1.relating_to) == 0 - assert len(related_post1.relating_to) == 0 - assert len(post.related_by) == 0 - assert len(related_post1.related_by) == 1 - assert len(related_post1.related_by) == 1 + assert len(post.relations) == 2 + # relation bidirectionality is realized on business level in func.posts + assert len(related_post1.relations) == 0 + assert len(related_post2.relations) == 0 def test_cascade_deletions(post_factory, user_factory, tag_factory): user = user_factory() @@ -73,8 +71,8 @@ def test_cascade_deletions(post_factory, user_factory, tag_factory): post.user = user post.tags.append(tag1) post.tags.append(tag2) - post.relating_to.append(related_post1) - post.relating_to.append(related_post2) + post.relations.append(related_post1) + post.relations.append(related_post2) post.scores.append(score) post.favorited_by.append(favorite) post.features.append(feature) @@ -83,7 +81,7 @@ def test_cascade_deletions(post_factory, user_factory, tag_factory): assert not db.session.dirty assert post.user is not None and post.user.user_id is not None - assert len(post.relating_to) == 2 + assert len(post.relations) == 2 assert db.session.query(db.User).count() == 1 assert db.session.query(db.Tag).count() == 2 assert db.session.query(db.Post).count() == 3 diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py index 1fde001..e83934b 100644 --- a/server/szurubooru/tests/func/test_posts.py +++ b/server/szurubooru/tests/func/test_posts.py @@ -394,11 +394,22 @@ def test_update_post_relations(post_factory): relation2 = post_factory() db.session.add_all([relation1, relation2]) db.session.flush() - post = db.Post() + post = post_factory() posts.update_post_relations(post, [relation1.post_id, relation2.post_id]) - assert len(post.relating_to) == 2 - assert post.relating_to[0].post_id == relation1.post_id - assert post.relating_to[1].post_id == relation2.post_id + assert len(post.relations) == 2 + assert post.relations[0].post_id == relation1.post_id + assert post.relations[1].post_id == relation2.post_id + +def test_relation_bidirectionality(post_factory): + relation1 = post_factory() + relation2 = post_factory() + db.session.add_all([relation1, relation2]) + db.session.flush() + post = post_factory() + posts.update_post_relations(post, [relation1.post_id, relation2.post_id]) + posts.update_post_relations(relation1, []) + assert len(post.relations) == 1 + assert post.relations[0].post_id == relation2.post_id def test_update_post_non_existing_relations(): post = db.Post() diff --git a/server/szurubooru/tests/func/test_snapshots.py b/server/szurubooru/tests/func/test_snapshots.py index abb182d..c82b82f 100644 --- a/server/szurubooru/tests/func/test_snapshots.py +++ b/server/szurubooru/tests/func/test_snapshots.py @@ -38,8 +38,8 @@ def test_serializing_post(post_factory, user_factory, tag_factory): post.source = 'example.com' post.tags.append(tag1) post.tags.append(tag2) - post.relating_to.append(related_post1) - post.relating_to.append(related_post2) + post.relations.append(related_post1) + post.relations.append(related_post2) post.scores.append(score) post.favorited_by.append(favorite) post.features.append(feature)