server/tags: fix sorting tag siblings

This commit is contained in:
rr- 2016-05-22 11:59:16 +02:00
parent cf3b97b8de
commit 6a48020426
2 changed files with 17 additions and 12 deletions

View File

@ -138,13 +138,13 @@ def get_tag_siblings(tag):
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) pt_alias1 = sqlalchemy.orm.aliased(db.PostTag)
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) pt_alias2 = sqlalchemy.orm.aliased(db.PostTag)
result = db.session \ result = db.session \
.query(tag_alias, sqlalchemy.func.count(tag_alias.tag_id)) \ .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) \
.join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) \ .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) \
.join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) \ .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) \
.filter(pt_alias2.tag_id == tag.tag_id) \ .filter(pt_alias2.tag_id == tag.tag_id) \
.filter(pt_alias1.tag_id != tag.tag_id) \ .filter(pt_alias1.tag_id != tag.tag_id) \
.group_by(tag_alias.tag_id) \ .group_by(tag_alias.tag_id) \
.order_by(tag_alias.post_count.desc()) \ .order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) \
.limit(50) .limit(50)
return result return result

View File

@ -4,11 +4,11 @@ from szurubooru import api, db, errors
from szurubooru.func import util, tags from szurubooru.func import util, tags
def assert_results(result, expected_tag_names_and_occurrences): def assert_results(result, expected_tag_names_and_occurrences):
actual_tag_names_and_occurences = {} actual_tag_names_and_occurences = []
for item in result['results']: for item in result['results']:
tag_name = item['tag']['names'][0] tag_name = item['tag']['names'][0]
occurrences = item['occurrences'] occurrences = item['occurrences']
actual_tag_names_and_occurences[tag_name] = occurrences actual_tag_names_and_occurences.append((tag_name, occurrences))
assert actual_tag_names_and_occurences == expected_tag_names_and_occurrences assert actual_tag_names_and_occurences == expected_tag_names_and_occurrences
@pytest.fixture @pytest.fixture
@ -33,7 +33,7 @@ def test_unused(test_ctx):
result = test_ctx.api.get( result = test_ctx.api.get(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag') user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag')
assert_results(result, {}) assert_results(result, [])
def test_used_alone(test_ctx): def test_used_alone(test_ctx):
tag = test_ctx.tag_factory(names=['tag']) tag = test_ctx.tag_factory(names=['tag'])
@ -43,7 +43,7 @@ def test_used_alone(test_ctx):
result = test_ctx.api.get( result = test_ctx.api.get(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag') user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag')
assert_results(result, {}) assert_results(result, [])
def test_used_with_others(test_ctx): def test_used_with_others(test_ctx):
tag1 = test_ctx.tag_factory(names=['tag1']) tag1 = test_ctx.tag_factory(names=['tag1'])
@ -54,11 +54,11 @@ def test_used_with_others(test_ctx):
result = test_ctx.api.get( result = test_ctx.api.get(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1') user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1')
assert_results(result, {'tag2': 1}) assert_results(result, [('tag2', 1)])
result = test_ctx.api.get( result = test_ctx.api.get(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2') user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2')
assert_results(result, {'tag1': 1}) assert_results(result, [('tag1', 1)])
def test_used_with_multiple_others(test_ctx): def test_used_with_multiple_others(test_ctx):
tag1 = test_ctx.tag_factory(names=['tag1']) tag1 = test_ctx.tag_factory(names=['tag1'])
@ -66,21 +66,26 @@ def test_used_with_multiple_others(test_ctx):
tag3 = test_ctx.tag_factory(names=['tag3']) tag3 = test_ctx.tag_factory(names=['tag3'])
post1 = test_ctx.post_factory() post1 = test_ctx.post_factory()
post2 = test_ctx.post_factory() post2 = test_ctx.post_factory()
post3 = test_ctx.post_factory()
post4 = test_ctx.post_factory()
post1.tags = [tag1, tag2, tag3] post1.tags = [tag1, tag2, tag3]
post2.tags = [tag1, tag3] post2.tags = [tag1, tag3]
db.session.add_all([post1, post2, tag1, tag2, tag3]) post3.tags = [tag2]
post4.tags = [tag2]
db.session.add_all([post1, post2, post3, post4, tag1, tag2, tag3])
result = test_ctx.api.get( result = test_ctx.api.get(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1') user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag1')
assert_results(result, {'tag2': 1, 'tag3': 2}) assert_results(result, [('tag3', 2), ('tag2', 1)])
result = test_ctx.api.get( result = test_ctx.api.get(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2') user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag2')
assert_results(result, {'tag1': 1, 'tag3': 1}) assert_results(result, [('tag1', 1), ('tag3', 1)])
result = test_ctx.api.get( result = test_ctx.api.get(
test_ctx.context_factory( test_ctx.context_factory(
user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag3') user=test_ctx.user_factory(rank=db.User.RANK_REGULAR)), 'tag3')
assert_results(result, {'tag1': 2, 'tag2': 1}) # even though tag2 is used more widely, tag1 is more relevant to tag3
assert_results(result, [('tag1', 2), ('tag2', 1)])
def test_trying_to_retrieve_non_existing(test_ctx): def test_trying_to_retrieve_non_existing(test_ctx):
with pytest.raises(tags.TagNotFoundError): with pytest.raises(tags.TagNotFoundError):