FLUX.1-dev is one of the most popular open-weight models available today. Developed by Black Forest Labs, it has 12 billion parameters. The goal of this post is to provide a barebones example of the training loop to help you learn and implement your own fine-tuning pipelines.
Feel free to skip directly to the code if you already have background on the FLUX models. This code can be run directly on a GPU in a Marimo Notebook Oxen.ai. Signing up will give you $10 of free credits to get started.
The FLUX.1 Series of Models
If you haven't seen examples generated by a fine-tuned FLUX model, they are of very high quality. For example, here is a quick experiment I ran using the code below. I took 20 images of my dog from my camera roll (left column) and was able to fine-tune FLUX to generate the images on the right.
Black Forest Labs initially launched the Flux series with a few open weights models:
- FLUX.1 [dev] - An open-weight, guidance-distilled model for non-commercial applications. It has similar performance to pro, but you can contact them for a commercial license.
- NOTE: This license just got a whole lot more restrictive, and is pretty expensive at $999/month
- FLUX.1 [schnell] - The fastest model is tailored for local development and personal use. It is openly available under an Apache 2.0 license 🎉
- Schnell is a step distilled model, meaning it can generate an image in just a few steps. However, this makes it impossible to train on it directly because every step you train breaks down the compression more and more.
Just yesterday, they also announced the newest member of the family FLUX.1 Kontext, which is useful for editing images with text. This blog will focus on FLUX.1-dev, and we will follow up with more information on FLUX.1-Kontext with our next deep dive.
👨🎨 The Task
Our task is to train a LoRA on a character of our choosing. LoRAs are a smaller, lightweight set of parameters that are fast to train and can give a model superpowers. They can be used to customize models and imbue them with new characters or styles. In this case, we will be using my dog.
The two most common questions I get when diving into fine-tuning are:
- How much data do I need?
- How do I get the data?
For this example, all we needed was 20 diverse images of my dog. As a general rule of thumb: the more data, the better; the higher the quality, the better; and the more diverse the data, the better.
As for data sourcing, we will be doing a follow up post on how to generate synthetic data from existing frontier models. If all it takes is 20-100 images, it is worth doing some hand capture from your team or hiring with the skills to collect the initial images.
🧠 The Model
Before we dive headfirst into the code, it's helpful to have some background on the model we're training. There's a lot of jargon in the component names, and understanding the architecture will help you grasp the internals more clearly.
From the FLUX.1-Kontext paper the team states that: FLUX.1 is a rectified flow transformer trained in the latent space of an image autoencoder. Let's break down the two most important terms "image autoencoder" and "rectified flow transformer".
Autoencoder
Let's start with what is an autoencoder? An autoencoder takes in an image, and compresses it down to a vector through a model such as a convolutional neural network, then decompresses it trying to reconstruct the original image from only the vector (z). This intermediate vector is what is called the "latent space" representation of the image.
Having a good representation of the image in latent space is key for many aspects of the diffusion process. Autoencoders are also really easy to train since they don't need any human labeled images. They are essentially using the identity function as the objective.
Rectified Flow Transformer
The rectified flow transformer was introduced in the "Scaling Rectified Flow Transformers for High-Resolution Image Synthesis" paper, otherwise known as Stable Diffusion 3.
The diagram below abstracts all the internals of the Diffusion Transformer, but will give you a sense of the moving parts during training. We will reference these parts while implementing the code. If you are unfamiliar with Diffusion Models, we have a few Arxiv Dives on the process that will clear up some of the terminology.
The most important parts to pay attention to are:
- VAE - Encode the image into latent space
- T5 Encoder - Encode the text into latent space
- CLIP - Encode the text into latent space
- Noise - What we are trying to remove from the image. This is where the "rectified flow"comes in.
- Diffusion Transformer - Process the latent space to predict the noise
- Loss function - How close the noise prediction is to the actual noise
After training, you no longer have the image input + VAE encoder. We simply input the text and the noise. You then use the VAE decoder to decode the noise, remove the predicted noise, and iterate N times until you have a valid image.
💻 The Hardware
All the experiments here were run on a single H100 on Oxen.ai. These days it is pretty cheap and easy to rent an H100 in the cloud for a couple hours, so I didn't optimize for anything smaller. In theory you can quantize the model and train it on an A10 with 24GB of VRAM, but I wanted a dead simple example without many bells and whistles.
👨💻 The Code
We will be writing all of the code in a Marimo notebook, which will make it easy for us to iteratively run the cells and poke around the data and model. Marimo notebooks are pure Python, and can also be run as command line applications or web apps making it easy run anywhere.
⚙️ Install Dependencies
The code needs the following dependencies installed.
pip install pandas
pip install torch
pip install datasets
pip install trl
pip install peft
pip install huggingface_hub[hf_transfer]
pip install torch
pip install torchvision
pip install pillow
pip install tqdm
pip install diffusers
pip install transformers
pip install protobuf
pip install sentencepiece
pip install einops
pip install bitsandbytes
pip install oxenai When running in Oxen.ai, simply upload your file, click "Launch Notebook", and provide your dependencies in the custom build script. Be sure to select a beefy GPU like the H100 with enough memory.
👨💻 Import Dependencies
In the first cell of the Notebook, start by importing all the libraries that we will need for training.
# Generic libs
import os
import math
import random
import gc
import json
from pathlib import Path
from datetime import datetime
from tqdm import tqdm
# Data types and utils
import torch
# For F.mse_loss
import torch.nn.functional as F
# To load the datasets
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import numpy as np
import pandas as pd
# Loading Models
from diffusers import (
FluxTransformer2DModel,
FlowMatchEulerDiscreteScheduler,
DDPMScheduler,
AutoencoderKL,
FluxPipeline
)
from transformers import T5TokenizerFast, T5EncoderModel, CLIPTextModel, CLIPTokenizer
from peft import LoraConfig, get_peft_model, TaskType, get_peft_model_state_dict
from einops import rearrange, repeat
import bitsandbytes as bnb
# Saving Data to Disk
from safetensors.torch import save_file
from PIL import Image
# Saving Data to Oxen.ai (optional)
from oxen import RemoteRepo We will be saving the training data, samples during training, and model weights to an Oxen.ai repository, hence the oxen dependency at the end. This is optional, but a good way to version and store your data as you are training models.