SHARK is a portable High Performance Machine Learning Runtime for PyTorch.

In this blog we demonstrate PyTorch Training and Inference on the Apple M1Max GPU with SHARK with only a few lines of additional code and outperforming Apple’s Tensorflow-metal plugin.  Though Apple has released GPU support for Tensorflow via the now deprecated tensorflow-macos plugin and the newer tensorflow-metal plugin the most popular machine learning framework PyTorch lacks GPU support on Apple Silicon. Until now.

SHARK is built on MLIR and IREE and can target various hardware seamlessly. Since SHARK generates kernels on the fly for each workload you can port to new architectures like the M1Max without the vendor provided handwritten / hand tuned library.

Nod.ai has added AMD GPU support to be able to retarget the code generation for AMD MI100/MI200 class devices and move machine learning workloads from Nvidia V100/A100 to AMD MI100 seamlessly. In the past we demonstrated better codegen than Intel MKL and Apache/OctoML TVM on Intel Alderlake CPUs and outperforming Nvidia’s cuDNN/cuBLAS/CUTLASS used by ML frameworks such as Onnxruntime, Pytorch/Torchscript and Tensorflow/XLA. Today we demonstrate SHARK targeting  Apple’s 32 Core GPU in the M1Max with PyTorch Models for BERT Inference and Training. So if you love PyTorch and want to use those 32 GPU cores in your new Apple Silicon Macbook Pro read on.

For our experiment we will utilize a 14″ MacBook Pro with the Apple M1 Max with 64GB RAM. We have also run the same benchmarks on a 16″ MacBook Pro and notice the same performance and both don’t thermally throttle during our benchmarks.

SHARK Runtime on M1MAX GPU: 1.5x faster than TF-metal 2.8.0

SHARK on Apple M1MAX GPU

Here is the output of running the shark-bench tool on the microsoft/MiniLM-L12-H384 model.

(base) anush@MacBook-Pro examples % ./shark-bench --module_file=minilm_jan6_m1max.vmfb --entry_function=predict  --function_input=1x128xi32 --function_input=1x128xi32 --function_input=1x128xi32 --benchmark_repetitions=10
2022-02-20T10:17:55-08:00
Running ./shark-bench
Run on (10 X 24.1214 MHz CPU s)
CPU Caches:
  L1 Data 64 KiB (x10)
  L1 Instruction 128 KiB (x10)
  L2 Unified 4096 KiB (x5)
Load Average: 2.14, 1.91, 1.81
-----------------------------------------------------------------------------------
Benchmark                                         Time             CPU   Iterations
-----------------------------------------------------------------------------------
BM_predict/process_time/real_time              11.7 ms         1.39 ms           58
BM_predict/process_time/real_time              11.5 ms         1.34 ms           58
BM_predict/process_time/real_time              11.7 ms         1.43 ms           58
BM_predict/process_time/real_time              11.6 ms         1.30 ms           58
BM_predict/process_time/real_time              11.5 ms         1.33 ms           58
BM_predict/process_time/real_time              11.7 ms         1.46 ms           58
BM_predict/process_time/real_time              11.6 ms         1.31 ms           58
BM_predict/process_time/real_time              11.5 ms         1.33 ms           58
BM_predict/process_time/real_time              11.7 ms         1.46 ms           58
BM_predict/process_time/real_time              11.6 ms         1.30 ms           58
BM_predict/process_time/real_time_mean         11.6 ms         1.36 ms           10
BM_predict/process_time/real_time_median       11.6 ms         1.34 ms           10
BM_predict/process_time/real_time_stddev      0.074 ms        0.063 ms           10
BM_predict/process_time/real_time_cv           0.64 %          4.65 %            10
(base) anush@MacBook-Pro examples % 
Arguments: Namespace(models=['microsoft/MiniLM-L12-H384-uncased'], model_source='pt', model_class=None, engines=['tensorflow'], cache_dir='./cache_models', onnx_dir='./onnx_models', use_gpu=True, precision=<Precision.FLOAT32: 'fp32'>, verbose=False, overwrite=False, optimize_onnx=False, validate_onnx=False, fusion_csv='fusion.csv', detail_csv='detail.csv', result_csv='result.csv', input_counts=[1], test_times=1000, batch_sizes=[1], sequence_lengths=[128], disable_ort_io_binding=False, num_threads=[10])
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

All model checkpoint layers were used when initializing TFBertModel.

All the layers of TFBertModel were initialized from the model checkpoint at microsoft/MiniLM-L12-H384-uncased.
If your task is similar to the task the model of the checkpoint was trained on, you can already use TFBertModel for predictions without further training.
Run Tensorflow on microsoft/MiniLM-L12-H384-uncased with input shape [1, 128]
{'engine': 'tensorflow', 'version': '2.8.0', 'device': 'cuda', 'optimizer': '', 'precision': <Precision.FLOAT32: 'fp32'>, 'io_binding': '', 'model_name': 'microsoft/MiniLM-L12-H384-uncased', 'inputs': 1, 'threads': 10, 'batch_size': 1, 'sequence_length': 128, 'datetime': '2022-02-21 07:48:04.965312', 'test_times': 1000, 'latency_variance': '0.00', 'latency_90_percentile': '19.25', 'latency_95_percentile': '22.10', 'latency_99_percentile': '23.14', 'average_latency_ms': '16.99', 'QPS': '58.85'}

SHARK on M1 CPU: 2x faster than TF, ~1.5x faster than PyTorch, ONNX

SHARK on Apple M1MAX CPUs

Here is a video of the demo runs.

SHARK BERT Inference on M1Max at 1.5X Performance of TF-Metal

BERT Training

SHARK Runtime: 2x faster than TF-metal 2.8

SHARK Training on Apple M1MAX GPU

In our tests we noticed the Tensorflow-metal plugin doesn’t seem to offload the backwards graph onto the GPU efficiently. Since Tensorflow-metal is a binary only release we have no way to debug it. The same Tensorflow implementation works well to offload onto CUDA GPUs.

(base) anush@MacBook-Pro examples % ./shark-bench --module_file=bert_training_feb17.vmfb --function_input=1x512xi32 --function_input=1x512xi32 --function_input=1x512xi32 --function_input=1xi32 --entry_function=learn --benchmark_repetitions=10
2022-02-20T23:04:22-08:00
Running ./shark-bench
Run on (10 X 24.2416 MHz CPU s)
CPU Caches:
  L1 Data 64 KiB (x10)
  L1 Instruction 128 KiB (x10)
  L2 Unified 4096 KiB (x5)
Load Average: 2.03, 2.53, 2.31
---------------------------------------------------------------------------------
Benchmark                                       Time             CPU   Iterations
---------------------------------------------------------------------------------
BM_learn/process_time/real_time               104 ms         16.2 ms            5
BM_learn/process_time/real_time               104 ms         16.1 ms            5
BM_learn/process_time/real_time               105 ms         15.3 ms            5
BM_learn/process_time/real_time               104 ms         16.0 ms            5
BM_learn/process_time/real_time               103 ms         15.2 ms            5
BM_learn/process_time/real_time               105 ms         16.9 ms            5
BM_learn/process_time/real_time               104 ms         15.5 ms            5
BM_learn/process_time/real_time               104 ms         14.9 ms            5
BM_learn/process_time/real_time               105 ms         17.4 ms            5
BM_learn/process_time/real_time               105 ms         15.7 ms            5
BM_learn/process_time/real_time_mean          104 ms         15.9 ms           10
BM_learn/process_time/real_time_median        104 ms         15.8 ms           10
BM_learn/process_time/real_time_stddev      0.830 ms        0.774 ms           10
BM_learn/process_time/real_time_cv           0.80 %          4.86 %            10
(base) anush@MacBook-Pro examples % 
(base) anush@MacBook-Pro nlp_models % python ./bert_small_tf_run.py
Metal device set to: Apple M1 Max

systemMemory: 64.00 GB
maxCacheSize: 24.00 GB

2022-02-21 07:58:36.593857: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2022-02-21 07:58:36.594010: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)
1 Physical GPUs, 1 Logical GPU
Model: "bert_classifier"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_word_ids (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 input_mask (InputLayer)        [(None, None)]       0           []                               
                                                                                                  
 input_type_ids (InputLayer)    [(None, None)]       0           []                               
                                                                                                  
 bert_encoder_1 (BertEncoder)   [(None, None, 768),  15250176    ['input_word_ids[0][0]',         
                                 (None, 768)]                     'input_mask[0][0]',             
                                                                  'input_type_ids[0][0]']         
                                                                                                  
 dropout_1 (Dropout)            (None, 768)          0           ['bert_encoder_1[0][1]']         
                                                                                                  
 sentence_prediction (Classific  (None, 5)           3845        ['dropout_1[0][0]']              
 ationHead)                                                                                       
                                                                                                  
==================================================================================================
Total params: 15,254,021
Trainable params: 15,254,021
Non-trainable params: 0
__________________________________________________________________________________________________
...
time: 1.7728650569915771
time/iter: 0.19698500633239746

PERFORMANCE / WATT

A100 vs M1MAX for BERT Training

Power measurements need to be done in a very controlled and instrumented environment. However this is a good approximation for running the same BERT training model on an A100 and M1MAX.

The A100 draws 131W peak during the training run, the M1MAX GPU draws a maximum of 15.4 W. The A100 runs an iteration at 7ms vs 104ms (with SHARK) and 196ms (with TF-Metal). Since TF-Metal doesn’t offload the Training graph well onto the GPU we remove it from our perf/watt comparisons.

Mon Feb 21 00:56:33 2022                                                       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 470.57.02    Driver Version: 470.57.02    CUDA Version: 11.4     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|===============================+======================+======================|
|   0  NVIDIA A100-SXM...  Off  | 00000000:00:04.0 Off |                    0 |
| N/A   36C    P0   131W / 350W |  39341MiB / 40536MiB |     53%      Default |
|                               |                      |             Disabled |
+-------------------------------+----------------------+----------------------+
samples/ModelCompiler/nlp_models$ python bert_small_tf_run.py
..
Model: "bert_classifier"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to
==================================================================================================
input_word_ids (InputLayer)     [(None, None)]       0
__________________________________________________________________________________________________
input_mask (InputLayer)         [(None, None)]       0
__________________________________________________________________________________________________
input_type_ids (InputLayer)     [(None, None)]       0
__________________________________________________________________________________________________
bert_encoder_1 (BertEncoder)    [(None, None, 768),  15250176    input_word_ids[0][0]
                                                                 input_mask[0][0]
                                                                 input_type_ids[0][0]
__________________________________________________________________________________________________
dropout_1 (Dropout)             (None, 768)          0           bert_encoder_1[0][1]
__________________________________________________________________________________________________
sentence_prediction (Classifica (None, 5)            3845        dropout_1[0][0]
==================================================================================================
Total params: 15,254,021
Trainable params: 15,254,021
Non-trainable params: 0
__________________________________________________________________________________________________
..                                                                            
time: 6.907362699508667
time/iter: 0.006977134039907744
Package Power: 2.48W (avg: 2.34W peak: 37.78W) throttle: no          
CPU: 1.37W (avg: 1.03W peak: 32.32W)    GPU: 0.00W (avg: 0.00W peak: 15.41W)                  

SOFTWARE

The SHARK Runtime is available as pip package and requires a few lines of code changes in your Pytorch/Python file. Here is an example of running Resnet50 and BERT from Python using SHARK and we plan to add more Python examples there.  For the experiments in the post we will use a benchmark binary built with Google Benchmark support to run a controlled experiment though calling from Python adds a few microseconds. We test a BERT MiniLM Model (used in HuggingFace’s Infinity demos) and a BERT Training model. You can also try bert-base-uncased, Resnet50, Mobilenetv3 etc and all the models are automatically offloaded to the GPU via torch-mlir.

The Torch-MLIR lowering for the BERT training graph is being integrated in this staging branch.  All the code the recreate the tests are here and here. However you will need to install PyTorch torchvision from source since torchvision doesn’t have support for M1 yet. You will also need to build SHARK from the apple-m1-max-support branch from the SHARK repository.

(base) anush@MacBook-Pro examples % pip list | grep tensorflow
tensorflow-estimator          2.6.0
tensorflow-macos              2.8.0
tensorflow-metal              0.3.0
(base) anush@MacBook-Pro examples % pip list | grep onnx                                            
onnx                          1.10.1
onnxconverter-common          1.8.1
(base) anush@MacBook-Pro examples % pip list | grep tensor
tensorboard                   2.8.0
tensorboard-data-server       0.6.0
tensorboard-plugin-wit        1.8.0
tensorflow-estimator          2.6.0
tensorflow-macos              2.8.0
tensorflow-metal              0.3.0
(base) anush@MacBook-Pro examples % pip list | grep torch 
torch                         1.10.0
torchvision                   0.9.0a0
(base) anush@MacBook-Pro examples % pip list | grep onnxy
(base) anush@MacBook-Pro examples % pip list | grep onnx 
onnx                          1.10.1
onnxconverter-common          1.8.1
(base) anush@MacBook-Pro examples % 

We use asitop to monitor the Power Usage, GPU Usage, Core throttling etc which can be installed via pip.

Apple’s CoreML has the ability to target not just the CPU or GPU but also the Apple Neural Engine though only for inference. We did try to get CoreML to work for the inference comparison but ran into model conversion issues here and excluded it from the tests. The latest coremltools seems to require an older Tensorflow version 2.5.0 which is no longer the default when you install tensorflow with conda.

The OSX Window Server crashes when all GPUs are used to the maximum. We have filed a Feedback issue with Apple on this bug and hope it will be resolved soon but when recreating results we recommend ssh into the MacBook Pro and don’t turn on the display.

All the code in these repositories are very much work in progress and shown as technical previews which requires some level of polish before they are ready for non-technical users to be able to use. We plan to continue making it user friendly and add eager mode support to Torch-MLIR so PyTorch support on M1Max GPUs works out the box for seamless development to deployment. If putting all the open source pieces is not your thing or if you have a business case for deploying PyTorch with GPU support on M1 devices today and professional solution please sign up here and our solutions team will reach out to help.

BERT MINILM Artifacts [ MLIR | GPUBIN ]

BERT Training Artifacts [ MLIR | GPUBIN ]

Conclusion and future work

Using SHARK Runtime, we demonstrate high performance PyTorch models on Apple M1Max GPUs. It outperforms Tensorflow-Metal by 1.5x for inferencing and 2x in training BERT models. In the near future we plan to enhance end user experience and add “eager” mode support so it is seamless from development to deployment on any hardware.

If you would like access to the commercial version of SHARK Runtime sign up for access here and if you have trouble recreating results please open an issue.

Acknowledgements

SHARK is built on open source packages Torch-MLIR,  LLVM/MLIR and Google IREE and we are thank for all developers and community for their support. We specifically want to call out Ben Vanik, Lei Zhang, Stella Laurenzo and Thomas Raoux for their help, support and guidance on the MLIR/IREE GPU codegen paths.

Comments are closed.