Skip to content

Commit

Permalink
Fix average epoch/throughput regexs for multiprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
mselim00 committed Jan 27, 2025
1 parent 714576d commit 874c16d
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 10 deletions.
17 changes: 9 additions & 8 deletions test/cases/neuron-training/bert_training_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,15 +37,15 @@ var (
trainingPodCommServiceManifest []byte

// Regex to match lines like:
// ...[Rank 0] local_samples=50.0, training_time=10.00s, local_throughput=5.00 samples/s, local_avg_epoch_time=...
// local_throughput=5.00 samples/s
rankThroughputRegex = regexp.MustCompile(
`\[Rank\s+(\d+)\].+local_throughput\s*=\s*([\d\.]+)\s+samples\/s`,
`local_throughput\s*=\s*([\d\.]+)\s+samples\/s`,
)

// Regex to match lines like:
// ...[Rank 0] ... local_avg_epoch_time=12.50s
// local_avg_epoch_time=12.50s
rankEpochTimeRegex = regexp.MustCompile(
`\[Rank\s+(\d+)\].+local_avg_epoch_time\s*=\s*([\d\.]+)s`,
`local_avg_epoch_time=([\d\.]+)s`,
)
)

Expand Down Expand Up @@ -211,8 +211,8 @@ func aggregateThroughputFromLogs(logs string) (avg float64, sum float64, count i
for scanner.Scan() {
line := scanner.Text()
matches := rankThroughputRegex.FindStringSubmatch(line)
if len(matches) == 3 {
valStr := matches[2] // e.g. "5.00"
if len(matches) == 2 {
valStr := matches[1] // e.g. "5.00"
val, err := strconv.ParseFloat(valStr, 64)
if err == nil {
sum += val
Expand All @@ -236,8 +236,9 @@ func aggregateEpochTimeFromLogs(logs string) (avg float64, sum float64, count in
for scanner.Scan() {
line := scanner.Text()
matches := rankEpochTimeRegex.FindStringSubmatch(line)
if len(matches) == 3 {
valStr := matches[2] // e.g. "12.50"
fmt.Println(matches)
if len(matches) == 2 {
valStr := matches[1] // e.g. "12.50"
val, err := strconv.ParseFloat(valStr, 64)
if err == nil {
sum += val
Expand Down
2 changes: 0 additions & 2 deletions test/cases/neuron-training/vars.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,11 @@ var (
efaPerNode int
neuronPerNode int
neuronCorePerNode int
masterPort *string
)

func init() {
// Define command-line flags
bertTrainingImage = flag.String("bertTrainingImage", "", "Docker image used for BERT training workload")
efaEnabled = flag.Bool("efaEnabled", false, "Enable Elastic Fabric Adapter (EFA)")
nodeType = flag.String("nodeType", "", "Instance type for cluster nodes (e.g., inf1.24xlarge)")
masterPort = flag.String("masterPort", "12355", "Port to use for inter-process communication")
}

0 comments on commit 874c16d

Please sign in to comment.