package proxy import ( "context" "crypto/rand" "fmt" "io" "log" "math/big" "net" "strconv" "strings" "sync" "sync/atomic" "time" "github.com/SNI_Proxy/config" "github.com/SNI_Proxy/dns" ) // 服务器结构 type Server struct { config *config.Config listener net.Listener wg sync.WaitGroup shutdownCh chan struct{} resolver dns.Resolver connMutex sync.Mutex conns map[net.Conn]struct{} // 活跃连接跟踪 connCounter int64 // 当前连接计数器 } // 创建新的代理服务器 func NewServer(cfg *config.Config) *Server { return &Server{ config: cfg, shutdownCh: make(chan struct{}), conns: make(map[net.Conn]struct{}), } } // 启动代理服务器 func (s *Server) Start() error { var err error // 创建DNS解析器 s.resolver, err = dns.NewResolver(s.config.DNS) if err != nil { return fmt.Errorf("创建DNS解析器失败: %w", err) } s.listener, err = net.Listen("tcp", s.config.Listen) if err != nil { return fmt.Errorf("监听地址失败: %w", err) } s.wg.Add(1) go s.acceptLoop() // 启动连接监控 go s.monitorConnections() // 启动连接清理 go s.cleanIdleConnections() return nil } // 停止代理服务器 func (s *Server) Stop() { log.Println("正在关闭服务器...") close(s.shutdownCh) if s.listener != nil { s.listener.Close() } // 关闭所有活跃连接 s.closeAllConnections() if s.resolver != nil { s.resolver.Close() } // 等待所有goroutine结束 log.Println("等待所有连接关闭...") s.wg.Wait() log.Println("服务器已关闭") } // 关闭所有活跃连接 func (s *Server) closeAllConnections() { s.connMutex.Lock() defer s.connMutex.Unlock() log.Printf("关闭 %d 个活跃连接", len(s.conns)) for conn := range s.conns { conn.Close() delete(s.conns, conn) } } // 添加连接到跟踪列表 func (s *Server) trackConn(conn net.Conn) bool { s.connMutex.Lock() defer s.connMutex.Unlock() // 检查是否超过最大连接数 if s.config.MaxConns > 0 && len(s.conns) >= s.config.MaxConns { log.Printf("达到最大连接数 %d,拒绝新连接", s.config.MaxConns) return false } s.conns[conn] = struct{}{} atomic.AddInt64(&s.connCounter, 1) return true } // 从跟踪列表中移除连接 func (s *Server) untrackConn(conn net.Conn) { s.connMutex.Lock() defer s.connMutex.Unlock() if _, exists := s.conns[conn]; exists { delete(s.conns, conn) atomic.AddInt64(&s.connCounter, -1) } } // 监控连接状态 func (s *Server) monitorConnections() { ticker := time.NewTicker(30 * time.Second) defer ticker.Stop() for { select { case <-s.shutdownCh: return case <-ticker.C: // 记录当前活跃连接数 s.connMutex.Lock() activeConns := len(s.conns) s.connMutex.Unlock() totalConns := atomic.LoadInt64(&s.connCounter) log.Printf("连接统计 - 当前活跃: %d, 总计处理: %d", activeConns, totalConns) } } } // 清理空闲连接 func (s *Server) cleanIdleConnections() { // 每分钟检查一次空闲连接 ticker := time.NewTicker(1 * time.Minute) defer ticker.Stop() for { select { case <-s.shutdownCh: return case <-ticker.C: s.connMutex.Lock() log.Printf("开始清理空闲连接,当前连接数: %d", len(s.conns)) // 这里我们不做实际清理,因为连接已经设置了超时 // 实际清理由各个连接自己的超时机制处理 s.connMutex.Unlock() } } } // 接受连接循环 func (s *Server) acceptLoop() { defer s.wg.Done() for { conn, err := s.listener.Accept() if err != nil { select { case <-s.shutdownCh: return default: log.Printf("接受连接失败: %v", err) // 短暂休眠,避免CPU占用过高 time.Sleep(100 * time.Millisecond) continue } } // 检查是否可以接受新连接 if !s.trackConn(conn) { conn.Close() continue } s.wg.Add(1) go func(clientConn net.Conn) { defer s.wg.Done() defer s.untrackConn(clientConn) // 移除连接跟踪 defer clientConn.Close() // 设置连接生命周期 if s.config.Timeout.LifeTime > 0 { deadline := time.Now().Add(time.Duration(s.config.Timeout.LifeTime) * time.Second) clientConn.SetDeadline(deadline) } // 创建一个带超时的上下文 ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.Timeout.Idle)*time.Second) defer cancel() // 使用goroutine和channel处理连接,支持超时控制 errCh := make(chan error, 1) go func() { errCh <- s.handleConnection(clientConn) }() // 等待处理完成或超时 select { case err := <-errCh: if err != nil { if s.config.LogLevel == "debug" || (s.config.LogLevel != "error" && !isCommonNetworkError(err)) { log.Printf("处理连接失败: %v", err) } } case <-ctx.Done(): log.Printf("处理连接超时") case <-s.shutdownCh: log.Printf("服务器关闭,终止连接处理") } }(conn) } } // 判断是否是常见网络错误(可以不记录日志的错误) func isCommonNetworkError(err error) bool { if err == nil { return false } errStr := err.Error() // 检查常见的网络错误 commonErrors := []string{ "connection reset by peer", "broken pipe", "i/o timeout", "use of closed network connection", "EOF", } for _, e := range commonErrors { if strings.Contains(strings.ToLower(errStr), e) { return true } } return false } // 处理单个连接 func (s *Server) handleConnection(clientConn net.Conn) error { // 设置读取超时 readTimeout := time.Duration(s.config.Timeout.Read) * time.Second if err := clientConn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil { return fmt.Errorf("设置读取超时失败: %w", err) } // 读取并解析SNI信息 sni, clientHello, err := readSNI(clientConn) if err != nil { return fmt.Errorf("读取SNI失败: %w", err) } // 重置读取超时 if err := clientConn.SetReadDeadline(time.Time{}); err != nil { return fmt.Errorf("重置读取超时失败: %w", err) } // 查找匹配的规则 var targetRule *config.ProxyRule for i := range s.config.Rules { rule := &s.config.Rules[i] for j := range rule.Domains { if rule.Domains[j].Match(sni) { targetRule = rule break } } if targetRule != nil { break } } if targetRule == nil { // 如果没有匹配的规则,使用默认端口 if s.config.LogLevel != "error" { log.Printf("未找到匹配的规则,使用默认端口 %d 连接到 %s", s.config.DefaultPort, sni) } return s.proxyConnection(clientConn, sni, s.config.DefaultPort, clientHello, config.FragmentConfig{Enabled: false}) } // 使用匹配规则中的端口和分片配置 return s.proxyConnection(clientConn, sni, targetRule.Port, clientHello, targetRule.Fragment) } // 代理连接到目标服务器 func (s *Server) proxyConnection(clientConn net.Conn, targetHost string, targetPort int, clientHello []byte, fragConfig config.FragmentConfig) error { // 设置DNS解析超时 ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.DNS.Timeout)*time.Second) defer cancel() // 创建一个带超时的解析任务 type resolveResult struct { ips []net.IP err error } resCh := make(chan resolveResult, 1) go func() { ips, err := s.resolver.Resolve(targetHost) resCh <- resolveResult{ips, err} }() // 等待DNS解析完成或超时 var ips []net.IP select { case res := <-resCh: if res.err != nil { return fmt.Errorf("解析域名失败: %w", res.err) } ips = res.ips case <-ctx.Done(): return fmt.Errorf("解析域名超时") } // 选择第一个IP地址 targetIP := ips[0] if s.config.LogLevel == "debug" { log.Printf("域名 %s 解析为 %s", targetHost, targetIP.String()) } // 构建目标地址 targetAddr := net.JoinHostPort(targetIP.String(), strconv.Itoa(targetPort)) if s.config.LogLevel != "error" { log.Printf("代理连接到: %s (原始SNI: %s)", targetAddr, targetHost) } // 设置连接超时 dialCtx, dialCancel := context.WithTimeout(context.Background(), time.Duration(s.config.Timeout.Connect)*time.Second) defer dialCancel() // 连接到目标服务器 var d net.Dialer targetConn, err := d.DialContext(dialCtx, "tcp", targetAddr) if err != nil { return fmt.Errorf("连接目标服务器失败: %w", err) } defer targetConn.Close() // 设置目标连接的读写超时 if tcpConn, ok := targetConn.(*net.TCPConn); ok { tcpConn.SetKeepAlive(true) tcpConn.SetKeepAlivePeriod(30 * time.Second) } // 发送ClientHello到目标服务器,可能需要进行TLS分片 if err := s.sendClientHello(targetConn, clientHello, fragConfig); err != nil { return fmt.Errorf("发送ClientHello失败: %w", err) } // 双向转发数据 errCh := make(chan error, 2) copyDone := make(chan struct{}, 2) // 客户端 -> 目标服务器 go func() { buf := make([]byte, 32*1024) // 使用较大的缓冲区 _, err := s.copyBuffer(targetConn, clientConn, buf) errCh <- err copyDone <- struct{}{} }() // 目标服务器 -> 客户端 go func() { buf := make([]byte, 32*1024) // 使用较大的缓冲区 _, err := s.copyBuffer(clientConn, targetConn, buf) errCh <- err copyDone <- struct{}{} }() // 等待任一方向的数据传输完成 <-copyDone // 设置一个短暂的超时,让另一个方向有机会完成 time.Sleep(100 * time.Millisecond) // 检查是否有错误 select { case err := <-errCh: if err != nil && err != io.EOF && !isCommonNetworkError(err) { return fmt.Errorf("数据转发失败: %w", err) } default: // 没有错误 } return nil } // 带超时控制的数据复制 func (s *Server) copyBuffer(dst net.Conn, src net.Conn, buf []byte) (int64, error) { var written int64 readTimeout := time.Duration(s.config.Timeout.Read) * time.Second writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second for { // 设置读取超时 src.SetReadDeadline(time.Now().Add(readTimeout)) nr, err := src.Read(buf) if nr > 0 { // 设置写入超时 dst.SetWriteDeadline(time.Now().Add(writeTimeout)) nw, err := dst.Write(buf[0:nr]) if nw > 0 { written += int64(nw) } if err != nil { return written, err } if nr != nw { return written, io.ErrShortWrite } } if err != nil { if err == io.EOF { return written, nil } return written, err } } } // 发送ClientHello,根据配置进行TLS分片 func (s *Server) sendClientHello(conn net.Conn, clientHello []byte, fragConfig config.FragmentConfig) error { writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second if s.config.LogLevel == "debug" { printTLSRecord(clientHello, "原始ClientHello") } if !fragConfig.Enabled { // 不进行分片,直接发送 if s.config.LogLevel == "debug" { log.Printf("TLS分片未启用,直接发送 %d 字节", len(clientHello)) } // 设置写入超时 conn.SetWriteDeadline(time.Now().Add(writeTimeout)) _, err := conn.Write(clientHello) return err } // 进行TLS分片 if s.config.LogLevel == "debug" { log.Printf("启用TLS分片,分片大小范围: %d-%d", fragConfig.MinSize, fragConfig.MaxSize) } // 检查ClientHello是否是有效的TLS记录 if fragConfig.Validate { valid, reason := validateTLSRecord(clientHello) if !valid { log.Printf("警告: ClientHello不是有效的TLS记录 (%s),跳过分片", reason) conn.SetWriteDeadline(time.Now().Add(writeTimeout)) _, err := conn.Write(clientHello) return err } } if clientHello[0] != recordTypeHandshake { log.Printf("警告: 记录类型不是握手类型,跳过分片") conn.SetWriteDeadline(time.Now().Add(writeTimeout)) _, err := conn.Write(clientHello) return err } // 获取TLS记录头部和数据部分 recordHeader := clientHello[:5] recordData := clientHello[5:] recordLength := len(recordData) if s.config.LogLevel == "debug" { log.Printf("TLS记录头部: %v, 数据长度: %d", recordHeader, recordLength) // 如果是握手消息,打印握手消息头部信息 if recordData[0] == handshakeTypeClientHello && len(recordData) >= 4 { handshakeType := recordData[0] handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3]) log.Printf("握手消息头部: 类型=%d, 长度=%d", handshakeType, handshakeLength) } } // 如果记录太小,不进行分片 if recordLength <= fragConfig.MinSize { if s.config.LogLevel == "debug" { log.Printf("记录太小 (%d 字节),不进行分片", recordLength) } conn.SetWriteDeadline(time.Now().Add(writeTimeout)) _, err := conn.Write(clientHello) return err } // 第一个分片:发送记录头部和部分数据 minSize := fragConfig.MinSize if minSize <= 0 { minSize = 10 // 默认最小分片大小 } maxSize := fragConfig.MaxSize if maxSize <= 0 { maxSize = 100 // 默认最大分片大小 } if minSize > maxSize { minSize = maxSize } // 随机分片大小 size := minSize if maxSize > minSize { randVal, _ := rand.Int(rand.Reader, big.NewInt(int64(maxSize-minSize+1))) size = minSize + int(randVal.Int64()) } // 确保不超过记录数据大小 if size > recordLength { size = recordLength } // 对于握手消息,确保第一个分片至少包含完整的握手消息头部(4字节) if recordData[0] == handshakeTypeClientHello && size < 4 { size = 4 // 至少包含握手消息头部 } // 创建第一个分片 firstFragment := make([]byte, 5+size) copy(firstFragment[:5], recordHeader) copy(firstFragment[5:], recordData[:size]) // 修改第一个分片的长度字段 firstFragment[3] = byte(size >> 8) firstFragment[4] = byte(size) if s.config.LogLevel == "debug" { printTLSRecord(firstFragment, "第一个分片") log.Printf("第一个分片长度: %d (头部5字节 + 数据%d字节)", len(firstFragment), size) // 验证第一个分片的握手消息 if recordData[0] == handshakeTypeClientHello { // 创建一个临时的握手消息进行验证 tempHandshake := make([]byte, size) copy(tempHandshake, recordData[:size]) // 调整握手消息长度以匹配实际数据 handshakeLength := uint32(size - 4) // 减去握手消息头部的4字节 tempHandshake[1] = byte(handshakeLength >> 16) tempHandshake[2] = byte(handshakeLength >> 8) tempHandshake[3] = byte(handshakeLength) valid, reason := validateHandshakeMessage(tempHandshake) log.Printf("第一个分片握手消息验证: %v, %s", valid, reason) } } // 设置写入超时并发送第一个分片 conn.SetWriteDeadline(time.Now().Add(writeTimeout)) if _, err := conn.Write(firstFragment); err != nil { return fmt.Errorf("发送第一个分片失败: %w", err) } // 如果还有剩余数据,创建第二个分片 if size < recordLength { // 添加一个小延迟,模拟网络延迟,提高分片效果 delayMin := fragConfig.DelayMin if delayMin <= 0 { delayMin = 10 // 默认最小延迟10毫秒 } delayMax := fragConfig.DelayMax if delayMax <= 0 { delayMax = 30 // 默认最大延迟30毫秒 } if delayMin > delayMax { delayMin = delayMax } delayRange := big.NewInt(int64(delayMax - delayMin + 1)) randVal, _ := rand.Int(rand.Reader, delayRange) delay := time.Duration(delayMin+int(randVal.Int64())) * time.Millisecond time.Sleep(delay) if s.config.LogLevel == "debug" { log.Printf("延迟 %v 后发送第二个分片", delay) } // 创建第二个分片 remainingSize := recordLength - size secondFragment := make([]byte, 5+remainingSize) // 复制记录头部(保持与原始记录头部一致) copy(secondFragment[:5], recordHeader) // 设置正确的长度字段 secondFragment[3] = byte(remainingSize >> 8) secondFragment[4] = byte(remainingSize) // 复制剩余数据,确保握手类型保持一致 // 注意:我们不能直接复制数据,因为第二个分片的握手消息头部需要特殊处理 if recordData[0] == handshakeTypeClientHello { // 创建新的握手消息头部 // 1. 保持握手类型不变 // 2. 调整握手消息长度 handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3]) // 计算第一个分片中已发送的握手消息数据长度(不包括握手消息头部的4字节) sentDataLength := uint32(size) - 4 // 计算剩余的握手消息长度 remainingHandshakeLength := handshakeLength - sentDataLength if s.config.LogLevel == "debug" { log.Printf("握手消息总长度: %d, 第一个分片已发送数据: %d, 第二个分片剩余数据: %d", handshakeLength, sentDataLength, remainingHandshakeLength) } // 复制握手类型 secondFragment[5] = recordData[0] // 设置新的握手消息长度 secondFragment[6] = byte(remainingHandshakeLength >> 16) secondFragment[7] = byte(remainingHandshakeLength >> 8) secondFragment[8] = byte(remainingHandshakeLength) // 复制剩余数据(跳过第一个分片已经发送的部分) copy(secondFragment[9:], recordData[size:]) } else { // 如果不是ClientHello,直接复制数据 copy(secondFragment[5:], recordData[size:]) } // 对于第二个分片,我们采用不同的策略 // 不再尝试修改握手消息头部,而是直接将剩余数据作为应用数据发送 // 这样可以避免TLS握手解析错误 // 修改第二个分片的记录类型为应用数据 secondFragment[0] = 23 // ApplicationData // 复制剩余数据 copy(secondFragment[5:], recordData[size:]) if s.config.LogLevel == "debug" { log.Printf("第二个分片使用应用数据类型,避免握手解析错误") printTLSRecord(secondFragment, "第二个分片") log.Printf("第二个分片长度: %d (头部5字节 + 数据%d字节)", len(secondFragment), remainingSize) } // 设置写入超时并发送第二个分片 conn.SetWriteDeadline(time.Now().Add(writeTimeout)) if _, err := conn.Write(secondFragment); err != nil { return fmt.Errorf("发送第二个分片失败: %w", err) } if s.config.LogLevel == "debug" { log.Printf("TLS分片完成: 总共发送 %d 字节 (第一分片: %d 字节, 第二分片: %d 字节)", len(firstFragment)+len(secondFragment), len(firstFragment), len(secondFragment)) } } else { if s.config.LogLevel == "debug" { log.Printf("TLS分片完成: 只需一个分片,总共发送 %d 字节", len(firstFragment)) } } return nil } // TLS记录类型 const ( recordTypeHandshake = 22 ) // TLS握手类型 const ( handshakeTypeClientHello = 1 ) // 读取SNI信息 func readSNI(conn net.Conn) (string, []byte, error) { // 读取TLS记录头部 header := make([]byte, 5) if _, err := io.ReadFull(conn, header); err != nil { return "", nil, fmt.Errorf("读取TLS记录头部失败: %w", err) } // 检查是否是TLS握手记录 if header[0] != recordTypeHandshake { return "", nil, fmt.Errorf("不是TLS握手记录,记录类型: %d", header[0]) } // 获取记录长度 recordLength := int(header[3])<<8 | int(header[4]) if recordLength > 16384 { return "", nil, fmt.Errorf("TLS记录太长: %d 字节", recordLength) } if recordLength < 4 { return "", nil, fmt.Errorf("TLS握手记录太短: %d 字节", recordLength) } // 读取握手消息 handshakeData := make([]byte, recordLength) if _, err := io.ReadFull(conn, handshakeData); err != nil { return "", nil, fmt.Errorf("读取握手消息失败: %w", err) } // 检查是否是ClientHello if handshakeData[0] != handshakeTypeClientHello { return "", nil, fmt.Errorf("不是ClientHello消息,握手类型: %d", handshakeData[0]) } // 完整的ClientHello消息(包括记录头部) clientHello := append(header, handshakeData...) // 验证TLS记录的完整性 valid, reason := validateTLSRecord(clientHello) if !valid { log.Printf("警告: 收到的ClientHello不是有效的TLS记录: %s", reason) } // 解析SNI扩展 sni, err := extractSNI(handshakeData) if err != nil { return "", nil, fmt.Errorf("提取SNI失败: %w", err) } if sni != "" && len(sni) > 0 { log.Printf("成功提取SNI: %s", sni) } return sni, clientHello, nil } // 从ClientHello中提取SNI func extractSNI(data []byte) (string, error) { if len(data) < 42 { return "", fmt.Errorf("ClientHello太短") } // 跳过握手类型(1) + 长度(3) + 协议版本(2) + 随机数(32) pos := 38 // 跳过会话ID if pos+1 > len(data) { return "", fmt.Errorf("解析会话ID长度时数据不足") } sessionIDLength := int(data[pos]) pos += 1 + sessionIDLength // 跳过密码套件 if pos+2 > len(data) { return "", fmt.Errorf("解析密码套件长度时数据不足") } cipherSuitesLength := int(data[pos])<<8 | int(data[pos+1]) pos += 2 + cipherSuitesLength // 跳过压缩方法 if pos+1 > len(data) { return "", fmt.Errorf("解析压缩方法长度时数据不足") } compressionMethodsLength := int(data[pos]) pos += 1 + compressionMethodsLength // 检查是否有扩展 if pos+2 > len(data) { return "", fmt.Errorf("没有扩展数据") } extensionsLength := int(data[pos])<<8 | int(data[pos+1]) pos += 2 // 解析扩展 extensionsEnd := pos + extensionsLength if extensionsEnd > len(data) { return "", fmt.Errorf("扩展数据不足") } for pos < extensionsEnd { // 读取扩展类型 if pos+4 > len(data) { return "", fmt.Errorf("解析扩展类型时数据不足") } extensionType := int(data[pos])<<8 | int(data[pos+1]) extensionLength := int(data[pos+2])<<8 | int(data[pos+3]) pos += 4 // 检查是否是SNI扩展 (类型为0) if extensionType == 0 { // 跳过SNI列表长度 if pos+2 > len(data) { return "", fmt.Errorf("解析SNI列表长度时数据不足") } pos += 2 // 检查SNI类型 if pos+1 > len(data) { return "", fmt.Errorf("解析SNI类型时数据不足") } if data[pos] != 0 { return "", fmt.Errorf("不支持的SNI类型") } pos++ // 读取主机名长度 if pos+2 > len(data) { return "", fmt.Errorf("解析主机名长度时数据不足") } hostnameLength := int(data[pos])<<8 | int(data[pos+1]) pos += 2 // 读取主机名 if pos+hostnameLength > len(data) { return "", fmt.Errorf("解析主机名时数据不足") } hostname := string(data[pos : pos+hostnameLength]) return hostname, nil } // 跳过当前扩展 pos += extensionLength } return "", fmt.Errorf("未找到SNI扩展") } // 打印TLS记录的详细信息(用于调试) func printTLSRecord(data []byte, prefix string) { if len(data) < 5 { log.Printf("%s: 数据太短,不是有效的TLS记录", prefix) return } recordType := data[0] version := (uint16(data[1]) << 8) | uint16(data[2]) length := (uint16(data[3]) << 8) | uint16(data[4]) var recordTypeStr string switch recordType { case 20: recordTypeStr = "ChangeCipherSpec" case 21: recordTypeStr = "Alert" case 22: recordTypeStr = "Handshake" case 23: recordTypeStr = "ApplicationData" default: recordTypeStr = fmt.Sprintf("Unknown(%d)", recordType) } var versionStr string switch version { case 0x0301: versionStr = "TLS 1.0" case 0x0302: versionStr = "TLS 1.1" case 0x0303: versionStr = "TLS 1.2" case 0x0304: versionStr = "TLS 1.3" default: versionStr = fmt.Sprintf("Unknown(0x%04x)", version) } log.Printf("%s: 类型=%s, 版本=%s, 长度=%d, 总字节数=%d", prefix, recordTypeStr, versionStr, length, len(data)) if recordType == recordTypeHandshake && len(data) >= 6 { handshakeType := data[5] var handshakeTypeStr string switch handshakeType { case 1: handshakeTypeStr = "ClientHello" case 2: handshakeTypeStr = "ServerHello" case 11: handshakeTypeStr = "Certificate" case 16: handshakeTypeStr = "ClientKeyExchange" default: handshakeTypeStr = fmt.Sprintf("Unknown(%d)", handshakeType) } log.Printf("%s: 握手类型=%s", prefix, handshakeTypeStr) } } // 验证TLS记录的完整性 func validateTLSRecord(data []byte) (bool, string) { if len(data) < 5 { return false, "数据太短,不是有效的TLS记录" } recordType := data[0] length := (uint16(data[3]) << 8) | uint16(data[4]) // 检查记录类型是否有效 validTypes := map[byte]bool{ 20: true, // ChangeCipherSpec 21: true, // Alert 22: true, // Handshake 23: true, // ApplicationData } if !validTypes[recordType] { return false, fmt.Sprintf("无效的TLS记录类型: %d", recordType) } // 检查记录长度是否与实际数据长度匹配 if int(length) != len(data)-5 { return false, fmt.Sprintf("TLS记录长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", length, len(data)-5) } // 对于握手记录,进行更详细的验证 if recordType == recordTypeHandshake && len(data) >= 9 { // 握手消息至少需要4字节的头部 handshakeType := data[5] handshakeLength := (uint32(data[6]) << 16) | (uint32(data[7]) << 8) | uint32(data[8]) // 检查握手消息长度是否与记录长度匹配 if int(handshakeLength)+4 > int(length) { return false, fmt.Sprintf("握手消息长度不匹配: 握手头部指示 %d 字节,记录头部指示 %d 字节", handshakeLength, length) } // 对于ClientHello,进行更详细的验证 if handshakeType == handshakeTypeClientHello && len(data) >= 43 { // ClientHello至少需要包含协议版本(2)、随机数(32)和会话ID长度(1) return true, "有效的ClientHello记录" } return true, fmt.Sprintf("有效的握手记录,类型: %d", handshakeType) } return true, fmt.Sprintf("有效的TLS记录,类型: %d", recordType) } // 验证握手消息的完整性 func validateHandshakeMessage(data []byte) (bool, string) { if len(data) < 4 { return false, "数据太短,不是有效的握手消息" } handshakeType := data[0] handshakeLength := (uint32(data[1]) << 16) | (uint32(data[2]) << 8) | uint32(data[3]) // 检查握手消息长度是否与实际数据长度匹配 if int(handshakeLength)+4 != len(data) { return false, fmt.Sprintf("握手消息长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", handshakeLength, len(data)-4) } // 对于ClientHello,进行更详细的验证 if handshakeType == handshakeTypeClientHello { if len(data) < 38 { return false, "ClientHello消息太短" } // 检查协议版本、随机数等字段 // 这里只是简单检查长度,实际应用中可以更详细地验证 return true, "有效的ClientHello消息" } return true, fmt.Sprintf("有效的握手消息,类型: %d", handshakeType) }