SNI_proxy/dns/resolver.go

399 lines
8.8 KiB
Go
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package dns
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/SNI_Proxy/config"
"github.com/miekg/dns"
)
// Resolver 是DNS解析器接口
type Resolver interface {
// Resolve 解析域名为IP地址
Resolve(domain string) ([]net.IP, error)
// Close 关闭解析器
Close() error
}
// 基础DNS解析器结构
type BaseResolver struct {
server string
timeout time.Duration
}
// 创建DNS解析器
func NewResolver(cfg config.DNSResolverConfig) (Resolver, error) {
switch cfg.Protocol {
case config.DNSProtocolUDP:
return NewUDPResolver(cfg)
case config.DNSProtocolTCP:
return NewTCPResolver(cfg)
case config.DNSProtocolDoT:
return NewDoTResolver(cfg)
case config.DNSProtocolDoH:
return NewDoHResolver(cfg)
case config.DNSProtocolDoQ:
return NewDoQResolver(cfg)
default:
return nil, fmt.Errorf("不支持的DNS协议: %s", cfg.Protocol)
}
}
// UDP DNS解析器
type UDPResolver struct {
BaseResolver
client *dns.Client
}
func NewUDPResolver(cfg config.DNSResolverConfig) (*UDPResolver, error) {
timeout := time.Duration(cfg.Timeout) * time.Second
return &UDPResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: timeout,
},
client: &dns.Client{
Net: "udp",
Timeout: timeout,
},
}, nil
}
// 执行DNS查询
func dnsQuery(client *dns.Client, server string, domain string, qtype uint16) ([]net.IP, error) {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), qtype)
m.RecursionDesired = true
resp, _, err := client.Exchange(m, server)
if err != nil {
return nil, err
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("DNS查询返回错误码: %d", resp.Rcode)
}
var ips []net.IP
for _, ans := range resp.Answer {
switch rr := ans.(type) {
case *dns.A:
ips = append(ips, rr.A)
case *dns.AAAA:
ips = append(ips, rr.AAAA)
}
}
return ips, nil
}
func (r *UDPResolver) Resolve(domain string) ([]net.IP, error) {
// 先尝试A记录
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
if err != nil {
return nil, fmt.Errorf("UDP DNS A查询失败: %w", err)
}
// 如果没有A记录尝试AAAA记录
if len(ips) == 0 {
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
if err != nil {
return nil, fmt.Errorf("UDP DNS AAAA查询失败: %w", err)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
}
return ips, nil
}
func (r *UDPResolver) Close() error {
return nil // UDP没有需要关闭的资源
}
// TCP DNS解析器
type TCPResolver struct {
BaseResolver
client *dns.Client
}
func NewTCPResolver(cfg config.DNSResolverConfig) (*TCPResolver, error) {
timeout := time.Duration(cfg.Timeout) * time.Second
return &TCPResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: timeout,
},
client: &dns.Client{
Net: "tcp",
Timeout: timeout,
},
}, nil
}
func (r *TCPResolver) Resolve(domain string) ([]net.IP, error) {
// 先尝试A记录
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
if err != nil {
return nil, fmt.Errorf("TCP DNS A查询失败: %w", err)
}
// 如果没有A记录尝试AAAA记录
if len(ips) == 0 {
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
if err != nil {
return nil, fmt.Errorf("TCP DNS AAAA查询失败: %w", err)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
}
return ips, nil
}
func (r *TCPResolver) Close() error {
return nil // TCP没有需要关闭的资源
}
// DoT DNS解析器 (DNS over TLS)
type DoTResolver struct {
BaseResolver
client *dns.Client
}
func NewDoTResolver(cfg config.DNSResolverConfig) (*DoTResolver, error) {
timeout := time.Duration(cfg.Timeout) * time.Second
return &DoTResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: timeout,
},
client: &dns.Client{
Net: "tcp-tls",
Timeout: timeout,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
},
}, nil
}
func (r *DoTResolver) Resolve(domain string) ([]net.IP, error) {
// 先尝试A记录
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
if err != nil {
return nil, fmt.Errorf("DoT DNS A查询失败: %w", err)
}
// 如果没有A记录尝试AAAA记录
if len(ips) == 0 {
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
if err != nil {
return nil, fmt.Errorf("DoT DNS AAAA查询失败: %w", err)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
}
return ips, nil
}
func (r *DoTResolver) Close() error {
return nil // DoT没有需要关闭的资源
}
// DoH DNS解析器 (DNS over HTTPS)
type DoHResolver struct {
BaseResolver
client *http.Client
}
func NewDoHResolver(cfg config.DNSResolverConfig) (*DoHResolver, error) {
timeout := time.Duration(cfg.Timeout) * time.Second
return &DoHResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: timeout,
},
client: &http.Client{
Timeout: timeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
IdleConnTimeout: 30 * time.Second,
},
},
}, nil
}
type dohResponse struct {
Status int `json:"Status"`
TC bool `json:"TC"`
RD bool `json:"RD"`
RA bool `json:"RA"`
AD bool `json:"AD"`
CD bool `json:"CD"`
Question []struct {
Name string `json:"name"`
Type int `json:"type"`
} `json:"Question"`
Answer []struct {
Name string `json:"name"`
Type int `json:"type"`
TTL int `json:"TTL"`
Data string `json:"data"`
} `json:"Answer"`
}
func (r *DoHResolver) Resolve(domain string) ([]net.IP, error) {
// 构建DoH请求URL
url := r.server
if !strings.HasPrefix(url, "https://") {
url = "https://" + url
}
if !strings.Contains(url, "?") {
url += "?name=" + domain + "&type=A"
}
// 发送请求
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("创建DoH请求失败: %w", err)
}
req.Header.Set("Accept", "application/dns-json")
resp, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("DoH请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DoH请求返回非200状态码: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取DoH响应失败: %w", err)
}
var dohResp dohResponse
if err := json.Unmarshal(body, &dohResp); err != nil {
return nil, fmt.Errorf("解析DoH响应失败: %w", err)
}
if dohResp.Status != 0 {
return nil, fmt.Errorf("DoH查询返回错误码: %d", dohResp.Status)
}
var ips []net.IP
for _, ans := range dohResp.Answer {
if ans.Type == 1 { // A记录
ip := net.ParseIP(ans.Data)
if ip != nil {
ips = append(ips, ip)
}
}
}
// 如果没有A记录尝试AAAA记录
if len(ips) == 0 {
url := r.server
if !strings.HasPrefix(url, "https://") {
url = "https://" + url
}
if !strings.Contains(url, "?") {
url += "?name=" + domain + "&type=AAAA"
} else {
url = strings.Replace(url, "type=A", "type=AAAA", 1)
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("创建DoH AAAA请求失败: %w", err)
}
req.Header.Set("Accept", "application/dns-json")
resp, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("DoH AAAA请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DoH AAAA请求返回非200状态码: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取DoH AAAA响应失败: %w", err)
}
if err := json.Unmarshal(body, &dohResp); err != nil {
return nil, fmt.Errorf("解析DoH AAAA响应失败: %w", err)
}
for _, ans := range dohResp.Answer {
if ans.Type == 28 { // AAAA记录
ip := net.ParseIP(ans.Data)
if ip != nil {
ips = append(ips, ip)
}
}
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
}
return ips, nil
}
func (r *DoHResolver) Close() error {
r.client.CloseIdleConnections()
return nil
}
// DoQ DNS解析器 (DNS over QUIC)
type DoQResolver struct {
BaseResolver
}
func NewDoQResolver(cfg config.DNSResolverConfig) (*DoQResolver, error) {
return &DoQResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: time.Duration(cfg.Timeout) * time.Second,
},
}, nil
}
func (r *DoQResolver) Resolve(domain string) ([]net.IP, error) {
// DoQ实现较为复杂需要QUIC协议支持
// 这里仅作为占位符,实际项目中需要实现
return nil, fmt.Errorf("DoQ协议暂未实现")
}
func (r *DoQResolver) Close() error {
return nil
}