191 lines
6.0 KiB
Rust
Executable File
191 lines
6.0 KiB
Rust
Executable File
use clap::{App, Arg};
|
|
use lws_client::LwsVfsIns;
|
|
use std::{process, thread};
|
|
use std::process::Command;
|
|
extern crate log;
|
|
use nix::unistd::{fork, ForkResult, getpid, getppid};
|
|
use signal_hook::consts::{SIGCHLD, SIGINT};
|
|
use signal_hook::iterator::Signals;
|
|
use std::env;
|
|
|
|
const DEFAULT_PORT: u16 = 33444;
|
|
const DEFAULT_CACHE_LIFE: u32 = 10_000;
|
|
fn get_ssh_clinet() -> String {
|
|
// 获取特定环境变量的值
|
|
let client = env::var("SSH_CLIENT").expect("only support auto get ssh connection");
|
|
let client = client.split(' ').next().unwrap();
|
|
client.to_string()
|
|
}
|
|
|
|
fn param_parser() -> (String, String, u32) {
|
|
let matches = App::new("lws_client")
|
|
.version("1.0")
|
|
.author("Ekko.bao")
|
|
.about("linux&windows shared filesystem")
|
|
.arg(
|
|
Arg::with_name("s")
|
|
.short('s')
|
|
.long("server")
|
|
.takes_value(true)
|
|
.min_values(1)
|
|
.max_values(1)
|
|
.value_name("server")
|
|
.help(
|
|
format!(
|
|
"server addr: [ip[:port]], default is [sshclient:{}]",
|
|
DEFAULT_PORT
|
|
)
|
|
.as_str(),
|
|
),
|
|
)
|
|
.arg(
|
|
Arg::with_name("m")
|
|
.short('m')
|
|
.long("mount")
|
|
.takes_value(true)
|
|
.min_values(1)
|
|
.max_values(1)
|
|
.value_name("mount point")
|
|
.required(true)
|
|
.help("mount point, eg: ~/mnt"),
|
|
)
|
|
.arg(
|
|
Arg::with_name("c")
|
|
.short('c')
|
|
.long("cache")
|
|
.takes_value(true)
|
|
.min_values(1)
|
|
.max_values(1)
|
|
.value_name("cache invalid time")
|
|
.help("file info cache invalid time, unit: ms, defualt is 10_000ms"),
|
|
)
|
|
.get_matches();
|
|
let (ip, port) = match matches.value_of("s") {
|
|
Some(server) => {
|
|
let split = server.split(":").collect::<Vec<&str>>();
|
|
if split.len() == 2 {
|
|
(split[0].to_string(), split[1].to_string())
|
|
} else {
|
|
(split[0].to_string(), DEFAULT_PORT.to_string())
|
|
}
|
|
}
|
|
None => (String::new(), String::new()),
|
|
};
|
|
let ip = if ip.len() == 0 { get_ssh_clinet() } else { ip };
|
|
let port = if port.len() == 0 {
|
|
DEFAULT_PORT.to_string()
|
|
} else {
|
|
port
|
|
};
|
|
let server = format!("{}:{}", ip, port);
|
|
log::info!("args server: [{}]", server);
|
|
let mount_point = matches.value_of("m").unwrap().to_string();
|
|
log::info!("args mount_point: [{}]", mount_point);
|
|
let cache_life = match matches.value_of("c") {
|
|
Some(cache_life) => cache_life.parse().unwrap(),
|
|
None => DEFAULT_CACHE_LIFE,
|
|
};
|
|
log::info!("args cache invalid time: [{}ms]", cache_life);
|
|
(server, mount_point, cache_life)
|
|
}
|
|
|
|
fn umount_point(mount: &String) {
|
|
log::info!("fusermount -u {}", mount);
|
|
let output = Command::new("fusermount")
|
|
.arg("-u")
|
|
.arg(mount)
|
|
.output()
|
|
.expect("umount command execaute failed");
|
|
if output.status.success(){
|
|
log::info!("{}", format!("umount {} success", mount));
|
|
} else {
|
|
if let Some(code) = output.status.code() {
|
|
let stdout = String::from_utf8_lossy(&output.stdout).to_string();
|
|
let stderr:String = String::from_utf8_lossy(&output.stderr).to_string();
|
|
// 不存在的话就不需要umount了
|
|
if stderr.contains("not found") {
|
|
log::info!("{}", format!("not need umount {}", mount));
|
|
return;
|
|
}
|
|
log::error!("{}", format!("umount {} fail: {}", mount, code));
|
|
log::error!("{} {}", stdout, stderr);
|
|
} else {
|
|
log::error!("{}", format!("umount {} interrupted", mount));
|
|
}
|
|
}
|
|
|
|
}
|
|
fn set_cleanup(mount: &String) {
|
|
// 捕获 SIGCHLD 信号
|
|
log::info!("waiting for SIGCHLD SIGINT signal");
|
|
let mut signals = Signals::new(&[SIGCHLD, SIGINT]).expect("Error creating signals");
|
|
for _ in signals.forever() {
|
|
log::info!("Catch SIGCHLD||SIGINT, exit....");
|
|
umount_point(&mount);
|
|
break;
|
|
}
|
|
}
|
|
|
|
#[tokio::main]
|
|
async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|
env_logger::init();
|
|
let (server, mount_point, cache_life) = param_parser();
|
|
let lws_ins = match LwsVfsIns::new(&server, cache_life).await {
|
|
Ok(ins) => ins,
|
|
Err(e) => {
|
|
log::error!("Error creating lws server instance: {:?}", e);
|
|
return Err(e);
|
|
}
|
|
};
|
|
// 尝试卸载此前的挂载点
|
|
umount_point(&mount_point);
|
|
child_process(&mount_point);
|
|
log::info!("lws client instance created");
|
|
match lws_ins.hello().await {
|
|
Err(e) => {
|
|
log::error!("lws client instance hello err {:?}", e);
|
|
return Err(e);
|
|
}
|
|
_ => {}
|
|
}
|
|
log::info!("start mount process");
|
|
let handle = thread::spawn(move || match LwsVfsIns::mount(&mount_point, lws_ins) {
|
|
Ok(_) => Ok::<i32, String>(0),
|
|
Err(e) => {
|
|
log::error!("mount err {:?}", e);
|
|
Ok::<i32, String>(-1)
|
|
}
|
|
});
|
|
match handle.join() {
|
|
Ok(_) => Ok(()),
|
|
Err(e) => {
|
|
log::error!("mount thread start err {:?}", e);
|
|
Err(Box::new(std::io::Error::new(
|
|
std::io::ErrorKind::Other,
|
|
"mount fail",
|
|
)))
|
|
}
|
|
}
|
|
|
|
}
|
|
|
|
fn child_process(mount: &String) {
|
|
match unsafe{fork()} {
|
|
Ok(ForkResult::Parent { child: _ }) => {
|
|
// no thing
|
|
}
|
|
Ok(ForkResult::Child) => {
|
|
let pid = getpid();
|
|
let ppid = getppid();
|
|
log::info!("parent: {}, child: {}", ppid, pid);
|
|
set_cleanup(mount);
|
|
}
|
|
Err(err) => {
|
|
log::error!("fork fail: {}", err);
|
|
process::exit(1);
|
|
}
|
|
}
|
|
|
|
|
|
}
|