You can install AIHWKIT-Lightning with
pip install git+https://github.com/IBM/aihwkit-lightning.git
or with
git clone [email protected]:IBM/aihwkit-lightning.git
cd aihwkit-lightning
pip install -e .
For the triton
mode to work, you need to have >= GPU compute capability 7.0 (e.g. V100, A100, H100) and have triton
installed.
Currently, only the nightly version works, which can be installed using
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly
from torch import Tensor
from torch.nn.functional import mse_loss
from torch.optim import SGD
# Import the aihwkit constructs.
from aihwkit_lightning.nn import AnalogLinear
from aihwkit_lightning.optim import AnalogOptimizer
from aihwkit_lightning.simulator.configs import TorchInferenceRPUConfig
x = Tensor([[0.1, 0.2, 0.4, 0.3], [0.2, 0.1, 0.1, 0.3]])
y = Tensor([[1.0, 0.5], [0.7, 0.3]])
# Define a network using a single Analog layer.
model = AnalogLinear(4, 2, rpu_config=TorchInferenceRPUConfig())
# Use the analog-aware stochastic gradient descent optimizer.
opt = AnalogOptimizer(SGD, model.analog_layers(), model.parameters(), lr=0.01)
# Train the network.
for epoch in range(10):
pred = model(x)
loss = mse_loss(pred, y)
loss.backward()
opt.step()
print(f"Loss error: {loss:.4f}")
In the examples folder, we have some examples that show how to use the AIHWKIT-Lightning:
- Huggingface shows how to train a network with Huggingface and AIHWKIT-Lightning.
- DeepSpeed + AIHWKIT-Lightning shows how to integrate AIHWKIT-Lightning with DeepSpeed.
- SLURM + DeepSpeed + Huggingface Accelerate + AIHWKIT-Lightning shows how to do multi-node training of a language model using DeepSpeed, Slurm, Huggingface Accelerate and AIHWKIT-Lightning.
One can easily convert any model trained with AIHWKIT-Lightning to AIHWKIT.
from aihwkit_lightning.nn.export import export_to_aihwkit
# `model` is a model from AIHWKIT-Lightning
# `max_output_size` <= 0 means that we do not
# split layers along the output dimension
aihwkit_model = export_to_aihwkit(model=model, max_output_size=-1)
Before starting to write code for a possible contribution, please get in touch with us on Slack or by opening an issue. We can then discuss whether the proposed feature makes sense for this toolkit and how to proceed.
For bug-fixes, please follow the instructions below.
Install the development requirements.
pip install -r requirements_dev.txt
mypy --install-types
Create a fork from the main
branch and make a well-documented PR. Make sure to run the following before submitting the PR:
make pytest
make black
make mypy
make pycodestyle
make pylint
All of these should pass.
AIHWKIT-Lightning can be accelerated using triton. This generally only makes sense when your layer sizes are in the thousands and when you want to split the layer into multiple tiles (only across the input dimension is supported).
To enable triton for AnalogConv2d
and AnalogLinear
, either export AIHWKIT_USE_TRITON="1"
or execute your script as such AIHWKIT_USE_TRITON="1" python your_script.py
. This feature is off by default.
The triton kernel is generally faster than the normal PyTorch implementation, and much faster when you set mapping.max_input_size
to something other than 0 or -1
, i.e. you split your matrix into tiles.
However, some things are still not optimal. Until these points are resolved, we consider the triton
mode experimental.
- The sliced
std()
kernel that calculates thestd()
for slices of a tensor is not very fast. Fixing this, would speed up the scenario where we chunk the weight matrix along the input dimension significantly.
- Tests checking the correctness against AIHWKIT are passing. Becuase we don't normalize inputs and weights, tests for
float16
andbfloat16
only pass for highatol
. When normalizing the input (which is not needed and adds extra FLOPs, tests are also passing in half precision). - Currently,
torch.compile
doesn't work when input range learning is activated, because a leaf variable requiring gradients gets updated in the forward pass. - Input range learning is made up of three gradients. Our own gradient + the gradients resulting from the operations
inp_slice = inp_slice / input_range[slice_idx]
andout_slice *= input_range[slice_idx]
. These two gradients are not accessible intriton
mode. We verified that the downstrean accuracy was not affected by this. The pure PyTorch version also leaves out these gradients, but the custom gradient is correct compared the AIHWKIT.
IBM Research has developed AIHWKIT-Lightning with Julian Büchel as the initial core author. You can contact us by opening a new issue in the repository.
@inproceedings{
aihwkitlightning,
title={AIHWKIT-Lightning: A Scalable HW-Aware Training Toolkit for Analog In-Memory Computing},
author={Julian Büchel and William Andrew Simon and Corey Lammie and Giovanni Acampa and Kaoutar El Maghraoui and Manuel Le Gallo and Abu Sebastian},
booktitle={NeurIPS 2024 Workshop Machine Learning with new Compute Paradigms},
year={2024},
url={https://openreview.net/forum?id=QNdxOgGmhR}
}
This project is licensed under the MIT License.