SNI_proxy/proxy/server.go

979 lines
26 KiB
Go
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

package proxy
import (
"context"
"crypto/rand"
"fmt"
"io"
"log"
"math/big"
"net"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/SNI_Proxy/config"
"github.com/SNI_Proxy/dns"
)
// 服务器结构
type Server struct {
config *config.Config
listener net.Listener
wg sync.WaitGroup
shutdownCh chan struct{}
resolver dns.Resolver
connMutex sync.Mutex
conns map[net.Conn]struct{} // 活跃连接跟踪
connCounter int64 // 当前连接计数器
}
// 创建新的代理服务器
func NewServer(cfg *config.Config) *Server {
return &Server{
config: cfg,
shutdownCh: make(chan struct{}),
conns: make(map[net.Conn]struct{}),
}
}
// 启动代理服务器
func (s *Server) Start() error {
var err error
// 创建DNS解析器
s.resolver, err = dns.NewResolver(s.config.DNS)
if err != nil {
return fmt.Errorf("创建DNS解析器失败: %w", err)
}
s.listener, err = net.Listen("tcp", s.config.Listen)
if err != nil {
return fmt.Errorf("监听地址失败: %w", err)
}
s.wg.Add(1)
go s.acceptLoop()
// 启动连接监控
go s.monitorConnections()
// 启动连接清理
go s.cleanIdleConnections()
return nil
}
// 停止代理服务器
func (s *Server) Stop() {
log.Println("正在关闭服务器...")
close(s.shutdownCh)
if s.listener != nil {
s.listener.Close()
}
// 关闭所有活跃连接
s.closeAllConnections()
if s.resolver != nil {
s.resolver.Close()
}
// 等待所有goroutine结束
log.Println("等待所有连接关闭...")
s.wg.Wait()
log.Println("服务器已关闭")
}
// 关闭所有活跃连接
func (s *Server) closeAllConnections() {
s.connMutex.Lock()
defer s.connMutex.Unlock()
log.Printf("关闭 %d 个活跃连接", len(s.conns))
for conn := range s.conns {
conn.Close()
delete(s.conns, conn)
}
}
// 添加连接到跟踪列表
func (s *Server) trackConn(conn net.Conn) bool {
s.connMutex.Lock()
defer s.connMutex.Unlock()
// 检查是否超过最大连接数
if s.config.MaxConns > 0 && len(s.conns) >= s.config.MaxConns {
log.Printf("达到最大连接数 %d拒绝新连接", s.config.MaxConns)
return false
}
s.conns[conn] = struct{}{}
atomic.AddInt64(&s.connCounter, 1)
return true
}
// 从跟踪列表中移除连接
func (s *Server) untrackConn(conn net.Conn) {
s.connMutex.Lock()
defer s.connMutex.Unlock()
if _, exists := s.conns[conn]; exists {
delete(s.conns, conn)
atomic.AddInt64(&s.connCounter, -1)
}
}
// 监控连接状态
func (s *Server) monitorConnections() {
ticker := time.NewTicker(30 * time.Second)
defer ticker.Stop()
for {
select {
case <-s.shutdownCh:
return
case <-ticker.C:
// 记录当前活跃连接数
s.connMutex.Lock()
activeConns := len(s.conns)
s.connMutex.Unlock()
totalConns := atomic.LoadInt64(&s.connCounter)
log.Printf("连接统计 - 当前活跃: %d, 总计处理: %d", activeConns, totalConns)
}
}
}
// 清理空闲连接
func (s *Server) cleanIdleConnections() {
// 每分钟检查一次空闲连接
ticker := time.NewTicker(1 * time.Minute)
defer ticker.Stop()
for {
select {
case <-s.shutdownCh:
return
case <-ticker.C:
s.connMutex.Lock()
log.Printf("开始清理空闲连接,当前连接数: %d", len(s.conns))
// 这里我们不做实际清理,因为连接已经设置了超时
// 实际清理由各个连接自己的超时机制处理
s.connMutex.Unlock()
}
}
}
// 接受连接循环
func (s *Server) acceptLoop() {
defer s.wg.Done()
for {
conn, err := s.listener.Accept()
if err != nil {
select {
case <-s.shutdownCh:
return
default:
log.Printf("接受连接失败: %v", err)
// 短暂休眠避免CPU占用过高
time.Sleep(100 * time.Millisecond)
continue
}
}
// 检查是否可以接受新连接
if !s.trackConn(conn) {
conn.Close()
continue
}
s.wg.Add(1)
go func(clientConn net.Conn) {
defer s.wg.Done()
defer s.untrackConn(clientConn) // 移除连接跟踪
defer clientConn.Close()
// 设置连接生命周期
if s.config.Timeout.LifeTime > 0 {
deadline := time.Now().Add(time.Duration(s.config.Timeout.LifeTime) * time.Second)
clientConn.SetDeadline(deadline)
}
// 创建一个带超时的上下文
ctx, cancel := context.WithTimeout(context.Background(),
time.Duration(s.config.Timeout.Idle)*time.Second)
defer cancel()
// 使用goroutine和channel处理连接支持超时控制
errCh := make(chan error, 1)
go func() {
errCh <- s.handleConnection(clientConn)
}()
// 等待处理完成或超时
select {
case err := <-errCh:
if err != nil {
if s.config.LogLevel == "debug" ||
(s.config.LogLevel != "error" && !isCommonNetworkError(err)) {
log.Printf("处理连接失败: %v", err)
}
}
case <-ctx.Done():
log.Printf("处理连接超时")
case <-s.shutdownCh:
log.Printf("服务器关闭,终止连接处理")
}
}(conn)
}
}
// 判断是否是常见网络错误(可以不记录日志的错误)
func isCommonNetworkError(err error) bool {
if err == nil {
return false
}
errStr := err.Error()
// 检查常见的网络错误
commonErrors := []string{
"connection reset by peer",
"broken pipe",
"i/o timeout",
"use of closed network connection",
"EOF",
}
for _, e := range commonErrors {
if strings.Contains(strings.ToLower(errStr), e) {
return true
}
}
return false
}
// 处理单个连接
func (s *Server) handleConnection(clientConn net.Conn) error {
// 设置读取超时
readTimeout := time.Duration(s.config.Timeout.Read) * time.Second
if err := clientConn.SetReadDeadline(time.Now().Add(readTimeout)); err != nil {
return fmt.Errorf("设置读取超时失败: %w", err)
}
// 读取并解析SNI信息
sni, clientHello, err := readSNI(clientConn)
if err != nil {
return fmt.Errorf("读取SNI失败: %w", err)
}
// 重置读取超时
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("重置读取超时失败: %w", err)
}
// 查找匹配的规则
var targetRule *config.ProxyRule
for i := range s.config.Rules {
rule := &s.config.Rules[i]
for j := range rule.Domains {
if rule.Domains[j].Match(sni) {
targetRule = rule
break
}
}
if targetRule != nil {
break
}
}
if targetRule == nil {
// 如果没有匹配的规则,使用默认端口
if s.config.LogLevel != "error" {
log.Printf("未找到匹配的规则,使用默认端口 %d 连接到 %s", s.config.DefaultPort, sni)
}
return s.proxyConnection(clientConn, sni, s.config.DefaultPort, clientHello, config.FragmentConfig{Enabled: false})
}
// 使用匹配规则中的端口和分片配置
return s.proxyConnection(clientConn, sni, targetRule.Port, clientHello, targetRule.Fragment)
}
// 代理连接到目标服务器
func (s *Server) proxyConnection(clientConn net.Conn, targetHost string, targetPort int, clientHello []byte, fragConfig config.FragmentConfig) error {
// 设置DNS解析超时
ctx, cancel := context.WithTimeout(context.Background(), time.Duration(s.config.DNS.Timeout)*time.Second)
defer cancel()
// 创建一个带超时的解析任务
type resolveResult struct {
ips []net.IP
err error
}
resCh := make(chan resolveResult, 1)
go func() {
ips, err := s.resolver.Resolve(targetHost)
resCh <- resolveResult{ips, err}
}()
// 等待DNS解析完成或超时
var ips []net.IP
select {
case res := <-resCh:
if res.err != nil {
return fmt.Errorf("解析域名失败: %w", res.err)
}
ips = res.ips
case <-ctx.Done():
return fmt.Errorf("解析域名超时")
}
// 选择第一个IP地址
targetIP := ips[0]
if s.config.LogLevel == "debug" {
log.Printf("域名 %s 解析为 %s", targetHost, targetIP.String())
}
// 构建目标地址
targetAddr := net.JoinHostPort(targetIP.String(), strconv.Itoa(targetPort))
if s.config.LogLevel != "error" {
log.Printf("代理连接到: %s (原始SNI: %s)", targetAddr, targetHost)
}
// 设置连接超时
dialCtx, dialCancel := context.WithTimeout(context.Background(),
time.Duration(s.config.Timeout.Connect)*time.Second)
defer dialCancel()
// 连接到目标服务器
var d net.Dialer
targetConn, err := d.DialContext(dialCtx, "tcp", targetAddr)
if err != nil {
return fmt.Errorf("连接目标服务器失败: %w", err)
}
defer targetConn.Close()
// 设置目标连接的读写超时
if tcpConn, ok := targetConn.(*net.TCPConn); ok {
tcpConn.SetKeepAlive(true)
tcpConn.SetKeepAlivePeriod(30 * time.Second)
}
// 发送ClientHello到目标服务器可能需要进行TLS分片
if err := s.sendClientHello(targetConn, clientHello, fragConfig); err != nil {
return fmt.Errorf("发送ClientHello失败: %w", err)
}
// 双向转发数据
errCh := make(chan error, 2)
copyDone := make(chan struct{}, 2)
// 客户端 -> 目标服务器
go func() {
buf := make([]byte, 32*1024) // 使用较大的缓冲区
_, err := s.copyBuffer(targetConn, clientConn, buf)
errCh <- err
copyDone <- struct{}{}
}()
// 目标服务器 -> 客户端
go func() {
buf := make([]byte, 32*1024) // 使用较大的缓冲区
_, err := s.copyBuffer(clientConn, targetConn, buf)
errCh <- err
copyDone <- struct{}{}
}()
// 等待任一方向的数据传输完成
<-copyDone
// 设置一个短暂的超时,让另一个方向有机会完成
time.Sleep(100 * time.Millisecond)
// 检查是否有错误
select {
case err := <-errCh:
if err != nil && err != io.EOF && !isCommonNetworkError(err) {
return fmt.Errorf("数据转发失败: %w", err)
}
default:
// 没有错误
}
return nil
}
// 带超时控制的数据复制
func (s *Server) copyBuffer(dst net.Conn, src net.Conn, buf []byte) (int64, error) {
var written int64
readTimeout := time.Duration(s.config.Timeout.Read) * time.Second
writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second
for {
// 设置读取超时
src.SetReadDeadline(time.Now().Add(readTimeout))
nr, err := src.Read(buf)
if nr > 0 {
// 设置写入超时
dst.SetWriteDeadline(time.Now().Add(writeTimeout))
nw, err := dst.Write(buf[0:nr])
if nw > 0 {
written += int64(nw)
}
if err != nil {
return written, err
}
if nr != nw {
return written, io.ErrShortWrite
}
}
if err != nil {
if err == io.EOF {
return written, nil
}
return written, err
}
}
}
// 发送ClientHello根据配置进行TLS分片
func (s *Server) sendClientHello(conn net.Conn, clientHello []byte, fragConfig config.FragmentConfig) error {
writeTimeout := time.Duration(s.config.Timeout.Write) * time.Second
if s.config.LogLevel == "debug" {
printTLSRecord(clientHello, "原始ClientHello")
}
if !fragConfig.Enabled {
// 不进行分片,直接发送
if s.config.LogLevel == "debug" {
log.Printf("TLS分片未启用直接发送 %d 字节", len(clientHello))
}
// 设置写入超时
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := conn.Write(clientHello)
return err
}
// 进行TLS分片
if s.config.LogLevel == "debug" {
log.Printf("启用TLS分片分片大小范围: %d-%d", fragConfig.MinSize, fragConfig.MaxSize)
}
// 检查ClientHello是否是有效的TLS记录
if fragConfig.Validate {
valid, reason := validateTLSRecord(clientHello)
if !valid {
log.Printf("警告: ClientHello不是有效的TLS记录 (%s),跳过分片", reason)
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := conn.Write(clientHello)
return err
}
}
if clientHello[0] != recordTypeHandshake {
log.Printf("警告: 记录类型不是握手类型,跳过分片")
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := conn.Write(clientHello)
return err
}
// 获取TLS记录头部和数据部分
recordHeader := clientHello[:5]
recordData := clientHello[5:]
recordLength := len(recordData)
if s.config.LogLevel == "debug" {
log.Printf("TLS记录头部: %v, 数据长度: %d", recordHeader, recordLength)
// 如果是握手消息,打印握手消息头部信息
if recordData[0] == handshakeTypeClientHello && len(recordData) >= 4 {
handshakeType := recordData[0]
handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3])
log.Printf("握手消息头部: 类型=%d, 长度=%d", handshakeType, handshakeLength)
}
}
// 如果记录太小,不进行分片
if recordLength <= fragConfig.MinSize {
if s.config.LogLevel == "debug" {
log.Printf("记录太小 (%d 字节),不进行分片", recordLength)
}
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
_, err := conn.Write(clientHello)
return err
}
// 第一个分片:发送记录头部和部分数据
minSize := fragConfig.MinSize
if minSize <= 0 {
minSize = 10 // 默认最小分片大小
}
maxSize := fragConfig.MaxSize
if maxSize <= 0 {
maxSize = 100 // 默认最大分片大小
}
if minSize > maxSize {
minSize = maxSize
}
// 随机分片大小
size := minSize
if maxSize > minSize {
randVal, _ := rand.Int(rand.Reader, big.NewInt(int64(maxSize-minSize+1)))
size = minSize + int(randVal.Int64())
}
// 确保不超过记录数据大小
if size > recordLength {
size = recordLength
}
// 对于握手消息确保第一个分片至少包含完整的握手消息头部4字节
if recordData[0] == handshakeTypeClientHello && size < 4 {
size = 4 // 至少包含握手消息头部
}
// 创建第一个分片
firstFragment := make([]byte, 5+size)
copy(firstFragment[:5], recordHeader)
copy(firstFragment[5:], recordData[:size])
// 修改第一个分片的长度字段
firstFragment[3] = byte(size >> 8)
firstFragment[4] = byte(size)
if s.config.LogLevel == "debug" {
printTLSRecord(firstFragment, "第一个分片")
log.Printf("第一个分片长度: %d (头部5字节 + 数据%d字节)", len(firstFragment), size)
// 验证第一个分片的握手消息
if recordData[0] == handshakeTypeClientHello {
// 创建一个临时的握手消息进行验证
tempHandshake := make([]byte, size)
copy(tempHandshake, recordData[:size])
// 调整握手消息长度以匹配实际数据
handshakeLength := uint32(size - 4) // 减去握手消息头部的4字节
tempHandshake[1] = byte(handshakeLength >> 16)
tempHandshake[2] = byte(handshakeLength >> 8)
tempHandshake[3] = byte(handshakeLength)
valid, reason := validateHandshakeMessage(tempHandshake)
log.Printf("第一个分片握手消息验证: %v, %s", valid, reason)
}
}
// 设置写入超时并发送第一个分片
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if _, err := conn.Write(firstFragment); err != nil {
return fmt.Errorf("发送第一个分片失败: %w", err)
}
// 如果还有剩余数据,创建第二个分片
if size < recordLength {
// 添加一个小延迟,模拟网络延迟,提高分片效果
delayMin := fragConfig.DelayMin
if delayMin <= 0 {
delayMin = 10 // 默认最小延迟10毫秒
}
delayMax := fragConfig.DelayMax
if delayMax <= 0 {
delayMax = 30 // 默认最大延迟30毫秒
}
if delayMin > delayMax {
delayMin = delayMax
}
delayRange := big.NewInt(int64(delayMax - delayMin + 1))
randVal, _ := rand.Int(rand.Reader, delayRange)
delay := time.Duration(delayMin+int(randVal.Int64())) * time.Millisecond
time.Sleep(delay)
if s.config.LogLevel == "debug" {
log.Printf("延迟 %v 后发送第二个分片", delay)
}
// 创建第二个分片
remainingSize := recordLength - size
secondFragment := make([]byte, 5+remainingSize)
// 复制记录头部(保持与原始记录头部一致)
copy(secondFragment[:5], recordHeader)
// 设置正确的长度字段
secondFragment[3] = byte(remainingSize >> 8)
secondFragment[4] = byte(remainingSize)
// 复制剩余数据,确保握手类型保持一致
// 注意:我们不能直接复制数据,因为第二个分片的握手消息头部需要特殊处理
if recordData[0] == handshakeTypeClientHello {
// 创建新的握手消息头部
// 1. 保持握手类型不变
// 2. 调整握手消息长度
handshakeLength := (uint32(recordData[1]) << 16) | (uint32(recordData[2]) << 8) | uint32(recordData[3])
// 计算第一个分片中已发送的握手消息数据长度不包括握手消息头部的4字节
sentDataLength := uint32(size) - 4
// 计算剩余的握手消息长度
remainingHandshakeLength := handshakeLength - sentDataLength
if s.config.LogLevel == "debug" {
log.Printf("握手消息总长度: %d, 第一个分片已发送数据: %d, 第二个分片剩余数据: %d",
handshakeLength, sentDataLength, remainingHandshakeLength)
}
// 复制握手类型
secondFragment[5] = recordData[0]
// 设置新的握手消息长度
secondFragment[6] = byte(remainingHandshakeLength >> 16)
secondFragment[7] = byte(remainingHandshakeLength >> 8)
secondFragment[8] = byte(remainingHandshakeLength)
// 复制剩余数据(跳过第一个分片已经发送的部分)
copy(secondFragment[9:], recordData[size:])
} else {
// 如果不是ClientHello直接复制数据
copy(secondFragment[5:], recordData[size:])
}
// 对于第二个分片,我们采用不同的策略
// 不再尝试修改握手消息头部,而是直接将剩余数据作为应用数据发送
// 这样可以避免TLS握手解析错误
// 修改第二个分片的记录类型为应用数据
secondFragment[0] = 23 // ApplicationData
// 复制剩余数据
copy(secondFragment[5:], recordData[size:])
if s.config.LogLevel == "debug" {
log.Printf("第二个分片使用应用数据类型,避免握手解析错误")
printTLSRecord(secondFragment, "第二个分片")
log.Printf("第二个分片长度: %d (头部5字节 + 数据%d字节)", len(secondFragment), remainingSize)
}
// 设置写入超时并发送第二个分片
conn.SetWriteDeadline(time.Now().Add(writeTimeout))
if _, err := conn.Write(secondFragment); err != nil {
return fmt.Errorf("发送第二个分片失败: %w", err)
}
if s.config.LogLevel == "debug" {
log.Printf("TLS分片完成: 总共发送 %d 字节 (第一分片: %d 字节, 第二分片: %d 字节)",
len(firstFragment)+len(secondFragment), len(firstFragment), len(secondFragment))
}
} else {
if s.config.LogLevel == "debug" {
log.Printf("TLS分片完成: 只需一个分片,总共发送 %d 字节", len(firstFragment))
}
}
return nil
}
// TLS记录类型
const (
recordTypeHandshake = 22
)
// TLS握手类型
const (
handshakeTypeClientHello = 1
)
// 读取SNI信息
func readSNI(conn net.Conn) (string, []byte, error) {
// 读取TLS记录头部
header := make([]byte, 5)
if _, err := io.ReadFull(conn, header); err != nil {
return "", nil, fmt.Errorf("读取TLS记录头部失败: %w", err)
}
// 检查是否是TLS握手记录
if header[0] != recordTypeHandshake {
return "", nil, fmt.Errorf("不是TLS握手记录记录类型: %d", header[0])
}
// 获取记录长度
recordLength := int(header[3])<<8 | int(header[4])
if recordLength > 16384 {
return "", nil, fmt.Errorf("TLS记录太长: %d 字节", recordLength)
}
if recordLength < 4 {
return "", nil, fmt.Errorf("TLS握手记录太短: %d 字节", recordLength)
}
// 读取握手消息
handshakeData := make([]byte, recordLength)
if _, err := io.ReadFull(conn, handshakeData); err != nil {
return "", nil, fmt.Errorf("读取握手消息失败: %w", err)
}
// 检查是否是ClientHello
if handshakeData[0] != handshakeTypeClientHello {
return "", nil, fmt.Errorf("不是ClientHello消息握手类型: %d", handshakeData[0])
}
// 完整的ClientHello消息包括记录头部
clientHello := append(header, handshakeData...)
// 验证TLS记录的完整性
valid, reason := validateTLSRecord(clientHello)
if !valid {
log.Printf("警告: 收到的ClientHello不是有效的TLS记录: %s", reason)
}
// 解析SNI扩展
sni, err := extractSNI(handshakeData)
if err != nil {
return "", nil, fmt.Errorf("提取SNI失败: %w", err)
}
if sni != "" && len(sni) > 0 {
log.Printf("成功提取SNI: %s", sni)
}
return sni, clientHello, nil
}
// 从ClientHello中提取SNI
func extractSNI(data []byte) (string, error) {
if len(data) < 42 {
return "", fmt.Errorf("ClientHello太短")
}
// 跳过握手类型(1) + 长度(3) + 协议版本(2) + 随机数(32)
pos := 38
// 跳过会话ID
if pos+1 > len(data) {
return "", fmt.Errorf("解析会话ID长度时数据不足")
}
sessionIDLength := int(data[pos])
pos += 1 + sessionIDLength
// 跳过密码套件
if pos+2 > len(data) {
return "", fmt.Errorf("解析密码套件长度时数据不足")
}
cipherSuitesLength := int(data[pos])<<8 | int(data[pos+1])
pos += 2 + cipherSuitesLength
// 跳过压缩方法
if pos+1 > len(data) {
return "", fmt.Errorf("解析压缩方法长度时数据不足")
}
compressionMethodsLength := int(data[pos])
pos += 1 + compressionMethodsLength
// 检查是否有扩展
if pos+2 > len(data) {
return "", fmt.Errorf("没有扩展数据")
}
extensionsLength := int(data[pos])<<8 | int(data[pos+1])
pos += 2
// 解析扩展
extensionsEnd := pos + extensionsLength
if extensionsEnd > len(data) {
return "", fmt.Errorf("扩展数据不足")
}
for pos < extensionsEnd {
// 读取扩展类型
if pos+4 > len(data) {
return "", fmt.Errorf("解析扩展类型时数据不足")
}
extensionType := int(data[pos])<<8 | int(data[pos+1])
extensionLength := int(data[pos+2])<<8 | int(data[pos+3])
pos += 4
// 检查是否是SNI扩展 (类型为0)
if extensionType == 0 {
// 跳过SNI列表长度
if pos+2 > len(data) {
return "", fmt.Errorf("解析SNI列表长度时数据不足")
}
pos += 2
// 检查SNI类型
if pos+1 > len(data) {
return "", fmt.Errorf("解析SNI类型时数据不足")
}
if data[pos] != 0 {
return "", fmt.Errorf("不支持的SNI类型")
}
pos++
// 读取主机名长度
if pos+2 > len(data) {
return "", fmt.Errorf("解析主机名长度时数据不足")
}
hostnameLength := int(data[pos])<<8 | int(data[pos+1])
pos += 2
// 读取主机名
if pos+hostnameLength > len(data) {
return "", fmt.Errorf("解析主机名时数据不足")
}
hostname := string(data[pos : pos+hostnameLength])
return hostname, nil
}
// 跳过当前扩展
pos += extensionLength
}
return "", fmt.Errorf("未找到SNI扩展")
}
// 打印TLS记录的详细信息用于调试
func printTLSRecord(data []byte, prefix string) {
if len(data) < 5 {
log.Printf("%s: 数据太短不是有效的TLS记录", prefix)
return
}
recordType := data[0]
version := (uint16(data[1]) << 8) | uint16(data[2])
length := (uint16(data[3]) << 8) | uint16(data[4])
var recordTypeStr string
switch recordType {
case 20:
recordTypeStr = "ChangeCipherSpec"
case 21:
recordTypeStr = "Alert"
case 22:
recordTypeStr = "Handshake"
case 23:
recordTypeStr = "ApplicationData"
default:
recordTypeStr = fmt.Sprintf("Unknown(%d)", recordType)
}
var versionStr string
switch version {
case 0x0301:
versionStr = "TLS 1.0"
case 0x0302:
versionStr = "TLS 1.1"
case 0x0303:
versionStr = "TLS 1.2"
case 0x0304:
versionStr = "TLS 1.3"
default:
versionStr = fmt.Sprintf("Unknown(0x%04x)", version)
}
log.Printf("%s: 类型=%s, 版本=%s, 长度=%d, 总字节数=%d",
prefix, recordTypeStr, versionStr, length, len(data))
if recordType == recordTypeHandshake && len(data) >= 6 {
handshakeType := data[5]
var handshakeTypeStr string
switch handshakeType {
case 1:
handshakeTypeStr = "ClientHello"
case 2:
handshakeTypeStr = "ServerHello"
case 11:
handshakeTypeStr = "Certificate"
case 16:
handshakeTypeStr = "ClientKeyExchange"
default:
handshakeTypeStr = fmt.Sprintf("Unknown(%d)", handshakeType)
}
log.Printf("%s: 握手类型=%s", prefix, handshakeTypeStr)
}
}
// 验证TLS记录的完整性
func validateTLSRecord(data []byte) (bool, string) {
if len(data) < 5 {
return false, "数据太短不是有效的TLS记录"
}
recordType := data[0]
length := (uint16(data[3]) << 8) | uint16(data[4])
// 检查记录类型是否有效
validTypes := map[byte]bool{
20: true, // ChangeCipherSpec
21: true, // Alert
22: true, // Handshake
23: true, // ApplicationData
}
if !validTypes[recordType] {
return false, fmt.Sprintf("无效的TLS记录类型: %d", recordType)
}
// 检查记录长度是否与实际数据长度匹配
if int(length) != len(data)-5 {
return false, fmt.Sprintf("TLS记录长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", length, len(data)-5)
}
// 对于握手记录,进行更详细的验证
if recordType == recordTypeHandshake && len(data) >= 9 {
// 握手消息至少需要4字节的头部
handshakeType := data[5]
handshakeLength := (uint32(data[6]) << 16) | (uint32(data[7]) << 8) | uint32(data[8])
// 检查握手消息长度是否与记录长度匹配
if int(handshakeLength)+4 > int(length) {
return false, fmt.Sprintf("握手消息长度不匹配: 握手头部指示 %d 字节,记录头部指示 %d 字节", handshakeLength, length)
}
// 对于ClientHello进行更详细的验证
if handshakeType == handshakeTypeClientHello && len(data) >= 43 {
// ClientHello至少需要包含协议版本(2)、随机数(32)和会话ID长度(1)
return true, "有效的ClientHello记录"
}
return true, fmt.Sprintf("有效的握手记录,类型: %d", handshakeType)
}
return true, fmt.Sprintf("有效的TLS记录类型: %d", recordType)
}
// 验证握手消息的完整性
func validateHandshakeMessage(data []byte) (bool, string) {
if len(data) < 4 {
return false, "数据太短,不是有效的握手消息"
}
handshakeType := data[0]
handshakeLength := (uint32(data[1]) << 16) | (uint32(data[2]) << 8) | uint32(data[3])
// 检查握手消息长度是否与实际数据长度匹配
if int(handshakeLength)+4 != len(data) {
return false, fmt.Sprintf("握手消息长度不匹配: 头部指示 %d 字节,实际数据 %d 字节", handshakeLength, len(data)-4)
}
// 对于ClientHello进行更详细的验证
if handshakeType == handshakeTypeClientHello {
if len(data) < 38 {
return false, "ClientHello消息太短"
}
// 检查协议版本、随机数等字段
// 这里只是简单检查长度,实际应用中可以更详细地验证
return true, "有效的ClientHello消息"
}
return true, fmt.Sprintf("有效的握手消息,类型: %d", handshakeType)
}