server/general: add assertions
This commit is contained in:
parent
bb86e9bf56
commit
c2bbf7b62c
|
@ -40,6 +40,7 @@ def create_password():
|
|||
return ''.join(random.choice(alphabet[l]) for l in list(pattern))
|
||||
|
||||
def is_valid_password(user, password):
|
||||
assert user
|
||||
salt, valid_hash = user.password_salt, user.password_hash
|
||||
possible_hashes = [
|
||||
get_password_hash(salt, password),
|
||||
|
@ -48,6 +49,7 @@ def is_valid_password(user, password):
|
|||
return valid_hash in possible_hashes
|
||||
|
||||
def has_privilege(user, privilege_name):
|
||||
assert user
|
||||
all_ranks = list(RANK_MAP.keys())
|
||||
assert privilege_name in config.config['privileges']
|
||||
assert user.rank in all_ranks
|
||||
|
@ -57,11 +59,13 @@ def has_privilege(user, privilege_name):
|
|||
return user.rank in good_ranks
|
||||
|
||||
def verify_privilege(user, privilege_name):
|
||||
assert user
|
||||
if not has_privilege(user, privilege_name):
|
||||
raise errors.AuthError('Insufficient privileges to do this.')
|
||||
|
||||
def generate_authentication_token(user):
|
||||
''' Generate nonguessable challenge (e.g. links in password reminder). '''
|
||||
assert user
|
||||
digest = hashlib.md5()
|
||||
digest.update(config.config['secret'].encode('utf8'))
|
||||
digest.update(user.password_salt.encode('utf8'))
|
||||
|
|
|
@ -42,6 +42,7 @@ def create_comment(user, post, text):
|
|||
return comment
|
||||
|
||||
def update_comment_text(comment, text):
|
||||
assert comment
|
||||
if not text:
|
||||
raise EmptyCommentTextError('Comment text cannot be empty.')
|
||||
comment.text = text
|
||||
|
|
|
@ -5,23 +5,32 @@ from szurubooru.func import scores
|
|||
class InvalidFavoriteTargetError(errors.ValidationError): pass
|
||||
|
||||
def _get_table_info(entity):
|
||||
assert entity
|
||||
resource_type, _, _ = db.util.get_resource_info(entity)
|
||||
if resource_type == 'post':
|
||||
return db.PostFavorite, lambda table: table.post_id
|
||||
raise InvalidFavoriteTargetError()
|
||||
|
||||
def _get_fav_entity(entity, user):
|
||||
assert entity
|
||||
assert user
|
||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||
|
||||
def has_favorited(entity, user):
|
||||
assert entity
|
||||
assert user
|
||||
return _get_fav_entity(entity, user) is not None
|
||||
|
||||
def unset_favorite(entity, user):
|
||||
assert entity
|
||||
assert user
|
||||
fav_entity = _get_fav_entity(entity, user)
|
||||
if fav_entity:
|
||||
db.session.delete(fav_entity)
|
||||
|
||||
def set_favorite(entity, user):
|
||||
assert entity
|
||||
assert user
|
||||
try:
|
||||
scores.set_score(entity, user, 1)
|
||||
except scores.InvalidScoreTargetError:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
import urllib.request
|
||||
|
||||
def download(url):
|
||||
assert url
|
||||
request = urllib.request.Request(url)
|
||||
request.add_header('Referer', url)
|
||||
with urllib.request.urlopen(request) as handle:
|
||||
|
|
|
@ -36,27 +36,33 @@ FLAG_MAP = {
|
|||
}
|
||||
|
||||
def get_post_content_url(post):
|
||||
assert post
|
||||
return '%s/posts/%d.%s' % (
|
||||
config.config['data_url'].rstrip('/'),
|
||||
post.post_id,
|
||||
mime.get_extension(post.mime_type) or 'dat')
|
||||
|
||||
def get_post_thumbnail_url(post):
|
||||
assert post
|
||||
return '%s/generated-thumbnails/%d.jpg' % (
|
||||
config.config['data_url'].rstrip('/'),
|
||||
post.post_id)
|
||||
|
||||
def get_post_content_path(post):
|
||||
assert post
|
||||
return 'posts/%d.%s' % (
|
||||
post.post_id, mime.get_extension(post.mime_type) or 'dat')
|
||||
|
||||
def get_post_thumbnail_path(post):
|
||||
assert post
|
||||
return 'generated-thumbnails/%d.jpg' % (post.post_id)
|
||||
|
||||
def get_post_thumbnail_backup_path(post):
|
||||
assert post
|
||||
return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
|
||||
|
||||
def serialize_note(note):
|
||||
assert note
|
||||
return {
|
||||
'polygon': note.polygon,
|
||||
'text': note.text,
|
||||
|
@ -175,6 +181,7 @@ def create_post(content, tag_names, user):
|
|||
return (post, new_tags)
|
||||
|
||||
def update_post_safety(post, safety):
|
||||
assert post
|
||||
safety = util.flip(SAFETY_MAP).get(safety, None)
|
||||
if not safety:
|
||||
raise InvalidPostSafetyError(
|
||||
|
@ -182,11 +189,13 @@ def update_post_safety(post, safety):
|
|||
post.safety = safety
|
||||
|
||||
def update_post_source(post, source):
|
||||
assert post
|
||||
if util.value_exceeds_column_size(source, db.Post.source):
|
||||
raise InvalidPostSourceError('Source is too long.')
|
||||
post.source = source
|
||||
|
||||
def update_post_content(post, content):
|
||||
assert post
|
||||
if not content:
|
||||
raise InvalidPostContentError('Post content missing.')
|
||||
post.mime_type = mime.get_mime_type(content)
|
||||
|
@ -227,6 +236,7 @@ def update_post_content(post, content):
|
|||
update_post_thumbnail(post, content=None, do_delete=False)
|
||||
|
||||
def update_post_thumbnail(post, content=None, do_delete=True):
|
||||
assert post
|
||||
if not content:
|
||||
content = files.get(get_post_content_path(post))
|
||||
if do_delete:
|
||||
|
@ -236,6 +246,7 @@ def update_post_thumbnail(post, content=None, do_delete=True):
|
|||
generate_post_thumbnail(post)
|
||||
|
||||
def generate_post_thumbnail(post):
|
||||
assert post
|
||||
if files.has(get_post_thumbnail_backup_path(post)):
|
||||
content = files.get(get_post_thumbnail_backup_path(post))
|
||||
else:
|
||||
|
@ -250,11 +261,13 @@ def generate_post_thumbnail(post):
|
|||
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
|
||||
|
||||
def update_post_tags(post, tag_names):
|
||||
assert post
|
||||
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
||||
post.tags = existing_tags + new_tags
|
||||
return new_tags
|
||||
|
||||
def update_post_relations(post, new_post_ids):
|
||||
assert post
|
||||
old_posts = post.relations
|
||||
old_post_ids = [p.post_id for p in old_posts]
|
||||
new_posts = db.session \
|
||||
|
@ -274,6 +287,7 @@ def update_post_relations(post, new_post_ids):
|
|||
relation.relations.append(post)
|
||||
|
||||
def update_post_notes(post, notes):
|
||||
assert post
|
||||
post.notes = []
|
||||
for note in notes:
|
||||
for field in ('polygon', 'text'):
|
||||
|
@ -309,6 +323,7 @@ def update_post_notes(post, notes):
|
|||
db.PostNote(polygon=note['polygon'], text=str(note['text'])))
|
||||
|
||||
def update_post_flags(post, flags):
|
||||
assert post
|
||||
target_flags = []
|
||||
for flag in flags:
|
||||
flag = util.flip(FLAG_MAP).get(flag, None)
|
||||
|
@ -319,6 +334,7 @@ def update_post_flags(post, flags):
|
|||
post.flags = target_flags
|
||||
|
||||
def feature_post(post, user):
|
||||
assert post
|
||||
post_feature = db.PostFeature()
|
||||
post_feature.time = datetime.datetime.utcnow()
|
||||
post_feature.post = post
|
||||
|
@ -326,4 +342,5 @@ def feature_post(post, user):
|
|||
db.session.add(post_feature)
|
||||
|
||||
def delete(post):
|
||||
assert post
|
||||
db.session.delete(post)
|
||||
|
|
|
@ -6,6 +6,7 @@ class InvalidScoreTargetError(errors.ValidationError): pass
|
|||
class InvalidScoreValueError(errors.ValidationError): pass
|
||||
|
||||
def _get_table_info(entity):
|
||||
assert entity
|
||||
resource_type, _, _ = db.util.get_resource_info(entity)
|
||||
if resource_type == 'post':
|
||||
return db.PostScore, lambda table: table.post_id
|
||||
|
@ -14,14 +15,19 @@ def _get_table_info(entity):
|
|||
raise InvalidScoreTargetError()
|
||||
|
||||
def _get_score_entity(entity, user):
|
||||
assert user
|
||||
return db.util.get_aux_entity(db.session, _get_table_info, entity, user)
|
||||
|
||||
def delete_score(entity, user):
|
||||
assert entity
|
||||
assert user
|
||||
score_entity = _get_score_entity(entity, user)
|
||||
if score_entity:
|
||||
db.session.delete(score_entity)
|
||||
|
||||
def get_score(entity, user):
|
||||
assert entity
|
||||
assert user
|
||||
table, get_column = _get_table_info(entity)
|
||||
row = db.session \
|
||||
.query(table.score) \
|
||||
|
@ -31,6 +37,8 @@ def get_score(entity, user):
|
|||
return row[0] if row else 0
|
||||
|
||||
def set_score(entity, user, score):
|
||||
assert entity
|
||||
assert user
|
||||
if not score:
|
||||
delete_score(entity, user)
|
||||
try:
|
||||
|
|
|
@ -40,6 +40,7 @@ serializers = {
|
|||
}
|
||||
|
||||
def get_previous_snapshot(snapshot):
|
||||
assert snapshot
|
||||
return db.session \
|
||||
.query(db.Snapshot) \
|
||||
.filter(db.Snapshot.resource_type == snapshot.resource_type) \
|
||||
|
@ -50,6 +51,7 @@ def get_previous_snapshot(snapshot):
|
|||
.first()
|
||||
|
||||
def get_snapshots(entity):
|
||||
assert entity
|
||||
resource_type, resource_id, _ = db.util.get_resource_info(entity)
|
||||
return db.session \
|
||||
.query(db.Snapshot) \
|
||||
|
@ -59,6 +61,7 @@ def get_snapshots(entity):
|
|||
.all()
|
||||
|
||||
def serialize_snapshot(snapshot, earlier_snapshot=()):
|
||||
assert snapshot
|
||||
if earlier_snapshot is ():
|
||||
earlier_snapshot = get_previous_snapshot(snapshot)
|
||||
return {
|
||||
|
@ -82,6 +85,8 @@ def get_serialized_history(entity):
|
|||
return ret
|
||||
|
||||
def _save(operation, entity, auth_user):
|
||||
assert operation
|
||||
assert entity
|
||||
resource_type, resource_id, resource_repr = db.util.get_resource_info(entity)
|
||||
now = datetime.datetime.utcnow()
|
||||
|
||||
|
@ -115,10 +120,13 @@ def _save(operation, entity, auth_user):
|
|||
db.session.add(snapshot)
|
||||
|
||||
def save_entity_creation(entity, auth_user):
|
||||
assert entity
|
||||
_save(db.Snapshot.OPERATION_CREATED, entity, auth_user)
|
||||
|
||||
def save_entity_modification(entity, auth_user):
|
||||
assert entity
|
||||
_save(db.Snapshot.OPERATION_MODIFIED, entity, auth_user)
|
||||
|
||||
def save_entity_deletion(entity, auth_user):
|
||||
assert entity
|
||||
_save(db.Snapshot.OPERATION_DELETED, entity, auth_user)
|
||||
|
|
|
@ -37,6 +37,7 @@ def create_category(name, color):
|
|||
return category
|
||||
|
||||
def update_category_name(category, name):
|
||||
assert category
|
||||
if not name:
|
||||
raise InvalidTagCategoryNameError('Name cannot be empty.')
|
||||
expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower()
|
||||
|
@ -52,6 +53,7 @@ def update_category_name(category, name):
|
|||
category.name = name
|
||||
|
||||
def update_category_color(category, color):
|
||||
assert category
|
||||
if not color:
|
||||
raise InvalidTagCategoryColorError('Color cannot be empty.')
|
||||
if not re.match(r'^#?[0-9a-z]+$', color):
|
||||
|
@ -103,6 +105,7 @@ def get_default_category():
|
|||
return category
|
||||
|
||||
def set_default_category(category):
|
||||
assert category
|
||||
old_category = try_get_default_category()
|
||||
if old_category:
|
||||
old_category.default = False
|
||||
|
|
|
@ -20,6 +20,7 @@ def _verify_name_validity(name):
|
|||
raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex)
|
||||
|
||||
def _get_plain_names(tag):
|
||||
assert tag
|
||||
return [tag_name.name for tag_name in tag.names]
|
||||
|
||||
def _lower_list(names):
|
||||
|
@ -157,6 +158,7 @@ def get_or_create_tags_by_names(names):
|
|||
return existing_tags, new_tags
|
||||
|
||||
def get_tag_siblings(tag):
|
||||
assert tag
|
||||
tag_alias = sqlalchemy.orm.aliased(db.Tag)
|
||||
pt_alias1 = sqlalchemy.orm.aliased(db.PostTag)
|
||||
pt_alias2 = sqlalchemy.orm.aliased(db.PostTag)
|
||||
|
@ -172,6 +174,7 @@ def get_tag_siblings(tag):
|
|||
return result
|
||||
|
||||
def delete(source_tag):
|
||||
assert source_tag
|
||||
db.session.execute(
|
||||
sqlalchemy.sql.expression.delete(db.TagSuggestion) \
|
||||
.where(db.TagSuggestion.child_id == source_tag.tag_id))
|
||||
|
@ -181,6 +184,8 @@ def delete(source_tag):
|
|||
db.session.delete(source_tag)
|
||||
|
||||
def merge_tags(source_tag, target_tag):
|
||||
assert source_tag
|
||||
assert target_tag
|
||||
if source_tag.tag_id == target_tag.tag_id:
|
||||
raise InvalidTagRelationError('Cannot merge tag with itself.')
|
||||
pt1 = db.PostTag
|
||||
|
@ -205,9 +210,11 @@ def create_tag(names, category_name, suggestions, implications):
|
|||
return tag
|
||||
|
||||
def update_tag_category_name(tag, category_name):
|
||||
assert tag
|
||||
tag.category = tag_categories.get_category_by_name(category_name)
|
||||
|
||||
def update_tag_names(tag, names):
|
||||
assert tag
|
||||
names = util.icase_unique([name for name in names if name])
|
||||
if not len(names):
|
||||
raise InvalidTagNameError('At least one name must be specified.')
|
||||
|
@ -232,16 +239,19 @@ def update_tag_names(tag, names):
|
|||
tag.names.append(db.TagName(name))
|
||||
|
||||
def update_tag_implications(tag, relations):
|
||||
assert tag
|
||||
if _check_name_intersection(_get_plain_names(tag), relations):
|
||||
raise InvalidTagRelationError('Tag cannot imply itself.')
|
||||
tag.implications = get_tags_by_names(relations)
|
||||
|
||||
def update_tag_suggestions(tag, relations):
|
||||
assert tag
|
||||
if _check_name_intersection(_get_plain_names(tag), relations):
|
||||
raise InvalidTagRelationError('Tag cannot suggest itself.')
|
||||
tag.suggestions = get_tags_by_names(relations)
|
||||
|
||||
def update_tag_description(tag, description):
|
||||
assert tag
|
||||
if util.value_exceeds_column_size(description, db.Tag.description):
|
||||
raise InvalidTagDescriptionError('Description is too long.')
|
||||
tag.description = description
|
||||
|
|
|
@ -16,15 +16,20 @@ def _get_avatar_path(name):
|
|||
return 'avatars/' + name.lower() + '.png'
|
||||
|
||||
def _get_avatar_url(user):
|
||||
assert user
|
||||
if user.avatar_style == user.AVATAR_GRAVATAR:
|
||||
assert user.email or user.name
|
||||
return 'https://gravatar.com/avatar/%s?d=retro&s=%d' % (
|
||||
util.get_md5((user.email or user.name).lower()),
|
||||
config.config['thumbnails']['avatar_width'])
|
||||
else:
|
||||
assert user.name
|
||||
return '%s/avatars/%s.png' % (
|
||||
config.config['data_url'].rstrip('/'), user.name.lower())
|
||||
|
||||
def _get_email(user, authenticated_user, force_show_email):
|
||||
assert user
|
||||
assert authenticated_user
|
||||
if not force_show_email \
|
||||
and authenticated_user.user_id != user.user_id \
|
||||
and not auth.has_privilege(authenticated_user, 'users:edit:any:email'):
|
||||
|
@ -32,11 +37,15 @@ def _get_email(user, authenticated_user, force_show_email):
|
|||
return user.email
|
||||
|
||||
def _get_liked_post_count(user, authenticated_user):
|
||||
assert user
|
||||
assert authenticated_user
|
||||
if authenticated_user.user_id != user.user_id:
|
||||
return False
|
||||
return user.liked_post_count
|
||||
|
||||
def _get_disliked_post_count(user, authenticated_user):
|
||||
assert user
|
||||
assert authenticated_user
|
||||
if authenticated_user.user_id != user.user_id:
|
||||
return False
|
||||
return user.disliked_post_count
|
||||
|
@ -113,6 +122,7 @@ def create_user(name, password, email):
|
|||
return user
|
||||
|
||||
def update_user_name(user, name):
|
||||
assert user
|
||||
if not name:
|
||||
raise InvalidUserNameError('Name cannot be empty.')
|
||||
if util.value_exceeds_column_size(name, db.User.name):
|
||||
|
@ -130,6 +140,7 @@ def update_user_name(user, name):
|
|||
user.name = name
|
||||
|
||||
def update_user_password(user, password):
|
||||
assert user
|
||||
if not password:
|
||||
raise InvalidPasswordError('Password cannot be empty.')
|
||||
password_regex = config.config['password_regex']
|
||||
|
@ -140,6 +151,7 @@ def update_user_password(user, password):
|
|||
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
||||
|
||||
def update_user_email(user, email):
|
||||
assert user
|
||||
if email:
|
||||
email = email.strip()
|
||||
if not email:
|
||||
|
@ -151,6 +163,7 @@ def update_user_email(user, email):
|
|||
user.email = email
|
||||
|
||||
def update_user_rank(user, rank, authenticated_user):
|
||||
assert user
|
||||
if not rank:
|
||||
raise InvalidRankError('Rank cannot be empty.')
|
||||
rank = util.flip(auth.RANK_MAP).get(rank.strip(), None)
|
||||
|
@ -166,6 +179,7 @@ def update_user_rank(user, rank, authenticated_user):
|
|||
user.rank = rank
|
||||
|
||||
def update_user_avatar(user, avatar_style, avatar_content):
|
||||
assert user
|
||||
if avatar_style == 'gravatar':
|
||||
user.avatar_style = user.AVATAR_GRAVATAR
|
||||
elif avatar_style == 'manual':
|
||||
|
@ -186,9 +200,11 @@ def update_user_avatar(user, avatar_style, avatar_content):
|
|||
avatar_style, ['gravatar', 'manual']))
|
||||
|
||||
def bump_user_login_time(user):
|
||||
assert user
|
||||
user.last_login_time = datetime.datetime.utcnow()
|
||||
|
||||
def reset_user_password(user):
|
||||
assert user
|
||||
password = auth.create_password()
|
||||
user.password_salt = auth.create_password()
|
||||
user.password_hash = auth.get_password_hash(user.password_salt, password)
|
||||
|
|
Loading…
Reference in New Issue