970 lines
29 KiB
Python
970 lines
29 KiB
Python
import hmac
|
|
import logging
|
|
from datetime import datetime
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import sqlalchemy as sa
|
|
|
|
from szurubooru import config, db, errors, model, rest
|
|
from szurubooru.func import (
|
|
comments,
|
|
files,
|
|
image_hash,
|
|
images,
|
|
mime,
|
|
pools,
|
|
scores,
|
|
serialization,
|
|
snapshots,
|
|
tags,
|
|
users,
|
|
util,
|
|
)
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
EMPTY_PIXEL = (
|
|
b"\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00"
|
|
b"\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00"
|
|
b"\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b"
|
|
)
|
|
|
|
|
|
class PostNotFoundError(errors.NotFoundError):
|
|
pass
|
|
|
|
|
|
class PostAlreadyFeaturedError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class PostAlreadyUploadedError(errors.ValidationError):
|
|
def __init__(self, other_post: model.Post) -> None:
|
|
super().__init__(
|
|
"Post already uploaded (%d)" % other_post.post_id,
|
|
{
|
|
"otherPostUrl": get_post_content_url(other_post),
|
|
"otherPostId": other_post.post_id,
|
|
},
|
|
)
|
|
|
|
|
|
class InvalidPostIdError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostSafetyError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostSourceError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostContentError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostRelationError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostNoteError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPostFlagError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
SAFETY_MAP = {
|
|
model.Post.SAFETY_SAFE: "safe",
|
|
model.Post.SAFETY_SKETCHY: "sketchy",
|
|
model.Post.SAFETY_UNSAFE: "unsafe",
|
|
}
|
|
|
|
TYPE_MAP = {
|
|
model.Post.TYPE_IMAGE: "image",
|
|
model.Post.TYPE_ANIMATION: "animation",
|
|
model.Post.TYPE_VIDEO: "video",
|
|
model.Post.TYPE_FLASH: "flash",
|
|
}
|
|
|
|
FLAG_MAP = {
|
|
model.Post.FLAG_LOOP: "loop",
|
|
model.Post.FLAG_SOUND: "sound",
|
|
}
|
|
|
|
|
|
def get_post_security_hash(id: int) -> str:
|
|
return hmac.new(
|
|
config.config["secret"].encode("utf8"),
|
|
msg=str(id).encode("utf-8"),
|
|
digestmod="md5",
|
|
).hexdigest()[0:16]
|
|
|
|
|
|
def get_post_content_url(post: model.Post) -> str:
|
|
assert post
|
|
return "%s/posts/%d_%s.%s" % (
|
|
config.config["data_url"].rstrip("/"),
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id),
|
|
mime.get_extension(post.mime_type) or "dat",
|
|
)
|
|
|
|
|
|
def get_post_thumbnail_url(post: model.Post) -> str:
|
|
assert post
|
|
return "%s/generated-thumbnails/%d_%s.jpg" % (
|
|
config.config["data_url"].rstrip("/"),
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id),
|
|
)
|
|
|
|
|
|
def get_post_content_path(post: model.Post) -> str:
|
|
assert post
|
|
assert post.post_id
|
|
return "posts/%d_%s.%s" % (
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id),
|
|
mime.get_extension(post.mime_type) or "dat",
|
|
)
|
|
|
|
|
|
def get_post_thumbnail_path(post: model.Post) -> str:
|
|
assert post
|
|
return "generated-thumbnails/%d_%s.jpg" % (
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id),
|
|
)
|
|
|
|
|
|
def get_post_thumbnail_backup_path(post: model.Post) -> str:
|
|
assert post
|
|
return "posts/custom-thumbnails/%d_%s.dat" % (
|
|
post.post_id,
|
|
get_post_security_hash(post.post_id),
|
|
)
|
|
|
|
|
|
def serialize_note(note: model.PostNote) -> rest.Response:
|
|
assert note
|
|
return {
|
|
"polygon": note.polygon,
|
|
"text": note.text,
|
|
}
|
|
|
|
|
|
class PostSerializer(serialization.BaseSerializer):
|
|
def __init__(self, post: model.Post, auth_user: model.User) -> None:
|
|
self.post = post
|
|
self.auth_user = auth_user
|
|
|
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
|
return {
|
|
"id": self.serialize_id,
|
|
"version": self.serialize_version,
|
|
"creationTime": self.serialize_creation_time,
|
|
"lastEditTime": self.serialize_last_edit_time,
|
|
"safety": self.serialize_safety,
|
|
"source": self.serialize_source,
|
|
"type": self.serialize_type,
|
|
"mimeType": self.serialize_mime,
|
|
"checksum": self.serialize_checksum,
|
|
"checksumMD5": self.serialize_checksum_md5,
|
|
"fileSize": self.serialize_file_size,
|
|
"canvasWidth": self.serialize_canvas_width,
|
|
"canvasHeight": self.serialize_canvas_height,
|
|
"contentUrl": self.serialize_content_url,
|
|
"thumbnailUrl": self.serialize_thumbnail_url,
|
|
"flags": self.serialize_flags,
|
|
"tags": self.serialize_tags,
|
|
"relations": self.serialize_relations,
|
|
"user": self.serialize_user,
|
|
"score": self.serialize_score,
|
|
"ownScore": self.serialize_own_score,
|
|
"ownFavorite": self.serialize_own_favorite,
|
|
"tagCount": self.serialize_tag_count,
|
|
"favoriteCount": self.serialize_favorite_count,
|
|
"commentCount": self.serialize_comment_count,
|
|
"noteCount": self.serialize_note_count,
|
|
"relationCount": self.serialize_relation_count,
|
|
"featureCount": self.serialize_feature_count,
|
|
"lastFeatureTime": self.serialize_last_feature_time,
|
|
"favoritedBy": self.serialize_favorited_by,
|
|
"hasCustomThumbnail": self.serialize_has_custom_thumbnail,
|
|
"notes": self.serialize_notes,
|
|
"comments": self.serialize_comments,
|
|
"pools": self.serialize_pools,
|
|
}
|
|
|
|
def serialize_id(self) -> Any:
|
|
return self.post.post_id
|
|
|
|
def serialize_version(self) -> Any:
|
|
return self.post.version
|
|
|
|
def serialize_creation_time(self) -> Any:
|
|
return self.post.creation_time
|
|
|
|
def serialize_last_edit_time(self) -> Any:
|
|
return self.post.last_edit_time
|
|
|
|
def serialize_safety(self) -> Any:
|
|
return SAFETY_MAP[self.post.safety]
|
|
|
|
def serialize_source(self) -> Any:
|
|
return self.post.source
|
|
|
|
def serialize_type(self) -> Any:
|
|
return TYPE_MAP[self.post.type]
|
|
|
|
def serialize_mime(self) -> Any:
|
|
return self.post.mime_type
|
|
|
|
def serialize_checksum(self) -> Any:
|
|
return self.post.checksum
|
|
|
|
def serialize_checksum_md5(self) -> Any:
|
|
return self.post.checksum_md5
|
|
|
|
def serialize_file_size(self) -> Any:
|
|
return self.post.file_size
|
|
|
|
def serialize_canvas_width(self) -> Any:
|
|
return self.post.canvas_width
|
|
|
|
def serialize_canvas_height(self) -> Any:
|
|
return self.post.canvas_height
|
|
|
|
def serialize_content_url(self) -> Any:
|
|
return get_post_content_url(self.post)
|
|
|
|
def serialize_thumbnail_url(self) -> Any:
|
|
return get_post_thumbnail_url(self.post)
|
|
|
|
def serialize_flags(self) -> Any:
|
|
return self.post.flags
|
|
|
|
def serialize_tags(self) -> Any:
|
|
return [
|
|
{
|
|
"names": [name.name for name in tag.names],
|
|
"category": tag.category.name,
|
|
"usages": tag.post_count,
|
|
}
|
|
for tag in tags.sort_tags(self.post.tags)
|
|
]
|
|
|
|
def serialize_relations(self) -> Any:
|
|
return sorted(
|
|
{
|
|
post["id"]: post
|
|
for post in [
|
|
serialize_micro_post(rel, self.auth_user)
|
|
for rel in self.post.relations
|
|
]
|
|
}.values(),
|
|
key=lambda post: post["id"],
|
|
)
|
|
|
|
def serialize_user(self) -> Any:
|
|
return users.serialize_micro_user(self.post.user, self.auth_user)
|
|
|
|
def serialize_score(self) -> Any:
|
|
return self.post.score
|
|
|
|
def serialize_own_score(self) -> Any:
|
|
return scores.get_score(self.post, self.auth_user)
|
|
|
|
def serialize_own_favorite(self) -> Any:
|
|
return (
|
|
len(
|
|
[
|
|
user
|
|
for user in self.post.favorited_by
|
|
if user.user_id == self.auth_user.user_id
|
|
]
|
|
)
|
|
> 0
|
|
)
|
|
|
|
def serialize_tag_count(self) -> Any:
|
|
return self.post.tag_count
|
|
|
|
def serialize_favorite_count(self) -> Any:
|
|
return self.post.favorite_count
|
|
|
|
def serialize_comment_count(self) -> Any:
|
|
return self.post.comment_count
|
|
|
|
def serialize_note_count(self) -> Any:
|
|
return self.post.note_count
|
|
|
|
def serialize_relation_count(self) -> Any:
|
|
return self.post.relation_count
|
|
|
|
def serialize_feature_count(self) -> Any:
|
|
return self.post.feature_count
|
|
|
|
def serialize_last_feature_time(self) -> Any:
|
|
return self.post.last_feature_time
|
|
|
|
def serialize_favorited_by(self) -> Any:
|
|
return [
|
|
users.serialize_micro_user(rel.user, self.auth_user)
|
|
for rel in self.post.favorited_by
|
|
]
|
|
|
|
def serialize_has_custom_thumbnail(self) -> Any:
|
|
return files.has(get_post_thumbnail_backup_path(self.post))
|
|
|
|
def serialize_notes(self) -> Any:
|
|
return sorted(
|
|
[serialize_note(note) for note in self.post.notes],
|
|
key=lambda x: x["polygon"],
|
|
)
|
|
|
|
def serialize_comments(self) -> Any:
|
|
return [
|
|
comments.serialize_comment(comment, self.auth_user)
|
|
for comment in sorted(
|
|
self.post.comments, key=lambda comment: comment.creation_time
|
|
)
|
|
]
|
|
|
|
def serialize_pools(self) -> List[Any]:
|
|
return [
|
|
pools.serialize_micro_pool(pool)
|
|
for pool in sorted(
|
|
self.post.pools, key=lambda pool: pool.creation_time
|
|
)
|
|
]
|
|
|
|
|
|
def serialize_post(
|
|
post: Optional[model.Post], auth_user: model.User, options: List[str] = []
|
|
) -> Optional[rest.Response]:
|
|
if not post:
|
|
return None
|
|
return PostSerializer(post, auth_user).serialize(options)
|
|
|
|
|
|
def serialize_micro_post(
|
|
post: model.Post, auth_user: model.User
|
|
) -> Optional[rest.Response]:
|
|
return serialize_post(
|
|
post, auth_user=auth_user, options=["id", "thumbnailUrl"]
|
|
)
|
|
|
|
|
|
def get_post_count() -> int:
|
|
return db.session.query(sa.func.count(model.Post.post_id)).one()[0]
|
|
|
|
|
|
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
|
|
return (
|
|
db.session.query(model.Post)
|
|
.filter(model.Post.post_id == post_id)
|
|
.one_or_none()
|
|
)
|
|
|
|
|
|
def get_post_by_id(post_id: int) -> model.Post:
|
|
post = try_get_post_by_id(post_id)
|
|
if not post:
|
|
raise PostNotFoundError("Post %r not found." % post_id)
|
|
return post
|
|
|
|
|
|
def get_posts_by_ids(ids: List[int]) -> List[model.Post]:
|
|
if len(ids) == 0:
|
|
return []
|
|
posts = (
|
|
db.session.query(model.Post)
|
|
.filter(sa.sql.or_(model.Post.post_id == post_id for post_id in ids))
|
|
.all()
|
|
)
|
|
id_order = {v: k for k, v in enumerate(ids)}
|
|
return sorted(posts, key=lambda post: id_order.get(post.post_id))
|
|
|
|
|
|
def try_get_current_post_feature() -> Optional[model.PostFeature]:
|
|
return (
|
|
db.session.query(model.PostFeature)
|
|
.order_by(model.PostFeature.time.desc())
|
|
.first()
|
|
)
|
|
|
|
|
|
def try_get_featured_post() -> Optional[model.Post]:
|
|
post_feature = try_get_current_post_feature()
|
|
return post_feature.post if post_feature else None
|
|
|
|
|
|
def create_post(
|
|
content: bytes, tag_names: List[str], user: Optional[model.User]
|
|
) -> Tuple[model.Post, List[model.Tag]]:
|
|
post = model.Post()
|
|
post.safety = model.Post.SAFETY_SAFE
|
|
post.user = user
|
|
post.creation_time = datetime.utcnow()
|
|
post.flags = []
|
|
|
|
post.type = ""
|
|
post.checksum = ""
|
|
post.mime_type = ""
|
|
|
|
update_post_content(post, content)
|
|
new_tags = update_post_tags(post, tag_names)
|
|
|
|
db.session.add(post)
|
|
return post, new_tags
|
|
|
|
|
|
def update_post_safety(post: model.Post, safety: str) -> None:
|
|
assert post
|
|
safety = util.flip(SAFETY_MAP).get(safety, None)
|
|
if not safety:
|
|
raise InvalidPostSafetyError(
|
|
"Safety can be either of %r." % list(SAFETY_MAP.values())
|
|
)
|
|
post.safety = safety
|
|
|
|
|
|
def update_post_source(post: model.Post, source: Optional[str]) -> None:
|
|
assert post
|
|
if util.value_exceeds_column_size(source, model.Post.source):
|
|
raise InvalidPostSourceError("Source is too long.")
|
|
post.source = source or None
|
|
|
|
|
|
@sa.events.event.listens_for(model.Post, "after_insert")
|
|
def _after_post_insert(
|
|
_mapper: Any, _connection: Any, post: model.Post
|
|
) -> None:
|
|
_sync_post_content(post)
|
|
|
|
|
|
@sa.events.event.listens_for(model.Post, "after_update")
|
|
def _after_post_update(
|
|
_mapper: Any, _connection: Any, post: model.Post
|
|
) -> None:
|
|
_sync_post_content(post)
|
|
|
|
|
|
@sa.events.event.listens_for(model.Post, "before_delete")
|
|
def _before_post_delete(
|
|
_mapper: Any, _connection: Any, post: model.Post
|
|
) -> None:
|
|
if post.post_id:
|
|
if config.config["delete_source_files"]:
|
|
files.delete(get_post_content_path(post))
|
|
files.delete(get_post_thumbnail_path(post))
|
|
|
|
|
|
def _sync_post_content(post: model.Post) -> None:
|
|
regenerate_thumb = False
|
|
|
|
if hasattr(post, "__content"):
|
|
content = getattr(post, "__content")
|
|
files.save(get_post_content_path(post), content)
|
|
delattr(post, "__content")
|
|
regenerate_thumb = True
|
|
|
|
if hasattr(post, "__thumbnail"):
|
|
if getattr(post, "__thumbnail"):
|
|
files.save(
|
|
get_post_thumbnail_backup_path(post),
|
|
getattr(post, "__thumbnail"),
|
|
)
|
|
else:
|
|
files.delete(get_post_thumbnail_backup_path(post))
|
|
delattr(post, "__thumbnail")
|
|
regenerate_thumb = True
|
|
|
|
if regenerate_thumb:
|
|
generate_post_thumbnail(post)
|
|
|
|
|
|
def generate_alternate_formats(
|
|
post: model.Post, content: bytes
|
|
) -> List[Tuple[model.Post, List[model.Tag]]]:
|
|
assert post
|
|
assert content
|
|
new_posts = []
|
|
if mime.is_animated_gif(content):
|
|
tag_names = [tag.first_name for tag in post.tags]
|
|
|
|
if config.config["convert"]["gif"]["to_mp4"]:
|
|
mp4_post, new_tags = create_post(
|
|
images.Image(content).to_mp4(), tag_names, post.user
|
|
)
|
|
update_post_flags(mp4_post, ["loop"])
|
|
update_post_safety(mp4_post, post.safety)
|
|
update_post_source(mp4_post, post.source)
|
|
new_posts += [(mp4_post, new_tags)]
|
|
|
|
if config.config["convert"]["gif"]["to_webm"]:
|
|
webm_post, new_tags = create_post(
|
|
images.Image(content).to_webm(), tag_names, post.user
|
|
)
|
|
update_post_flags(webm_post, ["loop"])
|
|
update_post_safety(webm_post, post.safety)
|
|
update_post_source(webm_post, post.source)
|
|
new_posts += [(webm_post, new_tags)]
|
|
|
|
db.session.flush()
|
|
|
|
new_posts = [p for p in new_posts if p[0] is not None]
|
|
|
|
new_relations = [p[0].post_id for p in new_posts]
|
|
if len(new_relations) > 0:
|
|
update_post_relations(post, new_relations)
|
|
|
|
return new_posts
|
|
|
|
|
|
def get_default_flags(content: bytes) -> List[str]:
|
|
assert content
|
|
ret = []
|
|
if mime.is_video(mime.get_mime_type(content)):
|
|
ret.append(model.Post.FLAG_LOOP)
|
|
if images.Image(content).check_for_sound():
|
|
ret.append(model.Post.FLAG_SOUND)
|
|
return ret
|
|
|
|
|
|
def purge_post_signature(post: model.Post) -> None:
|
|
(
|
|
db.session.query(model.PostSignature)
|
|
.filter(model.PostSignature.post_id == post.post_id)
|
|
.delete()
|
|
)
|
|
|
|
|
|
def generate_post_signature(post: model.Post, content: bytes) -> None:
|
|
try:
|
|
unpacked_signature = image_hash.generate_signature(content)
|
|
packed_signature = image_hash.pack_signature(unpacked_signature)
|
|
words = image_hash.generate_words(unpacked_signature)
|
|
|
|
db.session.add(
|
|
model.PostSignature(
|
|
post=post, signature=packed_signature, words=words
|
|
)
|
|
)
|
|
except errors.ProcessingError:
|
|
if not config.config["allow_broken_uploads"]:
|
|
raise InvalidPostContentError(
|
|
"Unable to generate image hash data."
|
|
)
|
|
|
|
|
|
def update_all_post_signatures() -> None:
|
|
posts_to_hash = (
|
|
db.session.query(model.Post)
|
|
.filter(
|
|
(model.Post.type == model.Post.TYPE_IMAGE)
|
|
| (model.Post.type == model.Post.TYPE_ANIMATION)
|
|
)
|
|
.filter(model.Post.signature == None) # noqa: E711
|
|
.order_by(model.Post.post_id.asc())
|
|
.all()
|
|
)
|
|
for post in posts_to_hash:
|
|
try:
|
|
generate_post_signature(
|
|
post, files.get(get_post_content_path(post))
|
|
)
|
|
db.session.commit()
|
|
logger.info("Created Signature - Post %d", post.post_id)
|
|
except Exception as ex:
|
|
logger.exception(ex)
|
|
|
|
|
|
def update_all_md5_checksums() -> None:
|
|
posts_to_hash = (
|
|
db.session.query(model.Post)
|
|
.filter(model.Post.checksum_md5 == None) # noqa: E711
|
|
.order_by(model.Post.post_id.asc())
|
|
.all()
|
|
)
|
|
for post in posts_to_hash:
|
|
try:
|
|
post.checksum_md5 = util.get_md5(
|
|
files.get(get_post_content_path(post))
|
|
)
|
|
db.session.commit()
|
|
logger.info("Created MD5 - Post %d", post.post_id)
|
|
except Exception as ex:
|
|
logger.exception(ex)
|
|
|
|
|
|
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
|
|
assert post
|
|
if not content:
|
|
raise InvalidPostContentError("Post content missing.")
|
|
|
|
update_signature = False
|
|
post.mime_type = mime.get_mime_type(content)
|
|
if mime.is_flash(post.mime_type):
|
|
post.type = model.Post.TYPE_FLASH
|
|
elif mime.is_image(post.mime_type):
|
|
update_signature = True
|
|
if mime.is_animated_gif(content):
|
|
post.type = model.Post.TYPE_ANIMATION
|
|
else:
|
|
post.type = model.Post.TYPE_IMAGE
|
|
elif mime.is_video(post.mime_type):
|
|
post.type = model.Post.TYPE_VIDEO
|
|
else:
|
|
raise InvalidPostContentError(
|
|
"Unhandled file type: %r" % post.mime_type
|
|
)
|
|
|
|
post.checksum = util.get_sha1(content)
|
|
post.checksum_md5 = util.get_md5(content)
|
|
other_post = (
|
|
db.session.query(model.Post)
|
|
.filter(model.Post.checksum == post.checksum)
|
|
.filter(model.Post.post_id != post.post_id)
|
|
.one_or_none()
|
|
)
|
|
if (
|
|
other_post
|
|
and other_post.post_id
|
|
and other_post.post_id != post.post_id
|
|
):
|
|
raise PostAlreadyUploadedError(other_post)
|
|
|
|
if update_signature:
|
|
purge_post_signature(post)
|
|
post.signature = generate_post_signature(post, content)
|
|
|
|
post.file_size = len(content)
|
|
try:
|
|
image = images.Image(content)
|
|
post.canvas_width = image.width
|
|
post.canvas_height = image.height
|
|
except errors.ProcessingError:
|
|
if not config.config["allow_broken_uploads"]:
|
|
raise InvalidPostContentError("Unable to process image metadata")
|
|
else:
|
|
post.canvas_width = None
|
|
post.canvas_height = None
|
|
if (post.canvas_width is not None and post.canvas_width <= 0) or (
|
|
post.canvas_height is not None and post.canvas_height <= 0
|
|
):
|
|
if not config.config["allow_broken_uploads"]:
|
|
raise InvalidPostContentError(
|
|
"Invalid image dimensions returned during processing"
|
|
)
|
|
else:
|
|
post.canvas_width = None
|
|
post.canvas_height = None
|
|
setattr(post, "__content", content)
|
|
|
|
|
|
def update_post_thumbnail(
|
|
post: model.Post, content: Optional[bytes] = None
|
|
) -> None:
|
|
assert post
|
|
setattr(post, "__thumbnail", content)
|
|
|
|
|
|
def generate_post_thumbnail(post: model.Post) -> None:
|
|
assert post
|
|
if files.has(get_post_thumbnail_backup_path(post)):
|
|
content = files.get(get_post_thumbnail_backup_path(post))
|
|
else:
|
|
content = files.get(get_post_content_path(post))
|
|
try:
|
|
assert content
|
|
image = images.Image(content)
|
|
image.resize_fill(
|
|
int(config.config["thumbnails"]["post_width"]),
|
|
int(config.config["thumbnails"]["post_height"]),
|
|
)
|
|
files.save(get_post_thumbnail_path(post), image.to_jpeg())
|
|
except errors.ProcessingError:
|
|
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
|
|
|
|
|
|
def update_post_tags(
|
|
post: model.Post, tag_names: List[str]
|
|
) -> List[model.Tag]:
|
|
assert post
|
|
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
|
|
post.tags = existing_tags + new_tags
|
|
return new_tags
|
|
|
|
|
|
def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
|
|
assert post
|
|
try:
|
|
new_post_ids = [int(id) for id in new_post_ids]
|
|
except ValueError:
|
|
raise InvalidPostRelationError("A relation must be numeric post ID.")
|
|
old_posts = post.relations
|
|
old_post_ids = [int(p.post_id) for p in old_posts]
|
|
if new_post_ids:
|
|
new_posts = (
|
|
db.session.query(model.Post)
|
|
.filter(model.Post.post_id.in_(new_post_ids))
|
|
.all()
|
|
)
|
|
else:
|
|
new_posts = []
|
|
if len(new_posts) != len(new_post_ids):
|
|
raise InvalidPostRelationError("One of relations does not exist.")
|
|
if post.post_id in new_post_ids:
|
|
raise InvalidPostRelationError("Post cannot relate to itself.")
|
|
|
|
relations_to_del = [p for p in old_posts if p.post_id not in new_post_ids]
|
|
relations_to_add = [p for p in new_posts if p.post_id not in old_post_ids]
|
|
for relation in relations_to_del:
|
|
post.relations.remove(relation)
|
|
relation.relations.remove(post)
|
|
for relation in relations_to_add:
|
|
post.relations.append(relation)
|
|
relation.relations.append(post)
|
|
|
|
|
|
def update_post_notes(post: model.Post, notes: Any) -> None:
|
|
assert post
|
|
post.notes = []
|
|
for note in notes:
|
|
for field in ("polygon", "text"):
|
|
if field not in note:
|
|
raise InvalidPostNoteError("Note is missing %r field." % field)
|
|
if not note["text"]:
|
|
raise InvalidPostNoteError("A note's text cannot be empty.")
|
|
if not isinstance(note["polygon"], (list, tuple)):
|
|
raise InvalidPostNoteError(
|
|
"A note's polygon must be a list of points."
|
|
)
|
|
if len(note["polygon"]) < 3:
|
|
raise InvalidPostNoteError(
|
|
"A note's polygon must have at least 3 points."
|
|
)
|
|
for point in note["polygon"]:
|
|
if not isinstance(point, (list, tuple)):
|
|
raise InvalidPostNoteError(
|
|
"A note's polygon point must be a list of length 2."
|
|
)
|
|
if len(point) != 2:
|
|
raise InvalidPostNoteError(
|
|
"A point in note's polygon must have two coordinates."
|
|
)
|
|
try:
|
|
pos_x = float(point[0])
|
|
pos_y = float(point[1])
|
|
if not 0 <= pos_x <= 1 or not 0 <= pos_y <= 1:
|
|
raise InvalidPostNoteError(
|
|
"All points must fit in the image (0..1 range)."
|
|
)
|
|
except ValueError:
|
|
raise InvalidPostNoteError(
|
|
"A point in note's polygon must be numeric."
|
|
)
|
|
if util.value_exceeds_column_size(note["text"], model.PostNote.text):
|
|
raise InvalidPostNoteError("Note text is too long.")
|
|
post.notes.append(
|
|
model.PostNote(polygon=note["polygon"], text=str(note["text"]))
|
|
)
|
|
|
|
|
|
def update_post_flags(post: model.Post, flags: List[str]) -> None:
|
|
assert post
|
|
target_flags = []
|
|
for flag in flags:
|
|
flag = util.flip(FLAG_MAP).get(flag, None)
|
|
if not flag:
|
|
raise InvalidPostFlagError(
|
|
"Flag must be one of %r." % list(FLAG_MAP.values())
|
|
)
|
|
target_flags.append(flag)
|
|
post.flags = target_flags
|
|
|
|
|
|
def feature_post(post: model.Post, user: Optional[model.User]) -> None:
|
|
assert post
|
|
post_feature = model.PostFeature()
|
|
post_feature.time = datetime.utcnow()
|
|
post_feature.post = post
|
|
post_feature.user = user
|
|
db.session.add(post_feature)
|
|
|
|
|
|
def delete(post: model.Post) -> None:
|
|
assert post
|
|
db.session.delete(post)
|
|
|
|
|
|
def merge_posts(
|
|
source_post: model.Post, target_post: model.Post, replace_content: bool
|
|
) -> None:
|
|
assert source_post
|
|
assert target_post
|
|
if source_post.post_id == target_post.post_id:
|
|
raise InvalidPostRelationError("Cannot merge post with itself.")
|
|
|
|
def merge_tables(
|
|
table: model.Base,
|
|
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
|
|
source_post_id: int,
|
|
target_post_id: int,
|
|
) -> None:
|
|
alias1 = table
|
|
alias2 = sa.orm.util.aliased(table)
|
|
update_stmt = sa.sql.expression.update(alias1).where(
|
|
alias1.post_id == source_post_id
|
|
)
|
|
|
|
if anti_dup_func is not None:
|
|
update_stmt = update_stmt.where(
|
|
~sa.exists()
|
|
.where(anti_dup_func(alias1, alias2))
|
|
.where(alias2.post_id == target_post_id)
|
|
)
|
|
|
|
update_stmt = update_stmt.values(post_id=target_post_id)
|
|
db.session.execute(update_stmt)
|
|
|
|
def merge_tags(source_post_id: int, target_post_id: int) -> None:
|
|
merge_tables(
|
|
model.PostTag,
|
|
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
|
|
source_post_id,
|
|
target_post_id,
|
|
)
|
|
|
|
def merge_scores(source_post_id: int, target_post_id: int) -> None:
|
|
merge_tables(
|
|
model.PostScore,
|
|
lambda alias1, alias2: alias1.user_id == alias2.user_id,
|
|
source_post_id,
|
|
target_post_id,
|
|
)
|
|
|
|
def merge_favorites(source_post_id: int, target_post_id: int) -> None:
|
|
merge_tables(
|
|
model.PostFavorite,
|
|
lambda alias1, alias2: alias1.user_id == alias2.user_id,
|
|
source_post_id,
|
|
target_post_id,
|
|
)
|
|
|
|
def merge_comments(source_post_id: int, target_post_id: int) -> None:
|
|
merge_tables(model.Comment, None, source_post_id, target_post_id)
|
|
|
|
def merge_relations(source_post_id: int, target_post_id: int) -> None:
|
|
alias1 = model.PostRelation
|
|
alias2 = sa.orm.util.aliased(model.PostRelation)
|
|
update_stmt = (
|
|
sa.sql.expression.update(alias1)
|
|
.where(alias1.parent_id == source_post_id)
|
|
.where(alias1.child_id != target_post_id)
|
|
.where(
|
|
~sa.exists()
|
|
.where(alias2.child_id == alias1.child_id)
|
|
.where(alias2.parent_id == target_post_id)
|
|
)
|
|
.values(parent_id=target_post_id)
|
|
)
|
|
db.session.execute(update_stmt)
|
|
|
|
update_stmt = (
|
|
sa.sql.expression.update(alias1)
|
|
.where(alias1.child_id == source_post_id)
|
|
.where(alias1.parent_id != target_post_id)
|
|
.where(
|
|
~sa.exists()
|
|
.where(alias2.parent_id == alias1.parent_id)
|
|
.where(alias2.child_id == target_post_id)
|
|
)
|
|
.values(child_id=target_post_id)
|
|
)
|
|
db.session.execute(update_stmt)
|
|
|
|
merge_tags(source_post.post_id, target_post.post_id)
|
|
merge_comments(source_post.post_id, target_post.post_id)
|
|
merge_scores(source_post.post_id, target_post.post_id)
|
|
merge_favorites(source_post.post_id, target_post.post_id)
|
|
merge_relations(source_post.post_id, target_post.post_id)
|
|
|
|
def transfer_flags(source_post_id: int, target_post_id: int) -> None:
|
|
target = get_post_by_id(target_post_id)
|
|
source = get_post_by_id(source_post_id)
|
|
target.flags = source.flags
|
|
db.session.flush()
|
|
|
|
content = None
|
|
if replace_content:
|
|
content = files.get(get_post_content_path(source_post))
|
|
transfer_flags(source_post.post_id, target_post.post_id)
|
|
|
|
# fixes unknown issue with SA's cascade deletions
|
|
purge_post_signature(source_post)
|
|
delete(source_post)
|
|
db.session.flush()
|
|
|
|
if content is not None:
|
|
update_post_content(target_post, content)
|
|
|
|
|
|
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
|
|
checksum = util.get_sha1(image_content)
|
|
return (
|
|
db.session.query(model.Post)
|
|
.filter(model.Post.checksum == checksum)
|
|
.one_or_none()
|
|
)
|
|
|
|
|
|
def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]:
|
|
query_signature = image_hash.generate_signature(image_content)
|
|
query_words = image_hash.generate_words(query_signature)
|
|
|
|
"""
|
|
The unnest function is used here to expand one row containing the 'words'
|
|
array into multiple rows each containing a singular word.
|
|
|
|
Documentation of the unnest function can be found here:
|
|
https://www.postgresql.org/docs/9.2/functions-array.html
|
|
"""
|
|
|
|
dbquery = """
|
|
SELECT s.post_id, s.signature, count(a.query) AS score
|
|
FROM post_signature AS s, unnest(s.words, :q) AS a(word, query)
|
|
WHERE a.word = a.query
|
|
GROUP BY s.post_id
|
|
ORDER BY score DESC LIMIT 100;
|
|
"""
|
|
|
|
candidates = db.session.execute(dbquery, {"q": query_words})
|
|
data = tuple(
|
|
zip(
|
|
*[
|
|
(post_id, image_hash.unpack_signature(packedsig))
|
|
for post_id, packedsig, score in candidates
|
|
]
|
|
)
|
|
)
|
|
if data:
|
|
candidate_post_ids, sigarray = data
|
|
distances = image_hash.normalized_distance(sigarray, query_signature)
|
|
return [
|
|
(distance, try_get_post_by_id(candidate_post_id))
|
|
for candidate_post_id, distance in zip(
|
|
candidate_post_ids, distances
|
|
)
|
|
if distance < image_hash.DISTANCE_CUTOFF
|
|
]
|
|
else:
|
|
return []
|