Technical Report · cs.LG · cs.CV · Nov 2025
Documentation → ← Back to White Papers
Training JiT Diffusion on Two Consumer GPUs: Hardware Adaptation, Debugging Cascade, and Phase-1 Reproduction of ViT-Backbone x-Prediction Diffusion at ImageNet-256
Aaditya Jain
Diffusion Models · Independent Reproduction Study · Vast.ai Cloud GPU
Submitted: November 2025 Subject: cs.LG · cs.CV Keywords: diffusion models, ViT, x-prediction, JiT, hardware adaptation, reproduction study, consumer GPU training
Abstract
We report a Phase-1 reproduction of the JiT (Just-image-Tokens) diffusion architecture [1] on consumer hardware, scaled down from the reference 8 × H200 setup to a 2 × RTX 3060 (12 GB each) Vast.ai instance. The reference recipe — ViT-B/16 backbone, x-prediction parameterisation, ImageNet-256 — was held fixed; only the hardware-bound hyperparameters (per-GPU batch size, gradient-accumulation factor, num_workers, mixed-precision settings) were varied. The paper documents two contributions. First, a hardware-adaptation table that maps the reference 1024-effective batch on 80 % memory utilisation to a 32-effective batch on 55 % utilisation, and reports the resulting epoch-time inflation from approximately one hour per epoch on H200 to approximately six hours per epoch on RTX 3060. The mapping is reproducible — Phase-0 training to FID 281.24 at epoch 0 succeeded with loss decreasing monotonically from 0.95 to 0.50 across the first epoch and a 2.4 GB checkpoint emitted at the epoch boundary. Second, a debugging cascade — five named, ordered failures hit while bringing the reference repository up on the target hardware — that documents the specific incompatibility surface between current PyTorch / CUDA / MKL stacks and a research-style training script written for the data-centre stack. The five failures (Intel MKL ITT, ImageNet directory layout, generation-script hang, DataLoader OOM, position-embedding shape mismatch in the generation entry point) are reported with the symptom, the diagnostic step that isolated the cause, and the one-line fix. Training crashed at the start of epoch 1 — diagnosed as an out-of-memory event in the validation/sampling pass, not in the training loop — and is the open problem for Phase-2 work. The contribution is the documented hardware-adaptation recipe and the diagnostic cascade, not a new training method. Keywords: JiT, ViT-B/16, x-prediction diffusion, ImageNet-256, hardware adaptation, RTX 3060, Vast.ai, reproduction study.
1. Introduction

JiT diffusion [1] is a recent ViT-backbone diffusion-model recipe that replaces convolutional U-Net stacks [2] with a Vision Transformer [3] patch tokeniser and a residual transformer body, parameterising the diffusion target as the clean image x rather than the noise ε or the velocity v [4,5]. The reference release trains on 8 × H200 80 GB SXM with an effective batch of 1024 and reports approximately one hour per ImageNet-256 epoch.

This paper reports the result of taking that reference recipe and pushing it down to consumer hardware — 2 × RTX 3060 12 GB cards on a Vast.ai rental — without modifying the architecture, loss, optimiser, or schedule. The question the work answers is operational: does the recipe survive a 60× batch reduction and a 24× memory reduction, what does training look like in that regime, and what fails on the way down?

The contributions are: (1) a quantitative hardware-adaptation table mapping reference hyperparameters to consumer-GPU equivalents (§3); (2) a five-stage diagnostic cascade documenting the specific software-stack failures encountered when running the reference repository on a current PyTorch + CUDA 12.4 install on rented consumer GPUs (§4); (3) a Phase-1 training log showing the first epoch converging from loss 0.95 to 0.50 and reaching FID 281.24 with a 2.4 GB checkpoint, and the crash at epoch 1 that bounds the current result (§5).

This is a reproduction-style technical report. It does not propose a new architecture or training method. The reason it exists as a paper rather than as a private engineering log is that the hardware-adaptation table and the debugging cascade together compose a reusable recipe for any researcher attempting to bring a data-centre-scale ViT diffusion run down to consumer hardware, and the failure surface documented in §4 is broadly applicable to other reference releases written against the H100/H200 stack.

2. The JiT Architecture (Brief)
2.1 ViT backbone and patch tokenisation

The reference JiT-B model uses a ViT-B/16 backbone: 16 × 16 patch tokeniser over a 256 × 256 input, producing 256 patch tokens per image, plus a learned class token. The transformer body is 12 layers, 12 attention heads, 768-dimensional embeddings — the standard ViT-B configuration. Class conditioning enters via a small label-embedding MLP added to every token in a final projection step. Timestep conditioning enters through adaptive layer-norm modulation applied at every transformer block.

2.2 x-prediction parameterisation

The diffusion target is the clean image x rather than the noise ε or the velocity v. Per [1] the choice trades a small amount of late-timestep loss for stability at early timesteps where ε-prediction has degenerate signal. Loss is an MSE between the predicted clean-image and the ground-truth image, computed at each randomly-sampled timestep.

Table 1 — Diffusion-target parameterisations referenced in the JiT paper.
TargetFormFailure modeUsed by
ε-predictionNetwork predicts the additive Gaussian noiseSignal-to-noise blows up at t → 0; small errors imply large image errorsDDPM, classical diffusion
v-predictionNetwork predicts a rotated parameterisation of (x, ε)Slightly better than ε at small t but more complex training scheduleImagen, several SD variants
x-predictionNetwork predicts the clean image directlySlight loss degradation at large tJiT
2.3 Reference training setup

The reference release is trained on ImageNet-256 with 8 × NVIDIA H200 SXM (80 GB each), a per-GPU batch of 128, a total effective batch of 1024, AdamW at 1×10⁻⁴ with cosine warm-up over 5 000 steps, bf16 mixed-precision, and a 1 000-step DDPM schedule for sampling. Reported epoch time is approximately one hour. The model checkpoint at the reference scale is approximately 86 M parameters.

3. Hardware Adaptation

The target hardware is a Vast.ai consumer-GPU instance with 2 × NVIDIA RTX 3060 12 GB. Each card has approximately 7 % of an H200's memory bandwidth and 8 % of its memory capacity. The per-GPU batch size therefore has to drop, and the gradient-accumulation factor has to rise to preserve the effective batch.

Table 2 — Hardware-adaptation mapping from 8 × H200 reference to 2 × RTX 3060 target.
SettingReference (8 × H200)Target (2 × RTX 3060)Ratio
GPUs820.25 ×
Per-GPU memory80 GB12 GB0.15 ×
Per-GPU batch12840.031 ×
Effective batch (no accum)102480.0078 ×
Gradient-accumulation steps144 ×
Effective batch (with accum)1024320.031 ×
Mixed precisionbf16fp16 (AMP)RTX 3060 lacks bf16 hardware path
num_workers per loader164OOM otherwise (see §4.4)
Memory utilisation~80 %~55 % (6.6 / 12 GB)
Time per epoch~1 h~6 h~6 ×

Two points of nuance. First, the effective batch could be matched to the reference 1024 by setting gradient-accumulation to 128, but the resulting wall-clock per parameter-update step makes training intractable on rented GPU time. The 32-effective-batch setting was the compromise — small enough to keep update frequency reasonable, large enough to keep the gradient noise within the regime where AdamW behaves comparably to the reference. The loss curve in §5 indicates this compromise is sound for the first epoch; whether it sustains across hundreds of epochs is an open question.

Second, the bf16 → fp16 switch is forced — Ampere-class consumer cards do not have a bf16 fast path. fp16 is more prone to overflow in the transformer's attention softmax denominator; the AdamW + gradient clipping at 1.0 setting was retained from the reference and proved sufficient to keep fp16 stable through the first epoch.

4. The Debugging Cascade

Five named failures were hit between cloning the reference repository and the first successful training step. Each is reported as symptom → diagnostic → fix. The cascade is ordered — earlier failures masked later ones — but is otherwise not interesting as a sequence; the table below is the recipe a future reproducer should consult.

Table 3 — Five named failures on the path from clone to first training step.
#SymptomDiagnosticFix
1 iJIT_NotifyEvent symbol error on PyTorch import. Crashes before any model code runs. The Intel MKL-ITT integration in the cached torch wheel disagrees with the system MKL version installed by the base image. Reinstall torch == 2.5.1 against CUDA 12.4 wheel index. Forces a clean MKL bind.
2 ImageNet loader returns zero examples. Dataset is "present" on disk. Reference loader expects per-class subdirectories (train/n01440764/*.JPEG). The mounted dataset is in flat-file layout. Generate symlink tree: one shell-loop creates train/<synset>/ dirs and symlinks each .JPEG into its class subdir based on the filename prefix. No data copied; symlinks are zero-cost.
3 Generation script main_jit.py --evaluate_gen hangs at "Loading model …" indefinitely. No GPU usage. Process is alive but not progressing. DDP rendezvous is failing silently when launched without torchrun. Wait. (The first run was killed at 8 minutes assuming a deadlock. The second run was let to continue and completed at 14 minutes — the rendezvous time-out plus first-batch latency on cold cache.) No code fix needed; the diagnostic is "do not assume hang.")
4 DataLoader OOM at startup before any training step. System memory not GPU memory — pinned host memory exhausted. num_workers = 16 (reference default) × per-worker prefetch buffer × pinned memory request exceeds the 32 GB host on the Vast.ai instance. Set num_workers = 4. Lose some throughput, gain training that starts.
5 Loading a checkpoint into the generation entry point fails with a position-embedding shape mismatch: checkpoint has (257, 768), model expects (N, 768) for a different N. The reference repo's training entry point and generation entry point construct the ViT with different default token counts. The CLI-flag wiring on --evaluate_gen in main.py does not propagate the right config. Use main_jit.py --evaluate_gen (the JiT-specific generation entry) rather than the generic main.py --evaluate_gen. The former has the correct token-count default. (Discovered by reading main_jit.py argparse defaults.)

The general lesson from the cascade is that each failure had a single-line fix once isolated; the cost was the isolation step, not the fix. The fifth failure in particular cost the most diagnostic time — the symptom (shape mismatch) suggests a checkpoint-vs-config bug, but the actual cause was using the wrong entry-point script in a repository that ships two near-identical ones.

5. Phase-1 Results

After the cascade resolved, training started and converged through epoch 0. Key numbers in Table 4.

Table 4 — Phase-1 training log, epoch 0.
MetricValueNotes
Initial training loss0.95MSE on x-prediction at step 0; in line with reference initialisation.
End-of-epoch-0 loss0.50Monotonic decrease, no instability spikes.
FID (50 K samples, epoch 0 checkpoint)281.24Far from converged but matches reference epoch-0 FID order-of-magnitude.
Checkpoint size2.4 GBModel + optimiser state + EMA; the 86 M-parameter model takes ~340 MB in fp32, the rest is optimiser state.
Wall-clock per epoch~6 h~6 × the reference H200 time.
Peak GPU memory6.6 / 12 GB (55 %)Headroom for slightly larger micro-batch on Phase 2.
Crash pointStart of epoch 1OOM during the validation/sampling pass — fixable; see §6.

The epoch-0 FID of 281.24 is uninformative as a quality number — for context, the reference paper reports FID below 5 at the converged checkpoint and FID well above 200 at the first epoch. The number is reported here as a sanity check that the model is genuinely learning, not as a quality claim.

The crash at epoch 1 was not in the training loop but in the validation/sampling step that runs at every epoch boundary. The sampling pass uses a larger effective batch than the training pass because the 1 000-step DDPM unroll holds a tensor for every step in memory simultaneously. The fix is to reduce the sampling batch — the work is not yet done and is the Phase-2 entry condition.

6. Open Problems and Phase-2 Plan

Three concrete follow-ups close out the work documented here. (i) Sampling-pass OOM. Reduce the per-step sampling batch and/or implement gradient checkpointing through the sampling unroll. Either change is a few lines but neither has been tested. (ii) Multi-epoch run. The first epoch converged as expected; the next milestone is to run to ~50 epochs and check the FID trajectory against the reference paper's published curve. The wall-clock cost is approximately 12.5 days at 6 h/epoch on the rented instance, so this is a money-on-rental rather than an engineering question. (iii) Architectural follow-on. The recipe documented here is a stepping-stone to MambaFlow3D [6] — substituting a Mamba state-space block [7] for the ViT transformer body, retaining x-prediction, and porting the recipe to 3D triplane representations. The Phase-1 result shows the hardware-adaptation table works for the ViT case; whether it survives the Mamba substitution is the next experiment.

7. Conclusion

The JiT reference recipe trains on 2 × RTX 3060 with a roughly 6 × wall-clock penalty per epoch, a 32 × reduction in effective batch, and a 55 % memory utilisation. The first epoch converged with monotonic loss decrease and an end-of-epoch FID consistent with the reference paper's epoch-0 number. Reaching that first epoch required resolving five distinct software-stack failures, none individually deep but together expensive in diagnostic time, all reported in §4 with the symptom-diagnostic-fix triple. The contribution is the documented recipe and the diagnostic cascade — not a new training method. The next experiment is the Mamba-substitution follow-on in [6].

References
[1] Li, T. et al. "Just-image-Tokens (JiT) Diffusion: A Transformer-Native Recipe for Image Diffusion." GitHub: LTH14/JiT, 2025.
[2] Ho, J., Jain, A., Abbeel, P. "Denoising Diffusion Probabilistic Models." NeurIPS, 2020.
[3] Dosovitskiy, A. et al. "An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale." ICLR, 2021. The ViT backbone.
[4] Salimans, T., Ho, J. "Progressive Distillation for Fast Sampling of Diffusion Models." ICLR, 2022. Velocity-parameterisation source.
[5] Karras, T. et al. "Elucidating the Design Space of Diffusion-Based Generative Models." NeurIPS, 2022. Survey of parameterisation choices.
[6] Jain, A. "MambaFlow3D: State-Space Backbones for 3-D Diffusion." Thesis research, in preparation. /whitepaper/mambaflow3d
[7] Gu, A., Dao, T. "Mamba: Linear-Time Sequence Modelling with Selective State Spaces." 2023.
[8] Paszke, A. et al. "PyTorch: An Imperative Style, High-Performance Deep Learning Library." NeurIPS, 2019. The 2.5.1+CUDA 12.4 build referenced in §4.1.
[9] NVIDIA Corp. "H200 SXM Datasheet" and "RTX 3060 (GA106) Datasheet." For the memory-bandwidth and capacity ratios in §3.