server/general: improve pylint score
+ incorporate some in-house rules - no more useless doc strings...
This commit is contained in:
parent
9ce67b64ed
commit
2578a297bf
|
@ -0,0 +1,14 @@
|
||||||
|
[basic]
|
||||||
|
method-rgx=[a-z_][a-z0-9_]{2,30}$|^test_
|
||||||
|
|
||||||
|
[variables]
|
||||||
|
dummy-variables-rgx=_|dummy
|
||||||
|
|
||||||
|
[format]
|
||||||
|
max-line-length=90
|
||||||
|
|
||||||
|
[messages control]
|
||||||
|
disable=missing-docstring,no-self-use,too-few-public-methods
|
||||||
|
|
||||||
|
[typecheck]
|
||||||
|
generated-members=add|add_all
|
|
@ -1,12 +1,10 @@
|
||||||
''' Exports BaseApi. '''
|
|
||||||
|
|
||||||
import types
|
import types
|
||||||
|
|
||||||
def _bind_method(target, desired_method_name):
|
def _bind_method(target, desired_method_name):
|
||||||
actual_method = getattr(target, desired_method_name)
|
actual_method = getattr(target, desired_method_name)
|
||||||
def _wrapper_method(self, request, response, *args, **kwargs):
|
def _wrapper_method(_self, request, _response, *args, **kwargs):
|
||||||
request.context.result = actual_method(
|
request.context.result = actual_method(
|
||||||
request, request.context, *args, **kwargs)
|
request.context, *args, **kwargs)
|
||||||
return types.MethodType(_wrapper_method, target)
|
return types.MethodType(_wrapper_method, target)
|
||||||
|
|
||||||
class BaseApi(object):
|
class BaseApi(object):
|
||||||
|
|
|
@ -1,47 +1,47 @@
|
||||||
''' Exports PasswordReminderApi. '''
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
|
from szurubooru import errors
|
||||||
from szurubooru.api.base_api import BaseApi
|
from szurubooru.api.base_api import BaseApi
|
||||||
from szurubooru.errors import ValidationError, NotFoundError
|
|
||||||
|
MAIL_SUBJECT = 'Password reset for {name}'
|
||||||
|
MAIL_BODY = \
|
||||||
|
'You (or someone else) requested to reset your password on {name}.\n' \
|
||||||
|
'If you wish to proceed, click this link: {url}\n' \
|
||||||
|
'Otherwise, please ignore this email.'
|
||||||
|
|
||||||
class PasswordReminderApi(BaseApi):
|
class PasswordReminderApi(BaseApi):
|
||||||
''' API for password reminders. '''
|
|
||||||
def __init__(self, config, mailer, user_service):
|
def __init__(self, config, mailer, user_service):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._config = config
|
self._config = config
|
||||||
self._mailer = mailer
|
self._mailer = mailer
|
||||||
self._user_service = user_service
|
self._user_service = user_service
|
||||||
|
|
||||||
def get(self, request, context, user_name):
|
def get(self, context, user_name):
|
||||||
user = self._user_service.get_by_name(context.session, user_name)
|
user = self._user_service.get_by_name(context.session, user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError('User %r not found.' % user_name)
|
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||||
if not user.email:
|
if not user.email:
|
||||||
raise ValidationError(
|
raise errors.ValidationError(
|
||||||
'User %r hasn\'t supplied email. Cannot reset password.' % user_name)
|
'User %r hasn\'t supplied email. Cannot reset password.' % user_name)
|
||||||
token = self._generate_authentication_token(user)
|
token = self._generate_authentication_token(user)
|
||||||
|
url = '%s/password-reset/%s' % (
|
||||||
|
self._config['basic']['base_url'].rstrip('/'), token)
|
||||||
self._mailer.send(
|
self._mailer.send(
|
||||||
'noreply@%s' % self._config['basic']['name'],
|
'noreply@%s' % self._config['basic']['name'],
|
||||||
user.email,
|
user.email,
|
||||||
'Password reset for %s' % self._config['basic']['name'],
|
MAIL_SUBJECT.format(name=self._config['basic']['name']),
|
||||||
'You (or someone else) requested to reset your password on %s.\n'
|
MAIL_BODY.format(name=self._config['basic']['name'], url=url))
|
||||||
'If you wish to proceed, click this link: %s/password-reset/%s\n'
|
|
||||||
'Otherwise, please ignore this email.' %
|
|
||||||
(self._config['basic']['name'],
|
|
||||||
self._config['basic']['base_url'].rstrip('/'),
|
|
||||||
token))
|
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def post(self, request, context, user_name):
|
def post(self, context, user_name):
|
||||||
user = self._user_service.get_by_name(context.session, user_name)
|
user = self._user_service.get_by_name(context.session, user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError('User %r not found.' % user_name)
|
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||||
good_token = self._generate_authentication_token(user)
|
good_token = self._generate_authentication_token(user)
|
||||||
if not 'token' in context.request:
|
if not 'token' in context.request:
|
||||||
raise ValidationError('Missing password reset token.')
|
raise errors.ValidationError('Missing password reset token.')
|
||||||
token = context.request['token']
|
token = context.request['token']
|
||||||
if token != good_token:
|
if token != good_token:
|
||||||
raise ValidationError('Invalid password reset token.')
|
raise errors.ValidationError('Invalid password reset token.')
|
||||||
new_password = self._user_service.reset_password(user)
|
new_password = self._user_service.reset_password(user)
|
||||||
context.session.commit()
|
context.session.commit()
|
||||||
return {'password': new_password}
|
return {'password': new_password}
|
||||||
|
|
|
@ -1,11 +1,9 @@
|
||||||
''' Exports UserListApi and UserDetailApi. '''
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
|
from szurubooru import errors
|
||||||
|
from szurubooru import util
|
||||||
from szurubooru.api.base_api import BaseApi
|
from szurubooru.api.base_api import BaseApi
|
||||||
from szurubooru.errors import IntegrityError, ValidationError, NotFoundError, AuthError
|
from szurubooru.services import search
|
||||||
from szurubooru.services.search import UserSearchConfig, SearchExecutor
|
|
||||||
from szurubooru.util import is_valid_email
|
|
||||||
|
|
||||||
def _serialize_user(authenticated_user, user):
|
def _serialize_user(authenticated_user, user):
|
||||||
ret = {
|
ret = {
|
||||||
|
@ -22,17 +20,18 @@ def _serialize_user(authenticated_user, user):
|
||||||
|
|
||||||
class UserListApi(BaseApi):
|
class UserListApi(BaseApi):
|
||||||
''' API for lists of users. '''
|
''' API for lists of users. '''
|
||||||
|
|
||||||
def __init__(self, auth_service, user_service):
|
def __init__(self, auth_service, user_service):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._auth_service = auth_service
|
self._auth_service = auth_service
|
||||||
self._user_service = user_service
|
self._user_service = user_service
|
||||||
self._search_executor = SearchExecutor(UserSearchConfig())
|
self._search_executor = search.SearchExecutor(search.UserSearchConfig())
|
||||||
|
|
||||||
def get(self, request, context):
|
def get(self, context):
|
||||||
''' Retrieves a list of users. '''
|
''' Retrieves a list of users. '''
|
||||||
self._auth_service.verify_privilege(context.user, 'users:list')
|
self._auth_service.verify_privilege(context.user, 'users:list')
|
||||||
query = request.get_param_as_string('query')
|
query = context.get_param_as_string('query')
|
||||||
page = request.get_param_as_int('page', 1)
|
page = context.get_param_as_int('page', 1)
|
||||||
count, users = self._search_executor.execute(context.session, query, page)
|
count, users = self._search_executor.execute(context.session, query, page)
|
||||||
return {
|
return {
|
||||||
'query': query,
|
'query': query,
|
||||||
|
@ -42,7 +41,7 @@ class UserListApi(BaseApi):
|
||||||
'users': [_serialize_user(context.user, user) for user in users],
|
'users': [_serialize_user(context.user, user) for user in users],
|
||||||
}
|
}
|
||||||
|
|
||||||
def post(self, request, context):
|
def post(self, context):
|
||||||
''' Creates a new user. '''
|
''' Creates a new user. '''
|
||||||
self._auth_service.verify_privilege(context.user, 'users:create')
|
self._auth_service.verify_privilege(context.user, 'users:create')
|
||||||
|
|
||||||
|
@ -51,18 +50,19 @@ class UserListApi(BaseApi):
|
||||||
password = context.request['password']
|
password = context.request['password']
|
||||||
email = context.request['email'].strip()
|
email = context.request['email'].strip()
|
||||||
except KeyError as ex:
|
except KeyError as ex:
|
||||||
raise ValidationError('Field %r not found.' % ex.args[0])
|
raise errors.ValidationError('Field %r not found.' % ex.args[0])
|
||||||
|
|
||||||
user = self._user_service.create_user(
|
user = self._user_service.create_user(
|
||||||
context.session, name, password, email)
|
context.session, name, password, email)
|
||||||
try:
|
try:
|
||||||
context.session.commit()
|
context.session.commit()
|
||||||
except sqlalchemy.exc.IntegrityError:
|
except sqlalchemy.exc.IntegrityError:
|
||||||
raise IntegrityError('User %r already exists.' % name)
|
raise errors.IntegrityError('User %r already exists.' % name)
|
||||||
return {'user': _serialize_user(context.user, user)}
|
return {'user': _serialize_user(context.user, user)}
|
||||||
|
|
||||||
class UserDetailApi(BaseApi):
|
class UserDetailApi(BaseApi):
|
||||||
''' API for individual users. '''
|
''' API for individual users. '''
|
||||||
|
|
||||||
def __init__(self, config, auth_service, password_service, user_service):
|
def __init__(self, config, auth_service, password_service, user_service):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self._available_access_ranks = config['service']['user_ranks']
|
self._available_access_ranks = config['service']['user_ranks']
|
||||||
|
@ -72,19 +72,19 @@ class UserDetailApi(BaseApi):
|
||||||
self._auth_service = auth_service
|
self._auth_service = auth_service
|
||||||
self._user_service = user_service
|
self._user_service = user_service
|
||||||
|
|
||||||
def get(self, request, context, user_name):
|
def get(self, context, user_name):
|
||||||
''' Retrieves an user. '''
|
''' Retrieves an user. '''
|
||||||
self._auth_service.verify_privilege(context.user, 'users:view')
|
self._auth_service.verify_privilege(context.user, 'users:view')
|
||||||
user = self._user_service.get_by_name(context.session, user_name)
|
user = self._user_service.get_by_name(context.session, user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError('User %r not found.' % user_name)
|
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||||
return {'user': _serialize_user(context.user, user)}
|
return {'user': _serialize_user(context.user, user)}
|
||||||
|
|
||||||
def put(self, request, context, user_name):
|
def put(self, context, user_name):
|
||||||
''' Updates an existing user. '''
|
''' Updates an existing user. '''
|
||||||
user = self._user_service.get_by_name(context.session, user_name)
|
user = self._user_service.get_by_name(context.session, user_name)
|
||||||
if not user:
|
if not user:
|
||||||
raise NotFoundError('User %r not found.' % user_name)
|
raise errors.NotFoundError('User %r not found.' % user_name)
|
||||||
|
|
||||||
if context.user.user_id == user.user_id:
|
if context.user.user_id == user.user_id:
|
||||||
infix = 'self'
|
infix = 'self'
|
||||||
|
@ -96,7 +96,7 @@ class UserDetailApi(BaseApi):
|
||||||
context.user, 'users:edit:%s:name' % infix)
|
context.user, 'users:edit:%s:name' % infix)
|
||||||
name = context.request['name'].strip()
|
name = context.request['name'].strip()
|
||||||
if not re.match(self._name_regex, name):
|
if not re.match(self._name_regex, name):
|
||||||
raise ValidationError(
|
raise errors.ValidationError(
|
||||||
'Name must satisfy regex %r.' % self._name_regex)
|
'Name must satisfy regex %r.' % self._name_regex)
|
||||||
user.name = name
|
user.name = name
|
||||||
|
|
||||||
|
@ -105,7 +105,7 @@ class UserDetailApi(BaseApi):
|
||||||
self._auth_service.verify_privilege(
|
self._auth_service.verify_privilege(
|
||||||
context.user, 'users:edit:%s:pass' % infix)
|
context.user, 'users:edit:%s:pass' % infix)
|
||||||
if not re.match(self._password_regex, password):
|
if not re.match(self._password_regex, password):
|
||||||
raise ValidationError(
|
raise errors.ValidationError(
|
||||||
'Password must satisfy regex %r.' % self._password_regex)
|
'Password must satisfy regex %r.' % self._password_regex)
|
||||||
user.password_salt = self._password_service.create_password()
|
user.password_salt = self._password_service.create_password()
|
||||||
user.password_hash = self._password_service.get_password_hash(
|
user.password_hash = self._password_service.get_password_hash(
|
||||||
|
@ -114,12 +114,10 @@ class UserDetailApi(BaseApi):
|
||||||
if 'email' in context.request:
|
if 'email' in context.request:
|
||||||
self._auth_service.verify_privilege(
|
self._auth_service.verify_privilege(
|
||||||
context.user, 'users:edit:%s:email' % infix)
|
context.user, 'users:edit:%s:email' % infix)
|
||||||
email = context.request['email'].strip()
|
email = context.request['email'].strip() or None
|
||||||
if not is_valid_email(email):
|
if not util.is_valid_email(email):
|
||||||
raise ValidationError('%r is not a vaild email address.' % email)
|
raise errors.ValidationError(
|
||||||
# prefer nulls to empty strings in the DB
|
'%r is not a vaild email address.' % email)
|
||||||
if not email:
|
|
||||||
email = None
|
|
||||||
user.email = email
|
user.email = email
|
||||||
|
|
||||||
if 'accessRank' in context.request:
|
if 'accessRank' in context.request:
|
||||||
|
@ -127,12 +125,12 @@ class UserDetailApi(BaseApi):
|
||||||
context.user, 'users:edit:%s:rank' % infix)
|
context.user, 'users:edit:%s:rank' % infix)
|
||||||
rank = context.request['accessRank'].strip()
|
rank = context.request['accessRank'].strip()
|
||||||
if not rank in self._available_access_ranks:
|
if not rank in self._available_access_ranks:
|
||||||
raise ValidationError(
|
raise errors.ValidationError(
|
||||||
'Bad access rank. Valid access ranks: %r' \
|
'Bad access rank. Valid access ranks: %r' \
|
||||||
% self._available_access_ranks)
|
% self._available_access_ranks)
|
||||||
if self._available_access_ranks.index(context.user.access_rank) \
|
if self._available_access_ranks.index(context.user.access_rank) \
|
||||||
< self._available_access_ranks.index(rank):
|
< self._available_access_ranks.index(rank):
|
||||||
raise AuthError(
|
raise errors.AuthError(
|
||||||
'Trying to set higher access rank than one has')
|
'Trying to set higher access rank than one has')
|
||||||
user.access_rank = rank
|
user.access_rank = rank
|
||||||
|
|
||||||
|
@ -141,6 +139,6 @@ class UserDetailApi(BaseApi):
|
||||||
try:
|
try:
|
||||||
context.session.commit()
|
context.session.commit()
|
||||||
except sqlalchemy.exc.IntegrityError:
|
except sqlalchemy.exc.IntegrityError:
|
||||||
raise IntegrityError('User %r already exists.' % name)
|
raise errors.IntegrityError('User %r already exists.' % name)
|
||||||
|
|
||||||
return {'user': _serialize_user(context.user, user)}
|
return {'user': _serialize_user(context.user, user)}
|
||||||
|
|
|
@ -1,16 +1,15 @@
|
||||||
''' Exports create_app. '''
|
''' Exports create_app. '''
|
||||||
|
|
||||||
import os
|
|
||||||
import falcon
|
import falcon
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
import sqlalchemy.orm
|
import sqlalchemy.orm
|
||||||
import szurubooru.api
|
import szurubooru.api
|
||||||
import szurubooru.config
|
import szurubooru.config
|
||||||
|
import szurubooru.errors
|
||||||
import szurubooru.middleware
|
import szurubooru.middleware
|
||||||
import szurubooru.services
|
import szurubooru.services
|
||||||
import szurubooru.services.search
|
import szurubooru.services.search
|
||||||
import szurubooru.util
|
import szurubooru.util
|
||||||
from szurubooru.errors import *
|
|
||||||
|
|
||||||
class _CustomRequest(falcon.Request):
|
class _CustomRequest(falcon.Request):
|
||||||
context_type = szurubooru.util.dotdict
|
context_type = szurubooru.util.dotdict
|
||||||
|
@ -28,28 +27,26 @@ class _CustomRequest(falcon.Request):
|
||||||
return default
|
return default
|
||||||
raise falcon.HTTPMissingParam(name)
|
raise falcon.HTTPMissingParam(name)
|
||||||
|
|
||||||
def _on_auth_error(ex, request, response, params):
|
def _on_auth_error(ex, _request, _response, _params):
|
||||||
raise falcon.HTTPForbidden(
|
raise falcon.HTTPForbidden(
|
||||||
title='Authentication error', description=str(ex))
|
title='Authentication error', description=str(ex))
|
||||||
|
|
||||||
def _on_validation_error(ex, request, response, params):
|
def _on_validation_error(ex, _request, _response, _params):
|
||||||
raise falcon.HTTPBadRequest(title='Validation error', description=str(ex))
|
raise falcon.HTTPBadRequest(title='Validation error', description=str(ex))
|
||||||
|
|
||||||
def _on_search_error(ex, request, response, params):
|
def _on_search_error(ex, _request, _response, _params):
|
||||||
raise falcon.HTTPBadRequest(title='Search error', description=str(ex))
|
raise falcon.HTTPBadRequest(title='Search error', description=str(ex))
|
||||||
|
|
||||||
def _on_integrity_error(ex, request, response, params):
|
def _on_integrity_error(ex, _request, _response, _params):
|
||||||
raise falcon.HTTPConflict(
|
raise falcon.HTTPConflict(
|
||||||
title='Integrity violation', description=ex.args[0])
|
title='Integrity violation', description=ex.args[0])
|
||||||
|
|
||||||
def _on_not_found_error(ex, request, response, params):
|
def _on_not_found_error(ex, _request, _response, _params):
|
||||||
raise falcon.HTTPNotFound(title='Not found', description=str(ex))
|
raise falcon.HTTPNotFound(title='Not found', description=str(ex))
|
||||||
|
|
||||||
def create_app():
|
def create_app():
|
||||||
''' Creates a WSGI compatible App object. '''
|
''' Creates a WSGI compatible App object. '''
|
||||||
config = szurubooru.config.Config()
|
config = szurubooru.config.Config()
|
||||||
root_dir = os.path.dirname(__file__)
|
|
||||||
static_dir = os.path.join(root_dir, os.pardir, 'static')
|
|
||||||
|
|
||||||
engine = sqlalchemy.create_engine(
|
engine = sqlalchemy.create_engine(
|
||||||
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
'{schema}://{user}:{password}@{host}:{port}/{name}'.format(
|
||||||
|
@ -77,9 +74,10 @@ def create_app():
|
||||||
app = falcon.API(
|
app = falcon.API(
|
||||||
request_type=_CustomRequest,
|
request_type=_CustomRequest,
|
||||||
middleware=[
|
middleware=[
|
||||||
|
szurubooru.middleware.ImbueContext(),
|
||||||
szurubooru.middleware.RequireJson(),
|
szurubooru.middleware.RequireJson(),
|
||||||
szurubooru.middleware.JsonTranslator(),
|
szurubooru.middleware.JsonTranslator(),
|
||||||
szurubooru.middleware.DbSession(session_maker),
|
szurubooru.middleware.DbSession(scoped_session),
|
||||||
szurubooru.middleware.Authenticator(auth_service, user_service),
|
szurubooru.middleware.Authenticator(auth_service, user_service),
|
||||||
])
|
])
|
||||||
|
|
||||||
|
|
|
@ -1,11 +1,6 @@
|
||||||
''' Exports Config. '''
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import configobj
|
import configobj
|
||||||
|
import szurubooru.errors
|
||||||
class ConfigurationError(RuntimeError):
|
|
||||||
''' A problem with config.ini file. '''
|
|
||||||
pass
|
|
||||||
|
|
||||||
class Config(object):
|
class Config(object):
|
||||||
''' INI config parser and container. '''
|
''' INI config parser and container. '''
|
||||||
|
@ -26,13 +21,14 @@ class Config(object):
|
||||||
all_ranks = self['service']['user_ranks']
|
all_ranks = self['service']['user_ranks']
|
||||||
for privilege, rank in self['privileges'].items():
|
for privilege, rank in self['privileges'].items():
|
||||||
if rank not in all_ranks:
|
if rank not in all_ranks:
|
||||||
raise ConfigurationError(
|
raise szurubooru.errors.ConfigError(
|
||||||
'Rank %r for privilege %r is missing from user_ranks' % (
|
'Rank %r for privilege %r is missing from user_ranks' % (
|
||||||
rank, privilege))
|
rank, privilege))
|
||||||
for rank in ['anonymous', 'admin', 'nobody']:
|
for rank in ['anonymous', 'admin', 'nobody']:
|
||||||
if rank not in all_ranks:
|
if rank not in all_ranks:
|
||||||
raise ConfigurationError('Fixed rank %r is missing from user_ranks' % rank)
|
raise szurubooru.errors.ConfigError(
|
||||||
|
'Fixed rank %r is missing from user_ranks' % rank)
|
||||||
if self['service']['default_user_rank'] not in all_ranks:
|
if self['service']['default_user_rank'] not in all_ranks:
|
||||||
raise ConfigurationError(
|
raise szurubooru.errors.ConfigError(
|
||||||
'Default rank %r is missing from user_ranks' % (
|
'Default rank %r is missing from user_ranks' % (
|
||||||
self['service']['default_user_rank']))
|
self['service']['default_user_rank']))
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
''' Exports custom errors. '''
|
class ConfigError(RuntimeError):
|
||||||
|
''' A problem with config.ini file. '''
|
||||||
|
|
||||||
class AuthError(RuntimeError):
|
class AuthError(RuntimeError):
|
||||||
''' Generic authentication error '''
|
''' Generic authentication error '''
|
||||||
|
|
|
@ -4,3 +4,4 @@ from szurubooru.middleware.authenticator import Authenticator
|
||||||
from szurubooru.middleware.json_translator import JsonTranslator
|
from szurubooru.middleware.json_translator import JsonTranslator
|
||||||
from szurubooru.middleware.require_json import RequireJson
|
from szurubooru.middleware.require_json import RequireJson
|
||||||
from szurubooru.middleware.db_session import DbSession
|
from szurubooru.middleware.db_session import DbSession
|
||||||
|
from szurubooru.middleware.imbue_context import ImbueContext
|
||||||
|
|
|
@ -1,9 +1,7 @@
|
||||||
''' Exports Authenticator. '''
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import falcon
|
import falcon
|
||||||
from szurubooru.model.user import User
|
from szurubooru import errors
|
||||||
from szurubooru.errors import AuthError
|
from szurubooru import model
|
||||||
|
|
||||||
class Authenticator(object):
|
class Authenticator(object):
|
||||||
'''
|
'''
|
||||||
|
@ -15,7 +13,7 @@ class Authenticator(object):
|
||||||
self._auth_service = auth_service
|
self._auth_service = auth_service
|
||||||
self._user_service = user_service
|
self._user_service = user_service
|
||||||
|
|
||||||
def process_request(self, request, response):
|
def process_request(self, request, _response):
|
||||||
''' Executed before passing the request to the API. '''
|
''' Executed before passing the request to the API. '''
|
||||||
request.context.user = self._get_user(request)
|
request.context.user = self._get_user(request)
|
||||||
if request.get_param_as_bool('bump-login') \
|
if request.get_param_as_bool('bump-login') \
|
||||||
|
@ -51,13 +49,13 @@ class Authenticator(object):
|
||||||
''' Tries to authenticate user. Throws AuthError for invalid users. '''
|
''' Tries to authenticate user. Throws AuthError for invalid users. '''
|
||||||
user = self._user_service.get_by_name(session, username)
|
user = self._user_service.get_by_name(session, username)
|
||||||
if not user:
|
if not user:
|
||||||
raise AuthError('No such user.')
|
raise errors.AuthError('No such user.')
|
||||||
if not self._auth_service.is_valid_password(user, password):
|
if not self._auth_service.is_valid_password(user, password):
|
||||||
raise AuthError('Invalid password.')
|
raise errors.AuthError('Invalid password.')
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def _create_anonymous_user(self):
|
def _create_anonymous_user(self):
|
||||||
user = User()
|
user = model.User()
|
||||||
user.name = None
|
user.name = None
|
||||||
user.access_rank = 'anonymous'
|
user.access_rank = 'anonymous'
|
||||||
user.password = None
|
user.password = None
|
||||||
|
|
|
@ -1,16 +1,14 @@
|
||||||
''' Exports DbSession. '''
|
|
||||||
|
|
||||||
class DbSession(object):
|
class DbSession(object):
|
||||||
''' Attaches database session to the context of every request. '''
|
''' Attaches database session to the context of every request. '''
|
||||||
|
|
||||||
def __init__(self, session_factory):
|
def __init__(self, session_factory):
|
||||||
self._session_factory = session_factory
|
self._session_factory = session_factory
|
||||||
|
|
||||||
def process_request(self, request, response):
|
def process_request(self, request, _response):
|
||||||
''' Executed before passing the request to the API. '''
|
''' Executed before passing the request to the API. '''
|
||||||
request.context.session = self._session_factory()
|
request.context.session = self._session_factory()
|
||||||
|
|
||||||
def process_response(self, request, response, resource):
|
def process_response(self, request, _response, _resource):
|
||||||
'''
|
'''
|
||||||
Executed before passing the response to falcon.
|
Executed before passing the response to falcon.
|
||||||
Any commits to database need to happen explicitly in the API layer.
|
Any commits to database need to happen explicitly in the API layer.
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
class ImbueContext(object):
|
||||||
|
''' Decorates context with methods from falcon's request. '''
|
||||||
|
|
||||||
|
def process_request(self, request, _response):
|
||||||
|
request.context.get_param_as_string = request.get_param_as_string
|
||||||
|
request.context.get_param_as_bool = request.get_param_as_bool
|
||||||
|
request.context.get_param_as_int = request.get_param_as_int
|
||||||
|
request.context.get_param_as_list = request.get_param_as_list
|
|
@ -1,12 +1,10 @@
|
||||||
''' Exports JsonTranslator. '''
|
import datetime
|
||||||
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
|
||||||
import falcon
|
import falcon
|
||||||
|
|
||||||
def json_serial(obj):
|
def json_serial(obj):
|
||||||
''' JSON serializer for objects not serializable by default JSON code '''
|
''' JSON serializer for objects not serializable by default JSON code '''
|
||||||
if isinstance(obj, datetime):
|
if isinstance(obj, datetime.datetime):
|
||||||
serial = obj.isoformat()
|
serial = obj.isoformat()
|
||||||
return serial
|
return serial
|
||||||
raise TypeError('Type not serializable')
|
raise TypeError('Type not serializable')
|
||||||
|
@ -17,7 +15,7 @@ class JsonTranslator(object):
|
||||||
context.
|
context.
|
||||||
'''
|
'''
|
||||||
|
|
||||||
def process_request(self, request, response):
|
def process_request(self, request, _response):
|
||||||
''' Executed before passing the request to the API. '''
|
''' Executed before passing the request to the API. '''
|
||||||
if request.content_length in (None, 0):
|
if request.content_length in (None, 0):
|
||||||
return
|
return
|
||||||
|
@ -37,7 +35,7 @@ class JsonTranslator(object):
|
||||||
'Could not decode the request body. The '
|
'Could not decode the request body. The '
|
||||||
'JSON was incorrect or not encoded as UTF-8.')
|
'JSON was incorrect or not encoded as UTF-8.')
|
||||||
|
|
||||||
def process_response(self, request, response, resource):
|
def process_response(self, request, response, _resource):
|
||||||
''' Executed before passing the response to falcon. '''
|
''' Executed before passing the response to falcon. '''
|
||||||
if 'result' not in request.context:
|
if 'result' not in request.context:
|
||||||
return
|
return
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
''' Exports RequireJson. '''
|
|
||||||
|
|
||||||
import falcon
|
import falcon
|
||||||
|
|
||||||
class RequireJson(object):
|
class RequireJson(object):
|
||||||
''' Sanitizes requests so that only JSON is accepted. '''
|
''' Sanitizes requests so that only JSON is accepted. '''
|
||||||
|
|
||||||
def process_request(self, req, resp):
|
def process_request(self, request, _response):
|
||||||
''' Executed before passing the request to the API. '''
|
''' Executed before passing the request to the API. '''
|
||||||
if not req.client_accepts_json:
|
if not request.client_accepts_json:
|
||||||
raise falcon.HTTPNotAcceptable(
|
raise falcon.HTTPNotAcceptable(
|
||||||
'This API only supports responses encoded as JSON.')
|
'This API only supports responses encoded as JSON.')
|
||||||
|
|
|
@ -1,12 +1,9 @@
|
||||||
# pylint: disable=too-many-instance-attributes,too-few-public-methods
|
# pylint: disable=too-many-instance-attributes,too-few-public-methods
|
||||||
|
|
||||||
''' Exports User. '''
|
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from szurubooru.model.base import Base
|
from szurubooru.model.base import Base
|
||||||
|
|
||||||
class User(Base):
|
class User(Base):
|
||||||
''' Database representation of an user. '''
|
|
||||||
__tablename__ = 'user'
|
__tablename__ = 'user'
|
||||||
|
|
||||||
AVATAR_GRAVATAR = 1
|
AVATAR_GRAVATAR = 1
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
''' Exports AuthService. '''
|
from szurubooru import errors
|
||||||
|
|
||||||
from szurubooru.errors import AuthError
|
|
||||||
|
|
||||||
class AuthService(object):
|
class AuthService(object):
|
||||||
''' Services related to user authentication '''
|
|
||||||
|
|
||||||
def __init__(self, config, password_service):
|
def __init__(self, config, password_service):
|
||||||
self._config = config
|
self._config = config
|
||||||
self._password_service = password_service
|
self._password_service = password_service
|
||||||
|
@ -29,4 +25,4 @@ class AuthService(object):
|
||||||
minimal_rank = self._config['privileges'][privilege_name]
|
minimal_rank = self._config['privileges'][privilege_name]
|
||||||
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
|
good_ranks = all_ranks[all_ranks.index(minimal_rank):]
|
||||||
if user.access_rank not in good_ranks:
|
if user.access_rank not in good_ranks:
|
||||||
raise AuthError('Insufficient privileges to do this.')
|
raise errors.AuthError('Insufficient privileges to do this.')
|
||||||
|
|
|
@ -1,12 +1,12 @@
|
||||||
import smtplib
|
import smtplib
|
||||||
from email.mime.text import MIMEText
|
import email.mime.text
|
||||||
|
|
||||||
class Mailer(object):
|
class Mailer(object):
|
||||||
def __init__(self, config):
|
def __init__(self, config):
|
||||||
self._config = config
|
self._config = config
|
||||||
|
|
||||||
def send(self, sender, recipient, subject, body):
|
def send(self, sender, recipient, subject, body):
|
||||||
msg = MIMEText(body)
|
msg = email.mime.text.MIMEText(body)
|
||||||
msg['Subject'] = subject
|
msg['Subject'] = subject
|
||||||
msg['From'] = sender
|
msg['From'] = sender
|
||||||
msg['To'] = recipient
|
msg['To'] = recipient
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
''' Exports PasswordService. '''
|
|
||||||
|
|
||||||
import hashlib
|
import hashlib
|
||||||
import random
|
import random
|
||||||
|
|
||||||
|
|
|
@ -1,34 +1,32 @@
|
||||||
''' Exports BaseSearchConfig. '''
|
|
||||||
|
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from szurubooru.errors import SearchError
|
import szurubooru.errors
|
||||||
from szurubooru.services.search.criteria import *
|
from szurubooru import util
|
||||||
from szurubooru.util import parse_time_range
|
from szurubooru.services.search import criteria
|
||||||
|
|
||||||
def _apply_criterion_to_column(
|
def _apply_criterion_to_column(
|
||||||
column, query, criterion, allow_composite=True, allow_ranged=True):
|
column, query, criterion, allow_composite=True, allow_ranged=True):
|
||||||
''' Decorates SQLAlchemy filter on given column using supplied criterion. '''
|
''' Decorates SQLAlchemy filter on given column using supplied criterion. '''
|
||||||
if isinstance(criterion, StringSearchCriterion):
|
if isinstance(criterion, criteria.StringSearchCriterion):
|
||||||
filter = column == criterion.value
|
expr = column == criterion.value
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
filter = ~filter
|
expr = ~expr
|
||||||
return query.filter(filter)
|
return query.filter(expr)
|
||||||
elif isinstance(criterion, ArraySearchCriterion):
|
elif isinstance(criterion, criteria.ArraySearchCriterion):
|
||||||
if not allow_composite:
|
if not allow_composite:
|
||||||
raise SearchError(
|
raise szurubooru.errors.SearchError(
|
||||||
'Composite token %r is invalid in this context.' % (criterion,))
|
'Composite token %r is invalid in this context.' % (criterion,))
|
||||||
filter = column.in_(criterion.values)
|
expr = column.in_(criterion.values)
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
filter = ~filter
|
expr = ~expr
|
||||||
return query.filter(filter)
|
return query.filter(expr)
|
||||||
elif isinstance(criterion, RangedSearchCriterion):
|
elif isinstance(criterion, criteria.RangedSearchCriterion):
|
||||||
if not allow_ranged:
|
if not allow_ranged:
|
||||||
raise SearchError(
|
raise szurubooru.errors.SearchError(
|
||||||
'Ranged token %r is invalid in this context.' % (criterion,))
|
'Ranged token %r is invalid in this context.' % (criterion,))
|
||||||
filter = column.between(criterion.min_value, criterion.max_value)
|
expr = column.between(criterion.min_value, criterion.max_value)
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
filter = ~filter
|
expr = ~expr
|
||||||
return query.filter(filter)
|
return query.filter(expr)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError('Invalid search type: %r.' % (criterion,))
|
raise RuntimeError('Invalid search type: %r.' % (criterion,))
|
||||||
|
|
||||||
|
@ -37,36 +35,35 @@ def _apply_date_criterion_to_column(column, query, criterion):
|
||||||
Decorates SQLAlchemy filter on given column using supplied criterion.
|
Decorates SQLAlchemy filter on given column using supplied criterion.
|
||||||
Parses the datetime inside the criterion.
|
Parses the datetime inside the criterion.
|
||||||
'''
|
'''
|
||||||
if isinstance(criterion, StringSearchCriterion):
|
if isinstance(criterion, criteria.StringSearchCriterion):
|
||||||
min_date, max_date = parse_time_range(criterion.value)
|
min_date, max_date = util.parse_time_range(criterion.value)
|
||||||
filter = column.between(min_date, max_date)
|
expr = column.between(min_date, max_date)
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
filter = ~filter
|
expr = ~expr
|
||||||
return query.filter(filter)
|
return query.filter(expr)
|
||||||
elif isinstance(criterion, ArraySearchCriterion):
|
elif isinstance(criterion, criteria.ArraySearchCriterion):
|
||||||
result = query
|
expr = sqlalchemy.sql.false()
|
||||||
filter = sqlalchemy.sql.false()
|
|
||||||
for value in criterion.values:
|
for value in criterion.values:
|
||||||
min_date, max_date = parse_time_range(value)
|
min_date, max_date = util.parse_time_range(value)
|
||||||
filter = filter | column.between(min_date, max_date)
|
expr = expr | column.between(min_date, max_date)
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
filter = ~filter
|
expr = ~expr
|
||||||
return query.filter(filter)
|
return query.filter(expr)
|
||||||
elif isinstance(criterion, RangedSearchCriterion):
|
elif isinstance(criterion, criteria.RangedSearchCriterion):
|
||||||
assert criterion.min_value or criterion.max_value
|
assert criterion.min_value or criterion.max_value
|
||||||
if criterion.min_value and criterion.max_value:
|
if criterion.min_value and criterion.max_value:
|
||||||
min_date = parse_time_range(criterion.min_value)[0]
|
min_date = util.parse_time_range(criterion.min_value)[0]
|
||||||
max_date = parse_time_range(criterion.max_value)[1]
|
max_date = util.parse_time_range(criterion.max_value)[1]
|
||||||
filter = column.between(min_date, max_date)
|
expr = column.between(min_date, max_date)
|
||||||
elif criterion.min_value:
|
elif criterion.min_value:
|
||||||
min_date = parse_time_range(criterion.min_value)[0]
|
min_date = util.parse_time_range(criterion.min_value)[0]
|
||||||
filter = column >= min_date
|
expr = column >= min_date
|
||||||
elif criterion.max_value:
|
elif criterion.max_value:
|
||||||
max_date = parse_time_range(criterion.max_value)[1]
|
max_date = util.parse_time_range(criterion.max_value)[1]
|
||||||
filter = column <= max_date
|
expr = column <= max_date
|
||||||
if criterion.negated:
|
if criterion.negated:
|
||||||
filter = ~filter
|
expr = ~expr
|
||||||
return query.filter(filter)
|
return query.filter(expr)
|
||||||
|
|
||||||
class BaseSearchConfig(object):
|
class BaseSearchConfig(object):
|
||||||
def create_query(self, session):
|
def create_query(self, session):
|
||||||
|
|
|
@ -1,4 +1,4 @@
|
||||||
class BaseSearchCriterion(object):
|
class _BaseSearchCriterion(object):
|
||||||
def __init__(self, original_text, negated):
|
def __init__(self, original_text, negated):
|
||||||
self.original_text = original_text
|
self.original_text = original_text
|
||||||
self.negated = negated
|
self.negated = negated
|
||||||
|
@ -6,18 +6,18 @@ class BaseSearchCriterion(object):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return self.original_text
|
return self.original_text
|
||||||
|
|
||||||
class RangedSearchCriterion(BaseSearchCriterion):
|
class RangedSearchCriterion(_BaseSearchCriterion):
|
||||||
def __init__(self, original_text, negated, min_value, max_value):
|
def __init__(self, original_text, negated, min_value, max_value):
|
||||||
super().__init__(original_text, negated)
|
super().__init__(original_text, negated)
|
||||||
self.min_value = min_value
|
self.min_value = min_value
|
||||||
self.max_value = max_value
|
self.max_value = max_value
|
||||||
|
|
||||||
class StringSearchCriterion(BaseSearchCriterion):
|
class StringSearchCriterion(_BaseSearchCriterion):
|
||||||
def __init__(self, original_text, negated, value):
|
def __init__(self, original_text, negated, value):
|
||||||
super().__init__(original_text, negated)
|
super().__init__(original_text, negated)
|
||||||
self.value = value
|
self.value = value
|
||||||
|
|
||||||
class ArraySearchCriterion(BaseSearchCriterion):
|
class ArraySearchCriterion(_BaseSearchCriterion):
|
||||||
def __init__(self, original_text, negated, values):
|
def __init__(self, original_text, negated, values):
|
||||||
super().__init__(original_text, negated)
|
super().__init__(original_text, negated)
|
||||||
self.values = values
|
self.values = values
|
||||||
|
|
|
@ -2,8 +2,8 @@
|
||||||
|
|
||||||
import re
|
import re
|
||||||
import sqlalchemy
|
import sqlalchemy
|
||||||
from szurubooru.errors import SearchError
|
from szurubooru import errors
|
||||||
from szurubooru.services.search.criteria import *
|
from szurubooru.services.search import criteria
|
||||||
|
|
||||||
class SearchExecutor(object):
|
class SearchExecutor(object):
|
||||||
ORDER_DESC = 1
|
ORDER_DESC = 1
|
||||||
|
@ -62,9 +62,11 @@ class SearchExecutor(object):
|
||||||
elif order_str == 'desc':
|
elif order_str == 'desc':
|
||||||
order = self.ORDER_DESC
|
order = self.ORDER_DESC
|
||||||
else:
|
else:
|
||||||
raise SearchError('Unknown search direction: %r.' % order_str)
|
raise errors.SearchError(
|
||||||
|
'Unknown search direction: %r.' % order_str)
|
||||||
else:
|
else:
|
||||||
raise SearchError('Too many commas in order search token.')
|
raise errors.SearchError(
|
||||||
|
'Too many commas in order search token.')
|
||||||
if negated:
|
if negated:
|
||||||
if order == self.ORDER_DESC:
|
if order == self.ORDER_DESC:
|
||||||
order = self.ORDER_ASC
|
order = self.ORDER_ASC
|
||||||
|
@ -79,21 +81,22 @@ class SearchExecutor(object):
|
||||||
|
|
||||||
def _handle_anonymous(self, query, criterion):
|
def _handle_anonymous(self, query, criterion):
|
||||||
if not self._search_config.anonymous_filter:
|
if not self._search_config.anonymous_filter:
|
||||||
raise SearchError(
|
raise errors.SearchError(
|
||||||
'Anonymous tokens are not valid in this context.')
|
'Anonymous tokens are not valid in this context.')
|
||||||
return self._search_config.anonymous_filter(query, criterion)
|
return self._search_config.anonymous_filter(query, criterion)
|
||||||
|
|
||||||
def _handle_named(self, query, key, criterion):
|
def _handle_named(self, query, key, criterion):
|
||||||
if key in self._search_config.named_filters:
|
if key in self._search_config.named_filters:
|
||||||
return self._search_config.named_filters[key](query, criterion)
|
return self._search_config.named_filters[key](query, criterion)
|
||||||
raise SearchError(
|
raise errors.SearchError(
|
||||||
'Unknown named token: %r. Available named tokens: %r.' % (
|
'Unknown named token: %r. Available named tokens: %r.' % (
|
||||||
key, list(self._search_config.named_filters.keys())))
|
key, list(self._search_config.named_filters.keys())))
|
||||||
|
|
||||||
def _handle_special(self, query, value, negated):
|
def _handle_special(self, query, value, negated):
|
||||||
if value in self._search_config.special_filters:
|
if value in self._search_config.special_filters:
|
||||||
return self._search_config.special_filters[value](query, criterion)
|
return self._search_config.special_filters[value](
|
||||||
raise SearchError(
|
query, value, negated)
|
||||||
|
raise errors.SearchError(
|
||||||
'Unknown special token: %r. Available special tokens: %r.' % (
|
'Unknown special token: %r. Available special tokens: %r.' % (
|
||||||
value, list(self._search_config.special_filters.keys())))
|
value, list(self._search_config.special_filters.keys())))
|
||||||
|
|
||||||
|
@ -105,7 +108,7 @@ class SearchExecutor(object):
|
||||||
else:
|
else:
|
||||||
column = column.desc()
|
column = column.desc()
|
||||||
return query.order_by(column)
|
return query.order_by(column)
|
||||||
raise SearchError(
|
raise errors.SearchError(
|
||||||
'Unknown search order: %r. Available search orders: %r.' % (
|
'Unknown search order: %r. Available search orders: %r.' % (
|
||||||
value, list(self._search_config.order_columns.keys())))
|
value, list(self._search_config.order_columns.keys())))
|
||||||
|
|
||||||
|
@ -113,8 +116,9 @@ class SearchExecutor(object):
|
||||||
if '..' in value:
|
if '..' in value:
|
||||||
low, high = value.split('..')
|
low, high = value.split('..')
|
||||||
if not low and not high:
|
if not low and not high:
|
||||||
raise SearchError('Empty ranged value')
|
raise errors.SearchError('Empty ranged value')
|
||||||
return RangedSearchCriterion(value, negated, low, high)
|
return criteria.RangedSearchCriterion(value, negated, low, high)
|
||||||
if ',' in value:
|
if ',' in value:
|
||||||
return ArraySearchCriterion(value, negated, value.split(','))
|
return criteria.ArraySearchCriterion(
|
||||||
return StringSearchCriterion(value, negated, value)
|
value, negated, value.split(','))
|
||||||
|
return criteria.StringSearchCriterion(value, negated, value)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
''' Exports UserSearchConfig. '''
|
''' Exports UserSearchConfig. '''
|
||||||
|
|
||||||
from sqlalchemy.sql.expression import func
|
from sqlalchemy.sql.expression import func
|
||||||
from szurubooru.errors import SearchError
|
|
||||||
from szurubooru.model import User
|
from szurubooru.model import User
|
||||||
from szurubooru.services.search.base_search_config import BaseSearchConfig
|
from szurubooru.services.search.base_search_config import BaseSearchConfig
|
||||||
|
|
||||||
|
|
|
@ -1,10 +1,8 @@
|
||||||
''' Exports UserService. '''
|
|
||||||
|
|
||||||
import re
|
import re
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from szurubooru.errors import ValidationError
|
from szurubooru import errors
|
||||||
from szurubooru.model.user import User
|
from szurubooru import model
|
||||||
from szurubooru.util import is_valid_email
|
from szurubooru import util
|
||||||
|
|
||||||
class UserService(object):
|
class UserService(object):
|
||||||
''' User management '''
|
''' User management '''
|
||||||
|
@ -19,29 +17,26 @@ class UserService(object):
|
||||||
''' Creates an user with given parameters and returns it. '''
|
''' Creates an user with given parameters and returns it. '''
|
||||||
|
|
||||||
if not re.match(self._name_regex, name):
|
if not re.match(self._name_regex, name):
|
||||||
raise ValidationError(
|
raise errors.ValidationError(
|
||||||
'Name must satisfy regex %r.' % self._name_regex)
|
'Name must satisfy regex %r.' % self._name_regex)
|
||||||
|
|
||||||
if not re.match(self._password_regex, password):
|
if not re.match(self._password_regex, password):
|
||||||
raise ValidationError(
|
raise errors.ValidationError(
|
||||||
'Password must satisfy regex %r.' % self._password_regex)
|
'Password must satisfy regex %r.' % self._password_regex)
|
||||||
|
|
||||||
if not is_valid_email(email):
|
if not util.is_valid_email(email):
|
||||||
raise ValidationError('%r is not a vaild email address.' % email)
|
raise errors.ValidationError(
|
||||||
|
'%r is not a vaild email address.' % email)
|
||||||
|
|
||||||
# prefer nulls to empty strings in the DB
|
user = model.User()
|
||||||
if not email:
|
|
||||||
email = None
|
|
||||||
|
|
||||||
user = User()
|
|
||||||
user.name = name
|
user.name = name
|
||||||
user.password_salt = self._password_service.create_password()
|
user.password_salt = self._password_service.create_password()
|
||||||
user.password_hash = self._password_service.get_password_hash(
|
user.password_hash = self._password_service.get_password_hash(
|
||||||
user.password_salt, password)
|
user.password_salt, password)
|
||||||
user.email = email
|
user.email = email or None
|
||||||
user.access_rank = self._config['service']['default_user_rank']
|
user.access_rank = self._config['service']['default_user_rank']
|
||||||
user.creation_time = datetime.now()
|
user.creation_time = datetime.now()
|
||||||
user.avatar_style = User.AVATAR_GRAVATAR
|
user.avatar_style = model.User.AVATAR_GRAVATAR
|
||||||
|
|
||||||
session.add(user)
|
session.add(user)
|
||||||
return user
|
return user
|
||||||
|
@ -58,4 +53,4 @@ class UserService(object):
|
||||||
|
|
||||||
def get_by_name(self, session, name):
|
def get_by_name(self, session, name):
|
||||||
''' Retrieves an user by its name. '''
|
''' Retrieves an user by its name. '''
|
||||||
return session.query(User).filter_by(name=name).first()
|
return session.query(model.User).filter_by(name=name).first()
|
||||||
|
|
|
@ -39,8 +39,6 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context = dotdict()
|
self.context = dotdict()
|
||||||
self.context.session = self.session
|
self.context.session = self.session
|
||||||
self.context.request = {}
|
self.context.request = {}
|
||||||
self.request = dotdict()
|
|
||||||
self.request.context = self.context
|
|
||||||
|
|
||||||
def _create_user(self, name, rank='admin'):
|
def _create_user(self, name, rank='admin'):
|
||||||
user = User()
|
user = User()
|
||||||
|
@ -58,7 +56,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
admin_user = self._create_user('u1', 'admin')
|
admin_user = self._create_user('u1', 'admin')
|
||||||
self.session.add(admin_user)
|
self.session.add(admin_user)
|
||||||
self.context.user = admin_user
|
self.context.user = admin_user
|
||||||
self.api.put(self.request, self.context, 'u1')
|
self.api.put(self.context, 'u1')
|
||||||
admin_user = self.session.query(User).filter_by(name='u1').one()
|
admin_user = self.session.query(User).filter_by(name='u1').one()
|
||||||
self.assertEqual(admin_user.name, 'u1')
|
self.assertEqual(admin_user.name, 'u1')
|
||||||
self.assertEqual(admin_user.email, 'dummy')
|
self.assertEqual(admin_user.email, 'dummy')
|
||||||
|
@ -74,7 +72,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
'password': 'valid',
|
'password': 'valid',
|
||||||
'accessRank': 'mod',
|
'accessRank': 'mod',
|
||||||
}
|
}
|
||||||
self.api.put(self.request, self.context, 'u1')
|
self.api.put(self.context, 'u1')
|
||||||
admin_user = self.session.query(User).filter_by(name='chewie').one()
|
admin_user = self.session.query(User).filter_by(name='chewie').one()
|
||||||
self.assertEqual(admin_user.name, 'chewie')
|
self.assertEqual(admin_user.name, 'chewie')
|
||||||
self.assertEqual(admin_user.email, 'asd@asd.asd')
|
self.assertEqual(admin_user.email, 'asd@asd.asd')
|
||||||
|
@ -87,7 +85,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.session.add(admin_user)
|
self.session.add(admin_user)
|
||||||
self.context.user = admin_user
|
self.context.user = admin_user
|
||||||
self.context.request = {'email': ''}
|
self.context.request = {'email': ''}
|
||||||
self.api.put(self.request, self.context, 'u1')
|
self.api.put(self.context, 'u1')
|
||||||
admin_user = self.session.query(User).filter_by(name='u1').one()
|
admin_user = self.session.query(User).filter_by(name='u1').one()
|
||||||
self.assertEqual(admin_user.email, None)
|
self.assertEqual(admin_user.email, None)
|
||||||
|
|
||||||
|
@ -97,16 +95,16 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context.user = admin_user
|
self.context.user = admin_user
|
||||||
self.context.request = {'name': '.'}
|
self.context.request = {'name': '.'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.request, self.context, 'u1')
|
ValidationError, self.api.put, self.context, 'u1')
|
||||||
self.context.request = {'password': '.'}
|
self.context.request = {'password': '.'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.request, self.context, 'u1')
|
ValidationError, self.api.put, self.context, 'u1')
|
||||||
self.context.request = {'accessRank': '.'}
|
self.context.request = {'accessRank': '.'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.request, self.context, 'u1')
|
ValidationError, self.api.put, self.context, 'u1')
|
||||||
self.context.request = {'email': '.'}
|
self.context.request = {'email': '.'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.request, self.context, 'u1')
|
ValidationError, self.api.put, self.context, 'u1')
|
||||||
|
|
||||||
def test_user_trying_to_update_someone_else(self):
|
def test_user_trying_to_update_someone_else(self):
|
||||||
user1 = self._create_user('u1', 'regular_user')
|
user1 = self._create_user('u1', 'regular_user')
|
||||||
|
@ -120,7 +118,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
{'password': 'whatever'}]:
|
{'password': 'whatever'}]:
|
||||||
self.context.request = request
|
self.context.request = request
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
AuthError, self.api.put, self.request, self.context, user2.name)
|
AuthError, self.api.put, self.context, user2.name)
|
||||||
|
|
||||||
def test_user_trying_to_become_someone_else(self):
|
def test_user_trying_to_become_someone_else(self):
|
||||||
user1 = self._create_user('u1', 'regular_user')
|
user1 = self._create_user('u1', 'regular_user')
|
||||||
|
@ -129,7 +127,7 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context.user = user1
|
self.context.user = user1
|
||||||
self.context.request = {'name': 'u2'}
|
self.context.request = {'name': 'u2'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
ValidationError, self.api.put, self.request, self.context, 'u1')
|
ValidationError, self.api.put, self.context, 'u1')
|
||||||
|
|
||||||
def test_mods_trying_to_become_admin(self):
|
def test_mods_trying_to_become_admin(self):
|
||||||
user1 = self._create_user('u1', 'mod')
|
user1 = self._create_user('u1', 'mod')
|
||||||
|
@ -138,6 +136,6 @@ class TestUserDetailApi(DatabaseTestCase):
|
||||||
self.context.user = user1
|
self.context.user = user1
|
||||||
self.context.request = {'accessRank': 'admin'}
|
self.context.request = {'accessRank': 'admin'}
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
AuthError, self.api.put, self.request, self.context, user1.name)
|
AuthError, self.api.put, self.context, user1.name)
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
AuthError, self.api.put, self.request, self.context, user2.name)
|
AuthError, self.api.put, self.context, user2.name)
|
||||||
|
|
|
@ -1,18 +1,17 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from szurubooru.errors import SearchError
|
from szurubooru import errors
|
||||||
from szurubooru.model.user import User
|
from szurubooru import model
|
||||||
from szurubooru.services.search.search_executor import SearchExecutor
|
from szurubooru.services import search
|
||||||
from szurubooru.services.search.user_search_config import UserSearchConfig
|
|
||||||
from szurubooru.tests.database_test_case import DatabaseTestCase
|
from szurubooru.tests.database_test_case import DatabaseTestCase
|
||||||
|
|
||||||
class TestUserSearchExecutor(DatabaseTestCase):
|
class TestUserSearchExecutor(DatabaseTestCase):
|
||||||
def setUp(self):
|
def setUp(self):
|
||||||
super().setUp()
|
super().setUp()
|
||||||
self.search_config = UserSearchConfig()
|
self.search_config = search.UserSearchConfig()
|
||||||
self.executor = SearchExecutor(self.search_config)
|
self.executor = search.SearchExecutor(self.search_config)
|
||||||
|
|
||||||
def _create_user(self, name):
|
def _create_user(self, name):
|
||||||
user = User()
|
user = model.User()
|
||||||
user.name = name
|
user.name = name
|
||||||
user.password = 'dummy'
|
user.password = 'dummy'
|
||||||
user.password_salt = 'dummy'
|
user.password_salt = 'dummy'
|
||||||
|
@ -20,7 +19,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
|
||||||
user.email = 'dummy'
|
user.email = 'dummy'
|
||||||
user.access_rank = 'dummy'
|
user.access_rank = 'dummy'
|
||||||
user.creation_time = datetime.now()
|
user.creation_time = datetime.now()
|
||||||
user.avatar_style = User.AVATAR_GRAVATAR
|
user.avatar_style = model.User.AVATAR_GRAVATAR
|
||||||
return user
|
return user
|
||||||
|
|
||||||
def _test(self, query, page, expected_count, expected_user_names):
|
def _test(self, query, page, expected_count, expected_user_names):
|
||||||
|
@ -60,7 +59,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
|
||||||
self._test('%s:2014-06..' % alias, 1, 2, ['u2', 'u3'])
|
self._test('%s:2014-06..' % alias, 1, 2, ['u2', 'u3'])
|
||||||
self._test('%s:..2014-06' % alias, 1, 2, ['u1', 'u2'])
|
self._test('%s:..2014-06' % alias, 1, 2, ['u1', 'u2'])
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
SearchError, self.executor.execute, self.session, '%s:..', 1)
|
errors.SearchError, self.executor.execute, self.session, '%s:..', 1)
|
||||||
|
|
||||||
def test_filter_by_negated_ranged_creation_time(self):
|
def test_filter_by_negated_ranged_creation_time(self):
|
||||||
user1 = self._create_user('u1')
|
user1 = self._create_user('u1')
|
||||||
|
@ -116,7 +115,7 @@ class TestUserSearchExecutor(DatabaseTestCase):
|
||||||
|
|
||||||
def test_filter_by_ranged_name(self):
|
def test_filter_by_ranged_name(self):
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
SearchError, self.executor.execute, self.session, 'name:u1..u2', 1)
|
errors.SearchError, self.executor.execute, self.session, 'name:u1..u2', 1)
|
||||||
|
|
||||||
def test_paging(self):
|
def test_paging(self):
|
||||||
self.executor.page_size = 1
|
self.executor.page_size = 1
|
||||||
|
@ -161,4 +160,4 @@ class TestUserSearchExecutor(DatabaseTestCase):
|
||||||
|
|
||||||
def test_special(self):
|
def test_special(self):
|
||||||
self.assertRaises(
|
self.assertRaises(
|
||||||
SearchError, self.executor.execute, self.session, 'special:-', 1)
|
errors.SearchError, self.executor.execute, self.session, 'special:-', 1)
|
||||||
|
|
|
@ -5,6 +5,7 @@ from szurubooru.util import parse_time_range
|
||||||
from szurubooru.errors import ValidationError
|
from szurubooru.errors import ValidationError
|
||||||
|
|
||||||
class FakeDatetime(datetime):
|
class FakeDatetime(datetime):
|
||||||
|
@staticmethod
|
||||||
def now(tz=None):
|
def now(tz=None):
|
||||||
return datetime(1997, 1, 2, 3, 4, 5, tzinfo=tz)
|
return datetime(1997, 1, 2, 3, 4, 5, tzinfo=tz)
|
||||||
|
|
||||||
|
@ -15,28 +16,28 @@ class TestParseTime(unittest.TestCase):
|
||||||
def test_today(self):
|
def test_today(self):
|
||||||
szurubooru.util.datetime.datetime = FakeDatetime
|
szurubooru.util.datetime.datetime = FakeDatetime
|
||||||
date_min, date_max = parse_time_range('today')
|
date_min, date_max = parse_time_range('today')
|
||||||
self.assertEquals(date_min, datetime(1997, 1, 2, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1997, 1, 2, 0, 0, 0))
|
||||||
self.assertEquals(date_max, datetime(1997, 1, 2, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1997, 1, 2, 23, 59, 59))
|
||||||
|
|
||||||
def test_yesterday(self):
|
def test_yesterday(self):
|
||||||
szurubooru.util.datetime.datetime = FakeDatetime
|
szurubooru.util.datetime.datetime = FakeDatetime
|
||||||
date_min, date_max = parse_time_range('yesterday')
|
date_min, date_max = parse_time_range('yesterday')
|
||||||
self.assertEquals(date_min, datetime(1997, 1, 1, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1997, 1, 1, 0, 0, 0))
|
||||||
self.assertEquals(date_max, datetime(1997, 1, 1, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1997, 1, 1, 23, 59, 59))
|
||||||
|
|
||||||
def test_year(self):
|
def test_year(self):
|
||||||
date_min, date_max = parse_time_range('1999')
|
date_min, date_max = parse_time_range('1999')
|
||||||
self.assertEquals(date_min, datetime(1999, 1, 1, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1999, 1, 1, 0, 0, 0))
|
||||||
self.assertEquals(date_max, datetime(1999, 12, 31, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1999, 12, 31, 23, 59, 59))
|
||||||
|
|
||||||
def test_month(self):
|
def test_month(self):
|
||||||
for text in ['1999-2', '1999-02']:
|
for text in ['1999-2', '1999-02']:
|
||||||
date_min, date_max = parse_time_range(text)
|
date_min, date_max = parse_time_range(text)
|
||||||
self.assertEquals(date_min, datetime(1999, 2, 1, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1999, 2, 1, 0, 0, 0))
|
||||||
self.assertEquals(date_max, datetime(1999, 2, 28, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1999, 2, 28, 23, 59, 59))
|
||||||
|
|
||||||
def test_day(self):
|
def test_day(self):
|
||||||
for text in ['1999-2-6', '1999-02-6', '1999-2-06', '1999-02-06']:
|
for text in ['1999-2-6', '1999-02-6', '1999-2-06', '1999-02-06']:
|
||||||
date_min, date_max = parse_time_range(text)
|
date_min, date_max = parse_time_range(text)
|
||||||
self.assertEquals(date_min, datetime(1999, 2, 6, 0, 0, 0))
|
self.assertEqual(date_min, datetime(1999, 2, 6, 0, 0, 0))
|
||||||
self.assertEquals(date_max, datetime(1999, 2, 6, 23, 59, 59))
|
self.assertEqual(date_max, datetime(1999, 2, 6, 23, 59, 59))
|
||||||
|
|
|
@ -1,12 +1,10 @@
|
||||||
''' Exports miscellaneous functions and data structures. '''
|
|
||||||
|
|
||||||
import datetime
|
import datetime
|
||||||
import re
|
import re
|
||||||
from szurubooru.errors import ValidationError
|
from szurubooru.errors import ValidationError
|
||||||
|
|
||||||
def is_valid_email(email):
|
def is_valid_email(email):
|
||||||
''' Validates given email address. '''
|
''' Validates given email address. '''
|
||||||
return not email or re.match('^[^@]*@[^@]*\.[^@]*$', email)
|
return not email or re.match(r'^[^@]*@[^@]*\.[^@]*$', email)
|
||||||
|
|
||||||
class dotdict(dict): # pylint: disable=invalid-name
|
class dotdict(dict): # pylint: disable=invalid-name
|
||||||
''' dot.notation access to dictionary attributes. '''
|
''' dot.notation access to dictionary attributes. '''
|
||||||
|
@ -28,24 +26,24 @@ def parse_time_range(value, timezone=datetime.timezone(datetime.timedelta())):
|
||||||
now = datetime.datetime.now(tz=timezone)
|
now = datetime.datetime.now(tz=timezone)
|
||||||
return (
|
return (
|
||||||
datetime.datetime(now.year, now.month, now.day, 0, 0, 0),
|
datetime.datetime(now.year, now.month, now.day, 0, 0, 0),
|
||||||
datetime.datetime(now.year, now.month, now.day, 0, 0, 0)
|
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) \
|
||||||
+ one_day - one_second)
|
+ one_day - one_second)
|
||||||
|
|
||||||
if value == 'yesterday':
|
if value == 'yesterday':
|
||||||
now = datetime.datetime.now(tz=timezone)
|
now = datetime.datetime.now(tz=timezone)
|
||||||
return (
|
return (
|
||||||
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) - one_day,
|
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) - one_day,
|
||||||
datetime.datetime(now.year, now.month, now.day, 0, 0, 0)
|
datetime.datetime(now.year, now.month, now.day, 0, 0, 0) \
|
||||||
- one_second)
|
- one_second)
|
||||||
|
|
||||||
match = re.match('^(\d{4})$', value)
|
match = re.match(r'^(\d{4})$', value)
|
||||||
if match:
|
if match:
|
||||||
year = int(match.group(1))
|
year = int(match.group(1))
|
||||||
return (
|
return (
|
||||||
datetime.datetime(year, 1, 1),
|
datetime.datetime(year, 1, 1),
|
||||||
datetime.datetime(year + 1, 1, 1) - one_second)
|
datetime.datetime(year + 1, 1, 1) - one_second)
|
||||||
|
|
||||||
match = re.match('^(\d{4})-(\d{1,2})$', value)
|
match = re.match(r'^(\d{4})-(\d{1,2})$', value)
|
||||||
if match:
|
if match:
|
||||||
year = int(match.group(1))
|
year = int(match.group(1))
|
||||||
month = int(match.group(2))
|
month = int(match.group(2))
|
||||||
|
@ -53,7 +51,7 @@ def parse_time_range(value, timezone=datetime.timezone(datetime.timedelta())):
|
||||||
datetime.datetime(year, month, 1),
|
datetime.datetime(year, month, 1),
|
||||||
datetime.datetime(year, month + 1, 1) - one_second)
|
datetime.datetime(year, month + 1, 1) - one_second)
|
||||||
|
|
||||||
match = re.match('^(\d{4})-(\d{1,2})-(\d{1,2})$', value)
|
match = re.match(r'^(\d{4})-(\d{1,2})-(\d{1,2})$', value)
|
||||||
if match:
|
if match:
|
||||||
year = int(match.group(1))
|
year = int(match.group(1))
|
||||||
month = int(match.group(2))
|
month = int(match.group(2))
|
||||||
|
|
Loading…
Reference in New Issue