From d9adfd38c3a7829ce534462339117e64a90db0bd Mon Sep 17 00:00:00 2001 From: BHznJNs <441768875@qq.com> Date: Fri, 15 Nov 2024 22:32:42 +0800 Subject: [PATCH] connection port scanning optimized --- adb_controller.py | 4 +- main.py | 2 +- ui/connecting_window.py | 90 ++++++++++++++++++++++++----------------- utils/ip_check.py | 10 ++--- utils/logger.py | 14 +++++++ utils/scan_port.py | 8 ++-- 6 files changed, 79 insertions(+), 49 deletions(-) diff --git a/adb_controller.py b/adb_controller.py index 1bbe9f7..f712ffe 100644 --- a/adb_controller.py +++ b/adb_controller.py @@ -13,16 +13,18 @@ os.environ["ADBUTILS_ADB_PATH"] = str(adb_bin_path) ADB_BIN_PATH = str(adb_bin_path) -def try_connect_device(addr: str, timeout: float=4.0) -> adbutils.AdbClient | None: +def try_connect_device(addr: str, timeout: float=3.0) -> adbutils.AdbClient | None: client = adbutils.AdbClient() try: output = client.connect(addr, timeout) assert len(client.device_list()) > 0 LOGGER.write(LogType.Server, output) except adbutils.AdbTimeout as e: + client.disconnect(addr) LOGGER.write(LogType.Error, "Connect timeout: " + str(e)) return None except Exception as e: + client.disconnect(addr) LOGGER.write(LogType.Error, "Connect failed: " + str(e)) return None return client diff --git a/main.py b/main.py index 86a0a95..6e58cfa 100644 --- a/main.py +++ b/main.py @@ -58,7 +58,7 @@ def close_notification_resolver(errno: Exception | None): sys.exit(1) client_socket = try_connect_server("localhost") - if not isinstance(client_socket, socket.socket): + if isinstance(client_socket, Exception): close_notification_resolver(client_socket) sys.exit(1) diff --git a/ui/connecting_window.py b/ui/connecting_window.py index 181d921..2ad9710 100644 --- a/ui/connecting_window.py +++ b/ui/connecting_window.py @@ -1,5 +1,6 @@ import sys import threading +import adbutils import customtkinter as ctk from dataclasses import dataclass @@ -8,8 +9,9 @@ from adb_controller import try_connect_device, try_pairing from ui import ICON_ICO_PATH from utils.config_manager import CONFIG +from utils.logger import unreachable from utils.scan_port import scan_port -from utils.ip_check import get_ip_addr_part, is_valid_ip_addr_part, is_valid_ip_str +from utils.ip_check import get_ip_from_ip_port, is_valid_ip, is_valid_ip_port from utils.i18n import I18N def mount_pairing_view(tabview: ctk.CTkTabview, connecting_addr_entry: ctk.CTkEntry): @@ -20,14 +22,14 @@ def pair_callback(): addr = addr_entry.get().strip() pairing_code = pairing_code_entry.get().strip() - if not is_valid_ip_str(addr): + if not is_valid_ip(addr): error_label.configure(text=I18N(["Invalid pairing address!", "配对地址无效!"])) return ret = try_pairing(addr, pairing_code) if ret == False: error_label.configure(text=I18N(["Pairing Failed!", "配对失败!"])) return - device_ip = get_ip_addr_part(addr) + device_ip = get_ip_from_ip_port(addr) connecting_addr_entry.insert(0, device_ip) tabview.set(I18N(["Connecting", "连接"])) connecting_window.deiconify() @@ -97,58 +99,69 @@ def validate_entry(text: str): button2.pack(side=ctk.RIGHT) def mount_connecting_view(tabview: ctk.CTkTabview) -> ctk.CTkEntry: + @dataclass + class ProcessOk: ip_port_str: str @dataclass class ProcessError: msg: str - process_data_queue: Queue[str | ProcessError] = Queue() - def process_ip_port(addr: str, to_scan_port: bool): - valid_ip_str = is_valid_ip_str(addr) - valid_ip_addr_part = is_valid_ip_addr_part(addr) - if to_scan_port: - if not valid_ip_str and not valid_ip_addr_part: - process_data_queue.put( - ProcessError(I18N(["Invalid connecting address!", "连接地址无效!"]))) - return - # format address string into ip_part only string - ip_addr_part = addr if valid_ip_addr_part else get_ip_addr_part(addr) + process_data_queue: Queue[ProcessOk | ProcessError] = Queue() - target_port = scan_port(ip_addr_part) - if target_port is None: - process_data_queue.put( - ProcessError(I18N(["Port scanning failed, please check the IP address.", "扫描端口失败,请检查 IP 地址是否正确。"]))) + def scan_and_connect(addr: str): + nonlocal process_data_queue + valid_ip = is_valid_ip(addr) + valid_ip_port = is_valid_ip_port(addr) + if not valid_ip and not valid_ip_port: + process_data_queue.put( + ProcessError(I18N(["Invalid connecting address!", "连接地址无效!"]))) + return + # format address string into ip_part only string + ip_addr_part = addr if valid_ip else get_ip_from_ip_port(addr) + + target_ports = scan_port(ip_addr_part) + for port in target_ports: + connect_addr = f"{ip_addr_part}:{port}" + ret = try_connect_device(connect_addr) + if ret is not None: + process_data_queue.put(ProcessOk(connect_addr)) return - process_data_queue.put(f"{ip_addr_part}:{target_port}") - elif not valid_ip_str: - # not to_scan_port and ip-port address string is not valid + process_data_queue.put( + ProcessError(I18N(["Port scanning failed, please check the IP address.", "扫描端口失败,请检查 IP 地址是否正确。"]))) + + def direct_connect(addr: str): + nonlocal process_data_queue + if not is_valid_ip(addr): process_data_queue.put( ProcessError(I18N(["Invalid connecting address!", "配对地址无效!"]))) - else: process_data_queue.put(addr) + return + ret = try_connect_device(addr) + if ret is not None: + process_data_queue.put(ProcessOk(addr)) + return + process_data_queue.put( + ProcessError(I18N(["Connecting failed, please retry.", "连接失败,请重试。"]))) + + def process_ip_port(addr: str, to_scan_port: bool): + if to_scan_port: scan_and_connect(addr) + else: direct_connect(addr) def process_callback(): global connecting_window - nonlocal error_label, waiting_label + nonlocal process_data_queue, error_label, waiting_label if process_data_queue.empty(): # wait for ip_port data processed connecting_window.after(func=process_callback, ms=100) return - connect_addr = process_data_queue.get() - if type(connect_addr) == ProcessError: - error_label.configure(text=connect_addr.msg) + result = process_data_queue.get() + if type(result) == ProcessError: + error_label.configure(text=result.msg) waiting_label.configure(text="") - enable_widgets() connecting_window.configure(cursor="arrow") - return - - assert type(connect_addr) == str - # got valid ip_port string, connect - ret = try_connect_device(connect_addr) - if ret is None: - error_label.configure(text=I18N(["Connecting failed, please retry.", "连接失败!"])) - waiting_label.configure(text="") - return - CONFIG.config.device_ip1 = get_ip_addr_part(connect_addr) - connecting_window.destroy() + enable_widgets() + elif type(result) == ProcessOk: + CONFIG.config.device_ip1 = get_ip_from_ip_port(result.ip_port_str) + connecting_window.destroy() + else: unreachable("Connection result: " + str(result)) def enable_widgets(): nonlocal addr_entry, auto_scan_port, button1, button2 @@ -233,6 +246,7 @@ def open_connecting_window(): global connecting_window def delete_window_callback(): connecting_window.destroy() + adbutils.AdbClient().server_kill() sys.exit(0) connecting_window = ctk.CTk() diff --git a/utils/ip_check.py b/utils/ip_check.py index 02f6731..545d2d8 100644 --- a/utils/ip_check.py +++ b/utils/ip_check.py @@ -1,6 +1,6 @@ import ipaddress -def get_ip_addr_part(ip_port_str: str) -> str: +def get_ip_from_ip_port(ip_port_str: str) -> str: if ip_port_str.count(":") > 1: # IPv6 address ip_part = ip_port_str.rsplit(":", 1)[0] @@ -10,7 +10,7 @@ def get_ip_addr_part(ip_port_str: str) -> str: # for only ip address part, like: # "192.168.2.1" -def is_valid_ip_addr_part(ip_str: str) -> bool: +def is_valid_ip(ip_str: str) -> bool: try: ipaddress.ip_address(ip_str) return True @@ -18,6 +18,6 @@ def is_valid_ip_addr_part(ip_str: str) -> bool: # for ip + port string, like: # "192.168.2.1:80" -def is_valid_ip_str(ip_port_str: str) -> bool: - ip_part = get_ip_addr_part(ip_port_str) - return is_valid_ip_addr_part(ip_part) +def is_valid_ip_port(ip_port_str: str) -> bool: + ip_part = get_ip_from_ip_port(ip_port_str) + return is_valid_ip(ip_part) diff --git a/utils/logger.py b/utils/logger.py index 97f3e07..134b9da 100644 --- a/utils/logger.py +++ b/utils/logger.py @@ -29,6 +29,20 @@ def write(self, type: LogType, message: str): self.file.write(complete_log_message + "\n") self.file.flush() +def todo(msg: str | None=None): + global LOGGER + if msg is None: + LOGGER.write(LogType.Error, "not yet implemented") + else: + LOGGER.write(LogType.Error, "not yet implemented: " + msg) + +def unreachable(msg: str | None=None): + global LOGGER + if msg is None: + LOGGER.write(LogType.Error, "entered unreachable code") + else: + LOGGER.write(LogType.Error, "entered unreachable code: " + msg) + if getattr(sys, "frozen", False): log_base_dir = os.path.dirname(sys.executable) else: diff --git a/utils/scan_port.py b/utils/scan_port.py index 6f113d3..8c00c03 100644 --- a/utils/scan_port.py +++ b/utils/scan_port.py @@ -26,20 +26,20 @@ def test_batch_port(ip: str, start_port: int, end_port: int) -> int | None: selector.close() if port is not None: return port -def scan_port(ip: str) -> int | None: +def scan_port(ip: str) -> list[int]: executor = ThreadPoolExecutor(max_workers=4) futures: list[Future] = [] for i in range(DEFAULT_START_PORT, DEFAULT_END_PORT, DEFAULT_STEP): future = executor.submit(test_batch_port, ip, i, min(DEFAULT_END_PORT, i + DEFAULT_STEP)) futures.append(future) - target_port = None + target_ports = [] for future in futures: try: result = future.result() if result is None: continue - target_port = result; break + target_ports.append(result) except Exception as e: LOGGER.write(LogType.Error, "Port scanning error: " + str(e)) executor.shutdown(cancel_futures=True) - return target_port + return target_ports