server/password-reset: try to construct full URL
This commit is contained in:
		
							parent
							
								
									d85e746a65
								
							
						
					
					
						commit
						c9cb9aa539
					
				@ -13,7 +13,7 @@ MAIL_BODY = (
 | 
			
		||||
 | 
			
		||||
@rest.routes.get('/password-reset/(?P<user_name>[^/]+)/?')
 | 
			
		||||
def start_password_reset(
 | 
			
		||||
        _ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
 | 
			
		||||
        ctx: rest.Context, params: Dict[str, str]) -> rest.Response:
 | 
			
		||||
    user_name = params['user_name']
 | 
			
		||||
    user = users.get_user_by_name_or_email(user_name)
 | 
			
		||||
    if not user.email:
 | 
			
		||||
@ -21,12 +21,19 @@ def start_password_reset(
 | 
			
		||||
            'User %r hasn\'t supplied email. Cannot reset password.' % (
 | 
			
		||||
                user_name))
 | 
			
		||||
    token = auth.generate_authentication_token(user)
 | 
			
		||||
    url = '/password-reset/%s:%s' % (user.name, token)
 | 
			
		||||
 | 
			
		||||
    if 'HTTP_ORIGIN' in ctx.env:
 | 
			
		||||
        url = ctx.env['HTTP_ORIGIN'].rstrip('/')
 | 
			
		||||
    else:
 | 
			
		||||
        url = ''
 | 
			
		||||
    url += '/password-reset/%s:%s' % (user.name, token)
 | 
			
		||||
 | 
			
		||||
    mailer.send_mail(
 | 
			
		||||
        'noreply@%s' % config.config['name'],
 | 
			
		||||
        user.email,
 | 
			
		||||
        MAIL_SUBJECT.format(name=config.config['name']),
 | 
			
		||||
        MAIL_BODY.format(name=config.config['name'], url=url))
 | 
			
		||||
 | 
			
		||||
    return {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -63,7 +63,7 @@ def _create_context(env: Dict[str, Any]) -> context.Context:
 | 
			
		||||
                'Could not decode the request body. The JSON '
 | 
			
		||||
                'was incorrect or was not encoded as UTF-8.')
 | 
			
		||||
 | 
			
		||||
    return context.Context(method, path, headers, params, files)
 | 
			
		||||
    return context.Context(env, method, path, headers, params, files)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def application(
 | 
			
		||||
 | 
			
		||||
@ -11,11 +11,13 @@ Response = Optional[Dict[str, Any]]
 | 
			
		||||
class Context:
 | 
			
		||||
    def __init__(
 | 
			
		||||
            self,
 | 
			
		||||
            env: Dict[str, Any],
 | 
			
		||||
            method: str,
 | 
			
		||||
            url: str,
 | 
			
		||||
            headers: Dict[str, str] = None,
 | 
			
		||||
            params: Request = None,
 | 
			
		||||
            files: Dict[str, bytes] = None) -> None:
 | 
			
		||||
        self.env = env
 | 
			
		||||
        self.method = method
 | 
			
		||||
        self.url = url
 | 
			
		||||
        self._headers = headers or {}
 | 
			
		||||
 | 
			
		||||
@ -95,6 +95,7 @@ def session(query_logger):  # pylint: disable=unused-argument
 | 
			
		||||
def context_factory(session):
 | 
			
		||||
    def factory(params=None, files=None, user=None, headers=None):
 | 
			
		||||
        ctx = rest.Context(
 | 
			
		||||
            env={'HTTP_ORIGIN': 'http://example.com'},
 | 
			
		||||
            method=None,
 | 
			
		||||
            url=None,
 | 
			
		||||
            headers=headers or {},
 | 
			
		||||
 | 
			
		||||
@ -6,13 +6,14 @@ from szurubooru.func import net
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_has_param():
 | 
			
		||||
    ctx = rest.Context(method=None, url=None, params={'key': 'value'})
 | 
			
		||||
    ctx = rest.Context(env={}, method=None, url=None, params={'key': 'value'})
 | 
			
		||||
    assert ctx.has_param('key')
 | 
			
		||||
    assert not ctx.has_param('non-existing')
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_get_file():
 | 
			
		||||
    ctx = rest.Context(method=None, url=None, files={'key': b'content'})
 | 
			
		||||
    ctx = rest.Context(
 | 
			
		||||
        env={}, method=None, url=None, files={'key': b'content'})
 | 
			
		||||
    assert ctx.get_file('key') == b'content'
 | 
			
		||||
    with pytest.raises(errors.ValidationError):
 | 
			
		||||
        ctx.get_file('non-existing')
 | 
			
		||||
@ -22,7 +23,7 @@ def test_get_file_from_url():
 | 
			
		||||
    with unittest.mock.patch('szurubooru.func.net.download'):
 | 
			
		||||
        net.download.return_value = b'content'
 | 
			
		||||
        ctx = rest.Context(
 | 
			
		||||
            method=None, url=None, params={'keyUrl': 'example.com'})
 | 
			
		||||
            env={}, method=None, url=None, params={'keyUrl': 'example.com'})
 | 
			
		||||
        assert ctx.get_file('key') == b'content'
 | 
			
		||||
        net.download.assert_called_once_with('example.com')
 | 
			
		||||
        with pytest.raises(errors.ValidationError):
 | 
			
		||||
@ -31,6 +32,7 @@ def test_get_file_from_url():
 | 
			
		||||
 | 
			
		||||
def test_getting_list_parameter():
 | 
			
		||||
    ctx = rest.Context(
 | 
			
		||||
        env={},
 | 
			
		||||
        method=None,
 | 
			
		||||
        url=None,
 | 
			
		||||
        params={'key': 'value', 'list': ['1', '2', '3']})
 | 
			
		||||
@ -43,6 +45,7 @@ def test_getting_list_parameter():
 | 
			
		||||
 | 
			
		||||
def test_getting_string_parameter():
 | 
			
		||||
    ctx = rest.Context(
 | 
			
		||||
        env={},
 | 
			
		||||
        method=None,
 | 
			
		||||
        url=None,
 | 
			
		||||
        params={'key': 'value', 'list': ['1', '2', '3']})
 | 
			
		||||
@ -55,6 +58,7 @@ def test_getting_string_parameter():
 | 
			
		||||
 | 
			
		||||
def test_getting_int_parameter():
 | 
			
		||||
    ctx = rest.Context(
 | 
			
		||||
        env={},
 | 
			
		||||
        method=None,
 | 
			
		||||
        url=None,
 | 
			
		||||
        params={'key': '50', 'err': 'invalid', 'list': [1, 2, 3]})
 | 
			
		||||
@ -76,7 +80,8 @@ def test_getting_int_parameter():
 | 
			
		||||
 | 
			
		||||
def test_getting_bool_parameter():
 | 
			
		||||
    def test(value):
 | 
			
		||||
        ctx = rest.Context(method=None, url=None, params={'key': value})
 | 
			
		||||
        ctx = rest.Context(
 | 
			
		||||
            env={}, method=None, url=None, params={'key': value})
 | 
			
		||||
        return ctx.get_param_as_bool('key')
 | 
			
		||||
 | 
			
		||||
    assert test('1') is True
 | 
			
		||||
@ -104,7 +109,7 @@ def test_getting_bool_parameter():
 | 
			
		||||
    with pytest.raises(errors.ValidationError):
 | 
			
		||||
        test(['1', '2'])
 | 
			
		||||
 | 
			
		||||
    ctx = rest.Context(method=None, url=None)
 | 
			
		||||
    ctx = rest.Context(env={}, method=None, url=None)
 | 
			
		||||
    with pytest.raises(errors.ValidationError):
 | 
			
		||||
        ctx.get_param_as_bool('non-existing')
 | 
			
		||||
    assert ctx.get_param_as_bool('non-existing', default=True) is True
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user