An Overview of PyTorch Lightning with a Simple Code Walkthrough
As you continue your research into the Python coding language and explore the various frameworks it provides you may come across PyTorch Lightning. You may even ask yourself, “Why use PyTorch Lightning when there are so many other frameworks out there, especially if I am already using PyTorch?”
It is a fair question and this guide will help you understand why PyTorch Lightning could be effective for your AI project as well as how to get started using it. This guide will walk through three key pieces of learning in order to fully grasp how PyTorch Lightning can benefit you.
- First, understanding whether or not PyTorch Lightning is better than anything else PyTorch has to offer will be critical for your understanding.
- Second, clarifying how the switch from PyTorch to PyTorch Lightning will help you succeed in your next project.
- Finally, focus on learning how to use PyTorch Lightning so you can make your next AI project amazing.
Is PyTorch Lightning Better Than PyTorch?
A valid question when figuring out why to use PyTorch Lightning versus another PyTorch framework is to consider whether or not PyTorch Lightning is better than PyTorch. Before unraveling this answer, though, we need to understand what PyTorch and PyTorch Lightning are all about.
PyTorch is an open-source framework for machine learning. It is based on the Torch library used for AI models like computer vision and natural language processing.
PyTorch Lightning is a PyTorch-based high-level Python framework. It was built and designed with academics in mind so they could experiment with novel deep learning and machine learning models. More specifically, PyTorch Lightning opens the door to making machine learning scalable, so researchers can build more AI models efficiently and quickly.
With this in mind, an academic interested in heavy amounts of research and AI development would absolutely benefit from using PyTorch Lightning. It gives you access to the ability to experiment and further your research far more quickly than many other PyTorch frameworks. However, there are use cases for non-academic pursuits as well.
While it may be impossible to simply say that PyTorch Lightning is better or worse than PyTorch, it is possible to say that, in some circumstances, PyTorch Lightning is the best fit for the job.
Switching from PyTorch to PyTorch Lightning
Switching from PyTorch to PyTorch Lightning can feel tricky. However, the core of what PyTorch Lightning does is to simply streamline and clean up the coding process.
It might feel like magic, but PyTorch Lightning simply manipulates the boilerplate of your PyTorch code so that your code is structured rather than arbitrary. Most importantly, it can do this for every loop of the machine learning model training process.
All the code that will go unchanged throughout the AI model training process is reorganized and abstracted so the code looks cleaner, becomes easier to read and easier to track down flaws or errors. It also speeds up the process for others to iterate off one AI model to another.
Switching from PyTorch to PyTorch Lightning is simply a matter of getting used to seeing boilerplate code structured into simplified lines throughout your code. If you get all that boilerplate done correctly the first time, then the rest of the process should go much smoother.
One other benefit of switching from PyTorch to PyTorch Lightning is that it comes with a suite of free features such as progress bars and checkpointing. Why use PyTorch Lightning? For the sake of simplicity!
How to Use PyTorch Lightning
Figuring out how to use PyTorch Lightning is simple with only a few steps that will save you a lot more down the road. PyTorch Lightning, like many other Python projects, installs with pip (the package installer for Python).
For this, we recommend choosing a favorite virtual environment manager to handle installs and dependencies without clogging up your main Python installation. Once installed, running PyTorch Lightning is fairly straightforward.
Using PyTorch Lightning is similar to using raw PyTorch. The main difference, as we have mentioned, is the altering of boilerplate code becomes unnecessary. Other than that, all you have to do is inherit the LightningModule instead of the nn.module. PyTorch Lightning handles all of the critical components of deep learning network modeling.
Another important piece of information to consider when using PyTorch Lightning is that it is hardware agnostic. You can choose, based on what kind of model you are building and research you are performing, whether to run your AI models off of CPU or GPU or anything else you can make work. This makes using PyTorch one of the most flexible options for creating reproducible AI models.
If you are wondering how to use PyTorch, then the good news is that it is as simple as using raw PyTorch. If you are still wondering why use PyTorch Lightning instead of raw PyTorch, or any other framework for that matter, then the primary reason is for speed, efficiency, and reproducibility.
A Simple Code Walkthrough Example
Let's first install Lightning:
pip install pytorch-lightning
Of course, you have to make sure that you have PyTorch installed in your system/environment as well. As mentioned above, the key to organizing code with Lightning is to use the class LightningModule
. You have to define a class that inherits from this class and build on that.
The initialization may look like the following:
import torch from torch.nn import functional as F from torch import nn from pytorch_lightning.core.lightning import LightningModule class LitMNIST(LightningModule): def __init__(self): super().__init__() # mnist images are (1, 28, 28) (channels, height, width) self.layer_1 = nn.Linear(28 * 28, 128) self.layer_2 = nn.Linear(128, 256) self.layer_3 = nn.Linear(256, 10)
Clearly, you are building three layers of the neural net,
- one input layer which is ready to accept a 28x 28 dimension vector (e.g., an MNIST image),
- one hidden layer with 256 neurons
- one output layer with 10 classes
Thereafter, you define the forward propagation method, just like you do in PyTorch:
def forward(self, x): batch_size, channels, height, width = x.size() # (b, 1, 28, 28) -> (b, 1*28*28) x = x.view(batch_size, -1) x = self.layer_1(x) x = F.relu(x) x = self.layer_2(x) x = F.relu(x) x = self.layer_3(x) x = F.log_softmax(x, dim=1) return x
After this, you have to define two more components, the first one of which is the training step.
def training_step(self, batch, batch_idx): x, y = batch logits = self(x) loss = F.nll_loss(logits, y) return loss
And, lastly, the optimizer for training. Here, there is a clear difference from PyTorch code style. In Lightning, you use the configure_optimizer
method to define the optimizer. For example to introduce the famous Adam optimizer:
def configure_optimizers(self): return Adam(self.parameters(), lr=1e-3)
How to handle and load data for training? It is always advisable to use the DataLoader
class. Here is the code to get the data:
from torch.utils.data import DataLoader, random_split from torchvision.datasets import MNIST import os from torchvision import datasets, transforms # prepare transforms standard to MNIST transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]) # data mnist_train = MNIST(os.getcwd(), train=True, download=True, transform=transform) mnist_train = DataLoader(mnist_train, batch_size=64)
Note that here mnist_train
is an instance of the DataLoader
class. Finally, you can pass this object to the fit
process of Lightning like this to start training!
model = LitMNIST() trainer = Trainer() trainer.fit(model, mnist_train)
For more details and examples, please see the official documentation here.
Looking For More Information On PyTorch and PyTorch Lightning?
As you can see, we are big fans of PyTorch Lightning, especially for those who are interested in experimentation and research to see what all machine learning and deep learning models can do. For those interested in research or purely academic pursuits, there is simply no better framework than moving from PyTorch to PyTorch Lightning.
What do you think, though? Have we answered why use PyTorch Lightning at all? Or is there something you think we missed? We would love to hear from you and answer any questions you may have or help you get started on your next PyTorch Lightning project.
Interested in more PyTorch Lightning tutorials? Check out the PyTorch Lightning website for more great walkthroughs. You can also see how the growing community is using it here.
Feel free to contact us if you have any questions or take a look at our Deep Learning Solutions if you're interested in a workstation or server to run PyTorch/PyTorch Lightning on.