399 lines
8.8 KiB
Go
399 lines
8.8 KiB
Go
package dns
|
||
|
||
import (
|
||
"context"
|
||
"crypto/tls"
|
||
"encoding/json"
|
||
"fmt"
|
||
"io"
|
||
"net"
|
||
"net/http"
|
||
"strings"
|
||
"time"
|
||
|
||
"github.com/SNI_Proxy/config"
|
||
"github.com/miekg/dns"
|
||
)
|
||
|
||
// Resolver 是DNS解析器接口
|
||
type Resolver interface {
|
||
// Resolve 解析域名为IP地址
|
||
Resolve(domain string) ([]net.IP, error)
|
||
// Close 关闭解析器
|
||
Close() error
|
||
}
|
||
|
||
// 基础DNS解析器结构
|
||
type BaseResolver struct {
|
||
server string
|
||
timeout time.Duration
|
||
}
|
||
|
||
// 创建DNS解析器
|
||
func NewResolver(cfg config.DNSResolverConfig) (Resolver, error) {
|
||
switch cfg.Protocol {
|
||
case config.DNSProtocolUDP:
|
||
return NewUDPResolver(cfg)
|
||
case config.DNSProtocolTCP:
|
||
return NewTCPResolver(cfg)
|
||
case config.DNSProtocolDoT:
|
||
return NewDoTResolver(cfg)
|
||
case config.DNSProtocolDoH:
|
||
return NewDoHResolver(cfg)
|
||
case config.DNSProtocolDoQ:
|
||
return NewDoQResolver(cfg)
|
||
default:
|
||
return nil, fmt.Errorf("不支持的DNS协议: %s", cfg.Protocol)
|
||
}
|
||
}
|
||
|
||
// UDP DNS解析器
|
||
type UDPResolver struct {
|
||
BaseResolver
|
||
client *dns.Client
|
||
}
|
||
|
||
func NewUDPResolver(cfg config.DNSResolverConfig) (*UDPResolver, error) {
|
||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||
return &UDPResolver{
|
||
BaseResolver: BaseResolver{
|
||
server: cfg.Server,
|
||
timeout: timeout,
|
||
},
|
||
client: &dns.Client{
|
||
Net: "udp",
|
||
Timeout: timeout,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
// 执行DNS查询
|
||
func dnsQuery(client *dns.Client, server string, domain string, qtype uint16) ([]net.IP, error) {
|
||
m := new(dns.Msg)
|
||
m.SetQuestion(dns.Fqdn(domain), qtype)
|
||
m.RecursionDesired = true
|
||
|
||
resp, _, err := client.Exchange(m, server)
|
||
if err != nil {
|
||
return nil, err
|
||
}
|
||
|
||
if resp.Rcode != dns.RcodeSuccess {
|
||
return nil, fmt.Errorf("DNS查询返回错误码: %d", resp.Rcode)
|
||
}
|
||
|
||
var ips []net.IP
|
||
for _, ans := range resp.Answer {
|
||
switch rr := ans.(type) {
|
||
case *dns.A:
|
||
ips = append(ips, rr.A)
|
||
case *dns.AAAA:
|
||
ips = append(ips, rr.AAAA)
|
||
}
|
||
}
|
||
|
||
return ips, nil
|
||
}
|
||
|
||
func (r *UDPResolver) Resolve(domain string) ([]net.IP, error) {
|
||
// 先尝试A记录
|
||
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("UDP DNS A查询失败: %w", err)
|
||
}
|
||
|
||
// 如果没有A记录,尝试AAAA记录
|
||
if len(ips) == 0 {
|
||
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("UDP DNS AAAA查询失败: %w", err)
|
||
}
|
||
}
|
||
|
||
if len(ips) == 0 {
|
||
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
|
||
}
|
||
|
||
return ips, nil
|
||
}
|
||
|
||
func (r *UDPResolver) Close() error {
|
||
return nil // UDP没有需要关闭的资源
|
||
}
|
||
|
||
// TCP DNS解析器
|
||
type TCPResolver struct {
|
||
BaseResolver
|
||
client *dns.Client
|
||
}
|
||
|
||
func NewTCPResolver(cfg config.DNSResolverConfig) (*TCPResolver, error) {
|
||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||
return &TCPResolver{
|
||
BaseResolver: BaseResolver{
|
||
server: cfg.Server,
|
||
timeout: timeout,
|
||
},
|
||
client: &dns.Client{
|
||
Net: "tcp",
|
||
Timeout: timeout,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
func (r *TCPResolver) Resolve(domain string) ([]net.IP, error) {
|
||
// 先尝试A记录
|
||
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("TCP DNS A查询失败: %w", err)
|
||
}
|
||
|
||
// 如果没有A记录,尝试AAAA记录
|
||
if len(ips) == 0 {
|
||
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("TCP DNS AAAA查询失败: %w", err)
|
||
}
|
||
}
|
||
|
||
if len(ips) == 0 {
|
||
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
|
||
}
|
||
|
||
return ips, nil
|
||
}
|
||
|
||
func (r *TCPResolver) Close() error {
|
||
return nil // TCP没有需要关闭的资源
|
||
}
|
||
|
||
// DoT DNS解析器 (DNS over TLS)
|
||
type DoTResolver struct {
|
||
BaseResolver
|
||
client *dns.Client
|
||
}
|
||
|
||
func NewDoTResolver(cfg config.DNSResolverConfig) (*DoTResolver, error) {
|
||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||
return &DoTResolver{
|
||
BaseResolver: BaseResolver{
|
||
server: cfg.Server,
|
||
timeout: timeout,
|
||
},
|
||
client: &dns.Client{
|
||
Net: "tcp-tls",
|
||
Timeout: timeout,
|
||
TLSConfig: &tls.Config{
|
||
MinVersion: tls.VersionTLS12,
|
||
},
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
func (r *DoTResolver) Resolve(domain string) ([]net.IP, error) {
|
||
// 先尝试A记录
|
||
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("DoT DNS A查询失败: %w", err)
|
||
}
|
||
|
||
// 如果没有A记录,尝试AAAA记录
|
||
if len(ips) == 0 {
|
||
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("DoT DNS AAAA查询失败: %w", err)
|
||
}
|
||
}
|
||
|
||
if len(ips) == 0 {
|
||
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
|
||
}
|
||
|
||
return ips, nil
|
||
}
|
||
|
||
func (r *DoTResolver) Close() error {
|
||
return nil // DoT没有需要关闭的资源
|
||
}
|
||
|
||
// DoH DNS解析器 (DNS over HTTPS)
|
||
type DoHResolver struct {
|
||
BaseResolver
|
||
client *http.Client
|
||
}
|
||
|
||
func NewDoHResolver(cfg config.DNSResolverConfig) (*DoHResolver, error) {
|
||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||
return &DoHResolver{
|
||
BaseResolver: BaseResolver{
|
||
server: cfg.Server,
|
||
timeout: timeout,
|
||
},
|
||
client: &http.Client{
|
||
Timeout: timeout,
|
||
Transport: &http.Transport{
|
||
TLSClientConfig: &tls.Config{
|
||
MinVersion: tls.VersionTLS12,
|
||
},
|
||
IdleConnTimeout: 30 * time.Second,
|
||
},
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
type dohResponse struct {
|
||
Status int `json:"Status"`
|
||
TC bool `json:"TC"`
|
||
RD bool `json:"RD"`
|
||
RA bool `json:"RA"`
|
||
AD bool `json:"AD"`
|
||
CD bool `json:"CD"`
|
||
Question []struct {
|
||
Name string `json:"name"`
|
||
Type int `json:"type"`
|
||
} `json:"Question"`
|
||
Answer []struct {
|
||
Name string `json:"name"`
|
||
Type int `json:"type"`
|
||
TTL int `json:"TTL"`
|
||
Data string `json:"data"`
|
||
} `json:"Answer"`
|
||
}
|
||
|
||
func (r *DoHResolver) Resolve(domain string) ([]net.IP, error) {
|
||
// 构建DoH请求URL
|
||
url := r.server
|
||
if !strings.HasPrefix(url, "https://") {
|
||
url = "https://" + url
|
||
}
|
||
if !strings.Contains(url, "?") {
|
||
url += "?name=" + domain + "&type=A"
|
||
}
|
||
|
||
// 发送请求
|
||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||
defer cancel()
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建DoH请求失败: %w", err)
|
||
}
|
||
req.Header.Set("Accept", "application/dns-json")
|
||
|
||
resp, err := r.client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("DoH请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("DoH请求返回非200状态码: %d", resp.StatusCode)
|
||
}
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取DoH响应失败: %w", err)
|
||
}
|
||
|
||
var dohResp dohResponse
|
||
if err := json.Unmarshal(body, &dohResp); err != nil {
|
||
return nil, fmt.Errorf("解析DoH响应失败: %w", err)
|
||
}
|
||
|
||
if dohResp.Status != 0 {
|
||
return nil, fmt.Errorf("DoH查询返回错误码: %d", dohResp.Status)
|
||
}
|
||
|
||
var ips []net.IP
|
||
for _, ans := range dohResp.Answer {
|
||
if ans.Type == 1 { // A记录
|
||
ip := net.ParseIP(ans.Data)
|
||
if ip != nil {
|
||
ips = append(ips, ip)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 如果没有A记录,尝试AAAA记录
|
||
if len(ips) == 0 {
|
||
url := r.server
|
||
if !strings.HasPrefix(url, "https://") {
|
||
url = "https://" + url
|
||
}
|
||
if !strings.Contains(url, "?") {
|
||
url += "?name=" + domain + "&type=AAAA"
|
||
} else {
|
||
url = strings.Replace(url, "type=A", "type=AAAA", 1)
|
||
}
|
||
|
||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("创建DoH AAAA请求失败: %w", err)
|
||
}
|
||
req.Header.Set("Accept", "application/dns-json")
|
||
|
||
resp, err := r.client.Do(req)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("DoH AAAA请求失败: %w", err)
|
||
}
|
||
defer resp.Body.Close()
|
||
|
||
if resp.StatusCode != http.StatusOK {
|
||
return nil, fmt.Errorf("DoH AAAA请求返回非200状态码: %d", resp.StatusCode)
|
||
}
|
||
|
||
body, err := io.ReadAll(resp.Body)
|
||
if err != nil {
|
||
return nil, fmt.Errorf("读取DoH AAAA响应失败: %w", err)
|
||
}
|
||
|
||
if err := json.Unmarshal(body, &dohResp); err != nil {
|
||
return nil, fmt.Errorf("解析DoH AAAA响应失败: %w", err)
|
||
}
|
||
|
||
for _, ans := range dohResp.Answer {
|
||
if ans.Type == 28 { // AAAA记录
|
||
ip := net.ParseIP(ans.Data)
|
||
if ip != nil {
|
||
ips = append(ips, ip)
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
if len(ips) == 0 {
|
||
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
|
||
}
|
||
|
||
return ips, nil
|
||
}
|
||
|
||
func (r *DoHResolver) Close() error {
|
||
r.client.CloseIdleConnections()
|
||
return nil
|
||
}
|
||
|
||
// DoQ DNS解析器 (DNS over QUIC)
|
||
type DoQResolver struct {
|
||
BaseResolver
|
||
}
|
||
|
||
func NewDoQResolver(cfg config.DNSResolverConfig) (*DoQResolver, error) {
|
||
return &DoQResolver{
|
||
BaseResolver: BaseResolver{
|
||
server: cfg.Server,
|
||
timeout: time.Duration(cfg.Timeout) * time.Second,
|
||
},
|
||
}, nil
|
||
}
|
||
|
||
func (r *DoQResolver) Resolve(domain string) ([]net.IP, error) {
|
||
// DoQ实现较为复杂,需要QUIC协议支持
|
||
// 这里仅作为占位符,实际项目中需要实现
|
||
return nil, fmt.Errorf("DoQ协议暂未实现")
|
||
}
|
||
|
||
func (r *DoQResolver) Close() error {
|
||
return nil
|
||
}
|