server: lint
This commit is contained in:
		
							parent
							
								
									fea9a94945
								
							
						
					
					
						commit
						4bc58a3c95
					
				@ -27,7 +27,7 @@ def _serialize(
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/comments/?')
 | 
			
		||||
def get_comments(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'comments:list')
 | 
			
		||||
    return _search_executor.execute_and_serialize(
 | 
			
		||||
        ctx, lambda comment: _serialize(ctx, comment))
 | 
			
		||||
@ -35,7 +35,7 @@ def get_comments(
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/comments/?')
 | 
			
		||||
def create_comment(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'comments:create')
 | 
			
		||||
    text = ctx.get_param_as_string('text')
 | 
			
		||||
    post_id = ctx.get_param_as_int('postId')
 | 
			
		||||
 | 
			
		||||
@ -28,7 +28,7 @@ def _get_disk_usage() -> int:
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/info/?')
 | 
			
		||||
def get_info(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    post_feature = posts.try_get_current_post_feature()
 | 
			
		||||
    return {
 | 
			
		||||
        'postCount': posts.get_post_count(),
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,10 @@ from hashlib import md5
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
MAIL_SUBJECT = 'Password reset for {name}'
 | 
			
		||||
MAIL_BODY = \
 | 
			
		||||
    'You (or someone else) requested to reset your password on {name}.\n' \
 | 
			
		||||
    'If you wish to proceed, click this link: {url}\n' \
 | 
			
		||||
    'Otherwise, please ignore this email.'
 | 
			
		||||
MAIL_BODY = (
 | 
			
		||||
    'You (or someone else) requested to reset your password on {name}.\n'
 | 
			
		||||
    'If you wish to proceed, click this link: {url}\n'
 | 
			
		||||
    'Otherwise, please ignore this email.')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
 | 
			
		||||
 | 
			
		||||
@ -31,7 +31,7 @@ def _serialize_post(
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/posts/?')
 | 
			
		||||
def get_posts(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'posts:list')
 | 
			
		||||
    _search_executor_config.user = ctx.user
 | 
			
		||||
    return _search_executor.execute_and_serialize(
 | 
			
		||||
@ -40,7 +40,7 @@ def get_posts(
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/posts/?')
 | 
			
		||||
def create_post(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    anonymous = ctx.get_param_as_bool('anonymous', default=False)
 | 
			
		||||
    if anonymous:
 | 
			
		||||
        auth.verify_privilege(ctx.user, 'posts:create:anonymous')
 | 
			
		||||
@ -144,7 +144,7 @@ def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/post-merge/?')
 | 
			
		||||
def merge_posts(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    source_post_id = ctx.get_param_as_int('remove')
 | 
			
		||||
    target_post_id = ctx.get_param_as_int('mergeTo')
 | 
			
		||||
    source_post = posts.get_post_by_id(source_post_id)
 | 
			
		||||
@ -162,14 +162,14 @@ def merge_posts(
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/featured-post/?')
 | 
			
		||||
def get_featured_post(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    post = posts.try_get_featured_post()
 | 
			
		||||
    return _serialize_post(ctx, post)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/featured-post/?')
 | 
			
		||||
def set_featured_post(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'posts:feature')
 | 
			
		||||
    post_id = ctx.get_param_as_int('id')
 | 
			
		||||
    post = posts.get_post_by_id(post_id)
 | 
			
		||||
@ -235,7 +235,7 @@ def get_posts_around(
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/posts/reverse-search/?')
 | 
			
		||||
def get_posts_by_image(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'posts:reverse_search')
 | 
			
		||||
    content = ctx.get_file('content')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -8,7 +8,7 @@ _search_executor = search.Executor(search.configs.SnapshotSearchConfig())
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/snapshots/?')
 | 
			
		||||
def get_snapshots(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'snapshots:list')
 | 
			
		||||
    return _search_executor.execute_and_serialize(
 | 
			
		||||
        ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user))
 | 
			
		||||
 | 
			
		||||
@ -28,14 +28,15 @@ def _create_if_needed(tag_names: List[str], user: model.User) -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/tags/?')
 | 
			
		||||
def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
def get_tags(ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'tags:list')
 | 
			
		||||
    return _search_executor.execute_and_serialize(
 | 
			
		||||
        ctx, lambda tag: _serialize(ctx, tag))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/tags/?')
 | 
			
		||||
def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
def create_tag(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'tags:create')
 | 
			
		||||
 | 
			
		||||
    names = ctx.get_param_as_string_list('names')
 | 
			
		||||
@ -112,7 +113,7 @@ def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/tag-merge/?')
 | 
			
		||||
def merge_tags(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    source_tag_name = ctx.get_param_as_string('remove')
 | 
			
		||||
    target_tag_name = ctx.get_param_as_string('mergeTo')
 | 
			
		||||
    source_tag = tags.get_tag_by_name(source_tag_name)
 | 
			
		||||
 | 
			
		||||
@ -12,7 +12,7 @@ def _serialize(
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/tag-categories/?')
 | 
			
		||||
def get_tag_categories(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'tag_categories:list')
 | 
			
		||||
    categories = tag_categories.get_all_categories()
 | 
			
		||||
    return {
 | 
			
		||||
@ -22,7 +22,7 @@ def get_tag_categories(
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/tag-categories/?')
 | 
			
		||||
def create_tag_category(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'tag_categories:create')
 | 
			
		||||
    name = ctx.get_param_as_string('name')
 | 
			
		||||
    color = ctx.get_param_as_string('color')
 | 
			
		||||
 | 
			
		||||
@ -5,7 +5,7 @@ from szurubooru.func import auth, file_uploads
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/uploads/?')
 | 
			
		||||
def create_temporary_file(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'uploads:create')
 | 
			
		||||
    content = ctx.get_file('content', allow_tokens=False)
 | 
			
		||||
    token = file_uploads.save(content)
 | 
			
		||||
 | 
			
		||||
@ -16,7 +16,8 @@ def _serialize(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/users/?')
 | 
			
		||||
def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
def get_users(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'users:list')
 | 
			
		||||
    return _search_executor.execute_and_serialize(
 | 
			
		||||
        ctx, lambda user: _serialize(ctx, user))
 | 
			
		||||
@ -24,7 +25,7 @@ def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
 | 
			
		||||
@rest.routes.post('/users/?')
 | 
			
		||||
def create_user(
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, _params: Dict[str, str] = {}) -> rest.Response:
 | 
			
		||||
    auth.verify_privilege(ctx.user, 'users:create')
 | 
			
		||||
    name = ctx.get_param_as_string('name')
 | 
			
		||||
    password = ctx.get_param_as_string('password')
 | 
			
		||||
 | 
			
		||||
@ -4,8 +4,8 @@ from typing import Dict
 | 
			
		||||
class BaseError(RuntimeError):
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self,
 | 
			
		||||
            message: str='Unknown error',
 | 
			
		||||
            extra_fields: Dict[str, str]=None) -> None:
 | 
			
		||||
            message: str = 'Unknown error',
 | 
			
		||||
            extra_fields: Dict[str, str] = None) -> None:
 | 
			
		||||
        super().__init__(message)
 | 
			
		||||
        self.extra_fields = extra_fields
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -21,9 +21,9 @@ class LruCache:
 | 
			
		||||
                i
 | 
			
		||||
                for i, v in enumerate(self.item_list)
 | 
			
		||||
                if v.key == item.key)
 | 
			
		||||
            self.item_list[:] \
 | 
			
		||||
                = self.item_list[:item_index] \
 | 
			
		||||
                + self.item_list[item_index + 1:]
 | 
			
		||||
            self.item_list[:] = (
 | 
			
		||||
                self.item_list[:item_index] +
 | 
			
		||||
                self.item_list[item_index + 1:])
 | 
			
		||||
            self.item_list.insert(0, item)
 | 
			
		||||
        else:
 | 
			
		||||
            if len(self.item_list) > self.length:
 | 
			
		||||
 | 
			
		||||
@ -1,7 +1,7 @@
 | 
			
		||||
from datetime import datetime
 | 
			
		||||
from typing import Any, Optional, List, Dict, Callable
 | 
			
		||||
from szurubooru import db, model, errors, rest
 | 
			
		||||
from szurubooru.func import users, scores, util, serialization
 | 
			
		||||
from szurubooru.func import users, scores, serialization
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class InvalidCommentIdError(errors.ValidationError):
 | 
			
		||||
@ -65,7 +65,7 @@ class CommentSerializer(serialization.BaseSerializer):
 | 
			
		||||
def serialize_comment(
 | 
			
		||||
        comment: model.Comment,
 | 
			
		||||
        auth_user: model.User,
 | 
			
		||||
        options: List[str]=[]) -> rest.Response:
 | 
			
		||||
        options: List[str] = []) -> rest.Response:
 | 
			
		||||
    if comment is None:
 | 
			
		||||
        return None
 | 
			
		||||
    return CommentSerializer(comment, auth_user).serialize(options)
 | 
			
		||||
@ -73,10 +73,11 @@ def serialize_comment(
 | 
			
		||||
 | 
			
		||||
def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]:
 | 
			
		||||
    comment_id = int(comment_id)
 | 
			
		||||
    return db.session \
 | 
			
		||||
        .query(model.Comment) \
 | 
			
		||||
        .filter(model.Comment.comment_id == comment_id) \
 | 
			
		||||
        .one_or_none()
 | 
			
		||||
    return (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.Comment)
 | 
			
		||||
        .filter(model.Comment.comment_id == comment_id)
 | 
			
		||||
        .one_or_none())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_comment_by_id(comment_id: int) -> model.Comment:
 | 
			
		||||
 | 
			
		||||
@ -99,7 +99,7 @@ def _normalize_and_threshold(
 | 
			
		||||
def _compute_grid_points(
 | 
			
		||||
        image: NpMatrix,
 | 
			
		||||
        n: float,
 | 
			
		||||
        window: Window=None) -> Tuple[NpMatrix, NpMatrix]:
 | 
			
		||||
        window: Window = None) -> Tuple[NpMatrix, NpMatrix]:
 | 
			
		||||
    if window is None:
 | 
			
		||||
        window = ((0, image.shape[0]), (0, image.shape[1]))
 | 
			
		||||
    x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1]
 | 
			
		||||
@ -219,7 +219,7 @@ def _max_contrast(array: NpMatrix) -> None:
 | 
			
		||||
def _normalized_distance(
 | 
			
		||||
        target_array: NpMatrix,
 | 
			
		||||
        vec: NpMatrix,
 | 
			
		||||
        nan_value: float=1.0) -> List[float]:
 | 
			
		||||
        nan_value: float = 1.0) -> List[float]:
 | 
			
		||||
    target_array = target_array.astype(int)
 | 
			
		||||
    vec = vec.astype(int)
 | 
			
		||||
    topvec = np.linalg.norm(vec - target_array, axis=1)
 | 
			
		||||
 | 
			
		||||
@ -11,8 +11,8 @@ from szurubooru.func import mime, util
 | 
			
		||||
logger = logging.getLogger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_SCALE_FIT_FMT = \
 | 
			
		||||
    r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)'
 | 
			
		||||
_SCALE_FIT_FMT = (
 | 
			
		||||
    r'scale=iw*max({width}/iw\,{height}/ih):ih*max({width}/iw\,{height}/ih)')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Image:
 | 
			
		||||
@ -77,7 +77,7 @@ class Image:
 | 
			
		||||
            '-',
 | 
			
		||||
        ])
 | 
			
		||||
 | 
			
		||||
    def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes:
 | 
			
		||||
    def _execute(self, cli: List[str], program: str = 'ffmpeg') -> bytes:
 | 
			
		||||
        extension = mime.get_extension(mime.get_mime_type(self.content))
 | 
			
		||||
        assert extension
 | 
			
		||||
        with util.create_temp_file(suffix='.' + extension) as handle:
 | 
			
		||||
 | 
			
		||||
@ -7,10 +7,10 @@ from szurubooru.func import (
 | 
			
		||||
    mime, images, files, image_hash, serialization)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
EMPTY_PIXEL = \
 | 
			
		||||
    b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
 | 
			
		||||
    b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
 | 
			
		||||
    b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
 | 
			
		||||
EMPTY_PIXEL = (
 | 
			
		||||
    b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
 | 
			
		||||
    b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
 | 
			
		||||
    b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PostNotFoundError(errors.NotFoundError):
 | 
			
		||||
@ -283,7 +283,7 @@ class PostSerializer(serialization.BaseSerializer):
 | 
			
		||||
def serialize_post(
 | 
			
		||||
        post: Optional[model.Post],
 | 
			
		||||
        auth_user: model.User,
 | 
			
		||||
        options: List[str]=[]) -> Optional[rest.Response]:
 | 
			
		||||
        options: List[str] = []) -> Optional[rest.Response]:
 | 
			
		||||
    if not post:
 | 
			
		||||
        return None
 | 
			
		||||
    return PostSerializer(post, auth_user).serialize(options)
 | 
			
		||||
@ -300,10 +300,11 @@ def get_post_count() -> int:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
 | 
			
		||||
    return db.session \
 | 
			
		||||
        .query(model.Post) \
 | 
			
		||||
        .filter(model.Post.post_id == post_id) \
 | 
			
		||||
        .one_or_none()
 | 
			
		||||
    return (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.Post)
 | 
			
		||||
        .filter(model.Post.post_id == post_id)
 | 
			
		||||
        .one_or_none())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_post_by_id(post_id: int) -> model.Post:
 | 
			
		||||
@ -314,10 +315,11 @@ def get_post_by_id(post_id: int) -> model.Post:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_get_current_post_feature() -> Optional[model.PostFeature]:
 | 
			
		||||
    return db.session \
 | 
			
		||||
        .query(model.PostFeature) \
 | 
			
		||||
        .order_by(model.PostFeature.time.desc()) \
 | 
			
		||||
        .first()
 | 
			
		||||
    return (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.PostFeature)
 | 
			
		||||
        .order_by(model.PostFeature.time.desc())
 | 
			
		||||
        .first())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_get_featured_post() -> Optional[model.Post]:
 | 
			
		||||
@ -426,11 +428,12 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
 | 
			
		||||
            'Unhandled file type: %r' % post.mime_type)
 | 
			
		||||
 | 
			
		||||
    post.checksum = util.get_sha1(content)
 | 
			
		||||
    other_post = db.session \
 | 
			
		||||
        .query(model.Post) \
 | 
			
		||||
        .filter(model.Post.checksum == post.checksum) \
 | 
			
		||||
        .filter(model.Post.post_id != post.post_id) \
 | 
			
		||||
        .one_or_none()
 | 
			
		||||
    other_post = (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.Post)
 | 
			
		||||
        .filter(model.Post.checksum == post.checksum)
 | 
			
		||||
        .filter(model.Post.post_id != post.post_id)
 | 
			
		||||
        .one_or_none())
 | 
			
		||||
    if other_post \
 | 
			
		||||
            and other_post.post_id \
 | 
			
		||||
            and other_post.post_id != post.post_id:
 | 
			
		||||
@ -452,7 +455,7 @@ def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def update_post_thumbnail(
 | 
			
		||||
        post: model.Post, content: Optional[bytes]=None) -> None:
 | 
			
		||||
        post: model.Post, content: Optional[bytes] = None) -> None:
 | 
			
		||||
    assert post
 | 
			
		||||
    setattr(post, '__thumbnail', content)
 | 
			
		||||
 | 
			
		||||
@ -492,10 +495,11 @@ def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
 | 
			
		||||
    old_posts = post.relations
 | 
			
		||||
    old_post_ids = [int(p.post_id) for p in old_posts]
 | 
			
		||||
    if new_post_ids:
 | 
			
		||||
        new_posts = db.session \
 | 
			
		||||
            .query(model.Post) \
 | 
			
		||||
            .filter(model.Post.post_id.in_(new_post_ids)) \
 | 
			
		||||
            .all()
 | 
			
		||||
        new_posts = (
 | 
			
		||||
            db.session
 | 
			
		||||
            .query(model.Post)
 | 
			
		||||
            .filter(model.Post.post_id.in_(new_post_ids))
 | 
			
		||||
            .all())
 | 
			
		||||
    else:
 | 
			
		||||
        new_posts = []
 | 
			
		||||
    if len(new_posts) != len(new_post_ids):
 | 
			
		||||
@ -673,10 +677,11 @@ def merge_posts(
 | 
			
		||||
 | 
			
		||||
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
 | 
			
		||||
    checksum = util.get_sha1(image_content)
 | 
			
		||||
    return db.session \
 | 
			
		||||
        .query(model.Post) \
 | 
			
		||||
        .filter(model.Post.checksum == checksum) \
 | 
			
		||||
        .one_or_none()
 | 
			
		||||
    return (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.Post)
 | 
			
		||||
        .filter(model.Post.checksum == checksum)
 | 
			
		||||
        .one_or_none())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def search_by_image(image_content: bytes) -> List[PostLookalike]:
 | 
			
		||||
 | 
			
		||||
@ -39,11 +39,12 @@ def get_score(entity: model.Base, user: model.User) -> int:
 | 
			
		||||
    assert entity
 | 
			
		||||
    assert user
 | 
			
		||||
    table, get_column = _get_table_info(entity)
 | 
			
		||||
    row = db.session \
 | 
			
		||||
        .query(table.score) \
 | 
			
		||||
        .filter(get_column(table) == get_column(entity)) \
 | 
			
		||||
        .filter(table.user_id == user.user_id) \
 | 
			
		||||
        .one_or_none()
 | 
			
		||||
    row = (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(table.score)
 | 
			
		||||
        .filter(get_column(table) == get_column(entity))
 | 
			
		||||
        .filter(table.user_id == user.user_id)
 | 
			
		||||
        .one_or_none())
 | 
			
		||||
    return row[0] if row else 0
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,5 @@
 | 
			
		||||
from typing import Any, Optional, List, Dict, Callable
 | 
			
		||||
from szurubooru import db, model, rest, errors
 | 
			
		||||
from typing import Any, List, Dict, Callable
 | 
			
		||||
from szurubooru import model, rest, errors
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_serialization_options(ctx: rest.Context) -> List[str]:
 | 
			
		||||
 | 
			
		||||
@ -66,7 +66,7 @@ class TagCategorySerializer(serialization.BaseSerializer):
 | 
			
		||||
 | 
			
		||||
def serialize_category(
 | 
			
		||||
        category: Optional[model.TagCategory],
 | 
			
		||||
        options: List[str]=[]) -> Optional[rest.Response]:
 | 
			
		||||
        options: List[str] = []) -> Optional[rest.Response]:
 | 
			
		||||
    if not category:
 | 
			
		||||
        return None
 | 
			
		||||
    return TagCategorySerializer(category).serialize(options)
 | 
			
		||||
@ -113,16 +113,17 @@ def update_category_color(category: model.TagCategory, color: str) -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_get_category_by_name(
 | 
			
		||||
        name: str, lock: bool=False) -> Optional[model.TagCategory]:
 | 
			
		||||
    query = db.session \
 | 
			
		||||
        .query(model.TagCategory) \
 | 
			
		||||
        .filter(sa.func.lower(model.TagCategory.name) == name.lower())
 | 
			
		||||
        name: str, lock: bool = False) -> Optional[model.TagCategory]:
 | 
			
		||||
    query = (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.TagCategory)
 | 
			
		||||
        .filter(sa.func.lower(model.TagCategory.name) == name.lower()))
 | 
			
		||||
    if lock:
 | 
			
		||||
        query = query.with_lockmode('update')
 | 
			
		||||
    return query.one_or_none()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory:
 | 
			
		||||
def get_category_by_name(name: str, lock: bool = False) -> model.TagCategory:
 | 
			
		||||
    category = try_get_category_by_name(name, lock)
 | 
			
		||||
    if not category:
 | 
			
		||||
        raise TagCategoryNotFoundError('Tag category %r not found.' % name)
 | 
			
		||||
@ -137,26 +138,29 @@ def get_all_categories() -> List[model.TagCategory]:
 | 
			
		||||
    return db.session.query(model.TagCategory).all()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]:
 | 
			
		||||
    query = db.session \
 | 
			
		||||
        .query(model.TagCategory) \
 | 
			
		||||
        .filter(model.TagCategory.default)
 | 
			
		||||
def try_get_default_category(
 | 
			
		||||
        lock: bool = False) -> Optional[model.TagCategory]:
 | 
			
		||||
    query = (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.TagCategory)
 | 
			
		||||
        .filter(model.TagCategory.default))
 | 
			
		||||
    if lock:
 | 
			
		||||
        query = query.with_lockmode('update')
 | 
			
		||||
    category = query.first()
 | 
			
		||||
    # if for some reason (e.g. as a result of migration) there's no default
 | 
			
		||||
    # category, get the first record available.
 | 
			
		||||
    if not category:
 | 
			
		||||
        query = db.session \
 | 
			
		||||
            .query(model.TagCategory) \
 | 
			
		||||
            .order_by(model.TagCategory.tag_category_id.asc())
 | 
			
		||||
        query = (
 | 
			
		||||
            db.session
 | 
			
		||||
            .query(model.TagCategory)
 | 
			
		||||
            .order_by(model.TagCategory.tag_category_id.asc()))
 | 
			
		||||
        if lock:
 | 
			
		||||
            query = query.with_lockmode('update')
 | 
			
		||||
        category = query.first()
 | 
			
		||||
    return category
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_default_category(lock: bool=False) -> model.TagCategory:
 | 
			
		||||
def get_default_category(lock: bool = False) -> model.TagCategory:
 | 
			
		||||
    category = try_get_default_category(lock)
 | 
			
		||||
    if not category:
 | 
			
		||||
        raise TagCategoryNotFoundError('No tag category created yet.')
 | 
			
		||||
 | 
			
		||||
@ -122,7 +122,7 @@ class TagSerializer(serialization.BaseSerializer):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def serialize_tag(
 | 
			
		||||
        tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]:
 | 
			
		||||
        tag: model.Tag, options: List[str] = []) -> Optional[rest.Response]:
 | 
			
		||||
    if not tag:
 | 
			
		||||
        return None
 | 
			
		||||
    return TagSerializer(tag).serialize(options)
 | 
			
		||||
@ -209,7 +209,8 @@ def get_tags_by_names(names: List[str]) -> List[model.Tag]:
 | 
			
		||||
    names = util.icase_unique(names)
 | 
			
		||||
    if len(names) == 0:
 | 
			
		||||
        return []
 | 
			
		||||
    return (db.session.query(model.Tag)
 | 
			
		||||
    return (
 | 
			
		||||
        db.session.query(model.Tag)
 | 
			
		||||
        .join(model.TagName)
 | 
			
		||||
        .filter(
 | 
			
		||||
            sa.sql.or_(
 | 
			
		||||
 | 
			
		||||
@ -86,7 +86,7 @@ class UserSerializer(serialization.BaseSerializer):
 | 
			
		||||
            self,
 | 
			
		||||
            user: model.User,
 | 
			
		||||
            auth_user: model.User,
 | 
			
		||||
            force_show_email: bool=False) -> None:
 | 
			
		||||
            force_show_email: bool = False) -> None:
 | 
			
		||||
        self.user = user
 | 
			
		||||
        self.auth_user = auth_user
 | 
			
		||||
        self.force_show_email = force_show_email
 | 
			
		||||
@ -151,8 +151,8 @@ class UserSerializer(serialization.BaseSerializer):
 | 
			
		||||
def serialize_user(
 | 
			
		||||
        user: Optional[model.User],
 | 
			
		||||
        auth_user: model.User,
 | 
			
		||||
        options: List[str]=[],
 | 
			
		||||
        force_show_email: bool=False) -> Optional[rest.Response]:
 | 
			
		||||
        options: List[str] = [],
 | 
			
		||||
        force_show_email: bool = False) -> Optional[rest.Response]:
 | 
			
		||||
    if not user:
 | 
			
		||||
        return None
 | 
			
		||||
    return UserSerializer(user, auth_user, force_show_email).serialize(options)
 | 
			
		||||
@ -170,10 +170,11 @@ def get_user_count() -> int:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def try_get_user_by_name(name: str) -> Optional[model.User]:
 | 
			
		||||
    return db.session \
 | 
			
		||||
        .query(model.User) \
 | 
			
		||||
        .filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \
 | 
			
		||||
        .one_or_none()
 | 
			
		||||
    return (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.User)
 | 
			
		||||
        .filter(sa.func.lower(model.User.name) == sa.func.lower(name))
 | 
			
		||||
        .one_or_none())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_user_by_name(name: str) -> model.User:
 | 
			
		||||
@ -276,7 +277,7 @@ def update_user_rank(
 | 
			
		||||
def update_user_avatar(
 | 
			
		||||
        user: model.User,
 | 
			
		||||
        avatar_style: str,
 | 
			
		||||
        avatar_content: Optional[bytes]=None) -> None:
 | 
			
		||||
        avatar_content: Optional[bytes] = None) -> None:
 | 
			
		||||
    assert user
 | 
			
		||||
    if avatar_style == 'gravatar':
 | 
			
		||||
        user.avatar_style = user.AVATAR_GRAVATAR
 | 
			
		||||
 | 
			
		||||
@ -2,8 +2,7 @@ import os
 | 
			
		||||
import hashlib
 | 
			
		||||
import re
 | 
			
		||||
import tempfile
 | 
			
		||||
from typing import (
 | 
			
		||||
    Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar)
 | 
			
		||||
from typing import Any, Optional, Union, Tuple, List, Dict, Generator, TypeVar
 | 
			
		||||
from datetime import datetime, timedelta
 | 
			
		||||
from contextlib import contextmanager
 | 
			
		||||
from szurubooru import errors
 | 
			
		||||
 | 
			
		||||
@ -4,7 +4,7 @@ from szurubooru import errors, rest, model
 | 
			
		||||
def verify_version(
 | 
			
		||||
        entity: model.Base,
 | 
			
		||||
        context: rest.Context,
 | 
			
		||||
        field_name: str='version') -> None:
 | 
			
		||||
        field_name: str = 'version') -> None:
 | 
			
		||||
    actual_version = context.get_param_as_int(field_name)
 | 
			
		||||
    expected_version = entity.version
 | 
			
		||||
    if actual_version != expected_version:
 | 
			
		||||
 | 
			
		||||
@ -27,8 +27,9 @@ def _get_user(ctx: rest.Context) -> Optional[model.User]:
 | 
			
		||||
            credentials.encode('ascii')).decode('utf8').split(':')
 | 
			
		||||
        return _authenticate(username, password)
 | 
			
		||||
    except ValueError as err:
 | 
			
		||||
        msg = 'Basic authentication header value are not properly formed. ' \
 | 
			
		||||
            + 'Supplied header {0}. Got error: {1}'
 | 
			
		||||
        msg = (
 | 
			
		||||
            'Basic authentication header value are not properly formed. '
 | 
			
		||||
            'Supplied header {0}. Got error: {1}')
 | 
			
		||||
        raise HttpBadRequest(
 | 
			
		||||
            'ValidationError',
 | 
			
		||||
            msg.format(ctx.get_header('Authorization'), str(err)))
 | 
			
		||||
 | 
			
		||||
@ -50,13 +50,14 @@ def upgrade():
 | 
			
		||||
 | 
			
		||||
def downgrade():
 | 
			
		||||
    session = sa.orm.session.Session(bind=op.get_bind())
 | 
			
		||||
    default_category = session \
 | 
			
		||||
        .query(TagCategory) \
 | 
			
		||||
        .filter(TagCategory.name == 'default') \
 | 
			
		||||
        .filter(TagCategory.color == 'default') \
 | 
			
		||||
        .filter(TagCategory.version == 1) \
 | 
			
		||||
        .filter(TagCategory.default == True) \
 | 
			
		||||
        .one_or_none()
 | 
			
		||||
    default_category = (
 | 
			
		||||
        session
 | 
			
		||||
        .query(TagCategory)
 | 
			
		||||
        .filter(TagCategory.name == 'default')
 | 
			
		||||
        .filter(TagCategory.color == 'default')
 | 
			
		||||
        .filter(TagCategory.version == 1)
 | 
			
		||||
        .filter(TagCategory.default == 1)
 | 
			
		||||
        .one_or_none())
 | 
			
		||||
    if default_category:
 | 
			
		||||
        session.delete(default_category)
 | 
			
		||||
    session.commit()
 | 
			
		||||
 | 
			
		||||
@ -211,10 +211,11 @@ class Post(Base):
 | 
			
		||||
 | 
			
		||||
    @property
 | 
			
		||||
    def is_featured(self) -> bool:
 | 
			
		||||
        featured_post = sa.orm.object_session(self) \
 | 
			
		||||
            .query(PostFeature) \
 | 
			
		||||
            .order_by(PostFeature.time.desc()) \
 | 
			
		||||
            .first()
 | 
			
		||||
        featured_post = (
 | 
			
		||||
            sa.orm.object_session(self)
 | 
			
		||||
            .query(PostFeature)
 | 
			
		||||
            .order_by(PostFeature.time.desc())
 | 
			
		||||
            .first())
 | 
			
		||||
        return featured_post and featured_post.post_id == self.post_id
 | 
			
		||||
 | 
			
		||||
    score = sa.orm.column_property(
 | 
			
		||||
 | 
			
		||||
@ -14,7 +14,7 @@ class TagCategory(Base):
 | 
			
		||||
        'color', sa.Unicode(32), nullable=False, default='#000000')
 | 
			
		||||
    default = sa.Column('default', sa.Boolean, nullable=False, default=False)
 | 
			
		||||
 | 
			
		||||
    def __init__(self, name: Optional[str]=None) -> None:
 | 
			
		||||
    def __init__(self, name: Optional[str] = None) -> None:
 | 
			
		||||
        self.name = name
 | 
			
		||||
 | 
			
		||||
    tag_count = sa.orm.column_property(
 | 
			
		||||
 | 
			
		||||
@ -13,9 +13,9 @@ class Context:
 | 
			
		||||
            self,
 | 
			
		||||
            method: str,
 | 
			
		||||
            url: str,
 | 
			
		||||
            headers: Dict[str, str]=None,
 | 
			
		||||
            params: Request=None,
 | 
			
		||||
            files: Dict[str, bytes]=None) -> None:
 | 
			
		||||
            headers: Dict[str, str] = None,
 | 
			
		||||
            params: Request = None,
 | 
			
		||||
            files: Dict[str, bytes] = None) -> None:
 | 
			
		||||
        self.method = method
 | 
			
		||||
        self.url = url
 | 
			
		||||
        self._headers = headers or {}
 | 
			
		||||
@ -34,7 +34,7 @@ class Context:
 | 
			
		||||
    def get_header(self, name: str) -> str:
 | 
			
		||||
        return self._headers.get(name, '')
 | 
			
		||||
 | 
			
		||||
    def has_file(self, name: str, allow_tokens: bool=True) -> bool:
 | 
			
		||||
    def has_file(self, name: str, allow_tokens: bool = True) -> bool:
 | 
			
		||||
        return (
 | 
			
		||||
            name in self._files or
 | 
			
		||||
            name + 'Url' in self._params or
 | 
			
		||||
@ -43,8 +43,8 @@ class Context:
 | 
			
		||||
    def get_file(
 | 
			
		||||
            self,
 | 
			
		||||
            name: str,
 | 
			
		||||
            default: Union[object, bytes]=MISSING,
 | 
			
		||||
            allow_tokens: bool=True) -> bytes:
 | 
			
		||||
            default: Union[object, bytes] = MISSING,
 | 
			
		||||
            allow_tokens: bool = True) -> bytes:
 | 
			
		||||
        if name in self._files and self._files[name]:
 | 
			
		||||
            return self._files[name]
 | 
			
		||||
 | 
			
		||||
@ -70,7 +70,7 @@ class Context:
 | 
			
		||||
    def get_param_as_list(
 | 
			
		||||
            self,
 | 
			
		||||
            name: str,
 | 
			
		||||
            default: Union[object, List[Any]]=MISSING) -> List[Any]:
 | 
			
		||||
            default: Union[object, List[Any]] = MISSING) -> List[Any]:
 | 
			
		||||
        if name not in self._params:
 | 
			
		||||
            if default is not MISSING:
 | 
			
		||||
                return cast(List[Any], default)
 | 
			
		||||
@ -89,7 +89,7 @@ class Context:
 | 
			
		||||
    def get_param_as_int_list(
 | 
			
		||||
            self,
 | 
			
		||||
            name: str,
 | 
			
		||||
            default: Union[object, List[int]]=MISSING) -> List[int]:
 | 
			
		||||
            default: Union[object, List[int]] = MISSING) -> List[int]:
 | 
			
		||||
        ret = self.get_param_as_list(name, default)
 | 
			
		||||
        for item in ret:
 | 
			
		||||
            if type(item) is not int:
 | 
			
		||||
@ -100,7 +100,7 @@ class Context:
 | 
			
		||||
    def get_param_as_string_list(
 | 
			
		||||
            self,
 | 
			
		||||
            name: str,
 | 
			
		||||
            default: Union[object, List[str]]=MISSING) -> List[str]:
 | 
			
		||||
            default: Union[object, List[str]] = MISSING) -> List[str]:
 | 
			
		||||
        ret = self.get_param_as_list(name, default)
 | 
			
		||||
        for item in ret:
 | 
			
		||||
            if type(item) is not str:
 | 
			
		||||
@ -111,7 +111,7 @@ class Context:
 | 
			
		||||
    def get_param_as_string(
 | 
			
		||||
            self,
 | 
			
		||||
            name: str,
 | 
			
		||||
            default: Union[object, str]=MISSING) -> str:
 | 
			
		||||
            default: Union[object, str] = MISSING) -> str:
 | 
			
		||||
        if name not in self._params:
 | 
			
		||||
            if default is not MISSING:
 | 
			
		||||
                return cast(str, default)
 | 
			
		||||
@ -135,9 +135,9 @@ class Context:
 | 
			
		||||
    def get_param_as_int(
 | 
			
		||||
            self,
 | 
			
		||||
            name: str,
 | 
			
		||||
            default: Union[object, int]=MISSING,
 | 
			
		||||
            min: Optional[int]=None,
 | 
			
		||||
            max: Optional[int]=None) -> int:
 | 
			
		||||
            default: Union[object, int] = MISSING,
 | 
			
		||||
            min: Optional[int] = None,
 | 
			
		||||
            max: Optional[int] = None) -> int:
 | 
			
		||||
        if name not in self._params:
 | 
			
		||||
            if default is not MISSING:
 | 
			
		||||
                return cast(int, default)
 | 
			
		||||
@ -161,7 +161,7 @@ class Context:
 | 
			
		||||
    def get_param_as_bool(
 | 
			
		||||
            self,
 | 
			
		||||
            name: str,
 | 
			
		||||
            default: Union[object, bool]=MISSING) -> bool:
 | 
			
		||||
            default: Union[object, bool] = MISSING) -> bool:
 | 
			
		||||
        if name not in self._params:
 | 
			
		||||
            if default is not MISSING:
 | 
			
		||||
                return cast(bool, default)
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
from typing import Callable, Type, Dict
 | 
			
		||||
from typing import Optional, Callable, Type, Dict
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
error_handlers = {}  # pylint: disable=invalid-name
 | 
			
		||||
@ -12,8 +12,8 @@ class BaseHttpError(RuntimeError):
 | 
			
		||||
            self,
 | 
			
		||||
            name: str,
 | 
			
		||||
            description: str,
 | 
			
		||||
            title: str=None,
 | 
			
		||||
            extra_fields: Dict[str, str]=None) -> None:
 | 
			
		||||
            title: Optional[str] = None,
 | 
			
		||||
            extra_fields: Optional[Dict[str, str]] = None) -> None:
 | 
			
		||||
        super().__init__()
 | 
			
		||||
        # error name for programmers
 | 
			
		||||
        self.name = name
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
from typing import Callable
 | 
			
		||||
from typing import List, Callable
 | 
			
		||||
from szurubooru.rest.context import Context
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
from typing import Callable, Dict, Any
 | 
			
		||||
from typing import Callable, Dict
 | 
			
		||||
from collections import defaultdict
 | 
			
		||||
from szurubooru.rest.context import Context, Response
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -52,10 +52,11 @@ def _create_score_filter(score: int) -> Filter:
 | 
			
		||||
            user_alias.name, criterion)
 | 
			
		||||
        if negated:
 | 
			
		||||
            expr = ~expr
 | 
			
		||||
        ret = query \
 | 
			
		||||
            .join(score_alias, score_alias.post_id == model.Post.post_id) \
 | 
			
		||||
            .join(user_alias, user_alias.user_id == score_alias.user_id) \
 | 
			
		||||
            .filter(expr)
 | 
			
		||||
        ret = (
 | 
			
		||||
            query
 | 
			
		||||
            .join(score_alias, score_alias.post_id == model.Post.post_id)
 | 
			
		||||
            .join(user_alias, user_alias.user_id == score_alias.user_id)
 | 
			
		||||
            .filter(expr))
 | 
			
		||||
        return ret
 | 
			
		||||
    return wrapper
 | 
			
		||||
 | 
			
		||||
@ -124,7 +125,8 @@ class PostSearchConfig(BaseSearchConfig):
 | 
			
		||||
            sa.orm.lazyload
 | 
			
		||||
            if disable_eager_loads
 | 
			
		||||
            else sa.orm.subqueryload)
 | 
			
		||||
        return db.session.query(model.Post) \
 | 
			
		||||
        return (
 | 
			
		||||
            db.session.query(model.Post)
 | 
			
		||||
            .options(
 | 
			
		||||
                sa.orm.lazyload('*'),
 | 
			
		||||
                # use config optimized for official client
 | 
			
		||||
@ -141,7 +143,7 @@ class PostSearchConfig(BaseSearchConfig):
 | 
			
		||||
                strategy(model.Post.tags).subqueryload(model.Tag.names),
 | 
			
		||||
                strategy(model.Post.tags).defer(model.Tag.post_count),
 | 
			
		||||
                strategy(model.Post.tags).lazyload(model.Tag.implications),
 | 
			
		||||
                strategy(model.Post.tags).lazyload(model.Tag.suggestions))
 | 
			
		||||
                strategy(model.Post.tags).lazyload(model.Tag.suggestions)))
 | 
			
		||||
 | 
			
		||||
    def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
 | 
			
		||||
        return db.session.query(model.Post)
 | 
			
		||||
 | 
			
		||||
@ -14,8 +14,9 @@ class TagSearchConfig(BaseSearchConfig):
 | 
			
		||||
            sa.orm.lazyload
 | 
			
		||||
            if _disable_eager_loads
 | 
			
		||||
            else sa.orm.subqueryload)
 | 
			
		||||
        return db.session.query(model.Tag) \
 | 
			
		||||
            .join(model.TagCategory) \
 | 
			
		||||
        return (
 | 
			
		||||
            db.session.query(model.Tag)
 | 
			
		||||
            .join(model.TagCategory)
 | 
			
		||||
            .options(
 | 
			
		||||
                sa.orm.defer(model.Tag.first_name),
 | 
			
		||||
                sa.orm.defer(model.Tag.suggestion_count),
 | 
			
		||||
@ -23,7 +24,7 @@ class TagSearchConfig(BaseSearchConfig):
 | 
			
		||||
                sa.orm.defer(model.Tag.post_count),
 | 
			
		||||
                strategy(model.Tag.names),
 | 
			
		||||
                strategy(model.Tag.suggestions).joinedload(model.Tag.names),
 | 
			
		||||
                strategy(model.Tag.implications).joinedload(model.Tag.names))
 | 
			
		||||
                strategy(model.Tag.implications).joinedload(model.Tag.names)))
 | 
			
		||||
 | 
			
		||||
    def create_count_query(self, _disable_eager_loads: bool) -> SaQuery:
 | 
			
		||||
        return db.session.query(model.Tag)
 | 
			
		||||
 | 
			
		||||
@ -69,7 +69,7 @@ def float_transformer(value: str) -> float:
 | 
			
		||||
def apply_num_criterion_to_column(
 | 
			
		||||
        column: Any,
 | 
			
		||||
        criterion: criteria.BaseCriterion,
 | 
			
		||||
        transformer: Callable[[str], Number]=integer_transformer) -> SaQuery:
 | 
			
		||||
        transformer: Callable[[str], Number] = integer_transformer) -> SaQuery:
 | 
			
		||||
    try:
 | 
			
		||||
        if isinstance(criterion, criteria.PlainCriterion):
 | 
			
		||||
            expr = column == transformer(criterion.value)
 | 
			
		||||
@ -95,7 +95,7 @@ def apply_num_criterion_to_column(
 | 
			
		||||
 | 
			
		||||
def create_num_filter(
 | 
			
		||||
        column: Any,
 | 
			
		||||
        transformer: Callable[[str], Number]=integer_transformer) -> SaQuery:
 | 
			
		||||
        transformer: Callable[[str], Number] = integer_transformer) -> SaQuery:
 | 
			
		||||
    def wrapper(
 | 
			
		||||
            query: SaQuery,
 | 
			
		||||
            criterion: Optional[criteria.BaseCriterion],
 | 
			
		||||
@ -111,7 +111,7 @@ def create_num_filter(
 | 
			
		||||
def apply_str_criterion_to_column(
 | 
			
		||||
        column: SaColumn,
 | 
			
		||||
        criterion: criteria.BaseCriterion,
 | 
			
		||||
        transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery:
 | 
			
		||||
        transformer: Callable[[str], str] = wildcard_transformer) -> SaQuery:
 | 
			
		||||
    if isinstance(criterion, criteria.PlainCriterion):
 | 
			
		||||
        expr = column.ilike(transformer(criterion.value))
 | 
			
		||||
    elif isinstance(criterion, criteria.ArrayCriterion):
 | 
			
		||||
@ -128,8 +128,8 @@ def apply_str_criterion_to_column(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_str_filter(
 | 
			
		||||
    column: SaColumn, transformer: Callable[[str], str]=wildcard_transformer
 | 
			
		||||
) -> Filter:
 | 
			
		||||
        column: SaColumn,
 | 
			
		||||
        transformer: Callable[[str], str] = wildcard_transformer) -> Filter:
 | 
			
		||||
    def wrapper(
 | 
			
		||||
            query: SaQuery,
 | 
			
		||||
            criterion: Optional[criteria.BaseCriterion],
 | 
			
		||||
@ -187,7 +187,7 @@ def create_subquery_filter(
 | 
			
		||||
        right_id_column: SaColumn,
 | 
			
		||||
        filter_column: SaColumn,
 | 
			
		||||
        filter_factory: SaColumn,
 | 
			
		||||
        subquery_decorator: Callable[[SaQuery], None]=None) -> Filter:
 | 
			
		||||
        subquery_decorator: Callable[[SaQuery], None] = None) -> Filter:
 | 
			
		||||
    filter_func = filter_factory(filter_column)
 | 
			
		||||
 | 
			
		||||
    def wrapper(
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,4 @@
 | 
			
		||||
from typing import Optional, List, Callable
 | 
			
		||||
from typing import Optional, List
 | 
			
		||||
from szurubooru.search.typing import SaQuery
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -100,18 +100,20 @@ class Executor:
 | 
			
		||||
        filter_query = self.config.create_filter_query(disable_eager_loads)
 | 
			
		||||
        filter_query = filter_query.options(sa.orm.lazyload('*'))
 | 
			
		||||
        filter_query = self._prepare_db_query(filter_query, search_query, True)
 | 
			
		||||
        entities = filter_query \
 | 
			
		||||
            .offset(offset) \
 | 
			
		||||
            .limit(limit) \
 | 
			
		||||
            .all()
 | 
			
		||||
        entities = (
 | 
			
		||||
            filter_query
 | 
			
		||||
            .offset(offset)
 | 
			
		||||
            .limit(limit)
 | 
			
		||||
            .all())
 | 
			
		||||
 | 
			
		||||
        count_query = self.config.create_count_query(disable_eager_loads)
 | 
			
		||||
        count_query = count_query.options(sa.orm.lazyload('*'))
 | 
			
		||||
        count_query = self._prepare_db_query(count_query, search_query, False)
 | 
			
		||||
        count_statement = count_query \
 | 
			
		||||
            .statement \
 | 
			
		||||
            .with_only_columns([sa.func.count()]) \
 | 
			
		||||
            .order_by(None)
 | 
			
		||||
        count_statement = (
 | 
			
		||||
            count_query
 | 
			
		||||
            .statement
 | 
			
		||||
            .with_only_columns([sa.func.count()])
 | 
			
		||||
            .order_by(None))
 | 
			
		||||
        count = db.session.execute(count_statement).scalar()
 | 
			
		||||
 | 
			
		||||
        ret = (count, entities)
 | 
			
		||||
 | 
			
		||||
@ -1,5 +1,4 @@
 | 
			
		||||
import re
 | 
			
		||||
from typing import Match, List
 | 
			
		||||
from szurubooru import errors
 | 
			
		||||
from szurubooru.search import criteria, tokens
 | 
			
		||||
from szurubooru.search.query import SearchQuery
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,5 @@
 | 
			
		||||
from szurubooru.search import tokens
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class SearchQuery:
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
import pytest
 | 
			
		||||
from szurubooru import api, db, model, errors
 | 
			
		||||
from szurubooru import api, model, errors
 | 
			
		||||
from szurubooru.func import tags, snapshots
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
from unittest.mock import patch
 | 
			
		||||
import pytest
 | 
			
		||||
from szurubooru import api, db, model, errors
 | 
			
		||||
from szurubooru import api, model, errors
 | 
			
		||||
from szurubooru.func import users
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -5,10 +5,10 @@ from szurubooru import db, model, errors
 | 
			
		||||
from szurubooru.func import auth, users, files, util
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
EMPTY_PIXEL = \
 | 
			
		||||
    b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
 | 
			
		||||
    b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
 | 
			
		||||
    b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
 | 
			
		||||
EMPTY_PIXEL = (
 | 
			
		||||
    b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
 | 
			
		||||
    b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
 | 
			
		||||
    b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@pytest.mark.parametrize('user_name', ['test', 'TEST'])
 | 
			
		||||
 | 
			
		||||
@ -28,11 +28,12 @@ def test_saving_tag(tag_factory):
 | 
			
		||||
    tag.implications.append(imp2)
 | 
			
		||||
    db.session.commit()
 | 
			
		||||
 | 
			
		||||
    tag = db.session \
 | 
			
		||||
        .query(model.Tag) \
 | 
			
		||||
        .join(model.TagName) \
 | 
			
		||||
        .filter(model.TagName.name == 'alias1') \
 | 
			
		||||
        .one()
 | 
			
		||||
    tag = (
 | 
			
		||||
        db.session
 | 
			
		||||
        .query(model.Tag)
 | 
			
		||||
        .join(model.TagName)
 | 
			
		||||
        .filter(model.TagName.name == 'alias1')
 | 
			
		||||
        .one())
 | 
			
		||||
    assert [tag_name.name for tag_name in tag.names] == ['alias1', 'alias2']
 | 
			
		||||
    assert tag.category.name == 'category'
 | 
			
		||||
    assert tag.creation_time == datetime(1997, 1, 1)
 | 
			
		||||
 | 
			
		||||
@ -300,7 +300,7 @@ def test_filter_by_note_count(
 | 
			
		||||
    ('note-text:text3*', [3]),
 | 
			
		||||
    ('note-text:text3a,text2', [2, 3]),
 | 
			
		||||
])
 | 
			
		||||
def test_filter_by_note_count(
 | 
			
		||||
def test_filter_by_note_text(
 | 
			
		||||
        verify_unpaged, post_factory, note_factory, input, expected_post_ids):
 | 
			
		||||
    post1 = post_factory(id=1)
 | 
			
		||||
    post2 = post_factory(id=2)
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user