Skip to content

Commit

Permalink
connection port scanning optimized
Browse files Browse the repository at this point in the history
  • Loading branch information
BHznJNs committed Nov 15, 2024
1 parent bb59907 commit d9adfd3
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 49 deletions.
4 changes: 3 additions & 1 deletion adb_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
os.environ["ADBUTILS_ADB_PATH"] = str(adb_bin_path)
ADB_BIN_PATH = str(adb_bin_path)

def try_connect_device(addr: str, timeout: float=4.0) -> adbutils.AdbClient | None:
def try_connect_device(addr: str, timeout: float=3.0) -> adbutils.AdbClient | None:
client = adbutils.AdbClient()
try:
output = client.connect(addr, timeout)
assert len(client.device_list()) > 0
LOGGER.write(LogType.Server, output)
except adbutils.AdbTimeout as e:
client.disconnect(addr)
LOGGER.write(LogType.Error, "Connect timeout: " + str(e))
return None
except Exception as e:
client.disconnect(addr)
LOGGER.write(LogType.Error, "Connect failed: " + str(e))
return None
return client
Expand Down
2 changes: 1 addition & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def close_notification_resolver(errno: Exception | None):
sys.exit(1)

client_socket = try_connect_server("localhost")
if not isinstance(client_socket, socket.socket):
if isinstance(client_socket, Exception):
close_notification_resolver(client_socket)
sys.exit(1)

Expand Down
90 changes: 52 additions & 38 deletions ui/connecting_window.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import threading
import adbutils
import customtkinter as ctk

from dataclasses import dataclass
Expand All @@ -8,8 +9,9 @@
from adb_controller import try_connect_device, try_pairing
from ui import ICON_ICO_PATH
from utils.config_manager import CONFIG
from utils.logger import unreachable
from utils.scan_port import scan_port
from utils.ip_check import get_ip_addr_part, is_valid_ip_addr_part, is_valid_ip_str
from utils.ip_check import get_ip_from_ip_port, is_valid_ip, is_valid_ip_port
from utils.i18n import I18N

def mount_pairing_view(tabview: ctk.CTkTabview, connecting_addr_entry: ctk.CTkEntry):
Expand All @@ -20,14 +22,14 @@ def pair_callback():

addr = addr_entry.get().strip()
pairing_code = pairing_code_entry.get().strip()
if not is_valid_ip_str(addr):
if not is_valid_ip(addr):
error_label.configure(text=I18N(["Invalid pairing address!", "配对地址无效!"]))
return
ret = try_pairing(addr, pairing_code)
if ret == False:
error_label.configure(text=I18N(["Pairing Failed!", "配对失败!"]))
return
device_ip = get_ip_addr_part(addr)
device_ip = get_ip_from_ip_port(addr)
connecting_addr_entry.insert(0, device_ip)
tabview.set(I18N(["Connecting", "连接"]))
connecting_window.deiconify()
Expand Down Expand Up @@ -97,58 +99,69 @@ def validate_entry(text: str):
button2.pack(side=ctk.RIGHT)

def mount_connecting_view(tabview: ctk.CTkTabview) -> ctk.CTkEntry:
@dataclass
class ProcessOk: ip_port_str: str
@dataclass
class ProcessError: msg: str
process_data_queue: Queue[str | ProcessError] = Queue()

def process_ip_port(addr: str, to_scan_port: bool):
valid_ip_str = is_valid_ip_str(addr)
valid_ip_addr_part = is_valid_ip_addr_part(addr)
if to_scan_port:
if not valid_ip_str and not valid_ip_addr_part:
process_data_queue.put(
ProcessError(I18N(["Invalid connecting address!", "连接地址无效!"])))
return
# format address string into ip_part only string
ip_addr_part = addr if valid_ip_addr_part else get_ip_addr_part(addr)
process_data_queue: Queue[ProcessOk | ProcessError] = Queue()

target_port = scan_port(ip_addr_part)
if target_port is None:
process_data_queue.put(
ProcessError(I18N(["Port scanning failed, please check the IP address.", "扫描端口失败,请检查 IP 地址是否正确。"])))
def scan_and_connect(addr: str):
nonlocal process_data_queue
valid_ip = is_valid_ip(addr)
valid_ip_port = is_valid_ip_port(addr)
if not valid_ip and not valid_ip_port:
process_data_queue.put(
ProcessError(I18N(["Invalid connecting address!", "连接地址无效!"])))
return
# format address string into ip_part only string
ip_addr_part = addr if valid_ip else get_ip_from_ip_port(addr)

target_ports = scan_port(ip_addr_part)
for port in target_ports:
connect_addr = f"{ip_addr_part}:{port}"
ret = try_connect_device(connect_addr)
if ret is not None:
process_data_queue.put(ProcessOk(connect_addr))
return
process_data_queue.put(f"{ip_addr_part}:{target_port}")
elif not valid_ip_str:
# not to_scan_port and ip-port address string is not valid
process_data_queue.put(
ProcessError(I18N(["Port scanning failed, please check the IP address.", "扫描端口失败,请检查 IP 地址是否正确。"])))

def direct_connect(addr: str):
nonlocal process_data_queue
if not is_valid_ip(addr):
process_data_queue.put(
ProcessError(I18N(["Invalid connecting address!", "配对地址无效!"])))
else: process_data_queue.put(addr)
return
ret = try_connect_device(addr)
if ret is not None:
process_data_queue.put(ProcessOk(addr))
return
process_data_queue.put(
ProcessError(I18N(["Connecting failed, please retry.", "连接失败,请重试。"])))

def process_ip_port(addr: str, to_scan_port: bool):
if to_scan_port: scan_and_connect(addr)
else: direct_connect(addr)

def process_callback():
global connecting_window
nonlocal error_label, waiting_label
nonlocal process_data_queue, error_label, waiting_label
if process_data_queue.empty():
# wait for ip_port data processed
connecting_window.after(func=process_callback, ms=100)
return

connect_addr = process_data_queue.get()
if type(connect_addr) == ProcessError:
error_label.configure(text=connect_addr.msg)
result = process_data_queue.get()
if type(result) == ProcessError:
error_label.configure(text=result.msg)
waiting_label.configure(text="")
enable_widgets()
connecting_window.configure(cursor="arrow")
return

assert type(connect_addr) == str
# got valid ip_port string, connect
ret = try_connect_device(connect_addr)
if ret is None:
error_label.configure(text=I18N(["Connecting failed, please retry.", "连接失败!"]))
waiting_label.configure(text="")
return
CONFIG.config.device_ip1 = get_ip_addr_part(connect_addr)
connecting_window.destroy()
enable_widgets()
elif type(result) == ProcessOk:
CONFIG.config.device_ip1 = get_ip_from_ip_port(result.ip_port_str)
connecting_window.destroy()
else: unreachable("Connection result: " + str(result))

def enable_widgets():
nonlocal addr_entry, auto_scan_port, button1, button2
Expand Down Expand Up @@ -233,6 +246,7 @@ def open_connecting_window():
global connecting_window
def delete_window_callback():
connecting_window.destroy()
adbutils.AdbClient().server_kill()
sys.exit(0)

connecting_window = ctk.CTk()
Expand Down
10 changes: 5 additions & 5 deletions utils/ip_check.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import ipaddress

def get_ip_addr_part(ip_port_str: str) -> str:
def get_ip_from_ip_port(ip_port_str: str) -> str:
if ip_port_str.count(":") > 1:
# IPv6 address
ip_part = ip_port_str.rsplit(":", 1)[0]
Expand All @@ -10,14 +10,14 @@ def get_ip_addr_part(ip_port_str: str) -> str:

# for only ip address part, like:
# "192.168.2.1"
def is_valid_ip_addr_part(ip_str: str) -> bool:
def is_valid_ip(ip_str: str) -> bool:
try:
ipaddress.ip_address(ip_str)
return True
except: return False

# for ip + port string, like:
# "192.168.2.1:80"
def is_valid_ip_str(ip_port_str: str) -> bool:
ip_part = get_ip_addr_part(ip_port_str)
return is_valid_ip_addr_part(ip_part)
def is_valid_ip_port(ip_port_str: str) -> bool:
ip_part = get_ip_from_ip_port(ip_port_str)
return is_valid_ip(ip_part)
14 changes: 14 additions & 0 deletions utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,20 @@ def write(self, type: LogType, message: str):
self.file.write(complete_log_message + "\n")
self.file.flush()

def todo(msg: str | None=None):
global LOGGER
if msg is None:
LOGGER.write(LogType.Error, "not yet implemented")
else:
LOGGER.write(LogType.Error, "not yet implemented: " + msg)

def unreachable(msg: str | None=None):
global LOGGER
if msg is None:
LOGGER.write(LogType.Error, "entered unreachable code")
else:
LOGGER.write(LogType.Error, "entered unreachable code: " + msg)

if getattr(sys, "frozen", False):
log_base_dir = os.path.dirname(sys.executable)
else:
Expand Down
8 changes: 4 additions & 4 deletions utils/scan_port.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,20 +26,20 @@ def test_batch_port(ip: str, start_port: int, end_port: int) -> int | None:
selector.close()
if port is not None: return port

def scan_port(ip: str) -> int | None:
def scan_port(ip: str) -> list[int]:
executor = ThreadPoolExecutor(max_workers=4)
futures: list[Future] = []
for i in range(DEFAULT_START_PORT, DEFAULT_END_PORT, DEFAULT_STEP):
future = executor.submit(test_batch_port, ip, i, min(DEFAULT_END_PORT, i + DEFAULT_STEP))
futures.append(future)

target_port = None
target_ports = []
for future in futures:
try:
result = future.result()
if result is None: continue
target_port = result; break
target_ports.append(result)
except Exception as e:
LOGGER.write(LogType.Error, "Port scanning error: " + str(e))
executor.shutdown(cancel_futures=True)
return target_port
return target_ports

0 comments on commit d9adfd3

Please sign in to comment.