Adapting the LTH14/JiT (Just-in-Time) ViT-based diffusion model from its 8× H200 reference setup down to a 2× RTX 3060 cloud rig and getting ImageNet-256 training to convergence after a debugging cascade. Epoch 0 FID 281.24, loss 0.95 → 0.50, then crash from too-aggressive DataLoader workers — a documented hardware-adaptation log.
The broader thesis line involves a sequence of generative-3D experiments using diffusion and flow-matching architectures. Each new architecture that the field publishes is a candidate component for downstream work, but only if it can be trained on accessible hardware — the 8× H200 setups in the original papers are not the deployment substrate. Before committing to JiT (or any other architecture) as the 3-D-extension backbone, the question is whether it trains at all on consumer-GPU hardware and how the training behaves relative to the prior flow-matching work on the same data.
JiT (Just-in-Time diffusion, LTH14 group) is recent at the time of this work (paper ~one week old when the experiment started). The interesting architectural choices: Vision Transformer (ViT) backbone instead of the standard U-Net, and x-prediction (predict the clean image directly) instead of the more common ε-prediction (predict noise) or v-prediction (predict velocity field) used by flow-matching variants. Reference setup: 8× H200 GPUs, batch size 128 per GPU, ImageNet-256. The question for the thesis: does this work at 2× RTX 3060 (12 GB each)?
The experiment is explicitly a benchmark, not a reproduction —
the goal is not to match the paper's reported FID at scale, it's to
establish how the architecture behaves under tighter compute, where
the failure modes are, and how that compares to the prior flow-matching
setup on the same hardware. The Phase 1 result here is: trains to
epoch 0 stably, FID 281.24 (early-convergence range), crashed at
epoch 1 from a DataLoader-worker tuning mistake, M4-Mac inference
blocked on a pos_embed tensor-shape mismatch. The
debugging path is the contribution.
JiT departs from the standard diffusion architecture in two ways:
Backbone — Vision Transformer instead of U-Net. The
standard latent-diffusion design (Stable Diffusion, Imagen) uses a
U-Net to map noisy latents to a prediction. JiT replaces this with a
Vision Transformer — the input image is split into 16 × 16
patches, each patch is linearly embedded, the patches are processed by
a transformer with self-attention, and the output patches are
reassembled into a prediction-image. The advantage: patch-level
attention captures long-range dependencies that a convolutional U-Net
has to build up through depth.
Objective — x-prediction instead of ε-prediction.
Most diffusion models train the network to predict the noise ε
that was added to the clean image — the loss is ‖ε − ε̂‖².
JiT trains the network to predict the clean image directly —
loss is ‖x₀ − x̂₀‖². The two objectives are mathematically
related (you can convert one prediction to the other given the noise
schedule) but have different gradient properties in practice. The
JiT authors argue x-prediction has more stable training at the
high-signal-to-noise end of the diffusion schedule (where ε-prediction
is poorly conditioned because the ε signal is small).
| Prediction type | Network target | Used by |
|---|---|---|
| ε-prediction (epsilon) | Predict the noise added to the clean image | DDPM, Stable Diffusion 1.x, classical diffusion |
| v-prediction (velocity) | Predict the velocity field v = α·ε − σ·x₀ | Stable Diffusion 2.x, Flow Matching, Rectified Flow |
| x-prediction | Predict the clean image x₀ directly | JiT, some Imagen variants |
The reference training setup uses 8 × NVIDIA H200 GPUs with ~141 GB of HBM3e each — a ~1 TB total HBM pool. The work here adapts to a Vast.ai cloud instance with 2 × RTX 3060 (12 GB GDDR6 each, 24 GB total). Adapting requires re-deriving the training configuration to fit the smaller memory budget while keeping the optimisation dynamics roughly equivalent.
| Setting | Reference (8×H200) | This work (2×RTX 3060) |
|---|---|---|
| Total GPU memory | ~1 130 GB HBM3e | 24 GB GDDR6 |
| Batch size per GPU | 128 | 16 (could go 24–32; conservative for stability) |
| Effective batch size | 1 024 | 32 (32× smaller — informs LR scaling) |
| Memory used per GPU | ~80 % (~110 GB) | ~55 % (~6.6 GB) — room to grow |
| Model variant | JiT-B/16, /32 variants | JiT-B/16 — base model, 16-pixel patches |
| Image resolution | 256² | 256² (matched) |
| Dataset | ImageNet-1k (full) | ImageNet-256 from Kaggle (same data) |
The batch-size reduction from 1024 → 32 is the dominant change. Learning rate was scaled linearly (LR / 32) per the standard square-root or linear scaling rule for batch-size adaptation; gradient accumulation was not used because the reduced batch was still in the stable optimisation range for the available steps.
8 × H200 → 2 × 3060.
Trains. Slowly.
The interesting result is that JiT's ViT-plus-x-prediction architecture is not memory-bound on consumer GPUs — 16-images-per-3060 fits comfortably in 12 GB with headroom. The bottleneck is throughput, not capacity: the same wallclock that produced epoch 0 on this rig in ~6 hours produced 50+ epochs on the reference H200 setup. The architecture works; the compute budget is the only thing that's different.
| # | Symptom | Root cause | Fix |
|---|---|---|---|
| 1 | iJIT_NotifyEvent CUDA error on startup |
PyTorch / CUDA version mismatch on the Vast.ai instance | Reinstall torch==2.5.1 with CUDA 12.4 wheels |
| 2 | Training script can't find ImageNet split structure | Kaggle ImageNet-256 archive is flat class directories, not train/+val/ |
Restructure with symlinks: ln -s class_dir train/class_dir |
| 3 | Loss looks healthy but generation hangs | Sampling phase runs separately from training (157 steps, ~51 min on this rig) | Wait — diagnostic only. Generation completes normally. |
| 4 | Training crashes at epoch 1 | DataLoader workers killed by OOM. num_workers=12 is too aggressive for the 24 GB system memory available on this Vast.ai instance. |
Reduce to num_workers=4; reduce per-GPU batch to 12 if memory pressure persists |
| 5 | M4 Mac inference fails with pos_embed shape mismatch |
Checkpoint contains model weights for a different architecture variant than the inference script expects — possibly a code-version skew between training and inference paths | Use the original main_jit.py --evaluate_gen path with a dummy ImageNet dataset structure rather than custom inference scripts |
After the debugging cascade, training proceeded stably through epoch
0. Loss curve fell from ~0.95 at step 0 to ~0.50 at end-of-epoch
(clean monotonic decrease, no instabilities). Sampling at end of
epoch 0 produced an FID of 281.24 on the held-out
validation split with the standard cfg=2.9 classifier-free
guidance scale.
Context for the number: trained-to-convergence JiT-B/16 in the reference paper achieves FID ~3.5 on the same dataset. The 281 result is approximately what one would expect after one pass through the data with 32× smaller effective batch size — early-convergence range. The point of the experiment was not to chase FID, but to confirm the architecture trains stably under the constrained compute budget.
| Metric | Value |
|---|---|
| Final loss (end of epoch 0) | ~0.50 (from ~0.95 at step 0) |
| FID (epoch 0) | 281.24 |
| Generation parameters | --model JiT-B/16 --img_size 256 --gen_bsz 16 --num_images 64 --cfg 2.9 |
| Checkpoint size | 2.4 GB |
| Wallclock per epoch | ~6 hours on 2× RTX 3060 |
| GPU memory utilisation | ~6.6 GB / 12 GB per GPU (55 %) |
| Training termination | Crashed at epoch 1 (DataLoader worker OOM) |
Step through the diffusion denoising loop on a stylised input. Pick a target class, then advance through 10 sampling steps to watch noise resolve into a structured image. The middle pane shows the 16×16 ViT patch grid that JiT operates on; the right pane shows the current x-prediction output at the active step.
arXiv-format write-up · JiT on consumer GPUs · architecture study, hardware adaptation, debugging cascade, baseline results