gallery.accords-library.com/server/szurubooru/func/posts.py

947 lines
28 KiB
Python
Raw Normal View History

import hmac
import logging
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
import sqlalchemy as sa
2016-04-30 21:17:08 +00:00
from szurubooru import config, db, errors, model, rest
from szurubooru.func import (
comments,
files,
image_hash,
images,
mime,
pools,
scores,
serialization,
snapshots,
tags,
users,
util,
)
logger = logging.getLogger(__name__)
2017-04-24 21:30:53 +00:00
EMPTY_PIXEL = (
b"\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00"
b"\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00"
b"\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b"
)
2016-04-22 18:58:04 +00:00
class PostNotFoundError(errors.NotFoundError):
pass
class PostAlreadyFeaturedError(errors.ValidationError):
pass
class PostAlreadyUploadedError(errors.ValidationError):
def __init__(self, other_post: model.Post) -> None:
super().__init__(
"Post already uploaded (%d)" % other_post.post_id,
{
"otherPostUrl": get_post_content_url(other_post),
"otherPostId": other_post.post_id,
},
)
class InvalidPostIdError(errors.ValidationError):
pass
class InvalidPostSafetyError(errors.ValidationError):
pass
class InvalidPostSourceError(errors.ValidationError):
pass
class InvalidPostContentError(errors.ValidationError):
pass
class InvalidPostRelationError(errors.ValidationError):
pass
class InvalidPostNoteError(errors.ValidationError):
pass
class InvalidPostFlagError(errors.ValidationError):
pass
2016-04-30 21:17:08 +00:00
SAFETY_MAP = {
model.Post.SAFETY_SAFE: "safe",
model.Post.SAFETY_SKETCHY: "sketchy",
model.Post.SAFETY_UNSAFE: "unsafe",
2016-04-30 21:17:08 +00:00
}
2016-04-30 21:17:08 +00:00
TYPE_MAP = {
model.Post.TYPE_IMAGE: "image",
model.Post.TYPE_ANIMATION: "animation",
model.Post.TYPE_VIDEO: "video",
model.Post.TYPE_FLASH: "flash",
2016-04-30 21:17:08 +00:00
}
2016-05-10 09:57:05 +00:00
FLAG_MAP = {
model.Post.FLAG_LOOP: "loop",
model.Post.FLAG_SOUND: "sound",
2016-05-10 09:57:05 +00:00
}
2016-04-30 21:17:08 +00:00
def get_post_security_hash(id: int) -> str:
return hmac.new(
config.config["secret"].encode("utf8"),
msg=str(id).encode("utf-8"),
digestmod="md5",
).hexdigest()[0:16]
def get_post_content_url(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
return "%s/posts/%d_%s.%s" % (
config.config["data_url"].rstrip("/"),
2016-04-30 21:17:08 +00:00
post.post_id,
get_post_security_hash(post.post_id),
mime.get_extension(post.mime_type) or "dat",
)
2016-04-30 21:17:08 +00:00
def get_post_thumbnail_url(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
return "%s/generated-thumbnails/%d_%s.jpg" % (
config.config["data_url"].rstrip("/"),
post.post_id,
get_post_security_hash(post.post_id),
)
2016-04-30 21:17:08 +00:00
def get_post_content_path(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
assert post.post_id
return "posts/%d_%s.%s" % (
post.post_id,
get_post_security_hash(post.post_id),
mime.get_extension(post.mime_type) or "dat",
)
2016-04-30 21:17:08 +00:00
def get_post_thumbnail_path(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
return "generated-thumbnails/%d_%s.jpg" % (
post.post_id,
get_post_security_hash(post.post_id),
)
2016-04-30 21:17:08 +00:00
def get_post_thumbnail_backup_path(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
return "posts/custom-thumbnails/%d_%s.dat" % (
post.post_id,
get_post_security_hash(post.post_id),
)
2016-04-30 21:17:08 +00:00
def serialize_note(note: model.PostNote) -> rest.Response:
2016-08-14 08:45:00 +00:00
assert note
2016-04-30 21:17:08 +00:00
return {
"polygon": note.polygon,
"text": note.text,
2016-04-30 21:17:08 +00:00
}
class PostSerializer(serialization.BaseSerializer):
def __init__(self, post: model.Post, auth_user: model.User) -> None:
self.post = post
self.auth_user = auth_user
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
"id": self.serialize_id,
"version": self.serialize_version,
"creationTime": self.serialize_creation_time,
"lastEditTime": self.serialize_last_edit_time,
"safety": self.serialize_safety,
"source": self.serialize_source,
"type": self.serialize_type,
"mimeType": self.serialize_mime,
"checksum": self.serialize_checksum,
"fileSize": self.serialize_file_size,
"canvasWidth": self.serialize_canvas_width,
"canvasHeight": self.serialize_canvas_height,
"contentUrl": self.serialize_content_url,
"thumbnailUrl": self.serialize_thumbnail_url,
"flags": self.serialize_flags,
"tags": self.serialize_tags,
"relations": self.serialize_relations,
"user": self.serialize_user,
"score": self.serialize_score,
"ownScore": self.serialize_own_score,
"ownFavorite": self.serialize_own_favorite,
"tagCount": self.serialize_tag_count,
"favoriteCount": self.serialize_favorite_count,
"commentCount": self.serialize_comment_count,
"noteCount": self.serialize_note_count,
"relationCount": self.serialize_relation_count,
"featureCount": self.serialize_feature_count,
"lastFeatureTime": self.serialize_last_feature_time,
"favoritedBy": self.serialize_favorited_by,
"hasCustomThumbnail": self.serialize_has_custom_thumbnail,
"notes": self.serialize_notes,
"comments": self.serialize_comments,
"pools": self.serialize_pools,
}
def serialize_id(self) -> Any:
return self.post.post_id
def serialize_version(self) -> Any:
return self.post.version
def serialize_creation_time(self) -> Any:
return self.post.creation_time
def serialize_last_edit_time(self) -> Any:
return self.post.last_edit_time
def serialize_safety(self) -> Any:
return SAFETY_MAP[self.post.safety]
def serialize_source(self) -> Any:
return self.post.source
def serialize_type(self) -> Any:
return TYPE_MAP[self.post.type]
def serialize_mime(self) -> Any:
return self.post.mime_type
def serialize_checksum(self) -> Any:
return self.post.checksum
def serialize_file_size(self) -> Any:
return self.post.file_size
def serialize_canvas_width(self) -> Any:
return self.post.canvas_width
def serialize_canvas_height(self) -> Any:
return self.post.canvas_height
def serialize_content_url(self) -> Any:
return get_post_content_url(self.post)
def serialize_thumbnail_url(self) -> Any:
return get_post_thumbnail_url(self.post)
def serialize_flags(self) -> Any:
return self.post.flags
def serialize_tags(self) -> Any:
2017-10-01 19:46:53 +00:00
return [
{
"names": [name.name for name in tag.names],
"category": tag.category.name,
"usages": tag.post_count,
2017-10-01 19:46:53 +00:00
}
for tag in tags.sort_tags(self.post.tags)
]
def serialize_relations(self) -> Any:
return sorted(
{
post["id"]: post
for post in [
serialize_micro_post(rel, self.auth_user)
for rel in self.post.relations
]
}.values(),
key=lambda post: post["id"],
)
def serialize_user(self) -> Any:
return users.serialize_micro_user(self.post.user, self.auth_user)
def serialize_score(self) -> Any:
return self.post.score
def serialize_own_score(self) -> Any:
return scores.get_score(self.post, self.auth_user)
def serialize_own_favorite(self) -> Any:
return (
len(
[
user
for user in self.post.favorited_by
if user.user_id == self.auth_user.user_id
]
)
> 0
)
def serialize_tag_count(self) -> Any:
return self.post.tag_count
def serialize_favorite_count(self) -> Any:
return self.post.favorite_count
def serialize_comment_count(self) -> Any:
return self.post.comment_count
def serialize_note_count(self) -> Any:
return self.post.note_count
def serialize_relation_count(self) -> Any:
return self.post.relation_count
def serialize_feature_count(self) -> Any:
return self.post.feature_count
def serialize_last_feature_time(self) -> Any:
return self.post.last_feature_time
def serialize_favorited_by(self) -> Any:
return [
users.serialize_micro_user(rel.user, self.auth_user)
for rel in self.post.favorited_by
]
def serialize_has_custom_thumbnail(self) -> Any:
return files.has(get_post_thumbnail_backup_path(self.post))
def serialize_notes(self) -> Any:
return sorted(
[serialize_note(note) for note in self.post.notes],
key=lambda x: x["polygon"],
)
def serialize_comments(self) -> Any:
return [
comments.serialize_comment(comment, self.auth_user)
for comment in sorted(
self.post.comments, key=lambda comment: comment.creation_time
)
]
2020-06-03 00:43:18 +00:00
def serialize_pools(self) -> List[Any]:
2020-05-04 09:20:23 +00:00
return [
pools.serialize_pool(pool)
for pool in sorted(
self.post.pools, key=lambda pool: pool.creation_time
)
]
2020-05-04 09:20:23 +00:00
def serialize_post(
post: Optional[model.Post], auth_user: model.User, options: List[str] = []
) -> Optional[rest.Response]:
if not post:
return None
return PostSerializer(post, auth_user).serialize(options)
def serialize_micro_post(
post: model.Post, auth_user: model.User
) -> Optional[rest.Response]:
return serialize_post(
post, auth_user=auth_user, options=["id", "thumbnailUrl"]
)
def get_post_count() -> int:
return db.session.query(sa.func.count(model.Post.post_id)).one()[0]
2016-04-22 18:58:04 +00:00
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
2017-04-24 21:30:53 +00:00
return (
db.session.query(model.Post)
2017-04-24 21:30:53 +00:00
.filter(model.Post.post_id == post_id)
.one_or_none()
)
2016-04-22 18:58:04 +00:00
def get_post_by_id(post_id: int) -> model.Post:
post = try_get_post_by_id(post_id)
if not post:
raise PostNotFoundError("Post %r not found." % post_id)
return post
2020-06-03 00:43:18 +00:00
def get_posts_by_ids(ids: List[int]) -> List[model.Post]:
2020-05-04 07:09:33 +00:00
if len(ids) == 0:
return []
posts = (
db.session.query(model.Post)
.filter(sa.sql.or_(model.Post.post_id == post_id for post_id in ids))
.all()
)
id_order = {v: k for k, v in enumerate(ids)}
2020-05-04 07:09:33 +00:00
return sorted(posts, key=lambda post: id_order.get(post.post_id))
def try_get_current_post_feature() -> Optional[model.PostFeature]:
2017-04-24 21:30:53 +00:00
return (
db.session.query(model.PostFeature)
2017-04-24 21:30:53 +00:00
.order_by(model.PostFeature.time.desc())
.first()
)
def try_get_featured_post() -> Optional[model.Post]:
post_feature = try_get_current_post_feature()
2016-04-22 18:58:04 +00:00
return post_feature.post if post_feature else None
def create_post(
content: bytes, tag_names: List[str], user: Optional[model.User]
) -> Tuple[model.Post, List[model.Tag]]:
post = model.Post()
post.safety = model.Post.SAFETY_SAFE
2016-04-30 21:17:08 +00:00
post.user = user
post.creation_time = datetime.utcnow()
post.flags = []
2016-04-30 21:17:08 +00:00
post.type = ""
post.checksum = ""
post.mime_type = ""
2016-04-30 21:17:08 +00:00
update_post_content(post, content)
new_tags = update_post_tags(post, tag_names)
db.session.add(post)
return post, new_tags
2016-04-30 21:17:08 +00:00
def update_post_safety(post: model.Post, safety: str) -> None:
2016-08-14 08:45:00 +00:00
assert post
2016-04-30 21:17:08 +00:00
safety = util.flip(SAFETY_MAP).get(safety, None)
if not safety:
raise InvalidPostSafetyError(
"Safety can be either of %r." % list(SAFETY_MAP.values())
)
2016-04-30 21:17:08 +00:00
post.safety = safety
def update_post_source(post: model.Post, source: Optional[str]) -> None:
2016-08-14 08:45:00 +00:00
assert post
if util.value_exceeds_column_size(source, model.Post.source):
raise InvalidPostSourceError("Source is too long.")
post.source = source or None
2016-04-30 21:17:08 +00:00
@sa.events.event.listens_for(model.Post, "after_insert")
def _after_post_insert(
_mapper: Any, _connection: Any, post: model.Post
) -> None:
_sync_post_content(post)
@sa.events.event.listens_for(model.Post, "after_update")
def _after_post_update(
_mapper: Any, _connection: Any, post: model.Post
) -> None:
_sync_post_content(post)
@sa.events.event.listens_for(model.Post, "before_delete")
def _before_post_delete(
_mapper: Any, _connection: Any, post: model.Post
) -> None:
if post.post_id:
if config.config["delete_source_files"]:
files.delete(get_post_content_path(post))
files.delete(get_post_thumbnail_path(post))
def _sync_post_content(post: model.Post) -> None:
regenerate_thumb = False
if hasattr(post, "__content"):
content = getattr(post, "__content")
files.save(get_post_content_path(post), content)
delattr(post, "__content")
regenerate_thumb = True
if hasattr(post, "__thumbnail"):
if getattr(post, "__thumbnail"):
files.save(
get_post_thumbnail_backup_path(post),
getattr(post, "__thumbnail"),
)
else:
files.delete(get_post_thumbnail_backup_path(post))
delattr(post, "__thumbnail")
regenerate_thumb = True
if regenerate_thumb:
generate_post_thumbnail(post)
def generate_alternate_formats(
post: model.Post, content: bytes
) -> List[Tuple[model.Post, List[model.Tag]]]:
assert post
assert content
new_posts = []
if mime.is_animated_gif(content):
2019-02-11 18:47:46 +00:00
tag_names = [tag.first_name for tag in post.tags]
if config.config["convert"]["gif"]["to_mp4"]:
mp4_post, new_tags = create_post(
images.Image(content).to_mp4(), tag_names, post.user
)
update_post_flags(mp4_post, ["loop"])
update_post_safety(mp4_post, post.safety)
update_post_source(mp4_post, post.source)
new_posts += [(mp4_post, new_tags)]
if config.config["convert"]["gif"]["to_webm"]:
webm_post, new_tags = create_post(
images.Image(content).to_webm(), tag_names, post.user
)
update_post_flags(webm_post, ["loop"])
update_post_safety(webm_post, post.safety)
update_post_source(webm_post, post.source)
new_posts += [(webm_post, new_tags)]
db.session.flush()
new_posts = [p for p in new_posts if p[0] is not None]
new_relations = [p[0].post_id for p in new_posts]
if len(new_relations) > 0:
update_post_relations(post, new_relations)
return new_posts
def get_default_flags(content: bytes) -> List[str]:
assert content
ret = []
if mime.is_video(mime.get_mime_type(content)):
ret.append(model.Post.FLAG_LOOP)
if images.Image(content).check_for_sound():
ret.append(model.Post.FLAG_SOUND)
return ret
def purge_post_signature(post: model.Post) -> None:
(
db.session.query(model.PostSignature)
.filter(model.PostSignature.post_id == post.post_id)
.delete()
)
def generate_post_signature(post: model.Post, content: bytes) -> None:
try:
unpacked_signature = image_hash.generate_signature(content)
packed_signature = image_hash.pack_signature(unpacked_signature)
words = image_hash.generate_words(unpacked_signature)
db.session.add(
model.PostSignature(
post=post, signature=packed_signature, words=words
)
)
except errors.ProcessingError:
if not config.config["allow_broken_uploads"]:
raise InvalidPostContentError(
"Unable to generate image hash data."
)
def update_all_post_signatures() -> None:
posts_to_hash = (
db.session.query(model.Post)
.filter(
(model.Post.type == model.Post.TYPE_IMAGE)
| (model.Post.type == model.Post.TYPE_ANIMATION)
)
2020-06-05 14:02:18 +00:00
.filter(model.Post.signature == None) # noqa: E711
.order_by(model.Post.post_id.asc())
.all()
)
for post in posts_to_hash:
try:
generate_post_signature(
post, files.get(get_post_content_path(post))
)
db.session.commit()
logger.info("Hashed Post %d", post.post_id)
except Exception as ex:
logger.exception(ex)
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
2016-08-14 08:45:00 +00:00
assert post
2016-04-30 21:17:08 +00:00
if not content:
raise InvalidPostContentError("Post content missing.")
update_signature = False
2016-04-30 21:17:08 +00:00
post.mime_type = mime.get_mime_type(content)
if mime.is_flash(post.mime_type):
post.type = model.Post.TYPE_FLASH
2016-04-30 21:17:08 +00:00
elif mime.is_image(post.mime_type):
update_signature = True
2016-04-30 21:17:08 +00:00
if mime.is_animated_gif(content):
post.type = model.Post.TYPE_ANIMATION
2016-04-30 21:17:08 +00:00
else:
post.type = model.Post.TYPE_IMAGE
2016-04-30 21:17:08 +00:00
elif mime.is_video(post.mime_type):
post.type = model.Post.TYPE_VIDEO
2016-04-30 21:17:08 +00:00
else:
raise InvalidPostContentError(
"Unhandled file type: %r" % post.mime_type
)
2016-04-30 21:17:08 +00:00
post.checksum = util.get_sha1(content)
2017-04-24 21:30:53 +00:00
other_post = (
db.session.query(model.Post)
2017-04-24 21:30:53 +00:00
.filter(model.Post.checksum == post.checksum)
.filter(model.Post.post_id != post.post_id)
.one_or_none()
)
if (
other_post
and other_post.post_id
and other_post.post_id != post.post_id
):
raise PostAlreadyUploadedError(other_post)
2016-04-30 21:17:08 +00:00
if update_signature:
purge_post_signature(post)
post.signature = generate_post_signature(post, content)
2016-04-30 21:17:08 +00:00
post.file_size = len(content)
try:
image = images.Image(content)
post.canvas_width = image.width
post.canvas_height = image.height
except errors.ProcessingError:
if not config.config["allow_broken_uploads"]:
raise InvalidPostContentError("Unable to process image metadata")
else:
post.canvas_width = None
post.canvas_height = None
if (post.canvas_width is not None and post.canvas_width <= 0) or (
post.canvas_height is not None and post.canvas_height <= 0
):
if not config.config["allow_broken_uploads"]:
raise InvalidPostContentError(
"Invalid image dimensions returned during processing"
)
else:
post.canvas_width = None
post.canvas_height = None
setattr(post, "__content", content)
2016-04-30 21:17:08 +00:00
def update_post_thumbnail(
post: model.Post, content: Optional[bytes] = None
) -> None:
2016-08-14 08:45:00 +00:00
assert post
setattr(post, "__thumbnail", content)
2016-04-30 21:17:08 +00:00
def generate_post_thumbnail(post: model.Post) -> None:
2016-08-14 08:45:00 +00:00
assert post
if files.has(get_post_thumbnail_backup_path(post)):
content = files.get(get_post_thumbnail_backup_path(post))
else:
content = files.get(get_post_content_path(post))
2016-04-30 21:17:08 +00:00
try:
assert content
2016-04-30 21:17:08 +00:00
image = images.Image(content)
image.resize_fill(
int(config.config["thumbnails"]["post_width"]),
int(config.config["thumbnails"]["post_height"]),
)
2016-04-30 21:17:08 +00:00
files.save(get_post_thumbnail_path(post), image.to_jpeg())
except errors.ProcessingError:
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
def update_post_tags(
post: model.Post, tag_names: List[str]
) -> List[model.Tag]:
2016-08-14 08:45:00 +00:00
assert post
2016-04-30 21:17:08 +00:00
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
post.tags = existing_tags + new_tags
return new_tags
2016-04-30 21:17:08 +00:00
def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
2016-08-14 08:45:00 +00:00
assert post
try:
new_post_ids = [int(id) for id in new_post_ids]
except ValueError:
raise InvalidPostRelationError("A relation must be numeric post ID.")
old_posts = post.relations
old_post_ids = [int(p.post_id) for p in old_posts]
if new_post_ids:
2017-04-24 21:30:53 +00:00
new_posts = (
db.session.query(model.Post)
2017-04-24 21:30:53 +00:00
.filter(model.Post.post_id.in_(new_post_ids))
.all()
)
else:
new_posts = []
if len(new_posts) != len(new_post_ids):
raise InvalidPostRelationError("One of relations does not exist.")
if post.post_id in new_post_ids:
raise InvalidPostRelationError("Post cannot relate to itself.")
relations_to_del = [p for p in old_posts if p.post_id not in new_post_ids]
relations_to_add = [p for p in new_posts if p.post_id not in old_post_ids]
for relation in relations_to_del:
post.relations.remove(relation)
relation.relations.remove(post)
for relation in relations_to_add:
post.relations.append(relation)
relation.relations.append(post)
2016-04-30 21:17:08 +00:00
def update_post_notes(post: model.Post, notes: Any) -> None:
2016-08-14 08:45:00 +00:00
assert post
2016-04-30 21:17:08 +00:00
post.notes = []
for note in notes:
for field in ("polygon", "text"):
2016-04-30 21:17:08 +00:00
if field not in note:
raise InvalidPostNoteError("Note is missing %r field." % field)
if not note["text"]:
raise InvalidPostNoteError("A note's text cannot be empty.")
if not isinstance(note["polygon"], (list, tuple)):
raise InvalidPostNoteError(
"A note's polygon must be a list of points."
)
if len(note["polygon"]) < 3:
2016-04-30 21:17:08 +00:00
raise InvalidPostNoteError(
"A note's polygon must have at least 3 points."
)
for point in note["polygon"]:
if not isinstance(point, (list, tuple)):
raise InvalidPostNoteError(
"A note's polygon point must be a list of length 2."
)
2016-04-30 21:17:08 +00:00
if len(point) != 2:
raise InvalidPostNoteError(
"A point in note's polygon must have two coordinates."
)
2016-04-30 21:17:08 +00:00
try:
pos_x = float(point[0])
pos_y = float(point[1])
if not 0 <= pos_x <= 1 or not 0 <= pos_y <= 1:
raise InvalidPostNoteError(
"All points must fit in the image (0..1 range)."
)
2016-04-30 21:17:08 +00:00
except ValueError:
raise InvalidPostNoteError(
"A point in note's polygon must be numeric."
)
if util.value_exceeds_column_size(note["text"], model.PostNote.text):
raise InvalidPostNoteError("Note text is too long.")
2016-04-30 21:17:08 +00:00
post.notes.append(
model.PostNote(polygon=note["polygon"], text=str(note["text"]))
)
2016-04-30 21:17:08 +00:00
def update_post_flags(post: model.Post, flags: List[str]) -> None:
2016-08-14 08:45:00 +00:00
assert post
2016-05-10 09:57:05 +00:00
target_flags = []
2016-04-30 21:17:08 +00:00
for flag in flags:
2016-05-10 09:57:05 +00:00
flag = util.flip(FLAG_MAP).get(flag, None)
if not flag:
2016-04-30 21:17:08 +00:00
raise InvalidPostFlagError(
"Flag must be one of %r." % list(FLAG_MAP.values())
)
2016-05-10 09:57:05 +00:00
target_flags.append(flag)
post.flags = target_flags
2016-04-30 21:17:08 +00:00
def feature_post(post: model.Post, user: Optional[model.User]) -> None:
2016-08-14 08:45:00 +00:00
assert post
post_feature = model.PostFeature()
post_feature.time = datetime.utcnow()
2016-04-22 18:58:04 +00:00
post_feature.post = post
post_feature.user = user
db.session.add(post_feature)
def delete(post: model.Post) -> None:
2016-08-14 08:45:00 +00:00
assert post
db.session.delete(post)
2016-10-21 19:48:08 +00:00
def merge_posts(
source_post: model.Post, target_post: model.Post, replace_content: bool
) -> None:
2016-10-21 19:48:08 +00:00
assert source_post
assert target_post
if source_post.post_id == target_post.post_id:
raise InvalidPostRelationError("Cannot merge post with itself.")
2016-10-21 19:48:08 +00:00
def merge_tables(
table: model.Base,
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
source_post_id: int,
target_post_id: int,
) -> None:
2016-10-22 15:57:25 +00:00
alias1 = table
alias2 = sa.orm.util.aliased(table)
update_stmt = sa.sql.expression.update(alias1).where(
alias1.post_id == source_post_id
)
2016-10-21 19:48:08 +00:00
if anti_dup_func is not None:
update_stmt = update_stmt.where(
~sa.exists()
.where(anti_dup_func(alias1, alias2))
.where(alias2.post_id == target_post_id)
)
2016-10-21 19:48:08 +00:00
2016-10-22 15:57:25 +00:00
update_stmt = update_stmt.values(post_id=target_post_id)
2016-10-21 19:48:08 +00:00
db.session.execute(update_stmt)
def merge_tags(source_post_id: int, target_post_id: int) -> None:
2016-10-21 19:48:08 +00:00
merge_tables(
model.PostTag,
2016-10-21 19:48:08 +00:00
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
source_post_id,
target_post_id,
)
2016-10-21 19:48:08 +00:00
def merge_scores(source_post_id: int, target_post_id: int) -> None:
2016-10-21 19:48:08 +00:00
merge_tables(
model.PostScore,
2016-10-21 19:48:08 +00:00
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id,
)
2016-10-21 19:48:08 +00:00
def merge_favorites(source_post_id: int, target_post_id: int) -> None:
2016-10-21 19:48:08 +00:00
merge_tables(
model.PostFavorite,
2016-10-21 19:48:08 +00:00
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id,
)
2016-10-21 19:48:08 +00:00
def merge_comments(source_post_id: int, target_post_id: int) -> None:
merge_tables(model.Comment, None, source_post_id, target_post_id)
2016-10-21 19:48:08 +00:00
def merge_relations(source_post_id: int, target_post_id: int) -> None:
alias1 = model.PostRelation
alias2 = sa.orm.util.aliased(model.PostRelation)
2017-02-03 20:42:15 +00:00
update_stmt = (
sa.sql.expression.update(alias1)
2016-10-22 15:57:25 +00:00
.where(alias1.parent_id == source_post_id)
.where(alias1.child_id != target_post_id)
2017-02-03 20:42:15 +00:00
.where(
~sa.exists()
2016-10-22 15:57:25 +00:00
.where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_post_id)
)
.values(parent_id=target_post_id)
)
2016-10-21 19:48:08 +00:00
db.session.execute(update_stmt)
2017-02-03 20:42:15 +00:00
update_stmt = (
sa.sql.expression.update(alias1)
2016-10-22 15:57:25 +00:00
.where(alias1.child_id == source_post_id)
.where(alias1.parent_id != target_post_id)
2017-02-03 20:42:15 +00:00
.where(
~sa.exists()
2016-10-22 15:57:25 +00:00
.where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_post_id)
)
.values(child_id=target_post_id)
)
2016-10-21 19:48:08 +00:00
db.session.execute(update_stmt)
merge_tags(source_post.post_id, target_post.post_id)
merge_comments(source_post.post_id, target_post.post_id)
merge_scores(source_post.post_id, target_post.post_id)
merge_favorites(source_post.post_id, target_post.post_id)
merge_relations(source_post.post_id, target_post.post_id)
def transfer_flags(source_post_id: int, target_post_id: int) -> None:
target = get_post_by_id(target_post_id)
source = get_post_by_id(source_post_id)
target.flags = source.flags
db.session.flush()
content = None
if replace_content:
content = files.get(get_post_content_path(source_post))
transfer_flags(source_post.post_id, target_post.post_id)
# fixes unknown issue with SA's cascade deletions
purge_post_signature(source_post)
delete(source_post)
db.session.flush()
if content is not None:
update_post_content(target_post, content)
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content)
2017-04-24 21:30:53 +00:00
return (
db.session.query(model.Post)
2017-04-24 21:30:53 +00:00
.filter(model.Post.checksum == checksum)
.one_or_none()
)
def search_by_image(image_content: bytes) -> List[Tuple[float, model.Post]]:
query_signature = image_hash.generate_signature(image_content)
query_words = image_hash.generate_words(query_signature)
"""
The unnest function is used here to expand one row containing the 'words'
array into multiple rows each containing a singular word.
Documentation of the unnest function can be found here:
https://www.postgresql.org/docs/9.2/functions-array.html
"""
dbquery = """
SELECT s.post_id, s.signature, count(a.query) AS score
FROM post_signature AS s, unnest(s.words, :q) AS a(word, query)
WHERE a.word = a.query
GROUP BY s.post_id
ORDER BY score DESC LIMIT 100;
"""
candidates = db.session.execute(dbquery, {"q": query_words})
data = tuple(
zip(
*[
(post_id, image_hash.unpack_signature(packedsig))
for post_id, packedsig, score in candidates
]
)
)
if data:
candidate_post_ids, sigarray = data
distances = image_hash.normalized_distance(sigarray, query_signature)
return [
(distance, try_get_post_by_id(candidate_post_id))
for candidate_post_id, distance in zip(
candidate_post_ids, distances
)
if distance < image_hash.DISTANCE_CUTOFF
]
else:
return []