Bringing SOTA quantization to mobile LLM deployment: A practical Executorch integration guide
Running Llama-3-8b in 2.5 Gb of RAM: https://huggingface.co/ISTA-DASLab/AQLM-executorch
Goals
Executorch
Last year, Torch set course for bridging the gap between existing Torch capabilities and mobile and edge architectures. As the result, the Executorch library was created. Built upon the elaborate existing Intermediate Representation and core operations frameworks, the first releases provided a way to export and run PyTorch models on edge devices with minimal runtime overheads.
But, as is the case with many complicated Neural Network architectures and optimizations, many performance-critical parts of the computational graphs cannot be efficiently represented with basic core operations and require kernels that are both operation and architecture specific. As such, Executorch was designed with the notion of delegating parts of the computation graph to specific backends that wrap parts of the graph in high-level external operations and hide specific kernel optimizations behind it. This idea is described in more details in the official documentation.
Many existing backend integrations mostly handle very broadly operations, like matrix multiplications, convolutions and LayerNorms. In contrast, many RnD applications require quick prototyping and integration of very specific kernels. For example, let us consider a new quantization method: it operates on a specific weight representation and has a specific matrix multiplication kernel. It would make sense to integrate this singular kernel straight into Executorch to quickly evaluate its performance.
Luckily, Executorch provides all the tools necessary to integrate abstract C++ operations into both model export and device deployment runtimes with relatively small configuration overhead. Sadly, I found the documentation on how to actually do it lacking in details. Therefore, one of the main goals of this post is to provide a detailed explanation on what steps to take to make your C++ kernel into a traced operation registered and running on an mobile phone.
AQLM
Specifically, I will try to adapt one of the latest extreme Large Language Models quantization methods, AQLM, to run on an Android phone. AQLM was designed to allow for LLM quantization to as low as 2 bits per parameter, and it was presented as a poster at ICML 2024. More details about the method can be found in the original paper or in a detailed blogbost explanation.
Code Modifcations
From torch:: to torch::executor::
Let us assume we have a kernel implemented to be used with the C++ Torch library.
On a high level, AQLM stores linear layer matrices as three sets of parameters: codes, codebooks and scales. As such, a custom kernel implementing matrix multiplication with model activations would have to recieve all those tensors, as well as activations themselves.
torch::Tensor& code2x8_lut_matmat(
const torch::Tensor& input,
const torch::Tensor& codes,
const torch::Tensor& codebooks,
const torch::Tensor& scales,
const std::optional<torch::Tensor>& bias
) {
// Calculations
return out;
}
A few things must be done to make it Executorch compliant:
Output in args: Executorch operates on mostly preallocated buffers. As such, preallocated output tensor is passed as the last argument by convention.
torch::Tensor → torch::executor::Tensor: a much more limiting tensor interface. Doesn’t allow for many familiar methods (automatic allocation, e.g.) but way more lightweight. Notably, there is no matmul operation for such tensors, so one has to use low-level functions provided in
#include <executorch/kernels/optimized/blas/CPUBlas.h>
.torch::executor::RuntimeContext: added as the first argument. Handles logging and memory allocation.
What we get as the result is:
namespace torch {
namespace executor {
namespace native {
Tensor& code2x8_lut_matmat_out(
RuntimeContext& ctx,
const Tensor& input,
const Tensor& codes,
const Tensor& codebooks,
const Tensor& scales,
const optional<Tensor>& bias,
Tensor& out
) {
// Calculations
return out;
}
} // namespace native
} // namespace executor
} // namespace torch
Executorch kernel registration and CMake
As described in the documentations, the easiest way to link this kernel into Executorch is the EXECUTORCH_LIBRARY
macro. However, the documentation totally omits the fact that a very specific linking procedure is needed for this macro to take effect during program launch. Namely, target_link_options_shared_lib
function from Executorch CMake needs to be called. Moreover, this target_link_options_shared_lib
needs to be invoked for every library and executable that is to use your kernel. Therefore, it appears to be crucial that the developed kernel is to be written and build as a part of the Executorch project in their repository, as is the case with custom kernels they provide and test.
Therefore, in this post we will be building the integration from within the Executorch repository, integrating with both its CMake structure and Python scripts for LLM export. More precisely, all the modifications described here can be seen as a Pull Request on top of Executorch v0.3.0. In fact, that is exactly how I present the code for this post.
PyTorch registration
The documentation addresses this topic quite well. However, I would add that for the PyTorch integration to be available from within the Executorch python library, proper modifications should be made to install.py
.
CMake arguments and selective builds
It should be mentioned, that it’s not quite obvious which binding should be linked where and how. In general, in CMake we separate binaries that only need the low level Executorch runtime and binaries that need the whole torch
, such as the pip installable package.
In this example, the first library is called aqlm
and the second - aqlm_aot_lib
. the separation is controlled with a CMake argument EXECUTORCH_BUILD_KERNELS_CUSTOM_AOT
. The former is linked to everything, including the Python library, Llama CLI executables and Android dynamic library. The latter is only linked to the Python library.
How to build
Since AQLM is now properly integrated with the Executorch repository, we can follow the official documentation to build everything from source:
setup.py
would properly build and link the Python package, making the kernel available for integration with the Llama export script. Modifications made to the Python package files located inexecutorch/examples/models/llama2
make it possible to load, trace and export AQLM models.The steps described in the Android Llama Demo section of documentation will automatically build and link AQLM wherever necessary.
How to export AQLM models to .pte
Simply follow the original Llama export guide, adding -qmode aqlm-2x8
and the --converted_aqlm_checkpoint_path
, pointing to the weights produced with the convert_from_hf.ipynb
that I added to easily convert AQLM checkpoints published on HF Hub.
Notably, I decided to also quantize the embedding layer and the model head with existing Executorch quantization code for fair comparison with other quantization methods that quantize them.
Android NDK version
For unknown reasons, I observed that changing the Android NDK version from 25.2.9519653
to 25.0.8775105
result in 10x slowdown. My advice is to be cautious when selecting SDK version.
Running the Demo
Prebuilt APK and exported models
We quantized meta-llama/Meta-Llama-3.1-8B-Instruct with AQLM to produce ISTA-DASLab/Meta-Llama-3.1-8B-Instruct-AQLM-PV-2Bit-2x8-hf, which retained most of the original model’s performance on zero-shot tasks. Moreover, we applied PV-Tuning - SOTA post-quantization tuning technique to further decrease the quantization error.
We converted the HF checkpoint into a torch checkpoint with examples/models/llama2/aqlm/convert_from_hf.ipynb
.
We exported the checkpoint using the following script.
#!/bin/bash
python -m examples.models.llama2.export_llama \
-p ~/models/Meta-Llama-3.1-8B-Instruct-AQLM-PV-2Bit-2x8-hf/params.json \
--output_name "llama-3.1-8b-aqlm-pv.pte" \
-X -kv --use_sdpa_with_kv_cache -d fp32 \
--metadata '{"get_bos_id":128000, "get_eos_id":128009}' \
--max_seq_len 1024 \
--embedding-quantize 4,32 \
-qmode aqlm-2x8 --converted_aqlm_checkpoint_path ~/models/Meta-Llama-3.1-8B-Instruct-AQLM-PV-2Bit-2x8-hf/model.pth --group_size 128
As you can see, we also seamlessly applied existing Executorch embedding quantization.
We uploaded the resulting llama-3.1-8b-aqlm-pv.pte
checkpoint, as well as the original Llama-3 tokenizer and an APK app to run (build following the Android Llama Demo instructions) to the Hugging Face Hub.
To run it, you have to download the .pte
checkpoint and the tokenizer and put them into the /data/local/tmp/llama
folder on your phone. Installing and running the application, you will see the option to load them.
Measured performance
AQLM:
RAM Consumption: 2.3 Gb
Generation speed: 1.0-1.1 tok/s
8DA4W:
Checkpoint size: 4.1 Gb
Generation speed: 9-10 tok/s
Demonstration
Further Improvements
Optimizing the kernels
The AQLM kernels that we wrote are single-core and in no way optimized for `arm-v8a`. That’s why they are almost 10x slower than 8DA4W.
However, the performance critical section of AQLM is just 25 lines on loads and additions, meaning that they, most probably, can be strongly optimized for specific architectures.
template<typename fp_dtype>
void quadruple_for(
int num_inputs,
int num_input_groups, const fp_dtype* __restrict__ lut,
int out_features, const uint8_t* __restrict__ b_alt,
fp_dtype* __restrict__ output_vec
)
{
std::memset(output_vec, 0, num_inputs * out_features * sizeof(fp_dtype));
const int lut_stride = num_input_groups * 2 * 256;
const int b_alt_stride = 2 * out_features;
for (int input = 0; input < num_inputs; ++input) {
for (int j = 0; j < num_input_groups; ++j) {
const fp_dtype* lut_ptr = lut + input * lut_stride + j * 2 * 256;
const uint8_t* b_alt_ptr = b_alt + j * b_alt_stride;
for (int i = 0; i < out_features; ++i) {
output_vec[input * out_features + i] += lut_ptr[b_alt_ptr[i * 2]];
output_vec[input * out_features + i] += lut_ptr[256 + b_alt_ptr[i * 2 + 1]];
}
}
}
}
If you have experience in `arm-v8a` optimizations, or any other architectures you wanted to deploy AQLM to, and wanted to collaborate, I’d be happy to!
Contacts me at: LinkedIn, Twitter, Telegram, andrei@panferov.org
This post was done in collaboration with Vladimir Malinovskii.
Hi I'm Mengwei Liu from the ExecuTorch team. Thank you so much for taking the time to write this up, and to provide valuable feedback! We created an issue https://github.com/pytorch/executorch/issues/4719 to track your suggestions about improving the documentation. Please add more comments there if you have specifics. We'd also like to include you on PRs that improve the docs to make sure that we're incorporating your feedback.