From 9e88973ef7076fdd32dbce904124e4a087675b22 Mon Sep 17 00:00:00 2001 From: Yoo1tic <137816438+Yoo1tic@users.noreply.github.com> Date: Mon, 28 Jul 2025 18:35:06 +0800 Subject: [PATCH] feat: refactor key validation logic and restructure configuration handling; remove unused modules --- Cargo.lock | 1 + Cargo.toml | 1 + src/config/{basic_config.rs => config.rs} | 0 src/config/mod.rs | 6 +- src/error.rs | 15 ++++ src/key_validator.rs | 78 ------------------- src/lib.rs | 5 +- src/main.rs | 22 ++---- .../http_client.rs} | 8 +- src/service/key_tester.rs | 63 +++++++++++++++ src/service/mod.rs | 7 ++ src/{ => service}/validation.rs | 23 +++++- src/types.rs | 7 -- 13 files changed, 119 insertions(+), 117 deletions(-) rename src/config/{basic_config.rs => config.rs} (100%) create mode 100644 src/error.rs delete mode 100644 src/key_validator.rs rename src/{config/basic_client.rs => service/http_client.rs} (93%) create mode 100644 src/service/key_tester.rs create mode 100644 src/service/mod.rs rename src/{ => service}/validation.rs (75%) diff --git a/Cargo.lock b/Cargo.lock index 508b4b2..f68f27a 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -398,6 +398,7 @@ dependencies = [ "reqwest", "serde", "serde_json", + "thiserror", "tokio", "toml 0.9.2", "url", diff --git a/Cargo.toml b/Cargo.toml index 9585149..c673841 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -23,3 +23,4 @@ async-stream = "0.3" figment = { version = "0.10", features = ["env", "toml"] } serde = { version = "1.0", features = ["derive"] } toml = "0.9" +thiserror = "2.0.12" diff --git a/src/config/basic_config.rs b/src/config/config.rs similarity index 100% rename from src/config/basic_config.rs rename to src/config/config.rs diff --git a/src/config/mod.rs b/src/config/mod.rs index 688680a..ee0ad2d 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,5 +1,3 @@ -mod basic_client; -mod basic_config; +mod config; -pub use basic_client::client_builder; -pub use basic_config::{KeyCheckerConfig, TEST_MESSAGE_BODY}; +pub use config::{KeyCheckerConfig, TEST_MESSAGE_BODY}; diff --git a/src/error.rs b/src/error.rs new file mode 100644 index 0000000..fb2090d --- /dev/null +++ b/src/error.rs @@ -0,0 +1,15 @@ +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum ValidationError { + #[error("HTTP error: {0}")] + HttpRequest(#[from] reqwest::Error), + + #[error("Key is unavailable or invalid")] + KeyUnavailable, + + #[error("Key validation failed: {0}")] + Invalid(String), +} + +pub type Result = std::result::Result; diff --git a/src/key_validator.rs b/src/key_validator.rs deleted file mode 100644 index a0f7a42..0000000 --- a/src/key_validator.rs +++ /dev/null @@ -1,78 +0,0 @@ -use anyhow::Result; -use backon::{ExponentialBuilder, Retryable}; -use reqwest::{Client, StatusCode}; -use tokio::time::Duration; -use url::Url; - -use crate::config::TEST_MESSAGE_BODY; -use crate::types::{GeminiKey, KeyStatus}; - -pub async fn validate_key_with_retry( - client: Client, - full_url: Url, - key: GeminiKey, -) -> Option { - let retry_policy = ExponentialBuilder::default() - .with_max_times(3) - .with_min_delay(Duration::from_secs(3)) - .with_max_delay(Duration::from_secs(5)); - - let result = (async || match keytest(client.to_owned(), &full_url, &key).await { - Ok(KeyStatus::Valid) => { - println!("Key: {}... -> SUCCESS", &key.as_ref()[..10]); - Ok(Some(key.clone())) - } - Ok(KeyStatus::Invalid) => { - println!("Key: {}... -> INVALID (Forbidden)", &key.as_ref()[..10]); - Ok(None) - } - Ok(KeyStatus::Retryable(reason)) => { - eprintln!( - "Key: {}... -> RETRYABLE (Reason: {})", - &key.as_ref()[..10], - reason - ); - Err(anyhow::anyhow!("Retryable error: {}", reason)) - } - Err(e) => { - eprintln!( - "Key: {}... -> NETWORK ERROR (Reason: {})", - &key.as_ref()[..10], - e - ); - Err(e) - } - }) - .retry(retry_policy) - .await; - - match result { - Ok(key_result) => key_result, - Err(_) => { - eprintln!( - "Key: {}... -> FAILED after all retries.", - &key.as_ref()[..10] - ); - None - } - } -} - -async fn keytest(client: Client, full_url: &Url, key: &GeminiKey) -> Result { - let response = client - .post(full_url.clone()) - .header("Content-Type", "application/json") - .header("X-goog-api-key", key.as_ref()) - .json(&*TEST_MESSAGE_BODY) - .send() - .await?; - - let status = response.status(); - - let key_status = match status { - StatusCode::OK => KeyStatus::Valid, - StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED => KeyStatus::Invalid, - other => KeyStatus::Retryable(format!("Received status {}, will retry.", other)), - }; - Ok(key_status) -} diff --git a/src/lib.rs b/src/lib.rs index e987dfc..9fc54f4 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,8 +1,8 @@ pub mod adapters; pub mod config; -pub mod key_validator; +pub mod error; pub mod types; -pub mod validation; +pub mod service; // ASCII art for Gemini pub const BANNER: &str = r#" @@ -12,4 +12,3 @@ pub const BANNER: &str = r#" / /_/ // __// / / / / // // / / // / \____/ \___//_/ /_/ /_//_//_/ /_//_/ "#; - diff --git a/src/main.rs b/src/main.rs index f17cad8..358cc45 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,30 +1,20 @@ use anyhow::Result; -use gemini_keychecker::{ - BANNER, - adapters::load_keys, - config::{KeyCheckerConfig, client_builder}, - validation::ValidationService, -}; +use gemini_keychecker::{BANNER, config::KeyCheckerConfig, service::start_validation}; use mimalloc::MiMalloc; #[global_allocator] static GLOBAL: MiMalloc = MiMalloc; -/// Main function - orchestrates the key validation process +/// Main function - displays banner and starts validation service #[tokio::main] async fn main() -> Result<()> { - let config = KeyCheckerConfig::load_config().unwrap(); - // Display banner and configuration status at startup println!("{BANNER}"); + + let config = KeyCheckerConfig::load_config()?; println!("{config}"); - let keys = load_keys(config.input_path.as_path())?; - let client = client_builder(&config)?; - - let validation_service = ValidationService::new(config, client); - validation_service.validate_keys(keys).await?; - - Ok(()) + // Start validation service + start_validation().await } diff --git a/src/config/basic_client.rs b/src/service/http_client.rs similarity index 93% rename from src/config/basic_client.rs rename to src/service/http_client.rs index b567980..456ea9b 100644 --- a/src/config/basic_client.rs +++ b/src/service/http_client.rs @@ -1,8 +1,6 @@ -use std::time::Duration; - -use reqwest::Client; - use crate::config::KeyCheckerConfig; +use reqwest::Client; +use std::time::Duration; pub fn client_builder(config: &KeyCheckerConfig) -> Result { // Set the maximum number of connections per host based on concurrency. @@ -17,7 +15,7 @@ pub fn client_builder(config: &KeyCheckerConfig) -> Result Result { + let api_endpoint = api_endpoint.into_url()?; + + match send_test_request(client, &api_endpoint, api_key.clone()).await { + Ok(response) => { + let status = response.status(); + match status { + StatusCode::OK => Ok(api_key), + StatusCode::UNAUTHORIZED + | StatusCode::FORBIDDEN + | StatusCode::TOO_MANY_REQUESTS => Err(ValidationError::KeyUnavailable), + _ => Err(ValidationError::HttpRequest( + response.error_for_status().unwrap_err(), + )), + } + } + Err(e) => Err(ValidationError::HttpRequest(e)), + } +} + +async fn send_test_request( + client: Client, + api_endpoint: &Url, + key: GeminiKey, +) -> Result { + let retry_policy = ExponentialBuilder::default() + .with_max_times(3) + .with_min_delay(Duration::from_secs(3)) + .with_max_delay(Duration::from_secs(5)); + + (async || { + let response = client + .post(api_endpoint.clone()) + .header("Content-Type", "application/json") + .header("X-goog-api-key", key.as_ref()) + .json(&*TEST_MESSAGE_BODY) + .send() + .await?; + + response.error_for_status() + }) + .retry(&retry_policy) + .when(|e: &reqwest::Error| { + !matches!( + e.status(), + Some(StatusCode::FORBIDDEN | StatusCode::UNAUTHORIZED) + ) + }) + .await +} diff --git a/src/service/mod.rs b/src/service/mod.rs new file mode 100644 index 0000000..424ef96 --- /dev/null +++ b/src/service/mod.rs @@ -0,0 +1,7 @@ +pub mod http_client; +pub mod key_tester; +pub mod validation; + +pub use http_client::client_builder; +pub use key_tester::validate_key; +pub use validation::{ValidationService, start_validation}; diff --git a/src/validation.rs b/src/service/validation.rs similarity index 75% rename from src/validation.rs rename to src/service/validation.rs index 9ad79c3..0b64ecb 100644 --- a/src/validation.rs +++ b/src/service/validation.rs @@ -5,9 +5,9 @@ use reqwest::Client; use std::time::Instant; use tokio::{fs, io::AsyncWriteExt, sync::mpsc}; -use crate::adapters::write_keys_txt_file; +use super::{key_tester::validate_key, http_client::client_builder}; +use crate::adapters::{write_keys_txt_file, load_keys}; use crate::config::KeyCheckerConfig; -use crate::key_validator::validate_key_with_retry; use crate::types::GeminiKey; pub struct ValidationService { @@ -48,9 +48,9 @@ impl ValidationService { // Create stream to validate keys concurrently let valid_keys_stream = stream - .map(|key| validate_key_with_retry(self.client.to_owned(), self.full_url.clone(), key)) + .map(|key| validate_key(self.client.clone(), self.full_url.clone(), key)) .buffer_unordered(self.config.concurrency) - .filter_map(|r| async { r }); + .filter_map(|result| async { result.ok() }); pin_mut!(valid_keys_stream); // Open output file for writing valid keys @@ -72,3 +72,18 @@ impl ValidationService { Ok(()) } } + +/// 启动验证服务 - 封装了所有启动逻辑 +pub async fn start_validation() -> Result<()> { + let config = KeyCheckerConfig::load_config()?; + + // 加载密钥 + let keys = load_keys(config.input_path.as_path())?; + + // 构建HTTP客户端 + let client = client_builder(&config)?; + + // 创建验证服务并启动 + let validation_service = ValidationService::new(config, client); + validation_service.validate_keys(keys).await +} diff --git a/src/types.rs b/src/types.rs index 44bbd1a..ffb9f7d 100644 --- a/src/types.rs +++ b/src/types.rs @@ -2,13 +2,6 @@ use regex::Regex; use std::str::FromStr; use std::sync::LazyLock; -#[derive(Debug)] -pub enum KeyStatus { - Valid, - Invalid, - Retryable(String), -} - #[derive(Debug, Clone)] pub struct GeminiKey { pub inner: String,