package dns import ( "context" "crypto/tls" "encoding/json" "fmt" "io" "net" "net/http" "strings" "time" "github.com/SNI_Proxy/config" "github.com/miekg/dns" ) // Resolver 是DNS解析器接口 type Resolver interface { // Resolve 解析域名为IP地址 Resolve(domain string) ([]net.IP, error) // Close 关闭解析器 Close() error } // 基础DNS解析器结构 type BaseResolver struct { server string timeout time.Duration } // 创建DNS解析器 func NewResolver(cfg config.DNSResolverConfig) (Resolver, error) { switch cfg.Protocol { case config.DNSProtocolUDP: return NewUDPResolver(cfg) case config.DNSProtocolTCP: return NewTCPResolver(cfg) case config.DNSProtocolDoT: return NewDoTResolver(cfg) case config.DNSProtocolDoH: return NewDoHResolver(cfg) case config.DNSProtocolDoQ: return NewDoQResolver(cfg) default: return nil, fmt.Errorf("不支持的DNS协议: %s", cfg.Protocol) } } // UDP DNS解析器 type UDPResolver struct { BaseResolver client *dns.Client } func NewUDPResolver(cfg config.DNSResolverConfig) (*UDPResolver, error) { timeout := time.Duration(cfg.Timeout) * time.Second return &UDPResolver{ BaseResolver: BaseResolver{ server: cfg.Server, timeout: timeout, }, client: &dns.Client{ Net: "udp", Timeout: timeout, }, }, nil } // 执行DNS查询 func dnsQuery(client *dns.Client, server string, domain string, qtype uint16) ([]net.IP, error) { m := new(dns.Msg) m.SetQuestion(dns.Fqdn(domain), qtype) m.RecursionDesired = true resp, _, err := client.Exchange(m, server) if err != nil { return nil, err } if resp.Rcode != dns.RcodeSuccess { return nil, fmt.Errorf("DNS查询返回错误码: %d", resp.Rcode) } var ips []net.IP for _, ans := range resp.Answer { switch rr := ans.(type) { case *dns.A: ips = append(ips, rr.A) case *dns.AAAA: ips = append(ips, rr.AAAA) } } return ips, nil } func (r *UDPResolver) Resolve(domain string) ([]net.IP, error) { // 先尝试A记录 ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA) if err != nil { return nil, fmt.Errorf("UDP DNS A查询失败: %w", err) } // 如果没有A记录,尝试AAAA记录 if len(ips) == 0 { ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA) if err != nil { return nil, fmt.Errorf("UDP DNS AAAA查询失败: %w", err) } } if len(ips) == 0 { return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain) } return ips, nil } func (r *UDPResolver) Close() error { return nil // UDP没有需要关闭的资源 } // TCP DNS解析器 type TCPResolver struct { BaseResolver client *dns.Client } func NewTCPResolver(cfg config.DNSResolverConfig) (*TCPResolver, error) { timeout := time.Duration(cfg.Timeout) * time.Second return &TCPResolver{ BaseResolver: BaseResolver{ server: cfg.Server, timeout: timeout, }, client: &dns.Client{ Net: "tcp", Timeout: timeout, }, }, nil } func (r *TCPResolver) Resolve(domain string) ([]net.IP, error) { // 先尝试A记录 ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA) if err != nil { return nil, fmt.Errorf("TCP DNS A查询失败: %w", err) } // 如果没有A记录,尝试AAAA记录 if len(ips) == 0 { ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA) if err != nil { return nil, fmt.Errorf("TCP DNS AAAA查询失败: %w", err) } } if len(ips) == 0 { return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain) } return ips, nil } func (r *TCPResolver) Close() error { return nil // TCP没有需要关闭的资源 } // DoT DNS解析器 (DNS over TLS) type DoTResolver struct { BaseResolver client *dns.Client } func NewDoTResolver(cfg config.DNSResolverConfig) (*DoTResolver, error) { timeout := time.Duration(cfg.Timeout) * time.Second return &DoTResolver{ BaseResolver: BaseResolver{ server: cfg.Server, timeout: timeout, }, client: &dns.Client{ Net: "tcp-tls", Timeout: timeout, TLSConfig: &tls.Config{ MinVersion: tls.VersionTLS12, }, }, }, nil } func (r *DoTResolver) Resolve(domain string) ([]net.IP, error) { // 先尝试A记录 ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA) if err != nil { return nil, fmt.Errorf("DoT DNS A查询失败: %w", err) } // 如果没有A记录,尝试AAAA记录 if len(ips) == 0 { ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA) if err != nil { return nil, fmt.Errorf("DoT DNS AAAA查询失败: %w", err) } } if len(ips) == 0 { return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain) } return ips, nil } func (r *DoTResolver) Close() error { return nil // DoT没有需要关闭的资源 } // DoH DNS解析器 (DNS over HTTPS) type DoHResolver struct { BaseResolver client *http.Client } func NewDoHResolver(cfg config.DNSResolverConfig) (*DoHResolver, error) { timeout := time.Duration(cfg.Timeout) * time.Second return &DoHResolver{ BaseResolver: BaseResolver{ server: cfg.Server, timeout: timeout, }, client: &http.Client{ Timeout: timeout, Transport: &http.Transport{ TLSClientConfig: &tls.Config{ MinVersion: tls.VersionTLS12, }, IdleConnTimeout: 30 * time.Second, }, }, }, nil } type dohResponse struct { Status int `json:"Status"` TC bool `json:"TC"` RD bool `json:"RD"` RA bool `json:"RA"` AD bool `json:"AD"` CD bool `json:"CD"` Question []struct { Name string `json:"name"` Type int `json:"type"` } `json:"Question"` Answer []struct { Name string `json:"name"` Type int `json:"type"` TTL int `json:"TTL"` Data string `json:"data"` } `json:"Answer"` } func (r *DoHResolver) Resolve(domain string) ([]net.IP, error) { // 构建DoH请求URL url := r.server if !strings.HasPrefix(url, "https://") { url = "https://" + url } if !strings.Contains(url, "?") { url += "?name=" + domain + "&type=A" } // 发送请求 ctx, cancel := context.WithTimeout(context.Background(), r.timeout) defer cancel() req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("创建DoH请求失败: %w", err) } req.Header.Set("Accept", "application/dns-json") resp, err := r.client.Do(req) if err != nil { return nil, fmt.Errorf("DoH请求失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("DoH请求返回非200状态码: %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取DoH响应失败: %w", err) } var dohResp dohResponse if err := json.Unmarshal(body, &dohResp); err != nil { return nil, fmt.Errorf("解析DoH响应失败: %w", err) } if dohResp.Status != 0 { return nil, fmt.Errorf("DoH查询返回错误码: %d", dohResp.Status) } var ips []net.IP for _, ans := range dohResp.Answer { if ans.Type == 1 { // A记录 ip := net.ParseIP(ans.Data) if ip != nil { ips = append(ips, ip) } } } // 如果没有A记录,尝试AAAA记录 if len(ips) == 0 { url := r.server if !strings.HasPrefix(url, "https://") { url = "https://" + url } if !strings.Contains(url, "?") { url += "?name=" + domain + "&type=AAAA" } else { url = strings.Replace(url, "type=A", "type=AAAA", 1) } req, err := http.NewRequestWithContext(ctx, "GET", url, nil) if err != nil { return nil, fmt.Errorf("创建DoH AAAA请求失败: %w", err) } req.Header.Set("Accept", "application/dns-json") resp, err := r.client.Do(req) if err != nil { return nil, fmt.Errorf("DoH AAAA请求失败: %w", err) } defer resp.Body.Close() if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("DoH AAAA请求返回非200状态码: %d", resp.StatusCode) } body, err := io.ReadAll(resp.Body) if err != nil { return nil, fmt.Errorf("读取DoH AAAA响应失败: %w", err) } if err := json.Unmarshal(body, &dohResp); err != nil { return nil, fmt.Errorf("解析DoH AAAA响应失败: %w", err) } for _, ans := range dohResp.Answer { if ans.Type == 28 { // AAAA记录 ip := net.ParseIP(ans.Data) if ip != nil { ips = append(ips, ip) } } } } if len(ips) == 0 { return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain) } return ips, nil } func (r *DoHResolver) Close() error { r.client.CloseIdleConnections() return nil } // DoQ DNS解析器 (DNS over QUIC) type DoQResolver struct { BaseResolver } func NewDoQResolver(cfg config.DNSResolverConfig) (*DoQResolver, error) { return &DoQResolver{ BaseResolver: BaseResolver{ server: cfg.Server, timeout: time.Duration(cfg.Timeout) * time.Second, }, }, nil } func (r *DoQResolver) Resolve(domain string) ([]net.IP, error) { // DoQ实现较为复杂,需要QUIC协议支持 // 这里仅作为占位符,实际项目中需要实现 return nil, fmt.Errorf("DoQ协议暂未实现") } func (r *DoQResolver) Close() error { return nil }