update
commit
da892685bc
|
@ -0,0 +1,149 @@
|
|||
# SNI 代理
|
||||
|
||||
这是一个基于 Go 语言实现的 SNI 代理,支持 TLS 分片功能和多种域名匹配方法。
|
||||
|
||||
## 功能特点
|
||||
|
||||
- 直接使用 SNI 域名作为目标地址
|
||||
- 支持多种 DNS 协议解析域名(UDP/TCP/DoT/DoH/DoQ)
|
||||
- 支持 TLS 分片(可配置分片大小范围)
|
||||
- 多种域名匹配方式:
|
||||
- 正则表达式匹配
|
||||
- 后缀匹配
|
||||
- 关键词匹配
|
||||
- 完善的超时控制和连接管理
|
||||
- 基于配置文件的灵活配置
|
||||
|
||||
## 安装
|
||||
|
||||
### 从源码构建
|
||||
|
||||
```bash
|
||||
# 克隆仓库
|
||||
git clone https://github.com/yourusername/SNI_Proxy.git
|
||||
cd SNI_Proxy
|
||||
|
||||
# 安装依赖
|
||||
go mod tidy
|
||||
|
||||
# 构建
|
||||
go build -o sni_proxy
|
||||
```
|
||||
|
||||
## 使用方法
|
||||
|
||||
1. 创建配置文件 `config.yaml`(参考示例配置文件)
|
||||
2. 运行代理服务器:
|
||||
|
||||
```bash
|
||||
./sni_proxy -config config.yaml
|
||||
```
|
||||
|
||||
## 配置文件说明
|
||||
|
||||
配置文件使用 YAML 格式,包含以下主要部分:
|
||||
|
||||
```yaml
|
||||
# 监听地址
|
||||
listen: "0.0.0.0:443"
|
||||
|
||||
# 默认目标端口
|
||||
default_port: 443
|
||||
|
||||
# DNS解析器配置
|
||||
dns:
|
||||
protocol: "udp" # 支持 udp, tcp, dot, doh, doq
|
||||
server: "8.8.8.8:53" # DNS服务器地址
|
||||
timeout: 5 # 超时时间(秒)
|
||||
|
||||
# 超时配置
|
||||
timeout:
|
||||
connect: 10 # 连接超时(秒)
|
||||
read: 30 # 读取超时(秒)
|
||||
write: 5 # 写入超时(秒)
|
||||
idle: 60 # 空闲超时(秒)
|
||||
lifetime: 300 # 连接最大生命周期(秒)
|
||||
|
||||
# 日志级别: debug, info, warn, error
|
||||
log_level: "info"
|
||||
|
||||
# 最大并发连接数
|
||||
max_conns: 1000
|
||||
|
||||
# 代理规则
|
||||
rules:
|
||||
# 规则示例
|
||||
- domains:
|
||||
- type: regexp # 支持 regexp, suffix, keyword
|
||||
value: "^api\\.example\\.com$"
|
||||
port: 443 # 目标端口
|
||||
fragment:
|
||||
enabled: true # 是否启用 TLS 分片
|
||||
min_size: 100 # 最小分片大小
|
||||
max_size: 500 # 最大分片大小
|
||||
```
|
||||
|
||||
### 工作原理
|
||||
|
||||
1. 代理服务器接收客户端的 TLS 连接
|
||||
2. 从 ClientHello 消息中提取 SNI(服务器名称指示)
|
||||
3. 根据配置的规则匹配 SNI 域名
|
||||
4. 使用配置的 DNS 解析器将 SNI 域名解析为 IP 地址
|
||||
5. 使用解析得到的 IP 地址和配置的端口连接到目标服务器
|
||||
6. 将客户端的请求转发到目标服务器,可选择进行 TLS 分片
|
||||
|
||||
### 域名匹配方式
|
||||
|
||||
1. **正则表达式匹配** (`regexp`):使用正则表达式匹配域名
|
||||
2. **后缀匹配** (`suffix`):匹配指定后缀的域名
|
||||
3. **关键词匹配** (`keyword`):匹配包含指定关键词的域名
|
||||
|
||||
### DNS 解析器配置
|
||||
|
||||
- `protocol`: DNS 协议类型,支持以下选项:
|
||||
- `udp`: 标准 UDP DNS(默认)
|
||||
- `tcp`: 标准 TCP DNS
|
||||
- `dot`: DNS over TLS
|
||||
- `doh`: DNS over HTTPS
|
||||
- `doq`: DNS over QUIC(目前不支持)
|
||||
- `server`: DNS 服务器地址,格式为 `IP:端口`
|
||||
- `timeout`: DNS 查询超时时间(秒)
|
||||
|
||||
### 超时和连接管理配置
|
||||
|
||||
- `timeout`: 超时配置
|
||||
- `connect`: 连接目标服务器的超时时间(秒)
|
||||
- `read`: 读取数据的超时时间(秒)
|
||||
- `write`: 写入数据的超时时间(秒)
|
||||
- `idle`: 空闲连接的超时时间(秒)
|
||||
- `lifetime`: 连接的最大生命周期(秒)
|
||||
- `log_level`: 日志级别,支持 debug、info、warn、error
|
||||
- `max_conns`: 最大并发连接数,超过此数量的新连接将被拒绝
|
||||
|
||||
### TLS 分片配置
|
||||
|
||||
- `enabled`: 是否启用 TLS 分片
|
||||
- `min_size`: 最小分片大小(字节)
|
||||
- `max_size`: 最大分片大小(字节)
|
||||
|
||||
## 性能优化
|
||||
|
||||
为了避免进程卡死和资源泄漏,本代理实现了以下优化:
|
||||
|
||||
1. **完善的超时控制**:为每个连接阶段设置合理的超时时间
|
||||
2. **连接生命周期管理**:限制连接的最大生存时间
|
||||
3. **连接数量限制**:防止过多连接导致资源耗尽
|
||||
4. **空闲连接清理**:定期清理空闲连接
|
||||
5. **错误处理优化**:优化常见网络错误的处理方式
|
||||
6. **资源使用监控**:定期记录连接统计信息
|
||||
|
||||
## 注意事项
|
||||
|
||||
- 本程序需要以 root 权限运行才能监听 443 端口
|
||||
- 确保目标服务器地址可达
|
||||
- 正则表达式需要使用有效的 Go 语言正则表达式语法
|
||||
- 适当调整超时和连接管理参数,以适应不同的网络环境
|
||||
|
||||
## 许可证
|
||||
|
||||
MIT
|
|
@ -0,0 +1,74 @@
|
|||
# SNI代理配置文件
|
||||
|
||||
# 基本配置
|
||||
listen: "0.0.0.0:443"
|
||||
default_port: 443
|
||||
max_conns: 1000
|
||||
log_level: "info" # 可选: debug, info, warn, error
|
||||
|
||||
# DNS解析器配置
|
||||
dns:
|
||||
protocol: "udp" # 支持: udp, tcp, dot, doh, doq
|
||||
server: "8.8.8.8:53"
|
||||
timeout: 5
|
||||
|
||||
# 超时配置(秒)
|
||||
timeout:
|
||||
connect: 10
|
||||
read: 30
|
||||
write: 5
|
||||
idle: 60
|
||||
lifetime: 300
|
||||
|
||||
# 代理规则
|
||||
rules:
|
||||
# Cloudflare DNS规则
|
||||
- domains:
|
||||
- type: suffix
|
||||
value: "cloudflare-dns.com"
|
||||
port: 443
|
||||
fragment:
|
||||
enabled: true
|
||||
min_size: 10
|
||||
max_size: 50
|
||||
delay_min: 10
|
||||
delay_max: 30
|
||||
validate: true
|
||||
|
||||
# API规则
|
||||
- domains:
|
||||
- type: regexp
|
||||
value: "^api\\.example\\.com$"
|
||||
port: 443
|
||||
fragment:
|
||||
enabled: true
|
||||
min_size: 100
|
||||
max_size: 500
|
||||
delay_min: 5
|
||||
delay_max: 15
|
||||
validate: true
|
||||
|
||||
# 自定义端口规则
|
||||
- domains:
|
||||
- type: suffix
|
||||
value: ".example.org"
|
||||
port: 8443
|
||||
fragment:
|
||||
enabled: false=
|
||||
|
||||
# 关键词匹配规则
|
||||
- domains:
|
||||
- type: keyword
|
||||
value: "google"
|
||||
- type: keyword
|
||||
value: "youtube"
|
||||
- type: keyword
|
||||
value: "cloudflare"
|
||||
port: 443
|
||||
fragment:
|
||||
enabled: true
|
||||
min_size: 100
|
||||
max_size: 200
|
||||
delay_min: 8
|
||||
delay_max: 25
|
||||
validate: true
|
|
@ -0,0 +1,246 @@
|
|||
package config
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"regexp"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
// 域名匹配规则类型
|
||||
const (
|
||||
MatchTypeRegexp = "regexp"
|
||||
MatchTypeSuffix = "suffix"
|
||||
MatchTypeKeyword = "keyword"
|
||||
)
|
||||
|
||||
// DNS协议类型
|
||||
const (
|
||||
DNSProtocolUDP = "udp"
|
||||
DNSProtocolTCP = "tcp"
|
||||
DNSProtocolDoT = "dot" // DNS over TLS
|
||||
DNSProtocolDoH = "doh" // DNS over HTTPS
|
||||
DNSProtocolDoQ = "doq" // DNS over QUIC
|
||||
)
|
||||
|
||||
// 默认配置值
|
||||
const (
|
||||
DefaultPort = 443
|
||||
DefaultDNSProtocol = DNSProtocolUDP
|
||||
DefaultDNSServer = "8.8.8.8:53"
|
||||
DefaultDNSTimeout = 5
|
||||
DefaultConnectTimeout = 10
|
||||
DefaultReadTimeout = 30
|
||||
DefaultWriteTimeout = 5
|
||||
DefaultIdleTimeout = 60
|
||||
DefaultLifeTime = 300
|
||||
DefaultLogLevel = "info"
|
||||
DefaultMaxConns = 1000
|
||||
DefaultMinFragSize = 10
|
||||
DefaultMaxFragSize = 100
|
||||
DefaultMinDelay = 10
|
||||
DefaultMaxDelay = 30
|
||||
)
|
||||
|
||||
// 超时配置
|
||||
type TimeoutConfig struct {
|
||||
Connect int `yaml:"connect"` // 连接超时(秒)
|
||||
Read int `yaml:"read"` // 读取超时(秒)
|
||||
Write int `yaml:"write"` // 写入超时(秒)
|
||||
Idle int `yaml:"idle"` // 空闲超时(秒)
|
||||
LifeTime int `yaml:"lifetime"` // 连接最大生命周期(秒)
|
||||
}
|
||||
|
||||
// TLS分片配置
|
||||
type FragmentConfig struct {
|
||||
Enabled bool `yaml:"enabled"` // 是否启用TLS分片
|
||||
MinSize int `yaml:"min_size"` // 最小分片大小(字节)
|
||||
MaxSize int `yaml:"max_size"` // 最大分片大小(字节)
|
||||
DelayMin int `yaml:"delay_min"` // 分片之间的最小延迟(毫秒)
|
||||
DelayMax int `yaml:"delay_max"` // 分片之间的最大延迟(毫秒)
|
||||
Validate bool `yaml:"validate"` // 是否验证TLS记录完整性
|
||||
}
|
||||
|
||||
// DNS解析器配置
|
||||
type DNSResolverConfig struct {
|
||||
Protocol string `yaml:"protocol"` // udp, tcp, dot, doh, doq
|
||||
Server string `yaml:"server"` // 服务器地址,如 8.8.8.8:53, 1.1.1.1:853
|
||||
Timeout int `yaml:"timeout"` // 超时时间(秒)
|
||||
}
|
||||
|
||||
// 域名匹配规则
|
||||
type DomainRule struct {
|
||||
Type string `yaml:"type"` // regexp, suffix, keyword
|
||||
Value string `yaml:"value"` // 匹配值
|
||||
|
||||
// 编译后的正则表达式(仅当Type为regexp时使用)
|
||||
compiledRegexp *regexp.Regexp
|
||||
}
|
||||
|
||||
// 代理规则
|
||||
type ProxyRule struct {
|
||||
Domains []DomainRule `yaml:"domains"`
|
||||
Port int `yaml:"port"` // 目标端口,默认为443
|
||||
Fragment FragmentConfig `yaml:"fragment"`
|
||||
}
|
||||
|
||||
// 配置结构
|
||||
type Config struct {
|
||||
Listen string `yaml:"listen"`
|
||||
Rules []ProxyRule `yaml:"rules"`
|
||||
DefaultPort int `yaml:"default_port"` // 默认目标端口,如果规则中未指定
|
||||
DNS DNSResolverConfig `yaml:"dns"` // DNS解析器配置
|
||||
Timeout TimeoutConfig `yaml:"timeout"` // 超时配置
|
||||
LogLevel string `yaml:"log_level"` // 日志级别:debug, info, warn, error
|
||||
MaxConns int `yaml:"max_conns"` // 最大并发连接数
|
||||
}
|
||||
|
||||
// 加载配置文件
|
||||
func LoadConfig(path string) (*Config, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
var cfg Config
|
||||
if err := yaml.Unmarshal(data, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("解析配置文件失败: %w", err)
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
setDefaultValues(&cfg)
|
||||
|
||||
// 编译正则表达式
|
||||
if err := compileRegexps(&cfg); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &cfg, nil
|
||||
}
|
||||
|
||||
// 设置默认值
|
||||
func setDefaultValues(cfg *Config) {
|
||||
// 设置默认端口
|
||||
if cfg.DefaultPort == 0 {
|
||||
cfg.DefaultPort = DefaultPort
|
||||
}
|
||||
|
||||
// 设置默认DNS配置
|
||||
if cfg.DNS.Protocol == "" {
|
||||
cfg.DNS.Protocol = DefaultDNSProtocol
|
||||
}
|
||||
if cfg.DNS.Server == "" {
|
||||
cfg.DNS.Server = DefaultDNSServer
|
||||
}
|
||||
if cfg.DNS.Timeout == 0 {
|
||||
cfg.DNS.Timeout = DefaultDNSTimeout
|
||||
}
|
||||
|
||||
// 设置默认超时配置
|
||||
if cfg.Timeout.Connect == 0 {
|
||||
cfg.Timeout.Connect = DefaultConnectTimeout
|
||||
}
|
||||
if cfg.Timeout.Read == 0 {
|
||||
cfg.Timeout.Read = DefaultReadTimeout
|
||||
}
|
||||
if cfg.Timeout.Write == 0 {
|
||||
cfg.Timeout.Write = DefaultWriteTimeout
|
||||
}
|
||||
if cfg.Timeout.Idle == 0 {
|
||||
cfg.Timeout.Idle = DefaultIdleTimeout
|
||||
}
|
||||
if cfg.Timeout.LifeTime == 0 {
|
||||
cfg.Timeout.LifeTime = DefaultLifeTime
|
||||
}
|
||||
|
||||
// 设置默认日志级别
|
||||
if cfg.LogLevel == "" {
|
||||
cfg.LogLevel = DefaultLogLevel
|
||||
}
|
||||
|
||||
// 设置默认最大连接数
|
||||
if cfg.MaxConns == 0 {
|
||||
cfg.MaxConns = DefaultMaxConns
|
||||
}
|
||||
|
||||
// 设置规则默认值
|
||||
for i := range cfg.Rules {
|
||||
// 如果规则中未指定端口,使用默认端口
|
||||
if cfg.Rules[i].Port == 0 {
|
||||
cfg.Rules[i].Port = cfg.DefaultPort
|
||||
}
|
||||
|
||||
// 设置TLS分片配置的默认值
|
||||
if cfg.Rules[i].Fragment.Enabled {
|
||||
setFragmentDefaults(&cfg.Rules[i].Fragment)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 设置分片默认值
|
||||
func setFragmentDefaults(frag *FragmentConfig) {
|
||||
// 设置默认的分片大小范围
|
||||
if frag.MinSize <= 0 {
|
||||
frag.MinSize = DefaultMinFragSize
|
||||
}
|
||||
if frag.MaxSize <= 0 {
|
||||
frag.MaxSize = DefaultMaxFragSize
|
||||
}
|
||||
if frag.MinSize > frag.MaxSize {
|
||||
frag.MinSize = frag.MaxSize
|
||||
}
|
||||
|
||||
// 设置默认的分片延迟范围
|
||||
if frag.DelayMin <= 0 {
|
||||
frag.DelayMin = DefaultMinDelay
|
||||
}
|
||||
if frag.DelayMax <= 0 {
|
||||
frag.DelayMax = DefaultMaxDelay
|
||||
}
|
||||
if frag.DelayMin > frag.DelayMax {
|
||||
frag.DelayMin = frag.DelayMax
|
||||
}
|
||||
}
|
||||
|
||||
// 编译正则表达式
|
||||
func compileRegexps(cfg *Config) error {
|
||||
for i := range cfg.Rules {
|
||||
for j := range cfg.Rules[i].Domains {
|
||||
if cfg.Rules[i].Domains[j].Type == MatchTypeRegexp {
|
||||
re, err := regexp.Compile(cfg.Rules[i].Domains[j].Value)
|
||||
if err != nil {
|
||||
return fmt.Errorf("编译正则表达式失败 '%s': %w",
|
||||
cfg.Rules[i].Domains[j].Value, err)
|
||||
}
|
||||
cfg.Rules[i].Domains[j].compiledRegexp = re
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// 检查域名是否匹配规则
|
||||
func (r *DomainRule) Match(domain string) bool {
|
||||
switch r.Type {
|
||||
case MatchTypeRegexp:
|
||||
return r.compiledRegexp.MatchString(domain)
|
||||
case MatchTypeSuffix:
|
||||
return len(domain) >= len(r.Value) &&
|
||||
domain[len(domain)-len(r.Value):] == r.Value
|
||||
case MatchTypeKeyword:
|
||||
return contains(domain, r.Value)
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// 辅助函数:检查字符串是否包含子串
|
||||
func contains(s, substr string) bool {
|
||||
for i := 0; i <= len(s)-len(substr); i++ {
|
||||
if s[i:i+len(substr)] == substr {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
|
@ -0,0 +1,398 @@
|
|||
package dns
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/SNI_Proxy/config"
|
||||
"github.com/miekg/dns"
|
||||
)
|
||||
|
||||
// Resolver 是DNS解析器接口
|
||||
type Resolver interface {
|
||||
// Resolve 解析域名为IP地址
|
||||
Resolve(domain string) ([]net.IP, error)
|
||||
// Close 关闭解析器
|
||||
Close() error
|
||||
}
|
||||
|
||||
// 基础DNS解析器结构
|
||||
type BaseResolver struct {
|
||||
server string
|
||||
timeout time.Duration
|
||||
}
|
||||
|
||||
// 创建DNS解析器
|
||||
func NewResolver(cfg config.DNSResolverConfig) (Resolver, error) {
|
||||
switch cfg.Protocol {
|
||||
case config.DNSProtocolUDP:
|
||||
return NewUDPResolver(cfg)
|
||||
case config.DNSProtocolTCP:
|
||||
return NewTCPResolver(cfg)
|
||||
case config.DNSProtocolDoT:
|
||||
return NewDoTResolver(cfg)
|
||||
case config.DNSProtocolDoH:
|
||||
return NewDoHResolver(cfg)
|
||||
case config.DNSProtocolDoQ:
|
||||
return NewDoQResolver(cfg)
|
||||
default:
|
||||
return nil, fmt.Errorf("不支持的DNS协议: %s", cfg.Protocol)
|
||||
}
|
||||
}
|
||||
|
||||
// UDP DNS解析器
|
||||
type UDPResolver struct {
|
||||
BaseResolver
|
||||
client *dns.Client
|
||||
}
|
||||
|
||||
func NewUDPResolver(cfg config.DNSResolverConfig) (*UDPResolver, error) {
|
||||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||||
return &UDPResolver{
|
||||
BaseResolver: BaseResolver{
|
||||
server: cfg.Server,
|
||||
timeout: timeout,
|
||||
},
|
||||
client: &dns.Client{
|
||||
Net: "udp",
|
||||
Timeout: timeout,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// 执行DNS查询
|
||||
func dnsQuery(client *dns.Client, server string, domain string, qtype uint16) ([]net.IP, error) {
|
||||
m := new(dns.Msg)
|
||||
m.SetQuestion(dns.Fqdn(domain), qtype)
|
||||
m.RecursionDesired = true
|
||||
|
||||
resp, _, err := client.Exchange(m, server)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if resp.Rcode != dns.RcodeSuccess {
|
||||
return nil, fmt.Errorf("DNS查询返回错误码: %d", resp.Rcode)
|
||||
}
|
||||
|
||||
var ips []net.IP
|
||||
for _, ans := range resp.Answer {
|
||||
switch rr := ans.(type) {
|
||||
case *dns.A:
|
||||
ips = append(ips, rr.A)
|
||||
case *dns.AAAA:
|
||||
ips = append(ips, rr.AAAA)
|
||||
}
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (r *UDPResolver) Resolve(domain string) ([]net.IP, error) {
|
||||
// 先尝试A记录
|
||||
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("UDP DNS A查询失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果没有A记录,尝试AAAA记录
|
||||
if len(ips) == 0 {
|
||||
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("UDP DNS AAAA查询失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (r *UDPResolver) Close() error {
|
||||
return nil // UDP没有需要关闭的资源
|
||||
}
|
||||
|
||||
// TCP DNS解析器
|
||||
type TCPResolver struct {
|
||||
BaseResolver
|
||||
client *dns.Client
|
||||
}
|
||||
|
||||
func NewTCPResolver(cfg config.DNSResolverConfig) (*TCPResolver, error) {
|
||||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||||
return &TCPResolver{
|
||||
BaseResolver: BaseResolver{
|
||||
server: cfg.Server,
|
||||
timeout: timeout,
|
||||
},
|
||||
client: &dns.Client{
|
||||
Net: "tcp",
|
||||
Timeout: timeout,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *TCPResolver) Resolve(domain string) ([]net.IP, error) {
|
||||
// 先尝试A记录
|
||||
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("TCP DNS A查询失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果没有A记录,尝试AAAA记录
|
||||
if len(ips) == 0 {
|
||||
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("TCP DNS AAAA查询失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (r *TCPResolver) Close() error {
|
||||
return nil // TCP没有需要关闭的资源
|
||||
}
|
||||
|
||||
// DoT DNS解析器 (DNS over TLS)
|
||||
type DoTResolver struct {
|
||||
BaseResolver
|
||||
client *dns.Client
|
||||
}
|
||||
|
||||
func NewDoTResolver(cfg config.DNSResolverConfig) (*DoTResolver, error) {
|
||||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||||
return &DoTResolver{
|
||||
BaseResolver: BaseResolver{
|
||||
server: cfg.Server,
|
||||
timeout: timeout,
|
||||
},
|
||||
client: &dns.Client{
|
||||
Net: "tcp-tls",
|
||||
Timeout: timeout,
|
||||
TLSConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *DoTResolver) Resolve(domain string) ([]net.IP, error) {
|
||||
// 先尝试A记录
|
||||
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DoT DNS A查询失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果没有A记录,尝试AAAA记录
|
||||
if len(ips) == 0 {
|
||||
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DoT DNS AAAA查询失败: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (r *DoTResolver) Close() error {
|
||||
return nil // DoT没有需要关闭的资源
|
||||
}
|
||||
|
||||
// DoH DNS解析器 (DNS over HTTPS)
|
||||
type DoHResolver struct {
|
||||
BaseResolver
|
||||
client *http.Client
|
||||
}
|
||||
|
||||
func NewDoHResolver(cfg config.DNSResolverConfig) (*DoHResolver, error) {
|
||||
timeout := time.Duration(cfg.Timeout) * time.Second
|
||||
return &DoHResolver{
|
||||
BaseResolver: BaseResolver{
|
||||
server: cfg.Server,
|
||||
timeout: timeout,
|
||||
},
|
||||
client: &http.Client{
|
||||
Timeout: timeout,
|
||||
Transport: &http.Transport{
|
||||
TLSClientConfig: &tls.Config{
|
||||
MinVersion: tls.VersionTLS12,
|
||||
},
|
||||
IdleConnTimeout: 30 * time.Second,
|
||||
},
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
type dohResponse struct {
|
||||
Status int `json:"Status"`
|
||||
TC bool `json:"TC"`
|
||||
RD bool `json:"RD"`
|
||||
RA bool `json:"RA"`
|
||||
AD bool `json:"AD"`
|
||||
CD bool `json:"CD"`
|
||||
Question []struct {
|
||||
Name string `json:"name"`
|
||||
Type int `json:"type"`
|
||||
} `json:"Question"`
|
||||
Answer []struct {
|
||||
Name string `json:"name"`
|
||||
Type int `json:"type"`
|
||||
TTL int `json:"TTL"`
|
||||
Data string `json:"data"`
|
||||
} `json:"Answer"`
|
||||
}
|
||||
|
||||
func (r *DoHResolver) Resolve(domain string) ([]net.IP, error) {
|
||||
// 构建DoH请求URL
|
||||
url := r.server
|
||||
if !strings.HasPrefix(url, "https://") {
|
||||
url = "https://" + url
|
||||
}
|
||||
if !strings.Contains(url, "?") {
|
||||
url += "?name=" + domain + "&type=A"
|
||||
}
|
||||
|
||||
// 发送请求
|
||||
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
|
||||
defer cancel()
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建DoH请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/dns-json")
|
||||
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DoH请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("DoH请求返回非200状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取DoH响应失败: %w", err)
|
||||
}
|
||||
|
||||
var dohResp dohResponse
|
||||
if err := json.Unmarshal(body, &dohResp); err != nil {
|
||||
return nil, fmt.Errorf("解析DoH响应失败: %w", err)
|
||||
}
|
||||
|
||||
if dohResp.Status != 0 {
|
||||
return nil, fmt.Errorf("DoH查询返回错误码: %d", dohResp.Status)
|
||||
}
|
||||
|
||||
var ips []net.IP
|
||||
for _, ans := range dohResp.Answer {
|
||||
if ans.Type == 1 { // A记录
|
||||
ip := net.ParseIP(ans.Data)
|
||||
if ip != nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 如果没有A记录,尝试AAAA记录
|
||||
if len(ips) == 0 {
|
||||
url := r.server
|
||||
if !strings.HasPrefix(url, "https://") {
|
||||
url = "https://" + url
|
||||
}
|
||||
if !strings.Contains(url, "?") {
|
||||
url += "?name=" + domain + "&type=AAAA"
|
||||
} else {
|
||||
url = strings.Replace(url, "type=A", "type=AAAA", 1)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("创建DoH AAAA请求失败: %w", err)
|
||||
}
|
||||
req.Header.Set("Accept", "application/dns-json")
|
||||
|
||||
resp, err := r.client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("DoH AAAA请求失败: %w", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("DoH AAAA请求返回非200状态码: %d", resp.StatusCode)
|
||||
}
|
||||
|
||||
body, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("读取DoH AAAA响应失败: %w", err)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(body, &dohResp); err != nil {
|
||||
return nil, fmt.Errorf("解析DoH AAAA响应失败: %w", err)
|
||||
}
|
||||
|
||||
for _, ans := range dohResp.Answer {
|
||||
if ans.Type == 28 { // AAAA记录
|
||||
ip := net.ParseIP(ans.Data)
|
||||
if ip != nil {
|
||||
ips = append(ips, ip)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(ips) == 0 {
|
||||
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
|
||||
}
|
||||
|
||||
return ips, nil
|
||||
}
|
||||
|
||||
func (r *DoHResolver) Close() error {
|
||||
r.client.CloseIdleConnections()
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoQ DNS解析器 (DNS over QUIC)
|
||||
type DoQResolver struct {
|
||||
BaseResolver
|
||||
}
|
||||
|
||||
func NewDoQResolver(cfg config.DNSResolverConfig) (*DoQResolver, error) {
|
||||
return &DoQResolver{
|
||||
BaseResolver: BaseResolver{
|
||||
server: cfg.Server,
|
||||
timeout: time.Duration(cfg.Timeout) * time.Second,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *DoQResolver) Resolve(domain string) ([]net.IP, error) {
|
||||
// DoQ实现较为复杂,需要QUIC协议支持
|
||||
// 这里仅作为占位符,实际项目中需要实现
|
||||
return nil, fmt.Errorf("DoQ协议暂未实现")
|
||||
}
|
||||
|
||||
func (r *DoQResolver) Close() error {
|
||||
return nil
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
module github.com/SNI_Proxy
|
||||
|
||||
go 1.20
|
||||
|
||||
require (
|
||||
github.com/miekg/dns v1.1.58
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
)
|
||||
|
||||
require (
|
||||
golang.org/x/mod v0.14.0 // indirect
|
||||
golang.org/x/net v0.20.0 // indirect
|
||||
golang.org/x/sys v0.16.0 // indirect
|
||||
golang.org/x/tools v0.17.0 // indirect
|
||||
)
|
|
@ -0,0 +1,15 @@
|
|||
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
|
||||
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
|
||||
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
|
||||
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
|
||||
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
|
||||
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
|
||||
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
|
||||
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc=
|
||||
golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
|
||||
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
|
|
@ -0,0 +1,3 @@
|
|||
give me golang example to make a SNI proxy, which can do tls fragement (configured in config file), the domain match should use regexp / suffix / keyword method.
|
||||
Ask me more if needed.
|
||||
curl --resolve "cloudflare-dns.com:443:127.0.0.1" -H "accept: application/dns-json" "https://cloudflare-dns.com/dns-query?name=example.com&type=A"
|
|
@ -0,0 +1,46 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
"syscall"
|
||||
|
||||
"github.com/SNI_Proxy/config"
|
||||
"github.com/SNI_Proxy/proxy"
|
||||
)
|
||||
|
||||
func main() {
|
||||
// 解析命令行参数
|
||||
configPath := flag.String("config", "config.yaml", "配置文件路径")
|
||||
flag.Parse()
|
||||
|
||||
// 加载配置
|
||||
cfg, err := config.LoadConfig(*configPath)
|
||||
if err != nil {
|
||||
log.Fatalf("加载配置失败: %v", err)
|
||||
}
|
||||
|
||||
// 创建并启动代理服务器
|
||||
server := proxy.NewServer(cfg)
|
||||
if err := server.Start(); err != nil {
|
||||
log.Fatalf("启动代理服务器失败: %v", err)
|
||||
}
|
||||
|
||||
log.Printf("SNI代理服务器已启动,监听地址: %s", cfg.Listen)
|
||||
|
||||
// 等待中断信号以优雅地关闭服务器
|
||||
waitForInterrupt(server)
|
||||
}
|
||||
|
||||
// 等待中断信号
|
||||
func waitForInterrupt(server *proxy.Server) {
|
||||
sigChan := make(chan os.Signal, 1)
|
||||
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
|
||||
<-sigChan
|
||||
|
||||
log.Println("正在关闭服务器...")
|
||||
server.Stop()
|
||||
log.Println("服务器已关闭")
|
||||
}
|
|
@ -0,0 +1,978 @@
|
|||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/SNI_Proxy/config"
|
||||
"github.com/SNI_Proxy/dns"
|
||||
)
|
||||
|
||||
// 服务器结构
|
||||
type Server struct {
|
||||
config *config.Config
|
||||
listener net.Listener
|
||||
wg sync.WaitGroup
|
||||
shutdownCh chan struct{}
|
||||
resolver dns.Resolver
|
||||
connMutex sync.Mutex
|
||||
conns map[net.Conn]struct{} // 活跃连接跟踪
|
||||
connCounter int64 // 当前连接计数器
|
||||
}
|
||||
|
||||
// 创建新的代理服务器
|
||||
func NewServer(cfg *config.Config) *Server {
|
||||
return &Server{
|
||||
config: cfg,
|
||||
shutdownCh: make(chan struct{}),
|
||||
conns: make(map[net.Conn]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// 启动代理服务器
|
||||
func (s *Server) Start() error {
|
||||
var err error
|
||||
|
||||
// 创建DNS解析器
|
||||
s.resolver, err = dns.NewResolver(s.config.DNS)
|
||||
if err != nil {
|
||||
return fmt.Errorf("创建DNS解析器失败: %w", err)
|
||||
}
|
||||
|
||||
s.listener, err = net.Listen("tcp", s.config.Listen)
|
||||
if err != nil {
|
||||
return fmt.Errorf("监听地址失败: %w", err)
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
go s.acceptLoop()
|
||||
|
||||
// 启动连接监控
|
||||
go s.monitorConnections()
|
||||
|
||||
// 启动连接清理
|
||||
go s.cleanIdleConnections()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 停止代理服务器
|
||||
func (s *Server) Stop() {
|
||||
log.Println("正在关闭服务器...")
|
||||
close(s.shutdownCh)
|
||||
if s.listener != nil {
|
||||
s.listener.Close()
|
||||
}
|
||||
|
||||
// 关闭所有活跃连接
|
||||
s.closeAllConnections()
|
||||
|
||||
if s.resolver != nil {
|
||||
s.resolver.Close()
|
||||
}
|
||||
|
||||
// 等待所有goroutine结束
|
||||
log.Println("等待所有连接关闭...")
|
||||
s.wg.Wait()
|
||||
log.Println("服务器已关闭")
|
||||
}
|
||||
|
||||
// 关闭所有活跃连接
|
||||
func (s *Server) closeAllConnections() {
|
||||
s.connMutex.Lock()
|
||||
defer s.connMutex.Unlock()
|
||||
|
||||
log.Printf("关闭 %d 个活跃连接", len(s.conns))
|
||||
for conn := range s.conns {
|
||||
conn.Close()
|
||||
delete(s.conns, conn)
|
||||
}
|
||||
}
|
||||
|
||||
// 添加连接到跟踪列表
|
||||
func (s *Server) trackConn(conn net.Conn) bool {
|
||||
s.connMutex.Lock()
|
||||
defer s.connMutex.Unlock()
|
||||
|
||||
// 检查是否超过最大连接数
|
||||
if s.config.MaxConns > 0 && len(s.conns) >= s.config.MaxConns {
|
||||
log.Printf("达到最大连接数 %d,拒绝新连接", s.config.MaxConns)
|
||||
return false
|
||||
}
|
||||
|
||||
s.conns[conn] = struct{}{}
|
||||
atomic.AddInt64(&s.connCounter, 1)
|
||||
return true
|
||||
}
|
||||
|
||||
// 从跟踪列表中移除连接
|
||||
func (s *Server) untrackConn(conn net.Conn) {
|
||||
s.connMutex.Lock()
|
||||
defer s.connMutex.Unlock()
|
||||
|
||||
if _, exists := s.conns[conn]; exists {
|
||||
delete(s.conns, conn)
|
||||
atomic.AddInt64(&s.connCounter, -1)
|
||||
}
|
||||
}
|
||||
|
||||
// 监控连接状态
|
||||
func (s *Server) monitorConnections() {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.shutdownCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
// 记录当前活跃连接数
|
||||
s.connMutex.Lock()
|
||||
activeConns := len(s.conns)
|
||||
s.connMutex.Unlock()
|
||||
|
||||
totalConns := atomic.LoadInt64(&s.connCounter)
|
||||
log.Printf("连接统计 - 当前活跃: %d, 总计处理: %d", activeConns, totalConns)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 清理空闲连接
|
||||
func (s *Server) cleanIdleConnections() {
|
||||
// 每分钟检查一次空闲连接
|
||||
ticker := time.NewTicker(1 * time.Minute)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-s.shutdownCh:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.connMutex.Lock()
|
||||
log.Printf("开始清理空闲连接,当前连接数: %d", len(s.conns))
|
||||
// 这里我们不做实际清理,因为连接已经设置了超时
|
||||
// 实际清理由各个连接自己的超时机制处理
|
||||
s.connMutex.Unlock()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 接受连接循环
|
||||
func (s *Server) acceptLoop() {
|
||||
defer s.wg.Done()
|
||||
|
||||
for {
|
||||
conn, err := s.listener.Accept()
|
||||
if err != nil {
|
||||
select {
|
||||
case <-s.shutdownCh:
|
||||
return
|
||||
default:
|
||||
log.Printf("接受连接失败: %v", err)
|
||||
// 短暂休眠,避免CPU占用过高
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
// 检查是否可以接受新连接
|
||||
if !s.trackConn(conn) {
|
||||
conn.Close()
|
||||
continue
|
||||
}
|
||||
|
||||
s.wg.Add(1)
|
||||
|
||||
go func(clientConn net.Conn) {
|
||||
defer s.wg.Done()
|
||||
defer s.untrackConn(clientConn) // 移除连接跟踪
|
||||
defer clientConn.Close()
|
||||
|
||||
// 设置连接生命周期
|
||||
if s.config.Timeout.LifeTime > 0 {
|
||||
deadline := time.Now().Add(time.Duration(s.config.Timeout.LifeTime) * time.Second)
|
||||
clientConn.SetDeadline(deadline)
|
||||
}
|
||||
|
||||
// 创建一个带超时的上下文
|
||||
ctx, cancel := context.WithTimeout(context.Background(),
|
||||
time.Duration(s.config.Timeout.Idle)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 使用goroutine和channel处理连接,支持超时控制
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
errCh <- s.handleConnection(clientConn)
|
||||
}()
|
||||
|
||||
// 等待处理完成或超时
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil {
|
||||
if s.config.LogLevel == "debug" ||
|
||||
(s.config.LogLevel != "error" && !isCommonNetworkError(err)) {
|
||||
log.Printf("处理连接失败: %v", err)
|
||||
}
|
||||
}
|
||||
case <-ctx.Done():
|
||||
log.Printf("处理连接超时")
|
||||
case <-s.shutdownCh:
|
||||
log.Printf("服务器关闭,终止连接处理")
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}
|
||||
|
||||
// 判断是否是常见网络错误(可以不记录日志的错误)
|
||||
func isCommonNetworkError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
errStr := err.Error()
|
||||
// 检查常见的网络错误
|
||||
commonErrors := []string{
|
||||
"connection reset by peer",
|
||||
"broken pipe",
|
||||
"i/o timeout",
|
||||
"use of closed network connection",
|
||||
"EOF",
|
||||
}
|
||||
|
||||
for _, e := range commonErrors {
|
||||
if strings.Contains(strings.ToLower(errStr), e) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// 处理单个连接
|
||||
func (s *Server) handleConnection(clientConn net.Conn) error {
|
||||
// 设置读取超时
|
||||
readTimeout := time.Duration(s.config.Timeout.Read) * time.Second
|
||||
if err := clientConn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
|
||||
return fmt.Errorf("设置读取超时失败: %w", err)
|
||||
}
|
||||
|
||||
// 读取并解析SNI信息
|
||||
sni, clientHello, err := readSNI(clientConn)
|
||||
if err != nil {
|
||||
return fmt.Errorf("读取SNI失败: %w", err)
|
||||
}
|
||||
|
||||
// 重置读取超时
|
||||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
||||
return fmt.Errorf("重置读取超时失败: %w", err)
|
||||
}
|
||||
|
||||
// 查找匹配的规则
|
||||
var targetRule *config.ProxyRule
|
||||
for i := range s.config.Rules {
|
||||
rule := &s.config.Rules[i]
|
||||
for j := range rule.Domains {
|
||||
if rule.Domains[j].Match(sni) {
|
||||
targetRule = rule
|
||||
break
|
||||
}
|
||||
}
|
||||
if targetRule != nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if targetRule == nil {
|
||||
// 如果没有匹配的规则,使用默认端口
|
||||
if s.config.LogLevel != "error" {
|
||||
log.Printf("未找到匹配的规则,使用默认端口 %d 连接到 %s", s.config.DefaultPort, sni)
|
||||
}
|
||||
return s.proxyConnection(clientConn, sni, s.config.DefaultPort, clientHello, config.FragmentConfig{Enabled: false})
|
||||
}
|
||||
|
||||
// 使用匹配规则中的端口和分片配置
|
||||
return s.proxyConnection(clientConn, sni, targetRule.Port, clientHello, targetRule.Fragment)
|
||||
}
|
||||
|
||||
// 代理连接到目标服务器
|
||||
func (s *Server) proxyConnection(clientConn net.Conn, targetHost string, targetPort int, clientHello []byte, fragConfig config.FragmentConfig) error {
|
||||
// 设置DNS解析超时
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.DNS.Timeout)*time.Second)
|
||||
defer cancel()
|
||||
|
||||
// 创建一个带超时的解析任务
|
||||
type resolveResult struct {
|
||||
ips []net.IP
|
||||
err error
|
||||
}
|
||||
|
||||
resCh := make(chan resolveResult, 1)
|
||||
go func() {
|
||||
ips, err := s.resolver.Resolve(targetHost)
|
||||
resCh <- resolveResult{ips, err}
|
||||
}()
|
||||
|
||||
// 等待DNS解析完成或超时
|
||||
var ips []net.IP
|
||||
select {
|
||||
case res := <-resCh:
|
||||
if res.err != nil {
|
||||
return fmt.Errorf("解析域名失败: %w", res.err)
|
||||
}
|
||||
ips = res.ips
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("解析域名超时")
|
||||
}
|
||||
|
||||
// 选择第一个IP地址
|
||||
targetIP := ips[0]
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("域名 %s 解析为 %s", targetHost, targetIP.String())
|
||||
}
|
||||
|
||||
// 构建目标地址
|
||||
targetAddr := net.JoinHostPort(targetIP.String(), strconv.Itoa(targetPort))
|
||||
if s.config.LogLevel != "error" {
|
||||
log.Printf("代理连接到: %s (原始SNI: %s)", targetAddr, targetHost)
|
||||
}
|
||||
|
||||
// 设置连接超时
|
||||
dialCtx, dialCancel := context.WithTimeout(context.Background(),
|
||||
time.Duration(s.config.Timeout.Connect)*time.Second)
|
||||
defer dialCancel()
|
||||
|
||||
// 连接到目标服务器
|
||||
var d net.Dialer
|
||||
targetConn, err := d.DialContext(dialCtx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("连接目标服务器失败: %w", err)
|
||||
}
|
||||
defer targetConn.Close()
|
||||
|
||||
// 设置目标连接的读写超时
|
||||
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
|
||||
tcpConn.SetKeepAlive(true)
|
||||
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||||
}
|
||||
|
||||
// 发送ClientHello到目标服务器,可能需要进行TLS分片
|
||||
if err := s.sendClientHello(targetConn, clientHello, fragConfig); err != nil {
|
||||
return fmt.Errorf("发送ClientHello失败: %w", err)
|
||||
}
|
||||
|
||||
// 双向转发数据
|
||||
errCh := make(chan error, 2)
|
||||
copyDone := make(chan struct{}, 2)
|
||||
|
||||
// 客户端 -> 目标服务器
|
||||
go func() {
|
||||
buf := make([]byte, 32*1024) // 使用较大的缓冲区
|
||||
_, err := s.copyBuffer(targetConn, clientConn, buf)
|
||||
errCh <- err
|
||||
copyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
// 目标服务器 -> 客户端
|
||||
go func() {
|
||||
buf := make([]byte, 32*1024) // 使用较大的缓冲区
|
||||
_, err := s.copyBuffer(clientConn, targetConn, buf)
|
||||
errCh <- err
|
||||
copyDone <- struct{}{}
|
||||
}()
|
||||
|
||||
// 等待任一方向的数据传输完成
|
||||
<-copyDone
|
||||
|
||||
// 设置一个短暂的超时,让另一个方向有机会完成
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
|
||||
// 检查是否有错误
|
||||
select {
|
||||
case err := <-errCh:
|
||||
if err != nil && err != io.EOF && !isCommonNetworkError(err) {
|
||||
return fmt.Errorf("数据转发失败: %w", err)
|
||||
}
|
||||
default:
|
||||
// 没有错误
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// 带超时控制的数据复制
|
||||
func (s *Server) copyBuffer(dst net.Conn, src net.Conn, buf []byte) (int64, error) {
|
||||
var written int64
|
||||
readTimeout := time.Duration(s.config.Timeout.Read) * time.Second
|
||||
writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second
|
||||
|
||||
for {
|
||||
// 设置读取超时
|
||||
src.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
|
||||
nr, err := src.Read(buf)
|
||||
if nr > 0 {
|
||||
// 设置写入超时
|
||||
dst.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
|
||||
nw, err := dst.Write(buf[0:nr])
|
||||
if nw > 0 {
|
||||
written += int64(nw)
|
||||
}
|
||||
if err != nil {
|
||||
return written, err
|
||||
}
|
||||
if nr != nw {
|
||||
return written, io.ErrShortWrite
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if err == io.EOF {
|
||||
return written, nil
|
||||
}
|
||||
return written, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 发送ClientHello,根据配置进行TLS分片
|
||||
func (s *Server) sendClientHello(conn net.Conn, clientHello []byte, fragConfig config.FragmentConfig) error {
|
||||
writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second
|
||||
|
||||
if s.config.LogLevel == "debug" {
|
||||
printTLSRecord(clientHello, "原始ClientHello")
|
||||
}
|
||||
|
||||
if !fragConfig.Enabled {
|
||||
// 不进行分片,直接发送
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("TLS分片未启用,直接发送 %d 字节", len(clientHello))
|
||||
}
|
||||
// 设置写入超时
|
||||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
_, err := conn.Write(clientHello)
|
||||
return err
|
||||
}
|
||||
|
||||
// 进行TLS分片
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("启用TLS分片,分片大小范围: %d-%d", fragConfig.MinSize, fragConfig.MaxSize)
|
||||
}
|
||||
|
||||
// 检查ClientHello是否是有效的TLS记录
|
||||
if fragConfig.Validate {
|
||||
valid, reason := validateTLSRecord(clientHello)
|
||||
if !valid {
|
||||
log.Printf("警告: ClientHello不是有效的TLS记录 (%s),跳过分片", reason)
|
||||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
_, err := conn.Write(clientHello)
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if clientHello[0] != recordTypeHandshake {
|
||||
log.Printf("警告: 记录类型不是握手类型,跳过分片")
|
||||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
_, err := conn.Write(clientHello)
|
||||
return err
|
||||
}
|
||||
|
||||
// 获取TLS记录头部和数据部分
|
||||
recordHeader := clientHello[:5]
|
||||
recordData := clientHello[5:]
|
||||
recordLength := len(recordData)
|
||||
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("TLS记录头部: %v, 数据长度: %d", recordHeader, recordLength)
|
||||
|
||||
// 如果是握手消息,打印握手消息头部信息
|
||||
if recordData[0] == handshakeTypeClientHello && len(recordData) >= 4 {
|
||||
handshakeType := recordData[0]
|
||||
handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3])
|
||||
log.Printf("握手消息头部: 类型=%d, 长度=%d", handshakeType, handshakeLength)
|
||||
}
|
||||
}
|
||||
|
||||
// 如果记录太小,不进行分片
|
||||
if recordLength <= fragConfig.MinSize {
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("记录太小 (%d 字节),不进行分片", recordLength)
|
||||
}
|
||||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
_, err := conn.Write(clientHello)
|
||||
return err
|
||||
}
|
||||
|
||||
// 第一个分片:发送记录头部和部分数据
|
||||
minSize := fragConfig.MinSize
|
||||
if minSize <= 0 {
|
||||
minSize = 10 // 默认最小分片大小
|
||||
}
|
||||
maxSize := fragConfig.MaxSize
|
||||
if maxSize <= 0 {
|
||||
maxSize = 100 // 默认最大分片大小
|
||||
}
|
||||
if minSize > maxSize {
|
||||
minSize = maxSize
|
||||
}
|
||||
|
||||
// 随机分片大小
|
||||
size := minSize
|
||||
if maxSize > minSize {
|
||||
randVal, _ := rand.Int(rand.Reader, big.NewInt(int64(maxSize-minSize+1)))
|
||||
size = minSize + int(randVal.Int64())
|
||||
}
|
||||
|
||||
// 确保不超过记录数据大小
|
||||
if size > recordLength {
|
||||
size = recordLength
|
||||
}
|
||||
|
||||
// 对于握手消息,确保第一个分片至少包含完整的握手消息头部(4字节)
|
||||
if recordData[0] == handshakeTypeClientHello && size < 4 {
|
||||
size = 4 // 至少包含握手消息头部
|
||||
}
|
||||
|
||||
// 创建第一个分片
|
||||
firstFragment := make([]byte, 5+size)
|
||||
copy(firstFragment[:5], recordHeader)
|
||||
copy(firstFragment[5:], recordData[:size])
|
||||
|
||||
// 修改第一个分片的长度字段
|
||||
firstFragment[3] = byte(size >> 8)
|
||||
firstFragment[4] = byte(size)
|
||||
|
||||
if s.config.LogLevel == "debug" {
|
||||
printTLSRecord(firstFragment, "第一个分片")
|
||||
log.Printf("第一个分片长度: %d (头部5字节 + 数据%d字节)", len(firstFragment), size)
|
||||
|
||||
// 验证第一个分片的握手消息
|
||||
if recordData[0] == handshakeTypeClientHello {
|
||||
// 创建一个临时的握手消息进行验证
|
||||
tempHandshake := make([]byte, size)
|
||||
copy(tempHandshake, recordData[:size])
|
||||
// 调整握手消息长度以匹配实际数据
|
||||
handshakeLength := uint32(size - 4) // 减去握手消息头部的4字节
|
||||
tempHandshake[1] = byte(handshakeLength >> 16)
|
||||
tempHandshake[2] = byte(handshakeLength >> 8)
|
||||
tempHandshake[3] = byte(handshakeLength)
|
||||
|
||||
valid, reason := validateHandshakeMessage(tempHandshake)
|
||||
log.Printf("第一个分片握手消息验证: %v, %s", valid, reason)
|
||||
}
|
||||
}
|
||||
|
||||
// 设置写入超时并发送第一个分片
|
||||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
if _, err := conn.Write(firstFragment); err != nil {
|
||||
return fmt.Errorf("发送第一个分片失败: %w", err)
|
||||
}
|
||||
|
||||
// 如果还有剩余数据,创建第二个分片
|
||||
if size < recordLength {
|
||||
// 添加一个小延迟,模拟网络延迟,提高分片效果
|
||||
delayMin := fragConfig.DelayMin
|
||||
if delayMin <= 0 {
|
||||
delayMin = 10 // 默认最小延迟10毫秒
|
||||
}
|
||||
|
||||
delayMax := fragConfig.DelayMax
|
||||
if delayMax <= 0 {
|
||||
delayMax = 30 // 默认最大延迟30毫秒
|
||||
}
|
||||
|
||||
if delayMin > delayMax {
|
||||
delayMin = delayMax
|
||||
}
|
||||
|
||||
delayRange := big.NewInt(int64(delayMax - delayMin + 1))
|
||||
randVal, _ := rand.Int(rand.Reader, delayRange)
|
||||
delay := time.Duration(delayMin+int(randVal.Int64())) * time.Millisecond
|
||||
time.Sleep(delay)
|
||||
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("延迟 %v 后发送第二个分片", delay)
|
||||
}
|
||||
|
||||
// 创建第二个分片
|
||||
remainingSize := recordLength - size
|
||||
secondFragment := make([]byte, 5+remainingSize)
|
||||
|
||||
// 复制记录头部(保持与原始记录头部一致)
|
||||
copy(secondFragment[:5], recordHeader)
|
||||
|
||||
// 设置正确的长度字段
|
||||
secondFragment[3] = byte(remainingSize >> 8)
|
||||
secondFragment[4] = byte(remainingSize)
|
||||
|
||||
// 复制剩余数据,确保握手类型保持一致
|
||||
// 注意:我们不能直接复制数据,因为第二个分片的握手消息头部需要特殊处理
|
||||
if recordData[0] == handshakeTypeClientHello {
|
||||
// 创建新的握手消息头部
|
||||
// 1. 保持握手类型不变
|
||||
// 2. 调整握手消息长度
|
||||
handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3])
|
||||
|
||||
// 计算第一个分片中已发送的握手消息数据长度(不包括握手消息头部的4字节)
|
||||
sentDataLength := uint32(size) - 4
|
||||
|
||||
// 计算剩余的握手消息长度
|
||||
remainingHandshakeLength := handshakeLength - sentDataLength
|
||||
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("握手消息总长度: %d, 第一个分片已发送数据: %d, 第二个分片剩余数据: %d",
|
||||
handshakeLength, sentDataLength, remainingHandshakeLength)
|
||||
}
|
||||
|
||||
// 复制握手类型
|
||||
secondFragment[5] = recordData[0]
|
||||
|
||||
// 设置新的握手消息长度
|
||||
secondFragment[6] = byte(remainingHandshakeLength >> 16)
|
||||
secondFragment[7] = byte(remainingHandshakeLength >> 8)
|
||||
secondFragment[8] = byte(remainingHandshakeLength)
|
||||
|
||||
// 复制剩余数据(跳过第一个分片已经发送的部分)
|
||||
copy(secondFragment[9:], recordData[size:])
|
||||
} else {
|
||||
// 如果不是ClientHello,直接复制数据
|
||||
copy(secondFragment[5:], recordData[size:])
|
||||
}
|
||||
|
||||
// 对于第二个分片,我们采用不同的策略
|
||||
// 不再尝试修改握手消息头部,而是直接将剩余数据作为应用数据发送
|
||||
// 这样可以避免TLS握手解析错误
|
||||
|
||||
// 修改第二个分片的记录类型为应用数据
|
||||
secondFragment[0] = 23 // ApplicationData
|
||||
|
||||
// 复制剩余数据
|
||||
copy(secondFragment[5:], recordData[size:])
|
||||
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("第二个分片使用应用数据类型,避免握手解析错误")
|
||||
printTLSRecord(secondFragment, "第二个分片")
|
||||
log.Printf("第二个分片长度: %d (头部5字节 + 数据%d字节)", len(secondFragment), remainingSize)
|
||||
}
|
||||
|
||||
// 设置写入超时并发送第二个分片
|
||||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||||
if _, err := conn.Write(secondFragment); err != nil {
|
||||
return fmt.Errorf("发送第二个分片失败: %w", err)
|
||||
}
|
||||
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("TLS分片完成: 总共发送 %d 字节 (第一分片: %d 字节, 第二分片: %d 字节)",
|
||||
len(firstFragment)+len(secondFragment), len(firstFragment), len(secondFragment))
|
||||
}
|
||||
} else {
|
||||
if s.config.LogLevel == "debug" {
|
||||
log.Printf("TLS分片完成: 只需一个分片,总共发送 %d 字节", len(firstFragment))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TLS记录类型
|
||||
const (
|
||||
recordTypeHandshake = 22
|
||||
)
|
||||
|
||||
// TLS握手类型
|
||||
const (
|
||||
handshakeTypeClientHello = 1
|
||||
)
|
||||
|
||||
// 读取SNI信息
|
||||
func readSNI(conn net.Conn) (string, []byte, error) {
|
||||
// 读取TLS记录头部
|
||||
header := make([]byte, 5)
|
||||
if _, err := io.ReadFull(conn, header); err != nil {
|
||||
return "", nil, fmt.Errorf("读取TLS记录头部失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否是TLS握手记录
|
||||
if header[0] != recordTypeHandshake {
|
||||
return "", nil, fmt.Errorf("不是TLS握手记录,记录类型: %d", header[0])
|
||||
}
|
||||
|
||||
// 获取记录长度
|
||||
recordLength := int(header[3])<<8 | int(header[4])
|
||||
if recordLength > 16384 {
|
||||
return "", nil, fmt.Errorf("TLS记录太长: %d 字节", recordLength)
|
||||
}
|
||||
|
||||
if recordLength < 4 {
|
||||
return "", nil, fmt.Errorf("TLS握手记录太短: %d 字节", recordLength)
|
||||
}
|
||||
|
||||
// 读取握手消息
|
||||
handshakeData := make([]byte, recordLength)
|
||||
if _, err := io.ReadFull(conn, handshakeData); err != nil {
|
||||
return "", nil, fmt.Errorf("读取握手消息失败: %w", err)
|
||||
}
|
||||
|
||||
// 检查是否是ClientHello
|
||||
if handshakeData[0] != handshakeTypeClientHello {
|
||||
return "", nil, fmt.Errorf("不是ClientHello消息,握手类型: %d", handshakeData[0])
|
||||
}
|
||||
|
||||
// 完整的ClientHello消息(包括记录头部)
|
||||
clientHello := append(header, handshakeData...)
|
||||
|
||||
// 验证TLS记录的完整性
|
||||
valid, reason := validateTLSRecord(clientHello)
|
||||
if !valid {
|
||||
log.Printf("警告: 收到的ClientHello不是有效的TLS记录: %s", reason)
|
||||
}
|
||||
|
||||
// 解析SNI扩展
|
||||
sni, err := extractSNI(handshakeData)
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("提取SNI失败: %w", err)
|
||||
}
|
||||
|
||||
if sni != "" && len(sni) > 0 {
|
||||
log.Printf("成功提取SNI: %s", sni)
|
||||
}
|
||||
|
||||
return sni, clientHello, nil
|
||||
}
|
||||
|
||||
// 从ClientHello中提取SNI
|
||||
func extractSNI(data []byte) (string, error) {
|
||||
if len(data) < 42 {
|
||||
return "", fmt.Errorf("ClientHello太短")
|
||||
}
|
||||
|
||||
// 跳过握手类型(1) + 长度(3) + 协议版本(2) + 随机数(32)
|
||||
pos := 38
|
||||
|
||||
// 跳过会话ID
|
||||
if pos+1 > len(data) {
|
||||
return "", fmt.Errorf("解析会话ID长度时数据不足")
|
||||
}
|
||||
sessionIDLength := int(data[pos])
|
||||
pos += 1 + sessionIDLength
|
||||
|
||||
// 跳过密码套件
|
||||
if pos+2 > len(data) {
|
||||
return "", fmt.Errorf("解析密码套件长度时数据不足")
|
||||
}
|
||||
cipherSuitesLength := int(data[pos])<<8 | int(data[pos+1])
|
||||
pos += 2 + cipherSuitesLength
|
||||
|
||||
// 跳过压缩方法
|
||||
if pos+1 > len(data) {
|
||||
return "", fmt.Errorf("解析压缩方法长度时数据不足")
|
||||
}
|
||||
compressionMethodsLength := int(data[pos])
|
||||
pos += 1 + compressionMethodsLength
|
||||
|
||||
// 检查是否有扩展
|
||||
if pos+2 > len(data) {
|
||||
return "", fmt.Errorf("没有扩展数据")
|
||||
}
|
||||
extensionsLength := int(data[pos])<<8 | int(data[pos+1])
|
||||
pos += 2
|
||||
|
||||
// 解析扩展
|
||||
extensionsEnd := pos + extensionsLength
|
||||
if extensionsEnd > len(data) {
|
||||
return "", fmt.Errorf("扩展数据不足")
|
||||
}
|
||||
|
||||
for pos < extensionsEnd {
|
||||
// 读取扩展类型
|
||||
if pos+4 > len(data) {
|
||||
return "", fmt.Errorf("解析扩展类型时数据不足")
|
||||
}
|
||||
extensionType := int(data[pos])<<8 | int(data[pos+1])
|
||||
extensionLength := int(data[pos+2])<<8 | int(data[pos+3])
|
||||
pos += 4
|
||||
|
||||
// 检查是否是SNI扩展 (类型为0)
|
||||
if extensionType == 0 {
|
||||
// 跳过SNI列表长度
|
||||
if pos+2 > len(data) {
|
||||
return "", fmt.Errorf("解析SNI列表长度时数据不足")
|
||||
}
|
||||
pos += 2
|
||||
|
||||
// 检查SNI类型
|
||||
if pos+1 > len(data) {
|
||||
return "", fmt.Errorf("解析SNI类型时数据不足")
|
||||
}
|
||||
if data[pos] != 0 {
|
||||
return "", fmt.Errorf("不支持的SNI类型")
|
||||
}
|
||||
pos++
|
||||
|
||||
// 读取主机名长度
|
||||
if pos+2 > len(data) {
|
||||
return "", fmt.Errorf("解析主机名长度时数据不足")
|
||||
}
|
||||
hostnameLength := int(data[pos])<<8 | int(data[pos+1])
|
||||
pos += 2
|
||||
|
||||
// 读取主机名
|
||||
if pos+hostnameLength > len(data) {
|
||||
return "", fmt.Errorf("解析主机名时数据不足")
|
||||
}
|
||||
hostname := string(data[pos : pos+hostnameLength])
|
||||
return hostname, nil
|
||||
}
|
||||
|
||||
// 跳过当前扩展
|
||||
pos += extensionLength
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("未找到SNI扩展")
|
||||
}
|
||||
|
||||
// 打印TLS记录的详细信息(用于调试)
|
||||
func printTLSRecord(data []byte, prefix string) {
|
||||
if len(data) < 5 {
|
||||
log.Printf("%s: 数据太短,不是有效的TLS记录", prefix)
|
||||
return
|
||||
}
|
||||
|
||||
recordType := data[0]
|
||||
version := (uint16(data[1]) << 8) | uint16(data[2])
|
||||
length := (uint16(data[3]) << 8) | uint16(data[4])
|
||||
|
||||
var recordTypeStr string
|
||||
switch recordType {
|
||||
case 20:
|
||||
recordTypeStr = "ChangeCipherSpec"
|
||||
case 21:
|
||||
recordTypeStr = "Alert"
|
||||
case 22:
|
||||
recordTypeStr = "Handshake"
|
||||
case 23:
|
||||
recordTypeStr = "ApplicationData"
|
||||
default:
|
||||
recordTypeStr = fmt.Sprintf("Unknown(%d)", recordType)
|
||||
}
|
||||
|
||||
var versionStr string
|
||||
switch version {
|
||||
case 0x0301:
|
||||
versionStr = "TLS 1.0"
|
||||
case 0x0302:
|
||||
versionStr = "TLS 1.1"
|
||||
case 0x0303:
|
||||
versionStr = "TLS 1.2"
|
||||
case 0x0304:
|
||||
versionStr = "TLS 1.3"
|
||||
default:
|
||||
versionStr = fmt.Sprintf("Unknown(0x%04x)", version)
|
||||
}
|
||||
|
||||
log.Printf("%s: 类型=%s, 版本=%s, 长度=%d, 总字节数=%d",
|
||||
prefix, recordTypeStr, versionStr, length, len(data))
|
||||
|
||||
if recordType == recordTypeHandshake && len(data) >= 6 {
|
||||
handshakeType := data[5]
|
||||
var handshakeTypeStr string
|
||||
switch handshakeType {
|
||||
case 1:
|
||||
handshakeTypeStr = "ClientHello"
|
||||
case 2:
|
||||
handshakeTypeStr = "ServerHello"
|
||||
case 11:
|
||||
handshakeTypeStr = "Certificate"
|
||||
case 16:
|
||||
handshakeTypeStr = "ClientKeyExchange"
|
||||
default:
|
||||
handshakeTypeStr = fmt.Sprintf("Unknown(%d)", handshakeType)
|
||||
}
|
||||
log.Printf("%s: 握手类型=%s", prefix, handshakeTypeStr)
|
||||
}
|
||||
}
|
||||
|
||||
// 验证TLS记录的完整性
|
||||
func validateTLSRecord(data []byte) (bool, string) {
|
||||
if len(data) < 5 {
|
||||
return false, "数据太短,不是有效的TLS记录"
|
||||
}
|
||||
|
||||
recordType := data[0]
|
||||
length := (uint16(data[3]) << 8) | uint16(data[4])
|
||||
|
||||
// 检查记录类型是否有效
|
||||
validTypes := map[byte]bool{
|
||||
20: true, // ChangeCipherSpec
|
||||
21: true, // Alert
|
||||
22: true, // Handshake
|
||||
23: true, // ApplicationData
|
||||
}
|
||||
if !validTypes[recordType] {
|
||||
return false, fmt.Sprintf("无效的TLS记录类型: %d", recordType)
|
||||
}
|
||||
|
||||
// 检查记录长度是否与实际数据长度匹配
|
||||
if int(length) != len(data)-5 {
|
||||
return false, fmt.Sprintf("TLS记录长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", length, len(data)-5)
|
||||
}
|
||||
|
||||
// 对于握手记录,进行更详细的验证
|
||||
if recordType == recordTypeHandshake && len(data) >= 9 {
|
||||
// 握手消息至少需要4字节的头部
|
||||
handshakeType := data[5]
|
||||
handshakeLength := (uint32(data[6]) << 16) | (uint32(data[7]) << 8) | uint32(data[8])
|
||||
|
||||
// 检查握手消息长度是否与记录长度匹配
|
||||
if int(handshakeLength)+4 > int(length) {
|
||||
return false, fmt.Sprintf("握手消息长度不匹配: 握手头部指示 %d 字节,记录头部指示 %d 字节", handshakeLength, length)
|
||||
}
|
||||
|
||||
// 对于ClientHello,进行更详细的验证
|
||||
if handshakeType == handshakeTypeClientHello && len(data) >= 43 {
|
||||
// ClientHello至少需要包含协议版本(2)、随机数(32)和会话ID长度(1)
|
||||
return true, "有效的ClientHello记录"
|
||||
}
|
||||
|
||||
return true, fmt.Sprintf("有效的握手记录,类型: %d", handshakeType)
|
||||
}
|
||||
|
||||
return true, fmt.Sprintf("有效的TLS记录,类型: %d", recordType)
|
||||
}
|
||||
|
||||
// 验证握手消息的完整性
|
||||
func validateHandshakeMessage(data []byte) (bool, string) {
|
||||
if len(data) < 4 {
|
||||
return false, "数据太短,不是有效的握手消息"
|
||||
}
|
||||
|
||||
handshakeType := data[0]
|
||||
handshakeLength := (uint32(data[1]) << 16) | (uint32(data[2]) << 8) | uint32(data[3])
|
||||
|
||||
// 检查握手消息长度是否与实际数据长度匹配
|
||||
if int(handshakeLength)+4 != len(data) {
|
||||
return false, fmt.Sprintf("握手消息长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", handshakeLength, len(data)-4)
|
||||
}
|
||||
|
||||
// 对于ClientHello,进行更详细的验证
|
||||
if handshakeType == handshakeTypeClientHello {
|
||||
if len(data) < 38 {
|
||||
return false, "ClientHello消息太短"
|
||||
}
|
||||
|
||||
// 检查协议版本、随机数等字段
|
||||
// 这里只是简单检查长度,实际应用中可以更详细地验证
|
||||
return true, "有效的ClientHello消息"
|
||||
}
|
||||
|
||||
return true, fmt.Sprintf("有效的握手消息,类型: %d", handshakeType)
|
||||
}
|
Loading…
Reference in New Issue