gallery.accords-library.com/server/szurubooru/func/pools.py

302 lines
8.9 KiB
Python
Raw Normal View History

2020-05-04 02:53:28 +00:00
import re
from typing import Any, Optional, Tuple, List, Dict, Callable
from datetime import datetime
import sqlalchemy as sa
from szurubooru import config, db, model, errors, rest
2020-05-04 07:09:33 +00:00
from szurubooru.func import util, pool_categories, serialization, posts
2020-05-04 02:53:28 +00:00
class PoolNotFoundError(errors.NotFoundError):
pass
class PoolAlreadyExistsError(errors.ValidationError):
pass
class PoolIsInUseError(errors.ValidationError):
pass
class InvalidPoolNameError(errors.ValidationError):
pass
2020-05-04 07:09:33 +00:00
class InvalidPoolDuplicateError(errors.ValidationError):
2020-05-04 02:53:28 +00:00
pass
class InvalidPoolCategoryError(errors.ValidationError):
pass
class InvalidPoolDescriptionError(errors.ValidationError):
pass
2020-05-04 22:15:30 +00:00
class InvalidPoolRelationError(errors.ValidationError):
pass
2020-05-04 02:53:28 +00:00
def _verify_name_validity(name: str) -> None:
if util.value_exceeds_column_size(name, model.PoolName.name):
raise InvalidPoolNameError('Name is too long.')
name_regex = config.config['pool_name_regex']
if not re.match(name_regex, name):
raise InvalidPoolNameError('Name must satisfy regex %r.' % name_regex)
def _get_names(pool: model.Pool) -> List[str]:
assert pool
return [pool_name.name for pool_name in pool.names]
def _lower_list(names: List[str]) -> List[str]:
return [name.lower() for name in names]
def _check_name_intersection(
names1: List[str], names2: List[str], case_sensitive: bool) -> bool:
if not case_sensitive:
names1 = _lower_list(names1)
names2 = _lower_list(names2)
return len(set(names1).intersection(names2)) > 0
2020-05-04 07:09:33 +00:00
def _check_post_duplication(post_ids: List[int]) -> bool:
return len(post_ids) != len(set(post_ids))
2020-05-04 02:53:28 +00:00
def sort_pools(pools: List[model.Pool]) -> List[model.Pool]:
default_category_name = pool_categories.get_default_category_name()
return sorted(
pools,
key=lambda pool: (
default_category_name == pool.category.name,
pool.category.name,
pool.names[0].name)
)
class PoolSerializer(serialization.BaseSerializer):
def __init__(self, pool: model.Pool) -> None:
self.pool = pool
def _serializers(self) -> Dict[str, Callable[[], Any]]:
return {
'id': self.serialize_id,
'names': self.serialize_names,
'category': self.serialize_category,
'version': self.serialize_version,
'description': self.serialize_description,
'creationTime': self.serialize_creation_time,
'lastEditTime': self.serialize_last_edit_time,
2020-05-04 07:09:33 +00:00
'postCount': self.serialize_post_count,
'posts': self.serialize_posts
2020-05-04 02:53:28 +00:00
}
def serialize_id(self) -> Any:
return self.pool.pool_id
def serialize_names(self) -> Any:
return [pool_name.name for pool_name in self.pool.names]
def serialize_category(self) -> Any:
return self.pool.category.name
def serialize_version(self) -> Any:
return self.pool.version
def serialize_description(self) -> Any:
return self.pool.description
def serialize_creation_time(self) -> Any:
return self.pool.creation_time
def serialize_last_edit_time(self) -> Any:
return self.pool.last_edit_time
def serialize_post_count(self) -> Any:
return self.pool.post_count
2020-05-04 07:09:33 +00:00
def serialize_posts(self) -> Any:
return [
{
'id': post.post_id
}
for post in self.pool.posts]
2020-05-04 02:53:28 +00:00
def serialize_pool(
pool: model.Pool, options: List[str] = []) -> Optional[rest.Response]:
if not pool:
return None
return PoolSerializer(pool).serialize(options)
def try_get_pool_by_id(pool_id: int) -> Optional[model.Pool]:
return (
db.session
.query(model.Pool)
.filter(model.Pool.pool_id == pool_id)
.one_or_none())
def get_pool_by_id(pool_id: int) -> model.Pool:
pool = try_get_pool_by_id(pool_id)
if not pool:
raise PoolNotFoundError('Pool %r not found.' % pool_id)
return pool
def try_get_pool_by_name(name: str) -> Optional[model.Pool]:
return (
db.session
.query(model.Pool)
.join(model.PoolName)
.filter(sa.func.lower(model.PoolName.name) == name.lower())
.one_or_none())
def get_pool_by_name(name: str) -> model.Pool:
pool = try_get_pool_by_name(name)
if not pool:
raise PoolNotFoundError('Pool %r not found.' % name)
return pool
def get_pools_by_names(names: List[str]) -> List[model.Pool]:
names = util.icase_unique(names)
if len(names) == 0:
return []
return (
db.session.query(model.Pool)
.join(model.PoolName)
.filter(
sa.sql.or_(
sa.func.lower(model.PoolName.name) == name.lower()
for name in names))
.all())
def get_or_create_pools_by_names(
names: List[str]) -> Tuple[List[model.Pool], List[model.Pool]]:
names = util.icase_unique(names)
existing_pools = get_pools_by_names(names)
new_pools = []
pool_category_name = pool_categories.get_default_category_name()
for name in names:
found = False
for existing_pool in existing_pools:
if _check_name_intersection(
_get_names(existing_pool), [name], False):
found = True
break
if not found:
new_pool = create_pool(
names=[name],
2020-05-04 07:09:33 +00:00
category_name=pool_category_name,
post_ids=[])
2020-05-04 02:53:28 +00:00
db.session.add(new_pool)
new_pools.append(new_pool)
return existing_pools, new_pools
def delete(source_pool: model.Pool) -> None:
assert source_pool
db.session.delete(source_pool)
def merge_pools(source_pool: model.Pool, target_pool: model.Pool) -> None:
assert source_pool
assert target_pool
if source_pool.pool_id == target_pool.pool_id:
raise InvalidPoolRelationError('Cannot merge pool with itself.')
def merge_posts(source_pool_id: int, target_pool_id: int) -> None:
2020-05-04 22:15:30 +00:00
alias1 = model.PoolPost
alias2 = sa.orm.util.aliased(model.PoolPost)
2020-05-04 02:53:28 +00:00
update_stmt = (
sa.sql.expression.update(alias1)
2020-05-04 22:15:30 +00:00
.where(alias1.pool_id == source_pool_id))
2020-05-04 02:53:28 +00:00
update_stmt = (
2020-05-04 22:15:30 +00:00
update_stmt
2020-05-04 02:53:28 +00:00
.where(
~sa.exists()
2020-05-04 22:15:30 +00:00
.where(alias1.post_id == alias2.post_id)
.where(alias2.pool_id == target_pool_id)))
update_stmt = update_stmt.values(pool_id=target_pool_id)
2020-05-04 02:53:28 +00:00
db.session.execute(update_stmt)
merge_posts(source_pool.pool_id, target_pool.pool_id)
delete(source_pool)
def create_pool(
names: List[str],
2020-05-04 07:09:33 +00:00
category_name: str,
post_ids: List[int]) -> model.Pool:
2020-05-04 02:53:28 +00:00
pool = model.Pool()
pool.creation_time = datetime.utcnow()
update_pool_names(pool, names)
update_pool_category_name(pool, category_name)
2020-05-04 07:09:33 +00:00
update_pool_posts(pool, post_ids)
2020-05-04 02:53:28 +00:00
return pool
def update_pool_category_name(pool: model.Pool, category_name: str) -> None:
assert pool
pool.category = pool_categories.get_category_by_name(category_name)
def update_pool_names(pool: model.Pool, names: List[str]) -> None:
# sanitize
assert pool
names = util.icase_unique([name for name in names if name])
if not len(names):
raise InvalidPoolNameError('At least one name must be specified.')
for name in names:
_verify_name_validity(name)
# check for existing pools
expr = sa.sql.false()
for name in names:
expr = expr | (sa.func.lower(model.PoolName.name) == name.lower())
if pool.pool_id:
expr = expr & (model.PoolName.pool_id != pool.pool_id)
existing_pools = db.session.query(model.PoolName).filter(expr).all()
if len(existing_pools):
raise PoolAlreadyExistsError(
'One of names is already used by another pool.')
# remove unwanted items
for pool_name in pool.names[:]:
if not _check_name_intersection([pool_name.name], names, True):
pool.names.remove(pool_name)
# add wanted items
for name in names:
if not _check_name_intersection(_get_names(pool), [name], True):
pool.names.append(model.PoolName(name, -1))
# set alias order to match the request
for i, name in enumerate(names):
for pool_name in pool.names:
if pool_name.name.lower() == name.lower():
pool_name.order = i
def update_pool_description(pool: model.Pool, description: str) -> None:
assert pool
if util.value_exceeds_column_size(description, model.Pool.description):
raise InvalidPoolDescriptionError('Description is too long.')
pool.description = description or None
2020-05-04 07:09:33 +00:00
def update_pool_posts(pool: model.Pool, post_ids: List[int]) -> None:
assert pool
if _check_post_duplication(post_ids):
raise InvalidPoolDuplicateError('Duplicate post in pool.')
pool.posts.clear()
for post in posts.get_posts_by_ids(post_ids):
pool.posts.append(post)