Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelize neuron training processes for each neuron core #566

Merged
merged 9 commits into from
Jan 28, 2025
115 changes: 48 additions & 67 deletions test/cases/neuron-training/bert_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
package training

import (
"bufio"
"bytes"
"context"
_ "embed"
"fmt"
"log"
"regexp"
"strconv"
"strings"
"testing"
"time"

Expand All @@ -30,19 +28,22 @@ import (
)

var (
//go:embed manifests/neuron-bert-training.yaml
neuronBertTrainingManifest []byte
mselim00 marked this conversation as resolved.
Show resolved Hide resolved
//go:embed manifests/bert-training.yaml
bertTrainingJobManifest []byte

//go:embed manifests/training-comm-service.yaml
trainingPodCommServiceManifest []byte

// Regex to match lines like:
// ...[Rank 0] local_samples=50.0, training_time=10.00s, local_throughput=5.00 samples/s, local_avg_epoch_time=...
// local_throughput=5.00 samples/s
rankThroughputRegex = regexp.MustCompile(
`\[Rank\s+(\d+)\].+local_throughput\s*=\s*([\d\.]+)\s+samples\/s`,
`local_throughput\s*=\s*([\d\.]+)\s+samples\/s`,
)

// Regex to match lines like:
// ...[Rank 0] ... local_avg_epoch_time=12.50s
// local_avg_epoch_time=12.50s
rankEpochTimeRegex = regexp.MustCompile(
`\[Rank\s+(\d+)\].+local_avg_epoch_time\s*=\s*([\d\.]+)s`,
`local_avg_epoch_time=([\d\.]+)s`,
)
)

Expand All @@ -56,27 +57,36 @@ func TestBertTraining(t *testing.T) {
renderVars := map[string]string{
"BertTrainingImage": *bertTrainingImage,
"NodeType": *nodeType,
"SlotsPerWorker": fmt.Sprintf("%d", 1), // Hardcode to 1 for now
"WorkerReplicas": fmt.Sprintf("%d", nodeCount),
"NP": fmt.Sprintf("%d", nodeCount),
"SlotsPerWorker": fmt.Sprintf("%d", nodeCount),
"NodeCount": fmt.Sprintf("%d", nodeCount),
"NeuronPerNode": fmt.Sprintf("%d", neuronPerNode),
"NeuronCorePerNode": fmt.Sprintf("%d", neuronCorePerNode),
"EFARequested": fmt.Sprintf("%d", efaPerNode),
"EFAPerNode": fmt.Sprintf("%d", efaPerNode),
}

// Render the manifest
renderedManifest, err := fwext.RenderManifests(neuronBertTrainingManifest, renderVars)
renderedManifest, err := fwext.RenderManifests(bertTrainingJobManifest, renderVars)
if err != nil {
t.Fatalf("failed to render neuron BERT training manifest: %v", err)
}

renderedCommServiceManifest, err := fwext.RenderManifests(trainingPodCommServiceManifest, renderVars)
if err != nil {
t.Fatalf("failed to render pod communication manifest: %v", err)
}

// Define a feature for the Neuron BERT training
neuronTraining := features.New("neuron-training").
neuronTraining := features.New("bert-training").
WithLabel("suite", "neuron").
WithLabel("hardware", "neuron").
Setup(func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
log.Println("Applying pod communication service manifest.")
err := fwext.ApplyManifests(cfg.Client().RESTConfig(), renderedCommServiceManifest)
if err != nil {
t.Fatalf("failed to apply communication service manifest: %v", err)
}
log.Println("Applying rendered Neuron training manifest.")
err := fwext.ApplyManifests(cfg.Client().RESTConfig(), renderedManifest)
err = fwext.ApplyManifests(cfg.Client().RESTConfig(), renderedManifest)
if err != nil {
t.Fatalf("failed to apply Neuron training manifest: %v", err)
}
Expand All @@ -86,47 +96,48 @@ func TestBertTraining(t *testing.T) {
Assess("Neuron training Job succeeds", func(ctx context.Context, t *testing.T, cfg *envconf.Config) context.Context {
job := &batchv1.Job{
ObjectMeta: metav1.ObjectMeta{
Name: "neuron-training-launcher",
Name: "bert-training",
Namespace: "default",
},
}

// Step 1: Wait for the Job resource to appear
log.Println("Waiting for the 'neuron-training-launcher' Job resource to be created...")
log.Println("Waiting for the 'bert-training' Job resource to be created...")
err := wait.For(
conditions.New(cfg.Client().Resources()).ResourceMatch(job, func(object k8s.Object) bool {
return true
}),
wait.WithTimeout(time.Minute*5),
)
if err != nil {
t.Fatalf("Failed to detect creation of Job 'neuron-training-launcher': %v", err)
t.Fatalf("Failed to detect creation of Job 'bert-training': %v", err)
}
log.Println("Job 'neuron-training-launcher' is created in the cluster.")
log.Println("Job 'bert-training' is created in the cluster.")

// Step 2: Wait for the Job to succeed (i.e., complete)
log.Println("Waiting for 'neuron-training-launcher' Job to succeed...")
log.Println("Waiting for 'bert-training' Job to succeed...")
err = wait.For(
fwext.NewConditionExtension(cfg.Client().Resources()).JobSucceeded(job),
wait.WithTimeout(30*time.Minute),
// Bake in large margin b/c compile time. TODO: pre-compile and find best fit
wait.WithTimeout(60*time.Minute),
)
if err != nil {
t.Fatalf("Neuron training Job did not succeed: %v", err)
}
log.Println("Job 'neuron-training-launcher' succeeded!")
log.Println("Job 'bert-training' succeeded!")

// Gather logs from the training pods (launcher)
logsBuf, logErr := gatherJobLogs(ctx, cfg, "default", "neuron-training-launcher")
logsBuf, logErr := gatherJobLogs(ctx, cfg, "default", "bert-training")
if logErr != nil {
log.Printf("Warning: failed to retrieve neuron-training job logs: %v", logErr)
log.Printf("Warning: failed to retrieve bert-training job logs: %v", logErr)
return ctx
}

log.Println("== Raw Logs from the launcher pods ==")
log.Println(logsBuf.String())

// 1) Throughput Aggregation
avgThru, sumThru, countThru := aggregateThroughputFromLogs(logsBuf.String())
avgThru, sumThru, countThru := aggregateMetricFromLogs(rankThroughputRegex, logsBuf.String())
if countThru == 0 {
log.Printf("No throughput lines found. Possibly missing in logs.")
} else {
Expand All @@ -137,7 +148,7 @@ func TestBertTraining(t *testing.T) {
}

// 2) Average Epoch Time Aggregation
avgEp, sumEp, countEp := aggregateEpochTimeFromLogs(logsBuf.String())
avgEp, sumEp, countEp := aggregateMetricFromLogs(rankEpochTimeRegex, logsBuf.String())
if countEp == 0 {
log.Printf("No epoch time lines found. Possibly missing in logs.")
} else {
Expand Down Expand Up @@ -188,48 +199,18 @@ func gatherJobLogs(ctx context.Context, cfg *envconf.Config, namespace, jobName
return &out, nil
}

// aggregateThroughputFromLogs scans the log output for lines like:
//
// [Rank 3] ... local_throughput=5.00 ...
// aggregateMetricFromLogs scans the log output for lines based on a provided RegEx.
// The RegEx is assumed to take a sufficiently unique form like <metric>=<value> to avoid
// collisions, but also to simplify parsing.
//
// returning the average, sum, and count for rank throughput lines.
func aggregateThroughputFromLogs(logs string) (avg float64, sum float64, count int) {
scanner := bufio.NewScanner(strings.NewReader(logs))
for scanner.Scan() {
line := scanner.Text()
matches := rankThroughputRegex.FindStringSubmatch(line)
if len(matches) == 3 {
valStr := matches[2] // e.g. "5.00"
val, err := strconv.ParseFloat(valStr, 64)
if err == nil {
sum += val
count++
}
}
}
if count > 0 {
avg = sum / float64(count)
}
return avg, sum, count
}

// aggregateEpochTimeFromLogs scans log output for lines like:
//
// [Rank 0] ... local_avg_epoch_time=12.50s
//
// returning the average, sum, and count for rank epoch times.
func aggregateEpochTimeFromLogs(logs string) (avg float64, sum float64, count int) {
scanner := bufio.NewScanner(strings.NewReader(logs))
for scanner.Scan() {
line := scanner.Text()
matches := rankEpochTimeRegex.FindStringSubmatch(line)
if len(matches) == 3 {
valStr := matches[2] // e.g. "12.50"
val, err := strconv.ParseFloat(valStr, 64)
if err == nil {
sum += val
count++
}
// returns the average, sum, and count for all occurrences of the metric.
func aggregateMetricFromLogs(metricRegex *regexp.Regexp, logs string) (avg float64, sum float64, count int) {
matches := metricRegex.FindAllStringSubmatch(logs, -1)
for _, match := range matches {
val, err := strconv.ParseFloat(match[1], 64)
if err == nil {
sum += val
count++
Comment on lines +207 to +213
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is nice.

}
}
if count > 0 {
Expand Down
36 changes: 9 additions & 27 deletions test/cases/neuron-training/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,9 @@ import (

fwext "github.com/aws/aws-k8s-tester/internal/e2e"
appsv1 "k8s.io/api/apps/v1"
v1 "k8s.io/api/core/v1"
metav1 "k8s.io/apimachinery/pkg/apis/meta/v1"
"k8s.io/client-go/kubernetes"
"sigs.k8s.io/e2e-framework/klient/wait"
"sigs.k8s.io/e2e-framework/klient/wait/conditions"
"sigs.k8s.io/e2e-framework/pkg/env"
"sigs.k8s.io/e2e-framework/pkg/envconf"
)
Expand All @@ -28,8 +26,6 @@ var (
neuronDevicePluginRbacManifest []byte
//go:embed manifests/k8s-neuron-device-plugin.yml
neuronDevicePluginManifest []byte
//go:embed manifests/mpi-operator.yaml
mpiOperatorManifest []byte
//go:embed manifests/efa-device-plugin.yaml
efaDevicePluginManifest []byte
)
Expand All @@ -44,35 +40,17 @@ func TestMain(m *testing.M) {
manifests := [][]byte{
neuronDevicePluginRbacManifest,
neuronDevicePluginManifest,
mpiOperatorManifest,
efaDevicePluginManifest,
}

testenv.Setup(
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("Applying Neuron device plugin RBAC, Neuron device plugin, MPI operator, and EFA device plugin manifests.")
log.Println("Applying Neuron device plugin RBAC, Neuron device plugin and EFA device plugin manifests.")
err := fwext.ApplyManifests(config.Client().RESTConfig(), manifests...)
if err != nil {
return ctx, fmt.Errorf("failed to apply manifests: %w", err)
}
log.Println("Successfully applied Neuron device plugin RBAC, Neuron device plugin, MPI operator, and EFA device plugin manifests.")
return ctx, nil
},
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("Waiting for MPI Operator deployment to be available.")
deployment := appsv1.Deployment{
ObjectMeta: metav1.ObjectMeta{Name: "mpi-operator", Namespace: "mpi-operator"},
}
err := wait.For(
conditions.New(config.Client().Resources()).DeploymentConditionMatch(
&deployment, appsv1.DeploymentAvailable, v1.ConditionTrue,
),
wait.WithTimeout(time.Minute*5),
)
if err != nil {
return ctx, fmt.Errorf("MPI Operator deployment is not available: %w", err)
}
log.Println("MPI Operator deployment is available.")
log.Println("Successfully applied Neuron device plugin RBAC, Neuron device plugin and EFA device plugin manifests.")
return ctx, nil
},
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
Expand Down Expand Up @@ -110,13 +88,13 @@ func TestMain(m *testing.M) {

testenv.Finish(
func(ctx context.Context, config *envconf.Config) (context.Context, error) {
log.Println("Deleting Neuron device plugin, MPI operator, and EFA device plugin manifests.")
log.Println("Deleting Neuron device plugin and EFA device plugin manifests.")
slices.Reverse(manifests)
err := fwext.DeleteManifests(config.Client().RESTConfig(), manifests...)
if err != nil {
return ctx, fmt.Errorf("failed to delete manifests: %w", err)
}
log.Println("Successfully deleted Neuron device plugin, MPI operator, and EFA device plugin manifests.")
log.Println("Successfully deleted Neuron device plugin and EFA device plugin manifests.")
return ctx, nil
},
)
Expand Down Expand Up @@ -144,8 +122,12 @@ func checkNodeTypes(ctx context.Context, config *envconf.Config) (context.Contex

// Check if all nodes have the same instance type
for i := 1; i < len(nodes.Items); i++ {
if nodes.Items[i].Labels["node.kubernetes.io/instance-type"] != nodes.Items[i-1].Labels["node.kubernetes.io/instance-type"] {
currentInstanceType := nodes.Items[i].Labels["node.kubernetes.io/instance-type"]
if currentInstanceType != nodes.Items[i-1].Labels["node.kubernetes.io/instance-type"] {
return ctx, fmt.Errorf("inconsistent node types detected, all nodes must have the same instance type")
} else if *nodeType == "" {
log.Printf("[INFO] nodeType was not set, discovered type %s", currentInstanceType)
*nodeType = currentInstanceType
}
}

Expand Down
47 changes: 47 additions & 0 deletions test/cases/neuron-training/manifests/bert-training.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
apiVersion: batch/v1
kind: Job
metadata:
labels:
app: bert-training
name: bert-training
spec:
completionMode: Indexed
completions: {{.NodeCount}}
parallelism: {{.NodeCount}}
template:
spec:
restartPolicy: Never
containers:
- image: {{.BertTrainingImage}}
name: bert-training
env:
- name: MASTER_ADDR
value: bert-training-0.training
args:
- sh
- -c
- |
# Enable EFA https://awsdocs-neuron.readthedocs-hosted.com/en/latest/neuron-runtime/nrt-troubleshoot.html#fi-efa-fork-safe (AL2 legacy requirement)
export FI_EFA_FORK_SAFE=1
export CCOM_SOCKET_IFNAME=eth0
export NCCL_DEBUG=ERROR
torchrun --nproc_per_node {{.NeuronCorePerNode}} --nnodes {{.NodeCount}} --node_rank $JOB_COMPLETION_INDEX --master_addr $MASTER_ADDR train.py
volumeMounts:
- name: dshm
mountPath: /dev/shm
resources:
requests:
aws.amazon.com/neuron: {{.NeuronPerNode}}
aws.amazon.com/neuroncore: {{.NeuronCorePerNode}}
vpc.amazonaws.com/efa: {{.EFAPerNode}}
limits:
aws.amazon.com/neuron: {{.NeuronPerNode}}
aws.amazon.com/neuroncore: {{.NeuronCorePerNode}}
vpc.amazonaws.com/efa: {{.EFAPerNode}}
nodeSelector:
node.kubernetes.io/instance-type: {{.NodeType}}
subdomain: training
volumes:
- name: dshm
emptyDir:
medium: Memory
Loading