MyNixOS website logo
Description

Haskell Frontend for StableHLO — type-safe ML inference on CPU and GPU.

HHLO is a Haskell library and runtime for building, compiling, and executing machine-learning programs targeting StableHLO, the portable intermediate representation of the OpenXLA ecosystem. . Instead of replicating JAX's Python-based tracing infrastructure, HHLO generates StableHLO MLIR text directly from Haskell and compiles it to CPU or GPU via the PJRT plugin interface. . Key features: . * Type-safe EDSL with phantom-shape tensors (e.g. Tensor '[2,3] 'F32) * 50+ ops: matmul, conv2d, softmax, batch-norm, control flow, and more * CPU execution via PJRT CPU plugin * GPU execution via PJRT CUDA plugin with device enumeration * Multi-GPU concurrent inference scaling * No LLVM/MLIR build dependency — emits text and lets PJRT parse it . See the README for setup instructions and examples.

HHLO — Haskell Frontend for StableHLO

HHLO is a Haskell library and runtime for building, compiling, and executing machine learning programs targeting StableHLO, the portable, versioned intermediate representation of the OpenXLA ecosystem.

Instead of replicating JAX's Python-based tracing infrastructure, HHLO generates StableHLO MLIR text directly from Haskell and compiles it to CPU or GPU via the PJRT plugin interface.


Design

HHLO is structured in four layers:

┌─────────────────────────────────────┐
│  EDSL (HHLO.EDSL.Ops)               │  Type-safe frontend: add, matmul, relu, etc.
├─────────────────────────────────────┤
│  IR Builder (HHLO.IR.Builder)       │  Stateful monad for constructing MLIR
├─────────────────────────────────────┤
│  Pretty Printer (HHLO.IR.Pretty)    │  Emits StableHLO MLIR text
├─────────────────────────────────────┤
│  PJRT Runtime (HHLO.Runtime.*)      │  Compile → Execute on CPU or GPU
└─────────────────────────────────────┘

Text Emission + PJRT

The library emits StableHLO MLIR text directly and hands it to PJRT_Client_Compile. This is the same path used by JAX's C++ backend and avoids the heavy dependency of building LLVM/MLIR from source.

Phantom Types

Every tensor carries its shape and dtype as phantom type parameters:

Tensor '[2, 3] 'F32   -- 2×3 matrix of Float32

Matmul, broadcast, and conv shapes are checked at compile time via type families.

ForeignPtr Finalizers

PJRT buffers and executables are managed by ForeignPtr finalizers that automatically call PJRT_Buffer_Destroy and PJRT_LoadedExecutable_Destroy when values are garbage-collected. You can still let references drop out of scope without explicit cleanup.

Dynamic Output Counts

The runtime queries the compiled executable for its actual number of outputs via PJRT_Executable_NumOutputs instead of guessing or hardcoding a maximum.

Async Execution

HHLO.Runtime.Async provides true non-blocking execution: executeAsync returns buffer handles immediately, bufferReady polls for completion, and awaitBuffers blocks until device-side computation finishes.

Device Enumeration & Selection

HHLO.Runtime.Device lets you discover and select specific GPUs at runtime:

addressableDevices api client        -- list all devices
deviceKind api dev                   -- "cpu" or "NVIDIA GeForce RTX 5090"
defaultGPUDevice api client          -- first non-CPU device

Multi-GPU Inference Scaling

HHLO.Runtime.Execute provides executeReplicas for running the same compiled model concurrently across multiple GPUs:

compileWithOptions api client mlirText
    (defaultCompileOptions { optNumReplicas = numDevs })

-- Launch independent forward passes on all GPUs
executeReplicas api exec
    [ (gpu0, [bufA0, bufB0])
    , (gpu1, [bufA1, bufB1])
    , ...
    ]

Multi-Result Operations

The AST Operation type supports multiple results, enabling ops like stablehlo.rng_bit_generator and multi-value control flow:

-- Two-result operation
(newState, output) <- rngBitGenerator state

Multi-Value Control Flow

whileLoop2 / conditional2 carry multiple typed tensors through loops and conditionals without manual packing:

-- Loop with two accumulators: counter and running sum
(resultCounter, resultSum) <- whileLoop2 counter0 sum0
    (\c s -> compare c limit "LT")
    (\c s -> do
        cNext <- add c one
        sNext <- add s cNext
        returnTuple2 cNext sNext)

Random Number Generation

Three RNG primitives are exposed in the EDSL:

uniform  <- rngUniform a b      -- uniform in [a, b)
normal   <- rngNormal            -- standard normal (mean 0, std 1)
(newSt, bits) <- rngBitGenerator state   -- Threefry bit generator

Installation

System Requirements

  • GHC 9.6+ and Cabal 3.10+
  • Linux x86_64 (other platforms supported by PJRT artifacts may work)
  • curl, tar, and standard C toolchain (gcc or clang)
  • libstdc++ and libdl (usually present on Linux)

Download PJRT Plugins

Run the provided script to download prebuilt PJRT plugins:

./pjrt_script.sh

This downloads libpjrt_cpu.so from the zml/pjrt-artifacts nightly builds into deps/pjrt/. If you have an NVIDIA GPU with nvidia-smi available, the CUDA plugin is also fetched automatically.

Build the Project

cabal build all

This compiles the library, the demo, the examples, and the test suite.


Usage

CPU (works out of the box)

cabal run example-add --flag=examples
cabal test

Note: All example-* executables are guarded by the examples flag in hhlo.cabal (defaults to False). Append --flag=examples to every cabal run example-* command.

GPU (requires runtime libraries)

The PJRT CUDA plugin depends on NVIDIA runtime libraries: cuDNN, NCCL, and NVSHMEM. These are commonly available via conda, pip, or system packages.

If you already have them (e.g. via PyTorch or JAX installations), simply run:

./setup_gpu_env.sh
source ~/.bashrc

This idempotent script auto-discovers the libraries and appends them to ~/.bashrc. After that, GPU examples work directly:

cabal run example-gpu-add --flag=examples
cabal run example-gpu-matmul-bench --flag=examples
cabal run example-multi-gpu-inference --flag=examples

EDSL Quick Start

{-# LANGUAGE DataKinds #-}
{-# LANGUAGE OverloadedStrings #-}

import HHLO.Core.Types
import HHLO.EDSL.Ops
import HHLO.IR.AST (FuncArg(..), TensorType(..))
import HHLO.IR.Builder
import HHLO.IR.Pretty
import qualified Data.Text as T

-- Build a program: c = a + b
program :: Module
program = moduleFromBuilder @'[2,2] @'F32 "main"
    [ FuncArg "a" (TensorType [2, 2] F32)
    , FuncArg "b" (TensorType [2, 2] F32)
    ]
    $ do
        a <- arg
        b <- arg
        c <- add a b
        return c

main :: IO ()
main = T.putStrLn (render program)

Output:

module {
  func.func @main(%arg0: tensor<2x2xf32>, %arg1: tensor<2x2xf32>) -> tensor<2x2xf32> {
      %0 = stablehlo.add %arg0, %arg1 : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32>
      return %0 : tensor<2x2xf32>
  }
}

Running the Demo

cabal run hhlo-demo

The demo builds a stablehlo.add program via the EDSL, compiles it with PJRT CPU, creates F32 input buffers, executes, and reads back the result:

=== HHLO End-to-End Demo ===
Loading PJRT CPU plugin...
Plugin loaded.
...
Result: [6.0,8.0,10.0,12.0]
SUCCESS: Results match expected values!

Running Examples

Standalone examples are provided in examples/:

#CommandDescription
1cabal run example-add --flag=examplesElement-wise c = a + b
2cabal run example-matmul --flag=examples2×3 @ 3×2 matrix multiply
3cabal run example-chain-ops --flag=examples(a + b) * (a - b)
4cabal run example-async --flag=examplesAsync executeAsync + relu
5cabal run example-mlp --flag=examples2-layer MLP
6cabal run example-mlp-batched --flag=examplesBatched MLP
7cabal run example-tuple --flag=examplesMulti-result func.func
8cabal run example-reduce --flag=examplesreduceSum over all dimensions
9cabal run example-softmax --flag=examples1-D and batched 2-D softmax
10cabal run example-conv2d --flag=examplesNHWC conv2d
11cabal run example-batch-norm --flag=examplesBatch norm inference
12cabal run example-while --flag=exampleswhileLoop count-up
13cabal run example-conditional --flag=examplesconditional if-then-else
14cabal run example-gather --flag=examplesgather rows from matrix
15cabal run example-scatter --flag=examplesscatter replace into vector
16cabal run example-slice --flag=examplesslice sub-array extraction
17cabal run example-pad --flag=examplespad with edge/interior padding
18cabal run example-dynamic-slice --flag=examplesdynamicSlice runtime indices
19cabal run example-sort --flag=examplessort 1-D ascending
20cabal run example-select --flag=examplesElement-wise ternary select
21cabal run example-map --flag=examplesmap with custom computation
22cabal run example-new-ops-smoke-test --flag=examplesSmoke test for newer ops
23cabal run example-resnet --flag=examplesResNet-18 toy (8×8 input)
24cabal run example-alexnet --flag=examplesAlexNet toy (16×16 input)
25cabal run example-transformer --flag=examplesTransformer encoder (1×4×16)
26cabal run example-unet --flag=examplesUNet segmentation toy (16×16)
30cabal run example-rng-uniform --flag=examplesrngUniform random floats [0,1)
31cabal run example-rng-normal --flag=examplesrngNormal standard normal distribution
32cabal run example-rng-bit-generator --flag=examplesrngBitGenerator Threefry PRNG
33cabal run example-multi-value-loop --flag=exampleswhileLoop2 with two loop-carried values
27cabal run example-gpu-add --flag=examplesGPU smoke test
28cabal run example-gpu-matmul-bench --flag=examplesGPU 4096×4096 benchmark
29cabal run example-multi-gpu-inference --flag=examplesMulti-GPU concurrent matmul

Tests

CPU Tests (default)

cabal test

Runs 124 tests across three tiers:

  • Tier 1 — Golden tests — Verify rendered MLIR text for EDSL ops, IR constructs, NN layers, and control flow.
  • Tier 2 — End-to-end runtime tests — Load the PJRT CPU plugin, compile StableHLO programs, execute them, and verify numerical results. Covers arithmetic, matmul, reductions, data movement, and NN ops.
  • Tier 3 — Runtime integration tests — Buffer metadata queries, async execution, and error handling.

GPU Tests

HHLO_TEST_GPU=1 cabal test

Runs the full 124 CPU tests plus 6 additional GPU integration tests:

  • EndToEnd.GPU — GPU availability and device enumeration
  • Runtime.BufferGPU — Buffer round-trip and metadata queries on GPU
  • Runtime.AsyncGPU — Async execution and bufferReady polling on GPU
  • Runtime.MultiGPU — Concurrent executeReplicas across all GPUs

Sample output:

HHLO Tests
  EDSL.Ops
    Binary element-wise
      add:                            OK
      ...
  EndToEnd.Arithmetic
    relu:                             OK (0.02s)
    ...
  Runtime.Buffer
    buffer round-trip f32:            OK
  Runtime.Async
    buffer ready after sync execute:  OK (0.02s)
  EndToEnd.GPU
    gpu available:                    OK
  Runtime.BufferGPU
    gpu buffer round-trip f32:        OK
  Runtime.AsyncGPU
    gpu executeAsync + await:         OK
  Runtime.MultiGPU
    execute replicas on all GPUs:     OK

All 130 tests passed (16.27s)

Project Structure

.
├── app/                    # hhlo-demo executable
├── cbits/                  # C shim around PJRT C API
│   ├── pjrt_c_api.h        # Upstream PJRT header
│   ├── pjrt_shim.c         # Thin wrapper exposing flat C functions
│   └── pjrt_shim.h         # C header for the shim
├── deps/
│   └── pjrt/               # Downloaded PJRT plugins (.so files)
│       └── lib_symlinks/   # Compatibility symlinks for missing library versions
├── doc/                    # Architecture and design documents
├── examples/               # Standalone example programs (01–33)
├── src/HHLO/
│   ├── Core/Types.hs       # DType, Shape, HostType type families
│   ├── IR/
│   │   ├── AST.hs          # MLIR AST (Operation, Function, Module)
│   │   ├── Builder.hs      # Stateful Builder monad + Tensor/Tuple GADTs
│   │   └── Pretty.hs       # MLIR text pretty-printer
│   ├── EDSL/Ops.hs         # Type-safe frontend ops (50+ ops)
│   └── Runtime/
│       ├── PJRT/
│       │   ├── FFI.hs      # C FFI declarations
│       │   ├── Types.hs    # Opaque pointer newtypes + buffer type constants
│       │   ├── Error.hs    # PJRT error handling
│       │   └── Plugin.hs   # Backend-agnostic plugin loading (withPJRT)
│       ├── Device.hs       # Device enumeration & selection
│       ├── Compile.hs      # MLIR → PJRT executable
│       ├── Compile.hs      # MLIR → PJRT executable (with `CompileOptions`)
│       ├── Execute.hs      # Synchronous + device-targeted + multi-GPU replica execution
│       ├── Async.hs        # Non-blocking execution with PJRT_Event
│       └── Buffer.hs       # Host↔device buffer transfers + metadata queries
├── test/
│   ├── Test/
│   │   ├── EDSL/Ops.hs
│   │   ├── IR/
│   │   │   ├── Builder.hs
│   │   │   ├── Pretty.hs
│   │   │   ├── PrettyOps.hs
│   │   │   ├── PrettyNN.hs
│   │   │   └── PrettyControlFlow.hs
│   │   ├── Runtime/
│   │   │   ├── EndToEnd*.hs       # CPU E2E test modules
│   │   │   ├── EndToEndGPU.hs     # GPU availability test
│   │   │   ├── Buffer.hs
│   │   │   ├── BufferGPU.hs       # GPU buffer integration tests
│   │   │   ├── Async.hs
│   │   │   ├── AsyncGPU.hs        # GPU async tests
│   │   │   ├── MultiGPU.hs        # Multi-GPU inference scaling tests
│   │   │   └── Errors.hs
│   │   └── Utils.hs
│   └── Main.hs
├── hhlo.cabal
├── pjrt_script.sh          # Downloads PJRT plugins
├── setup_gpu_env.sh        # Auto-configures LD_LIBRARY_PATH for GPU
└── README.md

Architecture Docs

The doc/ directory contains detailed design documents:

DocumentContents
implementation-design.mdFour-layer architecture and design decisions
progress-and-remaining-work.mdCurrent status, completed features, and backlog
test-suite-documentation.mdTest catalog and tier descriptions

License

MIT License — see LICENSE.

Metadata

Version

0.2.0.0

License

Platforms (80)

    Darwin
    FreeBSD
    Genode
    GHCJS
    Linux
    MMIXware
    NetBSD
    none
    OpenBSD
    Redox
    Solaris
    uefi
    WASI
    Windows
Show all
  • aarch64-darwin
  • aarch64-freebsd
  • aarch64-genode
  • aarch64-linux
  • aarch64-netbsd
  • aarch64-none
  • aarch64-uefi
  • aarch64-windows
  • aarch64_be-none
  • arc-linux
  • arm-none
  • armv5tel-linux
  • armv6l-linux
  • armv6l-netbsd
  • armv6l-none
  • armv7a-linux
  • armv7a-netbsd
  • armv7l-linux
  • armv7l-netbsd
  • avr-none
  • i686-cygwin
  • i686-freebsd
  • i686-genode
  • i686-linux
  • i686-netbsd
  • i686-none
  • i686-openbsd
  • i686-windows
  • javascript-ghcjs
  • loongarch64-linux
  • m68k-linux
  • m68k-netbsd
  • m68k-none
  • microblaze-linux
  • microblaze-none
  • microblazeel-linux
  • microblazeel-none
  • mips-linux
  • mips-none
  • mips64-linux
  • mips64-none
  • mips64el-linux
  • mipsel-linux
  • mipsel-netbsd
  • mmix-mmixware
  • msp430-none
  • or1k-none
  • powerpc-linux
  • powerpc-netbsd
  • powerpc-none
  • powerpc64-linux
  • powerpc64le-linux
  • powerpcle-none
  • riscv32-linux
  • riscv32-netbsd
  • riscv32-none
  • riscv64-linux
  • riscv64-netbsd
  • riscv64-none
  • rx-none
  • s390-linux
  • s390-none
  • s390x-linux
  • s390x-none
  • sh4-linux
  • vc4-none
  • wasm32-wasi
  • wasm64-wasi
  • x86_64-cygwin
  • x86_64-darwin
  • x86_64-freebsd
  • x86_64-genode
  • x86_64-linux
  • x86_64-netbsd
  • x86_64-none
  • x86_64-openbsd
  • x86_64-redox
  • x86_64-solaris
  • x86_64-uefi
  • x86_64-windows