Skip to content

Commit

Permalink
Merge pull request #9 from scastlara/improve-locks
Browse files Browse the repository at this point in the history
feat: Now works with multiprocessing
  • Loading branch information
scastlara authored Oct 23, 2022
2 parents a40dce0 + 7c52c77 commit 9a5fdf3
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 26 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ pytest --tags integration --tags MY_COMPONENT_NAME --tags-operand AND


## Extra
- It is thread-safe, so it can be used with [pytest-parallel](https://github.com/browsertron/pytest-parallel) `--tests-per-worker` option.
- It is thread-safe, so it can be used with [pytest-parallel](https://github.com/browsertron/pytest-parallel).
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "pytest_tagging"
version = "1.1.1"
version = "1.2.0"
description = "a pytest plugin to tag tests"
authors = ["Sergio Castillo <s.cast.lara@gmail.com>"]
readme = "README.md"
Expand Down
23 changes: 16 additions & 7 deletions pytest_tagging/plugin.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,28 @@
import threading
from collections import Counter
from enum import Enum

import pytest

from .utils import TagCounter, get_tags_from_item
from .utils import TagCounterThreadSafe, get_tags_from_item


class OperandChoices(Enum):
OR = "OR"
AND = "AND"


def pytest_configure(config):
def select_counter_class(config) -> type[Counter] | type[TagCounterThreadSafe]:
must_be_threadsafe = getattr(config.option, "workers", None) or getattr(
config.option, "tests_per_worker", None
)
return TagCounterThreadSafe if must_be_threadsafe else Counter


def pytest_configure(config) -> None:
config.addinivalue_line("markers", "tags('tag1', 'tag2'): add tags to a given test")
if not config.option.collectonly:
config.pluginmanager.register(TaggerRunner(), "taggerrunner")
counter_class = select_counter_class(config)
config.pluginmanager.register(TaggerRunner(counter_class), "taggerrunner")


def pytest_addoption(parser, pluginmanager) -> None:
Expand All @@ -39,9 +47,10 @@ def pytest_addoption(parser, pluginmanager) -> None:


class TaggerRunner:
def __init__(self):
self.lock = threading.Lock()
self.counter = TagCounter(self.lock)
def __init__(
self, counter_class: type[Counter] | type[TagCounterThreadSafe]
) -> None:
self.counter = counter_class()

def pytest_report_header(self, config) -> list[str]:
"""Add tagging config to pytest header."""
Expand Down
7 changes: 3 additions & 4 deletions pytest_tagging/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
import threading
from multiprocessing import Manager
from typing import Any, Iterable


class TagCounter:
class TagCounterThreadSafe:
"""Counter that uses pytest caching module to store the counts"""

def __init__(self, lock: threading.Lock) -> None:
self.lock = lock
def __init__(self) -> None:
self._manager = Manager()
self.counter = self._manager.dict()
self.lock = self._manager.Lock()

def update(self, tags: Iterable[str]) -> None:
with self.lock:
Expand Down
46 changes: 39 additions & 7 deletions tests/test_plugin.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,27 @@
from collections import Counter
from unittest.mock import Mock

import pytest

from pytest_tagging.plugin import select_counter_class
from pytest_tagging.utils import TagCounterThreadSafe


@pytest.mark.parametrize(
"option, expected_counter",
(
("foo", Counter),
("workers", TagCounterThreadSafe),
("tests_per_worker", TagCounterThreadSafe),
),
)
def test_select_counter_class(option, expected_counter):
m_option = Mock(spec=[option])
setattr(m_option, option, 1)
m_config = Mock(option=m_option)

assert select_counter_class(m_config) is expected_counter


def test_collect_tag(testdir):
testdir.makepyfile(
Expand Down Expand Up @@ -29,7 +51,7 @@ def test_untagged():
"""
)
result = testdir.runpytest("--tags=foo")
result.assert_outcomes(passed=1)
result.assert_outcomes(passed=1, failed=0)


def test_collect_tags_or(testdir):
Expand Down Expand Up @@ -69,7 +91,7 @@ def test_tagged_3():
"""
)
result = testdir.runpytest("--tags=foo", "--tags=bar", "--tags-operand=AND")
result.assert_outcomes(passed=1)
result.assert_outcomes(passed=1, failed=0)


def test_summary_contains_counts(testdir):
Expand All @@ -87,29 +109,39 @@ def test_tagged_1():
result.stdout.re_match_lines("foo - 1")


def test_taggerrunner_with_parallel_with_threads(testdir):
"""This fails if the taggerrunner is not threadsafe"""
def test_taggerrunner_with_parallel_with_processes_and_threads(testdir):
"""
This test ensures counts are collected correctly when tests run in different processes and threads.
Cannot use `pytest.mark.parametrize` because `testdir` fixture ends up raising a weird
AssertionError on teardown.
"""
testdir.makepyfile(
"""
import pytest
from time import sleep
@pytest.mark.tags('foo')
def test_tagged_1():
sleep(0.01)
assert False
@pytest.mark.tags('foo', 'bar')
def test_tagged_2():
sleep(0.01)
assert False
@pytest.mark.tags('bar')
def test_tagged_3():
sleep(0.01)
assert False
@pytest.mark.tags('bar')
def test_tagged_4():
sleep(0.01)
assert False
"""
)
result = testdir.runpytest("--tests-per-worker=2")
result.stdout.re_match_lines("foo - 2")
result.stdout.re_match_lines("bar - 3")
for _ in range(10): # to ensure the passing test is not just (very bad) luck
result = testdir.runpytest("--workers=2", "--tests-per-worker=2")
result.stdout.re_match_lines("foo - 2")
result.stdout.re_match_lines("bar - 3")
12 changes: 6 additions & 6 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import threading

from pytest_tagging.utils import TagCounter
from pytest_tagging.utils import TagCounterThreadSafe


class TestTagCounter:
class TestTagCounterThreadSafe:
def test_update(self):
counter = TagCounter(threading.Lock())
counter = TagCounterThreadSafe()
assert dict(counter.items()) == {}

counter.update({"A", "B", "C"})
Expand All @@ -15,10 +15,10 @@ def test_update(self):
assert dict(counter.items()) == {"A": 2, "B": 1, "C": 1}

def test_empty_is_false(self):
assert bool(TagCounter(threading.Lock())) is False
assert bool(TagCounterThreadSafe()) is False

def test_sorted_items(self):
counter = TagCounter(threading.Lock())
counter = TagCounterThreadSafe()
counter.update({"A", "B", "C"})
counter.update({"A"})
counter.update({"A"})
Expand All @@ -27,7 +27,7 @@ def test_sorted_items(self):
assert list(counter.items()) == [("A", 3), ("B", 2), ("C", 1)]

def test_threadsafe_update(self):
counter = TagCounter(threading.Lock())
counter = TagCounterThreadSafe()

def update(counter):
counter.update({"A", "B"})
Expand Down

0 comments on commit 9a5fdf3

Please sign in to comment.