Train and run inference on one model per data partition using ManyModelTraining and ManyModelInference. Auto-serialization, get_model(), and distributed inference.
Train separate ML models for each data partition in parallel, then run distributed inference using those models. MMT/MMI handles distributed orchestration, model serialization, and automatic model loading.
Load this skill when:
get_model() convenienceFor Model Registry: If user wants to register models in the Model Registry (versioning, SQL access), see ../../model-registry/partitioned-inference/SKILL.md for @partitioned_api approach.
Ask user:
I'll help you train models per partition. Please confirm:
1. Training data table?
2. Partition column? (e.g., REGION, STORE_ID)
3. Target/label column?
4. Feature columns?
5. Model type? (XGBoost / LightGBM / sklearn / custom)
6. Stage for storing models?
⚠️ STOP: Wait for user response.
from snowflake.ml.modeling.distributors.many_model import ManyModelTraining
from snowflake.ml.modeling.distributors.distributed_partition_function.entities import RunStatus
def train_model(data_connector, context):
"""
Args:
data_connector: Access partition data via .to_pandas()
context: Partition info via context.partition_id
Returns:
Trained model (auto-serialized)
"""
df = data_connector.to_pandas()
print(f"Training for partition: {context.partition_id}")
from xgboost import XGBRegressor
model = XGBRegressor(n_estimators=100)
model.fit(df[['feature1', 'feature2']], df['target'])
return model # Auto-serialized to stage
trainer = ManyModelTraining(train_model, "<STAGE_NAME>")
training_run = trainer.run(
partition_by="<PARTITION_COLUMN>",
snowpark_dataframe=session.table("<TABLE>"),
run_id="<DESCRIPTIVE_RUN_ID>"
)
final_status = training_run.wait()
print(f"Training completed with status: {final_status}")
⚠️ STOP: After training completes, verify with user:
Training complete for [N] partitions.
Status: [SUCCESS/PARTIAL/FAILED]
Partition results:
- partition_1: SUCCESS
- partition_2: SUCCESS
...
Would you like to:
1. Retrieve models for use
2. Check failed partitions
3. Proceed to partitioned inference
if final_status == RunStatus.SUCCESS:
# Get models by partition
for partition_id in training_run.partition_details:
model = training_run.get_model(partition_id)
print(f"Retrieved model for {partition_id}")
# Or collect all into dict
models = {
pid: training_run.get_model(pid)
for pid in training_run.partition_details
}
To restore a completed run later:
from snowflake.ml.modeling.distributors.many_model import ManyModelRun
restored_run = ManyModelRun.restore_from("<RUN_ID>", "<STAGE_NAME>")
model = restored_run.get_model("<PARTITION_ID>")
After training with MMT, use ManyModelInference to run distributed inference across partitions. Models are automatically loaded from the training run.
from snowflake.ml.modeling.distributors.many_model import ManyModelInference
def predict_with_model(data_connector, model, context):
"""
Args:
data_connector: Access partition data via .to_pandas()
model: Pre-loaded model (auto-loaded from training run)
context: Partition info via context.partition_id
Returns:
Prediction results
"""
df = data_connector.to_pandas()
print(f"Running inference for partition: {context.partition_id}")
# Model is already loaded - just use it
predictions = model.predict(df[['feature1', 'feature2']])
results = df.copy()
results['predictions'] = predictions
# Save results to stage
context.upload_to_stage(results, "predictions.csv",
write_function=lambda df, path: df.to_csv(path, index=False))
return results
inference = ManyModelInference(
predict_with_model,
"<STAGE_NAME>", # Same stage as training
training_run_id="<TRAINING_RUN_ID>" # Run ID from Step 3
)
inference_run = inference.run(
partition_by="<PARTITION_COLUMN>", # Must match training
snowpark_dataframe=session.table("<NEW_DATA_TABLE>"),
run_id="<INFERENCE_RUN_ID>"
)
final_status = inference_run.wait()
print(f"Inference completed with status: {final_status}")
⚠️ STOP: After inference completes, verify with user:
Inference complete for [N] partitions.
Status: [SUCCESS/PARTIAL/FAILED]
Results stored in stage: @<STAGE_NAME>/<INFERENCE_RUN_ID>/
Would you like to:
1. Download prediction results
2. Check failed partitions
3. Run another inference batch
For large-scale inference, write results directly to a table:
def predict_to_table(data_connector, model, context):
df = data_connector.to_pandas()
predictions = model.predict(df[['feature1', 'feature2']])
results = df.copy()
results['predictions'] = predictions
results['partition_id'] = context.partition_id
# Write to Snowflake table (uses bounded session pool)
context.with_session(lambda session:
session.create_dataframe(results)
.write.mode("append")
.save_as_table("MY_PREDICTIONS_TABLE")
)
return {"rows_written": len(results)}
Use the same serde from training:
from snowflake.ml.modeling.distributors.many_model import (
ManyModelInference, PickleSerde, TorchSerde, TensorFlowSerde
)
# Default (pickle) - XGBoost, sklearn, LightGBM
inference = ManyModelInference(predict_func, "stage", "train_run_v1")
# PyTorch models
inference = ManyModelInference(predict_func, "stage", "train_run_v1", serde=TorchSerde())
# TensorFlow models
inference = ManyModelInference(predict_func, "stage", "train_run_v1", serde=TensorFlowSerde())
# Custom serde (must match training)
inference = ManyModelInference(predict_func, "stage", "train_run_v1", serde=CustomSerde())
Training (MMT):
get_model(partition_id) or restore_from()Inference (MMI):
@<STAGE>/<RUN_ID>/<PARTITION_ID>/Default serialization uses PickleSerde which works for most sklearn-compatible models. For specific frameworks:
from snowflake.ml.modeling.distributors.many_model import (
ManyModelTraining, PickleSerde, TorchSerde, TensorFlowSerde, ModelSerde
)
# Default (pickle) - works for XGBoost, sklearn, LightGBM
trainer = ManyModelTraining(train_func, "stage") # Uses PickleSerde()
# PyTorch models
trainer = ManyModelTraining(train_func, "stage", serde=TorchSerde())
# TensorFlow/Keras models
trainer = ManyModelTraining(train_func, "stage", serde=TensorFlowSerde())
# Custom serialization
class CustomSerde(ModelSerde):
@property
def filename(self) -> str:
return "model.joblib"
def write(self, model, file_path: str) -> None:
import joblib
joblib.dump(model, file_path)
def read(self, file_path: str):
import joblib
return joblib.load(file_path)
trainer = ManyModelTraining(train_func, "stage", serde=CustomSerde())
See ../references/compute-pool-sizing.md for instance families and node count sizing. For ExecutionOptions, see ../dpf/SKILL.md. For monitoring and troubleshooting, see ../references/monitoring-troubleshooting.md.
ManyModelInference signature: The inference function takes 3 arguments (not 2 like MMT training):
def predict_func(data_connector, model, context): # model is auto-loaded and passed as 2nd arg
ManyModelRun: Extends DPFRun (see ../dpf/SKILL.md) with get_model(partition_id) to load trained models.
After MMT/MMI:
../../model-registry/partitioned-inference/SKILL.md for @partitioned_api../references/compute-pool-sizing.md../references/monitoring-troubleshooting.md