858 lines
27 KiB
Python
858 lines
27 KiB
Python
import logging
|
|
import hmac
|
|
from typing import Any, Optional, Tuple, List, Dict, Callable
|
|
from datetime import datetime
|
|
import sqlalchemy as sa
|
|
from szurubooru import config, db, model, errors, rest
|
|
from szurubooru.func import (
|
|
users, scores, comments, tags, util,
|
|
mime, images, files, image_hash, serialization, snapshots)
|
|
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
EMPTY_PIXEL = (
|
|
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
|
|
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
|
|
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
|
|
|
|
|
|
class PostNotFoundError(errors.NotFoundError):
|
|
pass
|
|
|
|
|
|
class PostAlreadyFeaturedError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class PostAlreadyUploadedError(errors.ValidationError):
|
|
def __init__(self, other_post: model.Post) -> None:
|
|
super().__init__(
|
|
'Post already uploaded (%d)' % other_post.post_id,
|
|
{
|
|
'otherPostUrl': get_post_content_url(other_post),
|
|
'otherPostId': other_post.post_id,
|
|
})
|
|
|
|
|
|
class InvalidPostIdError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostSafetyError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostSourceError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostContentError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostRelationError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostNoteError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostFlagError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
SAFETY_MAP = {
|
|
model.Post.SAFETY_SAFE: 'safe',
|
|
model.Post.SAFETY_SKETCHY: 'sketchy',
|
|
model.Post.SAFETY_UNSAFE: 'unsafe',
|
|
}
|
|
|
|
TYPE_MAP = {
|
|
model.Post.TYPE_IMAGE: 'image',
|
|
model.Post.TYPE_ANIMATION: 'animation',
|
|
model.Post.TYPE_VIDEO: 'video',
|
|
model.Post.TYPE_FLASH: 'flash',
|
|
}
|
|
|
|
FLAG_MAP = {
|
|
model.Post.FLAG_LOOP: 'loop',
|
|
model.Post.FLAG_SOUND: 'sound',
|
|
}
|
|
|
|
|
|
def get_post_security_hash(id: int) -> str:
|
|
return hmac.new(
|
|
config.config['secret'].encode('utf8'),
|
|
msg=str(id).encode('utf-8'),
|
|
digestmod='md5').hexdigest()[0:16]
|
|
|
|
|
|
def get_post_content_url(post: model.Post) -> str:
|
|
assert post
|
|
return '%s/posts/%d_%s.%s' % (
|
|
config.config['data_url'].rstrip('/'),
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id),
|
|
mime.get_extension(post.mime_type) or 'dat')
|
|
|
|
|
|
def get_post_thumbnail_url(post: model.Post) -> str:
|
|
assert post
|
|
return '%s/generated-thumbnails/%d_%s.jpg' % (
|
|
config.config['data_url'].rstrip('/'),
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id))
|
|
|
|
|
|
def get_post_content_path(post: model.Post) -> str:
|
|
assert post
|
|
assert post.post_id
|
|
return 'posts/%d_%s.%s' % (
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id),
|
|
mime.get_extension(post.mime_type) or 'dat')
|
|
|
|
|
|
def get_post_thumbnail_path(post: model.Post) -> str:
|
|
assert post
|
|
return 'generated-thumbnails/%d_%s.jpg' % (
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id))
|
|
|
|
|
|
def get_post_thumbnail_backup_path(post: model.Post) -> str:
|
|
assert post
|
|
return 'posts/custom-thumbnails/%d_%s.dat' % (
|
|
post.post_id, get_post_security_hash(post.post_id))
|
|
|
|
|
|
def serialize_note(note: model.PostNote) -> rest.Response:
|
|
assert note
|
|
return {
|
|
'polygon': note.polygon,
|
|
'text': note.text,
|
|
}
|
|
|
|
|
|
class PostSerializer(serialization.BaseSerializer):
|
|
def __init__(self, post: model.Post, auth_user: model.User) -> None:
|
|
self.post = post
|
|
self.auth_user = auth_user
|
|
|
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
|
return {
|
|
'id': self.serialize_id,
|
|
'version': self.serialize_version,
|
|
'creationTime': self.serialize_creation_time,
|
|
'lastEditTime': self.serialize_last_edit_time,
|
|
'safety': self.serialize_safety,
|
|
'source': self.serialize_source,
|
|
'type': self.serialize_type,
|
|
'mimeType': self.serialize_mime,
|
|
'checksum': self.serialize_checksum,
|
|
'fileSize': self.serialize_file_size,
|
|
'canvasWidth': self.serialize_canvas_width,
|
|
'canvasHeight': self.serialize_canvas_height,
|
|
'contentUrl': self.serialize_content_url,
|
|
'thumbnailUrl': self.serialize_thumbnail_url,
|
|
'flags': self.serialize_flags,
|
|
'tags': self.serialize_tags,
|
|
'relations': self.serialize_relations,
|
|
'user': self.serialize_user,
|
|
'score': self.serialize_score,
|
|
'ownScore': self.serialize_own_score,
|
|
'ownFavorite': self.serialize_own_favorite,
|
|
'tagCount': self.serialize_tag_count,
|
|
'favoriteCount': self.serialize_favorite_count,
|
|
'commentCount': self.serialize_comment_count,
|
|
'noteCount': self.serialize_note_count,
|
|
'relationCount': self.serialize_relation_count,
|
|
'featureCount': self.serialize_feature_count,
|
|
'lastFeatureTime': self.serialize_last_feature_time,
|
|
'favoritedBy': self.serialize_favorited_by,
|
|
'hasCustomThumbnail': self.serialize_has_custom_thumbnail,
|
|
'notes': self.serialize_notes,
|
|
'comments': self.serialize_comments,
|
|
}
|
|
|
|
def serialize_id(self) -> Any:
|
|
return self.post.post_id
|
|
|
|
def serialize_version(self) -> Any:
|
|
return self.post.version
|
|
|
|
def serialize_creation_time(self) -> Any:
|
|
return self.post.creation_time
|
|
|
|
def serialize_last_edit_time(self) -> Any:
|
|
return self.post.last_edit_time
|
|
|
|
def serialize_safety(self) -> Any:
|
|
return SAFETY_MAP[self.post.safety]
|
|
|
|
def serialize_source(self) -> Any:
|
|
return self.post.source
|
|
|
|
def serialize_type(self) -> Any:
|
|
return TYPE_MAP[self.post.type]
|
|
|
|
def serialize_mime(self) -> Any:
|
|
return self.post.mime_type
|
|
|
|
def serialize_checksum(self) -> Any:
|
|
return self.post.checksum
|
|
|
|
def serialize_file_size(self) -> Any:
|
|
return self.post.file_size
|
|
|
|
def serialize_canvas_width(self) -> Any:
|
|
return self.post.canvas_width
|
|
|
|
def serialize_canvas_height(self) -> Any:
|
|
return self.post.canvas_height
|
|
|
|
def serialize_content_url(self) -> Any:
|
|
return get_post_content_url(self.post)
|
|
|
|
def serialize_thumbnail_url(self) -> Any:
|
|
return get_post_thumbnail_url(self.post)
|
|
|
|
def serialize_flags(self) -> Any:
|
|
return self.post.flags
|
|
|
|
def serialize_tags(self) -> Any:
|
|
return [
|
|
{
|
|
'names': [name.name for name in tag.names],
|
|
'category': tag.category.name,
|
|
'usages': tag.post_count,
|
|
}
|
|
for tag in tags.sort_tags(self.post.tags)]
|
|
|
|
def serialize_relations(self) -> Any:
|
|
return sorted(
|
|
{
|
|
post['id']: post
|
|
for post in [
|
|
serialize_micro_post(rel, self.auth_user)
|
|
for rel in self.post.relations]
|
|
}.values(),
|
|
key=lambda post: post['id'])
|
|
|
|
def serialize_user(self) -> Any:
|
|
return users.serialize_micro_user(self.post.user, self.auth_user)
|
|
|
|
def serialize_score(self) -> Any:
|
|
return self.post.score
|
|
|
|
def serialize_own_score(self) -> Any:
|
|
return scores.get_score(self.post, self.auth_user)
|
|
|
|
def serialize_own_favorite(self) -> Any:
|
|
return len([
|
|
user for user in self.post.favorited_by
|
|
if user.user_id == self.auth_user.user_id]
|
|
) > 0
|
|
|
|
def serialize_tag_count(self) -> Any:
|
|
return self.post.tag_count
|
|
|
|
def serialize_favorite_count(self) -> Any:
|
|
return self.post.favorite_count
|
|
|
|
def serialize_comment_count(self) -> Any:
|
|
return self.post.comment_count
|
|
|
|
def serialize_note_count(self) -> Any:
|
|
return self.post.note_count
|
|
|
|
def serialize_relation_count(self) -> Any:
|
|
return self.post.relation_count
|
|
|
|
def serialize_feature_count(self) -> Any:
|
|
return self.post.feature_count
|
|
|
|
def serialize_last_feature_time(self) -> Any:
|
|
return self.post.last_feature_time
|
|
|
|
def serialize_favorited_by(self) -> Any:
|
|
return [
|
|
users.serialize_micro_user(rel.user, self.auth_user)
|
|
for rel in self.post.favorited_by
|
|
]
|
|
|
|
def serialize_has_custom_thumbnail(self) -> Any:
|
|
return files.has(get_post_thumbnail_backup_path(self.post))
|
|
|
|
def serialize_notes(self) -> Any:
|
|
return sorted(
|
|
[serialize_note(note) for note in self.post.notes],
|
|
key=lambda x: x['polygon'])
|
|
|
|
def serialize_comments(self) -> Any:
|
|
return [
|
|
comments.serialize_comment(comment, self.auth_user)
|
|
for comment in sorted(
|
|
self.post.comments,
|
|
key=lambda comment: comment.creation_time)]
|
|
|
|
|
|
def serialize_post(
|
|
post: Optional[model.Post],
|
|
auth_user: model.User,
|
|
options: List[str] = []) -> Optional[rest.Response]:
|
|
if not post:
|
|
return None
|
|
return PostSerializer(post, auth_user).serialize(options)
|
|
|
|
|
|
def serialize_micro_post(
|
|
post: model.Post, auth_user: model.User) -> Optional[rest.Response]:
|
|
return serialize_post(
|
|
post, auth_user=auth_user, options=['id', 'thumbnailUrl'])
|
|
|
|
|
|
def get_post_count() -> int:
|
|
return db.session.query(sa.func.count(model.Post.post_id)).one()[0]
|
|
|
|
|
|
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
|
|
return (
|
|
db.session
|
|
.query(model.Post)
|
|
.filter(model.Post.post_id == post_id)
|
|
.one_or_none())
|
|
|
|
|
|
def get_post_by_id(post_id: int) -> model.Post:
|
|
post = try_get_post_by_id(post_id)
|
|
if not post:
|
|
raise PostNotFoundError('Post %r not found.' % post_id)
|
|
return post
|
|
|
|
|
|
def try_get_current_post_feature() -> Optional[model.PostFeature]:
|
|
return (
|
|
db.session
|
|
.query(model.PostFeature)
|
|
.order_by(model.PostFeature.time.desc())
|
|
.first())
|
|
|
|
|
|
def try_get_featured_post() -> Optional[model.Post]:
|
|
post_feature = try_get_current_post_feature()
|
|
return post_feature.post if post_feature else None
|
|
|
|
|
|
def create_post(
|
|
content: bytes,
|
|
tag_names: List[str],
|
|
user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]:
|
|
post = model.Post()
|
|
post.safety = model.Post.SAFETY_SAFE
|
|
post.user = user
|
|
post.creation_time = datetime.utcnow()
|
|
post.flags = []
|
|
|
|
post.type = ''
|
|
post.checksum = ''
|
|
post.mime_type = ''
|
|
|
|
update_post_content(post, content)
|
|
new_tags = update_post_tags(post, tag_names)
|
|
|
|
db.session.add(post)
|
|
return post, new_tags
|
|
|
|
|
|
def update_post_safety(post: model.Post, safety: str) -> None:
|
|
assert post
|
|
safety = util.flip(SAFETY_MAP).get(safety, None)
|
|
if not safety:
|
|
raise InvalidPostSafetyError(
|
|
'Safety can be either of %r.' % list(SAFETY_MAP.values()))
|
|
post.safety = safety
|
|
|
|
|
|
def update_post_source(post: model.Post, source: Optional[str]) -> None:
|
|
assert post
|
|
if util.value_exceeds_column_size(source, model.Post.source):
|
|
raise InvalidPostSourceError('Source is too long.')
|
|
post.source = source or None
|
|
|
|
|
|
@sa.events.event.listens_for(model.Post, 'after_insert')
|
|
def _after_post_insert(
|
|
_mapper: Any, _connection: Any, post: model.Post) -> None:
|
|
_sync_post_content(post)
|
|
|
|
|
|
@sa.events.event.listens_for(model.Post, 'after_update')
|
|
def _after_post_update(
|
|
_mapper: Any, _connection: Any, post: model.Post) -> None:
|
|
_sync_post_content(post)
|
|
|
|
|
|
@sa.events.event.listens_for(model.Post, 'before_delete')
|
|
def _before_post_delete(
|
|
_mapper: Any, _connection: Any, post: model.Post) -> None:
|
|
if post.post_id:
|
|
if config.config['delete_source_files']:
|
|
files.delete(get_post_content_path(post))
|
|
files.delete(get_post_thumbnail_path(post))
|
|
|
|
|
|
def _sync_post_content(post: model.Post) -> None:
|
|
regenerate_thumb = False
|
|
|
|
if hasattr(post, '__content'):
|
|
content = getattr(post, '__content')
|
|
files.save(get_post_content_path(post), content)
|
|
delattr(post, '__content')
|
|
regenerate_thumb = True
|
|
|
|
if hasattr(post, '__thumbnail'):
|
|
if getattr(post, '__thumbnail'):
|
|
files.save(
|
|
get_post_thumbnail_backup_path(post),
|
|
getattr(post, '__thumbnail'))
|
|
else:
|
|
files.delete(get_post_thumbnail_backup_path(post))
|
|
delattr(post, '__thumbnail')
|
|
regenerate_thumb = True
|
|
|
|
if regenerate_thumb:
|
|
generate_post_thumbnail(post)
|
|
|
|
|
|
def generate_alternate_formats(post: model.Post, content: bytes) \
|
|
-> List[Tuple[model.Post, List[model.Tag]]]:
|
|
assert post
|
|
assert content
|
|
new_posts = []
|
|
if mime.is_animated_gif(content):
|
|
tag_names = [tag.first_name for tag in post.tags]
|
|
|
|
if config.config['convert']['gif']['to_mp4']:
|
|
mp4_post, new_tags = create_post(
|
|
images.Image(content).to_mp4(),
|
|
tag_names,
|
|
post.user)
|
|
update_post_flags(mp4_post, ['loop'])
|
|
update_post_safety(mp4_post, post.safety)
|
|
update_post_source(mp4_post, post.source)
|
|
new_posts += [(mp4_post, new_tags)]
|
|
|
|
if config.config['convert']['gif']['to_webm']:
|
|
webm_post, new_tags = create_post(
|
|
images.Image(content).to_webm(),
|
|
tag_names,
|
|
post.user)
|
|
update_post_flags(webm_post, ['loop'])
|
|
update_post_safety(webm_post, post.safety)
|
|
update_post_source(webm_post, post.source)
|
|
new_posts += [(webm_post, new_tags)]
|
|
|
|
db.session.flush()
|
|
|
|
new_posts = [p for p in new_posts if p[0] is not None]
|
|
|
|
new_relations = [p[0].post_id for p in new_posts]
|
|
if len(new_relations) > 0:
|
|
update_post_relations(post, new_relations)
|
|
|
|
return new_posts
|
|
|
|
|
|
def get_default_flags(content: bytes) -> List[str]:
|
|
assert content
|
|
ret = []
|
|
if mime.is_video(mime.get_mime_type(content)):
|
|
ret.append(model.Post.FLAG_LOOP)
|
|
if images.Image(content).check_for_sound():
|
|
ret.append(model.Post.FLAG_SOUND)
|
|
return ret
|
|
|
|
|
|
def purge_post_signature(post: model.Post) -> None:
|
|
(db.session
|
|
.query(model.PostSignature)
|
|
.filter(model.PostSignature.post_id == post.post_id)
|
|
.delete())
|
|
|
|
|
|
def generate_post_signature(post: model.Post, content: bytes) -> None:
|
|
try:
|
|
unpacked_signature = image_hash.generate_signature(content)
|
|
packed_signature = image_hash.pack_signature(unpacked_signature)
|
|
words = image_hash.generate_words(unpacked_signature)
|
|
|
|
db.session.add(model.PostSignature(
|
|
post=post, signature=packed_signature, words=words))
|
|
except errors.ProcessingError:
|
|
if not config.config['allow_broken_uploads']:
|
|
raise InvalidPostContentError(
|
|
'Unable to generate image hash data.')
|
|
|
|
|
|
def update_all_post_signatures() -> None:
|
|
posts_to_hash = (
|
|
db.session
|
|
.query(model.Post)
|
|
.filter(
|
|
(model.Post.type == model.Post.TYPE_IMAGE) |
|
|
(model.Post.type == model.Post.TYPE_ANIMATION))
|
|
.filter(model.Post.signature == None)
|
|
.order_by(model.Post.post_id.asc())
|
|
.all())
|
|
for post in posts_to_hash:
|
|
try:
|
|
generate_post_signature(
|
|
post, files.get(get_post_content_path(post)))
|
|
db.session.commit()
|
|
logger.info('Hashed Post %d', post.post_id)
|
|
except Exception as ex:
|
|
logger.exception(ex)
|
|
|
|
|
|
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
|
assert post
|
|
if not content:
|
|
raise InvalidPostContentError('Post content missing.')
|
|
|
|
update_signature = False
|
|
post.mime_type = mime.get_mime_type(content)
|
|
if mime.is_flash(post.mime_type):
|
|
post.type = model.Post.TYPE_FLASH
|
|
elif mime.is_image(post.mime_type):
|
|
update_signature = True
|
|
if mime.is_animated_gif(content):
|
|
post.type = model.Post.TYPE_ANIMATION
|
|
else:
|
|
post.type = model.Post.TYPE_IMAGE
|
|
elif mime.is_video(post.mime_type):
|
|
post.type = model.Post.TYPE_VIDEO
|
|
else:
|
|
raise InvalidPostContentError(
|
|
'Unhandled file type: %r' % post.mime_type)
|
|
|
|
post.checksum = util.get_sha1(content)
|
|
other_post = (
|
|
db.session
|
|
.query(model.Post)
|
|
.filter(model.Post.checksum == post.checksum)
|
|
.filter(model.Post.post_id != post.post_id)
|
|
.one_or_none())
|
|
if other_post \
|
|
and other_post.post_id \
|
|
and other_post.post_id != post.post_id:
|
|
raise PostAlreadyUploadedError(other_post)
|
|
|
|
if update_signature:
|
|
purge_post_signature(post)
|
|
post.signature = generate_post_signature(post, content)
|
|
|
|
post.file_size = len(content)
|
|
try:
|
|
image = images.Image(content)
|
|
post.canvas_width = image.width
|
|
post.canvas_height = image.height
|
|
except errors.ProcessingError:
|
|
if not config.config['allow_broken_uploads']:
|
|
raise InvalidPostContentError(
|
|
'Unable to process image metadata')
|
|
else:
|
|
post.canvas_width = None
|
|
post.canvas_height = None
|
|
if (post.canvas_width is not None and post.canvas_width <= 0) \
|
|
or (post.canvas_height is not None and post.canvas_height <= 0):
|
|
if not config.config['allow_broken_uploads']:
|
|
raise InvalidPostContentError(
|
|
'Invalid image dimensions returned during processing')
|
|
else:
|
|
post.canvas_width = None
|
|
post.canvas_height = None
|
|
setattr(post, '__content', content)
|
|
|
|
|
|
def update_post_thumbnail(
|
|
post: model.Post, content: Optional[bytes] = None) -> None:
|
|
assert post
|
|
setattr(post, '__thumbnail', content)
|
|
|
|
|
|
def generate_post_thumbnail(post: model.Post) -> None:
|
|
assert post
|
|
if files.has(get_post_thumbnail_backup_path(post)):
|
|
content = files.get(get_post_thumbnail_backup_path(post))
|
|
else:
|
|
content = files.get(get_post_content_path(post))
|
|
try:
|
|
assert content
|
|
image = images.Image(content)
|
|
image.resize_fill(
|
|
int(config.config['thumbnails']['post_width']),
|
|
int(config.config['thumbnails']['post_height']))
|
|
files.save(get_post_thumbnail_path(post), image.to_jpeg())
|
|
except errors.ProcessingError:
|
|
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
|
|
|
|
|
|
def update_post_tags(
|
|
post: model.Post, tag_names: List[str]) -> List[model.Tag]:
|
|
assert post
|
|
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
|
post.tags = existing_tags + new_tags
|
|
return new_tags
|
|
|
|
|
|
def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
|
|
assert post
|
|
try:
|
|
new_post_ids = [int(id) for id in new_post_ids]
|
|
except ValueError:
|
|
raise InvalidPostRelationError(
|
|
'A relation must be numeric post ID.')
|
|
old_posts = post.relations
|
|
old_post_ids = [int(p.post_id) for p in old_posts]
|
|
if new_post_ids:
|
|
new_posts = (
|
|
db.session
|
|
.query(model.Post)
|
|
.filter(model.Post.post_id.in_(new_post_ids))
|
|
.all())
|
|
else:
|
|
new_posts = []
|
|
if len(new_posts) != len(new_post_ids):
|
|
raise InvalidPostRelationError('One of relations does not exist.')
|
|
if post.post_id in new_post_ids:
|
|
raise InvalidPostRelationError('Post cannot relate to itself.')
|
|
|
|
relations_to_del = [p for p in old_posts if p.post_id not in new_post_ids]
|
|
relations_to_add = [p for p in new_posts if p.post_id not in old_post_ids]
|
|
for relation in relations_to_del:
|
|
post.relations.remove(relation)
|
|
relation.relations.remove(post)
|
|
for relation in relations_to_add:
|
|
post.relations.append(relation)
|
|
relation.relations.append(post)
|
|
|
|
|
|
def update_post_notes(post: model.Post, notes: Any) -> None:
|
|
assert post
|
|
post.notes = []
|
|
for note in notes:
|
|
for field in ('polygon', 'text'):
|
|
if field not in note:
|
|
raise InvalidPostNoteError('Note is missing %r field.' % field)
|
|
if not note['text']:
|
|
raise InvalidPostNoteError('A note\'s text cannot be empty.')
|
|
if not isinstance(note['polygon'], (list, tuple)):
|
|
raise InvalidPostNoteError(
|
|
'A note\'s polygon must be a list of points.')
|
|
if len(note['polygon']) < 3:
|
|
raise InvalidPostNoteError(
|
|
'A note\'s polygon must have at least 3 points.')
|
|
for point in note['polygon']:
|
|
if not isinstance(point, (list, tuple)):
|
|
raise InvalidPostNoteError(
|
|
'A note\'s polygon point must be a list of length 2.')
|
|
if len(point) != 2:
|
|
raise InvalidPostNoteError(
|
|
'A point in note\'s polygon must have two coordinates.')
|
|
try:
|
|
pos_x = float(point[0])
|
|
pos_y = float(point[1])
|
|
if not 0 <= pos_x <= 1 or not 0 <= pos_y <= 1:
|
|
raise InvalidPostNoteError(
|
|
'All points must fit in the image (0..1 range).')
|
|
except ValueError:
|
|
raise InvalidPostNoteError(
|
|
'A point in note\'s polygon must be numeric.')
|
|
if util.value_exceeds_column_size(note['text'], model.PostNote.text):
|
|
raise InvalidPostNoteError('Note text is too long.')
|
|
post.notes.append(
|
|
model.PostNote(polygon=note['polygon'], text=str(note['text'])))
|
|
|
|
|
|
def update_post_flags(post: model.Post, flags: List[str]) -> None:
|
|
assert post
|
|
target_flags = []
|
|
for flag in flags:
|
|
flag = util.flip(FLAG_MAP).get(flag, None)
|
|
if not flag:
|
|
raise InvalidPostFlagError(
|
|
'Flag must be one of %r.' % list(FLAG_MAP.values()))
|
|
target_flags.append(flag)
|
|
post.flags = target_flags
|
|
|
|
|
|
def feature_post(post: model.Post, user: Optional[model.User]) -> None:
|
|
assert post
|
|
post_feature = model.PostFeature()
|
|
post_feature.time = datetime.utcnow()
|
|
post_feature.post = post
|
|
post_feature.user = user
|
|
db.session.add(post_feature)
|
|
|
|
|
|
def delete(post: model.Post) -> None:
|
|
assert post
|
|
db.session.delete(post)
|
|
|
|
|
|
def merge_posts(
|
|
source_post: model.Post,
|
|
target_post: model.Post,
|
|
replace_content: bool) -> None:
|
|
assert source_post
|
|
assert target_post
|
|
if source_post.post_id == target_post.post_id:
|
|
raise InvalidPostRelationError('Cannot merge post with itself.')
|
|
|
|
def merge_tables(
|
|
table: model.Base,
|
|
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
|
|
source_post_id: int,
|
|
target_post_id: int) -> None:
|
|
alias1 = table
|
|
alias2 = sa.orm.util.aliased(table)
|
|
update_stmt = (
|
|
sa.sql.expression.update(alias1)
|
|
.where(alias1.post_id == source_post_id))
|
|
|
|
if anti_dup_func is not None:
|
|
update_stmt = (
|
|
update_stmt
|
|
.where(
|
|
~sa.exists()
|
|
.where(anti_dup_func(alias1, alias2))
|
|
.where(alias2.post_id == target_post_id)))
|
|
|
|
update_stmt = update_stmt.values(post_id=target_post_id)
|
|
db.session.execute(update_stmt)
|
|
|
|
def merge_tags(source_post_id: int, target_post_id: int) -> None:
|
|
merge_tables(
|
|
model.PostTag,
|
|
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
|
|
source_post_id,
|
|
target_post_id)
|
|
|
|
def merge_scores(source_post_id: int, target_post_id: int) -> None:
|
|
merge_tables(
|
|
model.PostScore,
|
|
lambda alias1, alias2: alias1.user_id == alias2.user_id,
|
|
source_post_id,
|
|
target_post_id)
|
|
|
|
def merge_favorites(source_post_id: int, target_post_id: int) -> None:
|
|
merge_tables(
|
|
model.PostFavorite,
|
|
lambda alias1, alias2: alias1.user_id == alias2.user_id,
|
|
source_post_id,
|
|
target_post_id)
|
|
|
|
def merge_comments(source_post_id: int, target_post_id: int) -> None:
|
|
merge_tables(model.Comment, None, source_post_id, target_post_id)
|
|
|
|
def merge_relations(source_post_id: int, target_post_id: int) -> None:
|
|
alias1 = model.PostRelation
|
|
alias2 = sa.orm.util.aliased(model.PostRelation)
|
|
update_stmt = (
|
|
sa.sql.expression.update(alias1)
|
|
.where(alias1.parent_id == source_post_id)
|
|
.where(alias1.child_id != target_post_id)
|
|
.where(
|
|
~sa.exists()
|
|
.where(alias2.child_id == alias1.child_id)
|
|
.where(alias2.parent_id == target_post_id))
|
|
.values(parent_id=target_post_id))
|
|
db.session.execute(update_stmt)
|
|
|
|
update_stmt = (
|
|
sa.sql.expression.update(alias1)
|
|
.where(alias1.child_id == source_post_id)
|
|
.where(alias1.parent_id != target_post_id)
|
|
.where(
|
|
~sa.exists()
|
|
.where(alias2.parent_id == alias1.parent_id)
|
|
.where(alias2.child_id == target_post_id))
|
|
.values(child_id=target_post_id))
|
|
db.session.execute(update_stmt)
|
|
|
|
merge_tags(source_post.post_id, target_post.post_id)
|
|
merge_comments(source_post.post_id, target_post.post_id)
|
|
merge_scores(source_post.post_id, target_post.post_id)
|
|
merge_favorites(source_post.post_id, target_post.post_id)
|
|
merge_relations(source_post.post_id, target_post.post_id)
|
|
|
|
def transfer_flags(source_post_id: int, target_post_id: int) -> None:
|
|
target = get_post_by_id(target_post_id)
|
|
source = get_post_by_id(source_post_id)
|
|
target.flags = source.flags
|
|
db.session.flush()
|
|
|
|
content = None
|
|
if replace_content:
|
|
content = files.get(get_post_content_path(source_post))
|
|
transfer_flags(source_post.post_id, target_post.post_id)
|
|
|
|
# fixes unknown issue with SA's cascade deletions
|
|
purge_post_signature(source_post)
|
|
delete(source_post)
|
|
db.session.flush()
|
|
|
|
if content is not None:
|
|
update_post_content(target_post, content)
|
|
|
|
|
|
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
|
|
checksum = util.get_sha1(image_content)
|
|
return (
|
|
db.session
|
|
.query(model.Post)
|
|
.filter(model.Post.checksum == checksum)
|
|
.one_or_none())
|
|
|
|
|
|
def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]:
|
|
query_signature = image_hash.generate_signature(image_content)
|
|
query_words = image_hash.generate_words(query_signature)
|
|
|
|
'''
|
|
The unnest function is used here to expand one row containing the 'words'
|
|
array into multiple rows each containing a singular word.
|
|
|
|
Documentation of the unnest function can be found here:
|
|
https://www.postgresql.org/docs/9.2/functions-array.html
|
|
'''
|
|
|
|
dbquery = '''
|
|
SELECT s.post_id, s.signature, count(a.query) AS score
|
|
FROM post_signature AS s, unnest(s.words, :q) AS a(word, query)
|
|
WHERE a.word = a.query
|
|
GROUP BY s.post_id
|
|
ORDER BY score DESC LIMIT 100;
|
|
'''
|
|
|
|
candidates = db.session.execute(dbquery, {'q': query_words})
|
|
data = tuple(zip(*[
|
|
(post_id, image_hash.unpack_signature(packedsig))
|
|
for post_id, packedsig, score in candidates
|
|
]))
|
|
if data:
|
|
candidate_post_ids, sigarray = data
|
|
distances = image_hash.normalized_distance(sigarray, query_signature)
|
|
return [
|
|
(distance, try_get_post_by_id(candidate_post_id))
|
|
for candidate_post_id, distance
|
|
in zip(candidate_post_ids, distances)
|
|
if distance < image_hash.DISTANCE_CUTOFF
|
|
]
|
|
else:
|
|
return []
|