🚀 Try this Example
View the complete source code on GitHub.Report incorrect code
Copy
Ask AI
# Only import the python bundled packages, fal and fastapi in the global scope, app specific imports must be inside a function
import math
from typing import Literal
import fal
from fal.toolkit.image import ImageSizeInput, Image, ImageSize, get_image_size
from fal.toolkit.image.safety_checker import postprocess_images
from fastapi import Response
from pydantic import Field, BaseModel
# Base Output Model, it can be reused for image endpoints
class Output(BaseModel):
images: list[Image] = Field(description="The generated image files info.")
seed: int = Field(
description="""
Seed of the generated Image. It will be the same value of the one passed in the
input or the randomly generated that was used in case none was passed.
"""
)
has_nsfw_concepts: list[bool] = Field(
description="Whether the generated images contain NSFW concepts."
)
prompt: str = Field(
description="The prompt used for generating the image.",
)
# The input model for the inference request, make sure to set the title and description for each field
# and set the examples for the fields that are not optional
class BaseInput(BaseModel):
prompt: str = Field(
title="Prompt",
description="The prompt to generate an image from.",
# Set the example to show it on the playground
examples=[
"Underwater coral reef ecosystem during peak bioluminescent activity, multiple layers of marine life - from microscopic plankton to massive coral structures, light refracting through crystal-clear tropical waters, creating prismatic color gradients, hyper-detailed texture of marine organisms",
],
)
negative_prompt: str = Field(
default="",
description="""
The negative prompt to use. Use it to address details that you don't want
in the image. This could be colors, objects, scenery and even the small details
(e.g. moustache, blurry, low resolution).
""",
examples=[
"",
],
)
# Use the ImageSizeInput to set the image size, it will be converted to ImageSize
image_size: ImageSizeInput = Field(
default=ImageSize(width=3840, height=2160),
description="The size of the generated image.",
)
num_inference_steps: int = Field(
default=18,
description="The number of inference steps to perform.",
# set the least and max values whenver possible to limit the input values
ge=1,
le=50,
)
seed: int | None = Field(
default=None,
description="""
The same seed and the same prompt given to the same version of the model
will output the same image every time.
""",
)
guidance_scale: float = Field(
default=5.0,
description="""
The CFG (Classifier Free Guidance) scale is a measure of how close you want
the model to stick to your prompt when looking for a related image to show you.
""",
ge=0.0,
le=20.0,
title="Guidance scale (CFG)",
)
enable_safety_checker: bool = Field(
default=True,
description="If set to true, the safety checker will be enabled.",
)
num_images: int = Field(
default=1,
description="The number of images to generate.",
ge=1,
le=4,
)
output_format: Literal["jpeg", "png"] = Field(
default="jpeg",
description="The format of the generated image.",
)
# For the base endpoint
class TextToImageInput(BaseInput):
pass
# For the sprint endpoint, we can reuse the base input model and override the fields that we want to change
class SprintInput(BaseInput):
num_inference_steps: int = Field(
default=2,
description="The number of inference steps to perform.",
ge=1,
le=20,
)
class SanaOutput(Output):
images: list[Image] = Field(
description="The generated image files info.",
# Set default examples to show a generated image when the user visits the playground
examples=[
[
Image(
url="https://fal.media/files/kangaroo/QAABS8yM6X99WhiMeLcoL.jpeg",
width=3840,
height=2160,
content_type="image/jpeg",
)
],
],
)
class SanaSprintOutput(Output):
images: list[Image] = Field(
description="The generated image files info.",
# Set default examples to show a generated image when the user visits the playground
examples=[
[
Image(
url="https://v3.fal.media/files/penguin/Q-i_zCk-Xf5EggWA9OmG2_e753bacc9b324050855f9664deda3618.jpg",
width=3840,
height=2160,
content_type="image/jpeg",
)
],
],
)
class Sana(
fal.App,
keep_alive=600, # The worker will be kept alive for 10 minutes after the last request
min_concurrency=0, # The minimum number of concurrent workers to keep alive, if set to 0, the app will startup when the first request is received
max_concurrency=10, # The maximum number of concurrent workers to acquire, it helps limit the number of concurrent requests to the app
name="sana", # set the app name, the endpoint will be served at username/sana
):
"""
Specify requirements as follows and make sure to pin the versions of packages and commit hashes to ensure reliability.
"""
requirements = [
"torch==2.6.0",
"accelerate==1.6.0",
"transformers==4.51.3",
"git+https://github.com/huggingface/diffusers.git@f4fa3beee7f49b80ce7a58f9c8002f43299175c9",
"hf_transfer==0.1.9",
"peft==0.15.0",
"sentencepiece==0.2.0",
"--extra-index-url",
"https://download.pytorch.org/whl/cu124",
]
machine_type = "GPU-H100" # Choose machine type from https://docs.fal.ai/private-serverless-models/resources/
def setup(self):
"""
This method is called once when the app is started. Use it to load your model and cache it for all requests.
"""
# Import the libraries inside the setup method since these are installed in the worker enviroment as set in the requirements
import torch
from diffusers import SanaPipeline, SanaSprintPipeline
self.pipes = {}
self.pipes["base"] = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
torch_dtype=torch.bfloat16,
).to("cuda")
self.pipes["base"].text_encoder.to(torch.bfloat16)
self.pipes["sprint"] = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
text_encoder=self.pipes["base"].text_encoder, # Reuse the text encoder from the base pipeline
torch_dtype=torch.bfloat16,
).to("cuda")
async def _generate(
self,
input: TextToImageInput,
response: Response,
model_id: str,
) -> Output:
"""
Reuse the Inference code for both endpoints. Both the base and sprint pipelines have very similar inference code.
"""
import torch
# Preprocess the input
image_size = get_image_size(
input.image_size,
)
seed = input.seed or torch.seed()
generator = torch.Generator("cuda").manual_seed(seed)
model_input = {
"prompt": input.prompt,
"negative_prompt": input.negative_prompt,
"num_inference_steps": input.num_inference_steps,
"guidance_scale": input.guidance_scale,
"height": image_size.height,
"width": image_size.width,
"num_images_per_prompt": input.num_images,
"generator": generator,
}
if model_id == "sprint":
# Negative prompt is not supported in the sprint pipeline
model_input.pop("negative_prompt")
# Generate the images
images = self.pipes[model_id](**model_input).images
# Perform the safety check
postprocessed_images = postprocess_images(
images,
input.enable_safety_checker,
)
# Pricing
resolution_factor = math.ceil(
(image_size.width * image_size.height) / (1024 * 1024)
)
# The number of billable units is the resolution factor multiplied by the number of images
response.headers["x-fal-billable-units"] = str(
resolution_factor * input.num_images
)
# The cost is set in the billing dashboard which is calculated as the number of billable units multiplied by the cost per unit
return Output(
images=[
Image.from_pil(image, input.output_format)
for image in postprocessed_images["images"]
],
seed=seed,
has_nsfw_concepts=postprocessed_images["has_nsfw_concepts"],
prompt=input.prompt,
)
@fal.endpoint("/")
async def generate(
self,
input: TextToImageInput, # This will be used to autgenerate the OpenAPI spec and the playground form
response: Response, # This is the response object that will be used to set the headers for setting the billing units
) -> SanaOutput: # This is the output object that will be used to autgenerate the OpenAPI spec
return await self._generate(input, response, "base")
@fal.endpoint("/sprint")
async def generate_sprint(
self,
input: SprintInput, # Use a different input class for the sprint endpoint to change example values and remove the negative prompt
response: Response,
) -> SanaSprintOutput:
return await self._generate(input, response, "sprint")
# Run the app with:
# cd fal_demos/image
# fal run sana
#
# Or directly with:
# fal run fal_demos/image/sana.py::Sana
#
# The app will be served on an ephemeral URL, example: https://fal.ai/dashboard/sdk/fal-ai/9fe9b6fc-534d-4926-95b1-87b7f15a67de
# Visit https://fal.ai/dashboard/sdk/fal-ai/9fe9b6fc-534d-4926-95b1-87b7f15a67de to test the root endpoint
# To test the sprint endpoint, visit https://fal.ai/dashboard/sdk/fal-ai/9fe9b6fc-534d-4926-95b1-87b7f15a67de/sprint
Or clone this repository:
Report incorrect code
Copy
Ask AI
git clone https://github.com/fal-ai-community/fal-demos.git
cd fal-demos
pip install -e .
# Use the app name (defined in pyproject.toml)
fal run sana
# Or use the full file path:
# fal run fal_demos/image/sana.py::Sana
Before you run, make sure you have:
- Authenticated with fal:
fal auth login
- Activated your virtual environment (recommended):
python -m venv venv && source venv/bin/activate
(macOS/Linux) orvenv\Scripts\activate
(Windows)
Key Features
- Multiple endpoints with shared model components
- Input validation using Pydantic models
- Safety checking for generated content
- Flexible image generation with customizable parameters
Project Structure
Report incorrect code
Copy
Ask AI
# Only import the python bundled packages, fal and fastapi in the global scope
import math
from typing import Literal
import fal
from fal.toolkit.image import ImageSizeInput, Image, ImageSize, get_image_size
from fal.toolkit.image.safety_checker import postprocess_images
from fastapi import Response
from pydantic import Field, BaseModel
from fal_demos.image.common import Output
Input Models
Define your input schemas with clear documentation and examples:Report incorrect code
Copy
Ask AI
class BaseInput(BaseModel):
prompt: str = Field(
title="Prompt",
description="The prompt to generate an image from.",
examples=[
"Underwater coral reef ecosystem during peak bioluminescent activity, multiple layers of marine life - from microscopic plankton to massive coral structures, light refracting through crystal-clear tropical waters, creating prismatic color gradients, hyper-detailed texture of marine organisms",
],
)
negative_prompt: str = Field(
default="",
description="""
The negative prompt to use. Use it to address details that you don't want
in the image. This could be colors, objects, scenery and even the small details
(e.g. moustache, blurry, low resolution).
""",
examples=[""],
)
image_size: ImageSizeInput = Field(
default=ImageSize(width=3840, height=2160),
description="The size of the generated image.",
)
num_inference_steps: int = Field(
default=18,
description="The number of inference steps to perform.",
ge=1,
le=50,
)
seed: int | None = Field(
default=None,
description="""
The same seed and the same prompt given to the same version of the model
will output the same image every time.
""",
)
guidance_scale: float = Field(
default=5.0,
description="""
The CFG (Classifier Free Guidance) scale is a measure of how close you want
the model to stick to your prompt when looking for a related image to show you.
""",
ge=0.0,
le=20.0,
title="Guidance scale (CFG)",
)
enable_safety_checker: bool = Field(
default=True,
description="If set to true, the safety checker will be enabled.",
)
num_images: int = Field(
default=1,
description="The number of images to generate.",
ge=1,
le=4,
)
output_format: Literal["jpeg", "png"] = Field(
default="jpeg",
description="The format of the generated image.",
)
# Endpoint-specific input models
class TextToImageInput(BaseInput):
pass
class SprintInput(BaseInput):
# Override settings for the faster endpoint
num_inference_steps: int = Field(
default=2,
description="The number of inference steps to perform.",
ge=1,
le=20,
)
Output Models
Create output models with example data for the playground:Report incorrect code
Copy
Ask AI
class SanaOutput(Output):
images: list[Image] = Field(
description="The generated image files info.",
examples=[
[
Image(
url="https://fal.media/files/kangaroo/QAABS8yM6X99WhiMeLcoL.jpeg",
width=3840,
height=2160,
content_type="image/jpeg",
)
],
],
)
class SanaSprintOutput(Output):
images: list[Image] = Field(
description="The generated image files info.",
examples=[
[
Image(
url="https://v3.fal.media/files/penguin/Q-i_zCk-Xf5EggWA9OmG2_e753bacc9b324050855f9664deda3618.jpg",
width=3840,
height=2160,
content_type="image/jpeg",
)
],
],
)
Main Application Class
Report incorrect code
Copy
Ask AI
class Sana(fal.App):
keep_alive = 60 # Keep worker alive for 1 minute
min_concurrency = 0 # Scale to zero when idle
max_concurrency = 2 # Limit concurrent requests
app_name = "sana" # Endpoint served at username/sana
requirements = [
"torch==2.6.0",
"accelerate==1.6.0",
"transformers==4.51.3",
"git+https://github.com/huggingface/diffusers.git@f4fa3beee7f49b80ce7a58f9c8002f43299175c9",
"hf_transfer==0.1.9",
"peft==0.15.0",
"sentencepiece==0.2.0",
"--extra-index-url",
"https://download.pytorch.org/whl/cu124",
]
local_python_modules = [
"fal_demos",
]
machine_type = "GPU-H100"
def setup(self):
"""Load and cache models for all requests."""
import torch
from diffusers import SanaPipeline, SanaSprintPipeline
self.pipes = {}
# Load base pipeline
self.pipes["base"] = SanaPipeline.from_pretrained(
"Efficient-Large-Model/Sana_1600M_1024px_BF16_diffusers",
torch_dtype=torch.bfloat16,
).to("cuda")
self.pipes["base"].text_encoder.to(torch.bfloat16)
# Load sprint pipeline, reusing text encoder for efficiency
self.pipes["sprint"] = SanaSprintPipeline.from_pretrained(
"Efficient-Large-Model/Sana_Sprint_1.6B_1024px_diffusers",
text_encoder=self.pipes["base"].text_encoder, # Reuse component
torch_dtype=torch.bfloat16,
).to("cuda")
Shared Generation Logic
Create a reusable generation method for both endpoints:Report incorrect code
Copy
Ask AI
async def _generate(
self,
input: TextToImageInput,
response: Response,
model_id: str,
) -> Output:
import torch
# Preprocess input
image_size = get_image_size(input.image_size)
seed = input.seed or torch.seed()
generator = torch.Generator("cuda").manual_seed(seed)
# Prepare model input
model_input = {
"prompt": input.prompt,
"negative_prompt": input.negative_prompt,
"num_inference_steps": input.num_inference_steps,
"guidance_scale": input.guidance_scale,
"height": image_size.height,
"width": image_size.width,
"num_images_per_prompt": input.num_images,
"generator": generator,
}
# Handle model-specific differences
if model_id == "sprint":
model_input.pop("negative_prompt") # Not supported in sprint
# Generate images
images = self.pipes[model_id](**model_input).images
# Apply safety checking
postprocessed_images = postprocess_images(
images,
input.enable_safety_checker,
)
# Calculate billing
resolution_factor = math.ceil(
(image_size.width * image_size.height) / (1024 * 1024)
)
response.headers["x-fal-billable-units"] = str(
resolution_factor * input.num_images
)
return Output(
images=[
Image.from_pil(image, input.output_format)
for image in postprocessed_images["images"]
],
seed=seed,
has_nsfw_concepts=postprocessed_images["has_nsfw_concepts"],
prompt=input.prompt,
)
Endpoint Definitions
Define multiple endpoints using the shared generation logic:Report incorrect code
Copy
Ask AI
@fal.endpoint("/")
async def generate(
self,
input: TextToImageInput,
response: Response,
) -> SanaOutput:
return await self._generate(input, response, "base")
@fal.endpoint("/sprint")
async def generate_sprint(
self,
input: SprintInput,
response: Response,
) -> SanaSprintOutput:
return await self._generate(input, response, "sprint")
Running the Application
Development
Report incorrect code
Copy
Ask AI
fal run fal_demos/image/sana.py::Sana
Using pyproject.toml
Add to yourpyproject.toml
:
Report incorrect code
Copy
Ask AI
[tool.fal.apps]
sana = "fal_demos.image.sana:Sana"
Report incorrect code
Copy
Ask AI
fal run sana
Testing Your Endpoints
Once deployed, your app will be available at URLs like:- Base endpoint:
https://fal.ai/dashboard/sdk/username/app-id
- Sprint endpoint:
https://fal.ai/dashboard/sdk/username/app-id/sprint
Best Practices Demonstrated
- Resource Sharing: The text encoder is shared between pipelines to save memory
- Input Validation: Comprehensive Pydantic models with examples and constraints
- Error Handling: Safety checking and proper response formatting
- Billing Integration: Resolution-based pricing
- Endpoint Flexibility: Different configurations for different use cases
- Documentation: Rich field descriptions and examples for auto-generated docs
Key Takeaways
- Use
setup()
to load models once and cache them - Share components between models when possible to optimize memory
- Create endpoint-specific input models for different use cases
- Implement proper billing with
x-fal-billable-units
- Use the fal image toolkit for safety checking and processing
- Pin all dependency versions for reliability