188 lines
5.7 KiB
Python
188 lines
5.7 KiB
Python
from datetime import datetime
|
|
from typing import Any, Callable, Dict, Optional
|
|
|
|
import sqlalchemy as sa
|
|
|
|
from szurubooru import db, model
|
|
from szurubooru.func import diff, net, users
|
|
|
|
|
|
def get_tag_category_snapshot(category: model.TagCategory) -> Dict[str, Any]:
|
|
assert category
|
|
return {
|
|
"name": category.name,
|
|
"color": category.color,
|
|
"default": True if category.default else False,
|
|
}
|
|
|
|
|
|
def get_tag_snapshot(tag: model.Tag) -> Dict[str, Any]:
|
|
assert tag
|
|
return {
|
|
"names": [tag_name.name for tag_name in tag.names],
|
|
"category": tag.category.name,
|
|
"suggestions": sorted(rel.first_name for rel in tag.suggestions),
|
|
"implications": sorted(rel.first_name for rel in tag.implications),
|
|
}
|
|
|
|
|
|
def get_pool_category_snapshot(category: model.PoolCategory) -> Dict[str, Any]:
|
|
assert category
|
|
return {
|
|
"name": category.name,
|
|
"color": category.color,
|
|
"default": True if category.default else False,
|
|
}
|
|
|
|
|
|
def get_pool_snapshot(pool: model.Pool) -> Dict[str, Any]:
|
|
assert pool
|
|
return {
|
|
"names": [pool_name.name for pool_name in pool.names],
|
|
"category": pool.category.name,
|
|
"posts": [post.post_id for post in pool.posts],
|
|
}
|
|
|
|
|
|
def get_post_snapshot(post: model.Post) -> Dict[str, Any]:
|
|
assert post
|
|
return {
|
|
"source": post.source,
|
|
"safety": post.safety,
|
|
"checksum": post.checksum,
|
|
"flags": post.flags,
|
|
"featured": post.is_featured,
|
|
"tags": sorted([tag.first_name for tag in post.tags]),
|
|
"relations": sorted([rel.post_id for rel in post.relations]),
|
|
"notes": sorted(
|
|
[
|
|
{
|
|
"polygon": [
|
|
[point[0], point[1]] for point in note.polygon
|
|
],
|
|
"text": note.text,
|
|
}
|
|
for note in post.notes
|
|
],
|
|
key=lambda x: x["polygon"],
|
|
),
|
|
}
|
|
|
|
|
|
_snapshot_factories = {
|
|
# lambdas allow mocking target functions in the tests
|
|
"tag_category": lambda entity: get_tag_category_snapshot(entity),
|
|
"tag": lambda entity: get_tag_snapshot(entity),
|
|
"post": lambda entity: get_post_snapshot(entity),
|
|
"pool_category": lambda entity: get_pool_category_snapshot(entity),
|
|
"pool": lambda entity: get_pool_snapshot(entity),
|
|
} # type: Dict[model.Base, Callable[[model.Base], Dict[str ,Any]]]
|
|
|
|
|
|
def serialize_snapshot(
|
|
snapshot: model.Snapshot, auth_user: model.User
|
|
) -> Dict[str, Any]:
|
|
assert snapshot
|
|
return {
|
|
"operation": snapshot.operation,
|
|
"type": snapshot.resource_type,
|
|
"id": snapshot.resource_name,
|
|
"user": users.serialize_micro_user(snapshot.user, auth_user),
|
|
"data": snapshot.data,
|
|
"time": snapshot.creation_time,
|
|
}
|
|
|
|
|
|
def _post_to_webhooks(snapshot: model.Snapshot) -> None:
|
|
webhook_user = model.User()
|
|
webhook_user.name = None
|
|
webhook_user.rank = "anonymous"
|
|
net.post_to_webhooks(serialize_snapshot(snapshot, webhook_user))
|
|
|
|
|
|
def _create(
|
|
operation: str, entity: model.Base, auth_user: Optional[model.User]
|
|
) -> model.Snapshot:
|
|
resource_type, resource_pkey, resource_name = model.util.get_resource_info(
|
|
entity
|
|
)
|
|
|
|
snapshot = model.Snapshot()
|
|
snapshot.creation_time = datetime.utcnow()
|
|
snapshot.operation = operation
|
|
snapshot.resource_type = resource_type
|
|
snapshot.resource_pkey = resource_pkey
|
|
snapshot.resource_name = resource_name
|
|
snapshot.user = auth_user
|
|
return snapshot
|
|
|
|
|
|
def create(entity: model.Base, auth_user: Optional[model.User]) -> None:
|
|
assert entity
|
|
snapshot = _create(model.Snapshot.OPERATION_CREATED, entity, auth_user)
|
|
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
|
snapshot.data = snapshot_factory(entity)
|
|
db.session.add(snapshot)
|
|
_post_to_webhooks(snapshot)
|
|
|
|
|
|
def modify(entity: model.Base, auth_user: Optional[model.User]) -> None:
|
|
assert entity
|
|
|
|
table = next(
|
|
(
|
|
cls
|
|
for cls in model.Base._decl_class_registry.values()
|
|
if hasattr(cls, "__table__")
|
|
and cls.__table__.fullname == entity.__table__.fullname
|
|
),
|
|
None,
|
|
)
|
|
assert table
|
|
|
|
snapshot = _create(model.Snapshot.OPERATION_MODIFIED, entity, auth_user)
|
|
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
|
|
|
detached_session = sa.orm.sessionmaker(bind=db.session.get_bind())()
|
|
detached_entity = detached_session.query(table).get(snapshot.resource_pkey)
|
|
assert detached_entity, "Entity not found in DB, have you committed it?"
|
|
detached_snapshot = snapshot_factory(detached_entity)
|
|
detached_session.close()
|
|
|
|
active_snapshot = snapshot_factory(entity)
|
|
|
|
snapshot.data = diff.get_dict_diff(detached_snapshot, active_snapshot)
|
|
if not snapshot.data:
|
|
return
|
|
db.session.add(snapshot)
|
|
_post_to_webhooks(snapshot)
|
|
|
|
|
|
def delete(entity: model.Base, auth_user: Optional[model.User]) -> None:
|
|
assert entity
|
|
snapshot = _create(model.Snapshot.OPERATION_DELETED, entity, auth_user)
|
|
snapshot_factory = _snapshot_factories[snapshot.resource_type]
|
|
snapshot.data = snapshot_factory(entity)
|
|
db.session.add(snapshot)
|
|
_post_to_webhooks(snapshot)
|
|
|
|
|
|
def merge(
|
|
source_entity: model.Base,
|
|
target_entity: model.Base,
|
|
auth_user: Optional[model.User],
|
|
) -> None:
|
|
assert source_entity
|
|
assert target_entity
|
|
snapshot = _create(
|
|
model.Snapshot.OPERATION_MERGED, source_entity, auth_user
|
|
)
|
|
(
|
|
resource_type,
|
|
_resource_pkey,
|
|
resource_name,
|
|
) = model.util.get_resource_info(target_entity)
|
|
snapshot.data = [resource_type, resource_name]
|
|
db.session.add(snapshot)
|
|
_post_to_webhooks(snapshot)
|