Skip to content

Commit

Permalink
Fix file too long on tmp file create (#2132)
Browse files Browse the repository at this point in the history
* Splits the extension on a ? query parameter
* Ensure that the filename in URLPath is truncated
  • Loading branch information
8W9aG authored Feb 3, 2025
1 parent 552a0cc commit 0e36b61
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 1 deletion.
4 changes: 4 additions & 0 deletions python/cog/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,9 @@ class URLPath(pathlib.PosixPath): # pylint: disable=abstract-method
_path: Optional[Path]

def __init__(self, *, source: str, filename: str, fileobj: io.IOBase) -> None: # pylint: disable=super-init-not-called
if len(filename) > FILENAME_MAX_LENGTH:
filename = _truncate_filename_bytes(filename, FILENAME_MAX_LENGTH)

self.source = source
self.filename = filename
self.fileobj = fileobj
Expand Down Expand Up @@ -540,5 +543,6 @@ def _truncate_filename_bytes(s: str, length: int, encoding: str = "utf-8") -> st
and avoiding text encoding corruption from truncation.
"""
root, ext = os.path.splitext(s.encode(encoding))
ext = ext.decode(encoding).split("?")[0].encode(encoding)
root = root[: length - len(ext) - 1]
return root.decode(encoding, "ignore") + "~" + ext.decode(encoding)
23 changes: 22 additions & 1 deletion python/tests/test_types.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import io
import pickle
import random
import string

import pytest
import responses

from cog.types import Secret, URLFile, get_filename
from cog.types import Secret, URLFile, URLPath, get_filename


def test_urlfile_protocol_validation():
Expand Down Expand Up @@ -146,3 +148,22 @@ def test_secret_type():

assert secret.get_secret_value() == secret_value
assert str(secret) == "**********"


def test_truncate_filename_if_long():
# Test that a file too long exception is not raised.
random_str = "".join(
random.SystemRandom().choice(string.ascii_uppercase + string.digits)
for _ in range(350)
)
big_query = "query=" + random_str
filename = "waffwyyg~.zip"
url = "https://www.amazon.com/" + filename + "?" + big_query
fileobj = io.BytesIO()
url_path = URLPath(
source=url,
filename=filename + "?" + big_query,
fileobj=fileobj,
)
assert url_path.filename == "waffwyyg~~.zip"
_ = url_path.convert()

0 comments on commit 0e36b61

Please sign in to comment.