Performance and parallelization patterns for Hamilton including async I/O, Spark, Ray, Dask, caching, and multithreading. Use for scaling Hamilton workflows.
This skill covers parallelization strategies and performance optimization for Apache Hamilton workflows.
Choose your scaling strategy based on workload:
When to Use:
Key Benefits:
Basic Async Pattern:
from hamilton import async_driver
import aiohttp
from typing import List
# Mix async and sync functions
async def api_data(user_id: str) -> dict:
"""Fetch from API (async I/O)."""
async with aiohttp.ClientSession() as session:
async with session.get(f"https://api.example.com/users/{user_id}") as resp:
return await resp.json()
def transform_data(api_data: dict) -> dict:
"""Transform data (sync CPU)."""
return {k: v.upper() if isinstance(v, str) else v for k, v in api_data.items()}
async def save_data(transform_data: dict) -> str:
"""Save to database (async I/O)."""
async with db_pool.acquire() as conn:
await conn.execute("INSERT INTO data VALUES ($1)", transform_data)
return "success"
# Use AsyncDriver
import my_async_module
dr = await async_driver.Builder().with_modules(my_async_module).build()
result = await dr.execute(['save_data'], inputs={'user_id': '123'})
Parallel I/O Operations:
# These three operations execute in parallel automatically!
async def user_data(user_id: str) -> dict:
"""Fetch user data."""
async with aiohttp.ClientSession() as session:
async with session.get(f"https://api.example.com/users/{user_id}") as resp:
return await resp.json()
async def user_orders(user_id: str) -> List[dict]:
"""Fetch user orders (parallel with user_data)."""
async with aiohttp.ClientSession() as session:
async with session.get(f"https://api.example.com/orders?user={user_id}") as resp:
return await resp.json()
async def user_preferences(user_id: str) -> dict:
"""Fetch preferences (parallel with both above)."""
async with db_pool.acquire() as conn:
return await conn.fetchrow("SELECT * FROM preferences WHERE user_id=$1", user_id)
def user_profile(user_data: dict, user_orders: List[dict], user_preferences: dict) -> dict:
"""Combine all data (waits for all three to complete)."""
return {"data": user_data, "orders": user_orders, "preferences": user_preferences}
For synchronous I/O-bound code (legacy APIs, blocking libraries):
from hamilton import driver
from hamilton.execution import executors
# Use multithreading for sync I/O operations
dr = driver.Builder()\
.with_modules(my_functions)\
.with_local_executor(executors.MultiThreadingExecutor(max_tasks=10))\
.build()
# Sync functions that do I/O will run concurrently
results = dr.execute(['final_output'], inputs={...})
When to Use:
When to Use Spark:
Basic PySpark Pattern:
from pyspark.sql import DataFrame as SparkDataFrame, SparkSession
from hamilton.plugins import h_spark
def raw_data(spark_session: SparkSession) -> SparkDataFrame:
"""Load data into Spark."""
return spark_session.read.csv("data.csv", header=True)
def filtered_data(raw_data: SparkDataFrame) -> SparkDataFrame:
"""Filter using Spark operations."""
return raw_data.filter(raw_data.age > 18)
def aggregated_stats(filtered_data: SparkDataFrame) -> SparkDataFrame:
"""Aggregate using Spark."""
return filtered_data.groupBy("country").count()
# Driver Setup
dr = driver.Builder()\
.with_modules(my_spark_functions)\
.with_adapters(h_spark.SPARK_INPUT_CHECK)\
.build()
result = dr.execute(['aggregated_stats'], inputs={'spark_session': spark})
Column-Level Transformations with @with_columns:
from hamilton.plugins.h_spark import with_columns
import pandas as pd
# File: map_transforms.py
def normalized_amount(amount: pd.Series) -> pd.Series:
"""Pandas UDF for normalization."""
return (amount - amount.mean()) / amount.std()
def amount_category(normalized_amount: pd.Series) -> pd.Series:
"""Categorize based on normalized amount."""
return pd.cut(normalized_amount, bins=[-float('inf'), -1, 1, float('inf')],
labels=['low', 'medium', 'high'])
# Main dataflow
@with_columns(
map_transforms,
columns_to_pass=["amount"]
)
def enriched_data(raw_data: SparkDataFrame) -> SparkDataFrame:
"""Apply pandas UDFs to Spark DataFrame."""
return raw_data
Spark Best Practices:
.collect() until final nodes@with_columns for map operations.limit(1000))When to Use Ray:
from hamilton.plugins import h_ray
import ray
ray.init()
ray_executor = h_ray.RayGraphAdapter(result_builder={"base": dict})
dr = driver.Driver({}, my_functions, adapter=ray_executor)
results = dr.execute(['large_computation'], inputs={...})
When to Use Dask:
from hamilton.plugins import h_dask
from dask import distributed
cluster = distributed.LocalCluster()
client = distributed.Client(cluster)
dask_executor = h_dask.DaskExecutor(client=client)
dr = driver.Builder()\
.with_remote_executor(dask_executor)\
.with_modules(my_functions)\
.build()
When to Use Caching:
Basic Caching Setup:
from hamilton import driver
# Enable caching
dr = driver.Builder()\
.with_modules(my_functions)\
.with_cache()\
.build()
# First execution: computes and caches
result1 = dr.execute(['final_output'], inputs={'data_path': 'data.csv'})
# Second execution: retrieves from cache (instant!)
result2 = dr.execute(['final_output'], inputs={'data_path': 'data.csv'})
Controlling Cache Behavior:
from hamilton.function_modifiers import cache
# Always recompute (for data loaders)
@cache(behavior="recompute")
def live_api_data(api_key: str) -> dict:
"""Always fetch fresh data."""
import requests
response = requests.get("https://api.example.com/data",
headers={"Authorization": api_key})
return response.json()
# Never cache (for non-deterministic operations)
@cache(behavior="disable")
def random_sample(data: pd.DataFrame) -> pd.DataFrame:
"""Random sampling should not be cached."""
return data.sample(frac=0.1)
# Custom format for efficiency
@cache(format="parquet")
def large_dataframe(processed_data: pd.DataFrame) -> pd.DataFrame:
"""Store as parquet for efficiency."""
return processed_data
Driver-Level Cache Control:
dr = driver.Builder()\
.with_modules(my_functions)\
.with_cache(
recompute=['raw_data'], # Always recompute these
disable=['random_sample'], # Never cache these
path="./my_cache" # Custom location
)\
.build()
# Force complete refresh
dr = driver.Builder()\
.with_modules(my_functions)\
.with_cache(recompute=True)\
.build()
Cache Inspection:
# Visualize what was cached vs executed
dr.cache.view_run() # Green = cache hit, Orange = executed
# Access cached results
run_id = dr.cache.last_run_id
data_version = dr.cache.data_versions[run_id]['my_node']
cached_result = dr.cache.result_store.get(data_version)
Decision Matrix:
| Workload Type | Strategy | Use When |
|---|---|---|
| I/O-bound (async-capable) | AsyncDriver | Multiple API calls, async libraries available |
| I/O-bound (sync only) | MultiThreading | Legacy APIs, blocking I/O |
| Large datasets | Spark | Multi-GB/TB data, cluster available |
| Python computation | Ray/Dask | Custom parallel logic, ML workloads |
| Expensive operations | Caching | Repeated computations, LLM calls |
Combining Strategies:
/hamilton-core/hamilton-llm