Rust is celebrated for its emphasis on safety and performance, largely thanks to its robust compile-time checks.
However, there are situations where you need to bypass these checks to perform low-level operations—this is where
Rust’s unsafe keyword comes in. While unsafe opens the door to powerful features, it also comes with significant
risks.
The solution?
Encapsulating unsafe code in safe abstractions.
This post explores what that means, why it’s important, and how to do it effectively.
Understanding unsafe in Rust
Rust enforces strict memory safety guarantees by default. However, some operations are inherently unsafe and require
explicit acknowledgment from the programmer. These include:
Raw pointer manipulation: Directly accessing memory without bounds or validity checks.
Foreign Function Interface (FFI): Interacting with non-Rust code (e.g., calling C functions).
Manual memory management: Allocating and freeing memory without Rust’s usual safeguards.
Concurrency primitives: Implementing data structures that require custom synchronization logic.
When you write unsafe code, you’re essentially telling the compiler, “I know what I’m doing; trust me.”
While this is sometimes necessary, it’s critical to minimize the potential for misuse by others.
Why Wrap Unsafe Code in Safe Abstractions?
Using unsafe is a trade-off. It gives you access to low-level features and optimizations but requires you to
manually uphold the invariants that Rust would otherwise enforce. Safe abstractions address this challenge by:
Avoiding Undefined Behavior: Preventing common pitfalls like null pointer dereferences, data races, or buffer overflows.
Improving Maintainability: Reducing the scattering of unsafe blocks across the codebase makes it easier to audit and debug.
Providing Ease of Use: Enabling most developers to rely on Rust’s safety guarantees without needing to understand the intricacies of the underlying unsafe implementation.
What is a Safe Abstraction?
A safe abstraction is an API or module where the internal implementation may use unsafe code, but the external
interface ensures that incorrect usage is either impossible or extremely difficult.
Let’s look at how to create one.
Example: Safe Wrapping of Unsafe Memory Allocation
Here’s a simplified example of wrapping unsafe memory management into a safe abstraction:
pubstructSafeAllocator{// Internal raw pointer or other unsafe constructsptr:*mutu8,size:usize,}implSafeAllocator{pubfnnew(size:usize)->Self{letptr=unsafe{libc::malloc(size)as*mutu8};ifptr.is_null(){panic!("Failed to allocate memory");}Self{ptr,size}}pubfnallocate(&self,offset:usize,len:usize)->&[u8]{ifoffset+len>self.size{panic!("Out of bounds access");}unsafe{std::slice::from_raw_parts(self.ptr.add(offset),len)}}pubfndeallocate(self){unsafe{libc::free(self.ptras*mutlibc::c_void);}}}implDropforSafeAllocator{fndrop(&mutself){unsafe{libc::free(self.ptras*mutlibc::c_void);}}}
In this example:
unsafe is confined to specific, well-defined sections of the code.
The API ensures that users cannot misuse the allocator (e.g., by accessing out-of-bounds memory).
Drop ensures memory is automatically freed when the allocator goes out of scope.
Example Usage of SafeAllocator
Here’s how you might use the SafeAllocator in practice:
fnmain(){// Create a new SafeAllocator with 1024 bytes of memoryletallocator=SafeAllocator::new(1024);// Allocate a slice of 128 bytes starting from offset 0letslice=allocator.allocate(0,128);println!("Allocated slice of length: {}",slice.len());// The allocator will automatically deallocate memory when it goes out of scope}
This usage demonstrates:
How to create and interact with the SafeAllocator API.
That memory is automatically managed via Rust’s Drop trait, preventing leaks.
Leveraging Rust’s Type System
Rust’s type system is another powerful tool for enforcing invariants. For example, you can use:
Lifetimes: To ensure references don’t outlive the data they point to.
PhantomData: To associate types or lifetimes with otherwise untyped data.
Ownership and Borrowing Rules: To enforce safe access patterns at compile time.
Documentation of Safety Contracts
Any unsafe code should include clear documentation of the invariants it relies on. For example:
// Safety:// - `ptr` must be non-null and point to a valid memory region.// - `len` must not exceed the bounds of the allocated memory.unsafe{std::slice::from_raw_parts(ptr,len)}
This makes it easier for future maintainers to understand and verify the correctness of the code.
Real-World Examples of Safe Abstractions
Many Rust libraries provide excellent examples of safe abstractions over unsafe code:
std::sync::Mutex: Internally uses unsafe for thread synchronization but exposes a safe API for locking and unlocking.
Vec: The Rust standard library’s Vec type uses unsafe for raw memory allocation and resizing but ensures bounds checks and proper memory management externally.
crossbeam: Provides safe concurrency primitives built on low-level atomic operations.
Costs and Benefits
While writing safe abstractions requires extra effort and careful thought, the benefits outweigh the costs:
Benefits:
Reduced Risk of Bugs: Encapsulating unsafe code minimizes the chance of introducing undefined behavior.
Improved Developer Experience: Safe APIs make it easier for others to use your code without worrying about low-level details.
Easier Auditing: With unsafe code isolated, it’s easier to review and verify its correctness.
Costs:
Initial Effort: Designing a robust safe abstraction takes time and expertise.
Performance Overhead: In rare cases, adding safety layers may incur slight overhead (though usually negligible in well-designed abstractions).
Conclusion
Writing safe abstractions for unsafe Rust code is both an art and a science. It involves understanding the invariants
of your unsafe code, leveraging Rust’s type system to enforce safety, and documenting your assumptions clearly. By
doing so, you can harness the power of unsafe while maintaining Rust’s guarantees of memory safety and concurrency
correctness—the best of both worlds.
In today’s post, we’ll build a simple key value server; but we’ll do it in an iterative way. We’ll build it up simple
and then add safety, concurrency, and networking as we go.
Implementation
Now we’ll get started with our iterations. The finished code will be available at the end of this post.
Baseline
All of our implementations will deal with a KeyValueStorestruct. This struct will hold all of the variables that
we want to keep track of in our server.
String is pretty limiting to store as far as the value side is concerned. We can upgrade this to specifically use
data types that we will find useful via an enum:
#[derive(Debug,Clone)]enumValue{String(String),Integer(i64),Float(f64),Boolean(bool),Binary(Vec<u8>),// Add more variants as needed}
We can swap out the value side of our data member now, too.
structKeyValueStore{data:HashMap<String,Value>,}
The implementation simply swaps the String for Value:
We’re now able to not only store strings. We can store integers, floats, binary, and booleans. This makes our key value
store a lot more versatile.
Thread Safety
We will have multiple threads of execution trying to perform actions on this structure at the same time, so we will
add some thread safety to the process now. Wrapping data in Arc will give us a thread safe, reference counting
pointer. We’re also going to need to lock this data structure for reading and for writing. We can use RwLock to
take care of that for us.
We update our data structure to include these new types:
These functions are now safe, which means calling code can be multithreaded and we can guaranteed that our data
structure will be treated consistently.
fnmain(){letstore=Arc::new(KeyValueStore::new());// Create a vector to hold thread handlesletmuthandles=vec![];// Spawn threads to perform insertsforiin0..5{letstore=Arc::clone(&store);lethandle=thread::spawn(move||{letkey=format!("key{}",i);letvalue=Value::Integer(i*10);store.insert(key.clone(),value);println!("Thread {} inserted: {}",i,key);});handles.push(handle);}// Spawn threads to read valuesforiin0..5{letstore=Arc::clone(&store);lethandle=thread::spawn(move||{letkey=format!("key{}",i);ifletSome(value)=store.get(&key){println!("Thread {} read: {} -> {:?}",i,key,value);}else{println!("Thread {} could not find: {}",i,key);}});handles.push(handle);}// Spawn threads to delete keysforiin0..5{letstore=Arc::clone(&store);lethandle=thread::spawn(move||{letkey=format!("key{}",i);store.delete(&key);println!("Thread {} deleted: {}",i,key);});handles.push(handle);}// Wait for all threads to completeforhandleinhandles{handle.join().unwrap();}println!("Final state of the store: {:?}",store.data.read().unwrap());}
Error handling
You can see that we’re using unwrap in the implementation functions, which might be ok for tests or short scripts. If
we’re going to expect to run this code in production, we’d be best replacing these with actual error handling counterparts.
In order to do that, we need to define our error domain first. We create an enum called StoreError. As we fill out
our implementation, we’ll run into a number of different error cases. We’ll use StoreError to centralise all of these
errors so we can express them clearly.
We’ve implemented PoisonError for our StoreError because the PoisonError type is an error which can be returned
whenever a lock is acquired. If something goes wrong and we’ve acquired a lock, it’s a PoisonError that’s used.
Our insert, get, and delete methods now need an upgrade. We’ll be returning Result<T, E> values from our
functions now to accomodate potential failures.
fninsert(&self,key:String,value:Value)->Result<(),StoreError>{letmutlocked=self.data.write()?;locked.insert(key,value);Ok(())}fnget(&self,key:&str)->Result<Option<Value>,StoreError>{letlocked=self.data.read()?;Ok(locked.get(key).cloned())// Clone the value to return an owned copy}fndelete(&self,key:&str)->Result<(),StoreError>{letmutlocked=self.data.write()?;iflocked.remove(key).is_none(){returnErr(StoreError::KeyNotFound(key.to_string()));}Ok(())}
We’ve removed the use of unwrap now, swapping out to using the ? operator. This will allow us to actually handle
any failure that is bubbled out of calling code.
Using the File System
We need to be able to persist the state of our key value store out to disk for durability. In order to do this, we need
to keep track of where we’ll write the file. We add a file_path member to our structure:
Starting out this implementation simply, we just write a load and save function that we can call at any time. Before
we do this we need some extra dependencies added for serialisation:
This will allow us to reduce our internal state to JSON.
Loading the database off disk
/// Load the state from a filefnload(&self)->Result<(),StoreError>{ifletSome(refpath)=self.file_path{matchfs::read_to_string(path){Ok(contents)=>{letdeserialized:HashMap<String,Value>=serde_json::from_str(&contents)?;letmutlocked=self.data.write()?;*locked=deserialized;// Replace the current state with the loaded oneOk(())}Err(e)ife.kind()==ErrorKind::NotFound=>{// File doesn't exist, just return Ok (no data to load)Ok(())}Err(e)=>Err(e.into()),}}else{Err(StoreError::IoError("File path not set".to_string()))}}
We need to make sure that a file_path was specified. We read everything off from the file into contents as a big
string. Using serde_json::from_str we can turn that contents into the deserialised representation. From there, we
simply swap out the underlying content.
We’ve got some new errors to deal with here in IoError.
This will be used for our write implementation which looks like this:
/// Save the current state to a filefnsave(&self)->Result<(),StoreError>{ifletSome(refpath)=self.file_path{letlocked=self.data.read()?;letserialized=serde_json::to_string(&*locked)?;fs::write(path,serialized)?;Ok(())}else{Err(StoreError::IoError("File path not set".to_string()))}}
The magic here really is the serde_json::to_string taking our internal state and writing it as json.
Finally, we’ll add some networking to the solution. A really basic network interface will allow remote clients to
perform the get, set, and delete operations for us.
The handle_client function is the heart of the server process, performing the needed processing on incoming requests
and routing them to the database instance:
fnhandle_client(mutstream:TcpStream,store:Arc<KeyValueStore>){letmutbuffer=[0;512];// Read the incoming requestmatchstream.read(&mutbuffer){Ok(_)=>{letrequest=String::from_utf8_lossy(&buffer);letmutparts=request.trim().split_whitespace();letcommand=parts.next();letresponse=matchcommand{Some("SET")=>{letkey=parts.next().unwrap_or_default().to_string();letvalue=parts.next().unwrap_or_default().to_string();store.insert(key,Value::String(value));"OK\n".to_string()}Some("GET")=>{letkey=parts.next().unwrap_or_default();ifletOk(Some(value))=store.get(key){format!("{:?}\n",value)}else{"Key not found\n".to_string()}}Some("DEL")=>{letkey=parts.next().unwrap_or_default();store.delete(key);"OK\n".to_string()}_=>"Unknown command\n".to_string(),};// Send the response back to the clientstream.write_all(response.as_bytes()).unwrap();}Err(e)=>eprintln!("Failed to read from socket: {}",e),}}
Out networking “protocol” looks like this:
-- set the key "key1" to the value "hello"
SET key1 hello
-- get the value of the key "key1"
GET key1
-- remove the value and key "key1"
DEL key1
We read in the request data from the client into request. This gets split up on white spaces into parts with command
given the first of these parts. The code is expectingcommand to be either SET, GET, or DEL that is then
handled in the following pattern match.
This function gets mounted onto the server in the main function which now looks like this:
fnmain(){letstore=Arc::new(KeyValueStore::new(None));letlistener=TcpListener::bind("127.0.0.1:7878").unwrap();println!("Server running on 127.0.0.1:7878");forstreaminlistener.incoming(){matchstream{Ok(stream)=>{letstore=Arc::clone(&store);std::thread::spawn(move||handle_client(stream,store));}Err(e)=>eprintln!("Connection failed: {}",e),}}}
We’re starting our server on port 7878 and handling each connection with our handle_client function.
Running this and giving it a test with telnet gives us the following:
➜ telnet 127.0.0.1 7878
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
SET key1 hello
OK
Connection closed by foreign host.
➜ telnet 127.0.0.1 7878
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
GET key1
String("hello")
Connection closed by foreign host.
➜ telnet 127.0.0.1 7878
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
DEL key1
OK
Connection closed by foreign host.
➜ telnet 127.0.0.1 7878
Trying 127.0.0.1...
Connected to 127.0.0.1.
Escape character is '^]'.
GET key1
Key not found
Connection closed by foreign host.
So, it works. It’s crude and needs to be patched to be a little more production ready than this - but this is a start.
Conclusion
In this article, we walked through building a thread-safe, persistent key-value store in Rust. We started with a simple
in-memory implementation and iteratively improved it by:
Adding support for multiple data types using an enum.
Ensuring thread safety with RwLock and Arc.
Replacing unwrap with proper error handling.
Adding file persistence using JSON serialization and deserialization.
Added some basic network access
This provides a solid foundation for a more robust and scalable key-value server. Next steps could include:
Implementing advanced features like snapshots or replication.
Optimizing for performance with tools like async I/O or a custom storage engine.
Rust, known for its performance, memory safety, and low-level control, is gaining traction in domains traditionally
dominated by Python, such as machine learning (ML). While Python is the go-to for prototyping ML models due to its
mature ecosystem, Rust shines in scenarios demanding high performance, safety, and seamless system-level integration.
In this post, we’ll explore how to implement logistic regression in Rust and discuss the implications of the model’s
output.
Why use Rust?
Before diving into code, it’s worth asking: why choose Rust for ML when Python’s libraries like TensorFlow and PyTorch
exist?
Benefits of Rust:
Performance: Rust offers near-C speeds, making it ideal for performance-critical tasks.
Memory Safety: Its ownership model ensures memory safety, preventing bugs like segmentation faults and data races.
Integration: Rust can easily integrate with low-level systems, making it a great choice for embedding ML models into IoT, edge devices, or game engines.
Control: Rust provides fine-grained control over execution, allowing developers to optimize their models at a deeper level.
While Rust’s ML ecosystem is still evolving, libraries like ndarray, linfa, and smartcore provide foundational
tools for implementing machine learning models.
Logistic Regression
Logistic regression is a simple yet powerful algorithm for binary classification. It predicts whether a data point
belongs to class 0 or 1 based on a weighted sum of features passed through a sigmoid function.
Below is a Rust implementation of logistic regression using the ndarray crate for numerical operations.
usendarray::{Array2,Array1};usendarray_rand::RandomExt;usendarray_rand::rand_distr::Uniform;fnsigmoid(x:f64)->f64{1.0/(1.0+(-x).exp())}fnlogistic_regression(X:&Array2<f64>,y:&Array1<f64>,learning_rate:f64,epochs:usize)->Array1<f64>{let(n_samples,n_features)=X.dim();letmutweights=Array1::<f64>::random(n_features,Uniform::new(-0.01,0.01));letmutbias=0.0;for_in0..epochs{letlinear_model=X.dot(&weights)+bias;letpredictions=linear_model.mapv(sigmoid);// Compute the errorleterror=&predictions-y;// Compute gradientsletgradient_weights=X.t().dot(&error)/n_samplesasf64;letgradient_bias=error.sum()/n_samplesasf64;// Update weights and biasweights-=&(learning_rate*gradient_weights);bias-=learning_rate*gradient_bias;}weights}fnmain(){letX=Array2::random((100,2),Uniform::new(-1.0,1.0));// Random featureslety=Array1::random(100,Uniform::new(0.0,1.0)).mapv(|v|ifv>0.5{1.0}else{0.0});// Random labelsletweights=logistic_regression(&X,&y,0.01,1000);println!("Trained Weights: {:?}",weights);}
Key Concepts:
Sigmoid Function: Converts the linear combination of inputs into a value between 0 and 1.
Gradient Descent: Updates weights and bias iteratively to minimize the error between predictions and actual labels.
Random Initialization: Weights start with small random values and are fine-tuned during training.
Output
When you run the code, you’ll see output similar to this:
Predictions close to 1 indicate class 1, while predictions close to 0 indicate class 0.
Why Does This Matter?
This simple implementation demonstrates the flexibility and control Rust provides for machine learning tasks. While
Python excels in rapid prototyping, Rust’s performance and safety make it ideal for deploying models in production,
especially in resource-constrained or latency-critical environments.
When Should You Use Rust for ML?
Rust is a great choice if:
Performance is critical: For example, in real-time systems or embedded devices.
Memory safety is a priority: Rust eliminates common bugs like memory leaks.
Integration with system-level components is needed: Rust can seamlessly work in environments where Python may not be ideal.
Custom ML Implementations: You want more control over how the algorithms are built and optimized.
For research or quick prototyping, Python remains the best choice due to its rich ecosystem and community. However,
for production-grade systems, Rust’s strengths make it a compelling alternative.
Conclusion
While Rust’s machine learning ecosystem is still maturing, it’s already capable of handling fundamental ML tasks like
logistic regression. By combining performance, safety, and control, Rust offers a unique proposition for ML developers
building high-performance or production-critical applications.
In this blog post, we’ll dive into the fascinating world of ARM
assembly programming by writing and running a basic bootloader. ARM’s
dominance in mobile and embedded systems makes it an essential architecture to understand for developers working with
low-level programming or optimization.
We’re going to use QEMU, an open-source emulator, we can develop and test our code right on your PC. So we won’t need any
hardware (just yet).
What is ARM?
ARM, short for Advanced RISC Machine, is a family of Reduced
Instruction Set Computing (RISC) architectures.
ARM processors power billions of devices, from smartphones
and tablets to embedded systems and IoT devices. Its popularity stems from its energy efficiency and simplicity
compared to complex instruction set computing (CISC)
architectures like x86.
Why Emulation?
While ARM assembly is usually executed on physical devices, emulation tools like QEMU allow
you to:
Test code without requiring hardware.
Experiment with different ARM-based architectures and peripherals.
Debug programs more effectively using tools like GDB.
Supported ARM Hardware
Before we begin coding, let’s take a brief look at some popular ARM-based platforms:
Raspberry Pi: A widely used single-board computer.
BeagleBone Black: A powerful option for embedded projects.
STM32 Microcontrollers: Common in IoT and robotics applications.
Setup
Before we begin, we need to setup our development and build environment. I’m using Manjaro so package
names might be slightly different for your distro of choice.
QEMU emulates a variety of hardware architectures, including ARM.
sudo pacman -Ss qemu-system-arm
Now we need to install the ARM toolchain which will include the assembler (as), linker (ld), and other essential tools.
sudo pacman -S arm-none-eabi-gcc binutils
Now you should have everything you need to get going.
Bootloader
Our goal is to write a minimal ARM assembly program that outputs “Hello, World!” via the UART interface.
The Code
Here is the source code for our bootloader, saved as boot.s:
.section .text
.global _start
_start:
LDR R0, =0x101f1000 @ UART0 base address
LDR R1, =message @ Address of the message
LDR R2, =message_end @ Address of the end of the message
loop:
LDRB R3, [R1], #1 @ Load a byte from the message and increment the pointer
CMP R1, R2 @ Check if we’ve reached the end of the message
BEQ done @ If yes, branch to done
STRB R3, [R0] @ Output the character to UART0
B loop @ Repeat for the next character
done:
B done @ Infinite loop to prevent execution from going beyond
message:
.asciz "Hello, World!\n" @ Null-terminated string
message_end:
Breaking this down line by line, we get:
LDR R0, =0x101f1000: Load the memory address of UART0 (used for serial output) into register R0.
LDR R1, =message: Load the starting address of the message into R1.
LDR R2, =message_end: Load the end address of the message into R2.
After this setup, we move into a loop.
Load a byte from the message (R3).
Compare R1 (current pointer) with R2 (end of the message).
Write the character to UART0 and repeat.
Finally, we finish up with an infinite loop to prevent the program from running into uninitialized memory.
Building
First we need to assemble the code into an object file:
arm-none-eabi-as -o boot.o boot.s
Next, we link the object file to produce an executable:
For deployment, we’ll use a Raspberry Pi as an example. This process is similar for other ARM-based boards.
Flashing
First, we need to convert the ELF file to a raw binary format suitable for booting:
arm-none-eabi-objcopy -O binary boot.elf boot.bin
Use a tool like dd to write the binary to an SD card:
dd if=boot.bin of=/dev/sdX bs=512 seek=2048
Running
Insert the SD card into the board.
Power up the device and connect to its UART output (e.g., using a USB-to-serial adapter).
You should see “Hello, World!” printed on the serial console.
Conclusion
Congratulations! You’ve successfully written, emulated, and deployed a simple ARM bootloader. Along the way, you learned:
How to write and debug ARM assembly.
How to use QEMU for emulation.
How to deploy code to real hardware.
From here, you can explore more advanced topics like interrupts, timers, or even writing a simple operating system kernel. The journey into ARM assembly has just begun!
Loss functions are the unsung heroes of machine learning. They guide the
learning process by quantifying the difference between the predicted and actual outputs. While frameworks like
PyTorch and TensorFlow offer a plethora of standard loss
functions such as Cross-Entropy and
Mean Squared Error, there are times when a custom loss function
is necessary.
In this post, we’ll explore the why and how of custom loss functions by:
Setting up a simple neural network.
Using standard loss functions to train the model.
Introducing and implementing custom loss functions tailored to specific needs.
Pre-reqs
Before we begin, you’ll need to setup a python project and install some dependencies. We’ll be using PyTorch and
torchvision. To install these dependencies, use the following command:
pip install torch torchvision
Once installed, verify the installation by running:
python -c"import torch; print(torch.__version__)"
Network Setup
Let’s start by creating a simple neural network to classify data. For simplicity, we’ll use a toy dataset like the
MNIST digits dataset.
Dataet preparation
Use the MNIST dataset (handwritten digits) as an example.
Normalize the dataset for faster convergence during training.
importtorchimporttorch.optimasoptimfromtorchvisionimportdatasets,transforms# Data preparation
transform=transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,),(0.5,))])train_data=datasets.MNIST(root='./data',train=True,transform=transform,download=True)train_loader=torch.utils.data.DataLoader(train_data,batch_size=64,shuffle=True)
Model Architecture
Input layer flattens the 28x28 pixel images into a single vector.
Two hidden layers with 128 and 64 neurons, each followed by a ReLU activation.
An output layer with 10 neurons (one for each digit) and no activation (handled by the loss function).
# Simple Neural Network
importtorch.nnasnnclassSimpleNN(nn.Module):def__init__(self):super(SimpleNN,self).__init__()self.fc1=nn.Linear(28*28,128)self.fc2=nn.Linear(128,64)self.fc3=nn.Linear(64,10)defforward(self,x):x=x.view(x.size(0),-1)# Flatten the input
x=torch.relu(self.fc1(x))x=torch.relu(self.fc2(x))x=self.fc3(x)returnx
Training Setup:
Use an optimizer (e.g., Adam) and CrossEntropyLoss for training.
Loop over the dataset for a fixed number of epochs, computing loss and updating weights.
# Initialize model, optimizer, and device
model=SimpleNN()optimizer=optim.Adam(model.parameters(),lr=0.001)device=torch.device('cuda'iftorch.cuda.is_available()else'cpu')model.to(device)
Standard Loss
Let’s train the model using the standard Cross-Entropy Loss, which is suitable for classification tasks.
Combines log_softmax and negative log likelihood into one step.
Suitable for classification tasks as it penalizes incorrect predictions heavily.
# Standard loss function
criterion=nn.CrossEntropyLoss()# Training loop
deftrain_model(model,train_loader,criterion,optimizer,epochs=5):model.train()forepochinrange(epochs):total_loss=0forimages,labelsintrain_loader:images,labels=images.to(device),labels.to(device)# Forward pass
outputs=model(images)loss=criterion(outputs,labels)# Backward pass and optimization
optimizer.zero_grad()loss.backward()optimizer.step()total_loss+=loss.item()print(f'Epoch {epoch+1}/{epochs}, Loss: {total_loss/len(train_loader):.4f}')train_model(model,train_loader,criterion,optimizer)
The output of this training session should look something like this:
Standard loss functions may not work well in cases like:
Imbalanced Datasets: Classes have significantly different frequencies.
Multi-Task Learning: Different tasks require different weights.
Task-Specific Goals: Optimizing for metrics like precision or recall rather than accuracy.
Example: Weighted Loss
Suppose we want to penalize misclassifying certain classes more heavily. We can achieve this by implementing a
weighted Cross-Entropy Loss.
# Custom weighted loss function
classWeightedCrossEntropyLoss(nn.Module):def__init__(self,class_weights):super(WeightedCrossEntropyLoss,self).__init__()self.class_weights=torch.tensor(class_weights).to(device)defforward(self,outputs,targets):log_probs=torch.log_softmax(outputs,dim=1)loss=-torch.sum(self.class_weights[targets]*log_probs[range(len(targets)),targets])/len(targets)returnloss# Example: Higher weight for class 0
class_weights=[2.0ifi==0else1.0foriinrange(10)]custom_criterion=WeightedCrossEntropyLoss(class_weights)# Training with custom loss function
train_model(model,train_loader,custom_criterion,optimizer)
After running this, you should see output like the following:
Sometimes, you might want to combine multiple objectives into a single loss function.
# Custom loss combining Cross-Entropy and L1 regularization
classCombinedLoss(nn.Module):def__init__(self,alpha=0.1):super(CombinedLoss,self).__init__()self.ce_loss=nn.CrossEntropyLoss()self.alpha=alphadefforward(self,outputs,targets,model):ce_loss=self.ce_loss(outputs,targets)l1_loss=sum(torch.sum(torch.abs(param))forparaminmodel.parameters())returnce_loss+self.alpha*l1_losscustom_criterion=CombinedLoss(alpha=0.01)# Training with combined loss
train_model(model,train_loader,lambdaoutputs,targets:custom_criterion(outputs,targets,model),optimizer)
Comparing Results
To compare the results of standard and custom loss functions, you need to evaluate the following:
Training Loss:
Plot the loss per epoch for both standard and custom loss functions.
Accuracy:
Measure training and validation accuracy after each epoch.
Compare how well the model performs in predicting each class.
Precision and Recall:
Useful for imbalanced datasets to measure performance on minority classes.
Visualization:
Confusion matrix: Visualize how often each class is misclassified.
Loss curve: Show convergence speed and stability for different loss functions.
We can use graphs to visualise how these metrics perform:
fromsklearn.metricsimportclassification_report,confusion_matriximportmatplotlib.pyplotaspltimportnumpyasnp# After training
model.eval()all_preds,all_labels=[],[]withtorch.no_grad():forimages,labelsintrain_loader:images,labels=images.to(device),labels.to(device)outputs=model(images)preds=torch.argmax(outputs,dim=1)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())# Confusion Matrix
cm=confusion_matrix(all_labels,all_preds)plt.imshow(cm,cmap='Blues')plt.title('Confusion Matrix')plt.colorbar()plt.show()# Classification Report
print(classification_report(all_labels,all_preds))
We can also produce visualisations of our loss curves:
# Assuming loss values are stored during training
plt.plot(range(len(train_losses)),train_losses,label="Standard Loss")plt.plot(range(len(custom_losses)),custom_losses,label="Custom Loss")plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.title('Loss Curve')plt.show()
Conclusion
Custom loss functions empower you to fine-tune your neural networks for unique problems. By carefully designing and
experimenting with loss functions, you can align your model’s learning process with the specific goals of your
application.
Some closing tips for custom loss functions:
Always start with a simple baseline (e.g., Cross-Entropy Loss) to understand your model’s behavior.
Visualize performance across metrics, especially when using weighted or multi-objective losses.
Experiment with different weights and loss combinations to find the optimal setup for your task.
The key is to balance complexity and interpretability—sometimes, even simple tweaks can significantly impact
performance.