Skip to content

Commit

Permalink
Fix metrics parsing and unspecified nodetype handling
Browse files Browse the repository at this point in the history
  • Loading branch information
mselim00 committed Jan 28, 2025
1 parent f3bc20f commit 74ffd5d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 46 deletions.
58 changes: 13 additions & 45 deletions test/cases/neuron-training/bert_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,13 @@
package training

import (
"bufio"
"bytes"
"context"
_ "embed"
"fmt"
"log"
"regexp"
"strconv"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -139,7 +137,7 @@ func TestBertTraining(t *testing.T) {
log.Println(logsBuf.String())

// 1) Throughput Aggregation
avgThru, sumThru, countThru := aggregateThroughputFromLogs(logsBuf.String())
avgThru, sumThru, countThru := aggregateMetricFromLogs(rankThroughputRegex, logsBuf.String())
if countThru == 0 {
log.Printf("No throughput lines found. Possibly missing in logs.")
} else {
Expand All @@ -150,7 +148,7 @@ func TestBertTraining(t *testing.T) {
}

// 2) Average Epoch Time Aggregation
avgEp, sumEp, countEp := aggregateEpochTimeFromLogs(logsBuf.String())
avgEp, sumEp, countEp := aggregateMetricFromLogs(rankEpochTimeRegex, logsBuf.String())
if countEp == 0 {
log.Printf("No epoch time lines found. Possibly missing in logs.")
} else {
Expand Down Expand Up @@ -201,48 +199,18 @@ func gatherJobLogs(ctx context.Context, cfg *envconf.Config, namespace, jobName
return &out, nil
}

// aggregateThroughputFromLogs scans the log output for lines like:
// aggregateMetricFromLogs scans the log output for lines based on a provided RegEx.
// The RegEx is assumed to take a sufficiently unique form like <metric>=<value> to avoid
// collisions, but also to simplify parsing.
//
// [Rank 3] ... local_throughput=5.00 ...
//
// returning the average, sum, and count for rank throughput lines.
func aggregateThroughputFromLogs(logs string) (avg float64, sum float64, count int) {
scanner := bufio.NewScanner(strings.NewReader(logs))
for scanner.Scan() {
line := scanner.Text()
matches := rankThroughputRegex.FindStringSubmatch(line)
if len(matches) == 2 {
valStr := matches[1] // e.g. "5.00"
val, err := strconv.ParseFloat(valStr, 64)
if err == nil {
sum += val
count++
}
}
}
if count > 0 {
avg = sum / float64(count)
}
return avg, sum, count
}

// aggregateEpochTimeFromLogs scans log output for lines like:
//
// [Rank 0] ... local_avg_epoch_time=12.50s
//
// returning the average, sum, and count for rank epoch times.
func aggregateEpochTimeFromLogs(logs string) (avg float64, sum float64, count int) {
scanner := bufio.NewScanner(strings.NewReader(logs))
for scanner.Scan() {
line := scanner.Text()
matches := rankEpochTimeRegex.FindStringSubmatch(line)
if len(matches) == 2 {
valStr := matches[1] // e.g. "12.50"
val, err := strconv.ParseFloat(valStr, 64)
if err == nil {
sum += val
count++
}
// returns the average, sum, and count for all occurrences of the metric.
func aggregateMetricFromLogs(metricRegex *regexp.Regexp, logs string) (avg float64, sum float64, count int) {
matches := metricRegex.FindAllStringSubmatch(logs, -1)
for _, match := range matches {
val, err := strconv.ParseFloat(match[1], 64)
if err == nil {
sum += val
count++
}
}
if count > 0 {
Expand Down
6 changes: 5 additions & 1 deletion test/cases/neuron-training/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,8 +122,12 @@ func checkNodeTypes(ctx context.Context, config *envconf.Config) (context.Contex

// Check if all nodes have the same instance type
for i := 1; i < len(nodes.Items); i++ {
if nodes.Items[i].Labels["node.kubernetes.io/instance-type"] != nodes.Items[i-1].Labels["node.kubernetes.io/instance-type"] {
currentInstanceType := nodes.Items[i].Labels["node.kubernetes.io/instance-type"]
if currentInstanceType != nodes.Items[i-1].Labels["node.kubernetes.io/instance-type"] {
return ctx, fmt.Errorf("inconsistent node types detected, all nodes must have the same instance type")
} else if *nodeType == "" {
log.Printf("[INFO] nodeType was not set, discovered type %s.", currentInstanceType)
*nodeType = currentInstanceType
}
}

Expand Down

0 comments on commit 74ffd5d

Please sign in to comment.