server: refactor + add type hinting
- Added type hinting (for now, 3.5-compatible) - Split `db` namespace into `db` module and `model` namespace - Changed elastic search to be created lazily for each operation - Changed to class based approach in entity serialization to allow stronger typing - Removed `required` argument from `context.get_*` family of functions; now it's implied if `default` argument is omitted - Changed `unalias_dict` implementation to use less magic inputs
This commit is contained in:
		
							parent
							
								
									abf1fc2b2d
								
							
						
					
					
						commit
						ad842ee8a5
					
				| @ -8,7 +8,7 @@ import zlib | ||||
| import concurrent.futures | ||||
| import logging | ||||
| import coloredlogs | ||||
| import sqlalchemy | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import config, db | ||||
| from szurubooru.func import files, images, posts, comments | ||||
| 
 | ||||
| @ -42,8 +42,8 @@ def get_v1_session(args): | ||||
|             port=args.port, | ||||
|             name=args.name) | ||||
|     logger.info('Connecting to %r...', dsn) | ||||
|     engine = sqlalchemy.create_engine(dsn) | ||||
|     session_maker = sqlalchemy.orm.sessionmaker(bind=engine) | ||||
|     engine = sa.create_engine(dsn) | ||||
|     session_maker = sa.orm.sessionmaker(bind=engine) | ||||
|     return session_maker() | ||||
| 
 | ||||
| def parse_args(): | ||||
|  | ||||
							
								
								
									
										14
									
								
								server/mypy.ini
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										14
									
								
								server/mypy.ini
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,14 @@ | ||||
| [mypy] | ||||
| ignore_missing_imports = True | ||||
| follow_imports = skip | ||||
| disallow_untyped_calls = True | ||||
| disallow_untyped_defs = True | ||||
| check_untyped_defs = True | ||||
| disallow_subclassing_any = False | ||||
| warn_redundant_casts = True | ||||
| warn_unused_ignores = True | ||||
| strict_optional = True | ||||
| strict_boolean = False | ||||
| 
 | ||||
| [mypy-szurubooru.tests.*] | ||||
| ignore_errors=True | ||||
| @ -1,31 +1,44 @@ | ||||
| import datetime | ||||
| from szurubooru import search | ||||
| from szurubooru.rest import routes | ||||
| from szurubooru.func import auth, comments, posts, scores, util, versions | ||||
| from typing import Dict | ||||
| from datetime import datetime | ||||
| from szurubooru import search, rest, model | ||||
| from szurubooru.func import ( | ||||
|     auth, comments, posts, scores, versions, serialization) | ||||
| 
 | ||||
| 
 | ||||
| _search_executor = search.Executor(search.configs.CommentSearchConfig()) | ||||
| 
 | ||||
| 
 | ||||
| def _serialize(ctx, comment, **kwargs): | ||||
| def _get_comment(params: Dict[str, str]) -> model.Comment: | ||||
|     try: | ||||
|         comment_id = int(params['comment_id']) | ||||
|     except TypeError: | ||||
|         raise comments.InvalidCommentIdError( | ||||
|             'Invalid comment ID: %r.' % params['comment_id']) | ||||
|     return comments.get_comment_by_id(comment_id) | ||||
| 
 | ||||
| 
 | ||||
| def _serialize( | ||||
|         ctx: rest.Context, comment: model.Comment) -> rest.Response: | ||||
|     return comments.serialize_comment( | ||||
|         comment, | ||||
|         ctx.user, | ||||
|         options=util.get_serialization_options(ctx), **kwargs) | ||||
|         options=serialization.get_serialization_options(ctx)) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/comments/?') | ||||
| def get_comments(ctx, _params=None): | ||||
| @rest.routes.get('/comments/?') | ||||
| def get_comments( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'comments:list') | ||||
|     return _search_executor.execute_and_serialize( | ||||
|         ctx, lambda comment: _serialize(ctx, comment)) | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/comments/?') | ||||
| def create_comment(ctx, _params=None): | ||||
| @rest.routes.post('/comments/?') | ||||
| def create_comment( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'comments:create') | ||||
|     text = ctx.get_param_as_string('text', required=True) | ||||
|     post_id = ctx.get_param_as_int('postId', required=True) | ||||
|     text = ctx.get_param_as_string('text') | ||||
|     post_id = ctx.get_param_as_int('postId') | ||||
|     post = posts.get_post_by_id(post_id) | ||||
|     comment = comments.create_comment(ctx.user, post, text) | ||||
|     ctx.session.add(comment) | ||||
| @ -33,30 +46,30 @@ def create_comment(ctx, _params=None): | ||||
|     return _serialize(ctx, comment) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/comment/(?P<comment_id>[^/]+)/?') | ||||
| def get_comment(ctx, params): | ||||
| @rest.routes.get('/comment/(?P<comment_id>[^/]+)/?') | ||||
| def get_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'comments:view') | ||||
|     comment = comments.get_comment_by_id(params['comment_id']) | ||||
|     comment = _get_comment(params) | ||||
|     return _serialize(ctx, comment) | ||||
| 
 | ||||
| 
 | ||||
| @routes.put('/comment/(?P<comment_id>[^/]+)/?') | ||||
| def update_comment(ctx, params): | ||||
|     comment = comments.get_comment_by_id(params['comment_id']) | ||||
| @rest.routes.put('/comment/(?P<comment_id>[^/]+)/?') | ||||
| def update_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     comment = _get_comment(params) | ||||
|     versions.verify_version(comment, ctx) | ||||
|     versions.bump_version(comment) | ||||
|     infix = 'own' if ctx.user.user_id == comment.user_id else 'any' | ||||
|     text = ctx.get_param_as_string('text', required=True) | ||||
|     text = ctx.get_param_as_string('text') | ||||
|     auth.verify_privilege(ctx.user, 'comments:edit:%s' % infix) | ||||
|     comments.update_comment_text(comment, text) | ||||
|     comment.last_edit_time = datetime.datetime.utcnow() | ||||
|     comment.last_edit_time = datetime.utcnow() | ||||
|     ctx.session.commit() | ||||
|     return _serialize(ctx, comment) | ||||
| 
 | ||||
| 
 | ||||
| @routes.delete('/comment/(?P<comment_id>[^/]+)/?') | ||||
| def delete_comment(ctx, params): | ||||
|     comment = comments.get_comment_by_id(params['comment_id']) | ||||
| @rest.routes.delete('/comment/(?P<comment_id>[^/]+)/?') | ||||
| def delete_comment(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     comment = _get_comment(params) | ||||
|     versions.verify_version(comment, ctx) | ||||
|     infix = 'own' if ctx.user.user_id == comment.user_id else 'any' | ||||
|     auth.verify_privilege(ctx.user, 'comments:delete:%s' % infix) | ||||
| @ -65,20 +78,22 @@ def delete_comment(ctx, params): | ||||
|     return {} | ||||
| 
 | ||||
| 
 | ||||
| @routes.put('/comment/(?P<comment_id>[^/]+)/score/?') | ||||
| def set_comment_score(ctx, params): | ||||
| @rest.routes.put('/comment/(?P<comment_id>[^/]+)/score/?') | ||||
| def set_comment_score( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'comments:score') | ||||
|     score = ctx.get_param_as_int('score', required=True) | ||||
|     comment = comments.get_comment_by_id(params['comment_id']) | ||||
|     score = ctx.get_param_as_int('score') | ||||
|     comment = _get_comment(params) | ||||
|     scores.set_score(comment, ctx.user, score) | ||||
|     ctx.session.commit() | ||||
|     return _serialize(ctx, comment) | ||||
| 
 | ||||
| 
 | ||||
| @routes.delete('/comment/(?P<comment_id>[^/]+)/score/?') | ||||
| def delete_comment_score(ctx, params): | ||||
| @rest.routes.delete('/comment/(?P<comment_id>[^/]+)/score/?') | ||||
| def delete_comment_score( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'comments:score') | ||||
|     comment = comments.get_comment_by_id(params['comment_id']) | ||||
|     comment = _get_comment(params) | ||||
|     scores.delete_score(comment, ctx.user) | ||||
|     ctx.session.commit() | ||||
|     return _serialize(ctx, comment) | ||||
|  | ||||
| @ -1,19 +1,20 @@ | ||||
| import datetime | ||||
| import os | ||||
| from szurubooru import config | ||||
| from szurubooru.rest import routes | ||||
| from typing import Optional, Dict | ||||
| from datetime import datetime, timedelta | ||||
| from szurubooru import config, rest | ||||
| from szurubooru.func import posts, users, util | ||||
| 
 | ||||
| 
 | ||||
| _cache_time = None | ||||
| _cache_result = None | ||||
| _cache_time = None  # type: Optional[datetime] | ||||
| _cache_result = None  # type: Optional[int] | ||||
| 
 | ||||
| 
 | ||||
| def _get_disk_usage(): | ||||
| def _get_disk_usage() -> int: | ||||
|     global _cache_time, _cache_result  # pylint: disable=global-statement | ||||
|     threshold = datetime.timedelta(hours=48) | ||||
|     now = datetime.datetime.utcnow() | ||||
|     threshold = timedelta(hours=48) | ||||
|     now = datetime.utcnow() | ||||
|     if _cache_time and _cache_time > now - threshold: | ||||
|         assert _cache_result | ||||
|         return _cache_result | ||||
|     total_size = 0 | ||||
|     for dir_path, _, file_names in os.walk(config.config['data_dir']): | ||||
| @ -25,8 +26,9 @@ def _get_disk_usage(): | ||||
|     return total_size | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/info/?') | ||||
| def get_info(ctx, _params=None): | ||||
| @rest.routes.get('/info/?') | ||||
| def get_info( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     post_feature = posts.try_get_current_post_feature() | ||||
|     return { | ||||
|         'postCount': posts.get_post_count(), | ||||
| @ -38,7 +40,7 @@ def get_info(ctx, _params=None): | ||||
|         'featuringUser': | ||||
|             users.serialize_user(post_feature.user, ctx.user) | ||||
|             if post_feature else None, | ||||
|         'serverTime': datetime.datetime.utcnow(), | ||||
|         'serverTime': datetime.utcnow(), | ||||
|         'config': { | ||||
|             'userNameRegex': config.config['user_name_regex'], | ||||
|             'passwordRegex': config.config['password_regex'], | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| from szurubooru import config, errors | ||||
| from szurubooru.rest import routes | ||||
| from typing import Dict | ||||
| from szurubooru import config, errors, rest | ||||
| from szurubooru.func import auth, mailer, users, versions | ||||
| 
 | ||||
| 
 | ||||
| @ -10,9 +10,9 @@ MAIL_BODY = \ | ||||
|     'Otherwise, please ignore this email.' | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/password-reset/(?P<user_name>[^/]+)/?') | ||||
| def start_password_reset(_ctx, params): | ||||
|     ''' Send a mail with secure token to the correlated user. ''' | ||||
| @rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?') | ||||
| def start_password_reset( | ||||
|         _ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     user_name = params['user_name'] | ||||
|     user = users.get_user_by_name_or_email(user_name) | ||||
|     if not user.email: | ||||
| @ -30,13 +30,13 @@ def start_password_reset(_ctx, params): | ||||
|     return {} | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/password-reset/(?P<user_name>[^/]+)/?') | ||||
| def finish_password_reset(ctx, params): | ||||
|     ''' Verify token from mail, generate a new password and return it. ''' | ||||
| @rest.routes.post('/password-reset/(?P<user_name>[^/]+)/?') | ||||
| def finish_password_reset( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     user_name = params['user_name'] | ||||
|     user = users.get_user_by_name_or_email(user_name) | ||||
|     good_token = auth.generate_authentication_token(user) | ||||
|     token = ctx.get_param_as_string('token', required=True) | ||||
|     token = ctx.get_param_as_string('token') | ||||
|     if token != good_token: | ||||
|         raise errors.ValidationError('Invalid password reset token.') | ||||
|     new_password = users.reset_user_password(user) | ||||
|  | ||||
| @ -1,44 +1,60 @@ | ||||
| import datetime | ||||
| from szurubooru import search, db, errors | ||||
| from szurubooru.rest import routes | ||||
| from typing import Optional, Dict | ||||
| from datetime import datetime | ||||
| from szurubooru import db, model, errors, rest, search | ||||
| from szurubooru.func import ( | ||||
|     auth, tags, posts, snapshots, favorites, scores, util, versions) | ||||
|     auth, tags, posts, snapshots, favorites, scores, serialization, versions) | ||||
| 
 | ||||
| 
 | ||||
| _search_executor = search.Executor(search.configs.PostSearchConfig()) | ||||
| _search_executor_config = search.configs.PostSearchConfig() | ||||
| _search_executor = search.Executor(_search_executor_config) | ||||
| 
 | ||||
| 
 | ||||
| def _serialize_post(ctx, post): | ||||
| def _get_post_id(params: Dict[str, str]) -> int: | ||||
|     try: | ||||
|         return int(params['post_id']) | ||||
|     except TypeError: | ||||
|         raise posts.InvalidPostIdError( | ||||
|             'Invalid post ID: %r.' % params['post_id']) | ||||
| 
 | ||||
| 
 | ||||
| def _get_post(params: Dict[str, str]) -> model.Post: | ||||
|     return posts.get_post_by_id(_get_post_id(params)) | ||||
| 
 | ||||
| 
 | ||||
| def _serialize_post( | ||||
|         ctx: rest.Context, post: Optional[model.Post]) -> rest.Response: | ||||
|     return posts.serialize_post( | ||||
|         post, | ||||
|         ctx.user, | ||||
|         options=util.get_serialization_options(ctx)) | ||||
|         options=serialization.get_serialization_options(ctx)) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/posts/?') | ||||
| def get_posts(ctx, _params=None): | ||||
| @rest.routes.get('/posts/?') | ||||
| def get_posts( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:list') | ||||
|     _search_executor.config.user = ctx.user | ||||
|     _search_executor_config.user = ctx.user | ||||
|     return _search_executor.execute_and_serialize( | ||||
|         ctx, lambda post: _serialize_post(ctx, post)) | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/posts/?') | ||||
| def create_post(ctx, _params=None): | ||||
| @rest.routes.post('/posts/?') | ||||
| def create_post( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     anonymous = ctx.get_param_as_bool('anonymous', default=False) | ||||
|     if anonymous: | ||||
|         auth.verify_privilege(ctx.user, 'posts:create:anonymous') | ||||
|     else: | ||||
|         auth.verify_privilege(ctx.user, 'posts:create:identified') | ||||
|     content = ctx.get_file('content', required=True) | ||||
|     tag_names = ctx.get_param_as_list('tags', required=False, default=[]) | ||||
|     safety = ctx.get_param_as_string('safety', required=True) | ||||
|     source = ctx.get_param_as_string('source', required=False, default=None) | ||||
|     content = ctx.get_file('content') | ||||
|     tag_names = ctx.get_param_as_list('tags', default=[]) | ||||
|     safety = ctx.get_param_as_string('safety') | ||||
|     source = ctx.get_param_as_string('source', default='') | ||||
|     if ctx.has_param('contentUrl') and not source: | ||||
|         source = ctx.get_param_as_string('contentUrl') | ||||
|     relations = ctx.get_param_as_list('relations', required=False) or [] | ||||
|     notes = ctx.get_param_as_list('notes', required=False) or [] | ||||
|     flags = ctx.get_param_as_list('flags', required=False) or [] | ||||
|         source = ctx.get_param_as_string('contentUrl', default='') | ||||
|     relations = ctx.get_param_as_list('relations', default=[]) | ||||
|     notes = ctx.get_param_as_list('notes', default=[]) | ||||
|     flags = ctx.get_param_as_list('flags', default=[]) | ||||
| 
 | ||||
|     post, new_tags = posts.create_post( | ||||
|         content, tag_names, None if anonymous else ctx.user) | ||||
| @ -61,16 +77,16 @@ def create_post(ctx, _params=None): | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/post/(?P<post_id>[^/]+)/?') | ||||
| def get_post(ctx, params): | ||||
| @rest.routes.get('/post/(?P<post_id>[^/]+)/?') | ||||
| def get_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:view') | ||||
|     post = posts.get_post_by_id(params['post_id']) | ||||
|     post = _get_post(params) | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.put('/post/(?P<post_id>[^/]+)/?') | ||||
| def update_post(ctx, params): | ||||
|     post = posts.get_post_by_id(params['post_id']) | ||||
| @rest.routes.put('/post/(?P<post_id>[^/]+)/?') | ||||
| def update_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     post = _get_post(params) | ||||
|     versions.verify_version(post, ctx) | ||||
|     versions.bump_version(post) | ||||
|     if ctx.has_file('content'): | ||||
| @ -104,7 +120,7 @@ def update_post(ctx, params): | ||||
|     if ctx.has_file('thumbnail'): | ||||
|         auth.verify_privilege(ctx.user, 'posts:edit:thumbnail') | ||||
|         posts.update_post_thumbnail(post, ctx.get_file('thumbnail')) | ||||
|     post.last_edit_time = datetime.datetime.utcnow() | ||||
|     post.last_edit_time = datetime.utcnow() | ||||
|     ctx.session.flush() | ||||
|     snapshots.modify(post, ctx.user) | ||||
|     ctx.session.commit() | ||||
| @ -112,10 +128,10 @@ def update_post(ctx, params): | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.delete('/post/(?P<post_id>[^/]+)/?') | ||||
| def delete_post(ctx, params): | ||||
| @rest.routes.delete('/post/(?P<post_id>[^/]+)/?') | ||||
| def delete_post(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:delete') | ||||
|     post = posts.get_post_by_id(params['post_id']) | ||||
|     post = _get_post(params) | ||||
|     versions.verify_version(post, ctx) | ||||
|     snapshots.delete(post, ctx.user) | ||||
|     posts.delete(post) | ||||
| @ -124,13 +140,14 @@ def delete_post(ctx, params): | ||||
|     return {} | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/post-merge/?') | ||||
| def merge_posts(ctx, _params=None): | ||||
|     source_post_id = ctx.get_param_as_string('remove', required=True) or '' | ||||
|     target_post_id = ctx.get_param_as_string('mergeTo', required=True) or '' | ||||
|     replace_content = ctx.get_param_as_bool('replaceContent') | ||||
| @rest.routes.post('/post-merge/?') | ||||
| def merge_posts( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     source_post_id = ctx.get_param_as_int('remove') | ||||
|     target_post_id = ctx.get_param_as_int('mergeTo') | ||||
|     source_post = posts.get_post_by_id(source_post_id) | ||||
|     target_post = posts.get_post_by_id(target_post_id) | ||||
|     replace_content = ctx.get_param_as_bool('replaceContent') | ||||
|     versions.verify_version(source_post, ctx, 'removeVersion') | ||||
|     versions.verify_version(target_post, ctx, 'mergeToVersion') | ||||
|     versions.bump_version(target_post) | ||||
| @ -141,16 +158,18 @@ def merge_posts(ctx, _params=None): | ||||
|     return _serialize_post(ctx, target_post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/featured-post/?') | ||||
| def get_featured_post(ctx, _params=None): | ||||
| @rest.routes.get('/featured-post/?') | ||||
| def get_featured_post( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     post = posts.try_get_featured_post() | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/featured-post/?') | ||||
| def set_featured_post(ctx, _params=None): | ||||
| @rest.routes.post('/featured-post/?') | ||||
| def set_featured_post( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:feature') | ||||
|     post_id = ctx.get_param_as_int('id', required=True) | ||||
|     post_id = ctx.get_param_as_int('id') | ||||
|     post = posts.get_post_by_id(post_id) | ||||
|     featured_post = posts.try_get_featured_post() | ||||
|     if featured_post and featured_post.post_id == post.post_id: | ||||
| @ -162,55 +181,61 @@ def set_featured_post(ctx, _params=None): | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.put('/post/(?P<post_id>[^/]+)/score/?') | ||||
| def set_post_score(ctx, params): | ||||
| @rest.routes.put('/post/(?P<post_id>[^/]+)/score/?') | ||||
| def set_post_score(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:score') | ||||
|     post = posts.get_post_by_id(params['post_id']) | ||||
|     score = ctx.get_param_as_int('score', required=True) | ||||
|     post = _get_post(params) | ||||
|     score = ctx.get_param_as_int('score') | ||||
|     scores.set_score(post, ctx.user, score) | ||||
|     ctx.session.commit() | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.delete('/post/(?P<post_id>[^/]+)/score/?') | ||||
| def delete_post_score(ctx, params): | ||||
| @rest.routes.delete('/post/(?P<post_id>[^/]+)/score/?') | ||||
| def delete_post_score( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:score') | ||||
|     post = posts.get_post_by_id(params['post_id']) | ||||
|     post = _get_post(params) | ||||
|     scores.delete_score(post, ctx.user) | ||||
|     ctx.session.commit() | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/post/(?P<post_id>[^/]+)/favorite/?') | ||||
| def add_post_to_favorites(ctx, params): | ||||
| @rest.routes.post('/post/(?P<post_id>[^/]+)/favorite/?') | ||||
| def add_post_to_favorites( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:favorite') | ||||
|     post = posts.get_post_by_id(params['post_id']) | ||||
|     post = _get_post(params) | ||||
|     favorites.set_favorite(post, ctx.user) | ||||
|     ctx.session.commit() | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.delete('/post/(?P<post_id>[^/]+)/favorite/?') | ||||
| def delete_post_from_favorites(ctx, params): | ||||
| @rest.routes.delete('/post/(?P<post_id>[^/]+)/favorite/?') | ||||
| def delete_post_from_favorites( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:favorite') | ||||
|     post = posts.get_post_by_id(params['post_id']) | ||||
|     post = _get_post(params) | ||||
|     favorites.unset_favorite(post, ctx.user) | ||||
|     ctx.session.commit() | ||||
|     return _serialize_post(ctx, post) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/post/(?P<post_id>[^/]+)/around/?') | ||||
| def get_posts_around(ctx, params): | ||||
| @rest.routes.get('/post/(?P<post_id>[^/]+)/around/?') | ||||
| def get_posts_around( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:list') | ||||
|     _search_executor.config.user = ctx.user | ||||
|     _search_executor_config.user = ctx.user | ||||
|     post_id = _get_post_id(params) | ||||
|     return _search_executor.get_around_and_serialize( | ||||
|         ctx, params['post_id'], lambda post: _serialize_post(ctx, post)) | ||||
|         ctx, post_id, lambda post: _serialize_post(ctx, post)) | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/posts/reverse-search/?') | ||||
| def get_posts_by_image(ctx, _params=None): | ||||
| @rest.routes.post('/posts/reverse-search/?') | ||||
| def get_posts_by_image( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'posts:reverse_search') | ||||
|     content = ctx.get_file('content', required=True) | ||||
|     content = ctx.get_file('content') | ||||
| 
 | ||||
|     try: | ||||
|         lookalikes = posts.search_by_image(content) | ||||
|  | ||||
| @ -1,13 +1,14 @@ | ||||
| from szurubooru import search | ||||
| from szurubooru.rest import routes | ||||
| from typing import Dict | ||||
| from szurubooru import search, rest | ||||
| from szurubooru.func import auth, snapshots | ||||
| 
 | ||||
| 
 | ||||
| _search_executor = search.Executor(search.configs.SnapshotSearchConfig()) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/snapshots/?') | ||||
| def get_snapshots(ctx, _params=None): | ||||
| @rest.routes.get('/snapshots/?') | ||||
| def get_snapshots( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'snapshots:list') | ||||
|     return _search_executor.execute_and_serialize( | ||||
|         ctx, lambda snapshot: snapshots.serialize_snapshot(snapshot, ctx.user)) | ||||
|  | ||||
| @ -1,18 +1,22 @@ | ||||
| import datetime | ||||
| from szurubooru import db, search | ||||
| from szurubooru.rest import routes | ||||
| from szurubooru.func import auth, tags, snapshots, util, versions | ||||
| from typing import Optional, List, Dict | ||||
| from datetime import datetime | ||||
| from szurubooru import db, model, search, rest | ||||
| from szurubooru.func import auth, tags, snapshots, serialization, versions | ||||
| 
 | ||||
| 
 | ||||
| _search_executor = search.Executor(search.configs.TagSearchConfig()) | ||||
| 
 | ||||
| 
 | ||||
| def _serialize(ctx, tag): | ||||
| def _serialize(ctx: rest.Context, tag: model.Tag) -> rest.Response: | ||||
|     return tags.serialize_tag( | ||||
|         tag, options=util.get_serialization_options(ctx)) | ||||
|         tag, options=serialization.get_serialization_options(ctx)) | ||||
| 
 | ||||
| 
 | ||||
| def _create_if_needed(tag_names, user): | ||||
| def _get_tag(params: Dict[str, str]) -> model.Tag: | ||||
|     return tags.get_tag_by_name(params['tag_name']) | ||||
| 
 | ||||
| 
 | ||||
| def _create_if_needed(tag_names: List[str], user: model.User) -> None: | ||||
|     if not tag_names: | ||||
|         return | ||||
|     _existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) | ||||
| @ -23,25 +27,22 @@ def _create_if_needed(tag_names, user): | ||||
|         snapshots.create(tag, user) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/tags/?') | ||||
| def get_tags(ctx, _params=None): | ||||
| @rest.routes.get('/tags/?') | ||||
| def get_tags(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'tags:list') | ||||
|     return _search_executor.execute_and_serialize( | ||||
|         ctx, lambda tag: _serialize(ctx, tag)) | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/tags/?') | ||||
| def create_tag(ctx, _params=None): | ||||
| @rest.routes.post('/tags/?') | ||||
| def create_tag(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'tags:create') | ||||
| 
 | ||||
|     names = ctx.get_param_as_list('names', required=True) | ||||
|     category = ctx.get_param_as_string('category', required=True) | ||||
|     description = ctx.get_param_as_string( | ||||
|         'description', required=False, default=None) | ||||
|     suggestions = ctx.get_param_as_list( | ||||
|         'suggestions', required=False, default=[]) | ||||
|     implications = ctx.get_param_as_list( | ||||
|         'implications', required=False, default=[]) | ||||
|     names = ctx.get_param_as_list('names') | ||||
|     category = ctx.get_param_as_string('category') | ||||
|     description = ctx.get_param_as_string('description', default='') | ||||
|     suggestions = ctx.get_param_as_list('suggestions', default=[]) | ||||
|     implications = ctx.get_param_as_list('implications', default=[]) | ||||
| 
 | ||||
|     _create_if_needed(suggestions, ctx.user) | ||||
|     _create_if_needed(implications, ctx.user) | ||||
| @ -56,16 +57,16 @@ def create_tag(ctx, _params=None): | ||||
|     return _serialize(ctx, tag) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/tag/(?P<tag_name>.+)') | ||||
| def get_tag(ctx, params): | ||||
| @rest.routes.get('/tag/(?P<tag_name>.+)') | ||||
| def get_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'tags:view') | ||||
|     tag = tags.get_tag_by_name(params['tag_name']) | ||||
|     tag = _get_tag(params) | ||||
|     return _serialize(ctx, tag) | ||||
| 
 | ||||
| 
 | ||||
| @routes.put('/tag/(?P<tag_name>.+)') | ||||
| def update_tag(ctx, params): | ||||
|     tag = tags.get_tag_by_name(params['tag_name']) | ||||
| @rest.routes.put('/tag/(?P<tag_name>.+)') | ||||
| def update_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     tag = _get_tag(params) | ||||
|     versions.verify_version(tag, ctx) | ||||
|     versions.bump_version(tag) | ||||
|     if ctx.has_param('names'): | ||||
| @ -78,7 +79,7 @@ def update_tag(ctx, params): | ||||
|     if ctx.has_param('description'): | ||||
|         auth.verify_privilege(ctx.user, 'tags:edit:description') | ||||
|         tags.update_tag_description( | ||||
|             tag, ctx.get_param_as_string('description', default=None)) | ||||
|             tag, ctx.get_param_as_string('description')) | ||||
|     if ctx.has_param('suggestions'): | ||||
|         auth.verify_privilege(ctx.user, 'tags:edit:suggestions') | ||||
|         suggestions = ctx.get_param_as_list('suggestions') | ||||
| @ -89,7 +90,7 @@ def update_tag(ctx, params): | ||||
|         implications = ctx.get_param_as_list('implications') | ||||
|         _create_if_needed(implications, ctx.user) | ||||
|         tags.update_tag_implications(tag, implications) | ||||
|     tag.last_edit_time = datetime.datetime.utcnow() | ||||
|     tag.last_edit_time = datetime.utcnow() | ||||
|     ctx.session.flush() | ||||
|     snapshots.modify(tag, ctx.user) | ||||
|     ctx.session.commit() | ||||
| @ -97,9 +98,9 @@ def update_tag(ctx, params): | ||||
|     return _serialize(ctx, tag) | ||||
| 
 | ||||
| 
 | ||||
| @routes.delete('/tag/(?P<tag_name>.+)') | ||||
| def delete_tag(ctx, params): | ||||
|     tag = tags.get_tag_by_name(params['tag_name']) | ||||
| @rest.routes.delete('/tag/(?P<tag_name>.+)') | ||||
| def delete_tag(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     tag = _get_tag(params) | ||||
|     versions.verify_version(tag, ctx) | ||||
|     auth.verify_privilege(ctx.user, 'tags:delete') | ||||
|     snapshots.delete(tag, ctx.user) | ||||
| @ -109,10 +110,11 @@ def delete_tag(ctx, params): | ||||
|     return {} | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/tag-merge/?') | ||||
| def merge_tags(ctx, _params=None): | ||||
|     source_tag_name = ctx.get_param_as_string('remove', required=True) or '' | ||||
|     target_tag_name = ctx.get_param_as_string('mergeTo', required=True) or '' | ||||
| @rest.routes.post('/tag-merge/?') | ||||
| def merge_tags( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     source_tag_name = ctx.get_param_as_string('remove') | ||||
|     target_tag_name = ctx.get_param_as_string('mergeTo') | ||||
|     source_tag = tags.get_tag_by_name(source_tag_name) | ||||
|     target_tag = tags.get_tag_by_name(target_tag_name) | ||||
|     versions.verify_version(source_tag, ctx, 'removeVersion') | ||||
| @ -126,10 +128,11 @@ def merge_tags(ctx, _params=None): | ||||
|     return _serialize(ctx, target_tag) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/tag-siblings/(?P<tag_name>.+)') | ||||
| def get_tag_siblings(ctx, params): | ||||
| @rest.routes.get('/tag-siblings/(?P<tag_name>.+)') | ||||
| def get_tag_siblings( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'tags:view') | ||||
|     tag = tags.get_tag_by_name(params['tag_name']) | ||||
|     tag = _get_tag(params) | ||||
|     result = tags.get_tag_siblings(tag) | ||||
|     serialized_siblings = [] | ||||
|     for sibling, occurrences in result: | ||||
|  | ||||
| @ -1,15 +1,18 @@ | ||||
| from szurubooru.rest import routes | ||||
| from typing import Dict | ||||
| from szurubooru import model, rest | ||||
| from szurubooru.func import ( | ||||
|     auth, tags, tag_categories, snapshots, util, versions) | ||||
|     auth, tags, tag_categories, snapshots, serialization, versions) | ||||
| 
 | ||||
| 
 | ||||
| def _serialize(ctx, category): | ||||
| def _serialize( | ||||
|         ctx: rest.Context, category: model.TagCategory) -> rest.Response: | ||||
|     return tag_categories.serialize_category( | ||||
|         category, options=util.get_serialization_options(ctx)) | ||||
|         category, options=serialization.get_serialization_options(ctx)) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/tag-categories/?') | ||||
| def get_tag_categories(ctx, _params=None): | ||||
| @rest.routes.get('/tag-categories/?') | ||||
| def get_tag_categories( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'tag_categories:list') | ||||
|     categories = tag_categories.get_all_categories() | ||||
|     return { | ||||
| @ -17,11 +20,12 @@ def get_tag_categories(ctx, _params=None): | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/tag-categories/?') | ||||
| def create_tag_category(ctx, _params=None): | ||||
| @rest.routes.post('/tag-categories/?') | ||||
| def create_tag_category( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'tag_categories:create') | ||||
|     name = ctx.get_param_as_string('name', required=True) | ||||
|     color = ctx.get_param_as_string('color', required=True) | ||||
|     name = ctx.get_param_as_string('name') | ||||
|     color = ctx.get_param_as_string('color') | ||||
|     category = tag_categories.create_category(name, color) | ||||
|     ctx.session.add(category) | ||||
|     ctx.session.flush() | ||||
| @ -31,15 +35,17 @@ def create_tag_category(ctx, _params=None): | ||||
|     return _serialize(ctx, category) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/tag-category/(?P<category_name>[^/]+)/?') | ||||
| def get_tag_category(ctx, params): | ||||
| @rest.routes.get('/tag-category/(?P<category_name>[^/]+)/?') | ||||
| def get_tag_category( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'tag_categories:view') | ||||
|     category = tag_categories.get_category_by_name(params['category_name']) | ||||
|     return _serialize(ctx, category) | ||||
| 
 | ||||
| 
 | ||||
| @routes.put('/tag-category/(?P<category_name>[^/]+)/?') | ||||
| def update_tag_category(ctx, params): | ||||
| @rest.routes.put('/tag-category/(?P<category_name>[^/]+)/?') | ||||
| def update_tag_category( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     category = tag_categories.get_category_by_name( | ||||
|         params['category_name'], lock=True) | ||||
|     versions.verify_version(category, ctx) | ||||
| @ -59,8 +65,9 @@ def update_tag_category(ctx, params): | ||||
|     return _serialize(ctx, category) | ||||
| 
 | ||||
| 
 | ||||
| @routes.delete('/tag-category/(?P<category_name>[^/]+)/?') | ||||
| def delete_tag_category(ctx, params): | ||||
| @rest.routes.delete('/tag-category/(?P<category_name>[^/]+)/?') | ||||
| def delete_tag_category( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     category = tag_categories.get_category_by_name( | ||||
|         params['category_name'], lock=True) | ||||
|     versions.verify_version(category, ctx) | ||||
| @ -72,8 +79,9 @@ def delete_tag_category(ctx, params): | ||||
|     return {} | ||||
| 
 | ||||
| 
 | ||||
| @routes.put('/tag-category/(?P<category_name>[^/]+)/default/?') | ||||
| def set_tag_category_as_default(ctx, params): | ||||
| @rest.routes.put('/tag-category/(?P<category_name>[^/]+)/default/?') | ||||
| def set_tag_category_as_default( | ||||
|         ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'tag_categories:set_default') | ||||
|     category = tag_categories.get_category_by_name( | ||||
|         params['category_name'], lock=True) | ||||
|  | ||||
| @ -1,10 +1,12 @@ | ||||
| from szurubooru.rest import routes | ||||
| from typing import Dict | ||||
| from szurubooru import rest | ||||
| from szurubooru.func import auth, file_uploads | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/uploads/?') | ||||
| def create_temporary_file(ctx, _params=None): | ||||
| @rest.routes.post('/uploads/?') | ||||
| def create_temporary_file( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'uploads:create') | ||||
|     content = ctx.get_file('content', required=True, allow_tokens=False) | ||||
|     content = ctx.get_file('content', allow_tokens=False) | ||||
|     token = file_uploads.save(content) | ||||
|     return {'token': token} | ||||
|  | ||||
| @ -1,56 +1,57 @@ | ||||
| from szurubooru import search | ||||
| from szurubooru.rest import routes | ||||
| from szurubooru.func import auth, users, util, versions | ||||
| from typing import Any, Dict | ||||
| from szurubooru import model, search, rest | ||||
| from szurubooru.func import auth, users, serialization, versions | ||||
| 
 | ||||
| 
 | ||||
| _search_executor = search.Executor(search.configs.UserSearchConfig()) | ||||
| 
 | ||||
| 
 | ||||
| def _serialize(ctx, user, **kwargs): | ||||
| def _serialize( | ||||
|         ctx: rest.Context, user: model.User, **kwargs: Any) -> rest.Response: | ||||
|     return users.serialize_user( | ||||
|         user, | ||||
|         ctx.user, | ||||
|         options=util.get_serialization_options(ctx), | ||||
|         options=serialization.get_serialization_options(ctx), | ||||
|         **kwargs) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/users/?') | ||||
| def get_users(ctx, _params=None): | ||||
| @rest.routes.get('/users/?') | ||||
| def get_users(ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'users:list') | ||||
|     return _search_executor.execute_and_serialize( | ||||
|         ctx, lambda user: _serialize(ctx, user)) | ||||
| 
 | ||||
| 
 | ||||
| @routes.post('/users/?') | ||||
| def create_user(ctx, _params=None): | ||||
| @rest.routes.post('/users/?') | ||||
| def create_user( | ||||
|         ctx: rest.Context, _params: Dict[str, str]={}) -> rest.Response: | ||||
|     auth.verify_privilege(ctx.user, 'users:create') | ||||
|     name = ctx.get_param_as_string('name', required=True) | ||||
|     password = ctx.get_param_as_string('password', required=True) | ||||
|     email = ctx.get_param_as_string('email', required=False, default='') | ||||
|     name = ctx.get_param_as_string('name') | ||||
|     password = ctx.get_param_as_string('password') | ||||
|     email = ctx.get_param_as_string('email', default='') | ||||
|     user = users.create_user(name, password, email) | ||||
|     if ctx.has_param('rank'): | ||||
|         users.update_user_rank( | ||||
|             user, ctx.get_param_as_string('rank'), ctx.user) | ||||
|         users.update_user_rank(user, ctx.get_param_as_string('rank'), ctx.user) | ||||
|     if ctx.has_param('avatarStyle'): | ||||
|         users.update_user_avatar( | ||||
|             user, | ||||
|             ctx.get_param_as_string('avatarStyle'), | ||||
|             ctx.get_file('avatar')) | ||||
|             ctx.get_file('avatar', default=b'')) | ||||
|     ctx.session.add(user) | ||||
|     ctx.session.commit() | ||||
|     return _serialize(ctx, user, force_show_email=True) | ||||
| 
 | ||||
| 
 | ||||
| @routes.get('/user/(?P<user_name>[^/]+)/?') | ||||
| def get_user(ctx, params): | ||||
| @rest.routes.get('/user/(?P<user_name>[^/]+)/?') | ||||
| def get_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     user = users.get_user_by_name(params['user_name']) | ||||
|     if ctx.user.user_id != user.user_id: | ||||
|         auth.verify_privilege(ctx.user, 'users:view') | ||||
|     return _serialize(ctx, user) | ||||
| 
 | ||||
| 
 | ||||
| @routes.put('/user/(?P<user_name>[^/]+)/?') | ||||
| def update_user(ctx, params): | ||||
| @rest.routes.put('/user/(?P<user_name>[^/]+)/?') | ||||
| def update_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     user = users.get_user_by_name(params['user_name']) | ||||
|     versions.verify_version(user, ctx) | ||||
|     versions.bump_version(user) | ||||
| @ -74,13 +75,13 @@ def update_user(ctx, params): | ||||
|         users.update_user_avatar( | ||||
|             user, | ||||
|             ctx.get_param_as_string('avatarStyle'), | ||||
|             ctx.get_file('avatar')) | ||||
|             ctx.get_file('avatar', default=b'')) | ||||
|     ctx.session.commit() | ||||
|     return _serialize(ctx, user) | ||||
| 
 | ||||
| 
 | ||||
| @routes.delete('/user/(?P<user_name>[^/]+)/?') | ||||
| def delete_user(ctx, params): | ||||
| @rest.routes.delete('/user/(?P<user_name>[^/]+)/?') | ||||
| def delete_user(ctx: rest.Context, params: Dict[str, str]) -> rest.Response: | ||||
|     user = users.get_user_by_name(params['user_name']) | ||||
|     versions.verify_version(user, ctx) | ||||
|     infix = 'self' if ctx.user.user_id == user.user_id else 'any' | ||||
|  | ||||
| @ -1,8 +1,9 @@ | ||||
| from typing import Dict | ||||
| import os | ||||
| import yaml | ||||
| 
 | ||||
| 
 | ||||
| def merge(left, right): | ||||
| def merge(left: Dict, right: Dict) -> Dict: | ||||
|     for key in right: | ||||
|         if key in left: | ||||
|             if isinstance(left[key], dict) and isinstance(right[key], dict): | ||||
| @ -14,7 +15,7 @@ def merge(left, right): | ||||
|     return left | ||||
| 
 | ||||
| 
 | ||||
| def read_config(): | ||||
| def read_config() -> Dict: | ||||
|     with open('../config.yaml.dist') as handle: | ||||
|         ret = yaml.load(handle.read()) | ||||
|         if os.path.exists('../config.yaml'): | ||||
|  | ||||
							
								
								
									
										36
									
								
								server/szurubooru/db.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										36
									
								
								server/szurubooru/db.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,36 @@ | ||||
| from typing import Any | ||||
| import threading | ||||
| import sqlalchemy as sa | ||||
| import sqlalchemy.orm | ||||
| from szurubooru import config | ||||
| 
 | ||||
| # pylint: disable=invalid-name | ||||
| _data = threading.local() | ||||
| _engine = sa.create_engine(config.config['database'])  # type: Any | ||||
| sessionmaker = sa.orm.sessionmaker(bind=_engine, autoflush=False)  # type: Any | ||||
| session = sa.orm.scoped_session(sessionmaker)  # type: Any | ||||
| 
 | ||||
| 
 | ||||
| def get_session() -> Any: | ||||
|     global session | ||||
|     return session | ||||
| 
 | ||||
| 
 | ||||
| def set_sesssion(new_session: Any) -> None: | ||||
|     global session | ||||
|     session = new_session | ||||
| 
 | ||||
| 
 | ||||
| def reset_query_count() -> None: | ||||
|     _data.query_count = 0 | ||||
| 
 | ||||
| 
 | ||||
| def get_query_count() -> int: | ||||
|     return _data.query_count | ||||
| 
 | ||||
| 
 | ||||
| def _bump_query_count() -> None: | ||||
|     _data.query_count = getattr(_data, 'query_count', 0) + 1 | ||||
| 
 | ||||
| 
 | ||||
| sa.event.listen(_engine, 'after_execute', lambda *args: _bump_query_count()) | ||||
| @ -1,17 +0,0 @@ | ||||
| from szurubooru.db.base import Base | ||||
| from szurubooru.db.user import User | ||||
| from szurubooru.db.tag_category import TagCategory | ||||
| from szurubooru.db.tag import (Tag, TagName, TagSuggestion, TagImplication) | ||||
| from szurubooru.db.post import ( | ||||
|     Post, | ||||
|     PostTag, | ||||
|     PostRelation, | ||||
|     PostFavorite, | ||||
|     PostScore, | ||||
|     PostNote, | ||||
|     PostFeature) | ||||
| from szurubooru.db.comment import (Comment, CommentScore) | ||||
| from szurubooru.db.snapshot import Snapshot | ||||
| from szurubooru.db.session import ( | ||||
|     session, sessionmaker, reset_query_count, get_query_count) | ||||
| import szurubooru.db.util | ||||
| @ -1,27 +0,0 @@ | ||||
| import threading | ||||
| import sqlalchemy | ||||
| from szurubooru import config | ||||
| 
 | ||||
| 
 | ||||
| # pylint: disable=invalid-name | ||||
| _engine = sqlalchemy.create_engine(config.config['database']) | ||||
| sessionmaker = sqlalchemy.orm.sessionmaker(bind=_engine, autoflush=False) | ||||
| session = sqlalchemy.orm.scoped_session(sessionmaker) | ||||
| 
 | ||||
| _data = threading.local() | ||||
| 
 | ||||
| 
 | ||||
| def reset_query_count(): | ||||
|     _data.query_count = 0 | ||||
| 
 | ||||
| 
 | ||||
| def get_query_count(): | ||||
|     return _data.query_count | ||||
| 
 | ||||
| 
 | ||||
| def _bump_query_count(): | ||||
|     _data.query_count = getattr(_data, 'query_count', 0) + 1 | ||||
| 
 | ||||
| 
 | ||||
| sqlalchemy.event.listen( | ||||
|     _engine, 'after_execute', lambda *args: _bump_query_count()) | ||||
| @ -1,34 +0,0 @@ | ||||
| from sqlalchemy.inspection import inspect | ||||
| 
 | ||||
| 
 | ||||
| def get_resource_info(entity): | ||||
|     serializers = { | ||||
|         'tag': lambda tag: tag.first_name, | ||||
|         'tag_category': lambda category: category.name, | ||||
|         'comment': lambda comment: comment.comment_id, | ||||
|         'post': lambda post: post.post_id, | ||||
|     } | ||||
| 
 | ||||
|     resource_type = entity.__table__.name | ||||
|     assert resource_type in serializers | ||||
| 
 | ||||
|     primary_key = inspect(entity).identity | ||||
|     assert primary_key is not None | ||||
|     assert len(primary_key) == 1 | ||||
| 
 | ||||
|     resource_name = serializers[resource_type](entity) | ||||
|     assert resource_name | ||||
| 
 | ||||
|     resource_pkey = primary_key[0] | ||||
|     assert resource_pkey | ||||
| 
 | ||||
|     return (resource_type, resource_pkey, resource_name) | ||||
| 
 | ||||
| 
 | ||||
| def get_aux_entity(session, get_table_info, entity, user): | ||||
|     table, get_column = get_table_info(entity) | ||||
|     return session \ | ||||
|         .query(table) \ | ||||
|         .filter(get_column(table) == get_column(entity)) \ | ||||
|         .filter(table.user_id == user.user_id) \ | ||||
|         .one_or_none() | ||||
| @ -1,5 +1,11 @@ | ||||
| from typing import Dict | ||||
| 
 | ||||
| 
 | ||||
| class BaseError(RuntimeError): | ||||
|     def __init__(self, message='Unknown error', extra_fields=None): | ||||
|     def __init__( | ||||
|             self, | ||||
|             message: str='Unknown error', | ||||
|             extra_fields: Dict[str, str]=None) -> None: | ||||
|         super().__init__(message) | ||||
|         self.extra_fields = extra_fields | ||||
| 
 | ||||
|  | ||||
| @ -2,7 +2,10 @@ import os | ||||
| import time | ||||
| import logging | ||||
| import threading | ||||
| from typing import Callable, Any, Type | ||||
| 
 | ||||
| import coloredlogs | ||||
| import sqlalchemy as sa | ||||
| import sqlalchemy.orm.exc | ||||
| from szurubooru import config, db, errors, rest | ||||
| from szurubooru.func import posts, file_uploads | ||||
| @ -10,7 +13,10 @@ from szurubooru.func import posts, file_uploads | ||||
| from szurubooru import api, middleware | ||||
| 
 | ||||
| 
 | ||||
| def _map_error(ex, target_class, title): | ||||
| def _map_error( | ||||
|         ex: Exception, | ||||
|         target_class: Type[rest.errors.BaseHttpError], | ||||
|         title: str) -> rest.errors.BaseHttpError: | ||||
|     return target_class( | ||||
|         name=type(ex).__name__, | ||||
|         title=title, | ||||
| @ -18,38 +24,38 @@ def _map_error(ex, target_class, title): | ||||
|         extra_fields=getattr(ex, 'extra_fields', {})) | ||||
| 
 | ||||
| 
 | ||||
| def _on_auth_error(ex): | ||||
| def _on_auth_error(ex: Exception) -> None: | ||||
|     raise _map_error(ex, rest.errors.HttpForbidden, 'Authentication error') | ||||
| 
 | ||||
| 
 | ||||
| def _on_validation_error(ex): | ||||
| def _on_validation_error(ex: Exception) -> None: | ||||
|     raise _map_error(ex, rest.errors.HttpBadRequest, 'Validation error') | ||||
| 
 | ||||
| 
 | ||||
| def _on_search_error(ex): | ||||
| def _on_search_error(ex: Exception) -> None: | ||||
|     raise _map_error(ex, rest.errors.HttpBadRequest, 'Search error') | ||||
| 
 | ||||
| 
 | ||||
| def _on_integrity_error(ex): | ||||
| def _on_integrity_error(ex: Exception) -> None: | ||||
|     raise _map_error(ex, rest.errors.HttpConflict, 'Integrity violation') | ||||
| 
 | ||||
| 
 | ||||
| def _on_not_found_error(ex): | ||||
| def _on_not_found_error(ex: Exception) -> None: | ||||
|     raise _map_error(ex, rest.errors.HttpNotFound, 'Not found') | ||||
| 
 | ||||
| 
 | ||||
| def _on_processing_error(ex): | ||||
| def _on_processing_error(ex: Exception) -> None: | ||||
|     raise _map_error(ex, rest.errors.HttpBadRequest, 'Processing error') | ||||
| 
 | ||||
| 
 | ||||
| def _on_third_party_error(ex): | ||||
| def _on_third_party_error(ex: Exception) -> None: | ||||
|     raise _map_error( | ||||
|         ex, | ||||
|         rest.errors.HttpInternalServerError, | ||||
|         'Server configuration error') | ||||
| 
 | ||||
| 
 | ||||
| def _on_stale_data_error(_ex): | ||||
| def _on_stale_data_error(_ex: Exception) -> None: | ||||
|     raise rest.errors.HttpConflict( | ||||
|         name='IntegrityError', | ||||
|         title='Integrity violation', | ||||
| @ -58,7 +64,7 @@ def _on_stale_data_error(_ex): | ||||
|             'Please try again.')) | ||||
| 
 | ||||
| 
 | ||||
| def validate_config(): | ||||
| def validate_config() -> None: | ||||
|     ''' | ||||
|     Check whether config doesn't contain errors that might prove | ||||
|     lethal at runtime. | ||||
| @ -86,7 +92,7 @@ def validate_config(): | ||||
|         raise errors.ConfigError('Database is not configured') | ||||
| 
 | ||||
| 
 | ||||
| def purge_old_uploads(): | ||||
| def purge_old_uploads() -> None: | ||||
|     while True: | ||||
|         try: | ||||
|             file_uploads.purge_old_uploads() | ||||
| @ -95,7 +101,7 @@ def purge_old_uploads(): | ||||
|         time.sleep(60 * 5) | ||||
| 
 | ||||
| 
 | ||||
| def create_app(): | ||||
| def create_app() -> Callable[[Any, Any], Any]: | ||||
|     ''' Create a WSGI compatible App object. ''' | ||||
|     validate_config() | ||||
|     coloredlogs.install(fmt='[%(asctime)-15s] %(name)s %(message)s') | ||||
| @ -122,7 +128,7 @@ def create_app(): | ||||
|     rest.errors.handle(errors.NotFoundError, _on_not_found_error) | ||||
|     rest.errors.handle(errors.ProcessingError, _on_processing_error) | ||||
|     rest.errors.handle(errors.ThirdPartyError, _on_third_party_error) | ||||
|     rest.errors.handle(sqlalchemy.orm.exc.StaleDataError, _on_stale_data_error) | ||||
|     rest.errors.handle(sa.orm.exc.StaleDataError, _on_stale_data_error) | ||||
| 
 | ||||
|     return rest.application | ||||
| 
 | ||||
|  | ||||
| @ -1,22 +1,22 @@ | ||||
| import hashlib | ||||
| import random | ||||
| from collections import OrderedDict | ||||
| from szurubooru import config, db, errors | ||||
| from szurubooru import config, model, errors | ||||
| from szurubooru.func import util | ||||
| 
 | ||||
| 
 | ||||
| RANK_MAP = OrderedDict([ | ||||
|     (db.User.RANK_ANONYMOUS, 'anonymous'), | ||||
|     (db.User.RANK_RESTRICTED, 'restricted'), | ||||
|     (db.User.RANK_REGULAR, 'regular'), | ||||
|     (db.User.RANK_POWER, 'power'), | ||||
|     (db.User.RANK_MODERATOR, 'moderator'), | ||||
|     (db.User.RANK_ADMINISTRATOR, 'administrator'), | ||||
|     (db.User.RANK_NOBODY, 'nobody'), | ||||
|     (model.User.RANK_ANONYMOUS, 'anonymous'), | ||||
|     (model.User.RANK_RESTRICTED, 'restricted'), | ||||
|     (model.User.RANK_REGULAR, 'regular'), | ||||
|     (model.User.RANK_POWER, 'power'), | ||||
|     (model.User.RANK_MODERATOR, 'moderator'), | ||||
|     (model.User.RANK_ADMINISTRATOR, 'administrator'), | ||||
|     (model.User.RANK_NOBODY, 'nobody'), | ||||
| ]) | ||||
| 
 | ||||
| 
 | ||||
| def get_password_hash(salt, password): | ||||
| def get_password_hash(salt: str, password: str) -> str: | ||||
|     ''' Retrieve new-style password hash. ''' | ||||
|     digest = hashlib.sha256() | ||||
|     digest.update(config.config['secret'].encode('utf8')) | ||||
| @ -25,7 +25,7 @@ def get_password_hash(salt, password): | ||||
|     return digest.hexdigest() | ||||
| 
 | ||||
| 
 | ||||
| def get_legacy_password_hash(salt, password): | ||||
| def get_legacy_password_hash(salt: str, password: str) -> str: | ||||
|     ''' Retrieve old-style password hash. ''' | ||||
|     digest = hashlib.sha1() | ||||
|     digest.update(b'1A2/$_4xVa') | ||||
| @ -34,7 +34,7 @@ def get_legacy_password_hash(salt, password): | ||||
|     return digest.hexdigest() | ||||
| 
 | ||||
| 
 | ||||
| def create_password(): | ||||
| def create_password() -> str: | ||||
|     alphabet = { | ||||
|         'c': list('bcdfghijklmnpqrstvwxyz'), | ||||
|         'v': list('aeiou'), | ||||
| @ -44,7 +44,7 @@ def create_password(): | ||||
|     return ''.join(random.choice(alphabet[l]) for l in list(pattern)) | ||||
| 
 | ||||
| 
 | ||||
| def is_valid_password(user, password): | ||||
| def is_valid_password(user: model.User, password: str) -> bool: | ||||
|     assert user | ||||
|     salt, valid_hash = user.password_salt, user.password_hash | ||||
|     possible_hashes = [ | ||||
| @ -54,7 +54,7 @@ def is_valid_password(user, password): | ||||
|     return valid_hash in possible_hashes | ||||
| 
 | ||||
| 
 | ||||
| def has_privilege(user, privilege_name): | ||||
| def has_privilege(user: model.User, privilege_name: str) -> bool: | ||||
|     assert user | ||||
|     all_ranks = list(RANK_MAP.keys()) | ||||
|     assert privilege_name in config.config['privileges'] | ||||
| @ -65,13 +65,13 @@ def has_privilege(user, privilege_name): | ||||
|     return user.rank in good_ranks | ||||
| 
 | ||||
| 
 | ||||
| def verify_privilege(user, privilege_name): | ||||
| def verify_privilege(user: model.User, privilege_name: str) -> None: | ||||
|     assert user | ||||
|     if not has_privilege(user, privilege_name): | ||||
|         raise errors.AuthError('Insufficient privileges to do this.') | ||||
| 
 | ||||
| 
 | ||||
| def generate_authentication_token(user): | ||||
| def generate_authentication_token(user: model.User) -> str: | ||||
|     ''' Generate nonguessable challenge (e.g. links in password reminder). ''' | ||||
|     assert user | ||||
|     digest = hashlib.md5() | ||||
|  | ||||
| @ -1,21 +1,21 @@ | ||||
| from typing import Any, List, Dict | ||||
| from datetime import datetime | ||||
| 
 | ||||
| 
 | ||||
| class LruCacheItem: | ||||
|     def __init__(self, key, value): | ||||
|     def __init__(self, key: object, value: Any) -> None: | ||||
|         self.key = key | ||||
|         self.value = value | ||||
|         self.timestamp = datetime.utcnow() | ||||
| 
 | ||||
| 
 | ||||
| class LruCache: | ||||
|     def __init__(self, length, delta=None): | ||||
|     def __init__(self, length: int) -> None: | ||||
|         self.length = length | ||||
|         self.delta = delta | ||||
|         self.hash = {} | ||||
|         self.item_list = [] | ||||
|         self.hash = {}  # type: Dict[object, LruCacheItem] | ||||
|         self.item_list = []  # type: List[LruCacheItem] | ||||
| 
 | ||||
|     def insert_item(self, item): | ||||
|     def insert_item(self, item: LruCacheItem) -> None: | ||||
|         if item.key in self.hash: | ||||
|             item_index = next( | ||||
|                 i | ||||
| @ -31,11 +31,11 @@ class LruCache: | ||||
|             self.hash[item.key] = item | ||||
|             self.item_list.insert(0, item) | ||||
| 
 | ||||
|     def remove_all(self): | ||||
|     def remove_all(self) -> None: | ||||
|         self.hash = {} | ||||
|         self.item_list = [] | ||||
| 
 | ||||
|     def remove_item(self, item): | ||||
|     def remove_item(self, item: LruCacheItem) -> None: | ||||
|         del self.hash[item.key] | ||||
|         del self.item_list[self.item_list.index(item)] | ||||
| 
 | ||||
| @ -43,22 +43,22 @@ class LruCache: | ||||
| _CACHE = LruCache(length=100) | ||||
| 
 | ||||
| 
 | ||||
| def purge(): | ||||
| def purge() -> None: | ||||
|     _CACHE.remove_all() | ||||
| 
 | ||||
| 
 | ||||
| def has(key): | ||||
| def has(key: object) -> bool: | ||||
|     return key in _CACHE.hash | ||||
| 
 | ||||
| 
 | ||||
| def get(key): | ||||
| def get(key: object) -> Any: | ||||
|     return _CACHE.hash[key].value | ||||
| 
 | ||||
| 
 | ||||
| def remove(key): | ||||
| def remove(key: object) -> None: | ||||
|     if has(key): | ||||
|         del _CACHE.hash[key] | ||||
| 
 | ||||
| 
 | ||||
| def put(key, value): | ||||
| def put(key: object, value: Any) -> None: | ||||
|     _CACHE.insert_item(LruCacheItem(key, value)) | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| import datetime | ||||
| from szurubooru import db, errors | ||||
| from szurubooru.func import users, scores, util | ||||
| from datetime import datetime | ||||
| from typing import Any, Optional, List, Dict, Callable | ||||
| from szurubooru import db, model, errors, rest | ||||
| from szurubooru.func import users, scores, util, serialization | ||||
| 
 | ||||
| 
 | ||||
| class InvalidCommentIdError(errors.ValidationError): | ||||
| @ -15,52 +16,87 @@ class EmptyCommentTextError(errors.ValidationError): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def serialize_comment(comment, auth_user, options=None): | ||||
|     return util.serialize_entity( | ||||
|         comment, | ||||
|         { | ||||
|             'id': lambda: comment.comment_id, | ||||
|             'user': | ||||
|                 lambda: users.serialize_micro_user(comment.user, auth_user), | ||||
|             'postId': lambda: comment.post.post_id, | ||||
|             'version': lambda: comment.version, | ||||
|             'text': lambda: comment.text, | ||||
|             'creationTime': lambda: comment.creation_time, | ||||
|             'lastEditTime': lambda: comment.last_edit_time, | ||||
|             'score': lambda: comment.score, | ||||
|             'ownScore': lambda: scores.get_score(comment, auth_user), | ||||
|         }, | ||||
|         options) | ||||
| class CommentSerializer(serialization.BaseSerializer): | ||||
|     def __init__(self, comment: model.Comment, auth_user: model.User) -> None: | ||||
|         self.comment = comment | ||||
|         self.auth_user = auth_user | ||||
| 
 | ||||
|     def _serializers(self) -> Dict[str, Callable[[], Any]]: | ||||
|         return { | ||||
|             'id': self.serialize_id, | ||||
|             'user': self.serialize_user, | ||||
|             'postId': self.serialize_post_id, | ||||
|             'version': self.serialize_version, | ||||
|             'text': self.serialize_text, | ||||
|             'creationTime': self.serialize_creation_time, | ||||
|             'lastEditTime': self.serialize_last_edit_time, | ||||
|             'score': self.serialize_score, | ||||
|             'ownScore': self.serialize_own_score, | ||||
|         } | ||||
| 
 | ||||
|     def serialize_id(self) -> Any: | ||||
|         return self.comment.comment_id | ||||
| 
 | ||||
|     def serialize_user(self) -> Any: | ||||
|         return users.serialize_micro_user(self.comment.user, self.auth_user) | ||||
| 
 | ||||
|     def serialize_post_id(self) -> Any: | ||||
|         return self.comment.post.post_id | ||||
| 
 | ||||
|     def serialize_version(self) -> Any: | ||||
|         return self.comment.version | ||||
| 
 | ||||
|     def serialize_text(self) -> Any: | ||||
|         return self.comment.text | ||||
| 
 | ||||
|     def serialize_creation_time(self) -> Any: | ||||
|         return self.comment.creation_time | ||||
| 
 | ||||
|     def serialize_last_edit_time(self) -> Any: | ||||
|         return self.comment.last_edit_time | ||||
| 
 | ||||
|     def serialize_score(self) -> Any: | ||||
|         return self.comment.score | ||||
| 
 | ||||
|     def serialize_own_score(self) -> Any: | ||||
|         return scores.get_score(self.comment, self.auth_user) | ||||
| 
 | ||||
| 
 | ||||
| def try_get_comment_by_id(comment_id): | ||||
|     try: | ||||
|         comment_id = int(comment_id) | ||||
|     except ValueError: | ||||
|         raise InvalidCommentIdError('Invalid comment ID: %r.' % comment_id) | ||||
| def serialize_comment( | ||||
|         comment: model.Comment, | ||||
|         auth_user: model.User, | ||||
|         options: List[str]=[]) -> rest.Response: | ||||
|     if comment is None: | ||||
|         return None | ||||
|     return CommentSerializer(comment, auth_user).serialize(options) | ||||
| 
 | ||||
| 
 | ||||
| def try_get_comment_by_id(comment_id: int) -> Optional[model.Comment]: | ||||
|     comment_id = int(comment_id) | ||||
|     return db.session \ | ||||
|         .query(db.Comment) \ | ||||
|         .filter(db.Comment.comment_id == comment_id) \ | ||||
|         .query(model.Comment) \ | ||||
|         .filter(model.Comment.comment_id == comment_id) \ | ||||
|         .one_or_none() | ||||
| 
 | ||||
| 
 | ||||
| def get_comment_by_id(comment_id): | ||||
| def get_comment_by_id(comment_id: int) -> model.Comment: | ||||
|     comment = try_get_comment_by_id(comment_id) | ||||
|     if comment: | ||||
|         return comment | ||||
|     raise CommentNotFoundError('Comment %r not found.' % comment_id) | ||||
| 
 | ||||
| 
 | ||||
| def create_comment(user, post, text): | ||||
|     comment = db.Comment() | ||||
| def create_comment( | ||||
|         user: model.User, post: model.Post, text: str) -> model.Comment: | ||||
|     comment = model.Comment() | ||||
|     comment.user = user | ||||
|     comment.post = post | ||||
|     update_comment_text(comment, text) | ||||
|     comment.creation_time = datetime.datetime.utcnow() | ||||
|     comment.creation_time = datetime.utcnow() | ||||
|     return comment | ||||
| 
 | ||||
| 
 | ||||
| def update_comment_text(comment, text): | ||||
| def update_comment_text(comment: model.Comment, text: str) -> None: | ||||
|     assert comment | ||||
|     if not text: | ||||
|         raise EmptyCommentTextError('Comment text cannot be empty.') | ||||
|  | ||||
| @ -1,21 +1,26 @@ | ||||
| def get_list_diff(old, new): | ||||
|     value = {'type': 'list change', 'added': [], 'removed': []} | ||||
| from typing import List, Dict, Any | ||||
| 
 | ||||
| 
 | ||||
| def get_list_diff(old: List[Any], new: List[Any]) -> Any: | ||||
|     equal = True | ||||
|     removed = []  # type: List[Any] | ||||
|     added = []  # type: List[Any] | ||||
| 
 | ||||
|     for item in old: | ||||
|         if item not in new: | ||||
|             equal = False | ||||
|             value['removed'].append(item) | ||||
|             removed.append(item) | ||||
| 
 | ||||
|     for item in new: | ||||
|         if item not in old: | ||||
|             equal = False | ||||
|             value['added'].append(item) | ||||
|             added.append(item) | ||||
| 
 | ||||
|     return None if equal else value | ||||
|     return None if equal else { | ||||
|         'type': 'list change', 'added': added, 'removed': removed} | ||||
| 
 | ||||
| 
 | ||||
| def get_dict_diff(old, new): | ||||
| def get_dict_diff(old: Dict[str, Any], new: Dict[str, Any]) -> Any: | ||||
|     value = {} | ||||
|     equal = True | ||||
| 
 | ||||
|  | ||||
| @ -1,32 +1,34 @@ | ||||
| import datetime | ||||
| from szurubooru import db, errors | ||||
| from typing import Any, Optional, Callable, Tuple | ||||
| from datetime import datetime | ||||
| from szurubooru import db, model, errors | ||||
| 
 | ||||
| 
 | ||||
| class InvalidFavoriteTargetError(errors.ValidationError): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def _get_table_info(entity): | ||||
| def _get_table_info( | ||||
|         entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]: | ||||
|     assert entity | ||||
|     resource_type, _, _ = db.util.get_resource_info(entity) | ||||
|     resource_type, _, _ = model.util.get_resource_info(entity) | ||||
|     if resource_type == 'post': | ||||
|         return db.PostFavorite, lambda table: table.post_id | ||||
|         return model.PostFavorite, lambda table: table.post_id | ||||
|     raise InvalidFavoriteTargetError() | ||||
| 
 | ||||
| 
 | ||||
| def _get_fav_entity(entity, user): | ||||
| def _get_fav_entity(entity: model.Base, user: model.User) -> model.Base: | ||||
|     assert entity | ||||
|     assert user | ||||
|     return db.util.get_aux_entity(db.session, _get_table_info, entity, user) | ||||
|     return model.util.get_aux_entity(db.session, _get_table_info, entity, user) | ||||
| 
 | ||||
| 
 | ||||
| def has_favorited(entity, user): | ||||
| def has_favorited(entity: model.Base, user: model.User) -> bool: | ||||
|     assert entity | ||||
|     assert user | ||||
|     return _get_fav_entity(entity, user) is not None | ||||
| 
 | ||||
| 
 | ||||
| def unset_favorite(entity, user): | ||||
| def unset_favorite(entity: model.Base, user: Optional[model.User]) -> None: | ||||
|     assert entity | ||||
|     assert user | ||||
|     fav_entity = _get_fav_entity(entity, user) | ||||
| @ -34,7 +36,7 @@ def unset_favorite(entity, user): | ||||
|         db.session.delete(fav_entity) | ||||
| 
 | ||||
| 
 | ||||
| def set_favorite(entity, user): | ||||
| def set_favorite(entity: model.Base, user: Optional[model.User]) -> None: | ||||
|     from szurubooru.func import scores | ||||
|     assert entity | ||||
|     assert user | ||||
| @ -48,5 +50,5 @@ def set_favorite(entity, user): | ||||
|         fav_entity = table() | ||||
|         setattr(fav_entity, get_column(table).name, get_column(entity)) | ||||
|         fav_entity.user = user | ||||
|         fav_entity.time = datetime.datetime.utcnow() | ||||
|         fav_entity.time = datetime.utcnow() | ||||
|         db.session.add(fav_entity) | ||||
|  | ||||
| @ -1,27 +1,28 @@ | ||||
| import datetime | ||||
| from typing import Optional | ||||
| from datetime import datetime, timedelta | ||||
| from szurubooru.func import files, util | ||||
| 
 | ||||
| 
 | ||||
| MAX_MINUTES = 60 | ||||
| 
 | ||||
| 
 | ||||
| def _get_path(checksum): | ||||
| def _get_path(checksum: str) -> str: | ||||
|     return 'temporary-uploads/%s.dat' % checksum | ||||
| 
 | ||||
| 
 | ||||
| def purge_old_uploads(): | ||||
|     now = datetime.datetime.now() | ||||
| def purge_old_uploads() -> None: | ||||
|     now = datetime.now() | ||||
|     for file in files.scan('temporary-uploads'): | ||||
|         file_time = datetime.datetime.fromtimestamp(file.stat().st_ctime) | ||||
|         if now - file_time > datetime.timedelta(minutes=MAX_MINUTES): | ||||
|         file_time = datetime.fromtimestamp(file.stat().st_ctime) | ||||
|         if now - file_time > timedelta(minutes=MAX_MINUTES): | ||||
|             files.delete('temporary-uploads/%s' % file.name) | ||||
| 
 | ||||
| 
 | ||||
| def get(checksum): | ||||
| def get(checksum: str) -> Optional[bytes]: | ||||
|     return files.get('temporary-uploads/%s.dat' % checksum) | ||||
| 
 | ||||
| 
 | ||||
| def save(content): | ||||
| def save(content: bytes) -> str: | ||||
|     checksum = util.get_sha1(content) | ||||
|     path = _get_path(checksum) | ||||
|     if not files.has(path): | ||||
|  | ||||
| @ -1,32 +1,33 @@ | ||||
| from typing import Any, Optional, List | ||||
| import os | ||||
| from szurubooru import config | ||||
| 
 | ||||
| 
 | ||||
| def _get_full_path(path): | ||||
| def _get_full_path(path: str) -> str: | ||||
|     return os.path.join(config.config['data_dir'], path) | ||||
| 
 | ||||
| 
 | ||||
| def delete(path): | ||||
| def delete(path: str) -> None: | ||||
|     full_path = _get_full_path(path) | ||||
|     if os.path.exists(full_path): | ||||
|         os.unlink(full_path) | ||||
| 
 | ||||
| 
 | ||||
| def has(path): | ||||
| def has(path: str) -> bool: | ||||
|     return os.path.exists(_get_full_path(path)) | ||||
| 
 | ||||
| 
 | ||||
| def scan(path): | ||||
| def scan(path: str) -> List[os.DirEntry]: | ||||
|     if has(path): | ||||
|         return os.scandir(_get_full_path(path)) | ||||
|         return list(os.scandir(_get_full_path(path))) | ||||
|     return [] | ||||
| 
 | ||||
| 
 | ||||
| def move(source_path, target_path): | ||||
|     return os.rename(_get_full_path(source_path), _get_full_path(target_path)) | ||||
| def move(source_path: str, target_path: str) -> None: | ||||
|     os.rename(_get_full_path(source_path), _get_full_path(target_path)) | ||||
| 
 | ||||
| 
 | ||||
| def get(path): | ||||
| def get(path: str) -> Optional[bytes]: | ||||
|     full_path = _get_full_path(path) | ||||
|     if not os.path.exists(full_path): | ||||
|         return None | ||||
| @ -34,7 +35,7 @@ def get(path): | ||||
|         return handle.read() | ||||
| 
 | ||||
| 
 | ||||
| def save(path, content): | ||||
| def save(path: str, content: bytes) -> None: | ||||
|     full_path = _get_full_path(path) | ||||
|     os.makedirs(os.path.dirname(full_path), exist_ok=True) | ||||
|     with open(full_path, 'wb') as handle: | ||||
|  | ||||
| @ -1,6 +1,7 @@ | ||||
| import logging | ||||
| from io import BytesIO | ||||
| from datetime import datetime | ||||
| from typing import Any, Optional, Tuple, Set, List, Callable | ||||
| import elasticsearch | ||||
| import elasticsearch_dsl | ||||
| import numpy as np | ||||
| @ -10,13 +11,8 @@ from szurubooru import config, errors | ||||
| 
 | ||||
| # pylint: disable=invalid-name | ||||
| logger = logging.getLogger(__name__) | ||||
| es = elasticsearch.Elasticsearch([{ | ||||
|     'host': config.config['elasticsearch']['host'], | ||||
|     'port': config.config['elasticsearch']['port'], | ||||
| }]) | ||||
| 
 | ||||
| 
 | ||||
| # Math based on paper from H. Chi Wong, Marshall Bern and David Goldber | ||||
| # Math based on paper from H. Chi Wong, Marshall Bern and David Goldberg | ||||
| # Math code taken from https://github.com/ascribe/image-match | ||||
| # (which is licensed under Apache 2 license) | ||||
| 
 | ||||
| @ -32,14 +28,27 @@ MAX_WORDS = 63 | ||||
| ES_DOC_TYPE = 'image' | ||||
| ES_MAX_RESULTS = 100 | ||||
| 
 | ||||
| Window = Tuple[Tuple[float, float], Tuple[float, float]] | ||||
| NpMatrix = Any | ||||
| 
 | ||||
| def _preprocess_image(image_or_path): | ||||
|     img = Image.open(BytesIO(image_or_path)) | ||||
| 
 | ||||
| def _get_session() -> elasticsearch.Elasticsearch: | ||||
|     return elasticsearch.Elasticsearch([{ | ||||
|         'host': config.config['elasticsearch']['host'], | ||||
|         'port': config.config['elasticsearch']['port'], | ||||
|     }]) | ||||
| 
 | ||||
| 
 | ||||
| def _preprocess_image(content: bytes) -> NpMatrix: | ||||
|     img = Image.open(BytesIO(content)) | ||||
|     img = img.convert('RGB') | ||||
|     return rgb2gray(np.asarray(img, dtype=np.uint8)) | ||||
| 
 | ||||
| 
 | ||||
| def _crop_image(image, lower_percentile, upper_percentile): | ||||
| def _crop_image( | ||||
|         image: NpMatrix, | ||||
|         lower_percentile: float, | ||||
|         upper_percentile: float) -> Window: | ||||
|     rw = np.cumsum(np.sum(np.abs(np.diff(image, axis=1)), axis=1)) | ||||
|     cw = np.cumsum(np.sum(np.abs(np.diff(image, axis=0)), axis=0)) | ||||
|     upper_column_limit = np.searchsorted( | ||||
| @ -56,16 +65,19 @@ def _crop_image(image, lower_percentile, upper_percentile): | ||||
|     if lower_column_limit > upper_column_limit: | ||||
|         lower_column_limit = int(lower_percentile / 100. * image.shape[1]) | ||||
|         upper_column_limit = int(upper_percentile / 100. * image.shape[1]) | ||||
|     return [ | ||||
|     return ( | ||||
|         (lower_row_limit, upper_row_limit), | ||||
|         (lower_column_limit, upper_column_limit)] | ||||
|         (lower_column_limit, upper_column_limit)) | ||||
| 
 | ||||
| 
 | ||||
| def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): | ||||
| def _normalize_and_threshold( | ||||
|         diff_array: NpMatrix, | ||||
|         identical_tolerance: float, | ||||
|         n_levels: int) -> None: | ||||
|     mask = np.abs(diff_array) < identical_tolerance | ||||
|     diff_array[mask] = 0. | ||||
|     if np.all(mask): | ||||
|         return None | ||||
|         return | ||||
|     positive_cutoffs = np.percentile( | ||||
|         diff_array[diff_array > 0.], np.linspace(0, 100, n_levels + 1)) | ||||
|     negative_cutoffs = np.percentile( | ||||
| @ -82,18 +94,24 @@ def _normalize_and_threshold(diff_array, identical_tolerance, n_levels): | ||||
|         diff_array[ | ||||
|             (diff_array <= interval[0]) & (diff_array >= interval[1])] = \ | ||||
|             -(level + 1) | ||||
|     return None | ||||
| 
 | ||||
| 
 | ||||
| def _compute_grid_points(image, n, window=None): | ||||
| def _compute_grid_points( | ||||
|         image: NpMatrix, | ||||
|         n: float, | ||||
|         window: Window=None) -> Tuple[NpMatrix, NpMatrix]: | ||||
|     if window is None: | ||||
|         window = [(0, image.shape[0]), (0, image.shape[1])] | ||||
|         window = ((0, image.shape[0]), (0, image.shape[1])) | ||||
|     x_coords = np.linspace(window[0][0], window[0][1], n + 2, dtype=int)[1:-1] | ||||
|     y_coords = np.linspace(window[1][0], window[1][1], n + 2, dtype=int)[1:-1] | ||||
|     return x_coords, y_coords | ||||
| 
 | ||||
| 
 | ||||
| def _compute_mean_level(image, x_coords, y_coords, p): | ||||
| def _compute_mean_level( | ||||
|         image: NpMatrix, | ||||
|         x_coords: NpMatrix, | ||||
|         y_coords: NpMatrix, | ||||
|         p: Optional[float]) -> NpMatrix: | ||||
|     if p is None: | ||||
|         p = max([2.0, int(0.5 + min(image.shape) / 20.)]) | ||||
|     avg_grey = np.zeros((x_coords.shape[0], y_coords.shape[0])) | ||||
| @ -108,7 +126,7 @@ def _compute_mean_level(image, x_coords, y_coords, p): | ||||
|     return avg_grey | ||||
| 
 | ||||
| 
 | ||||
| def _compute_differentials(grey_level_matrix): | ||||
| def _compute_differentials(grey_level_matrix: NpMatrix) -> NpMatrix: | ||||
|     flipped = np.fliplr(grey_level_matrix) | ||||
|     right_neighbors = -np.concatenate( | ||||
|         ( | ||||
| @ -152,8 +170,8 @@ def _compute_differentials(grey_level_matrix): | ||||
|         lower_right_neighbors])) | ||||
| 
 | ||||
| 
 | ||||
| def _generate_signature(path_or_image): | ||||
|     im_array = _preprocess_image(path_or_image) | ||||
| def _generate_signature(content: bytes) -> NpMatrix: | ||||
|     im_array = _preprocess_image(content) | ||||
|     image_limits = _crop_image( | ||||
|         im_array, | ||||
|         lower_percentile=LOWER_PERCENTILE, | ||||
| @ -169,7 +187,7 @@ def _generate_signature(path_or_image): | ||||
|     return np.ravel(diff_matrix).astype('int8') | ||||
| 
 | ||||
| 
 | ||||
| def _get_words(array, k, n): | ||||
| def _get_words(array: NpMatrix, k: int, n: int) -> NpMatrix: | ||||
|     word_positions = np.linspace( | ||||
|         0, array.shape[0], n, endpoint=False).astype('int') | ||||
|     assert k <= array.shape[0] | ||||
| @ -187,21 +205,23 @@ def _get_words(array, k, n): | ||||
|     return words | ||||
| 
 | ||||
| 
 | ||||
| def _words_to_int(word_array): | ||||
| def _words_to_int(word_array: NpMatrix) -> NpMatrix: | ||||
|     width = word_array.shape[1] | ||||
|     coding_vector = 3**np.arange(width) | ||||
|     return np.dot(word_array + 1, coding_vector) | ||||
| 
 | ||||
| 
 | ||||
| def _max_contrast(array): | ||||
| def _max_contrast(array: NpMatrix) -> None: | ||||
|     array[array > 0] = 1 | ||||
|     array[array < 0] = -1 | ||||
|     return None | ||||
| 
 | ||||
| 
 | ||||
| def _normalized_distance(_target_array, _vec, nan_value=1.0): | ||||
|     target_array = _target_array.astype(int) | ||||
|     vec = _vec.astype(int) | ||||
| def _normalized_distance( | ||||
|         target_array: NpMatrix, | ||||
|         vec: NpMatrix, | ||||
|         nan_value: float=1.0) -> List[float]: | ||||
|     target_array = target_array.astype(int) | ||||
|     vec = vec.astype(int) | ||||
|     topvec = np.linalg.norm(vec - target_array, axis=1) | ||||
|     norm1 = np.linalg.norm(vec, axis=0) | ||||
|     norm2 = np.linalg.norm(target_array, axis=1) | ||||
| @ -210,9 +230,9 @@ def _normalized_distance(_target_array, _vec, nan_value=1.0): | ||||
|     return finvec | ||||
| 
 | ||||
| 
 | ||||
| def _safety_blanket(default_param_factory): | ||||
|     def wrapper_outer(target_function): | ||||
|         def wrapper_inner(*args, **kwargs): | ||||
| def _safety_blanket(default_param_factory: Callable[[], Any]) -> Callable: | ||||
|     def wrapper_outer(target_function: Callable) -> Callable: | ||||
|         def wrapper_inner(*args: Any, **kwargs: Any) -> Any: | ||||
|             try: | ||||
|                 return target_function(*args, **kwargs) | ||||
|             except elasticsearch.exceptions.NotFoundError: | ||||
| @ -226,20 +246,20 @@ def _safety_blanket(default_param_factory): | ||||
|             except IOError: | ||||
|                 raise errors.ProcessingError('Not an image.') | ||||
|             except Exception as ex: | ||||
|                 raise errors.ThirdPartyError('Unknown error (%s).', ex) | ||||
|                 raise errors.ThirdPartyError('Unknown error (%s).' % ex) | ||||
|         return wrapper_inner | ||||
|     return wrapper_outer | ||||
| 
 | ||||
| 
 | ||||
| class Lookalike: | ||||
|     def __init__(self, score, distance, path): | ||||
|     def __init__(self, score: int, distance: float, path: Any) -> None: | ||||
|         self.score = score | ||||
|         self.distance = distance | ||||
|         self.path = path | ||||
| 
 | ||||
| 
 | ||||
| @_safety_blanket(lambda: None) | ||||
| def add_image(path, image_content): | ||||
| def add_image(path: str, image_content: bytes) -> None: | ||||
|     assert path | ||||
|     assert image_content | ||||
|     signature = _generate_signature(image_content) | ||||
| @ -253,7 +273,7 @@ def add_image(path, image_content): | ||||
|     for i in range(MAX_WORDS): | ||||
|         record['simple_word_' + str(i)] = words[i].tolist() | ||||
| 
 | ||||
|     es.index( | ||||
|     _get_session().index( | ||||
|         index=config.config['elasticsearch']['index'], | ||||
|         doc_type=ES_DOC_TYPE, | ||||
|         body=record, | ||||
| @ -261,20 +281,20 @@ def add_image(path, image_content): | ||||
| 
 | ||||
| 
 | ||||
| @_safety_blanket(lambda: None) | ||||
| def delete_image(path): | ||||
| def delete_image(path: str) -> None: | ||||
|     assert path | ||||
|     es.delete_by_query( | ||||
|     _get_session().delete_by_query( | ||||
|         index=config.config['elasticsearch']['index'], | ||||
|         doc_type=ES_DOC_TYPE, | ||||
|         body={'query': {'term': {'path': path}}}) | ||||
| 
 | ||||
| 
 | ||||
| @_safety_blanket(lambda: []) | ||||
| def search_by_image(image_content): | ||||
| def search_by_image(image_content: bytes) -> List[Lookalike]: | ||||
|     signature = _generate_signature(image_content) | ||||
|     words = _get_words(signature, k=SAMPLE_WORDS, n=MAX_WORDS) | ||||
| 
 | ||||
|     res = es.search( | ||||
|     res = _get_session().search( | ||||
|         index=config.config['elasticsearch']['index'], | ||||
|         doc_type=ES_DOC_TYPE, | ||||
|         body={ | ||||
| @ -299,7 +319,7 @@ def search_by_image(image_content): | ||||
|     sigs = np.array([x['_source']['signature'] for x in res]) | ||||
|     dists = _normalized_distance(sigs, np.array(signature)) | ||||
| 
 | ||||
|     ids = set() | ||||
|     ids = set()  # type: Set[int] | ||||
|     ret = [] | ||||
|     for item, dist in zip(res, dists): | ||||
|         id = item['_id'] | ||||
| @ -314,8 +334,8 @@ def search_by_image(image_content): | ||||
| 
 | ||||
| 
 | ||||
| @_safety_blanket(lambda: None) | ||||
| def purge(): | ||||
|     es.delete_by_query( | ||||
| def purge() -> None: | ||||
|     _get_session().delete_by_query( | ||||
|         index=config.config['elasticsearch']['index'], | ||||
|         doc_type=ES_DOC_TYPE, | ||||
|         body={'query': {'match_all': {}}}, | ||||
| @ -323,10 +343,10 @@ def purge(): | ||||
| 
 | ||||
| 
 | ||||
| @_safety_blanket(lambda: set()) | ||||
| def get_all_paths(): | ||||
| def get_all_paths() -> Set[str]: | ||||
|     search = ( | ||||
|         elasticsearch_dsl.Search( | ||||
|             using=es, | ||||
|             using=_get_session(), | ||||
|             index=config.config['elasticsearch']['index'], | ||||
|             doc_type=ES_DOC_TYPE) | ||||
|         .source(['path'])) | ||||
|  | ||||
| @ -1,3 +1,4 @@ | ||||
| from typing import List | ||||
| import logging | ||||
| import json | ||||
| import shlex | ||||
| @ -15,23 +16,23 @@ _SCALE_FIT_FMT = \ | ||||
| 
 | ||||
| 
 | ||||
| class Image: | ||||
|     def __init__(self, content): | ||||
|     def __init__(self, content: bytes) -> None: | ||||
|         self.content = content | ||||
|         self._reload_info() | ||||
| 
 | ||||
|     @property | ||||
|     def width(self): | ||||
|     def width(self) -> int: | ||||
|         return self.info['streams'][0]['width'] | ||||
| 
 | ||||
|     @property | ||||
|     def height(self): | ||||
|     def height(self) -> int: | ||||
|         return self.info['streams'][0]['height'] | ||||
| 
 | ||||
|     @property | ||||
|     def frames(self): | ||||
|     def frames(self) -> int: | ||||
|         return self.info['streams'][0]['nb_read_frames'] | ||||
| 
 | ||||
|     def resize_fill(self, width, height): | ||||
|     def resize_fill(self, width: int, height: int) -> None: | ||||
|         cli = [ | ||||
|             '-i', '{path}', | ||||
|             '-f', 'image2', | ||||
| @ -53,7 +54,7 @@ class Image: | ||||
|         assert self.content | ||||
|         self._reload_info() | ||||
| 
 | ||||
|     def to_png(self): | ||||
|     def to_png(self) -> bytes: | ||||
|         return self._execute([ | ||||
|             '-i', '{path}', | ||||
|             '-f', 'image2', | ||||
| @ -63,7 +64,7 @@ class Image: | ||||
|             '-', | ||||
|         ]) | ||||
| 
 | ||||
|     def to_jpeg(self): | ||||
|     def to_jpeg(self) -> bytes: | ||||
|         return self._execute([ | ||||
|             '-f', 'lavfi', | ||||
|             '-i', 'color=white:s=%dx%d' % (self.width, self.height), | ||||
| @ -76,7 +77,7 @@ class Image: | ||||
|             '-', | ||||
|         ]) | ||||
| 
 | ||||
|     def _execute(self, cli, program='ffmpeg'): | ||||
|     def _execute(self, cli: List[str], program: str='ffmpeg') -> bytes: | ||||
|         extension = mime.get_extension(mime.get_mime_type(self.content)) | ||||
|         assert extension | ||||
|         with util.create_temp_file(suffix='.' + extension) as handle: | ||||
| @ -99,7 +100,7 @@ class Image: | ||||
|                     'Error while processing image.\n' + err.decode('utf-8')) | ||||
|             return out | ||||
| 
 | ||||
|     def _reload_info(self): | ||||
|     def _reload_info(self) -> None: | ||||
|         self.info = json.loads(self._execute([ | ||||
|             '-i', '{path}', | ||||
|             '-of', 'json', | ||||
|  | ||||
| @ -3,7 +3,7 @@ import email.mime.text | ||||
| from szurubooru import config | ||||
| 
 | ||||
| 
 | ||||
| def send_mail(sender, recipient, subject, body): | ||||
| def send_mail(sender: str, recipient: str, subject: str, body: str) -> None: | ||||
|     msg = email.mime.text.MIMEText(body) | ||||
|     msg['Subject'] = subject | ||||
|     msg['From'] = sender | ||||
|  | ||||
| @ -1,7 +1,8 @@ | ||||
| import re | ||||
| from typing import Optional | ||||
| 
 | ||||
| 
 | ||||
| def get_mime_type(content): | ||||
| def get_mime_type(content: bytes) -> str: | ||||
|     if not content: | ||||
|         return 'application/octet-stream' | ||||
| 
 | ||||
| @ -26,7 +27,7 @@ def get_mime_type(content): | ||||
|     return 'application/octet-stream' | ||||
| 
 | ||||
| 
 | ||||
| def get_extension(mime_type): | ||||
| def get_extension(mime_type: str) -> Optional[str]: | ||||
|     extension_map = { | ||||
|         'application/x-shockwave-flash': 'swf', | ||||
|         'image/gif': 'gif', | ||||
| @ -39,19 +40,19 @@ def get_extension(mime_type): | ||||
|     return extension_map.get((mime_type or '').strip().lower(), None) | ||||
| 
 | ||||
| 
 | ||||
| def is_flash(mime_type): | ||||
| def is_flash(mime_type: str) -> bool: | ||||
|     return mime_type.lower() == 'application/x-shockwave-flash' | ||||
| 
 | ||||
| 
 | ||||
| def is_video(mime_type): | ||||
| def is_video(mime_type: str) -> bool: | ||||
|     return mime_type.lower() in ('application/ogg', 'video/mp4', 'video/webm') | ||||
| 
 | ||||
| 
 | ||||
| def is_image(mime_type): | ||||
| def is_image(mime_type: str) -> bool: | ||||
|     return mime_type.lower() in ('image/jpeg', 'image/png', 'image/gif') | ||||
| 
 | ||||
| 
 | ||||
| def is_animated_gif(content): | ||||
| def is_animated_gif(content: bytes) -> bool: | ||||
|     pattern = b'\x21\xF9\x04[\x00-\xFF]{4}\x00[\x2C\x21]' | ||||
|     return get_mime_type(content) == 'image/gif' \ | ||||
|         and len(re.findall(pattern, content)) > 1 | ||||
|  | ||||
| @ -2,7 +2,7 @@ import urllib.request | ||||
| from szurubooru import errors | ||||
| 
 | ||||
| 
 | ||||
| def download(url): | ||||
| def download(url: str) -> bytes: | ||||
|     assert url | ||||
|     request = urllib.request.Request(url) | ||||
|     request.add_header('Referer', url) | ||||
|  | ||||
| @ -1,8 +1,10 @@ | ||||
| import datetime | ||||
| import sqlalchemy | ||||
| from szurubooru import config, db, errors | ||||
| from typing import Any, Optional, Tuple, List, Dict, Callable | ||||
| from datetime import datetime | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import config, db, model, errors, rest | ||||
| from szurubooru.func import ( | ||||
|     users, scores, comments, tags, util, mime, images, files, image_hash) | ||||
|     users, scores, comments, tags, util, | ||||
|     mime, images, files, image_hash, serialization) | ||||
| 
 | ||||
| 
 | ||||
| EMPTY_PIXEL = \ | ||||
| @ -20,7 +22,7 @@ class PostAlreadyFeaturedError(errors.ValidationError): | ||||
| 
 | ||||
| 
 | ||||
| class PostAlreadyUploadedError(errors.ValidationError): | ||||
|     def __init__(self, other_post): | ||||
|     def __init__(self, other_post: model.Post) -> None: | ||||
|         super().__init__( | ||||
|             'Post already uploaded (%d)' % other_post.post_id, | ||||
|             { | ||||
| @ -58,30 +60,30 @@ class InvalidPostFlagError(errors.ValidationError): | ||||
| 
 | ||||
| 
 | ||||
| class PostLookalike(image_hash.Lookalike): | ||||
|     def __init__(self, score, distance, post): | ||||
|     def __init__(self, score: int, distance: float, post: model.Post) -> None: | ||||
|         super().__init__(score, distance, post.post_id) | ||||
|         self.post = post | ||||
| 
 | ||||
| 
 | ||||
| SAFETY_MAP = { | ||||
|     db.Post.SAFETY_SAFE: 'safe', | ||||
|     db.Post.SAFETY_SKETCHY: 'sketchy', | ||||
|     db.Post.SAFETY_UNSAFE: 'unsafe', | ||||
|     model.Post.SAFETY_SAFE: 'safe', | ||||
|     model.Post.SAFETY_SKETCHY: 'sketchy', | ||||
|     model.Post.SAFETY_UNSAFE: 'unsafe', | ||||
| } | ||||
| 
 | ||||
| TYPE_MAP = { | ||||
|     db.Post.TYPE_IMAGE: 'image', | ||||
|     db.Post.TYPE_ANIMATION: 'animation', | ||||
|     db.Post.TYPE_VIDEO: 'video', | ||||
|     db.Post.TYPE_FLASH: 'flash', | ||||
|     model.Post.TYPE_IMAGE: 'image', | ||||
|     model.Post.TYPE_ANIMATION: 'animation', | ||||
|     model.Post.TYPE_VIDEO: 'video', | ||||
|     model.Post.TYPE_FLASH: 'flash', | ||||
| } | ||||
| 
 | ||||
| FLAG_MAP = { | ||||
|     db.Post.FLAG_LOOP: 'loop', | ||||
|     model.Post.FLAG_LOOP: 'loop', | ||||
| } | ||||
| 
 | ||||
| 
 | ||||
| def get_post_content_url(post): | ||||
| def get_post_content_url(post: model.Post) -> str: | ||||
|     assert post | ||||
|     return '%s/posts/%d.%s' % ( | ||||
|         config.config['data_url'].rstrip('/'), | ||||
| @ -89,31 +91,31 @@ def get_post_content_url(post): | ||||
|         mime.get_extension(post.mime_type) or 'dat') | ||||
| 
 | ||||
| 
 | ||||
| def get_post_thumbnail_url(post): | ||||
| def get_post_thumbnail_url(post: model.Post) -> str: | ||||
|     assert post | ||||
|     return '%s/generated-thumbnails/%d.jpg' % ( | ||||
|         config.config['data_url'].rstrip('/'), | ||||
|         post.post_id) | ||||
| 
 | ||||
| 
 | ||||
| def get_post_content_path(post): | ||||
| def get_post_content_path(post: model.Post) -> str: | ||||
|     assert post | ||||
|     assert post.post_id | ||||
|     return 'posts/%d.%s' % ( | ||||
|         post.post_id, mime.get_extension(post.mime_type) or 'dat') | ||||
| 
 | ||||
| 
 | ||||
| def get_post_thumbnail_path(post): | ||||
| def get_post_thumbnail_path(post: model.Post) -> str: | ||||
|     assert post | ||||
|     return 'generated-thumbnails/%d.jpg' % (post.post_id) | ||||
| 
 | ||||
| 
 | ||||
| def get_post_thumbnail_backup_path(post): | ||||
| def get_post_thumbnail_backup_path(post: model.Post) -> str: | ||||
|     assert post | ||||
|     return 'posts/custom-thumbnails/%d.dat' % (post.post_id) | ||||
| 
 | ||||
| 
 | ||||
| def serialize_note(note): | ||||
| def serialize_note(note: model.PostNote) -> rest.Response: | ||||
|     assert note | ||||
|     return { | ||||
|         'polygon': note.polygon, | ||||
| @ -121,113 +123,216 @@ def serialize_note(note): | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| def serialize_post(post, auth_user, options=None): | ||||
|     return util.serialize_entity( | ||||
|         post, | ||||
|         { | ||||
|             'id': lambda: post.post_id, | ||||
|             'version': lambda: post.version, | ||||
|             'creationTime': lambda: post.creation_time, | ||||
|             'lastEditTime': lambda: post.last_edit_time, | ||||
|             'safety': lambda: SAFETY_MAP[post.safety], | ||||
|             'source': lambda: post.source, | ||||
|             'type': lambda: TYPE_MAP[post.type], | ||||
|             'mimeType': lambda: post.mime_type, | ||||
|             'checksum': lambda: post.checksum, | ||||
|             'fileSize': lambda: post.file_size, | ||||
|             'canvasWidth': lambda: post.canvas_width, | ||||
|             'canvasHeight': lambda: post.canvas_height, | ||||
|             'contentUrl': lambda: get_post_content_url(post), | ||||
|             'thumbnailUrl': lambda: get_post_thumbnail_url(post), | ||||
|             'flags': lambda: post.flags, | ||||
|             'tags': lambda: [ | ||||
|                 tag.names[0].name for tag in tags.sort_tags(post.tags)], | ||||
|             'relations': lambda: sorted( | ||||
|                 { | ||||
|                     post['id']: | ||||
|                         post for post in [ | ||||
|                             serialize_micro_post(rel, auth_user) | ||||
|                             for rel in post.relations] | ||||
|                 }.values(), | ||||
|                 key=lambda post: post['id']), | ||||
|             'user': lambda: users.serialize_micro_user(post.user, auth_user), | ||||
|             'score': lambda: post.score, | ||||
|             'ownScore': lambda: scores.get_score(post, auth_user), | ||||
|             'ownFavorite': lambda: len([ | ||||
|                 user for user in post.favorited_by | ||||
|                 if user.user_id == auth_user.user_id] | ||||
|             ) > 0, | ||||
|             'tagCount': lambda: post.tag_count, | ||||
|             'favoriteCount': lambda: post.favorite_count, | ||||
|             'commentCount': lambda: post.comment_count, | ||||
|             'noteCount': lambda: post.note_count, | ||||
|             'relationCount': lambda: post.relation_count, | ||||
|             'featureCount': lambda: post.feature_count, | ||||
|             'lastFeatureTime': lambda: post.last_feature_time, | ||||
|             'favoritedBy': lambda: [ | ||||
|                 users.serialize_micro_user(rel.user, auth_user) | ||||
|                 for rel in post.favorited_by | ||||
|             ], | ||||
|             'hasCustomThumbnail': | ||||
|                 lambda: files.has(get_post_thumbnail_backup_path(post)), | ||||
|             'notes': lambda: sorted( | ||||
|                 [serialize_note(note) for note in post.notes], | ||||
|                 key=lambda x: x['polygon']), | ||||
|             'comments': lambda: [ | ||||
|                 comments.serialize_comment(comment, auth_user) | ||||
|                 for comment in sorted( | ||||
|                     post.comments, | ||||
|                     key=lambda comment: comment.creation_time)], | ||||
|         }, | ||||
|         options) | ||||
| class PostSerializer(serialization.BaseSerializer): | ||||
|     def __init__(self, post: model.Post, auth_user: model.User) -> None: | ||||
|         self.post = post | ||||
|         self.auth_user = auth_user | ||||
| 
 | ||||
|     def _serializers(self) -> Dict[str, Callable[[], Any]]: | ||||
|         return { | ||||
|             'id': self.serialize_id, | ||||
|             'version': self.serialize_version, | ||||
|             'creationTime': self.serialize_creation_time, | ||||
|             'lastEditTime': self.serialize_last_edit_time, | ||||
|             'safety': self.serialize_safety, | ||||
|             'source': self.serialize_source, | ||||
|             'type': self.serialize_type, | ||||
|             'mimeType': self.serialize_mime, | ||||
|             'checksum': self.serialize_checksum, | ||||
|             'fileSize': self.serialize_file_size, | ||||
|             'canvasWidth': self.serialize_canvas_width, | ||||
|             'canvasHeight': self.serialize_canvas_height, | ||||
|             'contentUrl': self.serialize_content_url, | ||||
|             'thumbnailUrl': self.serialize_thumbnail_url, | ||||
|             'flags': self.serialize_flags, | ||||
|             'tags': self.serialize_tags, | ||||
|             'relations': self.serialize_relations, | ||||
|             'user': self.serialize_user, | ||||
|             'score': self.serialize_score, | ||||
|             'ownScore': self.serialize_own_score, | ||||
|             'ownFavorite': self.serialize_own_favorite, | ||||
|             'tagCount': self.serialize_tag_count, | ||||
|             'favoriteCount': self.serialize_favorite_count, | ||||
|             'commentCount': self.serialize_comment_count, | ||||
|             'noteCount': self.serialize_note_count, | ||||
|             'relationCount': self.serialize_relation_count, | ||||
|             'featureCount': self.serialize_feature_count, | ||||
|             'lastFeatureTime': self.serialize_last_feature_time, | ||||
|             'favoritedBy': self.serialize_favorited_by, | ||||
|             'hasCustomThumbnail': self.serialize_has_custom_thumbnail, | ||||
|             'notes': self.serialize_notes, | ||||
|             'comments': self.serialize_comments, | ||||
|         } | ||||
| 
 | ||||
|     def serialize_id(self) -> Any: | ||||
|         return self.post.post_id | ||||
| 
 | ||||
|     def serialize_version(self) -> Any: | ||||
|         return self.post.version | ||||
| 
 | ||||
|     def serialize_creation_time(self) -> Any: | ||||
|         return self.post.creation_time | ||||
| 
 | ||||
|     def serialize_last_edit_time(self) -> Any: | ||||
|         return self.post.last_edit_time | ||||
| 
 | ||||
|     def serialize_safety(self) -> Any: | ||||
|         return SAFETY_MAP[self.post.safety] | ||||
| 
 | ||||
|     def serialize_source(self) -> Any: | ||||
|         return self.post.source | ||||
| 
 | ||||
|     def serialize_type(self) -> Any: | ||||
|         return TYPE_MAP[self.post.type] | ||||
| 
 | ||||
|     def serialize_mime(self) -> Any: | ||||
|         return self.post.mime_type | ||||
| 
 | ||||
|     def serialize_checksum(self) -> Any: | ||||
|         return self.post.checksum | ||||
| 
 | ||||
|     def serialize_file_size(self) -> Any: | ||||
|         return self.post.file_size | ||||
| 
 | ||||
|     def serialize_canvas_width(self) -> Any: | ||||
|         return self.post.canvas_width | ||||
| 
 | ||||
|     def serialize_canvas_height(self) -> Any: | ||||
|         return self.post.canvas_height | ||||
| 
 | ||||
|     def serialize_content_url(self) -> Any: | ||||
|         return get_post_content_url(self.post) | ||||
| 
 | ||||
|     def serialize_thumbnail_url(self) -> Any: | ||||
|         return get_post_thumbnail_url(self.post) | ||||
| 
 | ||||
|     def serialize_flags(self) -> Any: | ||||
|         return self.post.flags | ||||
| 
 | ||||
|     def serialize_tags(self) -> Any: | ||||
|         return [tag.names[0].name for tag in tags.sort_tags(self.post.tags)] | ||||
| 
 | ||||
|     def serialize_relations(self) -> Any: | ||||
|         return sorted( | ||||
|             { | ||||
|                 post['id']: post | ||||
|                 for post in [ | ||||
|                     serialize_micro_post(rel, self.auth_user) | ||||
|                     for rel in self.post.relations] | ||||
|             }.values(), | ||||
|             key=lambda post: post['id']) | ||||
| 
 | ||||
|     def serialize_user(self) -> Any: | ||||
|         return users.serialize_micro_user(self.post.user, self.auth_user) | ||||
| 
 | ||||
|     def serialize_score(self) -> Any: | ||||
|         return self.post.score | ||||
| 
 | ||||
|     def serialize_own_score(self) -> Any: | ||||
|         return scores.get_score(self.post, self.auth_user) | ||||
| 
 | ||||
|     def serialize_own_favorite(self) -> Any: | ||||
|         return len([ | ||||
|             user for user in self.post.favorited_by | ||||
|             if user.user_id == self.auth_user.user_id] | ||||
|         ) > 0 | ||||
| 
 | ||||
|     def serialize_tag_count(self) -> Any: | ||||
|         return self.post.tag_count | ||||
| 
 | ||||
|     def serialize_favorite_count(self) -> Any: | ||||
|         return self.post.favorite_count | ||||
| 
 | ||||
|     def serialize_comment_count(self) -> Any: | ||||
|         return self.post.comment_count | ||||
| 
 | ||||
|     def serialize_note_count(self) -> Any: | ||||
|         return self.post.note_count | ||||
| 
 | ||||
|     def serialize_relation_count(self) -> Any: | ||||
|         return self.post.relation_count | ||||
| 
 | ||||
|     def serialize_feature_count(self) -> Any: | ||||
|         return self.post.feature_count | ||||
| 
 | ||||
|     def serialize_last_feature_time(self) -> Any: | ||||
|         return self.post.last_feature_time | ||||
| 
 | ||||
|     def serialize_favorited_by(self) -> Any: | ||||
|         return [ | ||||
|             users.serialize_micro_user(rel.user, self.auth_user) | ||||
|             for rel in self.post.favorited_by | ||||
|         ] | ||||
| 
 | ||||
|     def serialize_has_custom_thumbnail(self) -> Any: | ||||
|         return files.has(get_post_thumbnail_backup_path(self.post)) | ||||
| 
 | ||||
|     def serialize_notes(self) -> Any: | ||||
|         return sorted( | ||||
|             [serialize_note(note) for note in self.post.notes], | ||||
|             key=lambda x: x['polygon']) | ||||
| 
 | ||||
|     def serialize_comments(self) -> Any: | ||||
|         return [ | ||||
|             comments.serialize_comment(comment, self.auth_user) | ||||
|             for comment in sorted( | ||||
|                 self.post.comments, | ||||
|                 key=lambda comment: comment.creation_time)] | ||||
| 
 | ||||
| 
 | ||||
| def serialize_micro_post(post, auth_user): | ||||
| def serialize_post( | ||||
|         post: Optional[model.Post], | ||||
|         auth_user: model.User, | ||||
|         options: List[str]=[]) -> Optional[rest.Response]: | ||||
|     if not post: | ||||
|         return None | ||||
|     return PostSerializer(post, auth_user).serialize(options) | ||||
| 
 | ||||
| 
 | ||||
| def serialize_micro_post( | ||||
|         post: model.Post, auth_user: model.User) -> Optional[rest.Response]: | ||||
|     return serialize_post( | ||||
|         post, | ||||
|         auth_user=auth_user, | ||||
|         options=['id', 'thumbnailUrl']) | ||||
|         post, auth_user=auth_user, options=['id', 'thumbnailUrl']) | ||||
| 
 | ||||
| 
 | ||||
| def get_post_count(): | ||||
|     return db.session.query(sqlalchemy.func.count(db.Post.post_id)).one()[0] | ||||
| def get_post_count() -> int: | ||||
|     return db.session.query(sa.func.count(model.Post.post_id)).one()[0] | ||||
| 
 | ||||
| 
 | ||||
| def try_get_post_by_id(post_id): | ||||
|     try: | ||||
|         post_id = int(post_id) | ||||
|     except ValueError: | ||||
|         raise InvalidPostIdError('Invalid post ID: %r.' % post_id) | ||||
| def try_get_post_by_id(post_id: int) -> Optional[model.Post]: | ||||
|     return db.session \ | ||||
|         .query(db.Post) \ | ||||
|         .filter(db.Post.post_id == post_id) \ | ||||
|         .query(model.Post) \ | ||||
|         .filter(model.Post.post_id == post_id) \ | ||||
|         .one_or_none() | ||||
| 
 | ||||
| 
 | ||||
| def get_post_by_id(post_id): | ||||
| def get_post_by_id(post_id: int) -> model.Post: | ||||
|     post = try_get_post_by_id(post_id) | ||||
|     if not post: | ||||
|         raise PostNotFoundError('Post %r not found.' % post_id) | ||||
|     return post | ||||
| 
 | ||||
| 
 | ||||
| def try_get_current_post_feature(): | ||||
| def try_get_current_post_feature() -> Optional[model.PostFeature]: | ||||
|     return db.session \ | ||||
|         .query(db.PostFeature) \ | ||||
|         .order_by(db.PostFeature.time.desc()) \ | ||||
|         .query(model.PostFeature) \ | ||||
|         .order_by(model.PostFeature.time.desc()) \ | ||||
|         .first() | ||||
| 
 | ||||
| 
 | ||||
| def try_get_featured_post(): | ||||
| def try_get_featured_post() -> Optional[model.Post]: | ||||
|     post_feature = try_get_current_post_feature() | ||||
|     return post_feature.post if post_feature else None | ||||
| 
 | ||||
| 
 | ||||
| def create_post(content, tag_names, user): | ||||
|     post = db.Post() | ||||
|     post.safety = db.Post.SAFETY_SAFE | ||||
| def create_post( | ||||
|         content: bytes, | ||||
|         tag_names: List[str], | ||||
|         user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]: | ||||
|     post = model.Post() | ||||
|     post.safety = model.Post.SAFETY_SAFE | ||||
|     post.user = user | ||||
|     post.creation_time = datetime.datetime.utcnow() | ||||
|     post.creation_time = datetime.utcnow() | ||||
|     post.flags = [] | ||||
| 
 | ||||
|     post.type = '' | ||||
| @ -240,7 +345,7 @@ def create_post(content, tag_names, user): | ||||
|     return (post, new_tags) | ||||
| 
 | ||||
| 
 | ||||
| def update_post_safety(post, safety): | ||||
| def update_post_safety(post: model.Post, safety: str) -> None: | ||||
|     assert post | ||||
|     safety = util.flip(SAFETY_MAP).get(safety, None) | ||||
|     if not safety: | ||||
| @ -249,30 +354,33 @@ def update_post_safety(post, safety): | ||||
|     post.safety = safety | ||||
| 
 | ||||
| 
 | ||||
| def update_post_source(post, source): | ||||
| def update_post_source(post: model.Post, source: Optional[str]) -> None: | ||||
|     assert post | ||||
|     if util.value_exceeds_column_size(source, db.Post.source): | ||||
|     if util.value_exceeds_column_size(source, model.Post.source): | ||||
|         raise InvalidPostSourceError('Source is too long.') | ||||
|     post.source = source | ||||
|     post.source = source or None | ||||
| 
 | ||||
| 
 | ||||
| @sqlalchemy.events.event.listens_for(db.Post, 'after_insert') | ||||
| def _after_post_insert(_mapper, _connection, post): | ||||
| @sa.events.event.listens_for(model.Post, 'after_insert') | ||||
| def _after_post_insert( | ||||
|         _mapper: Any, _connection: Any, post: model.Post) -> None: | ||||
|     _sync_post_content(post) | ||||
| 
 | ||||
| 
 | ||||
| @sqlalchemy.events.event.listens_for(db.Post, 'after_update') | ||||
| def _after_post_update(_mapper, _connection, post): | ||||
| @sa.events.event.listens_for(model.Post, 'after_update') | ||||
| def _after_post_update( | ||||
|         _mapper: Any, _connection: Any, post: model.Post) -> None: | ||||
|     _sync_post_content(post) | ||||
| 
 | ||||
| 
 | ||||
| @sqlalchemy.events.event.listens_for(db.Post, 'before_delete') | ||||
| def _before_post_delete(_mapper, _connection, post): | ||||
| @sa.events.event.listens_for(model.Post, 'before_delete') | ||||
| def _before_post_delete( | ||||
|         _mapper: Any, _connection: Any, post: model.Post) -> None: | ||||
|     if post.post_id: | ||||
|         image_hash.delete_image(post.post_id) | ||||
| 
 | ||||
| 
 | ||||
| def _sync_post_content(post): | ||||
| def _sync_post_content(post: model.Post) -> None: | ||||
|     regenerate_thumb = False | ||||
| 
 | ||||
|     if hasattr(post, '__content'): | ||||
| @ -281,7 +389,7 @@ def _sync_post_content(post): | ||||
|         delattr(post, '__content') | ||||
|         regenerate_thumb = True | ||||
|         if post.post_id and post.type in ( | ||||
|                 db.Post.TYPE_IMAGE, db.Post.TYPE_ANIMATION): | ||||
|                 model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION): | ||||
|             image_hash.delete_image(post.post_id) | ||||
|             image_hash.add_image(post.post_id, content) | ||||
| 
 | ||||
| @ -299,29 +407,29 @@ def _sync_post_content(post): | ||||
|         generate_post_thumbnail(post) | ||||
| 
 | ||||
| 
 | ||||
| def update_post_content(post, content): | ||||
| def update_post_content(post: model.Post, content: Optional[bytes]) -> None: | ||||
|     assert post | ||||
|     if not content: | ||||
|         raise InvalidPostContentError('Post content missing.') | ||||
|     post.mime_type = mime.get_mime_type(content) | ||||
|     if mime.is_flash(post.mime_type): | ||||
|         post.type = db.Post.TYPE_FLASH | ||||
|         post.type = model.Post.TYPE_FLASH | ||||
|     elif mime.is_image(post.mime_type): | ||||
|         if mime.is_animated_gif(content): | ||||
|             post.type = db.Post.TYPE_ANIMATION | ||||
|             post.type = model.Post.TYPE_ANIMATION | ||||
|         else: | ||||
|             post.type = db.Post.TYPE_IMAGE | ||||
|             post.type = model.Post.TYPE_IMAGE | ||||
|     elif mime.is_video(post.mime_type): | ||||
|         post.type = db.Post.TYPE_VIDEO | ||||
|         post.type = model.Post.TYPE_VIDEO | ||||
|     else: | ||||
|         raise InvalidPostContentError( | ||||
|             'Unhandled file type: %r' % post.mime_type) | ||||
| 
 | ||||
|     post.checksum = util.get_sha1(content) | ||||
|     other_post = db.session \ | ||||
|         .query(db.Post) \ | ||||
|         .filter(db.Post.checksum == post.checksum) \ | ||||
|         .filter(db.Post.post_id != post.post_id) \ | ||||
|         .query(model.Post) \ | ||||
|         .filter(model.Post.checksum == post.checksum) \ | ||||
|         .filter(model.Post.post_id != post.post_id) \ | ||||
|         .one_or_none() | ||||
|     if other_post \ | ||||
|             and other_post.post_id \ | ||||
| @ -343,18 +451,20 @@ def update_post_content(post, content): | ||||
|     setattr(post, '__content', content) | ||||
| 
 | ||||
| 
 | ||||
| def update_post_thumbnail(post, content=None): | ||||
| def update_post_thumbnail( | ||||
|         post: model.Post, content: Optional[bytes]=None) -> None: | ||||
|     assert post | ||||
|     setattr(post, '__thumbnail', content) | ||||
| 
 | ||||
| 
 | ||||
| def generate_post_thumbnail(post): | ||||
| def generate_post_thumbnail(post: model.Post) -> None: | ||||
|     assert post | ||||
|     if files.has(get_post_thumbnail_backup_path(post)): | ||||
|         content = files.get(get_post_thumbnail_backup_path(post)) | ||||
|     else: | ||||
|         content = files.get(get_post_content_path(post)) | ||||
|     try: | ||||
|         assert content | ||||
|         image = images.Image(content) | ||||
|         image.resize_fill( | ||||
|             int(config.config['thumbnails']['post_width']), | ||||
| @ -364,14 +474,15 @@ def generate_post_thumbnail(post): | ||||
|         files.save(get_post_thumbnail_path(post), EMPTY_PIXEL) | ||||
| 
 | ||||
| 
 | ||||
| def update_post_tags(post, tag_names): | ||||
| def update_post_tags( | ||||
|         post: model.Post, tag_names: List[str]) -> List[model.Tag]: | ||||
|     assert post | ||||
|     existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) | ||||
|     post.tags = existing_tags + new_tags | ||||
|     return new_tags | ||||
| 
 | ||||
| 
 | ||||
| def update_post_relations(post, new_post_ids): | ||||
| def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None: | ||||
|     assert post | ||||
|     try: | ||||
|         new_post_ids = [int(id) for id in new_post_ids] | ||||
| @ -382,8 +493,8 @@ def update_post_relations(post, new_post_ids): | ||||
|     old_post_ids = [int(p.post_id) for p in old_posts] | ||||
|     if new_post_ids: | ||||
|         new_posts = db.session \ | ||||
|             .query(db.Post) \ | ||||
|             .filter(db.Post.post_id.in_(new_post_ids)) \ | ||||
|             .query(model.Post) \ | ||||
|             .filter(model.Post.post_id.in_(new_post_ids)) \ | ||||
|             .all() | ||||
|     else: | ||||
|         new_posts = [] | ||||
| @ -402,7 +513,7 @@ def update_post_relations(post, new_post_ids): | ||||
|         relation.relations.append(post) | ||||
| 
 | ||||
| 
 | ||||
| def update_post_notes(post, notes): | ||||
| def update_post_notes(post: model.Post, notes: Any) -> None: | ||||
|     assert post | ||||
|     post.notes = [] | ||||
|     for note in notes: | ||||
| @ -433,13 +544,13 @@ def update_post_notes(post, notes): | ||||
|             except ValueError: | ||||
|                 raise InvalidPostNoteError( | ||||
|                     'A point in note\'s polygon must be numeric.') | ||||
|         if util.value_exceeds_column_size(note['text'], db.PostNote.text): | ||||
|         if util.value_exceeds_column_size(note['text'], model.PostNote.text): | ||||
|             raise InvalidPostNoteError('Note text is too long.') | ||||
|         post.notes.append( | ||||
|             db.PostNote(polygon=note['polygon'], text=str(note['text']))) | ||||
|             model.PostNote(polygon=note['polygon'], text=str(note['text']))) | ||||
| 
 | ||||
| 
 | ||||
| def update_post_flags(post, flags): | ||||
| def update_post_flags(post: model.Post, flags: List[str]) -> None: | ||||
|     assert post | ||||
|     target_flags = [] | ||||
|     for flag in flags: | ||||
| @ -451,88 +562,95 @@ def update_post_flags(post, flags): | ||||
|     post.flags = target_flags | ||||
| 
 | ||||
| 
 | ||||
| def feature_post(post, user): | ||||
| def feature_post(post: model.Post, user: Optional[model.User]) -> None: | ||||
|     assert post | ||||
|     post_feature = db.PostFeature() | ||||
|     post_feature.time = datetime.datetime.utcnow() | ||||
|     post_feature = model.PostFeature() | ||||
|     post_feature.time = datetime.utcnow() | ||||
|     post_feature.post = post | ||||
|     post_feature.user = user | ||||
|     db.session.add(post_feature) | ||||
| 
 | ||||
| 
 | ||||
| def delete(post): | ||||
| def delete(post: model.Post) -> None: | ||||
|     assert post | ||||
|     db.session.delete(post) | ||||
| 
 | ||||
| 
 | ||||
| def merge_posts(source_post, target_post, replace_content): | ||||
| def merge_posts( | ||||
|         source_post: model.Post, | ||||
|         target_post: model.Post, | ||||
|         replace_content: bool) -> None: | ||||
|     assert source_post | ||||
|     assert target_post | ||||
|     if source_post.post_id == target_post.post_id: | ||||
|         raise InvalidPostRelationError('Cannot merge post with itself.') | ||||
| 
 | ||||
|     def merge_tables(table, anti_dup_func, source_post_id, target_post_id): | ||||
|     def merge_tables( | ||||
|             table: model.Base, | ||||
|             anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]], | ||||
|             source_post_id: int, | ||||
|             target_post_id: int) -> None: | ||||
|         alias1 = table | ||||
|         alias2 = sqlalchemy.orm.util.aliased(table) | ||||
|         alias2 = sa.orm.util.aliased(table) | ||||
|         update_stmt = ( | ||||
|             sqlalchemy.sql.expression.update(alias1) | ||||
|             sa.sql.expression.update(alias1) | ||||
|             .where(alias1.post_id == source_post_id)) | ||||
| 
 | ||||
|         if anti_dup_func is not None: | ||||
|             update_stmt = ( | ||||
|                 update_stmt | ||||
|                 .where( | ||||
|                     ~sqlalchemy.exists() | ||||
|                     ~sa.exists() | ||||
|                     .where(anti_dup_func(alias1, alias2)) | ||||
|                     .where(alias2.post_id == target_post_id))) | ||||
| 
 | ||||
|         update_stmt = update_stmt.values(post_id=target_post_id) | ||||
|         db.session.execute(update_stmt) | ||||
| 
 | ||||
|     def merge_tags(source_post_id, target_post_id): | ||||
|     def merge_tags(source_post_id: int, target_post_id: int) -> None: | ||||
|         merge_tables( | ||||
|             db.PostTag, | ||||
|             model.PostTag, | ||||
|             lambda alias1, alias2: alias1.tag_id == alias2.tag_id, | ||||
|             source_post_id, | ||||
|             target_post_id) | ||||
| 
 | ||||
|     def merge_scores(source_post_id, target_post_id): | ||||
|     def merge_scores(source_post_id: int, target_post_id: int) -> None: | ||||
|         merge_tables( | ||||
|             db.PostScore, | ||||
|             model.PostScore, | ||||
|             lambda alias1, alias2: alias1.user_id == alias2.user_id, | ||||
|             source_post_id, | ||||
|             target_post_id) | ||||
| 
 | ||||
|     def merge_favorites(source_post_id, target_post_id): | ||||
|     def merge_favorites(source_post_id: int, target_post_id: int) -> None: | ||||
|         merge_tables( | ||||
|             db.PostFavorite, | ||||
|             model.PostFavorite, | ||||
|             lambda alias1, alias2: alias1.user_id == alias2.user_id, | ||||
|             source_post_id, | ||||
|             target_post_id) | ||||
| 
 | ||||
|     def merge_comments(source_post_id, target_post_id): | ||||
|         merge_tables(db.Comment, None, source_post_id, target_post_id) | ||||
|     def merge_comments(source_post_id: int, target_post_id: int) -> None: | ||||
|         merge_tables(model.Comment, None, source_post_id, target_post_id) | ||||
| 
 | ||||
|     def merge_relations(source_post_id, target_post_id): | ||||
|         alias1 = db.PostRelation | ||||
|         alias2 = sqlalchemy.orm.util.aliased(db.PostRelation) | ||||
|     def merge_relations(source_post_id: int, target_post_id: int) -> None: | ||||
|         alias1 = model.PostRelation | ||||
|         alias2 = sa.orm.util.aliased(model.PostRelation) | ||||
|         update_stmt = ( | ||||
|             sqlalchemy.sql.expression.update(alias1) | ||||
|             sa.sql.expression.update(alias1) | ||||
|             .where(alias1.parent_id == source_post_id) | ||||
|             .where(alias1.child_id != target_post_id) | ||||
|             .where( | ||||
|                 ~sqlalchemy.exists() | ||||
|                 ~sa.exists() | ||||
|                 .where(alias2.child_id == alias1.child_id) | ||||
|                 .where(alias2.parent_id == target_post_id)) | ||||
|             .values(parent_id=target_post_id)) | ||||
|         db.session.execute(update_stmt) | ||||
| 
 | ||||
|         update_stmt = ( | ||||
|             sqlalchemy.sql.expression.update(alias1) | ||||
|             sa.sql.expression.update(alias1) | ||||
|             .where(alias1.child_id == source_post_id) | ||||
|             .where(alias1.parent_id != target_post_id) | ||||
|             .where( | ||||
|                 ~sqlalchemy.exists() | ||||
|                 ~sa.exists() | ||||
|                 .where(alias2.parent_id == alias1.parent_id) | ||||
|                 .where(alias2.child_id == target_post_id)) | ||||
|             .values(child_id=target_post_id)) | ||||
| @ -553,15 +671,15 @@ def merge_posts(source_post, target_post, replace_content): | ||||
|         update_post_content(target_post, content) | ||||
| 
 | ||||
| 
 | ||||
| def search_by_image_exact(image_content): | ||||
| def search_by_image_exact(image_content: bytes) -> Optional[model.Post]: | ||||
|     checksum = util.get_sha1(image_content) | ||||
|     return db.session \ | ||||
|         .query(db.Post) \ | ||||
|         .filter(db.Post.checksum == checksum) \ | ||||
|         .query(model.Post) \ | ||||
|         .filter(model.Post.checksum == checksum) \ | ||||
|         .one_or_none() | ||||
| 
 | ||||
| 
 | ||||
| def search_by_image(image_content): | ||||
| def search_by_image(image_content: bytes) -> List[PostLookalike]: | ||||
|     ret = [] | ||||
|     for result in image_hash.search_by_image(image_content): | ||||
|         ret.append(PostLookalike( | ||||
| @ -571,24 +689,24 @@ def search_by_image(image_content): | ||||
|     return ret | ||||
| 
 | ||||
| 
 | ||||
| def populate_reverse_search(): | ||||
| def populate_reverse_search() -> None: | ||||
|     excluded_post_ids = image_hash.get_all_paths() | ||||
| 
 | ||||
|     post_ids_to_hash = ( | ||||
|         db.session | ||||
|         .query(db.Post.post_id) | ||||
|         .query(model.Post.post_id) | ||||
|         .filter( | ||||
|             (db.Post.type == db.Post.TYPE_IMAGE) | | ||||
|             (db.Post.type == db.Post.TYPE_ANIMATION)) | ||||
|         .filter(~db.Post.post_id.in_(excluded_post_ids)) | ||||
|         .order_by(db.Post.post_id.asc()) | ||||
|             (model.Post.type == model.Post.TYPE_IMAGE) | | ||||
|             (model.Post.type == model.Post.TYPE_ANIMATION)) | ||||
|         .filter(~model.Post.post_id.in_(excluded_post_ids)) | ||||
|         .order_by(model.Post.post_id.asc()) | ||||
|         .all()) | ||||
| 
 | ||||
|     for post_ids_chunk in util.chunks(post_ids_to_hash, 100): | ||||
|         posts_chunk = ( | ||||
|             db.session | ||||
|             .query(db.Post) | ||||
|             .filter(db.Post.post_id.in_(post_ids_chunk)) | ||||
|             .query(model.Post) | ||||
|             .filter(model.Post.post_id.in_(post_ids_chunk)) | ||||
|             .all()) | ||||
|         for post in posts_chunk: | ||||
|             content_path = get_post_content_path(post) | ||||
|  | ||||
| @ -1,5 +1,6 @@ | ||||
| import datetime | ||||
| from szurubooru import db, errors | ||||
| from typing import Any, Tuple, Callable | ||||
| from szurubooru import db, model, errors | ||||
| 
 | ||||
| 
 | ||||
| class InvalidScoreTargetError(errors.ValidationError): | ||||
| @ -10,22 +11,23 @@ class InvalidScoreValueError(errors.ValidationError): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def _get_table_info(entity): | ||||
| def _get_table_info( | ||||
|         entity: model.Base) -> Tuple[model.Base, Callable[[model.Base], Any]]: | ||||
|     assert entity | ||||
|     resource_type, _, _ = db.util.get_resource_info(entity) | ||||
|     resource_type, _, _ = model.util.get_resource_info(entity) | ||||
|     if resource_type == 'post': | ||||
|         return db.PostScore, lambda table: table.post_id | ||||
|         return model.PostScore, lambda table: table.post_id | ||||
|     elif resource_type == 'comment': | ||||
|         return db.CommentScore, lambda table: table.comment_id | ||||
|         return model.CommentScore, lambda table: table.comment_id | ||||
|     raise InvalidScoreTargetError() | ||||
| 
 | ||||
| 
 | ||||
| def _get_score_entity(entity, user): | ||||
| def _get_score_entity(entity: model.Base, user: model.User) -> model.Base: | ||||
|     assert user | ||||
|     return db.util.get_aux_entity(db.session, _get_table_info, entity, user) | ||||
|     return model.util.get_aux_entity(db.session, _get_table_info, entity, user) | ||||
| 
 | ||||
| 
 | ||||
| def delete_score(entity, user): | ||||
| def delete_score(entity: model.Base, user: model.User) -> None: | ||||
|     assert entity | ||||
|     assert user | ||||
|     score_entity = _get_score_entity(entity, user) | ||||
| @ -33,7 +35,7 @@ def delete_score(entity, user): | ||||
|         db.session.delete(score_entity) | ||||
| 
 | ||||
| 
 | ||||
| def get_score(entity, user): | ||||
| def get_score(entity: model.Base, user: model.User) -> int: | ||||
|     assert entity | ||||
|     assert user | ||||
|     table, get_column = _get_table_info(entity) | ||||
| @ -45,7 +47,7 @@ def get_score(entity, user): | ||||
|     return row[0] if row else 0 | ||||
| 
 | ||||
| 
 | ||||
| def set_score(entity, user, score): | ||||
| def set_score(entity: model.Base, user: model.User, score: int) -> None: | ||||
|     from szurubooru.func import favorites | ||||
|     assert entity | ||||
|     assert user | ||||
|  | ||||
							
								
								
									
										27
									
								
								server/szurubooru/func/serialization.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										27
									
								
								server/szurubooru/func/serialization.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,27 @@ | ||||
| from typing import Any, Optional, List, Dict, Callable | ||||
| from szurubooru import db, model, rest, errors | ||||
| 
 | ||||
| 
 | ||||
| def get_serialization_options(ctx: rest.Context) -> List[str]: | ||||
|     return ctx.get_param_as_list('fields', default=[]) | ||||
| 
 | ||||
| 
 | ||||
| class BaseSerializer: | ||||
|     _fields = {}  # type: Dict[str, Callable[[model.Base], Any]] | ||||
| 
 | ||||
|     def serialize(self, options: List[str]) -> Any: | ||||
|         field_factories = self._serializers() | ||||
|         if not options: | ||||
|             options = list(field_factories.keys()) | ||||
|         ret = {} | ||||
|         for key in options: | ||||
|             if key not in field_factories: | ||||
|                 raise errors.ValidationError( | ||||
|                     'Invalid key: %r. Valid keys: %r.' % ( | ||||
|                         key, list(sorted(field_factories.keys())))) | ||||
|             factory = field_factories[key] | ||||
|             ret[key] = factory() | ||||
|         return ret | ||||
| 
 | ||||
|     def _serializers(self) -> Dict[str, Callable[[], Any]]: | ||||
|         raise NotImplementedError() | ||||
| @ -1,9 +1,10 @@ | ||||
| from typing import Any, Optional, Dict, Callable | ||||
| from datetime import datetime | ||||
| from szurubooru import db | ||||
| from szurubooru import db, model | ||||
| from szurubooru.func import diff, users | ||||
| 
 | ||||
| 
 | ||||
| def get_tag_category_snapshot(category): | ||||
| def get_tag_category_snapshot(category: model.TagCategory) -> Dict[str, Any]: | ||||
|     assert category | ||||
|     return { | ||||
|         'name': category.name, | ||||
| @ -12,7 +13,7 @@ def get_tag_category_snapshot(category): | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| def get_tag_snapshot(tag): | ||||
| def get_tag_snapshot(tag: model.Tag) -> Dict[str, Any]: | ||||
|     assert tag | ||||
|     return { | ||||
|         'names': [tag_name.name for tag_name in tag.names], | ||||
| @ -22,7 +23,7 @@ def get_tag_snapshot(tag): | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| def get_post_snapshot(post): | ||||
| def get_post_snapshot(post: model.Post) -> Dict[str, Any]: | ||||
|     assert post | ||||
|     return { | ||||
|         'source': post.source, | ||||
| @ -45,10 +46,11 @@ _snapshot_factories = { | ||||
|     'tag_category': lambda entity: get_tag_category_snapshot(entity), | ||||
|     'tag': lambda entity: get_tag_snapshot(entity), | ||||
|     'post': lambda entity: get_post_snapshot(entity), | ||||
| } | ||||
| }  # type: Dict[model.Base, Callable[[model.Base], Dict[str ,Any]]] | ||||
| 
 | ||||
| 
 | ||||
| def serialize_snapshot(snapshot, auth_user): | ||||
| def serialize_snapshot( | ||||
|         snapshot: model.Snapshot, auth_user: model.User) -> Dict[str, Any]: | ||||
|     assert snapshot | ||||
|     return { | ||||
|         'operation': snapshot.operation, | ||||
| @ -60,11 +62,14 @@ def serialize_snapshot(snapshot, auth_user): | ||||
|     } | ||||
| 
 | ||||
| 
 | ||||
| def _create(operation, entity, auth_user): | ||||
| def _create( | ||||
|         operation: str, | ||||
|         entity: model.Base, | ||||
|         auth_user: Optional[model.User]) -> model.Snapshot: | ||||
|     resource_type, resource_pkey, resource_name = ( | ||||
|         db.util.get_resource_info(entity)) | ||||
|         model.util.get_resource_info(entity)) | ||||
| 
 | ||||
|     snapshot = db.Snapshot() | ||||
|     snapshot = model.Snapshot() | ||||
|     snapshot.creation_time = datetime.utcnow() | ||||
|     snapshot.operation = operation | ||||
|     snapshot.resource_type = resource_type | ||||
| @ -74,33 +79,33 @@ def _create(operation, entity, auth_user): | ||||
|     return snapshot | ||||
| 
 | ||||
| 
 | ||||
| def create(entity, auth_user): | ||||
| def create(entity: model.Base, auth_user: Optional[model.User]) -> None: | ||||
|     assert entity | ||||
|     snapshot = _create(db.Snapshot.OPERATION_CREATED, entity, auth_user) | ||||
|     snapshot = _create(model.Snapshot.OPERATION_CREATED, entity, auth_user) | ||||
|     snapshot_factory = _snapshot_factories[snapshot.resource_type] | ||||
|     snapshot.data = snapshot_factory(entity) | ||||
|     db.session.add(snapshot) | ||||
| 
 | ||||
| 
 | ||||
| # pylint: disable=protected-access | ||||
| def modify(entity, auth_user): | ||||
| def modify(entity: model.Base, auth_user: Optional[model.User]) -> None: | ||||
|     assert entity | ||||
| 
 | ||||
|     model = next( | ||||
|     table = next( | ||||
|         ( | ||||
|             model | ||||
|             for model in db.Base._decl_class_registry.values() | ||||
|             if hasattr(model, '__table__') | ||||
|             and model.__table__.fullname == entity.__table__.fullname | ||||
|             cls | ||||
|             for cls in model.Base._decl_class_registry.values() | ||||
|             if hasattr(cls, '__table__') | ||||
|             and cls.__table__.fullname == entity.__table__.fullname | ||||
|         ), | ||||
|         None) | ||||
|     assert model | ||||
|     assert table | ||||
| 
 | ||||
|     snapshot = _create(db.Snapshot.OPERATION_MODIFIED, entity, auth_user) | ||||
|     snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user) | ||||
|     snapshot_factory = _snapshot_factories[snapshot.resource_type] | ||||
| 
 | ||||
|     detached_session = db.sessionmaker() | ||||
|     detached_entity = detached_session.query(model).get(snapshot.resource_pkey) | ||||
|     detached_entity = detached_session.query(table).get(snapshot.resource_pkey) | ||||
|     assert detached_entity, 'Entity not found in DB, have you committed it?' | ||||
|     detached_snapshot = snapshot_factory(detached_entity) | ||||
|     detached_session.close() | ||||
| @ -113,19 +118,23 @@ def modify(entity, auth_user): | ||||
|     db.session.add(snapshot) | ||||
| 
 | ||||
| 
 | ||||
| def delete(entity, auth_user): | ||||
| def delete(entity: model.Base, auth_user: Optional[model.User]) -> None: | ||||
|     assert entity | ||||
|     snapshot = _create(db.Snapshot.OPERATION_DELETED, entity, auth_user) | ||||
|     snapshot = _create(model.Snapshot.OPERATION_DELETED, entity, auth_user) | ||||
|     snapshot_factory = _snapshot_factories[snapshot.resource_type] | ||||
|     snapshot.data = snapshot_factory(entity) | ||||
|     db.session.add(snapshot) | ||||
| 
 | ||||
| 
 | ||||
| def merge(source_entity, target_entity, auth_user): | ||||
| def merge( | ||||
|         source_entity: model.Base, | ||||
|         target_entity: model.Base, | ||||
|         auth_user: Optional[model.User]) -> None: | ||||
|     assert source_entity | ||||
|     assert target_entity | ||||
|     snapshot = _create(db.Snapshot.OPERATION_MERGED, source_entity, auth_user) | ||||
|     snapshot = _create( | ||||
|         model.Snapshot.OPERATION_MERGED, source_entity, auth_user) | ||||
|     resource_type, _resource_pkey, resource_name = ( | ||||
|         db.util.get_resource_info(target_entity)) | ||||
|         model.util.get_resource_info(target_entity)) | ||||
|     snapshot.data = [resource_type, resource_name] | ||||
|     db.session.add(snapshot) | ||||
|  | ||||
| @ -1,7 +1,8 @@ | ||||
| import re | ||||
| import sqlalchemy | ||||
| from szurubooru import config, db, errors | ||||
| from szurubooru.func import util, cache | ||||
| from typing import Any, Optional, Dict, List, Callable | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import config, db, model, errors, rest | ||||
| from szurubooru.func import util, serialization, cache | ||||
| 
 | ||||
| 
 | ||||
| DEFAULT_CATEGORY_NAME_CACHE_KEY = 'default-tag-category' | ||||
| @ -27,28 +28,52 @@ class InvalidTagCategoryColorError(errors.ValidationError): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def _verify_name_validity(name): | ||||
| def _verify_name_validity(name: str) -> None: | ||||
|     name_regex = config.config['tag_category_name_regex'] | ||||
|     if not re.match(name_regex, name): | ||||
|         raise InvalidTagCategoryNameError( | ||||
|             'Name must satisfy regex %r.' % name_regex) | ||||
| 
 | ||||
| 
 | ||||
| def serialize_category(category, options=None): | ||||
|     return util.serialize_entity( | ||||
|         category, | ||||
|         { | ||||
|             'name': lambda: category.name, | ||||
|             'version': lambda: category.version, | ||||
|             'color': lambda: category.color, | ||||
|             'usages': lambda: category.tag_count, | ||||
|             'default': lambda: category.default, | ||||
|         }, | ||||
|         options) | ||||
| class TagCategorySerializer(serialization.BaseSerializer): | ||||
|     def __init__(self, category: model.TagCategory) -> None: | ||||
|         self.category = category | ||||
| 
 | ||||
|     def _serializers(self) -> Dict[str, Callable[[], Any]]: | ||||
|         return { | ||||
|             'name': self.serialize_name, | ||||
|             'version': self.serialize_version, | ||||
|             'color': self.serialize_color, | ||||
|             'usages': self.serialize_usages, | ||||
|             'default': self.serialize_default, | ||||
|         } | ||||
| 
 | ||||
|     def serialize_name(self) -> Any: | ||||
|         return self.category.name | ||||
| 
 | ||||
|     def serialize_version(self) -> Any: | ||||
|         return self.category.version | ||||
| 
 | ||||
|     def serialize_color(self) -> Any: | ||||
|         return self.category.color | ||||
| 
 | ||||
|     def serialize_usages(self) -> Any: | ||||
|         return self.category.tag_count | ||||
| 
 | ||||
|     def serialize_default(self) -> Any: | ||||
|         return self.category.default | ||||
| 
 | ||||
| 
 | ||||
| def create_category(name, color): | ||||
|     category = db.TagCategory() | ||||
| def serialize_category( | ||||
|         category: Optional[model.TagCategory], | ||||
|         options: List[str]=[]) -> Optional[rest.Response]: | ||||
|     if not category: | ||||
|         return None | ||||
|     return TagCategorySerializer(category).serialize(options) | ||||
| 
 | ||||
| 
 | ||||
| def create_category(name: str, color: str) -> model.TagCategory: | ||||
|     category = model.TagCategory() | ||||
|     update_category_name(category, name) | ||||
|     update_category_color(category, color) | ||||
|     if not get_all_categories(): | ||||
| @ -56,64 +81,66 @@ def create_category(name, color): | ||||
|     return category | ||||
| 
 | ||||
| 
 | ||||
| def update_category_name(category, name): | ||||
| def update_category_name(category: model.TagCategory, name: str) -> None: | ||||
|     assert category | ||||
|     if not name: | ||||
|         raise InvalidTagCategoryNameError('Name cannot be empty.') | ||||
|     expr = sqlalchemy.func.lower(db.TagCategory.name) == name.lower() | ||||
|     expr = sa.func.lower(model.TagCategory.name) == name.lower() | ||||
|     if category.tag_category_id: | ||||
|         expr = expr & ( | ||||
|             db.TagCategory.tag_category_id != category.tag_category_id) | ||||
|     already_exists = db.session.query(db.TagCategory).filter(expr).count() > 0 | ||||
|             model.TagCategory.tag_category_id != category.tag_category_id) | ||||
|     already_exists = ( | ||||
|         db.session.query(model.TagCategory).filter(expr).count() > 0) | ||||
|     if already_exists: | ||||
|         raise TagCategoryAlreadyExistsError( | ||||
|             'A category with this name already exists.') | ||||
|     if util.value_exceeds_column_size(name, db.TagCategory.name): | ||||
|     if util.value_exceeds_column_size(name, model.TagCategory.name): | ||||
|         raise InvalidTagCategoryNameError('Name is too long.') | ||||
|     _verify_name_validity(name) | ||||
|     category.name = name | ||||
|     cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) | ||||
| 
 | ||||
| 
 | ||||
| def update_category_color(category, color): | ||||
| def update_category_color(category: model.TagCategory, color: str) -> None: | ||||
|     assert category | ||||
|     if not color: | ||||
|         raise InvalidTagCategoryColorError('Color cannot be empty.') | ||||
|     if not re.match(r'^#?[0-9a-z]+$', color): | ||||
|         raise InvalidTagCategoryColorError('Invalid color.') | ||||
|     if util.value_exceeds_column_size(color, db.TagCategory.color): | ||||
|     if util.value_exceeds_column_size(color, model.TagCategory.color): | ||||
|         raise InvalidTagCategoryColorError('Color is too long.') | ||||
|     category.color = color | ||||
| 
 | ||||
| 
 | ||||
| def try_get_category_by_name(name, lock=False): | ||||
| def try_get_category_by_name( | ||||
|         name: str, lock: bool=False) -> Optional[model.TagCategory]: | ||||
|     query = db.session \ | ||||
|         .query(db.TagCategory) \ | ||||
|         .filter(sqlalchemy.func.lower(db.TagCategory.name) == name.lower()) | ||||
|         .query(model.TagCategory) \ | ||||
|         .filter(sa.func.lower(model.TagCategory.name) == name.lower()) | ||||
|     if lock: | ||||
|         query = query.with_lockmode('update') | ||||
|     return query.one_or_none() | ||||
| 
 | ||||
| 
 | ||||
| def get_category_by_name(name, lock=False): | ||||
| def get_category_by_name(name: str, lock: bool=False) -> model.TagCategory: | ||||
|     category = try_get_category_by_name(name, lock) | ||||
|     if not category: | ||||
|         raise TagCategoryNotFoundError('Tag category %r not found.' % name) | ||||
|     return category | ||||
| 
 | ||||
| 
 | ||||
| def get_all_category_names(): | ||||
|     return [row[0] for row in db.session.query(db.TagCategory.name).all()] | ||||
| def get_all_category_names() -> List[str]: | ||||
|     return [row[0] for row in db.session.query(model.TagCategory.name).all()] | ||||
| 
 | ||||
| 
 | ||||
| def get_all_categories(): | ||||
|     return db.session.query(db.TagCategory).all() | ||||
| def get_all_categories() -> List[model.TagCategory]: | ||||
|     return db.session.query(model.TagCategory).all() | ||||
| 
 | ||||
| 
 | ||||
| def try_get_default_category(lock=False): | ||||
| def try_get_default_category(lock: bool=False) -> Optional[model.TagCategory]: | ||||
|     query = db.session \ | ||||
|         .query(db.TagCategory) \ | ||||
|         .filter(db.TagCategory.default) | ||||
|         .query(model.TagCategory) \ | ||||
|         .filter(model.TagCategory.default) | ||||
|     if lock: | ||||
|         query = query.with_lockmode('update') | ||||
|     category = query.first() | ||||
| @ -121,22 +148,22 @@ def try_get_default_category(lock=False): | ||||
|     # category, get the first record available. | ||||
|     if not category: | ||||
|         query = db.session \ | ||||
|             .query(db.TagCategory) \ | ||||
|             .order_by(db.TagCategory.tag_category_id.asc()) | ||||
|             .query(model.TagCategory) \ | ||||
|             .order_by(model.TagCategory.tag_category_id.asc()) | ||||
|         if lock: | ||||
|             query = query.with_lockmode('update') | ||||
|         category = query.first() | ||||
|     return category | ||||
| 
 | ||||
| 
 | ||||
| def get_default_category(lock=False): | ||||
| def get_default_category(lock: bool=False) -> model.TagCategory: | ||||
|     category = try_get_default_category(lock) | ||||
|     if not category: | ||||
|         raise TagCategoryNotFoundError('No tag category created yet.') | ||||
|     return category | ||||
| 
 | ||||
| 
 | ||||
| def get_default_category_name(): | ||||
| def get_default_category_name() -> str: | ||||
|     if cache.has(DEFAULT_CATEGORY_NAME_CACHE_KEY): | ||||
|         return cache.get(DEFAULT_CATEGORY_NAME_CACHE_KEY) | ||||
|     default_category = get_default_category() | ||||
| @ -145,7 +172,7 @@ def get_default_category_name(): | ||||
|     return default_category_name | ||||
| 
 | ||||
| 
 | ||||
| def set_default_category(category): | ||||
| def set_default_category(category: model.TagCategory) -> None: | ||||
|     assert category | ||||
|     old_category = try_get_default_category(lock=True) | ||||
|     if old_category: | ||||
| @ -156,7 +183,7 @@ def set_default_category(category): | ||||
|     cache.remove(DEFAULT_CATEGORY_NAME_CACHE_KEY) | ||||
| 
 | ||||
| 
 | ||||
| def delete_category(category): | ||||
| def delete_category(category: model.TagCategory) -> None: | ||||
|     assert category | ||||
|     if len(get_all_category_names()) == 1: | ||||
|         raise TagCategoryIsInUseError('Cannot delete the last category.') | ||||
|  | ||||
| @ -1,10 +1,11 @@ | ||||
| import datetime | ||||
| import json | ||||
| import os | ||||
| import re | ||||
| import sqlalchemy | ||||
| from szurubooru import config, db, errors | ||||
| from szurubooru.func import util, tag_categories | ||||
| from typing import Any, Optional, Tuple, List, Dict, Callable | ||||
| from datetime import datetime | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import config, db, model, errors, rest | ||||
| from szurubooru.func import util, tag_categories, serialization | ||||
| 
 | ||||
| 
 | ||||
| class TagNotFoundError(errors.NotFoundError): | ||||
| @ -35,31 +36,32 @@ class InvalidTagDescriptionError(errors.ValidationError): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def _verify_name_validity(name): | ||||
|     if util.value_exceeds_column_size(name, db.TagName.name): | ||||
| def _verify_name_validity(name: str) -> None: | ||||
|     if util.value_exceeds_column_size(name, model.TagName.name): | ||||
|         raise InvalidTagNameError('Name is too long.') | ||||
|     name_regex = config.config['tag_name_regex'] | ||||
|     if not re.match(name_regex, name): | ||||
|         raise InvalidTagNameError('Name must satisfy regex %r.' % name_regex) | ||||
| 
 | ||||
| 
 | ||||
| def _get_names(tag): | ||||
| def _get_names(tag: model.Tag) -> List[str]: | ||||
|     assert tag | ||||
|     return [tag_name.name for tag_name in tag.names] | ||||
| 
 | ||||
| 
 | ||||
| def _lower_list(names): | ||||
| def _lower_list(names: List[str]) -> List[str]: | ||||
|     return [name.lower() for name in names] | ||||
| 
 | ||||
| 
 | ||||
| def _check_name_intersection(names1, names2, case_sensitive): | ||||
| def _check_name_intersection( | ||||
|         names1: List[str], names2: List[str], case_sensitive: bool) -> bool: | ||||
|     if not case_sensitive: | ||||
|         names1 = _lower_list(names1) | ||||
|         names2 = _lower_list(names2) | ||||
|     return len(set(names1).intersection(names2)) > 0 | ||||
| 
 | ||||
| 
 | ||||
| def sort_tags(tags): | ||||
| def sort_tags(tags: List[model.Tag]) -> List[model.Tag]: | ||||
|     default_category_name = tag_categories.get_default_category_name() | ||||
|     return sorted( | ||||
|         tags, | ||||
| @ -70,35 +72,70 @@ def sort_tags(tags): | ||||
|     ) | ||||
| 
 | ||||
| 
 | ||||
| def serialize_tag(tag, options=None): | ||||
|     return util.serialize_entity( | ||||
|         tag, | ||||
|         { | ||||
|             'names': lambda: [tag_name.name for tag_name in tag.names], | ||||
|             'category': lambda: tag.category.name, | ||||
|             'version': lambda: tag.version, | ||||
|             'description': lambda: tag.description, | ||||
|             'creationTime': lambda: tag.creation_time, | ||||
|             'lastEditTime': lambda: tag.last_edit_time, | ||||
|             'usages': lambda: tag.post_count, | ||||
|             'suggestions': lambda: [ | ||||
|                 relation.names[0].name | ||||
|                 for relation in sort_tags(tag.suggestions)], | ||||
|             'implications': lambda: [ | ||||
|                 relation.names[0].name | ||||
|                 for relation in sort_tags(tag.implications)], | ||||
|         }, | ||||
|         options) | ||||
| class TagSerializer(serialization.BaseSerializer): | ||||
|     def __init__(self, tag: model.Tag) -> None: | ||||
|         self.tag = tag | ||||
| 
 | ||||
|     def _serializers(self) -> Dict[str, Callable[[], Any]]: | ||||
|         return { | ||||
|             'names': self.serialize_names, | ||||
|             'category': self.serialize_category, | ||||
|             'version': self.serialize_version, | ||||
|             'description': self.serialize_description, | ||||
|             'creationTime': self.serialize_creation_time, | ||||
|             'lastEditTime': self.serialize_last_edit_time, | ||||
|             'usages': self.serialize_usages, | ||||
|             'suggestions': self.serialize_suggestions, | ||||
|             'implications': self.serialize_implications, | ||||
|         } | ||||
| 
 | ||||
|     def serialize_names(self) -> Any: | ||||
|         return [tag_name.name for tag_name in self.tag.names] | ||||
| 
 | ||||
|     def serialize_category(self) -> Any: | ||||
|         return self.tag.category.name | ||||
| 
 | ||||
|     def serialize_version(self) -> Any: | ||||
|         return self.tag.version | ||||
| 
 | ||||
|     def serialize_description(self) -> Any: | ||||
|         return self.tag.description | ||||
| 
 | ||||
|     def serialize_creation_time(self) -> Any: | ||||
|         return self.tag.creation_time | ||||
| 
 | ||||
|     def serialize_last_edit_time(self) -> Any: | ||||
|         return self.tag.last_edit_time | ||||
| 
 | ||||
|     def serialize_usages(self) -> Any: | ||||
|         return self.tag.post_count | ||||
| 
 | ||||
|     def serialize_suggestions(self) -> Any: | ||||
|         return [ | ||||
|             relation.names[0].name | ||||
|             for relation in sort_tags(self.tag.suggestions)] | ||||
| 
 | ||||
|     def serialize_implications(self) -> Any: | ||||
|         return [ | ||||
|             relation.names[0].name | ||||
|             for relation in sort_tags(self.tag.implications)] | ||||
| 
 | ||||
| 
 | ||||
| def export_to_json(): | ||||
|     tags = {} | ||||
|     categories = {} | ||||
| def serialize_tag( | ||||
|         tag: model.Tag, options: List[str]=[]) -> Optional[rest.Response]: | ||||
|     if not tag: | ||||
|         return None | ||||
|     return TagSerializer(tag).serialize(options) | ||||
| 
 | ||||
| 
 | ||||
| def export_to_json() -> None: | ||||
|     tags = {}  # type: Dict[int, Any] | ||||
|     categories = {}  # type: Dict[int, Any] | ||||
| 
 | ||||
|     for result in db.session.query( | ||||
|             db.TagCategory.tag_category_id, | ||||
|             db.TagCategory.name, | ||||
|             db.TagCategory.color).all(): | ||||
|             model.TagCategory.tag_category_id, | ||||
|             model.TagCategory.name, | ||||
|             model.TagCategory.color).all(): | ||||
|         categories[result[0]] = { | ||||
|             'name': result[1], | ||||
|             'color': result[2], | ||||
| @ -106,8 +143,8 @@ def export_to_json(): | ||||
| 
 | ||||
|     for result in ( | ||||
|             db.session | ||||
|             .query(db.TagName.tag_id, db.TagName.name) | ||||
|             .order_by(db.TagName.order) | ||||
|             .query(model.TagName.tag_id, model.TagName.name) | ||||
|             .order_by(model.TagName.order) | ||||
|             .all()): | ||||
|         if not result[0] in tags: | ||||
|             tags[result[0]] = {'names': []} | ||||
| @ -115,8 +152,10 @@ def export_to_json(): | ||||
| 
 | ||||
|     for result in ( | ||||
|             db.session | ||||
|             .query(db.TagSuggestion.parent_id, db.TagName.name) | ||||
|             .join(db.TagName, db.TagName.tag_id == db.TagSuggestion.child_id) | ||||
|             .query(model.TagSuggestion.parent_id, model.TagName.name) | ||||
|             .join( | ||||
|                 model.TagName, | ||||
|                 model.TagName.tag_id == model.TagSuggestion.child_id) | ||||
|             .all()): | ||||
|         if 'suggestions' not in tags[result[0]]: | ||||
|             tags[result[0]]['suggestions'] = [] | ||||
| @ -124,17 +163,19 @@ def export_to_json(): | ||||
| 
 | ||||
|     for result in ( | ||||
|             db.session | ||||
|             .query(db.TagImplication.parent_id, db.TagName.name) | ||||
|             .join(db.TagName, db.TagName.tag_id == db.TagImplication.child_id) | ||||
|             .query(model.TagImplication.parent_id, model.TagName.name) | ||||
|             .join( | ||||
|                 model.TagName, | ||||
|                 model.TagName.tag_id == model.TagImplication.child_id) | ||||
|             .all()): | ||||
|         if 'implications' not in tags[result[0]]: | ||||
|             tags[result[0]]['implications'] = [] | ||||
|         tags[result[0]]['implications'].append(result[1]) | ||||
| 
 | ||||
|     for result in db.session.query( | ||||
|             db.Tag.tag_id, | ||||
|             db.Tag.category_id, | ||||
|             db.Tag.post_count).all(): | ||||
|             model.Tag.tag_id, | ||||
|             model.Tag.category_id, | ||||
|             model.Tag.post_count).all(): | ||||
|         tags[result[0]]['category'] = categories[result[1]]['name'] | ||||
|         tags[result[0]]['usages'] = result[2] | ||||
| 
 | ||||
| @ -148,33 +189,34 @@ def export_to_json(): | ||||
|         handle.write(json.dumps(output, separators=(',', ':'))) | ||||
| 
 | ||||
| 
 | ||||
| def try_get_tag_by_name(name): | ||||
| def try_get_tag_by_name(name: str) -> Optional[model.Tag]: | ||||
|     return ( | ||||
|         db.session | ||||
|         .query(db.Tag) | ||||
|         .join(db.TagName) | ||||
|         .filter(sqlalchemy.func.lower(db.TagName.name) == name.lower()) | ||||
|         .query(model.Tag) | ||||
|         .join(model.TagName) | ||||
|         .filter(sa.func.lower(model.TagName.name) == name.lower()) | ||||
|         .one_or_none()) | ||||
| 
 | ||||
| 
 | ||||
| def get_tag_by_name(name): | ||||
| def get_tag_by_name(name: str) -> model.Tag: | ||||
|     tag = try_get_tag_by_name(name) | ||||
|     if not tag: | ||||
|         raise TagNotFoundError('Tag %r not found.' % name) | ||||
|     return tag | ||||
| 
 | ||||
| 
 | ||||
| def get_tags_by_names(names): | ||||
| def get_tags_by_names(names: List[str]) -> List[model.Tag]: | ||||
|     names = util.icase_unique(names) | ||||
|     if len(names) == 0: | ||||
|         return [] | ||||
|     expr = sqlalchemy.sql.false() | ||||
|     expr = sa.sql.false() | ||||
|     for name in names: | ||||
|         expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) | ||||
|     return db.session.query(db.Tag).join(db.TagName).filter(expr).all() | ||||
|         expr = expr | (sa.func.lower(model.TagName.name) == name.lower()) | ||||
|     return db.session.query(model.Tag).join(model.TagName).filter(expr).all() | ||||
| 
 | ||||
| 
 | ||||
| def get_or_create_tags_by_names(names): | ||||
| def get_or_create_tags_by_names( | ||||
|         names: List[str]) -> Tuple[List[model.Tag], List[model.Tag]]: | ||||
|     names = util.icase_unique(names) | ||||
|     existing_tags = get_tags_by_names(names) | ||||
|     new_tags = [] | ||||
| @ -197,86 +239,87 @@ def get_or_create_tags_by_names(names): | ||||
|     return existing_tags, new_tags | ||||
| 
 | ||||
| 
 | ||||
| def get_tag_siblings(tag): | ||||
| def get_tag_siblings(tag: model.Tag) -> List[model.Tag]: | ||||
|     assert tag | ||||
|     tag_alias = sqlalchemy.orm.aliased(db.Tag) | ||||
|     pt_alias1 = sqlalchemy.orm.aliased(db.PostTag) | ||||
|     pt_alias2 = sqlalchemy.orm.aliased(db.PostTag) | ||||
|     tag_alias = sa.orm.aliased(model.Tag) | ||||
|     pt_alias1 = sa.orm.aliased(model.PostTag) | ||||
|     pt_alias2 = sa.orm.aliased(model.PostTag) | ||||
|     result = ( | ||||
|         db.session | ||||
|         .query(tag_alias, sqlalchemy.func.count(pt_alias2.post_id)) | ||||
|         .query(tag_alias, sa.func.count(pt_alias2.post_id)) | ||||
|         .join(pt_alias1, pt_alias1.tag_id == tag_alias.tag_id) | ||||
|         .join(pt_alias2, pt_alias2.post_id == pt_alias1.post_id) | ||||
|         .filter(pt_alias2.tag_id == tag.tag_id) | ||||
|         .filter(pt_alias1.tag_id != tag.tag_id) | ||||
|         .group_by(tag_alias.tag_id) | ||||
|         .order_by(sqlalchemy.func.count(pt_alias2.post_id).desc()) | ||||
|         .order_by(sa.func.count(pt_alias2.post_id).desc()) | ||||
|         .limit(50)) | ||||
|     return result | ||||
| 
 | ||||
| 
 | ||||
| def delete(source_tag): | ||||
| def delete(source_tag: model.Tag) -> None: | ||||
|     assert source_tag | ||||
|     db.session.execute( | ||||
|         sqlalchemy.sql.expression.delete(db.TagSuggestion) | ||||
|         .where(db.TagSuggestion.child_id == source_tag.tag_id)) | ||||
|         sa.sql.expression.delete(model.TagSuggestion) | ||||
|         .where(model.TagSuggestion.child_id == source_tag.tag_id)) | ||||
|     db.session.execute( | ||||
|         sqlalchemy.sql.expression.delete(db.TagImplication) | ||||
|         .where(db.TagImplication.child_id == source_tag.tag_id)) | ||||
|         sa.sql.expression.delete(model.TagImplication) | ||||
|         .where(model.TagImplication.child_id == source_tag.tag_id)) | ||||
|     db.session.delete(source_tag) | ||||
| 
 | ||||
| 
 | ||||
| def merge_tags(source_tag, target_tag): | ||||
| def merge_tags(source_tag: model.Tag, target_tag: model.Tag) -> None: | ||||
|     assert source_tag | ||||
|     assert target_tag | ||||
|     if source_tag.tag_id == target_tag.tag_id: | ||||
|         raise InvalidTagRelationError('Cannot merge tag with itself.') | ||||
| 
 | ||||
|     def merge_posts(source_tag_id, target_tag_id): | ||||
|         alias1 = db.PostTag | ||||
|         alias2 = sqlalchemy.orm.util.aliased(db.PostTag) | ||||
|     def merge_posts(source_tag_id: int, target_tag_id: int) -> None: | ||||
|         alias1 = model.PostTag | ||||
|         alias2 = sa.orm.util.aliased(model.PostTag) | ||||
|         update_stmt = ( | ||||
|             sqlalchemy.sql.expression.update(alias1) | ||||
|             sa.sql.expression.update(alias1) | ||||
|             .where(alias1.tag_id == source_tag_id)) | ||||
|         update_stmt = ( | ||||
|             update_stmt | ||||
|             .where( | ||||
|                 ~sqlalchemy.exists() | ||||
|                 ~sa.exists() | ||||
|                 .where(alias1.post_id == alias2.post_id) | ||||
|                 .where(alias2.tag_id == target_tag_id))) | ||||
|         update_stmt = update_stmt.values(tag_id=target_tag_id) | ||||
|         db.session.execute(update_stmt) | ||||
| 
 | ||||
|     def merge_relations(table, source_tag_id, target_tag_id): | ||||
|     def merge_relations( | ||||
|             table: model.Base, source_tag_id: int, target_tag_id: int) -> None: | ||||
|         alias1 = table | ||||
|         alias2 = sqlalchemy.orm.util.aliased(table) | ||||
|         alias2 = sa.orm.util.aliased(table) | ||||
|         update_stmt = ( | ||||
|             sqlalchemy.sql.expression.update(alias1) | ||||
|             sa.sql.expression.update(alias1) | ||||
|             .where(alias1.parent_id == source_tag_id) | ||||
|             .where(alias1.child_id != target_tag_id) | ||||
|             .where( | ||||
|                 ~sqlalchemy.exists() | ||||
|                 ~sa.exists() | ||||
|                 .where(alias2.child_id == alias1.child_id) | ||||
|                 .where(alias2.parent_id == target_tag_id)) | ||||
|             .values(parent_id=target_tag_id)) | ||||
|         db.session.execute(update_stmt) | ||||
| 
 | ||||
|         update_stmt = ( | ||||
|             sqlalchemy.sql.expression.update(alias1) | ||||
|             sa.sql.expression.update(alias1) | ||||
|             .where(alias1.child_id == source_tag_id) | ||||
|             .where(alias1.parent_id != target_tag_id) | ||||
|             .where( | ||||
|                 ~sqlalchemy.exists() | ||||
|                 ~sa.exists() | ||||
|                 .where(alias2.parent_id == alias1.parent_id) | ||||
|                 .where(alias2.child_id == target_tag_id)) | ||||
|             .values(child_id=target_tag_id)) | ||||
|         db.session.execute(update_stmt) | ||||
| 
 | ||||
|     def merge_suggestions(source_tag_id, target_tag_id): | ||||
|         merge_relations(db.TagSuggestion, source_tag_id, target_tag_id) | ||||
|     def merge_suggestions(source_tag_id: int, target_tag_id: int) -> None: | ||||
|         merge_relations(model.TagSuggestion, source_tag_id, target_tag_id) | ||||
| 
 | ||||
|     def merge_implications(source_tag_id, target_tag_id): | ||||
|         merge_relations(db.TagImplication, source_tag_id, target_tag_id) | ||||
|     def merge_implications(source_tag_id: int, target_tag_id: int) -> None: | ||||
|         merge_relations(model.TagImplication, source_tag_id, target_tag_id) | ||||
| 
 | ||||
|     merge_posts(source_tag.tag_id, target_tag.tag_id) | ||||
|     merge_suggestions(source_tag.tag_id, target_tag.tag_id) | ||||
| @ -284,9 +327,13 @@ def merge_tags(source_tag, target_tag): | ||||
|     delete(source_tag) | ||||
| 
 | ||||
| 
 | ||||
| def create_tag(names, category_name, suggestions, implications): | ||||
|     tag = db.Tag() | ||||
|     tag.creation_time = datetime.datetime.utcnow() | ||||
| def create_tag( | ||||
|         names: List[str], | ||||
|         category_name: str, | ||||
|         suggestions: List[str], | ||||
|         implications: List[str]) -> model.Tag: | ||||
|     tag = model.Tag() | ||||
|     tag.creation_time = datetime.utcnow() | ||||
|     update_tag_names(tag, names) | ||||
|     update_tag_category_name(tag, category_name) | ||||
|     update_tag_suggestions(tag, suggestions) | ||||
| @ -294,12 +341,12 @@ def create_tag(names, category_name, suggestions, implications): | ||||
|     return tag | ||||
| 
 | ||||
| 
 | ||||
| def update_tag_category_name(tag, category_name): | ||||
| def update_tag_category_name(tag: model.Tag, category_name: str) -> None: | ||||
|     assert tag | ||||
|     tag.category = tag_categories.get_category_by_name(category_name) | ||||
| 
 | ||||
| 
 | ||||
| def update_tag_names(tag, names): | ||||
| def update_tag_names(tag: model.Tag, names: List[str]) -> None: | ||||
|     # sanitize | ||||
|     assert tag | ||||
|     names = util.icase_unique([name for name in names if name]) | ||||
| @ -309,12 +356,12 @@ def update_tag_names(tag, names): | ||||
|         _verify_name_validity(name) | ||||
| 
 | ||||
|     # check for existing tags | ||||
|     expr = sqlalchemy.sql.false() | ||||
|     expr = sa.sql.false() | ||||
|     for name in names: | ||||
|         expr = expr | (sqlalchemy.func.lower(db.TagName.name) == name.lower()) | ||||
|         expr = expr | (sa.func.lower(model.TagName.name) == name.lower()) | ||||
|     if tag.tag_id: | ||||
|         expr = expr & (db.TagName.tag_id != tag.tag_id) | ||||
|     existing_tags = db.session.query(db.TagName).filter(expr).all() | ||||
|         expr = expr & (model.TagName.tag_id != tag.tag_id) | ||||
|     existing_tags = db.session.query(model.TagName).filter(expr).all() | ||||
|     if len(existing_tags): | ||||
|         raise TagAlreadyExistsError( | ||||
|             'One of names is already used by another tag.') | ||||
| @ -326,7 +373,7 @@ def update_tag_names(tag, names): | ||||
|     # add wanted items | ||||
|     for name in names: | ||||
|         if not _check_name_intersection(_get_names(tag), [name], True): | ||||
|             tag.names.append(db.TagName(name, None)) | ||||
|             tag.names.append(model.TagName(name, -1)) | ||||
| 
 | ||||
|     # set alias order to match the request | ||||
|     for i, name in enumerate(names): | ||||
| @ -336,7 +383,7 @@ def update_tag_names(tag, names): | ||||
| 
 | ||||
| 
 | ||||
| # TODO: what to do with relations that do not yet exist? | ||||
| def update_tag_implications(tag, relations): | ||||
| def update_tag_implications(tag: model.Tag, relations: List[str]) -> None: | ||||
|     assert tag | ||||
|     if _check_name_intersection(_get_names(tag), relations, False): | ||||
|         raise InvalidTagRelationError('Tag cannot imply itself.') | ||||
| @ -344,15 +391,15 @@ def update_tag_implications(tag, relations): | ||||
| 
 | ||||
| 
 | ||||
| # TODO: what to do with relations that do not yet exist? | ||||
| def update_tag_suggestions(tag, relations): | ||||
| def update_tag_suggestions(tag: model.Tag, relations: List[str]) -> None: | ||||
|     assert tag | ||||
|     if _check_name_intersection(_get_names(tag), relations, False): | ||||
|         raise InvalidTagRelationError('Tag cannot suggest itself.') | ||||
|     tag.suggestions = get_tags_by_names(relations) | ||||
| 
 | ||||
| 
 | ||||
| def update_tag_description(tag, description): | ||||
| def update_tag_description(tag: model.Tag, description: str) -> None: | ||||
|     assert tag | ||||
|     if util.value_exceeds_column_size(description, db.Tag.description): | ||||
|     if util.value_exceeds_column_size(description, model.Tag.description): | ||||
|         raise InvalidTagDescriptionError('Description is too long.') | ||||
|     tag.description = description | ||||
|     tag.description = description or None | ||||
|  | ||||
| @ -1,8 +1,9 @@ | ||||
| import datetime | ||||
| import re | ||||
| from sqlalchemy import func | ||||
| from szurubooru import config, db, errors | ||||
| from szurubooru.func import auth, util, files, images | ||||
| from typing import Any, Optional, Union, List, Dict, Callable | ||||
| from datetime import datetime | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import config, db, model, errors, rest | ||||
| from szurubooru.func import auth, util, serialization, files, images | ||||
| 
 | ||||
| 
 | ||||
| class UserNotFoundError(errors.NotFoundError): | ||||
| @ -33,11 +34,11 @@ class InvalidAvatarError(errors.ValidationError): | ||||
|     pass | ||||
| 
 | ||||
| 
 | ||||
| def get_avatar_path(user_name): | ||||
| def get_avatar_path(user_name: str) -> str: | ||||
|     return 'avatars/' + user_name.lower() + '.png' | ||||
| 
 | ||||
| 
 | ||||
| def get_avatar_url(user): | ||||
| def get_avatar_url(user: model.User) -> str: | ||||
|     assert user | ||||
|     if user.avatar_style == user.AVATAR_GRAVATAR: | ||||
|         assert user.email or user.name | ||||
| @ -49,7 +50,10 @@ def get_avatar_url(user): | ||||
|         config.config['data_url'].rstrip('/'), user.name.lower()) | ||||
| 
 | ||||
| 
 | ||||
| def get_email(user, auth_user, force_show_email): | ||||
| def get_email( | ||||
|         user: model.User, | ||||
|         auth_user: model.User, | ||||
|         force_show_email: bool) -> Union[bool, str]: | ||||
|     assert user | ||||
|     assert auth_user | ||||
|     if not force_show_email \ | ||||
| @ -59,7 +63,8 @@ def get_email(user, auth_user, force_show_email): | ||||
|     return user.email | ||||
| 
 | ||||
| 
 | ||||
| def get_liked_post_count(user, auth_user): | ||||
| def get_liked_post_count( | ||||
|         user: model.User, auth_user: model.User) -> Union[bool, int]: | ||||
|     assert user | ||||
|     assert auth_user | ||||
|     if auth_user.user_id != user.user_id: | ||||
| @ -67,7 +72,8 @@ def get_liked_post_count(user, auth_user): | ||||
|     return user.liked_post_count | ||||
| 
 | ||||
| 
 | ||||
| def get_disliked_post_count(user, auth_user): | ||||
| def get_disliked_post_count( | ||||
|         user: model.User, auth_user: model.User) -> Union[bool, int]: | ||||
|     assert user | ||||
|     assert auth_user | ||||
|     if auth_user.user_id != user.user_id: | ||||
| @ -75,91 +81,144 @@ def get_disliked_post_count(user, auth_user): | ||||
|     return user.disliked_post_count | ||||
| 
 | ||||
| 
 | ||||
| def serialize_user(user, auth_user, options=None, force_show_email=False): | ||||
|     return util.serialize_entity( | ||||
|         user, | ||||
|         { | ||||
|             'name': lambda: user.name, | ||||
|             'creationTime': lambda: user.creation_time, | ||||
|             'lastLoginTime': lambda: user.last_login_time, | ||||
|             'version': lambda: user.version, | ||||
|             'rank': lambda: user.rank, | ||||
|             'avatarStyle': lambda: user.avatar_style, | ||||
|             'avatarUrl': lambda: get_avatar_url(user), | ||||
|             'commentCount': lambda: user.comment_count, | ||||
|             'uploadedPostCount': lambda: user.post_count, | ||||
|             'favoritePostCount': lambda: user.favorite_post_count, | ||||
|             'likedPostCount': | ||||
|                 lambda: get_liked_post_count(user, auth_user), | ||||
|             'dislikedPostCount': | ||||
|                 lambda: get_disliked_post_count(user, auth_user), | ||||
|             'email': | ||||
|                 lambda: get_email(user, auth_user, force_show_email), | ||||
|         }, | ||||
|         options) | ||||
| class UserSerializer(serialization.BaseSerializer): | ||||
|     def __init__( | ||||
|             self, | ||||
|             user: model.User, | ||||
|             auth_user: model.User, | ||||
|             force_show_email: bool=False) -> None: | ||||
|         self.user = user | ||||
|         self.auth_user = auth_user | ||||
|         self.force_show_email = force_show_email | ||||
| 
 | ||||
|     def _serializers(self) -> Dict[str, Callable[[], Any]]: | ||||
|         return { | ||||
|             'name': self.serialize_name, | ||||
|             'creationTime': self.serialize_creation_time, | ||||
|             'lastLoginTime': self.serialize_last_login_time, | ||||
|             'version': self.serialize_version, | ||||
|             'rank': self.serialize_rank, | ||||
|             'avatarStyle': self.serialize_avatar_style, | ||||
|             'avatarUrl': self.serialize_avatar_url, | ||||
|             'commentCount': self.serialize_comment_count, | ||||
|             'uploadedPostCount': self.serialize_uploaded_post_count, | ||||
|             'favoritePostCount': self.serialize_favorite_post_count, | ||||
|             'likedPostCount': self.serialize_liked_post_count, | ||||
|             'dislikedPostCount': self.serialize_disliked_post_count, | ||||
|             'email': self.serialize_email, | ||||
|         } | ||||
| 
 | ||||
|     def serialize_name(self) -> Any: | ||||
|         return self.user.name | ||||
| 
 | ||||
|     def serialize_creation_time(self) -> Any: | ||||
|         return self.user.creation_time | ||||
| 
 | ||||
|     def serialize_last_login_time(self) -> Any: | ||||
|         return self.user.last_login_time | ||||
| 
 | ||||
|     def serialize_version(self) -> Any: | ||||
|         return self.user.version | ||||
| 
 | ||||
|     def serialize_rank(self) -> Any: | ||||
|         return self.user.rank | ||||
| 
 | ||||
|     def serialize_avatar_style(self) -> Any: | ||||
|         return self.user.avatar_style | ||||
| 
 | ||||
|     def serialize_avatar_url(self) -> Any: | ||||
|         return get_avatar_url(self.user) | ||||
| 
 | ||||
|     def serialize_comment_count(self) -> Any: | ||||
|         return self.user.comment_count | ||||
| 
 | ||||
|     def serialize_uploaded_post_count(self) -> Any: | ||||
|         return self.user.post_count | ||||
| 
 | ||||
|     def serialize_favorite_post_count(self) -> Any: | ||||
|         return self.user.favorite_post_count | ||||
| 
 | ||||
|     def serialize_liked_post_count(self) -> Any: | ||||
|         return get_liked_post_count(self.user, self.auth_user) | ||||
| 
 | ||||
|     def serialize_disliked_post_count(self) -> Any: | ||||
|         return get_disliked_post_count(self.user, self.auth_user) | ||||
| 
 | ||||
|     def serialize_email(self) -> Any: | ||||
|         return get_email(self.user, self.auth_user, self.force_show_email) | ||||
| 
 | ||||
| 
 | ||||
| def serialize_micro_user(user, auth_user): | ||||
| def serialize_user( | ||||
|         user: Optional[model.User], | ||||
|         auth_user: model.User, | ||||
|         options: List[str]=[], | ||||
|         force_show_email: bool=False) -> Optional[rest.Response]: | ||||
|     if not user: | ||||
|         return None | ||||
|     return UserSerializer(user, auth_user, force_show_email).serialize(options) | ||||
| 
 | ||||
| 
 | ||||
| def serialize_micro_user( | ||||
|         user: Optional[model.User], | ||||
|         auth_user: model.User) -> Optional[rest.Response]: | ||||
|     return serialize_user( | ||||
|         user, | ||||
|         auth_user=auth_user, | ||||
|         options=['name', 'avatarUrl']) | ||||
|         user, auth_user=auth_user, options=['name', 'avatarUrl']) | ||||
| 
 | ||||
| 
 | ||||
| def get_user_count(): | ||||
|     return db.session.query(db.User).count() | ||||
| def get_user_count() -> int: | ||||
|     return db.session.query(model.User).count() | ||||
| 
 | ||||
| 
 | ||||
| def try_get_user_by_name(name): | ||||
| def try_get_user_by_name(name: str) -> Optional[model.User]: | ||||
|     return db.session \ | ||||
|         .query(db.User) \ | ||||
|         .filter(func.lower(db.User.name) == func.lower(name)) \ | ||||
|         .query(model.User) \ | ||||
|         .filter(sa.func.lower(model.User.name) == sa.func.lower(name)) \ | ||||
|         .one_or_none() | ||||
| 
 | ||||
| 
 | ||||
| def get_user_by_name(name): | ||||
| def get_user_by_name(name: str) -> model.User: | ||||
|     user = try_get_user_by_name(name) | ||||
|     if not user: | ||||
|         raise UserNotFoundError('User %r not found.' % name) | ||||
|     return user | ||||
| 
 | ||||
| 
 | ||||
| def try_get_user_by_name_or_email(name_or_email): | ||||
| def try_get_user_by_name_or_email(name_or_email: str) -> Optional[model.User]: | ||||
|     return ( | ||||
|         db.session | ||||
|         .query(db.User) | ||||
|         .query(model.User) | ||||
|         .filter( | ||||
|             (func.lower(db.User.name) == func.lower(name_or_email)) | | ||||
|             (func.lower(db.User.email) == func.lower(name_or_email))) | ||||
|             (sa.func.lower(model.User.name) == sa.func.lower(name_or_email)) | | ||||
|             (sa.func.lower(model.User.email) == sa.func.lower(name_or_email))) | ||||
|         .one_or_none()) | ||||
| 
 | ||||
| 
 | ||||
| def get_user_by_name_or_email(name_or_email): | ||||
| def get_user_by_name_or_email(name_or_email: str) -> model.User: | ||||
|     user = try_get_user_by_name_or_email(name_or_email) | ||||
|     if not user: | ||||
|         raise UserNotFoundError('User %r not found.' % name_or_email) | ||||
|     return user | ||||
| 
 | ||||
| 
 | ||||
| def create_user(name, password, email): | ||||
|     user = db.User() | ||||
| def create_user(name: str, password: str, email: str) -> model.User: | ||||
|     user = model.User() | ||||
|     update_user_name(user, name) | ||||
|     update_user_password(user, password) | ||||
|     update_user_email(user, email) | ||||
|     if get_user_count() > 0: | ||||
|         user.rank = util.flip(auth.RANK_MAP)[config.config['default_rank']] | ||||
|     else: | ||||
|         user.rank = db.User.RANK_ADMINISTRATOR | ||||
|     user.creation_time = datetime.datetime.utcnow() | ||||
|     user.avatar_style = db.User.AVATAR_GRAVATAR | ||||
|         user.rank = model.User.RANK_ADMINISTRATOR | ||||
|     user.creation_time = datetime.utcnow() | ||||
|     user.avatar_style = model.User.AVATAR_GRAVATAR | ||||
|     return user | ||||
| 
 | ||||
| 
 | ||||
| def update_user_name(user, name): | ||||
| def update_user_name(user: model.User, name: str) -> None: | ||||
|     assert user | ||||
|     if not name: | ||||
|         raise InvalidUserNameError('Name cannot be empty.') | ||||
|     if util.value_exceeds_column_size(name, db.User.name): | ||||
|     if util.value_exceeds_column_size(name, model.User.name): | ||||
|         raise InvalidUserNameError('User name is too long.') | ||||
|     name = name.strip() | ||||
|     name_regex = config.config['user_name_regex'] | ||||
| @ -174,7 +233,7 @@ def update_user_name(user, name): | ||||
|     user.name = name | ||||
| 
 | ||||
| 
 | ||||
| def update_user_password(user, password): | ||||
| def update_user_password(user: model.User, password: str) -> None: | ||||
|     assert user | ||||
|     if not password: | ||||
|         raise InvalidPasswordError('Password cannot be empty.') | ||||
| @ -186,20 +245,18 @@ def update_user_password(user, password): | ||||
|     user.password_hash = auth.get_password_hash(user.password_salt, password) | ||||
| 
 | ||||
| 
 | ||||
| def update_user_email(user, email): | ||||
| def update_user_email(user: model.User, email: str) -> None: | ||||
|     assert user | ||||
|     if email: | ||||
|         email = email.strip() | ||||
|     if not email: | ||||
|         email = None | ||||
|     if email and util.value_exceeds_column_size(email, db.User.email): | ||||
|     email = email.strip() | ||||
|     if util.value_exceeds_column_size(email, model.User.email): | ||||
|         raise InvalidEmailError('Email is too long.') | ||||
|     if not util.is_valid_email(email): | ||||
|         raise InvalidEmailError('E-mail is invalid.') | ||||
|     user.email = email | ||||
|     user.email = email or None | ||||
| 
 | ||||
| 
 | ||||
| def update_user_rank(user, rank, auth_user): | ||||
| def update_user_rank( | ||||
|         user: model.User, rank: str, auth_user: model.User) -> None: | ||||
|     assert user | ||||
|     if not rank: | ||||
|         raise InvalidRankError('Rank cannot be empty.') | ||||
| @ -208,7 +265,7 @@ def update_user_rank(user, rank, auth_user): | ||||
|     if not rank: | ||||
|         raise InvalidRankError( | ||||
|             'Rank can be either of %r.' % all_ranks) | ||||
|     if rank in (db.User.RANK_ANONYMOUS, db.User.RANK_NOBODY): | ||||
|     if rank in (model.User.RANK_ANONYMOUS, model.User.RANK_NOBODY): | ||||
|         raise InvalidRankError('Rank %r cannot be used.' % auth.RANK_MAP[rank]) | ||||
|     if all_ranks.index(auth_user.rank) \ | ||||
|             < all_ranks.index(rank) and get_user_count() > 0: | ||||
| @ -216,7 +273,10 @@ def update_user_rank(user, rank, auth_user): | ||||
|     user.rank = rank | ||||
| 
 | ||||
| 
 | ||||
| def update_user_avatar(user, avatar_style, avatar_content=None): | ||||
| def update_user_avatar( | ||||
|         user: model.User, | ||||
|         avatar_style: str, | ||||
|         avatar_content: Optional[bytes]=None) -> None: | ||||
|     assert user | ||||
|     if avatar_style == 'gravatar': | ||||
|         user.avatar_style = user.AVATAR_GRAVATAR | ||||
| @ -238,12 +298,12 @@ def update_user_avatar(user, avatar_style, avatar_content=None): | ||||
|                 avatar_style, ['gravatar', 'manual'])) | ||||
| 
 | ||||
| 
 | ||||
| def bump_user_login_time(user): | ||||
| def bump_user_login_time(user: model.User) -> None: | ||||
|     assert user | ||||
|     user.last_login_time = datetime.datetime.utcnow() | ||||
|     user.last_login_time = datetime.utcnow() | ||||
| 
 | ||||
| 
 | ||||
| def reset_user_password(user): | ||||
| def reset_user_password(user: model.User) -> str: | ||||
|     assert user | ||||
|     password = auth.create_password() | ||||
|     user.password_salt = auth.create_password() | ||||
|  | ||||
| @ -2,52 +2,39 @@ import os | ||||
| import hashlib | ||||
| import re | ||||
| import tempfile | ||||
| from typing import ( | ||||
|     Any, Optional, Union, Tuple, List, Dict, Generator, Callable, TypeVar) | ||||
| from datetime import datetime, timedelta | ||||
| from contextlib import contextmanager | ||||
| from szurubooru import errors | ||||
| 
 | ||||
| 
 | ||||
| def snake_case_to_lower_camel_case(text): | ||||
| T = TypeVar('T') | ||||
| 
 | ||||
| 
 | ||||
| def snake_case_to_lower_camel_case(text: str) -> str: | ||||
|     components = text.split('_') | ||||
|     return components[0].lower() + \ | ||||
|         ''.join(word[0].upper() + word[1:].lower() for word in components[1:]) | ||||
| 
 | ||||
| 
 | ||||
| def snake_case_to_upper_train_case(text): | ||||
| def snake_case_to_upper_train_case(text: str) -> str: | ||||
|     return '-'.join( | ||||
|         word[0].upper() + word[1:].lower() for word in text.split('_')) | ||||
| 
 | ||||
| 
 | ||||
| def snake_case_to_lower_camel_case_keys(source): | ||||
| def snake_case_to_lower_camel_case_keys( | ||||
|         source: Dict[str, Any]) -> Dict[str, Any]: | ||||
|     target = {} | ||||
|     for key, value in source.items(): | ||||
|         target[snake_case_to_lower_camel_case(key)] = value | ||||
|     return target | ||||
| 
 | ||||
| 
 | ||||
| def get_serialization_options(ctx): | ||||
|     return ctx.get_param_as_list('fields', required=False, default=None) | ||||
| 
 | ||||
| 
 | ||||
| def serialize_entity(entity, field_factories, options): | ||||
|     if not entity: | ||||
|         return None | ||||
|     if not options or len(options) == 0: | ||||
|         options = field_factories.keys() | ||||
|     ret = {} | ||||
|     for key in options: | ||||
|         if key not in field_factories: | ||||
|             raise errors.ValidationError('Invalid key: %r. Valid keys: %r.' % ( | ||||
|                 key, list(sorted(field_factories.keys())))) | ||||
|         factory = field_factories[key] | ||||
|         ret[key] = factory() | ||||
|     return ret | ||||
| 
 | ||||
| 
 | ||||
| @contextmanager | ||||
| def create_temp_file(**kwargs): | ||||
|     (handle, path) = tempfile.mkstemp(**kwargs) | ||||
|     os.close(handle) | ||||
| def create_temp_file(**kwargs: Any) -> Generator: | ||||
|     (descriptor, path) = tempfile.mkstemp(**kwargs) | ||||
|     os.close(descriptor) | ||||
|     try: | ||||
|         with open(path, 'r+b') as handle: | ||||
|             yield handle | ||||
| @ -55,17 +42,15 @@ def create_temp_file(**kwargs): | ||||
|         os.remove(path) | ||||
| 
 | ||||
| 
 | ||||
| def unalias_dict(input_dict): | ||||
|     output_dict = {} | ||||
|     for key_list, value in input_dict.items(): | ||||
|         if isinstance(key_list, str): | ||||
|             key_list = [key_list] | ||||
|         for key in key_list: | ||||
|             output_dict[key] = value | ||||
| def unalias_dict(source: List[Tuple[List[str], T]]) -> Dict[str, T]: | ||||
|     output_dict = {}  # type: Dict[str, T] | ||||
|     for aliases, value in source: | ||||
|         for alias in aliases: | ||||
|             output_dict[alias] = value | ||||
|     return output_dict | ||||
| 
 | ||||
| 
 | ||||
| def get_md5(source): | ||||
| def get_md5(source: Union[str, bytes]) -> str: | ||||
|     if not isinstance(source, bytes): | ||||
|         source = source.encode('utf-8') | ||||
|     md5 = hashlib.md5() | ||||
| @ -73,7 +58,7 @@ def get_md5(source): | ||||
|     return md5.hexdigest() | ||||
| 
 | ||||
| 
 | ||||
| def get_sha1(source): | ||||
| def get_sha1(source: Union[str, bytes]) -> str: | ||||
|     if not isinstance(source, bytes): | ||||
|         source = source.encode('utf-8') | ||||
|     sha1 = hashlib.sha1() | ||||
| @ -81,24 +66,25 @@ def get_sha1(source): | ||||
|     return sha1.hexdigest() | ||||
| 
 | ||||
| 
 | ||||
| def flip(source): | ||||
| def flip(source: Dict[Any, Any]) -> Dict[Any, Any]: | ||||
|     return {v: k for k, v in source.items()} | ||||
| 
 | ||||
| 
 | ||||
| def is_valid_email(email): | ||||
| def is_valid_email(email: Optional[str]) -> bool: | ||||
|     ''' Return whether given email address is valid or empty. ''' | ||||
|     return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) | ||||
|     return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email) is not None | ||||
| 
 | ||||
| 
 | ||||
| class dotdict(dict):  # pylint: disable=invalid-name | ||||
|     ''' dot.notation access to dictionary attributes. ''' | ||||
|     def __getattr__(self, attr): | ||||
|     def __getattr__(self, attr: str) -> Any: | ||||
|         return self.get(attr) | ||||
| 
 | ||||
|     __setattr__ = dict.__setitem__ | ||||
|     __delattr__ = dict.__delitem__ | ||||
| 
 | ||||
| 
 | ||||
| def parse_time_range(value): | ||||
| def parse_time_range(value: str) -> Tuple[datetime, datetime]: | ||||
|     ''' Return tuple containing min/max time for given text representation. ''' | ||||
|     one_day = timedelta(days=1) | ||||
|     one_second = timedelta(seconds=1) | ||||
| @ -146,9 +132,9 @@ def parse_time_range(value): | ||||
|     raise errors.ValidationError('Invalid date format: %r.' % value) | ||||
| 
 | ||||
| 
 | ||||
| def icase_unique(source): | ||||
|     target = [] | ||||
|     target_low = [] | ||||
| def icase_unique(source: List[str]) -> List[str]: | ||||
|     target = []  # type: List[str] | ||||
|     target_low = []  # type: List[str] | ||||
|     for source_item in source: | ||||
|         if source_item.lower() not in target_low: | ||||
|             target.append(source_item) | ||||
| @ -156,7 +142,7 @@ def icase_unique(source): | ||||
|     return target | ||||
| 
 | ||||
| 
 | ||||
| def value_exceeds_column_size(value, column): | ||||
| def value_exceeds_column_size(value: Optional[str], column: Any) -> bool: | ||||
|     if not value: | ||||
|         return False | ||||
|     max_length = column.property.columns[0].type.length | ||||
| @ -165,6 +151,6 @@ def value_exceeds_column_size(value, column): | ||||
|     return len(value) > max_length | ||||
| 
 | ||||
| 
 | ||||
| def chunks(source_list, part_size): | ||||
| def chunks(source_list: List[Any], part_size: int) -> Generator: | ||||
|     for i in range(0, len(source_list), part_size): | ||||
|         yield source_list[i:i + part_size] | ||||
|  | ||||
| @ -1,8 +1,11 @@ | ||||
| from szurubooru import errors | ||||
| from szurubooru import errors, rest, model | ||||
| 
 | ||||
| 
 | ||||
| def verify_version(entity, context, field_name='version'): | ||||
|     actual_version = context.get_param_as_int(field_name, required=True) | ||||
| def verify_version( | ||||
|         entity: model.Base, | ||||
|         context: rest.Context, | ||||
|         field_name: str='version') -> None: | ||||
|     actual_version = context.get_param_as_int(field_name) | ||||
|     expected_version = entity.version | ||||
|     if actual_version != expected_version: | ||||
|         raise errors.IntegrityError( | ||||
| @ -10,5 +13,5 @@ def verify_version(entity, context, field_name='version'): | ||||
|             'Please try again.') | ||||
| 
 | ||||
| 
 | ||||
| def bump_version(entity): | ||||
| def bump_version(entity: model.Base) -> None: | ||||
|     entity.version = entity.version + 1 | ||||
|  | ||||
| @ -1,11 +1,11 @@ | ||||
| import base64 | ||||
| from szurubooru import db, errors | ||||
| from typing import Optional | ||||
| from szurubooru import db, model, errors, rest | ||||
| from szurubooru.func import auth, users | ||||
| from szurubooru.rest import middleware | ||||
| from szurubooru.rest.errors import HttpBadRequest | ||||
| 
 | ||||
| 
 | ||||
| def _authenticate(username, password): | ||||
| def _authenticate(username: str, password: str) -> model.User: | ||||
|     ''' Try to authenticate user. Throw AuthError for invalid users. ''' | ||||
|     user = users.get_user_by_name(username) | ||||
|     if not auth.is_valid_password(user, password): | ||||
| @ -13,16 +13,9 @@ def _authenticate(username, password): | ||||
|     return user | ||||
| 
 | ||||
| 
 | ||||
| def _create_anonymous_user(): | ||||
|     user = db.User() | ||||
|     user.name = None | ||||
|     user.rank = 'anonymous' | ||||
|     return user | ||||
| 
 | ||||
| 
 | ||||
| def _get_user(ctx): | ||||
| def _get_user(ctx: rest.Context) -> Optional[model.User]: | ||||
|     if not ctx.has_header('Authorization'): | ||||
|         return _create_anonymous_user() | ||||
|         return None | ||||
| 
 | ||||
|     try: | ||||
|         auth_type, credentials = ctx.get_header('Authorization').split(' ', 1) | ||||
| @ -41,10 +34,12 @@ def _get_user(ctx): | ||||
|             msg.format(ctx.get_header('Authorization'), str(err))) | ||||
| 
 | ||||
| 
 | ||||
| @middleware.pre_hook | ||||
| def process_request(ctx): | ||||
| @rest.middleware.pre_hook | ||||
| def process_request(ctx: rest.Context) -> None: | ||||
|     ''' Bind the user to request. Update last login time if needed. ''' | ||||
|     ctx.user = _get_user(ctx) | ||||
|     if ctx.get_param_as_bool('bump-login') and ctx.user.user_id: | ||||
|     auth_user = _get_user(ctx) | ||||
|     if auth_user: | ||||
|         ctx.user = auth_user | ||||
|     if ctx.get_param_as_bool('bump-login', default=False) and ctx.user.user_id: | ||||
|         users.bump_user_login_time(ctx.user) | ||||
|         ctx.session.commit() | ||||
|  | ||||
| @ -1,8 +1,9 @@ | ||||
| from szurubooru import rest | ||||
| from szurubooru.func import cache | ||||
| from szurubooru.rest import middleware | ||||
| 
 | ||||
| 
 | ||||
| @middleware.pre_hook | ||||
| def process_request(ctx): | ||||
| def process_request(ctx: rest.Context) -> None: | ||||
|     if ctx.method != 'GET': | ||||
|         cache.purge() | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| import logging | ||||
| from szurubooru import db | ||||
| from szurubooru import db, rest | ||||
| from szurubooru.rest import middleware | ||||
| 
 | ||||
| 
 | ||||
| @ -7,12 +7,12 @@ logger = logging.getLogger(__name__) | ||||
| 
 | ||||
| 
 | ||||
| @middleware.pre_hook | ||||
| def process_request(_ctx): | ||||
| def process_request(_ctx: rest.Context) -> None: | ||||
|     db.reset_query_count() | ||||
| 
 | ||||
| 
 | ||||
| @middleware.post_hook | ||||
| def process_response(ctx): | ||||
| def process_response(ctx: rest.Context) -> None: | ||||
|     logger.info( | ||||
|         '%s %s (user=%s, queries=%d)', | ||||
|         ctx.method, | ||||
|  | ||||
| @ -2,7 +2,7 @@ import os | ||||
| import sys | ||||
| 
 | ||||
| import alembic | ||||
| import sqlalchemy | ||||
| import sqlalchemy as sa | ||||
| import logging.config | ||||
| 
 | ||||
| # make szurubooru module importable | ||||
| @ -48,7 +48,7 @@ def run_migrations_online(): | ||||
|     In this scenario we need to create an Engine | ||||
|     and associate a connection with the context. | ||||
|     ''' | ||||
|     connectable = sqlalchemy.engine_from_config( | ||||
|     connectable = sa.engine_from_config( | ||||
|         alembic_config.get_section(alembic_config.config_ini_section), | ||||
|         prefix='sqlalchemy.', | ||||
|         poolclass=sqlalchemy.pool.NullPool) | ||||
|  | ||||
							
								
								
									
										15
									
								
								server/szurubooru/model/__init__.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										15
									
								
								server/szurubooru/model/__init__.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,15 @@ | ||||
| from szurubooru.model.base import Base | ||||
| from szurubooru.model.user import User | ||||
| from szurubooru.model.tag_category import TagCategory | ||||
| from szurubooru.model.tag import Tag, TagName, TagSuggestion, TagImplication | ||||
| from szurubooru.model.post import ( | ||||
|     Post, | ||||
|     PostTag, | ||||
|     PostRelation, | ||||
|     PostFavorite, | ||||
|     PostScore, | ||||
|     PostNote, | ||||
|     PostFeature) | ||||
| from szurubooru.model.comment import Comment, CommentScore | ||||
| from szurubooru.model.snapshot import Snapshot | ||||
| import szurubooru.model.util | ||||
| @ -1,7 +1,8 @@ | ||||
| from sqlalchemy import Column, Integer, DateTime, UnicodeText, ForeignKey | ||||
| from sqlalchemy.orm import relationship, backref | ||||
| from sqlalchemy.sql.expression import func | ||||
| from szurubooru.db.base import Base | ||||
| from szurubooru.db import get_session | ||||
| from szurubooru.model.base import Base | ||||
| 
 | ||||
| 
 | ||||
| class CommentScore(Base): | ||||
| @ -48,12 +49,12 @@ class Comment(Base): | ||||
|         'CommentScore', cascade='all, delete-orphan', lazy='joined') | ||||
| 
 | ||||
|     @property | ||||
|     def score(self): | ||||
|         from szurubooru.db import session | ||||
|         return session \ | ||||
|             .query(func.sum(CommentScore.score)) \ | ||||
|             .filter(CommentScore.comment_id == self.comment_id) \ | ||||
|             .one()[0] or 0 | ||||
|     def score(self) -> int: | ||||
|         return ( | ||||
|             get_session() | ||||
|             .query(func.sum(CommentScore.score)) | ||||
|             .filter(CommentScore.comment_id == self.comment_id) | ||||
|             .one()[0] or 0) | ||||
| 
 | ||||
|     __mapper_args__ = { | ||||
|         'version_id_col': version, | ||||
| @ -3,8 +3,8 @@ from sqlalchemy import ( | ||||
|     Column, Integer, DateTime, Unicode, UnicodeText, PickleType, ForeignKey) | ||||
| from sqlalchemy.orm import ( | ||||
|     relationship, column_property, object_session, backref) | ||||
| from szurubooru.db.base import Base | ||||
| from szurubooru.db.comment import Comment | ||||
| from szurubooru.model.base import Base | ||||
| from szurubooru.model.comment import Comment | ||||
| 
 | ||||
| 
 | ||||
| class PostFeature(Base): | ||||
| @ -17,10 +17,9 @@ class PostFeature(Base): | ||||
|         'user_id', Integer, ForeignKey('user.id'), nullable=False, index=True) | ||||
|     time = Column('time', DateTime, nullable=False) | ||||
| 
 | ||||
|     post = relationship('Post') | ||||
|     post = relationship('Post')  # type: Post | ||||
|     user = relationship( | ||||
|         'User', | ||||
|         backref=backref('post_features', cascade='all, delete-orphan')) | ||||
|         'User', backref=backref('post_features', cascade='all, delete-orphan')) | ||||
| 
 | ||||
| 
 | ||||
| class PostScore(Base): | ||||
| @ -104,7 +103,7 @@ class PostRelation(Base): | ||||
|         nullable=False, | ||||
|         index=True) | ||||
| 
 | ||||
|     def __init__(self, parent_id, child_id): | ||||
|     def __init__(self, parent_id: int, child_id: int) -> None: | ||||
|         self.parent_id = parent_id | ||||
|         self.child_id = child_id | ||||
| 
 | ||||
| @ -127,7 +126,7 @@ class PostTag(Base): | ||||
|         nullable=False, | ||||
|         index=True) | ||||
| 
 | ||||
|     def __init__(self, post_id, tag_id): | ||||
|     def __init__(self, post_id: int, tag_id: int) -> None: | ||||
|         self.post_id = post_id | ||||
|         self.tag_id = tag_id | ||||
| 
 | ||||
| @ -197,7 +196,7 @@ class Post(Base): | ||||
|     canvas_area = column_property(canvas_width * canvas_height) | ||||
| 
 | ||||
|     @property | ||||
|     def is_featured(self): | ||||
|     def is_featured(self) -> bool: | ||||
|         featured_post = object_session(self) \ | ||||
|             .query(PostFeature) \ | ||||
|             .order_by(PostFeature.time.desc()) \ | ||||
| @ -1,7 +1,7 @@ | ||||
| from sqlalchemy.orm import relationship | ||||
| from sqlalchemy import ( | ||||
|     Column, Integer, DateTime, Unicode, PickleType, ForeignKey) | ||||
| from szurubooru.db.base import Base | ||||
| from szurubooru.model.base import Base | ||||
| 
 | ||||
| 
 | ||||
| class Snapshot(Base): | ||||
| @ -2,8 +2,8 @@ from sqlalchemy import ( | ||||
|     Column, Integer, DateTime, Unicode, UnicodeText, ForeignKey) | ||||
| from sqlalchemy.orm import relationship, column_property | ||||
| from sqlalchemy.sql.expression import func, select | ||||
| from szurubooru.db.base import Base | ||||
| from szurubooru.db.post import PostTag | ||||
| from szurubooru.model.base import Base | ||||
| from szurubooru.model.post import PostTag | ||||
| 
 | ||||
| 
 | ||||
| class TagSuggestion(Base): | ||||
| @ -24,7 +24,7 @@ class TagSuggestion(Base): | ||||
|         primary_key=True, | ||||
|         index=True) | ||||
| 
 | ||||
|     def __init__(self, parent_id, child_id): | ||||
|     def __init__(self, parent_id: int, child_id: int) -> None: | ||||
|         self.parent_id = parent_id | ||||
|         self.child_id = child_id | ||||
| 
 | ||||
| @ -47,7 +47,7 @@ class TagImplication(Base): | ||||
|         primary_key=True, | ||||
|         index=True) | ||||
| 
 | ||||
|     def __init__(self, parent_id, child_id): | ||||
|     def __init__(self, parent_id: int, child_id: int) -> None: | ||||
|         self.parent_id = parent_id | ||||
|         self.child_id = child_id | ||||
| 
 | ||||
| @ -61,7 +61,7 @@ class TagName(Base): | ||||
|     name = Column('name', Unicode(64), nullable=False, unique=True) | ||||
|     order = Column('ord', Integer, nullable=False, index=True) | ||||
| 
 | ||||
|     def __init__(self, name, order): | ||||
|     def __init__(self, name: str, order: int) -> None: | ||||
|         self.name = name | ||||
|         self.order = order | ||||
| 
 | ||||
| @ -1,8 +1,9 @@ | ||||
| from typing import Optional | ||||
| from sqlalchemy import Column, Integer, Unicode, Boolean, table | ||||
| from sqlalchemy.orm import column_property | ||||
| from sqlalchemy.sql.expression import func, select | ||||
| from szurubooru.db.base import Base | ||||
| from szurubooru.db.tag import Tag | ||||
| from szurubooru.model.base import Base | ||||
| from szurubooru.model.tag import Tag | ||||
| 
 | ||||
| 
 | ||||
| class TagCategory(Base): | ||||
| @ -14,7 +15,7 @@ class TagCategory(Base): | ||||
|     color = Column('color', Unicode(32), nullable=False, default='#000000') | ||||
|     default = Column('default', Boolean, nullable=False, default=False) | ||||
| 
 | ||||
|     def __init__(self, name=None): | ||||
|     def __init__(self, name: Optional[str]=None) -> None: | ||||
|         self.name = name | ||||
| 
 | ||||
|     tag_count = column_property( | ||||
| @ -1,9 +1,7 @@ | ||||
| from sqlalchemy import Column, Integer, Unicode, DateTime | ||||
| from sqlalchemy.orm import relationship | ||||
| from sqlalchemy.sql.expression import func | ||||
| from szurubooru.db.base import Base | ||||
| from szurubooru.db.post import Post, PostScore, PostFavorite | ||||
| from szurubooru.db.comment import Comment | ||||
| import sqlalchemy as sa | ||||
| from szurubooru.model.base import Base | ||||
| from szurubooru.model.post import Post, PostScore, PostFavorite | ||||
| from szurubooru.model.comment import Comment | ||||
| 
 | ||||
| 
 | ||||
| class User(Base): | ||||
| @ -20,63 +18,64 @@ class User(Base): | ||||
|     RANK_ADMINISTRATOR = 'administrator' | ||||
|     RANK_NOBODY = 'nobody'  # unattainable, used for privileges | ||||
| 
 | ||||
|     user_id = Column('id', Integer, primary_key=True) | ||||
|     creation_time = Column('creation_time', DateTime, nullable=False) | ||||
|     last_login_time = Column('last_login_time', DateTime) | ||||
|     version = Column('version', Integer, default=1, nullable=False) | ||||
|     name = Column('name', Unicode(50), nullable=False, unique=True) | ||||
|     password_hash = Column('password_hash', Unicode(64), nullable=False) | ||||
|     password_salt = Column('password_salt', Unicode(32)) | ||||
|     email = Column('email', Unicode(64), nullable=True) | ||||
|     rank = Column('rank', Unicode(32), nullable=False) | ||||
|     avatar_style = Column( | ||||
|         'avatar_style', Unicode(32), nullable=False, default=AVATAR_GRAVATAR) | ||||
|     user_id = sa.Column('id', sa.Integer, primary_key=True) | ||||
|     creation_time = sa.Column('creation_time', sa.DateTime, nullable=False) | ||||
|     last_login_time = sa.Column('last_login_time', sa.DateTime) | ||||
|     version = sa.Column('version', sa.Integer, default=1, nullable=False) | ||||
|     name = sa.Column('name', sa.Unicode(50), nullable=False, unique=True) | ||||
|     password_hash = sa.Column('password_hash', sa.Unicode(64), nullable=False) | ||||
|     password_salt = sa.Column('password_salt', sa.Unicode(32)) | ||||
|     email = sa.Column('email', sa.Unicode(64), nullable=True) | ||||
|     rank = sa.Column('rank', sa.Unicode(32), nullable=False) | ||||
|     avatar_style = sa.Column( | ||||
|         'avatar_style', sa.Unicode(32), nullable=False, | ||||
|         default=AVATAR_GRAVATAR) | ||||
| 
 | ||||
|     comments = relationship('Comment') | ||||
|     comments = sa.orm.relationship('Comment') | ||||
| 
 | ||||
|     @property | ||||
|     def post_count(self): | ||||
|     def post_count(self) -> int: | ||||
|         from szurubooru.db import session | ||||
|         return ( | ||||
|             session | ||||
|             .query(func.sum(1)) | ||||
|             .query(sa.sql.expression.func.sum(1)) | ||||
|             .filter(Post.user_id == self.user_id) | ||||
|             .one()[0] or 0) | ||||
| 
 | ||||
|     @property | ||||
|     def comment_count(self): | ||||
|     def comment_count(self) -> int: | ||||
|         from szurubooru.db import session | ||||
|         return ( | ||||
|             session | ||||
|             .query(func.sum(1)) | ||||
|             .query(sa.sql.expression.func.sum(1)) | ||||
|             .filter(Comment.user_id == self.user_id) | ||||
|             .one()[0] or 0) | ||||
| 
 | ||||
|     @property | ||||
|     def favorite_post_count(self): | ||||
|     def favorite_post_count(self) -> int: | ||||
|         from szurubooru.db import session | ||||
|         return ( | ||||
|             session | ||||
|             .query(func.sum(1)) | ||||
|             .query(sa.sql.expression.func.sum(1)) | ||||
|             .filter(PostFavorite.user_id == self.user_id) | ||||
|             .one()[0] or 0) | ||||
| 
 | ||||
|     @property | ||||
|     def liked_post_count(self): | ||||
|     def liked_post_count(self) -> int: | ||||
|         from szurubooru.db import session | ||||
|         return ( | ||||
|             session | ||||
|             .query(func.sum(1)) | ||||
|             .query(sa.sql.expression.func.sum(1)) | ||||
|             .filter(PostScore.user_id == self.user_id) | ||||
|             .filter(PostScore.score == 1) | ||||
|             .one()[0] or 0) | ||||
| 
 | ||||
|     @property | ||||
|     def disliked_post_count(self): | ||||
|     def disliked_post_count(self) -> int: | ||||
|         from szurubooru.db import session | ||||
|         return ( | ||||
|             session | ||||
|             .query(func.sum(1)) | ||||
|             .query(sa.sql.expression.func.sum(1)) | ||||
|             .filter(PostScore.user_id == self.user_id) | ||||
|             .filter(PostScore.score == -1) | ||||
|             .one()[0] or 0) | ||||
							
								
								
									
										42
									
								
								server/szurubooru/model/util.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										42
									
								
								server/szurubooru/model/util.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,42 @@ | ||||
| from typing import Tuple, Any, Dict, Callable, Union, Optional | ||||
| import sqlalchemy as sa | ||||
| from szurubooru.model.base import Base | ||||
| from szurubooru.model.user import User | ||||
| 
 | ||||
| 
 | ||||
| def get_resource_info(entity: Base) -> Tuple[Any, Any, Union[str, int]]: | ||||
|     serializers = { | ||||
|         'tag': lambda tag: tag.first_name, | ||||
|         'tag_category': lambda category: category.name, | ||||
|         'comment': lambda comment: comment.comment_id, | ||||
|         'post': lambda post: post.post_id, | ||||
|     }  # type: Dict[str, Callable[[Base], Any]] | ||||
| 
 | ||||
|     resource_type = entity.__table__.name | ||||
|     assert resource_type in serializers | ||||
| 
 | ||||
|     primary_key = sa.inspection.inspect(entity).identity  # type: Any | ||||
|     assert primary_key is not None | ||||
|     assert len(primary_key) == 1 | ||||
| 
 | ||||
|     resource_name = serializers[resource_type](entity)  # type: Union[str, int] | ||||
|     assert resource_name | ||||
| 
 | ||||
|     resource_pkey = primary_key[0]  # type: Any | ||||
|     assert resource_pkey | ||||
| 
 | ||||
|     return (resource_type, resource_pkey, resource_name) | ||||
| 
 | ||||
| 
 | ||||
| def get_aux_entity( | ||||
|         session: Any, | ||||
|         get_table_info: Callable[[Base], Tuple[Base, Callable[[Base], Any]]], | ||||
|         entity: Base, | ||||
|         user: User) -> Optional[Base]: | ||||
|     table, get_column = get_table_info(entity) | ||||
|     return ( | ||||
|         session | ||||
|         .query(table) | ||||
|         .filter(get_column(table) == get_column(entity)) | ||||
|         .filter(table.user_id == user.user_id) | ||||
|         .one_or_none()) | ||||
| @ -1,2 +1,2 @@ | ||||
| from szurubooru.rest.app import application | ||||
| from szurubooru.rest.context import Context | ||||
| from szurubooru.rest.context import Context, Response | ||||
|  | ||||
| @ -2,13 +2,14 @@ import urllib.parse | ||||
| import cgi | ||||
| import json | ||||
| import re | ||||
| from typing import Dict, Any, Callable, Tuple | ||||
| from datetime import datetime | ||||
| from szurubooru import db | ||||
| from szurubooru.func import util | ||||
| from szurubooru.rest import errors, middleware, routes, context | ||||
| 
 | ||||
| 
 | ||||
| def _json_serializer(obj): | ||||
| def _json_serializer(obj: Any) -> str: | ||||
|     ''' JSON serializer for objects not serializable by default JSON code ''' | ||||
|     if isinstance(obj, datetime): | ||||
|         serial = obj.isoformat('T') + 'Z' | ||||
| @ -16,12 +17,12 @@ def _json_serializer(obj): | ||||
|     raise TypeError('Type not serializable') | ||||
| 
 | ||||
| 
 | ||||
| def _dump_json(obj): | ||||
| def _dump_json(obj: Any) -> str: | ||||
|     return json.dumps(obj, default=_json_serializer, indent=2) | ||||
| 
 | ||||
| 
 | ||||
| def _get_headers(env): | ||||
|     headers = {} | ||||
| def _get_headers(env: Dict[str, Any]) -> Dict[str, str]: | ||||
|     headers = {}  # type: Dict[str, str] | ||||
|     for key, value in env.items(): | ||||
|         if key.startswith('HTTP_'): | ||||
|             key = util.snake_case_to_upper_train_case(key[5:]) | ||||
| @ -29,7 +30,7 @@ def _get_headers(env): | ||||
|     return headers | ||||
| 
 | ||||
| 
 | ||||
| def _create_context(env): | ||||
| def _create_context(env: Dict[str, Any]) -> context.Context: | ||||
|     method = env['REQUEST_METHOD'] | ||||
|     path = '/' + env['PATH_INFO'].lstrip('/') | ||||
|     headers = _get_headers(env) | ||||
| @ -64,7 +65,9 @@ def _create_context(env): | ||||
|     return context.Context(method, path, headers, params, files) | ||||
| 
 | ||||
| 
 | ||||
| def application(env, start_response): | ||||
| def application( | ||||
|         env: Dict[str, Any], | ||||
|         start_response: Callable[[str, Any], Any]) -> Tuple[bytes]: | ||||
|     try: | ||||
|         ctx = _create_context(env) | ||||
|         if 'application/json' not in ctx.get_header('Accept'): | ||||
| @ -106,9 +109,9 @@ def application(env, start_response): | ||||
|             return (_dump_json(response).encode('utf-8'),) | ||||
| 
 | ||||
|         except Exception as ex: | ||||
|             for exception_type, handler in errors.error_handlers.items(): | ||||
|             for exception_type, ex_handler in errors.error_handlers.items(): | ||||
|                 if isinstance(ex, exception_type): | ||||
|                     handler(ex) | ||||
|                     ex_handler(ex) | ||||
|             raise | ||||
| 
 | ||||
|     except errors.BaseHttpError as ex: | ||||
|  | ||||
| @ -1,111 +1,158 @@ | ||||
| from szurubooru import errors | ||||
| from typing import Any, Union, List, Dict, Optional, cast | ||||
| from szurubooru import model, errors | ||||
| from szurubooru.func import net, file_uploads | ||||
| 
 | ||||
| 
 | ||||
| def _lower_first(source): | ||||
|     return source[0].lower() + source[1:] | ||||
| 
 | ||||
| 
 | ||||
| def _param_wrapper(func): | ||||
|     def wrapper(self, name, required=False, default=None, **kwargs): | ||||
|         # pylint: disable=protected-access | ||||
|         if name in self._params: | ||||
|             value = self._params[name] | ||||
|             try: | ||||
|                 value = func(self, value, **kwargs) | ||||
|             except errors.InvalidParameterError as ex: | ||||
|                 raise errors.InvalidParameterError( | ||||
|                     'Parameter %r is invalid: %s' % ( | ||||
|                         name, _lower_first(str(ex)))) | ||||
|             return value | ||||
|         if not required: | ||||
|             return default | ||||
|         raise errors.MissingRequiredParameterError( | ||||
|             'Required parameter %r is missing.' % name) | ||||
|     return wrapper | ||||
| MISSING = object() | ||||
| Request = Dict[str, Any] | ||||
| Response = Optional[Dict[str, Any]] | ||||
| 
 | ||||
| 
 | ||||
| class Context: | ||||
|     def __init__(self, method, url, headers=None, params=None, files=None): | ||||
|     def __init__( | ||||
|             self, | ||||
|             method: str, | ||||
|             url: str, | ||||
|             headers: Dict[str, str]=None, | ||||
|             params: Request=None, | ||||
|             files: Dict[str, bytes]=None) -> None: | ||||
|         self.method = method | ||||
|         self.url = url | ||||
|         self._headers = headers or {} | ||||
|         self._params = params or {} | ||||
|         self._files = files or {} | ||||
| 
 | ||||
|         # provided by middleware | ||||
|         # self.session = None | ||||
|         # self.user = None | ||||
|         self.user = model.User() | ||||
|         self.user.name = None | ||||
|         self.user.rank = 'anonymous' | ||||
| 
 | ||||
|     def has_header(self, name): | ||||
|         self.session = None  # type: Any | ||||
| 
 | ||||
|     def has_header(self, name: str) -> bool: | ||||
|         return name in self._headers | ||||
| 
 | ||||
|     def get_header(self, name): | ||||
|         return self._headers.get(name, None) | ||||
|     def get_header(self, name: str) -> str: | ||||
|         return self._headers.get(name, '') | ||||
| 
 | ||||
|     def has_file(self, name, allow_tokens=True): | ||||
|     def has_file(self, name: str, allow_tokens: bool=True) -> bool: | ||||
|         return ( | ||||
|             name in self._files or | ||||
|             name + 'Url' in self._params or | ||||
|             (allow_tokens and name + 'Token' in self._params)) | ||||
| 
 | ||||
|     def get_file(self, name, required=False, allow_tokens=True): | ||||
|         ret = None | ||||
|         if name in self._files: | ||||
|             ret = self._files[name] | ||||
|         elif name + 'Url' in self._params: | ||||
|             ret = net.download(self._params[name + 'Url']) | ||||
|         elif allow_tokens and name + 'Token' in self._params: | ||||
|     def get_file( | ||||
|             self, | ||||
|             name: str, | ||||
|             default: Union[object, bytes]=MISSING, | ||||
|             allow_tokens: bool=True) -> bytes: | ||||
|         if name in self._files and self._files[name]: | ||||
|             return self._files[name] | ||||
| 
 | ||||
|         if name + 'Url' in self._params: | ||||
|             return net.download(self._params[name + 'Url']) | ||||
| 
 | ||||
|         if allow_tokens and name + 'Token' in self._params: | ||||
|             ret = file_uploads.get(self._params[name + 'Token']) | ||||
|             if required and not ret: | ||||
|             if ret: | ||||
|                 return ret | ||||
|             elif default is not MISSING: | ||||
|                 raise errors.MissingOrExpiredRequiredFileError( | ||||
|                     'Required file %r is missing or has expired.' % name) | ||||
|         if required and not ret: | ||||
|             raise errors.MissingRequiredFileError( | ||||
|                 'Required file %r is missing.' % name) | ||||
|         return ret | ||||
| 
 | ||||
|     def has_param(self, name): | ||||
|         if default is not MISSING: | ||||
|             return cast(bytes, default) | ||||
|         raise errors.MissingRequiredFileError( | ||||
|             'Required file %r is missing.' % name) | ||||
| 
 | ||||
|     def has_param(self, name: str) -> bool: | ||||
|         return name in self._params | ||||
| 
 | ||||
|     @_param_wrapper | ||||
|     def get_param_as_list(self, value): | ||||
|         if not isinstance(value, list): | ||||
|     def get_param_as_list( | ||||
|             self, | ||||
|             name: str, | ||||
|             default: Union[object, List[Any]]=MISSING) -> List[Any]: | ||||
|         if name not in self._params: | ||||
|             if default is not MISSING: | ||||
|                 return cast(List[Any], default) | ||||
|             raise errors.MissingRequiredParameterError( | ||||
|                 'Required parameter %r is missing.' % name) | ||||
|         value = self._params[name] | ||||
|         if type(value) is str: | ||||
|             if ',' in value: | ||||
|                 return value.split(',') | ||||
|             return [value] | ||||
|         return value | ||||
|         if type(value) is list: | ||||
|             return value | ||||
|         raise errors.InvalidParameterError( | ||||
|             'Parameter %r must be a list.' % name) | ||||
| 
 | ||||
|     @_param_wrapper | ||||
|     def get_param_as_string(self, value): | ||||
|         if isinstance(value, list): | ||||
|             try: | ||||
|                 value = ','.join(value) | ||||
|             except TypeError: | ||||
|                 raise errors.InvalidParameterError('Expected simple string.') | ||||
|         return value | ||||
|     def get_param_as_string( | ||||
|             self, | ||||
|             name: str, | ||||
|             default: Union[object, str]=MISSING) -> str: | ||||
|         if name not in self._params: | ||||
|             if default is not MISSING: | ||||
|                 return cast(str, default) | ||||
|             raise errors.MissingRequiredParameterError( | ||||
|                 'Required parameter %r is missing.' % name) | ||||
|         value = self._params[name] | ||||
|         try: | ||||
|             if value is None: | ||||
|                 return '' | ||||
|             if type(value) is list: | ||||
|                 return ','.join(value) | ||||
|             if type(value) is int or type(value) is float: | ||||
|                 return str(value) | ||||
|             if type(value) is str: | ||||
|                 return value | ||||
|         except TypeError: | ||||
|             pass | ||||
|         raise errors.InvalidParameterError( | ||||
|             'Parameter %r must be a string value.' % name) | ||||
| 
 | ||||
|     @_param_wrapper | ||||
|     def get_param_as_int(self, value, min=None, max=None): | ||||
|     def get_param_as_int( | ||||
|             self, | ||||
|             name: str, | ||||
|             default: Union[object, int]=MISSING, | ||||
|             min: Optional[int]=None, | ||||
|             max: Optional[int]=None) -> int: | ||||
|         if name not in self._params: | ||||
|             if default is not MISSING: | ||||
|                 return cast(int, default) | ||||
|             raise errors.MissingRequiredParameterError( | ||||
|                 'Required parameter %r is missing.' % name) | ||||
|         value = self._params[name] | ||||
|         try: | ||||
|             value = int(value) | ||||
|             if min is not None and value < min: | ||||
|                 raise errors.InvalidParameterError( | ||||
|                     'Parameter %r must be at least %r.' % (name, min)) | ||||
|             if max is not None and value > max: | ||||
|                 raise errors.InvalidParameterError( | ||||
|                     'Parameter %r may not exceed %r.' % (name, max)) | ||||
|             return value | ||||
|         except (ValueError, TypeError): | ||||
|             raise errors.InvalidParameterError( | ||||
|                 'The value must be an integer.') | ||||
|         if min is not None and value < min: | ||||
|             raise errors.InvalidParameterError( | ||||
|                 'The value must be at least %r.' % min) | ||||
|         if max is not None and value > max: | ||||
|             raise errors.InvalidParameterError( | ||||
|                 'The value may not exceed %r.' % max) | ||||
|         return value | ||||
|             pass | ||||
|         raise errors.InvalidParameterError( | ||||
|             'Parameter %r must be an integer value.' % name) | ||||
| 
 | ||||
|     @_param_wrapper | ||||
|     def get_param_as_bool(self, value): | ||||
|         value = str(value).lower() | ||||
|     def get_param_as_bool( | ||||
|             self, | ||||
|             name: str, | ||||
|             default: Union[object, bool]=MISSING) -> bool: | ||||
|         if name not in self._params: | ||||
|             if default is not MISSING: | ||||
|                 return cast(bool, default) | ||||
|             raise errors.MissingRequiredParameterError( | ||||
|                 'Required parameter %r is missing.' % name) | ||||
|         value = self._params[name] | ||||
|         try: | ||||
|             value = str(value).lower() | ||||
|         except TypeError: | ||||
|             pass | ||||
|         if value in ['1', 'y', 'yes', 'yeah', 'yep', 'yup', 't', 'true']: | ||||
|             return True | ||||
|         if value in ['0', 'n', 'no', 'nope', 'f', 'false']: | ||||
|             return False | ||||
|         raise errors.InvalidParameterError( | ||||
|             'The value must be a boolean value.') | ||||
|             'Parameter %r must be a boolean value.' % name) | ||||
|  | ||||
| @ -1,11 +1,19 @@ | ||||
| from typing import Callable, Type, Dict | ||||
| 
 | ||||
| 
 | ||||
| error_handlers = {}  # pylint: disable=invalid-name | ||||
| 
 | ||||
| 
 | ||||
| class BaseHttpError(RuntimeError): | ||||
|     code = None | ||||
|     reason = None | ||||
|     code = -1 | ||||
|     reason = '' | ||||
| 
 | ||||
|     def __init__(self, name, description, title=None, extra_fields=None): | ||||
|     def __init__( | ||||
|             self, | ||||
|             name: str, | ||||
|             description: str, | ||||
|             title: str=None, | ||||
|             extra_fields: Dict[str, str]=None) -> None: | ||||
|         super().__init__() | ||||
|         # error name for programmers | ||||
|         self.name = name | ||||
| @ -52,5 +60,7 @@ class HttpInternalServerError(BaseHttpError): | ||||
|     reason = 'Internal Server Error' | ||||
| 
 | ||||
| 
 | ||||
| def handle(exception_type, handler): | ||||
| def handle( | ||||
|         exception_type: Type[Exception], | ||||
|         handler: Callable[[Exception], None]) -> None: | ||||
|     error_handlers[exception_type] = handler | ||||
|  | ||||
| @ -1,11 +1,15 @@ | ||||
| from typing import Callable | ||||
| from szurubooru.rest.context import Context | ||||
| 
 | ||||
| 
 | ||||
| # pylint: disable=invalid-name | ||||
| pre_hooks = [] | ||||
| post_hooks = [] | ||||
| pre_hooks = []  # type: List[Callable[[Context], None]] | ||||
| post_hooks = []  # type: List[Callable[[Context], None]] | ||||
| 
 | ||||
| 
 | ||||
| def pre_hook(handler): | ||||
| def pre_hook(handler: Callable) -> None: | ||||
|     pre_hooks.append(handler) | ||||
| 
 | ||||
| 
 | ||||
| def post_hook(handler): | ||||
| def post_hook(handler: Callable) -> None: | ||||
|     post_hooks.insert(0, handler) | ||||
|  | ||||
| @ -1,32 +1,36 @@ | ||||
| from typing import Callable, Dict, Any | ||||
| from collections import defaultdict | ||||
| from szurubooru.rest.context import Context, Response | ||||
| 
 | ||||
| 
 | ||||
| routes = defaultdict(dict)  # pylint: disable=invalid-name | ||||
| # pylint: disable=invalid-name | ||||
| RouteHandler = Callable[[Context, Dict[str, str]], Response] | ||||
| routes = defaultdict(dict)  # type: Dict[str, Dict[str, RouteHandler]] | ||||
| 
 | ||||
| 
 | ||||
| def get(url): | ||||
|     def wrapper(handler): | ||||
| def get(url: str) -> Callable[[RouteHandler], RouteHandler]: | ||||
|     def wrapper(handler: RouteHandler) -> RouteHandler: | ||||
|         routes[url]['GET'] = handler | ||||
|         return handler | ||||
|     return wrapper | ||||
| 
 | ||||
| 
 | ||||
| def put(url): | ||||
|     def wrapper(handler): | ||||
| def put(url: str) -> Callable[[RouteHandler], RouteHandler]: | ||||
|     def wrapper(handler: RouteHandler) -> RouteHandler: | ||||
|         routes[url]['PUT'] = handler | ||||
|         return handler | ||||
|     return wrapper | ||||
| 
 | ||||
| 
 | ||||
| def post(url): | ||||
|     def wrapper(handler): | ||||
| def post(url: str) -> Callable[[RouteHandler], RouteHandler]: | ||||
|     def wrapper(handler: RouteHandler) -> RouteHandler: | ||||
|         routes[url]['POST'] = handler | ||||
|         return handler | ||||
|     return wrapper | ||||
| 
 | ||||
| 
 | ||||
| def delete(url): | ||||
|     def wrapper(handler): | ||||
| def delete(url: str) -> Callable[[RouteHandler], RouteHandler]: | ||||
|     def wrapper(handler: RouteHandler) -> RouteHandler: | ||||
|         routes[url]['DELETE'] = handler | ||||
|         return handler | ||||
|     return wrapper | ||||
|  | ||||
| @ -1,38 +1,47 @@ | ||||
| from szurubooru.search import tokens | ||||
| from typing import Optional, Tuple, Dict, Callable | ||||
| from szurubooru.search import tokens, criteria | ||||
| from szurubooru.search.query import SearchQuery | ||||
| from szurubooru.search.typing import SaColumn, SaQuery | ||||
| 
 | ||||
| Filter = Callable[[SaQuery, Optional[criteria.BaseCriterion], bool], SaQuery] | ||||
| 
 | ||||
| 
 | ||||
| class BaseSearchConfig: | ||||
|     SORT_NONE = tokens.SortToken.SORT_NONE | ||||
|     SORT_ASC = tokens.SortToken.SORT_ASC | ||||
|     SORT_DESC = tokens.SortToken.SORT_DESC | ||||
| 
 | ||||
|     def on_search_query_parsed(self, search_query): | ||||
|     def on_search_query_parsed(self, search_query: SearchQuery) -> None: | ||||
|         pass | ||||
| 
 | ||||
|     def create_filter_query(self, _disable_eager_loads): | ||||
|     def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def create_count_query(self, disable_eager_loads): | ||||
|     def create_count_query(self, disable_eager_loads: bool) -> SaQuery: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def create_around_query(self): | ||||
|     def create_around_query(self) -> SaQuery: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def finalize_query(self, query: SaQuery) -> SaQuery: | ||||
|         return query | ||||
| 
 | ||||
|     @property | ||||
|     def id_column(self): | ||||
|     def id_column(self) -> SaColumn: | ||||
|         return None | ||||
| 
 | ||||
|     @property | ||||
|     def anonymous_filter(self): | ||||
|     def anonymous_filter(self) -> Optional[Filter]: | ||||
|         return None | ||||
| 
 | ||||
|     @property | ||||
|     def special_filters(self): | ||||
|     def special_filters(self) -> Dict[str, Filter]: | ||||
|         return {} | ||||
| 
 | ||||
|     @property | ||||
|     def named_filters(self): | ||||
|     def named_filters(self) -> Dict[str, Filter]: | ||||
|         return {} | ||||
| 
 | ||||
|     @property | ||||
|     def sort_columns(self): | ||||
|     def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: | ||||
|         return {} | ||||
|  | ||||
| @ -1,59 +1,62 @@ | ||||
| from sqlalchemy.sql.expression import func | ||||
| from szurubooru import db | ||||
| from typing import Tuple, Dict | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import db, model | ||||
| from szurubooru.search.typing import SaColumn, SaQuery | ||||
| from szurubooru.search.configs import util as search_util | ||||
| from szurubooru.search.configs.base_search_config import BaseSearchConfig | ||||
| from szurubooru.search.configs.base_search_config import ( | ||||
|     BaseSearchConfig, Filter) | ||||
| 
 | ||||
| 
 | ||||
| class CommentSearchConfig(BaseSearchConfig): | ||||
|     def create_filter_query(self, _disable_eager_loads): | ||||
|         return db.session.query(db.Comment).join(db.User) | ||||
|     def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         return db.session.query(model.Comment).join(model.User) | ||||
| 
 | ||||
|     def create_count_query(self, disable_eager_loads): | ||||
|     def create_count_query(self, disable_eager_loads: bool) -> SaQuery: | ||||
|         return self.create_filter_query(disable_eager_loads) | ||||
| 
 | ||||
|     def create_around_query(self): | ||||
|     def create_around_query(self) -> SaQuery: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def finalize_query(self, query): | ||||
|         return query.order_by(db.Comment.creation_time.desc()) | ||||
|     def finalize_query(self, query: SaQuery) -> SaQuery: | ||||
|         return query.order_by(model.Comment.creation_time.desc()) | ||||
| 
 | ||||
|     @property | ||||
|     def anonymous_filter(self): | ||||
|         return search_util.create_str_filter(db.Comment.text) | ||||
|     def anonymous_filter(self) -> SaQuery: | ||||
|         return search_util.create_str_filter(model.Comment.text) | ||||
| 
 | ||||
|     @property | ||||
|     def named_filters(self): | ||||
|     def named_filters(self) -> Dict[str, Filter]: | ||||
|         return { | ||||
|             'id': search_util.create_num_filter(db.Comment.comment_id), | ||||
|             'post': search_util.create_num_filter(db.Comment.post_id), | ||||
|             'user': search_util.create_str_filter(db.User.name), | ||||
|             'author': search_util.create_str_filter(db.User.name), | ||||
|             'text': search_util.create_str_filter(db.Comment.text), | ||||
|             'id': search_util.create_num_filter(model.Comment.comment_id), | ||||
|             'post': search_util.create_num_filter(model.Comment.post_id), | ||||
|             'user': search_util.create_str_filter(model.User.name), | ||||
|             'author': search_util.create_str_filter(model.User.name), | ||||
|             'text': search_util.create_str_filter(model.Comment.text), | ||||
|             'creation-date': | ||||
|                 search_util.create_date_filter(db.Comment.creation_time), | ||||
|                 search_util.create_date_filter(model.Comment.creation_time), | ||||
|             'creation-time': | ||||
|                 search_util.create_date_filter(db.Comment.creation_time), | ||||
|                 search_util.create_date_filter(model.Comment.creation_time), | ||||
|             'last-edit-date': | ||||
|                 search_util.create_date_filter(db.Comment.last_edit_time), | ||||
|                 search_util.create_date_filter(model.Comment.last_edit_time), | ||||
|             'last-edit-time': | ||||
|                 search_util.create_date_filter(db.Comment.last_edit_time), | ||||
|                 search_util.create_date_filter(model.Comment.last_edit_time), | ||||
|             'edit-date': | ||||
|                 search_util.create_date_filter(db.Comment.last_edit_time), | ||||
|                 search_util.create_date_filter(model.Comment.last_edit_time), | ||||
|             'edit-time': | ||||
|                 search_util.create_date_filter(db.Comment.last_edit_time), | ||||
|                 search_util.create_date_filter(model.Comment.last_edit_time), | ||||
|         } | ||||
| 
 | ||||
|     @property | ||||
|     def sort_columns(self): | ||||
|     def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: | ||||
|         return { | ||||
|             'random': (func.random(), None), | ||||
|             'user': (db.User.name, self.SORT_ASC), | ||||
|             'author': (db.User.name, self.SORT_ASC), | ||||
|             'post': (db.Comment.post_id, self.SORT_DESC), | ||||
|             'creation-date': (db.Comment.creation_time, self.SORT_DESC), | ||||
|             'creation-time': (db.Comment.creation_time, self.SORT_DESC), | ||||
|             'last-edit-date': (db.Comment.last_edit_time, self.SORT_DESC), | ||||
|             'last-edit-time': (db.Comment.last_edit_time, self.SORT_DESC), | ||||
|             'edit-date': (db.Comment.last_edit_time, self.SORT_DESC), | ||||
|             'edit-time': (db.Comment.last_edit_time, self.SORT_DESC), | ||||
|             'random': (sa.sql.expression.func.random(), self.SORT_NONE), | ||||
|             'user': (model.User.name, self.SORT_ASC), | ||||
|             'author': (model.User.name, self.SORT_ASC), | ||||
|             'post': (model.Comment.post_id, self.SORT_DESC), | ||||
|             'creation-date': (model.Comment.creation_time, self.SORT_DESC), | ||||
|             'creation-time': (model.Comment.creation_time, self.SORT_DESC), | ||||
|             'last-edit-date': (model.Comment.last_edit_time, self.SORT_DESC), | ||||
|             'last-edit-time': (model.Comment.last_edit_time, self.SORT_DESC), | ||||
|             'edit-date': (model.Comment.last_edit_time, self.SORT_DESC), | ||||
|             'edit-time': (model.Comment.last_edit_time, self.SORT_DESC), | ||||
|         } | ||||
|  | ||||
| @ -1,13 +1,16 @@ | ||||
| from sqlalchemy.orm import subqueryload, lazyload, defer, aliased | ||||
| from sqlalchemy.sql.expression import func | ||||
| from szurubooru import db, errors | ||||
| from typing import Any, Optional, Tuple, Dict | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import db, model, errors | ||||
| from szurubooru.func import util | ||||
| from szurubooru.search import criteria, tokens | ||||
| from szurubooru.search.typing import SaColumn, SaQuery | ||||
| from szurubooru.search.query import SearchQuery | ||||
| from szurubooru.search.configs import util as search_util | ||||
| from szurubooru.search.configs.base_search_config import BaseSearchConfig | ||||
| from szurubooru.search.configs.base_search_config import ( | ||||
|     BaseSearchConfig, Filter) | ||||
| 
 | ||||
| 
 | ||||
| def _enum_transformer(available_values, value): | ||||
| def _enum_transformer(available_values: Dict[str, Any], value: str) -> str: | ||||
|     try: | ||||
|         return available_values[value.lower()] | ||||
|     except KeyError: | ||||
| @ -16,71 +19,82 @@ def _enum_transformer(available_values, value): | ||||
|                 value, list(sorted(available_values.keys())))) | ||||
| 
 | ||||
| 
 | ||||
| def _type_transformer(value): | ||||
| def _type_transformer(value: str) -> str: | ||||
|     available_values = { | ||||
|         'image': db.Post.TYPE_IMAGE, | ||||
|         'animation': db.Post.TYPE_ANIMATION, | ||||
|         'animated': db.Post.TYPE_ANIMATION, | ||||
|         'anim': db.Post.TYPE_ANIMATION, | ||||
|         'gif': db.Post.TYPE_ANIMATION, | ||||
|         'video': db.Post.TYPE_VIDEO, | ||||
|         'webm': db.Post.TYPE_VIDEO, | ||||
|         'flash': db.Post.TYPE_FLASH, | ||||
|         'swf': db.Post.TYPE_FLASH, | ||||
|         'image': model.Post.TYPE_IMAGE, | ||||
|         'animation': model.Post.TYPE_ANIMATION, | ||||
|         'animated': model.Post.TYPE_ANIMATION, | ||||
|         'anim': model.Post.TYPE_ANIMATION, | ||||
|         'gif': model.Post.TYPE_ANIMATION, | ||||
|         'video': model.Post.TYPE_VIDEO, | ||||
|         'webm': model.Post.TYPE_VIDEO, | ||||
|         'flash': model.Post.TYPE_FLASH, | ||||
|         'swf': model.Post.TYPE_FLASH, | ||||
|     } | ||||
|     return _enum_transformer(available_values, value) | ||||
| 
 | ||||
| 
 | ||||
| def _safety_transformer(value): | ||||
| def _safety_transformer(value: str) -> str: | ||||
|     available_values = { | ||||
|         'safe': db.Post.SAFETY_SAFE, | ||||
|         'sketchy': db.Post.SAFETY_SKETCHY, | ||||
|         'questionable': db.Post.SAFETY_SKETCHY, | ||||
|         'unsafe': db.Post.SAFETY_UNSAFE, | ||||
|         'safe': model.Post.SAFETY_SAFE, | ||||
|         'sketchy': model.Post.SAFETY_SKETCHY, | ||||
|         'questionable': model.Post.SAFETY_SKETCHY, | ||||
|         'unsafe': model.Post.SAFETY_UNSAFE, | ||||
|     } | ||||
|     return _enum_transformer(available_values, value) | ||||
| 
 | ||||
| 
 | ||||
| def _create_score_filter(score): | ||||
|     def wrapper(query, criterion, negated): | ||||
| def _create_score_filter(score: int) -> Filter: | ||||
|     def wrapper( | ||||
|             query: SaQuery, | ||||
|             criterion: Optional[criteria.BaseCriterion], | ||||
|             negated: bool) -> SaQuery: | ||||
|         assert criterion | ||||
|         if not getattr(criterion, 'internal', False): | ||||
|             raise errors.SearchError( | ||||
|                 'Votes cannot be seen publicly. Did you mean %r?' | ||||
|                 % 'special:liked') | ||||
|         user_alias = aliased(db.User) | ||||
|         score_alias = aliased(db.PostScore) | ||||
|         user_alias = sa.orm.aliased(model.User) | ||||
|         score_alias = sa.orm.aliased(model.PostScore) | ||||
|         expr = score_alias.score == score | ||||
|         expr = expr & search_util.apply_str_criterion_to_column( | ||||
|             user_alias.name, criterion) | ||||
|         if negated: | ||||
|             expr = ~expr | ||||
|         ret = query \ | ||||
|             .join(score_alias, score_alias.post_id == db.Post.post_id) \ | ||||
|             .join(score_alias, score_alias.post_id == model.Post.post_id) \ | ||||
|             .join(user_alias, user_alias.user_id == score_alias.user_id) \ | ||||
|             .filter(expr) | ||||
|         return ret | ||||
|     return wrapper | ||||
| 
 | ||||
| 
 | ||||
| def _create_user_filter(): | ||||
|     def wrapper(query, criterion, negated): | ||||
| def _create_user_filter() -> Filter: | ||||
|     def wrapper( | ||||
|             query: SaQuery, | ||||
|             criterion: Optional[criteria.BaseCriterion], | ||||
|             negated: bool) -> SaQuery: | ||||
|         assert criterion | ||||
|         if isinstance(criterion, criteria.PlainCriterion) \ | ||||
|                 and not criterion.value: | ||||
|             # pylint: disable=singleton-comparison | ||||
|             expr = db.Post.user_id == None | ||||
|             expr = model.Post.user_id == None | ||||
|             if negated: | ||||
|                 expr = ~expr | ||||
|             return query.filter(expr) | ||||
|         return search_util.create_subquery_filter( | ||||
|             db.Post.user_id, | ||||
|             db.User.user_id, | ||||
|             db.User.name, | ||||
|             model.Post.user_id, | ||||
|             model.User.user_id, | ||||
|             model.User.name, | ||||
|             search_util.create_str_filter)(query, criterion, negated) | ||||
|     return wrapper | ||||
| 
 | ||||
| 
 | ||||
| class PostSearchConfig(BaseSearchConfig): | ||||
|     def on_search_query_parsed(self, search_query): | ||||
|     def __init__(self) -> None: | ||||
|         self.user = None  # type: Optional[model.User] | ||||
| 
 | ||||
|     def on_search_query_parsed(self, search_query: SearchQuery) -> SaQuery: | ||||
|         new_special_tokens = [] | ||||
|         for token in search_query.special_tokens: | ||||
|             if token.value in ('fav', 'liked', 'disliked'): | ||||
| @ -91,7 +105,7 @@ class PostSearchConfig(BaseSearchConfig): | ||||
|                 criterion = criteria.PlainCriterion( | ||||
|                     original_text=self.user.name, | ||||
|                     value=self.user.name) | ||||
|                 criterion.internal = True | ||||
|                 setattr(criterion, 'internal', True) | ||||
|                 search_query.named_tokens.append( | ||||
|                     tokens.NamedToken( | ||||
|                         name=token.value, | ||||
| @ -101,160 +115,324 @@ class PostSearchConfig(BaseSearchConfig): | ||||
|                 new_special_tokens.append(token) | ||||
|         search_query.special_tokens = new_special_tokens | ||||
| 
 | ||||
|     def create_around_query(self): | ||||
|         return db.session.query(db.Post).options(lazyload('*')) | ||||
|     def create_around_query(self) -> SaQuery: | ||||
|         return db.session.query(model.Post).options(sa.orm.lazyload('*')) | ||||
| 
 | ||||
|     def create_filter_query(self, disable_eager_loads): | ||||
|         strategy = lazyload if disable_eager_loads else subqueryload | ||||
|         return db.session.query(db.Post) \ | ||||
|     def create_filter_query(self, disable_eager_loads: bool) -> SaQuery: | ||||
|         strategy = ( | ||||
|             sa.orm.lazyload | ||||
|             if disable_eager_loads | ||||
|             else sa.orm.subqueryload) | ||||
|         return db.session.query(model.Post) \ | ||||
|             .options( | ||||
|                 lazyload('*'), | ||||
|                 sa.orm.lazyload('*'), | ||||
|                 # use config optimized for official client | ||||
|                 # defer(db.Post.score), | ||||
|                 # defer(db.Post.favorite_count), | ||||
|                 # defer(db.Post.comment_count), | ||||
|                 defer(db.Post.last_favorite_time), | ||||
|                 defer(db.Post.feature_count), | ||||
|                 defer(db.Post.last_feature_time), | ||||
|                 defer(db.Post.last_comment_creation_time), | ||||
|                 defer(db.Post.last_comment_edit_time), | ||||
|                 defer(db.Post.note_count), | ||||
|                 defer(db.Post.tag_count), | ||||
|                 strategy(db.Post.tags).subqueryload(db.Tag.names), | ||||
|                 strategy(db.Post.tags).defer(db.Tag.post_count), | ||||
|                 strategy(db.Post.tags).lazyload(db.Tag.implications), | ||||
|                 strategy(db.Post.tags).lazyload(db.Tag.suggestions)) | ||||
|                 # sa.orm.defer(model.Post.score), | ||||
|                 # sa.orm.defer(model.Post.favorite_count), | ||||
|                 # sa.orm.defer(model.Post.comment_count), | ||||
|                 sa.orm.defer(model.Post.last_favorite_time), | ||||
|                 sa.orm.defer(model.Post.feature_count), | ||||
|                 sa.orm.defer(model.Post.last_feature_time), | ||||
|                 sa.orm.defer(model.Post.last_comment_creation_time), | ||||
|                 sa.orm.defer(model.Post.last_comment_edit_time), | ||||
|                 sa.orm.defer(model.Post.note_count), | ||||
|                 sa.orm.defer(model.Post.tag_count), | ||||
|                 strategy(model.Post.tags).subqueryload(model.Tag.names), | ||||
|                 strategy(model.Post.tags).defer(model.Tag.post_count), | ||||
|                 strategy(model.Post.tags).lazyload(model.Tag.implications), | ||||
|                 strategy(model.Post.tags).lazyload(model.Tag.suggestions)) | ||||
| 
 | ||||
|     def create_count_query(self, _disable_eager_loads): | ||||
|         return db.session.query(db.Post) | ||||
|     def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         return db.session.query(model.Post) | ||||
| 
 | ||||
|     def finalize_query(self, query): | ||||
|         return query.order_by(db.Post.post_id.desc()) | ||||
|     def finalize_query(self, query: SaQuery) -> SaQuery: | ||||
|         return query.order_by(model.Post.post_id.desc()) | ||||
| 
 | ||||
|     @property | ||||
|     def id_column(self): | ||||
|         return db.Post.post_id | ||||
|     def id_column(self) -> SaColumn: | ||||
|         return model.Post.post_id | ||||
| 
 | ||||
|     @property | ||||
|     def anonymous_filter(self): | ||||
|     def anonymous_filter(self) -> Optional[Filter]: | ||||
|         return search_util.create_subquery_filter( | ||||
|             db.Post.post_id, | ||||
|             db.PostTag.post_id, | ||||
|             db.TagName.name, | ||||
|             model.Post.post_id, | ||||
|             model.PostTag.post_id, | ||||
|             model.TagName.name, | ||||
|             search_util.create_str_filter, | ||||
|             lambda subquery: subquery.join(db.Tag).join(db.TagName)) | ||||
|             lambda subquery: subquery.join(model.Tag).join(model.TagName)) | ||||
| 
 | ||||
|     @property | ||||
|     def named_filters(self): | ||||
|         return util.unalias_dict({ | ||||
|             'id': search_util.create_num_filter(db.Post.post_id), | ||||
|             'tag': search_util.create_subquery_filter( | ||||
|                 db.Post.post_id, | ||||
|                 db.PostTag.post_id, | ||||
|                 db.TagName.name, | ||||
|                 search_util.create_str_filter, | ||||
|                 lambda subquery: subquery.join(db.Tag).join(db.TagName)), | ||||
|             'score': search_util.create_num_filter(db.Post.score), | ||||
|             ('uploader', 'upload', 'submit'): | ||||
|                 _create_user_filter(), | ||||
|             'comment': search_util.create_subquery_filter( | ||||
|                 db.Post.post_id, | ||||
|                 db.Comment.post_id, | ||||
|                 db.User.name, | ||||
|                 search_util.create_str_filter, | ||||
|                 lambda subquery: subquery.join(db.User)), | ||||
|             'fav': search_util.create_subquery_filter( | ||||
|                 db.Post.post_id, | ||||
|                 db.PostFavorite.post_id, | ||||
|                 db.User.name, | ||||
|                 search_util.create_str_filter, | ||||
|                 lambda subquery: subquery.join(db.User)), | ||||
|             'liked': _create_score_filter(1), | ||||
|             'disliked': _create_score_filter(-1), | ||||
|             'tag-count': search_util.create_num_filter(db.Post.tag_count), | ||||
|             'comment-count': | ||||
|                 search_util.create_num_filter(db.Post.comment_count), | ||||
|             'fav-count': | ||||
|                 search_util.create_num_filter(db.Post.favorite_count), | ||||
|             'note-count': search_util.create_num_filter(db.Post.note_count), | ||||
|             'relation-count': | ||||
|                 search_util.create_num_filter(db.Post.relation_count), | ||||
|             'feature-count': | ||||
|                 search_util.create_num_filter(db.Post.feature_count), | ||||
|             'type': | ||||
|     def named_filters(self) -> Dict[str, Filter]: | ||||
|         return util.unalias_dict([ | ||||
|             ( | ||||
|                 ['id'], | ||||
|                 search_util.create_num_filter(model.Post.post_id) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['tag'], | ||||
|                 search_util.create_subquery_filter( | ||||
|                     model.Post.post_id, | ||||
|                     model.PostTag.post_id, | ||||
|                     model.TagName.name, | ||||
|                     search_util.create_str_filter, | ||||
|                     lambda subquery: | ||||
|                         subquery.join(model.Tag).join(model.TagName)) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['score'], | ||||
|                 search_util.create_num_filter(model.Post.score) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['uploader', 'upload', 'submit'], | ||||
|                 _create_user_filter() | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['comment'], | ||||
|                 search_util.create_subquery_filter( | ||||
|                     model.Post.post_id, | ||||
|                     model.Comment.post_id, | ||||
|                     model.User.name, | ||||
|                     search_util.create_str_filter, | ||||
|                     lambda subquery: subquery.join(model.User)) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['fav'], | ||||
|                 search_util.create_subquery_filter( | ||||
|                     model.Post.post_id, | ||||
|                     model.PostFavorite.post_id, | ||||
|                     model.User.name, | ||||
|                     search_util.create_str_filter, | ||||
|                     lambda subquery: subquery.join(model.User)) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['liked'], | ||||
|                 _create_score_filter(1) | ||||
|             ), | ||||
|             ( | ||||
|                 ['disliked'], | ||||
|                 _create_score_filter(-1) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['tag-count'], | ||||
|                 search_util.create_num_filter(model.Post.tag_count) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['comment-count'], | ||||
|                 search_util.create_num_filter(model.Post.comment_count) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['fav-count'], | ||||
|                 search_util.create_num_filter(model.Post.favorite_count) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['note-count'], | ||||
|                 search_util.create_num_filter(model.Post.note_count) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['relation-count'], | ||||
|                 search_util.create_num_filter(model.Post.relation_count) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['feature-count'], | ||||
|                 search_util.create_num_filter(model.Post.feature_count) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['type'], | ||||
|                 search_util.create_str_filter( | ||||
|                     db.Post.type, _type_transformer), | ||||
|             'content-checksum': search_util.create_str_filter( | ||||
|                 db.Post.checksum), | ||||
|             'file-size': search_util.create_num_filter(db.Post.file_size), | ||||
|             ('image-width', 'width'): | ||||
|                 search_util.create_num_filter(db.Post.canvas_width), | ||||
|             ('image-height', 'height'): | ||||
|                 search_util.create_num_filter(db.Post.canvas_height), | ||||
|             ('image-area', 'area'): | ||||
|                 search_util.create_num_filter(db.Post.canvas_area), | ||||
|             ('creation-date', 'creation-time', 'date', 'time'): | ||||
|                 search_util.create_date_filter(db.Post.creation_time), | ||||
|             ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): | ||||
|                 search_util.create_date_filter(db.Post.last_edit_time), | ||||
|             ('comment-date', 'comment-time'): | ||||
|                     model.Post.type, _type_transformer) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['content-checksum'], | ||||
|                 search_util.create_str_filter(model.Post.checksum) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['file-size'], | ||||
|                 search_util.create_num_filter(model.Post.file_size) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['image-width', 'width'], | ||||
|                 search_util.create_num_filter(model.Post.canvas_width) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['image-height', 'height'], | ||||
|                 search_util.create_num_filter(model.Post.canvas_height) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['image-area', 'area'], | ||||
|                 search_util.create_num_filter(model.Post.canvas_area) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['creation-date', 'creation-time', 'date', 'time'], | ||||
|                 search_util.create_date_filter(model.Post.creation_time) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], | ||||
|                 search_util.create_date_filter(model.Post.last_edit_time) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['comment-date', 'comment-time'], | ||||
|                 search_util.create_date_filter( | ||||
|                     db.Post.last_comment_creation_time), | ||||
|             ('fav-date', 'fav-time'): | ||||
|                 search_util.create_date_filter(db.Post.last_favorite_time), | ||||
|             ('feature-date', 'feature-time'): | ||||
|                 search_util.create_date_filter(db.Post.last_feature_time), | ||||
|             ('safety', 'rating'): | ||||
|                     model.Post.last_comment_creation_time) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['fav-date', 'fav-time'], | ||||
|                 search_util.create_date_filter(model.Post.last_favorite_time) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['feature-date', 'feature-time'], | ||||
|                 search_util.create_date_filter(model.Post.last_feature_time) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['safety', 'rating'], | ||||
|                 search_util.create_str_filter( | ||||
|                     db.Post.safety, _safety_transformer), | ||||
|         }) | ||||
|                     model.Post.safety, _safety_transformer) | ||||
|             ), | ||||
|         ]) | ||||
| 
 | ||||
|     @property | ||||
|     def sort_columns(self): | ||||
|         return util.unalias_dict({ | ||||
|             'random': (func.random(), None), | ||||
|             'id': (db.Post.post_id, self.SORT_DESC), | ||||
|             'score': (db.Post.score, self.SORT_DESC), | ||||
|             'tag-count': (db.Post.tag_count, self.SORT_DESC), | ||||
|             'comment-count': (db.Post.comment_count, self.SORT_DESC), | ||||
|             'fav-count': (db.Post.favorite_count, self.SORT_DESC), | ||||
|             'note-count': (db.Post.note_count, self.SORT_DESC), | ||||
|             'relation-count': (db.Post.relation_count, self.SORT_DESC), | ||||
|             'feature-count': (db.Post.feature_count, self.SORT_DESC), | ||||
|             'file-size': (db.Post.file_size, self.SORT_DESC), | ||||
|             ('image-width', 'width'): | ||||
|                 (db.Post.canvas_width, self.SORT_DESC), | ||||
|             ('image-height', 'height'): | ||||
|                 (db.Post.canvas_height, self.SORT_DESC), | ||||
|             ('image-area', 'area'): | ||||
|                 (db.Post.canvas_area, self.SORT_DESC), | ||||
|             ('creation-date', 'creation-time', 'date', 'time'): | ||||
|                 (db.Post.creation_time, self.SORT_DESC), | ||||
|             ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): | ||||
|                 (db.Post.last_edit_time, self.SORT_DESC), | ||||
|             ('comment-date', 'comment-time'): | ||||
|                 (db.Post.last_comment_creation_time, self.SORT_DESC), | ||||
|             ('fav-date', 'fav-time'): | ||||
|                 (db.Post.last_favorite_time, self.SORT_DESC), | ||||
|             ('feature-date', 'feature-time'): | ||||
|                 (db.Post.last_feature_time, self.SORT_DESC), | ||||
|         }) | ||||
|     def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: | ||||
|         return util.unalias_dict([ | ||||
|             ( | ||||
|                 ['random'], | ||||
|                 (sa.sql.expression.func.random(), self.SORT_NONE) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['id'], | ||||
|                 (model.Post.post_id, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['score'], | ||||
|                 (model.Post.score, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['tag-count'], | ||||
|                 (model.Post.tag_count, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['comment-count'], | ||||
|                 (model.Post.comment_count, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['fav-count'], | ||||
|                 (model.Post.favorite_count, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['note-count'], | ||||
|                 (model.Post.note_count, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['relation-count'], | ||||
|                 (model.Post.relation_count, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['feature-count'], | ||||
|                 (model.Post.feature_count, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['file-size'], | ||||
|                 (model.Post.file_size, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['image-width', 'width'], | ||||
|                 (model.Post.canvas_width, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['image-height', 'height'], | ||||
|                 (model.Post.canvas_height, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['image-area', 'area'], | ||||
|                 (model.Post.canvas_area, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['creation-date', 'creation-time', 'date', 'time'], | ||||
|                 (model.Post.creation_time, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], | ||||
|                 (model.Post.last_edit_time, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['comment-date', 'comment-time'], | ||||
|                 (model.Post.last_comment_creation_time, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['fav-date', 'fav-time'], | ||||
|                 (model.Post.last_favorite_time, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['feature-date', 'feature-time'], | ||||
|                 (model.Post.last_feature_time, self.SORT_DESC) | ||||
|             ), | ||||
|         ]) | ||||
| 
 | ||||
|     @property | ||||
|     def special_filters(self): | ||||
|     def special_filters(self) -> Dict[str, Filter]: | ||||
|         return { | ||||
|             # handled by parsed | ||||
|             'fav': None, | ||||
|             'liked': None, | ||||
|             'disliked': None, | ||||
|             # handled by parser | ||||
|             'fav': self.noop_filter, | ||||
|             'liked': self.noop_filter, | ||||
|             'disliked': self.noop_filter, | ||||
|             'tumbleweed': self.tumbleweed_filter, | ||||
|         } | ||||
| 
 | ||||
|     def tumbleweed_filter(self, query, negated): | ||||
|         expr = \ | ||||
|             (db.Post.comment_count == 0) \ | ||||
|             & (db.Post.favorite_count == 0) \ | ||||
|             & (db.Post.score == 0) | ||||
|     def noop_filter( | ||||
|             self, | ||||
|             query: SaQuery, | ||||
|             _criterion: Optional[criteria.BaseCriterion], | ||||
|             _negated: bool) -> SaQuery: | ||||
|         return query | ||||
| 
 | ||||
|     def tumbleweed_filter( | ||||
|             self, | ||||
|             query: SaQuery, | ||||
|             _criterion: Optional[criteria.BaseCriterion], | ||||
|             negated: bool) -> SaQuery: | ||||
|         expr = ( | ||||
|             (model.Post.comment_count == 0) | ||||
|             & (model.Post.favorite_count == 0) | ||||
|             & (model.Post.score == 0)) | ||||
|         if negated: | ||||
|             expr = ~expr | ||||
|         return query.filter(expr) | ||||
|  | ||||
| @ -1,28 +1,37 @@ | ||||
| from szurubooru import db | ||||
| from typing import Dict | ||||
| from szurubooru import db, model | ||||
| from szurubooru.search.typing import SaQuery | ||||
| from szurubooru.search.configs import util as search_util | ||||
| from szurubooru.search.configs.base_search_config import BaseSearchConfig | ||||
| from szurubooru.search.configs.base_search_config import ( | ||||
|     BaseSearchConfig, Filter) | ||||
| 
 | ||||
| 
 | ||||
| class SnapshotSearchConfig(BaseSearchConfig): | ||||
|     def create_filter_query(self, _disable_eager_loads): | ||||
|         return db.session.query(db.Snapshot) | ||||
|     def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         return db.session.query(model.Snapshot) | ||||
| 
 | ||||
|     def create_count_query(self, _disable_eager_loads): | ||||
|         return db.session.query(db.Snapshot) | ||||
|     def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         return db.session.query(model.Snapshot) | ||||
| 
 | ||||
|     def create_around_query(self): | ||||
|     def create_around_query(self) -> SaQuery: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def finalize_query(self, query): | ||||
|         return query.order_by(db.Snapshot.creation_time.desc()) | ||||
|     def finalize_query(self, query: SaQuery) -> SaQuery: | ||||
|         return query.order_by(model.Snapshot.creation_time.desc()) | ||||
| 
 | ||||
|     @property | ||||
|     def named_filters(self): | ||||
|     def named_filters(self) -> Dict[str, Filter]: | ||||
|         return { | ||||
|             'type': search_util.create_str_filter(db.Snapshot.resource_type), | ||||
|             'id': search_util.create_str_filter(db.Snapshot.resource_name), | ||||
|             'date': search_util.create_date_filter(db.Snapshot.creation_time), | ||||
|             'time': search_util.create_date_filter(db.Snapshot.creation_time), | ||||
|             'operation': search_util.create_str_filter(db.Snapshot.operation), | ||||
|             'user': search_util.create_str_filter(db.User.name), | ||||
|             'type': | ||||
|                 search_util.create_str_filter(model.Snapshot.resource_type), | ||||
|             'id': | ||||
|                 search_util.create_str_filter(model.Snapshot.resource_name), | ||||
|             'date': | ||||
|                 search_util.create_date_filter(model.Snapshot.creation_time), | ||||
|             'time': | ||||
|                 search_util.create_date_filter(model.Snapshot.creation_time), | ||||
|             'operation': | ||||
|                 search_util.create_str_filter(model.Snapshot.operation), | ||||
|             'user': | ||||
|                 search_util.create_str_filter(model.User.name), | ||||
|         } | ||||
|  | ||||
| @ -1,79 +1,134 @@ | ||||
| from sqlalchemy.orm import subqueryload, lazyload, defer | ||||
| from sqlalchemy.sql.expression import func | ||||
| from szurubooru import db | ||||
| from typing import Tuple, Dict | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import db, model | ||||
| from szurubooru.func import util | ||||
| from szurubooru.search.typing import SaColumn, SaQuery | ||||
| from szurubooru.search.configs import util as search_util | ||||
| from szurubooru.search.configs.base_search_config import BaseSearchConfig | ||||
| from szurubooru.search.configs.base_search_config import ( | ||||
|     BaseSearchConfig, Filter) | ||||
| 
 | ||||
| 
 | ||||
| class TagSearchConfig(BaseSearchConfig): | ||||
|     def create_filter_query(self, _disable_eager_loads): | ||||
|         strategy = lazyload if _disable_eager_loads else subqueryload | ||||
|         return db.session.query(db.Tag) \ | ||||
|             .join(db.TagCategory) \ | ||||
|     def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         strategy = ( | ||||
|             sa.orm.lazyload | ||||
|             if _disable_eager_loads | ||||
|             else sa.orm.subqueryload) | ||||
|         return db.session.query(model.Tag) \ | ||||
|             .join(model.TagCategory) \ | ||||
|             .options( | ||||
|                 defer(db.Tag.first_name), | ||||
|                 defer(db.Tag.suggestion_count), | ||||
|                 defer(db.Tag.implication_count), | ||||
|                 defer(db.Tag.post_count), | ||||
|                 strategy(db.Tag.names), | ||||
|                 strategy(db.Tag.suggestions).joinedload(db.Tag.names), | ||||
|                 strategy(db.Tag.implications).joinedload(db.Tag.names)) | ||||
|                 sa.orm.defer(model.Tag.first_name), | ||||
|                 sa.orm.defer(model.Tag.suggestion_count), | ||||
|                 sa.orm.defer(model.Tag.implication_count), | ||||
|                 sa.orm.defer(model.Tag.post_count), | ||||
|                 strategy(model.Tag.names), | ||||
|                 strategy(model.Tag.suggestions).joinedload(model.Tag.names), | ||||
|                 strategy(model.Tag.implications).joinedload(model.Tag.names)) | ||||
| 
 | ||||
|     def create_count_query(self, _disable_eager_loads): | ||||
|         return db.session.query(db.Tag) | ||||
|     def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         return db.session.query(model.Tag) | ||||
| 
 | ||||
|     def create_around_query(self): | ||||
|     def create_around_query(self) -> SaQuery: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def finalize_query(self, query): | ||||
|         return query.order_by(db.Tag.first_name.asc()) | ||||
|     def finalize_query(self, query: SaQuery) -> SaQuery: | ||||
|         return query.order_by(model.Tag.first_name.asc()) | ||||
| 
 | ||||
|     @property | ||||
|     def anonymous_filter(self): | ||||
|     def anonymous_filter(self) -> Filter: | ||||
|         return search_util.create_subquery_filter( | ||||
|             db.Tag.tag_id, | ||||
|             db.TagName.tag_id, | ||||
|             db.TagName.name, | ||||
|             model.Tag.tag_id, | ||||
|             model.TagName.tag_id, | ||||
|             model.TagName.name, | ||||
|             search_util.create_str_filter) | ||||
| 
 | ||||
|     @property | ||||
|     def named_filters(self): | ||||
|         return util.unalias_dict({ | ||||
|             'name': search_util.create_subquery_filter( | ||||
|                 db.Tag.tag_id, | ||||
|                 db.TagName.tag_id, | ||||
|                 db.TagName.name, | ||||
|                 search_util.create_str_filter), | ||||
|             'category': search_util.create_subquery_filter( | ||||
|                 db.Tag.category_id, | ||||
|                 db.TagCategory.tag_category_id, | ||||
|                 db.TagCategory.name, | ||||
|                 search_util.create_str_filter), | ||||
|             ('creation-date', 'creation-time'): | ||||
|                 search_util.create_date_filter(db.Tag.creation_time), | ||||
|             ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): | ||||
|                 search_util.create_date_filter(db.Tag.last_edit_time), | ||||
|             ('usage-count', 'post-count', 'usages'): | ||||
|                 search_util.create_num_filter(db.Tag.post_count), | ||||
|             'suggestion-count': | ||||
|                 search_util.create_num_filter(db.Tag.suggestion_count), | ||||
|             'implication-count': | ||||
|                 search_util.create_num_filter(db.Tag.implication_count), | ||||
|         }) | ||||
|     def named_filters(self) -> Dict[str, Filter]: | ||||
|         return util.unalias_dict([ | ||||
|             ( | ||||
|                 ['name'], | ||||
|                 search_util.create_subquery_filter( | ||||
|                     model.Tag.tag_id, | ||||
|                     model.TagName.tag_id, | ||||
|                     model.TagName.name, | ||||
|                     search_util.create_str_filter) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['category'], | ||||
|                 search_util.create_subquery_filter( | ||||
|                     model.Tag.category_id, | ||||
|                     model.TagCategory.tag_category_id, | ||||
|                     model.TagCategory.name, | ||||
|                     search_util.create_str_filter) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['creation-date', 'creation-time'], | ||||
|                 search_util.create_date_filter(model.Tag.creation_time) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], | ||||
|                 search_util.create_date_filter(model.Tag.last_edit_time) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['usage-count', 'post-count', 'usages'], | ||||
|                 search_util.create_num_filter(model.Tag.post_count) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['suggestion-count'], | ||||
|                 search_util.create_num_filter(model.Tag.suggestion_count) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['implication-count'], | ||||
|                 search_util.create_num_filter(model.Tag.implication_count) | ||||
|             ), | ||||
|         ]) | ||||
| 
 | ||||
|     @property | ||||
|     def sort_columns(self): | ||||
|         return util.unalias_dict({ | ||||
|             'random': (func.random(), None), | ||||
|             'name': (db.Tag.first_name, self.SORT_ASC), | ||||
|             'category': (db.TagCategory.name, self.SORT_ASC), | ||||
|             ('creation-date', 'creation-time'): | ||||
|                 (db.Tag.creation_time, self.SORT_DESC), | ||||
|             ('last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'): | ||||
|                 (db.Tag.last_edit_time, self.SORT_DESC), | ||||
|             ('usage-count', 'post-count', 'usages'): | ||||
|                 (db.Tag.post_count, self.SORT_DESC), | ||||
|             'suggestion-count': (db.Tag.suggestion_count, self.SORT_DESC), | ||||
|             'implication-count': (db.Tag.implication_count, self.SORT_DESC), | ||||
|         }) | ||||
|     def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: | ||||
|         return util.unalias_dict([ | ||||
|             ( | ||||
|                 ['random'], | ||||
|                 (sa.sql.expression.func.random(), self.SORT_NONE) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['name'], | ||||
|                 (model.Tag.first_name, self.SORT_ASC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['category'], | ||||
|                 (model.TagCategory.name, self.SORT_ASC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['creation-date', 'creation-time'], | ||||
|                 (model.Tag.creation_time, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['last-edit-date', 'last-edit-time', 'edit-date', 'edit-time'], | ||||
|                 (model.Tag.last_edit_time, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['usage-count', 'post-count', 'usages'], | ||||
|                 (model.Tag.post_count, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['suggestion-count'], | ||||
|                 (model.Tag.suggestion_count, self.SORT_DESC) | ||||
|             ), | ||||
| 
 | ||||
|             ( | ||||
|                 ['implication-count'], | ||||
|                 (model.Tag.implication_count, self.SORT_DESC) | ||||
|             ), | ||||
|         ]) | ||||
|  | ||||
| @ -1,53 +1,57 @@ | ||||
| from sqlalchemy.sql.expression import func | ||||
| from szurubooru import db | ||||
| from typing import Tuple, Dict | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import db, model | ||||
| from szurubooru.search.typing import SaColumn, SaQuery | ||||
| from szurubooru.search.configs import util as search_util | ||||
| from szurubooru.search.configs.base_search_config import BaseSearchConfig | ||||
| from szurubooru.search.configs.base_search_config import ( | ||||
|     BaseSearchConfig, Filter) | ||||
| 
 | ||||
| 
 | ||||
| class UserSearchConfig(BaseSearchConfig): | ||||
|     def create_filter_query(self, _disable_eager_loads): | ||||
|         return db.session.query(db.User) | ||||
|     def create_filter_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         return db.session.query(model.User) | ||||
| 
 | ||||
|     def create_count_query(self, _disable_eager_loads): | ||||
|         return db.session.query(db.User) | ||||
|     def create_count_query(self, _disable_eager_loads: bool) -> SaQuery: | ||||
|         return db.session.query(model.User) | ||||
| 
 | ||||
|     def create_around_query(self): | ||||
|     def create_around_query(self) -> SaQuery: | ||||
|         raise NotImplementedError() | ||||
| 
 | ||||
|     def finalize_query(self, query): | ||||
|         return query.order_by(db.User.name.asc()) | ||||
|     def finalize_query(self, query: SaQuery) -> SaQuery: | ||||
|         return query.order_by(model.User.name.asc()) | ||||
| 
 | ||||
|     @property | ||||
|     def anonymous_filter(self): | ||||
|         return search_util.create_str_filter(db.User.name) | ||||
|     def anonymous_filter(self) -> Filter: | ||||
|         return search_util.create_str_filter(model.User.name) | ||||
| 
 | ||||
|     @property | ||||
|     def named_filters(self): | ||||
|     def named_filters(self) -> Dict[str, Filter]: | ||||
|         return { | ||||
|             'name': search_util.create_str_filter(db.User.name), | ||||
|             'name': | ||||
|                 search_util.create_str_filter(model.User.name), | ||||
|             'creation-date': | ||||
|                 search_util.create_date_filter(db.User.creation_time), | ||||
|                 search_util.create_date_filter(model.User.creation_time), | ||||
|             'creation-time': | ||||
|                 search_util.create_date_filter(db.User.creation_time), | ||||
|                 search_util.create_date_filter(model.User.creation_time), | ||||
|             'last-login-date': | ||||
|                 search_util.create_date_filter(db.User.last_login_time), | ||||
|                 search_util.create_date_filter(model.User.last_login_time), | ||||
|             'last-login-time': | ||||
|                 search_util.create_date_filter(db.User.last_login_time), | ||||
|                 search_util.create_date_filter(model.User.last_login_time), | ||||
|             'login-date': | ||||
|                 search_util.create_date_filter(db.User.last_login_time), | ||||
|                 search_util.create_date_filter(model.User.last_login_time), | ||||
|             'login-time': | ||||
|                 search_util.create_date_filter(db.User.last_login_time), | ||||
|                 search_util.create_date_filter(model.User.last_login_time), | ||||
|         } | ||||
| 
 | ||||
|     @property | ||||
|     def sort_columns(self): | ||||
|     def sort_columns(self) -> Dict[str, Tuple[SaColumn, str]]: | ||||
|         return { | ||||
|             'random': (func.random(), None), | ||||
|             'name': (db.User.name, self.SORT_ASC), | ||||
|             'creation-date': (db.User.creation_time, self.SORT_DESC), | ||||
|             'creation-time': (db.User.creation_time, self.SORT_DESC), | ||||
|             'last-login-date': (db.User.last_login_time, self.SORT_DESC), | ||||
|             'last-login-time': (db.User.last_login_time, self.SORT_DESC), | ||||
|             'login-date': (db.User.last_login_time, self.SORT_DESC), | ||||
|             'login-time': (db.User.last_login_time, self.SORT_DESC), | ||||
|             'random': (sa.sql.expression.func.random(), self.SORT_NONE), | ||||
|             'name': (model.User.name, self.SORT_ASC), | ||||
|             'creation-date': (model.User.creation_time, self.SORT_DESC), | ||||
|             'creation-time': (model.User.creation_time, self.SORT_DESC), | ||||
|             'last-login-date': (model.User.last_login_time, self.SORT_DESC), | ||||
|             'last-login-time': (model.User.last_login_time, self.SORT_DESC), | ||||
|             'login-date': (model.User.last_login_time, self.SORT_DESC), | ||||
|             'login-time': (model.User.last_login_time, self.SORT_DESC), | ||||
|         } | ||||
|  | ||||
| @ -1,10 +1,13 @@ | ||||
| import sqlalchemy | ||||
| from typing import Any, Optional, Callable | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import db, errors | ||||
| from szurubooru.func import util | ||||
| from szurubooru.search import criteria | ||||
| from szurubooru.search.typing import SaColumn, SaQuery | ||||
| from szurubooru.search.configs.base_search_config import Filter | ||||
| 
 | ||||
| 
 | ||||
| def wildcard_transformer(value): | ||||
| def wildcard_transformer(value: str) -> str: | ||||
|     return ( | ||||
|         value | ||||
|         .replace('\\', '\\\\') | ||||
| @ -13,24 +16,21 @@ def wildcard_transformer(value): | ||||
|         .replace('*', '%')) | ||||
| 
 | ||||
| 
 | ||||
| def apply_num_criterion_to_column(column, criterion): | ||||
|     ''' | ||||
|     Decorate SQLAlchemy filter on given column using supplied criterion. | ||||
|     ''' | ||||
| def apply_num_criterion_to_column( | ||||
|         column: Any, criterion: criteria.BaseCriterion) -> Any: | ||||
|     try: | ||||
|         if isinstance(criterion, criteria.PlainCriterion): | ||||
|             expr = column == int(criterion.value) | ||||
|         elif isinstance(criterion, criteria.ArrayCriterion): | ||||
|             expr = column.in_(int(value) for value in criterion.values) | ||||
|         elif isinstance(criterion, criteria.RangedCriterion): | ||||
|             assert criterion.min_value != '' \ | ||||
|                 or criterion.max_value != '' | ||||
|             if criterion.min_value != '' and criterion.max_value != '': | ||||
|             assert criterion.min_value or criterion.max_value | ||||
|             if criterion.min_value and criterion.max_value: | ||||
|                 expr = column.between( | ||||
|                     int(criterion.min_value), int(criterion.max_value)) | ||||
|             elif criterion.min_value != '': | ||||
|             elif criterion.min_value: | ||||
|                 expr = column >= int(criterion.min_value) | ||||
|             elif criterion.max_value != '': | ||||
|             elif criterion.max_value: | ||||
|                 expr = column <= int(criterion.max_value) | ||||
|         else: | ||||
|             assert False | ||||
| @ -40,10 +40,13 @@ def apply_num_criterion_to_column(column, criterion): | ||||
|     return expr | ||||
| 
 | ||||
| 
 | ||||
| def create_num_filter(column): | ||||
|     def wrapper(query, criterion, negated): | ||||
|         expr = apply_num_criterion_to_column( | ||||
|             column, criterion) | ||||
| def create_num_filter(column: Any) -> Filter: | ||||
|     def wrapper( | ||||
|             query: SaQuery, | ||||
|             criterion: Optional[criteria.BaseCriterion], | ||||
|             negated: bool) -> SaQuery: | ||||
|         assert criterion | ||||
|         expr = apply_num_criterion_to_column(column, criterion) | ||||
|         if negated: | ||||
|             expr = ~expr | ||||
|         return query.filter(expr) | ||||
| @ -51,14 +54,13 @@ def create_num_filter(column): | ||||
| 
 | ||||
| 
 | ||||
| def apply_str_criterion_to_column( | ||||
|         column, criterion, transformer=wildcard_transformer): | ||||
|     ''' | ||||
|     Decorate SQLAlchemy filter on given column using supplied criterion. | ||||
|     ''' | ||||
|         column: SaColumn, | ||||
|         criterion: criteria.BaseCriterion, | ||||
|         transformer: Callable[[str], str]=wildcard_transformer) -> SaQuery: | ||||
|     if isinstance(criterion, criteria.PlainCriterion): | ||||
|         expr = column.ilike(transformer(criterion.value)) | ||||
|     elif isinstance(criterion, criteria.ArrayCriterion): | ||||
|         expr = sqlalchemy.sql.false() | ||||
|         expr = sa.sql.false() | ||||
|         for value in criterion.values: | ||||
|             expr = expr | column.ilike(transformer(value)) | ||||
|     elif isinstance(criterion, criteria.RangedCriterion): | ||||
| @ -68,8 +70,15 @@ def apply_str_criterion_to_column( | ||||
|     return expr | ||||
| 
 | ||||
| 
 | ||||
| def create_str_filter(column, transformer=wildcard_transformer): | ||||
|     def wrapper(query, criterion, negated): | ||||
| def create_str_filter( | ||||
|     column: SaColumn, | ||||
|     transformer: Callable[[str], str]=wildcard_transformer | ||||
| ) -> Filter: | ||||
|     def wrapper( | ||||
|             query: SaQuery, | ||||
|             criterion: Optional[criteria.BaseCriterion], | ||||
|             negated: bool) -> SaQuery: | ||||
|         assert criterion | ||||
|         expr = apply_str_criterion_to_column( | ||||
|             column, criterion, transformer) | ||||
|         if negated: | ||||
| @ -78,16 +87,13 @@ def create_str_filter(column, transformer=wildcard_transformer): | ||||
|     return wrapper | ||||
| 
 | ||||
| 
 | ||||
| def apply_date_criterion_to_column(column, criterion): | ||||
|     ''' | ||||
|     Decorate SQLAlchemy filter on given column using supplied criterion. | ||||
|     Parse the datetime inside the criterion. | ||||
|     ''' | ||||
| def apply_date_criterion_to_column( | ||||
|         column: SaQuery, criterion: criteria.BaseCriterion) -> SaQuery: | ||||
|     if isinstance(criterion, criteria.PlainCriterion): | ||||
|         min_date, max_date = util.parse_time_range(criterion.value) | ||||
|         expr = column.between(min_date, max_date) | ||||
|     elif isinstance(criterion, criteria.ArrayCriterion): | ||||
|         expr = sqlalchemy.sql.false() | ||||
|         expr = sa.sql.false() | ||||
|         for value in criterion.values: | ||||
|             min_date, max_date = util.parse_time_range(value) | ||||
|             expr = expr | column.between(min_date, max_date) | ||||
| @ -108,10 +114,13 @@ def apply_date_criterion_to_column(column, criterion): | ||||
|     return expr | ||||
| 
 | ||||
| 
 | ||||
| def create_date_filter(column): | ||||
|     def wrapper(query, criterion, negated): | ||||
|         expr = apply_date_criterion_to_column( | ||||
|             column, criterion) | ||||
| def create_date_filter(column: SaColumn) -> Filter: | ||||
|     def wrapper( | ||||
|             query: SaQuery, | ||||
|             criterion: Optional[criteria.BaseCriterion], | ||||
|             negated: bool) -> SaQuery: | ||||
|         assert criterion | ||||
|         expr = apply_date_criterion_to_column(column, criterion) | ||||
|         if negated: | ||||
|             expr = ~expr | ||||
|         return query.filter(expr) | ||||
| @ -119,18 +128,22 @@ def create_date_filter(column): | ||||
| 
 | ||||
| 
 | ||||
| def create_subquery_filter( | ||||
|         left_id_column, | ||||
|         right_id_column, | ||||
|         filter_column, | ||||
|         filter_factory, | ||||
|         subquery_decorator=None): | ||||
|         left_id_column: SaColumn, | ||||
|         right_id_column: SaColumn, | ||||
|         filter_column: SaColumn, | ||||
|         filter_factory: SaColumn, | ||||
|         subquery_decorator: Callable[[SaQuery], None]=None) -> Filter: | ||||
|     filter_func = filter_factory(filter_column) | ||||
| 
 | ||||
|     def wrapper(query, criterion, negated): | ||||
|     def wrapper( | ||||
|             query: SaQuery, | ||||
|             criterion: Optional[criteria.BaseCriterion], | ||||
|             negated: bool) -> SaQuery: | ||||
|         assert criterion | ||||
|         subquery = db.session.query(right_id_column.label('foreign_id')) | ||||
|         if subquery_decorator: | ||||
|             subquery = subquery_decorator(subquery) | ||||
|         subquery = subquery.options(sqlalchemy.orm.lazyload('*')) | ||||
|         subquery = subquery.options(sa.orm.lazyload('*')) | ||||
|         subquery = filter_func(subquery, criterion, False) | ||||
|         subquery = subquery.subquery('t') | ||||
|         expression = left_id_column.in_(subquery) | ||||
|  | ||||
| @ -1,34 +1,42 @@ | ||||
| class _BaseCriterion: | ||||
|     def __init__(self, original_text): | ||||
| from typing import Optional, List, Callable | ||||
| from szurubooru.search.typing import SaQuery | ||||
| 
 | ||||
| 
 | ||||
| class BaseCriterion: | ||||
|     def __init__(self, original_text: str) -> None: | ||||
|         self.original_text = original_text | ||||
| 
 | ||||
|     def __repr__(self): | ||||
|     def __repr__(self) -> str: | ||||
|         return self.original_text | ||||
| 
 | ||||
| 
 | ||||
| class RangedCriterion(_BaseCriterion): | ||||
|     def __init__(self, original_text, min_value, max_value): | ||||
| class RangedCriterion(BaseCriterion): | ||||
|     def __init__( | ||||
|             self, | ||||
|             original_text: str, | ||||
|             min_value: Optional[str], | ||||
|             max_value: Optional[str]) -> None: | ||||
|         super().__init__(original_text) | ||||
|         self.min_value = min_value | ||||
|         self.max_value = max_value | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|     def __hash__(self) -> int: | ||||
|         return hash(('range', self.min_value, self.max_value)) | ||||
| 
 | ||||
| 
 | ||||
| class PlainCriterion(_BaseCriterion): | ||||
|     def __init__(self, original_text, value): | ||||
| class PlainCriterion(BaseCriterion): | ||||
|     def __init__(self, original_text: str, value: str) -> None: | ||||
|         super().__init__(original_text) | ||||
|         self.value = value | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|     def __hash__(self) -> int: | ||||
|         return hash(self.value) | ||||
| 
 | ||||
| 
 | ||||
| class ArrayCriterion(_BaseCriterion): | ||||
|     def __init__(self, original_text, values): | ||||
| class ArrayCriterion(BaseCriterion): | ||||
|     def __init__(self, original_text: str, values: List[str]) -> None: | ||||
|         super().__init__(original_text) | ||||
|         self.values = values | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|     def __hash__(self) -> int: | ||||
|         return hash(tuple(['array'] + self.values)) | ||||
|  | ||||
| @ -1,14 +1,18 @@ | ||||
| import sqlalchemy | ||||
| from szurubooru import db, errors | ||||
| from typing import Union, Tuple, List, Dict, Callable | ||||
| import sqlalchemy as sa | ||||
| from szurubooru import db, model, errors, rest | ||||
| from szurubooru.func import cache | ||||
| from szurubooru.search import tokens, parser | ||||
| from szurubooru.search.typing import SaQuery | ||||
| from szurubooru.search.query import SearchQuery | ||||
| from szurubooru.search.configs.base_search_config import BaseSearchConfig | ||||
| 
 | ||||
| 
 | ||||
| def _format_dict_keys(source): | ||||
| def _format_dict_keys(source: Dict) -> List[str]: | ||||
|     return list(sorted(source.keys())) | ||||
| 
 | ||||
| 
 | ||||
| def _get_order(order, default_order): | ||||
| def _get_order(order: str, default_order: str) -> Union[bool, str]: | ||||
|     if order == tokens.SortToken.SORT_DEFAULT: | ||||
|         return default_order or tokens.SortToken.SORT_ASC | ||||
|     if order == tokens.SortToken.SORT_NEGATED_DEFAULT: | ||||
| @ -26,50 +30,57 @@ class Executor: | ||||
|     delegates sqlalchemy filter decoration to SearchConfig instances. | ||||
|     ''' | ||||
| 
 | ||||
|     def __init__(self, search_config): | ||||
|     def __init__(self, search_config: BaseSearchConfig) -> None: | ||||
|         self.config = search_config | ||||
|         self.parser = parser.Parser() | ||||
| 
 | ||||
|     def get_around(self, query_text, entity_id): | ||||
|     def get_around( | ||||
|             self, | ||||
|             query_text: str, | ||||
|             entity_id: int) -> Tuple[model.Base, model.Base]: | ||||
|         search_query = self.parser.parse(query_text) | ||||
|         self.config.on_search_query_parsed(search_query) | ||||
|         filter_query = ( | ||||
|             self.config | ||||
|                 .create_around_query() | ||||
|                 .options(sqlalchemy.orm.lazyload('*'))) | ||||
|                 .options(sa.orm.lazyload('*'))) | ||||
|         filter_query = self._prepare_db_query( | ||||
|             filter_query, search_query, False) | ||||
|         prev_filter_query = ( | ||||
|             filter_query | ||||
|             .filter(self.config.id_column > entity_id) | ||||
|             .order_by(None) | ||||
|             .order_by(sqlalchemy.func.abs( | ||||
|                 self.config.id_column - entity_id).asc()) | ||||
|             .order_by(sa.func.abs(self.config.id_column - entity_id).asc()) | ||||
|             .limit(1)) | ||||
|         next_filter_query = ( | ||||
|             filter_query | ||||
|             .filter(self.config.id_column < entity_id) | ||||
|             .order_by(None) | ||||
|             .order_by(sqlalchemy.func.abs( | ||||
|                 self.config.id_column - entity_id).asc()) | ||||
|             .order_by(sa.func.abs(self.config.id_column - entity_id).asc()) | ||||
|             .limit(1)) | ||||
|         return [ | ||||
|         return ( | ||||
|             prev_filter_query.one_or_none(), | ||||
|             next_filter_query.one_or_none()] | ||||
|             next_filter_query.one_or_none()) | ||||
| 
 | ||||
|     def get_around_and_serialize(self, ctx, entity_id, serializer): | ||||
|         entities = self.get_around(ctx.get_param_as_string('query'), entity_id) | ||||
|     def get_around_and_serialize( | ||||
|         self, | ||||
|         ctx: rest.Context, | ||||
|         entity_id: int, | ||||
|         serializer: Callable[[model.Base], rest.Response] | ||||
|     ) -> rest.Response: | ||||
|         entities = self.get_around( | ||||
|             ctx.get_param_as_string('query', default=''), entity_id) | ||||
|         return { | ||||
|             'prev': serializer(entities[0]), | ||||
|             'next': serializer(entities[1]), | ||||
|         } | ||||
| 
 | ||||
|     def execute(self, query_text, page, page_size): | ||||
|         ''' | ||||
|         Parse input and return tuple containing total record count and filtered | ||||
|         entities. | ||||
|         ''' | ||||
| 
 | ||||
|     def execute( | ||||
|         self, | ||||
|         query_text: str, | ||||
|         page: int, | ||||
|         page_size: int | ||||
|     ) -> Tuple[int, List[model.Base]]: | ||||
|         search_query = self.parser.parse(query_text) | ||||
|         self.config.on_search_query_parsed(search_query) | ||||
| 
 | ||||
| @ -83,7 +94,7 @@ class Executor: | ||||
|             return cache.get(key) | ||||
| 
 | ||||
|         filter_query = self.config.create_filter_query(disable_eager_loads) | ||||
|         filter_query = filter_query.options(sqlalchemy.orm.lazyload('*')) | ||||
|         filter_query = filter_query.options(sa.orm.lazyload('*')) | ||||
|         filter_query = self._prepare_db_query(filter_query, search_query, True) | ||||
|         entities = filter_query \ | ||||
|             .offset(max(page - 1, 0) * page_size) \ | ||||
| @ -91,11 +102,11 @@ class Executor: | ||||
|             .all() | ||||
| 
 | ||||
|         count_query = self.config.create_count_query(disable_eager_loads) | ||||
|         count_query = count_query.options(sqlalchemy.orm.lazyload('*')) | ||||
|         count_query = count_query.options(sa.orm.lazyload('*')) | ||||
|         count_query = self._prepare_db_query(count_query, search_query, False) | ||||
|         count_statement = count_query \ | ||||
|             .statement \ | ||||
|             .with_only_columns([sqlalchemy.func.count()]) \ | ||||
|             .with_only_columns([sa.func.count()]) \ | ||||
|             .order_by(None) | ||||
|         count = db.session.execute(count_statement).scalar() | ||||
| 
 | ||||
| @ -103,8 +114,12 @@ class Executor: | ||||
|         cache.put(key, ret) | ||||
|         return ret | ||||
| 
 | ||||
|     def execute_and_serialize(self, ctx, serializer): | ||||
|         query = ctx.get_param_as_string('query') | ||||
|     def execute_and_serialize( | ||||
|         self, | ||||
|         ctx: rest.Context, | ||||
|         serializer: Callable[[model.Base], rest.Response] | ||||
|     ) -> rest.Response: | ||||
|         query = ctx.get_param_as_string('query', default='') | ||||
|         page = ctx.get_param_as_int('page', default=1, min=1) | ||||
|         page_size = ctx.get_param_as_int( | ||||
|             'pageSize', default=100, min=1, max=100) | ||||
| @ -117,48 +132,51 @@ class Executor: | ||||
|             'results': [serializer(entity) for entity in entities], | ||||
|         } | ||||
| 
 | ||||
|     def _prepare_db_query(self, db_query, search_query, use_sort): | ||||
|         ''' Parse input and return SQLAlchemy query. ''' | ||||
| 
 | ||||
|         for token in search_query.anonymous_tokens: | ||||
|     def _prepare_db_query( | ||||
|             self, | ||||
|             db_query: SaQuery, | ||||
|             search_query: SearchQuery, | ||||
|             use_sort: bool) -> SaQuery: | ||||
|         for anon_token in search_query.anonymous_tokens: | ||||
|             if not self.config.anonymous_filter: | ||||
|                 raise errors.SearchError( | ||||
|                     'Anonymous tokens are not valid in this context.') | ||||
|             db_query = self.config.anonymous_filter( | ||||
|                 db_query, token.criterion, token.negated) | ||||
|                 db_query, anon_token.criterion, anon_token.negated) | ||||
| 
 | ||||
|         for token in search_query.named_tokens: | ||||
|             if token.name not in self.config.named_filters: | ||||
|         for named_token in search_query.named_tokens: | ||||
|             if named_token.name not in self.config.named_filters: | ||||
|                 raise errors.SearchError( | ||||
|                     'Unknown named token: %r. Available named tokens: %r.' % ( | ||||
|                         token.name, | ||||
|                         named_token.name, | ||||
|                         _format_dict_keys(self.config.named_filters))) | ||||
|             db_query = self.config.named_filters[token.name]( | ||||
|                 db_query, token.criterion, token.negated) | ||||
|             db_query = self.config.named_filters[named_token.name]( | ||||
|                 db_query, named_token.criterion, named_token.negated) | ||||
| 
 | ||||
|         for token in search_query.special_tokens: | ||||
|             if token.value not in self.config.special_filters: | ||||
|         for sp_token in search_query.special_tokens: | ||||
|             if sp_token.value not in self.config.special_filters: | ||||
|                 raise errors.SearchError( | ||||
|                     'Unknown special token: %r. ' | ||||
|                     'Available special tokens: %r.' % ( | ||||
|                         token.value, | ||||
|                         sp_token.value, | ||||
|                         _format_dict_keys(self.config.special_filters))) | ||||
|             db_query = self.config.special_filters[token.value]( | ||||
|                 db_query, token.negated) | ||||
|             db_query = self.config.special_filters[sp_token.value]( | ||||
|                 db_query, None, sp_token.negated) | ||||
| 
 | ||||
|         if use_sort: | ||||
|             for token in search_query.sort_tokens: | ||||
|                 if token.name not in self.config.sort_columns: | ||||
|             for sort_token in search_query.sort_tokens: | ||||
|                 if sort_token.name not in self.config.sort_columns: | ||||
|                     raise errors.SearchError( | ||||
|                         'Unknown sort token: %r. ' | ||||
|                         'Available sort tokens: %r.' % ( | ||||
|                             token.name, | ||||
|                             sort_token.name, | ||||
|                             _format_dict_keys(self.config.sort_columns))) | ||||
|                 column, default_order = self.config.sort_columns[token.name] | ||||
|                 order = _get_order(token.order, default_order) | ||||
|                 if order == token.SORT_ASC: | ||||
|                 column, default_order = ( | ||||
|                     self.config.sort_columns[sort_token.name]) | ||||
|                 order = _get_order(sort_token.order, default_order) | ||||
|                 if order == sort_token.SORT_ASC: | ||||
|                     db_query = db_query.order_by(column.asc()) | ||||
|                 elif order == token.SORT_DESC: | ||||
|                 elif order == sort_token.SORT_DESC: | ||||
|                     db_query = db_query.order_by(column.desc()) | ||||
| 
 | ||||
|         db_query = self.config.finalize_query(db_query) | ||||
|  | ||||
| @ -1,9 +1,12 @@ | ||||
| import re | ||||
| from typing import List | ||||
| from szurubooru import errors | ||||
| from szurubooru.search import criteria, tokens | ||||
| from szurubooru.search.query import SearchQuery | ||||
| 
 | ||||
| 
 | ||||
| def _create_criterion(original_value, value): | ||||
| def _create_criterion( | ||||
|         original_value: str, value: str) -> criteria.BaseCriterion: | ||||
|     if ',' in value: | ||||
|         return criteria.ArrayCriterion( | ||||
|             original_value, value.split(',')) | ||||
| @ -15,12 +18,12 @@ def _create_criterion(original_value, value): | ||||
|     return criteria.PlainCriterion(original_value, value) | ||||
| 
 | ||||
| 
 | ||||
| def _parse_anonymous(value, negated): | ||||
| def _parse_anonymous(value: str, negated: bool) -> tokens.AnonymousToken: | ||||
|     criterion = _create_criterion(value, value) | ||||
|     return tokens.AnonymousToken(criterion, negated) | ||||
| 
 | ||||
| 
 | ||||
| def _parse_named(key, value, negated): | ||||
| def _parse_named(key: str, value: str, negated: bool) -> tokens.NamedToken: | ||||
|     original_value = value | ||||
|     if key.endswith('-min'): | ||||
|         key = key[:-4] | ||||
| @ -32,11 +35,11 @@ def _parse_named(key, value, negated): | ||||
|     return tokens.NamedToken(key, criterion, negated) | ||||
| 
 | ||||
| 
 | ||||
| def _parse_special(value, negated): | ||||
| def _parse_special(value: str, negated: bool) -> tokens.SpecialToken: | ||||
|     return tokens.SpecialToken(value, negated) | ||||
| 
 | ||||
| 
 | ||||
| def _parse_sort(value, negated): | ||||
| def _parse_sort(value: str, negated: bool) -> tokens.SortToken: | ||||
|     if value.count(',') == 0: | ||||
|         order_str = None | ||||
|     elif value.count(',') == 1: | ||||
| @ -67,23 +70,8 @@ def _parse_sort(value, negated): | ||||
|     return tokens.SortToken(value, order) | ||||
| 
 | ||||
| 
 | ||||
| class SearchQuery: | ||||
|     def __init__(self): | ||||
|         self.anonymous_tokens = [] | ||||
|         self.named_tokens = [] | ||||
|         self.special_tokens = [] | ||||
|         self.sort_tokens = [] | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|         return hash(( | ||||
|             tuple(self.anonymous_tokens), | ||||
|             tuple(self.named_tokens), | ||||
|             tuple(self.special_tokens), | ||||
|             tuple(self.sort_tokens))) | ||||
| 
 | ||||
| 
 | ||||
| class Parser: | ||||
|     def parse(self, query_text): | ||||
|     def parse(self, query_text: str) -> SearchQuery: | ||||
|         query = SearchQuery() | ||||
|         for chunk in re.split(r'\s+', (query_text or '').lower()): | ||||
|             if not chunk: | ||||
|  | ||||
							
								
								
									
										16
									
								
								server/szurubooru/search/query.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										16
									
								
								server/szurubooru/search/query.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,16 @@ | ||||
| from szurubooru.search import tokens | ||||
| 
 | ||||
| 
 | ||||
| class SearchQuery: | ||||
|     def __init__(self) -> None: | ||||
|         self.anonymous_tokens = []  # type: List[tokens.AnonymousToken] | ||||
|         self.named_tokens = []  # type: List[tokens.NamedToken] | ||||
|         self.special_tokens = []  # type: List[tokens.SpecialToken] | ||||
|         self.sort_tokens = []  # type: List[tokens.SortToken] | ||||
| 
 | ||||
|     def __hash__(self) -> int: | ||||
|         return hash(( | ||||
|             tuple(self.anonymous_tokens), | ||||
|             tuple(self.named_tokens), | ||||
|             tuple(self.special_tokens), | ||||
|             tuple(self.sort_tokens))) | ||||
| @ -1,39 +1,44 @@ | ||||
| from szurubooru.search.criteria import BaseCriterion | ||||
| 
 | ||||
| 
 | ||||
| class AnonymousToken: | ||||
|     def __init__(self, criterion, negated): | ||||
|     def __init__(self, criterion: BaseCriterion, negated: bool) -> None: | ||||
|         self.criterion = criterion | ||||
|         self.negated = negated | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|     def __hash__(self) -> int: | ||||
|         return hash((self.criterion, self.negated)) | ||||
| 
 | ||||
| 
 | ||||
| class NamedToken(AnonymousToken): | ||||
|     def __init__(self, name, criterion, negated): | ||||
|     def __init__( | ||||
|             self, name: str, criterion: BaseCriterion, negated: bool) -> None: | ||||
|         super().__init__(criterion, negated) | ||||
|         self.name = name | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|     def __hash__(self) -> int: | ||||
|         return hash((self.name, self.criterion, self.negated)) | ||||
| 
 | ||||
| 
 | ||||
| class SortToken: | ||||
|     SORT_DESC = 'desc' | ||||
|     SORT_ASC = 'asc' | ||||
|     SORT_NONE = '' | ||||
|     SORT_DEFAULT = 'default' | ||||
|     SORT_NEGATED_DEFAULT = 'negated default' | ||||
| 
 | ||||
|     def __init__(self, name, order): | ||||
|     def __init__(self, name: str, order: str) -> None: | ||||
|         self.name = name | ||||
|         self.order = order | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|     def __hash__(self) -> int: | ||||
|         return hash((self.name, self.order)) | ||||
| 
 | ||||
| 
 | ||||
| class SpecialToken: | ||||
|     def __init__(self, value, negated): | ||||
|     def __init__(self, value: str, negated: bool) -> None: | ||||
|         self.value = value | ||||
|         self.negated = negated | ||||
| 
 | ||||
|     def __hash__(self): | ||||
|     def __hash__(self) -> int: | ||||
|         return hash((self.value, self.negated)) | ||||
|  | ||||
							
								
								
									
										6
									
								
								server/szurubooru/search/typing.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										6
									
								
								server/szurubooru/search/typing.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,6 @@ | ||||
| from typing import Any, Callable | ||||
| 
 | ||||
| 
 | ||||
| SaColumn = Any | ||||
| SaQuery = Any | ||||
| SaQueryFactory = Callable[[], SaQuery] | ||||
| @ -1,19 +1,20 @@ | ||||
| from datetime import datetime | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import comments, posts | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'comments:create': db.User.RANK_REGULAR}}) | ||||
|     config_injector( | ||||
|         {'privileges': {'comments:create': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_creating_comment( | ||||
|         user_factory, post_factory, context_factory, fake_datetime): | ||||
|     post = post_factory() | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     db.session.add_all([post, user]) | ||||
|     db.session.flush() | ||||
|     with patch('szurubooru.func.comments.serialize_comment'), \ | ||||
| @ -24,7 +25,7 @@ def test_creating_comment( | ||||
|                 params={'text': 'input', 'postId': post.post_id}, | ||||
|                 user=user)) | ||||
|         assert result == 'serialized comment' | ||||
|         comment = db.session.query(db.Comment).one() | ||||
|         comment = db.session.query(model.Comment).one() | ||||
|         assert comment.text == 'input' | ||||
|         assert comment.creation_time == datetime(1997, 1, 1) | ||||
|         assert comment.last_edit_time is None | ||||
| @ -41,7 +42,7 @@ def test_creating_comment( | ||||
| def test_trying_to_pass_invalid_params( | ||||
|         user_factory, post_factory, context_factory, params): | ||||
|     post = post_factory() | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     db.session.add_all([post, user]) | ||||
|     db.session.flush() | ||||
|     real_params = {'text': 'input', 'postId': post.post_id} | ||||
| @ -63,11 +64,11 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): | ||||
|         api.comment_api.create_comment( | ||||
|             context_factory( | ||||
|                 params={}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_comment_non_existing(user_factory, context_factory): | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     db.session.add_all([user]) | ||||
|     db.session.flush() | ||||
|     with pytest.raises(posts.PostNotFoundError): | ||||
| @ -81,4 +82,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory): | ||||
|         api.comment_api.create_comment( | ||||
|             context_factory( | ||||
|                 params={}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import comments | ||||
| 
 | ||||
| 
 | ||||
| @ -7,8 +7,8 @@ from szurubooru.func import comments | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'comments:delete:own': db.User.RANK_REGULAR, | ||||
|             'comments:delete:any': db.User.RANK_MODERATOR, | ||||
|             'comments:delete:own': model.User.RANK_REGULAR, | ||||
|             'comments:delete:any': model.User.RANK_MODERATOR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| @ -22,26 +22,26 @@ def test_deleting_own_comment(user_factory, comment_factory, context_factory): | ||||
|         context_factory(params={'version': 1}, user=user), | ||||
|         {'comment_id': comment.comment_id}) | ||||
|     assert result == {} | ||||
|     assert db.session.query(db.Comment).count() == 0 | ||||
|     assert db.session.query(model.Comment).count() == 0 | ||||
| 
 | ||||
| 
 | ||||
| def test_deleting_someones_else_comment( | ||||
|         user_factory, comment_factory, context_factory): | ||||
|     user1 = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=db.User.RANK_MODERATOR) | ||||
|     user1 = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=model.User.RANK_MODERATOR) | ||||
|     comment = comment_factory(user=user1) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
|     api.comment_api.delete_comment( | ||||
|         context_factory(params={'version': 1}, user=user2), | ||||
|         {'comment_id': comment.comment_id}) | ||||
|     assert db.session.query(db.Comment).count() == 0 | ||||
|     assert db.session.query(model.Comment).count() == 0 | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_delete_someones_else_comment_without_privileges( | ||||
|         user_factory, comment_factory, context_factory): | ||||
|     user1 = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user1 = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     comment = comment_factory(user=user1) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
| @ -49,7 +49,7 @@ def test_trying_to_delete_someones_else_comment_without_privileges( | ||||
|         api.comment_api.delete_comment( | ||||
|             context_factory(params={'version': 1}, user=user2), | ||||
|             {'comment_id': comment.comment_id}) | ||||
|     assert db.session.query(db.Comment).count() == 1 | ||||
|     assert db.session.query(model.Comment).count() == 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_delete_non_existing(user_factory, context_factory): | ||||
| @ -57,5 +57,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory): | ||||
|         api.comment_api.delete_comment( | ||||
|             context_factory( | ||||
|                 params={'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'comment_id': 1}) | ||||
|  | ||||
| @ -1,17 +1,18 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import comments | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'comments:score': db.User.RANK_REGULAR}}) | ||||
|     config_injector( | ||||
|         {'privileges': {'comments:score': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_simple_rating( | ||||
|         user_factory, comment_factory, context_factory, fake_datetime): | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     comment = comment_factory(user=user) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
| @ -22,14 +23,14 @@ def test_simple_rating( | ||||
|             context_factory(params={'score': 1}, user=user), | ||||
|             {'comment_id': comment.comment_id}) | ||||
|         assert result == 'serialized comment' | ||||
|         assert db.session.query(db.CommentScore).count() == 1 | ||||
|         assert db.session.query(model.CommentScore).count() == 1 | ||||
|         assert comment is not None | ||||
|         assert comment.score == 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_updating_rating( | ||||
|         user_factory, comment_factory, context_factory, fake_datetime): | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     comment = comment_factory(user=user) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
| @ -42,14 +43,14 @@ def test_updating_rating( | ||||
|             api.comment_api.set_comment_score( | ||||
|                 context_factory(params={'score': -1}, user=user), | ||||
|                 {'comment_id': comment.comment_id}) | ||||
|         comment = db.session.query(db.Comment).one() | ||||
|         assert db.session.query(db.CommentScore).count() == 1 | ||||
|         comment = db.session.query(model.Comment).one() | ||||
|         assert db.session.query(model.CommentScore).count() == 1 | ||||
|         assert comment.score == -1 | ||||
| 
 | ||||
| 
 | ||||
| def test_updating_rating_to_zero( | ||||
|         user_factory, comment_factory, context_factory, fake_datetime): | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     comment = comment_factory(user=user) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
| @ -62,14 +63,14 @@ def test_updating_rating_to_zero( | ||||
|             api.comment_api.set_comment_score( | ||||
|                 context_factory(params={'score': 0}, user=user), | ||||
|                 {'comment_id': comment.comment_id}) | ||||
|         comment = db.session.query(db.Comment).one() | ||||
|         assert db.session.query(db.CommentScore).count() == 0 | ||||
|         comment = db.session.query(model.Comment).one() | ||||
|         assert db.session.query(model.CommentScore).count() == 0 | ||||
|         assert comment.score == 0 | ||||
| 
 | ||||
| 
 | ||||
| def test_deleting_rating( | ||||
|         user_factory, comment_factory, context_factory, fake_datetime): | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     comment = comment_factory(user=user) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
| @ -82,15 +83,15 @@ def test_deleting_rating( | ||||
|             api.comment_api.delete_comment_score( | ||||
|                 context_factory(user=user), | ||||
|                 {'comment_id': comment.comment_id}) | ||||
|         comment = db.session.query(db.Comment).one() | ||||
|         assert db.session.query(db.CommentScore).count() == 0 | ||||
|         comment = db.session.query(model.Comment).one() | ||||
|         assert db.session.query(model.CommentScore).count() == 0 | ||||
|         assert comment.score == 0 | ||||
| 
 | ||||
| 
 | ||||
| def test_ratings_from_multiple_users( | ||||
|         user_factory, comment_factory, context_factory, fake_datetime): | ||||
|     user1 = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user1 = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     comment = comment_factory() | ||||
|     db.session.add_all([user1, user2, comment]) | ||||
|     db.session.commit() | ||||
| @ -103,8 +104,8 @@ def test_ratings_from_multiple_users( | ||||
|             api.comment_api.set_comment_score( | ||||
|                 context_factory(params={'score': -1}, user=user2), | ||||
|                 {'comment_id': comment.comment_id}) | ||||
|         comment = db.session.query(db.Comment).one() | ||||
|         assert db.session.query(db.CommentScore).count() == 2 | ||||
|         comment = db.session.query(model.Comment).one() | ||||
|         assert db.session.query(model.CommentScore).count() == 2 | ||||
|         assert comment.score == 0 | ||||
| 
 | ||||
| 
 | ||||
| @ -125,7 +126,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): | ||||
|         api.comment_api.set_comment_score( | ||||
|             context_factory( | ||||
|                 params={'score': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'comment_id': 5}) | ||||
| 
 | ||||
| 
 | ||||
| @ -138,5 +139,5 @@ def test_trying_to_rate_without_privileges( | ||||
|         api.comment_api.set_comment_score( | ||||
|             context_factory( | ||||
|                 params={'score': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'comment_id': comment.comment_id}) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import comments | ||||
| 
 | ||||
| 
 | ||||
| @ -8,8 +8,8 @@ from szurubooru.func import comments | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'comments:list': db.User.RANK_REGULAR, | ||||
|             'comments:view': db.User.RANK_REGULAR, | ||||
|             'comments:list': model.User.RANK_REGULAR, | ||||
|             'comments:view': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| @ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, comment_factory, context_factory): | ||||
|         result = api.comment_api.get_comments( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         assert result == { | ||||
|             'query': '', | ||||
|             'page': 1, | ||||
| @ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges( | ||||
|         api.comment_api.get_comments( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
| 
 | ||||
| 
 | ||||
| def test_retrieving_single(user_factory, comment_factory, context_factory): | ||||
| @ -51,7 +51,7 @@ def test_retrieving_single(user_factory, comment_factory, context_factory): | ||||
|         comments.serialize_comment.return_value = 'serialized comment' | ||||
|         result = api.comment_api.get_comment( | ||||
|             context_factory( | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'comment_id': comment.comment_id}) | ||||
|         assert result == 'serialized comment' | ||||
| 
 | ||||
| @ -60,7 +60,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): | ||||
|     with pytest.raises(comments.CommentNotFoundError): | ||||
|         api.comment_api.get_comment( | ||||
|             context_factory( | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'comment_id': 5}) | ||||
| 
 | ||||
| 
 | ||||
| @ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges( | ||||
|         user_factory, context_factory): | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.comment_api.get_comment( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'comment_id': 5}) | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| from datetime import datetime | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import comments | ||||
| 
 | ||||
| 
 | ||||
| @ -9,15 +9,15 @@ from szurubooru.func import comments | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'comments:edit:own': db.User.RANK_REGULAR, | ||||
|             'comments:edit:any': db.User.RANK_MODERATOR, | ||||
|             'comments:edit:own': model.User.RANK_REGULAR, | ||||
|             'comments:edit:any': model.User.RANK_MODERATOR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_simple_updating( | ||||
|         user_factory, comment_factory, context_factory, fake_datetime): | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     comment = comment_factory(user=user) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
| @ -73,14 +73,14 @@ def test_trying_to_update_non_existing(user_factory, context_factory): | ||||
|         api.comment_api.update_comment( | ||||
|             context_factory( | ||||
|                 params={'text': 'new text'}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'comment_id': 5}) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_update_someones_comment_without_privileges( | ||||
|         user_factory, comment_factory, context_factory): | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     comment = comment_factory(user=user) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
| @ -93,8 +93,8 @@ def test_trying_to_update_someones_comment_without_privileges( | ||||
| 
 | ||||
| def test_updating_someones_comment_with_privileges( | ||||
|         user_factory, comment_factory, context_factory): | ||||
|     user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=db.User.RANK_MODERATOR) | ||||
|     user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     user2 = user_factory(rank=model.User.RANK_MODERATOR) | ||||
|     comment = comment_factory(user=user) | ||||
|     db.session.add(comment) | ||||
|     db.session.commit() | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import auth, mailer | ||||
| 
 | ||||
| 
 | ||||
| @ -15,7 +15,7 @@ def inject_config(config_injector): | ||||
| 
 | ||||
| def test_reset_sending_email(context_factory, user_factory): | ||||
|     db.session.add(user_factory( | ||||
|         name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) | ||||
|         name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) | ||||
|     db.session.flush() | ||||
|     for initiating_user in ['u1', 'user@example.com']: | ||||
|         with patch('szurubooru.func.mailer.send_mail'): | ||||
| @ -39,7 +39,7 @@ def test_trying_to_reset_non_existing(context_factory): | ||||
| 
 | ||||
| def test_trying_to_reset_without_email(context_factory, user_factory): | ||||
|     db.session.add( | ||||
|         user_factory(name='u1', rank=db.User.RANK_REGULAR, email=None)) | ||||
|         user_factory(name='u1', rank=model.User.RANK_REGULAR, email=None)) | ||||
|     db.session.flush() | ||||
|     with pytest.raises(errors.ValidationError): | ||||
|         api.password_reset_api.start_password_reset( | ||||
| @ -48,7 +48,7 @@ def test_trying_to_reset_without_email(context_factory, user_factory): | ||||
| 
 | ||||
| def test_confirming_with_good_token(context_factory, user_factory): | ||||
|     user = user_factory( | ||||
|         name='u1', rank=db.User.RANK_REGULAR, email='user@example.com') | ||||
|         name='u1', rank=model.User.RANK_REGULAR, email='user@example.com') | ||||
|     old_hash = user.password_hash | ||||
|     db.session.add(user) | ||||
|     db.session.flush() | ||||
| @ -68,7 +68,7 @@ def test_trying_to_confirm_non_existing(context_factory): | ||||
| 
 | ||||
| def test_trying_to_confirm_without_token(context_factory, user_factory): | ||||
|     db.session.add(user_factory( | ||||
|         name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) | ||||
|         name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) | ||||
|     db.session.flush() | ||||
|     with pytest.raises(errors.ValidationError): | ||||
|         api.password_reset_api.finish_password_reset( | ||||
| @ -77,7 +77,7 @@ def test_trying_to_confirm_without_token(context_factory, user_factory): | ||||
| 
 | ||||
| def test_trying_to_confirm_with_bad_token(context_factory, user_factory): | ||||
|     db.session.add(user_factory( | ||||
|         name='u1', rank=db.User.RANK_REGULAR, email='user@example.com')) | ||||
|         name='u1', rank=model.User.RANK_REGULAR, email='user@example.com')) | ||||
|     db.session.flush() | ||||
|     with pytest.raises(errors.ValidationError): | ||||
|         api.password_reset_api.finish_password_reset( | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import posts, tags, snapshots, net | ||||
| 
 | ||||
| 
 | ||||
| @ -8,16 +8,16 @@ from szurubooru.func import posts, tags, snapshots, net | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'posts:create:anonymous': db.User.RANK_REGULAR, | ||||
|             'posts:create:identified': db.User.RANK_REGULAR, | ||||
|             'tags:create': db.User.RANK_REGULAR, | ||||
|             'posts:create:anonymous': model.User.RANK_REGULAR, | ||||
|             'posts:create:identified': model.User.RANK_REGULAR, | ||||
|             'tags:create': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_creating_minimal_posts( | ||||
|         context_factory, post_factory, user_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory() | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -53,20 +53,20 @@ def test_creating_minimal_posts( | ||||
|         posts.update_post_thumbnail.assert_called_once_with( | ||||
|             post, 'post-thumbnail') | ||||
|         posts.update_post_safety.assert_called_once_with(post, 'safe') | ||||
|         posts.update_post_source.assert_called_once_with(post, None) | ||||
|         posts.update_post_source.assert_called_once_with(post, '') | ||||
|         posts.update_post_relations.assert_called_once_with(post, []) | ||||
|         posts.update_post_notes.assert_called_once_with(post, []) | ||||
|         posts.update_post_flags.assert_called_once_with(post, []) | ||||
|         posts.update_post_thumbnail.assert_called_once_with( | ||||
|             post, 'post-thumbnail') | ||||
|         posts.serialize_post.assert_called_once_with( | ||||
|             post, auth_user, options=None) | ||||
|             post, auth_user, options=[]) | ||||
|         snapshots.create.assert_called_once_with(post, auth_user) | ||||
|         tags.export_to_json.assert_called_once_with() | ||||
| 
 | ||||
| 
 | ||||
| def test_creating_full_posts(context_factory, post_factory, user_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory() | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -109,14 +109,14 @@ def test_creating_full_posts(context_factory, post_factory, user_factory): | ||||
|         posts.update_post_flags.assert_called_once_with( | ||||
|             post, ['flag1', 'flag2']) | ||||
|         posts.serialize_post.assert_called_once_with( | ||||
|             post, auth_user, options=None) | ||||
|             post, auth_user, options=[]) | ||||
|         snapshots.create.assert_called_once_with(post, auth_user) | ||||
|         tags.export_to_json.assert_called_once_with() | ||||
| 
 | ||||
| 
 | ||||
| def test_anonymous_uploads( | ||||
|         config_injector, context_factory, post_factory, user_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory() | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -126,7 +126,7 @@ def test_anonymous_uploads( | ||||
|             patch('szurubooru.func.posts.create_post'), \ | ||||
|             patch('szurubooru.func.posts.update_post_source'): | ||||
|         config_injector({ | ||||
|             'privileges': {'posts:create:anonymous': db.User.RANK_REGULAR}, | ||||
|             'privileges': {'posts:create:anonymous': model.User.RANK_REGULAR}, | ||||
|         }) | ||||
|         posts.create_post.return_value = [post, []] | ||||
|         api.post_api.create_post( | ||||
| @ -146,7 +146,7 @@ def test_anonymous_uploads( | ||||
| 
 | ||||
| def test_creating_from_url_saves_source( | ||||
|         config_injector, context_factory, post_factory, user_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory() | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -157,7 +157,7 @@ def test_creating_from_url_saves_source( | ||||
|             patch('szurubooru.func.posts.create_post'), \ | ||||
|             patch('szurubooru.func.posts.update_post_source'): | ||||
|         config_injector({ | ||||
|             'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, | ||||
|             'privileges': {'posts:create:identified': model.User.RANK_REGULAR}, | ||||
|         }) | ||||
|         net.download.return_value = b'content' | ||||
|         posts.create_post.return_value = [post, []] | ||||
| @ -177,7 +177,7 @@ def test_creating_from_url_saves_source( | ||||
| 
 | ||||
| def test_creating_from_url_with_source_specified( | ||||
|         config_injector, context_factory, post_factory, user_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory() | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -188,7 +188,7 @@ def test_creating_from_url_with_source_specified( | ||||
|             patch('szurubooru.func.posts.create_post'), \ | ||||
|             patch('szurubooru.func.posts.update_post_source'): | ||||
|         config_injector({ | ||||
|             'privileges': {'posts:create:identified': db.User.RANK_REGULAR}, | ||||
|             'privileges': {'posts:create:identified': model.User.RANK_REGULAR}, | ||||
|         }) | ||||
|         net.download.return_value = b'content' | ||||
|         posts.create_post.return_value = [post, []] | ||||
| @ -218,14 +218,14 @@ def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): | ||||
|             context_factory( | ||||
|                 params=params, | ||||
|                 files={'content': '...'}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'field', ['tags', 'relations', 'source', 'notes', 'flags']) | ||||
| def test_omitting_optional_field( | ||||
|         field, context_factory, post_factory, user_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory() | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -268,10 +268,10 @@ def test_errors_not_spending_ids( | ||||
|             'post_height': 300, | ||||
|         }, | ||||
|         'privileges': { | ||||
|             'posts:create:identified': db.User.RANK_REGULAR, | ||||
|             'posts:create:identified': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
| 
 | ||||
|     # successful request | ||||
|     with patch('szurubooru.func.posts.serialize_post'), \ | ||||
| @ -316,7 +316,7 @@ def test_trying_to_omit_content(context_factory, user_factory): | ||||
|                     'safety': 'safe', | ||||
|                     'tags': ['tag1', 'tag2'], | ||||
|                 }, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_create_post_without_privileges( | ||||
| @ -324,16 +324,16 @@ def test_trying_to_create_post_without_privileges( | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.post_api.create_post(context_factory( | ||||
|             params='whatever', | ||||
|             user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|             user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_create_tags_without_privileges( | ||||
|         config_injector, context_factory, user_factory): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'posts:create:anonymous': db.User.RANK_REGULAR, | ||||
|             'posts:create:identified': db.User.RANK_REGULAR, | ||||
|             'tags:create': db.User.RANK_ADMINISTRATOR, | ||||
|             'posts:create:anonymous': model.User.RANK_REGULAR, | ||||
|             'posts:create:identified': model.User.RANK_REGULAR, | ||||
|             'tags:create': model.User.RANK_ADMINISTRATOR, | ||||
|         }, | ||||
|     }) | ||||
|     with pytest.raises(errors.AuthError), \ | ||||
| @ -349,4 +349,4 @@ def test_trying_to_create_tags_without_privileges( | ||||
|                 files={ | ||||
|                     'content': posts.EMPTY_PIXEL, | ||||
|                 }, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|  | ||||
| @ -1,16 +1,16 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import posts, tags, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'posts:delete': db.User.RANK_REGULAR}}) | ||||
|     config_injector({'privileges': {'posts:delete': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_deleting(user_factory, post_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory(id=1) | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -20,7 +20,7 @@ def test_deleting(user_factory, post_factory, context_factory): | ||||
|             context_factory(params={'version': 1}, user=auth_user), | ||||
|             {'post_id': 1}) | ||||
|         assert result == {} | ||||
|         assert db.session.query(db.Post).count() == 0 | ||||
|         assert db.session.query(model.Post).count() == 0 | ||||
|         snapshots.delete.assert_called_once_with(post, auth_user) | ||||
|         tags.export_to_json.assert_called_once_with() | ||||
| 
 | ||||
| @ -28,7 +28,7 @@ def test_deleting(user_factory, post_factory, context_factory): | ||||
| def test_trying_to_delete_non_existing(user_factory, context_factory): | ||||
|     with pytest.raises(posts.PostNotFoundError): | ||||
|         api.post_api.delete_post( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'post_id': 999}) | ||||
| 
 | ||||
| 
 | ||||
| @ -38,6 +38,6 @@ def test_trying_to_delete_without_privileges( | ||||
|     db.session.commit() | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.post_api.delete_post( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'post_id': 1}) | ||||
|     assert db.session.query(db.Post).count() == 1 | ||||
|     assert db.session.query(model.Post).count() == 1 | ||||
|  | ||||
| @ -1,13 +1,14 @@ | ||||
| from datetime import datetime | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import posts | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'posts:favorite': db.User.RANK_REGULAR}}) | ||||
|     config_injector( | ||||
|         {'privileges': {'posts:favorite': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_adding_to_favorites( | ||||
| @ -23,8 +24,8 @@ def test_adding_to_favorites( | ||||
|             context_factory(user=user_factory()), | ||||
|             {'post_id': post.post_id}) | ||||
|         assert result == 'serialized post' | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostFavorite).count() == 1 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostFavorite).count() == 1 | ||||
|         assert post is not None | ||||
|         assert post.favorite_count == 1 | ||||
|         assert post.score == 1 | ||||
| @ -47,9 +48,9 @@ def test_removing_from_favorites( | ||||
|             api.post_api.delete_post_from_favorites( | ||||
|                 context_factory(user=user), | ||||
|                 {'post_id': post.post_id}) | ||||
|         post = db.session.query(db.Post).one() | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert post.score == 1 | ||||
|         assert db.session.query(db.PostFavorite).count() == 0 | ||||
|         assert db.session.query(model.PostFavorite).count() == 0 | ||||
|         assert post.favorite_count == 0 | ||||
| 
 | ||||
| 
 | ||||
| @ -68,8 +69,8 @@ def test_favoriting_twice( | ||||
|             api.post_api.add_post_to_favorites( | ||||
|                 context_factory(user=user), | ||||
|                 {'post_id': post.post_id}) | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostFavorite).count() == 1 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostFavorite).count() == 1 | ||||
|         assert post.favorite_count == 1 | ||||
| 
 | ||||
| 
 | ||||
| @ -92,8 +93,8 @@ def test_removing_twice( | ||||
|             api.post_api.delete_post_from_favorites( | ||||
|                 context_factory(user=user), | ||||
|                 {'post_id': post.post_id}) | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostFavorite).count() == 0 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostFavorite).count() == 0 | ||||
|         assert post.favorite_count == 0 | ||||
| 
 | ||||
| 
 | ||||
| @ -113,8 +114,8 @@ def test_favorites_from_multiple_users( | ||||
|             api.post_api.add_post_to_favorites( | ||||
|                 context_factory(user=user2), | ||||
|                 {'post_id': post.post_id}) | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostFavorite).count() == 2 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostFavorite).count() == 2 | ||||
|         assert post.favorite_count == 2 | ||||
|         assert post.last_favorite_time == datetime(1997, 12, 2) | ||||
| 
 | ||||
| @ -133,5 +134,5 @@ def test_trying_to_rate_without_privileges( | ||||
|     db.session.commit() | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.post_api.add_post_to_favorites( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'post_id': post.post_id}) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import posts, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @ -8,14 +8,14 @@ from szurubooru.func import posts, snapshots | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'posts:feature': db.User.RANK_REGULAR, | ||||
|             'posts:view': db.User.RANK_REGULAR, | ||||
|             'posts:feature': model.User.RANK_REGULAR, | ||||
|             'posts:view': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_featuring(user_factory, post_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory(id=1) | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -31,7 +31,7 @@ def test_featuring(user_factory, post_factory, context_factory): | ||||
|         assert posts.get_post_by_id(1).is_featured | ||||
|         result = api.post_api.get_featured_post( | ||||
|             context_factory( | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         assert result == 'serialized post' | ||||
|         snapshots.modify.assert_called_once_with(post, auth_user) | ||||
| 
 | ||||
| @ -40,7 +40,7 @@ def test_trying_to_omit_required_parameter(user_factory, context_factory): | ||||
|     with pytest.raises(errors.MissingRequiredParameterError): | ||||
|         api.post_api.set_featured_post( | ||||
|             context_factory( | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_feature_the_same_post_twice( | ||||
| @ -51,12 +51,12 @@ def test_trying_to_feature_the_same_post_twice( | ||||
|         api.post_api.set_featured_post( | ||||
|             context_factory( | ||||
|                 params={'id': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         with pytest.raises(posts.PostAlreadyFeaturedError): | ||||
|             api.post_api.set_featured_post( | ||||
|                 context_factory( | ||||
|                     params={'id': 1}, | ||||
|                     user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                     user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_featuring_one_post_after_another( | ||||
| @ -72,12 +72,12 @@ def test_featuring_one_post_after_another( | ||||
|             api.post_api.set_featured_post( | ||||
|                 context_factory( | ||||
|                     params={'id': 1}, | ||||
|                     user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                     user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         with fake_datetime('1998'): | ||||
|             api.post_api.set_featured_post( | ||||
|                 context_factory( | ||||
|                     params={'id': 2}, | ||||
|                     user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                     user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         assert posts.try_get_featured_post() is not None | ||||
|         assert posts.try_get_featured_post().post_id == 2 | ||||
|         assert not posts.get_post_by_id(1).is_featured | ||||
| @ -89,7 +89,7 @@ def test_trying_to_feature_non_existing(user_factory, context_factory): | ||||
|         api.post_api.set_featured_post( | ||||
|             context_factory( | ||||
|                 params={'id': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_feature_without_privileges(user_factory, context_factory): | ||||
| @ -97,10 +97,10 @@ def test_trying_to_feature_without_privileges(user_factory, context_factory): | ||||
|         api.post_api.set_featured_post( | ||||
|             context_factory( | ||||
|                 params={'id': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
| 
 | ||||
| 
 | ||||
| def test_getting_featured_post_without_privileges_to_view( | ||||
|         user_factory, context_factory): | ||||
|     api.post_api.get_featured_post( | ||||
|         context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|         context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
|  | ||||
| @ -1,16 +1,16 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import posts, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'posts:merge': db.User.RANK_REGULAR}}) | ||||
|     config_injector({'privileges': {'posts:merge': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_merging(user_factory, context_factory, post_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     source_post = post_factory() | ||||
|     target_post = post_factory() | ||||
|     db.session.add_all([source_post, target_post]) | ||||
| @ -25,6 +25,7 @@ def test_merging(user_factory, context_factory, post_factory): | ||||
|                     'mergeToVersion': 1, | ||||
|                     'remove': source_post.post_id, | ||||
|                     'mergeTo': target_post.post_id, | ||||
|                     'replaceContent': False, | ||||
|                 }, | ||||
|                 user=auth_user)) | ||||
|         posts.merge_posts.called_once_with(source_post, target_post) | ||||
| @ -45,13 +46,14 @@ def test_trying_to_omit_mandatory_field( | ||||
|         'mergeToVersion': 1, | ||||
|         'remove': source_post.post_id, | ||||
|         'mergeTo': target_post.post_id, | ||||
|         'replaceContent': False, | ||||
|     } | ||||
|     del params[field] | ||||
|     with pytest.raises(errors.ValidationError): | ||||
|         api.post_api.merge_posts( | ||||
|             context_factory( | ||||
|                 params=params, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_merge_non_existing( | ||||
| @ -63,12 +65,12 @@ def test_trying_to_merge_non_existing( | ||||
|         api.post_api.merge_posts( | ||||
|             context_factory( | ||||
|                 params={'remove': post.post_id, 'mergeTo': 999}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|     with pytest.raises(posts.PostNotFoundError): | ||||
|         api.post_api.merge_posts( | ||||
|             context_factory( | ||||
|                 params={'remove': 999, 'mergeTo': post.post_id}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_merge_without_privileges( | ||||
| @ -85,5 +87,6 @@ def test_trying_to_merge_without_privileges( | ||||
|                     'mergeToVersion': 1, | ||||
|                     'remove': source_post.post_id, | ||||
|                     'mergeTo': target_post.post_id, | ||||
|                     'replaceContent': False, | ||||
|                 }, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
|  | ||||
| @ -1,12 +1,12 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import posts | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'posts:score': db.User.RANK_REGULAR}}) | ||||
|     config_injector({'privileges': {'posts:score': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_simple_rating( | ||||
| @ -22,8 +22,8 @@ def test_simple_rating( | ||||
|                 params={'score': 1}, user=user_factory()), | ||||
|             {'post_id': post.post_id}) | ||||
|         assert result == 'serialized post' | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostScore).count() == 1 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostScore).count() == 1 | ||||
|         assert post is not None | ||||
|         assert post.score == 1 | ||||
| 
 | ||||
| @ -43,8 +43,8 @@ def test_updating_rating( | ||||
|             api.post_api.set_post_score( | ||||
|                 context_factory(params={'score': -1}, user=user), | ||||
|                 {'post_id': post.post_id}) | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostScore).count() == 1 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostScore).count() == 1 | ||||
|         assert post.score == -1 | ||||
| 
 | ||||
| 
 | ||||
| @ -63,8 +63,8 @@ def test_updating_rating_to_zero( | ||||
|             api.post_api.set_post_score( | ||||
|                 context_factory(params={'score': 0}, user=user), | ||||
|                 {'post_id': post.post_id}) | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostScore).count() == 0 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostScore).count() == 0 | ||||
|         assert post.score == 0 | ||||
| 
 | ||||
| 
 | ||||
| @ -83,8 +83,8 @@ def test_deleting_rating( | ||||
|             api.post_api.delete_post_score( | ||||
|                 context_factory(user=user), | ||||
|                 {'post_id': post.post_id}) | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostScore).count() == 0 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostScore).count() == 0 | ||||
|         assert post.score == 0 | ||||
| 
 | ||||
| 
 | ||||
| @ -104,8 +104,8 @@ def test_ratings_from_multiple_users( | ||||
|             api.post_api.set_post_score( | ||||
|                 context_factory(params={'score': -1}, user=user2), | ||||
|                 {'post_id': post.post_id}) | ||||
|         post = db.session.query(db.Post).one() | ||||
|         assert db.session.query(db.PostScore).count() == 2 | ||||
|         post = db.session.query(model.Post).one() | ||||
|         assert db.session.query(model.PostScore).count() == 2 | ||||
|         assert post.score == 0 | ||||
| 
 | ||||
| 
 | ||||
| @ -136,5 +136,5 @@ def test_trying_to_rate_without_privileges( | ||||
|         api.post_api.set_post_score( | ||||
|             context_factory( | ||||
|                 params={'score': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'post_id': post.post_id}) | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| from datetime import datetime | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import posts | ||||
| 
 | ||||
| 
 | ||||
| @ -9,8 +9,8 @@ from szurubooru.func import posts | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'posts:list': db.User.RANK_REGULAR, | ||||
|             'posts:view': db.User.RANK_REGULAR, | ||||
|             'posts:list': model.User.RANK_REGULAR, | ||||
|             'posts:view': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| @ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory): | ||||
|         result = api.post_api.get_posts( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         assert result == { | ||||
|             'query': '', | ||||
|             'page': 1, | ||||
| @ -36,10 +36,10 @@ def test_retrieving_multiple(user_factory, post_factory, context_factory): | ||||
| 
 | ||||
| 
 | ||||
| def test_using_special_tokens(user_factory, post_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post1 = post_factory(id=1) | ||||
|     post2 = post_factory(id=2) | ||||
|     post1.favorited_by = [db.PostFavorite( | ||||
|     post1.favorited_by = [model.PostFavorite( | ||||
|         user=auth_user, time=datetime.utcnow())] | ||||
|     db.session.add_all([post1, post2, auth_user]) | ||||
|     db.session.flush() | ||||
| @ -68,7 +68,7 @@ def test_trying_to_use_special_tokens_without_logging_in( | ||||
|         api.post_api.get_posts( | ||||
|             context_factory( | ||||
|                 params={'query': 'special:fav', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_retrieve_multiple_without_privileges( | ||||
| @ -77,7 +77,7 @@ def test_trying_to_retrieve_multiple_without_privileges( | ||||
|         api.post_api.get_posts( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
| 
 | ||||
| 
 | ||||
| def test_retrieving_single(user_factory, post_factory, context_factory): | ||||
| @ -86,7 +86,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory): | ||||
|     with patch('szurubooru.func.posts.serialize_post'): | ||||
|         posts.serialize_post.return_value = 'serialized post' | ||||
|         result = api.post_api.get_post( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'post_id': 1}) | ||||
|         assert result == 'serialized post' | ||||
| 
 | ||||
| @ -94,7 +94,7 @@ def test_retrieving_single(user_factory, post_factory, context_factory): | ||||
| def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): | ||||
|     with pytest.raises(posts.PostNotFoundError): | ||||
|         api.post_api.get_post( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'post_id': 999}) | ||||
| 
 | ||||
| 
 | ||||
| @ -102,5 +102,5 @@ def test_trying_to_retrieve_single_without_privileges( | ||||
|         user_factory, context_factory): | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.post_api.get_post( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'post_id': 999}) | ||||
|  | ||||
| @ -1,7 +1,7 @@ | ||||
| from datetime import datetime | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import posts, tags, snapshots, net | ||||
| 
 | ||||
| 
 | ||||
| @ -9,22 +9,22 @@ from szurubooru.func import posts, tags, snapshots, net | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'posts:edit:tags': db.User.RANK_REGULAR, | ||||
|             'posts:edit:content': db.User.RANK_REGULAR, | ||||
|             'posts:edit:safety': db.User.RANK_REGULAR, | ||||
|             'posts:edit:source': db.User.RANK_REGULAR, | ||||
|             'posts:edit:relations': db.User.RANK_REGULAR, | ||||
|             'posts:edit:notes': db.User.RANK_REGULAR, | ||||
|             'posts:edit:flags': db.User.RANK_REGULAR, | ||||
|             'posts:edit:thumbnail': db.User.RANK_REGULAR, | ||||
|             'tags:create': db.User.RANK_MODERATOR, | ||||
|             'posts:edit:tags': model.User.RANK_REGULAR, | ||||
|             'posts:edit:content': model.User.RANK_REGULAR, | ||||
|             'posts:edit:safety': model.User.RANK_REGULAR, | ||||
|             'posts:edit:source': model.User.RANK_REGULAR, | ||||
|             'posts:edit:relations': model.User.RANK_REGULAR, | ||||
|             'posts:edit:notes': model.User.RANK_REGULAR, | ||||
|             'posts:edit:flags': model.User.RANK_REGULAR, | ||||
|             'posts:edit:thumbnail': model.User.RANK_REGULAR, | ||||
|             'tags:create': model.User.RANK_MODERATOR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_post_updating( | ||||
|         context_factory, post_factory, user_factory, fake_datetime): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     post = post_factory() | ||||
|     db.session.add(post) | ||||
|     db.session.flush() | ||||
| @ -76,7 +76,7 @@ def test_post_updating( | ||||
|         posts.update_post_flags.assert_called_once_with( | ||||
|             post, ['flag1', 'flag2']) | ||||
|         posts.serialize_post.assert_called_once_with( | ||||
|             post, auth_user, options=None) | ||||
|             post, auth_user, options=[]) | ||||
|         snapshots.modify.assert_called_once_with(post, auth_user) | ||||
|         tags.export_to_json.assert_called_once_with() | ||||
|         assert post.last_edit_time == datetime(1997, 1, 1) | ||||
| @ -97,7 +97,7 @@ def test_uploading_from_url_saves_source( | ||||
|         api.post_api.update_post( | ||||
|             context_factory( | ||||
|                 params={'contentUrl': 'example.com', 'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'post_id': post.post_id}) | ||||
|         net.download.assert_called_once_with('example.com') | ||||
|         posts.update_post_content.assert_called_once_with(post, b'content') | ||||
| @ -122,7 +122,7 @@ def test_uploading_from_url_with_source_specified( | ||||
|                     'contentUrl': 'example.com', | ||||
|                     'source': 'example2.com', | ||||
|                     'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'post_id': post.post_id}) | ||||
|         net.download.assert_called_once_with('example.com') | ||||
|         posts.update_post_content.assert_called_once_with(post, b'content') | ||||
| @ -134,7 +134,7 @@ def test_trying_to_update_non_existing(context_factory, user_factory): | ||||
|         api.post_api.update_post( | ||||
|             context_factory( | ||||
|                 params='whatever', | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'post_id': 1}) | ||||
| 
 | ||||
| 
 | ||||
| @ -158,7 +158,7 @@ def test_trying_to_update_field_without_privileges( | ||||
|             context_factory( | ||||
|                 params={**params, **{'version': 1}}, | ||||
|                 files=files, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'post_id': post.post_id}) | ||||
| 
 | ||||
| 
 | ||||
| @ -173,5 +173,5 @@ def test_trying_to_create_tags_without_privileges( | ||||
|         api.post_api.update_post( | ||||
|             context_factory( | ||||
|                 params={'tags': ['tag1', 'tag2'], 'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'post_id': post.post_id}) | ||||
|  | ||||
| @ -1,10 +1,10 @@ | ||||
| from datetime import datetime | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| 
 | ||||
| 
 | ||||
| def snapshot_factory(): | ||||
|     snapshot = db.Snapshot() | ||||
|     snapshot = model.Snapshot() | ||||
|     snapshot.creation_time = datetime(1999, 1, 1) | ||||
|     snapshot.resource_type = 'dummy' | ||||
|     snapshot.resource_pkey = 1 | ||||
| @ -17,7 +17,7 @@ def snapshot_factory(): | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': {'snapshots:list': db.User.RANK_REGULAR}, | ||||
|         'privileges': {'snapshots:list': model.User.RANK_REGULAR}, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| @ -29,7 +29,7 @@ def test_retrieving_multiple(user_factory, context_factory): | ||||
|     result = api.snapshot_api.get_snapshots( | ||||
|         context_factory( | ||||
|             params={'query': '', 'page': 1}, | ||||
|             user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|             user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|     assert result['query'] == '' | ||||
|     assert result['page'] == 1 | ||||
|     assert result['pageSize'] == 100 | ||||
| @ -43,4 +43,4 @@ def test_trying_to_retrieve_multiple_without_privileges( | ||||
|         api.snapshot_api.get_snapshots( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tag_categories, tags, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @ -11,13 +11,13 @@ def _update_category_name(category, name): | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': {'tag_categories:create': db.User.RANK_REGULAR}, | ||||
|         'privileges': {'tag_categories:create': model.User.RANK_REGULAR}, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_creating_category( | ||||
|         tag_category_factory, user_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     category = tag_category_factory(name='meta') | ||||
|     db.session.add(category) | ||||
| 
 | ||||
| @ -49,7 +49,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): | ||||
|         api.tag_category_api.create_tag_category( | ||||
|             context_factory( | ||||
|                 params=params, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_create_without_privileges(user_factory, context_factory): | ||||
| @ -57,4 +57,4 @@ def test_trying_to_create_without_privileges(user_factory, context_factory): | ||||
|         api.tag_category_api.create_tag_category( | ||||
|             context_factory( | ||||
|                 params={'name': 'meta', 'color': 'black'}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
|  | ||||
| @ -1,18 +1,18 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tag_categories, tags, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': {'tag_categories:delete': db.User.RANK_REGULAR}, | ||||
|         'privileges': {'tag_categories:delete': model.User.RANK_REGULAR}, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_deleting(user_factory, tag_category_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     category = tag_category_factory(name='category') | ||||
|     db.session.add(tag_category_factory(name='root')) | ||||
|     db.session.add(category) | ||||
| @ -23,8 +23,8 @@ def test_deleting(user_factory, tag_category_factory, context_factory): | ||||
|             context_factory(params={'version': 1}, user=auth_user), | ||||
|             {'category_name': 'category'}) | ||||
|         assert result == {} | ||||
|         assert db.session.query(db.TagCategory).count() == 1 | ||||
|         assert db.session.query(db.TagCategory).one().name == 'root' | ||||
|         assert db.session.query(model.TagCategory).count() == 1 | ||||
|         assert db.session.query(model.TagCategory).one().name == 'root' | ||||
|         snapshots.delete.assert_called_once_with(category, auth_user) | ||||
|         tags.export_to_json.assert_called_once_with() | ||||
| 
 | ||||
| @ -41,9 +41,9 @@ def test_trying_to_delete_used( | ||||
|         api.tag_category_api.delete_tag_category( | ||||
|             context_factory( | ||||
|                 params={'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'category_name': 'category'}) | ||||
|     assert db.session.query(db.TagCategory).count() == 1 | ||||
|     assert db.session.query(model.TagCategory).count() == 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_delete_last( | ||||
| @ -54,14 +54,14 @@ def test_trying_to_delete_last( | ||||
|         api.tag_category_api.delete_tag_category( | ||||
|             context_factory( | ||||
|                 params={'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'category_name': 'root'}) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_delete_non_existing(user_factory, context_factory): | ||||
|     with pytest.raises(tag_categories.TagCategoryNotFoundError): | ||||
|         api.tag_category_api.delete_tag_category( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'category_name': 'bad'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -73,6 +73,6 @@ def test_trying_to_delete_without_privileges( | ||||
|         api.tag_category_api.delete_tag_category( | ||||
|             context_factory( | ||||
|                 params={'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'category_name': 'category'}) | ||||
|     assert db.session.query(db.TagCategory).count() == 1 | ||||
|     assert db.session.query(model.TagCategory).count() == 1 | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tag_categories | ||||
| 
 | ||||
| 
 | ||||
| @ -7,8 +7,8 @@ from szurubooru.func import tag_categories | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'tag_categories:list': db.User.RANK_REGULAR, | ||||
|             'tag_categories:view': db.User.RANK_REGULAR, | ||||
|             'tag_categories:list': model.User.RANK_REGULAR, | ||||
|             'tag_categories:view': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| @ -21,7 +21,7 @@ def test_retrieving_multiple( | ||||
|     ]) | ||||
|     db.session.flush() | ||||
|     result = api.tag_category_api.get_tag_categories( | ||||
|         context_factory(user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|         context_factory(user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|     assert [cat['name'] for cat in result['results']] == ['c1', 'c2'] | ||||
| 
 | ||||
| 
 | ||||
| @ -30,7 +30,7 @@ def test_retrieving_single( | ||||
|     db.session.add(tag_category_factory(name='cat')) | ||||
|     db.session.flush() | ||||
|     result = api.tag_category_api.get_tag_category( | ||||
|         context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|         context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|         {'category_name': 'cat'}) | ||||
|     assert result == { | ||||
|         'name': 'cat', | ||||
| @ -44,7 +44,7 @@ def test_retrieving_single( | ||||
| def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): | ||||
|     with pytest.raises(tag_categories.TagCategoryNotFoundError): | ||||
|         api.tag_category_api.get_tag_category( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'category_name': '-'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -52,5 +52,5 @@ def test_trying_to_retrieve_single_without_privileges( | ||||
|         user_factory, context_factory): | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.tag_category_api.get_tag_category( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'category_name': '-'}) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tag_categories, tags, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @ -12,15 +12,15 @@ def _update_category_name(category, name): | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'tag_categories:edit:name': db.User.RANK_REGULAR, | ||||
|             'tag_categories:edit:color': db.User.RANK_REGULAR, | ||||
|             'tag_categories:set_default': db.User.RANK_REGULAR, | ||||
|             'tag_categories:edit:name': model.User.RANK_REGULAR, | ||||
|             'tag_categories:edit:color': model.User.RANK_REGULAR, | ||||
|             'tag_categories:set_default': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_simple_updating(user_factory, tag_category_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     category = tag_category_factory(name='name', color='black') | ||||
|     db.session.add(category) | ||||
|     db.session.flush() | ||||
| @ -61,7 +61,7 @@ def test_omitting_optional_field( | ||||
|         api.tag_category_api.update_tag_category( | ||||
|             context_factory( | ||||
|                 params={**params, **{'version': 1}}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'category_name': 'name'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -70,7 +70,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): | ||||
|         api.tag_category_api.update_tag_category( | ||||
|             context_factory( | ||||
|                 params={'name': ['dummy']}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'category_name': 'bad'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -86,7 +86,7 @@ def test_trying_to_update_without_privileges( | ||||
|         api.tag_category_api.update_tag_category( | ||||
|             context_factory( | ||||
|                 params={**params, **{'version': 1}}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'category_name': 'dummy'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -106,7 +106,7 @@ def test_set_as_default(user_factory, tag_category_factory, context_factory): | ||||
|                     'color': 'white', | ||||
|                     'version': 1, | ||||
|                 }, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'category_name': 'name'}) | ||||
|         assert result == 'serialized category' | ||||
|         tag_categories.set_default_category.assert_called_once_with(category) | ||||
|  | ||||
| @ -1,16 +1,16 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tags, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'tags:create': db.User.RANK_REGULAR}}) | ||||
|     config_injector({'privileges': {'tags:create': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_creating_simple_tags(tag_factory, user_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     tag = tag_factory() | ||||
|     with patch('szurubooru.func.tags.create_tag'), \ | ||||
|             patch('szurubooru.func.tags.get_or_create_tags_by_names'), \ | ||||
| @ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): | ||||
|         api.tag_api.create_tag( | ||||
|             context_factory( | ||||
|                 params=params, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize('field', ['implications', 'suggestions']) | ||||
| @ -70,7 +70,7 @@ def test_omitting_optional_field( | ||||
|         api.tag_api.create_tag( | ||||
|             context_factory( | ||||
|                 params=params, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_create_tag_without_privileges( | ||||
| @ -84,4 +84,4 @@ def test_trying_to_create_tag_without_privileges( | ||||
|                     'suggestions': ['tag'], | ||||
|                     'implications': [], | ||||
|                 }, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
|  | ||||
| @ -1,16 +1,16 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tags, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'tags:delete': db.User.RANK_REGULAR}}) | ||||
|     config_injector({'privileges': {'tags:delete': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_deleting(user_factory, tag_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     tag = tag_factory(names=['tag']) | ||||
|     db.session.add(tag) | ||||
|     db.session.commit() | ||||
| @ -20,7 +20,7 @@ def test_deleting(user_factory, tag_factory, context_factory): | ||||
|             context_factory(params={'version': 1}, user=auth_user), | ||||
|             {'tag_name': 'tag'}) | ||||
|         assert result == {} | ||||
|         assert db.session.query(db.Tag).count() == 0 | ||||
|         assert db.session.query(model.Tag).count() == 0 | ||||
|         snapshots.delete.assert_called_once_with(tag, auth_user) | ||||
|         tags.export_to_json.assert_called_once_with() | ||||
| 
 | ||||
| @ -36,17 +36,17 @@ def test_deleting_used( | ||||
|         api.tag_api.delete_tag( | ||||
|             context_factory( | ||||
|                 params={'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'tag_name': 'tag'}) | ||||
|         db.session.refresh(post) | ||||
|         assert db.session.query(db.Tag).count() == 0 | ||||
|         assert db.session.query(model.Tag).count() == 0 | ||||
|         assert post.tags == [] | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_delete_non_existing(user_factory, context_factory): | ||||
|     with pytest.raises(tags.TagNotFoundError): | ||||
|         api.tag_api.delete_tag( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'tag_name': 'bad'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -58,6 +58,6 @@ def test_trying_to_delete_without_privileges( | ||||
|         api.tag_api.delete_tag( | ||||
|             context_factory( | ||||
|                 params={'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'tag_name': 'tag'}) | ||||
|     assert db.session.query(db.Tag).count() == 1 | ||||
|     assert db.session.query(model.Tag).count() == 1 | ||||
|  | ||||
| @ -1,16 +1,16 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tags, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'tags:merge': db.User.RANK_REGULAR}}) | ||||
|     config_injector({'privileges': {'tags:merge': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_merging(user_factory, tag_factory, context_factory, post_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     source_tag = tag_factory(names=['source']) | ||||
|     target_tag = tag_factory(names=['target']) | ||||
|     db.session.add_all([source_tag, target_tag]) | ||||
| @ -62,7 +62,7 @@ def test_trying_to_omit_mandatory_field( | ||||
|         api.tag_api.merge_tags( | ||||
|             context_factory( | ||||
|                 params=params, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_merge_non_existing( | ||||
| @ -73,12 +73,12 @@ def test_trying_to_merge_non_existing( | ||||
|         api.tag_api.merge_tags( | ||||
|             context_factory( | ||||
|                 params={'remove': 'good', 'mergeTo': 'bad'}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|     with pytest.raises(tags.TagNotFoundError): | ||||
|         api.tag_api.merge_tags( | ||||
|             context_factory( | ||||
|                 params={'remove': 'bad', 'mergeTo': 'good'}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_merge_without_privileges( | ||||
| @ -97,4 +97,4 @@ def test_trying_to_merge_without_privileges( | ||||
|                     'remove': 'source', | ||||
|                     'mergeTo': 'target', | ||||
|                 }, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tags | ||||
| 
 | ||||
| 
 | ||||
| @ -8,8 +8,8 @@ from szurubooru.func import tags | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'tags:list': db.User.RANK_REGULAR, | ||||
|             'tags:view': db.User.RANK_REGULAR, | ||||
|             'tags:list': model.User.RANK_REGULAR, | ||||
|             'tags:view': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| @ -24,7 +24,7 @@ def test_retrieving_multiple(user_factory, tag_factory, context_factory): | ||||
|         result = api.tag_api.get_tags( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         assert result == { | ||||
|             'query': '', | ||||
|             'page': 1, | ||||
| @ -40,7 +40,7 @@ def test_trying_to_retrieve_multiple_without_privileges( | ||||
|         api.tag_api.get_tags( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
| 
 | ||||
| 
 | ||||
| def test_retrieving_single(user_factory, tag_factory, context_factory): | ||||
| @ -50,7 +50,7 @@ def test_retrieving_single(user_factory, tag_factory, context_factory): | ||||
|         tags.serialize_tag.return_value = 'serialized tag' | ||||
|         result = api.tag_api.get_tag( | ||||
|             context_factory( | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'tag_name': 'tag'}) | ||||
|         assert result == 'serialized tag' | ||||
| 
 | ||||
| @ -59,7 +59,7 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): | ||||
|     with pytest.raises(tags.TagNotFoundError): | ||||
|         api.tag_api.get_tag( | ||||
|             context_factory( | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'tag_name': '-'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -68,5 +68,5 @@ def test_trying_to_retrieve_single_without_privileges( | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.tag_api.get_tag( | ||||
|             context_factory( | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'tag_name': '-'}) | ||||
|  | ||||
| @ -1,12 +1,12 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tags | ||||
| 
 | ||||
| 
 | ||||
| @pytest.fixture(autouse=True) | ||||
| def inject_config(config_injector): | ||||
|     config_injector({'privileges': {'tags:view': db.User.RANK_REGULAR}}) | ||||
|     config_injector({'privileges': {'tags:view': model.User.RANK_REGULAR}}) | ||||
| 
 | ||||
| 
 | ||||
| def test_get_tag_siblings(user_factory, tag_factory, context_factory): | ||||
| @ -21,7 +21,7 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory): | ||||
|             (tag_factory(names=['sib2']), 3), | ||||
|         ] | ||||
|         result = api.tag_api.get_tag_siblings( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'tag_name': 'tag'}) | ||||
|         assert result == { | ||||
|             'results': [ | ||||
| @ -40,12 +40,12 @@ def test_get_tag_siblings(user_factory, tag_factory, context_factory): | ||||
| def test_trying_to_retrieve_non_existing(user_factory, context_factory): | ||||
|     with pytest.raises(tags.TagNotFoundError): | ||||
|         api.tag_api.get_tag_siblings( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'tag_name': '-'}) | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_retrieve_without_privileges(user_factory, context_factory): | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.tag_api.get_tag_siblings( | ||||
|             context_factory(user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|             context_factory(user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'tag_name': '-'}) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import tags, snapshots | ||||
| 
 | ||||
| 
 | ||||
| @ -8,18 +8,18 @@ from szurubooru.func import tags, snapshots | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'tags:create': db.User.RANK_REGULAR, | ||||
|             'tags:edit:names': db.User.RANK_REGULAR, | ||||
|             'tags:edit:category': db.User.RANK_REGULAR, | ||||
|             'tags:edit:description': db.User.RANK_REGULAR, | ||||
|             'tags:edit:suggestions': db.User.RANK_REGULAR, | ||||
|             'tags:edit:implications': db.User.RANK_REGULAR, | ||||
|             'tags:create': model.User.RANK_REGULAR, | ||||
|             'tags:edit:names': model.User.RANK_REGULAR, | ||||
|             'tags:edit:category': model.User.RANK_REGULAR, | ||||
|             'tags:edit:description': model.User.RANK_REGULAR, | ||||
|             'tags:edit:suggestions': model.User.RANK_REGULAR, | ||||
|             'tags:edit:implications': model.User.RANK_REGULAR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_simple_updating(user_factory, tag_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     tag = tag_factory(names=['tag1', 'tag2']) | ||||
|     db.session.add(tag) | ||||
|     db.session.commit() | ||||
| @ -56,8 +56,7 @@ def test_simple_updating(user_factory, tag_factory, context_factory): | ||||
|             tag, ['sug1', 'sug2']) | ||||
|         tags.update_tag_implications.assert_called_once_with( | ||||
|             tag, ['imp1', 'imp2']) | ||||
|         tags.serialize_tag.assert_called_once_with( | ||||
|             tag, options=None) | ||||
|         tags.serialize_tag.assert_called_once_with(tag, options=[]) | ||||
|         snapshots.modify.assert_called_once_with(tag, auth_user) | ||||
|         tags.export_to_json.assert_called_once_with() | ||||
| 
 | ||||
| @ -90,7 +89,7 @@ def test_omitting_optional_field( | ||||
|         api.tag_api.update_tag( | ||||
|             context_factory( | ||||
|                 params={**params, **{'version': 1}}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'tag_name': 'tag'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -99,7 +98,7 @@ def test_trying_to_update_non_existing(user_factory, context_factory): | ||||
|         api.tag_api.update_tag( | ||||
|             context_factory( | ||||
|                 params={'names': ['dummy']}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'tag_name': 'tag1'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -117,7 +116,7 @@ def test_trying_to_update_without_privileges( | ||||
|         api.tag_api.update_tag( | ||||
|             context_factory( | ||||
|                 params={**params, **{'version': 1}}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS)), | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS)), | ||||
|             {'tag_name': 'tag'}) | ||||
| 
 | ||||
| 
 | ||||
| @ -127,9 +126,9 @@ def test_trying_to_create_tags_without_privileges( | ||||
|     db.session.add(tag) | ||||
|     db.session.commit() | ||||
|     config_injector({'privileges': { | ||||
|         'tags:create': db.User.RANK_ADMINISTRATOR, | ||||
|         'tags:edit:suggestions': db.User.RANK_REGULAR, | ||||
|         'tags:edit:implications': db.User.RANK_REGULAR, | ||||
|         'tags:create': model.User.RANK_ADMINISTRATOR, | ||||
|         'tags:edit:suggestions': model.User.RANK_REGULAR, | ||||
|         'tags:edit:implications': model.User.RANK_REGULAR, | ||||
|     }}) | ||||
|     with patch('szurubooru.func.tags.get_or_create_tags_by_names'): | ||||
|         tags.get_or_create_tags_by_names.return_value = ([], ['new-tag']) | ||||
| @ -137,12 +136,12 @@ def test_trying_to_create_tags_without_privileges( | ||||
|             api.tag_api.update_tag( | ||||
|                 context_factory( | ||||
|                     params={'suggestions': ['tag1', 'tag2'], 'version': 1}, | ||||
|                     user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                     user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|                 {'tag_name': 'tag'}) | ||||
|         db.session.rollback() | ||||
|         with pytest.raises(errors.AuthError): | ||||
|             api.tag_api.update_tag( | ||||
|                 context_factory( | ||||
|                     params={'implications': ['tag1', 'tag2'], 'version': 1}, | ||||
|                     user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                     user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|                 {'tag_name': 'tag'}) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import users | ||||
| 
 | ||||
| 
 | ||||
| @ -31,7 +31,7 @@ def test_creating_user(user_factory, context_factory, fake_datetime): | ||||
|                     'avatarStyle': 'manual', | ||||
|                 }, | ||||
|                 files={'avatar': b'...'}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         assert result == 'serialized user' | ||||
|         users.create_user.assert_called_once_with( | ||||
|             'chewie1', 'oks', 'asd@asd.asd') | ||||
| @ -50,7 +50,7 @@ def test_trying_to_omit_mandatory_field(user_factory, context_factory, field): | ||||
|         'password': 'oks', | ||||
|     } | ||||
|     user = user_factory() | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     del params[field] | ||||
|     with patch('szurubooru.func.users.create_user'), \ | ||||
|             pytest.raises(errors.MissingRequiredParameterError): | ||||
| @ -70,7 +70,7 @@ def test_omitting_optional_field(user_factory, context_factory, field): | ||||
|     } | ||||
|     del params[field] | ||||
|     user = user_factory() | ||||
|     auth_user = user_factory(rank=db.User.RANK_MODERATOR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_MODERATOR) | ||||
|     with patch('szurubooru.func.users.create_user'), \ | ||||
|             patch('szurubooru.func.users.update_user_avatar'), \ | ||||
|             patch('szurubooru.func.users.serialize_user'): | ||||
| @ -84,4 +84,4 @@ def test_trying_to_create_user_without_privileges( | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.user_api.create_user(context_factory( | ||||
|             params='whatever', | ||||
|             user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|             user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
|  | ||||
| @ -1,5 +1,5 @@ | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import users | ||||
| 
 | ||||
| 
 | ||||
| @ -7,45 +7,45 @@ from szurubooru.func import users | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'users:delete:self': db.User.RANK_REGULAR, | ||||
|             'users:delete:any': db.User.RANK_MODERATOR, | ||||
|             'users:delete:self': model.User.RANK_REGULAR, | ||||
|             'users:delete:any': model.User.RANK_MODERATOR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_deleting_oneself(user_factory, context_factory): | ||||
|     user = user_factory(name='u', rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(name='u', rank=model.User.RANK_REGULAR) | ||||
|     db.session.add(user) | ||||
|     db.session.commit() | ||||
|     result = api.user_api.delete_user( | ||||
|         context_factory( | ||||
|             params={'version': 1}, user=user), {'user_name': 'u'}) | ||||
|     assert result == {} | ||||
|     assert db.session.query(db.User).count() == 0 | ||||
|     assert db.session.query(model.User).count() == 0 | ||||
| 
 | ||||
| 
 | ||||
| def test_deleting_someone_else(user_factory, context_factory): | ||||
|     user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) | ||||
|     user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) | ||||
|     user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) | ||||
|     user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR) | ||||
|     db.session.add_all([user1, user2]) | ||||
|     db.session.commit() | ||||
|     api.user_api.delete_user( | ||||
|         context_factory( | ||||
|             params={'version': 1}, user=user2), {'user_name': 'u1'}) | ||||
|     assert db.session.query(db.User).count() == 1 | ||||
|     assert db.session.query(model.User).count() == 1 | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_delete_someone_else_without_privileges( | ||||
|         user_factory, context_factory): | ||||
|     user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) | ||||
|     user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) | ||||
|     user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) | ||||
|     user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR) | ||||
|     db.session.add_all([user1, user2]) | ||||
|     db.session.commit() | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.user_api.delete_user( | ||||
|             context_factory( | ||||
|                 params={'version': 1}, user=user2), {'user_name': 'u1'}) | ||||
|     assert db.session.query(db.User).count() == 2 | ||||
|     assert db.session.query(model.User).count() == 2 | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_delete_non_existing(user_factory, context_factory): | ||||
| @ -53,5 +53,5 @@ def test_trying_to_delete_non_existing(user_factory, context_factory): | ||||
|         api.user_api.delete_user( | ||||
|             context_factory( | ||||
|                 params={'version': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR)), | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR)), | ||||
|             {'user_name': 'bad'}) | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import users | ||||
| 
 | ||||
| 
 | ||||
| @ -8,16 +8,16 @@ from szurubooru.func import users | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'users:list': db.User.RANK_REGULAR, | ||||
|             'users:view': db.User.RANK_REGULAR, | ||||
|             'users:edit:any:email': db.User.RANK_MODERATOR, | ||||
|             'users:list': model.User.RANK_REGULAR, | ||||
|             'users:view': model.User.RANK_REGULAR, | ||||
|             'users:edit:any:email': model.User.RANK_MODERATOR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_retrieving_multiple(user_factory, context_factory): | ||||
|     user1 = user_factory(name='u1', rank=db.User.RANK_MODERATOR) | ||||
|     user2 = user_factory(name='u2', rank=db.User.RANK_MODERATOR) | ||||
|     user1 = user_factory(name='u1', rank=model.User.RANK_MODERATOR) | ||||
|     user2 = user_factory(name='u2', rank=model.User.RANK_MODERATOR) | ||||
|     db.session.add_all([user1, user2]) | ||||
|     db.session.flush() | ||||
|     with patch('szurubooru.func.users.serialize_user'): | ||||
| @ -25,7 +25,7 @@ def test_retrieving_multiple(user_factory, context_factory): | ||||
|         result = api.user_api.get_users( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_REGULAR))) | ||||
|                 user=user_factory(rank=model.User.RANK_REGULAR))) | ||||
|         assert result == { | ||||
|             'query': '', | ||||
|             'page': 1, | ||||
| @ -41,12 +41,12 @@ def test_trying_to_retrieve_multiple_without_privileges( | ||||
|         api.user_api.get_users( | ||||
|             context_factory( | ||||
|                 params={'query': '', 'page': 1}, | ||||
|                 user=user_factory(rank=db.User.RANK_ANONYMOUS))) | ||||
|                 user=user_factory(rank=model.User.RANK_ANONYMOUS))) | ||||
| 
 | ||||
| 
 | ||||
| def test_retrieving_single(user_factory, context_factory): | ||||
|     user = user_factory(name='u1', rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     user = user_factory(name='u1', rank=model.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     db.session.add(user) | ||||
|     db.session.flush() | ||||
|     with patch('szurubooru.func.users.serialize_user'): | ||||
| @ -57,7 +57,7 @@ def test_retrieving_single(user_factory, context_factory): | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_REGULAR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_REGULAR) | ||||
|     with pytest.raises(users.UserNotFoundError): | ||||
|         api.user_api.get_user( | ||||
|             context_factory(user=auth_user), {'user_name': '-'}) | ||||
| @ -65,8 +65,8 @@ def test_trying_to_retrieve_single_non_existing(user_factory, context_factory): | ||||
| 
 | ||||
| def test_trying_to_retrieve_single_without_privileges( | ||||
|         user_factory, context_factory): | ||||
|     auth_user = user_factory(rank=db.User.RANK_ANONYMOUS) | ||||
|     db.session.add(user_factory(name='u1', rank=db.User.RANK_REGULAR)) | ||||
|     auth_user = user_factory(rank=model.User.RANK_ANONYMOUS) | ||||
|     db.session.add(user_factory(name='u1', rank=model.User.RANK_REGULAR)) | ||||
|     db.session.flush() | ||||
|     with pytest.raises(errors.AuthError): | ||||
|         api.user_api.get_user( | ||||
|  | ||||
| @ -1,6 +1,6 @@ | ||||
| from unittest.mock import patch | ||||
| import pytest | ||||
| from szurubooru import api, db, errors | ||||
| from szurubooru import api, db, model, errors | ||||
| from szurubooru.func import users | ||||
| 
 | ||||
| 
 | ||||
| @ -8,23 +8,23 @@ from szurubooru.func import users | ||||
| def inject_config(config_injector): | ||||
|     config_injector({ | ||||
|         'privileges': { | ||||
|             'users:edit:self:name': db.User.RANK_REGULAR, | ||||
|             'users:edit:self:pass': db.User.RANK_REGULAR, | ||||
|             'users:edit:self:email': db.User.RANK_REGULAR, | ||||
|             'users:edit:self:rank': db.User.RANK_MODERATOR, | ||||
|             'users:edit:self:avatar': db.User.RANK_MODERATOR, | ||||
|             'users:edit:any:name': db.User.RANK_MODERATOR, | ||||
|             'users:edit:any:pass': db.User.RANK_MODERATOR, | ||||
|             'users:edit:any:email': db.User.RANK_MODERATOR, | ||||
|             'users:edit:any:rank': db.User.RANK_ADMINISTRATOR, | ||||
|             'users:edit:any:avatar': db.User.RANK_ADMINISTRATOR, | ||||
|             'users:edit:self:name': model.User.RANK_REGULAR, | ||||
|             'users:edit:self:pass': model.User.RANK_REGULAR, | ||||
|             'users:edit:self:email': model.User.RANK_REGULAR, | ||||
|             'users:edit:self:rank': model.User.RANK_MODERATOR, | ||||
|             'users:edit:self:avatar': model.User.RANK_MODERATOR, | ||||
|             'users:edit:any:name': model.User.RANK_MODERATOR, | ||||
|             'users:edit:any:pass': model.User.RANK_MODERATOR, | ||||
|             'users:edit:any:email': model.User.RANK_MODERATOR, | ||||
|             'users:edit:any:rank': model.User.RANK_ADMINISTRATOR, | ||||
|             'users:edit:any:avatar': model.User.RANK_ADMINISTRATOR, | ||||
|         }, | ||||
|     }) | ||||
| 
 | ||||
| 
 | ||||
| def test_updating_user(context_factory, user_factory): | ||||
|     user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) | ||||
|     auth_user = user_factory(rank=db.User.RANK_ADMINISTRATOR) | ||||
|     user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) | ||||
|     auth_user = user_factory(rank=model.User.RANK_ADMINISTRATOR) | ||||
|     db.session.add(user) | ||||
|     db.session.flush() | ||||
| 
 | ||||
| @ -63,13 +63,13 @@ def test_updating_user(context_factory, user_factory): | ||||
|         users.update_user_avatar.assert_called_once_with( | ||||
|             user, 'manual', b'...') | ||||
|         users.serialize_user.assert_called_once_with( | ||||
|             user, auth_user, options=None) | ||||
|             user, auth_user, options=[]) | ||||
| 
 | ||||
| 
 | ||||
| @pytest.mark.parametrize( | ||||
|     'field', ['name', 'email', 'password', 'rank', 'avatarStyle']) | ||||
| def test_omitting_optional_field(user_factory, context_factory, field): | ||||
|     user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) | ||||
|     user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) | ||||
|     db.session.add(user) | ||||
|     db.session.flush() | ||||
|     params = { | ||||
| @ -96,7 +96,7 @@ def test_omitting_optional_field(user_factory, context_factory, field): | ||||
| 
 | ||||
| 
 | ||||
| def test_trying_to_update_non_existing(user_factory, context_factory): | ||||
|     user = user_factory(name='u1', rank=db.User.RANK_ADMINISTRATOR) | ||||
|     user = user_factory(name='u1', rank=model.User.RANK_ADMINISTRATOR) | ||||
|     db.session.add(user) | ||||
|     db.session.flush() | ||||
|     with pytest.raises(users.UserNotFoundError): | ||||
| @ -113,8 +113,8 @@ def test_trying_to_update_non_existing(user_factory, context_factory): | ||||
| ]) | ||||
| def test_trying_to_update_field_without_privileges( | ||||
|         user_factory, context_factory, params): | ||||
|     user1 = user_factory(name='u1', rank=db.User.RANK_REGULAR) | ||||
|     user2 = user_factory(name='u2', rank=db.User.RANK_REGULAR) | ||||
|     user1 = user_factory(name='u1', rank=model.User.RANK_REGULAR) | ||||
|     user2 = user_factory(name='u2', rank=model.User.RANK_REGULAR) | ||||
|     db.session.add_all([user1, user2]) | ||||
|     db.session.flush() | ||||
|     with pytest.raises(errors.AuthError): | ||||
|  | ||||
Some files were not shown because too many files have changed in this diff Show More
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user
	 rr-
						rr-