My name is Maksim Levental and I’m a PhD student in CS at UChicago. This Spring I worked as a compiler intern at Nod Labs. The project was in collaboration with Google’s IREE and torch-mlir projects and involved implementing a way to use PyTorch as an eager frontend to the MLIR compiler stack. What this means is users of PyTorch (which is the overwhelming majority of the research community) can now reap all the benefits of a powerful compiler stack, including sophisticated polyhedral optimizations, without dramatically altering their workflow. In addition, it enables PyTorch users to target hardware platforms that aren’t currently supported in-tree in PyTorch (such as bare-metal platforms, through IREE). 

The eager mode implementation winds its way through PyTorch’s dispatcher, its native intermediate representation (IR) TorchScript (TS), and the MLIR torch dialect. The details are worthwhile and interesting in and of themselves but pretty technical and so instead I’d rather talk about what it takes to port a non-trivial DNN model to MLIR in the absence of eager mode.

The model in question is Hugging Face’s (HF) implementation of Meta’s Open Pre-trained Transformers (OPT). For context, transformer models such as OPT are used for all sorts of Natural Language Processing (NLP) tasks, such as text completion, question answering, and language translation. Transformer models are famously very large (hundreds of millions to hundreds of billions of parameters), and thus very expensive to train, requiring extremely large datasets in order to get the weights to converge. Hence, performance improvement in the compute kernels composing such a model could easily amount to (pun intended) hundreds of thousands to millions of dollars saved during training (inference latency is always important but training is the real money maker/taker for these models). Thus, it makes a lot of sense (and cents!) to compile at least some parts of the model. 

The first step to compiling OPT via MLIR is to extract a TS representation of the model:

from transformers import OPTModel, OPTConfig, GPT2Tokenizer

configuration = OPTConfig()
configuration.return_dict = False # easier this way

model = OPTModel(configuration)
model.eval() # freeze dropout and etc
tokenizer = GPT2Tokenizer.from_pretrained("facebook/opt-350m")

inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
input_ids, attention_mask =["input_ids"],["attention_mask"]
ts = torch.jit.trace(model, (input_ids, attention_mask))

Most likely you’ll get warnings that warn about “Converting a tensor to a Python boolean”; this pertains to the dynamic nature of the python specification of the model and jit.trace’s fundamental inability to capture that dynamism accurately. In principle it would be better to statically analyze the model (using torch.jit.script) but that’s also a non-starter due to the class hierarchy designed by HF. 

After tracing the model you can use various torch-mlir APIs to further lower the model, or you can just use torch_mlir.compile to do the tracing and lowering in one fell swoop

module = torch_mlir.compile(
    (input_ids, attention_mask),

Or at least so you hope; this generates a big scary error:

Traceback (most recent call last):
  File "dSHARK/tank/pytorch/opt/", line 18, in <module>
    module = torch_mlir.compile(
  File "torch_mlir/", line 149, in compile
  File "torch_mlir/", line 53, in run_pipeline_with_repro_report
    raise Exception(f"""
Lowering TorchScript IR -> Torch Backend IR failed with the following diagnostics:
error: '' op operand type mismatch: expected operand type '!torch.float', but provided '!torch.number' for operand number 0
note: see current operation: %1025 = ""(%130, %1021, %1022, %1023, %1024) {callee = @__torch_mlir_shape_fn.aten.arange} …

The short story here is that parts of HF’s implementation of OPT rely heavily on python’s duck typing and PyTorch’s permissive type system to function (pun not intended). This one in particular is due to mask.size(-1) occasionally returning a tensor rather than an integer and the specific fix is int(mask.size(-1))

It’s perhaps unsurprising there are several more such idiosyncrasies that need to be handled before OPT can be successfully lowered to even the torch dialect. If you do fight the good fight, eventually you’ll be able to lower to torch dialect, but you will ultimately be thwarted when trying to lower to the linalg dialect; you will see errors about

%132 = torch.aten.view %130, %131 : !torch.tensor<[1,7,768],f32>, !torch.list<int> -> !torch.tensor<[1,7,12,64],f32> loc(#loc13)

and unsupported types. The explanation here is actually subtle (having to do with value semantics and container types) but suffice it to say it’s not happening anytime soon (maybe I’ll get to it soon…). For all these reasons, it would be in some sense preferable if there were a way to incrementally compile OPT, one usable subgraph at a time (if you’re thinking this sounds a lot like a JIT then you’re right!). This then is the value of eager mode in torch-mlir – a painless way to lower the parts of your model that are currently supported by torch-mlir and leave the remainder to PyTorch. By the way, I’d be remiss if I didn’t mention that PyTorch is also aiming to soon support BYOC (bring your own compiler) through TorchDynamo, within which all of this will be much simpler

In summary, if you’re using cutting edge, compute intensive, ginormous models, you probably want to compile them, at least incrementally. Eager mode within torch-mlir is a great way to dip your toes into the shallow end of the compiler framework pool, even if ultimately you’ll bite the bullet and compile your entire model. I really enjoyed working on the project and I learned a ton (I basically knew nothing about DNN compilers when I started…). I want to thank Sean Silva, Horace He, Ramiro Leal-Cavazos, Anush and everyone at Nod that helped me get up to speed on MLIR (and didn’t crucify me for git push -f blowing up the repo 😂).

Comments are closed.