diff --git a/server/szurubooru/search/base_search_config.py b/server/szurubooru/search/base_search_config.py index 8de0a23..1d63830 100644 --- a/server/szurubooru/search/base_search_config.py +++ b/server/szurubooru/search/base_search_config.py @@ -69,6 +69,9 @@ def _apply_str_criterion_to_column(column, query, criterion): return query.filter(expr) class BaseSearchConfig(object): + ORDER_DESC = 1 + ORDER_ASC = 2 + def create_query(self, session): raise NotImplementedError() diff --git a/server/szurubooru/search/search_executor.py b/server/szurubooru/search/search_executor.py index 8cbd6c3..154d5f2 100644 --- a/server/szurubooru/search/search_executor.py +++ b/server/szurubooru/search/search_executor.py @@ -9,9 +9,6 @@ class SearchExecutor(object): delegates sqlalchemy filter decoration to SearchConfig instances. ''' - ORDER_DESC = 1 - ORDER_ASC = 2 - def __init__(self, search_config): self._search_config = search_config @@ -52,13 +49,13 @@ class SearchExecutor(object): def _handle_key_value(self, query, key, value, negated): if key == 'order': if value.count(',') == 0: - order = self.ORDER_ASC + order = None elif value.count(',') == 1: value, order_str = value.split(',') if order_str == 'asc': - order = self.ORDER_ASC + order = self._search_config.ORDER_ASC elif order_str == 'desc': - order = self.ORDER_DESC + order = self._search_config.ORDER_DESC else: raise errors.SearchError( 'Unknown search direction: %r.' % order_str) @@ -66,10 +63,12 @@ class SearchExecutor(object): raise errors.SearchError( 'Too many commas in order search token.') if negated: - if order == self.ORDER_DESC: - order = self.ORDER_ASC + if order == self._search_config.ORDER_DESC: + order = self._search_config.ORDER_ASC + elif order == self._search_config.ORDER_ASC: + order = self._search_config.ORDER_DESC else: - order = self.ORDER_DESC + order = -1 return self._handle_order(query, value, order) elif key == 'special': return self._handle_special(query, value, negated) @@ -100,8 +99,17 @@ class SearchExecutor(object): def _handle_order(self, query, value, order): if value in self._search_config.order_columns: - column = self._search_config.order_columns[value] - if order == self.ORDER_ASC: + column, default_order = self._search_config.order_columns[value] + if order is None: + order = default_order + elif order == -1: + if default_order == self._search_config.ORDER_ASC: + order = self._search_config.ORDER_DESC + elif default_order == self._search_config.ORDER_DESC: + order = self._search_config.ORDER_ASC + else: + order = self._search_config.ORDER_ASC + if order == self._search_config.ORDER_ASC: column = column.asc() else: column = column.desc() diff --git a/server/szurubooru/search/user_search_config.py b/server/szurubooru/search/user_search_config.py index 45796f1..09798d2 100644 --- a/server/szurubooru/search/user_search_config.py +++ b/server/szurubooru/search/user_search_config.py @@ -35,11 +35,11 @@ class UserSearchConfig(BaseSearchConfig): def order_columns(self): return { 'random': func.random(), - 'name': db.User.name, - 'creation-date': db.User.creation_time, - 'creation-time': db.User.creation_time, - 'last-login-date': db.User.last_login_time, - 'last-login-time': db.User.last_login_time, - 'login-date': db.User.last_login_time, - 'login-time': db.User.last_login_time, + 'name': (db.User.name, self.ORDER_ASC), + 'creation-date': (db.User.creation_time, self.ORDER_DESC), + 'creation-time': (db.User.creation_time, self.ORDER_DESC), + 'last-login-date': (db.User.last_login_time, self.ORDER_DESC), + 'last-login-time': (db.User.last_login_time, self.ORDER_DESC), + 'login-date': (db.User.last_login_time, self.ORDER_DESC), + 'login-time': (db.User.last_login_time, self.ORDER_DESC), } diff --git a/server/szurubooru/tests/search/test_user_search_config.py b/server/szurubooru/tests/search/test_user_search_config.py index b1ec933..55a3a86 100644 --- a/server/szurubooru/tests/search/test_user_search_config.py +++ b/server/szurubooru/tests/search/test_user_search_config.py @@ -1,19 +1,7 @@ -from datetime import datetime +import datetime import pytest from szurubooru import db, errors, search -def mock_user(name): - user = db.User() - user.name = name - user.password = 'dummy' - user.password_salt = 'dummy' - user.password_hash = 'dummy' - user.email = 'dummy' - user.rank = 'dummy' - user.creation_time = datetime(1997, 1, 1) - user.avatar_style = db.User.AVATAR_GRAVATAR - return user - @pytest.fixture def executor(session): search_config = search.UserSearchConfig() @@ -29,7 +17,6 @@ def verify_unpaged(session, executor): assert actual_user_names == expected_user_names return verify -# ----------------------------------------------------------------------------- @pytest.mark.parametrize('input,expected_user_names', [ ('creation-time:2014', ['u1', 'u2']), ('creation-date:2014', ['u1', 'u2']), @@ -53,17 +40,16 @@ def verify_unpaged(session, executor): ('-creation-date:2014-01,2015', ['u2']), ]) def test_filter_by_creation_time( - verify_unpaged, session, input, expected_user_names): - user1 = mock_user('u1') - user2 = mock_user('u2') - user3 = mock_user('u3') - user1.creation_time = datetime(2014, 1, 1) - user2.creation_time = datetime(2014, 6, 1) - user3.creation_time = datetime(2015, 1, 1) + verify_unpaged, session, input, expected_user_names, user_factory): + user1 = user_factory(name='u1') + user2 = user_factory(name='u2') + user3 = user_factory(name='u3') + user1.creation_time = datetime.datetime(2014, 1, 1) + user2.creation_time = datetime.datetime(2014, 6, 1) + user3.creation_time = datetime.datetime(2015, 1, 1) session.add_all([user1, user2, user3]) verify_unpaged(input, expected_user_names) -# ----------------------------------------------------------------------------- @pytest.mark.parametrize('input,expected_user_names', [ ('name:user1', ['user1']), ('name:user2', ['user2']), @@ -82,41 +68,41 @@ def test_filter_by_creation_time( ('name:user1,user2', ['user1', 'user2']), ('-name:user1,user3', ['user2']), ]) -def test_filter_by_name(session, verify_unpaged, input, expected_user_names): - session.add(mock_user('user1')) - session.add(mock_user('user2')) - session.add(mock_user('user3')) +def test_filter_by_name( + session, verify_unpaged, input, expected_user_names, user_factory): + session.add(user_factory(name='user1')) + session.add(user_factory(name='user2')) + session.add(user_factory(name='user3')) verify_unpaged(input, expected_user_names) -# ----------------------------------------------------------------------------- @pytest.mark.parametrize('input,expected_user_names', [ ('', ['u1', 'u2']), ('u1', ['u1']), ('u2', ['u2']), ('u1,u2', ['u1', 'u2']), ]) -def test_anonymous(session, verify_unpaged, input, expected_user_names): - session.add(mock_user('u1')) - session.add(mock_user('u2')) +def test_anonymous( + session, verify_unpaged, input, expected_user_names, user_factory): + session.add(user_factory(name='u1')) + session.add(user_factory(name='u2')) verify_unpaged(input, expected_user_names) -# ----------------------------------------------------------------------------- @pytest.mark.parametrize('input,expected_user_names', [ ('creation-time:2014 u1', ['u1']), ('creation-time:2014 u2', ['u2']), ('creation-time:2016 u2', []), ]) -def test_combining_tokens(session, verify_unpaged, input, expected_user_names): - user1 = mock_user('u1') - user2 = mock_user('u2') - user3 = mock_user('u3') - user1.creation_time = datetime(2014, 1, 1) - user2.creation_time = datetime(2014, 6, 1) - user3.creation_time = datetime(2015, 1, 1) +def test_combining_tokens( + session, verify_unpaged, input, expected_user_names, user_factory): + user1 = user_factory(name='u1') + user2 = user_factory(name='u2') + user3 = user_factory(name='u3') + user1.creation_time = datetime.datetime(2014, 1, 1) + user2.creation_time = datetime.datetime(2014, 6, 1) + user3.creation_time = datetime.datetime(2015, 1, 1) session.add_all([user1, user2, user3]) verify_unpaged(input, expected_user_names) -# ----------------------------------------------------------------------------- @pytest.mark.parametrize( 'page,page_size,expected_total_count,expected_user_names', [ (1, 1, 2, ['u1']), @@ -126,17 +112,16 @@ def test_combining_tokens(session, verify_unpaged, input, expected_user_names): (0, 0, 2, []), ]) def test_paging( - session, executor, page, page_size, + session, executor, user_factory, page, page_size, expected_total_count, expected_user_names): - session.add(mock_user('u1')) - session.add(mock_user('u2')) + session.add(user_factory(name='u1')) + session.add(user_factory(name='u2')) actual_count, actual_users = executor.execute( session, '', page=page, page_size=page_size) actual_user_names = [u.name for u in actual_users] assert actual_count == expected_total_count assert actual_user_names == expected_user_names -# ----------------------------------------------------------------------------- @pytest.mark.parametrize('input,expected_user_names', [ ('', ['u1', 'u2']), ('order:name', ['u1', 'u2']), @@ -146,12 +131,32 @@ def test_paging( ('-order:name,asc', ['u2', 'u1']), ('-order:name,desc', ['u1', 'u2']), ]) -def test_order_by_name(session, verify_unpaged, input, expected_user_names): - session.add(mock_user('u2')) - session.add(mock_user('u1')) +def test_order_by_name( + session, verify_unpaged, input, expected_user_names, user_factory): + session.add(user_factory(name='u2')) + session.add(user_factory(name='u1')) + verify_unpaged(input, expected_user_names) + +@pytest.mark.parametrize('input,expected_user_names', [ + ('', ['u1', 'u2', 'u3']), + ('order:creation-date', ['u3', 'u2', 'u1']), + ('-order:creation-date', ['u1', 'u2', 'u3']), + ('order:creation-date,asc', ['u1', 'u2', 'u3']), + ('order:creation-date,desc', ['u3', 'u2', 'u1']), + ('-order:creation-date,asc', ['u3', 'u2', 'u1']), + ('-order:creation-date,desc', ['u1', 'u2', 'u3']), +]) +def test_order_by_name( + session, verify_unpaged, input, expected_user_names, user_factory): + user1 = user_factory(name='u1') + user2 = user_factory(name='u2') + user3 = user_factory(name='u3') + user1.creation_time = datetime.datetime(1991, 1, 1) + user2.creation_time = datetime.datetime(1991, 1, 2) + user3.creation_time = datetime.datetime(1991, 1, 3) + session.add_all([user3, user1, user2]) verify_unpaged(input, expected_user_names) -# ----------------------------------------------------------------------------- @pytest.mark.parametrize('input,expected_error', [ ('creation-date:..', errors.SearchError), ('creation-date:bad..', errors.ValidationError),