Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[client] Support DNS Labels for Peer Addressing #3252

Draft
wants to merge 25 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
1147702
Add DNS labels support to login request and configuration
hakansa Jan 30, 2025
d1025b2
[client] Update Login method to include DNS labels parameter
hakansa Jan 30, 2025
89b1e5a
move validateDomains to management.domain pkg to reuse
hakansa Jan 30, 2025
fad86ab
Add cleanDNSLabels field to LoginRequest and update validation logic
hakansa Jan 31, 2025
d943431
Rename dnsLabelsFlag to extra-dns-labels for clarity in command usage
hakansa Feb 3, 2025
1c34883
Merge branch 'main' into feature/dns-labels
hakansa Feb 3, 2025
12979f2
Refactor localResolver to improve multiple DNS record handling
hakansa Feb 3, 2025
5d87c66
Refactor local handler update to support multiple DNS records per domain
hakansa Feb 5, 2025
1682c53
Merge branch 'main' into feature/dns-labels
hakansa Feb 5, 2025
8dbd6d1
Remove unused import from routes_handler_test.go
hakansa Feb 5, 2025
940e8b3
Update management/domain/validate.go
hakansa Feb 6, 2025
c615337
Use slices.Equal for DNS label comparison and add error logging for r…
hakansa Feb 6, 2025
2c38be7
Update protoc version in generated proto files
hakansa Feb 6, 2025
5afa933
Fix DNS labels comparison logic
hakansa Feb 6, 2025
ceb4280
Merge branch 'main' into feature/dns-labels
hakansa Feb 6, 2025
6fc6c00
fix proto merge conflict
hakansa Feb 6, 2025
4946a7b
Replace assert library
hakansa Feb 6, 2025
ffdfa6b
use domain.FromPunycodeList on client login
hakansa Feb 7, 2025
74d4bc7
Merge branch 'main' into feature/dns-labels
hakansa Feb 10, 2025
64e1433
client grpc proto fix
hakansa Feb 10, 2025
f2a749f
[client] Update DNS labels flag description to specify maximum limit
hakansa Feb 10, 2025
b375671
Implement round-robin rotation for DNS records when multiple records …
hakansa Feb 11, 2025
94c8f84
Refactor DNS label handling in login and up commands for improved val…
hakansa Feb 12, 2025
f041b0f
Refactor DNS label validation in up command to improve error handling
hakansa Feb 12, 2025
d536f8f
Merge branch 'main' into feature/dns-labels
hakansa Feb 12, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions client/cmd/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,17 @@ var loginCmd = &cobra.Command{

client := proto.NewDaemonServiceClient(conn)

var dnsLabelsReq []string
if dnsLabelsValidated != nil {
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
}

loginRequest := proto.LoginRequest{
SetupKey: providedSetupKey,
ManagementUrl: managementURL,
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
DnsLabels: dnsLabelsReq,
}

if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
Expand Down
46 changes: 44 additions & 2 deletions client/cmd/up.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util"
)

Expand All @@ -29,9 +30,16 @@ const (
interfaceInputType
)

const (
dnsLabelsFlag = "extra-dns-labels"
)

var (
foregroundMode bool
upCmd = &cobra.Command{
foregroundMode bool
dnsLabels []string
dnsLabelsValidated domain.List

upCmd = &cobra.Command{
Use: "up",
Short: "install, login and start Netbird client",
RunE: upFunc,
Expand All @@ -49,6 +57,14 @@ func init() {
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")

upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
`Sets DNS labels`+
`You can specify a comma-separated list of up to 32 labels. `+
`An empty string "" clears the previous configuration. `+
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
`or --extra-dns-labels ""`,
)
}

func upFunc(cmd *cobra.Command, args []string) error {
Expand All @@ -67,6 +83,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
return err
}

dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
if err != nil {
return err
}

ctx := internal.CtxInitState(cmd.Context())

if hostName != "" {
Expand Down Expand Up @@ -98,6 +119,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
NATExternalIPs: natExternalIPs,
CustomDNSAddress: customDNSAddressConverted,
ExtraIFaceBlackList: extraIFaceBlackList,
DNSLabels: dnsLabelsValidated,
}

if cmd.Flag(enableRosenpassFlag).Changed {
Expand Down Expand Up @@ -240,6 +262,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
IsLinuxDesktopClient: isLinuxRunningDesktop(),
Hostname: hostName,
ExtraIFaceBlacklist: extraIFaceBlackList,
DnsLabels: dnsLabels,
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
}

if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
Expand Down Expand Up @@ -430,6 +454,24 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
return parsed, nil
}

func validateDnsLabels(labels []string) (domain.List, error) {
var (
domains domain.List
err error
)

if len(labels) == 0 {
return domains, nil
}

domains, err = domain.ValidateDomains(labels)
if err != nil {
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
}

return domains, nil
}

func isValidAddrPort(input string) bool {
if input == "" {
return true
Expand Down
14 changes: 14 additions & 0 deletions client/internal/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"os"
"reflect"
"runtime"
"slices"
"strings"
"time"

Expand All @@ -20,6 +21,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/dynamic"
"github.com/netbirdio/netbird/client/ssh"
mgm "github.com/netbirdio/netbird/management/client"
"github.com/netbirdio/netbird/management/domain"
"github.com/netbirdio/netbird/util"
)

Expand Down Expand Up @@ -68,6 +70,8 @@ type ConfigInput struct {
DisableFirewall *bool

BlockLANAccess *bool

DNSLabels domain.List
}

// Config Configuration type
Expand All @@ -93,6 +97,8 @@ type Config struct {

BlockLANAccess bool

DNSLabels domain.List

// SSHKey is a private SSH key in a PEM format
SSHKey string

Expand Down Expand Up @@ -489,6 +495,14 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
}
}

if input.DNSLabels != nil && !slices.Equal(config.DNSLabels, input.DNSLabels) {
log.Infof("updating DNS labels [ %s ] (old value: [ %s ])",
input.DNSLabels.SafeString(),
config.DNSLabels.SafeString())
config.DNSLabels = input.DNSLabels
updated = true
}

return updated, nil
}

Expand Down
2 changes: 1 addition & 1 deletion client/internal/connect.go
Original file line number Diff line number Diff line change
Expand Up @@ -478,7 +478,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.DisableDNS,
config.DisableFirewall,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {
return nil, err
}
Expand Down
62 changes: 45 additions & 17 deletions client/internal/dns/local.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ type registrationMap map[string]struct{}

type localResolver struct {
registeredMap registrationMap
records sync.Map
records sync.Map // key: string (domain_class_type), value: []dns.RR
}

func (d *localResolver) MatchSubdomains() bool {
Expand Down Expand Up @@ -43,11 +43,12 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
replyMessage := &dns.Msg{}
replyMessage.SetReply(r)
replyMessage.RecursionAvailable = true
replyMessage.Rcode = dns.RcodeSuccess

response := d.lookupRecord(r)
if response != nil {
replyMessage.Answer = append(replyMessage.Answer, response)
// lookup all records matching the question
records := d.lookupRecords(r)
if len(records) > 0 {
replyMessage.Rcode = dns.RcodeSuccess
replyMessage.Answer = append(replyMessage.Answer, records...)
} else {
replyMessage.Rcode = dns.RcodeNameError
}
Expand All @@ -58,37 +59,64 @@ func (d *localResolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
}

func (d *localResolver) lookupRecord(r *dns.Msg) dns.RR {
// lookupRecords fetches *all* DNS records matching the first question in r.
func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
if len(r.Question) == 0 {
return nil
}
question := r.Question[0]
record, found := d.records.Load(buildRecordKey(question.Name, question.Qclass, question.Qtype))
key := buildRecordKey(question.Name, question.Qclass, question.Qtype)

value, found := d.records.Load(key)
if !found {
return nil
}

return record.(dns.RR)
records, ok := value.([]dns.RR)
if !ok {
log.Errorf("failed to cast records to []dns.RR, records: %v", value)
return nil
hakansa marked this conversation as resolved.
Show resolved Hide resolved
}

// if there's more than one record, rotate them (round-robin)
if len(records) > 1 {
first := records[0]
records = append(records[1:], first)
d.records.Store(key, records)
}

return records
}

func (d *localResolver) registerRecord(record nbdns.SimpleRecord) error {
fullRecord, err := dns.NewRR(record.String())
// registerRecord stores a new record by appending it to any existing list
func (d *localResolver) registerRecord(record nbdns.SimpleRecord) (string, error) {
rr, err := dns.NewRR(record.String())
if err != nil {
return fmt.Errorf("register record: %w", err)
return "", fmt.Errorf("register record: %w", err)
}

fullRecord.Header().Rdlength = record.Len()
rr.Header().Rdlength = record.Len()
header := rr.Header()
key := buildRecordKey(header.Name, header.Class, header.Rrtype)

header := fullRecord.Header()
d.records.Store(buildRecordKey(header.Name, header.Class, header.Rrtype), fullRecord)
// load any existing slice of records, then append
existing, _ := d.records.LoadOrStore(key, []dns.RR{})
records := existing.([]dns.RR)
records = append(records, rr)

return nil
// store updated slice
d.records.Store(key, records)
return key, nil
}

// deleteRecord removes *all* records under the recordKey.
func (d *localResolver) deleteRecord(recordKey string) {
d.records.Delete(dns.Fqdn(recordKey))
}

// buildRecordKey consistently generates a key: name_class_type
func buildRecordKey(name string, class, qType uint16) string {
key := fmt.Sprintf("%s_%d_%d", name, class, qType)
return key
return fmt.Sprintf("%s_%d_%d", dns.Fqdn(name), class, qType)
}

func (d *localResolver) probeAvailability() {}
2 changes: 1 addition & 1 deletion client/internal/dns/local_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ func TestLocalResolver_ServeDNS(t *testing.T) {
resolver := &localResolver{
registeredMap: make(registrationMap),
}
_ = resolver.registerRecord(testCase.inputRecord)
_, _ = resolver.registerRecord(testCase.inputRecord)
var responseMSG *dns.Msg
responseWriter := &mockResponseWriter{
WriteMsgFunc: func(m *dns.Msg) error {
Expand Down
39 changes: 28 additions & 11 deletions client/internal/dns/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -393,18 +393,22 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
s.service.Stop()
}

localMuxUpdates, localRecords, err := s.buildLocalHandlerUpdate(update.CustomZones)
localMuxUpdates, localRecordsByDomain, err := s.buildLocalHandlerUpdate(update.CustomZones)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}

upstreamMuxUpdates, err := s.buildUpstreamHandlerUpdate(update.NameServerGroups)
if err != nil {
return fmt.Errorf("not applying dns update, error: %v", err)
}
muxUpdates := append(localMuxUpdates, upstreamMuxUpdates...) //nolint:gocritic

s.updateMux(muxUpdates)
s.updateLocalResolver(localRecords)

// register local records
s.updateLocalResolver(localRecordsByDomain)

s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())

hostUpdate := s.currentConfig
Expand Down Expand Up @@ -434,9 +438,12 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
return nil
}

func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone) ([]handlerWrapper, map[string]nbdns.SimpleRecord, error) {
func (s *DefaultServer) buildLocalHandlerUpdate(
customZones []nbdns.CustomZone,
) ([]handlerWrapper, map[string][]nbdns.SimpleRecord, error) {

var muxUpdates []handlerWrapper
localRecords := make(map[string]nbdns.SimpleRecord, 0)
localRecords := make(map[string][]nbdns.SimpleRecord)

for _, customZone := range customZones {
if len(customZone.Records) == 0 {
Expand All @@ -449,15 +456,18 @@ func (s *DefaultServer) buildLocalHandlerUpdate(customZones []nbdns.CustomZone)
priority: PriorityMatchDomain,
})

// group all records under this domain
for _, record := range customZone.Records {
var class uint16 = dns.ClassINET
if record.Class != nbdns.DefaultClass {
return nil, nil, fmt.Errorf("received an invalid class type: %s", record.Class)
}
key := buildRecordKey(record.Name, class, uint16(record.Type))
localRecords[key] = record

localRecords[key] = append(localRecords[key], record)
}
}

return muxUpdates, localRecords, nil
}

Expand Down Expand Up @@ -593,7 +603,8 @@ func (s *DefaultServer) updateMux(muxUpdates []handlerWrapper) {
s.dnsMuxMap = muxUpdateMap
}

func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord) {
func (s *DefaultServer) updateLocalResolver(update map[string][]nbdns.SimpleRecord) {
// remove old records that are no longer present
for key := range s.localResolver.registeredMap {
_, found := update[key]
if !found {
Expand All @@ -602,12 +613,18 @@ func (s *DefaultServer) updateLocalResolver(update map[string]nbdns.SimpleRecord
}

updatedMap := make(registrationMap)
for key, record := range update {
err := s.localResolver.registerRecord(record)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v", record.String(), err)
for _, recs := range update {
for _, rec := range recs {
// convert the record to a dns.RR and register
key, err := s.localResolver.registerRecord(rec)
if err != nil {
log.Warnf("got an error while registering the record (%s), error: %v",
rec.String(), err)
continue
}

updatedMap[key] = struct{}{}
}
updatedMap[key] = struct{}{}
}

s.localResolver.registeredMap = updatedMap
Expand Down
2 changes: 1 addition & 1 deletion client/internal/dns/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -573,7 +573,7 @@ func TestDNSServerStartStop(t *testing.T) {
}
time.Sleep(100 * time.Millisecond)
defer dnsServer.Stop()
err = dnsServer.localResolver.registerRecord(zoneRecords[0])
_, err = dnsServer.localResolver.registerRecord(zoneRecords[0])
if err != nil {
t.Error(err)
}
Expand Down
2 changes: 1 addition & 1 deletion client/internal/login.go
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.DisableDNS,
config.DisableFirewall,
)
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey)
_, err = mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, err
}

Expand Down
Loading
Loading