234 lines
6.2 KiB
Python
234 lines
6.2 KiB
Python
# pylint: disable=redefined-outer-name
|
|
import contextlib
|
|
import os
|
|
import random
|
|
import string
|
|
from unittest.mock import patch
|
|
from datetime import datetime
|
|
import pytest
|
|
import freezegun
|
|
import sqlalchemy as sa
|
|
from szurubooru import config, db, model, rest
|
|
|
|
|
|
class QueryCounter:
|
|
def __init__(self):
|
|
self._statements = []
|
|
|
|
def __enter__(self):
|
|
self._statements = []
|
|
|
|
def __exit__(self, *args, **kwargs):
|
|
self._statements = []
|
|
|
|
def create_before_cursor_execute(self):
|
|
def before_cursor_execute(
|
|
_conn, _cursor, statement, _params, _context, _executemany):
|
|
self._statements.append(statement)
|
|
return before_cursor_execute
|
|
|
|
@property
|
|
def statements(self):
|
|
return self._statements
|
|
|
|
|
|
if not config.config['test_database']:
|
|
raise RuntimeError('Test database not configured.')
|
|
|
|
_query_counter = QueryCounter()
|
|
_engine = sa.create_engine(config.config['test_database'])
|
|
model.Base.metadata.drop_all(bind=_engine)
|
|
model.Base.metadata.create_all(bind=_engine)
|
|
sa.event.listen(
|
|
_engine,
|
|
'before_cursor_execute',
|
|
_query_counter.create_before_cursor_execute())
|
|
|
|
|
|
def get_unique_name():
|
|
alphabet = string.ascii_letters + string.digits
|
|
return ''.join(random.choice(alphabet) for _ in range(8))
|
|
|
|
|
|
@pytest.fixture
|
|
def fake_datetime():
|
|
@contextlib.contextmanager
|
|
def injector(now):
|
|
freezer = freezegun.freeze_time(now)
|
|
freezer.start()
|
|
yield
|
|
freezer.stop()
|
|
return injector
|
|
|
|
|
|
@pytest.fixture()
|
|
def query_counter():
|
|
return _query_counter
|
|
|
|
|
|
@pytest.fixture
|
|
def query_logger():
|
|
if pytest.config.option.verbose > 0:
|
|
import logging
|
|
import coloredlogs
|
|
coloredlogs.install(
|
|
fmt='[%(asctime)-15s] %(name)s %(message)s', isatty=True)
|
|
logging.basicConfig()
|
|
logging.getLogger('sqlalchemy.engine').setLevel(logging.INFO)
|
|
|
|
|
|
@pytest.yield_fixture(scope='function', autouse=True)
|
|
def session(query_logger): # pylint: disable=unused-argument
|
|
db.sessionmaker = sa.orm.sessionmaker(
|
|
bind=_engine, autoflush=False)
|
|
db.session = sa.orm.scoped_session(db.sessionmaker)
|
|
try:
|
|
yield db.session
|
|
finally:
|
|
db.session.remove()
|
|
for table in reversed(model.Base.metadata.sorted_tables):
|
|
db.session.execute(table.delete())
|
|
db.session.commit()
|
|
|
|
|
|
@pytest.fixture
|
|
def context_factory(session):
|
|
def factory(params=None, files=None, user=None):
|
|
ctx = rest.Context(
|
|
method=None,
|
|
url=None,
|
|
headers={},
|
|
params=params or {},
|
|
files=files or {})
|
|
ctx.session = session
|
|
ctx.user = user or model.User()
|
|
return ctx
|
|
return factory
|
|
|
|
|
|
@pytest.fixture
|
|
def config_injector():
|
|
def injector(new_config_content):
|
|
config.config = new_config_content
|
|
return injector
|
|
|
|
|
|
@pytest.fixture
|
|
def user_factory():
|
|
def factory(name=None, rank=model.User.RANK_REGULAR, email='dummy'):
|
|
user = model.User()
|
|
user.name = name or get_unique_name()
|
|
user.password_salt = 'dummy'
|
|
user.password_hash = 'dummy'
|
|
user.email = email
|
|
user.rank = rank
|
|
user.creation_time = datetime(1997, 1, 1)
|
|
user.avatar_style = model.User.AVATAR_GRAVATAR
|
|
return user
|
|
return factory
|
|
|
|
|
|
@pytest.fixture
|
|
def tag_category_factory():
|
|
def factory(name=None, color='dummy', default=False):
|
|
category = model.TagCategory()
|
|
category.name = name or get_unique_name()
|
|
category.color = color
|
|
category.default = default
|
|
return category
|
|
return factory
|
|
|
|
|
|
@pytest.fixture
|
|
def tag_factory():
|
|
def factory(names=None, category=None):
|
|
if not category:
|
|
category = model.TagCategory(get_unique_name())
|
|
db.session.add(category)
|
|
tag = model.Tag()
|
|
tag.names = []
|
|
for i, name in enumerate(names or [get_unique_name()]):
|
|
tag.names.append(model.TagName(name, i))
|
|
tag.category = category
|
|
tag.creation_time = datetime(1996, 1, 1)
|
|
return tag
|
|
return factory
|
|
|
|
|
|
@pytest.yield_fixture
|
|
def skip_post_hashing():
|
|
with patch('szurubooru.func.image_hash.add_image'), \
|
|
patch('szurubooru.func.image_hash.delete_image'):
|
|
yield
|
|
|
|
|
|
@pytest.fixture
|
|
def post_factory(skip_post_hashing):
|
|
# pylint: disable=invalid-name
|
|
def factory(
|
|
id=None,
|
|
safety=model.Post.SAFETY_SAFE,
|
|
type=model.Post.TYPE_IMAGE,
|
|
checksum='...'):
|
|
post = model.Post()
|
|
post.post_id = id
|
|
post.safety = safety
|
|
post.type = type
|
|
post.checksum = checksum
|
|
post.flags = []
|
|
post.mime_type = 'application/octet-stream'
|
|
post.creation_time = datetime(1996, 1, 1)
|
|
return post
|
|
return factory
|
|
|
|
|
|
@pytest.fixture
|
|
def comment_factory(user_factory, post_factory):
|
|
def factory(user=None, post=None, text='dummy', time=None):
|
|
if not user:
|
|
user = user_factory()
|
|
db.session.add(user)
|
|
if not post:
|
|
post = post_factory()
|
|
db.session.add(post)
|
|
comment = model.Comment()
|
|
comment.user = user
|
|
comment.post = post
|
|
comment.text = text
|
|
comment.creation_time = time or datetime(1996, 1, 1)
|
|
return comment
|
|
return factory
|
|
|
|
|
|
@pytest.fixture
|
|
def post_score_factory(user_factory, post_factory):
|
|
def factory(post=None, user=None, score=1):
|
|
if user is None:
|
|
user = user_factory()
|
|
if post is None:
|
|
post = post_factory()
|
|
return model.PostScore(
|
|
post=post, user=user, score=score, time=datetime(1999, 1, 1))
|
|
return factory
|
|
|
|
|
|
@pytest.fixture
|
|
def post_favorite_factory(user_factory, post_factory):
|
|
def factory(post=None, user=None):
|
|
if user is None:
|
|
user = user_factory()
|
|
if post is None:
|
|
post = post_factory()
|
|
return model.PostFavorite(
|
|
post=post, user=user, time=datetime(1999, 1, 1))
|
|
return factory
|
|
|
|
|
|
@pytest.fixture
|
|
def read_asset():
|
|
def get(path):
|
|
path = os.path.join(os.path.dirname(__file__), 'assets', path)
|
|
with open(path, 'rb') as handle:
|
|
return handle.read()
|
|
return get
|