From fda6a7cdd3d98e51eb6fa28e784bae4a99a466d7 Mon Sep 17 00:00:00 2001 From: Kishikawa Katsumi Date: Wed, 4 Dec 2024 06:25:52 +0900 Subject: [PATCH] Make requests and responses typed --- .../DirectoryStructure.swift | 6 +- Sources/SMBClient/Messages/Close.swift | 6 +- Sources/SMBClient/Messages/Create.swift | 6 +- Sources/SMBClient/Messages/Echo.swift | 6 +- Sources/SMBClient/Messages/IOCtl.swift | 6 +- Sources/SMBClient/Messages/Logoff.swift | 6 +- Sources/SMBClient/Messages/Message.swift | 13 +++ Sources/SMBClient/Messages/Negotiate.swift | 6 +- .../SMBClient/Messages/QueryDirectory.swift | 8 +- Sources/SMBClient/Messages/QueryInfo.swift | 6 +- .../Read+CustomDebugStringConvertible.swift | 2 +- Sources/SMBClient/Messages/Read.swift | 6 +- Sources/SMBClient/Messages/SessionSetup.swift | 6 +- Sources/SMBClient/Messages/SetInfo.swift | 6 +- Sources/SMBClient/Messages/TreeConnect.swift | 6 +- .../SMBClient/Messages/TreeDisconnect.swift | 6 +- .../Write+CustomDebugStringConvertible.swift | 55 +++++++++ Sources/SMBClient/Messages/Write.swift | 6 +- Sources/SMBClient/Session.swift | 109 +++++++----------- 19 files changed, 171 insertions(+), 100 deletions(-) create mode 100644 Sources/SMBClient/Messages/Message.swift create mode 100644 Sources/SMBClient/Messages/Write+CustomDebugStringConvertible.swift diff --git a/Examples/FileBrowser/FileBrowser (macOS)/DirectoryStructure.swift b/Examples/FileBrowser/FileBrowser (macOS)/DirectoryStructure.swift index 36ed9aa..9172588 100644 --- a/Examples/FileBrowser/FileBrowser (macOS)/DirectoryStructure.swift +++ b/Examples/FileBrowser/FileBrowser (macOS)/DirectoryStructure.swift @@ -93,7 +93,7 @@ class DirectoryStructure { } func update(_ outlineView: NSOutlineView) { - guard let rootNodes: [FileNode] = DataRepository.shared.nodes(join(server, path)) else { + guard let rootNodes: [FileNode] = DataRepository.shared.nodes(join(server, treeAccessor.share, path)) else { return } @@ -102,7 +102,7 @@ class DirectoryStructure { return $0.isDirectory && outlineView.isItemExpanded($0) } .reduce(into: [FileNode]()) { - guard let nodes: [FileNode] = DataRepository.shared.nodes(join(server, $1.path)) else { + guard let nodes: [FileNode] = DataRepository.shared.nodes(join(server, treeAccessor.share, $1.path)) else { return } let parent = $1.id @@ -211,7 +211,7 @@ class DirectoryStructure { let nodes = files .map { FileNode(path: join(path, $0.name), file: $0, parent: parent?.id) } - DataRepository.shared.set(join(server, path), nodes: nodes) + DataRepository.shared.set(join(server, treeAccessor.share, path), nodes: nodes) return nodes } diff --git a/Sources/SMBClient/Messages/Close.swift b/Sources/SMBClient/Messages/Close.swift index 7c275fc..a19f7d3 100644 --- a/Sources/SMBClient/Messages/Close.swift +++ b/Sources/SMBClient/Messages/Close.swift @@ -1,7 +1,9 @@ import Foundation public enum Close { - public struct Request { + public struct Request: Message.Request { + public typealias Response = Close.Response + public let header: Header public let structureSize: UInt16 public let flags: UInt16 @@ -44,7 +46,7 @@ public enum Close { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let flags: UInt16 diff --git a/Sources/SMBClient/Messages/Create.swift b/Sources/SMBClient/Messages/Create.swift index cf8c76e..943a8a1 100644 --- a/Sources/SMBClient/Messages/Create.swift +++ b/Sources/SMBClient/Messages/Create.swift @@ -1,7 +1,9 @@ import Foundation public enum Create { - public struct Request { + public struct Request: Message.Request { + public typealias Response = Create.Response + public private(set) var header: Header public let structureSize: UInt16 public let securityFlags: UInt8 @@ -93,7 +95,7 @@ public enum Create { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let oplockLevel: UInt8 diff --git a/Sources/SMBClient/Messages/Echo.swift b/Sources/SMBClient/Messages/Echo.swift index c9243bb..e2336db 100644 --- a/Sources/SMBClient/Messages/Echo.swift +++ b/Sources/SMBClient/Messages/Echo.swift @@ -1,7 +1,9 @@ import Foundation public enum Echo { - public struct Request { + public struct Request: Message.Request { + public typealias Response = Echo.Response + public let header: Header public let structureSize: UInt16 public let reserved: UInt16 @@ -37,7 +39,7 @@ public enum Echo { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let reserved: UInt16 diff --git a/Sources/SMBClient/Messages/IOCtl.swift b/Sources/SMBClient/Messages/IOCtl.swift index 9ac82d7..da7debc 100644 --- a/Sources/SMBClient/Messages/IOCtl.swift +++ b/Sources/SMBClient/Messages/IOCtl.swift @@ -1,7 +1,9 @@ import Foundation public enum IOCtl { - public struct Request { + public struct Request: Message.Request { + public typealias Response = IOCtl.Response + public let header: Header public let structureSize: UInt16 public let reserved: UInt16 @@ -75,7 +77,7 @@ public enum IOCtl { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let reserved: UInt16 diff --git a/Sources/SMBClient/Messages/Logoff.swift b/Sources/SMBClient/Messages/Logoff.swift index 16ad1f5..aabc481 100644 --- a/Sources/SMBClient/Messages/Logoff.swift +++ b/Sources/SMBClient/Messages/Logoff.swift @@ -1,7 +1,9 @@ import Foundation public enum Logoff { - public struct Request { + public struct Request: Message.Request { + public typealias Response = Logoff.Response + public let header: Header public let structureSize: UInt16 public let reserved: UInt16 @@ -36,7 +38,7 @@ public enum Logoff { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let reserved: UInt16 diff --git a/Sources/SMBClient/Messages/Message.swift b/Sources/SMBClient/Messages/Message.swift new file mode 100644 index 0000000..9186b49 --- /dev/null +++ b/Sources/SMBClient/Messages/Message.swift @@ -0,0 +1,13 @@ +import Foundation + +public enum Message { + public protocol Request { + associatedtype Response: Message.Response + func encoded() -> Data + } + + public protocol Response { + var header: Header { get } + init(data: Data) + } +} diff --git a/Sources/SMBClient/Messages/Negotiate.swift b/Sources/SMBClient/Messages/Negotiate.swift index 69014c4..94229b3 100644 --- a/Sources/SMBClient/Messages/Negotiate.swift +++ b/Sources/SMBClient/Messages/Negotiate.swift @@ -1,7 +1,9 @@ import Foundation public enum Negotiate { - public struct Request { + public struct Request: Message.Request { + public typealias Response = Negotiate.Response + public let header: Header public let structureSize: UInt16 public let dialectCount: UInt16 @@ -65,7 +67,7 @@ public enum Negotiate { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let securityMode: SecurityMode diff --git a/Sources/SMBClient/Messages/QueryDirectory.swift b/Sources/SMBClient/Messages/QueryDirectory.swift index e7493e0..29b33ac 100644 --- a/Sources/SMBClient/Messages/QueryDirectory.swift +++ b/Sources/SMBClient/Messages/QueryDirectory.swift @@ -1,7 +1,9 @@ import Foundation public enum QueryDirectory { - public struct Request { + public struct Request: Message.Request { + public typealias Response = QueryDirectory.Response + public let header: Header public let structureSize: UInt16 public let fileInformationClass: FileInformationClass @@ -67,7 +69,7 @@ public enum QueryDirectory { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let outputBufferOffset: UInt16 @@ -82,7 +84,7 @@ public enum QueryDirectory { structureSize = reader.read() outputBufferOffset = reader.read() outputBufferLength = reader.read() - buffer = data[UInt32(outputBufferOffset).. [FileDirectoryInformation] { diff --git a/Sources/SMBClient/Messages/QueryInfo.swift b/Sources/SMBClient/Messages/QueryInfo.swift index 39dafe5..17f1969 100644 --- a/Sources/SMBClient/Messages/QueryInfo.swift +++ b/Sources/SMBClient/Messages/QueryInfo.swift @@ -1,7 +1,9 @@ import Foundation public enum QueryInfo { - public struct Request { + public struct Request: Message.Request { + public typealias Response = QueryInfo.Response + public let header: Header public let structureSize: UInt16 public let infoType: InfoType @@ -69,7 +71,7 @@ public enum QueryInfo { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let outputBufferOffset: UInt16 diff --git a/Sources/SMBClient/Messages/Read+CustomDebugStringConvertible.swift b/Sources/SMBClient/Messages/Read+CustomDebugStringConvertible.swift index 5743fd1..8a3e40b 100644 --- a/Sources/SMBClient/Messages/Read+CustomDebugStringConvertible.swift +++ b/Sources/SMBClient/Messages/Read+CustomDebugStringConvertible.swift @@ -32,7 +32,7 @@ extension Read.Response: CustomDebugStringConvertible { Blob Length: \(dataLength) Data Rmaining: \(dataRemaining) Reserved2: \(String(format: "%08x", reserved2)) - Data: \(buffer.hex) + Data: \(buffer) """ } } diff --git a/Sources/SMBClient/Messages/Read.swift b/Sources/SMBClient/Messages/Read.swift index e417607..57b5398 100644 --- a/Sources/SMBClient/Messages/Read.swift +++ b/Sources/SMBClient/Messages/Read.swift @@ -1,7 +1,9 @@ import Foundation public enum Read { - public struct Request { + public struct Request: Message.Request { + public typealias Response = Read.Response + public let header: Header public let structureSize: UInt16 public let padding: UInt8 @@ -72,7 +74,7 @@ public enum Read { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let dataOffset: UInt8 diff --git a/Sources/SMBClient/Messages/SessionSetup.swift b/Sources/SMBClient/Messages/SessionSetup.swift index 12bfd2b..54abcf2 100644 --- a/Sources/SMBClient/Messages/SessionSetup.swift +++ b/Sources/SMBClient/Messages/SessionSetup.swift @@ -1,7 +1,9 @@ import Foundation public enum SessionSetup { - public struct Request { + public struct Request: Message.Request { + public typealias Response = SessionSetup.Response + public let header: Header public let structureSize: UInt16 public let flags: Flags @@ -60,7 +62,7 @@ public enum SessionSetup { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 diff --git a/Sources/SMBClient/Messages/SetInfo.swift b/Sources/SMBClient/Messages/SetInfo.swift index 2b5f78a..6b589f4 100644 --- a/Sources/SMBClient/Messages/SetInfo.swift +++ b/Sources/SMBClient/Messages/SetInfo.swift @@ -1,7 +1,9 @@ import Foundation enum SetInfo { - public struct Request { + public struct Request: Message.Request { + public typealias Response = SetInfo.Response + public let header: Header public let structureSize: UInt16 public let infoType: InfoType @@ -69,7 +71,7 @@ enum SetInfo { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 diff --git a/Sources/SMBClient/Messages/TreeConnect.swift b/Sources/SMBClient/Messages/TreeConnect.swift index c25c7a8..40ec0f4 100644 --- a/Sources/SMBClient/Messages/TreeConnect.swift +++ b/Sources/SMBClient/Messages/TreeConnect.swift @@ -1,7 +1,9 @@ import Foundation public enum TreeConnect { - public struct Request { + public struct Request: Message.Request { + public typealias Response = TreeConnect.Response + public let header: Header public let structureSize: UInt16 public let reserved: UInt16 @@ -48,7 +50,7 @@ public enum TreeConnect { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let shareType: UInt8 diff --git a/Sources/SMBClient/Messages/TreeDisconnect.swift b/Sources/SMBClient/Messages/TreeDisconnect.swift index 2261fbb..ace3b81 100644 --- a/Sources/SMBClient/Messages/TreeDisconnect.swift +++ b/Sources/SMBClient/Messages/TreeDisconnect.swift @@ -1,7 +1,9 @@ import Foundation public enum TreeDisconnect { - public struct Request { + public struct Request: Message.Request { + public typealias Response = TreeDisconnect.Response + public let header: Header public let structureSize: UInt16 public let reserved: UInt16 @@ -37,7 +39,7 @@ public enum TreeDisconnect { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let reserved: UInt16 diff --git a/Sources/SMBClient/Messages/Write+CustomDebugStringConvertible.swift b/Sources/SMBClient/Messages/Write+CustomDebugStringConvertible.swift new file mode 100644 index 0000000..42937a5 --- /dev/null +++ b/Sources/SMBClient/Messages/Write+CustomDebugStringConvertible.swift @@ -0,0 +1,55 @@ +import Foundation + +extension Write.Request: CustomDebugStringConvertible { + public var debugDescription: String { + """ + \(header) + Write Request (\(String(format: "0x%02x", header.command))) + StructureSize: \(structureSize) + Data Offset: \(dataOffset) + Write Length: \(length) + File Offset: \(offset) + GUID handle: \(fileId.to(type: UUID.self)) + Channel: 0x\(String(format: "%08x", channel)) + Remaining Bytes: \(remainingBytes) + Write Flags: \(flags) + Blob Offset: \(writeChannelInfoOffset) + Blob Length: \(writeChannelInfoLength) + Channel Info Blob: \(buffer.hex) + """ + } +} + +extension Write.Response: CustomDebugStringConvertible { + public var debugDescription: String { + """ + \(header) + Read Response (\(String(format: "0x%02x", header.command))) + StructureSize: \(structureSize) + Reserved: \(String(format: "%04x", reserved)) + Write Count: \(count) + Write Remaining: \(remaining) + Channel Info Offset: \(writeChannelInfoOffset) + Channel Info Length: \(writeChannelInfoLength) + """ + } +} + +extension Write.Flags: CustomDebugStringConvertible { + public var debugDescription: String { + var values = [String]() + + if contains(.writeThrough) { + values.append("Write Through") + } + if contains(.writeUnbuffered) { + values.append("Unbuffered") + } + + if values.isEmpty { + return "0x\(String(format: "%08x", rawValue))" + } else { + return "0x\(String(format: "%08x", rawValue)) (\(values.joined(separator: ", ")))" + } + } +} diff --git a/Sources/SMBClient/Messages/Write.swift b/Sources/SMBClient/Messages/Write.swift index 062f254..624d175 100644 --- a/Sources/SMBClient/Messages/Write.swift +++ b/Sources/SMBClient/Messages/Write.swift @@ -1,7 +1,9 @@ import Foundation public enum Write { - public struct Request { + public struct Request: Message.Request { + public typealias Response = Write.Response + public let header: Header public let structureSize: UInt16 public let dataOffset: UInt16 @@ -69,7 +71,7 @@ public enum Write { } } - public struct Response { + public struct Response: Message.Response { public let header: Header public let structureSize: UInt16 public let reserved: UInt16 diff --git a/Sources/SMBClient/Session.swift b/Sources/SMBClient/Session.swift index f12ba3d..7c2cc3c 100644 --- a/Sources/SMBClient/Session.swift +++ b/Sources/SMBClient/Session.swift @@ -74,8 +74,7 @@ public class Session { dialects: dialects ) - let data = try await send(request.encoded()) - let response = Negotiate.Response(data: data) + let response = try await send(request) maxTransactSize = response.maxTransactSize maxReadSize = response.maxReadSize @@ -105,7 +104,7 @@ public class Session { previousSessionId: 0, securityBuffer: securityBuffer ) - let response = SessionSetup.Response(data: try await send(request.encoded())) + let response = try await send(request) if NTStatus(response.header.status) == .moreProcessingRequired { let challengeMessage = NTLM.ChallengeMessage(data: response.buffer) @@ -129,8 +128,7 @@ public class Session { securityBuffer: authenticateMessage.encoded() ) - let data = try await send(request.encoded()) - let response = SessionSetup.Response(data: data) + let response = try await send(request) sessionId = response.header.sessionId self.signingKey = signingKey @@ -149,8 +147,7 @@ public class Session { sessionId: sessionId ) - let data = try await send(request.encoded()) - let response = Logoff.Response(data: data) + let response = try await send(request) sessionId = 0 @@ -202,8 +199,7 @@ public class Session { path: #"\\\#(server)\\#(path)"# ) - let data = try await send(request.encoded()) - let response = TreeConnect.Response(data: data) + let response = try await send(request) treeId = response.header.treeId connectedTree = path @@ -219,8 +215,7 @@ public class Session { sessionId: sessionId ) - let data = try await send(request.encoded()) - let response = TreeDisconnect.Response(data: data) + let response = try await send(request) treeId = 0 connectedTree = nil @@ -248,8 +243,7 @@ public class Session { name: name ) - let data = try await send(request.encoded()) - return Create.Response(data: data) + return try await send(request) } public func read(fileId: Data, offset: UInt64) async throws -> Read.Response { @@ -270,8 +264,7 @@ public class Session { length: readSize ) - let response = Read.Response(data: try await send(request.encoded())) - return response + return try await send(request) } @discardableResult @@ -294,8 +287,7 @@ public class Session { data: data ) - let response = Write.Response(data: try await send(request.encoded())) - return response + return try await send(request) } @discardableResult @@ -307,8 +299,7 @@ public class Session { fileId: fileId ) - let data = try await send(request.encoded()) - return Close.Response(data: data) + return try await send(request) } public func queryDirectory(path: String, pattern: String) async throws -> [FileDirectoryInformation] { @@ -340,13 +331,7 @@ public class Session { outputBufferLength: outputBufferLength ) - let data = try await send( - createRequest.encoded(), - queryDirectoryRequest.encoded() - ) - - let createResponse = Create.Response(data: data) - let queryDirectoryResponse = QueryDirectory.Response(data: Data(data[createResponse.header.nextCommand...])) + let (createResponse, queryDirectoryResponse) = try await send(createRequest, queryDirectoryRequest) var files: [FileDirectoryInformation] = queryDirectoryResponse.files() @@ -366,8 +351,7 @@ public class Session { outputBufferLength: outputBufferLength ) - let data = try await send(queryDirectoryRequest.encoded()) - let queryDirectoryResponse = QueryDirectory.Response(data: data) + let queryDirectoryResponse = try await send(queryDirectoryRequest) files.append(contentsOf: queryDirectoryResponse.files()) if NTStatus(queryDirectoryResponse.header.status) == .noMoreFiles { @@ -401,13 +385,8 @@ public class Session { fileId: temporaryUUID ) - let data = try await send( - createRequest.encoded(), - closeRequest.encoded() - ) - - let createResponse = Create.Response(data: data) - return createResponse + let (response, _) = try await send(createRequest, closeRequest) + return response } public func existFile(path: String) async throws -> Bool { @@ -467,15 +446,8 @@ public class Session { fileId: temporaryUUID ) - let data = try await send( - createRequest.encoded(), - queryInfoRequest.encoded(), - closeRequest.encoded() - ) - - let createResponse = Create.Response(data: data) - let queryInfoResponse = QueryInfo.Response(data: Data(data[createResponse.header.nextCommand...])) - return queryInfoResponse + let (_, response, _) = try await send(createRequest, queryInfoRequest, closeRequest) + return response } @discardableResult @@ -535,11 +507,7 @@ public class Session { fileId: temporaryUUID ) - _ = try await send( - createRequest.encoded(), - setInfoRequest.encoded(), - closeRequest.encoded() - ) + _ = try await send(createRequest, setInfoRequest, closeRequest) } public func deleteFile(path: String) async throws { @@ -571,11 +539,7 @@ public class Session { fileId: temporaryUUID ) - _ = try await send( - createRequest.encoded(), - setInfoRequest.encoded(), - closeRequest.encoded() - ) + _ = try await send(createRequest, setInfoRequest, closeRequest) } public func move(from: String, to: String) async throws { @@ -607,11 +571,7 @@ public class Session { fileId: temporaryUUID ) - _ = try await send( - createRequest.encoded(), - setInfoRequest.encoded(), - closeRequest.encoded() - ) + _ = try await send(createRequest, setInfoRequest, closeRequest) } @discardableResult @@ -621,8 +581,7 @@ public class Session { sessionId: sessionId ) - let data = try await send(request.encoded()) - return Echo.Response(data: data) + return try await send(request) } @discardableResult @@ -654,8 +613,7 @@ public class Session { output: Data() ) - let data = try await send(request.encoded()) - return IOCtl.Response(data: data) + return try await send(request) } func netShareEnum(fileId: Data) async throws -> IOCtl.Response { @@ -679,12 +637,29 @@ public class Session { output: Data() ) - let data = try await send(request.encoded()) - return IOCtl.Response(data: data) + return try await send(request) + } + + private func send(_ message: Request) async throws -> Request.Response { + let packet = message.encoded() + let data = try await connection.send(sign(packet)) + let response = Request.Response(data: data) + return response + } + + private func send(_ m1: R1, _ m2: R2) async throws -> (R1.Response, R2.Response) { + let data = try await send(m1.encoded(), m2.encoded()) + let r1 = R1.Response(data: data) + let r2 = R2.Response(data: Data(data[r1.header.nextCommand...])) + return (r1, r2) } - private func send(_ packet: Data) async throws -> Data { - try await connection.send(sign(packet)) + private func send(_ m1: R1, _ m2: R2, _ m3: R3) async throws -> (R1.Response, R2.Response, R3.Response) { + let data = try await send(m1.encoded(), m2.encoded(), m3.encoded()) + let r1 = R1.Response(data: data) + let r2 = R2.Response(data: Data(data[r1.header.nextCommand...])) + let r3 = R3.Response(data: Data(data[r2.header.nextCommand...])) + return (r1, r2, r3) } private func send(_ packets: Data...) async throws -> Data {