# Triton + Flyte: How to get to realtime ML
Hey there! Welcome to my blog post on Triton!
The goal of this blog post is to hopefully resonate with your real time inference needs and showcase 1 path to success.
If you've trained an ML model and thought about deployment, this blog is for you.
With that, let's get started!
If you've dabbled with ML before, you've probably ran a huggingface inference example like below, and thought that was pretty cool
```python=
from transformers import BertTokenizer, BertModel
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained("bert-base-uncased")
text = "Replace me by any text you'd like."
encoded_input = tokenizer(text, return_tensors='pt')
output = model(**encoded_input)
print(f"model outputs; try comparing different inputs to see how it changes:{output}")
```
Your model did the thing! It ran on my CPU enabled system (it took a few minutes), now you can hook it into your express.js web server.. right?
## What's the problem? Is realtime ML Hard?
Wrong, my friend.. if only it were that simple.
### GPUs
// GPU Image
First off; the basic example may have taken minutes, but end users don't want to wait minutes for some widget on their website to load, or their phone's voice recognition system to function. For some models, we might not even be able to run them at all in reasonable human time without a powerful GPU (looking at you llama2 70b!).
Having your model accelerated (in inference) can have a monumental affect on performance; {{insert statistic}} but that requires setting up and configuring something separate from your express.js service you were using before.
### Model Versioning
// Model Versioning Graphic
Lets assume you were able to find a server with a GPU and run your express.js service there. You deploy your model, it runs for a while; but then of course you want to make some updates to it. You retrained your model on new data, however your boss asks you a question - "how are you going to ensure your new model is better? What is our rollback plan?"
### Multiple Model Management
// Multiple inference model flowchart
Ok; so you got a GPU attached to your express.js node; it just fits; you're able to perform some basic level of inference via your web app. However now your boss wants you to deploy another model; this one is different and has substantially different resource requirements. Unlike the first use case, this needs to run only occasionally - spinning up a brand new server + express.js instance to handle this use case seems wasteful. There must be a better way right?
What looked pretty easy is now looking complicated, but there are tools we can use! From now on we're calling these "Model Serving Frameworks" and there are a few options.
## Existing Tools
Let's delve into some tools we can leverage, understanding that the market is quite diverse! Here, we’ll not only list the tools but also give a glimpse into specific scenarios where each might shine.
> Note: We won’t be diving deeply into each of these frameworks, but exploring them before making a decision is highly recommended.
#### Mosec
[Mosec](https://github.com/mosecorg/mosec) is a streamlined model serving framework focusing on high concurrency and asynchronous processing.
- **Pros**:
- Efficient in handling multiple requests simultaneously.
- Excellent for real-time applications due to its asynchronous nature.
- **Cons**:
- More basic in functionality compared to full-featured serving systems.
- **Use-Case Scenario**:
Ideal for a startup developing a chatbot service where handling numerous concurrent user queries with minimal latency is crucial. Mosec's ability to efficiently process simultaneous requests makes it a strong candidate for such real-time interaction applications.
#### TensorFlow Serving
[TensorFlow Serving](https://www.tensorflow.org/tfx/guide/serving) is specifically designed for TensorFlow models, offering robust integration and serving capabilities.
- **Pros**:
- Seamless multi-version model serving, ideal for experiments.
- Deep TensorFlow integration ensures optimized model performance.
- **Cons**:
- Limited to TensorFlow, less suitable for other ML frameworks.
- **Use-Case Scenario**:
Perfect for a data science team in a large corporation, which consistently iterates and experiments on TensorFlow-based models for image recognition. TensorFlow Serving's multi-version serving capability allows for effortless A/B testing of different model versions in production.
#### Torch Serve
[Torch Serve](https://github.com/pytorch/serve) is PyTorch's native solution for model deployment, focusing on simplicity and native support.
- **Pros**:
- Intuitive deployment of PyTorch models, supporting dynamic graphs.
- Offers model versioning and multi-model serving.
- **Cons**:
- Restricted to the PyTorch ecosystem, not as versatile.
- **Use-Case Scenario**:
Ideal for an AI research lab working extensively with PyTorch models, especially those involving dynamic computation graphs like in natural language processing. Torch Serve’s native support ensures smooth integration and deployment.
#### Seldon Core
[Seldon Core](https://github.com/SeldonIO/seldon-core) is a Kubernetes-native platform, ideal for complex machine learning deployment workflows.
- **Pros**:
- Highly scalable, supporting diverse ML frameworks and languages.
- Advanced orchestration capabilities for sophisticated deployment strategies.
- **Cons**:
- Requires Kubernetes proficiency, potentially complex for beginners.
- **Use-Case Scenario**:
Suitable for an enterprise with a diverse set of ML models across different frameworks, looking to deploy these on a Kubernetes cluster. Seldon Core's ability to manage complex workflows and scale effectively makes it a great choice for such environments.
#### Nvidia Triton
[Nvidia Triton](https://github.com/triton-inference-server/server) offers a comprehensive solution for model serving, particularly optimized for GPU acceleration.
- **Pros**:
- Exceptional GPU optimization, enhancing performance for complex models.
- Supports a wide range of frameworks, including TensorFlow, PyTorch, and ONNX.
- Advanced features like model versioning and multi-model management.
- **Cons**:
- May require more setup and configuration, especially to fully leverage GPU capabilities.
- **Use-Case Scenario**:
Best for a tech company needing to deploy a variety of complex models (like large-scale language models) that require intensive computation. Nvidia Triton's GPU optimization ensures high performance, making it ideal for such demanding scenarios.
Each of these tools has unique strengths, catering to various needs in the context of ML serving – from efficient model management to comprehensive, scalable serving solutions. Nvidia Triton, while requiring more setup, stands out for its advanced GPU optimization and versatility, making it a potent choice for deep learning deployments.
## What's interesting about Nvidia Triton
All of the above tools can help us; however in terms of getting "the most" out of the GPUs, Nvidia has them beat. If you're curious about benchmarking results [Nvidia compared Triton Serving vs. CPUs](https://docs.nvidia.com/ai-enterprise/natural-language/0.1.0/benchmark.html#triton-bert-large-benchmarks).
### Distinct Features in Model Serving
NVIDIA Triton sets itself apart in the model serving domain with specific features that elevate it above competitors like Mosec, TorchServe, and TensorFlow Serving:
**GPU Optimizations**: In Triton, optimizing GPU performance for an ML model involves more than just assigning the model to a GPU. You can leverage specific settings in the config.pbtxt file to fine-tune GPU usage. For instance, you can optimize an ONNX model by enabling GPU execution and setting an optimization profile for better performance:
```protobuf=
backend: "onnxruntime"
optimization {
execution_accelerators {
gpu_execution_accelerator : [ {
name: "tensorrt"
parameters: {
precision_mode: "FP16"
}
}]
}
}
instance_group [
{
kind: KIND_GPU
count: 1
}
]
```
In this example, the ONNX model is configured to use the TensorRT execution accelerator with FP16 precision mode on the GPU. This specific setting can significantly improve inference performance by optimizing model execution on the GPU, making it faster and more efficient. This level of optimization is a testament to Triton's robust capability to enhance model serving on GPUs, especially for complex models and backends like ONNX. If you want to enable Kernel or Layer fusion, that's automatically enabled for both ONNXRuntime and TensorRT! If you want to understand more of what's possible with a config, you can take a look [here](https://docs.nvidia.com/deeplearning/triton-inference-server/user-guide/docs/user_guide/model_configuration.html).
**Framework Support**: Triton's ability to support multiple frameworks like TensorFlow, PyTorch, and ONNX, is not common among its peers. This feature is configured in the `config.pbtxt` file, where you define the [backend](https://github.com/triton-inference-server/backend) and the corresponding framework for each model:
```protobuf=
backend: "tensorflow"
```
Triton supports models serialized for Pytorch, Tensorflow, ONNX, OpenVINO, vLLM, TensorRT and more.
You can even [build your own backend](https://github.com/triton-inference-server/backend#how-can-i-develop-my-own-triton-backend) to support your specific model serialization and runtime requirements.
**3. Advanced Model Versioning**: Triton's model versioning allows for seamless transitions between different versions of a model. You can easily specify which versions to serve using the `model_version_policy` in the `config.pbtxt`:
```protobuf=
model_version_policy {
latest {
num_versions: 2
}
}
```
**4. Efficient Multi-Model Management**: Managing multiple models is simplified in Triton. The framework allows for detailed resource control, including setting the number of CPU threads or the memory limit for each model. This is again set in the `config.pbtxt`:
```protobuf=
dynamic_batching {
preferred_batch_size: [4, 8]
max_queue_delay_microseconds: 100
}
```
In essence, NVIDIA Triton's standout features are not just about being better; they are about providing precise control and optimization. From GPU utilization to multi-framework support, advanced versioning, and multi-model management, Triton offers detailed configurations that empower users to fine-tune their model serving to an unprecedented degree.
## The Triton How-To Guide
In this guide, we'll be using the recommended [containerized](https://github.com/triton-inference-server/server/blob/main/docs/customization_guide/build.md#building-with-docker) approach for running the service. Your mileage may vary if you intend on building this externally.
First off, lets talk about server requirements:
- [docker](https://docs.docker.com/get-docker)
- [CUDA/CUDNN](https://developer.nvidia.com/cuda-toolkit)
- [nvidia-docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html)
### QuickStart
We'll be using the default pre-built OCI image for this; you can get it locally by pulling it via docker
```bash=
docker pull nvcr.io/nvidia/tritonserver:23.12-py3-sdk
```
Now that we've pulled the image we can run the server in it's default mode as can be found [here](https://github.com/triton-inference-server/server#serve-a-model-in-3-easy-steps)
```bash=
# Step 1: Create the example model repository
git clone -b r23.12 https://github.com/triton-inference-server/server.git
cd server/docs/examples
./fetch_models.sh
# Step 2: Launch triton from the NGC Triton container
docker run --gpus=1 --rm --net=host -v ${PWD}/model_repository:/models nvcr.io/nvidia/tritonserver:23.12-py3 tritonserver --model-repository=/models
# Step 3: Sending an Inference Request
# In a separate console, launch the image_client example from the NGC Triton SDK container
docker run -it --rm --net=host nvcr.io/nvidia/tritonserver:23.12-py3-sdk
/workspace/install/bin/image_client -m densenet_onnx -c 3 -s INCEPTION /workspace/images/mug.jpg
# Inference should return the following
Image '/workspace/images/mug.jpg':
15.346230 (504) = COFFEE MUG
13.224326 (968) = CUP
10.422965 (505) = COFFEEPOT
```
Great we have a basic toy running, but that doesn't actually help us get setup right? We have our _own_ custom models
### Configuring Backends
First off, regarding backends - the image we're using above comes pre-built with a variant of the following backends:
- pytorch
- tensorflow
- openvino
- onnxruntime
- python
- fil
- dali
- python
Which is fantastic! This is a new change and is something that's super helpful; however if you need the latest version of a backend, or a
If you are intent on attaching a custom backend, you'll need to compile it from source from the following github repos; here's a few that I think are the mots useful:
- [pytorch](https://github.com/triton-inference-server/pytorch_backend)
- [tensorflow](https://github.com/triton-inference-server/tensorflow_backend)
- [onnx](https://github.com/triton-inference-server/onnxruntime_backend)
- [TensorRT](https://github.com/triton-inference-server/tensorrt_backend)
Backends sit inside the container in the `/opt/triton/backends` directory and have the following expected structure:
```
Root
- {backend_name}
-- libtriton_{backend_name}.so
--
```
Make sure that `libtriton_{backend_name}.so` exists and is found in /opt/triton/backends/{backend_name}. You may need to also add your new backend to `LD_LIBRARY_PATH`.
I'd recommend adding the new backend to your image as part of a Dockerfile addition, however you can also mount the backend as a docker volume with `--v`.
## Understanding the Model Repository
The Model Repository in Triton is pretty cool; but it has some structural specifics that aren't concretely spelled out in the documentation, this is the structure of a model repository:
```
Model Repository Directory
- {model name}
- - config.pbtxt
- - {version number}
- - - ... Model Artifacts ... (model.pb, model.chkpt, savedmodel, etc.)
```
And this is an example of what that directory structure looks like:
```
- llama2_chat
- - config.pbtxt
- - 1
- - - model.pb
- - 2
- - - model.pb
```
As you can see, each model is separately defined from eachother.
One nice thing with Triton is this can point to a local filesystem location (within the container); or it can point to an S3/GCS bucket. The Triton Server once running will check the model repository location on a set interval for updates. If there are new models, those will be loaded based on their load settings as defined in the `config.pbtxt`
> Make sure that your config.pbtxt is formatted correctly, as triton can hard fail if any config.pbtxt files are malformed. Check the stderr statements on the server after updating a config file in your registry.
In the quickstart above, the `./fetch_models.sh` command created a local model repository called `model_repository` please review and inspect this for more information.
## Setting up an Inference Server
Now that we've got a model registry setup lets finally launch our inference server.
For simplicities sake; we launched this server as part of a single container on an EC2 instance - however the recommended approach is to have a load balancer with the ability to autoscale to meet demand.
To launch the server; we make sure to use the docker image that contains the correct backends; or mount them separately.
```bash=
docker run --gpus=1 --rm --net=host nvcr.io/nvidia/tritonserver:23.12-py3 tritonserver --model_repository=s3://some_bucket_path
```
And with that (assuming we've opened the right ports and allowed access) we can start making inference requests against our Triton Server.
## Trition + Flyte Architecture
<img src="https://i.imgur.com/pkd4Wq1.png"></img>
[Flyte](www.flyte.org) is an open source general purpose compute platform that accels at Data Science and ML workflows; it runs on kubernetes and uses DAGs as it's bread and butter. It's used by some of the biggest open source organizations!
It's a great tool; but it doesn't have any built-in solution for real time inference.
Triton fits in two different locations within a Flyte deployment; Within a Training pipeline in Flyte - the Flyte workflow can delegate to a pod to train the model and validate it; and then another can be used to create an update to a triton server's model repository.
Triton itself can persist as a simple OCI Container. Some third party hosts can be AWS Sagemaker, ModelZ and Streamlit.
The possibility also exists of deploying Triton into a dedicated kubernetes replicaset with a load balancer infront; being something that can be deployed along side an existing Flyte cluster, like what's done with the Spark Operator.
## Deploy an ML model from Flyte -> Triton Server
Now that we have a Triton server running, listening to a model repository sitting in S3 that our Flyte Cluster has write access to; we're ready to show off our workflow.
> We'll be skipping over imports here, but you can find the entire workflow and executible code in our [public github repo](https://github.com/unionai-oss/workshops/tree/main/triton_flyte_client).
To register an arbitrary model (in this case, taken from huggingface hub) on any kind of schedule we need to setup a Flyte Workflow and accompanying task.
```python=
# We leverage the ImageSpec Flyte feature to build ontop of the default Nvidia Triton image by using envd
image = ImageSpec(
requirements="requirements.txt",
registry="ghcr.io/unionai-oss",
name="trition-deployer",
base_image="nvidia/cuda:12.3.1-runtime-ubuntu20.04",
python_version="3.10"
)
# Setting up default parameters for your workflow makes it easy to trigger for testing purposes
@workflow
def register_model(model_name: str="distilroberta-finetuned-financial-news-sentiment-analysis",
model_registry_uri: str="s3://<put-your-s3-bucket-here>/triton-model-registry",
hf_hub_model_name: str="mrm8488/distilroberta-finetuned-financial-news-sentiment-analysis",
version: int =1) -> FlyteDirectory:
return register_hf_text_classifier(hf_hub_model_name=hf_hub_model_name, model_name=model_name, model_registry_uri=model_registry_uri, version=version)
@task(container_image=image, requests=Resources(mem="25Gi", gpu="1"))
def register_hf_text_classifier(hf_hub_model_name: str, model_name: str, model_registry_uri: str, version: int) -> FlyteDirectory:
model = AutoModelForTokenClassification.from_pretrained(hf_hub_model_name, torchscript=True)
model = model.to("cuda")
tokenizer = AutoTokenizer.from_pretrained(hf_hub_model_name)
dummy_input = tokenizer("foo bar", padding='max_length', max_length=128, truncation=True, return_tensors="pt")
dummy_input = {key: value.to("cuda") for key, value in dummy_input.items()}
with th.no_grad():
traced_model = th.jit.trace(model, example_inputs=[ value for key, value in dummy_input.items()])
output = traced_model(**dummy_input)
inputs_and_shapes = {key: value.shape for key, value in dummy_input.items()}
outputs_and_shapes = {}
if isinstance(output, tuple):
for i, tensor in enumerate(output):
outputs_and_shapes[f"output_{i}"] = tensor.shape
else:
outputs_and_shapes["output"] = output.shape
os.makedirs(f"/tmp/{model_name}/{version}", exist_ok=True)
config = generate_config_pbtxt(model_name, inputs_and_shapes, outputs_and_shapes)
traced_model.save(f"/tmp/{model_name}/{version}/model.pt")
with open(f"/tmp/{model_name}/config.pbtxt", "w") as f:
f.write(config)
return FlyteDirectory(path=f"/tmp/{model_name}", remote_directory=f"{model_registry_uri}/{model_name}")
```
The above task & workflow allow for a user to arbitrarily take a Text Extraction model trained in pytorch; export it to TorchScript (a Pytorch Triton backend requirement), and then create a config file with the right IO so Triton can understand what it's doing.
As you can see with the output of `register_hf_text_classifier`, we're returning a FlyteDirectory! This is really interesting, because we're on-the-fly creating a S3 directory within an S3 bucket; and replicating the local directory structure to our remote. The output from this is irrelevant because Triton will be polling that S3 bucket for updates.
We can now directly interact with the model loaded in Triton with the following Flyte Task
> Remember to ensure that the host containing your Triton Server container is accessible from your Flyte Cluster! Otherwise you'll get a 500/409 permissions error.
```python=
# In our example, we're hardcoding the inference client information; make sure to modify this for your own use case.
@task(container_image=image)
def make_inference_request(input: str, model_name: str, hf_hub_model_name: str, model_version: int) -> np.ndarray:
tokenizer = AutoTokenizer.from_pretrained(hf_hub_model_name)
client = InferenceServerClient("<path-to-aws-ec2-instance>:8000")
encoded = tokenizer(input, padding='max_length', max_length=128, truncation=True)
encoded = {key: np.asarray(value, dtype=np.int32).reshape(1, -1) for key, value in encoded.items() }
model_inputs = [
InferInput(key, value.shape, "INT32").set_data_from_numpy(value) for key, value in encoded.items()
]
infer_result = client.infer(model_name, model_inputs, model_version=str(model_version))
result = infer_result.as_numpy("output_0")
return result
#The Triton Client not only supports serial requests; but also handles async requests internally. This makes it easy to ship a large batch of requests to your Triton Server.
@task(container_image=image)
async def make_batch_inference_request(input: str, model_name: str, hf_hub_model_name: str, model_version: int, num_reqs: int) -> List[np.ndarray]:
tokenizer = AutoTokenizer.from_pretrained(hf_hub_model_name, padding=True)
client = InferenceServerClient("<path-to-aws-instance>:8000")
async_results = []
for _ in range(num_reqs):
encoded = tokenizer(input, padding='max_length', max_length=128, truncation=True)
encoded = {key: np.asarray(value, dtype=np.int32).reshape(1, -1) for key, value in encoded.items()}
model_inputs = [
InferInput(key, value.shape, "INT32").set_data_from_numpy(value) for key, value in encoded.items()
]
async_result = client.async_infer(model_name, model_inputs, model_version=str(model_version))
async_results.append(async_result)
results = [(async_result.get_result()).as_numpy("output_0") for async_result in async_results]
return results
```
## Wrapping up
Diving into the world of Nvidia Triton and Flyte, we've journeyed through the maze of setting up real-time ML operations, uncovering how these formidable tools reshape the deployment landscape. Triton stands out with its exceptional GPU optimization and multi-model prowess, proving itself as an indispensable asset for varied ML endeavors. When paired with Flyte's adeptness in orchestrating and managing complex ML workflows, the combination becomes a powerhouse, propelling teams towards efficient and scalable model deployment. In the ever-evolving realm of machine learning, leveraging such advanced technologies is key to maintaining a competitive edge, shifting the focus from deployment challenges to pioneering new frontiers in ML innovation.