Python remains the dominant language for machine learning, yet modern AI systems rely heavily on compilation and runtime optimization to achieve performance. Two ecosystems — PyTorch and JAX — approach this challenge in fundamentally different ways, while sharing a common goal: preserving Python productivity while delivering high-performance execution.
In this poster, we will compare the execution stacks of PyTorch and JAX from a Python developer’s perspective. Starting from simple Python model code, we’ll trace how each framework captures computation, represents it internally, and executes it efficiently on accelerators.
We’ll introduce key concepts such as graph capture, intermediate representations, and runtime execution — using PyTorch 2.x and the JAX/OpenXLA ecosystem as concrete examples. Rather than diving into low-level compiler theory, the focus will be on building accurate mental models: how control flow is handled, how compilation boundaries differ, and how these design choices affect performance, debuggability, and developer experience.