diff --git a/README.md b/README.md index 35d4c2d..2d37957 100644 --- a/README.md +++ b/README.md @@ -96,13 +96,18 @@ Here is an example of how to configure Flask-Compress with caching using Flask-C The example demonstrates how to create a simple cache instance with a 1-hour timeout, and use it to cache compressed responses for incoming requests. ```python +from flask import Flask +from flask_compress import Compress +from flask_cache import Cache + # Initializing flask app app = Flask(__name__) -cache = Cache(app, config={ - 'CACHE_TYPE': 'simple', +cache = Cache(config={ + 'CACHE_TYPE': 'SimpleCache', 'CACHE_DEFAULT_TIMEOUT': 60*60 # 1 hour cache timeout }) +cache.init_app(app) # Define a function to return cache key for incoming requests def get_cache_key(request): diff --git a/flask_compress/flask_compress.py b/flask_compress/flask_compress.py index 1d2210b..2db1971 100644 --- a/flask_compress/flask_compress.py +++ b/flask_compress/flask_compress.py @@ -179,7 +179,7 @@ def after_request(self, response): response.direct_passthrough = False if self.cache is not None: - key = self.cache_key(request) + key = f"{chosen_algorithm};{self.cache_key(request)}" compressed_content = self.cache.get(key) if compressed_content is None: compressed_content = self.compress(app, response, chosen_algorithm) diff --git a/tests/test_flask_compress.py b/tests/test_flask_compress.py index fda15f7..6874eb0 100644 --- a/tests/test_flask_compress.py +++ b/tests/test_flask_compress.py @@ -1,7 +1,10 @@ +import gzip import os +import tempfile import unittest from flask import Flask, render_template +from flask_caching import Cache from flask_compress import Compress @@ -438,5 +441,61 @@ def test_disabled_stream(self): self.assertGreater(self.file_size, len(response.data)) +class CachingCompressionTests(unittest.TestCase): + def setUp(self): + self.view_calls = 0 + self.tmpdir = tempfile.TemporaryDirectory() + + self.app = Flask(__name__) + self.app.testing = True + cache = Cache( + config={ + "CACHE_TYPE": "FileSystemCache", + "CACHE_DIR": self.tmpdir.name, + "CACHE_DEFAULT_TIMEOUT": 60 * 60, # 1 hour cache timeout + } + ) + cache.init_app(self.app) + + def get_cache_key(request): + return request.url + + compress = Compress() + compress.init_app(self.app) + + compress.cache = cache + compress.cache_key = get_cache_key + + @self.app.route("/route/") + def view(): + self.view_calls += 1 + return render_template("large.html") + + def tearDown(self): + self.tmpdir.cleanup() + + def test_compression(self): + # Here we are testing cache pollution where the same query is cached + # but with different compression algorithms. The cache key should include + # the compression algorithm so that the cache is not polluted. + client = self.app.test_client() + + headers = [("Accept-Encoding", "deflate")] + response = client.get("/route/", headers=headers) + self.assertEqual(response.status_code, 200) + self.assertIn("Content-Encoding", response.headers) + self.assertEqual(response.headers.get("Content-Encoding"), "deflate") + self.assertEqual(self.view_calls, 1) + + headers = [("Accept-Encoding", "gzip")] + response = client.get("/route/", headers=headers) + self.assertEqual(response.status_code, 200) + self.assertIn("Content-Encoding", response.headers) + self.assertEqual(response.headers.get("Content-Encoding"), "gzip") + self.assertEqual(self.view_calls, 2) + # If cache is polluted, this decompression fails as we get brotli + _ = gzip.decompress(response.data) + + if __name__ == "__main__": unittest.main() diff --git a/tox.ini b/tox.ini index 2235b3c..e9402e8 100644 --- a/tox.ini +++ b/tox.ini @@ -25,6 +25,7 @@ wheel_build_env = .pkg deps = coverage[toml] pytest + flask-caching commands = coverage run -m pytest {posargs}