diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index dc145f3..3fbf26c 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -138,13 +138,13 @@ def get_tag_siblings(tag): pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) result = db.session \ - .query(tag_alias, sqlalchemy.func.count(tag_alias.tag_id)) \ + .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) \ .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) \ .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) \ .filter(pt_alias2.tag_id == tag.tag_id) \ .filter(pt_alias1.tag_id != tag.tag_id) \ .group_by(tag_alias.tag_id) \ - .order_by(tag_alias.post_count.desc()) \ + .order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) \ .limit(50) return result diff --git a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py index 37ac7de..cddb04a 100644 --- a/server/szurubooru/tests/api/test_tag_siblings_retrieving.py +++ b/server/szurubooru/tests/api/test_tag_siblings_retrieving.py @@ -4,11 +4,11 @@ from szurubooru import api, db, errors from szurubooru.func import util, tags def assert_results(result, expected_tag_names_and_occurrences): - actual_tag_names_and_occurences = {} + actual_tag_names_and_occurences = [] for item in result['results']: tag_name = item['tag']['names'][0] occurrences = item['occurrences'] - actual_tag_names_and_occurences[tag_name] = occurrences + actual_tag_names_and_occurences.append((tag_name, occurrences)) assert actual_tag_names_and_occurences == expected_tag_names_and_occurrences @pytest.fixture @@ -33,7 +33,7 @@ def test_unused(test_ctx): result = test_ctx.api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag') - assert_results(result, {}) + assert_results(result, []) def test_used_alone(test_ctx): tag = test_ctx.tag_factory(names=['tag']) @@ -43,7 +43,7 @@ def test_used_alone(test_ctx): result = test_ctx.api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag') - assert_results(result, {}) + assert_results(result, []) def test_used_with_others(test_ctx): tag1 = test_ctx.tag_factory(names=['tag1']) @@ -54,11 +54,11 @@ def test_used_with_others(test_ctx): result = test_ctx.api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1') - assert_results(result, {'tag2': 1}) + assert_results(result, [('tag2', 1)]) result = test_ctx.api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2') - assert_results(result, {'tag1': 1}) + assert_results(result, [('tag1', 1)]) def test_used_with_multiple_others(test_ctx): tag1 = test_ctx.tag_factory(names=['tag1']) @@ -66,21 +66,26 @@ def test_used_with_multiple_others(test_ctx): tag3 = test_ctx.tag_factory(names=['tag3']) post1 = test_ctx.post_factory() post2 = test_ctx.post_factory() + post3 = test_ctx.post_factory() + post4 = test_ctx.post_factory() post1.tags = [tag1, tag2, tag3] post2.tags = [tag1, tag3] - db.session.add_all([post1, post2, tag1, tag2, tag3]) + post3.tags = [tag2] + post4.tags = [tag2] + db.session.add_all([post1, post2, post3, post4, tag1, tag2, tag3]) result = test_ctx.api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1') - assert_results(result, {'tag2': 1, 'tag3': 2}) + assert_results(result, [('tag3', 2), ('tag2', 1)]) result = test_ctx.api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2') - assert_results(result, {'tag1': 1, 'tag3': 1}) + assert_results(result, [('tag1', 1), ('tag3', 1)]) result = test_ctx.api.get( test_ctx.context_factory( user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag3') - assert_results(result, {'tag1': 2, 'tag2': 1}) + # even though tag2 is used more widely, tag1 is more relevant to tag3 + assert_results(result, [('tag1', 2), ('tag2', 1)]) def test_trying_to_retrieve_non_existing(test_ctx): with pytest.raises(tags.TagNotFoundError):