191 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Rust
		
	
	
		
			Executable File
		
	
	
	
	
			
		
		
	
	
			191 lines
		
	
	
		
			6.0 KiB
		
	
	
	
		
			Rust
		
	
	
		
			Executable File
		
	
	
	
	
| use clap::{App, Arg};
 | |
| use lws_client::LwsVfsIns;
 | |
| use std::{process, thread};
 | |
| use std::process::Command;
 | |
| extern crate log;
 | |
| use nix::unistd::{fork, ForkResult, getpid, getppid};
 | |
| use signal_hook::consts::{SIGCHLD, SIGINT};
 | |
| use signal_hook::iterator::Signals;
 | |
| use std::env;
 | |
| 
 | |
| const DEFAULT_PORT: u16 = 33444;
 | |
| const DEFAULT_CACHE_LIFE: u32 = 10_000;
 | |
| fn get_ssh_clinet() -> String {
 | |
|     // 获取特定环境变量的值
 | |
|     let client = env::var("SSH_CLIENT").expect("only support auto get ssh connection");
 | |
|     let client = client.split(' ').next().unwrap();
 | |
|     client.to_string()
 | |
| }
 | |
| 
 | |
| fn param_parser() -> (String, String, u32) {
 | |
|     let matches = App::new("lws_client")
 | |
|         .version("1.0")
 | |
|         .author("Ekko.bao")
 | |
|         .about("linux&windows shared filesystem")
 | |
|         .arg(
 | |
|             Arg::with_name("s")
 | |
|                 .short('s')
 | |
|                 .long("server")
 | |
|                 .takes_value(true)
 | |
|                 .min_values(1)
 | |
|                 .max_values(1)
 | |
|                 .value_name("server")
 | |
|                 .help(
 | |
|                     format!(
 | |
|                         "server addr: [ip[:port]], default is [sshclient:{}]",
 | |
|                         DEFAULT_PORT
 | |
|                     )
 | |
|                     .as_str(),
 | |
|                 ),
 | |
|         )
 | |
|         .arg(
 | |
|             Arg::with_name("m")
 | |
|                 .short('m')
 | |
|                 .long("mount")
 | |
|                 .takes_value(true)
 | |
|                 .min_values(1)
 | |
|                 .max_values(1)
 | |
|                 .value_name("mount point")
 | |
|                 .required(true)
 | |
|                 .help("mount point, eg: ~/mnt"),
 | |
|         )
 | |
|         .arg(
 | |
|             Arg::with_name("c")
 | |
|                 .short('c')
 | |
|                 .long("cache")
 | |
|                 .takes_value(true)
 | |
|                 .min_values(1)
 | |
|                 .max_values(1)
 | |
|                 .value_name("cache invalid time")
 | |
|                 .help("file info cache invalid time, unit: ms, defualt is 10_000ms"),
 | |
|         )
 | |
|         .get_matches();
 | |
|     let (ip, port) = match matches.value_of("s") {
 | |
|         Some(server) => {
 | |
|             let split = server.split(":").collect::<Vec<&str>>();
 | |
|             if split.len() == 2 {
 | |
|                 (split[0].to_string(), split[1].to_string())
 | |
|             } else {
 | |
|                 (split[0].to_string(), DEFAULT_PORT.to_string())
 | |
|             }
 | |
|         }
 | |
|         None => (String::new(), String::new()),
 | |
|     };
 | |
|     let ip = if ip.len() == 0 { get_ssh_clinet() } else { ip };
 | |
|     let port = if port.len() == 0 {
 | |
|         DEFAULT_PORT.to_string()
 | |
|     } else {
 | |
|         port
 | |
|     };
 | |
|     let server = format!("{}:{}", ip, port);
 | |
|     log::info!("args server: [{}]", server);
 | |
|     let mount_point = matches.value_of("m").unwrap().to_string();
 | |
|     log::info!("args mount_point: [{}]", mount_point);
 | |
|     let cache_life = match matches.value_of("c") {
 | |
|         Some(cache_life) => cache_life.parse().unwrap(),
 | |
|         None => DEFAULT_CACHE_LIFE,
 | |
|     };
 | |
|     log::info!("args cache invalid time: [{}ms]", cache_life);
 | |
|     (server, mount_point, cache_life)
 | |
| }
 | |
| 
 | |
| fn umount_point(mount: &String) {
 | |
|     log::info!("fusermount -u {}", mount);
 | |
|     let output = Command::new("fusermount")
 | |
|        .arg("-u")
 | |
|        .arg(mount)
 | |
|        .output()
 | |
|        .expect("umount command execaute failed");
 | |
|     if output.status.success(){
 | |
|         log::info!("{}", format!("umount {} success", mount));
 | |
|     } else {
 | |
|         if let Some(code) = output.status.code() {
 | |
|             let stdout = String::from_utf8_lossy(&output.stdout).to_string();
 | |
|             let stderr:String = String::from_utf8_lossy(&output.stderr).to_string();
 | |
|             // 不存在的话就不需要umount了
 | |
|             if stderr.contains("not found") {
 | |
|                 log::info!("{}", format!("not need umount {}", mount));
 | |
|                 return;
 | |
|             }
 | |
|             log::error!("{}", format!("umount {} fail: {}", mount, code));
 | |
|             log::error!("{} {}", stdout, stderr);
 | |
|         } else {
 | |
|             log::error!("{}", format!("umount {} interrupted", mount));
 | |
|         }
 | |
|     }
 | |
|     
 | |
| }
 | |
| fn set_cleanup(mount: &String) {
 | |
|     // 捕获 SIGCHLD 信号
 | |
|     log::info!("waiting for SIGCHLD SIGINT signal");
 | |
|     let mut signals = Signals::new(&[SIGCHLD, SIGINT]).expect("Error creating signals");
 | |
|     for _ in signals.forever() {
 | |
|         log::info!("Catch SIGCHLD||SIGINT, exit....");
 | |
|         umount_point(&mount);
 | |
|         break;
 | |
|     }
 | |
| }
 | |
| 
 | |
| #[tokio::main]
 | |
| async fn main() -> Result<(), Box<dyn std::error::Error>> {
 | |
|     env_logger::init();
 | |
|     let (server, mount_point, cache_life) = param_parser();
 | |
|     let lws_ins = match LwsVfsIns::new(&server, cache_life).await {
 | |
|         Ok(ins) => ins,
 | |
|         Err(e) => {
 | |
|             log::error!("Error creating lws server instance: {:?}", e);
 | |
|             return Err(e);
 | |
|         }
 | |
|     };
 | |
|     // 尝试卸载此前的挂载点
 | |
|     umount_point(&mount_point);
 | |
|     child_process(&mount_point);
 | |
|     log::info!("lws client instance created");
 | |
|     match lws_ins.hello().await {
 | |
|         Err(e) => {
 | |
|             log::error!("lws client instance hello err {:?}", e);
 | |
|             return Err(e);
 | |
|         }
 | |
|         _ => {}
 | |
|     }
 | |
|     log::info!("start mount process");
 | |
|     let handle = thread::spawn(move || match LwsVfsIns::mount(&mount_point, lws_ins) {
 | |
|         Ok(_) => Ok::<i32, String>(0),
 | |
|         Err(e) => {
 | |
|             log::error!("mount err {:?}", e);
 | |
|             Ok::<i32, String>(-1)
 | |
|         }
 | |
|     });
 | |
|     match handle.join() {
 | |
|         Ok(_) => Ok(()),
 | |
|         Err(e) => {
 | |
|             log::error!("mount thread start err {:?}", e);
 | |
|             Err(Box::new(std::io::Error::new(
 | |
|                 std::io::ErrorKind::Other,
 | |
|                 "mount fail",
 | |
|             )))
 | |
|         }
 | |
|     }
 | |
| 
 | |
| }
 | |
| 
 | |
| fn child_process(mount: &String) {
 | |
|     match unsafe{fork()} {
 | |
|         Ok(ForkResult::Parent { child: _ }) => {
 | |
|             // no thing
 | |
|         }
 | |
|         Ok(ForkResult::Child) => {
 | |
|             let pid = getpid();
 | |
|             let ppid = getppid();
 | |
|             log::info!("parent: {}, child: {}", ppid, pid);
 | |
|             set_cleanup(mount);
 | |
|         }
 | |
|         Err(err) => {
 | |
|             log::error!("fork fail: {}", err);
 | |
|             process::exit(1);
 | |
|         }
 | |
|     }
 | |
|     
 | |
|     
 | |
| }
 |