Cogs and Levers A blog full of technical stuff

Writing a Key Value Server in Rust

Introduction

In today’s post, we’ll build a simple key value server; but we’ll do it in an iterative way. We’ll build it up simple and then add safety, concurrency, and networking as we go.

Implementation

Now we’ll get started with our iterations. The finished code will be available at the end of this post.

Baseline

All of our implementations will deal with a KeyValueStore struct. This struct will hold all of the variables that we want to keep track of in our server.

use std::collections::HashMap;

struct KeyValueStore {
    data: HashMap<String, String>,
}

We define data as the in-memory representation of our database. We use String keys and store String values.

Our implementation is very basic. All we’re really doing is shadowing the functionality that HashMap provides.

impl KeyValueStore {
    fn new() -> Self {
        Self {
            data: HashMap::new(),
        }
    }

    fn insert(&mut self, key: String, value: String) {
        self.data.insert(key, value);
    }

    fn get(&self, key: &str) -> Option<&String> {
        self.data.get(key)
    }

    fn delete(&mut self, key: &str) {
        self.data.remove(key);
    }
}

This is a pretty decent starting point. We can use our KeyValueStore in some basic tests:

fn main() {
    let mut store = KeyValueStore::new();
    store.insert("key1".to_string(), "value1".to_string());
    println!("{:?}", store.get("key1"));
    store.delete("key1");
    println!("{:?}", store.get("key1"));
}

Variants

String is pretty limiting to store as far as the value side is concerned. We can upgrade this to specifically use data types that we will find useful via an enum:

#[derive(Debug, Clone)]
enum Value {
    String(String),
    Integer(i64),
    Float(f64),
    Boolean(bool),
    Binary(Vec<u8>),
    // Add more variants as needed
}

We can swap out the value side of our data member now, too.

struct KeyValueStore {
    data: HashMap<String, Value>,
}

The implementation simply swaps the String for Value:

impl KeyValueStore {
    fn new() -> Self {
        Self {
            data: HashMap::new(),
        }
    }

    fn insert(&mut self, key: String, value: Value) {
        self.data.insert(key, value);
    }

    fn get(&self, key: &str) -> Option<&Value> {
        self.data.get(key)
    }

    fn delete(&mut self, key: &str) {
        self.data.remove(key);
    }
}

We’re now able to not only store strings. We can store integers, floats, binary, and booleans. This makes our key value store a lot more versatile.

Thread Safety

We will have multiple threads of execution trying to perform actions on this structure at the same time, so we will add some thread safety to the process now. Wrapping data in Arc will give us a thread safe, reference counting pointer. We’re also going to need to lock this data structure for reading and for writing. We can use RwLock to take care of that for us.

We update our data structure to include these new types:

struct KeyValueStore {
    data: Arc<RwLock<HashMap<String, Value>>>,
}

Now our implementation functions need to change to work with these new structures. We can keep the structure of functions the same though.

impl KeyValueStore {
    fn new() -> Self {
        Self {
            data: Arc::new(RwLock::new(HashMap::new())),
        }
    }

    fn insert(&self, key: String, value: Value) {
        let mut locked = self.data.write().unwrap();
        locked.insert(key, value);
    }

    fn get(&self, key: &str) -> Option<Value> {
        let mut locked = self.data.read().unwrap();
        locked.get(key).cloned()
    }

    fn delete(&self, key: &str) {
        let mut locked = self.data.write().unwrap();
        locked.remove(key);
    }
}

These functions are now safe, which means calling code can be multithreaded and we can guaranteed that our data structure will be treated consistently.

fn main() {
    let store = Arc::new(KeyValueStore::new());

    // Create a vector to hold thread handles
    let mut handles = vec![];

    // Spawn threads to perform inserts
    for i in 0..5 {
        let store = Arc::clone(&store);
        let handle = thread::spawn(move || {
            let key = format!("key{}", i);
            let value = Value::Integer(i * 10);
            store.insert(key.clone(), value);
            println!("Thread {} inserted: {}", i, key);
        });
        handles.push(handle);
    }

    // Spawn threads to read values
    for i in 0..5 {
        let store = Arc::clone(&store);
        let handle = thread::spawn(move || {
            let key = format!("key{}", i);
            if let Some(value) = store.get(&key) {
                println!("Thread {} read: {} -> {:?}", i, key, value);
            } else {
                println!("Thread {} could not find: {}", i, key);
            }
        });
        handles.push(handle);
    }

    // Spawn threads to delete keys
    for i in 0..5 {
        let store = Arc::clone(&store);
        let handle = thread::spawn(move || {
            let key = format!("key{}", i);
            store.delete(&key);
            println!("Thread {} deleted: {}", i, key);
        });
        handles.push(handle);
    }

    // Wait for all threads to complete
    for handle in handles {
        handle.join().unwrap();
    }

    println!("Final state of the store: {:?}", store.data.read().unwrap());
}

Error handling

You can see that we’re using unwrap in the implementation functions, which might be ok for tests or short scripts. If we’re going to expect to run this code in production, we’d be best replacing these with actual error handling counterparts.

In order to do that, we need to define our error domain first. We create an enum called StoreError. As we fill out our implementation, we’ll run into a number of different error cases. We’ll use StoreError to centralise all of these errors so we can express them clearly.

#[derive(Debug)]
enum StoreError {
    LockError(String),
    KeyNotFound(String),
}

impl<T> From<PoisonError<T>> for StoreError {
    fn from(err: PoisonError<T>) -> Self {
        StoreError::LockError(format!("Lock poisoned: {}", err))
    }
}

We’ve implemented PoisonError for our StoreError because the PoisonError type is an error which can be returned whenever a lock is acquired. If something goes wrong and we’ve acquired a lock, it’s a PoisonError that’s used.

Our insert, get, and delete methods now need an upgrade. We’ll be returning Result<T, E> values from our functions now to accomodate potential failures.

fn insert(&self, key: String, value: Value) -> Result<(), StoreError> {
    let mut locked = self.data.write()?;
    locked.insert(key, value);
    Ok(())
}

fn get(&self, key: &str) -> Result<Option<Value>, StoreError> {
    let locked = self.data.read()?;
    Ok(locked.get(key).cloned()) // Clone the value to return an owned copy
}

fn delete(&self, key: &str) -> Result<(), StoreError> {
    let mut locked = self.data.write()?;
    if locked.remove(key).is_none() {
        return Err(StoreError::KeyNotFound(key.to_string()));
    }
    Ok(())
}

We’ve removed the use of unwrap now, swapping out to using the ? operator. This will allow us to actually handle any failure that is bubbled out of calling code.

Using the File System

We need to be able to persist the state of our key value store out to disk for durability. In order to do this, we need to keep track of where we’ll write the file. We add a file_path member to our structure:

struct KeyValueStore {
    data: Arc<RwLock<HashMap<String, Value>>>,
    file_path: Option<String>,
}

impl KeyValueStore {
    fn new(file_path: Option<String>) -> Self {
        Self {
            data: Arc::new(RwLock::new(HashMap::new())),
            file_path,
        }
    }
}

Starting out this implementation simply, we just write a load and save function that we can call at any time. Before we do this we need some extra dependencies added for serialisation:

[dependencies]
serde = { version = "1.0.217", features = ["derive"] }
serde_json = "1.0.137"

This will allow us to reduce our internal state to JSON.

Loading the database off disk

/// Load the state from a file
fn load(&self) -> Result<(), StoreError> {
    if let Some(ref path) = self.file_path {
        match fs::read_to_string(path) {
            Ok(contents) => {
                let deserialized: HashMap<String, Value> = serde_json::from_str(&contents)?;
                let mut locked = self.data.write()?;
                *locked = deserialized; // Replace the current state with the loaded one
                Ok(())
            }
            Err(e) if e.kind() == ErrorKind::NotFound => {
                // File doesn't exist, just return Ok (no data to load)
                Ok(())
            }
            Err(e) => Err(e.into()),
        }
    } else {
        Err(StoreError::IoError("File path not set".to_string()))
    }
}

We need to make sure that a file_path was specified. We read everything off from the file into contents as a big string. Using serde_json::from_str we can turn that contents into the deserialised representation. From there, we simply swap out the underlying content.

We’ve got some new errors to deal with here in IoError.

#[derive(Debug)]
enum StoreError {
    LockError(String),
    KeyNotFound(String),
    IoError(String),
    SerdeError(String),
}

This will be used for our write implementation which looks like this:

/// Save the current state to a file
fn save(&self) -> Result<(), StoreError> {
    if let Some(ref path) = self.file_path {
        let locked = self.data.read()?;
        let serialized = serde_json::to_string(&*locked)?;
        fs::write(path, serialized)?;
        Ok(())
    } else {
        Err(StoreError::IoError("File path not set".to_string()))
    }
}

The magic here really is the serde_json::to_string taking our internal state and writing it as json.

An example of how this looks is like this:

{
    "key2":{"Integer":20},
    "key4":{"Integer":40},
    "key1":{"Integer":10},
    "key3":{"Integer":30},
    "key0":{"Integer":0}
}

Networking

Finally, we’ll add some networking to the solution. A really basic network interface will allow remote clients to perform the get, set, and delete operations for us.

The handle_client function is the heart of the server process, performing the needed processing on incoming requests and routing them to the database instance:

fn handle_client(mut stream: TcpStream, store: Arc<KeyValueStore>) {
    let mut buffer = [0; 512];

    // Read the incoming request
    match stream.read(&mut buffer) {
        Ok(_) => {
            let request = String::from_utf8_lossy(&buffer);
            let mut parts = request.trim().split_whitespace();
            let command = parts.next();

            let response = match command {
                Some("SET") => {
                    let key = parts.next().unwrap_or_default().to_string();
                    let value = parts.next().unwrap_or_default().to_string();
                    store.insert(key, Value::String(value));
                    "OK\n".to_string()
                }
                Some("GET") => {
                    let key = parts.next().unwrap_or_default();
                    if let Ok(Some(value)) = store.get(key) {
                        format!("{:?}\n", value)
                    } else {
                        "Key not found\n".to_string()
                    }
                }
                Some("DEL") => {
                    let key = parts.next().unwrap_or_default();
                    store.delete(key);
                    "OK\n".to_string()
                }
                _ => "Unknown command\n".to_string(),
            };

            // Send the response back to the client
            stream.write_all(response.as_bytes()).unwrap();
        }
        Err(e) => eprintln!("Failed to read from socket: {}", e),
    }
}

Out networking “protocol” looks like this:

-- set the key "key1" to the value "hello"
SET key1 hello

-- get the value of the key "key1"
GET key1

-- remove the value and key "key1"
DEL key1

This is all made possible by the following:

let request = String::from_utf8_lossy(&buffer);
let mut parts = request.trim().split_whitespace();
let command = parts.next();

We read in the request data from the client into request. This gets split up on white spaces into parts with command given the first of these parts. The code is expecting command to be either SET, GET, or DEL that is then handled in the following pattern match.

This function gets mounted onto the server in the main function which now looks like this:

fn main() {
    let store = Arc::new(
        KeyValueStore::new(None)
    );
    let listener = TcpListener::bind("127.0.0.1:7878").unwrap();

    println!("Server running on 127.0.0.1:7878");

    for stream in listener.incoming() {
        match stream {
            Ok(stream) => {
                let store = Arc::clone(&store);
                std::thread::spawn(move || handle_client(stream, store));
            }
            Err(e) => eprintln!("Connection failed: {}", e),
        }
    }
}

We’re starting our server on port 7878 and handling each connection with our handle_client function.

Running this and giving it a test with telnet gives us the following:

➜  telnet 127.0.0.1 7878
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
SET key1 hello
OK
Connection closed by foreign host.

➜  telnet 127.0.0.1 7878
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
GET key1
String("hello")
Connection closed by foreign host.

➜  telnet 127.0.0.1 7878
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
DEL key1
OK
Connection closed by foreign host.

➜  telnet 127.0.0.1 7878
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
GET key1
Key not found
Connection closed by foreign host.

So, it works. It’s crude and needs to be patched to be a little more production ready than this - but this is a start.

Conclusion

In this article, we walked through building a thread-safe, persistent key-value store in Rust. We started with a simple in-memory implementation and iteratively improved it by:

  • Adding support for multiple data types using an enum.
  • Ensuring thread safety with RwLock and Arc.
  • Replacing unwrap with proper error handling.
  • Adding file persistence using JSON serialization and deserialization.
  • Added some basic network access

This provides a solid foundation for a more robust and scalable key-value server. Next steps could include:

  • Implementing advanced features like snapshots or replication.
  • Optimizing for performance with tools like async I/O or a custom storage engine.
  • Partial reads and memory mapping
  • Clustering

The full implementation can be found here.

Implementing an ML Model in Rust

Introduction

Rust, known for its performance, memory safety, and low-level control, is gaining traction in domains traditionally dominated by Python, such as machine learning (ML). While Python is the go-to for prototyping ML models due to its mature ecosystem, Rust shines in scenarios demanding high performance, safety, and seamless system-level integration.

In this post, we’ll explore how to implement logistic regression in Rust and discuss the implications of the model’s output.

Why use Rust?

Before diving into code, it’s worth asking: why choose Rust for ML when Python’s libraries like TensorFlow and PyTorch exist?

Benefits of Rust:

  • Performance: Rust offers near-C speeds, making it ideal for performance-critical tasks.
  • Memory Safety: Its ownership model ensures memory safety, preventing bugs like segmentation faults and data races.
  • Integration: Rust can easily integrate with low-level systems, making it a great choice for embedding ML models into IoT, edge devices, or game engines.
  • Control: Rust provides fine-grained control over execution, allowing developers to optimize their models at a deeper level.

While Rust’s ML ecosystem is still evolving, libraries like ndarray, linfa, and smartcore provide foundational tools for implementing machine learning models.

Logistic Regression

Logistic regression is a simple yet powerful algorithm for binary classification. It predicts whether a data point belongs to class 0 or 1 based on a weighted sum of features passed through a sigmoid function.

Below is a Rust implementation of logistic regression using the ndarray crate for numerical operations.

use ndarray::{Array2, Array1};
use ndarray_rand::RandomExt;
use ndarray_rand::rand_distr::Uniform;

fn sigmoid(x: f64) -> f64 {
    1.0 / (1.0 + (-x).exp())
}

fn logistic_regression(X: &Array2<f64>, y: &Array1<f64>, learning_rate: f64, epochs: usize) -> Array1<f64> {
    let (n_samples, n_features) = X.dim();
    let mut weights = Array1::<f64>::random(n_features, Uniform::new(-0.01, 0.01));
    let mut bias = 0.0;

    for _ in 0..epochs {
        let linear_model = X.dot(&weights) + bias;
        let predictions = linear_model.mapv(sigmoid);

        // Compute the error
        let error = &predictions - y;

        // Compute gradients
        let gradient_weights = X.t().dot(&error) / n_samples as f64;
        let gradient_bias = error.sum() / n_samples as f64;

        // Update weights and bias
        weights -= &(learning_rate * gradient_weights);
        bias -= learning_rate * gradient_bias;
    }

    weights
}

fn main() {
    let X = Array2::random((100, 2), Uniform::new(-1.0, 1.0)); // Random features
    let y = Array1::random(100, Uniform::new(0.0, 1.0)).mapv(|v| if v > 0.5 { 1.0 } else { 0.0 }); // Random labels

    let weights = logistic_regression(&X, &y, 0.01, 1000);
    println!("Trained Weights: {:?}", weights);
}

Key Concepts:

  • Sigmoid Function: Converts the linear combination of inputs into a value between 0 and 1.
  • Gradient Descent: Updates weights and bias iteratively to minimize the error between predictions and actual labels.
  • Random Initialization: Weights start with small random values and are fine-tuned during training.

Output

When you run the code, you’ll see output similar to this:

Trained Weights: [0.034283492207871635, 0.3083430316223569], shape=[2], strides=[1], layout=CFcf (0xf), const ndim=1

What Does This Mean?

  1. Weights: Each weight corresponds to a feature in your dataset. For example, with 2 input features, the model learns two weights.
    • A positive weight means the feature increases the likelihood of predicting 1.
    • A negative weight means the feature decreases the likelihood of predicting 1.
  2. Bias (Optional): The bias adjusts the decision boundary and accounts for data not centered at zero. To view the bias, modify the print statement:
println!("Trained Weights: {:?}, Bias: {}", weights, bias);
  1. Predictions: To test the model, use new data points and calculate their predictions:
let new_data = array![0.5, -0.2];
let linear_combination = new_data.dot(&weights) + bias;
let prediction = sigmoid(linear_combination);
println!("Prediction Probability: {}", prediction);

Predictions close to 1 indicate class 1, while predictions close to 0 indicate class 0.

Why Does This Matter?

This simple implementation demonstrates the flexibility and control Rust provides for machine learning tasks. While Python excels in rapid prototyping, Rust’s performance and safety make it ideal for deploying models in production, especially in resource-constrained or latency-critical environments.

When Should You Use Rust for ML?

Rust is a great choice if:

  • Performance is critical: For example, in real-time systems or embedded devices.
  • Memory safety is a priority: Rust eliminates common bugs like memory leaks.
  • Integration with system-level components is needed: Rust can seamlessly work in environments where Python may not be ideal.
  • Custom ML Implementations: You want more control over how the algorithms are built and optimized.

For research or quick prototyping, Python remains the best choice due to its rich ecosystem and community. However, for production-grade systems, Rust’s strengths make it a compelling alternative.

Conclusion

While Rust’s machine learning ecosystem is still maturing, it’s already capable of handling fundamental ML tasks like logistic regression. By combining performance, safety, and control, Rust offers a unique proposition for ML developers building high-performance or production-critical applications.

Writing a Custom Loss Function for a Neural Network

Introdution

Loss functions are the unsung heroes of machine learning. They guide the learning process by quantifying the difference between the predicted and actual outputs. While frameworks like PyTorch and TensorFlow offer a plethora of standard loss functions such as Cross-Entropy and Mean Squared Error, there are times when a custom loss function is necessary.

In this post, we’ll explore the why and how of custom loss functions by:

  1. Setting up a simple neural network.
  2. Using standard loss functions to train the model.
  3. Introducing and implementing custom loss functions tailored to specific needs.

Pre-reqs

Before we begin, you’ll need to setup a python project and install some dependencies. We’ll be using PyTorch and torchvision. To install these dependencies, use the following command:

pip install torch torchvision

Once installed, verify the installation by running:

python -c "import torch; print(torch.__version__)"

Network Setup

Let’s start by creating a simple neural network to classify data. For simplicity, we’ll use a toy dataset like the MNIST digits dataset.

Dataet preparation

  • Use the MNIST dataset (handwritten digits) as an example.
  • Normalize the dataset for faster convergence during training.
import torch
import torch.optim as optim
from torchvision import datasets, transforms

# Data preparation
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_data = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
train_loader = torch.utils.data.DataLoader(train_data, batch_size=64, shuffle=True)

Model Architecture

  • Input layer flattens the 28x28 pixel images into a single vector.
  • Two hidden layers with 128 and 64 neurons, each followed by a ReLU activation.
  • An output layer with 10 neurons (one for each digit) and no activation (handled by the loss function).
# Simple Neural Network
import torch.nn as nn

class SimpleNN(nn.Module):
    def __init__(self):
        super(SimpleNN, self).__init__()
        self.fc1 = nn.Linear(28 * 28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)

    def forward(self, x):
        x = x.view(x.size(0), -1)  # Flatten the input
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

Training Setup:

  • Use an optimizer (e.g., Adam) and CrossEntropyLoss for training.
  • Loop over the dataset for a fixed number of epochs, computing loss and updating weights.
# Initialize model, optimizer, and device
model = SimpleNN()
optimizer = optim.Adam(model.parameters(), lr=0.001)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

Standard Loss

Let’s train the model using the standard Cross-Entropy Loss, which is suitable for classification tasks.

  • Combines log_softmax and negative log likelihood into one step.
  • Suitable for classification tasks as it penalizes incorrect predictions heavily.
# Standard loss function
criterion = nn.CrossEntropyLoss()

# Training loop
def train_model(model, train_loader, criterion, optimizer, epochs=5):
    model.train()
    for epoch in range(epochs):
        total_loss = 0
        for images, labels in train_loader:
            images, labels = images.to(device), labels.to(device)

            # Forward pass
            outputs = model(images)
            loss = criterion(outputs, labels)

            # Backward pass and optimization
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}')

train_model(model, train_loader, criterion, optimizer)

The output of this training session should look something like this:

Epoch 1/5, Loss: 0.3932
Epoch 2/5, Loss: 0.1834
Epoch 3/5, Loss: 0.1352
Epoch 4/5, Loss: 0.1054
Epoch 5/5, Loss: 0.0914

Custom Loss

Why Custom Loss Functions?

Standard loss functions may not work well in cases like:

  • Imbalanced Datasets: Classes have significantly different frequencies.
  • Multi-Task Learning: Different tasks require different weights.
  • Task-Specific Goals: Optimizing for metrics like precision or recall rather than accuracy.

Example: Weighted Loss

Suppose we want to penalize misclassifying certain classes more heavily. We can achieve this by implementing a weighted Cross-Entropy Loss.

# Custom weighted loss function
class WeightedCrossEntropyLoss(nn.Module):
    def __init__(self, class_weights):
        super(WeightedCrossEntropyLoss, self).__init__()
        self.class_weights = torch.tensor(class_weights).to(device)

    def forward(self, outputs, targets):
        log_probs = torch.log_softmax(outputs, dim=1)
        loss = -torch.sum(self.class_weights[targets] * log_probs[range(len(targets)), targets]) / len(targets)
        return loss

# Example: Higher weight for class 0
class_weights = [2.0 if i == 0 else 1.0 for i in range(10)]
custom_criterion = WeightedCrossEntropyLoss(class_weights)

# Training with custom loss function
train_model(model, train_loader, custom_criterion, optimizer)

After running this, you should see output like the following:

Epoch 1/5, Loss: 0.4222
Epoch 2/5, Loss: 0.1970
Epoch 3/5, Loss: 0.1390
Epoch 4/5, Loss: 0.1124
Epoch 5/5, Loss: 0.0976

Example: Combining Losses

Sometimes, you might want to combine multiple objectives into a single loss function.

# Custom loss combining Cross-Entropy and L1 regularization
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.1):
        super(CombinedLoss, self).__init__()
        self.ce_loss = nn.CrossEntropyLoss()
        self.alpha = alpha

    def forward(self, outputs, targets, model):
        ce_loss = self.ce_loss(outputs, targets)
        l1_loss = sum(torch.sum(torch.abs(param)) for param in model.parameters())
        return ce_loss + self.alpha * l1_loss

custom_criterion = CombinedLoss(alpha=0.01)

# Training with combined loss
train_model(model, train_loader, lambda outputs, targets: custom_criterion(outputs, targets, model), optimizer)

Comparing Results

To compare the results of standard and custom loss functions, you need to evaluate the following:

  1. Training Loss:
    • Plot the loss per epoch for both standard and custom loss functions.
  2. Accuracy:
    • Measure training and validation accuracy after each epoch.
    • Compare how well the model performs in predicting each class.
  3. Precision and Recall:
    • Useful for imbalanced datasets to measure performance on minority classes.
  4. Visualization:
    • Confusion matrix: Visualize how often each class is misclassified.
    • Loss curve: Show convergence speed and stability for different loss functions.

We can use graphs to visualise how these metrics perform:

from sklearn.metrics import classification_report, confusion_matrix
import matplotlib.pyplot as plt
import numpy as np

# After training
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for images, labels in train_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model(images)
        preds = torch.argmax(outputs, dim=1)
        all_preds.extend(preds.cpu().numpy())
        all_labels.extend(labels.cpu().numpy())

# Confusion Matrix
cm = confusion_matrix(all_labels, all_preds)
plt.imshow(cm, cmap='Blues')
plt.title('Confusion Matrix')
plt.colorbar()
plt.show()

# Classification Report
print(classification_report(all_labels, all_preds))

We can also produce visualisations of our loss curves:

# Assuming loss values are stored during training
plt.plot(range(len(train_losses)), train_losses, label="Standard Loss")
plt.plot(range(len(custom_losses)), custom_losses, label="Custom Loss")
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.title('Loss Curve')
plt.show()

Conclusion

Custom loss functions empower you to fine-tune your neural networks for unique problems. By carefully designing and experimenting with loss functions, you can align your model’s learning process with the specific goals of your application.

Some closing tips for custom loss functions:

  • Always start with a simple baseline (e.g., Cross-Entropy Loss) to understand your model’s behavior.
  • Visualize performance across metrics, especially when using weighted or multi-objective losses.
  • Experiment with different weights and loss combinations to find the optimal setup for your task.

The key is to balance complexity and interpretability—sometimes, even simple tweaks can significantly impact performance.

Implementing an LRU Cache in Rust

Introduction

When building high-performance software, caches often play a vital role in optimizing performance by reducing redundant computations or avoiding repeated I/O operations. One such common caching strategy is the Least Recently Used (LRU) cache, which ensures that the most recently accessed data stays available while evicting the least accessed items when space runs out.

What Is an LRU Cache?

At its core, an LRU cache stores a limited number of key-value pairs. When you access or insert an item:

  • If the item exists, it is marked as “recently used.”
  • If the item doesn’t exist and the cache is full, the least recently used item is evicted to make space for the new one.

LRU caches are particularly useful in scenarios where access patterns favor recently used data, such as:

  • Web page caching in browsers.
  • Database query caching for repeated queries.
  • API response caching to reduce repeated external requests.

In this post, we’ll build a simple and functional implementation of an LRU cache in Rust. Instead of diving into complex data structures like custom linked lists, we’ll leverage Rust’s standard library collections (HashMap and VecDeque) to achieve:

  • Constant-time access and updates using HashMap.
  • Efficient tracking of usage order with VecDeque.

  • This straightforward approach is easy to follow and demonstrates Rust’s powerful ownership model and memory safety.

LRUCache Structure

We’ll begin with a struct that defines the cache:

pub struct LRUCache<K, V> {
    capacity: usize,                 // Maximum number of items the cache can hold
    map: HashMap<K, V>,              // Key-value store
    order: VecDeque<K>,              // Tracks the order of key usage
}

This structure holds:

  1. capacity: The maximum number of items the cache can store.
  2. map: The main storage for key-value pairs.
  3. order: A queue to maintain the usage order of keys.

Implementation

Our implementation of LRUCache includes some constraints on the generic types K (key) and V (value). Specifically, the K type requires the following traits:

impl<K: Clone + Eq + std::hash::Hash + PartialEq, V> LRUCache<K, V> {
}

The Clone trait allows us to create a copy of the key when needed (via .clone()). Eq is a trait that ensure that keys can be compared for equality and are either strictly equal or not. The Hash trait enables us to hash the keys which is a requirement for using HashMap, and finally the PartialEq trait allows for equality comparisons between two keys.

Technically Eq should already imply PartialEq but we explicity include it here for clarity.

Create the Cache

To initialize the cache, we add a new method:

pub fn new(capacity: usize) -> Self {
    LRUCache {
        capacity,
        map: HashMap::with_capacity(capacity),
        order: VecDeque::with_capacity(capacity),
    }
}
  • HashMap::with_capacity: Preallocates space for the HashMap to avoid repeated resizing.
  • VecDeque::with_capacity: Allocates space for tracking key usage.

Value access via get

The get method retrieves a value by key and updates its usage order:

pub fn get(&mut self, key: &K) -> Option<&V> {
    if self.map.contains_key(key) {
        // Move the key to the back of the order queue
        self.order.retain(|k| k != key);
        self.order.push_back(key.clone());
        self.map.get(key)
    } else {
        None
    }
}
  • Check if the key exists via contains_key
  • Remove the key from its old position in order and push it to the back
  • Return the vlaue from the HashMap

In cases where a value never existed or has been evicted, this function sends None back to the caller.

Value insertion via put

The put method adds a new key-value pair or updates an existing one:

pub fn put(&mut self, key: K, value: V) {
    if self.map.contains_key(&key) {
        // Update existing key's value and mark it as most recently used
        self.map.insert(key.clone(), value);
        self.order.retain(|k| k != &key);
        self.order.push_back(key);
    } else {
        if self.map.len() == self.capacity {
            // Evict the least recently used item
            if let Some(lru_key) = self.order.pop_front() {
                self.map.remove(&lru_key);
            }
        }
        self.map.insert(key.clone(), value);
        self.order.push_back(key);
    }
}
  • If the key exists
    • The value is updated in map
    • The key is moved to the back of order
  • If the cache is full
    • Remove the least recently used key (which will be the front of order) from map
  • Insert the new key-value pair and mark it as recently used

Size

Finally, we add a helper method to get the current size of the cache:

pub fn len(&self) -> usize {
    self.map.len()
}

Testing

Now we can test our cache:

fn main() {
    let mut cache = LRUCache::new(3);

    cache.put("a", 1);
    cache.put("b", 2);
    cache.put("c", 3);

    println!("{:?}", cache.get(&"a")); // Some(1)
    cache.put("d", 4); // Evicts "b"
    println!("{:?}", cache.get(&"b")); // None
    println!("{:?}", cache.get(&"c")); // Some(3)
    println!("{:?}", cache.get(&"d")); // Some(4)
}

Running this code, we see the following:

Some(1)
None
Some(3)
Some(4)

Conclusion

In this post, we built a simple yet functional LRU cache in Rust. A full implementation can be found as a gist here.

While this implementation is perfect for understanding the basic principles, it can be extended further with:

  • Thread safety using synchronization primitives like Mutex or RwLock.
  • Custom linked structures for more efficient eviction and insertion.
  • Diagnostics and monitoring to observe cache performance in real-world scenarios.

If you’re looking for a robust cache for production, libraries like lru offer feature-rich implementations. But for learning purposes, rolling your own cache is an excellent way to dive deep into Rust’s collections and ownership model.

Building a Packet Sniffer with Raw Sockets in C

Introduction

Network packet sniffing is an essential skill in the toolbox of any systems programmer or network engineer. It enables us to inspect network traffic, debug communication issues, and even learn how various networking protocols function under the hood.

In this article, we will walk through the process of building a simple network packet sniffer in C using raw sockets.

Before we begin, it might help to run through a quick networking primer.

OSI and Networking Layers

Before diving into the code, let’s briefly revisit the OSI model—a conceptual framework that standardizes network communication into seven distinct layers:

  1. Physical Layer: Deals with the physical connection and transmission of raw data bits.
  2. Data Link Layer: Responsible for framing and MAC addressing. Ethernet operates at this layer.
  3. Network Layer: Handles logical addressing (IP addresses) and routing. This layer is where IP packets are structured.
  4. Transport Layer: Ensures reliable data transfer with protocols like TCP and UDP.
  5. Session Layer: Manages sessions between applications.
  6. Presentation Layer: Transforms data formats (e.g., encryption, compression).
  7. Application Layer: Interfaces directly with the user (e.g., HTTP, FTP).

Our packet sniffer focuses on Layers 2 through 4. By analyzing Ethernet, IP, TCP, UDP, and ICMP headers, we gain insights into packet structure and how data travels across a network.

The Code

In this section, we’ll run through the functions that are needed to implement our packet sniffer. The layers that we’ll focus on are:

  • Layer 2 (Data Link): Capturing raw Ethernet frames and extracting MAC addresses.
  • Layer 3 (Network): Parsing IP headers for source and destination IPs.
  • Layer 4 (Transport): Inspecting TCP, UDP, and ICMP protocols to understand port-level communication and message types.

The Data Link Layer is responsible for the physical addressing of devices on a network. It includes the Ethernet header, which contains the source and destination MAC addresses. In this section, we analyze and print the Ethernet header.

void print_eth_header(unsigned char *buffer, int size) { 
    struct ethhdr *eth = (struct ethhdr *)buffer;

    printf("\nEthernet Header\n");
    printf("   |-Source Address      : %.2X-%.2X-%.2X-%.2X-%.2X-%.2X \n",
           eth->h_source[0], eth->h_source[1], eth->h_source[2], eth->h_source[3], eth->h_source[4], eth->h_source[5]);
    printf("   |-Destination Address : %.2X-%.2X-%.2X-%.2X-%.2X-%.2X \n",
           eth->h_dest[0], eth->h_dest[1], eth->h_dest[2], eth->h_dest[3], eth->h_dest[4], eth->h_dest[5]);
    printf("   |-Protocol            : %u \n", (unsigned short)eth->h_proto);
}

Layer 3 (Network)

The Network Layer handles logical addressing and routing. In our code, this corresponds to the IP header, where we extract source and destination IP addresses.

void print_ip_header(unsigned char *buffer, int size) { 
    struct iphdr *ip = (struct iphdr *)(buffer + sizeof(struct ethhdr));

    printf("\nIP Header\n");
    printf("   |-Source IP        : %s\n", inet_ntoa(*(struct in_addr *)&ip->saddr));
    printf("   |-Destination IP   : %s\n", inet_ntoa(*(struct in_addr *)&ip->daddr));
    printf("   |-Protocol         : %d\n", ip->protocol);
}

Here, we use the iphdr structure to parse the IP header. The inet_ntoa function converts the source and destination IP addresses from binary format to a human-readable string.

Layer 4 (Transport)

The Transport Layer ensures reliable data transfer and includes protocols like TCP, UDP, and ICMP. We have specific functions to parse and display these packets:

The TCP version of this function has a source and destination for the packet, but also has a sequence and acknowledgement which are key features for this protocol.

void print_tcp_packet(unsigned char *buffer, int size) {
    struct iphdr *ip = (struct iphdr *)(buffer + sizeof(struct ethhdr));
    struct tcphdr *tcp = (struct tcphdr *)(buffer + sizeof(struct ethhdr) + ip->ihl * 4);

    printf("\nTCP Packet\n");
    print_ip_header(buffer, size);
    printf("\n   |-Source Port      : %u\n", ntohs(tcp->source));
    printf("   |-Destination Port : %u\n", ntohs(tcp->dest));
    printf("   |-Sequence Number  : %u\n", ntohl(tcp->seq));
    printf("   |-Acknowledgement  : %u\n", ntohl(tcp->ack_seq));
}

The UDP counterpart doesn’t have the sequencing or acknowledgement as it’s a general broadcast protocol.

void print_udp_packet(unsigned char *buffer, int size) {
    struct iphdr *ip = (struct iphdr *)(buffer + sizeof(struct ethhdr));
    struct udphdr *udp = (struct udphdr *)(buffer + sizeof(struct ethhdr) + ip->ihl * 4);

    printf("\nUDP Packet\n");
    print_ip_header(buffer, size);
    printf("\n   |-Source Port      : %u\n", ntohs(udp->source));
    printf("   |-Destination Port : %u\n", ntohs(udp->dest));
    printf("   |-Length           : %u\n", ntohs(udp->len));
}

ICMP’s type, code, and checksum are used in the verification process of this protocol.

void print_icmp_packet(unsigned char *buffer, int size) {
    struct iphdr *ip = (struct iphdr *)(buffer + sizeof(struct ethhdr));
    struct icmphdr *icmp = (struct icmphdr *)(buffer + sizeof(struct ethhdr) + ip->ihl * 4);

    printf("\nICMP Packet\n");
    print_ip_header(buffer, size);
    printf("\n   |-Type : %d\n", icmp->type);
    printf("   |-Code : %d\n", icmp->code);
    printf("   |-Checksum : %d\n", ntohs(icmp->checksum));
}

Tying it all together

The architecture of this code is fairly simple. The main function sets up a loop which will continually receive raw information from the socket. From there, a determination is made about what level the information is at. Using this information we’ll call/dispatch to a function that specialises in that layer.

int main() {
    int sock_raw;
    struct sockaddr saddr;
    socklen_t saddr_len = sizeof(saddr);

    unsigned char *buffer = (unsigned char *)malloc(BUFFER_SIZE);
    if (buffer == NULL) {
        perror("Failed to allocate memory");
        return 1;
    }

    sock_raw = socket(AF_PACKET, SOCK_RAW, htons(ETH_P_ALL));
    if (sock_raw < 0) {
        perror("Socket Error");
        free(buffer);
        return 1;
    }

    printf("Starting packet sniffer...\n");

    while (1) {
        int data_size = recvfrom(sock_raw, buffer, BUFFER_SIZE, 0, &saddr, &saddr_len);
        if (data_size < 0) {
            perror("Failed to receive packets");
            break;
        }
        process_packet(buffer, data_size);
    }

    close(sock_raw);
    free(buffer);
    return 0;
}

The recvfrom receives the raw bytes in from the socket.

The process_packet function is responsible for the dispatch of the information. This is really a switch statement focused on the incoming protocol:

void process_packet(unsigned char *buffer, int size) {
    struct iphdr *ip_header = (struct iphdr *)(buffer + sizeof(struct ethhdr));

    switch (ip_header->protocol) {
        case IPPROTO_TCP:
            print_tcp_packet(buffer, size);
            break;
        case IPPROTO_UDP:
            print_udp_packet(buffer, size);
            break;
        case IPPROTO_ICMP:
            print_icmp_packet(buffer, size);
            break;
        default:
            print_ip_header(buffer, size);
            break;
    }
}

This then ties all of our functions in together.

Running

Because of the nature of the information that this application will pull from your system, you will need to run this as root. You need that low-level access to your networking stack.

sudo ./psniff

Conclusion

Building a network packet sniffer using raw sockets in C offers valuable insight into how data flows through the network stack and how different protocols interact. By breaking down packets layer by layer—from the Data Link Layer (Ethernet) to the Transport Layer (TCP, UDP, ICMP)—we gain a deeper understanding of networking concepts and system-level programming.

This project demonstrates key topics such as:

  • Capturing raw packets using sockets.
  • Parsing headers to extract meaningful information.
  • Mapping functionality to specific OSI layers.

Packet sniffers like this are not only useful for learning but also serve as foundational tools for network diagnostics, debugging, and security monitoring. However, it’s essential to use such tools ethically and responsibly, adhering to legal and organizational guidelines.

In the future, we could extend this sniffer by writing packet payloads to a file, adding packet filtering (e.g., only capturing HTTP or DNS traffic), or even integrating with libraries like libpcap for more advanced use cases.

A full gist of this code is available to check out.