gallery.accords-library.com/server/szurubooru/func/pools.py

338 lines
9.6 KiB
Python
Raw Normal View History

2020-05-04 02:53:28 +00:00
import re
from datetime import datetime
from typing import Any, Callable, Dict, List, Optional, Tuple
2020-05-04 02:53:28 +00:00
import sqlalchemy as sa
from szurubooru import config, db, errors, model, rest
from szurubooru.func import pool_categories, posts, serialization, util
2020-05-04 02:53:28 +00:00
class PoolNotFoundError(errors.NotFoundError):
pass
class PoolAlreadyExistsError(errors.ValidationError):
pass
class PoolIsInUseError(errors.ValidationError):
pass
class InvalidPoolNameError(errors.ValidationError):
pass
2020-05-04 07:09:33 +00:00
class InvalidPoolDuplicateError(errors.ValidationError):
2020-05-04 02:53:28 +00:00
pass
class InvalidPoolCategoryError(errors.ValidationError):
pass
class InvalidPoolDescriptionError(errors.ValidationError):
pass
2020-05-04 22:15:30 +00:00
class InvalidPoolRelationError(errors.ValidationError):
pass
2020-05-05 02:12:54 +00:00
class InvalidPoolNonexistentPostError(errors.ValidationError):
pass
2020-05-04 02:53:28 +00:00
def _verify_name_validity(name: str) -> None:
if util.value_exceeds_column_size(name, model.PoolName.name):
raise InvalidPoolNameError("Name is too long.")
name_regex = config.config["pool_name_regex"]
2020-05-04 02:53:28 +00:00
if not re.match(name_regex, name):
raise InvalidPoolNameError("Name must satisfy regex %r." % name_regex)
2020-05-04 02:53:28 +00:00
def _get_names(pool: model.Pool) -> List[str]:
assert pool
return [pool_name.name for pool_name in pool.names]
def _lower_list(names: List[str]) -> List[str]:
return [name.lower() for name in names]
def _check_name_intersection(
names1: List[str], names2: List[str], case_sensitive: bool
) -> bool:
2020-05-04 02:53:28 +00:00
if not case_sensitive:
names1 = _lower_list(names1)
names2 = _lower_list(names2)
return len(set(names1).intersection(names2)) > 0
2020-05-05 02:12:54 +00:00
def _duplicates(a: List[int]) -> List[int]:
seen = set()
dupes = []
for x in a:
if x not in seen:
seen.add(x)
else:
dupes.append(x)
return dupes
2020-05-04 07:09:33 +00:00
2020-06-03 00:43:18 +00:00
2020-05-04 02:53:28 +00:00
def sort_pools(pools: List[model.Pool]) -> List[model.Pool]:
default_category_name = pool_categories.get_default_category_name()
return sorted(
pools,
key=lambda pool: (
default_category_name == pool.category.name,
pool.category.name,
pool.names[0].name,
),
2020-05-04 02:53:28 +00:00
)
class PoolSerializer(serialization.BaseSerializer):
def __init__(self, pool: model.Pool) -> None:
self.pool = pool
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
"id": self.serialize_id,
"names": self.serialize_names,
"category": self.serialize_category,
"version": self.serialize_version,
"description": self.serialize_description,
"creationTime": self.serialize_creation_time,
"lastEditTime": self.serialize_last_edit_time,
"postCount": self.serialize_post_count,
"posts": self.serialize_posts,
2020-05-04 02:53:28 +00:00
}
def serialize_id(self) -> Any:
return self.pool.pool_id
def serialize_names(self) -> Any:
return [pool_name.name for pool_name in self.pool.names]
def serialize_category(self) -> Any:
return self.pool.category.name
def serialize_version(self) -> Any:
return self.pool.version
def serialize_description(self) -> Any:
return self.pool.description
def serialize_creation_time(self) -> Any:
return self.pool.creation_time
def serialize_last_edit_time(self) -> Any:
return self.pool.last_edit_time
def serialize_post_count(self) -> Any:
return self.pool.post_count
2020-05-04 07:09:33 +00:00
def serialize_posts(self) -> Any:
2020-06-03 15:55:50 +00:00
return [
post
for post in [
2020-06-03 15:55:50 +00:00
posts.serialize_micro_post(rel, None)
for rel in self.pool.posts
]
]
2020-05-04 07:09:33 +00:00
2020-05-04 02:53:28 +00:00
def serialize_pool(
pool: model.Pool, options: List[str] = []
) -> Optional[rest.Response]:
2020-05-04 02:53:28 +00:00
if not pool:
return None
return PoolSerializer(pool).serialize(options)
def serialize_micro_pool(pool: model.Pool) -> Optional[rest.Response]:
return serialize_pool(
pool, options=["id", "names", "category", "description", "postCount"]
)
2020-05-04 02:53:28 +00:00
def try_get_pool_by_id(pool_id: int) -> Optional[model.Pool]:
return (
db.session.query(model.Pool)
2020-05-04 02:53:28 +00:00
.filter(model.Pool.pool_id == pool_id)
.one_or_none()
)
2020-05-04 02:53:28 +00:00
def get_pool_by_id(pool_id: int) -> model.Pool:
pool = try_get_pool_by_id(pool_id)
if not pool:
raise PoolNotFoundError("Pool %r not found." % pool_id)
2020-05-04 02:53:28 +00:00
return pool
def try_get_pool_by_name(name: str) -> Optional[model.Pool]:
return (
db.session.query(model.Pool)
2020-05-04 02:53:28 +00:00
.join(model.PoolName)
.filter(sa.func.lower(model.PoolName.name) == name.lower())
.one_or_none()
)
2020-05-04 02:53:28 +00:00
def get_pool_by_name(name: str) -> model.Pool:
pool = try_get_pool_by_name(name)
if not pool:
raise PoolNotFoundError("Pool %r not found." % name)
2020-05-04 02:53:28 +00:00
return pool
def get_pools_by_names(names: List[str]) -> List[model.Pool]:
names = util.icase_unique(names)
if len(names) == 0:
return []
return (
db.session.query(model.Pool)
.join(model.PoolName)
.filter(
sa.sql.or_(
sa.func.lower(model.PoolName.name) == name.lower()
for name in names
)
)
.all()
)
2020-05-04 02:53:28 +00:00
def get_or_create_pools_by_names(
names: List[str],
) -> Tuple[List[model.Pool], List[model.Pool]]:
2020-05-04 02:53:28 +00:00
names = util.icase_unique(names)
existing_pools = get_pools_by_names(names)
new_pools = []
pool_category_name = pool_categories.get_default_category_name()
for name in names:
found = False
for existing_pool in existing_pools:
if _check_name_intersection(
_get_names(existing_pool), [name], False
):
2020-05-04 02:53:28 +00:00
found = True
break
if not found:
new_pool = create_pool(
names=[name], category_name=pool_category_name, post_ids=[]
)
2020-05-04 02:53:28 +00:00
db.session.add(new_pool)
new_pools.append(new_pool)
return existing_pools, new_pools
def delete(source_pool: model.Pool) -> None:
assert source_pool
db.session.delete(source_pool)
def merge_pools(source_pool: model.Pool, target_pool: model.Pool) -> None:
assert source_pool
assert target_pool
if source_pool.pool_id == target_pool.pool_id:
raise InvalidPoolRelationError("Cannot merge pool with itself.")
2020-05-04 02:53:28 +00:00
2020-06-03 00:43:18 +00:00
def merge_pool_posts(source_pool_id: int, target_pool_id: int) -> None:
2020-05-04 22:15:30 +00:00
alias1 = model.PoolPost
alias2 = sa.orm.util.aliased(model.PoolPost)
update_stmt = sa.sql.expression.update(alias1).where(
alias1.pool_id == source_pool_id
)
update_stmt = update_stmt.where(
~sa.exists()
.where(alias1.post_id == alias2.post_id)
.where(alias2.pool_id == target_pool_id)
)
2020-05-04 22:15:30 +00:00
update_stmt = update_stmt.values(pool_id=target_pool_id)
2020-05-04 02:53:28 +00:00
db.session.execute(update_stmt)
2020-06-03 00:43:18 +00:00
merge_pool_posts(source_pool.pool_id, target_pool.pool_id)
2020-05-04 02:53:28 +00:00
delete(source_pool)
def create_pool(
names: List[str], category_name: str, post_ids: List[int]
) -> model.Pool:
2020-05-04 02:53:28 +00:00
pool = model.Pool()
pool.creation_time = datetime.utcnow()
update_pool_names(pool, names)
update_pool_category_name(pool, category_name)
2020-05-04 07:09:33 +00:00
update_pool_posts(pool, post_ids)
2020-05-04 02:53:28 +00:00
return pool
def update_pool_category_name(pool: model.Pool, category_name: str) -> None:
assert pool
pool.category = pool_categories.get_category_by_name(category_name)
def update_pool_names(pool: model.Pool, names: List[str]) -> None:
# sanitize
assert pool
names = util.icase_unique([name for name in names if name])
if not len(names):
raise InvalidPoolNameError("At least one name must be specified.")
2020-05-04 02:53:28 +00:00
for name in names:
_verify_name_validity(name)
# check for existing pools
expr = sa.sql.false()
for name in names:
expr = expr | (sa.func.lower(model.PoolName.name) == name.lower())
if pool.pool_id:
expr = expr & (model.PoolName.pool_id != pool.pool_id)
existing_pools = db.session.query(model.PoolName).filter(expr).all()
if len(existing_pools):
raise PoolAlreadyExistsError(
"One of names is already used by another pool."
)
2020-05-04 02:53:28 +00:00
# remove unwanted items
for pool_name in pool.names[:]:
if not _check_name_intersection([pool_name.name], names, True):
pool.names.remove(pool_name)
# add wanted items
for name in names:
if not _check_name_intersection(_get_names(pool), [name], True):
pool.names.append(model.PoolName(name, -1))
# set alias order to match the request
for i, name in enumerate(names):
for pool_name in pool.names:
if pool_name.name.lower() == name.lower():
pool_name.order = i
def update_pool_description(pool: model.Pool, description: str) -> None:
assert pool
if util.value_exceeds_column_size(description, model.Pool.description):
raise InvalidPoolDescriptionError("Description is too long.")
2020-05-04 02:53:28 +00:00
pool.description = description or None
2020-05-04 07:09:33 +00:00
def update_pool_posts(pool: model.Pool, post_ids: List[int]) -> None:
assert pool
2020-05-05 02:12:54 +00:00
dupes = _duplicates(post_ids)
if len(dupes) > 0:
dupes = ", ".join(list(str(x) for x in dupes))
raise InvalidPoolDuplicateError("Duplicate post(s) in pool: " + dupes)
2020-05-05 02:12:54 +00:00
ret = posts.get_posts_by_ids(post_ids)
if len(post_ids) != len(ret):
missing = set(post_ids) - set(post.post_id for post in ret)
missing = ", ".join(list(str(x) for x in missing))
2020-06-03 15:55:50 +00:00
raise InvalidPoolNonexistentPostError(
"The following posts do not exist: " + missing
)
2020-05-04 07:09:33 +00:00
pool.posts.clear()
2020-05-05 02:12:54 +00:00
for post in ret:
2020-05-04 07:09:33 +00:00
pool.posts.append(post)