server/net: prevent youtube-dl errors when downloading image links

This commit is contained in:
Shyam Sunder 2021-01-07 08:28:22 -05:00
parent c732e62844
commit 2b9a4ab786
2 changed files with 20 additions and 4 deletions

View File

@ -7,6 +7,7 @@ from threading import Thread
from typing import Any, Dict, List from typing import Any, Dict, List
from szurubooru import config, errors from szurubooru import config, errors
from szurubooru.func import mime
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
_dl_chunk_size = 2 ** 15 _dl_chunk_size = 2 ** 15
@ -22,8 +23,12 @@ class DownloadTooLargeError(DownloadError):
def download(url: str, use_video_downloader: bool = False) -> bytes: def download(url: str, use_video_downloader: bool = False) -> bytes:
assert url assert url
youtube_dl_error = None
if use_video_downloader: if use_video_downloader:
url = _get_youtube_dl_content_url(url) try:
url = _get_youtube_dl_content_url(url) or url
except errors.ThirdPartyError as ex:
youtube_dl_error = ex
request = urllib.request.Request(url) request = urllib.request.Request(url)
if config.config["user_agent"]: if config.config["user_agent"]:
@ -41,11 +46,18 @@ def download(url: str, use_video_downloader: bool = False) -> bytes:
content_buffer += chunk content_buffer += chunk
except urllib.error.HTTPError as ex: except urllib.error.HTTPError as ex:
raise DownloadError(url) from ex raise DownloadError(url) from ex
if (
youtube_dl_error
and mime.get_mime_type(content_buffer) == "application/octet-stream"
):
raise youtube_dl_error
return content_buffer return content_buffer
def _get_youtube_dl_content_url(url: str) -> str: def _get_youtube_dl_content_url(url: str) -> str:
cmd = ["youtube-dl", "--format", "best"] cmd = ["youtube-dl", "--format", "best", "--no-playlist"]
if config.config["user_agent"]: if config.config["user_agent"]:
cmd.extend(["--user-agent", config.config["user_agent"]]) cmd.extend(["--user-agent", config.config["user_agent"]])
cmd.extend(["--get-url", url]) cmd.extend(["--get-url", url])
@ -85,6 +97,6 @@ def _post_to_webhook(webhook: str, payload: Dict[str, Any]) -> int:
f"Webhook {webhook} returned {res.status} {res.reason}" f"Webhook {webhook} returned {res.status} {res.reason}"
) )
return res.status return res.status
except urllib.error.URLError as e: except urllib.error.URLError as ex:
logger.warning(f"Unable to call webhook {webhook}: {str(e)}") logger.warning(f"Unable to call webhook {webhook}: {ex}")
return 400 return 400

View File

@ -85,6 +85,10 @@ def test_too_large_download(url):
"https://upload.wikimedia.org/wikipedia/commons/a/ad/Utah_teapot.png", # noqa: E501 "https://upload.wikimedia.org/wikipedia/commons/a/ad/Utah_teapot.png", # noqa: E501
"cfadcbdeda1204dc1363ee5c1969191f26be2e41", "cfadcbdeda1204dc1363ee5c1969191f26be2e41",
), ),
(
"https://i.imgur.com/GPgh0AN.jpg",
"26861a4663fedae48e5beed3eec5156ded20640f",
),
], ],
) )
def test_content_download(url, expected_sha1): def test_content_download(url, expected_sha1):