rust websocket implment

axum = { version = "0.7.5", features = ["ws"] }
hyper-util = "0.1.1"
hyper = { version = "1", features = ["full"] }

main.rs

#![allow(unused)]
fn main() {
    use axum::extract::Request;
    use axum::middleware::Next;
    use axum::response::Response;
    use hyper::StatusCode;
    use std::net::IpAddr;
    pub async fn log_ip_middleware(
        request: Request<axum::body::Body>,
        next: Next,
    ) -> Result<Response, StatusCode> {
        let ip_port = request
            .extensions()
            .get::<axum::extract::ConnectInfo<std::net::SocketAddr>>()
            .map(|ci| (ci.0.ip(), ci.0.port()))
            .unwrap_or((IpAddr::from([0, 0, 0, 0]), 80));
        info!("request_ip_port: {}:{}", ip_port.0, ip_port.1);
        return Ok(next.run(request).await);
    }

    let app = Router::new()
        .route("/get_file/*filepath", get(get_file))
        .route("/ws", get(ws_handler))
        .layer(cors_layer)
        .layer(axum::middleware::from_fn(log_ip_middleware))
        .layer(axum::Extension(pool))
        .layer(DefaultBodyLimit::disable())
        .layer(
            ServiceBuilder::new()
                .layer(HandleErrorLayer::new(handle_error))
                .load_shed()
                .concurrency_limit(1024)
                .timeout(Duration::from_secs(100))
                .layer(TraceLayer::new_for_http()),
        );
    let listener: tokio::net::TcpListener =
        tokio::net::TcpListener::bind(format!("{}:{}", SERVER_HOST.lock(), SERVER_PORT.lock()))
            .await
            .unwrap();
    info!("listening on: http://{}", listener.local_addr().unwrap());
    axum::serve(listener, app).await.unwrap();
}

websocket.rs

#![allow(unused)]
fn main() {
use axum::extract::ws::Message;
use axum::extract::ws::WebSocket;
use axum::extract::WebSocketUpgrade;
use axum::response::Response;
use axum::Extension;
use deadpool_tiberius::deadpool::managed::Object;
use deadpool_tiberius::Manager;
use deadpool_tiberius::Pool;
use serde::Deserialize;
use serde::Serialize;

pub async fn ws_handler(
    Extension(pool): Extension<Pool>,
    ws_upgrade: WebSocketUpgrade,
) -> Response {
    let pool: deadpool_tiberius::deadpool::managed::Pool<deadpool_tiberius::Manager> = pool.clone();
    ws_upgrade.on_upgrade(move |ws| handle_ws(ws, pool))
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, Default)]
pub struct WsGetMsg {
    msg: String,
    arg: String,
}
pub async fn handle_ws(
    mut socket: WebSocket,
    pool: deadpool_tiberius::deadpool::managed::Pool<deadpool_tiberius::Manager>,
) {
    // let mut client: Object<Manager> = pool.get().await.unwrap();

    while let Some(msg) = socket.recv().await {
        if let Ok(msg) = msg {
            if let Message::Text(text) = msg {
                let json_msg: WsGetMsg = serde_json::from_str(&text).unwrap_or_default();
                if json_msg.msg.eq("search_news_list") {
                    let arg = json_msg.arg;
                    // let file_content = std::fs::read(arg).unwrap();
                    if socket
                        // .send(Message::Binary(file_content))
                        .send(Message::Text(arg.to_string()))
                        .await
                        .is_err()
                    {
                        // 如果出错了就关闭连接
                        println!("Web Socket Closed");
                        return;
                    }
                }
            }
        };
    }
    println!("Web Socket Closed");
}

}

websocket.html

<!DOCTYPE html>
<html lang="en">

<head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <title>Document</title>
</head>

<body>

    <label for="path">Path</label>
    <input type="text" class="path" style="width: 80%;">
    <div class="app">
        <button class="btn">get_file_content</button>
    </div>
    <script>
        let path_val = "";
        function downloadBlob(blob, filename) {
            const url = URL.createObjectURL(blob);
            const link = document.createElement('a');
            link.href = url;
            link.download = filename;
            document.body.appendChild(link);
            link.click();
            document.body.removeChild(link);
            URL.revokeObjectURL(url);
        }
        // websocket
        const ws = new WebSocket('ws://127.0.0.1:15000/ws');
        ws.onopen = () => {
            console.log('WebSocket connection established');
        };

        ws.onmessage = (event) => {
            let file_name = path_val.split('\\').pop();
            console.log(file_name);
            downloadBlob(event.data, file_name);
            // console.log('Received message from server:', event.data);
        };

        ws.onclose = () => {
            console.log('WebSocket connection closed');
        };

        class WsGetMsg {
            constructor(msg = "", arg = "") {
                this.msg = msg;
                this.arg = arg;
            }

            // Optional: Add methods to serialize/deserialize if needed
            toJSON() {
                return {
                    msg: this.msg,
                    arg: this.arg
                };
            }

            static fromJSON(json) {
                return new WsGetMsg(json.msg, json.arg);
            }
        }

        let btn = document.querySelector('.btn');
        btn.onclick = () => {
            path_val = document.querySelector('.path');
            const msg = new WsGetMsg("search_news_list", path_val);
            ws.send(JSON.stringify(msg));
        }
    </script>
</body>

</html>