master
jianghanxin 2025-03-17 20:01:21 +08:00
commit da892685bc
10 changed files with 1924 additions and 0 deletions

149
README.md Normal file
View File

@ -0,0 +1,149 @@
# SNI 代理
这是一个基于 Go 语言实现的 SNI 代理,支持 TLS 分片功能和多种域名匹配方法。
## 功能特点
- 直接使用 SNI 域名作为目标地址
- 支持多种 DNS 协议解析域名UDP/TCP/DoT/DoH/DoQ
- 支持 TLS 分片(可配置分片大小范围)
- 多种域名匹配方式:
- 正则表达式匹配
- 后缀匹配
- 关键词匹配
- 完善的超时控制和连接管理
- 基于配置文件的灵活配置
## 安装
### 从源码构建
```bash
# 克隆仓库
git clone https://github.com/yourusername/SNI_Proxy.git
cd SNI_Proxy
# 安装依赖
go mod tidy
# 构建
go build -o sni_proxy
```
## 使用方法
1. 创建配置文件 `config.yaml`(参考示例配置文件)
2. 运行代理服务器:
```bash
./sni_proxy -config config.yaml
```
## 配置文件说明
配置文件使用 YAML 格式,包含以下主要部分:
```yaml
# 监听地址
listen: "0.0.0.0:443"
# 默认目标端口
default_port: 443
# DNS解析器配置
dns:
protocol: "udp" # 支持 udp, tcp, dot, doh, doq
server: "8.8.8.8:53" # DNS服务器地址
timeout: 5 # 超时时间(秒)
# 超时配置
timeout:
connect: 10 # 连接超时(秒)
read: 30 # 读取超时(秒)
write: 5 # 写入超时(秒)
idle: 60 # 空闲超时(秒)
lifetime: 300 # 连接最大生命周期(秒)
# 日志级别: debug, info, warn, error
log_level: "info"
# 最大并发连接数
max_conns: 1000
# 代理规则
rules:
# 规则示例
- domains:
- type: regexp # 支持 regexp, suffix, keyword
value: "^api\\.example\\.com$"
port: 443 # 目标端口
fragment:
enabled: true # 是否启用 TLS 分片
min_size: 100 # 最小分片大小
max_size: 500 # 最大分片大小
```
### 工作原理
1. 代理服务器接收客户端的 TLS 连接
2. 从 ClientHello 消息中提取 SNI服务器名称指示
3. 根据配置的规则匹配 SNI 域名
4. 使用配置的 DNS 解析器将 SNI 域名解析为 IP 地址
5. 使用解析得到的 IP 地址和配置的端口连接到目标服务器
6. 将客户端的请求转发到目标服务器,可选择进行 TLS 分片
### 域名匹配方式
1. **正则表达式匹配** (`regexp`):使用正则表达式匹配域名
2. **后缀匹配** (`suffix`):匹配指定后缀的域名
3. **关键词匹配** (`keyword`):匹配包含指定关键词的域名
### DNS 解析器配置
- `protocol`: DNS 协议类型,支持以下选项:
- `udp`: 标准 UDP DNS默认
- `tcp`: 标准 TCP DNS
- `dot`: DNS over TLS
- `doh`: DNS over HTTPS
- `doq`: DNS over QUIC目前不支持
- `server`: DNS 服务器地址,格式为 `IP:端口`
- `timeout`: DNS 查询超时时间(秒)
### 超时和连接管理配置
- `timeout`: 超时配置
- `connect`: 连接目标服务器的超时时间(秒)
- `read`: 读取数据的超时时间(秒)
- `write`: 写入数据的超时时间(秒)
- `idle`: 空闲连接的超时时间(秒)
- `lifetime`: 连接的最大生命周期(秒)
- `log_level`: 日志级别,支持 debug、info、warn、error
- `max_conns`: 最大并发连接数,超过此数量的新连接将被拒绝
### TLS 分片配置
- `enabled`: 是否启用 TLS 分片
- `min_size`: 最小分片大小(字节)
- `max_size`: 最大分片大小(字节)
## 性能优化
为了避免进程卡死和资源泄漏,本代理实现了以下优化:
1. **完善的超时控制**:为每个连接阶段设置合理的超时时间
2. **连接生命周期管理**:限制连接的最大生存时间
3. **连接数量限制**:防止过多连接导致资源耗尽
4. **空闲连接清理**:定期清理空闲连接
5. **错误处理优化**:优化常见网络错误的处理方式
6. **资源使用监控**:定期记录连接统计信息
## 注意事项
- 本程序需要以 root 权限运行才能监听 443 端口
- 确保目标服务器地址可达
- 正则表达式需要使用有效的 Go 语言正则表达式语法
- 适当调整超时和连接管理参数,以适应不同的网络环境
## 许可证
MIT

BIN
SNI_Proxy Executable file

Binary file not shown.

74
config.yaml Normal file
View File

@ -0,0 +1,74 @@
# SNI代理配置文件
# 基本配置
listen: "0.0.0.0:443"
default_port: 443
max_conns: 1000
log_level: "info" # 可选: debug, info, warn, error
# DNS解析器配置
dns:
protocol: "udp" # 支持: udp, tcp, dot, doh, doq
server: "8.8.8.8:53"
timeout: 5
# 超时配置(秒)
timeout:
connect: 10
read: 30
write: 5
idle: 60
lifetime: 300
# 代理规则
rules:
# Cloudflare DNS规则
- domains:
- type: suffix
value: "cloudflare-dns.com"
port: 443
fragment:
enabled: true
min_size: 10
max_size: 50
delay_min: 10
delay_max: 30
validate: true
# API规则
- domains:
- type: regexp
value: "^api\\.example\\.com$"
port: 443
fragment:
enabled: true
min_size: 100
max_size: 500
delay_min: 5
delay_max: 15
validate: true
# 自定义端口规则
- domains:
- type: suffix
value: ".example.org"
port: 8443
fragment:
enabled: false=
# 关键词匹配规则
- domains:
- type: keyword
value: "google"
- type: keyword
value: "youtube"
- type: keyword
value: "cloudflare"
port: 443
fragment:
enabled: true
min_size: 100
max_size: 200
delay_min: 8
delay_max: 25
validate: true

246
config/config.go Normal file
View File

@ -0,0 +1,246 @@
package config
import (
"fmt"
"os"
"regexp"
"gopkg.in/yaml.v3"
)
// 域名匹配规则类型
const (
MatchTypeRegexp = "regexp"
MatchTypeSuffix = "suffix"
MatchTypeKeyword = "keyword"
)
// DNS协议类型
const (
DNSProtocolUDP = "udp"
DNSProtocolTCP = "tcp"
DNSProtocolDoT = "dot" // DNS over TLS
DNSProtocolDoH = "doh" // DNS over HTTPS
DNSProtocolDoQ = "doq" // DNS over QUIC
)
// 默认配置值
const (
DefaultPort = 443
DefaultDNSProtocol = DNSProtocolUDP
DefaultDNSServer = "8.8.8.8:53"
DefaultDNSTimeout = 5
DefaultConnectTimeout = 10
DefaultReadTimeout = 30
DefaultWriteTimeout = 5
DefaultIdleTimeout = 60
DefaultLifeTime = 300
DefaultLogLevel = "info"
DefaultMaxConns = 1000
DefaultMinFragSize = 10
DefaultMaxFragSize = 100
DefaultMinDelay = 10
DefaultMaxDelay = 30
)
// 超时配置
type TimeoutConfig struct {
Connect int `yaml:"connect"` // 连接超时(秒)
Read int `yaml:"read"` // 读取超时(秒)
Write int `yaml:"write"` // 写入超时(秒)
Idle int `yaml:"idle"` // 空闲超时(秒)
LifeTime int `yaml:"lifetime"` // 连接最大生命周期(秒)
}
// TLS分片配置
type FragmentConfig struct {
Enabled bool `yaml:"enabled"` // 是否启用TLS分片
MinSize int `yaml:"min_size"` // 最小分片大小(字节)
MaxSize int `yaml:"max_size"` // 最大分片大小(字节)
DelayMin int `yaml:"delay_min"` // 分片之间的最小延迟(毫秒)
DelayMax int `yaml:"delay_max"` // 分片之间的最大延迟(毫秒)
Validate bool `yaml:"validate"` // 是否验证TLS记录完整性
}
// DNS解析器配置
type DNSResolverConfig struct {
Protocol string `yaml:"protocol"` // udp, tcp, dot, doh, doq
Server string `yaml:"server"` // 服务器地址,如 8.8.8.8:53, 1.1.1.1:853
Timeout int `yaml:"timeout"` // 超时时间(秒)
}
// 域名匹配规则
type DomainRule struct {
Type string `yaml:"type"` // regexp, suffix, keyword
Value string `yaml:"value"` // 匹配值
// 编译后的正则表达式仅当Type为regexp时使用
compiledRegexp *regexp.Regexp
}
// 代理规则
type ProxyRule struct {
Domains []DomainRule `yaml:"domains"`
Port int `yaml:"port"` // 目标端口默认为443
Fragment FragmentConfig `yaml:"fragment"`
}
// 配置结构
type Config struct {
Listen string `yaml:"listen"`
Rules []ProxyRule `yaml:"rules"`
DefaultPort int `yaml:"default_port"` // 默认目标端口,如果规则中未指定
DNS DNSResolverConfig `yaml:"dns"` // DNS解析器配置
Timeout TimeoutConfig `yaml:"timeout"` // 超时配置
LogLevel string `yaml:"log_level"` // 日志级别debug, info, warn, error
MaxConns int `yaml:"max_conns"` // 最大并发连接数
}
// 加载配置文件
func LoadConfig(path string) (*Config, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("读取配置文件失败: %w", err)
}
var cfg Config
if err := yaml.Unmarshal(data, &cfg); err != nil {
return nil, fmt.Errorf("解析配置文件失败: %w", err)
}
// 设置默认值
setDefaultValues(&cfg)
// 编译正则表达式
if err := compileRegexps(&cfg); err != nil {
return nil, err
}
return &cfg, nil
}
// 设置默认值
func setDefaultValues(cfg *Config) {
// 设置默认端口
if cfg.DefaultPort == 0 {
cfg.DefaultPort = DefaultPort
}
// 设置默认DNS配置
if cfg.DNS.Protocol == "" {
cfg.DNS.Protocol = DefaultDNSProtocol
}
if cfg.DNS.Server == "" {
cfg.DNS.Server = DefaultDNSServer
}
if cfg.DNS.Timeout == 0 {
cfg.DNS.Timeout = DefaultDNSTimeout
}
// 设置默认超时配置
if cfg.Timeout.Connect == 0 {
cfg.Timeout.Connect = DefaultConnectTimeout
}
if cfg.Timeout.Read == 0 {
cfg.Timeout.Read = DefaultReadTimeout
}
if cfg.Timeout.Write == 0 {
cfg.Timeout.Write = DefaultWriteTimeout
}
if cfg.Timeout.Idle == 0 {
cfg.Timeout.Idle = DefaultIdleTimeout
}
if cfg.Timeout.LifeTime == 0 {
cfg.Timeout.LifeTime = DefaultLifeTime
}
// 设置默认日志级别
if cfg.LogLevel == "" {
cfg.LogLevel = DefaultLogLevel
}
// 设置默认最大连接数
if cfg.MaxConns == 0 {
cfg.MaxConns = DefaultMaxConns
}
// 设置规则默认值
for i := range cfg.Rules {
// 如果规则中未指定端口,使用默认端口
if cfg.Rules[i].Port == 0 {
cfg.Rules[i].Port = cfg.DefaultPort
}
// 设置TLS分片配置的默认值
if cfg.Rules[i].Fragment.Enabled {
setFragmentDefaults(&cfg.Rules[i].Fragment)
}
}
}
// 设置分片默认值
func setFragmentDefaults(frag *FragmentConfig) {
// 设置默认的分片大小范围
if frag.MinSize <= 0 {
frag.MinSize = DefaultMinFragSize
}
if frag.MaxSize <= 0 {
frag.MaxSize = DefaultMaxFragSize
}
if frag.MinSize > frag.MaxSize {
frag.MinSize = frag.MaxSize
}
// 设置默认的分片延迟范围
if frag.DelayMin <= 0 {
frag.DelayMin = DefaultMinDelay
}
if frag.DelayMax <= 0 {
frag.DelayMax = DefaultMaxDelay
}
if frag.DelayMin > frag.DelayMax {
frag.DelayMin = frag.DelayMax
}
}
// 编译正则表达式
func compileRegexps(cfg *Config) error {
for i := range cfg.Rules {
for j := range cfg.Rules[i].Domains {
if cfg.Rules[i].Domains[j].Type == MatchTypeRegexp {
re, err := regexp.Compile(cfg.Rules[i].Domains[j].Value)
if err != nil {
return fmt.Errorf("编译正则表达式失败 '%s': %w",
cfg.Rules[i].Domains[j].Value, err)
}
cfg.Rules[i].Domains[j].compiledRegexp = re
}
}
}
return nil
}
// 检查域名是否匹配规则
func (r *DomainRule) Match(domain string) bool {
switch r.Type {
case MatchTypeRegexp:
return r.compiledRegexp.MatchString(domain)
case MatchTypeSuffix:
return len(domain) >= len(r.Value) &&
domain[len(domain)-len(r.Value):] == r.Value
case MatchTypeKeyword:
return contains(domain, r.Value)
default:
return false
}
}
// 辅助函数:检查字符串是否包含子串
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
if s[i:i+len(substr)] == substr {
return true
}
}
return false
}

398
dns/resolver.go Normal file
View File

@ -0,0 +1,398 @@
package dns
import (
"context"
"crypto/tls"
"encoding/json"
"fmt"
"io"
"net"
"net/http"
"strings"
"time"
"github.com/SNI_Proxy/config"
"github.com/miekg/dns"
)
// Resolver 是DNS解析器接口
type Resolver interface {
// Resolve 解析域名为IP地址
Resolve(domain string) ([]net.IP, error)
// Close 关闭解析器
Close() error
}
// 基础DNS解析器结构
type BaseResolver struct {
server string
timeout time.Duration
}
// 创建DNS解析器
func NewResolver(cfg config.DNSResolverConfig) (Resolver, error) {
switch cfg.Protocol {
case config.DNSProtocolUDP:
return NewUDPResolver(cfg)
case config.DNSProtocolTCP:
return NewTCPResolver(cfg)
case config.DNSProtocolDoT:
return NewDoTResolver(cfg)
case config.DNSProtocolDoH:
return NewDoHResolver(cfg)
case config.DNSProtocolDoQ:
return NewDoQResolver(cfg)
default:
return nil, fmt.Errorf("不支持的DNS协议: %s", cfg.Protocol)
}
}
// UDP DNS解析器
type UDPResolver struct {
BaseResolver
client *dns.Client
}
func NewUDPResolver(cfg config.DNSResolverConfig) (*UDPResolver, error) {
timeout := time.Duration(cfg.Timeout) * time.Second
return &UDPResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: timeout,
},
client: &dns.Client{
Net: "udp",
Timeout: timeout,
},
}, nil
}
// 执行DNS查询
func dnsQuery(client *dns.Client, server string, domain string, qtype uint16) ([]net.IP, error) {
m := new(dns.Msg)
m.SetQuestion(dns.Fqdn(domain), qtype)
m.RecursionDesired = true
resp, _, err := client.Exchange(m, server)
if err != nil {
return nil, err
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("DNS查询返回错误码: %d", resp.Rcode)
}
var ips []net.IP
for _, ans := range resp.Answer {
switch rr := ans.(type) {
case *dns.A:
ips = append(ips, rr.A)
case *dns.AAAA:
ips = append(ips, rr.AAAA)
}
}
return ips, nil
}
func (r *UDPResolver) Resolve(domain string) ([]net.IP, error) {
// 先尝试A记录
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
if err != nil {
return nil, fmt.Errorf("UDP DNS A查询失败: %w", err)
}
// 如果没有A记录尝试AAAA记录
if len(ips) == 0 {
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
if err != nil {
return nil, fmt.Errorf("UDP DNS AAAA查询失败: %w", err)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
}
return ips, nil
}
func (r *UDPResolver) Close() error {
return nil // UDP没有需要关闭的资源
}
// TCP DNS解析器
type TCPResolver struct {
BaseResolver
client *dns.Client
}
func NewTCPResolver(cfg config.DNSResolverConfig) (*TCPResolver, error) {
timeout := time.Duration(cfg.Timeout) * time.Second
return &TCPResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: timeout,
},
client: &dns.Client{
Net: "tcp",
Timeout: timeout,
},
}, nil
}
func (r *TCPResolver) Resolve(domain string) ([]net.IP, error) {
// 先尝试A记录
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
if err != nil {
return nil, fmt.Errorf("TCP DNS A查询失败: %w", err)
}
// 如果没有A记录尝试AAAA记录
if len(ips) == 0 {
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
if err != nil {
return nil, fmt.Errorf("TCP DNS AAAA查询失败: %w", err)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
}
return ips, nil
}
func (r *TCPResolver) Close() error {
return nil // TCP没有需要关闭的资源
}
// DoT DNS解析器 (DNS over TLS)
type DoTResolver struct {
BaseResolver
client *dns.Client
}
func NewDoTResolver(cfg config.DNSResolverConfig) (*DoTResolver, error) {
timeout := time.Duration(cfg.Timeout) * time.Second
return &DoTResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: timeout,
},
client: &dns.Client{
Net: "tcp-tls",
Timeout: timeout,
TLSConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
},
}, nil
}
func (r *DoTResolver) Resolve(domain string) ([]net.IP, error) {
// 先尝试A记录
ips, err := dnsQuery(r.client, r.server, domain, dns.TypeA)
if err != nil {
return nil, fmt.Errorf("DoT DNS A查询失败: %w", err)
}
// 如果没有A记录尝试AAAA记录
if len(ips) == 0 {
ips, err = dnsQuery(r.client, r.server, domain, dns.TypeAAAA)
if err != nil {
return nil, fmt.Errorf("DoT DNS AAAA查询失败: %w", err)
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
}
return ips, nil
}
func (r *DoTResolver) Close() error {
return nil // DoT没有需要关闭的资源
}
// DoH DNS解析器 (DNS over HTTPS)
type DoHResolver struct {
BaseResolver
client *http.Client
}
func NewDoHResolver(cfg config.DNSResolverConfig) (*DoHResolver, error) {
timeout := time.Duration(cfg.Timeout) * time.Second
return &DoHResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: timeout,
},
client: &http.Client{
Timeout: timeout,
Transport: &http.Transport{
TLSClientConfig: &tls.Config{
MinVersion: tls.VersionTLS12,
},
IdleConnTimeout: 30 * time.Second,
},
},
}, nil
}
type dohResponse struct {
Status int `json:"Status"`
TC bool `json:"TC"`
RD bool `json:"RD"`
RA bool `json:"RA"`
AD bool `json:"AD"`
CD bool `json:"CD"`
Question []struct {
Name string `json:"name"`
Type int `json:"type"`
} `json:"Question"`
Answer []struct {
Name string `json:"name"`
Type int `json:"type"`
TTL int `json:"TTL"`
Data string `json:"data"`
} `json:"Answer"`
}
func (r *DoHResolver) Resolve(domain string) ([]net.IP, error) {
// 构建DoH请求URL
url := r.server
if !strings.HasPrefix(url, "https://") {
url = "https://" + url
}
if !strings.Contains(url, "?") {
url += "?name=" + domain + "&type=A"
}
// 发送请求
ctx, cancel := context.WithTimeout(context.Background(), r.timeout)
defer cancel()
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("创建DoH请求失败: %w", err)
}
req.Header.Set("Accept", "application/dns-json")
resp, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("DoH请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DoH请求返回非200状态码: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取DoH响应失败: %w", err)
}
var dohResp dohResponse
if err := json.Unmarshal(body, &dohResp); err != nil {
return nil, fmt.Errorf("解析DoH响应失败: %w", err)
}
if dohResp.Status != 0 {
return nil, fmt.Errorf("DoH查询返回错误码: %d", dohResp.Status)
}
var ips []net.IP
for _, ans := range dohResp.Answer {
if ans.Type == 1 { // A记录
ip := net.ParseIP(ans.Data)
if ip != nil {
ips = append(ips, ip)
}
}
}
// 如果没有A记录尝试AAAA记录
if len(ips) == 0 {
url := r.server
if !strings.HasPrefix(url, "https://") {
url = "https://" + url
}
if !strings.Contains(url, "?") {
url += "?name=" + domain + "&type=AAAA"
} else {
url = strings.Replace(url, "type=A", "type=AAAA", 1)
}
req, err := http.NewRequestWithContext(ctx, "GET", url, nil)
if err != nil {
return nil, fmt.Errorf("创建DoH AAAA请求失败: %w", err)
}
req.Header.Set("Accept", "application/dns-json")
resp, err := r.client.Do(req)
if err != nil {
return nil, fmt.Errorf("DoH AAAA请求失败: %w", err)
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("DoH AAAA请求返回非200状态码: %d", resp.StatusCode)
}
body, err := io.ReadAll(resp.Body)
if err != nil {
return nil, fmt.Errorf("读取DoH AAAA响应失败: %w", err)
}
if err := json.Unmarshal(body, &dohResp); err != nil {
return nil, fmt.Errorf("解析DoH AAAA响应失败: %w", err)
}
for _, ans := range dohResp.Answer {
if ans.Type == 28 { // AAAA记录
ip := net.ParseIP(ans.Data)
if ip != nil {
ips = append(ips, ip)
}
}
}
}
if len(ips) == 0 {
return nil, fmt.Errorf("未找到域名 %s 的IP地址", domain)
}
return ips, nil
}
func (r *DoHResolver) Close() error {
r.client.CloseIdleConnections()
return nil
}
// DoQ DNS解析器 (DNS over QUIC)
type DoQResolver struct {
BaseResolver
}
func NewDoQResolver(cfg config.DNSResolverConfig) (*DoQResolver, error) {
return &DoQResolver{
BaseResolver: BaseResolver{
server: cfg.Server,
timeout: time.Duration(cfg.Timeout) * time.Second,
},
}, nil
}
func (r *DoQResolver) Resolve(domain string) ([]net.IP, error) {
// DoQ实现较为复杂需要QUIC协议支持
// 这里仅作为占位符,实际项目中需要实现
return nil, fmt.Errorf("DoQ协议暂未实现")
}
func (r *DoQResolver) Close() error {
return nil
}

15
go.mod Normal file
View File

@ -0,0 +1,15 @@
module github.com/SNI_Proxy
go 1.20
require (
github.com/miekg/dns v1.1.58
gopkg.in/yaml.v3 v3.0.1
)
require (
golang.org/x/mod v0.14.0 // indirect
golang.org/x/net v0.20.0 // indirect
golang.org/x/sys v0.16.0 // indirect
golang.org/x/tools v0.17.0 // indirect
)

15
go.sum Normal file
View File

@ -0,0 +1,15 @@
github.com/miekg/dns v1.1.58 h1:ca2Hdkz+cDg/7eNF6V56jjzuZ4aCAE+DbVkILdQWG/4=
github.com/miekg/dns v1.1.58/go.mod h1:Ypv+3b/KadlvW9vJfXOTf300O4UqaHFzFCuHz+rPkBY=
golang.org/x/mod v0.14.0 h1:dGoOF9QVLYng8IHTm7BAyWqCqSheQ5pYWGhzW00YJr0=
golang.org/x/mod v0.14.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/net v0.20.0 h1:aCL9BSgETF1k+blQaYUBx9hJ9LOGP3gAVemcZlf1Kpo=
golang.org/x/net v0.20.0/go.mod h1:z8BVo6PvndSri0LbOE3hAn0apkU+1YvI6E70E9jsnvY=
golang.org/x/sync v0.6.0 h1:5BMeUDZ7vkXGfEr1x9B4bRcTH4lpkTkpdh0T/J+qjbQ=
golang.org/x/sys v0.16.0 h1:xWw16ngr6ZMtmxDyKyIgsE93KNKz5HKmMa3b8ALHidU=
golang.org/x/sys v0.16.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/tools v0.17.0 h1:FvmRgNOcs3kOa+T20R1uhfP9F6HgG2mfxDv1vrx1Htc=
golang.org/x/tools v0.17.0/go.mod h1:xsh6VxdV005rRVaS6SSAf9oiAqljS7UZUacMZ8Bnsps=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=

3
init.txt Normal file
View File

@ -0,0 +1,3 @@
give me golang example to make a SNI proxy, which can do tls fragement (configured in config file), the domain match should use regexp / suffix / keyword method.
Ask me more if needed.
curl --resolve "cloudflare-dns.com:443:127.0.0.1" -H "accept: application/dns-json" "https://cloudflare-dns.com/dns-query?name=example.com&type=A"

46
main.go Normal file
View File

@ -0,0 +1,46 @@
package main
import (
"flag"
"log"
"os"
"os/signal"
"syscall"
"github.com/SNI_Proxy/config"
"github.com/SNI_Proxy/proxy"
)
func main() {
// 解析命令行参数
configPath := flag.String("config", "config.yaml", "配置文件路径")
flag.Parse()
// 加载配置
cfg, err := config.LoadConfig(*configPath)
if err != nil {
log.Fatalf("加载配置失败: %v", err)
}
// 创建并启动代理服务器
server := proxy.NewServer(cfg)
if err := server.Start(); err != nil {
log.Fatalf("启动代理服务器失败: %v", err)
}
log.Printf("SNI代理服务器已启动监听地址: %s", cfg.Listen)
// 等待中断信号以优雅地关闭服务器
waitForInterrupt(server)
}
// 等待中断信号
func waitForInterrupt(server *proxy.Server) {
sigChan := make(chan os.Signal, 1)
signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
<-sigChan
log.Println("正在关闭服务器...")
server.Stop()
log.Println("服务器已关闭")
}

978
proxy/server.go Normal file
View File

@ -0,0 +1,978 @@
package proxy
import (
"context"
"crypto/rand"
"fmt"
"io"
"log"
"math/big"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/SNI_Proxy/config"
"github.com/SNI_Proxy/dns"
)
// 服务器结构
type Server struct {
config *config.Config
listener net.Listener
wg sync.WaitGroup
shutdownCh chan struct{}
resolver dns.Resolver
connMutex sync.Mutex
conns map[net.Conn]struct{} // 活跃连接跟踪
connCounter int64 // 当前连接计数器
}
// 创建新的代理服务器
func NewServer(cfg *config.Config) *Server {
return &Server{
config: cfg,
shutdownCh: make(chan struct{}),
conns: make(map[net.Conn]struct{}),
}
}
// 启动代理服务器
func (s *Server) Start() error {
var err error
// 创建DNS解析器
s.resolver, err = dns.NewResolver(s.config.DNS)
if err != nil {
return fmt.Errorf("创建DNS解析器失败: %w", err)
}
s.listener, err = net.Listen("tcp", s.config.Listen)
if err != nil {
return fmt.Errorf("监听地址失败: %w", err)
}
s.wg.Add(1)
go s.acceptLoop()
// 启动连接监控
go s.monitorConnections()
// 启动连接清理
go s.cleanIdleConnections()
return nil
}
// 停止代理服务器
func (s *Server) Stop() {
log.Println("正在关闭服务器...")
close(s.shutdownCh)
if s.listener != nil {
s.listener.Close()
}
// 关闭所有活跃连接
s.closeAllConnections()
if s.resolver != nil {
s.resolver.Close()
}
// 等待所有goroutine结束
log.Println("等待所有连接关闭...")
s.wg.Wait()
log.Println("服务器已关闭")
}
// 关闭所有活跃连接
func (s *Server) closeAllConnections() {
s.connMutex.Lock()
defer s.connMutex.Unlock()
log.Printf("关闭 %d 个活跃连接", len(s.conns))
for conn := range s.conns {
conn.Close()
delete(s.conns, conn)
}
}
// 添加连接到跟踪列表
func (s *Server) trackConn(conn net.Conn) bool {
s.connMutex.Lock()
defer s.connMutex.Unlock()
// 检查是否超过最大连接数
if s.config.MaxConns > 0 && len(s.conns) >= s.config.MaxConns {
log.Printf("达到最大连接数 %d拒绝新连接", s.config.MaxConns)
return false
}
s.conns[conn] = struct{}{}
atomic.AddInt64(&s.connCounter, 1)
return true
}
// 从跟踪列表中移除连接
func (s *Server) untrackConn(conn net.Conn) {
s.connMutex.Lock()
defer s.connMutex.Unlock()
if _, exists := s.conns[conn]; exists {
delete(s.conns, conn)
atomic.AddInt64(&s.connCounter, -1)
}
}
// 监控连接状态
func (s *Server) monitorConnections() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.shutdownCh:
return
case <-ticker.C:
// 记录当前活跃连接数
s.connMutex.Lock()
activeConns := len(s.conns)
s.connMutex.Unlock()
totalConns := atomic.LoadInt64(&s.connCounter)
log.Printf("连接统计 - 当前活跃: %d, 总计处理: %d", activeConns, totalConns)
}
}
}
// 清理空闲连接
func (s *Server) cleanIdleConnections() {
// 每分钟检查一次空闲连接
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.shutdownCh:
return
case <-ticker.C:
s.connMutex.Lock()
log.Printf("开始清理空闲连接,当前连接数: %d", len(s.conns))
// 这里我们不做实际清理,因为连接已经设置了超时
// 实际清理由各个连接自己的超时机制处理
s.connMutex.Unlock()
}
}
}
// 接受连接循环
func (s *Server) acceptLoop() {
defer s.wg.Done()
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.shutdownCh:
return
default:
log.Printf("接受连接失败: %v", err)
// 短暂休眠避免CPU占用过高
time.Sleep(100 * time.Millisecond)
continue
}
}
// 检查是否可以接受新连接
if !s.trackConn(conn) {
conn.Close()
continue
}
s.wg.Add(1)
go func(clientConn net.Conn) {
defer s.wg.Done()
defer s.untrackConn(clientConn) // 移除连接跟踪
defer clientConn.Close()
// 设置连接生命周期
if s.config.Timeout.LifeTime > 0 {
deadline := time.Now().Add(time.Duration(s.config.Timeout.LifeTime) * time.Second)
clientConn.SetDeadline(deadline)
}
// 创建一个带超时的上下文
ctx, cancel := context.WithTimeout(context.Background(),
time.Duration(s.config.Timeout.Idle)*time.Second)
defer cancel()
// 使用goroutine和channel处理连接支持超时控制
errCh := make(chan error, 1)
go func() {
errCh <- s.handleConnection(clientConn)
}()
// 等待处理完成或超时
select {
case err := <-errCh:
if err != nil {
if s.config.LogLevel == "debug" ||
(s.config.LogLevel != "error" && !isCommonNetworkError(err)) {
log.Printf("处理连接失败: %v", err)
}
}
case <-ctx.Done():
log.Printf("处理连接超时")
case <-s.shutdownCh:
log.Printf("服务器关闭,终止连接处理")
}
}(conn)
}
}
// 判断是否是常见网络错误(可以不记录日志的错误)
func isCommonNetworkError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// 检查常见的网络错误
commonErrors := []string{
"connection reset by peer",
"broken pipe",
"i/o timeout",
"use of closed network connection",
"EOF",
}
for _, e := range commonErrors {
if strings.Contains(strings.ToLower(errStr), e) {
return true
}
}
return false
}
// 处理单个连接
func (s *Server) handleConnection(clientConn net.Conn) error {
// 设置读取超时
readTimeout := time.Duration(s.config.Timeout.Read) * time.Second
if err := clientConn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
return fmt.Errorf("设置读取超时失败: %w", err)
}
// 读取并解析SNI信息
sni, clientHello, err := readSNI(clientConn)
if err != nil {
return fmt.Errorf("读取SNI失败: %w", err)
}
// 重置读取超时
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("重置读取超时失败: %w", err)
}
// 查找匹配的规则
var targetRule *config.ProxyRule
for i := range s.config.Rules {
rule := &s.config.Rules[i]
for j := range rule.Domains {
if rule.Domains[j].Match(sni) {
targetRule = rule
break
}
}
if targetRule != nil {
break
}
}
if targetRule == nil {
// 如果没有匹配的规则,使用默认端口
if s.config.LogLevel != "error" {
log.Printf("未找到匹配的规则,使用默认端口 %d 连接到 %s", s.config.DefaultPort, sni)
}
return s.proxyConnection(clientConn, sni, s.config.DefaultPort, clientHello, config.FragmentConfig{Enabled: false})
}
// 使用匹配规则中的端口和分片配置
return s.proxyConnection(clientConn, sni, targetRule.Port, clientHello, targetRule.Fragment)
}
// 代理连接到目标服务器
func (s *Server) proxyConnection(clientConn net.Conn, targetHost string, targetPort int, clientHello []byte, fragConfig config.FragmentConfig) error {
// 设置DNS解析超时
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.DNS.Timeout)*time.Second)
defer cancel()
// 创建一个带超时的解析任务
type resolveResult struct {
ips []net.IP
err error
}
resCh := make(chan resolveResult, 1)
go func() {
ips, err := s.resolver.Resolve(targetHost)
resCh <- resolveResult{ips, err}
}()
// 等待DNS解析完成或超时
var ips []net.IP
select {
case res := <-resCh:
if res.err != nil {
return fmt.Errorf("解析域名失败: %w", res.err)
}
ips = res.ips
case <-ctx.Done():
return fmt.Errorf("解析域名超时")
}
// 选择第一个IP地址
targetIP := ips[0]
if s.config.LogLevel == "debug" {
log.Printf("域名 %s 解析为 %s", targetHost, targetIP.String())
}
// 构建目标地址
targetAddr := net.JoinHostPort(targetIP.String(), strconv.Itoa(targetPort))
if s.config.LogLevel != "error" {
log.Printf("代理连接到: %s (原始SNI: %s)", targetAddr, targetHost)
}
// 设置连接超时
dialCtx, dialCancel := context.WithTimeout(context.Background(),
time.Duration(s.config.Timeout.Connect)*time.Second)
defer dialCancel()
// 连接到目标服务器
var d net.Dialer
targetConn, err := d.DialContext(dialCtx, "tcp", targetAddr)
if err != nil {
return fmt.Errorf("连接目标服务器失败: %w", err)
}
defer targetConn.Close()
// 设置目标连接的读写超时
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(30 * time.Second)
}
// 发送ClientHello到目标服务器可能需要进行TLS分片
if err := s.sendClientHello(targetConn, clientHello, fragConfig); err != nil {
return fmt.Errorf("发送ClientHello失败: %w", err)
}
// 双向转发数据
errCh := make(chan error, 2)
copyDone := make(chan struct{}, 2)
// 客户端 -> 目标服务器
go func() {
buf := make([]byte, 32*1024) // 使用较大的缓冲区
_, err := s.copyBuffer(targetConn, clientConn, buf)
errCh <- err
copyDone <- struct{}{}
}()
// 目标服务器 -> 客户端
go func() {
buf := make([]byte, 32*1024) // 使用较大的缓冲区
_, err := s.copyBuffer(clientConn, targetConn, buf)
errCh <- err
copyDone <- struct{}{}
}()
// 等待任一方向的数据传输完成
<-copyDone
// 设置一个短暂的超时,让另一个方向有机会完成
time.Sleep(100 * time.Millisecond)
// 检查是否有错误
select {
case err := <-errCh:
if err != nil && err != io.EOF && !isCommonNetworkError(err) {
return fmt.Errorf("数据转发失败: %w", err)
}
default:
// 没有错误
}
return nil
}
// 带超时控制的数据复制
func (s *Server) copyBuffer(dst net.Conn, src net.Conn, buf []byte) (int64, error) {
var written int64
readTimeout := time.Duration(s.config.Timeout.Read) * time.Second
writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second
for {
// 设置读取超时
src.SetReadDeadline(time.Now().Add(readTimeout))
nr, err := src.Read(buf)
if nr > 0 {
// 设置写入超时
dst.SetWriteDeadline(time.Now().Add(writeTimeout))
nw, err := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if err != nil {
return written, err
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if err != nil {
if err == io.EOF {
return written, nil
}
return written, err
}
}
}
// 发送ClientHello根据配置进行TLS分片
func (s *Server) sendClientHello(conn net.Conn, clientHello []byte, fragConfig config.FragmentConfig) error {
writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second
if s.config.LogLevel == "debug" {
printTLSRecord(clientHello, "原始ClientHello")
}
if !fragConfig.Enabled {
// 不进行分片,直接发送
if s.config.LogLevel == "debug" {
log.Printf("TLS分片未启用直接发送 %d 字节", len(clientHello))
}
// 设置写入超时
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := conn.Write(clientHello)
return err
}
// 进行TLS分片
if s.config.LogLevel == "debug" {
log.Printf("启用TLS分片分片大小范围: %d-%d", fragConfig.MinSize, fragConfig.MaxSize)
}
// 检查ClientHello是否是有效的TLS记录
if fragConfig.Validate {
valid, reason := validateTLSRecord(clientHello)
if !valid {
log.Printf("警告: ClientHello不是有效的TLS记录 (%s),跳过分片", reason)
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := conn.Write(clientHello)
return err
}
}
if clientHello[0] != recordTypeHandshake {
log.Printf("警告: 记录类型不是握手类型,跳过分片")
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := conn.Write(clientHello)
return err
}
// 获取TLS记录头部和数据部分
recordHeader := clientHello[:5]
recordData := clientHello[5:]
recordLength := len(recordData)
if s.config.LogLevel == "debug" {
log.Printf("TLS记录头部: %v, 数据长度: %d", recordHeader, recordLength)
// 如果是握手消息,打印握手消息头部信息
if recordData[0] == handshakeTypeClientHello && len(recordData) >= 4 {
handshakeType := recordData[0]
handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3])
log.Printf("握手消息头部: 类型=%d, 长度=%d", handshakeType, handshakeLength)
}
}
// 如果记录太小,不进行分片
if recordLength <= fragConfig.MinSize {
if s.config.LogLevel == "debug" {
log.Printf("记录太小 (%d 字节),不进行分片", recordLength)
}
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := conn.Write(clientHello)
return err
}
// 第一个分片:发送记录头部和部分数据
minSize := fragConfig.MinSize
if minSize <= 0 {
minSize = 10 // 默认最小分片大小
}
maxSize := fragConfig.MaxSize
if maxSize <= 0 {
maxSize = 100 // 默认最大分片大小
}
if minSize > maxSize {
minSize = maxSize
}
// 随机分片大小
size := minSize
if maxSize > minSize {
randVal, _ := rand.Int(rand.Reader, big.NewInt(int64(maxSize-minSize+1)))
size = minSize + int(randVal.Int64())
}
// 确保不超过记录数据大小
if size > recordLength {
size = recordLength
}
// 对于握手消息确保第一个分片至少包含完整的握手消息头部4字节
if recordData[0] == handshakeTypeClientHello && size < 4 {
size = 4 // 至少包含握手消息头部
}
// 创建第一个分片
firstFragment := make([]byte, 5+size)
copy(firstFragment[:5], recordHeader)
copy(firstFragment[5:], recordData[:size])
// 修改第一个分片的长度字段
firstFragment[3] = byte(size >> 8)
firstFragment[4] = byte(size)
if s.config.LogLevel == "debug" {
printTLSRecord(firstFragment, "第一个分片")
log.Printf("第一个分片长度: %d (头部5字节 + 数据%d字节)", len(firstFragment), size)
// 验证第一个分片的握手消息
if recordData[0] == handshakeTypeClientHello {
// 创建一个临时的握手消息进行验证
tempHandshake := make([]byte, size)
copy(tempHandshake, recordData[:size])
// 调整握手消息长度以匹配实际数据
handshakeLength := uint32(size - 4) // 减去握手消息头部的4字节
tempHandshake[1] = byte(handshakeLength >> 16)
tempHandshake[2] = byte(handshakeLength >> 8)
tempHandshake[3] = byte(handshakeLength)
valid, reason := validateHandshakeMessage(tempHandshake)
log.Printf("第一个分片握手消息验证: %v, %s", valid, reason)
}
}
// 设置写入超时并发送第一个分片
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if _, err := conn.Write(firstFragment); err != nil {
return fmt.Errorf("发送第一个分片失败: %w", err)
}
// 如果还有剩余数据,创建第二个分片
if size < recordLength {
// 添加一个小延迟,模拟网络延迟,提高分片效果
delayMin := fragConfig.DelayMin
if delayMin <= 0 {
delayMin = 10 // 默认最小延迟10毫秒
}
delayMax := fragConfig.DelayMax
if delayMax <= 0 {
delayMax = 30 // 默认最大延迟30毫秒
}
if delayMin > delayMax {
delayMin = delayMax
}
delayRange := big.NewInt(int64(delayMax - delayMin + 1))
randVal, _ := rand.Int(rand.Reader, delayRange)
delay := time.Duration(delayMin+int(randVal.Int64())) * time.Millisecond
time.Sleep(delay)
if s.config.LogLevel == "debug" {
log.Printf("延迟 %v 后发送第二个分片", delay)
}
// 创建第二个分片
remainingSize := recordLength - size
secondFragment := make([]byte, 5+remainingSize)
// 复制记录头部(保持与原始记录头部一致)
copy(secondFragment[:5], recordHeader)
// 设置正确的长度字段
secondFragment[3] = byte(remainingSize >> 8)
secondFragment[4] = byte(remainingSize)
// 复制剩余数据,确保握手类型保持一致
// 注意:我们不能直接复制数据,因为第二个分片的握手消息头部需要特殊处理
if recordData[0] == handshakeTypeClientHello {
// 创建新的握手消息头部
// 1. 保持握手类型不变
// 2. 调整握手消息长度
handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3])
// 计算第一个分片中已发送的握手消息数据长度不包括握手消息头部的4字节
sentDataLength := uint32(size) - 4
// 计算剩余的握手消息长度
remainingHandshakeLength := handshakeLength - sentDataLength
if s.config.LogLevel == "debug" {
log.Printf("握手消息总长度: %d, 第一个分片已发送数据: %d, 第二个分片剩余数据: %d",
handshakeLength, sentDataLength, remainingHandshakeLength)
}
// 复制握手类型
secondFragment[5] = recordData[0]
// 设置新的握手消息长度
secondFragment[6] = byte(remainingHandshakeLength >> 16)
secondFragment[7] = byte(remainingHandshakeLength >> 8)
secondFragment[8] = byte(remainingHandshakeLength)
// 复制剩余数据(跳过第一个分片已经发送的部分)
copy(secondFragment[9:], recordData[size:])
} else {
// 如果不是ClientHello直接复制数据
copy(secondFragment[5:], recordData[size:])
}
// 对于第二个分片,我们采用不同的策略
// 不再尝试修改握手消息头部,而是直接将剩余数据作为应用数据发送
// 这样可以避免TLS握手解析错误
// 修改第二个分片的记录类型为应用数据
secondFragment[0] = 23 // ApplicationData
// 复制剩余数据
copy(secondFragment[5:], recordData[size:])
if s.config.LogLevel == "debug" {
log.Printf("第二个分片使用应用数据类型,避免握手解析错误")
printTLSRecord(secondFragment, "第二个分片")
log.Printf("第二个分片长度: %d (头部5字节 + 数据%d字节)", len(secondFragment), remainingSize)
}
// 设置写入超时并发送第二个分片
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if _, err := conn.Write(secondFragment); err != nil {
return fmt.Errorf("发送第二个分片失败: %w", err)
}
if s.config.LogLevel == "debug" {
log.Printf("TLS分片完成: 总共发送 %d 字节 (第一分片: %d 字节, 第二分片: %d 字节)",
len(firstFragment)+len(secondFragment), len(firstFragment), len(secondFragment))
}
} else {
if s.config.LogLevel == "debug" {
log.Printf("TLS分片完成: 只需一个分片,总共发送 %d 字节", len(firstFragment))
}
}
return nil
}
// TLS记录类型
const (
recordTypeHandshake = 22
)
// TLS握手类型
const (
handshakeTypeClientHello = 1
)
// 读取SNI信息
func readSNI(conn net.Conn) (string, []byte, error) {
// 读取TLS记录头部
header := make([]byte, 5)
if _, err := io.ReadFull(conn, header); err != nil {
return "", nil, fmt.Errorf("读取TLS记录头部失败: %w", err)
}
// 检查是否是TLS握手记录
if header[0] != recordTypeHandshake {
return "", nil, fmt.Errorf("不是TLS握手记录记录类型: %d", header[0])
}
// 获取记录长度
recordLength := int(header[3])<<8 | int(header[4])
if recordLength > 16384 {
return "", nil, fmt.Errorf("TLS记录太长: %d 字节", recordLength)
}
if recordLength < 4 {
return "", nil, fmt.Errorf("TLS握手记录太短: %d 字节", recordLength)
}
// 读取握手消息
handshakeData := make([]byte, recordLength)
if _, err := io.ReadFull(conn, handshakeData); err != nil {
return "", nil, fmt.Errorf("读取握手消息失败: %w", err)
}
// 检查是否是ClientHello
if handshakeData[0] != handshakeTypeClientHello {
return "", nil, fmt.Errorf("不是ClientHello消息握手类型: %d", handshakeData[0])
}
// 完整的ClientHello消息包括记录头部
clientHello := append(header, handshakeData...)
// 验证TLS记录的完整性
valid, reason := validateTLSRecord(clientHello)
if !valid {
log.Printf("警告: 收到的ClientHello不是有效的TLS记录: %s", reason)
}
// 解析SNI扩展
sni, err := extractSNI(handshakeData)
if err != nil {
return "", nil, fmt.Errorf("提取SNI失败: %w", err)
}
if sni != "" && len(sni) > 0 {
log.Printf("成功提取SNI: %s", sni)
}
return sni, clientHello, nil
}
// 从ClientHello中提取SNI
func extractSNI(data []byte) (string, error) {
if len(data) < 42 {
return "", fmt.Errorf("ClientHello太短")
}
// 跳过握手类型(1) + 长度(3) + 协议版本(2) + 随机数(32)
pos := 38
// 跳过会话ID
if pos+1 > len(data) {
return "", fmt.Errorf("解析会话ID长度时数据不足")
}
sessionIDLength := int(data[pos])
pos += 1 + sessionIDLength
// 跳过密码套件
if pos+2 > len(data) {
return "", fmt.Errorf("解析密码套件长度时数据不足")
}
cipherSuitesLength := int(data[pos])<<8 | int(data[pos+1])
pos += 2 + cipherSuitesLength
// 跳过压缩方法
if pos+1 > len(data) {
return "", fmt.Errorf("解析压缩方法长度时数据不足")
}
compressionMethodsLength := int(data[pos])
pos += 1 + compressionMethodsLength
// 检查是否有扩展
if pos+2 > len(data) {
return "", fmt.Errorf("没有扩展数据")
}
extensionsLength := int(data[pos])<<8 | int(data[pos+1])
pos += 2
// 解析扩展
extensionsEnd := pos + extensionsLength
if extensionsEnd > len(data) {
return "", fmt.Errorf("扩展数据不足")
}
for pos < extensionsEnd {
// 读取扩展类型
if pos+4 > len(data) {
return "", fmt.Errorf("解析扩展类型时数据不足")
}
extensionType := int(data[pos])<<8 | int(data[pos+1])
extensionLength := int(data[pos+2])<<8 | int(data[pos+3])
pos += 4
// 检查是否是SNI扩展 (类型为0)
if extensionType == 0 {
// 跳过SNI列表长度
if pos+2 > len(data) {
return "", fmt.Errorf("解析SNI列表长度时数据不足")
}
pos += 2
// 检查SNI类型
if pos+1 > len(data) {
return "", fmt.Errorf("解析SNI类型时数据不足")
}
if data[pos] != 0 {
return "", fmt.Errorf("不支持的SNI类型")
}
pos++
// 读取主机名长度
if pos+2 > len(data) {
return "", fmt.Errorf("解析主机名长度时数据不足")
}
hostnameLength := int(data[pos])<<8 | int(data[pos+1])
pos += 2
// 读取主机名
if pos+hostnameLength > len(data) {
return "", fmt.Errorf("解析主机名时数据不足")
}
hostname := string(data[pos : pos+hostnameLength])
return hostname, nil
}
// 跳过当前扩展
pos += extensionLength
}
return "", fmt.Errorf("未找到SNI扩展")
}
// 打印TLS记录的详细信息用于调试
func printTLSRecord(data []byte, prefix string) {
if len(data) < 5 {
log.Printf("%s: 数据太短不是有效的TLS记录", prefix)
return
}
recordType := data[0]
version := (uint16(data[1]) << 8) | uint16(data[2])
length := (uint16(data[3]) << 8) | uint16(data[4])
var recordTypeStr string
switch recordType {
case 20:
recordTypeStr = "ChangeCipherSpec"
case 21:
recordTypeStr = "Alert"
case 22:
recordTypeStr = "Handshake"
case 23:
recordTypeStr = "ApplicationData"
default:
recordTypeStr = fmt.Sprintf("Unknown(%d)", recordType)
}
var versionStr string
switch version {
case 0x0301:
versionStr = "TLS 1.0"
case 0x0302:
versionStr = "TLS 1.1"
case 0x0303:
versionStr = "TLS 1.2"
case 0x0304:
versionStr = "TLS 1.3"
default:
versionStr = fmt.Sprintf("Unknown(0x%04x)", version)
}
log.Printf("%s: 类型=%s, 版本=%s, 长度=%d, 总字节数=%d",
prefix, recordTypeStr, versionStr, length, len(data))
if recordType == recordTypeHandshake && len(data) >= 6 {
handshakeType := data[5]
var handshakeTypeStr string
switch handshakeType {
case 1:
handshakeTypeStr = "ClientHello"
case 2:
handshakeTypeStr = "ServerHello"
case 11:
handshakeTypeStr = "Certificate"
case 16:
handshakeTypeStr = "ClientKeyExchange"
default:
handshakeTypeStr = fmt.Sprintf("Unknown(%d)", handshakeType)
}
log.Printf("%s: 握手类型=%s", prefix, handshakeTypeStr)
}
}
// 验证TLS记录的完整性
func validateTLSRecord(data []byte) (bool, string) {
if len(data) < 5 {
return false, "数据太短不是有效的TLS记录"
}
recordType := data[0]
length := (uint16(data[3]) << 8) | uint16(data[4])
// 检查记录类型是否有效
validTypes := map[byte]bool{
20: true, // ChangeCipherSpec
21: true, // Alert
22: true, // Handshake
23: true, // ApplicationData
}
if !validTypes[recordType] {
return false, fmt.Sprintf("无效的TLS记录类型: %d", recordType)
}
// 检查记录长度是否与实际数据长度匹配
if int(length) != len(data)-5 {
return false, fmt.Sprintf("TLS记录长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", length, len(data)-5)
}
// 对于握手记录,进行更详细的验证
if recordType == recordTypeHandshake && len(data) >= 9 {
// 握手消息至少需要4字节的头部
handshakeType := data[5]
handshakeLength := (uint32(data[6]) << 16) | (uint32(data[7]) << 8) | uint32(data[8])
// 检查握手消息长度是否与记录长度匹配
if int(handshakeLength)+4 > int(length) {
return false, fmt.Sprintf("握手消息长度不匹配: 握手头部指示 %d 字节,记录头部指示 %d 字节", handshakeLength, length)
}
// 对于ClientHello进行更详细的验证
if handshakeType == handshakeTypeClientHello && len(data) >= 43 {
// ClientHello至少需要包含协议版本(2)、随机数(32)和会话ID长度(1)
return true, "有效的ClientHello记录"
}
return true, fmt.Sprintf("有效的握手记录,类型: %d", handshakeType)
}
return true, fmt.Sprintf("有效的TLS记录类型: %d", recordType)
}
// 验证握手消息的完整性
func validateHandshakeMessage(data []byte) (bool, string) {
if len(data) < 4 {
return false, "数据太短,不是有效的握手消息"
}
handshakeType := data[0]
handshakeLength := (uint32(data[1]) << 16) | (uint32(data[2]) << 8) | uint32(data[3])
// 检查握手消息长度是否与实际数据长度匹配
if int(handshakeLength)+4 != len(data) {
return false, fmt.Sprintf("握手消息长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", handshakeLength, len(data)-4)
}
// 对于ClientHello进行更详细的验证
if handshakeType == handshakeTypeClientHello {
if len(data) < 38 {
return false, "ClientHello消息太短"
}
// 检查协议版本、随机数等字段
// 这里只是简单检查长度,实际应用中可以更详细地验证
return true, "有效的ClientHello消息"
}
return true, fmt.Sprintf("有效的握手消息,类型: %d", handshakeType)
}