338 lines
9.6 KiB
Python
338 lines
9.6 KiB
Python
import re
|
|
from datetime import datetime
|
|
from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
|
|
import sqlalchemy as sa
|
|
|
|
from szurubooru import config, db, errors, model, rest
|
|
from szurubooru.func import pool_categories, posts, serialization, util
|
|
|
|
|
|
class PoolNotFoundError(errors.NotFoundError):
|
|
pass
|
|
|
|
|
|
class PoolAlreadyExistsError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class PoolIsInUseError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPoolNameError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPoolDuplicateError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPoolCategoryError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPoolDescriptionError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPoolRelationError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
class InvalidPoolNonexistentPostError(errors.ValidationError):
|
|
pass
|
|
|
|
|
|
def _verify_name_validity(name: str) -> None:
|
|
if util.value_exceeds_column_size(name, model.PoolName.name):
|
|
raise InvalidPoolNameError("Name is too long.")
|
|
name_regex = config.config["pool_name_regex"]
|
|
if not re.match(name_regex, name):
|
|
raise InvalidPoolNameError("Name must satisfy regex %r." % name_regex)
|
|
|
|
|
|
def _get_names(pool: model.Pool) -> List[str]:
|
|
assert pool
|
|
return [pool_name.name for pool_name in pool.names]
|
|
|
|
|
|
def _lower_list(names: List[str]) -> List[str]:
|
|
return [name.lower() for name in names]
|
|
|
|
|
|
def _check_name_intersection(
|
|
names1: List[str], names2: List[str], case_sensitive: bool
|
|
) -> bool:
|
|
if not case_sensitive:
|
|
names1 = _lower_list(names1)
|
|
names2 = _lower_list(names2)
|
|
return len(set(names1).intersection(names2)) > 0
|
|
|
|
|
|
def _duplicates(a: List[int]) -> List[int]:
|
|
seen = set()
|
|
dupes = []
|
|
for x in a:
|
|
if x not in seen:
|
|
seen.add(x)
|
|
else:
|
|
dupes.append(x)
|
|
return dupes
|
|
|
|
|
|
def sort_pools(pools: List[model.Pool]) -> List[model.Pool]:
|
|
default_category_name = pool_categories.get_default_category_name()
|
|
return sorted(
|
|
pools,
|
|
key=lambda pool: (
|
|
default_category_name == pool.category.name,
|
|
pool.category.name,
|
|
pool.names[0].name,
|
|
),
|
|
)
|
|
|
|
|
|
class PoolSerializer(serialization.BaseSerializer):
|
|
def __init__(self, pool: model.Pool) -> None:
|
|
self.pool = pool
|
|
|
|
def _serializers(self) -> Dict[str, Callable[[], Any]]:
|
|
return {
|
|
"id": self.serialize_id,
|
|
"names": self.serialize_names,
|
|
"category": self.serialize_category,
|
|
"version": self.serialize_version,
|
|
"description": self.serialize_description,
|
|
"creationTime": self.serialize_creation_time,
|
|
"lastEditTime": self.serialize_last_edit_time,
|
|
"postCount": self.serialize_post_count,
|
|
"posts": self.serialize_posts,
|
|
}
|
|
|
|
def serialize_id(self) -> Any:
|
|
return self.pool.pool_id
|
|
|
|
def serialize_names(self) -> Any:
|
|
return [pool_name.name for pool_name in self.pool.names]
|
|
|
|
def serialize_category(self) -> Any:
|
|
return self.pool.category.name
|
|
|
|
def serialize_version(self) -> Any:
|
|
return self.pool.version
|
|
|
|
def serialize_description(self) -> Any:
|
|
return self.pool.description
|
|
|
|
def serialize_creation_time(self) -> Any:
|
|
return self.pool.creation_time
|
|
|
|
def serialize_last_edit_time(self) -> Any:
|
|
return self.pool.last_edit_time
|
|
|
|
def serialize_post_count(self) -> Any:
|
|
return self.pool.post_count
|
|
|
|
def serialize_posts(self) -> Any:
|
|
return [
|
|
post
|
|
for post in [
|
|
posts.serialize_micro_post(rel, None)
|
|
for rel in self.pool.posts
|
|
]
|
|
]
|
|
|
|
|
|
def serialize_pool(
|
|
pool: model.Pool, options: List[str] = []
|
|
) -> Optional[rest.Response]:
|
|
if not pool:
|
|
return None
|
|
return PoolSerializer(pool).serialize(options)
|
|
|
|
|
|
def serialize_micro_pool(pool: model.Pool) -> Optional[rest.Response]:
|
|
return serialize_pool(
|
|
pool, options=["id", "names", "category", "description", "postCount"]
|
|
)
|
|
|
|
|
|
def try_get_pool_by_id(pool_id: int) -> Optional[model.Pool]:
|
|
return (
|
|
db.session.query(model.Pool)
|
|
.filter(model.Pool.pool_id == pool_id)
|
|
.one_or_none()
|
|
)
|
|
|
|
|
|
def get_pool_by_id(pool_id: int) -> model.Pool:
|
|
pool = try_get_pool_by_id(pool_id)
|
|
if not pool:
|
|
raise PoolNotFoundError("Pool %r not found." % pool_id)
|
|
return pool
|
|
|
|
|
|
def try_get_pool_by_name(name: str) -> Optional[model.Pool]:
|
|
return (
|
|
db.session.query(model.Pool)
|
|
.join(model.PoolName)
|
|
.filter(sa.func.lower(model.PoolName.name) == name.lower())
|
|
.one_or_none()
|
|
)
|
|
|
|
|
|
def get_pool_by_name(name: str) -> model.Pool:
|
|
pool = try_get_pool_by_name(name)
|
|
if not pool:
|
|
raise PoolNotFoundError("Pool %r not found." % name)
|
|
return pool
|
|
|
|
|
|
def get_pools_by_names(names: List[str]) -> List[model.Pool]:
|
|
names = util.icase_unique(names)
|
|
if len(names) == 0:
|
|
return []
|
|
return (
|
|
db.session.query(model.Pool)
|
|
.join(model.PoolName)
|
|
.filter(
|
|
sa.sql.or_(
|
|
sa.func.lower(model.PoolName.name) == name.lower()
|
|
for name in names
|
|
)
|
|
)
|
|
.all()
|
|
)
|
|
|
|
|
|
def get_or_create_pools_by_names(
|
|
names: List[str],
|
|
) -> Tuple[List[model.Pool], List[model.Pool]]:
|
|
names = util.icase_unique(names)
|
|
existing_pools = get_pools_by_names(names)
|
|
new_pools = []
|
|
pool_category_name = pool_categories.get_default_category_name()
|
|
for name in names:
|
|
found = False
|
|
for existing_pool in existing_pools:
|
|
if _check_name_intersection(
|
|
_get_names(existing_pool), [name], False
|
|
):
|
|
found = True
|
|
break
|
|
if not found:
|
|
new_pool = create_pool(
|
|
names=[name], category_name=pool_category_name, post_ids=[]
|
|
)
|
|
db.session.add(new_pool)
|
|
new_pools.append(new_pool)
|
|
return existing_pools, new_pools
|
|
|
|
|
|
def delete(source_pool: model.Pool) -> None:
|
|
assert source_pool
|
|
db.session.delete(source_pool)
|
|
|
|
|
|
def merge_pools(source_pool: model.Pool, target_pool: model.Pool) -> None:
|
|
assert source_pool
|
|
assert target_pool
|
|
if source_pool.pool_id == target_pool.pool_id:
|
|
raise InvalidPoolRelationError("Cannot merge pool with itself.")
|
|
|
|
def merge_pool_posts(source_pool_id: int, target_pool_id: int) -> None:
|
|
alias1 = model.PoolPost
|
|
alias2 = sa.orm.util.aliased(model.PoolPost)
|
|
update_stmt = sa.sql.expression.update(alias1).where(
|
|
alias1.pool_id == source_pool_id
|
|
)
|
|
update_stmt = update_stmt.where(
|
|
~sa.exists()
|
|
.where(alias1.post_id == alias2.post_id)
|
|
.where(alias2.pool_id == target_pool_id)
|
|
)
|
|
update_stmt = update_stmt.values(pool_id=target_pool_id)
|
|
db.session.execute(update_stmt)
|
|
|
|
merge_pool_posts(source_pool.pool_id, target_pool.pool_id)
|
|
delete(source_pool)
|
|
|
|
|
|
def create_pool(
|
|
names: List[str], category_name: str, post_ids: List[int]
|
|
) -> model.Pool:
|
|
pool = model.Pool()
|
|
pool.creation_time = datetime.utcnow()
|
|
update_pool_names(pool, names)
|
|
update_pool_category_name(pool, category_name)
|
|
update_pool_posts(pool, post_ids)
|
|
return pool
|
|
|
|
|
|
def update_pool_category_name(pool: model.Pool, category_name: str) -> None:
|
|
assert pool
|
|
pool.category = pool_categories.get_category_by_name(category_name)
|
|
|
|
|
|
def update_pool_names(pool: model.Pool, names: List[str]) -> None:
|
|
# sanitize
|
|
assert pool
|
|
names = util.icase_unique([name for name in names if name])
|
|
if not len(names):
|
|
raise InvalidPoolNameError("At least one name must be specified.")
|
|
for name in names:
|
|
_verify_name_validity(name)
|
|
|
|
# check for existing pools
|
|
expr = sa.sql.false()
|
|
for name in names:
|
|
expr = expr | (sa.func.lower(model.PoolName.name) == name.lower())
|
|
if pool.pool_id:
|
|
expr = expr & (model.PoolName.pool_id != pool.pool_id)
|
|
existing_pools = db.session.query(model.PoolName).filter(expr).all()
|
|
if len(existing_pools):
|
|
raise PoolAlreadyExistsError(
|
|
"One of names is already used by another pool."
|
|
)
|
|
|
|
# remove unwanted items
|
|
for pool_name in pool.names[:]:
|
|
if not _check_name_intersection([pool_name.name], names, True):
|
|
pool.names.remove(pool_name)
|
|
# add wanted items
|
|
for name in names:
|
|
if not _check_name_intersection(_get_names(pool), [name], True):
|
|
pool.names.append(model.PoolName(name, -1))
|
|
|
|
# set alias order to match the request
|
|
for i, name in enumerate(names):
|
|
for pool_name in pool.names:
|
|
if pool_name.name.lower() == name.lower():
|
|
pool_name.order = i
|
|
|
|
|
|
def update_pool_description(pool: model.Pool, description: str) -> None:
|
|
assert pool
|
|
if util.value_exceeds_column_size(description, model.Pool.description):
|
|
raise InvalidPoolDescriptionError("Description is too long.")
|
|
pool.description = description or None
|
|
|
|
|
|
def update_pool_posts(pool: model.Pool, post_ids: List[int]) -> None:
|
|
assert pool
|
|
dupes = _duplicates(post_ids)
|
|
if len(dupes) > 0:
|
|
dupes = ", ".join(list(str(x) for x in dupes))
|
|
raise InvalidPoolDuplicateError("Duplicate post(s) in pool: " + dupes)
|
|
ret = posts.get_posts_by_ids(post_ids)
|
|
if len(post_ids) != len(ret):
|
|
missing = set(post_ids) - set(post.post_id for post in ret)
|
|
missing = ", ".join(list(str(x) for x in missing))
|
|
raise InvalidPoolNonexistentPostError(
|
|
"The following posts do not exist: " + missing
|
|
)
|
|
pool.posts.clear()
|
|
for post in ret:
|
|
pool.posts.append(post)
|