diff --git a/internal/gen.go b/internal/gen.go index 8ecefbe..fb43f13 100644 --- a/internal/gen.go +++ b/internal/gen.go @@ -77,7 +77,7 @@ func Generate(ctx context.Context, args Args) (*Result, error) { var testingOut []byte if args.ForTesting { - testingOut, err = execTestingTpl(tplData) + testingOut, err = execTestingTpl(tplData, specBcks) if err != nil { return nil, err } @@ -102,7 +102,7 @@ func Generate(ctx context.Context, args Args) (*Result, error) { SpecBackends: specBcks, SelectedNodes: selected.SelectedNodes, TransBackends: selected.TransitiveBackends, - TplData: tplData, + TplData: &tplData, WeldOutput: weldOut, BackendsOutput: bcksOut, TestingOutput: testingOut, diff --git a/internal/gen_test.go b/internal/gen_test.go index 88546d6..b5c6d57 100644 --- a/internal/gen_test.go +++ b/internal/gen_test.go @@ -134,6 +134,12 @@ func TestGenerate(t *testing.T) { Name: "variadic", WorkDir: "example/variadic/state", }, + { + Name: "transitive-testing", + WorkDir: "example/transitive/state", + Tags: "!dev", + ForTesting: true, + }, } for _, test := range tests { diff --git a/internal/template.go b/internal/template.go index 795e836..da7643b 100644 --- a/internal/template.go +++ b/internal/template.go @@ -32,7 +32,7 @@ var ( ) // execWeldTpl returns the generated source of the template data. -func execWeldTpl(data *TplData) ([]byte, error) { +func execWeldTpl(data TplData) ([]byte, error) { var buf bytes.Buffer err := weldTpl.Execute(&buf, data) if err != nil { @@ -53,7 +53,9 @@ func execWeldTpl(data *TplData) ([]byte, error) { return src, nil } -func execTestingTpl(data *TplData) ([]byte, error) { +func execTestingTpl(data TplData, bcks Backends) ([]byte, error) { + data.Deps = filterTransitiveDeps(data, bcks) + var buf bytes.Buffer err := testingTpl.Execute(&buf, data) if err != nil { @@ -74,11 +76,7 @@ func execTestingTpl(data *TplData) ([]byte, error) { return src, nil } -func maybeExecBackendsTpl(tplData *TplData, bcks Backends, genBcks bool) ([]byte, error) { - if !genBcks { - return nil, nil - } - +func filterTransitiveDeps(tplData TplData, bcks Backends) []TplDep { // Remove transitive deps var deps []TplDep for _, dep := range tplData.Deps { @@ -94,11 +92,18 @@ func maybeExecBackendsTpl(tplData *TplData, bcks Backends, genBcks bool) ([]byte } } - clone := *tplData - clone.Deps = deps + return deps +} + +func maybeExecBackendsTpl(tplData TplData, bcks Backends, genBcks bool) ([]byte, error) { + if !genBcks { + return nil, nil + } + + tplData.Deps = filterTransitiveDeps(tplData, bcks) var buf bytes.Buffer - err := bcksTpl.Execute(&buf, clone) + err := bcksTpl.Execute(&buf, tplData) if err != nil { return nil, err } @@ -118,14 +123,14 @@ func maybeExecBackendsTpl(tplData *TplData, bcks Backends, genBcks bool) ([]byte } // makeTplData returns the template data for backends and selected nodes. -func makeTplData(in, out *packages.Package, tags string, selected NodeSelection, specBcks Backends) (*TplData, error) { +func makeTplData(in, out *packages.Package, tags string, selected NodeSelection, specBcks Backends) (TplData, error) { pkgCache := NewPkgCache(in, out) pkgCache.Add(specBcks.Package) unionDeps := union(specBcks, selected.TransitiveBackends) err := sortInDependencyOrder(unionDeps, selected.SelectedNodes, selected.UnselectedTypes) if err != nil { - return nil, errors.Wrap(err, "error sorting in dependency order") + return TplData{}, errors.Wrap(err, "error sorting in dependency order") } // TODO(neil): The deps are now sorted mostly alphabetically, but also in @@ -141,7 +146,7 @@ func makeTplData(in, out *packages.Package, tags string, selected NodeSelection, for _, param := range selected.UnselectedTypes { v, err := type2Param(param) if err != nil { - return nil, err + return TplData{}, err } varMap[param.String()] = v } @@ -161,7 +166,7 @@ func makeTplData(in, out *packages.Package, tags string, selected NodeSelection, for _, dep := range unionDeps { d, err := makeTplDep(pkgCache, selected.SelectedNodes, dep.Getter, dep.Type, varMap) if err != nil { - return nil, err + return TplData{}, err } orig := d.Var @@ -185,20 +190,20 @@ func makeTplData(in, out *packages.Package, tags string, selected NodeSelection, tb, err := makeTplBcks(pkgCache, selected.TransitiveBackends) if err != nil { - return nil, err + return TplData{}, err } bcksTypeRef, err := makeTypeRef(pkgCache, specBcks.Type) if err != nil { - return nil, err + return TplData{}, err } params, err := makeParams(pkgCache, selected.UnselectedTypes) if err != nil { - return nil, err + return TplData{}, err } - return &TplData{ + return TplData{ Package: out.Name, Tags: tags, BackendsName: specBcks.Name, diff --git a/internal/templates/testing.tmpl b/internal/templates/testing.tmpl index 9277bd1..f2478e2 100644 --- a/internal/templates/testing.tmpl +++ b/internal/templates/testing.tmpl @@ -26,7 +26,7 @@ type Testing{{.BackendsName}} struct { {{end}}{{end}} } -{{range .Deps}} +{{range .Deps}}{{if not .IsDuplicate }} func (ti *Testing{{$.BackendsName}}) {{.Getter}}() {{.Type}} { if ti.{{.Var}} != nil { return ti.{{.Var}} @@ -34,19 +34,9 @@ func (ti *Testing{{$.BackendsName}}) {{.Getter}}() {{.Type}} { return ti.{{$.BackendsName}}.{{.Getter}}() } -{{end}} - -{{if .TransBcks -}} -// Transitive dependency interface assertions. -var ( -{{range .TransBcks -}} - _ {{.}} = (*Testing{{$.BackendsName}})(nil) -{{end}} -) -{{end}} - +{{end}}{{end}} -{{range .Deps }} +{{range .Deps }}{{if not .IsDuplicate }} // Set{{.Getter}}ForTesting is a runtime available override for the "{{.Var}}" dependency that should only be used for testing. func (ti *Testing{{$.BackendsName}}) Set{{.Getter}}ForTesting(t *testing.T, {{.Var}} {{.Type}}) { t.Cleanup(func() { @@ -55,4 +45,4 @@ func (ti *Testing{{$.BackendsName}}) Set{{.Getter}}ForTesting(t *testing.T, {{.V ti.{{.Var}} = {{.Var}} } -{{end}} \ No newline at end of file +{{end}}{{end}} \ No newline at end of file diff --git a/internal/testdata/example/transitive/state/testing_gen.go b/internal/testdata/example/transitive/state/testing_gen.go new file mode 100644 index 0000000..93aedcd --- /dev/null +++ b/internal/testdata/example/transitive/state/testing_gen.go @@ -0,0 +1,93 @@ +//go:build !dev + +package state + +// Code generated by weld. DO NOT EDIT. + +import ( + transitive_ops "example/transitive/ops" + "testing" +) + +func NewTestingBackends(b Backends) *TestingBackends { + return &TestingBackends{ + Backends: b, + } +} + +type TestingBackends struct { + Backends + + foo transitive_ops.Foo + qux transitive_ops.Qux + bar transitive_ops.Bar + baz transitive_ops.Baz +} + +func (ti *TestingBackends) Foo() transitive_ops.Foo { + if ti.foo != nil { + return ti.foo + } + + return ti.Backends.Foo() +} + +func (ti *TestingBackends) Qux() transitive_ops.Qux { + if ti.qux != nil { + return ti.qux + } + + return ti.Backends.Qux() +} + +func (ti *TestingBackends) Bar() transitive_ops.Bar { + if ti.bar != nil { + return ti.bar + } + + return ti.Backends.Bar() +} + +func (ti *TestingBackends) Baz() transitive_ops.Baz { + if ti.baz != nil { + return ti.baz + } + + return ti.Backends.Baz() +} + +// SetFooForTesting is a runtime available override for the "foo" dependency that should only be used for testing. +func (ti *TestingBackends) SetFooForTesting(t *testing.T, foo transitive_ops.Foo) { + t.Cleanup(func() { + ti.foo = nil + }) + + ti.foo = foo +} + +// SetQuxForTesting is a runtime available override for the "qux" dependency that should only be used for testing. +func (ti *TestingBackends) SetQuxForTesting(t *testing.T, qux transitive_ops.Qux) { + t.Cleanup(func() { + ti.qux = nil + }) + + ti.qux = qux +} + +// SetBarForTesting is a runtime available override for the "bar" dependency that should only be used for testing. +func (ti *TestingBackends) SetBarForTesting(t *testing.T, bar transitive_ops.Bar) { + t.Cleanup(func() { + ti.bar = nil + }) + + ti.bar = bar +} + +// SetBazForTesting is a runtime available override for the "baz" dependency that should only be used for testing. +func (ti *TestingBackends) SetBazForTesting(t *testing.T, baz transitive_ops.Baz) { + t.Cleanup(func() { + ti.baz = nil + }) + + ti.baz = baz +} diff --git a/internal/testdata/transitive-testing_bcksoutput.golden b/internal/testdata/transitive-testing_bcksoutput.golden new file mode 100644 index 0000000..138b4aa --- /dev/null +++ b/internal/testdata/transitive-testing_bcksoutput.golden @@ -0,0 +1,14 @@ +package state + +// Code generated by weld. DO NOT EDIT. + +import ( + transitive_ops "example/transitive/ops" +) + +type Backends interface { + Foo() transitive_ops.Foo + Qux() transitive_ops.Qux + Bar() transitive_ops.Bar + Baz() transitive_ops.Baz +} diff --git a/internal/testdata/transitive-testing_graph.golden b/internal/testdata/transitive-testing_graph.golden new file mode 100644 index 0000000..72f85df --- /dev/null +++ b/internal/testdata/transitive-testing_graph.golden @@ -0,0 +1,24 @@ +Set[16]: (inline) + Set[12]: var example/backends/providers.WeldProd + Set[6]: var example/backends/providers.GRPC + Set[2]: var example/identity/users/client/grpc.Provider + Func[1]: func example/identity/users/client/grpc.New() (*example/identity/users/client/grpc.client, error) + Bind[1]: example/identity/users.Client(*example/identity/users/client/grpc.client) + Set[2]: var example/identity/email/client/grpc.Provider + Func[1]: func example/identity/email/client/grpc.New() (*example/identity/email/client/grpc.client, error) + Bind[1]: example/identity/email.Client(*example/identity/email/client/grpc.client) + Set[2]: var example/exchange/client/grpc.Provider + Func[1]: func example/exchange/client/grpc.New() (*example/exchange/client/grpc.client, error) + Bind[1]: example/exchange.Client(*example/exchange/client/grpc.client) + Set[3]: var example/backends/providers.DB + Func[1]: func example/identity/email/db.Connect() (*example/identity/email/db.EmailDB, error) + Func[1]: func example/identity/users/db.Connect() (*example/identity/users/db.UsersDB, error) + Func[1]: func example/exchange/db.Connect() (*example/exchange/db.ExchangeDB, error) + Set[3]: var example/backends/providers.External + Func[1]: func example/external/mail.New(opts ...example/external/mail.option) (*example/external/mail.Mailer, error) + Func[1]: func example/external/mail/mail.New() (*example/external/mail/mail.MailerLegacy, error) + Func[1]: func example/external/versioned.New() *example/external/versioned/v1.Service + Func[1]: func example/transitive/ops.NewFoo() (example/transitive/ops.Foo, error) + Func[1]: func example/transitive/ops.NewBar(foo example/transitive/ops.Foo) example/transitive/ops.Bar + Func[1]: func example/transitive/ops.NewBaz(bar example/transitive/ops.Bar, qux example/transitive/ops.Qux) example/transitive/ops.Baz + Func[1]: func example/transitive/ops.NewQux() (example/transitive/ops.Qux, error) diff --git a/internal/testdata/transitive-testing_selected.golden b/internal/testdata/transitive-testing_selected.golden new file mode 100644 index 0000000..9617861 --- /dev/null +++ b/internal/testdata/transitive-testing_selected.golden @@ -0,0 +1,4 @@ +Func[1]: func example/transitive/ops.NewBar(foo example/transitive/ops.Foo) example/transitive/ops.Bar +Func[1]: func example/transitive/ops.NewBaz(bar example/transitive/ops.Bar, qux example/transitive/ops.Qux) example/transitive/ops.Baz +Func[1]: func example/transitive/ops.NewFoo() (example/transitive/ops.Foo, error) +Func[1]: func example/transitive/ops.NewQux() (example/transitive/ops.Qux, error) diff --git a/internal/testdata/transitive-testing_specBack.golden b/internal/testdata/transitive-testing_specBack.golden new file mode 100644 index 0000000..fcf8ba6 --- /dev/null +++ b/internal/testdata/transitive-testing_specBack.golden @@ -0,0 +1 @@ +example/transitive/state.Backends[4]: example/transitive/ops.Bar, example/transitive/ops.Baz, example/transitive/ops.Foo, example/transitive/ops.Qux diff --git a/internal/testdata/transitive-testing_testingoutput.golden b/internal/testdata/transitive-testing_testingoutput.golden new file mode 100644 index 0000000..93aedcd --- /dev/null +++ b/internal/testdata/transitive-testing_testingoutput.golden @@ -0,0 +1,93 @@ +//go:build !dev + +package state + +// Code generated by weld. DO NOT EDIT. + +import ( + transitive_ops "example/transitive/ops" + "testing" +) + +func NewTestingBackends(b Backends) *TestingBackends { + return &TestingBackends{ + Backends: b, + } +} + +type TestingBackends struct { + Backends + + foo transitive_ops.Foo + qux transitive_ops.Qux + bar transitive_ops.Bar + baz transitive_ops.Baz +} + +func (ti *TestingBackends) Foo() transitive_ops.Foo { + if ti.foo != nil { + return ti.foo + } + + return ti.Backends.Foo() +} + +func (ti *TestingBackends) Qux() transitive_ops.Qux { + if ti.qux != nil { + return ti.qux + } + + return ti.Backends.Qux() +} + +func (ti *TestingBackends) Bar() transitive_ops.Bar { + if ti.bar != nil { + return ti.bar + } + + return ti.Backends.Bar() +} + +func (ti *TestingBackends) Baz() transitive_ops.Baz { + if ti.baz != nil { + return ti.baz + } + + return ti.Backends.Baz() +} + +// SetFooForTesting is a runtime available override for the "foo" dependency that should only be used for testing. +func (ti *TestingBackends) SetFooForTesting(t *testing.T, foo transitive_ops.Foo) { + t.Cleanup(func() { + ti.foo = nil + }) + + ti.foo = foo +} + +// SetQuxForTesting is a runtime available override for the "qux" dependency that should only be used for testing. +func (ti *TestingBackends) SetQuxForTesting(t *testing.T, qux transitive_ops.Qux) { + t.Cleanup(func() { + ti.qux = nil + }) + + ti.qux = qux +} + +// SetBarForTesting is a runtime available override for the "bar" dependency that should only be used for testing. +func (ti *TestingBackends) SetBarForTesting(t *testing.T, bar transitive_ops.Bar) { + t.Cleanup(func() { + ti.bar = nil + }) + + ti.bar = bar +} + +// SetBazForTesting is a runtime available override for the "baz" dependency that should only be used for testing. +func (ti *TestingBackends) SetBazForTesting(t *testing.T, baz transitive_ops.Baz) { + t.Cleanup(func() { + ti.baz = nil + }) + + ti.baz = baz +} diff --git a/internal/testdata/transitive-testing_tpldata.golden b/internal/testdata/transitive-testing_tpldata.golden new file mode 100644 index 0000000..3c07348 --- /dev/null +++ b/internal/testdata/transitive-testing_tpldata.golden @@ -0,0 +1,55 @@ +package: state +tags: '!dev' +backendstype: Backends +backendsname: Backends +imports: + example/transitive/ops: + name: transitive_ops + pkgpath: example/transitive/ops + aliased: true +params: [] +deps: +- type: transitive_ops.Foo + var: foo + getter: Foo + isduplicate: false + provider: + funcpkg: transitive_ops + funcname: NewFoo + returnserr: true + params: [] + errwrapmsg: transitive ops new foo +- type: transitive_ops.Qux + var: qux + getter: Qux + isduplicate: false + provider: + funcpkg: transitive_ops + funcname: NewQux + returnserr: true + params: [] + errwrapmsg: transitive ops new qux +- type: transitive_ops.Bar + var: bar + getter: Bar + isduplicate: false + provider: + funcpkg: transitive_ops + funcname: NewBar + returnserr: false + params: + - b.foo + errwrapmsg: transitive ops new bar +- type: transitive_ops.Baz + var: baz + getter: Baz + isduplicate: false + provider: + funcpkg: transitive_ops + funcname: NewBaz + returnserr: false + params: + - b.bar + - b.qux + errwrapmsg: transitive ops new baz +transbcks: [] diff --git a/internal/testdata/transitive-testing_transBack.golden b/internal/testdata/transitive-testing_transBack.golden new file mode 100644 index 0000000..e69de29 diff --git a/internal/testdata/transitive-testing_weldoutput.golden b/internal/testdata/transitive-testing_weldoutput.golden new file mode 100644 index 0000000..93438c5 --- /dev/null +++ b/internal/testdata/transitive-testing_weldoutput.golden @@ -0,0 +1,57 @@ +//go:build !dev + +package state + +// Code generated by weld. DO NOT EDIT. + +import ( + transitive_ops "example/transitive/ops" + + "github.com/luno/jettison/errors" +) + +func MakeBackends() (Backends, error) { + var ( + b backendsImpl + err error + ) + + b.foo, err = transitive_ops.NewFoo() + if err != nil { + return nil, errors.Wrap(err, "transitive ops new foo") + } + + b.qux, err = transitive_ops.NewQux() + if err != nil { + return nil, errors.Wrap(err, "transitive ops new qux") + } + + b.bar = transitive_ops.NewBar(b.foo) + + b.baz = transitive_ops.NewBaz(b.bar, b.qux) + + return &b, nil +} + +type backendsImpl struct { + foo transitive_ops.Foo + qux transitive_ops.Qux + bar transitive_ops.Bar + baz transitive_ops.Baz +} + +func (b *backendsImpl) Foo() transitive_ops.Foo { + return b.foo +} + +func (b *backendsImpl) Qux() transitive_ops.Qux { + return b.qux +} + +func (b *backendsImpl) Bar() transitive_ops.Bar { + return b.bar +} + +func (b *backendsImpl) Baz() transitive_ops.Baz { + return b.baz +}