Skip to content

Commit

Permalink
Add support for USB connections
Browse files Browse the repository at this point in the history
Adds a new subclass of PybricksHub that manages USB connections.

Co-developed-by: David Lechner <david@pybricks.com>
Signed-off-by: Nate Karstens <nate.karstens@gmail.com>
  • Loading branch information
nkarstens authored and dlech committed Feb 1, 2025
1 parent 3d6cf82 commit b9e8e1e
Show file tree
Hide file tree
Showing 5 changed files with 222 additions and 7 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

## [Unreleased]

### Added
- Partial/experimental support for `pybricksdev run usb`.

### Fixed
- Fix crash when running `pybricksdev run ble -` (bug introduced in alpha.49).

Expand Down
45 changes: 38 additions & 7 deletions pybricksdev/cli/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,11 @@ def add_parser(self, subparsers: argparse._SubParsersAction):
)

async def run(self, args: argparse.Namespace):
from pybricksdev.ble import find_device
from pybricksdev.connections.ev3dev import EV3Connection
from pybricksdev.connections.lego import REPLHub
from pybricksdev.connections.pybricks import PybricksHubBLE

# Pick the right connection
if args.conntype == "ssh":
from pybricksdev.connections.ev3dev import EV3Connection

# So it's an ev3dev
if args.name is None:
print("--name is required for SSH connections", file=sys.stderr)
Expand All @@ -186,13 +184,46 @@ async def run(self, args: argparse.Namespace):
device_or_address = socket.gethostbyname(args.name)
hub = EV3Connection(device_or_address)
elif args.conntype == "ble":
from pybricksdev.ble import find_device as find_ble
from pybricksdev.connections.pybricks import PybricksHubBLE

# It is a Pybricks Hub with BLE. Device name or address is given.
print(f"Searching for {args.name or 'any hub with Pybricks service'}...")
device_or_address = await find_device(args.name)
device_or_address = await find_ble(args.name)
hub = PybricksHubBLE(device_or_address)

elif args.conntype == "usb":
hub = REPLHub()
from usb.core import find as find_usb

from pybricksdev.connections.pybricks import PybricksHubUSB
from pybricksdev.usb import (
LEGO_USB_VID,
MINDSTORMS_INVENTOR_USB_PID,
SPIKE_ESSENTIAL_USB_PID,
SPIKE_PRIME_USB_PID,
)

def is_pybricks_usb(dev):
return (
(dev.idVendor == LEGO_USB_VID)
and (
dev.idProduct
in [
SPIKE_PRIME_USB_PID,
SPIKE_ESSENTIAL_USB_PID,
MINDSTORMS_INVENTOR_USB_PID,
]
)
and dev.product.endswith("Pybricks")
)

device_or_address = find_usb(custom_match=is_pybricks_usb)

if device_or_address is not None:
hub = PybricksHubUSB(device_or_address)
else:
from pybricksdev.connections.lego import REPLHub

hub = REPLHub()
else:
raise ValueError(f"Unknown connection type: {args.conntype}")

Expand Down
148 changes: 148 additions & 0 deletions pybricksdev/connections/pybricks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import os
import struct
from typing import Awaitable, Callable, List, Optional, TypeVar
from uuid import UUID

import reactivex.operators as op
import semver
Expand All @@ -17,6 +18,10 @@
from reactivex.subject import BehaviorSubject, Subject
from tqdm.auto import tqdm
from tqdm.contrib.logging import logging_redirect_tqdm
from usb.control import get_descriptor
from usb.core import Device as USBDevice
from usb.core import Endpoint, USBTimeoutError
from usb.util import ENDPOINT_IN, ENDPOINT_OUT, endpoint_direction, find_descriptor

from pybricksdev.ble.lwp3.bytecodes import HubKind
from pybricksdev.ble.nus import NUS_RX_UUID, NUS_TX_UUID
Expand All @@ -38,6 +43,10 @@
from pybricksdev.connections import ConnectionState
from pybricksdev.tools import chunk
from pybricksdev.tools.checksum import xor_bytes
from pybricksdev.usb.pybricks import (
PybricksUsbInEpMessageType,
PybricksUsbOutEpMessageType,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -707,3 +716,142 @@ async def write_gatt_char(self, uuid: str, data, response: bool) -> None:

async def start_notify(self, uuid: str, callback: Callable) -> None:
return await self._client.start_notify(uuid, callback)


class PybricksHubUSB(PybricksHub):
_device: USBDevice
_ep_in: Endpoint
_ep_out: Endpoint
_notify_callbacks: dict[str, Callable] = {}
_monitor_task: asyncio.Task

def __init__(self, device: USBDevice):
super().__init__()
self._device = device
self._response_queue = asyncio.Queue[bytes]()

async def _client_connect(self) -> bool:
# Reset is essential to ensure endpoints are in a good state.
self._device.reset()
self._device.set_configuration()

# Save input and output endpoints
cfg = self._device.get_active_configuration()
intf = cfg[(0, 0)]
self._ep_in = find_descriptor(
intf,
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
== ENDPOINT_IN,
)
self._ep_out = find_descriptor(
intf,
custom_match=lambda e: endpoint_direction(e.bEndpointAddress)
== ENDPOINT_OUT,
)

# There is 1 byte overhead for PybricksUsbMessageType
self._max_write_size = self._ep_out.wMaxPacketSize - 1

# Get length of BOS descriptor
bos_descriptor = get_descriptor(self._device, 5, 0x0F, 0)
(ofst, _, bos_len, _) = struct.unpack("<BBHB", bos_descriptor)

# Get full BOS descriptor
bos_descriptor = get_descriptor(self._device, bos_len, 0x0F, 0)

while ofst < bos_len:
(len, desc_type, cap_type) = struct.unpack_from(
"<BBB", bos_descriptor, offset=ofst
)

if desc_type != 0x10:
logger.error("Expected Device Capability descriptor")
exit(1)

# Look for platform descriptors
if cap_type == 0x05:
uuid_bytes = bos_descriptor[ofst + 4 : ofst + 4 + 16]
uuid_str = str(UUID(bytes_le=bytes(uuid_bytes)))

if uuid_str == FW_REV_UUID:
fw_version = bytearray(bos_descriptor[ofst + 20 : ofst + len])
self.fw_version = Version(fw_version.decode())

elif uuid_str == SW_REV_UUID:
self._protocol_version = bytearray(
bos_descriptor[ofst + 20 : ofst + len]
)

elif uuid_str == PYBRICKS_HUB_CAPABILITIES_UUID:
caps = bytearray(bos_descriptor[ofst + 20 : ofst + len])
(
_,
self._capability_flags,
self._max_user_program_size,
) = unpack_hub_capabilities(caps)

ofst += len

self._monitor_task = asyncio.create_task(self._monitor_usb())

return True

async def _client_disconnect(self) -> bool:
self._monitor_task.cancel()
self._handle_disconnect()

async def read_gatt_char(self, uuid: str) -> bytearray:
# Most stuff is available via other properties due to reading BOS
# descriptor during connect.
raise NotImplementedError

async def write_gatt_char(self, uuid: str, data, response: bool) -> None:
if uuid.lower() != PYBRICKS_COMMAND_EVENT_UUID:
raise ValueError("Only Pybricks command event UUID is supported")

if not response:
raise ValueError("Response is required for USB")

self._ep_out.write(bytes([PybricksUsbOutEpMessageType.COMMAND]) + data)
# FIXME: This needs to race with hub disconnect, and could also use a
# timeout, otherwise it blocks forever. Pyusb doesn't currently seem to
# have any disconnect callback.
reply = await self._response_queue.get()

# REVISIT: could look up status error code and convert to string,
# although BLE doesn't do that either.
if int.from_bytes(reply[:4], "little") != 0:
raise RuntimeError(f"Write failed: {reply[0]}")

async def start_notify(self, uuid: str, callback: Callable) -> None:
# TODO: need to send subscribe message over USB
self._notify_callbacks[uuid] = callback

async def _monitor_usb(self):
loop = asyncio.get_running_loop()

while True:
msg = await loop.run_in_executor(None, self._read_usb)

if msg is None:
continue

if len(msg) == 0:
logger.warning("Empty USB message")
continue

if msg[0] == PybricksUsbInEpMessageType.RESPONSE:
self._response_queue.put_nowait(msg[1:])
elif msg[0] == PybricksUsbInEpMessageType.EVENT:
callback = self._notify_callbacks.get(PYBRICKS_COMMAND_EVENT_UUID)
if callback:
callback(None, msg[1:])
else:
logger.warning("Unknown USB message type: %d", msg[0])

def _read_usb(self) -> bytes | None:
try:
msg = self._ep_in.read(self._ep_in.wMaxPacketSize)
return msg
except USBTimeoutError:
return None
3 changes: 3 additions & 0 deletions pybricksdev/usb/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,8 @@
EV3_USB_PID = 0x0005
EV3_BOOTLOADER_USB_PID = 0x0006
SPIKE_PRIME_DFU_USB_PID = 0x0008
SPIKE_PRIME_USB_PID = 0x0009
SPIKE_ESSENTIAL_DFU_USB_PID = 0x000C
SPIKE_ESSENTIAL_USB_PID = 0x000D
MINDSTORMS_INVENTOR_USB_PID = 0x0010
MINDSTORMS_INVENTOR_DFU_USB_PID = 0x0011
30 changes: 30 additions & 0 deletions pybricksdev/usb/pybricks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
# SPDX-License-Identifier: MIT
# Copyright (c) 2025 The Pybricks Authors

"""
Pybricks-specific USB protocol.
"""

from enum import IntEnum


class PybricksUsbInEpMessageType(IntEnum):
RESPONSE = 1
"""
Analogous to BLE status response.
"""
EVENT = 2
"""
Analogous to BLE notification.
"""


class PybricksUsbOutEpMessageType(IntEnum):
SUBSCRIBE = 1
"""
Analogous to BLE Client Characteristic Configuration Descriptor (CCCD).
"""
COMMAND = 2
"""
Analogous to BLE write without response.
"""

0 comments on commit b9e8e1e

Please sign in to comment.