gallery.accords-library.com/server/szurubooru/func/posts.py

789 lines
25 KiB
Python
Raw Normal View History

import hmac
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
2016-04-30 21:17:08 +00:00
from szurubooru.func import (
users, scores, comments, tags, util,
mime, images, files, image_hash, serialization, snapshots)
2016-04-30 21:17:08 +00:00
2017-04-24 21:30:53 +00:00
EMPTY_PIXEL = (
b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00'
b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00'
b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b')
2016-04-22 18:58:04 +00:00
class PostNotFoundError(errors.NotFoundError):
pass
class PostAlreadyFeaturedError(errors.ValidationError):
pass
class PostAlreadyUploadedError(errors.ValidationError):
def __init__(self, other_post: model.Post) -> None:
super().__init__(
'Post already uploaded (%d)' % other_post.post_id,
{
'otherPostUrl': get_post_content_url(other_post),
'otherPostId': other_post.post_id,
})
class InvalidPostIdError(errors.ValidationError):
pass
class InvalidPostSafetyError(errors.ValidationError):
pass
class InvalidPostSourceError(errors.ValidationError):
pass
class InvalidPostContentError(errors.ValidationError):
pass
class InvalidPostRelationError(errors.ValidationError):
pass
class InvalidPostNoteError(errors.ValidationError):
pass
class InvalidPostFlagError(errors.ValidationError):
pass
2016-04-30 21:17:08 +00:00
class PostLookalike(image_hash.Lookalike):
def __init__(self, score: int, distance: float, post: model.Post) -> None:
super().__init__(score, distance, post.post_id)
self.post = post
2016-04-30 21:17:08 +00:00
SAFETY_MAP = {
model.Post.SAFETY_SAFE: 'safe',
model.Post.SAFETY_SKETCHY: 'sketchy',
model.Post.SAFETY_UNSAFE: 'unsafe',
2016-04-30 21:17:08 +00:00
}
2016-04-30 21:17:08 +00:00
TYPE_MAP = {
model.Post.TYPE_IMAGE: 'image',
model.Post.TYPE_ANIMATION: 'animation',
model.Post.TYPE_VIDEO: 'video',
model.Post.TYPE_FLASH: 'flash',
2016-04-30 21:17:08 +00:00
}
2016-05-10 09:57:05 +00:00
FLAG_MAP = {
model.Post.FLAG_LOOP: 'loop',
model.Post.FLAG_SOUND: 'sound',
2016-05-10 09:57:05 +00:00
}
2016-04-30 21:17:08 +00:00
def get_post_security_hash(id: int) -> str:
return hmac.new(
config.config['secret'].encode('utf8'),
str(id).encode('utf-8')).hexdigest()[0:16]
def get_post_content_url(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
return '%s/posts/%d_%s.%s' % (
2016-04-30 21:17:08 +00:00
config.config['data_url'].rstrip('/'),
post.post_id,
get_post_security_hash(post.post_id),
2016-04-30 21:17:08 +00:00
mime.get_extension(post.mime_type) or 'dat')
def get_post_thumbnail_url(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
return '%s/generated-thumbnails/%d_%s.jpg' % (
2016-04-30 21:17:08 +00:00
config.config['data_url'].rstrip('/'),
post.post_id,
get_post_security_hash(post.post_id))
2016-04-30 21:17:08 +00:00
def get_post_content_path(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
assert post.post_id
return 'posts/%d_%s.%s' % (
post.post_id,
get_post_security_hash(post.post_id),
mime.get_extension(post.mime_type) or 'dat')
2016-04-30 21:17:08 +00:00
def get_post_thumbnail_path(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
return 'generated-thumbnails/%d_%s.jpg' % (
post.post_id,
get_post_security_hash(post.post_id))
2016-04-30 21:17:08 +00:00
def get_post_thumbnail_backup_path(post: model.Post) -> str:
2016-08-14 08:45:00 +00:00
assert post
return 'posts/custom-thumbnails/%d_%s.dat' % (
post.post_id, get_post_security_hash(post.post_id))
2016-04-30 21:17:08 +00:00
def serialize_note(note: model.PostNote) -> rest.Response:
2016-08-14 08:45:00 +00:00
assert note
2016-04-30 21:17:08 +00:00
return {
2016-05-28 09:22:25 +00:00
'polygon': note.polygon,
2016-04-30 21:17:08 +00:00
'text': note.text,
}
class PostSerializer(serialization.BaseSerializer):
def __init__(self, post: model.Post, auth_user: model.User) -> None:
self.post = post
self.auth_user = auth_user
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'id': self.serialize_id,
'version': self.serialize_version,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
'safety': self.serialize_safety,
'source': self.serialize_source,
'type': self.serialize_type,
'mimeType': self.serialize_mime,
'checksum': self.serialize_checksum,
'fileSize': self.serialize_file_size,
'canvasWidth': self.serialize_canvas_width,
'canvasHeight': self.serialize_canvas_height,
'contentUrl': self.serialize_content_url,
'thumbnailUrl': self.serialize_thumbnail_url,
'flags': self.serialize_flags,
'tags': self.serialize_tags,
'relations': self.serialize_relations,
'user': self.serialize_user,
'score': self.serialize_score,
'ownScore': self.serialize_own_score,
'ownFavorite': self.serialize_own_favorite,
'tagCount': self.serialize_tag_count,
'favoriteCount': self.serialize_favorite_count,
'commentCount': self.serialize_comment_count,
'noteCount': self.serialize_note_count,
'relationCount': self.serialize_relation_count,
'featureCount': self.serialize_feature_count,
'lastFeatureTime': self.serialize_last_feature_time,
'favoritedBy': self.serialize_favorited_by,
'hasCustomThumbnail': self.serialize_has_custom_thumbnail,
'notes': self.serialize_notes,
'comments': self.serialize_comments,
}
def serialize_id(self) -> Any:
return self.post.post_id
def serialize_version(self) -> Any:
return self.post.version
def serialize_creation_time(self) -> Any:
return self.post.creation_time
def serialize_last_edit_time(self) -> Any:
return self.post.last_edit_time
def serialize_safety(self) -> Any:
return SAFETY_MAP[self.post.safety]
def serialize_source(self) -> Any:
return self.post.source
def serialize_type(self) -> Any:
return TYPE_MAP[self.post.type]
def serialize_mime(self) -> Any:
return self.post.mime_type
def serialize_checksum(self) -> Any:
return self.post.checksum
def serialize_file_size(self) -> Any:
return self.post.file_size
def serialize_canvas_width(self) -> Any:
return self.post.canvas_width
def serialize_canvas_height(self) -> Any:
return self.post.canvas_height
def serialize_content_url(self) -> Any:
return get_post_content_url(self.post)
def serialize_thumbnail_url(self) -> Any:
return get_post_thumbnail_url(self.post)
def serialize_flags(self) -> Any:
return self.post.flags
def serialize_tags(self) -> Any:
2017-10-01 19:46:53 +00:00
return [
{
'names': [name.name for name in tag.names],
'category': tag.category.name,
'usages': tag.post_count,
}
for tag in tags.sort_tags(self.post.tags)]
def serialize_relations(self) -> Any:
return sorted(
{
post['id']: post
for post in [
serialize_micro_post(rel, self.auth_user)
for rel in self.post.relations]
}.values(),
key=lambda post: post['id'])
def serialize_user(self) -> Any:
return users.serialize_micro_user(self.post.user, self.auth_user)
def serialize_score(self) -> Any:
return self.post.score
def serialize_own_score(self) -> Any:
return scores.get_score(self.post, self.auth_user)
def serialize_own_favorite(self) -> Any:
return len([
user for user in self.post.favorited_by
if user.user_id == self.auth_user.user_id]
) > 0
def serialize_tag_count(self) -> Any:
return self.post.tag_count
def serialize_favorite_count(self) -> Any:
return self.post.favorite_count
def serialize_comment_count(self) -> Any:
return self.post.comment_count
def serialize_note_count(self) -> Any:
return self.post.note_count
def serialize_relation_count(self) -> Any:
return self.post.relation_count
def serialize_feature_count(self) -> Any:
return self.post.feature_count
def serialize_last_feature_time(self) -> Any:
return self.post.last_feature_time
def serialize_favorited_by(self) -> Any:
return [
users.serialize_micro_user(rel.user, self.auth_user)
for rel in self.post.favorited_by
]
def serialize_has_custom_thumbnail(self) -> Any:
return files.has(get_post_thumbnail_backup_path(self.post))
def serialize_notes(self) -> Any:
return sorted(
[serialize_note(note) for note in self.post.notes],
key=lambda x: x['polygon'])
def serialize_comments(self) -> Any:
return [
comments.serialize_comment(comment, self.auth_user)
for comment in sorted(
self.post.comments,
key=lambda comment: comment.creation_time)]
def serialize_post(
post: Optional[model.Post],
auth_user: model.User,
2017-04-24 21:30:53 +00:00
options: List[str] = []) -> Optional[rest.Response]:
if not post:
return None
return PostSerializer(post, auth_user).serialize(options)
def serialize_micro_post(
post: model.Post, auth_user: model.User) -> Optional[rest.Response]:
return serialize_post(
post, auth_user=auth_user, options=['id', 'thumbnailUrl'])
def get_post_count() -> int:
return db.session.query(sa.func.count(model.Post.post_id)).one()[0]
2016-04-22 18:58:04 +00:00
def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
2017-04-24 21:30:53 +00:00
return (
db.session
.query(model.Post)
.filter(model.Post.post_id == post_id)
.one_or_none())
2016-04-22 18:58:04 +00:00
def get_post_by_id(post_id: int) -> model.Post:
post = try_get_post_by_id(post_id)
if not post:
raise PostNotFoundError('Post %r not found.' % post_id)
return post
def try_get_current_post_feature() -> Optional[model.PostFeature]:
2017-04-24 21:30:53 +00:00
return (
db.session
.query(model.PostFeature)
.order_by(model.PostFeature.time.desc())
.first())
def try_get_featured_post() -> Optional[model.Post]:
post_feature = try_get_current_post_feature()
2016-04-22 18:58:04 +00:00
return post_feature.post if post_feature else None
def create_post(
content: bytes,
tag_names: List[str],
user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]:
post = model.Post()
post.safety = model.Post.SAFETY_SAFE
2016-04-30 21:17:08 +00:00
post.user = user
post.creation_time = datetime.utcnow()
2016-04-30 21:17:08 +00:00
post.flags = []
post.type = ''
post.checksum = ''
post.mime_type = ''
db.session.add(post)
update_post_content(post, content)
new_tags = update_post_tags(post, tag_names)
return post, new_tags
2016-04-30 21:17:08 +00:00
def update_post_safety(post: model.Post, safety: str) -> None:
2016-08-14 08:45:00 +00:00
assert post
2016-04-30 21:17:08 +00:00
safety = util.flip(SAFETY_MAP).get(safety, None)
if not safety:
raise InvalidPostSafetyError(
2016-05-10 09:57:05 +00:00
'Safety can be either of %r.' % list(SAFETY_MAP.values()))
2016-04-30 21:17:08 +00:00
post.safety = safety
def update_post_source(post: model.Post, source: Optional[str]) -> None:
2016-08-14 08:45:00 +00:00
assert post
if util.value_exceeds_column_size(source, model.Post.source):
2016-04-30 21:17:08 +00:00
raise InvalidPostSourceError('Source is too long.')
post.source = source or None
2016-04-30 21:17:08 +00:00
@sa.events.event.listens_for(model.Post, 'after_insert')
def _after_post_insert(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_sync_post_content(post)
@sa.events.event.listens_for(model.Post, 'after_update')
def _after_post_update(
_mapper: Any, _connection: Any, post: model.Post) -> None:
_sync_post_content(post)
@sa.events.event.listens_for(model.Post, 'before_delete')
def _before_post_delete(
_mapper: Any, _connection: Any, post: model.Post) -> None:
if post.post_id:
image_hash.delete_image(post.post_id)
if config.config['delete_source_files']:
files.delete(get_post_content_path(post))
files.delete(get_post_thumbnail_path(post))
def _sync_post_content(post: model.Post) -> None:
regenerate_thumb = False
if hasattr(post, '__content'):
content = getattr(post, '__content')
files.save(get_post_content_path(post), content)
delattr(post, '__content')
regenerate_thumb = True
if post.post_id and post.type in (
model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION):
image_hash.delete_image(post.post_id)
image_hash.add_image(post.post_id, content)
if hasattr(post, '__thumbnail'):
if getattr(post, '__thumbnail'):
files.save(
get_post_thumbnail_backup_path(post),
getattr(post, '__thumbnail'))
else:
files.delete(get_post_thumbnail_backup_path(post))
delattr(post, '__thumbnail')
regenerate_thumb = True
if regenerate_thumb:
generate_post_thumbnail(post)
def generate_alternate_formats(post: model.Post, content: bytes) \
-> List[Tuple[model.Post, List[model.Tag]]]:
assert post
assert content
new_posts = []
if mime.is_animated_gif(content):
tag_names = [
tag_name.name
for tag_name in [tag.names for tag in post.tags]]
if config.config['convert']['gif']['to_mp4']:
mp4_post, new_tags = create_post(
images.Image(content).to_mp4(),
tag_names,
post.user)
update_post_flags(mp4_post, ['loop'])
update_post_safety(mp4_post, post.safety)
update_post_source(mp4_post, post.source)
new_posts += [(mp4_post, new_tags)]
if config.config['convert']['gif']['to_webm']:
webm_post, new_tags = create_post(
images.Image(content).to_webm(),
tag_names,
post.user)
update_post_flags(webm_post, ['loop'])
update_post_safety(webm_post, post.safety)
update_post_source(webm_post, post.source)
new_posts += [(webm_post, new_tags)]
db.session.flush()
new_posts = [p for p in new_posts if p[0] is not None]
new_relations = [p[0].post_id for p in new_posts]
if len(new_relations) > 0:
update_post_relations(post, new_relations)
return new_posts
def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
2016-08-14 08:45:00 +00:00
assert post
2016-04-30 21:17:08 +00:00
if not content:
raise InvalidPostContentError('Post content missing.')
post.mime_type = mime.get_mime_type(content)
if mime.is_flash(post.mime_type):
post.type = model.Post.TYPE_FLASH
2016-04-30 21:17:08 +00:00
elif mime.is_image(post.mime_type):
if mime.is_animated_gif(content):
post.type = model.Post.TYPE_ANIMATION
2016-04-30 21:17:08 +00:00
else:
post.type = model.Post.TYPE_IMAGE
2016-04-30 21:17:08 +00:00
elif mime.is_video(post.mime_type):
post.type = model.Post.TYPE_VIDEO
2016-04-30 21:17:08 +00:00
else:
raise InvalidPostContentError(
'Unhandled file type: %r' % post.mime_type)
2016-04-30 21:17:08 +00:00
post.checksum = util.get_sha1(content)
2017-04-24 21:30:53 +00:00
other_post = (
db.session
.query(model.Post)
.filter(model.Post.checksum == post.checksum)
.filter(model.Post.post_id != post.post_id)
.one_or_none())
if other_post \
and other_post.post_id \
and other_post.post_id != post.post_id:
raise PostAlreadyUploadedError(other_post)
2016-04-30 21:17:08 +00:00
post.file_size = len(content)
try:
image = images.Image(content)
post.canvas_width = image.width
post.canvas_height = image.height
except errors.ProcessingError:
post.canvas_width = None
post.canvas_height = None
if (post.canvas_width is not None and post.canvas_width <= 0) \
or (post.canvas_height is not None and post.canvas_height <= 0):
post.canvas_width = None
post.canvas_height = None
setattr(post, '__content', content)
2016-04-30 21:17:08 +00:00
def update_post_thumbnail(
2017-04-24 21:30:53 +00:00
post: model.Post, content: Optional[bytes] = None) -> None:
2016-08-14 08:45:00 +00:00
assert post
setattr(post, '__thumbnail', content)
2016-04-30 21:17:08 +00:00
def generate_post_thumbnail(post: model.Post) -> None:
2016-08-14 08:45:00 +00:00
assert post
if files.has(get_post_thumbnail_backup_path(post)):
content = files.get(get_post_thumbnail_backup_path(post))
else:
content = files.get(get_post_content_path(post))
2016-04-30 21:17:08 +00:00
try:
assert content
2016-04-30 21:17:08 +00:00
image = images.Image(content)
image.resize_fill(
int(config.config['thumbnails']['post_width']),
int(config.config['thumbnails']['post_height']))
files.save(get_post_thumbnail_path(post), image.to_jpeg())
except errors.ProcessingError:
files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)
def update_post_tags(
post: model.Post, tag_names: List[str]) -> List[model.Tag]:
2016-08-14 08:45:00 +00:00
assert post
2016-04-30 21:17:08 +00:00
existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
post.tags = existing_tags + new_tags
return new_tags
2016-04-30 21:17:08 +00:00
def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
2016-08-14 08:45:00 +00:00
assert post
try:
new_post_ids = [int(id) for id in new_post_ids]
except ValueError:
raise InvalidPostRelationError(
'A relation must be numeric post ID.')
old_posts = post.relations
old_post_ids = [int(p.post_id) for p in old_posts]
if new_post_ids:
2017-04-24 21:30:53 +00:00
new_posts = (
db.session
.query(model.Post)
.filter(model.Post.post_id.in_(new_post_ids))
.all())
else:
new_posts = []
if len(new_posts) != len(new_post_ids):
2016-04-30 21:17:08 +00:00
raise InvalidPostRelationError('One of relations does not exist.')
if post.post_id in new_post_ids:
raise InvalidPostRelationError('Post cannot relate to itself.')
relations_to_del = [p for p in old_posts if p.post_id not in new_post_ids]
relations_to_add = [p for p in new_posts if p.post_id not in old_post_ids]
for relation in relations_to_del:
post.relations.remove(relation)
relation.relations.remove(post)
for relation in relations_to_add:
post.relations.append(relation)
relation.relations.append(post)
2016-04-30 21:17:08 +00:00
def update_post_notes(post: model.Post, notes: Any) -> None:
2016-08-14 08:45:00 +00:00
assert post
2016-04-30 21:17:08 +00:00
post.notes = []
for note in notes:
for field in ('polygon', 'text'):
if field not in note:
raise InvalidPostNoteError('Note is missing %r field.' % field)
if not note['text']:
raise InvalidPostNoteError('A note\'s text cannot be empty.')
if not isinstance(note['polygon'], (list, tuple)):
raise InvalidPostNoteError(
'A note\'s polygon must be a list of points.')
2016-04-30 21:17:08 +00:00
if len(note['polygon']) < 3:
raise InvalidPostNoteError(
'A note\'s polygon must have at least 3 points.')
for point in note['polygon']:
if not isinstance(point, (list, tuple)):
raise InvalidPostNoteError(
'A note\'s polygon point must be a list of length 2.')
2016-04-30 21:17:08 +00:00
if len(point) != 2:
raise InvalidPostNoteError(
'A point in note\'s polygon must have two coordinates.')
try:
pos_x = float(point[0])
pos_y = float(point[1])
if not 0 <= pos_x <= 1 or not 0 <= pos_y <= 1:
raise InvalidPostNoteError(
'All points must fit in the image (0..1 range).')
2016-04-30 21:17:08 +00:00
except ValueError:
raise InvalidPostNoteError(
'A point in note\'s polygon must be numeric.')
if util.value_exceeds_column_size(note['text'], model.PostNote.text):
2016-04-30 21:17:08 +00:00
raise InvalidPostNoteError('Note text is too long.')
post.notes.append(
model.PostNote(polygon=note['polygon'], text=str(note['text'])))
2016-04-30 21:17:08 +00:00
def update_post_flags(post: model.Post, flags: List[str]) -> None:
2016-08-14 08:45:00 +00:00
assert post
2016-05-10 09:57:05 +00:00
target_flags = []
2016-04-30 21:17:08 +00:00
for flag in flags:
2016-05-10 09:57:05 +00:00
flag = util.flip(FLAG_MAP).get(flag, None)
if not flag:
2016-04-30 21:17:08 +00:00
raise InvalidPostFlagError(
2016-05-10 09:57:05 +00:00
'Flag must be one of %r.' % list(FLAG_MAP.values()))
target_flags.append(flag)
post.flags = target_flags
2016-04-30 21:17:08 +00:00
def feature_post(post: model.Post, user: Optional[model.User]) -> None:
2016-08-14 08:45:00 +00:00
assert post
post_feature = model.PostFeature()
post_feature.time = datetime.utcnow()
2016-04-22 18:58:04 +00:00
post_feature.post = post
post_feature.user = user
db.session.add(post_feature)
def delete(post: model.Post) -> None:
2016-08-14 08:45:00 +00:00
assert post
db.session.delete(post)
2016-10-21 19:48:08 +00:00
def merge_posts(
source_post: model.Post,
target_post: model.Post,
replace_content: bool) -> None:
2016-10-21 19:48:08 +00:00
assert source_post
assert target_post
if source_post.post_id == target_post.post_id:
raise InvalidPostRelationError('Cannot merge post with itself.')
def merge_tables(
table: model.Base,
anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
source_post_id: int,
target_post_id: int) -> None:
2016-10-22 15:57:25 +00:00
alias1 = table
alias2 = sa.orm.util.aliased(table)
2017-02-03 20:42:15 +00:00
update_stmt = (
sa.sql.expression.update(alias1)
2016-10-22 15:57:25 +00:00
.where(alias1.post_id == source_post_id))
2016-10-21 19:48:08 +00:00
if anti_dup_func is not None:
2017-02-03 20:42:15 +00:00
update_stmt = (
update_stmt
.where(
~sa.exists()
2016-10-22 15:57:25 +00:00
.where(anti_dup_func(alias1, alias2))
.where(alias2.post_id == target_post_id)))
2016-10-21 19:48:08 +00:00
2016-10-22 15:57:25 +00:00
update_stmt = update_stmt.values(post_id=target_post_id)
2016-10-21 19:48:08 +00:00
db.session.execute(update_stmt)
def merge_tags(source_post_id: int, target_post_id: int) -> None:
2016-10-21 19:48:08 +00:00
merge_tables(
model.PostTag,
2016-10-21 19:48:08 +00:00
lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
source_post_id,
target_post_id)
def merge_scores(source_post_id: int, target_post_id: int) -> None:
2016-10-21 19:48:08 +00:00
merge_tables(
model.PostScore,
2016-10-21 19:48:08 +00:00
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id)
def merge_favorites(source_post_id: int, target_post_id: int) -> None:
2016-10-21 19:48:08 +00:00
merge_tables(
model.PostFavorite,
2016-10-21 19:48:08 +00:00
lambda alias1, alias2: alias1.user_id == alias2.user_id,
source_post_id,
target_post_id)
def merge_comments(source_post_id: int, target_post_id: int) -> None:
merge_tables(model.Comment, None, source_post_id, target_post_id)
2016-10-21 19:48:08 +00:00
def merge_relations(source_post_id: int, target_post_id: int) -> None:
alias1 = model.PostRelation
alias2 = sa.orm.util.aliased(model.PostRelation)
2017-02-03 20:42:15 +00:00
update_stmt = (
sa.sql.expression.update(alias1)
2016-10-22 15:57:25 +00:00
.where(alias1.parent_id == source_post_id)
.where(alias1.child_id != target_post_id)
2017-02-03 20:42:15 +00:00
.where(
~sa.exists()
2016-10-22 15:57:25 +00:00
.where(alias2.child_id == alias1.child_id)
.where(alias2.parent_id == target_post_id))
2016-10-21 19:48:08 +00:00
.values(parent_id=target_post_id))
db.session.execute(update_stmt)
2017-02-03 20:42:15 +00:00
update_stmt = (
sa.sql.expression.update(alias1)
2016-10-22 15:57:25 +00:00
.where(alias1.child_id == source_post_id)
.where(alias1.parent_id != target_post_id)
2017-02-03 20:42:15 +00:00
.where(
~sa.exists()
2016-10-22 15:57:25 +00:00
.where(alias2.parent_id == alias1.parent_id)
.where(alias2.child_id == target_post_id))
2016-10-21 19:48:08 +00:00
.values(child_id=target_post_id))
db.session.execute(update_stmt)
merge_tags(source_post.post_id, target_post.post_id)
merge_comments(source_post.post_id, target_post.post_id)
merge_scores(source_post.post_id, target_post.post_id)
merge_favorites(source_post.post_id, target_post.post_id)
merge_relations(source_post.post_id, target_post.post_id)
content = None
if replace_content:
content = files.get(get_post_content_path(source_post))
delete(source_post)
db.session.flush()
if content is not None:
update_post_content(target_post, content)
def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
checksum = util.get_sha1(image_content)
2017-04-24 21:30:53 +00:00
return (
db.session
.query(model.Post)
.filter(model.Post.checksum == checksum)
.one_or_none())
def search_by_image(image_content: bytes) -> List[PostLookalike]:
ret = []
for result in image_hash.search_by_image(image_content):
post = try_get_post_by_id(result.path)
if post:
ret.append(PostLookalike(
score=result.score,
distance=result.distance,
post=post))
return ret
def populate_reverse_search() -> None:
excluded_post_ids = image_hash.get_all_paths()
2017-02-03 20:42:15 +00:00
post_ids_to_hash = (
db.session
.query(model.Post.post_id)
.filter(
(model.Post.type == model.Post.TYPE_IMAGE) |
(model.Post.type == model.Post.TYPE_ANIMATION))
.filter(~model.Post.post_id.in_(excluded_post_ids))
.order_by(model.Post.post_id.asc())
.all())
for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
2017-02-03 20:42:15 +00:00
posts_chunk = (
db.session
.query(model.Post)
.filter(model.Post.post_id.in_(post_ids_chunk))
.all())
for post in posts_chunk:
content_path = get_post_content_path(post)
if files.has(content_path):
image_hash.add_image(post.post_id, files.get(content_path))