diff --git a/Sources/SMBClient/FileReader.swift b/Sources/SMBClient/FileReader.swift index 4b54c3b..25beef2 100644 --- a/Sources/SMBClient/FileReader.swift +++ b/Sources/SMBClient/FileReader.swift @@ -62,6 +62,42 @@ public class FileReader { return buffer } + public func download(toLocalFile localFile: URL, overwrite: Bool, progressHandler: (_ progress: Double) -> Void) async throws { + let fileProxy = try await fileProxy() + + var offset: UInt64 = 0 + + let fileManger = FileManager.default + let filePath = localFile.path + let fileExists = fileManger.fileExists(atPath: filePath) + // If file does not exist, create an empty file so we can create a FileHandle + if !fileExists { + try Data().write(to: localFile) + } + guard let fileHandle = FileHandle(forWritingAtPath: filePath) else { + throw URLError(.cannotWriteToFile) + } + // If file did already exist and we are not overriding, throw an error + if fileExists, !overwrite { + throw CocoaError(.fileWriteFileExists) + } + defer { + fileHandle.closeFile() + } + var response: Read.Response + repeat { + response = try await session.read( + fileId: fileProxy.id, + offset: offset + ) + fileHandle.seekToEndOfFile() + fileHandle.write(response.buffer) + offset += UInt64(response.buffer.count) + let progress = Double(offset) / Double(fileProxy.size) + progressHandler(progress) + } while NTStatus(response.header.status) != .endOfFile && offset < fileProxy.size + } + public func close() async throws { if let createResponse { try await session.close(fileId: createResponse.fileId) diff --git a/Sources/SMBClient/SMBClient.swift b/Sources/SMBClient/SMBClient.swift index bb7cc74..5aeccbc 100644 --- a/Sources/SMBClient/SMBClient.swift +++ b/Sources/SMBClient/SMBClient.swift @@ -125,6 +125,12 @@ public class SMBClient { return data } + public func download(path: String, localPath: URL, overwrite: Bool = false, progressHandler: (_ progress: Double) -> Void = { _ in }) async throws { + let fileReader = fileReader(path: Pathname.normalize(path)) + try await fileReader.download(toLocalFile: localPath, overwrite: overwrite, progressHandler: progressHandler) + try await fileReader.close() + } + public func upload(content: Data, path: String) async throws { try await upload(content: content, path: Pathname.normalize(path), progressHandler: { _ in }) } diff --git a/Tests/SMBClientTests/SMBClientTests.swift b/Tests/SMBClientTests/SMBClientTests.swift index 42d8629..84048b9 100644 --- a/Tests/SMBClientTests/SMBClientTests.swift +++ b/Tests/SMBClientTests/SMBClientTests.swift @@ -379,6 +379,28 @@ final class SMBClientTests: XCTestCase { XCTAssertEqual(data, try Data(contentsOf: fixtureURL.appending(component: "\(user.sharePath)/\(path)"))) } + func testDownloadIntoFile() async throws { + let user = bob + let client = SMBClient(host: "localhost", port: 4445) + try await client.login(username: user.username, password: user.password) + try await client.connectShare(user.share) + + let path = "test_files/file_example_JPG_1MB.jpg" + + var progressWasUpdated: Bool = false + let fileManager = FileManager.default + let tempFolder = URL(fileURLWithPath: NSTemporaryDirectory(), isDirectory: true) + let destinationFile = tempFolder.appending(path: "downloadedfile.jpg", directoryHint: .notDirectory) + try await client.download(path: path, localPath: destinationFile, overwrite: true) { progress in + progressWasUpdated = true + } + + XCTAssertTrue(fileManager.fileExists(atPath: destinationFile.path)) + let data = try Data(contentsOf: destinationFile) + XCTAssertEqual(data, try Data(contentsOf: fixtureURL.appending(component: "\(user.sharePath)/\(path)"))) + XCTAssertTrue(progressWasUpdated) + } + func testRandomRead01() async throws { let user = bob let client = SMBClient(host: "localhost", port: 4445)