← Research Timeline Aditya Jain / Apple Maps · 3D Reconstruction
Nov 2025
Topic 27 Nov 2025 Diffusion · ViT · Hardware Adaptation

JiT Diffusion —
Training on Consumer GPUs.

Adapting the LTH14/JiT (Just-in-Time) ViT-based diffusion model from its 8× H200 reference setup down to a 2× RTX 3060 cloud rig and getting ImageNet-256 training to convergence after a debugging cascade. Epoch 0 FID 281.24, loss 0.95 → 0.50, then crash from too-aggressive DataLoader workers — a documented hardware-adaptation log.

00 — Motivation

A benchmark, not a reproduction. Can JiT train on consumer GPUs?

The broader thesis line involves a sequence of generative-3D experiments using diffusion and flow-matching architectures. Each new architecture that the field publishes is a candidate component for downstream work, but only if it can be trained on accessible hardware — the 8× H200 setups in the original papers are not the deployment substrate. Before committing to JiT (or any other architecture) as the 3-D-extension backbone, the question is whether it trains at all on consumer-GPU hardware and how the training behaves relative to the prior flow-matching work on the same data.

JiT (Just-in-Time diffusion, LTH14 group) is recent at the time of this work (paper ~one week old when the experiment started). The interesting architectural choices: Vision Transformer (ViT) backbone instead of the standard U-Net, and x-prediction (predict the clean image directly) instead of the more common ε-prediction (predict noise) or v-prediction (predict velocity field) used by flow-matching variants. Reference setup: 8× H200 GPUs, batch size 128 per GPU, ImageNet-256. The question for the thesis: does this work at 2× RTX 3060 (12 GB each)?

The experiment is explicitly a benchmark, not a reproduction — the goal is not to match the paper's reported FID at scale, it's to establish how the architecture behaves under tighter compute, where the failure modes are, and how that compares to the prior flow-matching setup on the same hardware. The Phase 1 result here is: trains to epoch 0 stably, FID 281.24 (early-convergence range), crashed at epoch 1 from a DataLoader-worker tuning mistake, M4-Mac inference blocked on a pos_embed tensor-shape mismatch. The debugging path is the contribution.

What it informs
The benchmark feeds the architectural choice for subsequent thesis topics — particularly MambaFlow3D (Topic 26) which evaluates Mamba vs Transformer vs Flow-matching on the same hardware substrate. The JiT result establishes the ViT/x-prediction baseline that the Mamba comparison is measured against.
01 — JiT Architecture

ViT backbone + x-prediction objective.

JiT departs from the standard diffusion architecture in two ways:

Backbone — Vision Transformer instead of U-Net. The standard latent-diffusion design (Stable Diffusion, Imagen) uses a U-Net to map noisy latents to a prediction. JiT replaces this with a Vision Transformer — the input image is split into 16 × 16 patches, each patch is linearly embedded, the patches are processed by a transformer with self-attention, and the output patches are reassembled into a prediction-image. The advantage: patch-level attention captures long-range dependencies that a convolutional U-Net has to build up through depth.

Objective — x-prediction instead of ε-prediction. Most diffusion models train the network to predict the noise ε that was added to the clean image — the loss is ‖ε − ε̂‖². JiT trains the network to predict the clean image directly — loss is ‖x₀ − x̂₀‖². The two objectives are mathematically related (you can convert one prediction to the other given the noise schedule) but have different gradient properties in practice. The JiT authors argue x-prediction has more stable training at the high-signal-to-noise end of the diffusion schedule (where ε-prediction is poorly conditioned because the ε signal is small).

Prediction typeNetwork targetUsed by
ε-prediction (epsilon)Predict the noise added to the clean imageDDPM, Stable Diffusion 1.x, classical diffusion
v-prediction (velocity)Predict the velocity field v = α·ε − σ·x₀Stable Diffusion 2.x, Flow Matching, Rectified Flow
x-predictionPredict the clean image x₀ directlyJiT, some Imagen variants
02 — Hardware Adaptation

From 8 × H200 to 2 × RTX 3060 on Vast.ai.

The reference training setup uses 8 × NVIDIA H200 GPUs with ~141 GB of HBM3e each — a ~1 TB total HBM pool. The work here adapts to a Vast.ai cloud instance with 2 × RTX 3060 (12 GB GDDR6 each, 24 GB total). Adapting requires re-deriving the training configuration to fit the smaller memory budget while keeping the optimisation dynamics roughly equivalent.

SettingReference (8×H200)This work (2×RTX 3060)
Total GPU memory~1 130 GB HBM3e24 GB GDDR6
Batch size per GPU12816 (could go 24–32; conservative for stability)
Effective batch size1 02432 (32× smaller — informs LR scaling)
Memory used per GPU~80 % (~110 GB)~55 % (~6.6 GB) — room to grow
Model variantJiT-B/16, /32 variantsJiT-B/16 — base model, 16-pixel patches
Image resolution256²256² (matched)
DatasetImageNet-1k (full)ImageNet-256 from Kaggle (same data)

The batch-size reduction from 1024 → 32 is the dominant change. Learning rate was scaled linearly (LR / 32) per the standard square-root or linear scaling rule for batch-size adaptation; gradient accumulation was not used because the reduced batch was still in the stable optimisation range for the available steps.

Core Insight

8 × H200 → 2 × 3060.
Trains. Slowly.

The interesting result is that JiT's ViT-plus-x-prediction architecture is not memory-bound on consumer GPUs — 16-images-per-3060 fits comfortably in 12 GB with headroom. The bottleneck is throughput, not capacity: the same wallclock that produced epoch 0 on this rig in ~6 hours produced 50+ epochs on the reference H200 setup. The architecture works; the compute budget is the only thing that's different.

03 — Debugging Cascade

Five named issues, in the order they surfaced.

#SymptomRoot causeFix
1 iJIT_NotifyEvent CUDA error on startup PyTorch / CUDA version mismatch on the Vast.ai instance Reinstall torch==2.5.1 with CUDA 12.4 wheels
2 Training script can't find ImageNet split structure Kaggle ImageNet-256 archive is flat class directories, not train/+val/ Restructure with symlinks: ln -s class_dir train/class_dir
3 Loss looks healthy but generation hangs Sampling phase runs separately from training (157 steps, ~51 min on this rig) Wait — diagnostic only. Generation completes normally.
4 Training crashes at epoch 1 DataLoader workers killed by OOM. num_workers=12 is too aggressive for the 24 GB system memory available on this Vast.ai instance. Reduce to num_workers=4; reduce per-GPU batch to 12 if memory pressure persists
5 M4 Mac inference fails with pos_embed shape mismatch Checkpoint contains model weights for a different architecture variant than the inference script expects — possibly a code-version skew between training and inference paths Use the original main_jit.py --evaluate_gen path with a dummy ImageNet dataset structure rather than custom inference scripts
04 — Phase 1 Results

Epoch 0 FID 281.24. Far from publication, expected for the budget.

After the debugging cascade, training proceeded stably through epoch 0. Loss curve fell from ~0.95 at step 0 to ~0.50 at end-of-epoch (clean monotonic decrease, no instabilities). Sampling at end of epoch 0 produced an FID of 281.24 on the held-out validation split with the standard cfg=2.9 classifier-free guidance scale.

Context for the number: trained-to-convergence JiT-B/16 in the reference paper achieves FID ~3.5 on the same dataset. The 281 result is approximately what one would expect after one pass through the data with 32× smaller effective batch size — early-convergence range. The point of the experiment was not to chase FID, but to confirm the architecture trains stably under the constrained compute budget.

MetricValue
Final loss (end of epoch 0)~0.50 (from ~0.95 at step 0)
FID (epoch 0)281.24
Generation parameters--model JiT-B/16 --img_size 256 --gen_bsz 16 --num_images 64 --cfg 2.9
Checkpoint size2.4 GB
Wallclock per epoch~6 hours on 2× RTX 3060
GPU memory utilisation~6.6 GB / 12 GB per GPU (55 %)
Training terminationCrashed at epoch 1 (DataLoader worker OOM)

Interactive Demo · Live

Step through the diffusion denoising loop on a stylised input. Pick a target class, then advance through 10 sampling steps to watch noise resolve into a structured image. The middle pane shows the 16×16 ViT patch grid that JiT operates on; the right pane shows the current x-prediction output at the active step.

01 — Pure Noise Input · CLICK TO RE-SEED CAT
02 — 16×16 ViT Patches STEP 0 / 10
03 — x-Prediction Output denoised image

Full Technical Paper

arXiv-format write-up · JiT on consumer GPUs · architecture study, hardware adaptation, debugging cascade, baseline results

Read Paper →
Related Thesis Chapters
Hexplane Autoencoder
Companion architecture experiment from the same period — same hardware substrate (consumer GPU), different objective (reconstruction vs generation). Both informed the architectural choices for downstream work.
Hierarchical Part-Based Triplane
Downstream consumer of the architectural-choice analysis. The variable-cardinality diffusion problem flagged in the triplane paper is informed by the prediction-type and scaling lessons from this benchmark.
SculptNet — Coarse-to-Fine Reconstruction
Sister architecture experiment — different decomposition strategy (primitive assembly vs continuous denoising) but same consumer-hardware deployment constraint.
Appendix — Raw Materials
Transcripts & Source References
████████████████████████████████████████████████
███████████████████████████████████████

██████████████████████████████████████
█████████ · ████ · █████████████████████
█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████
Restricted Access