import re from datetime import datetime from typing import Any, Callable, Dict, List, Optional, Tuple import sqlalchemy as sa from szurubooru import config, db, errors, model, rest from szurubooru.func import pool_categories, posts, serialization, util class PoolNotFoundError(errors.NotFoundError): pass class PoolAlreadyExistsError(errors.ValidationError): pass class PoolIsInUseError(errors.ValidationError): pass class InvalidPoolNameError(errors.ValidationError): pass class InvalidPoolDuplicateError(errors.ValidationError): pass class InvalidPoolCategoryError(errors.ValidationError): pass class InvalidPoolDescriptionError(errors.ValidationError): pass class InvalidPoolRelationError(errors.ValidationError): pass class InvalidPoolNonexistentPostError(errors.ValidationError): pass def _verify_name_validity(name: str) -> None: if util.value_exceeds_column_size(name, model.PoolName.name): raise InvalidPoolNameError("Name is too long.") name_regex = config.config["pool_name_regex"] if not re.match(name_regex, name): raise InvalidPoolNameError("Name must satisfy regex %r." % name_regex) def _get_names(pool: model.Pool) -> List[str]: assert pool return [pool_name.name for pool_name in pool.names] def _lower_list(names: List[str]) -> List[str]: return [name.lower() for name in names] def _check_name_intersection( names1: List[str], names2: List[str], case_sensitive: bool ) -> bool: if not case_sensitive: names1 = _lower_list(names1) names2 = _lower_list(names2) return len(set(names1).intersection(names2)) > 0 def _duplicates(a: List[int]) -> List[int]: seen = set() dupes = [] for x in a: if x not in seen: seen.add(x) else: dupes.append(x) return dupes def sort_pools(pools: List[model.Pool]) -> List[model.Pool]: default_category_name = pool_categories.get_default_category_name() return sorted( pools, key=lambda pool: ( default_category_name == pool.category.name, pool.category.name, pool.names[0].name, ), ) class PoolSerializer(serialization.BaseSerializer): def __init__(self, pool: model.Pool) -> None: self.pool = pool def _serializers(self) -> Dict[str, Callable[[], Any]]: return { "id": self.serialize_id, "names": self.serialize_names, "category": self.serialize_category, "version": self.serialize_version, "description": self.serialize_description, "creationTime": self.serialize_creation_time, "lastEditTime": self.serialize_last_edit_time, "postCount": self.serialize_post_count, "posts": self.serialize_posts, } def serialize_id(self) -> Any: return self.pool.pool_id def serialize_names(self) -> Any: return [pool_name.name for pool_name in self.pool.names] def serialize_category(self) -> Any: return self.pool.category.name def serialize_version(self) -> Any: return self.pool.version def serialize_description(self) -> Any: return self.pool.description def serialize_creation_time(self) -> Any: return self.pool.creation_time def serialize_last_edit_time(self) -> Any: return self.pool.last_edit_time def serialize_post_count(self) -> Any: return self.pool.post_count def serialize_posts(self) -> Any: return [ post for post in [ posts.serialize_micro_post(rel, None) for rel in self.pool.posts ] ] def serialize_pool( pool: model.Pool, options: List[str] = [] ) -> Optional[rest.Response]: if not pool: return None return PoolSerializer(pool).serialize(options) def serialize_micro_pool(pool: model.Pool) -> Optional[rest.Response]: return serialize_pool( pool, options=["id", "names", "category", "description", "postCount"] ) def try_get_pool_by_id(pool_id: int) -> Optional[model.Pool]: return ( db.session.query(model.Pool) .filter(model.Pool.pool_id == pool_id) .one_or_none() ) def get_pool_by_id(pool_id: int) -> model.Pool: pool = try_get_pool_by_id(pool_id) if not pool: raise PoolNotFoundError("Pool %r not found." % pool_id) return pool def try_get_pool_by_name(name: str) -> Optional[model.Pool]: return ( db.session.query(model.Pool) .join(model.PoolName) .filter(sa.func.lower(model.PoolName.name) == name.lower()) .one_or_none() ) def get_pool_by_name(name: str) -> model.Pool: pool = try_get_pool_by_name(name) if not pool: raise PoolNotFoundError("Pool %r not found." % name) return pool def get_pools_by_names(names: List[str]) -> List[model.Pool]: names = util.icase_unique(names) if len(names) == 0: return [] return ( db.session.query(model.Pool) .join(model.PoolName) .filter( sa.sql.or_( sa.func.lower(model.PoolName.name) == name.lower() for name in names ) ) .all() ) def get_or_create_pools_by_names( names: List[str], ) -> Tuple[List[model.Pool], List[model.Pool]]: names = util.icase_unique(names) existing_pools = get_pools_by_names(names) new_pools = [] pool_category_name = pool_categories.get_default_category_name() for name in names: found = False for existing_pool in existing_pools: if _check_name_intersection( _get_names(existing_pool), [name], False ): found = True break if not found: new_pool = create_pool( names=[name], category_name=pool_category_name, post_ids=[] ) db.session.add(new_pool) new_pools.append(new_pool) return existing_pools, new_pools def delete(source_pool: model.Pool) -> None: assert source_pool db.session.delete(source_pool) def merge_pools(source_pool: model.Pool, target_pool: model.Pool) -> None: assert source_pool assert target_pool if source_pool.pool_id == target_pool.pool_id: raise InvalidPoolRelationError("Cannot merge pool with itself.") def merge_pool_posts(source_pool_id: int, target_pool_id: int) -> None: alias1 = model.PoolPost alias2 = sa.orm.util.aliased(model.PoolPost) update_stmt = sa.sql.expression.update(alias1).where( alias1.pool_id == source_pool_id ) update_stmt = update_stmt.where( ~sa.exists() .where(alias1.post_id == alias2.post_id) .where(alias2.pool_id == target_pool_id) ) update_stmt = update_stmt.values(pool_id=target_pool_id) db.session.execute(update_stmt) merge_pool_posts(source_pool.pool_id, target_pool.pool_id) delete(source_pool) def create_pool( names: List[str], category_name: str, post_ids: List[int] ) -> model.Pool: pool = model.Pool() pool.creation_time = datetime.utcnow() update_pool_names(pool, names) update_pool_category_name(pool, category_name) update_pool_posts(pool, post_ids) return pool def update_pool_category_name(pool: model.Pool, category_name: str) -> None: assert pool pool.category = pool_categories.get_category_by_name(category_name) def update_pool_names(pool: model.Pool, names: List[str]) -> None: # sanitize assert pool names = util.icase_unique([name for name in names if name]) if not len(names): raise InvalidPoolNameError("At least one name must be specified.") for name in names: _verify_name_validity(name) # check for existing pools expr = sa.sql.false() for name in names: expr = expr | (sa.func.lower(model.PoolName.name) == name.lower()) if pool.pool_id: expr = expr & (model.PoolName.pool_id != pool.pool_id) existing_pools = db.session.query(model.PoolName).filter(expr).all() if len(existing_pools): raise PoolAlreadyExistsError( "One of names is already used by another pool." ) # remove unwanted items for pool_name in pool.names[:]: if not _check_name_intersection([pool_name.name], names, True): pool.names.remove(pool_name) # add wanted items for name in names: if not _check_name_intersection(_get_names(pool), [name], True): pool.names.append(model.PoolName(name, -1)) # set alias order to match the request for i, name in enumerate(names): for pool_name in pool.names: if pool_name.name.lower() == name.lower(): pool_name.order = i def update_pool_description(pool: model.Pool, description: str) -> None: assert pool if util.value_exceeds_column_size(description, model.Pool.description): raise InvalidPoolDescriptionError("Description is too long.") pool.description = description or None def update_pool_posts(pool: model.Pool, post_ids: List[int]) -> None: assert pool dupes = _duplicates(post_ids) if len(dupes) > 0: dupes = ", ".join(list(str(x) for x in dupes)) raise InvalidPoolDuplicateError("Duplicate post(s) in pool: " + dupes) ret = posts.get_posts_by_ids(post_ids) if len(post_ids) != len(ret): missing = set(post_ids) - set(post.post_id for post in ret) missing = ", ".join(list(str(x) for x in missing)) raise InvalidPoolNonexistentPostError( "The following posts do not exist: " + missing ) pool.posts.clear() for post in ret: pool.posts.append(post)