From cd15cdff7a399160113269b8ee70d62272329318 Mon Sep 17 00:00:00 2001 From: rr- Date: Tue, 10 May 2016 11:56:24 +0200 Subject: [PATCH] server/scores+favorites: merge duplicate code --- server/szurubooru/db/__init__.py | 1 + server/szurubooru/db/util.py | 32 +++++++++++++++++++++++++++++ server/szurubooru/func/favorites.py | 13 +++--------- server/szurubooru/func/scores.py | 13 +++--------- server/szurubooru/func/snapshots.py | 5 ++--- server/szurubooru/func/util.py | 24 ---------------------- 6 files changed, 41 insertions(+), 47 deletions(-) create mode 100644 server/szurubooru/db/util.py diff --git a/server/szurubooru/db/__init__.py b/server/szurubooru/db/__init__.py index 49ca156..1dbed94 100644 --- a/server/szurubooru/db/__init__.py +++ b/server/szurubooru/db/__init__.py @@ -22,3 +22,4 @@ from szurubooru.db.session import ( session, reset_query_count, get_query_count) +import szurubooru.db.util diff --git a/server/szurubooru/db/util.py b/server/szurubooru/db/util.py new file mode 100644 index 0000000..d9e5d67 --- /dev/null +++ b/server/szurubooru/db/util.py @@ -0,0 +1,32 @@ +from sqlalchemy.inspection import inspect + +def get_resource_info(entity): + serializers = { + 'tag': lambda tag: tag.first_name, + 'tag_category': lambda category: category.name, + 'comment': lambda comment: comment.comment_id, + 'post': lambda post: post.post_id, + } + + resource_type = entity.__table__.name + assert resource_type in serializers + + primary_key = inspect(entity).identity + assert primary_key is not None + assert len(primary_key) == 1 + + resource_repr = serializers[resource_type](entity) + assert resource_repr + + resource_id = primary_key[0] + assert resource_id + + return (resource_type, resource_id, resource_repr) + +def get_aux_entity(session, get_table_info, entity, user): + table, get_column = get_table_info(entity) + return session \ + .query(table) \ + .filter(get_column(table) == get_column(entity)) \ + .filter(table.user_id == user.user_id) \ + .one_or_none() diff --git a/server/szurubooru/func/favorites.py b/server/szurubooru/func/favorites.py index 9c3dda7..65adc47 100644 --- a/server/szurubooru/func/favorites.py +++ b/server/szurubooru/func/favorites.py @@ -1,21 +1,14 @@ import datetime from szurubooru import db -from szurubooru.func import util def _get_table_info(entity): - resource_type, _, _ = util.get_resource_info(entity) + resource_type, _, _ = db.util.get_resource_info(entity) if resource_type == 'post': return db.PostFavorite, lambda table: table.post_id - else: - assert False + assert False def _get_fav_entity(entity, user): - table, get_column = _get_table_info(entity) - return db.session \ - .query(table) \ - .filter(get_column(table) == get_column(entity)) \ - .filter(table.user_id == user.user_id) \ - .one_or_none() + return db.util.get_aux_entity(db.session, _get_table_info, entity, user) def has_favorited(entity, user): return _get_fav_entity(entity, user) is not None diff --git a/server/szurubooru/func/scores.py b/server/szurubooru/func/scores.py index 62ee773..03600dd 100644 --- a/server/szurubooru/func/scores.py +++ b/server/szurubooru/func/scores.py @@ -1,25 +1,18 @@ import datetime from szurubooru import db, errors -from szurubooru.func import util class InvalidScoreError(errors.ValidationError): pass def _get_table_info(entity): - resource_type, _, _ = util.get_resource_info(entity) + resource_type, _, _ = db.util.get_resource_info(entity) if resource_type == 'post': return db.PostScore, lambda table: table.post_id elif resource_type == 'comment': return db.CommentScore, lambda table: table.comment_id - else: - assert False + assert False def _get_score_entity(entity, user): - table, get_column = _get_table_info(entity) - return db.session \ - .query(table) \ - .filter(get_column(table) == get_column(entity)) \ - .filter(table.user_id == user.user_id) \ - .one_or_none() + return db.util.get_aux_entity(db.session, _get_table_info, entity, user) def delete_score(entity, user): score_entity = _get_score_entity(entity, user) diff --git a/server/szurubooru/func/snapshots.py b/server/szurubooru/func/snapshots.py index 4ecd9d3..44deb62 100644 --- a/server/szurubooru/func/snapshots.py +++ b/server/szurubooru/func/snapshots.py @@ -1,6 +1,5 @@ import datetime from szurubooru import db -from szurubooru.func import util def get_tag_snapshot(tag): return { @@ -49,7 +48,7 @@ def get_previous_snapshot(snapshot): .first() def get_snapshots(entity): - resource_type, resource_id, _ = util.get_resource_info(entity) + resource_type, resource_id, _ = db.util.get_resource_info(entity) return db.session \ .query(db.Snapshot) \ .filter(db.Snapshot.resource_type == resource_type) \ @@ -81,7 +80,7 @@ def get_serialized_history(entity): return ret def _save(operation, entity, auth_user): - resource_type, resource_id, resource_repr = util.get_resource_info(entity) + resource_type, resource_id, resource_repr = db.util.get_resource_info(entity) now = datetime.datetime.now() snapshot = db.Snapshot() diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index a80485d..f19e7fc 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -1,7 +1,6 @@ import datetime import hashlib import re -from sqlalchemy.inspection import inspect from szurubooru.errors import ValidationError def unalias_dict(input_dict): @@ -23,29 +22,6 @@ def get_md5(source): def flip(source): return {v: k for k, v in source.items()} -def get_resource_info(entity): - serializers = { - 'tag': lambda tag: tag.first_name, - 'tag_category': lambda category: category.name, - 'comment': lambda comment: comment.comment_id, - 'post': lambda post: post.post_id, - } - - resource_type = entity.__table__.name - assert resource_type in serializers - - primary_key = inspect(entity).identity - assert primary_key is not None - assert len(primary_key) == 1 - - resource_repr = serializers[resource_type](entity) - assert resource_repr - - resource_id = primary_key[0] - assert resource_id - - return (resource_type, resource_id, resource_repr) - def is_valid_email(email): ''' Return whether given email address is valid or empty. ''' return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)