commit da892685bc4e6ed96a56dc5f1d1b4adf08afec23 Author: jianghanxin Date: Mon Mar 17 20:01:21 2025 +0800 update diff --git a/README.md b/README.md new file mode 100644 index 0000000..d0ab216 --- /dev/null +++ b/README.md @@ -0,0 +1,149 @@ +# SNI 代理 + +这是一个基于 Go 语言实现的 SNI 代理,支持 TLS 分片功能和多种域名匹配方法。 + +## 功能特点 + +- 直接使用 SNI 域名作为目标地址 +- 支持多种 DNS 协议解析域名(UDP/TCP/DoT/DoH/DoQ) +- 支持 TLS 分片(可配置分片大小范围) +- 多种域名匹配方式: + - 正则表达式匹配 + - 后缀匹配 + - 关键词匹配 +- 完善的超时控制和连接管理 +- 基于配置文件的灵活配置 + +## 安装 + +### 从源码构建 + +```bash +# 克隆仓库 +git clone https://github.com/yourusername/SNI_Proxy.git +cd SNI_Proxy + +# 安装依赖 +go mod tidy + +# 构建 +go build -o sni_proxy +``` + +## 使用方法 + +1. 创建配置文件 `config.yaml`(参考示例配置文件) +2. 运行代理服务器: + +```bash +./sni_proxy -config config.yaml +``` + +## 配置文件说明 + +配置文件使用 YAML 格式,包含以下主要部分: + +```yaml +# 监听地址 +listen: "0.0.0.0:443" + +# 默认目标端口 +default_port: 443 + +# DNS解析器配置 +dns: + protocol: "udp" # 支持 udp, tcp, dot, doh, doq + server: "8.8.8.8:53" # DNS服务器地址 + timeout: 5 # 超时时间(秒) + +# 超时配置 +timeout: + connect: 10 # 连接超时(秒) + read: 30 # 读取超时(秒) + write: 5 # 写入超时(秒) + idle: 60 # 空闲超时(秒) + lifetime: 300 # 连接最大生命周期(秒) + +# 日志级别: debug, info, warn, error +log_level: "info" + +# 最大并发连接数 +max_conns: 1000 + +# 代理规则 +rules: + # 规则示例 + - domains: + - type: regexp # 支持 regexp, suffix, keyword + value: "^api\\.example\\.com$" + port: 443 # 目标端口 + fragment: + enabled: true # 是否启用 TLS 分片 + min_size: 100 # 最小分片大小 + max_size: 500 # 最大分片大小 +``` + +### 工作原理 + +1. 代理服务器接收客户端的 TLS 连接 +2. 从 ClientHello 消息中提取 SNI(服务器名称指示) +3. 根据配置的规则匹配 SNI 域名 +4. 使用配置的 DNS 解析器将 SNI 域名解析为 IP 地址 +5. 使用解析得到的 IP 地址和配置的端口连接到目标服务器 +6. 将客户端的请求转发到目标服务器,可选择进行 TLS 分片 + +### 域名匹配方式 + +1. **正则表达式匹配** (`regexp`):使用正则表达式匹配域名 +2. **后缀匹配** (`suffix`):匹配指定后缀的域名 +3. **关键词匹配** (`keyword`):匹配包含指定关键词的域名 + +### DNS 解析器配置 + +- `protocol`: DNS 协议类型,支持以下选项: + - `udp`: 标准 UDP DNS(默认) + - `tcp`: 标准 TCP DNS + - `dot`: DNS over TLS + - `doh`: DNS over HTTPS + - `doq`: DNS over QUIC(目前不支持) +- `server`: DNS 服务器地址,格式为 `IP:端口` +- `timeout`: DNS 查询超时时间(秒) + +### 超时和连接管理配置 + +- `timeout`: 超时配置 + - `connect`: 连接目标服务器的超时时间(秒) + - `read`: 读取数据的超时时间(秒) + - `write`: 写入数据的超时时间(秒) + - `idle`: 空闲连接的超时时间(秒) + - `lifetime`: 连接的最大生命周期(秒) +- `log_level`: 日志级别,支持 debug、info、warn、error +- `max_conns`: 最大并发连接数,超过此数量的新连接将被拒绝 + +### TLS 分片配置 + +- `enabled`: 是否启用 TLS 分片 +- `min_size`: 最小分片大小(字节) +- `max_size`: 最大分片大小(字节) + +## 性能优化 + +为了避免进程卡死和资源泄漏,本代理实现了以下优化: + +1. **完善的超时控制**:为每个连接阶段设置合理的超时时间 +2. **连接生命周期管理**:限制连接的最大生存时间 +3. **连接数量限制**:防止过多连接导致资源耗尽 +4. **空闲连接清理**:定期清理空闲连接 +5. **错误处理优化**:优化常见网络错误的处理方式 +6. **资源使用监控**:定期记录连接统计信息 + +## 注意事项 + +- 本程序需要以 root 权限运行才能监听 443 端口 +- 确保目标服务器地址可达 +- 正则表达式需要使用有效的 Go 语言正则表达式语法 +- 适当调整超时和连接管理参数,以适应不同的网络环境 + +## 许可证 + +MIT \ No newline at end of file diff --git a/SNI_Proxy b/SNI_Proxy new file mode 100755 index 0000000..733743e Binary files /dev/null and b/SNI_Proxy differ diff --git a/config.yaml b/config.yaml new file mode 100644 index 0000000..0866177 --- /dev/null +++ b/config.yaml @@ -0,0 +1,74 @@ +# SNI代理配置文件 + +# 基本配置 +listen: "0.0.0.0:443" +default_port: 443 +max_conns: 1000 +log_level: "info" # 可选: debug, info, warn, error + +# DNS解析器配置 +dns: + protocol: "udp" # 支持: udp, tcp, dot, doh, doq + server: "8.8.8.8:53" + timeout: 5 + +# 超时配置(秒) +timeout: + connect: 10 + read: 30 + write: 5 + idle: 60 + lifetime: 300 + +# 代理规则 +rules: + # Cloudflare DNS规则 + - domains: + - type: suffix + value: "cloudflare-dns.com" + port: 443 + fragment: + enabled: true + min_size: 10 + max_size: 50 + delay_min: 10 + delay_max: 30 + validate: true + + # API规则 + - domains: + - type: regexp + value: "^api\\.example\\.com$" + port: 443 + fragment: + enabled: true + min_size: 100 + max_size: 500 + delay_min: 5 + delay_max: 15 + validate: true + + # 自定义端口规则 + - domains: + - type: suffix + value: ".example.org" + port: 8443 + fragment: + enabled: false= + + # 关键词匹配规则 + - domains: + - type: keyword + value: "google" + - type: keyword + value: "youtube" + - type: keyword + value: "cloudflare" + port: 443 + fragment: + enabled: true + min_size: 100 + max_size: 200 + delay_min: 8 + delay_max: 25 + validate: true \ No newline at end of file diff --git a/config/config.go b/config/config.go new file mode 100644 index 0000000..32fcc2f --- /dev/null +++ b/config/config.go @@ -0,0 +1,246 @@ +package config + +import ( + "fmt" + "os" + "regexp" + + "gopkg.in/yaml.v3" +) + +// 域名匹配规则类型 +const ( + MatchTypeRegexp = "regexp" + MatchTypeSuffix = "suffix" + MatchTypeKeyword = "keyword" +) + +// DNS协议类型 +const ( + DNSProtocolUDP = "udp" + DNSProtocolTCP = "tcp" + DNSProtocolDoT = "dot" // DNS over TLS + DNSProtocolDoH = "doh" // DNS over HTTPS + DNSProtocolDoQ = "doq" // DNS over QUIC +) + +// 默认配置值 +const ( + DefaultPort = 443 + DefaultDNSProtocol = DNSProtocolUDP + DefaultDNSServer = "8.8.8.8:53" + DefaultDNSTimeout = 5 + DefaultConnectTimeout = 10 + DefaultReadTimeout = 30 + DefaultWriteTimeout = 5 + DefaultIdleTimeout = 60 + DefaultLifeTime = 300 + DefaultLogLevel = "info" + DefaultMaxConns = 1000 + DefaultMinFragSize = 10 + DefaultMaxFragSize = 100 + DefaultMinDelay = 10 + DefaultMaxDelay = 30 +) + +// 超时配置 +type TimeoutConfig struct { + Connect int `yaml:"connect"` // 连接超时(秒) + Read int `yaml:"read"` // 读取超时(秒) + Write int `yaml:"write"` // 写入超时(秒) + Idle int `yaml:"idle"` // 空闲超时(秒) + LifeTime int `yaml:"lifetime"` // 连接最大生命周期(秒) +} + +// TLS分片配置 +type FragmentConfig struct { + Enabled bool `yaml:"enabled"` // 是否启用TLS分片 + MinSize int `yaml:"min_size"` // 最小分片大小(字节) + MaxSize int `yaml:"max_size"` // 最大分片大小(字节) + DelayMin int `yaml:"delay_min"` // 分片之间的最小延迟(毫秒) + DelayMax int `yaml:"delay_max"` // 分片之间的最大延迟(毫秒) + Validate bool `yaml:"validate"` // 是否验证TLS记录完整性 +} + +// DNS解析器配置 +type DNSResolverConfig struct { + Protocol string `yaml:"protocol"` // udp, tcp, dot, doh, doq + Server string `yaml:"server"` // 服务器地址,如 8.8.8.8:53, 1.1.1.1:853 + Timeout int `yaml:"timeout"` // 超时时间(秒) +} + +// 域名匹配规则 +type DomainRule struct { + Type string `yaml:"type"` // regexp, suffix, keyword + Value string `yaml:"value"` // 匹配值 + + // 编译后的正则表达式(仅当Type为regexp时使用) + compiledRegexp *regexp.Regexp +} + +// 代理规则 +type ProxyRule struct { + Domains []DomainRule `yaml:"domains"` + Port int `yaml:"port"` // 目标端口,默认为443 + Fragment FragmentConfig `yaml:"fragment"` +} + +// 配置结构 +type Config struct { + Listen string `yaml:"listen"` + Rules []ProxyRule `yaml:"rules"` + DefaultPort int `yaml:"default_port"` // 默认目标端口,如果规则中未指定 + DNS DNSResolverConfig `yaml:"dns"` // DNS解析器配置 + Timeout TimeoutConfig `yaml:"timeout"` // 超时配置 + LogLevel string `yaml:"log_level"` // 日志级别:debug, info, warn, error + MaxConns int `yaml:"max_conns"` // 最大并发连接数 +} + +// 加载配置文件 +func LoadConfig(path string) (*Config, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, fmt.Errorf("读取配置文件失败: %w", err) + } + + var cfg Config + if err := yaml.Unmarshal(data, &cfg); err != nil { + return nil, fmt.Errorf("解析配置文件失败: %w", err) + } + + // 设置默认值 + setDefaultValues(&cfg) + + // 编译正则表达式 + if err := compileRegexps(&cfg); err != nil { + return nil, err + } + + return &cfg, nil +} + +// 设置默认值 +func setDefaultValues(cfg *Config) { + // 设置默认端口 + if cfg.DefaultPort == 0 { + cfg.DefaultPort = DefaultPort + } + + // 设置默认DNS配置 + if cfg.DNS.Protocol == "" { + cfg.DNS.Protocol = DefaultDNSProtocol + } + if cfg.DNS.Server == "" { + cfg.DNS.Server = DefaultDNSServer + } + if cfg.DNS.Timeout == 0 { + cfg.DNS.Timeout = DefaultDNSTimeout + } + + // 设置默认超时配置 + if cfg.Timeout.Connect == 0 { + cfg.Timeout.Connect = DefaultConnectTimeout + } + if cfg.Timeout.Read == 0 { + cfg.Timeout.Read = DefaultReadTimeout + } + if cfg.Timeout.Write == 0 { + cfg.Timeout.Write = DefaultWriteTimeout + } + if cfg.Timeout.Idle == 0 { + cfg.Timeout.Idle = DefaultIdleTimeout + } + if cfg.Timeout.LifeTime == 0 { + cfg.Timeout.LifeTime = DefaultLifeTime + } + + // 设置默认日志级别 + if cfg.LogLevel == "" { + cfg.LogLevel = DefaultLogLevel + } + + // 设置默认最大连接数 + if cfg.MaxConns == 0 { + cfg.MaxConns = DefaultMaxConns + } + + // 设置规则默认值 + for i := range cfg.Rules { + // 如果规则中未指定端口,使用默认端口 + if cfg.Rules[i].Port == 0 { + cfg.Rules[i].Port = cfg.DefaultPort + } + + // 设置TLS分片配置的默认值 + if cfg.Rules[i].Fragment.Enabled { + setFragmentDefaults(&cfg.Rules[i].Fragment) + } + } +} + +// 设置分片默认值 +func setFragmentDefaults(frag *FragmentConfig) { + // 设置默认的分片大小范围 + if frag.MinSize <= 0 { + frag.MinSize = DefaultMinFragSize + } + if frag.MaxSize <= 0 { + frag.MaxSize = DefaultMaxFragSize + } + if frag.MinSize > frag.MaxSize { + frag.MinSize = frag.MaxSize + } + + // 设置默认的分片延迟范围 + if frag.DelayMin <= 0 { + frag.DelayMin = DefaultMinDelay + } + if frag.DelayMax <= 0 { + frag.DelayMax = DefaultMaxDelay + } + if frag.DelayMin > frag.DelayMax { + frag.DelayMin = frag.DelayMax + } +} + +// 编译正则表达式 +func compileRegexps(cfg *Config) error { + for i := range cfg.Rules { + for j := range cfg.Rules[i].Domains { + if cfg.Rules[i].Domains[j].Type == MatchTypeRegexp { + re, err := regexp.Compile(cfg.Rules[i].Domains[j].Value) + if err != nil { + return fmt.Errorf("编译正则表达式失败 '%s': %w", + cfg.Rules[i].Domains[j].Value, err) + } + cfg.Rules[i].Domains[j].compiledRegexp = re + } + } + } + return nil +} + +// 检查域名是否匹配规则 +func (r *DomainRule) Match(domain string) bool { + switch r.Type { + case MatchTypeRegexp: + return r.compiledRegexp.MatchString(domain) + case MatchTypeSuffix: + return len(domain) >= len(r.Value) && + domain[len(domain)-len(r.Value):] == r.Value + case MatchTypeKeyword: + return contains(domain, r.Value) + default: + return false + } +} + +// 辅助函数:检查字符串是否包含子串 +func contains(s, substr string) bool { + for i := 0; i <= len(s)-len(substr); i++ { + if s[i:i+len(substr)] == substr { + return true + } + } + return false +} diff --git a/dns/resolver.go b/dns/resolver.go new file mode 100644 index 0000000..d49853b --- /dev/null +++ b/dns/resolver.go @@ -0,0 +1,398 @@ +package dns + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "io" + "net" + "net/http" + "strings" + "time" + + "github.com/SNI_Proxy/config" + "github.com/miekg/dns" +) + +// Resolver 是DNS解析器接口 +type Resolver interface { + // Resolve 解析域名为IP地址 + Resolve(domain string) ([]net.IP, error) + // Close 关闭解析器 + Close() error +} + +// 基础DNS解析器结构 +type BaseResolver struct { + server string + timeout time.Duration +} + +// 创建DNS解析器 +func NewResolver(cfg config.DNSResolverConfig) (Resolver, error) { + switch cfg.Protocol { + case config.DNSProtocolUDP: + return NewUDPResolver(cfg) + case config.DNSProtocolTCP: + return NewTCPResolver(cfg) + case config.DNSProtocolDoT: + return NewDoTResolver(cfg) + case config.DNSProtocolDoH: + return NewDoHResolver(cfg) + case config.DNSProtocolDoQ: + return NewDoQResolver(cfg) + default: + return nil, fmt.Errorf("不支持的DNS协议: %s", cfg.Protocol) + } +} + +// UDP DNS解析器 +type UDPResolver struct { + BaseResolver + client *dns.Client +} + +func NewUDPResolver(cfg config.DNSResolverConfig) (*UDPResolver, error) { + timeout := time.Duration(cfg.Timeout) * time.Second + return &UDPResolver{ + BaseResolver: BaseResolver{ + server: cfg.Server, + timeout: timeout, + }, + client: &dns.Client{ + Net: "udp", + Timeout: timeout, + }, + }, nil +} + +// 执行DNS查询 +func dnsQuery(client *dns.Client, server string, domain string, qtype uint16) ([]net.IP, error) { + m := new(dns.Msg) + m.SetQuestion(dns.Fqdn(domain), qtype) + m.RecursionDesired = true + + resp, _, err := client.Exchange(m, server) + if err != nil { + return nil, err + } + + if resp.Rcode != dns.RcodeSuccess { + return nil, fmt.Errorf("DNS查询返回错误码: %d", resp.Rcode) + } + + var ips []net.IP + for _, ans := range resp.Answer { + switch rr := ans.(type) { + case *dns.A: + ips = append(ips, rr.A) + case *dns.AAAA: + ips = append(ips, rr.AAAA) + } + } + + return ips, nil +} + +func (r *UDPResolver) Resolve(domain string) ([]net.IP, error) { + // 先尝试A记录 + ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA) + if err != nil { + return nil, fmt.Errorf("UDP DNS A查询失败: %w", err) + } + + // 如果没有A记录,尝试AAAA记录 + if len(ips) == 0 { + ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA) + if err != nil { + return nil, fmt.Errorf("UDP DNS AAAA查询失败: %w", err) + } + } + + if len(ips) == 0 { + return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain) + } + + return ips, nil +} + +func (r *UDPResolver) Close() error { + return nil // UDP没有需要关闭的资源 +} + +// TCP DNS解析器 +type TCPResolver struct { + BaseResolver + client *dns.Client +} + +func NewTCPResolver(cfg config.DNSResolverConfig) (*TCPResolver, error) { + timeout := time.Duration(cfg.Timeout) * time.Second + return &TCPResolver{ + BaseResolver: BaseResolver{ + server: cfg.Server, + timeout: timeout, + }, + client: &dns.Client{ + Net: "tcp", + Timeout: timeout, + }, + }, nil +} + +func (r *TCPResolver) Resolve(domain string) ([]net.IP, error) { + // 先尝试A记录 + ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA) + if err != nil { + return nil, fmt.Errorf("TCP DNS A查询失败: %w", err) + } + + // 如果没有A记录,尝试AAAA记录 + if len(ips) == 0 { + ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA) + if err != nil { + return nil, fmt.Errorf("TCP DNS AAAA查询失败: %w", err) + } + } + + if len(ips) == 0 { + return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain) + } + + return ips, nil +} + +func (r *TCPResolver) Close() error { + return nil // TCP没有需要关闭的资源 +} + +// DoT DNS解析器 (DNS over TLS) +type DoTResolver struct { + BaseResolver + client *dns.Client +} + +func NewDoTResolver(cfg config.DNSResolverConfig) (*DoTResolver, error) { + timeout := time.Duration(cfg.Timeout) * time.Second + return &DoTResolver{ + BaseResolver: BaseResolver{ + server: cfg.Server, + timeout: timeout, + }, + client: &dns.Client{ + Net: "tcp-tls", + Timeout: timeout, + TLSConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + }, + }, nil +} + +func (r *DoTResolver) Resolve(domain string) ([]net.IP, error) { + // 先尝试A记录 + ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA) + if err != nil { + return nil, fmt.Errorf("DoT DNS A查询失败: %w", err) + } + + // 如果没有A记录,尝试AAAA记录 + if len(ips) == 0 { + ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA) + if err != nil { + return nil, fmt.Errorf("DoT DNS AAAA查询失败: %w", err) + } + } + + if len(ips) == 0 { + return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain) + } + + return ips, nil +} + +func (r *DoTResolver) Close() error { + return nil // DoT没有需要关闭的资源 +} + +// DoH DNS解析器 (DNS over HTTPS) +type DoHResolver struct { + BaseResolver + client *http.Client +} + +func NewDoHResolver(cfg config.DNSResolverConfig) (*DoHResolver, error) { + timeout := time.Duration(cfg.Timeout) * time.Second + return &DoHResolver{ + BaseResolver: BaseResolver{ + server: cfg.Server, + timeout: timeout, + }, + client: &http.Client{ + Timeout: timeout, + Transport: &http.Transport{ + TLSClientConfig: &tls.Config{ + MinVersion: tls.VersionTLS12, + }, + IdleConnTimeout: 30 * time.Second, + }, + }, + }, nil +} + +type dohResponse struct { + Status int `json:"Status"` + TC bool `json:"TC"` + RD bool `json:"RD"` + RA bool `json:"RA"` + AD bool `json:"AD"` + CD bool `json:"CD"` + Question []struct { + Name string `json:"name"` + Type int `json:"type"` + } `json:"Question"` + Answer []struct { + Name string `json:"name"` + Type int `json:"type"` + TTL int `json:"TTL"` + Data string `json:"data"` + } `json:"Answer"` +} + +func (r *DoHResolver) Resolve(domain string) ([]net.IP, error) { + // 构建DoH请求URL + url := r.server + if !strings.HasPrefix(url, "https://") { + url = "https://" + url + } + if !strings.Contains(url, "?") { + url += "?name=" + domain + "&type=A" + } + + // 发送请求 + ctx, cancel := context.WithTimeout(context.Background(), r.timeout) + defer cancel() + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("创建DoH请求失败: %w", err) + } + req.Header.Set("Accept", "application/dns-json") + + resp, err := r.client.Do(req) + if err != nil { + return nil, fmt.Errorf("DoH请求失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("DoH请求返回非200状态码: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取DoH响应失败: %w", err) + } + + var dohResp dohResponse + if err := json.Unmarshal(body, &dohResp); err != nil { + return nil, fmt.Errorf("解析DoH响应失败: %w", err) + } + + if dohResp.Status != 0 { + return nil, fmt.Errorf("DoH查询返回错误码: %d", dohResp.Status) + } + + var ips []net.IP + for _, ans := range dohResp.Answer { + if ans.Type == 1 { // A记录 + ip := net.ParseIP(ans.Data) + if ip != nil { + ips = append(ips, ip) + } + } + } + + // 如果没有A记录,尝试AAAA记录 + if len(ips) == 0 { + url := r.server + if !strings.HasPrefix(url, "https://") { + url = "https://" + url + } + if !strings.Contains(url, "?") { + url += "?name=" + domain + "&type=AAAA" + } else { + url = strings.Replace(url, "type=A", "type=AAAA", 1) + } + + req, err := http.NewRequestWithContext(ctx, "GET", url, nil) + if err != nil { + return nil, fmt.Errorf("创建DoH AAAA请求失败: %w", err) + } + req.Header.Set("Accept", "application/dns-json") + + resp, err := r.client.Do(req) + if err != nil { + return nil, fmt.Errorf("DoH AAAA请求失败: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return nil, fmt.Errorf("DoH AAAA请求返回非200状态码: %d", resp.StatusCode) + } + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("读取DoH AAAA响应失败: %w", err) + } + + if err := json.Unmarshal(body, &dohResp); err != nil { + return nil, fmt.Errorf("解析DoH AAAA响应失败: %w", err) + } + + for _, ans := range dohResp.Answer { + if ans.Type == 28 { // AAAA记录 + ip := net.ParseIP(ans.Data) + if ip != nil { + ips = append(ips, ip) + } + } + } + } + + if len(ips) == 0 { + return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain) + } + + return ips, nil +} + +func (r *DoHResolver) Close() error { + r.client.CloseIdleConnections() + return nil +} + +// DoQ DNS解析器 (DNS over QUIC) +type DoQResolver struct { + BaseResolver +} + +func NewDoQResolver(cfg config.DNSResolverConfig) (*DoQResolver, error) { + return &DoQResolver{ + BaseResolver: BaseResolver{ + server: cfg.Server, + timeout: time.Duration(cfg.Timeout) * time.Second, + }, + }, nil +} + +func (r *DoQResolver) Resolve(domain string) ([]net.IP, error) { + // DoQ实现较为复杂,需要QUIC协议支持 + // 这里仅作为占位符,实际项目中需要实现 + return nil, fmt.Errorf("DoQ协议暂未实现") +} + +func (r *DoQResolver) Close() error { + return nil +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..be335b8 --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module github.com/SNI_Proxy + +go 1.20 + +require ( + github.com/miekg/dns v1.1.58 + gopkg.in/yaml.v3 v3.0.1 +) + +require ( + golang.org/x/mod v0.14.0 // indirect + golang.org/x/net v0.20.0 // indirect + golang.org/x/sys v0.16.0 // indirect + golang.org/x/tools v0.17.0 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..3b0de61 --- /dev/null +++ b/go.sum @@ -0,0 +1,15 @@ +github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4= +github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY= +golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0= +golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= +golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo= +golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY= +golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ= +golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU= +golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= +golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc= +golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/init.txt b/init.txt new file mode 100644 index 0000000..9867d58 --- /dev/null +++ b/init.txt @@ -0,0 +1,3 @@ +give me golang example to make a SNI proxy, which can do tls fragement (configured in config file), the domain match should use regexp / suffix / keyword method. +Ask me more if needed. +curl --resolve "cloudflare-dns.com:443:127.0.0.1" -H "accept: application/dns-json" "https://cloudflare-dns.com/dns-query?name=example.com&type=A" diff --git a/main.go b/main.go new file mode 100644 index 0000000..9cec7ca --- /dev/null +++ b/main.go @@ -0,0 +1,46 @@ +package main + +import ( + "flag" + "log" + "os" + "os/signal" + "syscall" + + "github.com/SNI_Proxy/config" + "github.com/SNI_Proxy/proxy" +) + +func main() { + // 解析命令行参数 + configPath := flag.String("config", "config.yaml", "配置文件路径") + flag.Parse() + + // 加载配置 + cfg, err := config.LoadConfig(*configPath) + if err != nil { + log.Fatalf("加载配置失败: %v", err) + } + + // 创建并启动代理服务器 + server := proxy.NewServer(cfg) + if err := server.Start(); err != nil { + log.Fatalf("启动代理服务器失败: %v", err) + } + + log.Printf("SNI代理服务器已启动,监听地址: %s", cfg.Listen) + + // 等待中断信号以优雅地关闭服务器 + waitForInterrupt(server) +} + +// 等待中断信号 +func waitForInterrupt(server *proxy.Server) { + sigChan := make(chan os.Signal, 1) + signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM) + <-sigChan + + log.Println("正在关闭服务器...") + server.Stop() + log.Println("服务器已关闭") +} diff --git a/proxy/server.go b/proxy/server.go new file mode 100644 index 0000000..b49c87f --- /dev/null +++ b/proxy/server.go @@ -0,0 +1,978 @@ +package proxy + +import ( + "context" + "crypto/rand" + "fmt" + "io" + "log" + "math/big" + "net" + "strconv" + "strings" + "sync" + "sync/atomic" + "time" + + "github.com/SNI_Proxy/config" + "github.com/SNI_Proxy/dns" +) + +// 服务器结构 +type Server struct { + config *config.Config + listener net.Listener + wg sync.WaitGroup + shutdownCh chan struct{} + resolver dns.Resolver + connMutex sync.Mutex + conns map[net.Conn]struct{} // 活跃连接跟踪 + connCounter int64 // 当前连接计数器 +} + +// 创建新的代理服务器 +func NewServer(cfg *config.Config) *Server { + return &Server{ + config: cfg, + shutdownCh: make(chan struct{}), + conns: make(map[net.Conn]struct{}), + } +} + +// 启动代理服务器 +func (s *Server) Start() error { + var err error + + // 创建DNS解析器 + s.resolver, err = dns.NewResolver(s.config.DNS) + if err != nil { + return fmt.Errorf("创建DNS解析器失败: %w", err) + } + + s.listener, err = net.Listen("tcp", s.config.Listen) + if err != nil { + return fmt.Errorf("监听地址失败: %w", err) + } + + s.wg.Add(1) + go s.acceptLoop() + + // 启动连接监控 + go s.monitorConnections() + + // 启动连接清理 + go s.cleanIdleConnections() + + return nil +} + +// 停止代理服务器 +func (s *Server) Stop() { + log.Println("正在关闭服务器...") + close(s.shutdownCh) + if s.listener != nil { + s.listener.Close() + } + + // 关闭所有活跃连接 + s.closeAllConnections() + + if s.resolver != nil { + s.resolver.Close() + } + + // 等待所有goroutine结束 + log.Println("等待所有连接关闭...") + s.wg.Wait() + log.Println("服务器已关闭") +} + +// 关闭所有活跃连接 +func (s *Server) closeAllConnections() { + s.connMutex.Lock() + defer s.connMutex.Unlock() + + log.Printf("关闭 %d 个活跃连接", len(s.conns)) + for conn := range s.conns { + conn.Close() + delete(s.conns, conn) + } +} + +// 添加连接到跟踪列表 +func (s *Server) trackConn(conn net.Conn) bool { + s.connMutex.Lock() + defer s.connMutex.Unlock() + + // 检查是否超过最大连接数 + if s.config.MaxConns > 0 && len(s.conns) >= s.config.MaxConns { + log.Printf("达到最大连接数 %d,拒绝新连接", s.config.MaxConns) + return false + } + + s.conns[conn] = struct{}{} + atomic.AddInt64(&s.connCounter, 1) + return true +} + +// 从跟踪列表中移除连接 +func (s *Server) untrackConn(conn net.Conn) { + s.connMutex.Lock() + defer s.connMutex.Unlock() + + if _, exists := s.conns[conn]; exists { + delete(s.conns, conn) + atomic.AddInt64(&s.connCounter, -1) + } +} + +// 监控连接状态 +func (s *Server) monitorConnections() { + ticker := time.NewTicker(30 * time.Second) + defer ticker.Stop() + + for { + select { + case <-s.shutdownCh: + return + case <-ticker.C: + // 记录当前活跃连接数 + s.connMutex.Lock() + activeConns := len(s.conns) + s.connMutex.Unlock() + + totalConns := atomic.LoadInt64(&s.connCounter) + log.Printf("连接统计 - 当前活跃: %d, 总计处理: %d", activeConns, totalConns) + } + } +} + +// 清理空闲连接 +func (s *Server) cleanIdleConnections() { + // 每分钟检查一次空闲连接 + ticker := time.NewTicker(1 * time.Minute) + defer ticker.Stop() + + for { + select { + case <-s.shutdownCh: + return + case <-ticker.C: + s.connMutex.Lock() + log.Printf("开始清理空闲连接,当前连接数: %d", len(s.conns)) + // 这里我们不做实际清理,因为连接已经设置了超时 + // 实际清理由各个连接自己的超时机制处理 + s.connMutex.Unlock() + } + } +} + +// 接受连接循环 +func (s *Server) acceptLoop() { + defer s.wg.Done() + + for { + conn, err := s.listener.Accept() + if err != nil { + select { + case <-s.shutdownCh: + return + default: + log.Printf("接受连接失败: %v", err) + // 短暂休眠,避免CPU占用过高 + time.Sleep(100 * time.Millisecond) + continue + } + } + + // 检查是否可以接受新连接 + if !s.trackConn(conn) { + conn.Close() + continue + } + + s.wg.Add(1) + + go func(clientConn net.Conn) { + defer s.wg.Done() + defer s.untrackConn(clientConn) // 移除连接跟踪 + defer clientConn.Close() + + // 设置连接生命周期 + if s.config.Timeout.LifeTime > 0 { + deadline := time.Now().Add(time.Duration(s.config.Timeout.LifeTime) * time.Second) + clientConn.SetDeadline(deadline) + } + + // 创建一个带超时的上下文 + ctx, cancel := context.WithTimeout(context.Background(), + time.Duration(s.config.Timeout.Idle)*time.Second) + defer cancel() + + // 使用goroutine和channel处理连接,支持超时控制 + errCh := make(chan error, 1) + go func() { + errCh <- s.handleConnection(clientConn) + }() + + // 等待处理完成或超时 + select { + case err := <-errCh: + if err != nil { + if s.config.LogLevel == "debug" || + (s.config.LogLevel != "error" && !isCommonNetworkError(err)) { + log.Printf("处理连接失败: %v", err) + } + } + case <-ctx.Done(): + log.Printf("处理连接超时") + case <-s.shutdownCh: + log.Printf("服务器关闭,终止连接处理") + } + }(conn) + } +} + +// 判断是否是常见网络错误(可以不记录日志的错误) +func isCommonNetworkError(err error) bool { + if err == nil { + return false + } + + errStr := err.Error() + // 检查常见的网络错误 + commonErrors := []string{ + "connection reset by peer", + "broken pipe", + "i/o timeout", + "use of closed network connection", + "EOF", + } + + for _, e := range commonErrors { + if strings.Contains(strings.ToLower(errStr), e) { + return true + } + } + + return false +} + +// 处理单个连接 +func (s *Server) handleConnection(clientConn net.Conn) error { + // 设置读取超时 + readTimeout := time.Duration(s.config.Timeout.Read) * time.Second + if err := clientConn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil { + return fmt.Errorf("设置读取超时失败: %w", err) + } + + // 读取并解析SNI信息 + sni, clientHello, err := readSNI(clientConn) + if err != nil { + return fmt.Errorf("读取SNI失败: %w", err) + } + + // 重置读取超时 + if err := clientConn.SetReadDeadline(time.Time{}); err != nil { + return fmt.Errorf("重置读取超时失败: %w", err) + } + + // 查找匹配的规则 + var targetRule *config.ProxyRule + for i := range s.config.Rules { + rule := &s.config.Rules[i] + for j := range rule.Domains { + if rule.Domains[j].Match(sni) { + targetRule = rule + break + } + } + if targetRule != nil { + break + } + } + + if targetRule == nil { + // 如果没有匹配的规则,使用默认端口 + if s.config.LogLevel != "error" { + log.Printf("未找到匹配的规则,使用默认端口 %d 连接到 %s", s.config.DefaultPort, sni) + } + return s.proxyConnection(clientConn, sni, s.config.DefaultPort, clientHello, config.FragmentConfig{Enabled: false}) + } + + // 使用匹配规则中的端口和分片配置 + return s.proxyConnection(clientConn, sni, targetRule.Port, clientHello, targetRule.Fragment) +} + +// 代理连接到目标服务器 +func (s *Server) proxyConnection(clientConn net.Conn, targetHost string, targetPort int, clientHello []byte, fragConfig config.FragmentConfig) error { + // 设置DNS解析超时 + ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.DNS.Timeout)*time.Second) + defer cancel() + + // 创建一个带超时的解析任务 + type resolveResult struct { + ips []net.IP + err error + } + + resCh := make(chan resolveResult, 1) + go func() { + ips, err := s.resolver.Resolve(targetHost) + resCh <- resolveResult{ips, err} + }() + + // 等待DNS解析完成或超时 + var ips []net.IP + select { + case res := <-resCh: + if res.err != nil { + return fmt.Errorf("解析域名失败: %w", res.err) + } + ips = res.ips + case <-ctx.Done(): + return fmt.Errorf("解析域名超时") + } + + // 选择第一个IP地址 + targetIP := ips[0] + if s.config.LogLevel == "debug" { + log.Printf("域名 %s 解析为 %s", targetHost, targetIP.String()) + } + + // 构建目标地址 + targetAddr := net.JoinHostPort(targetIP.String(), strconv.Itoa(targetPort)) + if s.config.LogLevel != "error" { + log.Printf("代理连接到: %s (原始SNI: %s)", targetAddr, targetHost) + } + + // 设置连接超时 + dialCtx, dialCancel := context.WithTimeout(context.Background(), + time.Duration(s.config.Timeout.Connect)*time.Second) + defer dialCancel() + + // 连接到目标服务器 + var d net.Dialer + targetConn, err := d.DialContext(dialCtx, "tcp", targetAddr) + if err != nil { + return fmt.Errorf("连接目标服务器失败: %w", err) + } + defer targetConn.Close() + + // 设置目标连接的读写超时 + if tcpConn, ok := targetConn.(*net.TCPConn); ok { + tcpConn.SetKeepAlive(true) + tcpConn.SetKeepAlivePeriod(30 * time.Second) + } + + // 发送ClientHello到目标服务器,可能需要进行TLS分片 + if err := s.sendClientHello(targetConn, clientHello, fragConfig); err != nil { + return fmt.Errorf("发送ClientHello失败: %w", err) + } + + // 双向转发数据 + errCh := make(chan error, 2) + copyDone := make(chan struct{}, 2) + + // 客户端 -> 目标服务器 + go func() { + buf := make([]byte, 32*1024) // 使用较大的缓冲区 + _, err := s.copyBuffer(targetConn, clientConn, buf) + errCh <- err + copyDone <- struct{}{} + }() + + // 目标服务器 -> 客户端 + go func() { + buf := make([]byte, 32*1024) // 使用较大的缓冲区 + _, err := s.copyBuffer(clientConn, targetConn, buf) + errCh <- err + copyDone <- struct{}{} + }() + + // 等待任一方向的数据传输完成 + <-copyDone + + // 设置一个短暂的超时,让另一个方向有机会完成 + time.Sleep(100 * time.Millisecond) + + // 检查是否有错误 + select { + case err := <-errCh: + if err != nil && err != io.EOF && !isCommonNetworkError(err) { + return fmt.Errorf("数据转发失败: %w", err) + } + default: + // 没有错误 + } + + return nil +} + +// 带超时控制的数据复制 +func (s *Server) copyBuffer(dst net.Conn, src net.Conn, buf []byte) (int64, error) { + var written int64 + readTimeout := time.Duration(s.config.Timeout.Read) * time.Second + writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second + + for { + // 设置读取超时 + src.SetReadDeadline(time.Now().Add(readTimeout)) + + nr, err := src.Read(buf) + if nr > 0 { + // 设置写入超时 + dst.SetWriteDeadline(time.Now().Add(writeTimeout)) + + nw, err := dst.Write(buf[0:nr]) + if nw > 0 { + written += int64(nw) + } + if err != nil { + return written, err + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if err != nil { + if err == io.EOF { + return written, nil + } + return written, err + } + } +} + +// 发送ClientHello,根据配置进行TLS分片 +func (s *Server) sendClientHello(conn net.Conn, clientHello []byte, fragConfig config.FragmentConfig) error { + writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second + + if s.config.LogLevel == "debug" { + printTLSRecord(clientHello, "原始ClientHello") + } + + if !fragConfig.Enabled { + // 不进行分片,直接发送 + if s.config.LogLevel == "debug" { + log.Printf("TLS分片未启用,直接发送 %d 字节", len(clientHello)) + } + // 设置写入超时 + conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + _, err := conn.Write(clientHello) + return err + } + + // 进行TLS分片 + if s.config.LogLevel == "debug" { + log.Printf("启用TLS分片,分片大小范围: %d-%d", fragConfig.MinSize, fragConfig.MaxSize) + } + + // 检查ClientHello是否是有效的TLS记录 + if fragConfig.Validate { + valid, reason := validateTLSRecord(clientHello) + if !valid { + log.Printf("警告: ClientHello不是有效的TLS记录 (%s),跳过分片", reason) + conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + _, err := conn.Write(clientHello) + return err + } + } + + if clientHello[0] != recordTypeHandshake { + log.Printf("警告: 记录类型不是握手类型,跳过分片") + conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + _, err := conn.Write(clientHello) + return err + } + + // 获取TLS记录头部和数据部分 + recordHeader := clientHello[:5] + recordData := clientHello[5:] + recordLength := len(recordData) + + if s.config.LogLevel == "debug" { + log.Printf("TLS记录头部: %v, 数据长度: %d", recordHeader, recordLength) + + // 如果是握手消息,打印握手消息头部信息 + if recordData[0] == handshakeTypeClientHello && len(recordData) >= 4 { + handshakeType := recordData[0] + handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3]) + log.Printf("握手消息头部: 类型=%d, 长度=%d", handshakeType, handshakeLength) + } + } + + // 如果记录太小,不进行分片 + if recordLength <= fragConfig.MinSize { + if s.config.LogLevel == "debug" { + log.Printf("记录太小 (%d 字节),不进行分片", recordLength) + } + conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + _, err := conn.Write(clientHello) + return err + } + + // 第一个分片:发送记录头部和部分数据 + minSize := fragConfig.MinSize + if minSize <= 0 { + minSize = 10 // 默认最小分片大小 + } + maxSize := fragConfig.MaxSize + if maxSize <= 0 { + maxSize = 100 // 默认最大分片大小 + } + if minSize > maxSize { + minSize = maxSize + } + + // 随机分片大小 + size := minSize + if maxSize > minSize { + randVal, _ := rand.Int(rand.Reader, big.NewInt(int64(maxSize-minSize+1))) + size = minSize + int(randVal.Int64()) + } + + // 确保不超过记录数据大小 + if size > recordLength { + size = recordLength + } + + // 对于握手消息,确保第一个分片至少包含完整的握手消息头部(4字节) + if recordData[0] == handshakeTypeClientHello && size < 4 { + size = 4 // 至少包含握手消息头部 + } + + // 创建第一个分片 + firstFragment := make([]byte, 5+size) + copy(firstFragment[:5], recordHeader) + copy(firstFragment[5:], recordData[:size]) + + // 修改第一个分片的长度字段 + firstFragment[3] = byte(size >> 8) + firstFragment[4] = byte(size) + + if s.config.LogLevel == "debug" { + printTLSRecord(firstFragment, "第一个分片") + log.Printf("第一个分片长度: %d (头部5字节 + 数据%d字节)", len(firstFragment), size) + + // 验证第一个分片的握手消息 + if recordData[0] == handshakeTypeClientHello { + // 创建一个临时的握手消息进行验证 + tempHandshake := make([]byte, size) + copy(tempHandshake, recordData[:size]) + // 调整握手消息长度以匹配实际数据 + handshakeLength := uint32(size - 4) // 减去握手消息头部的4字节 + tempHandshake[1] = byte(handshakeLength >> 16) + tempHandshake[2] = byte(handshakeLength >> 8) + tempHandshake[3] = byte(handshakeLength) + + valid, reason := validateHandshakeMessage(tempHandshake) + log.Printf("第一个分片握手消息验证: %v, %s", valid, reason) + } + } + + // 设置写入超时并发送第一个分片 + conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if _, err := conn.Write(firstFragment); err != nil { + return fmt.Errorf("发送第一个分片失败: %w", err) + } + + // 如果还有剩余数据,创建第二个分片 + if size < recordLength { + // 添加一个小延迟,模拟网络延迟,提高分片效果 + delayMin := fragConfig.DelayMin + if delayMin <= 0 { + delayMin = 10 // 默认最小延迟10毫秒 + } + + delayMax := fragConfig.DelayMax + if delayMax <= 0 { + delayMax = 30 // 默认最大延迟30毫秒 + } + + if delayMin > delayMax { + delayMin = delayMax + } + + delayRange := big.NewInt(int64(delayMax - delayMin + 1)) + randVal, _ := rand.Int(rand.Reader, delayRange) + delay := time.Duration(delayMin+int(randVal.Int64())) * time.Millisecond + time.Sleep(delay) + + if s.config.LogLevel == "debug" { + log.Printf("延迟 %v 后发送第二个分片", delay) + } + + // 创建第二个分片 + remainingSize := recordLength - size + secondFragment := make([]byte, 5+remainingSize) + + // 复制记录头部(保持与原始记录头部一致) + copy(secondFragment[:5], recordHeader) + + // 设置正确的长度字段 + secondFragment[3] = byte(remainingSize >> 8) + secondFragment[4] = byte(remainingSize) + + // 复制剩余数据,确保握手类型保持一致 + // 注意:我们不能直接复制数据,因为第二个分片的握手消息头部需要特殊处理 + if recordData[0] == handshakeTypeClientHello { + // 创建新的握手消息头部 + // 1. 保持握手类型不变 + // 2. 调整握手消息长度 + handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3]) + + // 计算第一个分片中已发送的握手消息数据长度(不包括握手消息头部的4字节) + sentDataLength := uint32(size) - 4 + + // 计算剩余的握手消息长度 + remainingHandshakeLength := handshakeLength - sentDataLength + + if s.config.LogLevel == "debug" { + log.Printf("握手消息总长度: %d, 第一个分片已发送数据: %d, 第二个分片剩余数据: %d", + handshakeLength, sentDataLength, remainingHandshakeLength) + } + + // 复制握手类型 + secondFragment[5] = recordData[0] + + // 设置新的握手消息长度 + secondFragment[6] = byte(remainingHandshakeLength >> 16) + secondFragment[7] = byte(remainingHandshakeLength >> 8) + secondFragment[8] = byte(remainingHandshakeLength) + + // 复制剩余数据(跳过第一个分片已经发送的部分) + copy(secondFragment[9:], recordData[size:]) + } else { + // 如果不是ClientHello,直接复制数据 + copy(secondFragment[5:], recordData[size:]) + } + + // 对于第二个分片,我们采用不同的策略 + // 不再尝试修改握手消息头部,而是直接将剩余数据作为应用数据发送 + // 这样可以避免TLS握手解析错误 + + // 修改第二个分片的记录类型为应用数据 + secondFragment[0] = 23 // ApplicationData + + // 复制剩余数据 + copy(secondFragment[5:], recordData[size:]) + + if s.config.LogLevel == "debug" { + log.Printf("第二个分片使用应用数据类型,避免握手解析错误") + printTLSRecord(secondFragment, "第二个分片") + log.Printf("第二个分片长度: %d (头部5字节 + 数据%d字节)", len(secondFragment), remainingSize) + } + + // 设置写入超时并发送第二个分片 + conn.SetWriteDeadline(time.Now().Add(writeTimeout)) + if _, err := conn.Write(secondFragment); err != nil { + return fmt.Errorf("发送第二个分片失败: %w", err) + } + + if s.config.LogLevel == "debug" { + log.Printf("TLS分片完成: 总共发送 %d 字节 (第一分片: %d 字节, 第二分片: %d 字节)", + len(firstFragment)+len(secondFragment), len(firstFragment), len(secondFragment)) + } + } else { + if s.config.LogLevel == "debug" { + log.Printf("TLS分片完成: 只需一个分片,总共发送 %d 字节", len(firstFragment)) + } + } + + return nil +} + +// TLS记录类型 +const ( + recordTypeHandshake = 22 +) + +// TLS握手类型 +const ( + handshakeTypeClientHello = 1 +) + +// 读取SNI信息 +func readSNI(conn net.Conn) (string, []byte, error) { + // 读取TLS记录头部 + header := make([]byte, 5) + if _, err := io.ReadFull(conn, header); err != nil { + return "", nil, fmt.Errorf("读取TLS记录头部失败: %w", err) + } + + // 检查是否是TLS握手记录 + if header[0] != recordTypeHandshake { + return "", nil, fmt.Errorf("不是TLS握手记录,记录类型: %d", header[0]) + } + + // 获取记录长度 + recordLength := int(header[3])<<8 | int(header[4]) + if recordLength > 16384 { + return "", nil, fmt.Errorf("TLS记录太长: %d 字节", recordLength) + } + + if recordLength < 4 { + return "", nil, fmt.Errorf("TLS握手记录太短: %d 字节", recordLength) + } + + // 读取握手消息 + handshakeData := make([]byte, recordLength) + if _, err := io.ReadFull(conn, handshakeData); err != nil { + return "", nil, fmt.Errorf("读取握手消息失败: %w", err) + } + + // 检查是否是ClientHello + if handshakeData[0] != handshakeTypeClientHello { + return "", nil, fmt.Errorf("不是ClientHello消息,握手类型: %d", handshakeData[0]) + } + + // 完整的ClientHello消息(包括记录头部) + clientHello := append(header, handshakeData...) + + // 验证TLS记录的完整性 + valid, reason := validateTLSRecord(clientHello) + if !valid { + log.Printf("警告: 收到的ClientHello不是有效的TLS记录: %s", reason) + } + + // 解析SNI扩展 + sni, err := extractSNI(handshakeData) + if err != nil { + return "", nil, fmt.Errorf("提取SNI失败: %w", err) + } + + if sni != "" && len(sni) > 0 { + log.Printf("成功提取SNI: %s", sni) + } + + return sni, clientHello, nil +} + +// 从ClientHello中提取SNI +func extractSNI(data []byte) (string, error) { + if len(data) < 42 { + return "", fmt.Errorf("ClientHello太短") + } + + // 跳过握手类型(1) + 长度(3) + 协议版本(2) + 随机数(32) + pos := 38 + + // 跳过会话ID + if pos+1 > len(data) { + return "", fmt.Errorf("解析会话ID长度时数据不足") + } + sessionIDLength := int(data[pos]) + pos += 1 + sessionIDLength + + // 跳过密码套件 + if pos+2 > len(data) { + return "", fmt.Errorf("解析密码套件长度时数据不足") + } + cipherSuitesLength := int(data[pos])<<8 | int(data[pos+1]) + pos += 2 + cipherSuitesLength + + // 跳过压缩方法 + if pos+1 > len(data) { + return "", fmt.Errorf("解析压缩方法长度时数据不足") + } + compressionMethodsLength := int(data[pos]) + pos += 1 + compressionMethodsLength + + // 检查是否有扩展 + if pos+2 > len(data) { + return "", fmt.Errorf("没有扩展数据") + } + extensionsLength := int(data[pos])<<8 | int(data[pos+1]) + pos += 2 + + // 解析扩展 + extensionsEnd := pos + extensionsLength + if extensionsEnd > len(data) { + return "", fmt.Errorf("扩展数据不足") + } + + for pos < extensionsEnd { + // 读取扩展类型 + if pos+4 > len(data) { + return "", fmt.Errorf("解析扩展类型时数据不足") + } + extensionType := int(data[pos])<<8 | int(data[pos+1]) + extensionLength := int(data[pos+2])<<8 | int(data[pos+3]) + pos += 4 + + // 检查是否是SNI扩展 (类型为0) + if extensionType == 0 { + // 跳过SNI列表长度 + if pos+2 > len(data) { + return "", fmt.Errorf("解析SNI列表长度时数据不足") + } + pos += 2 + + // 检查SNI类型 + if pos+1 > len(data) { + return "", fmt.Errorf("解析SNI类型时数据不足") + } + if data[pos] != 0 { + return "", fmt.Errorf("不支持的SNI类型") + } + pos++ + + // 读取主机名长度 + if pos+2 > len(data) { + return "", fmt.Errorf("解析主机名长度时数据不足") + } + hostnameLength := int(data[pos])<<8 | int(data[pos+1]) + pos += 2 + + // 读取主机名 + if pos+hostnameLength > len(data) { + return "", fmt.Errorf("解析主机名时数据不足") + } + hostname := string(data[pos : pos+hostnameLength]) + return hostname, nil + } + + // 跳过当前扩展 + pos += extensionLength + } + + return "", fmt.Errorf("未找到SNI扩展") +} + +// 打印TLS记录的详细信息(用于调试) +func printTLSRecord(data []byte, prefix string) { + if len(data) < 5 { + log.Printf("%s: 数据太短,不是有效的TLS记录", prefix) + return + } + + recordType := data[0] + version := (uint16(data[1]) << 8) | uint16(data[2]) + length := (uint16(data[3]) << 8) | uint16(data[4]) + + var recordTypeStr string + switch recordType { + case 20: + recordTypeStr = "ChangeCipherSpec" + case 21: + recordTypeStr = "Alert" + case 22: + recordTypeStr = "Handshake" + case 23: + recordTypeStr = "ApplicationData" + default: + recordTypeStr = fmt.Sprintf("Unknown(%d)", recordType) + } + + var versionStr string + switch version { + case 0x0301: + versionStr = "TLS 1.0" + case 0x0302: + versionStr = "TLS 1.1" + case 0x0303: + versionStr = "TLS 1.2" + case 0x0304: + versionStr = "TLS 1.3" + default: + versionStr = fmt.Sprintf("Unknown(0x%04x)", version) + } + + log.Printf("%s: 类型=%s, 版本=%s, 长度=%d, 总字节数=%d", + prefix, recordTypeStr, versionStr, length, len(data)) + + if recordType == recordTypeHandshake && len(data) >= 6 { + handshakeType := data[5] + var handshakeTypeStr string + switch handshakeType { + case 1: + handshakeTypeStr = "ClientHello" + case 2: + handshakeTypeStr = "ServerHello" + case 11: + handshakeTypeStr = "Certificate" + case 16: + handshakeTypeStr = "ClientKeyExchange" + default: + handshakeTypeStr = fmt.Sprintf("Unknown(%d)", handshakeType) + } + log.Printf("%s: 握手类型=%s", prefix, handshakeTypeStr) + } +} + +// 验证TLS记录的完整性 +func validateTLSRecord(data []byte) (bool, string) { + if len(data) < 5 { + return false, "数据太短,不是有效的TLS记录" + } + + recordType := data[0] + length := (uint16(data[3]) << 8) | uint16(data[4]) + + // 检查记录类型是否有效 + validTypes := map[byte]bool{ + 20: true, // ChangeCipherSpec + 21: true, // Alert + 22: true, // Handshake + 23: true, // ApplicationData + } + if !validTypes[recordType] { + return false, fmt.Sprintf("无效的TLS记录类型: %d", recordType) + } + + // 检查记录长度是否与实际数据长度匹配 + if int(length) != len(data)-5 { + return false, fmt.Sprintf("TLS记录长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", length, len(data)-5) + } + + // 对于握手记录,进行更详细的验证 + if recordType == recordTypeHandshake && len(data) >= 9 { + // 握手消息至少需要4字节的头部 + handshakeType := data[5] + handshakeLength := (uint32(data[6]) << 16) | (uint32(data[7]) << 8) | uint32(data[8]) + + // 检查握手消息长度是否与记录长度匹配 + if int(handshakeLength)+4 > int(length) { + return false, fmt.Sprintf("握手消息长度不匹配: 握手头部指示 %d 字节,记录头部指示 %d 字节", handshakeLength, length) + } + + // 对于ClientHello,进行更详细的验证 + if handshakeType == handshakeTypeClientHello && len(data) >= 43 { + // ClientHello至少需要包含协议版本(2)、随机数(32)和会话ID长度(1) + return true, "有效的ClientHello记录" + } + + return true, fmt.Sprintf("有效的握手记录,类型: %d", handshakeType) + } + + return true, fmt.Sprintf("有效的TLS记录,类型: %d", recordType) +} + +// 验证握手消息的完整性 +func validateHandshakeMessage(data []byte) (bool, string) { + if len(data) < 4 { + return false, "数据太短,不是有效的握手消息" + } + + handshakeType := data[0] + handshakeLength := (uint32(data[1]) << 16) | (uint32(data[2]) << 8) | uint32(data[3]) + + // 检查握手消息长度是否与实际数据长度匹配 + if int(handshakeLength)+4 != len(data) { + return false, fmt.Sprintf("握手消息长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", handshakeLength, len(data)-4) + } + + // 对于ClientHello,进行更详细的验证 + if handshakeType == handshakeTypeClientHello { + if len(data) < 38 { + return false, "ClientHello消息太短" + } + + // 检查协议版本、随机数等字段 + // 这里只是简单检查长度,实际应用中可以更详细地验证 + return true, "有效的ClientHello消息" + } + + return true, fmt.Sprintf("有效的握手消息,类型: %d", handshakeType) +}