Last week, PyTorch introduced torchao (Architecture Optimisation tool), a native library designed to enhance model training and inference. It achieves this by leveraging low-bit data types, quantization, and sparsity.
According to the PyTorch team, Torchao's quantization algorithms, applicable to popular models like Llama 3 and diffusion models, have demonstrated up to 97% speedup in inference and 73% peak VRAM reduction, maintaining high accuracy. "Quantizing weights to int4 and the KV cache to int8 supports Llama 3.1 8B at full 128K context length, running in under 18.9GB of VRAM."
"If you're interested in making your models faster and smaller for training or inference, we hope you'll find torchao useful and easy to integrate," said the PyTorch team.
Built "mostly" in PyTorch code, and with a little added flexibility for CUDA or Triton, torchao offers developers an accessible toolkit to streamline inference and training processes for deep learning workloads, by making the models smaller and thus faster.
AIM had mentioned before that it is important for LLMs to shift to a 1-bit architecture, which essentially means decreasing the weights within the parameters of the model to one of three values: -1, 0, or 1, decreasing the amount of compute required for training each parameter or token.
This shift would increase the efficiency of these models while inferencing, but the issue when it comes to doing this is the loss of quality of the AI model. But the PyTorch team seems to be dedicated to making that transition lossless.
"You have to give it to the pytorch team, they're really great at bringing complex optimization schemes...down to a simple to use API," said a developer on Hacker News. PyTorch has always been the love of developers.
For benchmarking, the team tested their techniques on Llama 3 and different diffusion models for low-bit quantization, and saw minimal drops in efficiency. The baselines were tested on NVIDIA's A100 80GB GPU.
When it comes to Llama 3 8B, there was 97% speedup for inference using autoquant, which is an automation tool for determining the best way to apply quantization, with int4 weight.
This happens because of an interesting introduction of Quantization Aware Training (QAT), which minimises the potential accuracy loss caused by low-bit quantization, providing a more robust solution for model optimisation. For experimental purposes, PyTorch has introduced 8-bit and 4-bit optimisers as a direct replacement for AdamW, helping users improve model training efficiency.
Post-training quantization at less than 4 bits can often lead to accuracy degradation. torchao addresses this with QAT, which has shown to recover up to 96% of lost accuracy on benchmarks like Hellaswag. This process is integrated into the torchtune recipe, simplifying the implementation of QAT and making torchao an indispensable tool for training models while preserving accuracy.
Further, by leveraging low-bit data types like int4 and float8, torchao provides powerful quantization options to optimise model performance. It supports dynamic activation quantization across multiple data types and integrates sparsity for enhanced flexibility.
Beyond inference, torchao extends its optimization features to training processes. With support for low-precision computation and communication, torchao offers efficient workflows starting with float8 for torch.nn.Linear layers.
"Why isn't this merged into PyTorch?" is one of the questions that developers asked. Mark Saroufim from the PyTorch team from Meta said that there are trade-offs. "Having a separate repository is called 'out of core,' while including it in PyTorch is 'in core.'" PyTorch is a large library, and adding code takes time due to complex continuous integration (CI), strict backward compatibility (BC) rules, and dependency challenges.
Creating separate repositories like torchao, torchtune, torchchat, etc., keeps PyTorch leaner, with a smaller binary size. It also allows teams to focus on their specific optimisations, which is why developers love PyTorch. "Mostly comes down to what's fastest to develop, it's faster to write a few custom kernels than it is to develop a new compiler backend," said Mark.
torchao has already been integrated into key open-source projects like Hugging Face transformers, providing an inference backend, and diffusers-torchao for accelerating diffusion models. It also serves as a reference implementation for QLoRA and QAT in torchtune.
Additionally, torchao's low-bit quantization techniques are being utilised in the SGLang project, demonstrating its value across research and production.
Looking ahead, PyTorch plans to expand torchao's capabilities, including exploring sub-4-bit quantization, developing high-throughput kernels, extending support to more layers, and optimising it for additional hardware backends like MX hardware.