diff --git a/Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift b/Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift index ac2cdd3..ad7a85a 100644 --- a/Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift +++ b/Sources/EnumeratorMacroImpl/EnumeratorMacroType.swift @@ -82,8 +82,7 @@ extension EnumeratorMacroType: MemberMacro { guard let rendered else { return nil } - let noEmptyLines = rendered.split(separator: "\n").joined(separator: "\n") - return (noEmptyLines, syntax) + return (rendered, syntax) } catch { let message: MacroError let errorSyntax: SyntaxProtocol @@ -168,11 +167,17 @@ extension EnumeratorMacroType: MemberMacro { ($0, result.codeSyntax) } } - let postProcessedSyntaxes = syntaxes.compactMap { + let postProcessedSyntaxes = syntaxes.compactMap { (syntax, codeSyntax) -> DeclSyntax? in - let postProcessor = PostProcessor() - let newSyntax = postProcessor.rewrite(syntax) - guard let declSyntax = DeclSyntax(newSyntax) else { + var processedSyntax = Syntax(syntax) + + let excessiveTriviaRemover = ExcessiveTriviaRemover() + processedSyntax = excessiveTriviaRemover.rewrite(processedSyntax) + + let switchRewriter = SwitchRewriter() + processedSyntax = switchRewriter.rewrite(processedSyntax) + + guard let declSyntax = DeclSyntax(processedSyntax) else { context.diagnose( Diagnostic( node: codeSyntax, diff --git a/Sources/EnumeratorMacroImpl/ExcessiveTriviaRemover.swift b/Sources/EnumeratorMacroImpl/ExcessiveTriviaRemover.swift new file mode 100644 index 0000000..2b30a6e --- /dev/null +++ b/Sources/EnumeratorMacroImpl/ExcessiveTriviaRemover.swift @@ -0,0 +1,45 @@ +import SwiftSyntax + +final class ExcessiveTriviaRemover: SyntaxRewriter { + /// Remove empty lines if there are more than 1 lines stacked together. + override func visitAny(_ node: Syntax) -> Syntax? { + var node = node + + var modifiedLeadingTrivia = false + let newLeadingTrivia = node.leadingTrivia.pieces.map { piece in + if case let .newlines(count) = piece, + count > 1 { + modifiedLeadingTrivia = true + return TriviaPiece.newlines(1) + } else { + return piece + } + } + if modifiedLeadingTrivia { + node = node.with( + \.leadingTrivia, + Trivia(pieces: newLeadingTrivia) + ) + } + + var modifiedTrailingTrivia = false + let newTrailingTrivia = node.trailingTrivia.pieces.map { piece in + if case let .newlines(count) = piece, + count > 1 { + modifiedTrailingTrivia = true + return TriviaPiece.newlines(1) + } else { + return piece + } + } + if modifiedTrailingTrivia { + node = node.with( + \.trailingTrivia, + Trivia(pieces: newTrailingTrivia) + ) + } + + let modified = modifiedLeadingTrivia || modifiedTrailingTrivia + return modified ? self.rewrite(node) : nil + } +} diff --git a/Sources/EnumeratorMacroImpl/PostProcessor.swift b/Sources/EnumeratorMacroImpl/SwitchRewriter.swift similarity index 99% rename from Sources/EnumeratorMacroImpl/PostProcessor.swift rename to Sources/EnumeratorMacroImpl/SwitchRewriter.swift index 8befe98..ea1dbb7 100644 --- a/Sources/EnumeratorMacroImpl/PostProcessor.swift +++ b/Sources/EnumeratorMacroImpl/SwitchRewriter.swift @@ -1,6 +1,6 @@ import SwiftSyntax -final class PostProcessor: SyntaxRewriter { +final class SwitchRewriter: SyntaxRewriter { override func visit(_ node: SwitchCaseSyntax) -> SwitchCaseSyntax { self.removeUnusedLet( self.removeUnusedArguments( diff --git a/Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift b/Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift index c491951..79c82fa 100644 --- a/Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift +++ b/Tests/EnumeratorMacroTests/EnumeratorMacroTests.swift @@ -5,7 +5,7 @@ import SwiftSyntaxMacrosTestSupport import XCTest final class EnumeratorMacroTests: XCTestCase { - func testCreatesCaseName() throws { + func testCreatesCaseName() { assertMacroExpansion( #""" @Enumerator( @@ -15,9 +15,6 @@ final class EnumeratorMacroTests: XCTestCase { {{#cases}} case .{{name}}: "{{name}}" - - - {{/cases}} } } @@ -47,7 +44,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testCreatesACopyOfSelf() throws { + func testCreatesACopyOfSelf() { assertMacroExpansion( #""" @Enumerator(""" @@ -80,7 +77,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testCreatesDeclarationsForCaseChecking() throws { + func testCreatesDeclarationsForCaseChecking() { assertMacroExpansion( #""" @Enumerator(""" @@ -139,7 +136,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testCreatesSubtypeWithMultiMacroArguments() throws { + func testCreatesSubtypeWithMultiMacroArguments() { assertMacroExpansion( #""" @Enumerator(""" @@ -193,7 +190,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testCreatesGetCaseValueFunctions() throws { + func testCreatesGetCaseValueFunctions() { assertMacroExpansion( #""" @Enumerator(""" @@ -245,7 +242,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testProperlyReadsComments() throws { + func testProperlyReadsComments() { assertMacroExpansion( #""" @Enumerator(""" @@ -298,7 +295,50 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testDiagnosesNotAnEnum() throws { + func removesExcessiveTrivia() { + assertMacroExpansion( + #""" + @Enumerator( + """ + var caseName: String { + switch self { + {{#cases}} + case .{{name}}: + "{{name}}" + + + + {{/cases}} + } + } + """ + ) + enum TestEnum { + case a + case b + } + """#, + expandedSource: #""" + enum TestEnum { + case a + case b + + var caseName: String { + switch self { + case .a: + "a" + case .b: + "b" + } + } + } + """#, + macros: EnumeratorMacroEntryPoint.macros + ) + } + + + func testDiagnosesNotAnEnum() { assertMacroExpansion( #""" @Enumerator(""" @@ -333,7 +373,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testDiagnosesNoArguments() throws { + func testDiagnosesNoArguments() { assertMacroExpansion( #""" @Enumerator @@ -362,7 +402,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testDiagnosesEmptyArguments() throws { + func testDiagnosesEmptyArguments() { assertMacroExpansion( #""" @Enumerator @@ -391,7 +431,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testDiagnosesUnacceptableArguments() throws { + func testDiagnosesUnacceptableArguments() { assertMacroExpansion( #""" @Enumerator(myVariable) @@ -420,7 +460,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testDiagnosesStringInterpolationInMustacheTemplate() throws { + func testDiagnosesStringInterpolationInMustacheTemplate() { assertMacroExpansion( #""" @Enumerator(""" @@ -474,7 +514,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testDiagnosesBadMustacheTemplate() throws { + func testDiagnosesBadMustacheTemplate() { assertMacroExpansion( #""" @Enumerator(""" @@ -513,7 +553,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testDiagnosesErroneousSwiftCode() throws { + func testDiagnosesErroneousSwiftCode() { assertMacroExpansion( #""" @Enumerator(""" @@ -640,7 +680,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testRemovesUnusedLetInSwitchStatements() throws { + func testRemovesUnusedLetInSwitchStatements() { assertMacroExpansion( #""" @Enumerator(""" @@ -679,7 +719,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testRemovesArgumentInSwitchStatements() throws { + func testRemovesArgumentInSwitchStatements() { assertMacroExpansion( #""" @Enumerator(""" @@ -718,7 +758,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } - func testRemovesArgumentInSwitchStatementsWithMultipleArgumentsWhereOneArgIsUsed() throws { + func testRemovesArgumentInSwitchStatementsWithMultipleArgumentsWhereOneArgIsUsed() { assertMacroExpansion( #""" @Enumerator(""" @@ -757,7 +797,7 @@ final class EnumeratorMacroTests: XCTestCase { ) } -// func testAppliesFixIts() throws { +// func testAppliesFixIts() { // assertMacroExpansion( // #""" // @Enumerator("""