Skip to content

Commit

Permalink
feedback from PR
Browse files Browse the repository at this point in the history
  • Loading branch information
anpep committed Feb 24, 2025
1 parent 4dfa7ee commit 900d0c7
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 86 deletions.
58 changes: 26 additions & 32 deletions internals/overlord/servstate/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"io"
"maps"
"os"
"os/exec"
"os/user"
Expand Down Expand Up @@ -116,19 +117,15 @@ func (m *ServiceManager) doStart(task *state.Task, tomb *tomb.Tomb) error {
}

currentPlan := m.getPlan()
config, ok := currentPlan.Services[request.Name]
if !ok {
config, configFound := currentPlan.Services[request.Name]
if !configFound {
return fmt.Errorf("cannot find service %q in plan", request.Name)
}

var workload *Workload
if s, ok := currentPlan.Sections[WorkloadsField]; ok {
ws, ok := s.(*WorkloadsSection)
if !ok {
return fmt.Errorf("internal error: invalid section type %T", ws)
}
workload, ok = ws.Entries[config.Workload]
if config.Workload != "" && !ok {
if config.Workload != "" {
ws := currentPlan.Sections[WorkloadsField].(*WorkloadsSection)
if workload = ws.Entries[config.Workload]; workload == nil {
return fmt.Errorf("cannot find workload %q for service %q in plan", config.Workload, request.Name)
}
}
Expand Down Expand Up @@ -184,30 +181,29 @@ func (m *ServiceManager) serviceForStart(config *plan.Service, workload *Workloa
m.servicesLock.Lock()
defer m.servicesLock.Unlock()

var w *Workload
if workload != nil {
w = workload.copy()
}

service = m.services[config.Name]
if service == nil {
// Not already started, create a new service object.
service = &serviceData{
manager: m,
state: stateInitial,
config: config.Copy(),
workload: w,
logs: servicelog.NewRingBuffer(maxLogBytes),
started: make(chan error, 1),
stopped: make(chan error, 2), // enough for killTimeElapsed to send, and exit if it happens after
manager: m,
state: stateInitial,
logs: servicelog.NewRingBuffer(maxLogBytes),
started: make(chan error, 1),
stopped: make(chan error, 2), // enough for killTimeElapsed to send, and exit if it happens after
}
service.config = config.Copy()
if workload != nil {
service.workload = workload.copy()
}
m.services[config.Name] = service
return service, ""
}

// Ensure config is up-to-date from the plan whenever the user starts a service.
service.config = config.Copy()
service.workload = w
if workload != nil {
service.workload = workload.copy()
}

switch service.state {
case stateInitial, stateStarting, stateRunning:
Expand Down Expand Up @@ -358,17 +354,21 @@ func (s *serviceData) startInternal() error {
s.cmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}

// Copy environment to avoid updating original.
environment := make(map[string]string, len(s.config.Environment))
for k, v := range s.config.Environment {
environment[k] = v
var environment map[string]string
if s.workload != nil && len(s.workload.Environment) > 0 {
environment = maps.Clone(s.workload.Environment)
} else if len(s.config.Environment) > 0 {
environment = maps.Clone(s.config.Environment)
} else {
environment = make(map[string]string)
}

s.cmd.Dir = s.config.WorkingDir

// Start as another user if specified in plan.
var uid, gid *int
if s.config.UserID != nil || s.config.GroupID != nil || s.config.User != "" || s.config.Group != "" {
// User/group config from the service takes precedence
// User/group config from the service takes precedence if any of them are set
uid, gid, err = osutil.NormalizeUidGid(s.config.UserID, s.config.GroupID, s.config.User, s.config.Group)
if err != nil {
return err
Expand Down Expand Up @@ -408,12 +408,6 @@ func (s *serviceData) startInternal() error {
}
}

if s.workload != nil && len(s.workload.Environment) != 0 {
for k, v := range s.workload.Environment {
environment[k] = v
}
}

// Pass service description's environment variables to child process.
s.cmd.Env = os.Environ()
for k, v := range environment {
Expand Down
2 changes: 2 additions & 0 deletions internals/overlord/servstate/manager.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,8 @@ func (m *ServiceManager) Replan() ([][]string, [][]string, error) {
return nil, nil, fmt.Errorf("internal error: invalid section type %T", ws)
}
s.workload = ws.Entries[s.config.Workload].copy()
} else {
s.workload = nil
}
}
needsRestart[name] = true
Expand Down
64 changes: 36 additions & 28 deletions internals/overlord/servstate/workloads.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,11 @@ import (
"bytes"
"errors"
"fmt"
"maps"

yaml "gopkg.in/yaml.v3"
"gopkg.in/yaml.v3"

"github.com/canonical/pebble/internals/osutil"
"github.com/canonical/pebble/internals/plan"
)

Expand All @@ -47,6 +49,7 @@ func (ext *WorkloadsSectionExtension) ParseSection(data yaml.Node) (plan.Section
// The following issue prevents us from using the yaml.Node decoder
// with KnownFields = true behavior. Once one of the proposals get
// merged, we can remove the intermediate Marshal step.
// https://github.com/go-yaml/yaml/issues/460
if len(data.Content) != 0 {
yml, err := yaml.Marshal(data)
if err != nil {
Expand Down Expand Up @@ -74,10 +77,19 @@ func (ext *WorkloadsSectionExtension) ValidatePlan(p *plan.Plan) error {
return fmt.Errorf("internal error: invalid section type %T", ws)
}
for name, service := range p.Services {
_, ok := ws.Entries[service.Workload]
if service.Workload != "" && !ok {
if service.Workload == "" {
continue
}
if _, ok := ws.Entries[service.Workload]; !ok {
return &plan.FormatError{
Message: fmt.Sprintf(`plan service %q workload not defined: %q`, name, service.Workload),
}
}
}
for name, workload := range ws.Entries {
if _, _, err := osutil.NormalizeUidGid(workload.UserID, workload.GroupID, workload.User, workload.Group); err != nil {
return &plan.FormatError{
Message: fmt.Sprintf(`plan service %q cannot run in unknown workload %q`, name, service.Workload),
Message: fmt.Sprintf(`plan workload %q %v`, err, name),
}
}
}
Expand All @@ -100,7 +112,7 @@ func (ws *WorkloadsSection) Validate() error {
for name, workload := range ws.Entries {
if workload == nil {
return &plan.FormatError{
Message: fmt.Sprintf("workload %q has a null value", name),
Message: fmt.Sprintf("workload %q cannot have a null value", name),
}
}
if err := workload.validate(); err != nil {
Expand All @@ -113,10 +125,10 @@ func (ws *WorkloadsSection) Validate() error {
}

func (ws *WorkloadsSection) combine(other *WorkloadsSection) error {
if len(other.Entries) != 0 && ws.Entries == nil {
ws.Entries = make(map[string]*Workload, len(other.Entries))
}
for name, workload := range other.Entries {
if ws.Entries == nil {
ws.Entries = make(map[string]*Workload, len(other.Entries))
}
switch workload.Override {
case plan.MergeOverride:
if current, ok := ws.Entries[name]; ok {
Expand Down Expand Up @@ -164,46 +176,42 @@ func (w *Workload) validate() error {

func (w *Workload) copy() *Workload {
copied := *w
if w.Environment != nil {
copied.Environment = make(map[string]string, len(w.Environment))
for k, v := range w.Environment {
copied.Environment[k] = v
}
}
if w.UserID != nil {
copied.UserID = copyIntPtr(w.UserID)
}
if w.GroupID != nil {
copied.GroupID = copyIntPtr(w.GroupID)
}
copied.Environment = maps.Clone(w.Environment)
copied.UserID = copyPtr(w.UserID)
copied.GroupID = copyPtr(w.GroupID)
return &copied
}

func (w *Workload) merge(other *Workload) {
if len(other.Environment) != 0 && w.Environment == nil {
w.Environment = make(map[string]string, len(other.Environment))
}
for k, v := range other.Environment {
w.Environment[k] = v
if len(other.Environment) > 0 {
w.Environment = makeMapIfNil(w.Environment)
maps.Copy(w.Environment, other.Environment)
}
if other.UserID != nil {
w.UserID = copyIntPtr(other.UserID)
w.UserID = copyPtr(other.UserID)
}
if other.User != "" {
w.User = other.User
}
if other.GroupID != nil {
w.GroupID = copyIntPtr(other.GroupID)
w.GroupID = copyPtr(other.GroupID)
}
if other.Group != "" {
w.Group = other.Group
}
}

func copyIntPtr(p *int) *int {
func copyPtr[T any](p *T) *T {
if p == nil {
return nil
}
copied := *p
return &copied
}

func makeMapIfNil[K comparable, V any](m map[K]V) map[K]V {
if m == nil {
m = make(map[K]V)
}
return m
}
Loading

0 comments on commit 900d0c7

Please sign in to comment.