diff --git a/agent/whois_ip_agent.py b/agent/whois_ip_agent.py index a8f9fef..195c999 100644 --- a/agent/whois_ip_agent.py +++ b/agent/whois_ip_agent.py @@ -122,6 +122,12 @@ def _process_ip(self, message: m.Message, host: str) -> None: network = ipaddress.ip_network(host) else: version = message.data.get("version") + if version is None: + try: + ip = ipaddress.ip_address(host) + version = ip.version + except ValueError: + raise ValueError(f"Invalid IP address: {host}") if version not in (4, 6): raise ValueError(f"Incorrect ip version {version}.") elif version == 4 and int(mask) < IPV4_CIDR_LIMIT: diff --git a/tests/whois_ip_agent_test.py b/tests/whois_ip_agent_test.py index 5225b0a..e2cc01b 100644 --- a/tests/whois_ip_agent_test.py +++ b/tests/whois_ip_agent_test.py @@ -1,6 +1,7 @@ """Unittests for WhoisIP agent.""" from typing import List, Dict +from unittest import mock import ipwhois import pytest @@ -304,3 +305,45 @@ def testWhoisIP_whenIPHasNoASN_doesNotCrash( test_agent.process(scan_message_global_ipv4_with_mask32) assert len(agent_mock) == 0 + + +def testWhoisIP_withIPv4AndMaskButNoVersion_shouldHandleVersionCorrectly( + test_agent: whois_ip_agent.WhoisIPAgent, + agent_persist_mock: dict[str | bytes, str | bytes], +) -> None: + """Test that process() handles the case when the version is None.""" + message_data = {"host": "80.121.155.176", "mask": "29"} + test_message = message.Message.from_data( + selector="v3.asset.ip.v4", + data=message_data, + ) + + with mock.patch.object( + test_agent, "_redis_client" + ) as mock_redis_client, mock.patch.object( + test_agent, "add_ip_network" + ) as mock_add_ip_network, mock.patch.object( + test_agent, "start", mock.MagicMock() + ), mock.patch.object(test_agent, "run", mock.MagicMock()), mock.patch( + "agent.whois_ip_agent.WhoisIPAgent.main", mock.MagicMock() + ): + mock_redis_client.sismember.return_value = False + + mock_add_ip_network.return_value = None + + test_agent.process(test_message) + + mock_add_ip_network.assert_called_once() + + +def testWhoisIP_whenInvalidIPAddressIsProvided_raisesValueError( + test_agent: whois_ip_agent.WhoisIPAgent, + mocker: plugin.MockerFixture, +) -> None: + """Test that a ValueError is raised when an invalid IP address is provided.""" + input_selector = "v3.asset.ip.v4" + input_data = {"host": "invalid_ip", "mask": "24"} + ip_msg = message.Message.from_data(selector=input_selector, data=input_data) + + with pytest.raises(ValueError, match="Invalid IP address: invalid_ip"): + test_agent.process(ip_msg)