Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cmd/errtrace: Add toolexec mode for automatic rewriting #90

Merged
merged 14 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 79 additions & 41 deletions cmd/errtrace/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ func main() {
Stdin: os.Stdin,
Stderr: os.Stderr,
Stdout: os.Stdout,
Getenv: os.Getenv,
}

os.Exit(cmd.Run(os.Args[1:]))
}

Expand Down Expand Up @@ -98,8 +100,6 @@ func (p *mainParams) Parse(w io.Writer, args []string) error {
flag.BoolVar(&p.List, "l", false,
"list files that would be modified without making any changes.")

// TODO: toolexec mode

if err := flag.Parse(args); err != nil {
return errtrace.Wrap(err)
}
Expand Down Expand Up @@ -176,13 +176,18 @@ type mainCmd struct {
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
Getenv func(string) string

log *log.Logger
}

func (cmd *mainCmd) Run(args []string) (exitCode int) {
cmd.log = log.New(cmd.Stderr, "", 0)

if exitCode, ok := cmd.handleToolExec(args); ok {
return exitCode
}

var p mainParams
if err := p.Parse(cmd.Stderr, args); err != nil {
if errors.Is(err, flag.ErrHelp) {
Expand Down Expand Up @@ -346,18 +351,65 @@ type fileRequest struct {
// The collected information is used to pick a package name,
// whether we need an import, etc. and *then* the edits are applied.
func (cmd *mainCmd) processFile(r fileRequest) error {
fset := token.NewFileSet()

src, err := cmd.readFile(r)
if err != nil {
return errtrace.Wrap(err)
}

f, err := parser.ParseFile(fset, r.Filename, src, parser.ParseComments)
parsed, err := cmd.parseFile(r.Filename, src)
if err != nil {
return errtrace.Wrap(err)
}

for _, line := range parsed.unusedOptouts {
cmd.log.Printf("%s:%d:unused errtrace:skip", r.Filename, line)
}
if r.List {
if len(parsed.inserts) > 0 {
_, err = fmt.Fprintf(cmd.Stdout, "%s\n", r.Filename)
}
return errtrace.Wrap(err)
}

out := bytes.NewBuffer(nil)
if err := cmd.rewriteFile(parsed, out); err != nil {
return errtrace.Wrap(err)
}

outSrc := out.Bytes()
if r.Format {
outSrc, err = gofmt.Source(outSrc)
if err != nil {
return errtrace.Wrap(fmt.Errorf("format: %w", err))
}
}

if r.Write {
err = os.WriteFile(r.Filename, outSrc, 0o644)
} else {
_, err = cmd.Stdout.Write(outSrc)
}
return errtrace.Wrap(err)
}

type file struct {
src []byte
fset *token.FileSet
file *ast.File

errtracePkg string
importsErrtrace bool
inserts []insert
unusedOptouts []int
}

func (cmd *mainCmd) parseFile(filename string, src []byte) (file, error) {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
if err != nil {
return file{}, errtrace.Wrap(err)
}

errtracePkg := "errtrace" // name to use for errtrace package
var importsErrtrace bool // whether the file imports errtrace already
for _, imp := range f.Imports {
Expand Down Expand Up @@ -409,25 +461,15 @@ func (cmd *mainCmd) processFile(r fileRequest) error {
ast.Walk(&w, f)

// Look for unused optouts and warn about them.
var unusedOptouts []int
if len(w.optouts) > 0 {
unusedOptouts := make([]int, 0, len(w.optouts))
unusedOptouts = make([]int, 0, len(w.optouts))
for line, used := range w.optouts {
if used == 0 {
unusedOptouts = append(unusedOptouts, line)
}
}
sort.Ints(unusedOptouts)

for _, line := range unusedOptouts {
cmd.log.Printf("%s:%d:unused errtrace:skip", r.Filename, line)
}
}

if r.List {
if len(inserts) > 0 {
_, err = fmt.Fprintf(cmd.Stdout, "%s\n", r.Filename)
}
return errtrace.Wrap(err)
}

// If errtrace isn't imported, but at least one insert was made,
Expand Down Expand Up @@ -487,13 +529,23 @@ func (cmd *mainCmd) processFile(r fileRequest) error {
return inserts[i].Pos() < inserts[j].Pos()
})

out := bytes.NewBuffer(nil)
return file{
src: src,
fset: fset,
file: f,
errtracePkg: errtracePkg,
importsErrtrace: importsErrtrace,
inserts: inserts,
unusedOptouts: unusedOptouts,
}, nil
}

func (cmd *mainCmd) rewriteFile(f file, out *bytes.Buffer) error {
var lastOffset int
filePos := fset.File(f.Pos()) // position information for this file
for _, it := range inserts {
filePos := f.fset.File(f.file.Pos()) // position information for this file
for _, it := range f.inserts {
offset := filePos.Offset(it.Pos())
_, _ = out.Write(src[lastOffset:offset])
_, _ = out.Write(f.src[lastOffset:offset])
lastOffset = offset

switch it := it.(type) {
Expand All @@ -503,15 +555,15 @@ func (cmd *mainCmd) processFile(r fileRequest) error {
_, _ = io.WriteString(out, "import ")
}

if errtracePkg == "errtrace" {
if f.errtracePkg == "errtrace" {
// Don't use named imports if we're using the default name.
fmt.Fprintf(out, "%q", "braces.dev/errtrace")
} else {
fmt.Fprintf(out, "%s %q", errtracePkg, "braces.dev/errtrace")
fmt.Fprintf(out, "%s %q", f.errtracePkg, "braces.dev/errtrace")
}

case *insertWrapOpen:
fmt.Fprintf(out, "%s.Wrap", errtracePkg)
fmt.Fprintf(out, "%s.Wrap", f.errtracePkg)
if it.N > 1 {
fmt.Fprintf(out, "%d", it.N)
}
Expand All @@ -536,30 +588,16 @@ func (cmd *mainCmd) processFile(r fileRequest) error {
if i > 0 {
_, _ = out.WriteString(", ")
}
fmt.Fprintf(out, "%s.Wrap(%s)", errtracePkg, name)
fmt.Fprintf(out, "%s.Wrap(%s)", f.errtracePkg, name)
}
_, _ = out.WriteString("; ")

default:
cmd.log.Panicf("unhandled insertion type %T", it)
}
}
_, _ = out.Write(src[lastOffset:]) // flush remaining

outSrc := out.Bytes()
if r.Format {
outSrc, err = gofmt.Source(outSrc)
if err != nil {
return errtrace.Wrap(fmt.Errorf("format: %w", err))
}
}

if r.Write {
err = os.WriteFile(r.Filename, outSrc, 0o644)
} else {
_, err = cmd.Stdout.Write(outSrc)
}
return errtrace.Wrap(err)
_, _ = out.Write(f.src[lastOffset:]) // flush remaining
return nil
}

func (cmd *mainCmd) readFile(r fileRequest) ([]byte, error) {
Expand Down
7 changes: 7 additions & 0 deletions cmd/errtrace/testdata/main/foo/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package foo

import "errors"

func Foo() error {
return errors.New("test")
}
5 changes: 5 additions & 0 deletions cmd/errtrace/testdata/main/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module braces.dev/errtrace/cmd/errtrace/testdata/main

go 1.21.4

require braces.dev/errtrace v0.3.0
2 changes: 2 additions & 0 deletions cmd/errtrace/testdata/main/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
braces.dev/errtrace v0.3.0 h1:pzfd6LcWgfWtXLaNFWRnxV/7NP+FSOlIjRLwDuHfPxs=
braces.dev/errtrace v0.3.0/go.mod h1:YQpXdo+u5iimgQdZzFoic8AjedEDncXGpp6/2SfazzI=
13 changes: 13 additions & 0 deletions cmd/errtrace/testdata/main/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package main

import (
"fmt"

"braces.dev/errtrace/cmd/errtrace/testdata/main/foo"
)

func main() {
if err := foo.Foo(); err != nil {
fmt.Printf("%+v\n", err)
}
}
4 changes: 4 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/errtrace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package main

// Opt-in to errtrace wrapping with toolexec.
import _ "braces.dev/errtrace"
17 changes: 17 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package main

import (
"fmt"

"braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p1"
)

func main() {
if err := callP1(); err != nil {
fmt.Printf("%+v\n", err)
}
}

func callP1() error {
return p1.WrapP2() // @trace
}
12 changes: 12 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/p1/p1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package p1

import (
"fmt"

"braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p2"
)

// WrapP2 wraps an error return from p2.
func WrapP2() error {
return fmt.Errorf("test2: %w", p2.ReturnErr())
}
4 changes: 4 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/p2/errtrace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package p2

// Opt-in to errtrace wrapping with toolexec.
import _ "braces.dev/errtrace"
12 changes: 12 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/p2/p2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package p2

import (
"errors"

"braces.dev/errtrace"
)

// ReturnErr returns an error.
func ReturnErr() error {
return errtrace.Wrap(errors.New("test")) // @trace
}
Loading
Loading