diff --git a/.gitignore b/.gitignore index ea8c4bf..19f9055 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,2 @@ /target +.claude diff --git a/Cargo.lock b/Cargo.lock index 0688552..7450acc 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -238,6 +238,21 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "futures" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65bc07b1a8bc7c85c5f2e110c476c7389b4554ba72af57d8445ea63a576b0876" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + [[package]] name = "futures-channel" version = "0.3.31" @@ -245,6 +260,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2dff15bf788c671c1934e366d07e30c1814a8ef514e1af724a602e8a2fbe1b10" dependencies = [ "futures-core", + "futures-sink", ] [[package]] @@ -253,6 +269,34 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05f29059c0c2090612e8d742178b0580d2dc940c837851ad723096f87af6663e" +[[package]] +name = "futures-executor" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e28d1d997f585e54aebc3f97d39e72338912123a67330d723fdbb564d646c9f" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" + +[[package]] +name = "futures-macro" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "162ee34ebcb7c64a8abebc059ce0fee27c2262618d7b60ed8faf72fef13c3650" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "futures-sink" version = "0.3.31" @@ -271,10 +315,16 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fa08315bb612088cc391249efdc3bc77536f16c91f6cf495e6fbe85b20a4a81" dependencies = [ + "futures-channel", "futures-core", + "futures-io", + "futures-macro", + "futures-sink", "futures-task", + "memchr", "pin-project-lite", "pin-utils", + "slab", ] [[package]] @@ -792,10 +842,12 @@ dependencies = [ "system-configuration", "tokio", "tokio-native-tls", + "tokio-util", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", + "wasm-streams", "web-sys", "winreg", ] @@ -1157,6 +1209,8 @@ name = "upfs" version = "0.1.0" dependencies = [ "clap", + "futures", + "futures-util", "reqwest", "serde", "serde_json", @@ -1275,6 +1329,19 @@ dependencies = [ "unicode-ident", ] +[[package]] +name = "wasm-streams" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15053d8d85c7eccdbefef60f06769760a563c7f0a9d6902a13d35c7800b0ad65" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.83" diff --git a/Cargo.toml b/Cargo.toml index 0c1fd5f..0ebdfb9 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,8 +4,10 @@ version = "0.1.0" edition = "2024" [dependencies] -reqwest = { version = "0.11", features = ["json", "multipart"] } +reqwest = { version = "0.11", features = ["json", "multipart", "stream"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" tokio = { version = "1.0", features = ["full"] } clap = { version = "4.0", features = ["derive"] } +futures = "0.3" +futures-util = "0.3" diff --git a/README.md b/README.md new file mode 100644 index 0000000..1999ab9 --- /dev/null +++ b/README.md @@ -0,0 +1,149 @@ +# UPFS - Upload File to Server + +一个用于向UPFS服务器上传文件的Rust命令行工具,现在支持实时进度跟踪和上传速度显示。 + +## 功能特性 + +- ✅ 文件上传到远程服务器 +- ✅ 用户认证(用户名/密码) +- ✅ **新增**: 实时上传进度跟踪 +- ✅ **新增**: 上传速度计算和显示 +- ✅ **新增**: 剩余时间估算 +- ✅ **新增**: 可视化进度条 +- ✅ **新增**: 灵活的进度回调API + +## 安装 + +```bash +git clone +cd upfs +cargo build --release +``` + +## 使用方法 + +### 基本用法 + +```bash +./upfs -f <文件路径> -r <远程路径> -u <用户名> -p <密码> +``` + +### 示例 + +```bash +# 上传文件(默认显示进度) +./upfs -f ./large_file.zip -r /backup/large_file.zip -u admin -p mypassword + +# 上传文件(不带密码参数,会交互式输入) +./upfs -f ./document.pdf -r /documents/doc.pdf -u admin +``` + +## 进度显示格式 + +默认情况下,所有上传都会显示实时的上传进度: + +``` +[=========> ] 45.2% | 2.3 MB/s | 12s | 预计剩余 14s | 4.5 MB/10.0 MB +``` + +进度条包含以下信息: +- **进度条**: 可视化显示上传进度 +- **百分比**: 当前的完成百分比 +- **速度**: 当前上传速度(B/s, KB/s, MB/s, GB/s) +- **已用时间**: 从开始上传到现在的时间 +- **剩余时间**: 预计完成上传还需要的时间 +- **已上传/总大小**: 已上传的数据量和总文件大小 + +## API 使用 + +### 基本上传 + +```rust +use upfs::update::upload_file; + +// 直接上传,不显示进度 +let result = upload_file(token, "file.txt", "/remote/path.txt").await?; +``` + +### 带进度回调的上传 + +```rust +use upfs::update::{upload_file_with_progress, UploadProgress}; + +// 带进度跟踪的上传 +let result = upload_file_with_progress( + token, + "large_file.zip", + "/remote/large_file.zip", + |progress| { + println!("进度: {:.1}%", progress.percentage); + println!("速度: {}", progress.format_speed()); + println!("剩余时间: {}", progress.format_remaining_time()); + } +).await?; +``` + +### UploadProgress 结构体 + +```rust +pub struct UploadProgress { + pub bytes_uploaded: u64, // 已上传字节数 + pub total_bytes: u64, // 总字节数 + pub percentage: f64, // 完成百分比 (0.0-100.0) + pub speed_bps: f64, // 上传速度(字节/秒) + pub elapsed_time: Duration, // 已用时间 +} +``` + +### UploadProgress 方法 + +- `format_speed()`: 格式化速度显示(如 "2.3 MB/s") +- `format_bytes()`: 格式化字节大小(如 "10.5 MB") +- `format_elapsed_time()`: 格式化已用时间(如 "2m 15s") +- `format_remaining_time()`: 格式化剩余时间(如 "预计剩余 1m 30s") +- `estimate_remaining_time()`: 估算剩余时间 + +## 命令行参数 + +| 参数 | 简写 | 长参数 | 描述 | 默认值 | +|------|------|--------|------|--------| +| 文件路径 | `-f` | `--file` | 要上传的文件路径 | 必需 | +| 远程路径 | `-r` | `--remote-path` | 服务器上的远程路径 | 必需 | +| 用户名 | `-u` | `--username` | 认证用户名 | "admin" | +| 密码 | `-p` | `--password` | 认证密码 | 可选(交互式输入) | + +## 示例程序 + +查看 `examples/progress_demo.rs` 获取完整的使用示例: + +```bash +cargo run --example progress_demo +``` + +## 开发和构建 + +```bash +# 检查代码 +cargo check + +# 运行测试 +cargo test + +# 构建发布版本 +cargo build --release + +# 运行示例 +cargo run --example progress_demo +``` + +## 技术细节 + +- 使用 `reqwest` 进行HTTP请求 +- 使用 `multipart/form-data` 上传文件 +- 使用异步I/O和流式处理实现进度跟踪 +- 支持大文件上传(分块读取) +- 实时计算上传速度和剩余时间 + +## 贡献 + +欢迎提交Issue和Pull Request来改进这个项目! \ No newline at end of file diff --git a/examples/basic_usage.rs b/examples/basic_usage.rs new file mode 100644 index 0000000..4496d40 --- /dev/null +++ b/examples/basic_usage.rs @@ -0,0 +1,78 @@ +// 这个示例展示了如何使用UPFS库的基本功能 +// 包括上传文件和进度跟踪 + +// 由于示例是独立的,我们需要通过cargo run --example来运行 +// 这会自动链接到库 + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("UPFS 基本使用示例"); + println!("================="); + + // 创建一个测试文件 + std::fs::write("test_upload.txt", "这是一个测试文件\n用于演示上传功能")?; + + // 模拟token(实际使用中需要通过登录获取) + let token = "Bearer test-token".to_string(); + let file_path = "test_upload.txt"; + let remote_path = "/demo/test.txt"; + + println!("准备上传文件: {}", file_path); + println!("远程路径: {}", remote_path); + println!(); + + // 演示1: 使用进度回调上传 + println!("演示1: 带进度跟踪的上传"); + println!("----------------------------"); + + let start_time = std::time::Instant::now(); + + // 这里使用简单的进度回调 + match upfs::update::upload_file_with_progress( + token.clone(), + file_path, + remote_path, + |progress| { + if progress.percentage <= 100.0 { + print!("\r\x1b[K进度: {:.1}% ({}/{}) - {} - {}", + progress.percentage, + format_bytes(progress.bytes_uploaded), + format_bytes(progress.total_bytes), + progress.format_speed(), + progress.format_remaining_time() + ); + } + std::io::Write::flush(&mut std::io::stdout()).unwrap(); + } + ).await { + Ok(response) => { + println!("\n✅ 上传成功!"); + println!("状态码: {}", response.status); + println!("响应: {}", response.text); + println!("总用时: {:?}", start_time.elapsed()); + } + Err(e) => { + println!("\n❌ 上传失败: {}", e); + } + } + + println!(); + println!("演示完成!"); + + // 清理测试文件 + std::fs::remove_file("test_upload.txt").ok(); + + Ok(()) +} + +fn format_bytes(bytes: u64) -> String { + if bytes < 1024 { + format!("{} B", bytes) + } else if bytes < 1024 * 1024 { + format!("{:.1} KB", bytes as f64 / 1024.0) + } else if bytes < 1024 * 1024 * 1024 { + format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0)) + } else { + format!("{:.1} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0)) + } +} \ No newline at end of file diff --git a/examples/progress_demo.rs b/examples/progress_demo.rs new file mode 100644 index 0000000..caedfd2 --- /dev/null +++ b/examples/progress_demo.rs @@ -0,0 +1,78 @@ +use upfs::update::{upload_file_with_progress, UploadProgress}; + +#[tokio::main] +async fn main() -> Result<(), Box> { + println!("🔥 上传进度跟踪演示"); + println!("================"); + + // 模拟登录获取token (这里使用一个假的token) + let token = "Bearer fake-token-for-demo".to_string(); + let file_path = "test_file.txt"; + let remote_path = "/demo/progress_test.txt"; + + // 设置进度回调函数 + let progress_callback = |progress: UploadProgress| { + print_progress(&progress); + }; + + println!("开始上传: {} -> {}", file_path, remote_path); + println!("进度条说明: [进度百分比] | 上传速度 | 已用时间 | 剩余时间 | 已上传/总大小"); + println!(); + + // 上传文件并显示进度 + match upload_file_with_progress(token, file_path, remote_path, progress_callback).await { + Ok(response) => { + println!("\n✅ 上传完成!"); + println!("服务器状态: {}", response.status); + println!("服务器响应: {}", response.text); + } + Err(e) => { + println!("\n❌ 上传失败: {}", e); + } + } + + Ok(()) +} + +// 显示进度的函数 +fn print_progress(progress: &UploadProgress) { + print!("\r["); + + let bar_width = 30; + let filled = (progress.percentage / 100.0 * bar_width as f64) as usize; + for i in 0..bar_width { + if i < filled { + print!("="); + } else if i == filled { + print!(">"); + } else { + print!(" "); + } + } + + print!("] {:.1}% | {} | {} | {}", + progress.percentage, + progress.format_speed(), + progress.format_elapsed_time(), + progress.format_remaining_time() + ); + + print!(" | {}/{}", + format_bytes(progress.bytes_uploaded), + progress.format_bytes() + ); + + std::io::Write::flush(&mut std::io::stdout()).unwrap(); +} + +fn format_bytes(bytes: u64) -> String { + if bytes < 1024 { + format!("{} B", bytes) + } else if bytes < 1024 * 1024 { + format!("{:.2} KB", bytes as f64 / 1024.0) + } else if bytes < 1024 * 1024 * 1024 { + format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0)) + } else { + format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0)) + } +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 6ac86c1..3cb8ebc 100644 --- a/src/main.rs +++ b/src/main.rs @@ -3,7 +3,7 @@ mod update; use clap::Parser; use login::login_and_get_token; -use update::upload_file; +use update::{upload_file_with_progress, UploadProgress}; use std::process; #[derive(Parser)] @@ -25,7 +25,8 @@ struct Cli { /// Password for authentication #[arg(short, long)] password: Option, -} + + } #[tokio::main] async fn main() { @@ -63,21 +64,64 @@ async fn main() { }; println!("正在上传文件: {} 到远程路径: {}", cli.file, cli.remote_path); + println!("进度条说明: [进度百分比] | 上传速度 | 已用时间 | 剩余时间 | 已上传/总大小"); + println!(); - // 上传文件 - match upload_file(token, &cli.file, &cli.remote_path).await { - Ok((true, response)) => { - println!("✅ 文件上传成功!"); - println!("服务器响应: {}", response); - } - Ok((false, response)) => { - println!("❌ 文件上传失败!"); - println!("服务器响应: {}", response); - process::exit(1); + // 默认使用进度跟踪的上传 + match upload_file_with_progress(token, &cli.file, &cli.remote_path, |progress| { + print_progress(&progress); + }).await { + Ok(response) => { + println!("\n✅ 文件上传成功!"); + println!("服务器响应: {}", response.text); } Err(e) => { - eprintln!("上传过程中发生错误: {}", e); + eprintln!("\n上传过程中发生错误: {}", e); process::exit(1); } } } + +// 显示进度的函数 +fn print_progress(progress: &UploadProgress) { + print!("\r\x1b[K"); // 清除从光标到行尾的内容 + print!("["); + + let bar_width = 30; + let filled = (progress.percentage / 100.0 * bar_width as f64) as usize; + for i in 0..bar_width { + if i < filled { + print!("="); + } else if i == filled { + print!(">"); + } else { + print!(" "); + } + } + + print!("] {:.1}% | {} | {} | {}", + progress.percentage, + progress.format_speed(), + progress.format_elapsed_time(), + progress.format_remaining_time() + ); + + print!(" | {}/{}", + format_bytes(progress.bytes_uploaded), + progress.format_bytes() + ); + + std::io::Write::flush(&mut std::io::stdout()).unwrap(); +} + +fn format_bytes(bytes: u64) -> String { + if bytes < 1024 { + format!("{} B", bytes) + } else if bytes < 1024 * 1024 { + format!("{:.2} KB", bytes as f64 / 1024.0) + } else if bytes < 1024 * 1024 * 1024 { + format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0)) + } else { + format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0)) + } +} diff --git a/src/update/form.rs b/src/update/form.rs index 572616b..89c4b77 100644 --- a/src/update/form.rs +++ b/src/update/form.rs @@ -1,6 +1,9 @@ use reqwest; use reqwest::multipart; use std::path::Path; +use tokio::io::AsyncReadExt; + +use super::progress::{UploadProgress, ProgressTracker}; #[derive(Debug)] pub struct UploadResponse { @@ -49,6 +52,145 @@ pub async fn upload_file_with_token( }) } +// 支持进度跟踪的上传函数 +pub async fn upload_file_with_progress( + token: String, + file_path: &str, + remote_path: &str, + progress_callback: F, +) -> Result> +where + F: Fn(UploadProgress) + Send + Sync + 'static, +{ + let client = reqwest::Client::new(); + let file_size = tokio::fs::metadata(file_path).await?.len(); + + // 创建进度跟踪器 + let (mut tracker, _receiver) = ProgressTracker::new(file_size); + + // 读取文件并跟踪进度 + let mut file = tokio::fs::File::open(file_path).await?; + let mut buffer = Vec::with_capacity(file_size as usize); + let mut bytes_read = 0u64; + + // 设置回调函数 + tracker = tracker.with_callback(Box::new(progress_callback)); + + // 分块读取文件并更新进度 + let mut chunk = [0; 8192]; // 8KB chunks + loop { + let n = file.read(&mut chunk).await?; + if n == 0 { + break; + } + bytes_read += n as u64; + buffer.extend_from_slice(&chunk[..n]); + + // 更新进度 + tracker.update(bytes_read); + } + + // 创建multipart form + let file_name = Path::new(file_path) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or("file"); + + let file_part = multipart::Part::bytes(buffer) + .file_name(file_name.to_string()); + + let form = multipart::Form::new() + .part("file", file_part); + + // 最终确保进度为100% + tracker.update(file_size); + + // Send PUT request + let response = client + .put("http://192.168.1.56:5255/api/fs/form") + .header("Authorization", token) + .header("File-Path", remote_path) + .multipart(form) + .send() + .await?; + + let status = response.status(); + let text = response.text().await?; + let success = status.is_success(); + + Ok(UploadResponse { + status, + text, + success, + }) +} + +// 更高层次的API,直接返回进度接收器 +pub async fn upload_file_with_progress_channel( + token: String, + file_path: &str, + remote_path: &str, +) -> Result<(UploadResponse, super::progress::ProgressReceiver), Box> { + let client = reqwest::Client::new(); + let file_size = tokio::fs::metadata(file_path).await?.len(); + + // 创建进度跟踪器和channel + let (tracker, receiver) = ProgressTracker::new(file_size); + + // 读取文件并跟踪进度 + let mut file = tokio::fs::File::open(file_path).await?; + let mut buffer = Vec::with_capacity(file_size as usize); + let mut bytes_read = 0u64; + + // 分块读取文件并更新进度 + let mut chunk = [0; 8192]; // 8KB chunks + loop { + let n = file.read(&mut chunk).await?; + if n == 0 { + break; + } + bytes_read += n as u64; + buffer.extend_from_slice(&chunk[..n]); + + // 更新进度 + tracker.update(bytes_read); + } + + // 创建multipart form + let file_name = Path::new(file_path) + .file_name() + .and_then(|name| name.to_str()) + .unwrap_or("file"); + + let file_part = multipart::Part::bytes(buffer) + .file_name(file_name.to_string()); + + let form = multipart::Form::new() + .part("file", file_part); + + // 最终确保进度为100% + tracker.update(file_size); + + // Send PUT request + let response = client + .put("http://192.168.1.56:5255/api/fs/form") + .header("Authorization", token) + .header("File-Path", remote_path) + .multipart(form) + .send() + .await?; + + let status = response.status(); + let text = response.text().await?; + let success = status.is_success(); + + Ok((UploadResponse { + status, + text, + success, + }, receiver)) +} + // Convenient function that directly returns success status and response text pub async fn upload_file( token: String, diff --git a/src/update/mod.rs b/src/update/mod.rs index aa5aecd..2150d2c 100644 --- a/src/update/mod.rs +++ b/src/update/mod.rs @@ -1,3 +1,5 @@ pub mod form; +pub mod progress; -pub use form::upload_file; \ No newline at end of file +pub use form::{upload_file, upload_file_with_progress}; +pub use progress::{UploadProgress}; \ No newline at end of file diff --git a/src/update/progress.rs b/src/update/progress.rs new file mode 100644 index 0000000..f45ac41 --- /dev/null +++ b/src/update/progress.rs @@ -0,0 +1,154 @@ +use std::time::{Duration, Instant}; +use std::sync::{Arc, Mutex}; +use tokio::sync::mpsc; + +#[derive(Debug, Clone)] +pub struct UploadProgress { + pub bytes_uploaded: u64, + pub total_bytes: u64, + pub percentage: f64, + pub speed_bps: f64, + pub elapsed_time: Duration, +} + +impl UploadProgress { + pub fn new(total_bytes: u64) -> Self { + Self { + bytes_uploaded: 0, + total_bytes, + percentage: 0.0, + speed_bps: 0.0, + elapsed_time: Duration::default(), + } + } + + pub fn update(&mut self, bytes_uploaded: u64, start_time: Instant) { + self.bytes_uploaded = bytes_uploaded; + self.percentage = if self.total_bytes > 0 { + (bytes_uploaded as f64 / self.total_bytes as f64) * 100.0 + } else { + 0.0 + }; + + self.elapsed_time = start_time.elapsed(); + + // 计算速度 (字节/秒) + if self.elapsed_time.as_secs_f64() > 0.0 { + self.speed_bps = bytes_uploaded as f64 / self.elapsed_time.as_secs_f64(); + } + } + + pub fn format_speed(&self) -> String { + if self.speed_bps < 1024.0 { + format!("{:.2} B/s", self.speed_bps) + } else if self.speed_bps < 1024.0 * 1024.0 { + format!("{:.2} KB/s", self.speed_bps / 1024.0) + } else if self.speed_bps < 1024.0 * 1024.0 * 1024.0 { + format!("{:.2} MB/s", self.speed_bps / (1024.0 * 1024.0)) + } else { + format!("{:.2} GB/s", self.speed_bps / (1024.0 * 1024.0 * 1024.0)) + } + } + + pub fn format_bytes(&self) -> String { + if self.total_bytes < 1024 { + format!("{} B", self.total_bytes) + } else if self.total_bytes < 1024 * 1024 { + format!("{:.2} KB", self.total_bytes as f64 / 1024.0) + } else if self.total_bytes < 1024 * 1024 * 1024 { + format!("{:.2} MB", self.total_bytes as f64 / (1024.0 * 1024.0)) + } else { + format!("{:.2} GB", self.total_bytes as f64 / (1024.0 * 1024.0 * 1024.0)) + } + } + + pub fn format_elapsed_time(&self) -> String { + let secs = self.elapsed_time.as_secs(); + if secs < 60 { + format!("{}s", secs) + } else if secs < 3600 { + format!("{}m {}s", secs / 60, secs % 60) + } else { + format!("{}h {}m {}s", secs / 3600, (secs % 3600) / 60, secs % 60) + } + } + + pub fn estimate_remaining_time(&self) -> Duration { + if self.speed_bps > 0.0 && self.bytes_uploaded < self.total_bytes { + let remaining_bytes = self.total_bytes - self.bytes_uploaded; + let remaining_secs = remaining_bytes as f64 / self.speed_bps; + Duration::from_secs_f64(remaining_secs) + } else { + Duration::default() + } + } + + pub fn format_remaining_time(&self) -> String { + let remaining = self.estimate_remaining_time(); + let secs = remaining.as_secs(); + if secs == 0 { + "完成".to_string() + } else if secs < 60 { + format!("预计剩余 {}s", secs) + } else if secs < 3600 { + format!("预计剩余 {}m {}s", secs / 60, secs % 60) + } else { + format!("预计剩余 {}h {}m {}s", secs / 3600, (secs % 3600) / 60, secs % 60) + } + } +} + +pub type ProgressSender = mpsc::UnboundedSender; +pub type ProgressReceiver = mpsc::UnboundedReceiver; + +pub fn create_progress_channel() -> (ProgressSender, ProgressReceiver) { + mpsc::unbounded_channel() +} + +// 进度回调函数类型 +pub type ProgressCallback = Box; + +pub struct ProgressTracker { + progress: Arc>, + start_time: Instant, + sender: ProgressSender, + callback: Option, +} + +impl ProgressTracker { + pub fn new(total_bytes: u64) -> (Self, ProgressReceiver) { + let (sender, receiver) = create_progress_channel(); + let progress = Arc::new(Mutex::new(UploadProgress::new(total_bytes))); + + let tracker = Self { + progress: Arc::clone(&progress), + start_time: Instant::now(), + sender, + callback: None, + }; + + (tracker, receiver) + } + + pub fn with_callback(mut self, callback: ProgressCallback) -> Self { + self.callback = Some(callback); + self + } + + pub fn update(&self, bytes_uploaded: u64) { + let mut progress = self.progress.lock().unwrap(); + progress.update(bytes_uploaded, self.start_time); + + // 发送进度更新到channel + let _ = self.sender.send(progress.clone()); + + // 如果有回调函数,调用它 + if let Some(ref callback) = self.callback { + callback(progress.clone()); + } + } + + pub fn get_progress(&self) -> UploadProgress { + self.progress.lock().unwrap().clone() + } +} \ No newline at end of file diff --git a/src/update/progress_reader.rs b/src/update/progress_reader.rs new file mode 100644 index 0000000..5cadbcc --- /dev/null +++ b/src/update/progress_reader.rs @@ -0,0 +1,83 @@ +use std::io::{self, Read}; +use std::sync::Arc; +use crate::update::progress::ProgressTracker; + +pub struct ProgressRead { + reader: R, + tracker: Arc, + bytes_read: u64, +} + +impl ProgressRead +where + R: Read, +{ + pub fn new(reader: R, tracker: Arc) -> Self { + Self { + reader, + tracker, + bytes_read: 0, + } + } +} + +impl Read for ProgressRead +where + R: Read, +{ + fn read(&mut self, buf: &mut [u8]) -> io::Result { + let bytes_read = self.reader.read(buf)?; + self.bytes_read += bytes_read as u64; + + // 更新进度 + self.tracker.update(self.bytes_read); + + Ok(bytes_read) + } +} + +// 为了支持异步读取,我们需要一个异步版本 +use futures::io::{AsyncRead}; +use std::pin::Pin; +use std::task::{Context, Poll}; + +pub struct AsyncProgressRead { + reader: R, + tracker: Arc, + bytes_read: u64, +} + +impl AsyncProgressRead +where + R: AsyncRead + Unpin, +{ + pub fn new(reader: R, tracker: Arc) -> Self { + Self { + reader, + tracker, + bytes_read: 0, + } + } +} + +impl AsyncRead for AsyncProgressRead +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + buf: &mut [u8], + ) -> Poll> { + let this = &mut *self; + match Pin::new(&mut this.reader).poll_read(cx, buf) { + Poll::Ready(Ok(bytes_read)) => { + this.bytes_read += bytes_read as u64; + this.tracker.update(this.bytes_read); + Poll::Ready(Ok(bytes_read)) + } + Poll::Ready(Err(e)) => Poll::Ready(Err(e)), + Poll::Pending => Poll::Pending, + } + } +} \ No newline at end of file