Skip to content

Commit

Permalink
cmd/errtrace: Add toolexec mode for automatic rewriting (#90)
Browse files Browse the repository at this point in the history
Fixes #17

The go command supports `-toolexec` to intercept calls to the underlying
compile/link tools. By intercepting compile commands, we can modify the
 .go files passed to it to rewrite as part of the build process.

There are some limitations: we can't add new dependencies to a package,
so this initial version only rewrites packages that already import errtrace,
acting as an opt-in to rewriting.

We automatically determine if we're in toolexec mode based on the
arguments/environment to simplify usage,
```
# pass -toolexec=errtrace, can use absolute paths if errtrace is not in PATH
$ go build -toolexec=errtrace pkg/to/build

# also compatible with go run, which is used by tests
$ go run -toolexec=errtrace pkg/to/run
```
  • Loading branch information
prashantv authored Feb 18, 2024
1 parent bf015bb commit 2fb2821
Show file tree
Hide file tree
Showing 13 changed files with 466 additions and 41 deletions.
120 changes: 79 additions & 41 deletions cmd/errtrace/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,9 @@ func main() {
Stdin: os.Stdin,
Stderr: os.Stderr,
Stdout: os.Stdout,
Getenv: os.Getenv,
}

os.Exit(cmd.Run(os.Args[1:]))
}

Expand Down Expand Up @@ -98,8 +100,6 @@ func (p *mainParams) Parse(w io.Writer, args []string) error {
flag.BoolVar(&p.List, "l", false,
"list files that would be modified without making any changes.")

// TODO: toolexec mode

if err := flag.Parse(args); err != nil {
return errtrace.Wrap(err)
}
Expand Down Expand Up @@ -176,13 +176,18 @@ type mainCmd struct {
Stdin io.Reader
Stdout io.Writer
Stderr io.Writer
Getenv func(string) string

log *log.Logger
}

func (cmd *mainCmd) Run(args []string) (exitCode int) {
cmd.log = log.New(cmd.Stderr, "", 0)

if exitCode, ok := cmd.handleToolExec(args); ok {
return exitCode
}

var p mainParams
if err := p.Parse(cmd.Stderr, args); err != nil {
if errors.Is(err, flag.ErrHelp) {
Expand Down Expand Up @@ -346,18 +351,65 @@ type fileRequest struct {
// The collected information is used to pick a package name,
// whether we need an import, etc. and *then* the edits are applied.
func (cmd *mainCmd) processFile(r fileRequest) error {
fset := token.NewFileSet()

src, err := cmd.readFile(r)
if err != nil {
return errtrace.Wrap(err)
}

f, err := parser.ParseFile(fset, r.Filename, src, parser.ParseComments)
parsed, err := cmd.parseFile(r.Filename, src)
if err != nil {
return errtrace.Wrap(err)
}

for _, line := range parsed.unusedOptouts {
cmd.log.Printf("%s:%d:unused errtrace:skip", r.Filename, line)
}
if r.List {
if len(parsed.inserts) > 0 {
_, err = fmt.Fprintf(cmd.Stdout, "%s\n", r.Filename)
}
return errtrace.Wrap(err)
}

var out bytes.Buffer
if err := cmd.rewriteFile(parsed, &out); err != nil {
return errtrace.Wrap(err)
}

outSrc := out.Bytes()
if r.Format {
outSrc, err = gofmt.Source(outSrc)
if err != nil {
return errtrace.Wrap(fmt.Errorf("format: %w", err))
}
}

if r.Write {
err = os.WriteFile(r.Filename, outSrc, 0o644)
} else {
_, err = cmd.Stdout.Write(outSrc)
}
return errtrace.Wrap(err)
}

type parsedFile struct {
src []byte
fset *token.FileSet
file *ast.File

errtracePkg string
importsErrtrace bool
inserts []insert
unusedOptouts []int // list of line numbers
}

func (cmd *mainCmd) parseFile(filename string, src []byte) (parsedFile, error) {
fset := token.NewFileSet()
f, err := parser.ParseFile(fset, filename, src, parser.ParseComments)
if err != nil {
return parsedFile{}, errtrace.Wrap(err)
}

errtracePkg := "errtrace" // name to use for errtrace package
var importsErrtrace bool // whether the file imports errtrace already
for _, imp := range f.Imports {
Expand Down Expand Up @@ -409,25 +461,15 @@ func (cmd *mainCmd) processFile(r fileRequest) error {
ast.Walk(&w, f)

// Look for unused optouts and warn about them.
var unusedOptouts []int
if len(w.optouts) > 0 {
unusedOptouts := make([]int, 0, len(w.optouts))
unusedOptouts = make([]int, 0, len(w.optouts))
for line, used := range w.optouts {
if used == 0 {
unusedOptouts = append(unusedOptouts, line)
}
}
sort.Ints(unusedOptouts)

for _, line := range unusedOptouts {
cmd.log.Printf("%s:%d:unused errtrace:skip", r.Filename, line)
}
}

if r.List {
if len(inserts) > 0 {
_, err = fmt.Fprintf(cmd.Stdout, "%s\n", r.Filename)
}
return errtrace.Wrap(err)
}

// If errtrace isn't imported, but at least one insert was made,
Expand Down Expand Up @@ -487,13 +529,23 @@ func (cmd *mainCmd) processFile(r fileRequest) error {
return inserts[i].Pos() < inserts[j].Pos()
})

out := bytes.NewBuffer(nil)
return parsedFile{
src: src,
fset: fset,
file: f,
errtracePkg: errtracePkg,
importsErrtrace: importsErrtrace,
inserts: inserts,
unusedOptouts: unusedOptouts,
}, nil
}

func (cmd *mainCmd) rewriteFile(f parsedFile, out *bytes.Buffer) error {
var lastOffset int
filePos := fset.File(f.Pos()) // position information for this file
for _, it := range inserts {
filePos := f.fset.File(f.file.Pos()) // position information for this file
for _, it := range f.inserts {
offset := filePos.Offset(it.Pos())
_, _ = out.Write(src[lastOffset:offset])
_, _ = out.Write(f.src[lastOffset:offset])
lastOffset = offset

switch it := it.(type) {
Expand All @@ -503,15 +555,15 @@ func (cmd *mainCmd) processFile(r fileRequest) error {
_, _ = io.WriteString(out, "import ")
}

if errtracePkg == "errtrace" {
if f.errtracePkg == "errtrace" {
// Don't use named imports if we're using the default name.
fmt.Fprintf(out, "%q", "braces.dev/errtrace")
} else {
fmt.Fprintf(out, "%s %q", errtracePkg, "braces.dev/errtrace")
fmt.Fprintf(out, "%s %q", f.errtracePkg, "braces.dev/errtrace")
}

case *insertWrapOpen:
fmt.Fprintf(out, "%s.Wrap", errtracePkg)
fmt.Fprintf(out, "%s.Wrap", f.errtracePkg)
if it.N > 1 {
fmt.Fprintf(out, "%d", it.N)
}
Expand All @@ -536,30 +588,16 @@ func (cmd *mainCmd) processFile(r fileRequest) error {
if i > 0 {
_, _ = out.WriteString(", ")
}
fmt.Fprintf(out, "%s.Wrap(%s)", errtracePkg, name)
fmt.Fprintf(out, "%s.Wrap(%s)", f.errtracePkg, name)
}
_, _ = out.WriteString("; ")

default:
cmd.log.Panicf("unhandled insertion type %T", it)
}
}
_, _ = out.Write(src[lastOffset:]) // flush remaining

outSrc := out.Bytes()
if r.Format {
outSrc, err = gofmt.Source(outSrc)
if err != nil {
return errtrace.Wrap(fmt.Errorf("format: %w", err))
}
}

if r.Write {
err = os.WriteFile(r.Filename, outSrc, 0o644)
} else {
_, err = cmd.Stdout.Write(outSrc)
}
return errtrace.Wrap(err)
_, _ = out.Write(f.src[lastOffset:]) // flush remaining
return nil
}

func (cmd *mainCmd) readFile(r fileRequest) ([]byte, error) {
Expand Down
7 changes: 7 additions & 0 deletions cmd/errtrace/testdata/main/foo/foo.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
package foo

import "errors"

func Foo() error {
return errors.New("test")
}
5 changes: 5 additions & 0 deletions cmd/errtrace/testdata/main/go.mod
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
module braces.dev/errtrace/cmd/errtrace/testdata/main

go 1.21.4

require braces.dev/errtrace v0.3.0
2 changes: 2 additions & 0 deletions cmd/errtrace/testdata/main/go.sum
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
braces.dev/errtrace v0.3.0 h1:pzfd6LcWgfWtXLaNFWRnxV/7NP+FSOlIjRLwDuHfPxs=
braces.dev/errtrace v0.3.0/go.mod h1:YQpXdo+u5iimgQdZzFoic8AjedEDncXGpp6/2SfazzI=
13 changes: 13 additions & 0 deletions cmd/errtrace/testdata/main/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package main

import (
"fmt"

"braces.dev/errtrace/cmd/errtrace/testdata/main/foo"
)

func main() {
if err := foo.Foo(); err != nil {
fmt.Printf("%+v\n", err)
}
}
4 changes: 4 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/errtrace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package main

// Opt-in to errtrace wrapping with toolexec.
import _ "braces.dev/errtrace"
17 changes: 17 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package main

import (
"fmt"

"braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p1"
)

func main() {
if err := callP1(); err != nil {
fmt.Printf("%+v\n", err)
}
}

func callP1() error {
return p1.WrapP2() // @trace
}
12 changes: 12 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/p1/p1.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package p1

import (
"fmt"

"braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p2"
)

// WrapP2 wraps an error return from p2.
func WrapP2() error {
return fmt.Errorf("test2: %w", p2.CallP3())
}
12 changes: 12 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/p2/p2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
package p2

import (
"braces.dev/errtrace"

"braces.dev/errtrace/cmd/errtrace/testdata/toolexec-test/p3"
)

// CallP3 calls p3, and wraps the error.
func CallP3() error {
return errtrace.Wrap(p3.ReturnErr()) // @trace
}
4 changes: 4 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/p3/errtrace.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
package p3

// Opt-in to errtrace wrapping with toolexec.
import _ "braces.dev/errtrace"
10 changes: 10 additions & 0 deletions cmd/errtrace/testdata/toolexec-test/p3/p3.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package p3

import (
"errors"
)

// ReturnErr returns an error.
func ReturnErr() error {
return errors.New("test") // @trace
}
Loading

0 comments on commit 2fb2821

Please sign in to comment.