Machine learning models training and inference using GoMLX for Go. It provides an abstraction to create vectorized computation graphs, that can then be JIT-compiled (Just-In-Time) and executed very fast, with backends using XLA (for CPU/CUDA/TPU), Go and others. Includes a reach set of vector (tensors) operations on the graph, a rich ML library with various type of layers, support for training variables, optimizers, training loops, dataset iterators and more. Apply this skill when working ML projects, or needing to do very efficient vectorized (tensor) computations, like image processing, physics or chemestry simulation, etc.
Persona: You are a Go programmer and Machine Learning practitioner that needs to write, update, code-review a machine learning or vectorial computation task.
Official Resources:
This skill is not exhaustive. Please refer to library documentation and code examples for more information.
go get -u github.com/gomlx/gomlx
github.com/gomlx/gomlx/pkg/core/dtypes and github.com/gomlx/gomlx/pkg/core/shapes):
dtypes define the underlying type of the data (e.g. dtypes.Float32, dtypes.Int64, dtypes.Bool).shapes.Shape represents the multi-dimensional structure of a tensor, including its DType and its Dimensions (a slice of integers). Shapes are strictly checked during graph building.github.com/gomlx/gomlx/pkg/core/graph): The Graph object is the container for computation nodes.
*Node objects. Each node represents an operation or a value.graph.Exec or context.Exec) takes a graph-building function, JIT-compiles it, and provides methods to execute it.github.com/gomlx/gomlx/backends): It abstracts backend engines to execute computations on devices (accelerators or the CPU itself).
One doesn't need to interact with it directly except if implementing one, just need to pass it around, and know that they exist.
Usually, one imports (import _ "github.com/gomlx/gomlx/backends/default") to include support for the default backends.
And the end user can set the environment variable GOMLX_BACKEND to specify in runtime a different backend, if they want.
The default backend uses XLA for GPU/TPU is available.github.com/gomlx/gomlx/pkg/core/tensors): These represent actual values, that can have local storage or "on-device" (accelerator) storage.
Usually, they are only used as inputs and outputs of computations, or to save, load or print values. Most methods are about conversion or
access to the underlying data (e.g., tensor.Value() returns a generic value, or tensor.Local().Copy() for moving back to CPU memory).github.com/gomlx/gomlx/pkg/ml/context): A container for stateful variables (like model weights) and hyperparameters. It has reference semantics.github.com/gomlx/gomlx/pkg/core/graph*Node as input and outputs.github.com/pkg/errors.
The use of exceptions (panics) is only when building graph computations, not for the the other packages.graph package
with import . "github.com/gomlx/gomlx/pkg/core/graph", and move all graph computation building functions in its own .go file.graph package reference for a list of common functions and their PyTorch equivalents.Example:
import . "github.com/gomlx/gomlx/pkg/core/graph"
func EuclideanDistance(a, b *Node) *Node {
return Sqrt(ReduceAllSum(Square(Sub(a, b))))
}
Node has a shape (and dtype).
When the shape of the *Node is known or fixed, it's often described as a side comment, or asserted (With something like x.Shape().AssertDims(batchSize, embedDim)) to make the code easy to read. Inputs or outputs of functions that that take a fixed shape should be documented in the function documentation.graph.Exec objectgraph.NewExec(backend, fn), where fn is the graph-building function.exec.Call(inputs...) is used to execute the compiled graph, taking tensors.Tensor or standard Go values (slices of slices) and returning tensors.Tensor.tensors.Tensor, but can be any value that can be converted automatically (so slices or slice or slices).github.com/gomlx/gomlx/pkg/core/tensorstensors.FromValue(...)) or directly on the backend device device (usually happens automatically for outputs of executions).tensors.FromValue(any) or tensors.FromShape(shape) to create tensors.exec.Call(input1, input2). The donated tensor's memory will be overwritten, so it shouldn't be used afterward.github.com/gomlx/gomlx/pkg/ml/context/model/layer1). You can enter sub-scopes via ctx.In("layer1").context.Context has a reference semantics and works as a "current scope path", and can cheaply be copied.ctx.VariableWithValue(name, value) or ctx.VariableWithShape(name, shape). Once created, they persist in the context and can be retrieved using ctx.InspectVariable(...).ctx.SetParam("key", value) and retrieved with context.GetParamOr(ctx, "key", defaultValue).checkpoints.New(ctx, dir) helps save and load the state of all variables in a context.Exec: context.Exec is similar to graph.Exec but designed for ML models. Its builder function takes an extra *context.Context parameter. Variables can be used and set in the graph building, and context.Exec will handle passing their values as inputs and outputs.Example:
func DenseLayer(ctx *context.Context, x *Node, outputDim int) *Node {
inputDim := x.Shape().Dimensions[len(x.Shape().Dimensions)-1]
weightsVar := ctx.VariableWithShape("weights", shapes.Make(x.DType(), inputDim, outputDim))
biasVar := ctx.VariableWithShape("bias", shapes.Make(x.DType(), outputDim))
x = Dot(x, weightsVar.ValueGraph(x.Graph())).Product()
return Add(x, biasVar.ValueGraph(x.Graph()))
}
github.com/gomlx/gomlx/pkg/ml/layers and sub-packageslayers package provides standard higher-level building blocks for ML models.*context.Context extensively to manage the weights/biases for each layer.activations (Relu, Swish, etc.), fnn (feed-forward neural networks), kan (Kolmogorov-Arnold Networks), regularizers, etc.layers package reference for a list of common layers and their PyTorch equivalents.github.com/gomlx/gomlx/pkg/ml/trainexamples/adult/demo: Shows a full ML pipeline.train.Trainer orchestrates the model function, the loss function, and the optimizer.
context.Context, a model function, a loss function (e.g., losses.BinaryCrossentropyLogits), and an optimizer (e.g., optimizers.Adam).pkg/ml/train/metrics): Used to evaluate model performance during training and evaluation.
train.NewTrainer initialization (one list for train metrics, one for eval metrics).metrics.NewMeanBinaryLogitsAccuracy(), metrics.NewSparseCategoricalAccuracy().train.Loop manages the iterative process, feeding datasets to the Trainer and calling callbacks (e.g., checkpoint saving, plotting).Example Training Pipeline:
// 1. Create dataset
trainDS := CreateDataset(...)
// 2. Metrics we are interested in.
meanAccuracyMetric := metrics.NewMeanBinaryLogitsAccuracy("Mean Accuracy", "#acc")
movingAccuracyMetric := metrics.NewMovingAverageBinaryLogitsAccuracy("Moving Average Accuracy", "~acc", 0.01)
// 3. Create a train.Trainer: orchestrates running the model, feeding results to the optimizer, evaluating metrics.
trainer := train.NewTrainer(backend, ctx, Model, losses.BinaryCrossentropyLogits,
optimizers.FromContext(ctx),
[]metrics.Interface{movingAccuracyMetric}, // trainMetrics
[]metrics.Interface{meanAccuracyMetric}) // evalMetrics
// 4. Create a standard training loop
loop := train.NewLoop(trainer)
// 5. Attach a progress bar to the loop.
commandline.AttachProgressBar(loop)
// 6. Get hyperparameters and run the training loop
trainSteps := context.GetParamOr(ctx, "train_steps", 1000)
_, err := loop.RunToGlobalStep(trainDS, trainSteps)
if err != nil {
return err
}