979 lines
26 KiB
Go
979 lines
26 KiB
Go
package proxy
|
||
|
||
import (
|
||
"context"
|
||
"crypto/rand"
|
||
"fmt"
|
||
"io"
|
||
"log"
|
||
"math/big"
|
||
"net"
|
||
"strconv"
|
||
"strings"
|
||
"sync"
|
||
"sync/atomic"
|
||
"time"
|
||
|
||
"github.com/SNI_Proxy/config"
|
||
"github.com/SNI_Proxy/dns"
|
||
)
|
||
|
||
// 服务器结构
|
||
type Server struct {
|
||
config *config.Config
|
||
listener net.Listener
|
||
wg sync.WaitGroup
|
||
shutdownCh chan struct{}
|
||
resolver dns.Resolver
|
||
connMutex sync.Mutex
|
||
conns map[net.Conn]struct{} // 活跃连接跟踪
|
||
connCounter int64 // 当前连接计数器
|
||
}
|
||
|
||
// 创建新的代理服务器
|
||
func NewServer(cfg *config.Config) *Server {
|
||
return &Server{
|
||
config: cfg,
|
||
shutdownCh: make(chan struct{}),
|
||
conns: make(map[net.Conn]struct{}),
|
||
}
|
||
}
|
||
|
||
// 启动代理服务器
|
||
func (s *Server) Start() error {
|
||
var err error
|
||
|
||
// 创建DNS解析器
|
||
s.resolver, err = dns.NewResolver(s.config.DNS)
|
||
if err != nil {
|
||
return fmt.Errorf("创建DNS解析器失败: %w", err)
|
||
}
|
||
|
||
s.listener, err = net.Listen("tcp", s.config.Listen)
|
||
if err != nil {
|
||
return fmt.Errorf("监听地址失败: %w", err)
|
||
}
|
||
|
||
s.wg.Add(1)
|
||
go s.acceptLoop()
|
||
|
||
// 启动连接监控
|
||
go s.monitorConnections()
|
||
|
||
// 启动连接清理
|
||
go s.cleanIdleConnections()
|
||
|
||
return nil
|
||
}
|
||
|
||
// 停止代理服务器
|
||
func (s *Server) Stop() {
|
||
log.Println("正在关闭服务器...")
|
||
close(s.shutdownCh)
|
||
if s.listener != nil {
|
||
s.listener.Close()
|
||
}
|
||
|
||
// 关闭所有活跃连接
|
||
s.closeAllConnections()
|
||
|
||
if s.resolver != nil {
|
||
s.resolver.Close()
|
||
}
|
||
|
||
// 等待所有goroutine结束
|
||
log.Println("等待所有连接关闭...")
|
||
s.wg.Wait()
|
||
log.Println("服务器已关闭")
|
||
}
|
||
|
||
// 关闭所有活跃连接
|
||
func (s *Server) closeAllConnections() {
|
||
s.connMutex.Lock()
|
||
defer s.connMutex.Unlock()
|
||
|
||
log.Printf("关闭 %d 个活跃连接", len(s.conns))
|
||
for conn := range s.conns {
|
||
conn.Close()
|
||
delete(s.conns, conn)
|
||
}
|
||
}
|
||
|
||
// 添加连接到跟踪列表
|
||
func (s *Server) trackConn(conn net.Conn) bool {
|
||
s.connMutex.Lock()
|
||
defer s.connMutex.Unlock()
|
||
|
||
// 检查是否超过最大连接数
|
||
if s.config.MaxConns > 0 && len(s.conns) >= s.config.MaxConns {
|
||
log.Printf("达到最大连接数 %d,拒绝新连接", s.config.MaxConns)
|
||
return false
|
||
}
|
||
|
||
s.conns[conn] = struct{}{}
|
||
atomic.AddInt64(&s.connCounter, 1)
|
||
return true
|
||
}
|
||
|
||
// 从跟踪列表中移除连接
|
||
func (s *Server) untrackConn(conn net.Conn) {
|
||
s.connMutex.Lock()
|
||
defer s.connMutex.Unlock()
|
||
|
||
if _, exists := s.conns[conn]; exists {
|
||
delete(s.conns, conn)
|
||
atomic.AddInt64(&s.connCounter, -1)
|
||
}
|
||
}
|
||
|
||
// 监控连接状态
|
||
func (s *Server) monitorConnections() {
|
||
ticker := time.NewTicker(30 * time.Second)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-s.shutdownCh:
|
||
return
|
||
case <-ticker.C:
|
||
// 记录当前活跃连接数
|
||
s.connMutex.Lock()
|
||
activeConns := len(s.conns)
|
||
s.connMutex.Unlock()
|
||
|
||
totalConns := atomic.LoadInt64(&s.connCounter)
|
||
log.Printf("连接统计 - 当前活跃: %d, 总计处理: %d", activeConns, totalConns)
|
||
}
|
||
}
|
||
}
|
||
|
||
// 清理空闲连接
|
||
func (s *Server) cleanIdleConnections() {
|
||
// 每分钟检查一次空闲连接
|
||
ticker := time.NewTicker(1 * time.Minute)
|
||
defer ticker.Stop()
|
||
|
||
for {
|
||
select {
|
||
case <-s.shutdownCh:
|
||
return
|
||
case <-ticker.C:
|
||
s.connMutex.Lock()
|
||
log.Printf("开始清理空闲连接,当前连接数: %d", len(s.conns))
|
||
// 这里我们不做实际清理,因为连接已经设置了超时
|
||
// 实际清理由各个连接自己的超时机制处理
|
||
s.connMutex.Unlock()
|
||
}
|
||
}
|
||
}
|
||
|
||
// 接受连接循环
|
||
func (s *Server) acceptLoop() {
|
||
defer s.wg.Done()
|
||
|
||
for {
|
||
conn, err := s.listener.Accept()
|
||
if err != nil {
|
||
select {
|
||
case <-s.shutdownCh:
|
||
return
|
||
default:
|
||
log.Printf("接受连接失败: %v", err)
|
||
// 短暂休眠,避免CPU占用过高
|
||
time.Sleep(100 * time.Millisecond)
|
||
continue
|
||
}
|
||
}
|
||
|
||
// 检查是否可以接受新连接
|
||
if !s.trackConn(conn) {
|
||
conn.Close()
|
||
continue
|
||
}
|
||
|
||
s.wg.Add(1)
|
||
|
||
go func(clientConn net.Conn) {
|
||
defer s.wg.Done()
|
||
defer s.untrackConn(clientConn) // 移除连接跟踪
|
||
defer clientConn.Close()
|
||
|
||
// 设置连接生命周期
|
||
if s.config.Timeout.LifeTime > 0 {
|
||
deadline := time.Now().Add(time.Duration(s.config.Timeout.LifeTime) * time.Second)
|
||
clientConn.SetDeadline(deadline)
|
||
}
|
||
|
||
// 创建一个带超时的上下文
|
||
ctx, cancel := context.WithTimeout(context.Background(),
|
||
time.Duration(s.config.Timeout.Idle)*time.Second)
|
||
defer cancel()
|
||
|
||
// 使用goroutine和channel处理连接,支持超时控制
|
||
errCh := make(chan error, 1)
|
||
go func() {
|
||
errCh <- s.handleConnection(clientConn)
|
||
}()
|
||
|
||
// 等待处理完成或超时
|
||
select {
|
||
case err := <-errCh:
|
||
if err != nil {
|
||
if s.config.LogLevel == "debug" ||
|
||
(s.config.LogLevel != "error" && !isCommonNetworkError(err)) {
|
||
log.Printf("处理连接失败: %v", err)
|
||
}
|
||
}
|
||
case <-ctx.Done():
|
||
log.Printf("处理连接超时")
|
||
case <-s.shutdownCh:
|
||
log.Printf("服务器关闭,终止连接处理")
|
||
}
|
||
}(conn)
|
||
}
|
||
}
|
||
|
||
// 判断是否是常见网络错误(可以不记录日志的错误)
|
||
func isCommonNetworkError(err error) bool {
|
||
if err == nil {
|
||
return false
|
||
}
|
||
|
||
errStr := err.Error()
|
||
// 检查常见的网络错误
|
||
commonErrors := []string{
|
||
"connection reset by peer",
|
||
"broken pipe",
|
||
"i/o timeout",
|
||
"use of closed network connection",
|
||
"EOF",
|
||
}
|
||
|
||
for _, e := range commonErrors {
|
||
if strings.Contains(strings.ToLower(errStr), e) {
|
||
return true
|
||
}
|
||
}
|
||
|
||
return false
|
||
}
|
||
|
||
// 处理单个连接
|
||
func (s *Server) handleConnection(clientConn net.Conn) error {
|
||
// 设置读取超时
|
||
readTimeout := time.Duration(s.config.Timeout.Read) * time.Second
|
||
if err := clientConn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
|
||
return fmt.Errorf("设置读取超时失败: %w", err)
|
||
}
|
||
|
||
// 读取并解析SNI信息
|
||
sni, clientHello, err := readSNI(clientConn)
|
||
if err != nil {
|
||
return fmt.Errorf("读取SNI失败: %w", err)
|
||
}
|
||
|
||
// 重置读取超时
|
||
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
|
||
return fmt.Errorf("重置读取超时失败: %w", err)
|
||
}
|
||
|
||
// 查找匹配的规则
|
||
var targetRule *config.ProxyRule
|
||
for i := range s.config.Rules {
|
||
rule := &s.config.Rules[i]
|
||
for j := range rule.Domains {
|
||
if rule.Domains[j].Match(sni) {
|
||
targetRule = rule
|
||
break
|
||
}
|
||
}
|
||
if targetRule != nil {
|
||
break
|
||
}
|
||
}
|
||
|
||
if targetRule == nil {
|
||
// 如果没有匹配的规则,使用默认端口
|
||
if s.config.LogLevel != "error" {
|
||
log.Printf("未找到匹配的规则,使用默认端口 %d 连接到 %s", s.config.DefaultPort, sni)
|
||
}
|
||
return s.proxyConnection(clientConn, sni, s.config.DefaultPort, clientHello, config.FragmentConfig{Enabled: false})
|
||
}
|
||
|
||
// 使用匹配规则中的端口和分片配置
|
||
return s.proxyConnection(clientConn, sni, targetRule.Port, clientHello, targetRule.Fragment)
|
||
}
|
||
|
||
// 代理连接到目标服务器
|
||
func (s *Server) proxyConnection(clientConn net.Conn, targetHost string, targetPort int, clientHello []byte, fragConfig config.FragmentConfig) error {
|
||
// 设置DNS解析超时
|
||
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.DNS.Timeout)*time.Second)
|
||
defer cancel()
|
||
|
||
// 创建一个带超时的解析任务
|
||
type resolveResult struct {
|
||
ips []net.IP
|
||
err error
|
||
}
|
||
|
||
resCh := make(chan resolveResult, 1)
|
||
go func() {
|
||
ips, err := s.resolver.Resolve(targetHost)
|
||
resCh <- resolveResult{ips, err}
|
||
}()
|
||
|
||
// 等待DNS解析完成或超时
|
||
var ips []net.IP
|
||
select {
|
||
case res := <-resCh:
|
||
if res.err != nil {
|
||
return fmt.Errorf("解析域名失败: %w", res.err)
|
||
}
|
||
ips = res.ips
|
||
case <-ctx.Done():
|
||
return fmt.Errorf("解析域名超时")
|
||
}
|
||
|
||
// 选择第一个IP地址
|
||
targetIP := ips[0]
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("域名 %s 解析为 %s", targetHost, targetIP.String())
|
||
}
|
||
|
||
// 构建目标地址
|
||
targetAddr := net.JoinHostPort(targetIP.String(), strconv.Itoa(targetPort))
|
||
if s.config.LogLevel != "error" {
|
||
log.Printf("代理连接到: %s (原始SNI: %s)", targetAddr, targetHost)
|
||
}
|
||
|
||
// 设置连接超时
|
||
dialCtx, dialCancel := context.WithTimeout(context.Background(),
|
||
time.Duration(s.config.Timeout.Connect)*time.Second)
|
||
defer dialCancel()
|
||
|
||
// 连接到目标服务器
|
||
var d net.Dialer
|
||
targetConn, err := d.DialContext(dialCtx, "tcp", targetAddr)
|
||
if err != nil {
|
||
return fmt.Errorf("连接目标服务器失败: %w", err)
|
||
}
|
||
defer targetConn.Close()
|
||
|
||
// 设置目标连接的读写超时
|
||
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
|
||
tcpConn.SetKeepAlive(true)
|
||
tcpConn.SetKeepAlivePeriod(30 * time.Second)
|
||
}
|
||
|
||
// 发送ClientHello到目标服务器,可能需要进行TLS分片
|
||
if err := s.sendClientHello(targetConn, clientHello, fragConfig); err != nil {
|
||
return fmt.Errorf("发送ClientHello失败: %w", err)
|
||
}
|
||
|
||
// 双向转发数据
|
||
errCh := make(chan error, 2)
|
||
copyDone := make(chan struct{}, 2)
|
||
|
||
// 客户端 -> 目标服务器
|
||
go func() {
|
||
buf := make([]byte, 32*1024) // 使用较大的缓冲区
|
||
_, err := s.copyBuffer(targetConn, clientConn, buf)
|
||
errCh <- err
|
||
copyDone <- struct{}{}
|
||
}()
|
||
|
||
// 目标服务器 -> 客户端
|
||
go func() {
|
||
buf := make([]byte, 32*1024) // 使用较大的缓冲区
|
||
_, err := s.copyBuffer(clientConn, targetConn, buf)
|
||
errCh <- err
|
||
copyDone <- struct{}{}
|
||
}()
|
||
|
||
// 等待任一方向的数据传输完成
|
||
<-copyDone
|
||
|
||
// 设置一个短暂的超时,让另一个方向有机会完成
|
||
time.Sleep(100 * time.Millisecond)
|
||
|
||
// 检查是否有错误
|
||
select {
|
||
case err := <-errCh:
|
||
if err != nil && err != io.EOF && !isCommonNetworkError(err) {
|
||
return fmt.Errorf("数据转发失败: %w", err)
|
||
}
|
||
default:
|
||
// 没有错误
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// 带超时控制的数据复制
|
||
func (s *Server) copyBuffer(dst net.Conn, src net.Conn, buf []byte) (int64, error) {
|
||
var written int64
|
||
readTimeout := time.Duration(s.config.Timeout.Read) * time.Second
|
||
writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second
|
||
|
||
for {
|
||
// 设置读取超时
|
||
src.SetReadDeadline(time.Now().Add(readTimeout))
|
||
|
||
nr, err := src.Read(buf)
|
||
if nr > 0 {
|
||
// 设置写入超时
|
||
dst.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||
|
||
nw, err := dst.Write(buf[0:nr])
|
||
if nw > 0 {
|
||
written += int64(nw)
|
||
}
|
||
if err != nil {
|
||
return written, err
|
||
}
|
||
if nr != nw {
|
||
return written, io.ErrShortWrite
|
||
}
|
||
}
|
||
if err != nil {
|
||
if err == io.EOF {
|
||
return written, nil
|
||
}
|
||
return written, err
|
||
}
|
||
}
|
||
}
|
||
|
||
// 发送ClientHello,根据配置进行TLS分片
|
||
func (s *Server) sendClientHello(conn net.Conn, clientHello []byte, fragConfig config.FragmentConfig) error {
|
||
writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second
|
||
|
||
if s.config.LogLevel == "debug" {
|
||
printTLSRecord(clientHello, "原始ClientHello")
|
||
}
|
||
|
||
if !fragConfig.Enabled {
|
||
// 不进行分片,直接发送
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("TLS分片未启用,直接发送 %d 字节", len(clientHello))
|
||
}
|
||
// 设置写入超时
|
||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||
_, err := conn.Write(clientHello)
|
||
return err
|
||
}
|
||
|
||
// 进行TLS分片
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("启用TLS分片,分片大小范围: %d-%d", fragConfig.MinSize, fragConfig.MaxSize)
|
||
}
|
||
|
||
// 检查ClientHello是否是有效的TLS记录
|
||
if fragConfig.Validate {
|
||
valid, reason := validateTLSRecord(clientHello)
|
||
if !valid {
|
||
log.Printf("警告: ClientHello不是有效的TLS记录 (%s),跳过分片", reason)
|
||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||
_, err := conn.Write(clientHello)
|
||
return err
|
||
}
|
||
}
|
||
|
||
if clientHello[0] != recordTypeHandshake {
|
||
log.Printf("警告: 记录类型不是握手类型,跳过分片")
|
||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||
_, err := conn.Write(clientHello)
|
||
return err
|
||
}
|
||
|
||
// 获取TLS记录头部和数据部分
|
||
recordHeader := clientHello[:5]
|
||
recordData := clientHello[5:]
|
||
recordLength := len(recordData)
|
||
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("TLS记录头部: %v, 数据长度: %d", recordHeader, recordLength)
|
||
|
||
// 如果是握手消息,打印握手消息头部信息
|
||
if recordData[0] == handshakeTypeClientHello && len(recordData) >= 4 {
|
||
handshakeType := recordData[0]
|
||
handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3])
|
||
log.Printf("握手消息头部: 类型=%d, 长度=%d", handshakeType, handshakeLength)
|
||
}
|
||
}
|
||
|
||
// 如果记录太小,不进行分片
|
||
if recordLength <= fragConfig.MinSize {
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("记录太小 (%d 字节),不进行分片", recordLength)
|
||
}
|
||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||
_, err := conn.Write(clientHello)
|
||
return err
|
||
}
|
||
|
||
// 第一个分片:发送记录头部和部分数据
|
||
minSize := fragConfig.MinSize
|
||
if minSize <= 0 {
|
||
minSize = 10 // 默认最小分片大小
|
||
}
|
||
maxSize := fragConfig.MaxSize
|
||
if maxSize <= 0 {
|
||
maxSize = 100 // 默认最大分片大小
|
||
}
|
||
if minSize > maxSize {
|
||
minSize = maxSize
|
||
}
|
||
|
||
// 随机分片大小
|
||
size := minSize
|
||
if maxSize > minSize {
|
||
randVal, _ := rand.Int(rand.Reader, big.NewInt(int64(maxSize-minSize+1)))
|
||
size = minSize + int(randVal.Int64())
|
||
}
|
||
|
||
// 确保不超过记录数据大小
|
||
if size > recordLength {
|
||
size = recordLength
|
||
}
|
||
|
||
// 对于握手消息,确保第一个分片至少包含完整的握手消息头部(4字节)
|
||
if recordData[0] == handshakeTypeClientHello && size < 4 {
|
||
size = 4 // 至少包含握手消息头部
|
||
}
|
||
|
||
// 创建第一个分片
|
||
firstFragment := make([]byte, 5+size)
|
||
copy(firstFragment[:5], recordHeader)
|
||
copy(firstFragment[5:], recordData[:size])
|
||
|
||
// 修改第一个分片的长度字段
|
||
firstFragment[3] = byte(size >> 8)
|
||
firstFragment[4] = byte(size)
|
||
|
||
if s.config.LogLevel == "debug" {
|
||
printTLSRecord(firstFragment, "第一个分片")
|
||
log.Printf("第一个分片长度: %d (头部5字节 + 数据%d字节)", len(firstFragment), size)
|
||
|
||
// 验证第一个分片的握手消息
|
||
if recordData[0] == handshakeTypeClientHello {
|
||
// 创建一个临时的握手消息进行验证
|
||
tempHandshake := make([]byte, size)
|
||
copy(tempHandshake, recordData[:size])
|
||
// 调整握手消息长度以匹配实际数据
|
||
handshakeLength := uint32(size - 4) // 减去握手消息头部的4字节
|
||
tempHandshake[1] = byte(handshakeLength >> 16)
|
||
tempHandshake[2] = byte(handshakeLength >> 8)
|
||
tempHandshake[3] = byte(handshakeLength)
|
||
|
||
valid, reason := validateHandshakeMessage(tempHandshake)
|
||
log.Printf("第一个分片握手消息验证: %v, %s", valid, reason)
|
||
}
|
||
}
|
||
|
||
// 设置写入超时并发送第一个分片
|
||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||
if _, err := conn.Write(firstFragment); err != nil {
|
||
return fmt.Errorf("发送第一个分片失败: %w", err)
|
||
}
|
||
|
||
// 如果还有剩余数据,创建第二个分片
|
||
if size < recordLength {
|
||
// 添加一个小延迟,模拟网络延迟,提高分片效果
|
||
delayMin := fragConfig.DelayMin
|
||
if delayMin <= 0 {
|
||
delayMin = 10 // 默认最小延迟10毫秒
|
||
}
|
||
|
||
delayMax := fragConfig.DelayMax
|
||
if delayMax <= 0 {
|
||
delayMax = 30 // 默认最大延迟30毫秒
|
||
}
|
||
|
||
if delayMin > delayMax {
|
||
delayMin = delayMax
|
||
}
|
||
|
||
delayRange := big.NewInt(int64(delayMax - delayMin + 1))
|
||
randVal, _ := rand.Int(rand.Reader, delayRange)
|
||
delay := time.Duration(delayMin+int(randVal.Int64())) * time.Millisecond
|
||
time.Sleep(delay)
|
||
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("延迟 %v 后发送第二个分片", delay)
|
||
}
|
||
|
||
// 创建第二个分片
|
||
remainingSize := recordLength - size
|
||
secondFragment := make([]byte, 5+remainingSize)
|
||
|
||
// 复制记录头部(保持与原始记录头部一致)
|
||
copy(secondFragment[:5], recordHeader)
|
||
|
||
// 设置正确的长度字段
|
||
secondFragment[3] = byte(remainingSize >> 8)
|
||
secondFragment[4] = byte(remainingSize)
|
||
|
||
// 复制剩余数据,确保握手类型保持一致
|
||
// 注意:我们不能直接复制数据,因为第二个分片的握手消息头部需要特殊处理
|
||
if recordData[0] == handshakeTypeClientHello {
|
||
// 创建新的握手消息头部
|
||
// 1. 保持握手类型不变
|
||
// 2. 调整握手消息长度
|
||
handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3])
|
||
|
||
// 计算第一个分片中已发送的握手消息数据长度(不包括握手消息头部的4字节)
|
||
sentDataLength := uint32(size) - 4
|
||
|
||
// 计算剩余的握手消息长度
|
||
remainingHandshakeLength := handshakeLength - sentDataLength
|
||
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("握手消息总长度: %d, 第一个分片已发送数据: %d, 第二个分片剩余数据: %d",
|
||
handshakeLength, sentDataLength, remainingHandshakeLength)
|
||
}
|
||
|
||
// 复制握手类型
|
||
secondFragment[5] = recordData[0]
|
||
|
||
// 设置新的握手消息长度
|
||
secondFragment[6] = byte(remainingHandshakeLength >> 16)
|
||
secondFragment[7] = byte(remainingHandshakeLength >> 8)
|
||
secondFragment[8] = byte(remainingHandshakeLength)
|
||
|
||
// 复制剩余数据(跳过第一个分片已经发送的部分)
|
||
copy(secondFragment[9:], recordData[size:])
|
||
} else {
|
||
// 如果不是ClientHello,直接复制数据
|
||
copy(secondFragment[5:], recordData[size:])
|
||
}
|
||
|
||
// 对于第二个分片,我们采用不同的策略
|
||
// 不再尝试修改握手消息头部,而是直接将剩余数据作为应用数据发送
|
||
// 这样可以避免TLS握手解析错误
|
||
|
||
// 修改第二个分片的记录类型为应用数据
|
||
secondFragment[0] = 23 // ApplicationData
|
||
|
||
// 复制剩余数据
|
||
copy(secondFragment[5:], recordData[size:])
|
||
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("第二个分片使用应用数据类型,避免握手解析错误")
|
||
printTLSRecord(secondFragment, "第二个分片")
|
||
log.Printf("第二个分片长度: %d (头部5字节 + 数据%d字节)", len(secondFragment), remainingSize)
|
||
}
|
||
|
||
// 设置写入超时并发送第二个分片
|
||
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
|
||
if _, err := conn.Write(secondFragment); err != nil {
|
||
return fmt.Errorf("发送第二个分片失败: %w", err)
|
||
}
|
||
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("TLS分片完成: 总共发送 %d 字节 (第一分片: %d 字节, 第二分片: %d 字节)",
|
||
len(firstFragment)+len(secondFragment), len(firstFragment), len(secondFragment))
|
||
}
|
||
} else {
|
||
if s.config.LogLevel == "debug" {
|
||
log.Printf("TLS分片完成: 只需一个分片,总共发送 %d 字节", len(firstFragment))
|
||
}
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// TLS记录类型
|
||
const (
|
||
recordTypeHandshake = 22
|
||
)
|
||
|
||
// TLS握手类型
|
||
const (
|
||
handshakeTypeClientHello = 1
|
||
)
|
||
|
||
// 读取SNI信息
|
||
func readSNI(conn net.Conn) (string, []byte, error) {
|
||
// 读取TLS记录头部
|
||
header := make([]byte, 5)
|
||
if _, err := io.ReadFull(conn, header); err != nil {
|
||
return "", nil, fmt.Errorf("读取TLS记录头部失败: %w", err)
|
||
}
|
||
|
||
// 检查是否是TLS握手记录
|
||
if header[0] != recordTypeHandshake {
|
||
return "", nil, fmt.Errorf("不是TLS握手记录,记录类型: %d", header[0])
|
||
}
|
||
|
||
// 获取记录长度
|
||
recordLength := int(header[3])<<8 | int(header[4])
|
||
if recordLength > 16384 {
|
||
return "", nil, fmt.Errorf("TLS记录太长: %d 字节", recordLength)
|
||
}
|
||
|
||
if recordLength < 4 {
|
||
return "", nil, fmt.Errorf("TLS握手记录太短: %d 字节", recordLength)
|
||
}
|
||
|
||
// 读取握手消息
|
||
handshakeData := make([]byte, recordLength)
|
||
if _, err := io.ReadFull(conn, handshakeData); err != nil {
|
||
return "", nil, fmt.Errorf("读取握手消息失败: %w", err)
|
||
}
|
||
|
||
// 检查是否是ClientHello
|
||
if handshakeData[0] != handshakeTypeClientHello {
|
||
return "", nil, fmt.Errorf("不是ClientHello消息,握手类型: %d", handshakeData[0])
|
||
}
|
||
|
||
// 完整的ClientHello消息(包括记录头部)
|
||
clientHello := append(header, handshakeData...)
|
||
|
||
// 验证TLS记录的完整性
|
||
valid, reason := validateTLSRecord(clientHello)
|
||
if !valid {
|
||
log.Printf("警告: 收到的ClientHello不是有效的TLS记录: %s", reason)
|
||
}
|
||
|
||
// 解析SNI扩展
|
||
sni, err := extractSNI(handshakeData)
|
||
if err != nil {
|
||
return "", nil, fmt.Errorf("提取SNI失败: %w", err)
|
||
}
|
||
|
||
if sni != "" && len(sni) > 0 {
|
||
log.Printf("成功提取SNI: %s", sni)
|
||
}
|
||
|
||
return sni, clientHello, nil
|
||
}
|
||
|
||
// 从ClientHello中提取SNI
|
||
func extractSNI(data []byte) (string, error) {
|
||
if len(data) < 42 {
|
||
return "", fmt.Errorf("ClientHello太短")
|
||
}
|
||
|
||
// 跳过握手类型(1) + 长度(3) + 协议版本(2) + 随机数(32)
|
||
pos := 38
|
||
|
||
// 跳过会话ID
|
||
if pos+1 > len(data) {
|
||
return "", fmt.Errorf("解析会话ID长度时数据不足")
|
||
}
|
||
sessionIDLength := int(data[pos])
|
||
pos += 1 + sessionIDLength
|
||
|
||
// 跳过密码套件
|
||
if pos+2 > len(data) {
|
||
return "", fmt.Errorf("解析密码套件长度时数据不足")
|
||
}
|
||
cipherSuitesLength := int(data[pos])<<8 | int(data[pos+1])
|
||
pos += 2 + cipherSuitesLength
|
||
|
||
// 跳过压缩方法
|
||
if pos+1 > len(data) {
|
||
return "", fmt.Errorf("解析压缩方法长度时数据不足")
|
||
}
|
||
compressionMethodsLength := int(data[pos])
|
||
pos += 1 + compressionMethodsLength
|
||
|
||
// 检查是否有扩展
|
||
if pos+2 > len(data) {
|
||
return "", fmt.Errorf("没有扩展数据")
|
||
}
|
||
extensionsLength := int(data[pos])<<8 | int(data[pos+1])
|
||
pos += 2
|
||
|
||
// 解析扩展
|
||
extensionsEnd := pos + extensionsLength
|
||
if extensionsEnd > len(data) {
|
||
return "", fmt.Errorf("扩展数据不足")
|
||
}
|
||
|
||
for pos < extensionsEnd {
|
||
// 读取扩展类型
|
||
if pos+4 > len(data) {
|
||
return "", fmt.Errorf("解析扩展类型时数据不足")
|
||
}
|
||
extensionType := int(data[pos])<<8 | int(data[pos+1])
|
||
extensionLength := int(data[pos+2])<<8 | int(data[pos+3])
|
||
pos += 4
|
||
|
||
// 检查是否是SNI扩展 (类型为0)
|
||
if extensionType == 0 {
|
||
// 跳过SNI列表长度
|
||
if pos+2 > len(data) {
|
||
return "", fmt.Errorf("解析SNI列表长度时数据不足")
|
||
}
|
||
pos += 2
|
||
|
||
// 检查SNI类型
|
||
if pos+1 > len(data) {
|
||
return "", fmt.Errorf("解析SNI类型时数据不足")
|
||
}
|
||
if data[pos] != 0 {
|
||
return "", fmt.Errorf("不支持的SNI类型")
|
||
}
|
||
pos++
|
||
|
||
// 读取主机名长度
|
||
if pos+2 > len(data) {
|
||
return "", fmt.Errorf("解析主机名长度时数据不足")
|
||
}
|
||
hostnameLength := int(data[pos])<<8 | int(data[pos+1])
|
||
pos += 2
|
||
|
||
// 读取主机名
|
||
if pos+hostnameLength > len(data) {
|
||
return "", fmt.Errorf("解析主机名时数据不足")
|
||
}
|
||
hostname := string(data[pos : pos+hostnameLength])
|
||
return hostname, nil
|
||
}
|
||
|
||
// 跳过当前扩展
|
||
pos += extensionLength
|
||
}
|
||
|
||
return "", fmt.Errorf("未找到SNI扩展")
|
||
}
|
||
|
||
// 打印TLS记录的详细信息(用于调试)
|
||
func printTLSRecord(data []byte, prefix string) {
|
||
if len(data) < 5 {
|
||
log.Printf("%s: 数据太短,不是有效的TLS记录", prefix)
|
||
return
|
||
}
|
||
|
||
recordType := data[0]
|
||
version := (uint16(data[1]) << 8) | uint16(data[2])
|
||
length := (uint16(data[3]) << 8) | uint16(data[4])
|
||
|
||
var recordTypeStr string
|
||
switch recordType {
|
||
case 20:
|
||
recordTypeStr = "ChangeCipherSpec"
|
||
case 21:
|
||
recordTypeStr = "Alert"
|
||
case 22:
|
||
recordTypeStr = "Handshake"
|
||
case 23:
|
||
recordTypeStr = "ApplicationData"
|
||
default:
|
||
recordTypeStr = fmt.Sprintf("Unknown(%d)", recordType)
|
||
}
|
||
|
||
var versionStr string
|
||
switch version {
|
||
case 0x0301:
|
||
versionStr = "TLS 1.0"
|
||
case 0x0302:
|
||
versionStr = "TLS 1.1"
|
||
case 0x0303:
|
||
versionStr = "TLS 1.2"
|
||
case 0x0304:
|
||
versionStr = "TLS 1.3"
|
||
default:
|
||
versionStr = fmt.Sprintf("Unknown(0x%04x)", version)
|
||
}
|
||
|
||
log.Printf("%s: 类型=%s, 版本=%s, 长度=%d, 总字节数=%d",
|
||
prefix, recordTypeStr, versionStr, length, len(data))
|
||
|
||
if recordType == recordTypeHandshake && len(data) >= 6 {
|
||
handshakeType := data[5]
|
||
var handshakeTypeStr string
|
||
switch handshakeType {
|
||
case 1:
|
||
handshakeTypeStr = "ClientHello"
|
||
case 2:
|
||
handshakeTypeStr = "ServerHello"
|
||
case 11:
|
||
handshakeTypeStr = "Certificate"
|
||
case 16:
|
||
handshakeTypeStr = "ClientKeyExchange"
|
||
default:
|
||
handshakeTypeStr = fmt.Sprintf("Unknown(%d)", handshakeType)
|
||
}
|
||
log.Printf("%s: 握手类型=%s", prefix, handshakeTypeStr)
|
||
}
|
||
}
|
||
|
||
// 验证TLS记录的完整性
|
||
func validateTLSRecord(data []byte) (bool, string) {
|
||
if len(data) < 5 {
|
||
return false, "数据太短,不是有效的TLS记录"
|
||
}
|
||
|
||
recordType := data[0]
|
||
length := (uint16(data[3]) << 8) | uint16(data[4])
|
||
|
||
// 检查记录类型是否有效
|
||
validTypes := map[byte]bool{
|
||
20: true, // ChangeCipherSpec
|
||
21: true, // Alert
|
||
22: true, // Handshake
|
||
23: true, // ApplicationData
|
||
}
|
||
if !validTypes[recordType] {
|
||
return false, fmt.Sprintf("无效的TLS记录类型: %d", recordType)
|
||
}
|
||
|
||
// 检查记录长度是否与实际数据长度匹配
|
||
if int(length) != len(data)-5 {
|
||
return false, fmt.Sprintf("TLS记录长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", length, len(data)-5)
|
||
}
|
||
|
||
// 对于握手记录,进行更详细的验证
|
||
if recordType == recordTypeHandshake && len(data) >= 9 {
|
||
// 握手消息至少需要4字节的头部
|
||
handshakeType := data[5]
|
||
handshakeLength := (uint32(data[6]) << 16) | (uint32(data[7]) << 8) | uint32(data[8])
|
||
|
||
// 检查握手消息长度是否与记录长度匹配
|
||
if int(handshakeLength)+4 > int(length) {
|
||
return false, fmt.Sprintf("握手消息长度不匹配: 握手头部指示 %d 字节,记录头部指示 %d 字节", handshakeLength, length)
|
||
}
|
||
|
||
// 对于ClientHello,进行更详细的验证
|
||
if handshakeType == handshakeTypeClientHello && len(data) >= 43 {
|
||
// ClientHello至少需要包含协议版本(2)、随机数(32)和会话ID长度(1)
|
||
return true, "有效的ClientHello记录"
|
||
}
|
||
|
||
return true, fmt.Sprintf("有效的握手记录,类型: %d", handshakeType)
|
||
}
|
||
|
||
return true, fmt.Sprintf("有效的TLS记录,类型: %d", recordType)
|
||
}
|
||
|
||
// 验证握手消息的完整性
|
||
func validateHandshakeMessage(data []byte) (bool, string) {
|
||
if len(data) < 4 {
|
||
return false, "数据太短,不是有效的握手消息"
|
||
}
|
||
|
||
handshakeType := data[0]
|
||
handshakeLength := (uint32(data[1]) << 16) | (uint32(data[2]) << 8) | uint32(data[3])
|
||
|
||
// 检查握手消息长度是否与实际数据长度匹配
|
||
if int(handshakeLength)+4 != len(data) {
|
||
return false, fmt.Sprintf("握手消息长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", handshakeLength, len(data)-4)
|
||
}
|
||
|
||
// 对于ClientHello,进行更详细的验证
|
||
if handshakeType == handshakeTypeClientHello {
|
||
if len(data) < 38 {
|
||
return false, "ClientHello消息太短"
|
||
}
|
||
|
||
// 检查协议版本、随机数等字段
|
||
// 这里只是简单检查长度,实际应用中可以更详细地验证
|
||
return true, "有效的ClientHello消息"
|
||
}
|
||
|
||
return true, fmt.Sprintf("有效的握手消息,类型: %d", handshakeType)
|
||
}
|