server/posts: add post creating

This commit is contained in:
rr- 2016-04-30 23:17:08 +02:00
parent a567974784
commit ce095816d9
35 changed files with 1060 additions and 90 deletions

127
API.md
View File

@ -29,7 +29,7 @@
- [Listing tag siblings](#listing-tag-siblings)
- Posts
- ~~Listing posts~~
- ~~Creating post~~
- [Creating post](#creating-post)
- ~~Updating post~~
- [Getting post](#getting-post)
- [Deleting post](#deleting-post)
@ -69,6 +69,7 @@
- [Detailed tag](#detailed-tag)
- [Post](#post)
- [Detailed post](#detailed-post)
- [Note](#note)
- [Comment](#comment)
- [Detailed comment](#detailed-comment)
- [Snapshot](#snapshot)
@ -125,7 +126,6 @@ Depending on the deployment, the URLs might be relative to some base path such
as `/api/`. Values denoted with diamond braces (`<like this>`) signify variable
data.
## Listing tag categories
- **Request**
@ -150,7 +150,6 @@ data.
caching. The data directory and its URL are controlled with `data_dir` and
`data_url` variables in server's configuration.
## Creating tag category
- **Request**
@ -181,7 +180,6 @@ data.
Creates a new tag category using specified parameters. Name must match
`tag_category_name_regex` from server's configuration.
## Updating tag category
- **Request**
@ -214,7 +212,6 @@ data.
match `tag_category_name_regex` from server's configuration. All fields are
optional - update concerns only provided fields.
## Getting tag category
- **Request**
@ -233,7 +230,6 @@ data.
Retrieves information about an existing tag category.
## Deleting tag category
- **Request**
@ -257,7 +253,6 @@ data.
Deletes existing tag category. The tag category to be deleted must have no
usages.
## Listing tags
- **Request**
@ -327,7 +322,6 @@ data.
None.
## Creating tag
- **Request**
@ -370,7 +364,6 @@ data.
first tag category found. If there are no tag categories established yet,
an error will be thrown.
## Updating tag
- **Request**
@ -412,7 +405,6 @@ data.
their category is set to the first tag category found. All fields are
optional - update concerns only provided fields.
## Getting tag
- **Request**
@ -431,7 +423,6 @@ data.
Retrieves information about an existing tag.
## Deleting tag
- **Request**
@ -453,7 +444,6 @@ data.
Deletes existing tag. The tag to be deleted must have no usages.
## Merging tags
- **Request**
@ -485,7 +475,6 @@ data.
and are discarded. The target tag effectively remains unchanged with the
exception of the set of posts it's used in.
## Listing tag siblings
- **Request**
@ -520,6 +509,48 @@ data.
appears with given tag. Results are sorted by occurrences count and the
list is truncated to the first 50 elements. Doesn't use paging.
## Creating post
- **Request**
`POST /posts/`
- **Input**
```json5
{
"tags": [<tag1>, <tag2>, <tag3>],
"safety": <safety>,
"source": <source>, // optional
"relations": [<post1>, <post2>, <post3>], // optional
"notes": [<note1>, <note2>, <note3>], // optional
"flags": [<flag1>, <flag2>] // optional
}
```
- **Files**
- `content` - the content of the content.
- `thumbnail` - the content of custom thumbnail (optional).
- **Output**
A [detailed post resource](#detailed-post).
- **Errors**
- tags have invalid names
- safety is invalid
- relations refer to non-existing posts
- privileges are too low
- **Description**
Creates a new post. If specified tags do not exist yet, they will be
automatically created. Tags created automatically have no implications, no
suggestions, one name and their category is set to the first tag category
found. Safety must be any of `"safe"`, `"sketchy"` or `"unsafe"`. `<flag>`
currently can be only `"loop"` to enable looping for video posts. Sending
empty `thumbnail` will cause the post to use default thumbnail.
## Getting post
- **Request**
@ -539,7 +570,6 @@ data.
Retrieves information about an existing post.
## Deleting post
- **Request**
@ -560,7 +590,6 @@ data.
Deletes existing post. Related posts and tags are kept.
## Rating post
- **Request**
@ -589,7 +618,6 @@ data.
Updates score of authenticated user for given post. Valid scores are -1, 0
and 1.
## Adding post to favorites
- **Request**
@ -608,7 +636,6 @@ data.
Marks the post as favorite for authenticated user.
## Removing post from favorites
- **Request**
@ -627,7 +654,6 @@ data.
Unmarks the post as favorite for authenticated user.
## Getting featured post
- **Request**
@ -647,7 +673,6 @@ data.
client. If no post is featured, `<post>` is null and `snapshots` array is
empty.
## Featuring post
- **Request**
@ -666,7 +691,6 @@ data.
Features a post on the main page in web client.
## Listing comments
- **Request**
@ -722,7 +746,6 @@ data.
None.
## Creating comment
- **Request**
@ -751,7 +774,6 @@ data.
Creates a new comment under given post.
## Updating comment
- **Request**
@ -779,7 +801,6 @@ data.
Updates an existing comment text.
## Getting comment
- **Request**
@ -798,7 +819,6 @@ data.
Retrieves information about an existing comment.
## Deleting comment
- **Request**
@ -819,7 +839,6 @@ data.
Deletes existing comment.
## Rating comment
- **Request**
@ -848,7 +867,6 @@ data.
Updates score of authenticated user for given comment. Valid scores are -1,
0 and 1.
## Listing users
- **Request**
@ -900,7 +918,6 @@ data.
None.
## Creating user
- **Request**
@ -947,7 +964,6 @@ data.
administrator, whereas subsequent users will be given the rank indicated by
`default_rank` in the server's configuration.
## Updating user
- **Request**
@ -993,7 +1009,6 @@ data.
`manual`. `manual` avatar style requires client to pass also `avatar`
file - see [file uploads](#file-uploads) for details.
## Getting user
- **Request**
@ -1012,7 +1027,6 @@ data.
Retrieves information about an existing user.
## Deleting user
- **Request**
@ -1033,7 +1047,6 @@ data.
Deletes existing user.
## Password reset - step 1: mail request
- **Request**
@ -1058,7 +1071,6 @@ data.
mailbox, which is a strong indication they are the rightful owner of the
account.
## Password reset - step 2: confirmation
- **Request**
@ -1091,7 +1103,6 @@ data.
Generates a new password for given user. Password is sent as plain-text, so
it is recommended to connect through HTTPS.
## Listing snapshots
- **Request**
@ -1133,7 +1144,6 @@ data.
None.
## Getting global info
- **Request**
@ -1325,29 +1335,34 @@ One file together with its metadata posted to the site.
```json5
{
"id": <id>,
"creationTime": <creation-time>,
"lastEditTime": <last-edit-time>,
"safety": <safety>,
"source": <source>,
"type": <type>,
"checksum": <checksum>,
"source": <source>,
"canvasWidth": <canvas-width>,
"canvasHeight": <canvas-height>,
"contentUrl": <content-url>,
"thumbnailUrl": <thumbnail-url>,
"flags": <flags>,
"tags": <tags>,
"relations": <relations>,
"creationTime": <creation-time>,
"lastEditTime": <last-edit-time>,
"notes": <notes>,
"user": <user>,
"score": <score>,
"ownScore": <own-score>,
"favoritedBy": <favorited-by>,
"featureCount": <feature-count>,
"lastFeatureTime": <last-feature-time>
"lastFeatureTime": <last-feature-time>,
"favoritedBy": <favorited-by>
}
```
**Field meaning**
- `<id>`: the post identifier.
- `<creation-time>`: time the tag was created, formatted as per RFC 3339.
- `<last-edit-time>`: time the tag was edited, formatted as per RFC 3339.
- `<safety>`: whether the post is safe for work.
Available values:
@ -1356,6 +1371,7 @@ One file together with its metadata posted to the site.
- `"sketchy"`
- `"unsafe"`
- `<source>`: where the post was grabbed form, supplied by the user.
- `<type>`: the type of the post.
Available values:
@ -1368,24 +1384,25 @@ One file together with its metadata posted to the site.
- `<checksum>`: the file checksum. Used in snapshots to signify changes of the
post content.
- `<source>`: where the post was grabbed form, supplied by the user.
- `<canvas-width>` and `<canvas-height>`: the original width and height of the
post content.
- `<content-url>`: where the post content is located.
- `<thumbnail-url>`: where the post thumbnail is located.
- `<flags>`: various flags such as whether the post is looped, represented as
array of plain strings.
- `<tags>`: list of tag names the post is tagged with.
- `<relations>`: a list of related post IDs. Links to related posts are shown
to the user by the web client.
- `<creation-time>`: time the tag was created, formatted as per RFC 3339.
- `<last-edit-time>`: time the tag was edited, formatted as per RFC 3339.
- `<notes>`: a list of post annotations, serialized as list of [note
resources](#note).
- `<user>`: who created the post, serialized as [user resource](#user).
- `<score>`: the collective score (+1/-1 rating) of the given post.
- `<own-score>`: the score (+1/-1 rating) of the given post by the
authenticated user.
- `<favorited-by>`: list of users, serialized as [user resources](#user).
- `<feature-count>`: how many times has the post been featured.
- `<last-feature-time>`: the last time the post was featured, formatted as per
RFC 3339.
- `<favorited-by>`: list of users, serialized as [user resources](#user).
## Detailed post
**Description**
@ -1416,6 +1433,27 @@ A post with extra information.
earlier versions.
- `<comment>`: a [comment resource](#comment) for given post.
## Note
**Description**
A text annotation rendered on top of the post.
**Structure**
```json5
{
"polygon": <list-of-points>,
"text": <text>,
}
```
**Field meaning**
- `<list-of-points>`: where to draw the annotation. Each point must have
coordinates within 0 to 1. For example, `[[0,0],[0,1],[1,1],[1,0]]` will draw
the annotation on the whole post, whereas `[[0,0],[0,0.5],[0.5,0.5],[0.5,0]]`
will draw it inside the post's upper left quarter.
- `<text>`: the annotation text. The client should render is as Markdown.
## Comment
**Description**
@ -1439,6 +1477,7 @@ A comment under a post.
**Field meaning**
- `<id>`: the comment identifier.
- `<post>`: a post resource the post is linked with.
- `<text>`: the comment content. The client should render is as Markdown.
- `<author>`: a user resource the post is created by.
- `<creation-time>`: time the comment was created, formatted as per RFC 3339.
- `<last-edit-time>`: time the comment was edited, formatted as per RFC 3339.
@ -1542,7 +1581,7 @@ A snapshot is a version of a database resource.
"checksum": "deadbeef",
"tags": ["tag1", "tag2"],
"relations": [1, 2],
"notes": [{"polygon": [[1,1],[200,1],[200,200],[1,200]], "text": "..."}],
"notes": [<note1>, <note2>, <note3>],
"flags": ["loop"],
"featured": false
}

View File

@ -15,6 +15,7 @@ from szurubooru.api.comment_api import (
CommentDetailApi,
CommentScoreApi)
from szurubooru.api.post_api import (
PostListApi,
PostDetailApi,
PostFeatureApi,
PostScoreApi,

View File

@ -12,8 +12,16 @@ class Context(object):
def has_param(self, name):
return name in self.input
def get_file(self, name):
return self.files.get(name, None)
def has_file(self, name):
return name in self.files
def get_file(self, name, required=False):
if name in self.files:
return self.files[name]
if not required:
return None
raise errors.MissingRequiredFileError(
'Required file %r is missing.' % name)
def get_param_as_list(self, name, required=False, default=None):
if name in self.input:
@ -23,7 +31,8 @@ class Context(object):
return param
if not required:
return default
raise errors.ValidationError('Required paramter %r is missing.' % name)
raise errors.MissingRequiredParameterError(
'Required paramter %r is missing.' % name)
def get_param_as_string(self, name, required=False, default=None):
if name in self.input:
@ -32,12 +41,14 @@ class Context(object):
try:
param = ','.join(param)
except:
raise errors.ValidationError(
'Parameter %r is invalid - expected simple string.' % name)
raise errors.InvalidParameterError(
'Parameter %r is invalid - expected simple string.'
% name)
return param
if not required:
return default
raise errors.ValidationError('Required paramter %r is missing.' % name)
raise errors.MissingRequiredParameterError(
'Required paramter %r is missing.' % name)
# pylint: disable=redefined-builtin,too-many-arguments
def get_param_as_int(
@ -47,21 +58,21 @@ class Context(object):
try:
val = int(val)
except (ValueError, TypeError):
raise errors.ValidationError(
raise errors.InvalidParameterError(
'Parameter %r is invalid: the value must be an integer.'
% name)
if min is not None and val < min:
raise errors.ValidationError(
raise errors.InvalidParameterError(
'Parameter %r is invalid: the value must be at least %r.'
% (name, min))
if max is not None and val > max:
raise errors.ValidationError(
raise errors.InvalidParameterError(
'Parameter %r is invalid: the value may not exceed %r.'
% (name, max))
return val
if not required:
return default
raise errors.ValidationError(
raise errors.MissingRequiredParameterError(
'Required parameter %r is missing.' % name)
class Request(falcon.Request):

View File

@ -1,6 +1,29 @@
from szurubooru.api.base_api import BaseApi
from szurubooru.func import auth, tags, posts, snapshots, favorites, scores
class PostListApi(BaseApi):
def post(self, ctx):
auth.verify_privilege(ctx.user, 'posts:create')
content = ctx.get_file('content', required=True)
tag_names = ctx.get_param_as_list('tags', required=True)
safety = ctx.get_param_as_string('safety', required=True)
source = ctx.get_param_as_string('source', required=False, default=None)
relations = ctx.get_param_as_list('relations', required=False) or []
notes = ctx.get_param_as_list('notes', required=False) or []
flags = ctx.get_param_as_list('flags', required=False) or []
post = posts.create_post(content, tag_names, ctx.user)
posts.update_post_safety(post, safety)
posts.update_post_source(post, source)
posts.update_post_relations(post, relations)
posts.update_post_notes(post, notes)
posts.update_post_flags(post, flags)
ctx.session.add(post)
snapshots.save_entity_creation(post, ctx.user)
ctx.session.commit()
tags.export_to_json()
return posts.serialize_post_with_details(post, ctx.user)
class PostDetailApi(BaseApi):
def get(self, ctx, post_id):
auth.verify_privilege(ctx.user, 'posts:view')

View File

@ -65,6 +65,7 @@ def create_app():
app.add_route('/tag-merge/', api.TagMergeApi())
app.add_route('/tag-siblings/{tag_name}', api.TagSiblingsApi())
app.add_route('/posts/', api.PostListApi())
app.add_route('/post/{post_id}', api.PostDetailApi())
app.add_route('/post/{post_id}/score', api.PostScoreApi())
app.add_route('/post/{post_id}/favorite', api.PostFavoriteApi())

View File

@ -71,26 +71,29 @@ class Post(Base):
SAFETY_SAFE = 'safe'
SAFETY_SKETCHY = 'sketchy'
SAFETY_UNSAFE = 'unsafe'
TYPE_IMAGE = 'anim'
TYPE_ANIMATION = 'anim'
TYPE_FLASH = 'flash'
TYPE_IMAGE = 'image'
TYPE_ANIMATION = 'animation'
TYPE_VIDEO = 'video'
TYPE_YOUTUBE = 'youtube'
FLAG_LOOP_VIDEO = 1
TYPE_FLASH = 'flash'
# basic meta
post_id = Column('id', Integer, primary_key=True)
user_id = Column('user_id', Integer, ForeignKey('user.id'))
creation_time = Column('creation_time', DateTime, nullable=False)
last_edit_time = Column('last_edit_time', DateTime)
safety = Column('safety', String(32), nullable=False)
source = Column('source', String(200))
flags = Column('flags', PickleType, default=None)
# content description
type = Column('type', String(32), nullable=False)
checksum = Column('checksum', String(64), nullable=False)
source = Column('source', String(200))
file_size = Column('file_size', Integer)
canvas_width = Column('image_width', Integer)
canvas_height = Column('image_height', Integer)
flags = Column('flags', PickleType, default=None)
mime_type = Column('mime-type', String(32), nullable=False)
# foreign tables
user = relationship('User')
tags = relationship('Tag', backref='posts', secondary='post_tag')
relations = relationship(
@ -106,8 +109,9 @@ class Post(Base):
'PostFavorite', cascade='all, delete-orphan', lazy='joined')
notes = relationship(
'PostNote', cascade='all, delete-orphan', lazy='joined')
comments = relationship('Comment')
# dynamic columns
tag_count = column_property(
select([func.count(PostTag.tag_id)]) \
.where(PostTag.post_id == post_id) \

View File

@ -5,3 +5,7 @@ class ValidationError(RuntimeError): pass
class SearchError(RuntimeError): pass
class NotFoundError(RuntimeError): pass
class ProcessingError(RuntimeError): pass
class MissingRequiredFileError(ValidationError): pass
class MissingRequiredParameterError(ValidationError): pass
class InvalidParameterError(ValidationError): pass

View File

@ -1,8 +1,23 @@
import os
from szurubooru import config
def _get_full_path(path):
return os.path.join(config.config['data_dir'], path)
def delete(path):
full_path = _get_full_path(path)
if os.path.exists(full_path):
os.unlink(full_path)
def get(path):
full_path = _get_full_path(path)
if not os.path.exists(full_path):
return None
with open(full_path, 'rb') as handle:
return handle.read()
def save(path, content):
full_path = os.path.join(config.config['data_dir'], path)
full_path = _get_full_path(path)
os.makedirs(os.path.dirname(full_path), exist_ok=True)
with open(full_path, 'wb') as handle:
handle.write(content)

View File

@ -1,3 +1,4 @@
import json
import subprocess
from szurubooru import errors
@ -7,6 +8,19 @@ _SCALE_FIT_FMT = \
class Image(object):
def __init__(self, content):
self.content = content
self._reload_info()
@property
def width(self):
return self.info['streams'][0]['width']
@property
def height(self):
return self.info['streams'][0]['height']
@property
def frames(self):
return self.info['streams'][0]['nb_read_frames']
def resize_fill(self, width, height):
self.content = self._execute([
@ -17,6 +31,7 @@ class Image(object):
'-vcodec', 'png',
'-',
])
self._reload_info()
def to_png(self):
return self._execute([
@ -36,9 +51,9 @@ class Image(object):
'-',
])
def _execute(self, cli):
def _execute(self, cli, program='ffmpeg'):
proc = subprocess.Popen(
['ffmpeg', '-loglevel', '24'] + cli,
[program, '-loglevel', '24'] + cli,
stdout=subprocess.PIPE,
stdin=subprocess.PIPE,
stderr=subprocess.PIPE)
@ -47,3 +62,15 @@ class Image(object):
raise errors.ProcessingError(
'Error while processing image.\n' + err.decode('utf-8'))
return out
def _reload_info(self):
self.info = json.loads(self._execute([
'-of', 'json',
'-select_streams', 'v',
'-show_streams',
'-count_frames',
'-i', '-',
], program='ffprobe').decode('utf-8'))
assert 'streams' in self.info
if len(self.info['streams']) != 1:
raise errors.ProcessingError('Multiple video streams detected.')

View File

@ -2,7 +2,7 @@ import re
def get_mime_type(content):
if not content:
return None
return 'application/octet-stream'
if content[0:3] in (b'CWS', b'FWS', b'ZWS'):
return 'application/x-shockwave-flash'
@ -33,7 +33,7 @@ def get_extension(mime_type):
'video/mp4': 'mp4',
'video/webm': 'webm',
}
return extension_map.get(mime_type.strip().lower(), None)
return extension_map.get((mime_type or '').strip().lower(), None)
def is_flash(mime_type):
return mime_type.lower() == 'application/x-shockwave-flash'

View File

@ -1,10 +1,62 @@
import datetime
import sqlalchemy
from szurubooru import db, errors
from szurubooru.func import users, snapshots, scores, comments
from szurubooru import config, db, errors
from szurubooru.func import (
users, snapshots, scores, comments, tags, util, mime, images, files)
EMPTY_PIXEL = \
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'
class PostNotFoundError(errors.NotFoundError): pass
class PostAlreadyFeaturedError(errors.ValidationError): pass
class PostAlreadyUploadedError(errors.ValidationError): pass
class InvalidPostSafetyError(errors.ValidationError): pass
class InvalidPostSourceError(errors.ValidationError): pass
class InvalidPostContentError(errors.ValidationError): pass
class InvalidPostRelationError(errors.ValidationError): pass
class InvalidPostNoteError(errors.ValidationError): pass
class InvalidPostFlagError(errors.ValidationError): pass
SAFETY_MAP = {
db.Post.SAFETY_SAFE: 'safe',
db.Post.SAFETY_SKETCHY: 'sketchy',
db.Post.SAFETY_UNSAFE: 'unsafe',
}
TYPE_MAP = {
db.Post.TYPE_IMAGE: 'image',
db.Post.TYPE_ANIMATION: 'animation',
db.Post.TYPE_VIDEO: 'video',
db.Post.TYPE_FLASH: 'flash',
}
def get_post_content_url(post):
return '%s/posts/%d.%s' % (
config.config['data_url'].rstrip('/'),
post.post_id,
mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_url(post):
return '%s/generated-thumbnails/%d.jpg' % (
config.config['data_url'].rstrip('/'),
post.post_id)
def get_post_content_path(post):
return 'posts/%d.%s' % (
post.post_id, mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_path(post):
return 'generated-thumbnails/%d.jpg' % (post.post_id)
def get_post_thumbnail_backup_path(post):
return 'posts/custom-thumbnails/%d.dat' % (post.post_id)
def serialize_note(note):
return {
'polygon': note.path,
'text': note.text,
}
def serialize_post(post, authenticated_user):
if not post:
@ -14,20 +66,19 @@ def serialize_post(post, authenticated_user):
'id': post.post_id,
'creationTime': post.creation_time,
'lastEditTime': post.last_edit_time,
'safety': post.safety,
'type': post.type,
'checksum': post.checksum,
'safety': SAFETY_MAP[post.safety],
'source': post.source,
'type': TYPE_MAP[post.type],
'checksum': post.checksum,
'fileSize': post.file_size,
'canvasWidth': post.canvas_width,
'canvasHeight': post.canvas_height,
'contentUrl': get_post_content_url(post),
'thumbnailUrl': get_post_thumbnail_url(post),
'flags': post.flags,
'tags': [tag.first_name for tag in post.tags],
'relations': [rel.post_id for rel in post.relations],
'notes': sorted([{
'path': note.path,
'text': note.text,
} for note in post.notes]),
'notes': sorted(serialize_note(note) for note in post.notes),
'user': users.serialize_user(post.user, authenticated_user),
'score': post.score,
'featureCount': post.feature_count,
@ -75,6 +126,140 @@ def try_get_featured_post():
.first()
return post_feature.post if post_feature else None
def create_post(content, tag_names, user):
post = db.Post()
post.safety = db.Post.SAFETY_SAFE
post.user = user
post.creation_time = datetime.datetime.now()
post.flags = []
# we'll need post ID
post.type = ''
post.checksum = ''
post.mime_type = ''
db.session.add(post)
db.session.flush()
update_post_content(post, content)
update_post_tags(post, tag_names)
return post
def update_post_safety(post, safety):
safety = util.flip(SAFETY_MAP).get(safety, None)
if not safety:
raise InvalidPostSafetyError(
'Safety can be either of %r.', list(SAFETY_MAP.values()))
post.safety = safety
def update_post_source(post, source):
if util.value_exceeds_column_size(source, db.Post.source):
raise InvalidPostSourceError('Source is too long.')
post.source = source
def update_post_content(post, content):
if not content:
raise InvalidPostContentError('Post content missing.')
post.mime_type = mime.get_mime_type(content)
if mime.is_flash(post.mime_type):
post.type = db.Post.TYPE_FLASH
elif mime.is_image(post.mime_type):
if mime.is_animated_gif(content):
post.type = db.Post.TYPE_ANIMATION
else:
post.type = db.Post.TYPE_IMAGE
elif mime.is_video(post.mime_type):
post.type = db.Post.TYPE_VIDEO
else:
raise InvalidPostContentError('Unhandled file type: %r' % post.mime_type)
post.checksum = util.get_md5(content)
other_post = db.session \
.query(db.Post) \
.filter(db.Post.checksum == post.checksum) \
.filter(db.Post.post_id != post.post_id) \
.one_or_none()
if other_post:
raise PostAlreadyUploadedError(
'Post already uploaded (%d)' % other_post.post_id)
post.file_size = len(content)
try:
image = images.Image(content)
post.canvas_width = image.width
post.canvas_height = image.height
except errors.ProcessingError:
post.canvas_width = None
post.canvas_height = None
files.save(get_post_content_path(post), content)
update_post_thumbnail(post, content=None, delete=False)
def update_post_thumbnail(post, content=None, delete=True):
if content is None:
content = files.get(get_post_content_path(post))
if delete:
files.delete(get_post_thumbnail_backup_path(post))
else:
files.save(get_post_thumbnail_backup_path(post), content)
try:
image = images.Image(content)
image.resize_fill(
int(config.config['thumbnails']['post_width']),
int(config.config['thumbnails']['post_height']))
files.save(get_post_thumbnail_path(post), image.to_jpeg())
except errors.ProcessingError:
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
def update_post_tags(post, tag_names):
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
post.tags = existing_tags + new_tags
def update_post_relations(post, post_ids):
relations = db.session \
.query(db.Post) \
.filter(db.Post.post_id.in_(post_ids)) \
.all()
if len(relations) != len(post_ids):
raise InvalidPostRelationError('One of relations does not exist.')
post.relations = relations
def update_post_notes(post, notes):
post.notes = []
for note in notes:
for field in ('polygon', 'text'):
if field not in note:
raise InvalidPostNoteError('Note is missing %r field.' % field)
if not note['text']:
raise InvalidPostNoteError('A note\'s text cannot be empty.')
if len(note['polygon']) < 3:
raise InvalidPostNoteError(
'A note\'s polygon must have at least 3 points.')
for point in note['polygon']:
if len(point) != 2:
raise InvalidPostNoteError(
'A point in note\'s polygon must have two coordinates.')
try:
pos_x = float(point[0])
pos_y = float(point[1])
if not 0 <= pos_x <= 1 or not 0 <= pos_y <= 1:
raise InvalidPostNoteError(
'A point in note\'s polygon must be in 0..1 range.')
except ValueError:
raise InvalidPostNoteError(
'A point in note\'s polygon must be numeric.')
if util.value_exceeds_column_size(note['text'], db.PostNote.text):
raise InvalidPostNoteError('Note text is too long.')
post.notes.append(
db.PostNote(polygon=note['polygon'], text=note['text']))
def update_post_flags(post, flags):
available_flags = ('loop',)
for flag in flags:
if flag not in available_flags:
raise InvalidPostFlagError(
'Flag must be one of %r.' % available_flags)
post.flags = flags
def feature_post(post, user):
post_feature = db.PostFeature()
post_feature.time = datetime.datetime.now()

View File

@ -77,7 +77,7 @@ def try_get_default_category():
.query(db.TagCategory) \
.order_by(db.TagCategory.tag_category_id.asc()) \
.limit(1) \
.one()
.first()
def get_default_category():
category = try_get_default_category()

View File

@ -12,6 +12,9 @@ class TagIsInUseError(errors.ValidationError): pass
class InvalidTagNameError(errors.ValidationError): pass
class InvalidTagRelationError(errors.ValidationError): pass
DEFAULT_CATEGORY_NAME = 'Default'
DEFAULT_CATEGORY_COLOR = 'default'
def _verify_name_validity(name):
name_regex = config.config['tag_name_regex']
if not re.match(name_regex, name):
@ -26,6 +29,13 @@ def _lower_list(names):
def _check_name_intersection(names1, names2):
return len(set(_lower_list(names1)).intersection(_lower_list(names2))) > 0
def _get_default_category_name():
tag_category = tag_categories.try_get_default_category()
if tag_category:
return tag_category.name
else:
return DEFAULT_CATEGORY_NAME
def serialize_tag(tag):
return {
'names': [tag_name.name for tag_name in tag.names],
@ -104,23 +114,24 @@ def get_or_create_tags_by_names(names):
names = util.icase_unique(names)
for name in names:
_verify_name_validity(name)
related_tags = get_tags_by_names(names)
existing_tags = get_tags_by_names(names)
new_tags = []
tag_category_name = _get_default_category_name()
for name in names:
found = False
for related_tag in related_tags:
if _check_name_intersection(_get_plain_names(related_tag), [name]):
for existing_tag in existing_tags:
if _check_name_intersection(_get_plain_names(existing_tag), [name]):
found = True
break
if not found:
new_tag = create_tag(
names=[name],
category_name=tag_categories.get_default_category().name,
category_name=tag_category_name,
suggestions=[],
implications=[])
db.session.add(new_tag)
new_tags.append(new_tag)
return related_tags, new_tags
return existing_tags, new_tags
def get_tag_siblings(tag):
tag_alias = sqlalchemy.orm.aliased(db.Tag)
@ -159,7 +170,8 @@ def update_tag_category_name(tag, category_name):
.filter(db.TagCategory.name == category_name) \
.first()
if not category:
category = tag_categories.create_category(category_name, 'default')
category = tag_categories.create_category(
category_name, DEFAULT_CATEGORY_COLOR)
db.session.add(category)
tag.category = category

View File

@ -1,5 +1,4 @@
import datetime
import hashlib
import re
from sqlalchemy import func
from szurubooru import config, db, errors
@ -28,11 +27,9 @@ def serialize_user(user, authenticated_user, force_show_email=False):
}
if user.avatar_style == user.AVATAR_GRAVATAR:
md5 = hashlib.md5()
md5.update((user.email or user.name).lower().encode('utf-8'))
digest = md5.hexdigest()
ret['avatarUrl'] = 'http://gravatar.com/avatar/%s?d=retro&s=%d' % (
digest, config.config['thumbnails']['avatar_width'])
util.get_md5((user.email or user.name).lower()),
config.config['thumbnails']['avatar_width'])
else:
ret['avatarUrl'] = '%s/avatars/%s.jpg' % (
config.config['data_url'].rstrip('/'), user.name.lower())

View File

@ -1,8 +1,19 @@
import datetime
import hashlib
import re
from sqlalchemy.inspection import inspect
from szurubooru.errors import ValidationError
def get_md5(source):
if not isinstance(source, bytes):
source = source.encode('utf-8')
md5 = hashlib.md5()
md5.update(source)
return md5.hexdigest()
def flip(source):
return {v: k for k, v in source.items()}
def get_resource_info(entity):
serializers = {
'tag': lambda tag: tag.first_name,
@ -96,4 +107,9 @@ def icase_unique(source):
return target
def value_exceeds_column_size(value, column):
return len(value) > column.property.columns[0].type.length
if not value:
return False
max_length = column.property.columns[0].type.length
if max_length is None:
return False
return len(value) > max_length

View File

@ -0,0 +1,20 @@
'''
Add mime type to posts
Revision ID: 23abaf4a0a4b
Created at: 2016-05-02 00:02:33.024885
'''
import sqlalchemy as sa
from alembic import op
revision = '23abaf4a0a4b'
down_revision = 'ed6dd16a30f3'
branch_labels = None
depends_on = None
def upgrade():
op.add_column('post', sa.Column('mime-type', sa.String(length=32), nullable=False))
def downgrade():
op.drop_column('post', 'mime-type')

View File

@ -6,6 +6,7 @@ from szurubooru.func import util, posts
@pytest.fixture
def test_ctx(config_injector, context_factory, post_factory, user_factory):
config_injector({
'data_url': 'http://example.com',
'ranks': ['anonymous', 'regular_user'],
'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord'},
'privileges': {'comments:create': 'regular_user'},

View File

@ -6,6 +6,7 @@ from szurubooru.func import util, comments, scores
@pytest.fixture
def test_ctx(config_injector, context_factory, user_factory, comment_factory):
config_injector({
'data_url': 'http://example.com',
'ranks': ['anonymous', 'regular_user', 'mod'],
'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord'},
'privileges': {

View File

@ -6,6 +6,7 @@ from szurubooru.func import util, comments
@pytest.fixture
def test_ctx(context_factory, config_injector, user_factory, comment_factory):
config_injector({
'data_url': 'http://example.com',
'ranks': ['anonymous', 'regular_user', 'mod', 'admin'],
'rank_names': {'regular_user': 'Peasant'},
'privileges': {

View File

@ -6,6 +6,7 @@ from szurubooru.func import util, comments
@pytest.fixture
def test_ctx(config_injector, context_factory, user_factory, comment_factory):
config_injector({
'data_url': 'http://example.com',
'ranks': ['anonymous', 'regular_user', 'mod'],
'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord', 'mod': 'King'},
'privileges': {

View File

@ -0,0 +1,133 @@
import datetime
import os
import unittest.mock
import pytest
from szurubooru import api, db, errors
from szurubooru.func import posts, tags, snapshots
@pytest.fixture(autouse=True)
def inject_config(config_injector):
config_injector({
'ranks': ['anonymous', 'regular_user'],
'privileges': {'posts:create': 'regular_user'},
})
def test_creating_minimal_posts(
context_factory, post_factory, user_factory):
auth_user = user_factory(rank='regular_user')
post = post_factory()
db.session.add(post)
db.session.flush()
with unittest.mock.patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post_with_details'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'):
posts.create_post.return_value = post
posts.serialize_post_with_details.return_value = 'serialized post'
result = api.PostListApi().post(
context_factory(
input={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
},
files={
'content': 'post-content',
},
user=auth_user))
assert result == 'serialized post'
posts.create_post.assert_called_once_with(
'post-content', ['tag1', 'tag2'], auth_user)
posts.update_post_safety.assert_called_once_with(post, 'safe')
posts.update_post_source.assert_called_once_with(post, None)
posts.update_post_relations.assert_called_once_with(post, [])
posts.update_post_notes.assert_called_once_with(post, [])
posts.update_post_flags.assert_called_once_with(post, [])
posts.serialize_post_with_details.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with()
snapshots.save_entity_creation.assert_called_once_with(post, auth_user)
def test_creating_full_posts(context_factory, post_factory, user_factory):
auth_user = user_factory(rank='regular_user')
post = post_factory()
db.session.add(post)
db.session.flush()
with unittest.mock.patch('szurubooru.func.posts.create_post'), \
unittest.mock.patch('szurubooru.func.posts.update_post_safety'), \
unittest.mock.patch('szurubooru.func.posts.update_post_source'), \
unittest.mock.patch('szurubooru.func.posts.update_post_relations'), \
unittest.mock.patch('szurubooru.func.posts.update_post_notes'), \
unittest.mock.patch('szurubooru.func.posts.update_post_flags'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post_with_details'), \
unittest.mock.patch('szurubooru.func.tags.export_to_json'), \
unittest.mock.patch('szurubooru.func.snapshots.save_entity_creation'):
posts.create_post.return_value = post
posts.serialize_post_with_details.return_value = 'serialized post'
result = api.PostListApi().post(
context_factory(
input={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
'relations': [1, 2],
'source': 'source',
'notes': ['note1', 'note2'],
'flags': ['flag1', 'flag2'],
},
files={
'content': 'post-content',
},
user=auth_user))
assert result == 'serialized post'
posts.create_post.assert_called_once_with(
'post-content', ['tag1', 'tag2'], auth_user)
posts.update_post_safety.assert_called_once_with(post, 'safe')
posts.update_post_source.assert_called_once_with(post, 'source')
posts.update_post_relations.assert_called_once_with(post, [1, 2])
posts.update_post_notes.assert_called_once_with(post, ['note1', 'note2'])
posts.update_post_flags.assert_called_once_with(post, ['flag1', 'flag2'])
posts.serialize_post_with_details.assert_called_once_with(post, auth_user)
tags.export_to_json.assert_called_once_with()
snapshots.save_entity_creation.assert_called_once_with(post, auth_user)
@pytest.mark.parametrize('field', ['tags', 'safety'])
def test_trying_to_omit_mandatory_field(context_factory, user_factory, field):
input = {
'safety': 'safe',
'tags': ['tag1', 'tag2'],
}
del input[field]
with pytest.raises(errors.MissingRequiredParameterError):
api.PostListApi().post(
context_factory(
input=input,
files={'content': '...'},
user=user_factory(rank='regular_user')))
def test_trying_to_omit_content(context_factory, user_factory):
with pytest.raises(errors.MissingRequiredFileError):
api.PostListApi().post(
context_factory(
input={
'safety': 'safe',
'tags': ['tag1', 'tag2'],
},
user=user_factory(rank='regular_user')))
def test_trying_to_create_without_privileges(context_factory, user_factory):
with pytest.raises(errors.AuthError):
api.PostListApi().post(
context_factory(
input={'name': 'meta', 'colro': 'black'},
user=user_factory(rank='anonymous')))

View File

@ -6,6 +6,7 @@ from szurubooru.func import util, posts
@pytest.fixture
def test_ctx(config_injector, context_factory, user_factory, post_factory):
config_injector({
'data_url': 'http://example.com',
'ranks': ['anonymous', 'regular_user', 'mod'],
'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord'},
'privileges': {

View File

@ -6,6 +6,7 @@ from szurubooru.func import util, posts
@pytest.fixture
def test_ctx(context_factory, config_injector, user_factory, post_factory):
config_injector({
'data_url': 'http://example.com',
'privileges': {
'posts:feature': 'regular_user',
'posts:view': 'regular_user',

View File

@ -5,6 +5,7 @@ from szurubooru.func import util, posts, scores
@pytest.fixture
def test_ctx(config_injector, context_factory, user_factory, post_factory):
config_injector({
'data_url': 'http://example.com',
'ranks': ['anonymous', 'regular_user'],
'rank_names': {'anonymous': 'Peasant', 'regular_user': 'Lord'},
'privileges': {'posts:score': 'regular_user'},

View File

@ -6,6 +6,7 @@ from szurubooru.func import util, posts
@pytest.fixture
def test_ctx(context_factory, config_injector, user_factory, post_factory):
config_injector({
'data_url': 'http://example.com',
'privileges': {
'posts:list': 'regular_user',
'posts:view': 'regular_user',

Binary file not shown.

Before

Width:  |  Height:  |  Size: 14 B

After

Width:  |  Height:  |  Size: 43 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 107 B

After

Width:  |  Height:  |  Size: 12 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 100 B

Binary file not shown.

Before

Width:  |  Height:  |  Size: 67 B

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

View File

@ -1,4 +1,5 @@
import contextlib
import os
import datetime
import uuid
import pytest
@ -136,6 +137,7 @@ def post_factory():
post.type = type
post.checksum = checksum
post.flags = []
post.mime_type = 'application/octet-stream'
post.creation_time = datetime.datetime(1996, 1, 1)
return post
return factory
@ -156,3 +158,11 @@ def comment_factory(user_factory, post_factory):
comment.creation_time = datetime.datetime(1996, 1, 1)
return comment
return factory
@pytest.fixture
def read_asset():
def get(path):
path = os.path.join(os.path.dirname(__file__), 'assets', path)
with open(path, 'rb') as handle:
return handle.read()
return get

View File

@ -13,6 +13,7 @@ def test_saving_post(post_factory, user_factory, tag_factory):
post.checksum = 'deadbeef'
post.creation_time = datetime(1997, 1, 1)
post.last_edit_time = datetime(1998, 1, 1)
post.mime_type = 'application/whatever'
db.session.add_all([user, tag1, tag2, related_post1, related_post2, post])
post.user = user

View File

@ -6,7 +6,7 @@ from szurubooru.func import mime
('mp4.mp4', 'video/mp4'),
('webm.webm', 'video/webm'),
('flash.swf', 'application/x-shockwave-flash'),
('png-transparent.png', 'image/png'),
('png.png', 'image/png'),
('jpeg.jpg', 'image/jpeg'),
('gif.gif', 'image/gif'),
])

View File

@ -0,0 +1,463 @@
import os
import datetime
import unittest.mock
import pytest
from szurubooru import db
from szurubooru.func import posts, users, comments, snapshots, tags, images
@pytest.mark.parametrize('input_mime_type,expected_url', [
('image/jpeg', 'http://example.com/posts/1.jpg'),
('image/gif', 'http://example.com/posts/1.gif'),
('totally/unknown', 'http://example.com/posts/1.dat'),
])
def test_get_post_url(input_mime_type, expected_url, config_injector):
config_injector({'data_url': 'http://example.com/'})
post = db.Post()
post.post_id = 1
post.mime_type = input_mime_type
assert posts.get_post_content_url(post) == expected_url
@pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif'])
def test_get_post_thumbnail_url(input_mime_type, config_injector):
config_injector({'data_url': 'http://example.com/'})
post = db.Post()
post.post_id = 1
post.mime_type = input_mime_type
assert posts.get_post_thumbnail_url(post) \
== 'http://example.com/generated-thumbnails/1.jpg'
@pytest.mark.parametrize('input_mime_type,expected_path', [
('image/jpeg', 'posts/1.jpg'),
('image/gif', 'posts/1.gif'),
('totally/unknown', 'posts/1.dat'),
])
def test_get_post_content_path(input_mime_type, expected_path):
post = db.Post()
post.post_id = 1
post.mime_type = input_mime_type
assert posts.get_post_content_path(post) == expected_path
@pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif'])
def test_get_post_thumbnail_path(input_mime_type):
post = db.Post()
post.post_id = 1
post.mime_type = input_mime_type
assert posts.get_post_thumbnail_path(post) == 'generated-thumbnails/1.jpg'
@pytest.mark.parametrize('input_mime_type', ['image/jpeg', 'image/gif'])
def test_get_post_thumbnail_backup_path(input_mime_type):
post = db.Post()
post.post_id = 1
post.mime_type = input_mime_type
assert posts.get_post_thumbnail_backup_path(post) \
== 'posts/custom-thumbnails/1.dat'
def test_serialize_note():
note = db.PostNote()
note.path = [[0, 1], [1, 1], [1, 0], [0, 0]]
note.text = '...'
assert posts.serialize_note(note) == {
'polygon': [[0, 1], [1, 1], [1, 0], [0, 0]],
'text': '...'
}
def test_serialize_empty_post():
assert posts.serialize_post(None, None) is None
def test_serialize_post(post_factory, user_factory, tag_factory):
with unittest.mock.patch('szurubooru.func.users.serialize_user'):
users.serialize_user.side_effect = lambda user, auth_user: user.name
auth_user = user_factory(name='auth user')
post = db.Post()
post.post_id = 1
post.creation_time = datetime.datetime(1997, 1, 1)
post.last_edit_time = datetime.datetime(1998, 1, 1)
post.tags = [
tag_factory(names=['tag1', 'tag2']),
tag_factory(names=['tag3'])
]
post.safety = db.Post.SAFETY_SAFE
post.source = '4gag'
post.type = db.Post.TYPE_IMAGE
post.checksum = 'deadbeef'
post.mime_type = 'image/jpeg'
post.file_size = 100
post.user = user_factory(name='post author')
post.canvas_width = 200
post.canvas_height = 300
post.flags = ['loop']
db.session.add(post)
db.session.flush()
db.session.add_all([
db.PostFavorite(
post=post,
user=user_factory(name='fav1'),
time=datetime.datetime(1800, 1, 1)),
db.PostFeature(
post=post,
user=user_factory(),
time=datetime.datetime(1999, 1, 1)),
db.PostScore(
post=post,
user=auth_user,
score=-1,
time=datetime.datetime(1800, 1, 1)),
db.PostScore(
post=post,
user=user_factory(),
score=1,
time=datetime.datetime(1800, 1, 1)),
db.PostScore(
post=post,
user=user_factory(),
score=1,
time=datetime.datetime(1800, 1, 1))])
db.session.flush()
result = posts.serialize_post(post, auth_user)
assert result == {
'id': 1,
'creationTime': datetime.datetime(1997, 1, 1),
'lastEditTime': datetime.datetime(1998, 1, 1),
'safety': 'safe',
'source': '4gag',
'type': 'image',
'checksum': 'deadbeef',
'fileSize': 100,
'canvasWidth': 200,
'canvasHeight': 300,
'contentUrl': 'http://example.com/posts/1.jpg',
'thumbnailUrl': 'http://example.com/generated-thumbnails/1.jpg',
'flags': ['loop'],
'tags': ['tag1', 'tag3'],
'relations': [],
'notes': [],
'user': 'post author',
'score': 1,
'ownScore': -1,
'featureCount': 1,
'lastFeatureTime': datetime.datetime(1999, 1, 1),
'favoritedBy': ['fav1'],
}
def test_serialize_post_with_details(post_factory, comment_factory, user_factory):
with unittest.mock.patch('szurubooru.func.comments.serialize_comment'), \
unittest.mock.patch('szurubooru.func.snapshots.get_serialized_history'), \
unittest.mock.patch('szurubooru.func.posts.serialize_post'):
comments.serialize_comment.side_effect \
= lambda comment, auth_user: comment.user.name
posts.serialize_post.side_effect \
= lambda post, auth_user: post.post_id
snapshots.get_serialized_history.return_value = 'snapshot history'
auth_user = user_factory(name='auth user')
post = post_factory()
post.comments = [
comment_factory(user=user_factory(name='commenter1')),
comment_factory(user=user_factory(name='commenter2')),
]
db.session.add(post)
db.session.flush()
result = posts.serialize_post_with_details(post, auth_user)
assert result == {
'post': post.post_id,
'snapshots': 'snapshot history',
'comments': ['commenter1', 'commenter2'],
}
def test_get_post_count(post_factory):
previous_count = posts.get_post_count()
db.session.add_all([post_factory(), post_factory()])
new_count = posts.get_post_count()
assert previous_count == 0
assert new_count == 2
def test_try_get_post_by_id(post_factory):
post = post_factory()
db.session.add(post)
db.session.flush()
assert posts.try_get_post_by_id(post.post_id) == post
assert posts.try_get_post_by_id(post.post_id + 1) is None
def test_get_post_by_id(post_factory):
post = post_factory()
db.session.add(post)
db.session.flush()
assert posts.get_post_by_id(post.post_id) == post
with pytest.raises(posts.PostNotFoundError):
posts.get_post_by_id(post.post_id + 1)
def test_create_post(user_factory, fake_datetime):
with unittest.mock.patch('szurubooru.func.posts.update_post_content'), \
unittest.mock.patch('szurubooru.func.posts.update_post_tags'), \
fake_datetime('1997-01-01'):
auth_user = user_factory()
post = posts.create_post('content', ['tag'], auth_user)
assert post.creation_time == datetime.datetime(1997, 1, 1)
assert post.last_edit_time is None
posts.update_post_tags.assert_called_once_with(post, ['tag'])
posts.update_post_content.assert_called_once_with(post, 'content')
@pytest.mark.parametrize('input_safety,expected_safety', [
('safe', db.Post.SAFETY_SAFE),
('sketchy', db.Post.SAFETY_SKETCHY),
('unsafe', db.Post.SAFETY_UNSAFE),
])
def test_update_post_safety(input_safety, expected_safety):
post = db.Post()
posts.update_post_safety(post, input_safety)
assert post.safety == expected_safety
def test_update_post_invalid_safety():
post = db.Post()
with pytest.raises(posts.InvalidPostSafetyError):
posts.update_post_safety(post, 'bad')
def test_update_post_source():
post = db.Post()
posts.update_post_source(post, 'x')
assert post.source == 'x'
def test_update_post_invalid_source():
post = db.Post()
with pytest.raises(posts.InvalidPostSourceError):
posts.update_post_source(post, 'x' * 1000)
@pytest.mark.parametrize(
'input_file,expected_mime_type,expected_type,output_file_name', [
('png.png', 'image/png', db.Post.TYPE_IMAGE, '1.png'),
('jpeg.jpg', 'image/jpeg', db.Post.TYPE_IMAGE, '1.jpg'),
('gif.gif', 'image/gif', db.Post.TYPE_IMAGE, '1.gif'),
('gif-animated.gif', 'image/gif', db.Post.TYPE_ANIMATION, '1.gif'),
('webm.webm', 'video/webm', db.Post.TYPE_VIDEO, '1.webm'),
('mp4.mp4', 'video/mp4', db.Post.TYPE_VIDEO, '1.mp4'),
('flash.swf', 'application/x-shockwave-flash', db.Post.TYPE_FLASH, '1.swf'),
])
def test_update_post_content(
tmpdir,
config_injector,
post_factory,
read_asset,
input_file,
expected_mime_type,
expected_type,
output_file_name):
with unittest.mock.patch('szurubooru.func.util.get_md5', return_value='crc'):
config_injector({
'data_dir': str(tmpdir.mkdir('data')),
'thumbnails': {
'post_width': 300,
'post_height': 300,
},
})
post = post_factory(id=1)
db.session.add(post)
db.session.flush()
posts.update_post_content(post, read_asset(input_file))
assert post.mime_type == expected_mime_type
assert post.type == expected_type
assert post.checksum == 'crc'
assert os.path.exists(str(tmpdir) + '/data/posts/' + output_file_name)
def test_update_post_content_to_existing_content(
tmpdir, config_injector, post_factory, read_asset):
config_injector({
'data_dir': str(tmpdir.mkdir('data')),
'thumbnails': {
'post_width': 300,
'post_height': 300,
},
})
post = post_factory()
another_post = post_factory()
db.session.add_all([post, another_post])
db.session.flush()
posts.update_post_content(post, read_asset('png.png'))
with pytest.raises(posts.PostAlreadyUploadedError):
posts.update_post_content(another_post, read_asset('png.png'))
def test_update_post_content_broken_content(
tmpdir, config_injector, post_factory, read_asset):
# the rationale behind this behavior is to salvage user upload even if the
# server software thinks it's broken. chances are the server is wrong,
# especially about flash movies.
config_injector({
'data_dir': str(tmpdir.mkdir('data')),
'thumbnails': {
'post_width': 300,
'post_height': 300,
},
})
post = post_factory()
another_post = post_factory()
db.session.add_all([post, another_post])
db.session.flush()
posts.update_post_content(post, read_asset('png-broken.png'))
assert post.canvas_width is None
assert post.canvas_height is None
@pytest.mark.parametrize('input_content', [None, b'not a media file'])
def test_update_post_invalid_content(input_content):
post = db.Post()
with pytest.raises(posts.InvalidPostContentError):
posts.update_post_content(post, input_content)
def test_update_post_thumbnail_to_new_one(
tmpdir, config_injector, read_asset, post_factory):
config_injector({
'data_dir': str(tmpdir.mkdir('data')),
'thumbnails': {
'post_width': 300,
'post_height': 300,
},
})
post = post_factory(id=1)
db.session.add(post)
db.session.flush()
posts.update_post_content(post, read_asset('png.png'))
posts.update_post_thumbnail(post, read_asset('jpeg.jpg'))
assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat')
assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg')
with open(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat', 'rb') as handle:
assert handle.read() == read_asset('jpeg.jpg')
def test_update_post_thumbnail_to_default(
tmpdir, config_injector, read_asset, post_factory):
config_injector({
'data_dir': str(tmpdir.mkdir('data')),
'thumbnails': {
'post_width': 300,
'post_height': 300,
},
})
post = post_factory(id=1)
db.session.add(post)
db.session.flush()
posts.update_post_content(post, read_asset('png.png'))
posts.update_post_thumbnail(post, read_asset('jpeg.jpg'))
posts.update_post_thumbnail(post, None)
assert not os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat')
assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg')
def test_update_post_thumbnail_broken_thumbnail(
tmpdir, config_injector, read_asset, post_factory):
config_injector({
'data_dir': str(tmpdir.mkdir('data')),
'thumbnails': {
'post_width': 300,
'post_height': 300,
},
})
post = post_factory(id=1)
db.session.add(post)
db.session.flush()
posts.update_post_content(post, read_asset('png.png'))
posts.update_post_thumbnail(post, read_asset('png-broken.png'))
assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat')
assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg')
with open(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat', 'rb') as handle:
assert handle.read() == read_asset('png-broken.png')
with open(str(tmpdir) + '/data/generated-thumbnails/1.jpg', 'rb') as handle:
image = images.Image(handle.read())
assert image.width == 1
assert image.height == 1
def test_update_post_content_leaves_custom_thumbnail(
tmpdir, config_injector, read_asset, post_factory):
config_injector({
'data_dir': str(tmpdir.mkdir('data')),
'thumbnails': {
'post_width': 300,
'post_height': 300,
},
})
post = post_factory(id=1)
db.session.add(post)
db.session.flush()
posts.update_post_content(post, read_asset('png.png'))
posts.update_post_thumbnail(post, read_asset('jpeg.jpg'))
posts.update_post_content(post, read_asset('png.png'))
assert os.path.exists(str(tmpdir) + '/data/posts/custom-thumbnails/1.dat')
assert os.path.exists(str(tmpdir) + '/data/generated-thumbnails/1.jpg')
def test_update_post_tags(tag_factory):
post = db.Post()
with unittest.mock.patch('szurubooru.func.tags.get_or_create_tags_by_names'):
tags.get_or_create_tags_by_names.side_effect \
= lambda tag_names: \
([tag_factory(names=[name]) for name in tag_names], [])
posts.update_post_tags(post, ['tag1', 'tag2'])
assert len(post.tags) == 2
assert post.tags[0].names[0].name == 'tag1'
assert post.tags[1].names[0].name == 'tag2'
def test_update_post_relations(post_factory):
relation1 = post_factory()
relation2 = post_factory()
db.session.add_all([relation1, relation2])
db.session.flush()
post = db.Post()
posts.update_post_relations(post, [relation1.post_id, relation2.post_id])
assert len(post.relations) == 2
assert post.relations[0].post_id == relation1.post_id
assert post.relations[1].post_id == relation2.post_id
def test_update_post_non_existing_relations():
post = db.Post()
with pytest.raises(posts.InvalidPostRelationError):
posts.update_post_relations(post, [100])
def test_update_post_notes():
post = db.Post()
posts.update_post_notes(
post,
[
{'polygon': [[0, 0], [0, 1], [1, 0], [0, 0]], 'text': 'text1'},
{'polygon': [[0, 0], [0, 1], [1, 0], [0, 0]], 'text': 'text2'},
])
assert len(post.notes) == 2
assert post.notes[0].polygon == [[0, 0], [0, 1], [1, 0], [0, 0]]
assert post.notes[0].text == 'text1'
assert post.notes[1].polygon == [[0, 0], [0, 1], [1, 0], [0, 0]]
assert post.notes[1].text == 'text2'
@pytest.mark.parametrize('input', [
[{'polygon': [[0, 0]], 'text': '...'}],
[{'polygon': [[0, 0], [0, 0], [0, 2]], 'text': '...'}],
[{'polygon': [[0, 0], [0, 0], [0, '...']], 'text': '...'}],
[{'polygon': [[0, 0], [0, 0], [0, 0, 0]], 'text': '...'}],
[{'polygon': [[0, 0], [0, 0], [0]], 'text': '...'}],
[{'polygon': [[0, 0], [0, 0], [0, 1]], 'text': ''}],
[{'polygon': [[0, 0], [0, 0], [0, 1]], 'text': None}],
[{'text': '...'}],
[{'polygon': [[0, 0], [0, 0], [0, 1]]}],
])
def test_update_post_invalid_notes(input):
post = db.Post()
with pytest.raises(posts.InvalidPostNoteError):
posts.update_post_notes(post, input)
def test_update_post_flags():
post = db.Post()
posts.update_post_flags(post, ['loop'])
assert post.flags == ['loop']
def test_update_post_invalid_flags():
post = db.Post()
with pytest.raises(posts.InvalidPostFlagError):
posts.update_post_flags(post, ['invalid'])
def test_featuring_post(post_factory, user_factory):
post = post_factory()
user = user_factory()
previous_featured_post = posts.try_get_featured_post()
posts.feature_post(post, user)
new_featured_post = posts.try_get_featured_post()
assert previous_featured_post is None
assert new_featured_post == post