Skip to content

Commit

Permalink
More cleanup.
Browse files Browse the repository at this point in the history
  • Loading branch information
wRAR committed Feb 2, 2025
1 parent 0a69c60 commit 57a3b7e
Show file tree
Hide file tree
Showing 4 changed files with 58 additions and 68 deletions.
50 changes: 24 additions & 26 deletions cssselect/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@
import operator
import re
import sys
import typing
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Protocol, Union, cast, overload

if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
Expand Down Expand Up @@ -375,17 +374,17 @@ class Attrib:
Represents selector[namespace|attrib operator value]
"""

@typing.overload
@overload
def __init__(
self,
selector: Tree,
namespace: str | None,
attrib: str,
operator: typing.Literal["exists"],
operator: Literal["exists"],
value: None,
) -> None: ...

@typing.overload
@overload
def __init__(
self,
selector: Tree,
Expand Down Expand Up @@ -607,7 +606,7 @@ def parse_selector(stream: TokenStream) -> tuple[Tree, PseudoElement | None]:
)
if peek.is_delim("+", ">", "~"):
# A combinator
combinator = typing.cast(str, stream.next().value)
combinator = cast(str, stream.next().value)
stream.skip_whitespace()
else:
# By exclusion, the last parse_simple_selector() ended
Expand Down Expand Up @@ -653,7 +652,7 @@ def parse_simple_selector(
"Got pseudo-element ::%s not at the end of a selector" % pseudo_element
)
if peek.type == "HASH":
result = Hash(result, typing.cast(str, stream.next().value))
result = Hash(result, cast(str, stream.next().value))
elif peek == ("DELIM", "."):
stream.next()
result = Class(result, stream.next_ident())
Expand Down Expand Up @@ -766,7 +765,7 @@ def parse_relative_selector(stream: TokenStream) -> tuple[Token, Selector]:
("DELIM", "."),
("DELIM", "*"),
]:
subselector += typing.cast(str, next.value)
subselector += cast(str, next.value)
elif next == ("DELIM", ")"):
result = parse(subselector)
return combinator, result[0]
Expand Down Expand Up @@ -820,13 +819,13 @@ def parse_attrib(selector: Tree, stream: TokenStream) -> Attrib:
stream.skip_whitespace()
next = stream.next()
if next == ("DELIM", "]"):
return Attrib(selector, namespace, typing.cast(str, attrib), "exists", None)
return Attrib(selector, namespace, cast(str, attrib), "exists", None)
if next == ("DELIM", "="):
op = "="
elif next.is_delim("^", "$", "*", "~", "|", "!") and (
stream.peek() == ("DELIM", "=")
):
op = typing.cast(str, next.value) + "="
op = cast(str, next.value) + "="
stream.next()
else:
raise SelectorSyntaxError("Operator expected, got %s" % (next,))
Expand All @@ -838,7 +837,7 @@ def parse_attrib(selector: Tree, stream: TokenStream) -> Attrib:
next = stream.next()
if next != ("DELIM", "]"):
raise SelectorSyntaxError("Expected ']', got %s" % (next,))
return Attrib(selector, namespace, typing.cast(str, attrib), op, value)
return Attrib(selector, namespace, cast(str, attrib), op, value)


def parse_series(tokens: Iterable[Token]) -> tuple[int, int]:
Expand All @@ -852,7 +851,7 @@ def parse_series(tokens: Iterable[Token]) -> tuple[int, int]:
for token in tokens:
if token.type == "STRING":
raise ValueError("String tokens not allowed in series.")
s = "".join(typing.cast(str, token.value) for token in tokens).strip()
s = "".join(cast(str, token.value) for token in tokens).strip()
if s == "odd":
return 2, 1
if s == "even":
Expand All @@ -878,16 +877,16 @@ def parse_series(tokens: Iterable[Token]) -> tuple[int, int]:


class Token(tuple[str, Optional[str]]): # noqa: SLOT001
@typing.overload
@overload
def __new__(
cls,
type_: typing.Literal["IDENT", "HASH", "STRING", "S", "DELIM", "NUMBER"],
type_: Literal["IDENT", "HASH", "STRING", "S", "DELIM", "NUMBER"],
value: str,
pos: int,
) -> Self: ...

@typing.overload
def __new__(cls, type_: typing.Literal["EOF"], value: None, pos: int) -> Self: ...
@overload
def __new__(cls, type_: Literal["EOF"], value: None, pos: int) -> Self: ...

def __new__(cls, type_: str, value: str | None, pos: int) -> Self:
obj = tuple.__new__(cls, (type_, value))
Expand All @@ -913,7 +912,7 @@ def value(self) -> str | None:
def css(self) -> str:
if self.type == "STRING":
return repr(self.value)
return typing.cast(str, self.value)
return cast(str, self.value)


class EOFToken(Token):
Expand All @@ -936,12 +935,10 @@ class TokenMacros:
nmstart = "[_a-z]|%s|%s" % (escape, nonascii)


if typing.TYPE_CHECKING:

class MatchFunc(typing.Protocol):
def __call__(
self, string: str, pos: int = ..., endpos: int = ...
) -> re.Match[str] | None: ...
class MatchFunc(Protocol):
def __call__(
self, string: str, pos: int = ..., endpos: int = ...
) -> re.Match[str] | None: ...


def _compile(pattern: str) -> MatchFunc:
Expand Down Expand Up @@ -1062,7 +1059,7 @@ def next(self) -> Token:
self._peeking = False
assert self.peeked is not None
self.used.append(self.peeked)
return typing.cast(Token, self.peeked)
return self.peeked
next = self.next_token()
self.used.append(next)
return next
Expand All @@ -1071,13 +1068,14 @@ def peek(self) -> Token:
if not self._peeking:
self.peeked = self.next_token()
self._peeking = True
return typing.cast(Token, self.peeked)
assert self.peeked is not None
return self.peeked

def next_ident(self) -> str:
next = self.next()
if next.type != "IDENT":
raise SelectorSyntaxError("Expected ident, got %s" % (next,))
return typing.cast(str, next.value)
return cast(str, next.value)

def next_ident_or_star(self) -> str | None:
next = self.next()
Expand Down
65 changes: 31 additions & 34 deletions cssselect/xpath.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,8 @@
from __future__ import annotations

import re
import typing
from collections.abc import Callable
from typing import Optional
from typing import TYPE_CHECKING, Optional, cast

from cssselect.parser import (
Attrib,
Expand All @@ -38,6 +37,10 @@
parse_series,
)

if TYPE_CHECKING:
# typing.Self requires Python 3.11
from typing_extensions import Self


class ExpressionError(SelectorError, RuntimeError):
"""Unknown or unsupported selector (eg. pseudo-class)."""
Expand Down Expand Up @@ -67,7 +70,7 @@ def __str__(self) -> str:
def __repr__(self) -> str:
return "%s[%s]" % (self.__class__.__name__, self)

def add_condition(self, condition: str, conjuction: str = "and") -> XPathExpr:
def add_condition(self, condition: str, conjuction: str = "and") -> Self:
if self.condition:
self.condition = "(%s) %s (%s)" % (self.condition, conjuction, condition)
else:
Expand Down Expand Up @@ -96,7 +99,7 @@ def join(
other: XPathExpr,
closing_combiner: str | None = None,
has_inner_condition: bool = False,
) -> XPathExpr:
) -> Self:
path = str(self) + combiner
# Any "star prefix" is redundant when joining.
if other.path != "*/":
Expand Down Expand Up @@ -276,19 +279,18 @@ def xpath_literal(s: str) -> str:
elif '"' not in s:
s = '"%s"' % s
else:
s = "concat(%s)" % ",".join(
[
((("'" in part) and '"%s"') or "'%s'") % part
for part in split_at_single_quotes(s)
if part
]
)
parts_quoted = [
f'"{part}"' if "'" in part else f"'{part}'"
for part in split_at_single_quotes(s)
if part
]
s = "concat({})".format(",".join(parts_quoted))
return s

def xpath(self, parsed_selector: Tree) -> XPathExpr:
"""Translate any parsed selector object."""
type_name = type(parsed_selector).__name__
method = typing.cast(
method = cast(
Optional[Callable[[Tree], XPathExpr]],
getattr(self, "xpath_%s" % type_name.lower(), None),
)
Expand All @@ -301,7 +303,7 @@ def xpath(self, parsed_selector: Tree) -> XPathExpr:
def xpath_combinedselector(self, combined: CombinedSelector) -> XPathExpr:
"""Translate a combined selector."""
combinator = self.combinator_mapping[combined.combinator]
method = typing.cast(
method = cast(
Callable[[XPathExpr, XPathExpr], XPathExpr],
getattr(self, "xpath_%s_combinator" % combinator),
)
Expand All @@ -321,12 +323,12 @@ def xpath_relation(self, relation: Relation) -> XPathExpr:
combinator = relation.combinator
subselector = relation.subselector
right = self.xpath(subselector.parsed_tree)
method = typing.cast(
method = cast(
Callable[[XPathExpr, XPathExpr], XPathExpr],
getattr(
self,
"xpath_relation_%s_combinator"
% self.combinator_mapping[typing.cast(str, combinator.value)],
% self.combinator_mapping[cast(str, combinator.value)],
),
)
return method(xpath, right)
Expand All @@ -352,7 +354,7 @@ def xpath_specificityadjustment(self, matching: SpecificityAdjustment) -> XPathE
def xpath_function(self, function: Function) -> XPathExpr:
"""Translate a functional pseudo-class."""
method_name = "xpath_%s_function" % function.name.replace("-", "_")
method = typing.cast(
method = cast(
Optional[Callable[[XPathExpr, Function], XPathExpr]],
getattr(self, method_name, None),
)
Expand All @@ -363,7 +365,7 @@ def xpath_function(self, function: Function) -> XPathExpr:
def xpath_pseudo(self, pseudo: Pseudo) -> XPathExpr:
"""Translate a pseudo-class."""
method_name = "xpath_%s_pseudo" % pseudo.ident.replace("-", "_")
method = typing.cast(
method = cast(
Optional[Callable[[XPathExpr], XPathExpr]], getattr(self, method_name, None)
)
if not method:
Expand All @@ -374,7 +376,7 @@ def xpath_pseudo(self, pseudo: Pseudo) -> XPathExpr:
def xpath_attrib(self, selector: Attrib) -> XPathExpr:
"""Translate an attribute selector."""
operator = self.attribute_operator_mapping[selector.operator]
method = typing.cast(
method = cast(
Callable[[XPathExpr, str, Optional[str]], XPathExpr],
getattr(self, "xpath_attrib_%s" % operator),
)
Expand All @@ -393,7 +395,7 @@ def xpath_attrib(self, selector: Attrib) -> XPathExpr:
if selector.value is None:
value = None
elif self.lower_case_attribute_values:
value = typing.cast(str, selector.value.value).lower()
value = cast(str, selector.value.value).lower()
else:
value = selector.value.value
return method(self.xpath(selector.selector), attrib, value)
Expand Down Expand Up @@ -649,7 +651,7 @@ def xpath_contains_function(
"Expected a single string or ident for :contains(), got %r"
% function.arguments
)
value = typing.cast(str, function.arguments[0].value)
value = cast(str, function.arguments[0].value)
return xpath.add_condition("contains(., %s)" % self.xpath_literal(value))

def xpath_lang_function(self, xpath: XPathExpr, function: Function) -> XPathExpr:
Expand All @@ -658,7 +660,7 @@ def xpath_lang_function(self, xpath: XPathExpr, function: Function) -> XPathExpr
"Expected a single string or ident for :lang(), got %r"
% function.arguments
)
value = typing.cast(str, function.arguments[0].value)
value = cast(str, function.arguments[0].value)
return xpath.add_condition("lang(%s)" % (self.xpath_literal(value)))

# Pseudo: dispatch by pseudo-class name
Expand Down Expand Up @@ -748,9 +750,9 @@ def xpath_attrib_includes(
self, xpath: XPathExpr, name: str, value: str | None
) -> XPathExpr:
if value and is_non_whitespace(value):
arg = self.xpath_literal(" " + value + " ")
xpath.add_condition(
"%s and contains(concat(' ', normalize-space(%s), ' '), %s)"
% (name, name, self.xpath_literal(" " + value + " "))
f"{name} and contains(concat(' ', normalize-space({name}), ' '), {arg})"
)
else:
xpath.add_condition("0")
Expand All @@ -760,16 +762,11 @@ def xpath_attrib_dashmatch(
self, xpath: XPathExpr, name: str, value: str | None
) -> XPathExpr:
assert value is not None
arg = self.xpath_literal(value)
arg_dash = self.xpath_literal(value + "-")
# Weird, but true...
xpath.add_condition(
"%s and (%s = %s or starts-with(%s, %s))"
% (
name,
name,
self.xpath_literal(value),
name,
self.xpath_literal(value + "-"),
)
f"{name} and ({name} = {arg} or starts-with({name}, {arg_dash}))"
)
return xpath

Expand Down Expand Up @@ -853,13 +850,13 @@ def xpath_lang_function(self, xpath: XPathExpr, function: Function) -> XPathExpr
)
value = function.arguments[0].value
assert value
arg = self.xpath_literal(value.lower() + "-")
return xpath.add_condition(
"ancestor-or-self::*[@lang][1][starts-with(concat("
# XPath 1.0 has no lower-case function...
"translate(@%s, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ', "
f"translate(@{self.lang_attribute}, 'ABCDEFGHIJKLMNOPQRSTUVWXYZ', "
"'abcdefghijklmnopqrstuvwxyz'), "
"'-'), %s)]"
% (self.lang_attribute, self.xpath_literal(value.lower() + "-"))
f"'-'), {arg})]"
)

def xpath_link_pseudo(self, xpath: XPathExpr) -> XPathExpr: # type: ignore[override]
Expand Down
1 change: 0 additions & 1 deletion pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ extension-pkg-allow-list=lxml
[MESSAGES CONTROL]
enable=useless-suppression
disable=consider-using-f-string,
duplicate-string-formatting-argument,
fixme,
invalid-name,
line-too-long,
Expand Down
10 changes: 3 additions & 7 deletions tests/test_cssselect.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,10 +78,7 @@ def repr_parse(css: str) -> list[str]:
selectors = parse(css)
for selector in selectors:
assert selector.pseudo_element is None
return [
repr(selector.parsed_tree).replace("(u'", "('")
for selector in selectors
]
return [repr(selector.parsed_tree) for selector in selectors]

def parse_many(first: str, *others: str) -> list[str]:
result = repr_parse(first)
Expand Down Expand Up @@ -196,7 +193,7 @@ def parse_pseudo(css: str) -> list[tuple[str, str | None]]:
pseudo = str(pseudo) if pseudo else pseudo
# No Symbol here
assert pseudo is None or isinstance(pseudo, str)
selector_as_str = repr(selector.parsed_tree).replace("(u'", "('")
selector_as_str = repr(selector.parsed_tree)
result.append((selector_as_str, pseudo))
return result

Expand Down Expand Up @@ -373,8 +370,7 @@ def get_error(css: str) -> str | None:
try:
parse(css)
except SelectorSyntaxError:
# Py2, Py3, ...
return str(sys.exc_info()[1]).replace("(u'", "('")
return str(sys.exc_info()[1])
return None

assert get_error("attributes(href)/html/body/a") == (
Expand Down

0 comments on commit 57a3b7e

Please sign in to comment.