diff --git a/plugin/ipset.go b/plugin/ipset.go index a7f90ca..02a751f 100644 --- a/plugin/ipset.go +++ b/plugin/ipset.go @@ -35,12 +35,12 @@ func (i *Ipset) Init(config map[string]interface{}) error { for name := range sets { domains, err := domain.TreeFromFile(sets[name].(map[string]interface{})["domain_file"].(string)) - if err != nil { - continue - } + i.Domains[name] = domains - set, err := ipset.New(name, "hash:net", &ipset.Params{}) + set, err := ipset.New(name, "hash:ip", &ipset.Params{ + Timeout: 600, + }) if set == nil { log.Error(err) continue @@ -50,6 +50,15 @@ func (i *Ipset) Init(config map[string]interface{}) error { log.Error(err) } + IpFile := sets[name].(map[string]interface{})["ip_file"].(string) + + for _, ip := range parseIPList(IpFile) { + if err = set.Add(ip, 0); err != nil { + log.Error(err) + } + + } + i.Set[name] = set } @@ -61,7 +70,8 @@ func (i *Ipset) Where() uint8 { } func (i *Ipset) HandleDns(ctx *common.Context) { - if ctx.Response != nil && len(ctx.Response.Answer) <= 0 { + + if ctx.Response != nil && len(ctx.Response.Answer) <= 0 && ctx.Query.Question[0].Qtype != dns.TypeA { return } log.Debug(dns.Field(ctx.Response.Answer[0], 1)) @@ -69,10 +79,13 @@ func (i *Ipset) HandleDns(ctx *common.Context) { for setName := range i.Domains { if i.Domains[setName].Has(domain.Domain(ctx.Response.Question[0].Name)) { for _, ans := range ctx.Response.Answer { - if ans.Header().Rrtype != dns.TypeA { - continue - } - err := i.Set[setName].Add(dns.Field(ans, 1), 0) + + err := i.Set[setName].Add(dns.Field(ans, 1), func(a, b int) int { + if a > b { + return a + } + return b + }(int(ans.Header().Ttl), 3600)) log.Debugf("ipset add %s to %s", ctx.Response.Question[0].Name, setName) if err != nil { log.Error(err)