Skip to content

Commit

Permalink
解决查询一类只返回一个实例名称的问题
Browse files Browse the repository at this point in the history
  • Loading branch information
matteriot committed Dec 3, 2019
1 parent 7225942 commit a423509
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 8 deletions.
54 changes: 51 additions & 3 deletions client.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,17 +201,22 @@ func (c *client) mainloop(ctx context.Context, params *LookupParams) {
c.shutdown()
return
case msg := <-msgCh:
//log.Println("==msgCh", msg)
entries = make(map[string]*ServiceEntry)
sections := append(msg.Answer, msg.Ns...)
sections = append(sections, msg.Extra...)

//log.Println("sections", sections)
for _, answer := range sections {
//log.Println("answer", answer)
switch rr := answer.(type) {
case *dns.PTR:
//log.Println("====dns.PTR", rr.Hdr.Name, rr.Ptr)
if params.ServiceName() != rr.Hdr.Name {
//log.Println("params.ServiceName() != rr.Hdr.Name")
continue
}
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Ptr {
//log.Println("params.ServiceInstanceName() != \"\" && params.ServiceInstanceName() != rr.Ptr")
continue
}
if _, ok := entries[rr.Ptr]; !ok {
Expand All @@ -222,6 +227,7 @@ func (c *client) mainloop(ctx context.Context, params *LookupParams) {
}
entries[rr.Ptr].TTL = rr.Hdr.Ttl
case *dns.SRV:
//log.Println("*dns.SRV", rr)
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name {
continue
} else if !strings.HasSuffix(rr.Hdr.Name, params.ServiceName()) {
Expand All @@ -237,6 +243,7 @@ func (c *client) mainloop(ctx context.Context, params *LookupParams) {
entries[rr.Hdr.Name].Port = int(rr.Port)
entries[rr.Hdr.Name].TTL = rr.Hdr.Ttl
case *dns.TXT:
//log.Println("dns.TXT", rr)
if params.ServiceInstanceName() != "" && params.ServiceInstanceName() != rr.Hdr.Name {
continue
} else if !strings.HasSuffix(rr.Hdr.Name, params.ServiceName()) {
Expand All @@ -257,37 +264,63 @@ func (c *client) mainloop(ctx context.Context, params *LookupParams) {
switch rr := answer.(type) {
case *dns.A:
for k, e := range entries {
//log.Println("A:",rr.A)
if e.HostName == rr.Hdr.Name {
entries[k].AddrIPv4 = append(entries[k].AddrIPv4, rr.A)
}
}
case *dns.AAAA:
for k, e := range entries {
//log.Println("A:",rr.AAAA)
if e.HostName == rr.Hdr.Name {
entries[k].AddrIPv6 = append(entries[k].AddrIPv6, rr.AAAA)
}
}
}
}
}

//log.Println("===entries", entries["home._home-assistant._tcp.local."])
if len(entries) > 0 {
//log.Println("len(entries) > 0")
//log.Println(entries)
//log.Println("===entries", entries["home._home-assistant._tcp.local."])
for k, e := range entries {
if e.TTL == 0 {
//log.Println("e.TTL == 0")
delete(entries, k)
delete(sentEntries, k)
continue
}
if _, ok := sentEntries[k]; ok {
//log.Println(" _, ok := sentEntries[k]; ok")
//log.Println(sentEntries[k])
//log.Println(entries[k])
continue
}

// If this is an DNS-SD query do not throw PTR away.
// It is expected to have only PTR for enumeration
//log.Println("=========")
//log.Println("ServiceTypeName:", params.ServiceRecord.ServiceTypeName() ,"ServiceName:", params.ServiceRecord.ServiceName())
if params.ServiceRecord.ServiceTypeName() != params.ServiceRecord.ServiceName() {
// Require at least one resolved IP address for ServiceEntry
// TODO: wait some more time as chances are high both will arrive.
if len(e.AddrIPv4) == 0 && len(e.AddrIPv6) == 0 {
//log.Println("len(e.AddrIPv4) == 0 && len(e.AddrIPv6) == 0")
//log.Println("e:", e)

newParams := defaultParams(e.Service)
newParams.Instance = e.Instance
newParams.Domain = e.Domain
newParams.Entries = params.Entries
_, cancel := context.WithCancel(ctx)
err := c.query(newParams)
if err != nil {
log.Println("cancel()")
cancel()
}
delete(entries, k)
delete(sentEntries, k)
continue
}
}
Expand All @@ -296,6 +329,7 @@ func (c *client) mainloop(ctx context.Context, params *LookupParams) {
// service entry.
params.Entries <- e
sentEntries[k] = e
//log.Println("sentEntries[k] = e")
params.disableProbing()
}
// reset entries
Expand Down Expand Up @@ -351,11 +385,14 @@ func (c *client) recv(ctx context.Context, l interface{}, msgCh chan *dns.Msg) {
fatalErr = err
continue
}
//log.Println(string(buf))
msg := new(dns.Msg)
if err := msg.Unpack(buf[:n]); err != nil {
// log.Printf("[WARN] mdns: Failed to unpack packet: %v", err)
//log.Println(string(buf))
//log.Printf("[WARN] mdns: Failed to unpack packet: %v", err)
continue
}
//log.Println("===msg:",msg)
select {
case msgCh <- msg:
// Submit decoded DNS message and continue.
Expand Down Expand Up @@ -409,10 +446,13 @@ func (c *client) periodicQuery(ctx context.Context, params *LookupParams) error
// Performs the actual query by service name (browse) or service instance name (lookup),
// start response listeners goroutines and loops over the entries channel.
func (c *client) query(params *LookupParams) error {
//log.Println("aaaa")
var serviceName, serviceInstanceName string
serviceName = fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain))
//log.Println("===serviceName:", serviceName)
if params.Instance != "" {
serviceInstanceName = fmt.Sprintf("%s.%s", params.Instance, serviceName)
//log.Println("===serviceInstanceName", serviceInstanceName)
}

// send the query
Expand All @@ -421,11 +461,18 @@ func (c *client) query(params *LookupParams) error {
m.Question = []dns.Question{
dns.Question{serviceInstanceName, dns.TypeSRV, dns.ClassINET},
dns.Question{serviceInstanceName, dns.TypeTXT, dns.ClassINET},
//dns.Question{serviceInstanceName, dns.TypeA, dns.ClassINET},
}
m.RecursionDesired = false
} else {
m.SetQuestion(serviceName, dns.TypePTR)
//m.Question = []dns.Question{
// dns.Question{serviceName, dns.TypeSRV, dns.ClassINET},
// dns.Question{serviceName, dns.TypeTXT, dns.ClassINET},
// dns.Question{serviceName, dns.TypeA, dns.ClassINET},
//}
m.RecursionDesired = false
//log.Println("++++++++++++++++++++dns:", m.String())
}
if err := c.sendQuery(m); err != nil {
return err
Expand All @@ -440,6 +487,7 @@ func (c *client) sendQuery(msg *dns.Msg) error {
if err != nil {
return err
}
//log.Println(string(buf))
if c.ipv4conn != nil {
var wcm ipv4.ControlMessage
for ifi := range c.ifaces {
Expand Down
14 changes: 9 additions & 5 deletions examples/resolv/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,21 +3,24 @@ package main
import (
"context"
"flag"
"fmt"
"log"
"time"

"github.com/grandcat/zeroconf"
"github.com/iotdevice/zeroconf"
)

var (
service = flag.String("service", "_workstation._tcp", "Set the service category to look for devices.")
service = flag.String("service", "_iotdevice._tcp", "Set the service category to look for devices.")
//service = flag.String("service", "_home-assistant._tcp", "Set the service category to look for devices.")
//service = flag.String("service", "home._home-assistant._tcp", "Set the service category to look for devices.")
domain = flag.String("domain", "local", "Set the search domain. For local networks, default is fine.")
waitTime = flag.Int("wait", 10, "Duration in [s] to run discovery.")
waitTime = flag.Int("wait", 2, "Duration in [s] to run discovery.")
)

func main() {
flag.Parse()

log.Println("=====================start============")
// Discover all services on the network (e.g. _workstation._tcp)
resolver, err := zeroconf.NewResolver(nil)
if err != nil {
Expand All @@ -27,14 +30,15 @@ func main() {
entries := make(chan *zeroconf.ServiceEntry)
go func(results <-chan *zeroconf.ServiceEntry) {
for entry := range results {
log.Println(entry)
fmt.Println("===entry:", entry)
}
log.Println("No more entries.")
}(entries)

ctx, cancel := context.WithTimeout(context.Background(), time.Second*time.Duration(*waitTime))
defer cancel()
err = resolver.Browse(ctx, *service, *domain, entries)
//err = resolver.Lookup(ctx, "home", "_home-assistant._tcp", "local", entries)
if err != nil {
log.Fatalln("Failed to browse:", err.Error())
}
Expand Down

0 comments on commit a423509

Please sign in to comment.