diff --git a/API.md b/API.md index c36bbf6..05df63d 100644 --- a/API.md +++ b/API.md @@ -29,7 +29,7 @@ - [Listing tag siblings](#listing-tag-siblings) - Posts - ~~Listing posts~~ - - ~~Creating post~~ + - [Creating post](#creating-post) - ~~Updating post~~ - [Getting post](#getting-post) - [Deleting post](#deleting-post) @@ -69,6 +69,7 @@ - [Detailed tag](#detailed-tag) - [Post](#post) - [Detailed post](#detailed-post) + - [Note](#note) - [Comment](#comment) - [Detailed comment](#detailed-comment) - [Snapshot](#snapshot) @@ -125,7 +126,6 @@ Depending on the deployment, the URLs might be relative to some base path such as `/api/`. Values denoted with diamond braces (``) signify variable data. - ## Listing tag categories - **Request** @@ -150,7 +150,6 @@ data. caching. The data directory and its URL are controlled with `data_dir` and `data_url` variables in server's configuration. - ## Creating tag category - **Request** @@ -181,7 +180,6 @@ data. Creates a new tag category using specified parameters. Name must match `tag_category_name_regex` from server's configuration. - ## Updating tag category - **Request** @@ -214,7 +212,6 @@ data. match `tag_category_name_regex` from server's configuration. All fields are optional - update concerns only provided fields. - ## Getting tag category - **Request** @@ -233,7 +230,6 @@ data. Retrieves information about an existing tag category. - ## Deleting tag category - **Request** @@ -257,7 +253,6 @@ data. Deletes existing tag category. The tag category to be deleted must have no usages. - ## Listing tags - **Request** @@ -327,7 +322,6 @@ data. None. - ## Creating tag - **Request** @@ -370,7 +364,6 @@ data. first tag category found. If there are no tag categories established yet, an error will be thrown. - ## Updating tag - **Request** @@ -412,7 +405,6 @@ data. their category is set to the first tag category found. All fields are optional - update concerns only provided fields. - ## Getting tag - **Request** @@ -431,7 +423,6 @@ data. Retrieves information about an existing tag. - ## Deleting tag - **Request** @@ -453,7 +444,6 @@ data. Deletes existing tag. The tag to be deleted must have no usages. - ## Merging tags - **Request** @@ -485,7 +475,6 @@ data. and are discarded. The target tag effectively remains unchanged with the exception of the set of posts it's used in. - ## Listing tag siblings - **Request** @@ -520,6 +509,48 @@ data. appears with given tag. Results are sorted by occurrences count and the list is truncated to the first 50 elements. Doesn't use paging. +## Creating post +- **Request** + + `POST /posts/` + +- **Input** + + ```json5 + { + "tags": [, , ], + "safety": , + "source": , // optional + "relations": [, , ], // optional + "notes": [, , ], // optional + "flags": [, ] // optional + } + ``` + +- **Files** + + - `content` - the content of the content. + - `thumbnail` - the content of custom thumbnail (optional). + +- **Output** + + A [detailed post resource](#detailed-post). + +- **Errors** + + - tags have invalid names + - safety is invalid + - relations refer to non-existing posts + - privileges are too low + +- **Description** + + Creates a new post. If specified tags do not exist yet, they will be + automatically created. Tags created automatically have no implications, no + suggestions, one name and their category is set to the first tag category + found. Safety must be any of `"safe"`, `"sketchy"` or `"unsafe"`. `` + currently can be only `"loop"` to enable looping for video posts. Sending + empty `thumbnail` will cause the post to use default thumbnail. ## Getting post - **Request** @@ -539,7 +570,6 @@ data. Retrieves information about an existing post. - ## Deleting post - **Request** @@ -560,7 +590,6 @@ data. Deletes existing post. Related posts and tags are kept. - ## Rating post - **Request** @@ -589,7 +618,6 @@ data. Updates score of authenticated user for given post. Valid scores are -1, 0 and 1. - ## Adding post to favorites - **Request** @@ -608,7 +636,6 @@ data. Marks the post as favorite for authenticated user. - ## Removing post from favorites - **Request** @@ -627,7 +654,6 @@ data. Unmarks the post as favorite for authenticated user. - ## Getting featured post - **Request** @@ -647,7 +673,6 @@ data. client. If no post is featured, `` is null and `snapshots` array is empty. - ## Featuring post - **Request** @@ -666,7 +691,6 @@ data. Features a post on the main page in web client. - ## Listing comments - **Request** @@ -722,7 +746,6 @@ data. None. - ## Creating comment - **Request** @@ -751,7 +774,6 @@ data. Creates a new comment under given post. - ## Updating comment - **Request** @@ -779,7 +801,6 @@ data. Updates an existing comment text. - ## Getting comment - **Request** @@ -798,7 +819,6 @@ data. Retrieves information about an existing comment. - ## Deleting comment - **Request** @@ -819,7 +839,6 @@ data. Deletes existing comment. - ## Rating comment - **Request** @@ -848,7 +867,6 @@ data. Updates score of authenticated user for given comment. Valid scores are -1, 0 and 1. - ## Listing users - **Request** @@ -900,7 +918,6 @@ data. None. - ## Creating user - **Request** @@ -947,7 +964,6 @@ data. administrator, whereas subsequent users will be given the rank indicated by `default_rank` in the server's configuration. - ## Updating user - **Request** @@ -993,7 +1009,6 @@ data. `manual`. `manual` avatar style requires client to pass also `avatar` file - see [file uploads](#file-uploads) for details. - ## Getting user - **Request** @@ -1012,7 +1027,6 @@ data. Retrieves information about an existing user. - ## Deleting user - **Request** @@ -1033,7 +1047,6 @@ data. Deletes existing user. - ## Password reset - step 1: mail request - **Request** @@ -1058,7 +1071,6 @@ data. mailbox, which is a strong indication they are the rightful owner of the account. - ## Password reset - step 2: confirmation - **Request** @@ -1091,7 +1103,6 @@ data. Generates a new password for given user. Password is sent as plain-text, so it is recommended to connect through HTTPS. - ## Listing snapshots - **Request** @@ -1133,7 +1144,6 @@ data. None. - ## Getting global info - **Request** @@ -1325,29 +1335,34 @@ One file together with its metadata posted to the site. ```json5 { "id": , + "creationTime": , + "lastEditTime": , "safety": , + "source": , "type": , "checksum": , - "source": , "canvasWidth": , "canvasHeight": , + "contentUrl": , + "thumbnailUrl": , "flags": , "tags": , "relations": , - "creationTime": , - "lastEditTime": , + "notes": , "user": , "score": , "ownScore": , - "favoritedBy": , "featureCount": , - "lastFeatureTime": + "lastFeatureTime": , + "favoritedBy": } ``` **Field meaning** - ``: the post identifier. +- ``: time the tag was created, formatted as per RFC 3339. +- ``: time the tag was edited, formatted as per RFC 3339. - ``: whether the post is safe for work. Available values: @@ -1356,6 +1371,7 @@ One file together with its metadata posted to the site. - `"sketchy"` - `"unsafe"` +- ``: where the post was grabbed form, supplied by the user. - ``: the type of the post. Available values: @@ -1368,24 +1384,25 @@ One file together with its metadata posted to the site. - ``: the file checksum. Used in snapshots to signify changes of the post content. -- ``: where the post was grabbed form, supplied by the user. - `` and ``: the original width and height of the post content. +- ``: where the post content is located. +- ``: where the post thumbnail is located. - ``: various flags such as whether the post is looped, represented as array of plain strings. - ``: list of tag names the post is tagged with. - ``: a list of related post IDs. Links to related posts are shown to the user by the web client. -- ``: time the tag was created, formatted as per RFC 3339. -- ``: time the tag was edited, formatted as per RFC 3339. +- ``: a list of post annotations, serialized as list of [note + resources](#note). - ``: who created the post, serialized as [user resource](#user). - ``: the collective score (+1/-1 rating) of the given post. - ``: the score (+1/-1 rating) of the given post by the authenticated user. -- ``: list of users, serialized as [user resources](#user). - ``: how many times has the post been featured. - ``: the last time the post was featured, formatted as per RFC 3339. +- ``: list of users, serialized as [user resources](#user). ## Detailed post **Description** @@ -1416,6 +1433,27 @@ A post with extra information. earlier versions. - ``: a [comment resource](#comment) for given post. +## Note +**Description** + +A text annotation rendered on top of the post. + +**Structure** + +```json5 +{ + "polygon": , + "text": , +} +``` + +**Field meaning** +- ``: where to draw the annotation. Each point must have + coordinates within 0 to 1. For example, `[[0,0],[0,1],[1,1],[1,0]]` will draw + the annotation on the whole post, whereas `[[0,0],[0,0.5],[0.5,0.5],[0.5,0]]` + will draw it inside the post's upper left quarter. +- ``: the annotation text. The client should render is as Markdown. + ## Comment **Description** @@ -1439,6 +1477,7 @@ A comment under a post. **Field meaning** - ``: the comment identifier. - ``: a post resource the post is linked with. +- ``: the comment content. The client should render is as Markdown. - ``: a user resource the post is created by. - ``: time the comment was created, formatted as per RFC 3339. - ``: time the comment was edited, formatted as per RFC 3339. @@ -1542,7 +1581,7 @@ A snapshot is a version of a database resource. "checksum": "deadbeef", "tags": ["tag1", "tag2"], "relations": [1, 2], - "notes": [{"polygon": [[1,1],[200,1],[200,200],[1,200]], "text": "..."}], + "notes": [, , ], "flags": ["loop"], "featured": false } diff --git a/server/szurubooru/api/__init__.py b/server/szurubooru/api/__init__.py index fb7324b..6649bf7 100644 --- a/server/szurubooru/api/__init__.py +++ b/server/szurubooru/api/__init__.py @@ -15,6 +15,7 @@ from szurubooru.api.comment_api import ( CommentDetailApi, CommentScoreApi) from szurubooru.api.post_api import ( + PostListApi, PostDetailApi, PostFeatureApi, PostScoreApi, diff --git a/server/szurubooru/api/context.py b/server/szurubooru/api/context.py index 58890e0..0829cf7 100644 --- a/server/szurubooru/api/context.py +++ b/server/szurubooru/api/context.py @@ -12,8 +12,16 @@ class Context(object): def has_param(self, name): return name in self.input - def get_file(self, name): - return self.files.get(name, None) + def has_file(self, name): + return name in self.files + + def get_file(self, name, required=False): + if name in self.files: + return self.files[name] + if not required: + return None + raise errors.MissingRequiredFileError( + 'Required file %r is missing.' % name) def get_param_as_list(self, name, required=False, default=None): if name in self.input: @@ -23,7 +31,8 @@ class Context(object): return param if not required: return default - raise errors.ValidationError('Required paramter %r is missing.' % name) + raise errors.MissingRequiredParameterError( + 'Required paramter %r is missing.' % name) def get_param_as_string(self, name, required=False, default=None): if name in self.input: @@ -32,12 +41,14 @@ class Context(object): try: param = ','.join(param) except: - raise errors.ValidationError( - 'Parameter %r is invalid - expected simple string.' % name) + raise errors.InvalidParameterError( + 'Parameter %r is invalid - expected simple string.' + % name) return param if not required: return default - raise errors.ValidationError('Required paramter %r is missing.' % name) + raise errors.MissingRequiredParameterError( + 'Required paramter %r is missing.' % name) # pylint: disable=redefined-builtin,too-many-arguments def get_param_as_int( @@ -47,21 +58,21 @@ class Context(object): try: val = int(val) except (ValueError, TypeError): - raise errors.ValidationError( + raise errors.InvalidParameterError( 'Parameter %r is invalid: the value must be an integer.' % name) if min is not None and val < min: - raise errors.ValidationError( + raise errors.InvalidParameterError( 'Parameter %r is invalid: the value must be at least %r.' % (name, min)) if max is not None and val > max: - raise errors.ValidationError( + raise errors.InvalidParameterError( 'Parameter %r is invalid: the value may not exceed %r.' % (name, max)) return val if not required: return default - raise errors.ValidationError( + raise errors.MissingRequiredParameterError( 'Required parameter %r is missing.' % name) class Request(falcon.Request): diff --git a/server/szurubooru/api/post_api.py b/server/szurubooru/api/post_api.py index 436464f..a299dc4 100644 --- a/server/szurubooru/api/post_api.py +++ b/server/szurubooru/api/post_api.py @@ -1,6 +1,29 @@ from szurubooru.api.base_api import BaseApi from szurubooru.func import auth, tags, posts, snapshots, favorites, scores +class PostListApi(BaseApi): + def post(self, ctx): + auth.verify_privilege(ctx.user, 'posts:create') + content = ctx.get_file('content', required=True) + tag_names = ctx.get_param_as_list('tags', required=True) + safety = ctx.get_param_as_string('safety', required=True) + source = ctx.get_param_as_string('source', required=False, default=None) + relations = ctx.get_param_as_list('relations', required=False) or [] + notes = ctx.get_param_as_list('notes', required=False) or [] + flags = ctx.get_param_as_list('flags', required=False) or [] + + post = posts.create_post(content, tag_names, ctx.user) + posts.update_post_safety(post, safety) + posts.update_post_source(post, source) + posts.update_post_relations(post, relations) + posts.update_post_notes(post, notes) + posts.update_post_flags(post, flags) + ctx.session.add(post) + snapshots.save_entity_creation(post, ctx.user) + ctx.session.commit() + tags.export_to_json() + return posts.serialize_post_with_details(post, ctx.user) + class PostDetailApi(BaseApi): def get(self, ctx, post_id): auth.verify_privilege(ctx.user, 'posts:view') diff --git a/server/szurubooru/app.py b/server/szurubooru/app.py index 85ccda8..f9d16bb 100644 --- a/server/szurubooru/app.py +++ b/server/szurubooru/app.py @@ -65,6 +65,7 @@ def create_app(): app.add_route('/tag-merge/', api.TagMergeApi()) app.add_route('/tag-siblings/{tag_name}', api.TagSiblingsApi()) + app.add_route('/posts/', api.PostListApi()) app.add_route('/post/{post_id}', api.PostDetailApi()) app.add_route('/post/{post_id}/score', api.PostScoreApi()) app.add_route('/post/{post_id}/favorite', api.PostFavoriteApi()) diff --git a/server/szurubooru/db/post.py b/server/szurubooru/db/post.py index 8d6b496..7affa73 100644 --- a/server/szurubooru/db/post.py +++ b/server/szurubooru/db/post.py @@ -71,26 +71,29 @@ class Post(Base): SAFETY_SAFE = 'safe' SAFETY_SKETCHY = 'sketchy' SAFETY_UNSAFE = 'unsafe' - TYPE_IMAGE = 'anim' - TYPE_ANIMATION = 'anim' - TYPE_FLASH = 'flash' + TYPE_IMAGE = 'image' + TYPE_ANIMATION = 'animation' TYPE_VIDEO = 'video' - TYPE_YOUTUBE = 'youtube' - FLAG_LOOP_VIDEO = 1 + TYPE_FLASH = 'flash' + # basic meta post_id = Column('id', Integer, primary_key=True) user_id = Column('user_id', Integer, ForeignKey('user.id')) creation_time = Column('creation_time', DateTime, nullable=False) last_edit_time = Column('last_edit_time', DateTime) safety = Column('safety', String(32), nullable=False) + source = Column('source', String(200)) + flags = Column('flags', PickleType, default=None) + + # content description type = Column('type', String(32), nullable=False) checksum = Column('checksum', String(64), nullable=False) - source = Column('source', String(200)) file_size = Column('file_size', Integer) canvas_width = Column('image_width', Integer) canvas_height = Column('image_height', Integer) - flags = Column('flags', PickleType, default=None) + mime_type = Column('mime-type', String(32), nullable=False) + # foreign tables user = relationship('User') tags = relationship('Tag', backref='posts', secondary='post_tag') relations = relationship( @@ -106,8 +109,9 @@ class Post(Base): 'PostFavorite', cascade='all, delete-orphan', lazy='joined') notes = relationship( 'PostNote', cascade='all, delete-orphan', lazy='joined') - comments = relationship('Comment') + + # dynamic columns tag_count = column_property( select([func.count(PostTag.tag_id)]) \ .where(PostTag.post_id == post_id) \ diff --git a/server/szurubooru/errors.py b/server/szurubooru/errors.py index 60491da..8842382 100644 --- a/server/szurubooru/errors.py +++ b/server/szurubooru/errors.py @@ -5,3 +5,7 @@ class ValidationError(RuntimeError): pass class SearchError(RuntimeError): pass class NotFoundError(RuntimeError): pass class ProcessingError(RuntimeError): pass + +class MissingRequiredFileError(ValidationError): pass +class MissingRequiredParameterError(ValidationError): pass +class InvalidParameterError(ValidationError): pass diff --git a/server/szurubooru/func/files.py b/server/szurubooru/func/files.py index 7091f88..2a0217b 100644 --- a/server/szurubooru/func/files.py +++ b/server/szurubooru/func/files.py @@ -1,8 +1,23 @@ import os from szurubooru import config +def _get_full_path(path): + return os.path.join(config.config['data_dir'], path) + +def delete(path): + full_path = _get_full_path(path) + if os.path.exists(full_path): + os.unlink(full_path) + +def get(path): + full_path = _get_full_path(path) + if not os.path.exists(full_path): + return None + with open(full_path, 'rb') as handle: + return handle.read() + def save(path, content): - full_path = os.path.join(config.config['data_dir'], path) + full_path = _get_full_path(path) os.makedirs(os.path.dirname(full_path), exist_ok=True) with open(full_path, 'wb') as handle: handle.write(content) diff --git a/server/szurubooru/func/images.py b/server/szurubooru/func/images.py index e94d675..741d726 100644 --- a/server/szurubooru/func/images.py +++ b/server/szurubooru/func/images.py @@ -1,3 +1,4 @@ +import json import subprocess from szurubooru import errors @@ -7,6 +8,19 @@ _SCALE_FIT_FMT = \ class Image(object): def __init__(self, content): self.content = content + self._reload_info() + + @property + def width(self): + return self.info['streams'][0]['width'] + + @property + def height(self): + return self.info['streams'][0]['height'] + + @property + def frames(self): + return self.info['streams'][0]['nb_read_frames'] def resize_fill(self, width, height): self.content = self._execute([ @@ -17,6 +31,7 @@ class Image(object): '-vcodec', 'png', '-', ]) + self._reload_info() def to_png(self): return self._execute([ @@ -36,9 +51,9 @@ class Image(object): '-', ]) - def _execute(self, cli): + def _execute(self, cli, program='ffmpeg'): proc = subprocess.Popen( - ['ffmpeg', '-loglevel', '24'] + cli, + [program, '-loglevel', '24'] + cli, stdout=subprocess.PIPE, stdin=subprocess.PIPE, stderr=subprocess.PIPE) @@ -47,3 +62,15 @@ class Image(object): raise errors.ProcessingError( 'Error while processing image.\n' + err.decode('utf-8')) return out + + def _reload_info(self): + self.info = json.loads(self._execute([ + '-of', 'json', + '-select_streams', 'v', + '-show_streams', + '-count_frames', + '-i', '-', + ], program='ffprobe').decode('utf-8')) + assert 'streams' in self.info + if len(self.info['streams']) != 1: + raise errors.ProcessingError('Multiple video streams detected.') diff --git a/server/szurubooru/func/mime.py b/server/szurubooru/func/mime.py index a7d26ca..26f17f0 100644 --- a/server/szurubooru/func/mime.py +++ b/server/szurubooru/func/mime.py @@ -2,7 +2,7 @@ import re def get_mime_type(content): if not content: - return None + return 'application/octet-stream' if content[0:3] in (b'CWS', b'FWS', b'ZWS'): return 'application/x-shockwave-flash' @@ -33,7 +33,7 @@ def get_extension(mime_type): 'video/mp4': 'mp4', 'video/webm': 'webm', } - return extension_map.get(mime_type.strip().lower(), None) + return extension_map.get((mime_type or '').strip().lower(), None) def is_flash(mime_type): return mime_type.lower() == 'application/x-shockwave-flash' diff --git a/server/szurubooru/func/posts.py b/server/szurubooru/func/posts.py index 7e1524a..b79c117 100644 --- a/server/szurubooru/func/posts.py +++ b/server/szurubooru/func/posts.py @@ -1,10 +1,62 @@ import datetime import sqlalchemy -from szurubooru import db, errors -from szurubooru.func import users, snapshots, scores, comments +from szurubooru import config, db, errors +from szurubooru.func import ( + users, snapshots, scores, comments, tags, util, mime, images, files) + +EMPTY_PIXEL = \ + b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \ + b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \ + b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b' class PostNotFoundError(errors.NotFoundError): pass class PostAlreadyFeaturedError(errors.ValidationError): pass +class PostAlreadyUploadedError(errors.ValidationError): pass +class InvalidPostSafetyError(errors.ValidationError): pass +class InvalidPostSourceError(errors.ValidationError): pass +class InvalidPostContentError(errors.ValidationError): pass +class InvalidPostRelationError(errors.ValidationError): pass +class InvalidPostNoteError(errors.ValidationError): pass +class InvalidPostFlagError(errors.ValidationError): pass + +SAFETY_MAP = { + db.Post.SAFETY_SAFE: 'safe', + db.Post.SAFETY_SKETCHY: 'sketchy', + db.Post.SAFETY_UNSAFE: 'unsafe', +} +TYPE_MAP = { + db.Post.TYPE_IMAGE: 'image', + db.Post.TYPE_ANIMATION: 'animation', + db.Post.TYPE_VIDEO: 'video', + db.Post.TYPE_FLASH: 'flash', +} + +def get_post_content_url(post): + return '%s/posts/%d.%s' % ( + config.config['data_url'].rstrip('/'), + post.post_id, + mime.get_extension(post.mime_type) or 'dat') + +def get_post_thumbnail_url(post): + return '%s/generated-thumbnails/%d.jpg' % ( + config.config['data_url'].rstrip('/'), + post.post_id) + +def get_post_content_path(post): + return 'posts/%d.%s' % ( + post.post_id, mime.get_extension(post.mime_type) or 'dat') + +def get_post_thumbnail_path(post): + return 'generated-thumbnails/%d.jpg' % (post.post_id) + +def get_post_thumbnail_backup_path(post): + return 'posts/custom-thumbnails/%d.dat' % (post.post_id) + +def serialize_note(note): + return { + 'polygon': note.path, + 'text': note.text, + } def serialize_post(post, authenticated_user): if not post: @@ -14,20 +66,19 @@ def serialize_post(post, authenticated_user): 'id': post.post_id, 'creationTime': post.creation_time, 'lastEditTime': post.last_edit_time, - 'safety': post.safety, - 'type': post.type, - 'checksum': post.checksum, + 'safety': SAFETY_MAP[post.safety], 'source': post.source, + 'type': TYPE_MAP[post.type], + 'checksum': post.checksum, 'fileSize': post.file_size, 'canvasWidth': post.canvas_width, 'canvasHeight': post.canvas_height, + 'contentUrl': get_post_content_url(post), + 'thumbnailUrl': get_post_thumbnail_url(post), 'flags': post.flags, 'tags': [tag.first_name for tag in post.tags], 'relations': [rel.post_id for rel in post.relations], - 'notes': sorted([{ - 'path': note.path, - 'text': note.text, - } for note in post.notes]), + 'notes': sorted(serialize_note(note) for note in post.notes), 'user': users.serialize_user(post.user, authenticated_user), 'score': post.score, 'featureCount': post.feature_count, @@ -75,6 +126,140 @@ def try_get_featured_post(): .first() return post_feature.post if post_feature else None +def create_post(content, tag_names, user): + post = db.Post() + post.safety = db.Post.SAFETY_SAFE + post.user = user + post.creation_time = datetime.datetime.now() + post.flags = [] + + # we'll need post ID + post.type = '' + post.checksum = '' + post.mime_type = '' + db.session.add(post) + db.session.flush() + + update_post_content(post, content) + update_post_tags(post, tag_names) + return post + +def update_post_safety(post, safety): + safety = util.flip(SAFETY_MAP).get(safety, None) + if not safety: + raise InvalidPostSafetyError( + 'Safety can be either of %r.', list(SAFETY_MAP.values())) + post.safety = safety + +def update_post_source(post, source): + if util.value_exceeds_column_size(source, db.Post.source): + raise InvalidPostSourceError('Source is too long.') + post.source = source + +def update_post_content(post, content): + if not content: + raise InvalidPostContentError('Post content missing.') + post.mime_type = mime.get_mime_type(content) + if mime.is_flash(post.mime_type): + post.type = db.Post.TYPE_FLASH + elif mime.is_image(post.mime_type): + if mime.is_animated_gif(content): + post.type = db.Post.TYPE_ANIMATION + else: + post.type = db.Post.TYPE_IMAGE + elif mime.is_video(post.mime_type): + post.type = db.Post.TYPE_VIDEO + else: + raise InvalidPostContentError('Unhandled file type: %r' % post.mime_type) + + post.checksum = util.get_md5(content) + other_post = db.session \ + .query(db.Post) \ + .filter(db.Post.checksum == post.checksum) \ + .filter(db.Post.post_id != post.post_id) \ + .one_or_none() + if other_post: + raise PostAlreadyUploadedError( + 'Post already uploaded (%d)' % other_post.post_id) + + post.file_size = len(content) + try: + image = images.Image(content) + post.canvas_width = image.width + post.canvas_height = image.height + except errors.ProcessingError: + post.canvas_width = None + post.canvas_height = None + files.save(get_post_content_path(post), content) + update_post_thumbnail(post, content=None, delete=False) + +def update_post_thumbnail(post, content=None, delete=True): + if content is None: + content = files.get(get_post_content_path(post)) + if delete: + files.delete(get_post_thumbnail_backup_path(post)) + else: + files.save(get_post_thumbnail_backup_path(post), content) + + try: + image = images.Image(content) + image.resize_fill( + int(config.config['thumbnails']['post_width']), + int(config.config['thumbnails']['post_height'])) + files.save(get_post_thumbnail_path(post), image.to_jpeg()) + except errors.ProcessingError: + files.save(get_post_thumbnail_path(post), EMPTY_PIXEL) + +def update_post_tags(post, tag_names): + existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names) + post.tags = existing_tags + new_tags + +def update_post_relations(post, post_ids): + relations = db.session \ + .query(db.Post) \ + .filter(db.Post.post_id.in_(post_ids)) \ + .all() + if len(relations) != len(post_ids): + raise InvalidPostRelationError('One of relations does not exist.') + post.relations = relations + +def update_post_notes(post, notes): + post.notes = [] + for note in notes: + for field in ('polygon', 'text'): + if field not in note: + raise InvalidPostNoteError('Note is missing %r field.' % field) + if not note['text']: + raise InvalidPostNoteError('A note\'s text cannot be empty.') + if len(note['polygon']) < 3: + raise InvalidPostNoteError( + 'A note\'s polygon must have at least 3 points.') + for point in note['polygon']: + if len(point) != 2: + raise InvalidPostNoteError( + 'A point in note\'s polygon must have two coordinates.') + try: + pos_x = float(point[0]) + pos_y = float(point[1]) + if not 0 <= pos_x <= 1 or not 0 <= pos_y <= 1: + raise InvalidPostNoteError( + 'A point in note\'s polygon must be in 0..1 range.') + except ValueError: + raise InvalidPostNoteError( + 'A point in note\'s polygon must be numeric.') + if util.value_exceeds_column_size(note['text'], db.PostNote.text): + raise InvalidPostNoteError('Note text is too long.') + post.notes.append( + db.PostNote(polygon=note['polygon'], text=note['text'])) + +def update_post_flags(post, flags): + available_flags = ('loop',) + for flag in flags: + if flag not in available_flags: + raise InvalidPostFlagError( + 'Flag must be one of %r.' % available_flags) + post.flags = flags + def feature_post(post, user): post_feature = db.PostFeature() post_feature.time = datetime.datetime.now() diff --git a/server/szurubooru/func/tag_categories.py b/server/szurubooru/func/tag_categories.py index ea6f2bd..de9df8b 100644 --- a/server/szurubooru/func/tag_categories.py +++ b/server/szurubooru/func/tag_categories.py @@ -77,7 +77,7 @@ def try_get_default_category(): .query(db.TagCategory) \ .order_by(db.TagCategory.tag_category_id.asc()) \ .limit(1) \ - .one() + .first() def get_default_category(): category = try_get_default_category() diff --git a/server/szurubooru/func/tags.py b/server/szurubooru/func/tags.py index 7b8b776..ecb3c99 100644 --- a/server/szurubooru/func/tags.py +++ b/server/szurubooru/func/tags.py @@ -12,6 +12,9 @@ class TagIsInUseError(errors.ValidationError): pass class InvalidTagNameError(errors.ValidationError): pass class InvalidTagRelationError(errors.ValidationError): pass +DEFAULT_CATEGORY_NAME = 'Default' +DEFAULT_CATEGORY_COLOR = 'default' + def _verify_name_validity(name): name_regex = config.config['tag_name_regex'] if not re.match(name_regex, name): @@ -26,6 +29,13 @@ def _lower_list(names): def _check_name_intersection(names1, names2): return len(set(_lower_list(names1)).intersection(_lower_list(names2))) > 0 +def _get_default_category_name(): + tag_category = tag_categories.try_get_default_category() + if tag_category: + return tag_category.name + else: + return DEFAULT_CATEGORY_NAME + def serialize_tag(tag): return { 'names': [tag_name.name for tag_name in tag.names], @@ -104,23 +114,24 @@ def get_or_create_tags_by_names(names): names = util.icase_unique(names) for name in names: _verify_name_validity(name) - related_tags = get_tags_by_names(names) + existing_tags = get_tags_by_names(names) new_tags = [] + tag_category_name = _get_default_category_name() for name in names: found = False - for related_tag in related_tags: - if _check_name_intersection(_get_plain_names(related_tag), [name]): + for existing_tag in existing_tags: + if _check_name_intersection(_get_plain_names(existing_tag), [name]): found = True break if not found: new_tag = create_tag( names=[name], - category_name=tag_categories.get_default_category().name, + category_name=tag_category_name, suggestions=[], implications=[]) db.session.add(new_tag) new_tags.append(new_tag) - return related_tags, new_tags + return existing_tags, new_tags def get_tag_siblings(tag): tag_alias = sqlalchemy.orm.aliased(db.Tag) @@ -159,7 +170,8 @@ def update_tag_category_name(tag, category_name): .filter(db.TagCategory.name == category_name) \ .first() if not category: - category = tag_categories.create_category(category_name, 'default') + category = tag_categories.create_category( + category_name, DEFAULT_CATEGORY_COLOR) db.session.add(category) tag.category = category diff --git a/server/szurubooru/func/users.py b/server/szurubooru/func/users.py index 09581cf..2e17e63 100644 --- a/server/szurubooru/func/users.py +++ b/server/szurubooru/func/users.py @@ -1,5 +1,4 @@ import datetime -import hashlib import re from sqlalchemy import func from szurubooru import config, db, errors @@ -28,11 +27,9 @@ def serialize_user(user, authenticated_user, force_show_email=False): } if user.avatar_style == user.AVATAR_GRAVATAR: - md5 = hashlib.md5() - md5.update((user.email or user.name).lower().encode('utf-8')) - digest = md5.hexdigest() ret['avatarUrl'] = 'http://gravatar.com/avatar/%s?d=retro&s=%d' % ( - digest, config.config['thumbnails']['avatar_width']) + util.get_md5((user.email or user.name).lower()), + config.config['thumbnails']['avatar_width']) else: ret['avatarUrl'] = '%s/avatars/%s.jpg' % ( config.config['data_url'].rstrip('/'), user.name.lower()) diff --git a/server/szurubooru/func/util.py b/server/szurubooru/func/util.py index 83d3931..2fcddd8 100644 --- a/server/szurubooru/func/util.py +++ b/server/szurubooru/func/util.py @@ -1,8 +1,19 @@ import datetime +import hashlib import re from sqlalchemy.inspection import inspect from szurubooru.errors import ValidationError +def get_md5(source): + if not isinstance(source, bytes): + source = source.encode('utf-8') + md5 = hashlib.md5() + md5.update(source) + return md5.hexdigest() + +def flip(source): + return {v: k for k, v in source.items()} + def get_resource_info(entity): serializers = { 'tag': lambda tag: tag.first_name, @@ -96,4 +107,9 @@ def icase_unique(source): return target def value_exceeds_column_size(value, column): - return len(value) > column.property.columns[0].type.length + if not value: + return False + max_length = column.property.columns[0].type.length + if max_length is None: + return False + return len(value) > max_length diff --git a/server/szurubooru/migrations/versions/23abaf4a0a4b_add_mime_type_to_posts.py b/server/szurubooru/migrations/versions/23abaf4a0a4b_add_mime_type_to_posts.py new file mode 100644 index 0000000..8ac336e --- /dev/null +++ b/server/szurubooru/migrations/versions/23abaf4a0a4b_add_mime_type_to_posts.py @@ -0,0 +1,20 @@ +''' +Add mime type to posts + +Revision ID: 23abaf4a0a4b +Created at: 2016-05-02 00:02:33.024885 +''' + +import sqlalchemy as sa +from alembic import op + +revision = '23abaf4a0a4b' +down_revision = 'ed6dd16a30f3' +branch_labels = None +depends_on = None + +def upgrade(): + op.add_column('post', sa.Column('mime-type', sa.String(length=32), nullable=False)) + +def downgrade(): + op.drop_column('post', 'mime-type') diff --git a/server/szurubooru/tests/api/test_comment_creating.py b/server/szurubooru/tests/api/test_comment_creating.py index efa9130..218239c 100644 --- a/server/szurubooru/tests/api/test_comment_creating.py +++ b/server/szurubooru/tests/api/test_comment_creating.py @@ -6,6 +6,7 @@ from szurubooru.func import util, posts @pytest.fixture def test_ctx(config_injector, context_factory, post_factory, user_factory): config_injector({ + 'data_url': 'http://example.com', 'ranks': ['anonymous', 'regular_user'], 'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord'}, 'privileges': {'comments:create': 'regular_user'}, diff --git a/server/szurubooru/tests/api/test_comment_rating.py b/server/szurubooru/tests/api/test_comment_rating.py index 73ad4b6..7d10700 100644 --- a/server/szurubooru/tests/api/test_comment_rating.py +++ b/server/szurubooru/tests/api/test_comment_rating.py @@ -6,6 +6,7 @@ from szurubooru.func import util, comments, scores @pytest.fixture def test_ctx(config_injector, context_factory, user_factory, comment_factory): config_injector({ + 'data_url': 'http://example.com', 'ranks': ['anonymous', 'regular_user', 'mod'], 'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord'}, 'privileges': { diff --git a/server/szurubooru/tests/api/test_comment_retrieving.py b/server/szurubooru/tests/api/test_comment_retrieving.py index cc9bfcb..b0011b6 100644 --- a/server/szurubooru/tests/api/test_comment_retrieving.py +++ b/server/szurubooru/tests/api/test_comment_retrieving.py @@ -6,6 +6,7 @@ from szurubooru.func import util, comments @pytest.fixture def test_ctx(context_factory, config_injector, user_factory, comment_factory): config_injector({ + 'data_url': 'http://example.com', 'ranks': ['anonymous', 'regular_user', 'mod', 'admin'], 'rank_names': {'regular_user': 'Peasant'}, 'privileges': { diff --git a/server/szurubooru/tests/api/test_comment_updating.py b/server/szurubooru/tests/api/test_comment_updating.py index 0d5297b..08ade5c 100644 --- a/server/szurubooru/tests/api/test_comment_updating.py +++ b/server/szurubooru/tests/api/test_comment_updating.py @@ -6,6 +6,7 @@ from szurubooru.func import util, comments @pytest.fixture def test_ctx(config_injector, context_factory, user_factory, comment_factory): config_injector({ + 'data_url': 'http://example.com', 'ranks': ['anonymous', 'regular_user', 'mod'], 'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord', 'mod': 'King'}, 'privileges': { diff --git a/server/szurubooru/tests/api/test_post_creating.py b/server/szurubooru/tests/api/test_post_creating.py new file mode 100644 index 0000000..9467506 --- /dev/null +++ b/server/szurubooru/tests/api/test_post_creating.py @@ -0,0 +1,133 @@ +import datetime +import os +import unittest.mock +import pytest +from szurubooru import api, db, errors +from szurubooru.func import posts, tags, snapshots + +@pytest.fixture(autouse=True) +def inject_config(config_injector): + config_injector({ + 'ranks': ['anonymous', 'regular_user'], + 'privileges': {'posts:create': 'regular_user'}, + }) + +def test_creating_minimal_posts( + context_factory, post_factory, user_factory): + auth_user = user_factory(rank='regular_user') + post = post_factory() + db.session.add(post) + db.session.flush() + + with unittest.mock.patch('szurubooru.func.posts.create_post'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ + unittest.mock.patch('szurubooru.func.posts.serialize_post_with_details'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ + unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'): + + posts.create_post.return_value = post + posts.serialize_post_with_details.return_value = 'serialized post' + + result = api.PostListApi().post( + context_factory( + input={ + 'safety': 'safe', + 'tags': ['tag1', 'tag2'], + }, + files={ + 'content': 'post-content', + }, + user=auth_user)) + + assert result == 'serialized post' + posts.create_post.assert_called_once_with( + 'post-content', ['tag1', 'tag2'], auth_user) + posts.update_post_safety.assert_called_once_with(post, 'safe') + posts.update_post_source.assert_called_once_with(post, None) + posts.update_post_relations.assert_called_once_with(post, []) + posts.update_post_notes.assert_called_once_with(post, []) + posts.update_post_flags.assert_called_once_with(post, []) + posts.serialize_post_with_details.assert_called_once_with(post, auth_user) + tags.export_to_json.assert_called_once_with() + snapshots.save_entity_creation.assert_called_once_with(post, auth_user) + +def test_creating_full_posts(context_factory, post_factory, user_factory): + auth_user = user_factory(rank='regular_user') + post = post_factory() + db.session.add(post) + db.session.flush() + + with unittest.mock.patch('szurubooru.func.posts.create_post'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_source'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \ + unittest.mock.patch('szurubooru.func.posts.serialize_post_with_details'), \ + unittest.mock.patch('szurubooru.func.tags.export_to_json'), \ + unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'): + + posts.create_post.return_value = post + posts.serialize_post_with_details.return_value = 'serialized post' + + result = api.PostListApi().post( + context_factory( + input={ + 'safety': 'safe', + 'tags': ['tag1', 'tag2'], + 'relations': [1, 2], + 'source': 'source', + 'notes': ['note1', 'note2'], + 'flags': ['flag1', 'flag2'], + }, + files={ + 'content': 'post-content', + }, + user=auth_user)) + + assert result == 'serialized post' + posts.create_post.assert_called_once_with( + 'post-content', ['tag1', 'tag2'], auth_user) + posts.update_post_safety.assert_called_once_with(post, 'safe') + posts.update_post_source.assert_called_once_with(post, 'source') + posts.update_post_relations.assert_called_once_with(post, [1, 2]) + posts.update_post_notes.assert_called_once_with(post, ['note1', 'note2']) + posts.update_post_flags.assert_called_once_with(post, ['flag1', 'flag2']) + posts.serialize_post_with_details.assert_called_once_with(post, auth_user) + tags.export_to_json.assert_called_once_with() + snapshots.save_entity_creation.assert_called_once_with(post, auth_user) + +@pytest.mark.parametrize('field', ['tags', 'safety']) +def test_trying_to_omit_mandatory_field(context_factory, user_factory, field): + input = { + 'safety': 'safe', + 'tags': ['tag1', 'tag2'], + } + del input[field] + with pytest.raises(errors.MissingRequiredParameterError): + api.PostListApi().post( + context_factory( + input=input, + files={'content': '...'}, + user=user_factory(rank='regular_user'))) + +def test_trying_to_omit_content(context_factory, user_factory): + with pytest.raises(errors.MissingRequiredFileError): + api.PostListApi().post( + context_factory( + input={ + 'safety': 'safe', + 'tags': ['tag1', 'tag2'], + }, + user=user_factory(rank='regular_user'))) + +def test_trying_to_create_without_privileges(context_factory, user_factory): + with pytest.raises(errors.AuthError): + api.PostListApi().post( + context_factory( + input={'name': 'meta', 'colro': 'black'}, + user=user_factory(rank='anonymous'))) diff --git a/server/szurubooru/tests/api/test_post_favoriting.py b/server/szurubooru/tests/api/test_post_favoriting.py index 30c3f2d..e5e627d 100644 --- a/server/szurubooru/tests/api/test_post_favoriting.py +++ b/server/szurubooru/tests/api/test_post_favoriting.py @@ -6,6 +6,7 @@ from szurubooru.func import util, posts @pytest.fixture def test_ctx(config_injector, context_factory, user_factory, post_factory): config_injector({ + 'data_url': 'http://example.com', 'ranks': ['anonymous', 'regular_user', 'mod'], 'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord'}, 'privileges': { diff --git a/server/szurubooru/tests/api/test_post_featuring.py b/server/szurubooru/tests/api/test_post_featuring.py index 005b04b..8a030f3 100644 --- a/server/szurubooru/tests/api/test_post_featuring.py +++ b/server/szurubooru/tests/api/test_post_featuring.py @@ -6,6 +6,7 @@ from szurubooru.func import util, posts @pytest.fixture def test_ctx(context_factory, config_injector, user_factory, post_factory): config_injector({ + 'data_url': 'http://example.com', 'privileges': { 'posts:feature': 'regular_user', 'posts:view': 'regular_user', diff --git a/server/szurubooru/tests/api/test_post_rating.py b/server/szurubooru/tests/api/test_post_rating.py index 65486e5..def80fa 100644 --- a/server/szurubooru/tests/api/test_post_rating.py +++ b/server/szurubooru/tests/api/test_post_rating.py @@ -5,6 +5,7 @@ from szurubooru.func import util, posts, scores @pytest.fixture def test_ctx(config_injector, context_factory, user_factory, post_factory): config_injector({ + 'data_url': 'http://example.com', 'ranks': ['anonymous', 'regular_user'], 'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord'}, 'privileges': {'posts:score': 'regular_user'}, diff --git a/server/szurubooru/tests/api/test_post_retrieving.py b/server/szurubooru/tests/api/test_post_retrieving.py index d307b8d..ec39a9a 100644 --- a/server/szurubooru/tests/api/test_post_retrieving.py +++ b/server/szurubooru/tests/api/test_post_retrieving.py @@ -6,6 +6,7 @@ from szurubooru.func import util, posts @pytest.fixture def test_ctx(context_factory, config_injector, user_factory, post_factory): config_injector({ + 'data_url': 'http://example.com', 'privileges': { 'posts:list': 'regular_user', 'posts:view': 'regular_user', diff --git a/server/szurubooru/tests/assets/flash.swf b/server/szurubooru/tests/assets/flash.swf index c6195c4..c02191c 100644 Binary files a/server/szurubooru/tests/assets/flash.swf and b/server/szurubooru/tests/assets/flash.swf differ diff --git a/server/szurubooru/tests/assets/gif.gif b/server/szurubooru/tests/assets/gif.gif index edaf2b9..a4a9e40 100644 Binary files a/server/szurubooru/tests/assets/gif.gif and b/server/szurubooru/tests/assets/gif.gif differ diff --git a/server/szurubooru/tests/assets/jpeg.jpg b/server/szurubooru/tests/assets/jpeg.jpg index 71911bf..2f0b00f 100644 Binary files a/server/szurubooru/tests/assets/jpeg.jpg and b/server/szurubooru/tests/assets/jpeg.jpg differ diff --git a/server/szurubooru/tests/assets/png-broken.png b/server/szurubooru/tests/assets/png-broken.png new file mode 100644 index 0000000..230aa15 Binary files /dev/null and b/server/szurubooru/tests/assets/png-broken.png differ diff --git a/server/szurubooru/tests/assets/png-transparent.png b/server/szurubooru/tests/assets/png-transparent.png deleted file mode 100644 index 91a99b9..0000000 Binary files a/server/szurubooru/tests/assets/png-transparent.png and /dev/null differ diff --git a/server/szurubooru/tests/assets/png.png b/server/szurubooru/tests/assets/png.png new file mode 100644 index 0000000..b7923db Binary files /dev/null and b/server/szurubooru/tests/assets/png.png differ diff --git a/server/szurubooru/tests/conftest.py b/server/szurubooru/tests/conftest.py index 5b4cbfd..99f4ffa 100644 --- a/server/szurubooru/tests/conftest.py +++ b/server/szurubooru/tests/conftest.py @@ -1,4 +1,5 @@ import contextlib +import os import datetime import uuid import pytest @@ -136,6 +137,7 @@ def post_factory(): post.type = type post.checksum = checksum post.flags = [] + post.mime_type = 'application/octet-stream' post.creation_time = datetime.datetime(1996, 1, 1) return post return factory @@ -156,3 +158,11 @@ def comment_factory(user_factory, post_factory): comment.creation_time = datetime.datetime(1996, 1, 1) return comment return factory + +@pytest.fixture +def read_asset(): + def get(path): + path = os.path.join(os.path.dirname(__file__), 'assets', path) + with open(path, 'rb') as handle: + return handle.read() + return get diff --git a/server/szurubooru/tests/db/test_post.py b/server/szurubooru/tests/db/test_post.py index 45a8951..041f58a 100644 --- a/server/szurubooru/tests/db/test_post.py +++ b/server/szurubooru/tests/db/test_post.py @@ -13,6 +13,7 @@ def test_saving_post(post_factory, user_factory, tag_factory): post.checksum = 'deadbeef' post.creation_time = datetime(1997, 1, 1) post.last_edit_time = datetime(1998, 1, 1) + post.mime_type = 'application/whatever' db.session.add_all([user, tag1, tag2, related_post1, related_post2, post]) post.user = user diff --git a/server/szurubooru/tests/func/test_mime.py b/server/szurubooru/tests/func/test_mime.py index 830b275..1d5a13a 100644 --- a/server/szurubooru/tests/func/test_mime.py +++ b/server/szurubooru/tests/func/test_mime.py @@ -6,7 +6,7 @@ from szurubooru.func import mime ('mp4.mp4', 'video/mp4'), ('webm.webm', 'video/webm'), ('flash.swf', 'application/x-shockwave-flash'), - ('png-transparent.png', 'image/png'), + ('png.png', 'image/png'), ('jpeg.jpg', 'image/jpeg'), ('gif.gif', 'image/gif'), ]) diff --git a/server/szurubooru/tests/func/test_posts.py b/server/szurubooru/tests/func/test_posts.py new file mode 100644 index 0000000..c0c458c --- /dev/null +++ b/server/szurubooru/tests/func/test_posts.py @@ -0,0 +1,463 @@ +import os +import datetime +import unittest.mock +import pytest +from szurubooru import db +from szurubooru.func import posts, users, comments, snapshots, tags, images + +@pytest.mark.parametrize('input_mime_type,expected_url', [ + ('image/jpeg', 'http://example.com/posts/1.jpg'), + ('image/gif', 'http://example.com/posts/1.gif'), + ('totally/unknown', 'http://example.com/posts/1.dat'), +]) +def test_get_post_url(input_mime_type, expected_url, config_injector): + config_injector({'data_url': 'http://example.com/'}) + post = db.Post() + post.post_id = 1 + post.mime_type = input_mime_type + assert posts.get_post_content_url(post) == expected_url + +@pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) +def test_get_post_thumbnail_url(input_mime_type, config_injector): + config_injector({'data_url': 'http://example.com/'}) + post = db.Post() + post.post_id = 1 + post.mime_type = input_mime_type + assert posts.get_post_thumbnail_url(post) \ + == 'http://example.com/generated-thumbnails/1.jpg' + +@pytest.mark.parametrize('input_mime_type,expected_path', [ + ('image/jpeg', 'posts/1.jpg'), + ('image/gif', 'posts/1.gif'), + ('totally/unknown', 'posts/1.dat'), +]) +def test_get_post_content_path(input_mime_type, expected_path): + post = db.Post() + post.post_id = 1 + post.mime_type = input_mime_type + assert posts.get_post_content_path(post) == expected_path + +@pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) +def test_get_post_thumbnail_path(input_mime_type): + post = db.Post() + post.post_id = 1 + post.mime_type = input_mime_type + assert posts.get_post_thumbnail_path(post) == 'generated-thumbnails/1.jpg' + +@pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif']) +def test_get_post_thumbnail_backup_path(input_mime_type): + post = db.Post() + post.post_id = 1 + post.mime_type = input_mime_type + assert posts.get_post_thumbnail_backup_path(post) \ + == 'posts/custom-thumbnails/1.dat' + +def test_serialize_note(): + note = db.PostNote() + note.path = [[0, 1], [1, 1], [1, 0], [0, 0]] + note.text = '...' + assert posts.serialize_note(note) == { + 'polygon': [[0, 1], [1, 1], [1, 0], [0, 0]], + 'text': '...' + } + +def test_serialize_empty_post(): + assert posts.serialize_post(None, None) is None + +def test_serialize_post(post_factory, user_factory, tag_factory): + with unittest.mock.patch('szurubooru.func.users.serialize_user'): + users.serialize_user.side_effect = lambda user, auth_user: user.name + + auth_user = user_factory(name='auth user') + post = db.Post() + post.post_id = 1 + post.creation_time = datetime.datetime(1997, 1, 1) + post.last_edit_time = datetime.datetime(1998, 1, 1) + post.tags = [ + tag_factory(names=['tag1', 'tag2']), + tag_factory(names=['tag3']) + ] + post.safety = db.Post.SAFETY_SAFE + post.source = '4gag' + post.type = db.Post.TYPE_IMAGE + post.checksum = 'deadbeef' + post.mime_type = 'image/jpeg' + post.file_size = 100 + post.user = user_factory(name='post author') + post.canvas_width = 200 + post.canvas_height = 300 + post.flags = ['loop'] + db.session.add(post) + db.session.flush() + db.session.add_all([ + db.PostFavorite( + post=post, + user=user_factory(name='fav1'), + time=datetime.datetime(1800, 1, 1)), + db.PostFeature( + post=post, + user=user_factory(), + time=datetime.datetime(1999, 1, 1)), + db.PostScore( + post=post, + user=auth_user, + score=-1, + time=datetime.datetime(1800, 1, 1)), + db.PostScore( + post=post, + user=user_factory(), + score=1, + time=datetime.datetime(1800, 1, 1)), + db.PostScore( + post=post, + user=user_factory(), + score=1, + time=datetime.datetime(1800, 1, 1))]) + db.session.flush() + + result = posts.serialize_post(post, auth_user) + + assert result == { + 'id': 1, + 'creationTime': datetime.datetime(1997, 1, 1), + 'lastEditTime': datetime.datetime(1998, 1, 1), + 'safety': 'safe', + 'source': '4gag', + 'type': 'image', + 'checksum': 'deadbeef', + 'fileSize': 100, + 'canvasWidth': 200, + 'canvasHeight': 300, + 'contentUrl': 'http://example.com/posts/1.jpg', + 'thumbnailUrl': 'http://example.com/generated-thumbnails/1.jpg', + 'flags': ['loop'], + 'tags': ['tag1', 'tag3'], + 'relations': [], + 'notes': [], + 'user': 'post author', + 'score': 1, + 'ownScore': -1, + 'featureCount': 1, + 'lastFeatureTime': datetime.datetime(1999, 1, 1), + 'favoritedBy': ['fav1'], + } + +def test_serialize_post_with_details(post_factory, comment_factory, user_factory): + with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \ + unittest.mock.patch('szurubooru.func.snapshots.get_serialized_history'), \ + unittest.mock.patch('szurubooru.func.posts.serialize_post'): + comments.serialize_comment.side_effect \ + = lambda comment, auth_user: comment.user.name + posts.serialize_post.side_effect \ + = lambda post, auth_user: post.post_id + snapshots.get_serialized_history.return_value = 'snapshot history' + + auth_user = user_factory(name='auth user') + post = post_factory() + post.comments = [ + comment_factory(user=user_factory(name='commenter1')), + comment_factory(user=user_factory(name='commenter2')), + ] + db.session.add(post) + db.session.flush() + + result = posts.serialize_post_with_details(post, auth_user) + + assert result == { + 'post': post.post_id, + 'snapshots': 'snapshot history', + 'comments': ['commenter1', 'commenter2'], + } + +def test_get_post_count(post_factory): + previous_count = posts.get_post_count() + db.session.add_all([post_factory(), post_factory()]) + new_count = posts.get_post_count() + assert previous_count == 0 + assert new_count == 2 + +def test_try_get_post_by_id(post_factory): + post = post_factory() + db.session.add(post) + db.session.flush() + assert posts.try_get_post_by_id(post.post_id) == post + assert posts.try_get_post_by_id(post.post_id + 1) is None + +def test_get_post_by_id(post_factory): + post = post_factory() + db.session.add(post) + db.session.flush() + assert posts.get_post_by_id(post.post_id) == post + with pytest.raises(posts.PostNotFoundError): + posts.get_post_by_id(post.post_id + 1) + +def test_create_post(user_factory, fake_datetime): + with unittest.mock.patch('szurubooru.func.posts.update_post_content'), \ + unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \ + fake_datetime('1997-01-01'): + auth_user = user_factory() + post = posts.create_post('content', ['tag'], auth_user) + assert post.creation_time == datetime.datetime(1997, 1, 1) + assert post.last_edit_time is None + posts.update_post_tags.assert_called_once_with(post, ['tag']) + posts.update_post_content.assert_called_once_with(post, 'content') + +@pytest.mark.parametrize('input_safety,expected_safety', [ + ('safe', db.Post.SAFETY_SAFE), + ('sketchy', db.Post.SAFETY_SKETCHY), + ('unsafe', db.Post.SAFETY_UNSAFE), +]) +def test_update_post_safety(input_safety, expected_safety): + post = db.Post() + posts.update_post_safety(post, input_safety) + assert post.safety == expected_safety + +def test_update_post_invalid_safety(): + post = db.Post() + with pytest.raises(posts.InvalidPostSafetyError): + posts.update_post_safety(post, 'bad') + +def test_update_post_source(): + post = db.Post() + posts.update_post_source(post, 'x') + assert post.source == 'x' + +def test_update_post_invalid_source(): + post = db.Post() + with pytest.raises(posts.InvalidPostSourceError): + posts.update_post_source(post, 'x' * 1000) + +@pytest.mark.parametrize( + 'input_file,expected_mime_type,expected_type,output_file_name', [ + ('png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'), + ('jpeg.jpg', 'image/jpeg', db.Post.TYPE_IMAGE, '1.jpg'), + ('gif.gif', 'image/gif', db.Post.TYPE_IMAGE, '1.gif'), + ('gif-animated.gif', 'image/gif', db.Post.TYPE_ANIMATION, '1.gif'), + ('webm.webm', 'video/webm', db.Post.TYPE_VIDEO, '1.webm'), + ('mp4.mp4', 'video/mp4', db.Post.TYPE_VIDEO, '1.mp4'), + ('flash.swf', 'application/x-shockwave-flash', db.Post.TYPE_FLASH, '1.swf'), +]) +def test_update_post_content( + tmpdir, + config_injector, + post_factory, + read_asset, + input_file, + expected_mime_type, + expected_type, + output_file_name): + with unittest.mock.patch('szurubooru.func.util.get_md5', return_value='crc'): + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + }) + post = post_factory(id=1) + db.session.add(post) + db.session.flush() + posts.update_post_content(post, read_asset(input_file)) + assert post.mime_type == expected_mime_type + assert post.type == expected_type + assert post.checksum == 'crc' + assert os.path.exists(str(tmpdir) + '/data/posts/' + output_file_name) + +def test_update_post_content_to_existing_content( + tmpdir, config_injector, post_factory, read_asset): + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + }) + post = post_factory() + another_post = post_factory() + db.session.add_all([post, another_post]) + db.session.flush() + posts.update_post_content(post, read_asset('png.png')) + with pytest.raises(posts.PostAlreadyUploadedError): + posts.update_post_content(another_post, read_asset('png.png')) + +def test_update_post_content_broken_content( + tmpdir, config_injector, post_factory, read_asset): + # the rationale behind this behavior is to salvage user upload even if the + # server software thinks it's broken. chances are the server is wrong, + # especially about flash movies. + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + }) + post = post_factory() + another_post = post_factory() + db.session.add_all([post, another_post]) + db.session.flush() + posts.update_post_content(post, read_asset('png-broken.png')) + assert post.canvas_width is None + assert post.canvas_height is None + +@pytest.mark.parametrize('input_content', [None, b'not a media file']) +def test_update_post_invalid_content(input_content): + post = db.Post() + with pytest.raises(posts.InvalidPostContentError): + posts.update_post_content(post, input_content) + +def test_update_post_thumbnail_to_new_one( + tmpdir, config_injector, read_asset, post_factory): + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + }) + post = post_factory(id=1) + db.session.add(post) + db.session.flush() + posts.update_post_content(post, read_asset('png.png')) + posts.update_post_thumbnail(post, read_asset('jpeg.jpg')) + assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') + assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') + with open(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat', 'rb') as handle: + assert handle.read() == read_asset('jpeg.jpg') + +def test_update_post_thumbnail_to_default( + tmpdir, config_injector, read_asset, post_factory): + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + }) + post = post_factory(id=1) + db.session.add(post) + db.session.flush() + posts.update_post_content(post, read_asset('png.png')) + posts.update_post_thumbnail(post, read_asset('jpeg.jpg')) + posts.update_post_thumbnail(post, None) + assert not os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') + assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') + +def test_update_post_thumbnail_broken_thumbnail( + tmpdir, config_injector, read_asset, post_factory): + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + }) + post = post_factory(id=1) + db.session.add(post) + db.session.flush() + posts.update_post_content(post, read_asset('png.png')) + posts.update_post_thumbnail(post, read_asset('png-broken.png')) + assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') + assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') + with open(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat', 'rb') as handle: + assert handle.read() == read_asset('png-broken.png') + with open(str(tmpdir) + '/data/generated-thumbnails/1.jpg', 'rb') as handle: + image = images.Image(handle.read()) + assert image.width == 1 + assert image.height == 1 + +def test_update_post_content_leaves_custom_thumbnail( + tmpdir, config_injector, read_asset, post_factory): + config_injector({ + 'data_dir': str(tmpdir.mkdir('data')), + 'thumbnails': { + 'post_width': 300, + 'post_height': 300, + }, + }) + post = post_factory(id=1) + db.session.add(post) + db.session.flush() + posts.update_post_content(post, read_asset('png.png')) + posts.update_post_thumbnail(post, read_asset('jpeg.jpg')) + posts.update_post_content(post, read_asset('png.png')) + assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat') + assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg') + +def test_update_post_tags(tag_factory): + post = db.Post() + with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'): + tags.get_or_create_tags_by_names.side_effect \ + = lambda tag_names: \ + ([tag_factory(names=[name]) for name in tag_names], []) + posts.update_post_tags(post, ['tag1', 'tag2']) + assert len(post.tags) == 2 + assert post.tags[0].names[0].name == 'tag1' + assert post.tags[1].names[0].name == 'tag2' + +def test_update_post_relations(post_factory): + relation1 = post_factory() + relation2 = post_factory() + db.session.add_all([relation1, relation2]) + db.session.flush() + post = db.Post() + posts.update_post_relations(post, [relation1.post_id, relation2.post_id]) + assert len(post.relations) == 2 + assert post.relations[0].post_id == relation1.post_id + assert post.relations[1].post_id == relation2.post_id + +def test_update_post_non_existing_relations(): + post = db.Post() + with pytest.raises(posts.InvalidPostRelationError): + posts.update_post_relations(post, [100]) + +def test_update_post_notes(): + post = db.Post() + posts.update_post_notes( + post, + [ + {'polygon': [[0, 0], [0, 1], [1, 0], [0, 0]], 'text': 'text1'}, + {'polygon': [[0, 0], [0, 1], [1, 0], [0, 0]], 'text': 'text2'}, + ]) + assert len(post.notes) == 2 + assert post.notes[0].polygon == [[0, 0], [0, 1], [1, 0], [0, 0]] + assert post.notes[0].text == 'text1' + assert post.notes[1].polygon == [[0, 0], [0, 1], [1, 0], [0, 0]] + assert post.notes[1].text == 'text2' + +@pytest.mark.parametrize('input', [ + [{'polygon': [[0, 0]], 'text': '...'}], + [{'polygon': [[0, 0], [0, 0], [0, 2]], 'text': '...'}], + [{'polygon': [[0, 0], [0, 0], [0, '...']], 'text': '...'}], + [{'polygon': [[0, 0], [0, 0], [0, 0, 0]], 'text': '...'}], + [{'polygon': [[0, 0], [0, 0], [0]], 'text': '...'}], + [{'polygon': [[0, 0], [0, 0], [0, 1]], 'text': ''}], + [{'polygon': [[0, 0], [0, 0], [0, 1]], 'text': None}], + [{'text': '...'}], + [{'polygon': [[0, 0], [0, 0], [0, 1]]}], +]) +def test_update_post_invalid_notes(input): + post = db.Post() + with pytest.raises(posts.InvalidPostNoteError): + posts.update_post_notes(post, input) + +def test_update_post_flags(): + post = db.Post() + posts.update_post_flags(post, ['loop']) + assert post.flags == ['loop'] + +def test_update_post_invalid_flags(): + post = db.Post() + with pytest.raises(posts.InvalidPostFlagError): + posts.update_post_flags(post, ['invalid']) + +def test_featuring_post(post_factory, user_factory): + post = post_factory() + user = user_factory() + + previous_featured_post = posts.try_get_featured_post() + posts.feature_post(post, user) + new_featured_post = posts.try_get_featured_post() + + assert previous_featured_post is None + assert new_featured_post == post