server: lint

This commit is contained in:
rr- 2017-04-24 23:30:53 +02:00
parent fea9a94945
commit 4bc58a3c95
42 changed files with 192 additions and 169 deletions

View File

@ -27,7 +27,7 @@ def _serialize(
@rest.routes.get('/comments/?') @rest.routes.get('/comments/?')
def get_comments( def get_comments(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:list') auth.verify_privilege(ctx.user, 'comments:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda comment: _serialize(ctx, comment)) ctx, lambda comment: _serialize(ctx, comment))
@ -35,7 +35,7 @@ def get_comments(
@rest.routes.post('/comments/?') @rest.routes.post('/comments/?')
def create_comment( def create_comment(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'comments:create') auth.verify_privilege(ctx.user, 'comments:create')
text = ctx.get_param_as_string('text') text = ctx.get_param_as_string('text')
post_id = ctx.get_param_as_int('postId') post_id = ctx.get_param_as_int('postId')

View File

@ -28,7 +28,7 @@ def _get_disk_usage() -> int:
@rest.routes.get('/info/?') @rest.routes.get('/info/?')
def get_info( def get_info(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
post_feature = posts.try_get_current_post_feature() post_feature = posts.try_get_current_post_feature()
return { return {
'postCount': posts.get_post_count(), 'postCount': posts.get_post_count(),

View File

@ -5,10 +5,10 @@ from hashlib import md5
MAIL_SUBJECT = 'Password reset for {name}' MAIL_SUBJECT = 'Password reset for {name}'
MAIL_BODY = \ MAIL_BODY = (
'You (or someone else) requested to reset your password on {name}.\n' \ 'You (or someone else) requested to reset your password on {name}.\n'
'If you wish to proceed, click this link: {url}\n' \ 'If you wish to proceed, click this link: {url}\n'
'Otherwise, please ignore this email.' 'Otherwise, please ignore this email.')
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?') @rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')

View File

@ -31,7 +31,7 @@ def _serialize_post(
@rest.routes.get('/posts/?') @rest.routes.get('/posts/?')
def get_posts( def get_posts(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:list') auth.verify_privilege(ctx.user, 'posts:list')
_search_executor_config.user = ctx.user _search_executor_config.user = ctx.user
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
@ -40,7 +40,7 @@ def get_posts(
@rest.routes.post('/posts/?') @rest.routes.post('/posts/?')
def create_post( def create_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
anonymous = ctx.get_param_as_bool('anonymous', default=False) anonymous = ctx.get_param_as_bool('anonymous', default=False)
if anonymous: if anonymous:
auth.verify_privilege(ctx.user, 'posts:create:anonymous') auth.verify_privilege(ctx.user, 'posts:create:anonymous')
@ -144,7 +144,7 @@ def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
@rest.routes.post('/post-merge/?') @rest.routes.post('/post-merge/?')
def merge_posts( def merge_posts(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
source_post_id = ctx.get_param_as_int('remove') source_post_id = ctx.get_param_as_int('remove')
target_post_id = ctx.get_param_as_int('mergeTo') target_post_id = ctx.get_param_as_int('mergeTo')
source_post = posts.get_post_by_id(source_post_id) source_post = posts.get_post_by_id(source_post_id)
@ -162,14 +162,14 @@ def merge_posts(
@rest.routes.get('/featured-post/?') @rest.routes.get('/featured-post/?')
def get_featured_post( def get_featured_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
post = posts.try_get_featured_post() post = posts.try_get_featured_post()
return _serialize_post(ctx, post) return _serialize_post(ctx, post)
@rest.routes.post('/featured-post/?') @rest.routes.post('/featured-post/?')
def set_featured_post( def set_featured_post(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:feature') auth.verify_privilege(ctx.user, 'posts:feature')
post_id = ctx.get_param_as_int('id') post_id = ctx.get_param_as_int('id')
post = posts.get_post_by_id(post_id) post = posts.get_post_by_id(post_id)
@ -235,7 +235,7 @@ def get_posts_around(
@rest.routes.post('/posts/reverse-search/?') @rest.routes.post('/posts/reverse-search/?')
def get_posts_by_image( def get_posts_by_image(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'posts:reverse_search') auth.verify_privilege(ctx.user, 'posts:reverse_search')
content = ctx.get_file('content') content = ctx.get_file('content')

View File

@ -8,7 +8,7 @@ _search_executor = search.Executor(search.configs.SnapshotSearchConfig())
@rest.routes.get('/snapshots/?') @rest.routes.get('/snapshots/?')
def get_snapshots( def get_snapshots(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'snapshots:list') auth.verify_privilege(ctx.user, 'snapshots:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user)) ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))

View File

@ -28,14 +28,15 @@ def _create_if_needed(tag_names: List[str], user: model.User) -> None:
@rest.routes.get('/tags/?') @rest.routes.get('/tags/?')
def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: def get_tags(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:list') auth.verify_privilege(ctx.user, 'tags:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda tag: _serialize(ctx, tag)) ctx, lambda tag: _serialize(ctx, tag))
@rest.routes.post('/tags/?') @rest.routes.post('/tags/?')
def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: def create_tag(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tags:create') auth.verify_privilege(ctx.user, 'tags:create')
names = ctx.get_param_as_string_list('names') names = ctx.get_param_as_string_list('names')
@ -112,7 +113,7 @@ def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
@rest.routes.post('/tag-merge/?') @rest.routes.post('/tag-merge/?')
def merge_tags( def merge_tags(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
source_tag_name = ctx.get_param_as_string('remove') source_tag_name = ctx.get_param_as_string('remove')
target_tag_name = ctx.get_param_as_string('mergeTo') target_tag_name = ctx.get_param_as_string('mergeTo')
source_tag = tags.get_tag_by_name(source_tag_name) source_tag = tags.get_tag_by_name(source_tag_name)

View File

@ -12,7 +12,7 @@ def _serialize(
@rest.routes.get('/tag-categories/?') @rest.routes.get('/tag-categories/?')
def get_tag_categories( def get_tag_categories(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:list') auth.verify_privilege(ctx.user, 'tag_categories:list')
categories = tag_categories.get_all_categories() categories = tag_categories.get_all_categories()
return { return {
@ -22,7 +22,7 @@ def get_tag_categories(
@rest.routes.post('/tag-categories/?') @rest.routes.post('/tag-categories/?')
def create_tag_category( def create_tag_category(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'tag_categories:create') auth.verify_privilege(ctx.user, 'tag_categories:create')
name = ctx.get_param_as_string('name') name = ctx.get_param_as_string('name')
color = ctx.get_param_as_string('color') color = ctx.get_param_as_string('color')

View File

@ -5,7 +5,7 @@ from szurubooru.func import auth, file_uploads
@rest.routes.post('/uploads/?') @rest.routes.post('/uploads/?')
def create_temporary_file( def create_temporary_file(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'uploads:create') auth.verify_privilege(ctx.user, 'uploads:create')
content = ctx.get_file('content', allow_tokens=False) content = ctx.get_file('content', allow_tokens=False)
token = file_uploads.save(content) token = file_uploads.save(content)

View File

@ -16,7 +16,8 @@ def _serialize(
@rest.routes.get('/users/?') @rest.routes.get('/users/?')
def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: def get_users(
ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:list') auth.verify_privilege(ctx.user, 'users:list')
return _search_executor.execute_and_serialize( return _search_executor.execute_and_serialize(
ctx, lambda user: _serialize(ctx, user)) ctx, lambda user: _serialize(ctx, user))
@ -24,7 +25,7 @@ def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
@rest.routes.post('/users/?') @rest.routes.post('/users/?')
def create_user( def create_user(
ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
auth.verify_privilege(ctx.user, 'users:create') auth.verify_privilege(ctx.user, 'users:create')
name = ctx.get_param_as_string('name') name = ctx.get_param_as_string('name')
password = ctx.get_param_as_string('password') password = ctx.get_param_as_string('password')

View File

@ -4,8 +4,8 @@ from typing import Dict
class BaseError(RuntimeError): class BaseError(RuntimeError):
def __init__( def __init__(
self, self,
message: str='Unknown error', message: str = 'Unknown error',
extra_fields: Dict[str, str]=None) -> None: extra_fields: Dict[str, str] = None) -> None:
super().__init__(message) super().__init__(message)
self.extra_fields = extra_fields self.extra_fields = extra_fields

View File

@ -21,9 +21,9 @@ class LruCache:
i i
for i, v in enumerate(self.item_list) for i, v in enumerate(self.item_list)
if v.key == item.key) if v.key == item.key)
self.item_list[:] \ self.item_list[:] = (
= self.item_list[:item_index] \ self.item_list[:item_index] +
+ self.item_list[item_index + 1:] self.item_list[item_index + 1:])
self.item_list.insert(0, item) self.item_list.insert(0, item)
else: else:
if len(self.item_list) > self.length: if len(self.item_list) > self.length:

View File

@ -1,7 +1,7 @@
from datetime import datetime from datetime import datetime
from typing import Any, Optional, List, Dict, Callable from typing import Any, Optional, List, Dict, Callable
from szurubooru import db, model, errors, rest from szurubooru import db, model, errors, rest
from szurubooru.func import users, scores, util, serialization from szurubooru.func import users, scores, serialization
class InvalidCommentIdError(errors.ValidationError): class InvalidCommentIdError(errors.ValidationError):
@ -65,7 +65,7 @@ class CommentSerializer(serialization.BaseSerializer):
def serialize_comment( def serialize_comment(
comment: model.Comment, comment: model.Comment,
auth_user: model.User, auth_user: model.User,
options: List[str]=[]) -> rest.Response: options: List[str] = []) -> rest.Response:
if comment is None: if comment is None:
return None return None
return CommentSerializer(comment, auth_user).serialize(options) return CommentSerializer(comment, auth_user).serialize(options)
@ -73,10 +73,11 @@ def serialize_comment(
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]: def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
comment_id = int(comment_id) comment_id = int(comment_id)
return db.session \ return (
.query(model.Comment) \ db.session
.filter(model.Comment.comment_id == comment_id) \ .query(model.Comment)
.one_or_none() .filter(model.Comment.comment_id == comment_id)
.one_or_none())
def get_comment_by_id(comment_id: int) -> model.Comment: def get_comment_by_id(comment_id: int) -> model.Comment:

View File

@ -99,7 +99,7 @@ def _normalize_and_threshold(
def _compute_grid_points( def _compute_grid_points(
image: NpMatrix, image: NpMatrix,
n: float, n: float,
window: Window=None) -> Tuple[NpMatrix, NpMatrix]: window: Window = None) -> Tuple[NpMatrix, NpMatrix]:
if window is None: if window is None:
window = ((0, image.shape[0]), (0, image.shape[1])) window = ((0, image.shape[0]), (0, image.shape[1]))
x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1] x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
@ -219,7 +219,7 @@ def _max_contrast(array: NpMatrix) -> None:
def _normalized_distance( def _normalized_distance(
target_array: NpMatrix, target_array: NpMatrix,
vec: NpMatrix, vec: NpMatrix,
nan_value: float=1.0) -> List[float]: nan_value: float = 1.0) -> List[float]:
target_array = target_array.astype(int) target_array = target_array.astype(int)
vec = vec.astype(int) vec = vec.astype(int)
topvec = np.linalg.norm(vec - target_array, axis=1) topvec = np.linalg.norm(vec - target_array, axis=1)

View File

@ -11,8 +11,8 @@ from szurubooru.func import mime, util
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_SCALE_FIT_FMT = \ _SCALE_FIT_FMT = (
r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)' r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)')
class Image: class Image:
@ -77,7 +77,7 @@ class Image:
'-', '-',
]) ])
def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes: def _execute(self, cli: List[str], program: str = 'ffmpeg') -> bytes:
extension = mime.get_extension(mime.get_mime_type(self.content)) extension = mime.get_extension(mime.get_mime_type(self.content))
assert extension assert extension
with util.create_temp_file(suffix='.' + extension) as handle: with util.create_temp_file(suffix='.' + extension) as handle:

View File

@ -7,10 +7,10 @@ from szurubooru.func import (
mime, images, files, image_hash, serialization) mime, images, files, image_hash, serialization)
EMPTY_PIXEL = \ EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
class PostNotFoundError(errors.NotFoundError): class PostNotFoundError(errors.NotFoundError):
@ -283,7 +283,7 @@ class PostSerializer(serialization.BaseSerializer):
def serialize_post( def serialize_post(
post: Optional[model.Post], post: Optional[model.Post],
auth_user: model.User, auth_user: model.User,
options: List[str]=[]) -> Optional[rest.Response]: options: List[str] = []) -> Optional[rest.Response]:
if not post: if not post:
return None return None
return PostSerializer(post, auth_user).serialize(options) return PostSerializer(post, auth_user).serialize(options)
@ -300,10 +300,11 @@ def get_post_count() -> int:
def try_get_post_by_id(post_id: int) -> Optional[model.Post]: def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
return db.session \ return (
.query(model.Post) \ db.session
.filter(model.Post.post_id == post_id) \ .query(model.Post)
.one_or_none() .filter(model.Post.post_id == post_id)
.one_or_none())
def get_post_by_id(post_id: int) -> model.Post: def get_post_by_id(post_id: int) -> model.Post:
@ -314,10 +315,11 @@ def get_post_by_id(post_id: int) -> model.Post:
def try_get_current_post_feature() -> Optional[model.PostFeature]: def try_get_current_post_feature() -> Optional[model.PostFeature]:
return db.session \ return (
.query(model.PostFeature) \ db.session
.order_by(model.PostFeature.time.desc()) \ .query(model.PostFeature)
.first() .order_by(model.PostFeature.time.desc())
.first())
def try_get_featured_post() -> Optional[model.Post]: def try_get_featured_post() -> Optional[model.Post]:
@ -426,11 +428,12 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
'Unhandled file type: %r' % post.mime_type) 'Unhandled file type: %r' % post.mime_type)
post.checksum = util.get_sha1(content) post.checksum = util.get_sha1(content)
other_post = db.session \ other_post = (
.query(model.Post) \ db.session
.filter(model.Post.checksum == post.checksum) \ .query(model.Post)
.filter(model.Post.post_id != post.post_id) \ .filter(model.Post.checksum == post.checksum)
.one_or_none() .filter(model.Post.post_id != post.post_id)
.one_or_none())
if other_post \ if other_post \
and other_post.post_id \ and other_post.post_id \
and other_post.post_id != post.post_id: and other_post.post_id != post.post_id:
@ -452,7 +455,7 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
def update_post_thumbnail( def update_post_thumbnail(
post: model.Post, content: Optional[bytes]=None) -> None: post: model.Post, content: Optional[bytes] = None) -> None:
assert post assert post
setattr(post, '__thumbnail', content) setattr(post, '__thumbnail', content)
@ -492,10 +495,11 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
old_posts = post.relations old_posts = post.relations
old_post_ids = [int(p.post_id) for p in old_posts] old_post_ids = [int(p.post_id) for p in old_posts]
if new_post_ids: if new_post_ids:
new_posts = db.session \ new_posts = (
.query(model.Post) \ db.session
.filter(model.Post.post_id.in_(new_post_ids)) \ .query(model.Post)
.all() .filter(model.Post.post_id.in_(new_post_ids))
.all())
else: else:
new_posts = [] new_posts = []
if len(new_posts) != len(new_post_ids): if len(new_posts) != len(new_post_ids):
@ -673,10 +677,11 @@ def merge_posts(
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content) checksum = util.get_sha1(image_content)
return db.session \ return (
.query(model.Post) \ db.session
.filter(model.Post.checksum == checksum) \ .query(model.Post)
.one_or_none() .filter(model.Post.checksum == checksum)
.one_or_none())
def search_by_image(image_content: bytes) -> List[PostLookalike]: def search_by_image(image_content: bytes) -> List[PostLookalike]:

View File

@ -39,11 +39,12 @@ def get_score(entity: model.Base, user: model.User) -> int:
assert entity assert entity
assert user assert user
table, get_column = _get_table_info(entity) table, get_column = _get_table_info(entity)
row = db.session \ row = (
.query(table.score) \ db.session
.filter(get_column(table) == get_column(entity)) \ .query(table.score)
.filter(table.user_id == user.user_id) \ .filter(get_column(table) == get_column(entity))
.one_or_none() .filter(table.user_id == user.user_id)
.one_or_none())
return row[0] if row else 0 return row[0] if row else 0

View File

@ -1,5 +1,5 @@
from typing import Any, Optional, List, Dict, Callable from typing import Any, List, Dict, Callable
from szurubooru import db, model, rest, errors from szurubooru import model, rest, errors
def get_serialization_options(ctx: rest.Context) -> List[str]: def get_serialization_options(ctx: rest.Context) -> List[str]:

View File

@ -66,7 +66,7 @@ class TagCategorySerializer(serialization.BaseSerializer):
def serialize_category( def serialize_category(
category: Optional[model.TagCategory], category: Optional[model.TagCategory],
options: List[str]=[]) -> Optional[rest.Response]: options: List[str] = []) -> Optional[rest.Response]:
if not category: if not category:
return None return None
return TagCategorySerializer(category).serialize(options) return TagCategorySerializer(category).serialize(options)
@ -113,16 +113,17 @@ def update_category_color(category: model.TagCategory, color: str) -> None:
def try_get_category_by_name( def try_get_category_by_name(
name: str, lock: bool=False) -> Optional[model.TagCategory]: name: str, lock: bool = False) -> Optional[model.TagCategory]:
query = db.session \ query = (
.query(model.TagCategory) \ db.session
.filter(sa.func.lower(model.TagCategory.name) == name.lower()) .query(model.TagCategory)
.filter(sa.func.lower(model.TagCategory.name) == name.lower()))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
return query.one_or_none() return query.one_or_none()
def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory: def get_category_by_name(name: str, lock: bool = False) -> model.TagCategory:
category = try_get_category_by_name(name, lock) category = try_get_category_by_name(name, lock)
if not category: if not category:
raise TagCategoryNotFoundError('Tag category %r not found.' % name) raise TagCategoryNotFoundError('Tag category %r not found.' % name)
@ -137,26 +138,29 @@ def get_all_categories() -> List[model.TagCategory]:
return db.session.query(model.TagCategory).all() return db.session.query(model.TagCategory).all()
def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]: def try_get_default_category(
query = db.session \ lock: bool = False) -> Optional[model.TagCategory]:
.query(model.TagCategory) \ query = (
.filter(model.TagCategory.default) db.session
.query(model.TagCategory)
.filter(model.TagCategory.default))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
category = query.first() category = query.first()
# if for some reason (e.g. as a result of migration) there's no default # if for some reason (e.g. as a result of migration) there's no default
# category, get the first record available. # category, get the first record available.
if not category: if not category:
query = db.session \ query = (
.query(model.TagCategory) \ db.session
.order_by(model.TagCategory.tag_category_id.asc()) .query(model.TagCategory)
.order_by(model.TagCategory.tag_category_id.asc()))
if lock: if lock:
query = query.with_lockmode('update') query = query.with_lockmode('update')
category = query.first() category = query.first()
return category return category
def get_default_category(lock: bool=False) -> model.TagCategory: def get_default_category(lock: bool = False) -> model.TagCategory:
category = try_get_default_category(lock) category = try_get_default_category(lock)
if not category: if not category:
raise TagCategoryNotFoundError('No tag category created yet.') raise TagCategoryNotFoundError('No tag category created yet.')

View File

@ -122,7 +122,7 @@ class TagSerializer(serialization.BaseSerializer):
def serialize_tag( def serialize_tag(
tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]: tag: model.Tag, options: List[str] = []) -> Optional[rest.Response]:
if not tag: if not tag:
return None return None
return TagSerializer(tag).serialize(options) return TagSerializer(tag).serialize(options)
@ -209,7 +209,8 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
names = util.icase_unique(names) names = util.icase_unique(names)
if len(names) == 0: if len(names) == 0:
return [] return []
return (db.session.query(model.Tag) return (
db.session.query(model.Tag)
.join(model.TagName) .join(model.TagName)
.filter( .filter(
sa.sql.or_( sa.sql.or_(

View File

@ -86,7 +86,7 @@ class UserSerializer(serialization.BaseSerializer):
self, self,
user: model.User, user: model.User,
auth_user: model.User, auth_user: model.User,
force_show_email: bool=False) -> None: force_show_email: bool = False) -> None:
self.user = user self.user = user
self.auth_user = auth_user self.auth_user = auth_user
self.force_show_email = force_show_email self.force_show_email = force_show_email
@ -151,8 +151,8 @@ class UserSerializer(serialization.BaseSerializer):
def serialize_user( def serialize_user(
user: Optional[model.User], user: Optional[model.User],
auth_user: model.User, auth_user: model.User,
options: List[str]=[], options: List[str] = [],
force_show_email: bool=False) -> Optional[rest.Response]: force_show_email: bool = False) -> Optional[rest.Response]:
if not user: if not user:
return None return None
return UserSerializer(user, auth_user, force_show_email).serialize(options) return UserSerializer(user, auth_user, force_show_email).serialize(options)
@ -170,10 +170,11 @@ def get_user_count() -> int:
def try_get_user_by_name(name: str) -> Optional[model.User]: def try_get_user_by_name(name: str) -> Optional[model.User]:
return db.session \ return (
.query(model.User) \ db.session
.filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \ .query(model.User)
.one_or_none() .filter(sa.func.lower(model.User.name) == sa.func.lower(name))
.one_or_none())
def get_user_by_name(name: str) -> model.User: def get_user_by_name(name: str) -> model.User:
@ -276,7 +277,7 @@ def update_user_rank(
def update_user_avatar( def update_user_avatar(
user: model.User, user: model.User,
avatar_style: str, avatar_style: str,
avatar_content: Optional[bytes]=None) -> None: avatar_content: Optional[bytes] = None) -> None:
assert user assert user
if avatar_style == 'gravatar': if avatar_style == 'gravatar':
user.avatar_style = user.AVATAR_GRAVATAR user.avatar_style = user.AVATAR_GRAVATAR

View File

@ -2,8 +2,7 @@ import os
import hashlib import hashlib
import re import re
import tempfile import tempfile
from typing import ( from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
from datetime import datetime, timedelta from datetime import datetime, timedelta
from contextlib import contextmanager from contextlib import contextmanager
from szurubooru import errors from szurubooru import errors

View File

@ -4,7 +4,7 @@ from szurubooru import errors, rest, model
def verify_version( def verify_version(
entity: model.Base, entity: model.Base,
context: rest.Context, context: rest.Context,
field_name: str='version') -> None: field_name: str = 'version') -> None:
actual_version = context.get_param_as_int(field_name) actual_version = context.get_param_as_int(field_name)
expected_version = entity.version expected_version = entity.version
if actual_version != expected_version: if actual_version != expected_version:

View File

@ -27,8 +27,9 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
credentials.encode('ascii')).decode('utf8').split(':') credentials.encode('ascii')).decode('utf8').split(':')
return _authenticate(username, password) return _authenticate(username, password)
except ValueError as err: except ValueError as err:
msg = 'Basic authentication header value are not properly formed. ' \ msg = (
+ 'Supplied header {0}. Got error: {1}' 'Basic authentication header value are not properly formed. '
'Supplied header {0}. Got error: {1}')
raise HttpBadRequest( raise HttpBadRequest(
'ValidationError', 'ValidationError',
msg.format(ctx.get_header('Authorization'), str(err))) msg.format(ctx.get_header('Authorization'), str(err)))

View File

@ -50,13 +50,14 @@ def upgrade():
def downgrade(): def downgrade():
session = sa.orm.session.Session(bind=op.get_bind()) session = sa.orm.session.Session(bind=op.get_bind())
default_category = session \ default_category = (
.query(TagCategory) \ session
.filter(TagCategory.name == 'default') \ .query(TagCategory)
.filter(TagCategory.color == 'default') \ .filter(TagCategory.name == 'default')
.filter(TagCategory.version == 1) \ .filter(TagCategory.color == 'default')
.filter(TagCategory.default == True) \ .filter(TagCategory.version == 1)
.one_or_none() .filter(TagCategory.default == 1)
.one_or_none())
if default_category: if default_category:
session.delete(default_category) session.delete(default_category)
session.commit() session.commit()

View File

@ -211,10 +211,11 @@ class Post(Base):
@property @property
def is_featured(self) -> bool: def is_featured(self) -> bool:
featured_post = sa.orm.object_session(self) \ featured_post = (
.query(PostFeature) \ sa.orm.object_session(self)
.order_by(PostFeature.time.desc()) \ .query(PostFeature)
.first() .order_by(PostFeature.time.desc())
.first())
return featured_post and featured_post.post_id == self.post_id return featured_post and featured_post.post_id == self.post_id
score = sa.orm.column_property( score = sa.orm.column_property(

View File

@ -14,7 +14,7 @@ class TagCategory(Base):
'color', sa.Unicode(32), nullable=False, default='#000000') 'color', sa.Unicode(32), nullable=False, default='#000000')
default = sa.Column('default', sa.Boolean, nullable=False, default=False) default = sa.Column('default', sa.Boolean, nullable=False, default=False)
def __init__(self, name: Optional[str]=None) -> None: def __init__(self, name: Optional[str] = None) -> None:
self.name = name self.name = name
tag_count = sa.orm.column_property( tag_count = sa.orm.column_property(

View File

@ -13,9 +13,9 @@ class Context:
self, self,
method: str, method: str,
url: str, url: str,
headers: Dict[str, str]=None, headers: Dict[str, str] = None,
params: Request=None, params: Request = None,
files: Dict[str, bytes]=None) -> None: files: Dict[str, bytes] = None) -> None:
self.method = method self.method = method
self.url = url self.url = url
self._headers = headers or {} self._headers = headers or {}
@ -34,7 +34,7 @@ class Context:
def get_header(self, name: str) -> str: def get_header(self, name: str) -> str:
return self._headers.get(name, '') return self._headers.get(name, '')
def has_file(self, name: str, allow_tokens: bool=True) -> bool: def has_file(self, name: str, allow_tokens: bool = True) -> bool:
return ( return (
name in self._files or name in self._files or
name + 'Url' in self._params or name + 'Url' in self._params or
@ -43,8 +43,8 @@ class Context:
def get_file( def get_file(
self, self,
name: str, name: str,
default: Union[object, bytes]=MISSING, default: Union[object, bytes] = MISSING,
allow_tokens: bool=True) -> bytes: allow_tokens: bool = True) -> bytes:
if name in self._files and self._files[name]: if name in self._files and self._files[name]:
return self._files[name] return self._files[name]
@ -70,7 +70,7 @@ class Context:
def get_param_as_list( def get_param_as_list(
self, self,
name: str, name: str,
default: Union[object, List[Any]]=MISSING) -> List[Any]: default: Union[object, List[Any]] = MISSING) -> List[Any]:
if name not in self._params: if name not in self._params:
if default is not MISSING: if default is not MISSING:
return cast(List[Any], default) return cast(List[Any], default)
@ -89,7 +89,7 @@ class Context:
def get_param_as_int_list( def get_param_as_int_list(
self, self,
name: str, name: str,
default: Union[object, List[int]]=MISSING) -> List[int]: default: Union[object, List[int]] = MISSING) -> List[int]:
ret = self.get_param_as_list(name, default) ret = self.get_param_as_list(name, default)
for item in ret: for item in ret:
if type(item) is not int: if type(item) is not int:
@ -100,7 +100,7 @@ class Context:
def get_param_as_string_list( def get_param_as_string_list(
self, self,
name: str, name: str,
default: Union[object, List[str]]=MISSING) -> List[str]: default: Union[object, List[str]] = MISSING) -> List[str]:
ret = self.get_param_as_list(name, default) ret = self.get_param_as_list(name, default)
for item in ret: for item in ret:
if type(item) is not str: if type(item) is not str:
@ -111,7 +111,7 @@ class Context:
def get_param_as_string( def get_param_as_string(
self, self,
name: str, name: str,
default: Union[object, str]=MISSING) -> str: default: Union[object, str] = MISSING) -> str:
if name not in self._params: if name not in self._params:
if default is not MISSING: if default is not MISSING:
return cast(str, default) return cast(str, default)
@ -135,9 +135,9 @@ class Context:
def get_param_as_int( def get_param_as_int(
self, self,
name: str, name: str,
default: Union[object, int]=MISSING, default: Union[object, int] = MISSING,
min: Optional[int]=None, min: Optional[int] = None,
max: Optional[int]=None) -> int: max: Optional[int] = None) -> int:
if name not in self._params: if name not in self._params:
if default is not MISSING: if default is not MISSING:
return cast(int, default) return cast(int, default)
@ -161,7 +161,7 @@ class Context:
def get_param_as_bool( def get_param_as_bool(
self, self,
name: str, name: str,
default: Union[object, bool]=MISSING) -> bool: default: Union[object, bool] = MISSING) -> bool:
if name not in self._params: if name not in self._params:
if default is not MISSING: if default is not MISSING:
return cast(bool, default) return cast(bool, default)

View File

@ -1,4 +1,4 @@
from typing import Callable, Type, Dict from typing import Optional, Callable, Type, Dict
error_handlers = {} # pylint: disable=invalid-name error_handlers = {} # pylint: disable=invalid-name
@ -12,8 +12,8 @@ class BaseHttpError(RuntimeError):
self, self,
name: str, name: str,
description: str, description: str,
title: str=None, title: Optional[str] = None,
extra_fields: Dict[str, str]=None) -> None: extra_fields: Optional[Dict[str, str]] = None) -> None:
super().__init__() super().__init__()
# error name for programmers # error name for programmers
self.name = name self.name = name

View File

@ -1,4 +1,4 @@
from typing import Callable from typing import List, Callable
from szurubooru.rest.context import Context from szurubooru.rest.context import Context

View File

@ -1,4 +1,4 @@
from typing import Callable, Dict, Any from typing import Callable, Dict
from collections import defaultdict from collections import defaultdict
from szurubooru.rest.context import Context, Response from szurubooru.rest.context import Context, Response

View File

@ -52,10 +52,11 @@ def _create_score_filter(score: int) -> Filter:
user_alias.name, criterion) user_alias.name, criterion)
if negated: if negated:
expr = ~expr expr = ~expr
ret = query \ ret = (
.join(score_alias, score_alias.post_id == model.Post.post_id) \ query
.join(user_alias, user_alias.user_id == score_alias.user_id) \ .join(score_alias, score_alias.post_id == model.Post.post_id)
.filter(expr) .join(user_alias, user_alias.user_id == score_alias.user_id)
.filter(expr))
return ret return ret
return wrapper return wrapper
@ -124,7 +125,8 @@ class PostSearchConfig(BaseSearchConfig):
sa.orm.lazyload sa.orm.lazyload
if disable_eager_loads if disable_eager_loads
else sa.orm.subqueryload) else sa.orm.subqueryload)
return db.session.query(model.Post) \ return (
db.session.query(model.Post)
.options( .options(
sa.orm.lazyload('*'), sa.orm.lazyload('*'),
# use config optimized for official client # use config optimized for official client
@ -141,7 +143,7 @@ class PostSearchConfig(BaseSearchConfig):
strategy(model.Post.tags).subqueryload(model.Tag.names), strategy(model.Post.tags).subqueryload(model.Tag.names),
strategy(model.Post.tags).defer(model.Tag.post_count), strategy(model.Post.tags).defer(model.Tag.post_count),
strategy(model.Post.tags).lazyload(model.Tag.implications), strategy(model.Post.tags).lazyload(model.Tag.implications),
strategy(model.Post.tags).lazyload(model.Tag.suggestions)) strategy(model.Post.tags).lazyload(model.Tag.suggestions)))
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Post) return db.session.query(model.Post)

View File

@ -14,8 +14,9 @@ class TagSearchConfig(BaseSearchConfig):
sa.orm.lazyload sa.orm.lazyload
if _disable_eager_loads if _disable_eager_loads
else sa.orm.subqueryload) else sa.orm.subqueryload)
return db.session.query(model.Tag) \ return (
.join(model.TagCategory) \ db.session.query(model.Tag)
.join(model.TagCategory)
.options( .options(
sa.orm.defer(model.Tag.first_name), sa.orm.defer(model.Tag.first_name),
sa.orm.defer(model.Tag.suggestion_count), sa.orm.defer(model.Tag.suggestion_count),
@ -23,7 +24,7 @@ class TagSearchConfig(BaseSearchConfig):
sa.orm.defer(model.Tag.post_count), sa.orm.defer(model.Tag.post_count),
strategy(model.Tag.names), strategy(model.Tag.names),
strategy(model.Tag.suggestions).joinedload(model.Tag.names), strategy(model.Tag.suggestions).joinedload(model.Tag.names),
strategy(model.Tag.implications).joinedload(model.Tag.names)) strategy(model.Tag.implications).joinedload(model.Tag.names)))
def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
return db.session.query(model.Tag) return db.session.query(model.Tag)

View File

@ -69,7 +69,7 @@ def float_transformer(value: str) -> float:
def apply_num_criterion_to_column( def apply_num_criterion_to_column(
column: Any, column: Any,
criterion: criteria.BaseCriterion, criterion: criteria.BaseCriterion,
transformer: Callable[[str], Number]=integer_transformer) -> SaQuery: transformer: Callable[[str], Number] = integer_transformer) -> SaQuery:
try: try:
if isinstance(criterion, criteria.PlainCriterion): if isinstance(criterion, criteria.PlainCriterion):
expr = column == transformer(criterion.value) expr = column == transformer(criterion.value)
@ -95,7 +95,7 @@ def apply_num_criterion_to_column(
def create_num_filter( def create_num_filter(
column: Any, column: Any,
transformer: Callable[[str], Number]=integer_transformer) -> SaQuery: transformer: Callable[[str], Number] = integer_transformer) -> SaQuery:
def wrapper( def wrapper(
query: SaQuery, query: SaQuery,
criterion: Optional[criteria.BaseCriterion], criterion: Optional[criteria.BaseCriterion],
@ -111,7 +111,7 @@ def create_num_filter(
def apply_str_criterion_to_column( def apply_str_criterion_to_column(
column: SaColumn, column: SaColumn,
criterion: criteria.BaseCriterion, criterion: criteria.BaseCriterion,
transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery: transformer: Callable[[str], str] = wildcard_transformer) -> SaQuery:
if isinstance(criterion, criteria.PlainCriterion): if isinstance(criterion, criteria.PlainCriterion):
expr = column.ilike(transformer(criterion.value)) expr = column.ilike(transformer(criterion.value))
elif isinstance(criterion, criteria.ArrayCriterion): elif isinstance(criterion, criteria.ArrayCriterion):
@ -128,8 +128,8 @@ def apply_str_criterion_to_column(
def create_str_filter( def create_str_filter(
column: SaColumn, transformer: Callable[[str], str]=wildcard_transformer column: SaColumn,
) -> Filter: transformer: Callable[[str], str] = wildcard_transformer) -> Filter:
def wrapper( def wrapper(
query: SaQuery, query: SaQuery,
criterion: Optional[criteria.BaseCriterion], criterion: Optional[criteria.BaseCriterion],
@ -187,7 +187,7 @@ def create_subquery_filter(
right_id_column: SaColumn, right_id_column: SaColumn,
filter_column: SaColumn, filter_column: SaColumn,
filter_factory: SaColumn, filter_factory: SaColumn,
subquery_decorator: Callable[[SaQuery], None]=None) -> Filter: subquery_decorator: Callable[[SaQuery], None] = None) -> Filter:
filter_func = filter_factory(filter_column) filter_func = filter_factory(filter_column)
def wrapper( def wrapper(

View File

@ -1,4 +1,4 @@
from typing import Optional, List, Callable from typing import Optional, List
from szurubooru.search.typing import SaQuery from szurubooru.search.typing import SaQuery

View File

@ -100,18 +100,20 @@ class Executor:
filter_query = self.config.create_filter_query(disable_eager_loads) filter_query = self.config.create_filter_query(disable_eager_loads)
filter_query = filter_query.options(sa.orm.lazyload('*')) filter_query = filter_query.options(sa.orm.lazyload('*'))
filter_query = self._prepare_db_query(filter_query, search_query, True) filter_query = self._prepare_db_query(filter_query, search_query, True)
entities = filter_query \ entities = (
.offset(offset) \ filter_query
.limit(limit) \ .offset(offset)
.all() .limit(limit)
.all())
count_query = self.config.create_count_query(disable_eager_loads) count_query = self.config.create_count_query(disable_eager_loads)
count_query = count_query.options(sa.orm.lazyload('*')) count_query = count_query.options(sa.orm.lazyload('*'))
count_query = self._prepare_db_query(count_query, search_query, False) count_query = self._prepare_db_query(count_query, search_query, False)
count_statement = count_query \ count_statement = (
.statement \ count_query
.with_only_columns([sa.func.count()]) \ .statement
.order_by(None) .with_only_columns([sa.func.count()])
.order_by(None))
count = db.session.execute(count_statement).scalar() count = db.session.execute(count_statement).scalar()
ret = (count, entities) ret = (count, entities)

View File

@ -1,5 +1,4 @@
import re import re
from typing import Match, List
from szurubooru import errors from szurubooru import errors
from szurubooru.search import criteria, tokens from szurubooru.search import criteria, tokens
from szurubooru.search.query import SearchQuery from szurubooru.search.query import SearchQuery

View File

@ -1,4 +1,5 @@
from szurubooru.search import tokens from szurubooru.search import tokens
from typing import List
class SearchQuery: class SearchQuery:

View File

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, model, errors from szurubooru import api, model, errors
from szurubooru.func import tags, snapshots from szurubooru.func import tags, snapshots

View File

@ -1,6 +1,6 @@
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from szurubooru import api, db, model, errors from szurubooru import api, model, errors
from szurubooru.func import users from szurubooru.func import users

View File

@ -5,10 +5,10 @@ from szurubooru import db, model, errors
from szurubooru.func import auth, users, files, util from szurubooru.func import auth, users, files, util
EMPTY_PIXEL = \ EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
@pytest.mark.parametrize('user_name', ['test', 'TEST']) @pytest.mark.parametrize('user_name', ['test', 'TEST'])

View File

@ -28,11 +28,12 @@ def test_saving_tag(tag_factory):
tag.implications.append(imp2) tag.implications.append(imp2)
db.session.commit() db.session.commit()
tag = db.session \ tag = (
.query(model.Tag) \ db.session
.join(model.TagName) \ .query(model.Tag)
.filter(model.TagName.name == 'alias1') \ .join(model.TagName)
.one() .filter(model.TagName.name == 'alias1')
.one())
assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2'] assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2']
assert tag.category.name == 'category' assert tag.category.name == 'category'
assert tag.creation_time == datetime(1997, 1, 1) assert tag.creation_time == datetime(1997, 1, 1)

View File

@ -300,7 +300,7 @@ def test_filter_by_note_count(
('note-text:text3*', [3]), ('note-text:text3*', [3]),
('note-text:text3a,text2', [2, 3]), ('note-text:text3a,text2', [2, 3]),
]) ])
def test_filter_by_note_count( def test_filter_by_note_text(
verify_unpaged, post_factory, note_factory, input, expected_post_ids): verify_unpaged, post_factory, note_factory, input, expected_post_ids):
post1 = post_factory(id=1) post1 = post_factory(id=1)
post2 = post_factory(id=2) post2 = post_factory(id=2)