Skip to main content

resource_tracker/collector/
network.rs

1use crate::metrics::NetworkMetrics;
2use procfs::net::dev_status;
3use std::collections::HashMap;
4use std::time::Instant;
5
6type Result<T> = std::result::Result<T, Box<dyn std::error::Error>>;
7
8// ---------------------------------------------------------------------------
9// sysfs helpers
10// ---------------------------------------------------------------------------
11
12fn sysfs_read(path: &str) -> Option<String> {
13    std::fs::read_to_string(path)
14        .ok()
15        .map(|s| s.trim().to_string())
16        .filter(|s| !s.is_empty())
17}
18
19fn net_attr(iface: &str, attr: &str) -> Option<String> {
20    sysfs_read(&format!("/sys/class/net/{}/{}", iface, attr))
21}
22
23// ---------------------------------------------------------------------------
24// Hardware identity - read once at startup
25// ---------------------------------------------------------------------------
26
27#[derive(Clone)]
28struct InterfaceInfo {
29    mac_address: Option<String>,
30    /// Kernel driver name resolved from `/sys/class/net/<if>/device/driver`
31    /// symlink basename, e.g. "igc", "virtio_net", "e1000e".
32    driver: Option<String>,
33}
34
35fn read_interface_info(iface: &str) -> InterfaceInfo {
36    let mac_address = net_attr(iface, "address");
37
38    // The driver symlink points to something like
39    // ../../../../bus/pci/drivers/igc - we just want the basename.
40    let driver = std::fs::read_link(format!("/sys/class/net/{}/device/driver", iface))
41        .ok()
42        .and_then(|p| p.file_name().map(|n| n.to_string_lossy().to_string()));
43
44    InterfaceInfo {
45        mac_address,
46        driver,
47    }
48}
49
50/// Discover all non-loopback interfaces and cache their static identity.
51/// Called once in NetworkCollector::new().
52fn discover_interfaces() -> HashMap<String, InterfaceInfo> {
53    let Ok(entries) = std::fs::read_dir("/sys/class/net") else {
54        return HashMap::new();
55    };
56    entries
57        .flatten()
58        .filter_map(|e| {
59            let name = e.file_name().to_string_lossy().to_string();
60            if name == "lo" {
61                return None;
62            }
63            let info = read_interface_info(&name);
64            Some((name, info))
65        })
66        .collect()
67}
68
69// ---------------------------------------------------------------------------
70// Dynamic link state - polled each interval
71// ---------------------------------------------------------------------------
72
73fn read_operstate(iface: &str) -> Option<String> {
74    net_attr(iface, "operstate")
75}
76
77fn read_speed_mbps(iface: &str) -> Option<i64> {
78    net_attr(iface, "speed")?.parse().ok()
79}
80
81fn read_mtu(iface: &str) -> Option<u32> {
82    net_attr(iface, "mtu")?.parse().ok()
83}
84
85// ---------------------------------------------------------------------------
86// Delta snapshot + Collector
87// ---------------------------------------------------------------------------
88
89struct Snapshot {
90    instant: Instant,
91    rx_bytes: HashMap<String, u64>,
92    tx_bytes: HashMap<String, u64>,
93}
94
95pub struct NetworkCollector {
96    /// Static hardware identity, cached once in new().
97    iface_cache: HashMap<String, InterfaceInfo>,
98    prev: Option<Snapshot>,
99}
100
101impl NetworkCollector {
102    pub fn new() -> Self {
103        Self {
104            iface_cache: discover_interfaces(),
105            prev: None,
106        }
107    }
108
109    pub fn collect(&mut self) -> Result<Vec<NetworkMetrics>> {
110        let devs = dev_status()?;
111        let now = Instant::now();
112
113        let rx_bytes: HashMap<String, u64> = devs
114            .iter()
115            .map(|(name, s)| (name.clone(), s.recv_bytes))
116            .collect();
117        let tx_bytes: HashMap<String, u64> = devs
118            .iter()
119            .map(|(name, s)| (name.clone(), s.sent_bytes))
120            .collect();
121
122        let mut metrics: Vec<NetworkMetrics> = devs
123            .keys()
124            .filter(|n| *n != "lo")
125            .map(|name| {
126                let info = self.iface_cache.get(name);
127
128                let (rx_bps, tx_bps) = match &self.prev {
129                    None => (0.0, 0.0),
130                    Some(prev) => {
131                        let secs = (now - prev.instant).as_secs_f64().max(0.001);
132                        let rx = rx_bytes[name];
133                        let tx = tx_bytes[name];
134                        let prx = prev.rx_bytes.get(name).copied().unwrap_or(rx);
135                        let ptx = prev.tx_bytes.get(name).copied().unwrap_or(tx);
136                        (
137                            rx.saturating_sub(prx) as f64 / secs,
138                            tx.saturating_sub(ptx) as f64 / secs,
139                        )
140                    }
141                };
142
143                NetworkMetrics {
144                    interface: name.clone(),
145                    mac_address: info.and_then(|i| i.mac_address.clone()),
146                    driver: info.and_then(|i| i.driver.clone()),
147                    operstate: read_operstate(name),
148                    speed_mbps: read_speed_mbps(name),
149                    mtu: read_mtu(name),
150                    rx_bytes_per_sec: rx_bps,
151                    tx_bytes_per_sec: tx_bps,
152                    rx_bytes_total: rx_bytes[name],
153                    tx_bytes_total: tx_bytes[name],
154                }
155            })
156            .collect();
157
158        metrics.sort_by(|a, b| a.interface.cmp(&b.interface));
159        self.prev = Some(Snapshot {
160            instant: now,
161            rx_bytes,
162            tx_bytes,
163        });
164        Ok(metrics)
165    }
166}
167
168// ---------------------------------------------------------------------------
169// Unit tests
170// ---------------------------------------------------------------------------
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    // T-NET-01: first collect() returns Ok; all rates are 0.0 (no prior snapshot).
177    #[test]
178    fn test_network_first_collect_rates_zero() {
179        let mut collector = NetworkCollector::new();
180        let metrics = collector.collect().expect("first collect() failed");
181        metrics.iter().for_each(|m| {
182            assert_eq!(
183                m.rx_bytes_per_sec, 0.0,
184                "rx_bytes_per_sec must be 0.0 on first collect for {}",
185                m.interface
186            );
187            assert_eq!(
188                m.tx_bytes_per_sec, 0.0,
189                "tx_bytes_per_sec must be 0.0 on first collect for {}",
190                m.interface
191            );
192        });
193    }
194
195    // T-NET-02: second collect() returns Ok; all rates are >= 0.0.
196    #[test]
197    fn test_network_second_collect_rates_nonneg() {
198        let mut collector = NetworkCollector::new();
199        let _ = collector.collect().expect("first collect() failed");
200        let metrics = collector.collect().expect("second collect() failed");
201        metrics.iter().for_each(|m| {
202            assert!(
203                m.rx_bytes_per_sec >= 0.0,
204                "rx_bytes_per_sec must be >= 0.0 for {}",
205                m.interface
206            );
207            assert!(
208                m.tx_bytes_per_sec >= 0.0,
209                "tx_bytes_per_sec must be >= 0.0 for {}",
210                m.interface
211            );
212        });
213    }
214
215    // T-NET-03: loopback ("lo") is excluded; results are sorted alphabetically.
216    #[test]
217    fn test_network_no_loopback_sorted() {
218        let mut collector = NetworkCollector::new();
219        let metrics = collector.collect().expect("collect() failed");
220        metrics.iter().for_each(|m| {
221            assert_ne!(m.interface, "lo", "loopback must not appear in results");
222        });
223        let names: Vec<&str> = metrics.iter().map(|m| m.interface.as_str()).collect();
224        let mut sorted = names.clone();
225        sorted.sort();
226        assert_eq!(names, sorted, "interfaces must be sorted alphabetically");
227    }
228
229    // T-NET-04: cumulative totals are non-decreasing between two consecutive calls.
230    #[test]
231    fn test_network_totals_nondecreasing() {
232        let mut collector = NetworkCollector::new();
233        let first = collector.collect().expect("first collect() failed");
234        let second = collector.collect().expect("second collect() failed");
235        let first_map: std::collections::HashMap<&str, (u64, u64)> = first
236            .iter()
237            .map(|m| (m.interface.as_str(), (m.rx_bytes_total, m.tx_bytes_total)))
238            .collect();
239        second.iter().for_each(|m| {
240            if let Some(&(prev_rx, prev_tx)) = first_map.get(m.interface.as_str()) {
241                assert!(
242                    m.rx_bytes_total >= prev_rx,
243                    "rx_bytes_total decreased for {}: {} < {}",
244                    m.interface,
245                    m.rx_bytes_total,
246                    prev_rx
247                );
248                assert!(
249                    m.tx_bytes_total >= prev_tx,
250                    "tx_bytes_total decreased for {}: {} < {}",
251                    m.interface,
252                    m.tx_bytes_total,
253                    prev_tx
254                );
255            }
256        });
257    }
258}