from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
from szurubooru.func import (
    users, scores, comments, tags, util,
    mime, images, files, image_hash, serialization)


EMPTY_PIXEL = \
    b'\x47\x49\x46\x38\x39\x61\x01\x00\x01\x00\x80\x01\x00\x00\x00\x00' \
    b'\xff\xff\xff\x21\xf9\x04\x01\x00\x00\x01\x00\x2c\x00\x00\x00\x00' \
    b'\x01\x00\x01\x00\x00\x02\x02\x4c\x01\x00\x3b'


class PostNotFoundError(errors.NotFoundError):
    pass


class PostAlreadyFeaturedError(errors.ValidationError):
    pass


class PostAlreadyUploadedError(errors.ValidationError):
    def __init__(self, other_post: model.Post) -> None:
        super().__init__(
            'Post already uploaded (%d)' % other_post.post_id,
            {
                'otherPostUrl': get_post_content_url(other_post),
                'otherPostId': other_post.post_id,
            })


class InvalidPostIdError(errors.ValidationError):
    pass


class InvalidPostSafetyError(errors.ValidationError):
    pass


class InvalidPostSourceError(errors.ValidationError):
    pass


class InvalidPostContentError(errors.ValidationError):
    pass


class InvalidPostRelationError(errors.ValidationError):
    pass


class InvalidPostNoteError(errors.ValidationError):
    pass


class InvalidPostFlagError(errors.ValidationError):
    pass


class PostLookalike(image_hash.Lookalike):
    def __init__(self, score: int, distance: float, post: model.Post) -> None:
        super().__init__(score, distance, post.post_id)
        self.post = post


SAFETY_MAP = {
    model.Post.SAFETY_SAFE: 'safe',
    model.Post.SAFETY_SKETCHY: 'sketchy',
    model.Post.SAFETY_UNSAFE: 'unsafe',
}

TYPE_MAP = {
    model.Post.TYPE_IMAGE: 'image',
    model.Post.TYPE_ANIMATION: 'animation',
    model.Post.TYPE_VIDEO: 'video',
    model.Post.TYPE_FLASH: 'flash',
}

FLAG_MAP = {
    model.Post.FLAG_LOOP: 'loop',
}


def get_post_content_url(post: model.Post) -> str:
    assert post
    return '%s/posts/%d.%s' % (
        config.config['data_url'].rstrip('/'),
        post.post_id,
        mime.get_extension(post.mime_type) or 'dat')


def get_post_thumbnail_url(post: model.Post) -> str:
    assert post
    return '%s/generated-thumbnails/%d.jpg' % (
        config.config['data_url'].rstrip('/'),
        post.post_id)


def get_post_content_path(post: model.Post) -> str:
    assert post
    assert post.post_id
    return 'posts/%d.%s' % (
        post.post_id, mime.get_extension(post.mime_type) or 'dat')


def get_post_thumbnail_path(post: model.Post) -> str:
    assert post
    return 'generated-thumbnails/%d.jpg' % (post.post_id)


def get_post_thumbnail_backup_path(post: model.Post) -> str:
    assert post
    return 'posts/custom-thumbnails/%d.dat' % (post.post_id)


def serialize_note(note: model.PostNote) -> rest.Response:
    assert note
    return {
        'polygon': note.polygon,
        'text': note.text,
    }


class PostSerializer(serialization.BaseSerializer):
    def __init__(self, post: model.Post, auth_user: model.User) -> None:
        self.post = post
        self.auth_user = auth_user

    def _serializers(self) -> Dict[str, Callable[[], Any]]:
        return {
            'id': self.serialize_id,
            'version': self.serialize_version,
            'creationTime': self.serialize_creation_time,
            'lastEditTime': self.serialize_last_edit_time,
            'safety': self.serialize_safety,
            'source': self.serialize_source,
            'type': self.serialize_type,
            'mimeType': self.serialize_mime,
            'checksum': self.serialize_checksum,
            'fileSize': self.serialize_file_size,
            'canvasWidth': self.serialize_canvas_width,
            'canvasHeight': self.serialize_canvas_height,
            'contentUrl': self.serialize_content_url,
            'thumbnailUrl': self.serialize_thumbnail_url,
            'flags': self.serialize_flags,
            'tags': self.serialize_tags,
            'relations': self.serialize_relations,
            'user': self.serialize_user,
            'score': self.serialize_score,
            'ownScore': self.serialize_own_score,
            'ownFavorite': self.serialize_own_favorite,
            'tagCount': self.serialize_tag_count,
            'favoriteCount': self.serialize_favorite_count,
            'commentCount': self.serialize_comment_count,
            'noteCount': self.serialize_note_count,
            'relationCount': self.serialize_relation_count,
            'featureCount': self.serialize_feature_count,
            'lastFeatureTime': self.serialize_last_feature_time,
            'favoritedBy': self.serialize_favorited_by,
            'hasCustomThumbnail': self.serialize_has_custom_thumbnail,
            'notes': self.serialize_notes,
            'comments': self.serialize_comments,
        }

    def serialize_id(self) -> Any:
        return self.post.post_id

    def serialize_version(self) -> Any:
        return self.post.version

    def serialize_creation_time(self) -> Any:
        return self.post.creation_time

    def serialize_last_edit_time(self) -> Any:
        return self.post.last_edit_time

    def serialize_safety(self) -> Any:
        return SAFETY_MAP[self.post.safety]

    def serialize_source(self) -> Any:
        return self.post.source

    def serialize_type(self) -> Any:
        return TYPE_MAP[self.post.type]

    def serialize_mime(self) -> Any:
        return self.post.mime_type

    def serialize_checksum(self) -> Any:
        return self.post.checksum

    def serialize_file_size(self) -> Any:
        return self.post.file_size

    def serialize_canvas_width(self) -> Any:
        return self.post.canvas_width

    def serialize_canvas_height(self) -> Any:
        return self.post.canvas_height

    def serialize_content_url(self) -> Any:
        return get_post_content_url(self.post)

    def serialize_thumbnail_url(self) -> Any:
        return get_post_thumbnail_url(self.post)

    def serialize_flags(self) -> Any:
        return self.post.flags

    def serialize_tags(self) -> Any:
        return [tag.names[0].name for tag in tags.sort_tags(self.post.tags)]

    def serialize_relations(self) -> Any:
        return sorted(
            {
                post['id']: post
                for post in [
                    serialize_micro_post(rel, self.auth_user)
                    for rel in self.post.relations]
            }.values(),
            key=lambda post: post['id'])

    def serialize_user(self) -> Any:
        return users.serialize_micro_user(self.post.user, self.auth_user)

    def serialize_score(self) -> Any:
        return self.post.score

    def serialize_own_score(self) -> Any:
        return scores.get_score(self.post, self.auth_user)

    def serialize_own_favorite(self) -> Any:
        return len([
            user for user in self.post.favorited_by
            if user.user_id == self.auth_user.user_id]
        ) > 0

    def serialize_tag_count(self) -> Any:
        return self.post.tag_count

    def serialize_favorite_count(self) -> Any:
        return self.post.favorite_count

    def serialize_comment_count(self) -> Any:
        return self.post.comment_count

    def serialize_note_count(self) -> Any:
        return self.post.note_count

    def serialize_relation_count(self) -> Any:
        return self.post.relation_count

    def serialize_feature_count(self) -> Any:
        return self.post.feature_count

    def serialize_last_feature_time(self) -> Any:
        return self.post.last_feature_time

    def serialize_favorited_by(self) -> Any:
        return [
            users.serialize_micro_user(rel.user, self.auth_user)
            for rel in self.post.favorited_by
        ]

    def serialize_has_custom_thumbnail(self) -> Any:
        return files.has(get_post_thumbnail_backup_path(self.post))

    def serialize_notes(self) -> Any:
        return sorted(
            [serialize_note(note) for note in self.post.notes],
            key=lambda x: x['polygon'])

    def serialize_comments(self) -> Any:
        return [
            comments.serialize_comment(comment, self.auth_user)
            for comment in sorted(
                self.post.comments,
                key=lambda comment: comment.creation_time)]


def serialize_post(
        post: Optional[model.Post],
        auth_user: model.User,
        options: List[str]=[]) -> Optional[rest.Response]:
    if not post:
        return None
    return PostSerializer(post, auth_user).serialize(options)


def serialize_micro_post(
        post: model.Post, auth_user: model.User) -> Optional[rest.Response]:
    return serialize_post(
        post, auth_user=auth_user, options=['id', 'thumbnailUrl'])


def get_post_count() -> int:
    return db.session.query(sa.func.count(model.Post.post_id)).one()[0]


def try_get_post_by_id(post_id: int) -> Optional[model.Post]:
    return db.session \
        .query(model.Post) \
        .filter(model.Post.post_id == post_id) \
        .one_or_none()


def get_post_by_id(post_id: int) -> model.Post:
    post = try_get_post_by_id(post_id)
    if not post:
        raise PostNotFoundError('Post %r not found.' % post_id)
    return post


def try_get_current_post_feature() -> Optional[model.PostFeature]:
    return db.session \
        .query(model.PostFeature) \
        .order_by(model.PostFeature.time.desc()) \
        .first()


def try_get_featured_post() -> Optional[model.Post]:
    post_feature = try_get_current_post_feature()
    return post_feature.post if post_feature else None


def create_post(
        content: bytes,
        tag_names: List[str],
        user: Optional[model.User]) -> Tuple[model.Post, List[model.Tag]]:
    post = model.Post()
    post.safety = model.Post.SAFETY_SAFE
    post.user = user
    post.creation_time = datetime.utcnow()
    post.flags = []

    post.type = ''
    post.checksum = ''
    post.mime_type = ''
    db.session.add(post)

    update_post_content(post, content)
    new_tags = update_post_tags(post, tag_names)
    return (post, new_tags)


def update_post_safety(post: model.Post, safety: str) -> None:
    assert post
    safety = util.flip(SAFETY_MAP).get(safety, None)
    if not safety:
        raise InvalidPostSafetyError(
            'Safety can be either of %r.' % list(SAFETY_MAP.values()))
    post.safety = safety


def update_post_source(post: model.Post, source: Optional[str]) -> None:
    assert post
    if util.value_exceeds_column_size(source, model.Post.source):
        raise InvalidPostSourceError('Source is too long.')
    post.source = source or None


@sa.events.event.listens_for(model.Post, 'after_insert')
def _after_post_insert(
        _mapper: Any, _connection: Any, post: model.Post) -> None:
    _sync_post_content(post)


@sa.events.event.listens_for(model.Post, 'after_update')
def _after_post_update(
        _mapper: Any, _connection: Any, post: model.Post) -> None:
    _sync_post_content(post)


@sa.events.event.listens_for(model.Post, 'before_delete')
def _before_post_delete(
        _mapper: Any, _connection: Any, post: model.Post) -> None:
    if post.post_id:
        image_hash.delete_image(post.post_id)


def _sync_post_content(post: model.Post) -> None:
    regenerate_thumb = False

    if hasattr(post, '__content'):
        content = getattr(post, '__content')
        files.save(get_post_content_path(post), content)
        delattr(post, '__content')
        regenerate_thumb = True
        if post.post_id and post.type in (
                model.Post.TYPE_IMAGE, model.Post.TYPE_ANIMATION):
            image_hash.delete_image(post.post_id)
            image_hash.add_image(post.post_id, content)

    if hasattr(post, '__thumbnail'):
        if getattr(post, '__thumbnail'):
            files.save(
                get_post_thumbnail_backup_path(post),
                getattr(post, '__thumbnail'))
        else:
            files.delete(get_post_thumbnail_backup_path(post))
        delattr(post, '__thumbnail')
        regenerate_thumb = True

    if regenerate_thumb:
        generate_post_thumbnail(post)


def update_post_content(post: model.Post, content: Optional[bytes]) -> None:
    assert post
    if not content:
        raise InvalidPostContentError('Post content missing.')
    post.mime_type = mime.get_mime_type(content)
    if mime.is_flash(post.mime_type):
        post.type = model.Post.TYPE_FLASH
    elif mime.is_image(post.mime_type):
        if mime.is_animated_gif(content):
            post.type = model.Post.TYPE_ANIMATION
        else:
            post.type = model.Post.TYPE_IMAGE
    elif mime.is_video(post.mime_type):
        post.type = model.Post.TYPE_VIDEO
    else:
        raise InvalidPostContentError(
            'Unhandled file type: %r' % post.mime_type)

    post.checksum = util.get_sha1(content)
    other_post = db.session \
        .query(model.Post) \
        .filter(model.Post.checksum == post.checksum) \
        .filter(model.Post.post_id != post.post_id) \
        .one_or_none()
    if other_post \
            and other_post.post_id \
            and other_post.post_id != post.post_id:
        raise PostAlreadyUploadedError(other_post)

    post.file_size = len(content)
    try:
        image = images.Image(content)
        post.canvas_width = image.width
        post.canvas_height = image.height
    except errors.ProcessingError:
        post.canvas_width = None
        post.canvas_height = None
    if (post.canvas_width is not None and post.canvas_width <= 0) \
            or (post.canvas_height is not None and post.canvas_height <= 0):
        post.canvas_width = None
        post.canvas_height = None
    setattr(post, '__content', content)


def update_post_thumbnail(
        post: model.Post, content: Optional[bytes]=None) -> None:
    assert post
    setattr(post, '__thumbnail', content)


def generate_post_thumbnail(post: model.Post) -> None:
    assert post
    if files.has(get_post_thumbnail_backup_path(post)):
        content = files.get(get_post_thumbnail_backup_path(post))
    else:
        content = files.get(get_post_content_path(post))
    try:
        assert content
        image = images.Image(content)
        image.resize_fill(
            int(config.config['thumbnails']['post_width']),
            int(config.config['thumbnails']['post_height']))
        files.save(get_post_thumbnail_path(post), image.to_jpeg())
    except errors.ProcessingError:
        files.save(get_post_thumbnail_path(post), EMPTY_PIXEL)


def update_post_tags(
        post: model.Post, tag_names: List[str]) -> List[model.Tag]:
    assert post
    existing_tags, new_tags = tags.get_or_create_tags_by_names(tag_names)
    post.tags = existing_tags + new_tags
    return new_tags


def update_post_relations(post: model.Post, new_post_ids: List[int]) -> None:
    assert post
    try:
        new_post_ids = [int(id) for id in new_post_ids]
    except ValueError:
        raise InvalidPostRelationError(
            'A relation must be numeric post ID.')
    old_posts = post.relations
    old_post_ids = [int(p.post_id) for p in old_posts]
    if new_post_ids:
        new_posts = db.session \
            .query(model.Post) \
            .filter(model.Post.post_id.in_(new_post_ids)) \
            .all()
    else:
        new_posts = []
    if len(new_posts) != len(new_post_ids):
        raise InvalidPostRelationError('One of relations does not exist.')
    if post.post_id in new_post_ids:
        raise InvalidPostRelationError('Post cannot relate to itself.')

    relations_to_del = [p for p in old_posts if p.post_id not in new_post_ids]
    relations_to_add = [p for p in new_posts if p.post_id not in old_post_ids]
    for relation in relations_to_del:
        post.relations.remove(relation)
        relation.relations.remove(post)
    for relation in relations_to_add:
        post.relations.append(relation)
        relation.relations.append(post)


def update_post_notes(post: model.Post, notes: Any) -> None:
    assert post
    post.notes = []
    for note in notes:
        for field in ('polygon', 'text'):
            if field not in note:
                raise InvalidPostNoteError('Note is missing %r field.' % field)
        if not note['text']:
            raise InvalidPostNoteError('A note\'s text cannot be empty.')
        if not isinstance(note['polygon'], (list, tuple)):
            raise InvalidPostNoteError(
                'A note\'s polygon must be a list of points.')
        if len(note['polygon']) < 3:
            raise InvalidPostNoteError(
                'A note\'s polygon must have at least 3 points.')
        for point in note['polygon']:
            if not isinstance(point, (list, tuple)):
                raise InvalidPostNoteError(
                    'A note\'s polygon point must be a list of length 2.')
            if len(point) != 2:
                raise InvalidPostNoteError(
                    'A point in note\'s polygon must have two coordinates.')
            try:
                pos_x = float(point[0])
                pos_y = float(point[1])
                if not 0 <= pos_x <= 1 or not 0 <= pos_y <= 1:
                    raise InvalidPostNoteError(
                        'All points must fit in the image (0..1 range).')
            except ValueError:
                raise InvalidPostNoteError(
                    'A point in note\'s polygon must be numeric.')
        if util.value_exceeds_column_size(note['text'], model.PostNote.text):
            raise InvalidPostNoteError('Note text is too long.')
        post.notes.append(
            model.PostNote(polygon=note['polygon'], text=str(note['text'])))


def update_post_flags(post: model.Post, flags: List[str]) -> None:
    assert post
    target_flags = []
    for flag in flags:
        flag = util.flip(FLAG_MAP).get(flag, None)
        if not flag:
            raise InvalidPostFlagError(
                'Flag must be one of %r.' % list(FLAG_MAP.values()))
        target_flags.append(flag)
    post.flags = target_flags


def feature_post(post: model.Post, user: Optional[model.User]) -> None:
    assert post
    post_feature = model.PostFeature()
    post_feature.time = datetime.utcnow()
    post_feature.post = post
    post_feature.user = user
    db.session.add(post_feature)


def delete(post: model.Post) -> None:
    assert post
    db.session.delete(post)


def merge_posts(
        source_post: model.Post,
        target_post: model.Post,
        replace_content: bool) -> None:
    assert source_post
    assert target_post
    if source_post.post_id == target_post.post_id:
        raise InvalidPostRelationError('Cannot merge post with itself.')

    def merge_tables(
            table: model.Base,
            anti_dup_func: Optional[Callable[[model.Base, model.Base], bool]],
            source_post_id: int,
            target_post_id: int) -> None:
        alias1 = table
        alias2 = sa.orm.util.aliased(table)
        update_stmt = (
            sa.sql.expression.update(alias1)
            .where(alias1.post_id == source_post_id))

        if anti_dup_func is not None:
            update_stmt = (
                update_stmt
                .where(
                    ~sa.exists()
                    .where(anti_dup_func(alias1, alias2))
                    .where(alias2.post_id == target_post_id)))

        update_stmt = update_stmt.values(post_id=target_post_id)
        db.session.execute(update_stmt)

    def merge_tags(source_post_id: int, target_post_id: int) -> None:
        merge_tables(
            model.PostTag,
            lambda alias1, alias2: alias1.tag_id == alias2.tag_id,
            source_post_id,
            target_post_id)

    def merge_scores(source_post_id: int, target_post_id: int) -> None:
        merge_tables(
            model.PostScore,
            lambda alias1, alias2: alias1.user_id == alias2.user_id,
            source_post_id,
            target_post_id)

    def merge_favorites(source_post_id: int, target_post_id: int) -> None:
        merge_tables(
            model.PostFavorite,
            lambda alias1, alias2: alias1.user_id == alias2.user_id,
            source_post_id,
            target_post_id)

    def merge_comments(source_post_id: int, target_post_id: int) -> None:
        merge_tables(model.Comment, None, source_post_id, target_post_id)

    def merge_relations(source_post_id: int, target_post_id: int) -> None:
        alias1 = model.PostRelation
        alias2 = sa.orm.util.aliased(model.PostRelation)
        update_stmt = (
            sa.sql.expression.update(alias1)
            .where(alias1.parent_id == source_post_id)
            .where(alias1.child_id != target_post_id)
            .where(
                ~sa.exists()
                .where(alias2.child_id == alias1.child_id)
                .where(alias2.parent_id == target_post_id))
            .values(parent_id=target_post_id))
        db.session.execute(update_stmt)

        update_stmt = (
            sa.sql.expression.update(alias1)
            .where(alias1.child_id == source_post_id)
            .where(alias1.parent_id != target_post_id)
            .where(
                ~sa.exists()
                .where(alias2.parent_id == alias1.parent_id)
                .where(alias2.child_id == target_post_id))
            .values(child_id=target_post_id))
        db.session.execute(update_stmt)

    merge_tags(source_post.post_id, target_post.post_id)
    merge_comments(source_post.post_id, target_post.post_id)
    merge_scores(source_post.post_id, target_post.post_id)
    merge_favorites(source_post.post_id, target_post.post_id)
    merge_relations(source_post.post_id, target_post.post_id)

    delete(source_post)

    db.session.flush()

    if replace_content:
        content = files.get(get_post_content_path(source_post))
        update_post_content(target_post, content)


def search_by_image_exact(image_content: bytes) -> Optional[model.Post]:
    checksum = util.get_sha1(image_content)
    return db.session \
        .query(model.Post) \
        .filter(model.Post.checksum == checksum) \
        .one_or_none()


def search_by_image(image_content: bytes) -> List[PostLookalike]:
    ret = []
    for result in image_hash.search_by_image(image_content):
        ret.append(PostLookalike(
            score=result.score,
            distance=result.distance,
            post=get_post_by_id(result.path)))
    return ret


def populate_reverse_search() -> None:
    excluded_post_ids = image_hash.get_all_paths()

    post_ids_to_hash = (
        db.session
        .query(model.Post.post_id)
        .filter(
            (model.Post.type == model.Post.TYPE_IMAGE) |
            (model.Post.type == model.Post.TYPE_ANIMATION))
        .filter(~model.Post.post_id.in_(excluded_post_ids))
        .order_by(model.Post.post_id.asc())
        .all())

    for post_ids_chunk in util.chunks(post_ids_to_hash, 100):
        posts_chunk = (
            db.session
            .query(model.Post)
            .filter(model.Post.post_id.in_(post_ids_chunk))
            .all())
        for post in posts_chunk:
            content_path = get_post_content_path(post)
            if files.has(content_path):
                image_hash.add_image(post.post_id, files.get(content_path))