server/scores+favorites: merge duplicate code

This commit is contained in:
rr- 2016-05-10 11:56:24 +02:00
parent f140ae6176
commit cd15cdff7a
6 changed files with 41 additions and 47 deletions

View File

@ -22,3 +22,4 @@ from szurubooru.db.session import (
session, session,
reset_query_count, reset_query_count,
get_query_count) get_query_count)
import szurubooru.db.util

View File

@ -0,0 +1,32 @@
from sqlalchemy.inspection import inspect
def get_resource_info(entity):
serializers = {
'tag': lambda tag: tag.first_name,
'tag_category': lambda category: category.name,
'comment': lambda comment: comment.comment_id,
'post': lambda post: post.post_id,
}
resource_type = entity.__table__.name
assert resource_type in serializers
primary_key = inspect(entity).identity
assert primary_key is not None
assert len(primary_key) == 1
resource_repr = serializers[resource_type](entity)
assert resource_repr
resource_id = primary_key[0]
assert resource_id
return (resource_type, resource_id, resource_repr)
def get_aux_entity(session, get_table_info, entity, user):
table, get_column = get_table_info(entity)
return session \
.query(table) \
.filter(get_column(table) == get_column(entity)) \
.filter(table.user_id == user.user_id) \
.one_or_none()

View File

@ -1,21 +1,14 @@
import datetime import datetime
from szurubooru import db from szurubooru import db
from szurubooru.func import util
def _get_table_info(entity): def _get_table_info(entity):
resource_type, _, _ = util.get_resource_info(entity) resource_type, _, _ = db.util.get_resource_info(entity)
if resource_type == 'post': if resource_type == 'post':
return db.PostFavorite, lambda table: table.post_id return db.PostFavorite, lambda table: table.post_id
else:
assert False assert False
def _get_fav_entity(entity, user): def _get_fav_entity(entity, user):
table, get_column = _get_table_info(entity) return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
return db.session \
.query(table) \
.filter(get_column(table) == get_column(entity)) \
.filter(table.user_id == user.user_id) \
.one_or_none()
def has_favorited(entity, user): def has_favorited(entity, user):
return _get_fav_entity(entity, user) is not None return _get_fav_entity(entity, user) is not None

View File

@ -1,25 +1,18 @@
import datetime import datetime
from szurubooru import db, errors from szurubooru import db, errors
from szurubooru.func import util
class InvalidScoreError(errors.ValidationError): pass class InvalidScoreError(errors.ValidationError): pass
def _get_table_info(entity): def _get_table_info(entity):
resource_type, _, _ = util.get_resource_info(entity) resource_type, _, _ = db.util.get_resource_info(entity)
if resource_type == 'post': if resource_type == 'post':
return db.PostScore, lambda table: table.post_id return db.PostScore, lambda table: table.post_id
elif resource_type == 'comment': elif resource_type == 'comment':
return db.CommentScore, lambda table: table.comment_id return db.CommentScore, lambda table: table.comment_id
else:
assert False assert False
def _get_score_entity(entity, user): def _get_score_entity(entity, user):
table, get_column = _get_table_info(entity) return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
return db.session \
.query(table) \
.filter(get_column(table) == get_column(entity)) \
.filter(table.user_id == user.user_id) \
.one_or_none()
def delete_score(entity, user): def delete_score(entity, user):
score_entity = _get_score_entity(entity, user) score_entity = _get_score_entity(entity, user)

View File

@ -1,6 +1,5 @@
import datetime import datetime
from szurubooru import db from szurubooru import db
from szurubooru.func import util
def get_tag_snapshot(tag): def get_tag_snapshot(tag):
return { return {
@ -49,7 +48,7 @@ def get_previous_snapshot(snapshot):
.first() .first()
def get_snapshots(entity): def get_snapshots(entity):
resource_type, resource_id, _ = util.get_resource_info(entity) resource_type, resource_id, _ = db.util.get_resource_info(entity)
return db.session \ return db.session \
.query(db.Snapshot) \ .query(db.Snapshot) \
.filter(db.Snapshot.resource_type == resource_type) \ .filter(db.Snapshot.resource_type == resource_type) \
@ -81,7 +80,7 @@ def get_serialized_history(entity):
return ret return ret
def _save(operation, entity, auth_user): def _save(operation, entity, auth_user):
resource_type, resource_id, resource_repr = util.get_resource_info(entity) resource_type, resource_id, resource_repr = db.util.get_resource_info(entity)
now = datetime.datetime.now() now = datetime.datetime.now()
snapshot = db.Snapshot() snapshot = db.Snapshot()

View File

@ -1,7 +1,6 @@
import datetime import datetime
import hashlib import hashlib
import re import re
from sqlalchemy.inspection import inspect
from szurubooru.errors import ValidationError from szurubooru.errors import ValidationError
def unalias_dict(input_dict): def unalias_dict(input_dict):
@ -23,29 +22,6 @@ def get_md5(source):
def flip(source): def flip(source):
return {v: k for k, v in source.items()} return {v: k for k, v in source.items()}
def get_resource_info(entity):
serializers = {
'tag': lambda tag: tag.first_name,
'tag_category': lambda category: category.name,
'comment': lambda comment: comment.comment_id,
'post': lambda post: post.post_id,
}
resource_type = entity.__table__.name
assert resource_type in serializers
primary_key = inspect(entity).identity
assert primary_key is not None
assert len(primary_key) == 1
resource_repr = serializers[resource_type](entity)
assert resource_repr
resource_id = primary_key[0]
assert resource_id
return (resource_type, resource_id, resource_repr)
def is_valid_email(email): def is_valid_email(email):
''' Return whether given email address is valid or empty. ''' ''' Return whether given email address is valid or empty. '''
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)