Skip to main content
This tutorial demonstrates how to leverage multiple GPUs for distributed model training using the fal.distributed module. We’ll build a production-ready Flux LoRA training service that uses Distributed Data Parallel (DDP) for efficient multi-GPU training with real-time progress streaming.
For a comprehensive overview of multi-GPU parallelism strategies including DDP and when to use them, see the Multi-GPU Workloads Overview.

Two-App Architecture

This tutorial demonstrates a microservices architecture with two separate apps communicating with each other:
  1. Preprocessor App (flux-preprocessor): Handles image preprocessing, captioning, and VAE/text encoding across multiple GPUs
  2. Training App (flux-training): Handles DDP training with the preprocessed data
Architectural Note: While this workflow could be implemented as a single app, we’re using two separate apps to demonstrate how to orchestrate multiple ML models across microservices. This pattern showcases:
  • Independent scaling: Scale preprocessing and training separately based on demand
  • Service isolation: Each model stays warm independently (no cold starts)
  • Flexible infrastructure: Different GPU types/counts for different workloads
  • Inter-service communication: Apps communicate via fal_client API calls
This microservices approach is valuable when building complex AI systems with multiple models that need independent deployment and scaling.
Important: For this tutorial, you need to run the preprocessor app first to get its app name (username/uuid format from the fal run output), then pass that app name to the training app.

🚀 Try this Example

View the complete source code on GitHub.
"""
Flux LoRA Training App - Calls Separate Preprocessor

Architecture:
- This app handles TRAINING only (8 GPUs with DDP)
- Calls flux-preprocessor-demo app for preprocessing (runs on separate 8 GPUs)
- Both apps stay warm, no reload overhead
- Clean separation of concerns
"""

import fal
from typing import ClassVar

from fal.distributed import DistributedRunner
from fal.toolkit import File, download_file
from fastapi.responses import StreamingResponse
from pydantic import BaseModel, Field

from fal_demos.distributed.training.flux_lora.trainer.training_worker import (
    FluxLoRATrainingWorker,
)


class CompleteTrainingRequest(BaseModel):
    """Request model for complete training (preprocessing + training)"""
    
    SCHEMA_IGNORES: ClassVar[set[str]] = {"preprocessor_app"}
    
    images_data_url: str = Field(
        description="URL to ZIP file containing training images"
    )
    trigger_word: str = Field(
        default="ohwx",
        description="Trigger word to inject into captions (e.g. 'ohwx', 'txcl')"
    )
    steps: int = Field(
        default=250,
        description="Number of training steps",
        ge=100,
        le=10000,
    )
    learning_rate: float = Field(
        default=5e-4,
        description="Base learning rate for optimizer"
    )
    b_up_factor: float = Field(
        default=3.0,
        description="Learning rate multiplier for lora_B parameters"
    )
    batch_size: int = Field(
        default=4,
        description="Batch size per GPU"
    )
    gradient_accumulation_steps: int = Field(
        default=1,
        description="Number of steps to accumulate gradients"
    )
    resolution: int = Field(
        default=512,
        description="Training resolution",
        ge=256,
        le=1024,
    )
    seed: int = Field(
        default=42,
        description="Random seed for reproducibility"
    )
    preprocessor_app: str = Field(
        description="URL to the preprocessor app",
        default="alex-w67ic4anktp1/flux-preprocessor-demo/",
    )


class TrainingRequest(BaseModel):
    """Request model for training from preprocessed data"""
    
    training_data_url: File = Field(
        description="Preprocessed training data (.pt file)",
        media_type="application/octet-stream",
    )
    learning_rate: float = Field(
        default=5e-4,
        description="Base learning rate for optimizer"
    )
    b_up_factor: float = Field(
        default=3.0,
        description="Learning rate multiplier for lora_B parameters"
    )
    max_train_steps: int = Field(
        default=250,
        description="Number of training steps"
    )
    batch_size: int = Field(
        default=4,
        description="Batch size per GPU"
    )
    gradient_accumulation_steps: int = Field(
        default=1,
        description="Number of steps to accumulate gradients"
    )
    guidance_scale: float = Field(
        default=1.0,
        description="Guidance scale for training"
    )
    use_masks: bool = Field(
        default=True,
        description="Whether to use masks (for face-focused training)"
    )
    lr_scheduler: str = Field(
        default="linear",
        description="Learning rate scheduler type: 'constant', 'linear', 'cosine'"
    )
    lr_warmup_steps: int = Field(
        default=0,
        description="Number of warmup steps for learning rate"
    )
    seed: int = Field(
        default=42,
        description="Random seed for reproducibility"
    )


class TrainingResponse(BaseModel):
    """Response model for training"""
    
    checkpoint: File = Field(
        description="Trained LoRA checkpoint file"
    )
    final_loss: float = Field(
        description="Final average training loss"
    )
    num_steps: int = Field(
        description="Total number of training steps completed"
    )
    message: str = Field(
        description="Status message"
    )


class FluxLoRATrainingApp(fal.App):
    """
    Flux LoRA Training App using DistributedRunner.
    
    This app focuses on TRAINING only and uses 8 GPUs with DDP.
    For preprocessing, it calls the flux-preprocessor-demo app.
    
    Benefits:
    - Training runner stays warm (no reload)
    - Preprocessor runs on separate GPUs (no conflict)
    - Clean separation of concerns
    - Can scale independently
    """
    
    machine_type = "GPU-H100"
    num_gpus = 2
    keep_alive = 3000
    min_concurrency = 1
    max_concurrency = 1
    
    requirements = [
        "torch==2.4.0",
        "diffusers==0.30.3",
        "transformers==4.46.0",
        "tokenizers==0.20.1",
        "sentencepiece",
        "peft==0.12.0",
        "safetensors==0.4.4",
        "accelerate==1.4.0",
        "pyzmq==26.0.0",
        "huggingface_hub==0.26.5",
        "fal-client",  # For calling preprocessor app
    ]
    
    async def setup(self) -> None:
        """
        Initialize the training runner.
        
        Downloads Flux weights and starts 8 GPU workers for training.
        Preprocessing is handled by calling a separate app.
        """
        import os
        from huggingface_hub import snapshot_download
        
        os.environ["HF_HOME"] = "/data/models"
        
        # Download Flux weights
        print("Downloading Flux model weights...")
        model_path = snapshot_download(
            repo_id="black-forest-labs/FLUX.1-dev",
            local_dir="/data/flux_weights",
        )
        print(f"Model downloaded to {model_path}")
        
        # Create training runner (uses all 8 GPUs)
        self.runner = DistributedRunner(
            worker_cls=FluxLoRATrainingWorker,
            world_size=self.num_gpus,
        )
        
        # Start training workers
        print(f"Starting {self.num_gpus} training workers...")
        await self.runner.start(model_path=model_path)
        
        print("Training workers ready!")
    
    @fal.endpoint("/train")
    async def train(
        self,
        request: CompleteTrainingRequest,
    ) -> TrainingResponse:
        """
        Complete training pipeline: raw images → trained LoRA.
        
        This endpoint:
        1. Calls flux-preprocessor-demo app (runs on separate 8 GPUs)
        2. Downloads preprocessed data
        3. Trains LoRA on our 8 GPUs
        4. Returns checkpoint
        
        Both apps stay warm, so no model reload overhead!
        """
        import fal_client
        import tempfile
        
        # Step 1: Call preprocessor app (runs on separate instance with 8 GPUs)
        print(f"Calling preprocessor app: {request.preprocessor_app}")
        print(f"Preprocessing {request.images_data_url}...")
        
        try:
            preprocess_result = fal_client.submit(
                request.preprocessor_app,
                arguments={
                    "images_data_url": request.images_data_url,
                    "trigger_word": request.trigger_word,
                    "resolution": request.resolution,
                }
            ).get()
        except Exception as e:
            print(f"Preprocessing failed: {e}")
            print(f"Make sure {request.preprocessor_app} is deployed!")
            raise
        
        # Get preprocessed data URL
        preprocessed_data_url = preprocess_result["preprocessed_data"]["url"]
        num_images = preprocess_result["num_images"]
        
        print(f"✓ Preprocessed {num_images} images")
        
        # Step 2: Download preprocessed data
        with tempfile.TemporaryDirectory() as temp_dir:
            print("Downloading preprocessed data...")
            preprocessed_path = str(download_file(
                preprocessed_data_url,
                target_dir=temp_dir
            ))
            
            # Step 3: Train LoRA (on our 8 GPUs)
            print(f"Training LoRA with {self.num_gpus} GPUs...")
            train_result = await self.runner.invoke({
                "training_data_path": preprocessed_path,
                "learning_rate": request.learning_rate,
                "b_up_factor": request.b_up_factor,
                "max_train_steps": request.steps,
                "batch_size": request.batch_size,
                "gradient_accumulation_steps": request.gradient_accumulation_steps,
                "guidance_scale": 1.0,
                "use_masks": False,
                "lr_scheduler": "constant",
                "lr_warmup_steps": 0,
                "seed": request.seed,
                "streaming": False,
            })
            
            # Upload checkpoint
            checkpoint_file = File.from_path(train_result["checkpoint_path"])
            
            return TrainingResponse(
                checkpoint=checkpoint_file,
                final_loss=train_result["final_loss"],
                num_steps=train_result["num_steps"],
                message=f"Training complete! Processed {num_images} images, {request.steps} steps, final loss: {train_result['final_loss']:.6f}"
            )
    
    @fal.endpoint("/stream")
    async def stream(
        self,
        request: CompleteTrainingRequest,
    ) -> StreamingResponse:
        """
        Complete training pipeline with real-time streaming progress.
        
        This endpoint streams training metrics in real-time:
        - Step number
        - Current loss
        - Average loss
        - Learning rate
        - Progress status
        
        Returns Server-Sent Events (SSE) stream.
        """
        import fal_client
        import tempfile
        
        # Step 1: Call preprocessor app (runs on separate instance with 8 GPUs)
        print(f"Calling preprocessor app: {request.preprocessor_app}")
        print(f"Preprocessing {request.images_data_url}...")
        
        try:
            preprocess_result = fal_client.submit(
                request.preprocessor_app,
                arguments={
                    "images_data_url": request.images_data_url,
                    "trigger_word": request.trigger_word,
                    "resolution": request.resolution,
                }
            ).get()
        except Exception as e:
            print(f"Preprocessing failed: {e}")
            print(f"Make sure {request.preprocessor_app} is deployed!")
            raise
        
        # Get preprocessed data URL
        preprocessed_data_url = preprocess_result["preprocessed_data"]["url"]
        num_images = preprocess_result["num_images"]
        
        print(f"✓ Preprocessed {num_images} images")
        
        # Step 2: Download preprocessed data and stream training
        async def generate_stream():
            with tempfile.TemporaryDirectory() as temp_dir:
                print("Downloading preprocessed data...")
                preprocessed_path = str(download_file(
                    preprocessed_data_url,
                    target_dir=temp_dir
                ))
                
                # Step 3: Train LoRA with streaming (on our GPUs)
                print(f"Training LoRA with {self.num_gpus} GPUs (streaming)...")
                async for event in self.runner.stream(
                    {
                        "training_data_path": preprocessed_path,
                        "learning_rate": request.learning_rate,
                        "b_up_factor": request.b_up_factor,
                        "max_train_steps": request.steps,
                        "batch_size": request.batch_size,
                        "gradient_accumulation_steps": request.gradient_accumulation_steps,
                        "guidance_scale": 1.0,
                        "use_masks": False,
                        "lr_scheduler": "constant",
                        "lr_warmup_steps": 0,
                        "seed": request.seed,
                        "streaming": True,  # Enable streaming!
                    },
                    as_text_events=True,
                ):
                    yield event
        
        return StreamingResponse(
            generate_stream(),
            media_type="text/event-stream",
        )
    
    @fal.endpoint("/train-from-preprocessed")
    async def train_from_preprocessed(
        self,
        request: TrainingRequest,
    ) -> TrainingResponse:
        """
        Train from already preprocessed data.
        
        Use this if you already have preprocessed .pt files.
        """
        import tempfile
        
        with tempfile.TemporaryDirectory() as temp_dir:
            print(f"Downloading preprocessed data from {request.training_data_url.url}")
            training_data_path = str(download_file(
                request.training_data_url.url,
                target_dir=temp_dir
            ))
            
            # Train on our 8 GPUs
            print(f"Training with {self.num_gpus} GPUs...")
            result = await self.runner.invoke({
                "training_data_path": training_data_path,
                "learning_rate": request.learning_rate,
                "b_up_factor": request.b_up_factor,
                "max_train_steps": request.max_train_steps,
                "batch_size": request.batch_size,
                "gradient_accumulation_steps": request.gradient_accumulation_steps,
                "guidance_scale": request.guidance_scale,
                "use_masks": request.use_masks,
                "lr_scheduler": request.lr_scheduler,
                "lr_warmup_steps": request.lr_warmup_steps,
                "seed": request.seed,
                "streaming": False,
            })
            
            # Check for errors
            if "error" in result:
                raise RuntimeError(f"Training failed: {result['error']}")
            
            # Upload checkpoint
            checkpoint_file = File.from_path(result["checkpoint_path"])
            
            return TrainingResponse(
                checkpoint=checkpoint_file,
                final_loss=result["final_loss"],
                num_steps=result["num_steps"],
                message=f"Training complete! {request.max_train_steps} steps, final loss: {result['final_loss']:.6f}"
            )
    
    @fal.endpoint("/train-from-preprocessed-stream")
    async def train_from_preprocessed_stream(
        self,
        request: TrainingRequest,
    ) -> StreamingResponse:
        """
        Train from already preprocessed data with real-time streaming progress.
        
        This endpoint streams training metrics in real-time as Server-Sent Events (SSE).
        """
        import tempfile
        
        async def generate_stream():
            with tempfile.TemporaryDirectory() as temp_dir:
                print(f"Downloading preprocessed data from {request.training_data_url.url}")
                training_data_path = str(download_file(
                    request.training_data_url.url,
                    target_dir=temp_dir
                ))
                
                # Train with streaming enabled
                print(f"Training with {self.num_gpus} GPUs (streaming)...")
                async for event in self.runner.stream(
                    {
                        "training_data_path": training_data_path,
                        "learning_rate": request.learning_rate,
                        "b_up_factor": request.b_up_factor,
                        "max_train_steps": request.max_train_steps,
                        "batch_size": request.batch_size,
                        "gradient_accumulation_steps": request.gradient_accumulation_steps,
                        "guidance_scale": request.guidance_scale,
                        "use_masks": request.use_masks,
                        "lr_scheduler": request.lr_scheduler,
                        "lr_warmup_steps": request.lr_warmup_steps,
                        "seed": request.seed,
                        "streaming": True,  # Enable streaming!
                    },
                    as_text_events=True,
                ):
                    yield event
        
        return StreamingResponse(
            generate_stream(),
            media_type="text/event-stream",
        )


if __name__ == "__main__":
    app = fal.wrap_app(FluxLoRATrainingApp)
    app()
Or clone this repository:
git clone https://github.com/fal-ai-community/fal-demos.git
cd fal-demos
pip install -e .
Step 1: Run the preprocessor app (in terminal 1):
# Start the preprocessor app
fal run flux-preprocessor

# You'll see output like:
# @ https://fal.run/username/0a5f684b-bad1-4fd8-8f1d-c49493ab18bd
# Extract the app name: username/0a5f684b-bad1-4fd8-8f1d-c49493ab18bd
# Copy this - you'll need it for the training app!
Step 2: Run the training app (in terminal 2):
# Start the training app
fal run flux-training

# You'll see output like:
# @ https://fal.run/username/1b2c3d4e-5f6a-7b8c-9d0e-1f2a3b4c5d6e
# Extract the app name: username/1b2c3d4e-5f6a-7b8c-9d0e-1f2a3b4c5d6e
Step 3: Submit a training request:
import fal_client

result = fal_client.submit(
    "username/1b2c3d4e-5f6a-7b8c-9d0e-1f2a3b4c5d6e",  # Training app name from Step 2
    arguments={
        "images_data_url": "https://example.com/training-images.zip",
        "trigger_word": "ohwx",
        "steps": 250,
        "preprocessor_app": "username/0a5f684b-bad1-4fd8-8f1d-c49493ab18bd",  # Preprocessor app name from Step 1!
    }
)
Before you run, make sure you have:
  • Authenticated with fal: fal auth login
  • Activated your virtual environment (recommended): python -m venv venv && source venv/bin/activate
  • A ZIP file containing training images (10-30 images recommended)

Key Features

Microservices Architecture:
  • Two independent apps communicate via API calls (demonstrates orchestrating multiple ML models)
  • Preprocessor app runs separately for image preprocessing, captioning, and encoding
  • Training app calls preprocessor via fal_client.submit() for clean service separation
Distributed Data Parallel (DDP) Training:
  • Each GPU has a full copy of the model wrapped in DDP
  • Each GPU processes different batches with automatic gradient synchronization
  • Only Rank 0 saves the final LoRA checkpoint

Architecture Overview

This training system demonstrates a microservices architecture with two separate apps communicating via API calls:

What is DDP (Distributed Data Parallel)?

DDP is a data parallelism strategy where:
  1. Each GPU has a full model copy: All workers have identical model parameters
  2. Each GPU processes different data: Training data is split across GPUs
  3. Gradients are synchronized: After backward pass, gradients are averaged across all GPUs
  4. Parameters stay in sync: All GPUs update with the same averaged gradients
This is the most common and efficient multi-GPU training strategy for models that fit on a single GPU.

Code Walkthrough

We’ll walk through the code in the order you’d build it: first the preprocessor (which you can test independently), then the training app.

Part 1: Preprocessor App

The preprocessor runs on separate GPUs and handles image preparation. Let’s start with the app definition:

Preprocessor App Configuration

class FluxPreprocessorApp(fal.App):
    """
    Standalone Flux preprocessing app.
    
    Uses all GPUs in parallel to preprocess images:
    - Each GPU processes a subset of images
    - Results are gathered and saved
    - Returns URL to preprocessed data
    """
    
    machine_type = "GPU-H100"
    num_gpus = 2
    keep_alive = 300
    min_concurrency = 0
    max_concurrency = 2
    
    requirements = [
        "torch==2.4.0",
        "torchvision",  # Required by moondream
        "diffusers==0.30.3",
        "transformers==4.46.0",
        "tokenizers==0.20.1",
        "sentencepiece",
        "accelerate==1.4.0",
        "pyzmq==26.0.0",
        "huggingface_hub==0.26.5",
        "moondream==0.0.5",
        "einops",  # Required by moondream
        "pillow>=10.0.0",
        "timm",  # Required by moondream vision encoder
    ]
Understanding the App Configuration:
  • machine_type = "GPU-H100": Specifies the hardware your app runs on. Here we’re using H100 GPUs for fast preprocessing.
  • num_gpus = 2: Requests 2 GPUs per runner. Each GPU will process different images in parallel.
  • keep_alive = 300: Keeps the runner warm for 5 minutes after the last request. Avoids cold starts for subsequent requests.
  • min_concurrency = 0: Allows runners to scale down to zero when idle (saves costs).
  • max_concurrency = 2: Allows up to 2 concurrent requests. Additional requests will queue.
  • requirements: Python packages to install on the runner. Always pin versions for reproducibility.

The Setup Function

The setup() function runs once when each runner starts. It’s where you load models, download weights, and initialize resources:
    async def setup(self) -> None:
        """
        Initialize the preprocessing runner.
        
        Downloads Flux weights and starts GPU workers for preprocessing.
        """
        import os
        from huggingface_hub import snapshot_download
        
        os.environ["HF_HOME"] = "/data/models"
        
        # Download Flux weights
        print("Downloading Flux model weights...")
        model_path = snapshot_download(
            repo_id="black-forest-labs/FLUX.1-dev",
            local_dir="/data/flux_weights",
        )
        print(f"Model downloaded to {model_path}")
        
        # Create preprocessing runner
        self.runner = DistributedRunner(
            worker_cls=FluxPreprocessorWorker,
            world_size=self.num_gpus,
        )
        
        # Start workers
        print(f"Starting {self.num_gpus} preprocessing workers...")
        await self.runner.start(model_path=model_path)
        
        print("Preprocessing workers ready!")
Key Concepts in Setup:
  1. /data/ directory: This is a persistent, shared volume attached to your runner. Files stored here persist across requests and are shared between the main process and all GPU workers. Perfect for model weights that you don’t want to re-download on every request.
  2. snapshot_download(..., local_dir="/data/flux_weights"): Downloads the Flux model from Hugging Face into the persistent /data/ volume. The first runner downloads it once, then subsequent runners (and requests) reuse the cached files.
  3. DistributedRunner(worker_cls, world_size): Creates a runner that orchestrates multiple GPU workers for parallel processing. See API Reference →
    • worker_cls=FluxPreprocessorWorker: Your custom worker class that inherits from DistributedWorker
    • world_size=self.num_gpus: Creates one worker process per GPU (2 workers for 2 GPUs)
  4. await self.runner.start(model_path=model_path): Starts all GPU worker processes and initializes them. See API Reference →
    • Each worker will run its own setup() method to load models onto its assigned GPU
    • The model_path keyword argument is passed to each worker’s setup() method
    • Waits for all workers to signal “READY” before returning
After setup() completes, your app is ready to handle requests. The runner stays warm (based on keep_alive), so subsequent requests skip this expensive setup.

Preprocessor Endpoint

The main endpoint handles the full preprocessing pipeline:
    @fal.endpoint("/")
    async def preprocess(
        self,
        request: PreprocessRequest,
        response: Response,
    ) -> PreprocessResponse:
        """
        Preprocess images for Flux LoRA training.
        
        This endpoint:
        1. Downloads ZIP of images
        2. Generates/loads captions
        3. Injects trigger word
        4. Encodes images with VAE (parallel across GPUs)
        5. Encodes captions with T5/CLIP (parallel across GPUs)
        6. Returns preprocessed data file
        """
        import tempfile
        
        with tempfile.TemporaryDirectory() as temp_dir:
            # Download images ZIP
            print(f"Downloading images from {request.images_data_url}")
            images_zip_path = str(download_file(
                request.images_data_url,
                target_dir=temp_dir
            ))
            
            # Generate unique request ID for this preprocessing job
            import time
            request_id = str(int(time.time() * 1000000))  # Microsecond timestamp
            
            # Preprocess in parallel across all GPUs
            print(f"Preprocessing with {self.num_gpus} GPUs (request_id: {request_id})...")
            result = await self.runner.invoke({
                "images_zip_url": images_zip_path,
                "request_id": request_id,
                "trigger_word": request.trigger_word,
                "resolution": request.resolution,
            })
            # ↑ runner.invoke() sends this dict to all workers' __call__() method
            # Workers process data in parallel, then rank 0 returns the result
            # See: /serverless/distributed/api-reference#invoke
            
            # Upload preprocessed data
            preprocessed_file = File.from_path(result["preprocessed_data_path"])
            num_images = result["num_images"]
            
            print(f"Preprocessing complete! Processed {num_images} images")
            
            return PreprocessResponse(
                preprocessed_data=preprocessed_file,
                num_images=num_images,
                message=f"Preprocessed {num_images} images with trigger word '{request.trigger_word}'"
            )
Testing the Preprocessor: Once you deploy this app with fal run flux-preprocessor, you can test it independently:
import fal_client

# Test preprocessing with your images
result = fal_client.subscribe(
    "username/your-preprocessor-app-id",
    arguments={
        "images_data_url": "https://example.com/training-images.zip",
        "trigger_word": "ohwx",
        "resolution": 512
    },
    with_logs=True,
)

print(f"Preprocessed {result['num_images']} images")
print(f"Data file: {result['preprocessed_data']['url']}")
This lets you verify your images are being captioned and encoded correctly before moving to training.

Preprocessor Worker Implementation

The worker runs on each GPU and handles the actual preprocessing. Here’s the implementation:
class FluxPreprocessorWorker(DistributedWorker):
    """Worker for preprocessing images in parallel across GPUs"""
    
    def setup(self, model_path: str, **kwargs):
        """
        Load VAE, text encoders, and captioning model on each GPU.
        Called once per worker during runner.start().
        """
        from diffusers import AutoencoderKL
        from transformers import CLIPTextModel, T5EncoderModel
        
        self.rank_print(f"Loading preprocessing models on {self.device}")
        
        # Load VAE for image encoding
        self.vae = AutoencoderKL.from_pretrained(
            model_path,
            subfolder="vae",
            torch_dtype=torch.float16,
        ).to(self.device)
        
        # Load text encoders (T5 and CLIP)
        self.text_encoder = CLIPTextModel.from_pretrained(
            model_path, subfolder="text_encoder"
        ).to(self.device)
        
        self.text_encoder_2 = T5EncoderModel.from_pretrained(
            model_path, subfolder="text_encoder_2"
        ).to(self.device)
        
        # Load captioning model (Moondream)
        from moondream import Moondream
        self.caption_model = Moondream(device=self.device)
        
        self.rank_print("Models loaded successfully")
    
    def __call__(
        self,
        images_zip_url: str,
        request_id: str,
        trigger_word: str,
        resolution: int = 512,
        **kwargs
    ):
        """
        Process a subset of images on this GPU.
        Each worker processes different images in parallel.
        """
        import torch.distributed as dist
        
        # Step 1: Load and distribute images
        if self.rank == 0:
            # Only rank 0 loads the full dataset
            images = load_images_from_zip(images_zip_url)
            num_images = len(images)
        else:
            images = []
            num_images = 0
        
        # Broadcast number of images to all workers
        num_images = dist.broadcast_object_list([num_images], src=0)[0]
        
        # Step 2: Each worker processes a subset
        # Worker 0 processes images [0, 4, 8, ...]
        # Worker 1 processes images [1, 5, 9, ...]
        my_images = images[self.rank::self.world_size]
        
        self.rank_print(f"Processing {len(my_images)} images")
        
        # Step 3: Generate captions for my subset
        captions = []
        for img in my_images:
            caption = self.caption_model.caption(img)
            # Inject trigger word
            caption = f"{trigger_word} {caption}"
            captions.append(caption)
        
        # Step 4: Encode images with VAE
        latents_list = []
        for img in my_images:
            img_tensor = preprocess_image(img, resolution)
            with torch.no_grad():
                latent = self.vae.encode(img_tensor.to(self.device)).latent_dist.sample()
            latents_list.append(latent)
        
        # Step 5: Encode captions with text encoders
        text_embeddings_list = []
        for caption in captions:
            with torch.no_grad():
                text_emb = encode_text(caption, self.text_encoder, self.text_encoder_2)
            text_embeddings_list.append(text_emb)
        
        # Step 6: Gather results from all workers to rank 0
        if self.rank == 0:
            # Rank 0 collects from all workers
            all_latents = gather_from_all_workers(latents_list)
            all_text_embeddings = gather_from_all_workers(text_embeddings_list)
            
            # Save preprocessed data
            output_path = f"/tmp/preprocessed_{request_id}.pt"
            torch.save({
                "latents": torch.cat(all_latents),
                "text_embeddings": torch.cat(all_text_embeddings),
                "num_images": num_images,
            }, output_path)
            
            return {
                "preprocessed_data_path": output_path,
                "num_images": num_images,
            }
        
        return {}  # Other ranks return empty
Key DistributedWorker Concepts in Preprocessing:
  • setup(model_path, **kwargs): Called once during runner.start() to load models. The model_path passed to runner.start() is received here. API Reference →
  • self.device: Each worker loads models onto its assigned GPU using self.device. Worker 0 uses cuda:0, worker 1 uses cuda:1, etc.
  • self.rank and self.world_size: Used to split work. Each worker processes every Nth image where N = world_size, starting at rank.
    • Worker 0 (rank=0): processes images [0, 2, 4, 6, …]
    • Worker 1 (rank=1): processes images [1, 3, 5, 7, …]
  • __call__(**kwargs): Receives the payload from runner.invoke() as keyword arguments. Processes data and returns results. API Reference →
  • self.rank_print(): Prints with rank prefix for debugging. API Reference →
  • Only rank 0 returns data: The final result is assembled on rank 0 and returned. Other workers return empty dict.

Part 2: Training App

Now that preprocessing works, let’s build the training app.

Training App Setup

class FluxLoRATrainingApp(fal.App):
    """
    Flux LoRA Training App using DistributedRunner.
    
    This app focuses on TRAINING only and uses GPUs with DDP.
    For preprocessing, it calls the flux-preprocessor app.
    
    Benefits:
    - Training runner stays warm (no reload)
    - Preprocessor runs on separate GPUs (no conflict)
    - Clean separation of concerns
    - Can scale independently
    """
    
    machine_type = "GPU-H100"
    num_gpus = 2
    keep_alive = 3000
    min_concurrency = 1
    max_concurrency = 1
    
    requirements = [
        "torch==2.4.0",
        "diffusers==0.30.3",
        "transformers==4.46.0",
        "tokenizers==0.20.1",
        "sentencepiece",
        "peft==0.12.0",
        "safetensors==0.4.4",
        "accelerate==1.4.0",
        "pyzmq==26.0.0",
        "huggingface_hub==0.26.5",
        "fal-client",  # For calling preprocessor app
    ]
    
    async def setup(self) -> None:
        """
        Initialize the training runner.
        
        Downloads Flux weights and starts GPU workers for training.
        Preprocessing is handled by calling a separate app.
        """
        import os
        from huggingface_hub import snapshot_download
        
        os.environ["HF_HOME"] = "/data/models"
        
        # Download Flux weights
        print("Downloading Flux model weights...")
        model_path = snapshot_download(
            repo_id="black-forest-labs/FLUX.1-dev",
            local_dir="/data/flux_weights",
        )
        print(f"Model downloaded to {model_path}")
        
        # Create training runner (uses all GPUs)
        self.runner = DistributedRunner(
            worker_cls=FluxLoRATrainingWorker,
            world_size=self.num_gpus,
        )
        
        # Start workers
        print(f"Starting {self.num_gpus} training workers...")
        await self.runner.start(model_path=model_path)
        
        print("Training workers ready!")
The training app setup follows the same pattern as the preprocessor: create a DistributedRunner, then call start() to initialize all workers. See the DistributedRunner API Reference for details.

Training Endpoint with Preprocessing

The main endpoint calls the preprocessor, then runs training:
    @fal.endpoint("/train")
    async def train(
        self,
        request: CompleteTrainingRequest,
    ) -> TrainingResponse:
        """
        Complete training pipeline: preprocess + train.
        
        This demonstrates microservices architecture where the training app
        calls a separate preprocessor app via API.
        """
        import fal_client
        import tempfile
        
        # Step 1: Call preprocessor app (runs on separate instance with GPUs)
        print(f"Calling preprocessor app: {request.preprocessor_app}")
        print(f"Preprocessing {request.images_data_url}...")
        
        try:
            preprocess_result = fal_client.subscribe(
                request.preprocessor_app,
                arguments={
                    "images_data_url": request.images_data_url,
                    "trigger_word": request.trigger_word,
                    "resolution": request.resolution,
                },
                with_logs=True,
            )
        except Exception as e:
            print(f"Preprocessing failed: {e}")
            print(f"Make sure {request.preprocessor_app} is deployed!")
            raise
        
        # Get preprocessed data URL
        preprocessed_data_url = preprocess_result["preprocessed_data"]["url"]
        num_images = preprocess_result["num_images"]
        
        print(f"✓ Preprocessed {num_images} images")
        
        # Step 2: Download preprocessed data and train
        with tempfile.TemporaryDirectory() as temp_dir:
            print("Downloading preprocessed data...")
            preprocessed_path = str(download_file(
                preprocessed_data_url,
                target_dir=temp_dir
            ))
            
            # Step 3: Train LoRA (on our GPUs)
            print(f"Training with {self.num_gpus} GPUs...")
            result = await self.runner.invoke({
                "training_data_path": preprocessed_path,
                "learning_rate": request.learning_rate,
                "b_up_factor": request.b_up_factor,
                "max_train_steps": request.steps,
                "batch_size": request.batch_size,
                "gradient_accumulation_steps": request.gradient_accumulation_steps,
                "guidance_scale": 1.0,
                "use_masks": False,
                "lr_scheduler": "constant",
                "lr_warmup_steps": 0,
                "seed": request.seed,
                "streaming": False,
            })
            # ↑ This payload dict is passed to each worker's __call__() method
            # All workers train in parallel with DDP, then rank 0 returns the checkpoint
            # See: /serverless/distributed/api-reference#invoke
            
            # Check for errors
            if "error" in result:
                raise RuntimeError(f"Training failed: {result['error']}")
            
            # Upload checkpoint
            checkpoint_file = File.from_path(result["checkpoint_path"])
            
            return TrainingResponse(
                checkpoint=checkpoint_file,
                final_loss=result["final_loss"],
                num_steps=result["num_steps"],
                message=f"Training complete! {request.steps} steps, final loss: {result['final_loss']:.6f}"
            )
Notice how this endpoint:
  1. Calls the preprocessor app via fal_client.subscribe()
  2. Downloads the preprocessed .pt file
  3. Passes it to the training worker via self.runner.invoke()

Part 3: Training Worker Implementation

The worker implements the actual DDP training logic. Here’s the key setup method:
class FluxLoRATrainingWorker(DistributedWorker):
    """
    Production-ready distributed worker for Flux LoRA training.
    """

    def setup(self, model_path: str = "/data/flux_weights", **kwargs: Any) -> None:
        """
        Initialize the model on each GPU worker with proper LoRA configuration.
        """
        from diffusers import FluxTransformer2DModel
        from peft import LoraConfig

        self.rank_print(f"Loading Flux model on {self.device}")
        
        # Load the transformer model
        self.transformer = FluxTransformer2DModel.from_pretrained(
            model_path,
            subfolder="transformer",
            torch_dtype=torch.bfloat16,
        ).to(self.device)
        
        # Configure LoRA targeting all blocks (both double and single stream)
        target_modules = []
        
        # Double stream blocks (19 blocks) - these handle text-image interaction
        for block_num in range(19):
            target_modules.extend([
                f"transformer_blocks.{block_num}.attn.to_q",
                f"transformer_blocks.{block_num}.attn.to_k",
                f"transformer_blocks.{block_num}.attn.to_v",
                f"transformer_blocks.{block_num}.attn.to_out.0",
                # ... more attention and FF layers
            ])
        
        # Single stream blocks (38 blocks) - CRITICAL for image generation
        for block_num in range(38):
            target_modules.extend([
                f"single_transformer_blocks.{block_num}.attn.to_q",
                f"single_transformer_blocks.{block_num}.attn.to_k",
                f"single_transformer_blocks.{block_num}.attn.to_v",
                # ... more attention layers
            ])
        
        lora_config = LoraConfig(
            r=16,  # rank
            lora_alpha=16,
            target_modules=target_modules,
            lora_dropout=0.0,
            bias="none",
            init_lora_weights="gaussian",
        )
        
        # Add LoRA adapters
        self.transformer.add_adapter(lora_config)
        
        # Freeze base model, only train LoRA
        self.transformer.requires_grad_(False)
        for name, param in self.transformer.named_parameters():
            if "lora" in name:
                param.requires_grad = True
        
        # Wrap with DDP for synchronized training
        self.transformer = DDP(
            self.transformer,
            device_ids=[self.rank],
            output_device=self.rank,
            find_unused_parameters=False,
        )
        
        self.rank_print("Model loaded and wrapped with DDP")
Key DistributedWorker Concepts:
  • setup(**kwargs): Called once per worker during runner.start(). This is where you load models, download weights, and initialize resources. See API Reference →
  • self.device: The CUDA device for this worker (cuda:0, cuda:1, etc.). Always load your model with .to(self.device).
  • self.rank: Worker ID (0 to world_size-1). Useful for rank-specific operations like saving checkpoints only on rank 0.
  • self.rank_print(): Prints messages with the rank prefix for easy debugging. See API Reference →
The key steps in this setup:
  1. Load Flux transformer on each GPU using self.device
  2. Add LoRA adapters to specific layers
  3. Freeze base model, only train LoRA parameters
  4. Wrap with DDP for gradient synchronization

Training Loop with Data Distribution

The __call__() method is called for each training request and implements the actual training loop. This method receives the payload dict from runner.invoke() as keyword arguments. See API Reference →
    def __call__(
        self,
        streaming: bool = False,
        training_data_path: str = None,
        learning_rate: float = 4e-4,
        max_train_steps: int = 100,
        batch_size: int = 1,
        seed: int = 42,
        **kwargs: Any,
    ) -> dict[str, Any]:
        """
        Production-ready training function with proper Flux training logic.
        """
        self.rank_print(f"Starting training: lr={learning_rate}, steps={max_train_steps}")
        
        # Step 1: Load training data (only rank 0)
    if self.rank == 0:
            self.rank_print(f"Loading training data from {training_data_path}")
            data = torch.load(training_data_path, map_location="cpu")
        latents = data["latents"]
        text_embeddings = data["text_embeddings"]
            pooled_embeddings = data["pooled_embeddings"]
            text_ids = data["text_ids"]
    else:
        latents = torch.empty(0)
        text_embeddings = torch.empty(0)
            pooled_embeddings = torch.empty(0)
            text_ids = torch.empty(0, dtype=torch.long)
    
        # Step 2: Broadcast data to all ranks
        torch.cuda.set_device(self.device)
        objects = [latents, text_embeddings, pooled_embeddings, text_ids]
    dist.broadcast_object_list(objects, src=0)
        latents, text_embeddings, pooled_embeddings, text_ids = objects
        
        # Move to GPU
        latents = latents.to(self.device, dtype=torch.bfloat16)
        text_embeddings = text_embeddings.to(self.device, dtype=torch.bfloat16)
        pooled_embeddings = pooled_embeddings.to(self.device, dtype=torch.bfloat16)
        text_ids = text_ids.to(self.device, dtype=torch.long)
        
        num_samples = latents.shape[0]
        self.rank_print(f"Loaded {num_samples} training samples")
        
        # Step 3: Set up optimizer with DUAL learning rates for lora_A and lora_B
        params_A = []
        params_B = []
        
        model = self.transformer.module
        
        for name, param in model.named_parameters():
            if param.requires_grad:
                if "lora_A" in name:
                    params_A.append(param)
                elif "lora_B" in name:
                    params_B.append(param)
        
        optimizer = torch.optim.AdamW([
            {"params": params_A, "lr": learning_rate, "weight_decay": 0.1},
            {"params": params_B, "lr": learning_rate * b_up_factor, "weight_decay": 0.1},
        ], betas=(0.9, 0.999), eps=1e-8)
        
        # Step 4: Training loop
        self.transformer.train()
        total_loss = 0.0
        
    for step in range(max_train_steps):
            # Each GPU gets different batch indices
            batch_indices = torch.randperm(num_samples, device="cpu")[:batch_size * self.world_size]
        local_indices = batch_indices[self.rank * batch_size:(self.rank + 1) * batch_size]
        
        # Get local batch
        batch_latents = latents[local_indices]
        batch_text_emb = text_embeddings[local_indices]
            batch_pooled_emb = pooled_embeddings[local_indices]
        
        # Compute loss
            loss = self.compute_flux_loss(
                latents=batch_latents,
                text_embeddings=batch_text_emb,
                pooled_embeddings=batch_pooled_emb,
                # ... other args
            )
        
        # Backward pass (DDP automatically syncs gradients)
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
            total_loss += loss.item()
            
            # Stream progress (only rank 0)
            if streaming and self.rank == 0 and step % 10 == 0:
                avg_loss = total_loss / (step + 1)
                self.add_streaming_result({
                    "step": step,
                    "loss": loss.item(),
                    "avg_loss": avg_loss,
                }, as_text_event=True)
                # ↑ Sends intermediate results to the client during streaming
                # Only call from rank 0 to avoid duplicate messages
                # See: /serverless/distributed/api-reference#add_streaming_result
    
        # Step 5: Save checkpoint (only rank 0)
    if self.rank == 0:
            self.rank_print("Saving checkpoint...")
        checkpoint_path = self.save_lora_checkpoint()
            
            return {
                "checkpoint_path": checkpoint_path,
                "final_loss": total_loss / max_train_steps,
                "num_steps": max_train_steps,
            }
    
    return {}  # Other ranks return empty dict
Key DDP patterns:
  1. Data loading: Only rank 0 loads, then broadcasts to all workers
  2. Different batches per GPU: Each GPU gets different local_indices
  3. Automatic gradient sync: DDP handles this during loss.backward()
  4. Rank 0 saves: Only one worker saves the checkpoint to avoid conflicts

Using the Application

After running fal run or fal deploy for each app, you’ll see URLs like https://fal.ai/dashboard/sdk/username/app-id/. You can:
  • Test in the Playground: Click the URL or visit it in your browser to open the interactive playground and test your app
  • View on Dashboard: Visit fal.ai/dashboard to see all your apps, monitor usage, and manage deployments
Important: You must provide the preprocessor_app parameter with the app name (username/uuid format)! Get it from the fal run flux-preprocessor output.

Test in the Playground

After deploying both apps, you can test them directly in the browser:
  1. Preprocessor App: Open https://fal.ai/dashboard/sdk/username/preprocessor-app-id/ to test image preprocessing
  2. Training App: Open https://fal.ai/dashboard/sdk/username/training-app-id/ to submit training jobs with a UI
The playground provides a form interface where you can upload images, set parameters, and see results without writing code.

Call from Code

import fal_client

# First, get your preprocessor app name by running:
# fal run flux-preprocessor
# From the output: @ https://fal.run/username/0a5f684b-bad1-4fd8-8f1d-c49493ab18bd
# Extract app name: username/0a5f684b-bad1-4fd8-8f1d-c49493ab18bd

result = fal_client.subscribe(
    "username/1b2c3d4e-5f6a-7b8c-9d0e-1f2a3b4c5d6e",  # Replace with your training app name
    arguments={
        "images_data_url": "https://example.com/training-images.zip",
        "trigger_word": "ohwx",
        "steps": 250,
        "learning_rate": 5e-4,
        "batch_size": 4,
        "resolution": 512,
        "preprocessor_app": "username/0a5f684b-bad1-4fd8-8f1d-c49493ab18bd",  # ⚠️ REQUIRED: Your preprocessor app name
    },
    with_logs=True,
)

print(f"Training complete! Loss: {result['final_loss']}")
print(f"Checkpoint: {result['checkpoint']['url']}")
For other languages (JavaScript, TypeScript, etc.) and advanced client usage, see the Client Libraries documentation.

DDP Best Practices

1. Synchronization Barriers

Use barriers when all GPUs need to wait:
# Wait for all GPUs to finish setup
dist.barrier()

2. Rank-specific Operations

Only perform I/O on rank 0 to avoid conflicts:
if self.rank == 0:
    # Save checkpoint
    # Log metrics
    # Upload results

Next Steps

Additional Resources

I