This page covers the essential APIs for building multi-GPU applications with fal.distributed
. We focus on the methods you’ll actually use in your code.
DistributedRunner
The DistributedRunner
class orchestrates multiple GPU workers for distributed computation. It handles process management, inter-process communication via ZMQ, and coordination between worker processes.
Constructor
DistributedRunner(
worker_cls: type[DistributedWorker],
world_size: int,
)
Parameters:
worker_cls
(type[DistributedWorker]
): Your custom worker class that inherits from DistributedWorker
.
world_size
(int
): Total number of worker processes to spawn (typically equals num_gpus
).
Example:
from fal.distributed import DistributedRunner, DistributedWorker
class MyWorker(DistributedWorker):
def setup(self, **kwargs):
self.model = load_model().to(self.device)
def __call__(self, prompt: str, **kwargs):
return self.model.generate(prompt)
# Create runner for 4 GPUs
runner = DistributedRunner(
worker_cls=MyWorker,
world_size=4,
)
start()
Starts all distributed worker processes and initializes them.
async def start(
self,
timeout: int = 1800,
**kwargs: Any
) -> None
Parameters:
timeout
(int
): Maximum time (in seconds) to wait for all workers to be ready. Default: 1800
(30 minutes).
**kwargs
: Additional keyword arguments passed to each worker’s setup()
method.
Raises:
RuntimeError
: If processes are already running or fail to start.
TimeoutError
: If workers don’t become ready within the timeout period.
Example:
class MyApp(fal.App):
num_gpus = 2
async def setup(self):
self.runner = DistributedRunner(
worker_cls=MyWorker,
world_size=self.num_gpus,
)
# Start workers and pass model path to their setup()
await self.runner.start(model_path="/data/models/flux")
# Workers are now ready to process requests
What it does:
- Spawns
world_size
worker processes (one per GPU)
- Each worker runs its
setup()
method with the provided **kwargs
- Waits for all workers to signal “READY”
- Starts the keepalive timer if configured
- Returns when all workers are initialized and ready
This method must be called before using invoke()
or stream()
. It’s typically called once in your app’s setup()
method.
invoke()
Executes the worker’s __call__()
method across all GPUs and returns the final result from rank 0.
async def invoke(
self,
payload: dict[str, Any] = {},
timeout: int | None = None,
) -> Any
Parameters:
payload
(dict[str, Any]
): Dictionary of arguments to pass to each worker’s __call__()
method. Default: {}
.
timeout
(int | None
): Maximum time (in seconds) to wait for the result. If None
, uses the runner’s default timeout. Default: None
.
Returns:
Any
: The result returned by rank 0 worker’s __call__()
method.
Raises:
RuntimeError
: If workers are not running or encounter an error during execution.
TimeoutError
: If the operation exceeds the timeout.
Example:
@fal.endpoint("/generate")
async def generate(self, request: GenerateRequest) -> GenerateResponse:
# Invoke workers to generate images
result = await self.runner.invoke({
"prompt": request.prompt,
"num_steps": request.num_steps,
"width": 1024,
"height": 1024,
})
return GenerateResponse(image=result["image"])
How it works:
- Serializes the payload and sends it to all workers
- Each worker executes its
__call__()
method with streaming=False
- Workers coordinate using PyTorch distributed operations (e.g.,
dist.gather()
)
- Only rank 0 returns the result
- Result is deserialized and returned to the caller
Use invoke()
for standard (non-streaming) requests where you need the final result only.
stream()
Streams intermediate results from workers during execution, useful for long-running operations like image generation or training.
async def stream(
self,
payload: dict[str, Any] = {},
timeout: int | None = None,
streaming_timeout: int | None = None,
as_text_events: bool = False,
) -> AsyncIterator[Any]
Parameters:
payload
(dict[str, Any]
): Dictionary of arguments to pass to each worker’s __call__()
method. Default: {}
.
timeout
(int | None
): Maximum total time (in seconds) for the entire operation. Default: None
(no limit).
streaming_timeout
(int | None
): Maximum time (in seconds) between consecutive yields. If no data is received within this period, raises TimeoutError
. Default: None
.
as_text_events
(bool
): If True
, yields Server-Sent Events (SSE) formatted as bytes. If False
, yields deserialized Python objects. Default: False
.
Returns:
AsyncIterator[Any]
: Async iterator yielding intermediate results and the final result.
Raises:
RuntimeError
: If workers are not running, encounter an error, or yield no data.
TimeoutError
: If the operation exceeds timeout or streaming_timeout.
Example:
@fal.endpoint("/stream")
async def stream_generate(self, request: GenerateRequest) -> StreamingResponse:
return StreamingResponse(
self.runner.stream(
payload={
"prompt": request.prompt,
"num_steps": request.num_steps,
},
as_text_events=True,
),
media_type="text/event-stream",
)
How it works:
- Serializes the payload and sends it to all workers
- Each worker executes its
__call__()
method with streaming=True
- Workers can call
self.add_streaming_result()
to send intermediate updates
- The runner yields each intermediate result as it’s received
- After workers finish, yields the final result
- Automatically handles serialization based on
as_text_events
Set as_text_events=True
when using with StreamingResponse
for browser-compatible Server-Sent Events.
DistributedWorker
The DistributedWorker
class is the base class for your custom GPU workers. Each instance runs on a separate GPU and handles model loading, inference, or training.
Create your own worker by inheriting from DistributedWorker
and overriding the setup()
and __call__()
methods.
class MyWorker(DistributedWorker):
def setup(self, **kwargs):
# Load model on this GPU
self.model = load_model().to(self.device)
def __call__(self, prompt: str, **kwargs):
# Process request
return self.model.generate(prompt)
Properties
device
Returns the CUDA device assigned to this worker.
@property
def device(self) -> torch.device
Returns:
torch.device
: The PyTorch device for this worker, e.g., cuda:0
, cuda:1
, etc.
Example:
class MyWorker(DistributedWorker):
def setup(self):
# Load model on this worker's GPU
self.model = MyModel().to(self.device)
print(f"Model loaded on {self.device}")
rank
The rank (ID) of this worker, from 0 to world_size-1.
Example:
if self.rank == 0:
print("I'm the master worker!")
# Only rank 0 saves checkpoints, uploads files, etc.
world_size
Total number of workers in the distributed setup.
Example:
print(f"Running with {self.world_size} GPUs")
Methods to Override
setup()
Called once when the worker is initialized. Use this to load models, download weights, and prepare resources.
def setup(self, **kwargs: Any) -> None
Parameters:
**kwargs
: Any keyword arguments passed to runner.start()
.
Example:
class FluxWorker(DistributedWorker):
def setup(self, model_path: str = "/data/flux", **kwargs):
"""Initialize the Flux model on this GPU"""
import torch
from diffusers import FluxPipeline
self.rank_print(f"Loading Flux on {self.device}")
self.pipeline = FluxPipeline.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
).to(self.device)
# Disable progress bar for non-main workers
if self.rank != 0:
self.pipeline.set_progress_bar_config(disable=True)
self.rank_print("Model loaded successfully")
Heavy operations like model loading should go in setup()
, not __call__()
, so they only happen once per worker.
call()
Called for each request. Implement your main processing logic here.
def __call__(self, streaming: bool = False, **kwargs: Any) -> Any
Parameters:
streaming
(bool
): True
if called via runner.stream()
, False
if called via runner.invoke()
.
**kwargs
: Arguments from the payload
dict passed to runner.invoke()
or runner.stream()
.
Returns:
Any
: The result to return. Only rank 0’s return value is sent back to the caller.
Example:
class FluxWorker(DistributedWorker):
def __call__(
self,
prompt: str,
num_steps: int = 20,
streaming: bool = False,
**kwargs
) -> dict:
"""Generate an image on this GPU"""
import torch.distributed as dist
# Each GPU generates independently
image = self.pipeline(
prompt=prompt,
num_inference_steps=num_steps,
output_type="pt",
).images[0]
# Gather all images to rank 0
if self.rank == 0:
gather_list = [
torch.zeros_like(image, device=self.device)
for _ in range(self.world_size)
]
else:
gather_list = None
dist.gather(image, gather_list, dst=0)
# Only rank 0 returns the result
if self.rank == 0:
combined_image = create_grid(gather_list)
return {"image": combined_image}
return {} # Other ranks return empty dict
Utility Methods
add_streaming_result()
Sends an intermediate result to the client during streaming.
def add_streaming_result(
self,
result: Any,
image_format: str = "jpeg",
as_text_event: bool = False,
) -> None
Parameters:
result
(Any
): The data to stream. Can be a dict, PIL image, or any serializable object.
image_format
(str
): Image format for PIL images ("jpeg"
or "png"
). Default: "jpeg"
.
as_text_event
(bool
): If True
, formats as Server-Sent Event. Must match the as_text_events
parameter in runner.stream()
. Default: False
.
Example:
def __call__(self, prompt: str, num_steps: int = 20, streaming: bool = False):
for step in range(num_steps):
# Generate intermediate result
latent = self.model.step(prompt)
# Stream progress (only rank 0)
if streaming and self.rank == 0 and step % 5 == 0:
preview_image = self.decode_latent(latent)
self.add_streaming_result({
"step": step,
"progress": (step + 1) / num_steps,
"preview": preview_image,
}, as_text_event=True)
# Return final result
return {"image": final_image}
Only call add_streaming_result()
from rank 0 to avoid duplicate messages to the client.
rank_print()
Prints a message with the worker’s rank prefix for easy debugging.
def rank_print(self, message: str, debug: bool = False) -> None
Parameters:
message
(str
): The message to print.
debug
(bool
): If True
, prefixes with [debug]
. Default: False
.
Example:
self.rank_print("Starting generation...")
# Output: [rank 0] Starting generation...
self.rank_print("Model loaded", debug=True)
# Output: [debug] [rank 0] Model loaded
Common Patterns
Pattern 1: Data Parallelism (Inference)
Each GPU processes different data independently:
class ParallelWorker(DistributedWorker):
def __call__(self, prompt: str, **kwargs):
import torch.distributed as dist
# Each GPU generates independently with different seed
result = self.model.generate(prompt)
# Gather all results to rank 0
if self.rank == 0:
gather_list = [torch.zeros_like(result) for _ in range(self.world_size)]
else:
gather_list = None
dist.gather(result, gather_list, dst=0)
# Only rank 0 returns combined result
if self.rank == 0:
return {"outputs": gather_list}
return {}
Pattern 2: Distributed Data Parallel (Training)
All GPUs have the same model, process different batches, and sync gradients:
class DDPWorker(DistributedWorker):
def setup(self, **kwargs):
from torch.nn.parallel import DistributedDataParallel as DDP
self.model = MyModel().to(self.device)
# Wrap with DDP for gradient synchronization
self.model = DDP(
self.model,
device_ids=[self.rank],
output_device=self.rank,
)
self.optimizer = torch.optim.Adam(self.model.parameters())
def __call__(self, data_path: str, **kwargs):
import torch.distributed as dist
# Load and distribute data
if self.rank == 0:
data = load_data(data_path)
else:
data = None
# Broadcast to all ranks
data = dist.broadcast_object_list([data], src=0)[0]
# Each GPU processes different batch
local_batch = data[self.rank::self.world_size]
# Training loop
for batch in local_batch:
loss = self.model(batch)
loss.backward() # DDP syncs gradients automatically
self.optimizer.step()
self.optimizer.zero_grad()
# Only rank 0 saves checkpoint
if self.rank == 0:
torch.save(self.model.state_dict(), "checkpoint.pt")
return {"checkpoint": "checkpoint.pt"}
return {}
Pattern 3: Streaming with Progress Updates
Stream intermediate results during long-running operations:
class StreamingWorker(DistributedWorker):
def __call__(self, prompt: str, steps: int = 50, streaming: bool = False):
import torch.distributed as dist
for step in range(steps):
result = self.model.step(prompt)
# Stream progress every 5 steps
if streaming and self.rank == 0 and step % 5 == 0:
self.add_streaming_result({
"step": step,
"progress": step / steps,
}, as_text_event=True)
# Sync all workers
dist.barrier()
# Return final result
if self.rank == 0:
return {"output": result}
return {}
Next Steps