How to Fine-Tune a FLUX.1-dev LoRA with Code, Step by Step

A barebones example of the training loop to help you learn and implement your own FLUX.1-dev fine-tuning pipelines

FLUX.1-dev is one of the most popular open-weight models available today. Developed by Black Forest Labs, it has 12 billion parameters. The goal of this post is to provide a barebones example of the training loop to help you learn and implement your own fine-tuning pipelines.

Feel free to skip directly to the code if you already have background on the FLUX models. This code can be run directly on a GPU in a Marimo Notebook Oxen.ai. Signing up will give you $10 of free credits to get started.

 


The FLUX.1 Series of Models

If you haven't seen examples generated by a fine-tuned FLUX model, they are of very high quality. For example, here is a quick experiment I ran using the code below. I took 20 images of my dog from my camera roll (left column) and was able to fine-tune FLUX to generate the images on the right.

Black Forest Labs initially launched the Flux series with a few open weights models:

Just yesterday, they also announced the newest member of the family FLUX.1 Kontext, which is useful for editing images with text. This blog will focus on FLUX.1-dev, and we will follow up with more information on FLUX.1-Kontext with our next deep dive.

👨🎨 The Task

Our task is to train a LoRA on a character of our choosing. LoRAs are a smaller, lightweight set of parameters that are fast to train and can give a model superpowers. They can be used to customize models and imbue them with new characters or styles. In this case, we will be using my dog.

The two most common questions I get when diving into fine-tuning are:

For this example, all we needed was 20 diverse images of my dog. As a general rule of thumb: the more data, the better; the higher the quality, the better; and the more diverse the data, the better.

As for data sourcing, we will be doing a follow up post on how to generate synthetic data from existing frontier models. If all it takes is 20-100 images, it is worth doing some hand capture from your team or hiring with the skills to collect the initial images.

🧠 The Model

Before we dive headfirst into the code, it's helpful to have some background on the model we're training. There's a lot of jargon in the component names, and understanding the architecture will help you grasp the internals more clearly.

From the FLUX.1-Kontext paper the team states that: FLUX.1 is a rectified flow transformer trained in the latent space of an image autoencoder. Let's break down the two most important terms "image autoencoder" and "rectified flow transformer".

Autoencoder

Let's start with what is an autoencoder? An autoencoder takes in an image, and compresses it down to a vector through a model such as a convolutional neural network, then decompresses it trying to reconstruct the original image from only the vector (z). This intermediate vector is what is called the "latent space" representation of the image.

Having a good representation of the image in latent space is key for many aspects of the diffusion process. Autoencoders are also really easy to train since they don't need any human labeled images. They are essentially using the identity function as the objective.

Rectified Flow Transformer

The rectified flow transformer was introduced in the "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" paper, otherwise known as Stable Diffusion 3.

The diagram below abstracts all the internals of the Diffusion Transformer, but will give you a sense of the moving parts during training. We will reference these parts while implementing the code. If you are unfamiliar with Diffusion Models, we have a few Arxiv Dives on the process that will clear up some of the terminology.

The most important parts to pay attention to are:

  1. VAE - Encode the image into latent space
  2. T5 Encoder - Encode the text into latent space
  3. CLIP - Encode the text into latent space
  4. Noise - What we are trying to remove from the image. This is where the "rectified flow"comes in.
  5. Diffusion Transformer - Process the latent space to predict the noise
  6. Loss function - How close the noise prediction is to the actual noise

After training, you no longer have the image input + VAE encoder. We simply input the text and the noise. You then use the VAE decoder to decode the noise, remove the predicted noise, and iterate N times until you have a valid image.

 


 

💻 The Hardware

All the experiments here were run on a single H100 on Oxen.ai. These days it is pretty cheap and easy to rent an H100 in the cloud for a couple hours, so I didn't optimize for anything smaller. In theory you can quantize the model and train it on an A10 with 24GB of VRAM, but I wanted a dead simple example without many bells and whistles.

👨💻 The Code

We will be writing all of the code in a Marimo notebook, which will make it easy for us to iteratively run the cells and poke around the data and model. Marimo notebooks are pure Python, and can also be run as command line applications or web apps making it easy run anywhere.

⚙️ Install Dependencies

The code needs the following dependencies installed.

pip install pandas
pip install torch
pip install datasets
pip install trl
pip install peft
pip install huggingface_hub[hf_transfer]
pip install torch
pip install torchvision
pip install pillow
pip install tqdm
pip install diffusers
pip install transformers
pip install protobuf
pip install sentencepiece
pip install einops
pip install bitsandbytes
pip install oxenai

When running in Oxen.ai, simply upload your file, click "Launch Notebook", and provide your dependencies in the custom build script. Be sure to select a beefy GPU like the H100 with enough memory.

 

👨💻 Import Dependencies

In the first cell of the Notebook, start by importing all the libraries that we will need for training.

# Generic libs
import os
import math
import random
import gc
import json
from pathlib import Path
from datetime import datetime
from tqdm import tqdm

# Data types and utils
import torch
# For F.mse_loss
import torch.nn.functional as F
# To load the datasets
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import pandas as pd

# Loading Models
from diffusers import (
    FluxTransformer2DModel,
    FlowMatchEulerDiscreteScheduler,
    DDPMScheduler,
    AutoencoderKL,
    FluxPipeline
)
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model, TaskType, get_peft_model_state_dict
from einops import rearrange, repeat
import bitsandbytes as bnb

# Saving Data to Disk
from safetensors.torch import save_file
from PIL import Image

# Saving Data to Oxen.ai (optional)
from oxen import RemoteRepo

We will be saving the training data, samples during training, and model weights to an Oxen.ai repository, hence the oxen dependency at the end. This is optional, but a good way to version and store your data as you are training models.

 

 

Updates

Follow us on Mastodon or Instagram for more regular updates on in-progress typefaces, and for more details about recent releases.