server/favorites: favoriting sets score to 1

This commit is contained in:
rr- 2016-05-21 22:29:31 +02:00
parent 519f606a39
commit 16d4d3ca68
5 changed files with 23 additions and 10 deletions

View File

@ -1,11 +1,14 @@
import datetime import datetime
from szurubooru import db from szurubooru import db, errors
from szurubooru.func import scores
class InvalidFavoriteTargetError(errors.ValidationError): pass
def _get_table_info(entity): def _get_table_info(entity):
resource_type, _, _ = db.util.get_resource_info(entity) resource_type, _, _ = db.util.get_resource_info(entity)
if resource_type == 'post': if resource_type == 'post':
return db.PostFavorite, lambda table: table.post_id return db.PostFavorite, lambda table: table.post_id
assert False raise InvalidFavoriteTargetError()
def _get_fav_entity(entity, user): def _get_fav_entity(entity, user):
return db.util.get_aux_entity(db.session, _get_table_info, entity, user) return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
@ -19,6 +22,10 @@ def unset_favorite(entity, user):
db.session.delete(fav_entity) db.session.delete(fav_entity)
def set_favorite(entity, user): def set_favorite(entity, user):
try:
scores.set_score(entity, user, 1)
except scores.InvalidScoreTargetError:
pass
fav_entity = _get_fav_entity(entity, user) fav_entity = _get_fav_entity(entity, user)
if not fav_entity: if not fav_entity:
table, get_column = _get_table_info(entity) table, get_column = _get_table_info(entity)

View File

@ -1,7 +1,8 @@
import datetime import datetime
from szurubooru import db, errors from szurubooru import db, errors
class InvalidScoreError(errors.ValidationError): pass class InvalidScoreTargetError(errors.ValidationError): pass
class InvalidScoreValueError(errors.ValidationError): pass
def _get_table_info(entity): def _get_table_info(entity):
resource_type, _, _ = db.util.get_resource_info(entity) resource_type, _, _ = db.util.get_resource_info(entity)
@ -9,7 +10,7 @@ def _get_table_info(entity):
return db.PostScore, lambda table: table.post_id return db.PostScore, lambda table: table.post_id
elif resource_type == 'comment': elif resource_type == 'comment':
return db.CommentScore, lambda table: table.comment_id return db.CommentScore, lambda table: table.comment_id
assert False raise InvalidScoreTargetError()
def _get_score_entity(entity, user): def _get_score_entity(entity, user):
return db.util.get_aux_entity(db.session, _get_table_info, entity, user) return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
@ -31,7 +32,7 @@ def set_score(entity, user, score):
delete_score(entity, user) delete_score(entity, user)
return return
if score not in (-1, 1): if score not in (-1, 1):
raise InvalidScoreError( raise InvalidScoreValueError(
'Score %r is invalid. Valid scores: %r.' % (score, (-1, 1))) 'Score %r is invalid. Valid scores: %r.' % (score, (-1, 1)))
score_entity = _get_score_entity(entity, user) score_entity = _get_score_entity(entity, user)
if score_entity: if score_entity:

View File

@ -110,8 +110,8 @@ def test_ratings_from_multiple_users(test_ctx, fake_datetime):
@pytest.mark.parametrize('input,expected_exception', [ @pytest.mark.parametrize('input,expected_exception', [
({'score': None}, errors.ValidationError), ({'score': None}, errors.ValidationError),
({'score': ''}, errors.ValidationError), ({'score': ''}, errors.ValidationError),
({'score': -2}, scores.InvalidScoreError), ({'score': -2}, scores.InvalidScoreValueError),
({'score': 2}, scores.InvalidScoreError), ({'score': 2}, scores.InvalidScoreValueError),
({'score': [1]}, errors.ValidationError), ({'score': [1]}, errors.ValidationError),
]) ])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):

View File

@ -23,10 +23,11 @@ def test_ctx(
ret.api = api.PostFavoriteApi() ret.api = api.PostFavoriteApi()
return ret return ret
def test_simple_rating(test_ctx, fake_datetime): def test_adding_to_favorites(test_ctx, fake_datetime):
post = test_ctx.post_factory() post = test_ctx.post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
assert post.score == 0
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.post( result = test_ctx.api.post(
test_ctx.context_factory(user=test_ctx.user_factory()), test_ctx.context_factory(user=test_ctx.user_factory()),
@ -37,21 +38,25 @@ def test_simple_rating(test_ctx, fake_datetime):
assert db.session.query(db.PostFavorite).count() == 1 assert db.session.query(db.PostFavorite).count() == 1
assert post is not None assert post is not None
assert post.favorite_count == 1 assert post.favorite_count == 1
assert post.score == 1
def test_removing_from_favorites(test_ctx, fake_datetime): def test_removing_from_favorites(test_ctx, fake_datetime):
user = test_ctx.user_factory() user = test_ctx.user_factory()
post = test_ctx.post_factory() post = test_ctx.post_factory()
db.session.add(post) db.session.add(post)
db.session.commit() db.session.commit()
assert post.score == 0
with fake_datetime('1997-12-01'): with fake_datetime('1997-12-01'):
result = test_ctx.api.post( result = test_ctx.api.post(
test_ctx.context_factory(user=user), test_ctx.context_factory(user=user),
post.post_id) post.post_id)
assert post.score == 1
with fake_datetime('1997-12-02'): with fake_datetime('1997-12-02'):
result = test_ctx.api.delete( result = test_ctx.api.delete(
test_ctx.context_factory(user=user), test_ctx.context_factory(user=user),
post.post_id) post.post_id)
post = db.session.query(db.Post).one() post = db.session.query(db.Post).one()
assert post.score == 1
assert db.session.query(db.PostFavorite).count() == 0 assert db.session.query(db.PostFavorite).count() == 0
assert post.favorite_count == 0 assert post.favorite_count == 0

View File

@ -106,8 +106,8 @@ def test_ratings_from_multiple_users(test_ctx, fake_datetime):
@pytest.mark.parametrize('input,expected_exception', [ @pytest.mark.parametrize('input,expected_exception', [
({'score': None}, errors.ValidationError), ({'score': None}, errors.ValidationError),
({'score': ''}, errors.ValidationError), ({'score': ''}, errors.ValidationError),
({'score': -2}, scores.InvalidScoreError), ({'score': -2}, scores.InvalidScoreValueError),
({'score': 2}, scores.InvalidScoreError), ({'score': 2}, scores.InvalidScoreValueError),
({'score': [1]}, errors.ValidationError), ({'score': [1]}, errors.ValidationError),
]) ])
def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception): def test_trying_to_pass_invalid_input(test_ctx, input, expected_exception):