From 7988809a7fdbb6e4c533d4dd74d7029a5373a4da Mon Sep 17 00:00:00 2001 From: Brent Salisbury Date: Wed, 22 Jan 2025 02:52:23 -0500 Subject: [PATCH 1/2] Adds an API server to enable advanced ilab features from the UI Signed-off-by: Brent Salisbury --- api-server/.gitignore | 26 + api-server/README.md | 408 +++++++++++++ api-server/go.mod | 16 + api-server/go.sum | 11 + api-server/handlers.go | 867 +++++++++++++++++++++++++++ api-server/jobs.go | 269 +++++++++ api-server/main.go | 865 ++++++++++++++++++++++++++ api-server/qna-eval/Containerfile | 40 ++ api-server/qna-eval/qna-eval.py | 101 ++++ api-server/qna-eval/requirements.txt | 176 ++++++ api-server/utils.go | 140 +++++ api-server/vllm-serve.go | 178 ++++++ api-server/zap.go | 34 ++ 13 files changed, 3131 insertions(+) create mode 100644 api-server/.gitignore create mode 100644 api-server/README.md create mode 100644 api-server/go.mod create mode 100644 api-server/go.sum create mode 100644 api-server/handlers.go create mode 100644 api-server/jobs.go create mode 100644 api-server/main.go create mode 100644 api-server/qna-eval/Containerfile create mode 100644 api-server/qna-eval/qna-eval.py create mode 100644 api-server/qna-eval/requirements.txt create mode 100644 api-server/utils.go create mode 100644 api-server/vllm-serve.go create mode 100644 api-server/zap.go diff --git a/api-server/.gitignore b/api-server/.gitignore new file mode 100644 index 00000000..cd9e8a34 --- /dev/null +++ b/api-server/.gitignore @@ -0,0 +1,26 @@ +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, built with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# Dependency directories (remove the comment below to include it) +# vendor/ + +# Go workspace file +go.work +go.work.sum + +# env file +.env + +# app specific +logs/ +jobs.json diff --git a/api-server/README.md b/api-server/README.md new file mode 100644 index 00000000..905859f8 --- /dev/null +++ b/api-server/README.md @@ -0,0 +1,408 @@ +# Ilab API Server + +## Overview + +This is an Ilab API Server that is a temporary set of APIs for service developing apps against [InstructLab](https://github.com/instructlab/). It provides endpoints for model management, data generation, training, job tracking and job logging. + +## Quickstart + +### Prerequisites + +- Ensure that the required directories (`base-dir` and `taxonomy-path`) exist and are accessible and Go is installed in the $PATH. + +### Install Dependencies + +To install the necessary dependencies, run: + +```bash +# gcc in $PATH required for sqlite +go mod download +``` + +### Run the Server + +#### For macOS with Metal (MPS) + +```bash +go run main.go --base-dir /path/to/base-dir --taxonomy-path /path/to/taxonomy --osx +``` + +#### For CUDA-enabled environments + +```bash +go run main.go --base-dir /path/to/base-dir --taxonomy-path /path/to/taxonomy --cuda +``` + +#### For a RHEL AI machine + +- If you're operating on a Red Hat Enterprise Linux AI (RHEL AI) machine, and the ilab binary is already available in your $PATH, you don't need to specify the --base-dir. Additionally, pass CUDA support with `--cuda`. + +```bash +go run main.go --taxonomy-path ~/.local/share/instructlab/taxonomy/ --rhelai --cuda +``` + +The `--rhelai` flag indicates that the ilab binary is available in the system's $PATH and does not require a virtual environment. +When using `--rhelai`, the `--base-dir` flag is not required since it will be in a known location at least for meow. + +### Example command with paths + +Here's an example command for running the server on a macOS machine with Metal support: + +```bash +go run main.go --base-dir /Users/user/code/instructlab --taxonomy-path ~/.local/share/instructlab/taxonomy/ --osx +``` + +## API Documentation + +### Models + +#### Get Models + +**Endpoint**: `GET /models` +Fetches the list of available models. + +- **Response**: + + ```json + [ + { + "name": "model-name", + "last_modified": "timestamp", + "size": "size-string" + } + ] + ``` + +### Data + +#### Get Data + +**Endpoint**: `GET /data` +Fetches the list of datasets. + +- **Response**: + + ```json + [ + { + "dataset": "dataset-name", + "created_at": "timestamp", + "file_size": "size-string" + } + ] + ``` + +#### Generate Data + +**Endpoint**: `POST /data/generate` +Starts a data generation job. + +- **Request**: None + +- **Response**: + + ```json + { + "job_id": "generated-job-id" + } + ``` + +### Jobs + +#### List Jobs + +**Endpoint**: `GET /jobs` +Fetches the list of all jobs. + +- **Response**: + + ```json + [ + { + "job_id": "job-id", + "status": "running/finished/failed", + "cmd": "command", + "branch": "branch-name", + "start_time": "timestamp", + "end_time": "timestamp" + } + ] + ``` + +#### Job Status + +**Endpoint**: `GET /jobs/{job_id}/status` +Fetches the status of a specific job. + +- **Response**: + + ```json + { + "job_id": "job-id", + "status": "running/finished/failed", + "branch": "branch-name", + "command": "command" + } + ``` + +#### Job Logs + +**Endpoint**: `GET /jobs/{job_id}/logs` +Fetches the logs of a specific job. + +- **Response**: + Text logs of the job. + +### Training + +#### Start Training + +**Endpoint**: `POST /model/train` +Starts a training job. + +- **Request**: + + ```json + { + "modelName": "name-of-the-model", + "branchName": "name-of-the-branch", + "epochs": 10 // Optional + } + ``` + + **Parameters**: + - `modelName` (string, required): The name of the model. Can be provided **with or without** the `models/` prefix. + - Examples: + - Without prefix: `"granite-7b-lab-Q4_K_M.gguf"` + - With prefix: `"models/granite-7b-starter"` + - `branchName` (string, required): The name of the branch to train on. + - `epochs` (integer, optional): The number of training epochs. Must be a positive integer. + +- **Response**: + + ```json + { + "job_id": "training-job-id" + } + ``` + +### Pipeline + +#### Generate and Train Pipeline + +**Endpoint**: `POST /pipeline/generate-train` +Combines data generation and training into a single pipeline job. + +- **Request**: + + ```json + { + "modelName": "name-of-the-model", + "branchName": "name-of-the-branch", + "epochs": 10 // Optional + } + ``` + + **Parameters**: + - `modelName` (string, required): The name of the model. Can be provided **with or without** the `models/` prefix. + - Examples: + - Without prefix: `"granite-7b-lab-Q4_K_M.gguf"` + - With prefix: `"models/granite-7b-starter"` + - `branchName` (string, required): The name of the branch to train on. + - `epochs` (integer, optional): The number of training epochs. Must be a positive integer. + +- **Response**: + + ```json + { + "pipeline_job_id": "pipeline-job-id" + } + ``` + +### Model Serving + +#### Serve Latest Checkpoint + +**Endpoint**: `POST /model/serve-latest` +Serves the latest model checkpoint on port `8001`. + +- **Request**: + + ```json + { + "checkpoint": "samples_12345" // Optional + } + ``` + + **Parameters**: + - `checkpoint` (string, optional): Name of the checkpoint directory (e.g., `"samples_12345"`). If omitted, the server uses the latest checkpoint. + +- **Response**: + + ```json + { + "status": "model process started", + "job_id": "serve-job-id" + } + ``` + +#### Serve Base Model + +**Endpoint**: `POST /model/serve-base` +Serves the base model on port `8000`. + +- **Request**: None + +- **Response**: + + ```json + { + "status": "model process started", + "job_id": "serve-job-id" + } + ``` + +### QnA Evaluation + +#### Run QnA Evaluation + +**Endpoint**: `POST /qna-eval` +Performs QnA evaluation using a specified model and YAML configuration. + +- **Request**: + + ```json + { + "model_path": "/path/to/model", + "yaml_file": "/path/to/config.yaml" + } + ``` + + **Parameters**: + - `model_path` (string, required): The file path to the model. + - `yaml_file` (string, required): The file path to the YAML configuration. + +- **Response**: + - **Success**: + + ```json + { + "result": "evaluation results..." + } + ``` + + - **Error**: + + ```json + { + "error": "error message" + } + ``` + +### Checkpoints + +#### List Checkpoints + +**Endpoint**: `GET /checkpoints` +Lists all available checkpoints. + +- **Response**: + + ```json + [ + "checkpoint1", + "checkpoint2", + "checkpoint3" + ] + ``` + +### VLLM + +#### List VLLM Containers + +**Endpoint**: `GET /vllm-containers` +Fetches the list of VLLM containers. + +- **Response**: + + ```json + { + "containers": [ + { + "container_id": "container-id-1", + "served_model_name": "pre-train", + "status": "running", + "port": "8000" + }, + { + "container_id": "container-id-2", + "served_model_name": "post-train", + "status": "running", + "port": "8001" + } + ] + } + ``` + +#### Unload VLLM Container + +**Endpoint**: `POST /vllm-unload` +Unloads a specific VLLM container. + +- **Request**: + + ```json + { + "model_name": "pre-train" // Must be either "pre-train" or "post-train" for meow + } + ``` + +- **Response**: + + ```json + { + "status": "success", + "message": "Model 'pre-train' unloaded successfully", + "modelName": "pre-train" + } + ``` + + **Error Response**: + + ```json + { + "error": "Failed to unload model 'pre-train': error details..." + } + ``` + +#### VLLM Status + +**Endpoint**: `GET /vllm-status` +Fetches the status of a specific VLLM model. + +- **Query Parameters**: + - `model_name` (string, required): The name of the model. Must be either `"pre-train"` or `"post-train"`. + +- **Response**: + + ```json + { + "status": "running" // Possible values: "running", "loading", "stopped" + } + ``` + +### GPU Information + +#### GPU Free + +**Endpoint**: `GET /gpu-free` +Retrieves the number of free and total GPUs available. + +- **Response**: + + ```json + { + "free_gpus": 2, + "total_gpus": 4 + } + ``` diff --git a/api-server/go.mod b/api-server/go.mod new file mode 100644 index 00000000..91cb5132 --- /dev/null +++ b/api-server/go.mod @@ -0,0 +1,16 @@ +module ilab-api-router + +go 1.21.6 + +require ( + github.com/gorilla/mux v1.8.1 + github.com/mattn/go-sqlite3 v1.14.24 + github.com/spf13/cobra v1.8.1 +) + +require ( + github.com/inconshreveable/mousetrap v1.1.0 // indirect + github.com/spf13/pflag v1.0.5 // indirect + go.uber.org/multierr v1.10.0 // indirect + go.uber.org/zap v1.27.0 // indirect +) diff --git a/api-server/go.sum b/api-server/go.sum new file mode 100644 index 00000000..8246dd0d --- /dev/null +++ b/api-server/go.sum @@ -0,0 +1,11 @@ +github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/api-server/handlers.go b/api-server/handlers.go new file mode 100644 index 00000000..2ad30d05 --- /dev/null +++ b/api-server/handlers.go @@ -0,0 +1,867 @@ +package main + +import ( + "bytes" + "encoding/json" + "fmt" + "github.com/gorilla/mux" + "io/ioutil" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "time" +) + +// ----------------------------------------------------------------------------- +// HTTP Handlers +// ----------------------------------------------------------------------------- + +// getModelsHandler is the HTTP handler for the /models endpoint. +func (srv *ILabServer) getModelsHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("GET /models called") + + srv.modelCache.Mutex.Lock() + cachedTime := srv.modelCache.Time + cachedModels := make([]Model, len(srv.modelCache.Models)) + copy(cachedModels, srv.modelCache.Models) + srv.modelCache.Mutex.Unlock() + + if len(cachedModels) > 0 && time.Since(cachedTime) < 20*time.Minute { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(cachedModels); err != nil { + srv.log.Errorf("Error encoding cached models: %v", err) + http.Error(w, "Failed to encode models", http.StatusInternalServerError) + return + } + srv.log.Info("GET /models returned cached models.") + return + } + + srv.log.Info("Cache is empty or stale. Refreshing model cache now...") + srv.refreshModelCache() + + srv.modelCache.Mutex.Lock() + cachedTime = srv.modelCache.Time + cachedModels = make([]Model, len(srv.modelCache.Models)) + copy(cachedModels, srv.modelCache.Models) + srv.modelCache.Mutex.Unlock() + + if len(cachedModels) > 0 { + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(cachedModels); err != nil { + srv.log.Errorf("Error encoding refreshed models: %v", err) + http.Error(w, "Failed to encode models", http.StatusInternalServerError) + return + } + srv.log.Info("GET /models returned refreshed models.") + } else { + http.Error(w, "Failed to retrieve models", http.StatusInternalServerError) + srv.log.Info("GET /models failed to retrieve models.") + } +} + +// getDataHandler is the HTTP handler for the /data endpoint. +func (srv *ILabServer) getDataHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("GET /data called") + output, err := srv.runIlabCommand("data", "list") + if err != nil { + srv.log.Errorf("Error running 'ilab data list': %v", err) + http.Error(w, string(output), http.StatusInternalServerError) + return + } + dataList, err := srv.parseDataList(output) + if err != nil { + srv.log.Errorf("Error parsing data list: %v", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(dataList) + srv.log.Info("GET /data successful") +} + +// generateDataHandler is the HTTP handler for the /data/generate endpoint. +func (srv *ILabServer) generateDataHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("POST /data/generate called") + jobID, err := srv.startGenerateJob() + if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"job_id": jobID}) + srv.log.Infof("POST /data/generate successful, job_id: %s", jobID) +} + +// trainModelHandler is the HTTP handler for the /model/train endpoint. +func (srv *ILabServer) trainModelHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("POST /model/train called") + + var reqBody struct { + ModelName string `json:"modelName"` + BranchName string `json:"branchName"` + Epochs *int `json:"epochs,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + srv.log.Errorf("Error parsing request body: %v", err) + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + if reqBody.ModelName == "" || reqBody.BranchName == "" { + srv.log.Info("Missing required parameters: modelName or branchName") + http.Error(w, "Missing required parameters: modelName or branchName", http.StatusBadRequest) + return + } + if reqBody.Epochs != nil && *reqBody.Epochs <= 0 { + srv.log.Info("Invalid 'epochs' parameter: must be a positive integer") + http.Error(w, "'epochs' must be a positive integer", http.StatusBadRequest) + return + } + + sanitizedModelName := srv.sanitizeModelName(reqBody.ModelName) + srv.log.Infof("Sanitized modelName: '%s'", sanitizedModelName) + + // Git checkout + gitCheckoutCmd := exec.Command("git", "checkout", reqBody.BranchName) + gitCheckoutCmd.Dir = srv.taxonomyPath + gitOutput, err := gitCheckoutCmd.CombinedOutput() + srv.log.Infof("Git checkout output: %s", string(gitOutput)) + if err != nil { + srv.log.Errorf("Error checking out branch '%s': %v", reqBody.BranchName, err) + http.Error(w, fmt.Sprintf("Failed to checkout branch '%s': %s", reqBody.BranchName, string(gitOutput)), http.StatusInternalServerError) + return + } + srv.log.Infof("Successfully checked out branch: '%s'", reqBody.BranchName) + + jobID, err := srv.startTrainJob(sanitizedModelName, reqBody.BranchName, reqBody.Epochs) + if err != nil { + srv.log.Errorf("Error starting train job: %v", err) + http.Error(w, "Failed to start train job", http.StatusInternalServerError) + return + } + srv.log.Infof("Train job started successfully with job_id: '%s'", jobID) + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"job_id": jobID}) + srv.log.Info("POST /model/train response sent successfully") +} + +// listVllmContainersHandler handles the GET /vllm-containers endpoint. +func (srv *ILabServer) listVllmContainersHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("GET /vllm-containers called") + + containers, err := srv.ListVllmContainers() + if err != nil { + srv.log.Errorf("Error listing vllm containers: %v", err) + http.Error(w, "Failed to list vllm containers", http.StatusInternalServerError) + return + } + + response := VllmContainerResponse{ + Containers: containers, + } + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(response); err != nil { + srv.log.Errorf("Error encoding vllm containers response: %v", err) + http.Error(w, "Failed to encode response", http.StatusInternalServerError) + return + } + srv.log.Infof("GET /vllm-containers returned %d containers", len(containers)) +} + +// unloadVllmContainerHandler handles the POST /vllm-unload endpoint. +func (srv *ILabServer) unloadVllmContainerHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("POST /vllm-unload called") + + var req UnloadModelRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + srv.log.Errorf("Error decoding unload model request: %v", err) + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + modelName := strings.TrimSpace(req.ModelName) + if modelName != "pre-train" && modelName != "post-train" { + srv.log.Errorf("Invalid model_name provided: %s", modelName) + http.Error(w, "Invalid model_name. Must be 'pre-train' or 'post-train'", http.StatusBadRequest) + return + } + + err := srv.StopVllmContainer(modelName) + if err != nil { + srv.log.Errorf("Error unloading model '%s': %v", modelName, err) + http.Error(w, fmt.Sprintf("Failed to unload model '%s': %v", modelName, err), http.StatusInternalServerError) + return + } + + response := map[string]string{ + "status": "success", + "message": fmt.Sprintf("Model '%s' unloaded successfully", modelName), + "modelName": modelName, + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + srv.log.Infof("POST /vllm-unload successfully unloaded model '%s'", modelName) +} + +// getVllmStatusHandler handles the GET /vllm-status endpoint. +func (srv *ILabServer) getVllmStatusHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Infof("vllm status called") + modelName := strings.TrimSpace(strings.ToLower(r.URL.Query().Get("model_name"))) + srv.log.Infof("Received model_name: '%s'", modelName) + + if modelName != "pre-train" && modelName != "post-train" { + srv.log.Warnf("Invalid model_name provided: %s", modelName) + http.Error(w, "Invalid model_name (must be 'pre-train' or 'post-train')", http.StatusBadRequest) + return + } + + containers, err := srv.ListVllmContainers() + if err != nil { + srv.log.Errorf("Error listing vllm containers: %v", err) + http.Error(w, "Failed to list vllm containers", http.StatusInternalServerError) + return + } + + var containerRunning bool + for _, c := range containers { + srv.log.Debugf("Checking container %s for model '%s'", c.ContainerID, modelName) + if strings.ToLower(c.ServedModelName) == modelName { + containerRunning = true + break + } + } + + w.Header().Set("Content-Type", "application/json") + + if !containerRunning { + srv.log.Infof("No running container found for model '%s'", modelName) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "stopped"}) + return + } + + srv.jobIDsMutex.RLock() + jobID, ok := srv.servedModelJobIDs[modelName] + srv.jobIDsMutex.RUnlock() + + if !ok { + srv.log.Infof("WTF jobid not found for model '%s'", modelName) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"}) + return + } + + srv.log.Infof("Retrieved job ID '%s' for model '%s'", jobID, modelName) + + job, err := srv.getJob(jobID) + if err != nil { + srv.log.Errorf("Error retrieving job '%s': %v", jobID, err) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"}) + return + } + if job == nil { + srv.log.Warnf("Job '%s' not found in DB", jobID) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"}) + return + } + + if job.LogFile == "" { + srv.log.Warnf("No log file specified for job '%s'", jobID) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"}) + return + } + + srv.log.Infof("Reading log file '%s' for job '%s'", job.LogFile, jobID) + logBytes, err := ioutil.ReadFile(job.LogFile) + if err != nil { + srv.log.Errorf("Error reading log file '%s': %v", job.LogFile, err) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"}) + return + } + + logContent := string(logBytes) + + if strings.Contains(logContent, "Uvicorn running") { + srv.log.Infof("vLLM has finished loading model '%s'", modelName) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "running"}) + } else { + srv.log.Debugf("Uvicorn not detected in logs for job '%s', current status: loading", jobID) + _ = json.NewEncoder(w).Encode(map[string]string{"status": "loading"}) + } +} + +// getGpuFreeHandler is the HTTP handler for the /gpu-free endpoint. +func (srv *ILabServer) getGpuFreeHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("GET /gpu-free called") + + cmd := exec.Command("nvidia-smi", "--query-gpu=memory.used", "--format=csv,noheader") + var out bytes.Buffer + var stderr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + + err := cmd.Run() + if err != nil { + srv.log.Errorf("Error running nvidia-smi: %v, stderr: %s", err, stderr.String()) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]int{"free_gpus": 0, "total_gpus": 0}) + return + } + + lines := strings.Split(strings.TrimSpace(out.String()), "\n") + freeCount := 0 + totalCount := 0 + + for _, line := range lines { + line = strings.TrimSpace(line) + if line == "" { + continue + } + totalCount++ + if strings.HasPrefix(line, "1 ") { + freeCount++ + } + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]int{ + "free_gpus": freeCount, + "total_gpus": totalCount, + }) + srv.log.Infof("GET /gpu-free => free_gpus=%d, total_gpus=%d", freeCount, totalCount) +} + +// ----------------------------------------------------------------------------- +// Jobs Handlers +// ----------------------------------------------------------------------------- + +func (srv *ILabServer) getJobStatusHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + jobID := vars["job_id"] + srv.log.Infof("GET /jobs/%s/status called", jobID) + + job, err := srv.getJob(jobID) + if err != nil { + srv.log.Errorf("Error retrieving job from DB: %v", err) + http.Error(w, "Failed to retrieve job", http.StatusInternalServerError) + return + } + if job == nil { + srv.log.Infof("Job %s not found", jobID) + http.Error(w, "Job not found", http.StatusNotFound) + return + } + + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]interface{}{ + "job_id": job.JobID, + "status": job.Status, + "branch": job.Branch, + "command": job.Cmd, + }) + srv.log.Infof("GET /jobs/%s/status successful, status: %s", jobID, job.Status) +} + +func (srv *ILabServer) getJobLogsHandler(w http.ResponseWriter, r *http.Request) { + vars := mux.Vars(r) + jobID := vars["job_id"] + srv.log.Debugf("GET /jobs/%s/logs called", jobID) + + job, err := srv.getJob(jobID) + if err != nil { + srv.log.Errorf("Error retrieving job from DB: %v", err) + http.Error(w, "Failed to retrieve job", http.StatusInternalServerError) + return + } + if job == nil { + srv.log.Warnf("Job %s not found in DB", jobID) + http.Error(w, "Job not found", http.StatusNotFound) + return + } + + if _, err := os.Stat(job.LogFile); os.IsNotExist(err) { + srv.log.Warnf("Log file for job %s not found", jobID) + http.Error(w, "Log file not found", http.StatusNotFound) + return + } + logContent, err := ioutil.ReadFile(job.LogFile) + if err != nil { + srv.log.Errorf("Error reading log file for job %s: %v", jobID, err) + http.Error(w, "Failed to read log file", http.StatusInternalServerError) + return + } + + w.Header().Set("Content-Type", "text/plain") + _, _ = w.Write(logContent) + srv.log.Infof("GET /jobs/%s/logs successful", jobID) +} + +func (srv *ILabServer) listJobsHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Debugf("GET /jobs called") + jobList, err := srv.listAllJobs() + if err != nil { + srv.log.Errorf("Error listing jobs: %v", err) + http.Error(w, "Failed to list jobs", http.StatusInternalServerError) + return + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(jobList) +} + +// ----------------------------------------------------------------------------- +// Checkpoints +// ----------------------------------------------------------------------------- + +// listCheckpointsHandler is the HTTP handler for the /checkpoints endpoint. +func (srv *ILabServer) listCheckpointsHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("GET /checkpoints called") + + homeDir, err := os.UserHomeDir() + if err != nil { + srv.log.Errorf("Error getting user home directory: %v", err) + http.Error(w, "Failed to get user home directory", http.StatusInternalServerError) + return + } + + checkpointsDir := filepath.Join(homeDir, ".local", "share", "instructlab", "checkpoints", "hf_format") + if _, err := os.Stat(checkpointsDir); os.IsNotExist(err) { + srv.log.Infof("Checkpoints directory does not exist: %s", checkpointsDir) + http.Error(w, "Checkpoints directory does not exist", http.StatusNotFound) + return + } + + entries, err := ioutil.ReadDir(checkpointsDir) + if err != nil { + srv.log.Errorf("Error reading checkpoints directory: %v", err) + http.Error(w, "Failed to read checkpoints directory", http.StatusInternalServerError) + return + } + + var directories []string + for _, entry := range entries { + if entry.IsDir() { + directories = append(directories, entry.Name()) + } + } + + w.Header().Set("Content-Type", "application/json") + if err := json.NewEncoder(w).Encode(directories); err != nil { + srv.log.Errorf("Error encoding directories to JSON: %v", err) + http.Error(w, "Failed to encode directories", http.StatusInternalServerError) + return + } + srv.log.Infof("GET /checkpoints successful, %d directories returned", len(directories)) +} + +// ----------------------------------------------------------------------------- +// QnA Evaluation +// ----------------------------------------------------------------------------- + +// runQnaEval is the HTTP handler for /qna-eval. +func (srv *ILabServer) runQnaEval(w http.ResponseWriter, r *http.Request) { + srv.log.Info("POST /qna-eval called") + + var req QnaEvalRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + srv.log.Errorf("Error decoding request body: %v", err) + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + if _, err := os.Stat(req.ModelPath); os.IsNotExist(err) { + srv.log.Errorf("Model path does not exist: %s", req.ModelPath) + http.Error(w, fmt.Sprintf("Model path does not exist: %s", req.ModelPath), http.StatusBadRequest) + return + } + if _, err := os.Stat(req.YamlFile); os.IsNotExist(err) { + srv.log.Errorf("YAML file does not exist: %s", req.YamlFile) + http.Error(w, fmt.Sprintf("YAML file does not exist: %s", req.YamlFile), http.StatusBadRequest) + return + } + + homeDir, err := os.UserHomeDir() + if err != nil { + srv.log.Errorf("Failed to get user's home directory: %v", err) + http.Error(w, "Internal server error", http.StatusInternalServerError) + return + } + + cmd := exec.Command("podman", "run", "--rm", + "--device", "nvidia.com/gpu=all", + "-v", fmt.Sprintf("%s:%s", homeDir, homeDir), + "quay.io/bsalisbu/qna-eval", + "--model_path", req.ModelPath, + "--yaml_file", req.YamlFile, + ) + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + srv.log.Infof("Executing Podman command: %v", cmd.Args) + err = cmd.Run() + if err != nil { + srv.log.Errorf("Podman command failed: %v, stderr: %s", err, stderr.String()) + response := map[string]string{"error": stderr.String()} + w.Header().Set("Content-Type", "application/json") + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(response) + return + } + + response := map[string]string{ + "result": stdout.String(), + } + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + srv.log.Info("POST /qna-eval completed successfully") +} + +// ----------------------------------------------------------------------------- +// Serve Models (CPU-based) or via VLLM +// ----------------------------------------------------------------------------- + +// serveLatestCheckpointHandler serves the latest checkpoint model on port 8001. +func (srv *ILabServer) serveLatestCheckpointHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("POST /model/serve-latest called, loading the latest checkpoint") + + // Parse the JSON request body to extract the optional "checkpoint" parameter. + var req ServeModelRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + srv.log.Errorf("Error decoding request body: %v", err) + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + + homeDir, err := os.UserHomeDir() + if err != nil { + srv.log.Errorf("Error getting user home directory: %v", err) + http.Error(w, "Failed to get home directory", http.StatusInternalServerError) + return + } + + checkpointsDir := filepath.Join(homeDir, ".local", "share", "instructlab", "checkpoints", "hf_format") + if _, err := os.Stat(checkpointsDir); os.IsNotExist(err) { + srv.log.Errorf("Checkpoints directory does not exist: %s", checkpointsDir) + http.Error(w, "Checkpoints directory does not exist", http.StatusNotFound) + return + } + + var modelPath string + if req.Checkpoint != "" { + // If a checkpoint is provided, construct the model path accordingly. + modelPath = filepath.Join(checkpointsDir, req.Checkpoint) + srv.log.Infof("Checkpoint provided: %s", modelPath) + + // Verify that the specified checkpoint directory exists. + if _, err := os.Stat(modelPath); os.IsNotExist(err) { + srv.log.Errorf("Specified checkpoint directory does not exist: %s", modelPath) + http.Error(w, fmt.Sprintf("Checkpoint '%s' does not exist", req.Checkpoint), http.StatusBadRequest) + return + } + } else { + // If no checkpoint is provided, find the latest "samples_*" directory. + latestDir, err := srv.findLatestDirWithPrefix(checkpointsDir, "samples_") + if err != nil { + srv.log.Errorf("Error finding latest checkpoint: %v", err) + http.Error(w, "Failed to find the latest checkpoint", http.StatusInternalServerError) + return + } + modelPath = latestDir + srv.log.Infof("No checkpoint provided. Using the latest checkpoint: %s", modelPath) + } + + if srv.useVllm { + // Serving using VLLM with Podman + srv.log.Infof("Serving model using vllm at %s on port 8001", modelPath) + srv.runVllmContainerHandler( + modelPath, + "8001", + "post-train", + 1, + "/var/home/cloud-user", + "/var/home/cloud-user", + w, + ) + } else { + // Basic local serve + srv.log.Infof("Serving model at %s on port 8001", modelPath) + srv.serveModelHandler(modelPath, "8001", w) + } +} + +// serveBaseModelHandler serves the "base" model on port 8000. +func (srv *ILabServer) serveBaseModelHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("POST /model/serve-base called") + + homeDir, err := os.UserHomeDir() + if err != nil { + srv.log.Errorf("Error getting user home directory: %v", err) + http.Error(w, "Failed to get home directory", http.StatusInternalServerError) + return + } + + if srv.useVllm { + // Spawn container for "pre-train" model + baseModelPath := filepath.Join(homeDir, ".cache", "instructlab", "models", "granite-8b-starter-v1") + srv.log.Infof("Serving base model using vllm at %s on port 8000", baseModelPath) + srv.runVllmContainerHandler( + baseModelPath, + "8000", + "pre-train", + 0, + "/var/home/cloud-user", + "/var/home/cloud-user", + w, + ) + } else { + baseModelPath := filepath.Join(homeDir, ".cache", "instructlab", "models", "granite-7b-lab-Q4_K_M.gguf") + srv.log.Infof("Serving base model at %s on port 8000", baseModelPath) + srv.serveModelHandler(baseModelPath, "8000", w) + } +} + +// runVllmContainerHandler spawns a container for vllm-openai with the specified parameters. +func (srv *ILabServer) runVllmContainerHandler( + modelPath, port, servedModelName string, + gpuIndex int, hostVolume, containerVolume string, + w http.ResponseWriter, +) { + cmdArgs := []string{ + "run", "--rm", + fmt.Sprintf("--device=nvidia.com/gpu=%d", gpuIndex), + fmt.Sprintf("-e=NVIDIA_VISIBLE_DEVICES=%d", gpuIndex), + "-v", "/usr/local/cuda-12.4/lib64:/usr/local/cuda-12.4/lib64", + "-v", fmt.Sprintf("%s:%s", hostVolume, containerVolume), + "-p", fmt.Sprintf("%s:%s", port, port), + "--ipc=host", + "vllm/vllm-openai:latest", + "--host", "0.0.0.0", + "--port", port, + "--model", modelPath, + "--load-format", "safetensors", + "--config-format", "hf", + "--trust-remote-code", + "--device", "cuda", + "--served-model-name", servedModelName, + } + + // Log the command for debugging + fullCmd := fmt.Sprintf("podman %s", strings.Join(cmdArgs, " ")) + srv.log.Infof("Executing Podman command: %s", fullCmd) + + // Create a unique job ID and a log file + jobID := fmt.Sprintf("v-%d", time.Now().UnixNano()) + logFilePath := filepath.Join("logs", fmt.Sprintf("%s.log", jobID)) + srv.log.Infof("Starting vllm-openai container with job_id: %s, logs: %s", jobID, logFilePath) + + cmd := exec.Command("podman", cmdArgs...) + + // Open the log file + logFile, err := os.Create(logFilePath) + if err != nil { + srv.log.Errorf("Error creating log file for vllm job %s: %v", jobID, err) + http.Error(w, "Failed to create log file for vllm job", http.StatusInternalServerError) + return + } + cmd.Stdout = logFile + cmd.Stderr = logFile + + // Start the container + if err := cmd.Start(); err != nil { + srv.log.Errorf("Error starting podman container for vllm job %s: %v", jobID, err) + logFile.Close() + http.Error(w, "Failed to start vllm container", http.StatusInternalServerError) + return + } + + srv.log.Infof("Vllm container started with PID %d for job_id: %s", cmd.Process.Pid, jobID) + + // Create a Job record and store it in the DB + newJob := &Job{ + JobID: jobID, + Cmd: "podman", + Args: cmdArgs, + Status: "running", + PID: cmd.Process.Pid, + LogFile: logFilePath, + StartTime: time.Now(), + } + if err := srv.createJob(newJob); err != nil { + srv.log.Errorf("Failed to create job in DB for %s: %v", jobID, err) + // We won't terminate here—container is already running, so just log the DB error + } + + // Update the servedModelJobIDs Map + srv.jobIDsMutex.Lock() + srv.servedModelJobIDs[servedModelName] = jobID + srv.jobIDsMutex.Unlock() + srv.log.Infof("Mapped model '%s' to job ID '%s'", servedModelName, jobID) + + // Monitor the container in a background goroutine + go func() { + defer logFile.Close() + + err := cmd.Wait() + newJob.Lock.Lock() + defer newJob.Lock.Unlock() + + if err != nil { + newJob.Status = "failed" + srv.log.Errorf("Vllm job '%s' failed: %v", newJob.JobID, err) + } else if cmd.ProcessState.Success() { + newJob.Status = "finished" + srv.log.Infof("Vllm job '%s' finished successfully", newJob.JobID) + } else { + newJob.Status = "failed" + srv.log.Warnf("Vllm job '%s' failed (unknown reason)", newJob.JobID) + } + + now := time.Now() + newJob.EndTime = &now + + if errDB := srv.updateJob(newJob); errDB != nil { + srv.log.Errorf("Failed to update DB for job '%s': %v", newJob.JobID, errDB) + } + + // **Remove the mapping from servedModelJobIDs if job is finished or failed** + srv.jobIDsMutex.Lock() + delete(srv.servedModelJobIDs, servedModelName) + srv.jobIDsMutex.Unlock() + srv.log.Infof("Removed mapping for model '%s' from servedModelJobIDs", servedModelName) + }() + + // Respond with the job ID + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{ + "status": "vllm container started", + "job_id": jobID, + }) + srv.log.Infof("POST /model/serve-%s response sent successfully with job_id: %s", servedModelName, jobID) +} + +// serveModelHandler starts serving a model on the specified port (CPU-based approach). +func (srv *ILabServer) serveModelHandler(modelPath, port string, w http.ResponseWriter) { + srv.modelLock.Lock() + defer srv.modelLock.Unlock() + + srv.log.Infof("serveModelHandler called with modelPath=%s, port=%s", modelPath, port) + + var targetProcess **exec.Cmd + if port == "8000" { + targetProcess = &srv.modelProcessBase + } else if port == "8001" { + targetProcess = &srv.modelProcessLatest + } else { + http.Error(w, "Invalid port specified", http.StatusBadRequest) + return + } + + if _, err := os.Stat(modelPath); os.IsNotExist(err) { + srv.log.Errorf("Model path does not exist: %s", modelPath) + http.Error(w, fmt.Sprintf("Model path does not exist: %s", modelPath), http.StatusNotFound) + return + } + + if *targetProcess != nil && (*targetProcess).Process != nil { + srv.log.Infof("Stopping existing model process on port %s...", port) + if err := (*targetProcess).Process.Kill(); err != nil { + srv.log.Errorf("Failed to kill existing model process on port %s: %v", port, err) + http.Error(w, "Failed to stop existing model process", http.StatusInternalServerError) + return + } + *targetProcess = nil + } + + cmdArgs := []string{ + "serve", "model", + "--model", modelPath, + "--host", "0.0.0.0", + "--port", port, + } + cmdPath := srv.getIlabCommand() + cmd := exec.Command(cmdPath, cmdArgs...) + if !srv.rhelai { + cmd.Dir = srv.baseDir + } + + jobID := fmt.Sprintf("ml-%d", time.Now().UnixNano()) + logFilePath := filepath.Join("logs", fmt.Sprintf("%s.log", jobID)) + srv.log.Infof("Model serve logs: %s", logFilePath) + + logFile, err := os.Create(logFilePath) + if err != nil { + srv.log.Errorf("Error creating model run log file: %v", err) + http.Error(w, "Failed to create log file", http.StatusInternalServerError) + return + } + + cmd.Stdout = logFile + cmd.Stderr = logFile + + srv.log.Info("Attempting to start model process...") + if err := cmd.Start(); err != nil { + srv.log.Errorf("Error starting model process: %v", err) + logFile.Close() + http.Error(w, "Failed to start model process", http.StatusInternalServerError) + return + } + *targetProcess = cmd + srv.log.Infof("Model process started with PID %d on port %s", cmd.Process.Pid, port) + + serveJob := &Job{ + JobID: jobID, + Cmd: cmdPath, + Args: cmdArgs, + Status: "running", + PID: cmd.Process.Pid, + LogFile: logFilePath, + StartTime: time.Now(), + } + _ = srv.createJob(serveJob) + + go func() { + err := cmd.Wait() + logFile.Sync() + logFile.Close() + + serveJob.Lock.Lock() + defer serveJob.Lock.Unlock() + + if err != nil { + serveJob.Status = "failed" + srv.log.Infof("Model run job '%s' on port %s failed: %v", jobID, port, err) + } else if cmd.ProcessState.Success() { + serveJob.Status = "finished" + srv.log.Infof("Model run job '%s' on port %s finished successfully", jobID, port) + } else { + serveJob.Status = "failed" + srv.log.Infof("Model run job '%s' on port %s failed (unknown reason)", jobID, port) + } + now := time.Now() + serveJob.EndTime = &now + _ = srv.updateJob(serveJob) + + srv.modelLock.Lock() + defer srv.modelLock.Unlock() + if port == "8000" && srv.modelProcessBase == cmd { + srv.modelProcessBase = nil + } + if port == "8001" && srv.modelProcessLatest == cmd { + srv.modelProcessLatest = nil + } + }() + + srv.log.Infof("Model serve started successfully on port %s, returning job_id: %s", port, jobID) + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(map[string]string{"status": "model process started", "job_id": jobID}) +} + +// listServedModelJobIDsHandler is a debug endpoint to list current model to jobID mappings. +func (srv *ILabServer) listServedModelJobIDsHandler(w http.ResponseWriter, r *http.Request) { + srv.jobIDsMutex.RLock() + defer srv.jobIDsMutex.RUnlock() + _ = json.NewEncoder(w).Encode(srv.servedModelJobIDs) +} diff --git a/api-server/jobs.go b/api-server/jobs.go new file mode 100644 index 00000000..48d6d164 --- /dev/null +++ b/api-server/jobs.go @@ -0,0 +1,269 @@ +package main + +import ( + "database/sql" + "encoding/json" + "fmt" + "os" + "syscall" + "time" +) + +// ----------------------------------------------------------------------------- +// Database +// ----------------------------------------------------------------------------- + +// initDB opens (or creates) a local SQLite database file named jobs.db +// and ensures a jobs table exists. +func (srv *ILabServer) initDB() { + var err error + srv.db, err = sql.Open("sqlite3", "jobs.db") + if err != nil { + srv.log.Fatalf("Failed to open SQLite database: %v", err) + } + + // Create the jobs table if it doesn't exist + createTableSQL := ` + CREATE TABLE IF NOT EXISTS jobs ( + job_id TEXT PRIMARY KEY, + cmd TEXT, + args TEXT, + status TEXT, + pid INTEGER, + log_file TEXT, + start_time TEXT, + end_time TEXT, + branch TEXT + ); + ` + _, err = srv.db.Exec(createTableSQL) + if err != nil { + srv.log.Fatalf("Failed to create jobs table: %v", err) + } +} + +// ----------------------------------------------------------------------------- +// Jobs +// ----------------------------------------------------------------------------- + +// createJob inserts a new job row into the DB. +func (srv *ILabServer) createJob(job *Job) error { + argsJSON, err := json.Marshal(job.Args) + if err != nil { + return fmt.Errorf("failed to marshal job Args: %v", err) + } + var endTimeStr *string + if job.EndTime != nil { + s := job.EndTime.Format(time.RFC3339) + endTimeStr = &s + } + _, err = srv.db.Exec(` + INSERT INTO jobs (job_id, cmd, args, status, pid, log_file, start_time, end_time, branch) + VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?) + `, + job.JobID, + job.Cmd, + string(argsJSON), + job.Status, + job.PID, + job.LogFile, + job.StartTime.Format(time.RFC3339), + endTimeStr, + job.Branch, + ) + if err != nil { + return fmt.Errorf("failed to insert job: %v", err) + } + return nil +} + +// getJob fetches a single job by job_id. +func (srv *ILabServer) getJob(jobID string) (*Job, error) { + row := srv.db.QueryRow("SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch FROM jobs WHERE job_id = ?", jobID) + + var j Job + var argsJSON string + var startTimeStr, endTimeStr sql.NullString + + err := row.Scan( + &j.JobID, + &j.Cmd, + &argsJSON, + &j.Status, + &j.PID, + &j.LogFile, + &startTimeStr, + &endTimeStr, + &j.Branch, + ) + if err == sql.ErrNoRows { + return nil, nil // not found + } else if err != nil { + return nil, err + } + + if err := json.Unmarshal([]byte(argsJSON), &j.Args); err != nil { + return nil, fmt.Errorf("failed to unmarshal job Args: %v", err) + } + if startTimeStr.Valid { + t, err := time.Parse(time.RFC3339, startTimeStr.String) + if err == nil { + j.StartTime = t + } + } + if endTimeStr.Valid && endTimeStr.String != "" { + t, err := time.Parse(time.RFC3339, endTimeStr.String) + if err == nil { + j.EndTime = &t + } + } + return &j, nil +} + +// updateJob updates an existing job in the DB. +func (srv *ILabServer) updateJob(job *Job) error { + argsJSON, err := json.Marshal(job.Args) + if err != nil { + return fmt.Errorf("failed to marshal job Args: %v", err) + } + var endTimeStr *string + if job.EndTime != nil { + s := job.EndTime.Format(time.RFC3339) + endTimeStr = &s + } + _, err = srv.db.Exec(` + UPDATE jobs + SET cmd = ?, args = ?, status = ?, pid = ?, log_file = ?, start_time = ?, end_time = ?, branch = ? + WHERE job_id = ? + `, + job.Cmd, + string(argsJSON), + job.Status, + job.PID, + job.LogFile, + job.StartTime.Format(time.RFC3339), + endTimeStr, + job.Branch, + job.JobID, + ) + if err != nil { + return fmt.Errorf("failed to update job %s: %v", job.JobID, err) + } + return nil +} + +// listAllJobs returns all jobs in the DB. +func (srv *ILabServer) listAllJobs() ([]*Job, error) { + rows, err := srv.db.Query("SELECT job_id, cmd, args, status, pid, log_file, start_time, end_time, branch FROM jobs") + if err != nil { + return nil, err + } + defer rows.Close() + + var jobs []*Job + for rows.Next() { + var j Job + var argsJSON string + var startTimeStr, endTimeStr sql.NullString + + err := rows.Scan( + &j.JobID, + &j.Cmd, + &argsJSON, + &j.Status, + &j.PID, + &j.LogFile, + &startTimeStr, + &endTimeStr, + &j.Branch, + ) + if err != nil { + return nil, err + } + + if err := json.Unmarshal([]byte(argsJSON), &j.Args); err != nil { + srv.log.Infof("Warning: failed to unmarshal job Args for job %s: %v", j.JobID, err) + } + if startTimeStr.Valid { + t, err := time.Parse(time.RFC3339, startTimeStr.String) + if err == nil { + j.StartTime = t + } + } + if endTimeStr.Valid && endTimeStr.String != "" { + t, err := time.Parse(time.RFC3339, endTimeStr.String) + if err == nil { + j.EndTime = &t + } + } + jobs = append(jobs, &j) + } + + return jobs, rows.Err() +} + +// ----------------------------------------------------------------------------- +// Checking Running Jobs after a server restart +// ----------------------------------------------------------------------------- + +// checkRunningJobs checks the status of "running" jobs and marks them as failed if their processes are not running. +func (srv *ILabServer) checkRunningJobs() { + rows, err := srv.db.Query("SELECT job_id, pid FROM jobs WHERE status = 'running'") + if err != nil { + srv.log.Errorf("Error querying running jobs: %v", err) + return + } + defer rows.Close() + + var jobsToMarkFailed []string + for rows.Next() { + var jobID string + var pid int + if err := rows.Scan(&jobID, &pid); err != nil { + srv.log.Infof("Error scanning row of running jobs: %v", err) + continue + } + if !srv.isProcessRunning(pid) { + srv.log.Infof("Job %s marked as failed (process not running)", jobID) + jobsToMarkFailed = append(jobsToMarkFailed, jobID) + } + } + + // 3. Update jobs that are no longer running + for _, jobID := range jobsToMarkFailed { + endTime := time.Now() + j, err := srv.getJob(jobID) + if err != nil || j == nil { + srv.log.Infof("Unable to fetch jobID=%s to mark as failed: %v", jobID, err) + continue + } + j.Status = "failed" + j.EndTime = &endTime + if err := srv.updateJob(j); err != nil { + srv.log.Infof("Error marking job %s as failed: %v", jobID, err) + } + } +} + +// isProcessRunning checks if a process with the given PID is still alive. +func (srv *ILabServer) isProcessRunning(pid int) bool { + if pid <= 0 { + if srv.debugEnabled { + srv.log.Debugf("[DEBUG] isProcessRunning called with invalid PID=%d", pid) + } + return false + } + process, err := os.FindProcess(pid) + if err != nil { + if srv.debugEnabled { + srv.log.Debugf("[DEBUG] os.FindProcess error: %v", err) + } + return false + } + + err = process.Signal(syscall.Signal(0)) + if srv.debugEnabled { + srv.log.Debugf("[DEBUG] process.Signal(0) on PID %d => err=%v", pid, err) + } + return err == nil +} diff --git a/api-server/main.go b/api-server/main.go new file mode 100644 index 00000000..094cbd9c --- /dev/null +++ b/api-server/main.go @@ -0,0 +1,865 @@ +package main + +import ( + "database/sql" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gorilla/mux" + _ "github.com/mattn/go-sqlite3" + "github.com/spf13/cobra" + "go.uber.org/zap" +) + +// ----------------------------------------------------------------------------- +// Structs +// ----------------------------------------------------------------------------- + +// Model represents a model record (from 'ilab model list'). +type Model struct { + Name string `json:"name"` + LastModified string `json:"last_modified"` + Size string `json:"size"` +} + +// Data represents a data record (from 'ilab data list'). +type Data struct { + Dataset string `json:"dataset"` + CreatedAt string `json:"created_at"` + FileSize string `json:"file_size"` +} + +// Job represents a background job, including train/generate/pipeline/vllm-run jobs. +type Job struct { + JobID string `json:"job_id"` + Cmd string `json:"cmd"` + Args []string `json:"args"` + Status string `json:"status"` // "running", "finished", "failed" + PID int `json:"pid"` + LogFile string `json:"log_file"` + StartTime time.Time `json:"start_time"` + EndTime *time.Time `json:"end_time,omitempty"` + Branch string `json:"branch"` + + // Lock is not serialized; it protects updates to the Job in memory. + Lock sync.Mutex `json:"-"` +} + +// ModelCache encapsulates the cached models and related metadata. +type ModelCache struct { + Models []Model + Time time.Time + Mutex sync.Mutex +} + +// QnaEvalRequest is used by the /qna-eval endpoint. +type QnaEvalRequest struct { + ModelPath string `json:"model_path"` + YamlFile string `json:"yaml_file"` +} + +// VllmContainerResponse is returned by the /vllm-containers endpoint. +type VllmContainerResponse struct { + Containers []VllmContainer `json:"containers"` +} + +type ServeModelRequest struct { + Checkpoint string `json:"checkpoint,omitempty"` // Optional: Name of the checkpoint directory (e.g., "samples_12345") +} + +// UnloadModelRequest is used by the /vllm-unload endpoint. +type UnloadModelRequest struct { + ModelName string `json:"model_name"` // Expected: "pre-train" or "post-train" +} + +// ----------------------------------------------------------------------------- +// ILabServer struct to hold configuration, DB handle, logs, etc. +// ----------------------------------------------------------------------------- + +type ILabServer struct { + baseDir string + taxonomyPath string + rhelai bool + ilabCmd string + isOSX bool + isCuda bool + useVllm bool + pipelineType string + debugEnabled bool + + // Logger + logger *zap.Logger + log *zap.SugaredLogger + + // Database handle + db *sql.DB + + // Model processes for CPU-based or local serving (if not using VLLM) + modelLock sync.Mutex + modelProcessBase *exec.Cmd + modelProcessLatest *exec.Cmd + + // Base model reference (used in some logic but not necessary if VLLM is used with default paths) + baseModel string + + // Map of "pre-train"/"post-train" => jobID for VLLM serving + servedModelJobIDs map[string]string + jobIDsMutex sync.RWMutex + + // Cache variables + modelCache ModelCache +} + +// ----------------------------------------------------------------------------- +// main(), flags and Cobra +// ----------------------------------------------------------------------------- + +func main() { + // We create an instance of ILabServer to hold all state and methods. + srv := &ILabServer{ + baseModel: "instructlab/granite-7b-lab", + servedModelJobIDs: make(map[string]string), + modelCache: ModelCache{}, + } + + rootCmd := &cobra.Command{ + Use: "ilab-server", + Short: "ILab Server Application", + Run: func(cmd *cobra.Command, args []string) { + // Now that flags are set, run the server method on the struct. + srv.runServer(cmd, args) + }, + } + + // Define flags + rootCmd.Flags().BoolVar(&srv.rhelai, "rhelai", false, "Use ilab binary from PATH instead of Python virtual environment") + rootCmd.Flags().StringVar(&srv.baseDir, "base-dir", "", "Base directory for ilab operations (required if --rhelai is not set)") + rootCmd.Flags().StringVar(&srv.taxonomyPath, "taxonomy-path", "", "Path to the taxonomy repository for Git operations (required)") + rootCmd.Flags().BoolVar(&srv.isOSX, "osx", false, "Enable OSX-specific settings (default: false)") + rootCmd.Flags().BoolVar(&srv.isCuda, "cuda", false, "Enable Cuda (default: false)") + rootCmd.Flags().BoolVar(&srv.useVllm, "vllm", false, "Enable VLLM model serving using podman containers") + rootCmd.Flags().StringVar(&srv.pipelineType, "pipeline", "", "Pipeline type (simple, accelerated, full)") + rootCmd.Flags().BoolVar(&srv.debugEnabled, "debug", false, "Enable debug logging") + + // PreRun to validate flags + rootCmd.PreRunE = func(cmd *cobra.Command, args []string) error { + if !srv.rhelai && srv.baseDir == "" { + return fmt.Errorf("--base-dir is required unless --rhelai is set") + } + if srv.taxonomyPath == "" { + return fmt.Errorf("--taxonomy-path is required") + } + + // Validate or set pipelineType based on --rhelai + if !srv.rhelai { + if srv.pipelineType == "" { + return fmt.Errorf("--pipeline is required unless --rhelai is set") + } + switch srv.pipelineType { + case "simple", "full", "accelerated": + // Valid + default: + return fmt.Errorf("--pipeline must be 'simple', 'accelerated' or 'full'; got '%s'", srv.pipelineType) + } + } else { + // When --rhelai is set and --pipeline is not provided, set a default + if srv.pipelineType == "" { + srv.pipelineType = "accelerated" + fmt.Println("--rhelai is set; defaulting --pipeline to 'accelerated'") + } else { + switch srv.pipelineType { + case "simple", "full", "accelerated": + // Valid + default: + return fmt.Errorf("--pipeline must be 'simple', 'accelerated' or 'full'; got '%s'", srv.pipelineType) + } + } + } + return nil + } + + if err := rootCmd.Execute(); err != nil { + fmt.Printf("Error executing command: %v\n", err) + os.Exit(1) + } +} + +// runServer is the main entry method after flags are parsed. +func (srv *ILabServer) runServer(cmd *cobra.Command, args []string) { + // Initialize zap logger + srv.initLogger(srv.debugEnabled) + + if srv.debugEnabled { + srv.log.Info("Debug logging is ENABLED.") + } else { + srv.log.Info("Debug logging is DISABLED.") + } + + // Initialize the database + srv.initDB() + + // Determine ilab command path + if srv.rhelai { + // Use ilab from PATH + ilabPath, err := exec.LookPath("ilab") + if err != nil { + srv.log.Fatalf("ilab binary not found in PATH. Please ensure ilab is installed and in your PATH.") + } + srv.ilabCmd = ilabPath + } else { + // Use ilab from virtual environment + srv.ilabCmd = filepath.Join(srv.baseDir, "venv", "bin", "ilab") + if _, err := os.Stat(srv.ilabCmd); os.IsNotExist(err) { + srv.log.Fatalf("ilab binary not found at %s. Please ensure the virtual environment is set up correctly.", srv.ilabCmd) + } + } + + srv.log.Infof("Using ilab command: %s", srv.ilabCmd) + + // Validate mandatory arguments if not using rhelai + if !srv.rhelai { + if _, err := os.Stat(srv.baseDir); os.IsNotExist(err) { + srv.log.Fatalf("Base directory does not exist: %s", srv.baseDir) + } + } + + if _, err := os.Stat(srv.taxonomyPath); os.IsNotExist(err) { + srv.log.Fatalf("Taxonomy path does not exist: %s", srv.taxonomyPath) + } + + srv.log.Infof("Running with baseDir=%s, taxonomyPath=%s, isOSX=%v, isCuda=%v, useVllm=%v, pipeline=%s", + srv.baseDir, srv.taxonomyPath, srv.isOSX, srv.isCuda, srv.useVllm, srv.pipelineType) + srv.log.Infof("Current working directory: %s", srv.mustGetCwd()) + + // Check statuses of any jobs that might have been running before a restart + srv.checkRunningJobs() + + // Initialize the model cache + srv.initializeModelCache() + + // Create the logs directory if it doesn't exist + err := os.MkdirAll("logs", os.ModePerm) + if err != nil { + srv.log.Fatalf("Failed to create logs directory: %v", err) + } + + // Setup HTTP routes + r := mux.NewRouter() + r.HandleFunc("/models", srv.getModelsHandler).Methods("GET") + r.HandleFunc("/data", srv.getDataHandler).Methods("GET") + r.HandleFunc("/data/generate", srv.generateDataHandler).Methods("POST") + r.HandleFunc("/model/train", srv.trainModelHandler).Methods("POST") + r.HandleFunc("/jobs/{job_id}/status", srv.getJobStatusHandler).Methods("GET") + r.HandleFunc("/jobs/{job_id}/logs", srv.getJobLogsHandler).Methods("GET") + r.HandleFunc("/jobs", srv.listJobsHandler).Methods("GET") + r.HandleFunc("/pipeline/generate-train", srv.generateTrainPipelineHandler).Methods("POST") + r.HandleFunc("/model/serve-latest", srv.serveLatestCheckpointHandler).Methods("POST") + r.HandleFunc("/model/serve-base", srv.serveBaseModelHandler).Methods("POST") + r.HandleFunc("/qna-eval", srv.runQnaEval).Methods("POST") + r.HandleFunc("/checkpoints", srv.listCheckpointsHandler).Methods("GET") + r.HandleFunc("/vllm-containers", srv.listVllmContainersHandler).Methods("GET") + r.HandleFunc("/vllm-unload", srv.unloadVllmContainerHandler).Methods("POST") + r.HandleFunc("/vllm-status", srv.getVllmStatusHandler).Methods("GET") + r.HandleFunc("/gpu-free", srv.getGpuFreeHandler).Methods("GET") + r.HandleFunc("/served-model-jobids", srv.listServedModelJobIDsHandler).Methods("GET") + + srv.log.Info("Server starting on port 8080... (Taxonomy path: ", srv.taxonomyPath, ")") + if err := http.ListenAndServe("0.0.0.0:8080", r); err != nil { + srv.log.Fatalf("Server failed to start: %v", err) + } +} + +// ----------------------------------------------------------------------------- +// Utility +// ----------------------------------------------------------------------------- + +// getIlabCommand returns the path to the ilab command, depending on rhelai or local venv. +func (srv *ILabServer) getIlabCommand() string { + return srv.ilabCmd +} + +// mustGetCwd returns the current working directory or "unknown" if it fails. +func (srv *ILabServer) mustGetCwd() string { + cwd, err := os.Getwd() + if err != nil { + return "unknown" + } + return cwd +} + +// sanitizeModelName checks if the modelName starts with "model/" and replaces it with "models/". +func (srv *ILabServer) sanitizeModelName(modelName string) string { + if strings.HasPrefix(modelName, "model/") { + return strings.Replace(modelName, "model/", "models/", 1) + } + return modelName +} + +// ----------------------------------------------------------------------------- +// Model Cache +// ----------------------------------------------------------------------------- + +// initializeModelCache refreshes the model cache once and then schedules a refresh every 20 minutes. +func (srv *ILabServer) initializeModelCache() { + srv.refreshModelCache() + go func() { + for { + time.Sleep(20 * time.Minute) + srv.refreshModelCache() + } + }() +} + +// refreshModelCache updates the model cache if it's older than 20 minutes or if empty. +// TODO: this is really slow due to a caching issue upstream/downstream, should probably be async +func (srv *ILabServer) refreshModelCache() { + srv.modelCache.Mutex.Lock() + defer srv.modelCache.Mutex.Unlock() + + if time.Since(srv.modelCache.Time) < 20*time.Minute && len(srv.modelCache.Models) > 0 { + srv.log.Info("Model cache is still valid; no refresh needed.") + return + } + + srv.log.Info("Refreshing model cache... Takes 10-20s") + output, err := srv.runIlabCommand("model", "list") + if err != nil { + srv.log.Errorf("Error refreshing model cache: %v", err) + return + } + models, err := srv.parseModelList(output) + if err != nil { + srv.log.Errorf("Error parsing model list during cache refresh: %v", err) + return + } + srv.modelCache.Models = models + srv.modelCache.Time = time.Now() + srv.log.Infof("Model cache refreshed at %v with %d models.", srv.modelCache.Time, len(models)) +} + +// ----------------------------------------------------------------------------- +// Start Generate Data Job +// ----------------------------------------------------------------------------- + +// startGenerateJob launches a job to run "ilab data generate" and tracks it. +func (srv *ILabServer) startGenerateJob() (string, error) { + ilabPath := srv.getIlabCommand() + + // Hard-coded pipeline choice for data generate, or we could use srv.pipelineType + cmdArgs := []string{"data", "generate", "--pipeline", "full"} + + cmd := exec.Command(ilabPath, cmdArgs...) + if !srv.rhelai { + cmd.Dir = srv.baseDir + } + + jobID := fmt.Sprintf("g-%d", time.Now().UnixNano()) + logFilePath := filepath.Join("logs", fmt.Sprintf("%s.log", jobID)) + srv.log.Infof("Starting generateDataHandler job: %s, logs: %s", jobID, logFilePath) + + logFile, err := os.Create(logFilePath) + if err != nil { + srv.log.Errorf("Error creating log file: %v", err) + return "", fmt.Errorf("Failed to create log file") + } + cmd.Stdout = logFile + cmd.Stderr = logFile + + srv.log.Infof("Running command: %s %v", ilabPath, cmdArgs) + if err := cmd.Start(); err != nil { + srv.log.Errorf("Error starting data generation command: %v", err) + logFile.Close() + return "", err + } + + newJob := &Job{ + JobID: jobID, + Cmd: ilabPath, + Args: cmdArgs, + Status: "running", + PID: cmd.Process.Pid, + LogFile: logFilePath, + StartTime: time.Now(), + } + if err := srv.createJob(newJob); err != nil { + srv.log.Errorf("Error creating job in DB: %v", err) + return "", err + } + + go func() { + defer logFile.Close() + err := cmd.Wait() + + newJob.Lock.Lock() + defer newJob.Lock.Unlock() + + if err != nil { + newJob.Status = "failed" + srv.log.Infof("Job %s failed with error: %v", newJob.JobID, err) + } else { + if cmd.ProcessState.Success() { + newJob.Status = "finished" + srv.log.Infof("Job %s finished successfully", newJob.JobID) + } else { + newJob.Status = "failed" + srv.log.Infof("Job %s failed", newJob.JobID) + } + } + now := time.Now() + newJob.EndTime = &now + _ = srv.updateJob(newJob) + }() + + return jobID, nil +} + +// ----------------------------------------------------------------------------- +// Start Train Job +// ----------------------------------------------------------------------------- + +// startTrainJob starts a training job with the given parameters. +func (srv *ILabServer) startTrainJob(modelName, branchName string, epochs *int) (string, error) { + srv.log.Infof("Starting training job for model: '%s', branch: '%s'", modelName, branchName) + + jobID := fmt.Sprintf("t-%d", time.Now().UnixNano()) + logFilePath := filepath.Join("logs", fmt.Sprintf("%s.log", jobID)) + + fullModelPath, err := srv.getFullModelPath(modelName) + if err != nil { + return "", fmt.Errorf("failed to get full model path: %v", err) + } + srv.log.Infof("Resolved fullModelPath: '%s'", fullModelPath) + + modelDir := filepath.Dir(fullModelPath) + if err := os.MkdirAll(modelDir, os.ModePerm); err != nil { + return "", fmt.Errorf("failed to create model directory '%s': %v", modelDir, err) + } + + ilabPath := srv.getIlabCommand() + + var cmdArgs []string + cmdArgs = append(cmdArgs, "model", "train") + + // If not rhelai, add pipeline if set + if !srv.rhelai && srv.pipelineType != "" { + cmdArgs = append(cmdArgs, "--pipeline", srv.pipelineType) + } + cmdArgs = append(cmdArgs, fmt.Sprintf("--model-path=%s", fullModelPath)) + + if srv.isOSX { + cmdArgs = append(cmdArgs, "--device=mps") + } + if srv.isCuda { + cmdArgs = append(cmdArgs, "--device=cuda") + } + if epochs != nil { + cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) + srv.log.Infof("Number of epochs specified: %d", *epochs) + } else { + srv.log.Info("No epochs specified; using default number of epochs.") + } + + // Additional logic if pipelineType == "simple" (and not rhelai) + if srv.pipelineType == "simple" && !srv.rhelai { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %v", err) + } + datasetDir := filepath.Join(homeDir, ".local", "share", "instructlab", "datasets") + + // Copy the latest knowledge_train_msgs_*.jsonl => train_gen.jsonl + latestTrainFile, err := srv.findLatestFileWithPrefix(datasetDir, "knowledge_train_msgs_") + if err != nil { + return "", fmt.Errorf("failed to find knowledge_train_msgs_*.jsonl file: %v", err) + } + trainGenPath := filepath.Join(datasetDir, "train_gen.jsonl") + if err := srv.overwriteCopy(latestTrainFile, trainGenPath); err != nil { + return "", fmt.Errorf("failed to copy %s to %s: %v", latestTrainFile, trainGenPath, err) + } + + // Copy the latest test_ggml-model-*.jsonl => test_gen.jsonl + latestTestFile, err := srv.findLatestFileWithPrefix(datasetDir, "test_ggml-model") + if err != nil { + return "", fmt.Errorf("failed to find test_ggml-model*.jsonl file: %v", err) + } + testGenPath := filepath.Join(datasetDir, "test_gen.jsonl") + if err := srv.overwriteCopy(latestTestFile, testGenPath); err != nil { + return "", fmt.Errorf("failed to copy %s to %s: %v", latestTestFile, testGenPath, err) + } + + // Reset cmdArgs to a simpler set + cmdArgs = []string{ + "model", "train", + "--pipeline", srv.pipelineType, + fmt.Sprintf("--data-path=%s", datasetDir), + fmt.Sprintf("--model-path=%s", fullModelPath), + } + if srv.isOSX { + cmdArgs = append(cmdArgs, "--device=mps") + } + if srv.isCuda { + cmdArgs = append(cmdArgs, "--device=cuda") + } + if epochs != nil { + cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) + srv.log.Infof("Number of epochs specified for simple pipeline: %d", *epochs) + } else { + srv.log.Info("No epochs specified for simple pipeline; using default number of epochs.") + } + } + + if srv.rhelai { + latestDataset, err := srv.getLatestDatasetFile() + if err != nil { + return "", fmt.Errorf("failed to get latest dataset file: %v", err) + } + cmdArgs = []string{ + "model", "train", + fmt.Sprintf("--data-path=%s", latestDataset), + "--max-batch-len=5000", + "--gpus=4", + "--device=cuda", + "--save-samples=1000", + fmt.Sprintf("--model-path=%s", fullModelPath), + "--pipeline", srv.pipelineType, + } + if epochs != nil { + cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) + srv.log.Infof("Number of epochs specified for rhelai pipeline: %d", *epochs) + } else { + srv.log.Info("No epochs specified for rhelai pipeline; using default number of epochs.") + } + } + + srv.log.Infof("[ILAB TRAIN COMMAND] %s %v", ilabPath, cmdArgs) + + cmd := exec.Command(ilabPath, cmdArgs...) + if !srv.rhelai { + cmd.Dir = srv.baseDir + } + + logFile, err := os.Create(logFilePath) + if err != nil { + return "", fmt.Errorf("failed to create log file '%s': %v", logFilePath, err) + } + defer logFile.Close() + + cmd.Stdout = logFile + cmd.Stderr = logFile + + srv.log.Infof("[ILAB TRAIN COMMAND] %s %v", ilabPath, cmdArgs) + if err := cmd.Start(); err != nil { + return "", fmt.Errorf("error starting training command: %v", err) + } + srv.log.Infof("Training process started with PID: %d", cmd.Process.Pid) + + newJob := &Job{ + JobID: jobID, + Cmd: ilabPath, + Args: cmdArgs, + Status: "running", + PID: cmd.Process.Pid, + LogFile: logFilePath, + Branch: branchName, + StartTime: time.Now(), + } + if err := srv.createJob(newJob); err != nil { + return "", fmt.Errorf("failed to create job in DB: %v", err) + } + + go func() { + defer logFile.Close() + err := cmd.Wait() + + newJob.Lock.Lock() + defer newJob.Lock.Unlock() + + if err != nil { + newJob.Status = "failed" + srv.log.Infof("Training job '%s' failed: %v", newJob.JobID, err) + } else if cmd.ProcessState.Success() { + newJob.Status = "finished" + srv.log.Infof("Training job '%s' finished successfully", newJob.JobID) + } else { + newJob.Status = "failed" + srv.log.Infof("Training job '%s' failed (unknown reason)", newJob.JobID) + } + now := time.Now() + newJob.EndTime = &now + _ = srv.updateJob(newJob) + }() + + return jobID, nil +} + +// ----------------------------------------------------------------------------- +// Pipeline +// ----------------------------------------------------------------------------- + +func (srv *ILabServer) generateTrainPipelineHandler(w http.ResponseWriter, r *http.Request) { + srv.log.Info("POST /pipeline/generate-train called") + + var reqBody struct { + ModelName string `json:"modelName"` + BranchName string `json:"branchName"` + Epochs *int `json:"epochs,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + srv.log.Errorf("Error parsing request body: %v", err) + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + if reqBody.ModelName == "" || reqBody.BranchName == "" { + srv.log.Info("Missing required parameters: modelName or branchName") + http.Error(w, "Missing required parameters: modelName or branchName", http.StatusBadRequest) + return + } + + sanitizedModelName := srv.sanitizeModelName(reqBody.ModelName) + srv.log.Infof("Sanitized modelName for pipeline: '%s'", sanitizedModelName) + + pipelineJobID := fmt.Sprintf("p-%d", time.Now().UnixNano()) + srv.log.Infof("Starting pipeline job with ID: %s", pipelineJobID) + + pipelineJob := &Job{ + JobID: pipelineJobID, + Cmd: "pipeline-generate-train", + Args: []string{sanitizedModelName, reqBody.BranchName}, + Status: "running", + PID: 0, // no direct OS process + LogFile: fmt.Sprintf("logs/%s.log", pipelineJobID), + Branch: reqBody.BranchName, + StartTime: time.Now(), + } + if err := srv.createJob(pipelineJob); err != nil { + srv.log.Errorf("Error creating pipeline job: %v", err) + http.Error(w, "Failed to create pipeline job", http.StatusInternalServerError) + return + } + + go srv.runPipelineJob(pipelineJob, sanitizedModelName, reqBody.BranchName, reqBody.Epochs) + + response := map[string]string{"pipeline_job_id": pipelineJobID} + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + srv.log.Infof("POST /pipeline/generate-train => pipeline_job_id=%s", pipelineJobID) +} + +// runPipelineJob orchestrates data generate + model train steps in sequence. +func (srv *ILabServer) runPipelineJob(job *Job, modelName, branchName string, epochs *int) { + // Open the pipeline job log + logFile, err := os.Create(job.LogFile) + if err != nil { + srv.log.Errorf("Error creating pipeline log file for job %s: %v", job.JobID, err) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + defer logFile.Close() + + stdLogger := zap.NewStdLog(srv.logger) + + // Redirect that standard logger's output to our log file + stdLogger.SetOutput(logFile) + + stdLogger.Printf("Starting pipeline job: %s, model: %s, branch: %s, epochs: %v", + job.JobID, modelName, branchName, epochs) + + // 1) Git checkout + gitCheckoutCmd := exec.Command("git", "checkout", branchName) + gitCheckoutCmd.Dir = srv.taxonomyPath + gitOutput, gitErr := gitCheckoutCmd.CombinedOutput() + stdLogger.Printf("Git checkout output: %s", string(gitOutput)) + if gitErr != nil { + stdLogger.Printf("Failed to checkout branch '%s': %v", branchName, gitErr) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + + // 2) Generate data step + stdLogger.Println("Starting data generation step...") + genJobID, genErr := srv.startGenerateJob() + if genErr != nil { + stdLogger.Printf("Data generation step failed: %v", genErr) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + stdLogger.Printf("Data generation step started with job_id=%s", genJobID) + + for { + time.Sleep(5 * time.Second) + genJob, err := srv.getJob(genJobID) + if err != nil || genJob == nil { + stdLogger.Printf("Data generation job %s not found or error: %v", genJobID, err) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + if genJob.Status == "failed" { + stdLogger.Println("Data generation step failed.") + job.Status = "failed" + _ = srv.updateJob(job) + return + } + if genJob.Status == "finished" { + stdLogger.Println("Data generation step completed successfully.") + break + } + } + + // 3) Train step + stdLogger.Println("Starting training step...") + trainJobID, trainErr := srv.startTrainJob(modelName, branchName, epochs) + if trainErr != nil { + stdLogger.Printf("Training step failed to start: %v", trainErr) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + stdLogger.Printf("Training step started with job_id=%s", trainJobID) + + for { + time.Sleep(5 * time.Second) + tJob, err := srv.getJob(trainJobID) + if err != nil || tJob == nil { + stdLogger.Printf("Training job %s not found or error: %v", trainJobID, err) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + if tJob.Status == "failed" { + stdLogger.Println("Training step failed.") + job.Status = "failed" + _ = srv.updateJob(job) + return + } + if tJob.Status == "finished" { + stdLogger.Println("Training step completed successfully.") + break + } + } + + job.Status = "finished" + _ = srv.updateJob(job) + stdLogger.Println("Pipeline job completed successfully.") +} + +// findLatestFileWithPrefix returns the newest file in dir that starts with prefix. +func (srv *ILabServer) findLatestFileWithPrefix(dir, prefix string) (string, error) { + files, err := ioutil.ReadDir(dir) + if err != nil { + return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) + } + var latest os.FileInfo + for _, f := range files { + if f.IsDir() { + continue + } + if strings.HasPrefix(f.Name(), prefix) && strings.HasSuffix(f.Name(), ".jsonl") { + if latest == nil || f.ModTime().After(latest.ModTime()) { + latest = f + } + } + } + if latest == nil { + return "", fmt.Errorf("no file found in %s with prefix '%s'", dir, prefix) + } + return filepath.Join(dir, latest.Name()), nil +} + +// overwriteCopy copies src to dst (overwrites if dst exists). +func (srv *ILabServer) overwriteCopy(src, dst string) error { + input, err := ioutil.ReadFile(src) + if err != nil { + return err + } + if err := ioutil.WriteFile(dst, input, 0644); err != nil { + return err + } + return nil +} + +// getFullModelPath returns the directory or file path for a given model name. +func (srv *ILabServer) getFullModelPath(modelName string) (string, error) { + // If the user passed something like "models/instructlab/my-model" we keep it + // but place it in ~/.cache/instructlab/models/... + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("cannot find home directory: %v", err) + } + base := filepath.Join(home, ".cache", "instructlab") + return filepath.Join(base, modelName), nil +} + +// runIlabCommand executes the ilab command with the provided arguments and returns combined output. +func (srv *ILabServer) runIlabCommand(args ...string) (string, error) { + cmdPath := srv.getIlabCommand() + cmd := exec.Command(cmdPath, args...) + if !srv.rhelai { + cmd.Dir = srv.baseDir + } + out, err := cmd.CombinedOutput() + return string(out), err +} + +// parseModelList parses the output of the "ilab model list" command into a slice of Model. +func (srv *ILabServer) parseModelList(output string) ([]Model, error) { + var models []Model + lines := strings.Split(output, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "+") || strings.HasPrefix(line, "| Model Name") || line == "" { + continue + } + if strings.HasPrefix(line, "|") { + line = strings.Trim(line, "|") + fields := strings.Split(line, "|") + if len(fields) != 3 { + continue + } + model := Model{ + Name: strings.TrimSpace(fields[0]), + LastModified: strings.TrimSpace(fields[1]), + Size: strings.TrimSpace(fields[2]), + } + models = append(models, model) + } + } + return models, nil +} + +// parseDataList parses the output of the "ilab data list" command into a slice of Data. +func (srv *ILabServer) parseDataList(output string) ([]Data, error) { + var dataList []Data + lines := strings.Split(output, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "+") || strings.HasPrefix(line, "| Dataset") || line == "" { + continue + } + if strings.HasPrefix(line, "|") { + line = strings.Trim(line, "|") + fields := strings.Split(line, "|") + if len(fields) != 3 { + continue + } + data := Data{ + Dataset: strings.TrimSpace(fields[0]), + CreatedAt: strings.TrimSpace(fields[1]), + FileSize: strings.TrimSpace(fields[2]), + } + dataList = append(dataList, data) + } + } + return dataList, nil +} diff --git a/api-server/qna-eval/Containerfile b/api-server/qna-eval/Containerfile new file mode 100644 index 00000000..02a5fdca --- /dev/null +++ b/api-server/qna-eval/Containerfile @@ -0,0 +1,40 @@ +# Podman container tookit support https://docs.nvidia.com/ai-enterprise/deployment/rhel-with-kvm/latest/podman.html +FROM python:3.11-slim + +RUN apt-get update && apt-get install -y --no-install-recommends \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Set the working directory inside the container +WORKDIR /app + +# Copy the requirements.txt to the container +COPY requirements.txt . + +# Install Python dependencies +RUN pip install --no-cache-dir -r requirements.txt + +# Copy the Python script to the container +COPY validate-qna.py . + +# Set the entrypoint to execute the Python script +ENTRYPOINT ["python", "validate-qna.py"] + +# Build the container image +# podman build -t qna-eval . + +# Run the container with the necessary arguments and volume mount +# podman run --rm \ +# --device nvidia.com/gpu=1 \ +# -v /var/home/cloud-user/:/var/home/cloud-user/ \ +# qna-eval \ +# --model_path "/var/home/cloud-user/.local/share/instructlab/checkpoints/samples_134632/" \ +# --yaml_file "/var/home/cloud-user/.local/share/instructlab/taxonomy/knowledge/history/foo/qna.yaml" + +# Or run from this quay repo +# podman run --rm \ +# --device nvidia.com/gpu=1 \ +# -v /var/home/cloud-user/:/var/home/cloud-user/ \ +# quay.io/bsalisbu/qna-eval \ +# --model_path "/var/home/cloud-user/.local/share/instructlab/checkpoints/samples_134632/" \ +# --yaml_file "/var/home/cloud-user/.local/share/instructlab/taxonomy/knowledge/history/foo/qna.yaml" diff --git a/api-server/qna-eval/qna-eval.py b/api-server/qna-eval/qna-eval.py new file mode 100644 index 00000000..825b594a --- /dev/null +++ b/api-server/qna-eval/qna-eval.py @@ -0,0 +1,101 @@ +import argparse +import yaml +from vllm import LLM, SamplingParams + +def extract_questions(yaml_file): + """ + Extracts all questions from the 'questions_and_answers' sections of the YAML file. + + Args: + yaml_file (str): Path to the qna.yaml file. + + Returns: + list: A list of questions extracted from the YAML file. + """ + with open(yaml_file, 'r') as f: + data = yaml.safe_load(f) + + questions = [] + # Navigate through the YAML structure to find all questions + seed_examples = data.get('seed_examples', []) + for example in seed_examples: + qna_list = example.get('questions_and_answers', []) + for qna in qna_list: + question = qna.get('question') + if question: + # Clean up the question if it starts with 'Q: ' or similar prefixes + if question.lower().startswith('q:'): + question = question[2:].strip() + questions.append(question) + return questions + +def query_model(llm, system_prompt, question): + """ + Constructs the prompt and queries the model to get the answer. + + Args: + llm (LLM): The language model instance. + system_prompt (str): The system prompt to set the model's context. + question (str): The question to query. + + Returns: + str: The answer generated by the model. + """ + prompt = f"<|system|>{system_prompt}<|user|>{question}<|assistant|>" + + sampling_params = SamplingParams( + max_tokens=200, + temperature=0, + ) + + response_generator = llm.generate(prompt, sampling_params) + answer = "" + + for response in response_generator: + # Debugging: Print the entire response object + #print("\n--- Debugging Response ---") + #print(response) + #print("--- End of Response ---\n") + + # Check if 'outputs' exist and have at least one CompletionOutput + if hasattr(response, 'outputs') and len(response.outputs) > 0: + completion = response.outputs[0] + if hasattr(completion, 'text'): + answer += completion.text.strip() + else: + print("Debug: 'text' attribute not found in CompletionOutput.") + else: + print("Debug: 'outputs' not found or empty in the response.") + + return answer + +def main(): + # Set up command-line argument parsing + parser = argparse.ArgumentParser(description='Query model with questions from a YAML file.') + parser.add_argument('--model_path', type=str, required=True, help='Path to the language model.') + parser.add_argument('--yaml_file', type=str, required=True, help='Path to the qna.yaml file.') + args = parser.parse_args() + + # Extract questions from the YAML file + questions = extract_questions(args.yaml_file) + + # Initialize the language model + llm = LLM( + model=args.model_path, + # dtype="bfloat16", # Adjust dtype as needed + ) + + # Define the system prompt + system_prompt = ( + "I am a Red Hat® Instruct Model, an AI language model developed by Red Hat and IBM Research " + "based on the granite-3.0-8b-base model. My primary role is to serve as a chat assistant." + ) + + # Iterate over each question, query the model, and print the Q&A + for idx, question in enumerate(questions, 1): + answer = query_model(llm, system_prompt, question) + print(f"Q{idx}: {question}") + print(f"A{idx}: {answer}\n") + +if __name__ == '__main__': + main() diff --git a/api-server/qna-eval/requirements.txt b/api-server/qna-eval/requirements.txt new file mode 100644 index 00000000..a6182302 --- /dev/null +++ b/api-server/qna-eval/requirements.txt @@ -0,0 +1,176 @@ +aiohappyeyeballs==2.4.4 +aiohttp==3.11.11 +aiohttp-cors==0.7.0 +aiosignal==1.3.2 +airportsdata==20241001 +annotated-types==0.7.0 +anyio==4.8.0 +astor==0.8.1 +asttokens==3.0.0 +attrs==24.3.0 +backcall==0.2.0 +beautifulsoup4==4.12.3 +blake3==1.0.1 +bleach==6.2.0 +cachetools==5.5.0 +certifi==2024.12.14 +charset-normalizer==3.4.1 +click==8.1.8 +cloudpickle==3.1.0 +colorful==0.5.6 +compressed-tensors==0.8.1 +decorator==5.1.1 +defusedxml==0.7.1 +depyf==0.18.0 +dill==0.3.9 +diskcache==5.6.3 +distlib==0.3.9 +distro==1.9.0 +docopt==0.6.2 +einops==0.8.0 +executing==2.1.0 +fastapi==0.115.6 +fastjsonschema==2.21.1 +filelock==3.16.1 +frozenlist==1.5.0 +fsspec==2024.12.0 +gguf==0.10.0 +google-api-core==2.24.0 +google-auth==2.37.0 +googleapis-common-protos==1.66.0 +grpcio==1.69.0 +h11==0.14.0 +httpcore==1.0.7 +httptools==0.6.4 +httpx==0.28.1 +huggingface-hub==0.27.1 +idna==3.10 +importlib_metadata==8.5.0 +iniconfig==2.0.0 +interegular==0.3.3 +ipython==8.12.3 +jedi==0.19.2 +Jinja2==3.1.5 +jiter==0.8.2 +jsonschema==4.23.0 +jsonschema-specifications==2024.10.1 +jupyter_client==8.6.3 +jupyter_core==5.7.2 +jupyterlab_pygments==0.3.0 +lark==1.2.2 +linkify-it-py==2.0.3 +lm-format-enforcer==0.10.9 +markdown-it-py==3.0.0 +MarkupSafe==3.0.2 +matplotlib-inline==0.1.7 +mdit-py-plugins==0.4.2 +mdurl==0.1.2 +memray==1.15.0 +mistral_common==1.5.1 +mistune==3.1.0 +mpmath==1.3.0 +msgpack==1.1.0 +msgspec==0.19.0 +multidict==6.1.0 +nbclient==0.10.2 +nbconvert==7.16.5 +nbformat==5.10.4 +nest-asyncio==1.6.0 +networkx==3.4.2 +numpy==1.26.4 +nvidia-cublas-cu12==12.4.5.8 +nvidia-cuda-cupti-cu12==12.4.127 +nvidia-cuda-nvrtc-cu12==12.4.127 +nvidia-cuda-runtime-cu12==12.4.127 +nvidia-cudnn-cu12==9.1.0.70 +nvidia-cufft-cu12==11.2.1.3 +nvidia-curand-cu12==10.3.5.147 +nvidia-cusolver-cu12==11.6.1.9 +nvidia-cusparse-cu12==12.3.1.170 +nvidia-ml-py==12.560.30 +nvidia-nccl-cu12==2.21.5 +nvidia-nvjitlink-cu12==12.4.127 +nvidia-nvtx-cu12==12.4.127 +openai==1.59.5 +opencensus==0.11.4 +opencensus-context==0.1.3 +opencv-python-headless==4.10.0.84 +outlines==0.1.11 +outlines_core==0.1.26 +packaging==24.2 +pandocfilters==1.5.1 +parso==0.8.4 +partial-json-parser==0.2.1.1.post5 +pexpect==4.9.0 +pickleshare==0.7.5 +pillow==10.4.0 +pipreqs==0.5.0 +platformdirs==4.3.6 +pluggy==1.5.0 +prometheus-fastapi-instrumentator==7.0.0 +prometheus_client==0.21.1 +prompt_toolkit==3.0.48 +propcache==0.2.1 +proto-plus==1.25.0 +protobuf==5.29.3 +psutil==6.1.1 +ptyprocess==0.7.0 +pure_eval==0.2.3 +py-cpuinfo==9.0.0 +py-spy==0.4.0 +pyasn1==0.6.1 +pyasn1_modules==0.4.1 +pybind11==2.13.6 +pycountry==24.6.1 +pydantic==2.10.4 +pydantic_core==2.27.2 +Pygments==2.19.1 +pytest==8.3.4 +python-dateutil==2.9.0.post0 +python-dotenv==1.0.1 +PyYAML==6.0.2 +pyzmq==26.2.0 +ray==2.40.0 +referencing==0.35.1 +regex==2024.11.6 +requests==2.32.3 +rich==13.9.4 +rpds-py==0.22.3 +rsa==4.9 +safetensors==0.5.2 +sentencepiece==0.2.0 +six==1.17.0 +smart-open==7.1.0 +sniffio==1.3.1 +soupsieve==2.6 +stack-data==0.6.3 +starlette==0.41.3 +sympy==1.13.1 +textual==1.0.0 +tiktoken==0.7.0 +tinycss2==1.4.0 +tokenizers==0.21.0 +torch==2.5.1 +torchvision==0.20.1 +tornado==6.4.2 +tqdm==4.67.1 +traitlets==5.14.3 +transformers==4.47.1 +triton==3.1.0 +typing_extensions==4.12.2 +uc-micro-py==1.0.3 +urllib3==2.3.0 +uvicorn==0.34.0 +uvloop==0.21.0 +virtualenv==20.28.1 +vllm==0.6.6.post1 +watchfiles==1.0.3 +wcwidth==0.2.13 +webencodings==0.5.1 +websockets==14.1 +wrapt==1.17.0 +xformers==0.0.28.post3 +xgrammar==0.1.9 +yarg==0.1.9 +yarl==1.18.3 +zipp==3.21.0 diff --git a/api-server/utils.go b/api-server/utils.go new file mode 100644 index 00000000..c9a7ae02 --- /dev/null +++ b/api-server/utils.go @@ -0,0 +1,140 @@ +package main + +import ( + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" +) + +// findLatestFileWithPrefix scans `dir` for all files whose name starts with `prefix`, +// and returns the path of the latest modified file. +func findLatestFileWithPrefix(dir, prefix string) (string, error) { + files, err := ioutil.ReadDir(dir) + if err != nil { + return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) + } + + var latestFile os.FileInfo + for _, f := range files { + if strings.HasPrefix(f.Name(), prefix) && strings.HasSuffix(f.Name(), ".jsonl") { + if latestFile == nil || f.ModTime().After(latestFile.ModTime()) { + latestFile = f + } + } + } + if latestFile == nil { + return "", fmt.Errorf("no file found matching prefix '%s' in '%s'", prefix, dir) + } + return filepath.Join(dir, latestFile.Name()), nil +} + +// overwriteCopy removes `destPath` if it exists, then copies srcPath -> destPath. +func overwriteCopy(srcPath, destPath string) error { + // If the destination file already exists, remove it + if _, err := os.Stat(destPath); err == nil { + if err := os.Remove(destPath); err != nil { + return fmt.Errorf("could not remove existing file '%s': %v", destPath, err) + } + } + + // Open the source + in, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("could not open source file '%s': %v", srcPath, err) + } + defer in.Close() + + // Create the destination + out, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("could not create dest file '%s': %v", destPath, err) + } + defer out.Close() + + // Copy contents + if _, err := io.Copy(out, in); err != nil { + return fmt.Errorf("failed to copy '%s' to '%s': %v", srcPath, destPath, err) + } + + return nil +} + +// getBaseCacheDir returns the base cache directory path: ~/.cache/instructlab/ +func getBaseCacheDir() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %v", err) + } + return filepath.Join(homeDir, ".cache", "instructlab"), nil +} + +// getFullModelPath converts a user-supplied model name into a fully qualified path: +// +// ~/.cache/instructlab/models/ +func getFullModelPath(modelName string) (string, error) { + baseCacheDir, err := getBaseCacheDir() + if err != nil { + return "", err + } + // If user-supplied name already starts with "models/", don't prepend again + if strings.HasPrefix(modelName, "models/") { + return filepath.Join(baseCacheDir, modelName), nil + } + return filepath.Join(baseCacheDir, "models", modelName), nil +} + +// findLatestDirWithPrefix finds the most recently modified directory within 'dir' that starts with 'prefix'. +func (srv *ILabServer) findLatestDirWithPrefix(dir, prefix string) (string, error) { + entries, err := ioutil.ReadDir(dir) + if err != nil { + return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) + } + + var latestDir os.FileInfo + for _, entry := range entries { + if !entry.IsDir() { + continue + } + if strings.HasPrefix(entry.Name(), prefix) { + if latestDir == nil || entry.ModTime().After(latestDir.ModTime()) { + latestDir = entry + } + } + } + + if latestDir == nil { + return "", fmt.Errorf("no directory found in '%s' with prefix '%s'", dir, prefix) + } + + latestPath := filepath.Join(dir, latestDir.Name()) + return latestPath, nil +} + +// getLatestDatasetFile returns the path to the latest dataset file named "knowledge_train_msgs_*.jsonl". +func (srv *ILabServer) getLatestDatasetFile() (string, error) { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %v", err) + } + datasetDir := filepath.Join(homeDir, ".local", "share", "instructlab", "datasets") + files, err := ioutil.ReadDir(datasetDir) + if err != nil { + return "", fmt.Errorf("failed to read dataset directory: %v", err) + } + + var latestFile os.FileInfo + for _, file := range files { + if strings.HasPrefix(file.Name(), "knowledge_train_msgs_") && strings.HasSuffix(file.Name(), ".jsonl") { + if latestFile == nil || file.ModTime().After(latestFile.ModTime()) { + latestFile = file + } + } + } + if latestFile == nil { + return "", fmt.Errorf("no dataset file found with the prefix 'knowledge_train_msgs_'") + } + return filepath.Join(datasetDir, latestFile.Name()), nil +} diff --git a/api-server/vllm-serve.go b/api-server/vllm-serve.go new file mode 100644 index 00000000..48436d6a --- /dev/null +++ b/api-server/vllm-serve.go @@ -0,0 +1,178 @@ +package main + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + "os/exec" + "strings" +) + +// VllmContainer details of a vllm container. +type VllmContainer struct { + ContainerID string `json:"container_id"` + Image string `json:"image"` + Command string `json:"command"` + CreatedAt string `json:"created_at"` + Status string `json:"status"` + Ports string `json:"ports"` + Names string `json:"names"` + ServedModelName string `json:"served_model_name"` + ModelPath string `json:"model_path"` +} + +// ListVllmContainers retrieves all running vllm containers and extracts the +// --served-model-name and --model values. +func (srv *ILabServer) ListVllmContainers() ([]VllmContainer, error) { + // Define a custom format with a pipe delimiter to avoid splitting on spaces. + format := "{{.ID}}|{{.Image}}|{{.Command}}|{{.CreatedAt}}|{{.Status}}|{{.Ports}}|{{.Names}}" + + cmd := exec.Command("podman", "ps", + "--filter", "ancestor=vllm/vllm-openai:latest", + "--format", format, + ) + + var out, stderr bytes.Buffer + cmd.Stdout = &out + cmd.Stderr = &stderr + + if err := cmd.Run(); err != nil { + return nil, fmt.Errorf("error running podman ps: %v, stderr: %s", err, stderr.String()) + } + + lines := strings.Split(strings.TrimSpace(out.String()), "\n") + var containers []VllmContainer + + for _, line := range lines { + if strings.TrimSpace(line) == "" { + continue + } + parts := strings.Split(line, "|") + if len(parts) != 7 { + srv.log.Warnf("Skipping malformed podman ps line: %s", line) + continue + } + + containerID := strings.TrimSpace(parts[0]) + image := strings.TrimSpace(parts[1]) + command := strings.TrimSpace(parts[2]) + createdAt := strings.TrimSpace(parts[3]) + status := strings.TrimSpace(parts[4]) + ports := strings.TrimSpace(parts[5]) + names := strings.TrimSpace(parts[6]) + + // Inspect the container to get the full command & extract args + servedModelName, modelPath, err := srv.ExtractVllmArgs(containerID) + if err != nil { + srv.log.Warnf("Error extracting vllm args for container %s: %v", containerID, err) + continue + } + + container := VllmContainer{ + ContainerID: containerID, + Image: image, + Command: command, + CreatedAt: createdAt, + Status: status, + Ports: ports, + Names: names, + ServedModelName: servedModelName, + ModelPath: modelPath, + } + containers = append(containers, container) + } + + return containers, nil +} + +// ExtractVllmArgs inspects a container and extracts --served-model-name and --model values. +func (srv *ILabServer) ExtractVllmArgs(containerID string) (string, string, error) { + inspectCmd := exec.Command("podman", "inspect", + "--format", "{{json .Config.Cmd}}", + containerID, + ) + + var inspectOut, inspectErr bytes.Buffer + inspectCmd.Stdout = &inspectOut + inspectCmd.Stderr = &inspectErr + + if err := inspectCmd.Run(); err != nil { + return "", "", fmt.Errorf("error inspecting container %s: %v, stderr: %s", + containerID, err, inspectErr.String()) + } + + // The command is a JSON array, e.g. ["--host","0.0.0.0","--port","8000","--model","/path","--served-model-name","pre-train"] + var cmdArgs []string + if err := json.Unmarshal(inspectOut.Bytes(), &cmdArgs); err != nil { + return "", "", fmt.Errorf("error unmarshalling command args for container %s: %v", + containerID, err) + } + + servedModelName, modelPath, err := srv.parseVllmArgs(cmdArgs) + if err != nil { + return "", "", fmt.Errorf("error parsing vllm args for container %s: %v", containerID, err) + } + return servedModelName, modelPath, nil +} + +// parseVllmArgs parses the command-line arguments to extract --served-model-name and --model values. +func (srv *ILabServer) parseVllmArgs(args []string) (string, string, error) { + var servedModelName, modelPath string + + for i := 0; i < len(args); i++ { + switch args[i] { + case "--served-model-name": + if i+1 < len(args) { + servedModelName = args[i+1] + i++ + } else { + return "", "", errors.New("missing value for --served-model-name") + } + case "--model": + if i+1 < len(args) { + modelPath = args[i+1] + i++ + } else { + return "", "", errors.New("missing value for --model") + } + } + } + if servedModelName == "" || modelPath == "" { + return "", "", errors.New("required arguments --served-model-name or --model not found") + } + return servedModelName, modelPath, nil +} + +// StopVllmContainer stops a running vllm container based on the served model name. +func (srv *ILabServer) StopVllmContainer(servedModelName string) error { + containers, err := srv.ListVllmContainers() + if err != nil { + return fmt.Errorf("failed to list vllm containers: %v", err) + } + + var targetContainer *VllmContainer + for _, c := range containers { + if c.ServedModelName == servedModelName { + targetContainer = &c + break + } + } + if targetContainer == nil { + return fmt.Errorf("no vllm container found with served-model-name '%s'", servedModelName) + } + + stopCmd := exec.Command("podman", "stop", targetContainer.ContainerID) + var stopOut, stopErr bytes.Buffer + stopCmd.Stdout = &stopOut + stopCmd.Stderr = &stopErr + + if err := stopCmd.Run(); err != nil { + return fmt.Errorf("error stopping container %s: %v, stderr: %s", + targetContainer.ContainerID, err, stopErr.String()) + } + + srv.log.Infof("Successfully stopped vllm container '%s' with served-model-name '%s'", + targetContainer.ContainerID, servedModelName) + return nil +} diff --git a/api-server/zap.go b/api-server/zap.go new file mode 100644 index 00000000..94b22132 --- /dev/null +++ b/api-server/zap.go @@ -0,0 +1,34 @@ +package main + +import ( + "fmt" + + "go.uber.org/zap" +) + +// ----------------------------------------------------------------------------- +// Logger Initialization +// ----------------------------------------------------------------------------- +func (srv *ILabServer) initLogger(debug bool) { + var cfg zap.Config + + if debug { + cfg = zap.NewDevelopmentConfig() + cfg.Level = zap.NewAtomicLevelAt(zap.DebugLevel) + } else { + cfg = zap.NewProductionConfig() + cfg.Level = zap.NewAtomicLevelAt(zap.InfoLevel) + } + + logger, err := cfg.Build() + if err != nil { + panic(fmt.Sprintf("Failed to initialize zap logger: %v", err)) + } + + srv.logger = logger + srv.log = logger.Sugar() + + if debug { + srv.log.Debug("Debug logging is ENABLED.") + } +} From 887dab8ba5811ed6001077f320342eeed6133037 Mon Sep 17 00:00:00 2001 From: Brent Salisbury Date: Thu, 23 Jan 2025 01:29:44 -0500 Subject: [PATCH 2/2] Remove $HOME hardcodes and added a build/fmt CI Signed-off-by: Brent Salisbury --- .github/workflows/api-server.yml | 53 ++ api-server/go.mod | 6 +- api-server/go.sum | 16 + api-server/handlers.go | 8 +- api-server/main.go | 1459 +++++++++++++++--------------- api-server/utils.go | 218 ++--- api-server/zap.go | 2 +- 7 files changed, 917 insertions(+), 845 deletions(-) create mode 100644 .github/workflows/api-server.yml diff --git a/.github/workflows/api-server.yml b/.github/workflows/api-server.yml new file mode 100644 index 00000000..fc476c18 --- /dev/null +++ b/.github/workflows/api-server.yml @@ -0,0 +1,53 @@ +name: api-server + +on: + push: + branches: + - main + - release-1.0 + pull_request: + branches: + - main + - release-1.0 + +jobs: + fmt-build-test: + runs-on: ubuntu-latest + + env: + CGO_ENABLED: 1 + + defaults: + run: + working-directory: api-server + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Go + uses: actions/setup-go@v4 + with: + go-version: '1.21.6' + + - name: Install Build Dependencies + run: | + sudo apt-get update + sudo apt-get install -y build-essential pkg-config + + - name: Install Go Dependencies + run: | + go mod download + + - name: Go Format + run: | + unformatted=$(gofmt -l .) + if [ -n "$unformatted" ]; then + echo "The following files are not formatted properly:" + echo "$unformatted" + exit 1 + fi + + - name: Build + run: | + go build ./... diff --git a/api-server/go.mod b/api-server/go.mod index 91cb5132..a78b68eb 100644 --- a/api-server/go.mod +++ b/api-server/go.mod @@ -1,16 +1,16 @@ -module ilab-api-router +module github.com/instructlab/ui/api-server -go 1.21.6 +go 1.22.1 require ( github.com/gorilla/mux v1.8.1 github.com/mattn/go-sqlite3 v1.14.24 github.com/spf13/cobra v1.8.1 + go.uber.org/zap v1.27.0 ) require ( github.com/inconshreveable/mousetrap v1.1.0 // indirect github.com/spf13/pflag v1.0.5 // indirect go.uber.org/multierr v1.10.0 // indirect - go.uber.org/zap v1.27.0 // indirect ) diff --git a/api-server/go.sum b/api-server/go.sum index 8246dd0d..27a58e0f 100644 --- a/api-server/go.sum +++ b/api-server/go.sum @@ -1,11 +1,27 @@ github.com/cpuguy83/go-md2man/v2 v2.0.4/go.mod h1:tgQtvFlXSQOSOSIRvRPT7W67SCa46tRHOmNcaadrF8o= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/gorilla/mux v1.8.1 h1:TuBL49tXwgrFYWhqrNgrUNEY92u81SPhu7sTdzQEiWY= github.com/gorilla/mux v1.8.1/go.mod h1:AKf9I4AEqPTmMytcMc0KkNouC66V3BtZ4qD5fmWSiMQ= +github.com/inconshreveable/mousetrap v1.1.0 h1:wN+x4NVGpMsO7ErUn/mUI3vEoE6Jt13X2s0bqwp9tc8= github.com/inconshreveable/mousetrap v1.1.0/go.mod h1:vpF70FUmC8bwa3OWnCshd2FqLfsEA9PFc4w1p2J65bw= +github.com/mattn/go-sqlite3 v1.14.24 h1:tpSp2G2KyMnnQu99ngJ47EIkWVmliIizyZBfPrBWDRM= github.com/mattn/go-sqlite3 v1.14.24/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM= +github.com/spf13/cobra v1.8.1 h1:e5/vxKd/rZsfSJMUX1agtjeTDf+qv1/JdBF8gg5k9ZM= github.com/spf13/cobra v1.8.1/go.mod h1:wHxEcudfqmLYa8iTfL+OuZPbBZkmvliBWKIezN3kD9Y= +github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA= github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg= +github.com/stretchr/testify v1.8.1 h1:w7B6lhMri9wdJUVmEZPGGhZzrYTPvgJArz7wNPgYKsk= +github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o6fzry7u4= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= +go.uber.org/multierr v1.10.0 h1:S0h4aNzvfcFsC3dRF1jLoaov7oRaKqRGC/pUEJ2yvPQ= go.uber.org/multierr v1.10.0/go.mod h1:20+QtiLqy0Nd6FdQB9TLXag12DsQkrbs3htMFfDN80Y= +go.uber.org/zap v1.27.0 h1:aJMhYGrd5QSmlpLMr2MftRKl7t8J8PTZPA732ud/XR8= go.uber.org/zap v1.27.0/go.mod h1:GB2qFLM7cTU87MWRP2mPIjqfIDnGu+VIO4V/SdhGo2E= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= diff --git a/api-server/handlers.go b/api-server/handlers.go index 2ad30d05..147669ce 100644 --- a/api-server/handlers.go +++ b/api-server/handlers.go @@ -581,8 +581,8 @@ func (srv *ILabServer) serveLatestCheckpointHandler(w http.ResponseWriter, r *ht "8001", "post-train", 1, - "/var/home/cloud-user", - "/var/home/cloud-user", + srv.homeDir, + srv.homeDir, w, ) } else { @@ -612,8 +612,8 @@ func (srv *ILabServer) serveBaseModelHandler(w http.ResponseWriter, r *http.Requ "8000", "pre-train", 0, - "/var/home/cloud-user", - "/var/home/cloud-user", + srv.homeDir, + srv.homeDir, w, ) } else { diff --git a/api-server/main.go b/api-server/main.go index 094cbd9c..f6420aef 100644 --- a/api-server/main.go +++ b/api-server/main.go @@ -1,22 +1,22 @@ package main import ( - "database/sql" - "encoding/json" - "fmt" - "io/ioutil" - "net/http" - "os" - "os/exec" - "path/filepath" - "strings" - "sync" - "time" - - "github.com/gorilla/mux" - _ "github.com/mattn/go-sqlite3" - "github.com/spf13/cobra" - "go.uber.org/zap" + "database/sql" + "encoding/json" + "fmt" + "io/ioutil" + "net/http" + "os" + "os/exec" + "path/filepath" + "strings" + "sync" + "time" + + "github.com/gorilla/mux" + _ "github.com/mattn/go-sqlite3" + "github.com/spf13/cobra" + "go.uber.org/zap" ) // ----------------------------------------------------------------------------- @@ -25,59 +25,59 @@ import ( // Model represents a model record (from 'ilab model list'). type Model struct { - Name string `json:"name"` - LastModified string `json:"last_modified"` - Size string `json:"size"` + Name string `json:"name"` + LastModified string `json:"last_modified"` + Size string `json:"size"` } // Data represents a data record (from 'ilab data list'). type Data struct { - Dataset string `json:"dataset"` - CreatedAt string `json:"created_at"` - FileSize string `json:"file_size"` + Dataset string `json:"dataset"` + CreatedAt string `json:"created_at"` + FileSize string `json:"file_size"` } // Job represents a background job, including train/generate/pipeline/vllm-run jobs. type Job struct { - JobID string `json:"job_id"` - Cmd string `json:"cmd"` - Args []string `json:"args"` - Status string `json:"status"` // "running", "finished", "failed" - PID int `json:"pid"` - LogFile string `json:"log_file"` - StartTime time.Time `json:"start_time"` - EndTime *time.Time `json:"end_time,omitempty"` - Branch string `json:"branch"` - - // Lock is not serialized; it protects updates to the Job in memory. - Lock sync.Mutex `json:"-"` + JobID string `json:"job_id"` + Cmd string `json:"cmd"` + Args []string `json:"args"` + Status string `json:"status"` // "running", "finished", "failed" + PID int `json:"pid"` + LogFile string `json:"log_file"` + StartTime time.Time `json:"start_time"` + EndTime *time.Time `json:"end_time,omitempty"` + Branch string `json:"branch"` + + // Lock is not serialized; it protects updates to the Job in memory. + Lock sync.Mutex `json:"-"` } // ModelCache encapsulates the cached models and related metadata. type ModelCache struct { - Models []Model - Time time.Time - Mutex sync.Mutex + Models []Model + Time time.Time + Mutex sync.Mutex } // QnaEvalRequest is used by the /qna-eval endpoint. type QnaEvalRequest struct { - ModelPath string `json:"model_path"` - YamlFile string `json:"yaml_file"` + ModelPath string `json:"model_path"` + YamlFile string `json:"yaml_file"` } // VllmContainerResponse is returned by the /vllm-containers endpoint. type VllmContainerResponse struct { - Containers []VllmContainer `json:"containers"` + Containers []VllmContainer `json:"containers"` } type ServeModelRequest struct { - Checkpoint string `json:"checkpoint,omitempty"` // Optional: Name of the checkpoint directory (e.g., "samples_12345") + Checkpoint string `json:"checkpoint,omitempty"` // Optional: Name of the checkpoint directory (e.g., "samples_12345") } // UnloadModelRequest is used by the /vllm-unload endpoint. type UnloadModelRequest struct { - ModelName string `json:"model_name"` // Expected: "pre-train" or "post-train" + ModelName string `json:"model_name"` // Expected: "pre-train" or "post-train" } // ----------------------------------------------------------------------------- @@ -85,37 +85,38 @@ type UnloadModelRequest struct { // ----------------------------------------------------------------------------- type ILabServer struct { - baseDir string - taxonomyPath string - rhelai bool - ilabCmd string - isOSX bool - isCuda bool - useVllm bool - pipelineType string - debugEnabled bool - - // Logger - logger *zap.Logger - log *zap.SugaredLogger - - // Database handle - db *sql.DB - - // Model processes for CPU-based or local serving (if not using VLLM) - modelLock sync.Mutex - modelProcessBase *exec.Cmd - modelProcessLatest *exec.Cmd - - // Base model reference (used in some logic but not necessary if VLLM is used with default paths) - baseModel string - - // Map of "pre-train"/"post-train" => jobID for VLLM serving - servedModelJobIDs map[string]string - jobIDsMutex sync.RWMutex - - // Cache variables - modelCache ModelCache + baseDir string + taxonomyPath string + rhelai bool + ilabCmd string + isOSX bool + isCuda bool + useVllm bool + pipelineType string + debugEnabled bool + homeDir string // New field added + + // Logger + logger *zap.Logger + log *zap.SugaredLogger + + // Database handle + db *sql.DB + + // Model processes for CPU-based or local serving (if not using VLLM) + modelLock sync.Mutex + modelProcessBase *exec.Cmd + modelProcessLatest *exec.Cmd + + // Base model reference + baseModel string + + // Map of "pre-train"/"post-train" => jobID for VLLM serving + servedModelJobIDs map[string]string + jobIDsMutex sync.RWMutex + + // Cache variables + modelCache ModelCache } // ----------------------------------------------------------------------------- @@ -123,158 +124,160 @@ type ILabServer struct { // ----------------------------------------------------------------------------- func main() { - // We create an instance of ILabServer to hold all state and methods. - srv := &ILabServer{ - baseModel: "instructlab/granite-7b-lab", - servedModelJobIDs: make(map[string]string), - modelCache: ModelCache{}, - } - - rootCmd := &cobra.Command{ - Use: "ilab-server", - Short: "ILab Server Application", - Run: func(cmd *cobra.Command, args []string) { - // Now that flags are set, run the server method on the struct. - srv.runServer(cmd, args) - }, - } - - // Define flags - rootCmd.Flags().BoolVar(&srv.rhelai, "rhelai", false, "Use ilab binary from PATH instead of Python virtual environment") - rootCmd.Flags().StringVar(&srv.baseDir, "base-dir", "", "Base directory for ilab operations (required if --rhelai is not set)") - rootCmd.Flags().StringVar(&srv.taxonomyPath, "taxonomy-path", "", "Path to the taxonomy repository for Git operations (required)") - rootCmd.Flags().BoolVar(&srv.isOSX, "osx", false, "Enable OSX-specific settings (default: false)") - rootCmd.Flags().BoolVar(&srv.isCuda, "cuda", false, "Enable Cuda (default: false)") - rootCmd.Flags().BoolVar(&srv.useVllm, "vllm", false, "Enable VLLM model serving using podman containers") - rootCmd.Flags().StringVar(&srv.pipelineType, "pipeline", "", "Pipeline type (simple, accelerated, full)") - rootCmd.Flags().BoolVar(&srv.debugEnabled, "debug", false, "Enable debug logging") - - // PreRun to validate flags - rootCmd.PreRunE = func(cmd *cobra.Command, args []string) error { - if !srv.rhelai && srv.baseDir == "" { - return fmt.Errorf("--base-dir is required unless --rhelai is set") - } - if srv.taxonomyPath == "" { - return fmt.Errorf("--taxonomy-path is required") - } - - // Validate or set pipelineType based on --rhelai - if !srv.rhelai { - if srv.pipelineType == "" { - return fmt.Errorf("--pipeline is required unless --rhelai is set") - } - switch srv.pipelineType { - case "simple", "full", "accelerated": - // Valid - default: - return fmt.Errorf("--pipeline must be 'simple', 'accelerated' or 'full'; got '%s'", srv.pipelineType) - } - } else { - // When --rhelai is set and --pipeline is not provided, set a default - if srv.pipelineType == "" { - srv.pipelineType = "accelerated" - fmt.Println("--rhelai is set; defaulting --pipeline to 'accelerated'") - } else { - switch srv.pipelineType { - case "simple", "full", "accelerated": - // Valid - default: - return fmt.Errorf("--pipeline must be 'simple', 'accelerated' or 'full'; got '%s'", srv.pipelineType) - } - } - } - return nil - } - - if err := rootCmd.Execute(); err != nil { - fmt.Printf("Error executing command: %v\n", err) - os.Exit(1) - } + // We create an instance of ILabServer to hold all state and methods. + srv := &ILabServer{ + baseModel: "instructlab/granite-7b-lab", + servedModelJobIDs: make(map[string]string), + modelCache: ModelCache{}, + } + + rootCmd := &cobra.Command{ + Use: "ilab-server", + Short: "ILab Server Application", + Run: func(cmd *cobra.Command, args []string) { + // Now that flags are set, run the server method on the struct. + srv.runServer(cmd, args) + }, + } + + // Define flags + rootCmd.Flags().BoolVar(&srv.rhelai, "rhelai", false, "Use ilab binary from PATH instead of Python virtual environment") + rootCmd.Flags().StringVar(&srv.baseDir, "base-dir", "", "Base directory for ilab operations (required if --rhelai is not set)") + rootCmd.Flags().StringVar(&srv.taxonomyPath, "taxonomy-path", "", "Path to the taxonomy repository for Git operations (required)") + rootCmd.Flags().BoolVar(&srv.isOSX, "osx", false, "Enable OSX-specific settings (default: false)") + rootCmd.Flags().BoolVar(&srv.isCuda, "cuda", false, "Enable Cuda (default: false)") + rootCmd.Flags().BoolVar(&srv.useVllm, "vllm", false, "Enable VLLM model serving using podman containers") + rootCmd.Flags().StringVar(&srv.pipelineType, "pipeline", "", "Pipeline type (simple, accelerated, full)") + rootCmd.Flags().BoolVar(&srv.debugEnabled, "debug", false, "Enable debug logging") + + // PreRun to validate flags + rootCmd.PreRunE = func(cmd *cobra.Command, args []string) error { + if !srv.rhelai && srv.baseDir == "" { + return fmt.Errorf("--base-dir is required unless --rhelai is set") + } + if srv.taxonomyPath == "" { + return fmt.Errorf("--taxonomy-path is required") + } + + // Validate or set pipelineType based on --rhelai + if !srv.rhelai { + if srv.pipelineType == "" { + return fmt.Errorf("--pipeline is required unless --rhelai is set") + } + switch srv.pipelineType { + case "simple", "full", "accelerated": + // Valid + default: + return fmt.Errorf("--pipeline must be 'simple', 'accelerated' or 'full'; got '%s'", srv.pipelineType) + } + } else { + // When --rhelai is set and --pipeline is not provided, set a default + if srv.pipelineType == "" { + srv.pipelineType = "accelerated" + fmt.Println("--rhelai is set; defaulting --pipeline to 'accelerated'") + } else { + switch srv.pipelineType { + case "simple", "full", "accelerated": + // Valid + default: + return fmt.Errorf("--pipeline must be 'simple', 'accelerated' or 'full'; got '%s'", srv.pipelineType) + } + } + } + return nil + } + + if err := rootCmd.Execute(); err != nil { + fmt.Printf("Error executing command: %v\n", err) + os.Exit(1) + } } // runServer is the main entry method after flags are parsed. func (srv *ILabServer) runServer(cmd *cobra.Command, args []string) { - // Initialize zap logger - srv.initLogger(srv.debugEnabled) - - if srv.debugEnabled { - srv.log.Info("Debug logging is ENABLED.") - } else { - srv.log.Info("Debug logging is DISABLED.") - } - - // Initialize the database - srv.initDB() - - // Determine ilab command path - if srv.rhelai { - // Use ilab from PATH - ilabPath, err := exec.LookPath("ilab") - if err != nil { - srv.log.Fatalf("ilab binary not found in PATH. Please ensure ilab is installed and in your PATH.") - } - srv.ilabCmd = ilabPath - } else { - // Use ilab from virtual environment - srv.ilabCmd = filepath.Join(srv.baseDir, "venv", "bin", "ilab") - if _, err := os.Stat(srv.ilabCmd); os.IsNotExist(err) { - srv.log.Fatalf("ilab binary not found at %s. Please ensure the virtual environment is set up correctly.", srv.ilabCmd) - } - } - - srv.log.Infof("Using ilab command: %s", srv.ilabCmd) - - // Validate mandatory arguments if not using rhelai - if !srv.rhelai { - if _, err := os.Stat(srv.baseDir); os.IsNotExist(err) { - srv.log.Fatalf("Base directory does not exist: %s", srv.baseDir) - } - } - - if _, err := os.Stat(srv.taxonomyPath); os.IsNotExist(err) { - srv.log.Fatalf("Taxonomy path does not exist: %s", srv.taxonomyPath) - } - - srv.log.Infof("Running with baseDir=%s, taxonomyPath=%s, isOSX=%v, isCuda=%v, useVllm=%v, pipeline=%s", - srv.baseDir, srv.taxonomyPath, srv.isOSX, srv.isCuda, srv.useVllm, srv.pipelineType) - srv.log.Infof("Current working directory: %s", srv.mustGetCwd()) - - // Check statuses of any jobs that might have been running before a restart - srv.checkRunningJobs() - - // Initialize the model cache - srv.initializeModelCache() - - // Create the logs directory if it doesn't exist - err := os.MkdirAll("logs", os.ModePerm) - if err != nil { - srv.log.Fatalf("Failed to create logs directory: %v", err) - } - - // Setup HTTP routes - r := mux.NewRouter() - r.HandleFunc("/models", srv.getModelsHandler).Methods("GET") - r.HandleFunc("/data", srv.getDataHandler).Methods("GET") - r.HandleFunc("/data/generate", srv.generateDataHandler).Methods("POST") - r.HandleFunc("/model/train", srv.trainModelHandler).Methods("POST") - r.HandleFunc("/jobs/{job_id}/status", srv.getJobStatusHandler).Methods("GET") - r.HandleFunc("/jobs/{job_id}/logs", srv.getJobLogsHandler).Methods("GET") - r.HandleFunc("/jobs", srv.listJobsHandler).Methods("GET") - r.HandleFunc("/pipeline/generate-train", srv.generateTrainPipelineHandler).Methods("POST") - r.HandleFunc("/model/serve-latest", srv.serveLatestCheckpointHandler).Methods("POST") - r.HandleFunc("/model/serve-base", srv.serveBaseModelHandler).Methods("POST") - r.HandleFunc("/qna-eval", srv.runQnaEval).Methods("POST") - r.HandleFunc("/checkpoints", srv.listCheckpointsHandler).Methods("GET") - r.HandleFunc("/vllm-containers", srv.listVllmContainersHandler).Methods("GET") - r.HandleFunc("/vllm-unload", srv.unloadVllmContainerHandler).Methods("POST") - r.HandleFunc("/vllm-status", srv.getVllmStatusHandler).Methods("GET") - r.HandleFunc("/gpu-free", srv.getGpuFreeHandler).Methods("GET") - r.HandleFunc("/served-model-jobids", srv.listServedModelJobIDsHandler).Methods("GET") - - srv.log.Info("Server starting on port 8080... (Taxonomy path: ", srv.taxonomyPath, ")") - if err := http.ListenAndServe("0.0.0.0:8080", r); err != nil { - srv.log.Fatalf("Server failed to start: %v", err) - } + // Initialize zap logger + srv.initLogger(srv.debugEnabled) + + // Initialize the database + srv.initDB() + + // Determine the user's home directory / TODO: alternative approch here for expected path? + homeDir, err := os.UserHomeDir() + if err != nil { + srv.log.Fatalf("Failed to get user home directory: %v", err) + } + srv.homeDir = homeDir + srv.log.Infof("User home directory set to: %s", srv.homeDir) + + // Determine ilab command path + if srv.rhelai { + // Use ilab from PATH + ilabPath, err := exec.LookPath("ilab") + if err != nil { + srv.log.Fatalf("ilab binary not found in PATH. Please ensure ilab is installed and in your PATH.") + } + srv.ilabCmd = ilabPath + } else { + // Use ilab from virtual environment + srv.ilabCmd = filepath.Join(srv.baseDir, "venv", "bin", "ilab") + if _, err := os.Stat(srv.ilabCmd); os.IsNotExist(err) { + srv.log.Fatalf("ilab binary not found at %s. Please ensure the virtual environment is set up correctly.", srv.ilabCmd) + } + } + + srv.log.Infof("Using ilab command: %s", srv.ilabCmd) + + // Validate mandatory arguments if not using rhelai + if !srv.rhelai { + if _, err := os.Stat(srv.baseDir); os.IsNotExist(err) { + srv.log.Fatalf("Base directory does not exist: %s", srv.baseDir) + } + } + + if _, err := os.Stat(srv.taxonomyPath); os.IsNotExist(err) { + srv.log.Fatalf("Taxonomy path does not exist: %s", srv.taxonomyPath) + } + + srv.log.Infof("Running with baseDir=%s, taxonomyPath=%s, isOSX=%v, isCuda=%v, useVllm=%v, pipeline=%s", + srv.baseDir, srv.taxonomyPath, srv.isOSX, srv.isCuda, srv.useVllm, srv.pipelineType) + srv.log.Infof("Current working directory: %s", srv.mustGetCwd()) + + // Check statuses of any jobs that might have been running before a restart + srv.checkRunningJobs() + + // Initialize the model cache + srv.initializeModelCache() + + // Create the logs directory if it doesn't exist + err = os.MkdirAll("logs", os.ModePerm) + if err != nil { + srv.log.Fatalf("Failed to create logs directory: %v", err) + } + + // Setup HTTP routes + r := mux.NewRouter() + r.HandleFunc("/models", srv.getModelsHandler).Methods("GET") + r.HandleFunc("/data", srv.getDataHandler).Methods("GET") + r.HandleFunc("/data/generate", srv.generateDataHandler).Methods("POST") + r.HandleFunc("/model/train", srv.trainModelHandler).Methods("POST") + r.HandleFunc("/jobs/{job_id}/status", srv.getJobStatusHandler).Methods("GET") + r.HandleFunc("/jobs/{job_id}/logs", srv.getJobLogsHandler).Methods("GET") + r.HandleFunc("/jobs", srv.listJobsHandler).Methods("GET") + r.HandleFunc("/pipeline/generate-train", srv.generateTrainPipelineHandler).Methods("POST") + r.HandleFunc("/model/serve-latest", srv.serveLatestCheckpointHandler).Methods("POST") + r.HandleFunc("/model/serve-base", srv.serveBaseModelHandler).Methods("POST") + r.HandleFunc("/qna-eval", srv.runQnaEval).Methods("POST") + r.HandleFunc("/checkpoints", srv.listCheckpointsHandler).Methods("GET") + r.HandleFunc("/vllm-containers", srv.listVllmContainersHandler).Methods("GET") + r.HandleFunc("/vllm-unload", srv.unloadVllmContainerHandler).Methods("POST") + r.HandleFunc("/vllm-status", srv.getVllmStatusHandler).Methods("GET") + r.HandleFunc("/gpu-free", srv.getGpuFreeHandler).Methods("GET") + r.HandleFunc("/served-model-jobids", srv.listServedModelJobIDsHandler).Methods("GET") + + srv.log.Info("Server starting on port 8080... (Taxonomy path: ", srv.taxonomyPath, ")") + if err := http.ListenAndServe("0.0.0.0:8080", r); err != nil { + srv.log.Fatalf("Server failed to start: %v", err) + } } // ----------------------------------------------------------------------------- @@ -283,24 +286,24 @@ func (srv *ILabServer) runServer(cmd *cobra.Command, args []string) { // getIlabCommand returns the path to the ilab command, depending on rhelai or local venv. func (srv *ILabServer) getIlabCommand() string { - return srv.ilabCmd + return srv.ilabCmd } // mustGetCwd returns the current working directory or "unknown" if it fails. func (srv *ILabServer) mustGetCwd() string { - cwd, err := os.Getwd() - if err != nil { - return "unknown" - } - return cwd + cwd, err := os.Getwd() + if err != nil { + return "unknown" + } + return cwd } // sanitizeModelName checks if the modelName starts with "model/" and replaces it with "models/". func (srv *ILabServer) sanitizeModelName(modelName string) string { - if strings.HasPrefix(modelName, "model/") { - return strings.Replace(modelName, "model/", "models/", 1) - } - return modelName + if strings.HasPrefix(modelName, "model/") { + return strings.Replace(modelName, "model/", "models/", 1) + } + return modelName } // ----------------------------------------------------------------------------- @@ -309,40 +312,40 @@ func (srv *ILabServer) sanitizeModelName(modelName string) string { // initializeModelCache refreshes the model cache once and then schedules a refresh every 20 minutes. func (srv *ILabServer) initializeModelCache() { - srv.refreshModelCache() - go func() { - for { - time.Sleep(20 * time.Minute) - srv.refreshModelCache() - } - }() + srv.refreshModelCache() + go func() { + for { + time.Sleep(20 * time.Minute) + srv.refreshModelCache() + } + }() } // refreshModelCache updates the model cache if it's older than 20 minutes or if empty. // TODO: this is really slow due to a caching issue upstream/downstream, should probably be async func (srv *ILabServer) refreshModelCache() { - srv.modelCache.Mutex.Lock() - defer srv.modelCache.Mutex.Unlock() - - if time.Since(srv.modelCache.Time) < 20*time.Minute && len(srv.modelCache.Models) > 0 { - srv.log.Info("Model cache is still valid; no refresh needed.") - return - } - - srv.log.Info("Refreshing model cache... Takes 10-20s") - output, err := srv.runIlabCommand("model", "list") - if err != nil { - srv.log.Errorf("Error refreshing model cache: %v", err) - return - } - models, err := srv.parseModelList(output) - if err != nil { - srv.log.Errorf("Error parsing model list during cache refresh: %v", err) - return - } - srv.modelCache.Models = models - srv.modelCache.Time = time.Now() - srv.log.Infof("Model cache refreshed at %v with %d models.", srv.modelCache.Time, len(models)) + srv.modelCache.Mutex.Lock() + defer srv.modelCache.Mutex.Unlock() + + if time.Since(srv.modelCache.Time) < 20*time.Minute && len(srv.modelCache.Models) > 0 { + srv.log.Info("Model cache is still valid; no refresh needed.") + return + } + + srv.log.Info("Refreshing model cache... Takes 10-20s") + output, err := srv.runIlabCommand("model", "list") + if err != nil { + srv.log.Errorf("Error refreshing model cache: %v", err) + return + } + models, err := srv.parseModelList(output) + if err != nil { + srv.log.Errorf("Error parsing model list during cache refresh: %v", err) + return + } + srv.modelCache.Models = models + srv.modelCache.Time = time.Now() + srv.log.Infof("Model cache refreshed at %v with %d models.", srv.modelCache.Time, len(models)) } // ----------------------------------------------------------------------------- @@ -351,74 +354,74 @@ func (srv *ILabServer) refreshModelCache() { // startGenerateJob launches a job to run "ilab data generate" and tracks it. func (srv *ILabServer) startGenerateJob() (string, error) { - ilabPath := srv.getIlabCommand() - - // Hard-coded pipeline choice for data generate, or we could use srv.pipelineType - cmdArgs := []string{"data", "generate", "--pipeline", "full"} - - cmd := exec.Command(ilabPath, cmdArgs...) - if !srv.rhelai { - cmd.Dir = srv.baseDir - } - - jobID := fmt.Sprintf("g-%d", time.Now().UnixNano()) - logFilePath := filepath.Join("logs", fmt.Sprintf("%s.log", jobID)) - srv.log.Infof("Starting generateDataHandler job: %s, logs: %s", jobID, logFilePath) - - logFile, err := os.Create(logFilePath) - if err != nil { - srv.log.Errorf("Error creating log file: %v", err) - return "", fmt.Errorf("Failed to create log file") - } - cmd.Stdout = logFile - cmd.Stderr = logFile - - srv.log.Infof("Running command: %s %v", ilabPath, cmdArgs) - if err := cmd.Start(); err != nil { - srv.log.Errorf("Error starting data generation command: %v", err) - logFile.Close() - return "", err - } - - newJob := &Job{ - JobID: jobID, - Cmd: ilabPath, - Args: cmdArgs, - Status: "running", - PID: cmd.Process.Pid, - LogFile: logFilePath, - StartTime: time.Now(), - } - if err := srv.createJob(newJob); err != nil { - srv.log.Errorf("Error creating job in DB: %v", err) - return "", err - } - - go func() { - defer logFile.Close() - err := cmd.Wait() - - newJob.Lock.Lock() - defer newJob.Lock.Unlock() - - if err != nil { - newJob.Status = "failed" - srv.log.Infof("Job %s failed with error: %v", newJob.JobID, err) - } else { - if cmd.ProcessState.Success() { - newJob.Status = "finished" - srv.log.Infof("Job %s finished successfully", newJob.JobID) - } else { - newJob.Status = "failed" - srv.log.Infof("Job %s failed", newJob.JobID) - } - } - now := time.Now() - newJob.EndTime = &now - _ = srv.updateJob(newJob) - }() - - return jobID, nil + ilabPath := srv.getIlabCommand() + + // Hard-coded pipeline choice for data generate, or we could use srv.pipelineType + cmdArgs := []string{"data", "generate", "--pipeline", "full"} + + cmd := exec.Command(ilabPath, cmdArgs...) + if !srv.rhelai { + cmd.Dir = srv.baseDir + } + + jobID := fmt.Sprintf("g-%d", time.Now().UnixNano()) + logFilePath := filepath.Join("logs", fmt.Sprintf("%s.log", jobID)) + srv.log.Infof("Starting generateDataHandler job: %s, logs: %s", jobID, logFilePath) + + logFile, err := os.Create(logFilePath) + if err != nil { + srv.log.Errorf("Error creating log file: %v", err) + return "", fmt.Errorf("Failed to create log file") + } + cmd.Stdout = logFile + cmd.Stderr = logFile + + srv.log.Infof("Running command: %s %v", ilabPath, cmdArgs) + if err := cmd.Start(); err != nil { + srv.log.Errorf("Error starting data generation command: %v", err) + logFile.Close() + return "", err + } + + newJob := &Job{ + JobID: jobID, + Cmd: ilabPath, + Args: cmdArgs, + Status: "running", + PID: cmd.Process.Pid, + LogFile: logFilePath, + StartTime: time.Now(), + } + if err := srv.createJob(newJob); err != nil { + srv.log.Errorf("Error creating job in DB: %v", err) + return "", err + } + + go func() { + defer logFile.Close() + err := cmd.Wait() + + newJob.Lock.Lock() + defer newJob.Lock.Unlock() + + if err != nil { + newJob.Status = "failed" + srv.log.Infof("Job %s failed with error: %v", newJob.JobID, err) + } else { + if cmd.ProcessState.Success() { + newJob.Status = "finished" + srv.log.Infof("Job %s finished successfully", newJob.JobID) + } else { + newJob.Status = "failed" + srv.log.Infof("Job %s failed", newJob.JobID) + } + } + now := time.Now() + newJob.EndTime = &now + _ = srv.updateJob(newJob) + }() + + return jobID, nil } // ----------------------------------------------------------------------------- @@ -427,177 +430,177 @@ func (srv *ILabServer) startGenerateJob() (string, error) { // startTrainJob starts a training job with the given parameters. func (srv *ILabServer) startTrainJob(modelName, branchName string, epochs *int) (string, error) { - srv.log.Infof("Starting training job for model: '%s', branch: '%s'", modelName, branchName) - - jobID := fmt.Sprintf("t-%d", time.Now().UnixNano()) - logFilePath := filepath.Join("logs", fmt.Sprintf("%s.log", jobID)) - - fullModelPath, err := srv.getFullModelPath(modelName) - if err != nil { - return "", fmt.Errorf("failed to get full model path: %v", err) - } - srv.log.Infof("Resolved fullModelPath: '%s'", fullModelPath) - - modelDir := filepath.Dir(fullModelPath) - if err := os.MkdirAll(modelDir, os.ModePerm); err != nil { - return "", fmt.Errorf("failed to create model directory '%s': %v", modelDir, err) - } - - ilabPath := srv.getIlabCommand() - - var cmdArgs []string - cmdArgs = append(cmdArgs, "model", "train") - - // If not rhelai, add pipeline if set - if !srv.rhelai && srv.pipelineType != "" { - cmdArgs = append(cmdArgs, "--pipeline", srv.pipelineType) - } - cmdArgs = append(cmdArgs, fmt.Sprintf("--model-path=%s", fullModelPath)) - - if srv.isOSX { - cmdArgs = append(cmdArgs, "--device=mps") - } - if srv.isCuda { - cmdArgs = append(cmdArgs, "--device=cuda") - } - if epochs != nil { - cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) - srv.log.Infof("Number of epochs specified: %d", *epochs) - } else { - srv.log.Info("No epochs specified; using default number of epochs.") - } - - // Additional logic if pipelineType == "simple" (and not rhelai) - if srv.pipelineType == "simple" && !srv.rhelai { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get user home directory: %v", err) - } - datasetDir := filepath.Join(homeDir, ".local", "share", "instructlab", "datasets") - - // Copy the latest knowledge_train_msgs_*.jsonl => train_gen.jsonl - latestTrainFile, err := srv.findLatestFileWithPrefix(datasetDir, "knowledge_train_msgs_") - if err != nil { - return "", fmt.Errorf("failed to find knowledge_train_msgs_*.jsonl file: %v", err) - } - trainGenPath := filepath.Join(datasetDir, "train_gen.jsonl") - if err := srv.overwriteCopy(latestTrainFile, trainGenPath); err != nil { - return "", fmt.Errorf("failed to copy %s to %s: %v", latestTrainFile, trainGenPath, err) - } - - // Copy the latest test_ggml-model-*.jsonl => test_gen.jsonl - latestTestFile, err := srv.findLatestFileWithPrefix(datasetDir, "test_ggml-model") - if err != nil { - return "", fmt.Errorf("failed to find test_ggml-model*.jsonl file: %v", err) - } - testGenPath := filepath.Join(datasetDir, "test_gen.jsonl") - if err := srv.overwriteCopy(latestTestFile, testGenPath); err != nil { - return "", fmt.Errorf("failed to copy %s to %s: %v", latestTestFile, testGenPath, err) - } - - // Reset cmdArgs to a simpler set - cmdArgs = []string{ - "model", "train", - "--pipeline", srv.pipelineType, - fmt.Sprintf("--data-path=%s", datasetDir), - fmt.Sprintf("--model-path=%s", fullModelPath), - } - if srv.isOSX { - cmdArgs = append(cmdArgs, "--device=mps") - } - if srv.isCuda { - cmdArgs = append(cmdArgs, "--device=cuda") - } - if epochs != nil { - cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) - srv.log.Infof("Number of epochs specified for simple pipeline: %d", *epochs) - } else { - srv.log.Info("No epochs specified for simple pipeline; using default number of epochs.") - } - } - - if srv.rhelai { - latestDataset, err := srv.getLatestDatasetFile() - if err != nil { - return "", fmt.Errorf("failed to get latest dataset file: %v", err) - } - cmdArgs = []string{ - "model", "train", - fmt.Sprintf("--data-path=%s", latestDataset), - "--max-batch-len=5000", - "--gpus=4", - "--device=cuda", - "--save-samples=1000", - fmt.Sprintf("--model-path=%s", fullModelPath), - "--pipeline", srv.pipelineType, - } - if epochs != nil { - cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) - srv.log.Infof("Number of epochs specified for rhelai pipeline: %d", *epochs) - } else { - srv.log.Info("No epochs specified for rhelai pipeline; using default number of epochs.") - } - } - - srv.log.Infof("[ILAB TRAIN COMMAND] %s %v", ilabPath, cmdArgs) - - cmd := exec.Command(ilabPath, cmdArgs...) - if !srv.rhelai { - cmd.Dir = srv.baseDir - } - - logFile, err := os.Create(logFilePath) - if err != nil { - return "", fmt.Errorf("failed to create log file '%s': %v", logFilePath, err) - } - defer logFile.Close() - - cmd.Stdout = logFile - cmd.Stderr = logFile - - srv.log.Infof("[ILAB TRAIN COMMAND] %s %v", ilabPath, cmdArgs) - if err := cmd.Start(); err != nil { - return "", fmt.Errorf("error starting training command: %v", err) - } - srv.log.Infof("Training process started with PID: %d", cmd.Process.Pid) - - newJob := &Job{ - JobID: jobID, - Cmd: ilabPath, - Args: cmdArgs, - Status: "running", - PID: cmd.Process.Pid, - LogFile: logFilePath, - Branch: branchName, - StartTime: time.Now(), - } - if err := srv.createJob(newJob); err != nil { - return "", fmt.Errorf("failed to create job in DB: %v", err) - } - - go func() { - defer logFile.Close() - err := cmd.Wait() - - newJob.Lock.Lock() - defer newJob.Lock.Unlock() - - if err != nil { - newJob.Status = "failed" - srv.log.Infof("Training job '%s' failed: %v", newJob.JobID, err) - } else if cmd.ProcessState.Success() { - newJob.Status = "finished" - srv.log.Infof("Training job '%s' finished successfully", newJob.JobID) - } else { - newJob.Status = "failed" - srv.log.Infof("Training job '%s' failed (unknown reason)", newJob.JobID) - } - now := time.Now() - newJob.EndTime = &now - _ = srv.updateJob(newJob) - }() - - return jobID, nil + srv.log.Infof("Starting training job for model: '%s', branch: '%s'", modelName, branchName) + + jobID := fmt.Sprintf("t-%d", time.Now().UnixNano()) + logFilePath := filepath.Join("logs", fmt.Sprintf("%s.log", jobID)) + + fullModelPath, err := srv.getFullModelPath(modelName) + if err != nil { + return "", fmt.Errorf("failed to get full model path: %v", err) + } + srv.log.Infof("Resolved fullModelPath: '%s'", fullModelPath) + + modelDir := filepath.Dir(fullModelPath) + if err := os.MkdirAll(modelDir, os.ModePerm); err != nil { + return "", fmt.Errorf("failed to create model directory '%s': %v", modelDir, err) + } + + ilabPath := srv.getIlabCommand() + + var cmdArgs []string + cmdArgs = append(cmdArgs, "model", "train") + + // If not rhelai, add pipeline if set + if !srv.rhelai && srv.pipelineType != "" { + cmdArgs = append(cmdArgs, "--pipeline", srv.pipelineType) + } + cmdArgs = append(cmdArgs, fmt.Sprintf("--model-path=%s", fullModelPath)) + + if srv.isOSX { + cmdArgs = append(cmdArgs, "--device=mps") + } + if srv.isCuda { + cmdArgs = append(cmdArgs, "--device=cuda") + } + if epochs != nil { + cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) + srv.log.Infof("Number of epochs specified: %d", *epochs) + } else { + srv.log.Info("No epochs specified; using default number of epochs.") + } + + // Additional logic if pipelineType == "simple" (and not rhelai) + if srv.pipelineType == "simple" && !srv.rhelai { + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %v", err) + } + datasetDir := filepath.Join(homeDir, ".local", "share", "instructlab", "datasets") + + // Copy the latest knowledge_train_msgs_*.jsonl => train_gen.jsonl + latestTrainFile, err := srv.findLatestFileWithPrefix(datasetDir, "knowledge_train_msgs_") + if err != nil { + return "", fmt.Errorf("failed to find knowledge_train_msgs_*.jsonl file: %v", err) + } + trainGenPath := filepath.Join(datasetDir, "train_gen.jsonl") + if err := srv.overwriteCopy(latestTrainFile, trainGenPath); err != nil { + return "", fmt.Errorf("failed to copy %s to %s: %v", latestTrainFile, trainGenPath, err) + } + + // Copy the latest test_ggml-model-*.jsonl => test_gen.jsonl + latestTestFile, err := srv.findLatestFileWithPrefix(datasetDir, "test_ggml-model") + if err != nil { + return "", fmt.Errorf("failed to find test_ggml-model*.jsonl file: %v", err) + } + testGenPath := filepath.Join(datasetDir, "test_gen.jsonl") + if err := srv.overwriteCopy(latestTestFile, testGenPath); err != nil { + return "", fmt.Errorf("failed to copy %s to %s: %v", latestTestFile, testGenPath, err) + } + + // Reset cmdArgs to a simpler set + cmdArgs = []string{ + "model", "train", + "--pipeline", srv.pipelineType, + fmt.Sprintf("--data-path=%s", datasetDir), + fmt.Sprintf("--model-path=%s", fullModelPath), + } + if srv.isOSX { + cmdArgs = append(cmdArgs, "--device=mps") + } + if srv.isCuda { + cmdArgs = append(cmdArgs, "--device=cuda") + } + if epochs != nil { + cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) + srv.log.Infof("Number of epochs specified for simple pipeline: %d", *epochs) + } else { + srv.log.Info("No epochs specified for simple pipeline; using default number of epochs.") + } + } + + if srv.rhelai { + latestDataset, err := srv.getLatestDatasetFile() + if err != nil { + return "", fmt.Errorf("failed to get latest dataset file: %v", err) + } + cmdArgs = []string{ + "model", "train", + fmt.Sprintf("--data-path=%s", latestDataset), + "--max-batch-len=5000", + "--gpus=4", + "--device=cuda", + "--save-samples=1000", + fmt.Sprintf("--model-path=%s", fullModelPath), + "--pipeline", srv.pipelineType, + } + if epochs != nil { + cmdArgs = append(cmdArgs, fmt.Sprintf("--num-epochs=%d", *epochs)) + srv.log.Infof("Number of epochs specified for rhelai pipeline: %d", *epochs) + } else { + srv.log.Info("No epochs specified for rhelai pipeline; using default number of epochs.") + } + } + + srv.log.Infof("[ILAB TRAIN COMMAND] %s %v", ilabPath, cmdArgs) + + cmd := exec.Command(ilabPath, cmdArgs...) + if !srv.rhelai { + cmd.Dir = srv.baseDir + } + + logFile, err := os.Create(logFilePath) + if err != nil { + return "", fmt.Errorf("failed to create log file '%s': %v", logFilePath, err) + } + defer logFile.Close() + + cmd.Stdout = logFile + cmd.Stderr = logFile + + srv.log.Infof("[ILAB TRAIN COMMAND] %s %v", ilabPath, cmdArgs) + if err := cmd.Start(); err != nil { + return "", fmt.Errorf("error starting training command: %v", err) + } + srv.log.Infof("Training process started with PID: %d", cmd.Process.Pid) + + newJob := &Job{ + JobID: jobID, + Cmd: ilabPath, + Args: cmdArgs, + Status: "running", + PID: cmd.Process.Pid, + LogFile: logFilePath, + Branch: branchName, + StartTime: time.Now(), + } + if err := srv.createJob(newJob); err != nil { + return "", fmt.Errorf("failed to create job in DB: %v", err) + } + + go func() { + defer logFile.Close() + err := cmd.Wait() + + newJob.Lock.Lock() + defer newJob.Lock.Unlock() + + if err != nil { + newJob.Status = "failed" + srv.log.Infof("Training job '%s' failed: %v", newJob.JobID, err) + } else if cmd.ProcessState.Success() { + newJob.Status = "finished" + srv.log.Infof("Training job '%s' finished successfully", newJob.JobID) + } else { + newJob.Status = "failed" + srv.log.Infof("Training job '%s' failed (unknown reason)", newJob.JobID) + } + now := time.Now() + newJob.EndTime = &now + _ = srv.updateJob(newJob) + }() + + return jobID, nil } // ----------------------------------------------------------------------------- @@ -605,261 +608,261 @@ func (srv *ILabServer) startTrainJob(modelName, branchName string, epochs *int) // ----------------------------------------------------------------------------- func (srv *ILabServer) generateTrainPipelineHandler(w http.ResponseWriter, r *http.Request) { - srv.log.Info("POST /pipeline/generate-train called") - - var reqBody struct { - ModelName string `json:"modelName"` - BranchName string `json:"branchName"` - Epochs *int `json:"epochs,omitempty"` - } - if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { - srv.log.Errorf("Error parsing request body: %v", err) - http.Error(w, "Invalid request body", http.StatusBadRequest) - return - } - if reqBody.ModelName == "" || reqBody.BranchName == "" { - srv.log.Info("Missing required parameters: modelName or branchName") - http.Error(w, "Missing required parameters: modelName or branchName", http.StatusBadRequest) - return - } - - sanitizedModelName := srv.sanitizeModelName(reqBody.ModelName) - srv.log.Infof("Sanitized modelName for pipeline: '%s'", sanitizedModelName) - - pipelineJobID := fmt.Sprintf("p-%d", time.Now().UnixNano()) - srv.log.Infof("Starting pipeline job with ID: %s", pipelineJobID) - - pipelineJob := &Job{ - JobID: pipelineJobID, - Cmd: "pipeline-generate-train", - Args: []string{sanitizedModelName, reqBody.BranchName}, - Status: "running", - PID: 0, // no direct OS process - LogFile: fmt.Sprintf("logs/%s.log", pipelineJobID), - Branch: reqBody.BranchName, - StartTime: time.Now(), - } - if err := srv.createJob(pipelineJob); err != nil { - srv.log.Errorf("Error creating pipeline job: %v", err) - http.Error(w, "Failed to create pipeline job", http.StatusInternalServerError) - return - } - - go srv.runPipelineJob(pipelineJob, sanitizedModelName, reqBody.BranchName, reqBody.Epochs) - - response := map[string]string{"pipeline_job_id": pipelineJobID} - w.Header().Set("Content-Type", "application/json") - _ = json.NewEncoder(w).Encode(response) - srv.log.Infof("POST /pipeline/generate-train => pipeline_job_id=%s", pipelineJobID) + srv.log.Info("POST /pipeline/generate-train called") + + var reqBody struct { + ModelName string `json:"modelName"` + BranchName string `json:"branchName"` + Epochs *int `json:"epochs,omitempty"` + } + if err := json.NewDecoder(r.Body).Decode(&reqBody); err != nil { + srv.log.Errorf("Error parsing request body: %v", err) + http.Error(w, "Invalid request body", http.StatusBadRequest) + return + } + if reqBody.ModelName == "" || reqBody.BranchName == "" { + srv.log.Info("Missing required parameters: modelName or branchName") + http.Error(w, "Missing required parameters: modelName or branchName", http.StatusBadRequest) + return + } + + sanitizedModelName := srv.sanitizeModelName(reqBody.ModelName) + srv.log.Infof("Sanitized modelName for pipeline: '%s'", sanitizedModelName) + + pipelineJobID := fmt.Sprintf("p-%d", time.Now().UnixNano()) + srv.log.Infof("Starting pipeline job with ID: %s", pipelineJobID) + + pipelineJob := &Job{ + JobID: pipelineJobID, + Cmd: "pipeline-generate-train", + Args: []string{sanitizedModelName, reqBody.BranchName}, + Status: "running", + PID: 0, // no direct OS process + LogFile: fmt.Sprintf("logs/%s.log", pipelineJobID), + Branch: reqBody.BranchName, + StartTime: time.Now(), + } + if err := srv.createJob(pipelineJob); err != nil { + srv.log.Errorf("Error creating pipeline job: %v", err) + http.Error(w, "Failed to create pipeline job", http.StatusInternalServerError) + return + } + + go srv.runPipelineJob(pipelineJob, sanitizedModelName, reqBody.BranchName, reqBody.Epochs) + + response := map[string]string{"pipeline_job_id": pipelineJobID} + w.Header().Set("Content-Type", "application/json") + _ = json.NewEncoder(w).Encode(response) + srv.log.Infof("POST /pipeline/generate-train => pipeline_job_id=%s", pipelineJobID) } // runPipelineJob orchestrates data generate + model train steps in sequence. func (srv *ILabServer) runPipelineJob(job *Job, modelName, branchName string, epochs *int) { - // Open the pipeline job log - logFile, err := os.Create(job.LogFile) - if err != nil { - srv.log.Errorf("Error creating pipeline log file for job %s: %v", job.JobID, err) - job.Status = "failed" - _ = srv.updateJob(job) - return - } - defer logFile.Close() - - stdLogger := zap.NewStdLog(srv.logger) - - // Redirect that standard logger's output to our log file - stdLogger.SetOutput(logFile) - - stdLogger.Printf("Starting pipeline job: %s, model: %s, branch: %s, epochs: %v", - job.JobID, modelName, branchName, epochs) - - // 1) Git checkout - gitCheckoutCmd := exec.Command("git", "checkout", branchName) - gitCheckoutCmd.Dir = srv.taxonomyPath - gitOutput, gitErr := gitCheckoutCmd.CombinedOutput() - stdLogger.Printf("Git checkout output: %s", string(gitOutput)) - if gitErr != nil { - stdLogger.Printf("Failed to checkout branch '%s': %v", branchName, gitErr) - job.Status = "failed" - _ = srv.updateJob(job) - return - } - - // 2) Generate data step - stdLogger.Println("Starting data generation step...") - genJobID, genErr := srv.startGenerateJob() - if genErr != nil { - stdLogger.Printf("Data generation step failed: %v", genErr) - job.Status = "failed" - _ = srv.updateJob(job) - return - } - stdLogger.Printf("Data generation step started with job_id=%s", genJobID) - - for { - time.Sleep(5 * time.Second) - genJob, err := srv.getJob(genJobID) - if err != nil || genJob == nil { - stdLogger.Printf("Data generation job %s not found or error: %v", genJobID, err) - job.Status = "failed" - _ = srv.updateJob(job) - return - } - if genJob.Status == "failed" { - stdLogger.Println("Data generation step failed.") - job.Status = "failed" - _ = srv.updateJob(job) - return - } - if genJob.Status == "finished" { - stdLogger.Println("Data generation step completed successfully.") - break - } - } - - // 3) Train step - stdLogger.Println("Starting training step...") - trainJobID, trainErr := srv.startTrainJob(modelName, branchName, epochs) - if trainErr != nil { - stdLogger.Printf("Training step failed to start: %v", trainErr) - job.Status = "failed" - _ = srv.updateJob(job) - return - } - stdLogger.Printf("Training step started with job_id=%s", trainJobID) - - for { - time.Sleep(5 * time.Second) - tJob, err := srv.getJob(trainJobID) - if err != nil || tJob == nil { - stdLogger.Printf("Training job %s not found or error: %v", trainJobID, err) - job.Status = "failed" - _ = srv.updateJob(job) - return - } - if tJob.Status == "failed" { - stdLogger.Println("Training step failed.") - job.Status = "failed" - _ = srv.updateJob(job) - return - } - if tJob.Status == "finished" { - stdLogger.Println("Training step completed successfully.") - break - } - } - - job.Status = "finished" - _ = srv.updateJob(job) - stdLogger.Println("Pipeline job completed successfully.") + // Open the pipeline job log + logFile, err := os.Create(job.LogFile) + if err != nil { + srv.log.Errorf("Error creating pipeline log file for job %s: %v", job.JobID, err) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + defer logFile.Close() + + stdLogger := zap.NewStdLog(srv.logger) + + // Redirect that standard logger's output to our log file + stdLogger.SetOutput(logFile) + + stdLogger.Printf("Starting pipeline job: %s, model: %s, branch: %s, epochs: %v", + job.JobID, modelName, branchName, epochs) + + // 1) Git checkout + gitCheckoutCmd := exec.Command("git", "checkout", branchName) + gitCheckoutCmd.Dir = srv.taxonomyPath + gitOutput, gitErr := gitCheckoutCmd.CombinedOutput() + stdLogger.Printf("Git checkout output: %s", string(gitOutput)) + if gitErr != nil { + stdLogger.Printf("Failed to checkout branch '%s': %v", branchName, gitErr) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + + // 2) Generate data step + stdLogger.Println("Starting data generation step...") + genJobID, genErr := srv.startGenerateJob() + if genErr != nil { + stdLogger.Printf("Data generation step failed: %v", genErr) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + stdLogger.Printf("Data generation step started with job_id=%s", genJobID) + + for { + time.Sleep(5 * time.Second) + genJob, err := srv.getJob(genJobID) + if err != nil || genJob == nil { + stdLogger.Printf("Data generation job %s not found or error: %v", genJobID, err) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + if genJob.Status == "failed" { + stdLogger.Println("Data generation step failed.") + job.Status = "failed" + _ = srv.updateJob(job) + return + } + if genJob.Status == "finished" { + stdLogger.Println("Data generation step completed successfully.") + break + } + } + + // 3) Train step + stdLogger.Println("Starting training step...") + trainJobID, trainErr := srv.startTrainJob(modelName, branchName, epochs) + if trainErr != nil { + stdLogger.Printf("Training step failed to start: %v", trainErr) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + stdLogger.Printf("Training step started with job_id=%s", trainJobID) + + for { + time.Sleep(5 * time.Second) + tJob, err := srv.getJob(trainJobID) + if err != nil || tJob == nil { + stdLogger.Printf("Training job %s not found or error: %v", trainJobID, err) + job.Status = "failed" + _ = srv.updateJob(job) + return + } + if tJob.Status == "failed" { + stdLogger.Println("Training step failed.") + job.Status = "failed" + _ = srv.updateJob(job) + return + } + if tJob.Status == "finished" { + stdLogger.Println("Training step completed successfully.") + break + } + } + + job.Status = "finished" + _ = srv.updateJob(job) + stdLogger.Println("Pipeline job completed successfully.") } // findLatestFileWithPrefix returns the newest file in dir that starts with prefix. func (srv *ILabServer) findLatestFileWithPrefix(dir, prefix string) (string, error) { - files, err := ioutil.ReadDir(dir) - if err != nil { - return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) - } - var latest os.FileInfo - for _, f := range files { - if f.IsDir() { - continue - } - if strings.HasPrefix(f.Name(), prefix) && strings.HasSuffix(f.Name(), ".jsonl") { - if latest == nil || f.ModTime().After(latest.ModTime()) { - latest = f - } - } - } - if latest == nil { - return "", fmt.Errorf("no file found in %s with prefix '%s'", dir, prefix) - } - return filepath.Join(dir, latest.Name()), nil + files, err := ioutil.ReadDir(dir) + if err != nil { + return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) + } + var latest os.FileInfo + for _, f := range files { + if f.IsDir() { + continue + } + if strings.HasPrefix(f.Name(), prefix) && strings.HasSuffix(f.Name(), ".jsonl") { + if latest == nil || f.ModTime().After(latest.ModTime()) { + latest = f + } + } + } + if latest == nil { + return "", fmt.Errorf("no file found in %s with prefix '%s'", dir, prefix) + } + return filepath.Join(dir, latest.Name()), nil } // overwriteCopy copies src to dst (overwrites if dst exists). func (srv *ILabServer) overwriteCopy(src, dst string) error { - input, err := ioutil.ReadFile(src) - if err != nil { - return err - } - if err := ioutil.WriteFile(dst, input, 0644); err != nil { - return err - } - return nil + input, err := ioutil.ReadFile(src) + if err != nil { + return err + } + if err := ioutil.WriteFile(dst, input, 0644); err != nil { + return err + } + return nil } // getFullModelPath returns the directory or file path for a given model name. func (srv *ILabServer) getFullModelPath(modelName string) (string, error) { - // If the user passed something like "models/instructlab/my-model" we keep it - // but place it in ~/.cache/instructlab/models/... - home, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("cannot find home directory: %v", err) - } - base := filepath.Join(home, ".cache", "instructlab") - return filepath.Join(base, modelName), nil + // If the user passed something like "models/instructlab/my-model" we keep it + // but place it in ~/.cache/instructlab/models/... + home, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("cannot find home directory: %v", err) + } + base := filepath.Join(home, ".cache", "instructlab") + return filepath.Join(base, modelName), nil } // runIlabCommand executes the ilab command with the provided arguments and returns combined output. func (srv *ILabServer) runIlabCommand(args ...string) (string, error) { - cmdPath := srv.getIlabCommand() - cmd := exec.Command(cmdPath, args...) - if !srv.rhelai { - cmd.Dir = srv.baseDir - } - out, err := cmd.CombinedOutput() - return string(out), err + cmdPath := srv.getIlabCommand() + cmd := exec.Command(cmdPath, args...) + if !srv.rhelai { + cmd.Dir = srv.baseDir + } + out, err := cmd.CombinedOutput() + return string(out), err } // parseModelList parses the output of the "ilab model list" command into a slice of Model. func (srv *ILabServer) parseModelList(output string) ([]Model, error) { - var models []Model - lines := strings.Split(output, "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "+") || strings.HasPrefix(line, "| Model Name") || line == "" { - continue - } - if strings.HasPrefix(line, "|") { - line = strings.Trim(line, "|") - fields := strings.Split(line, "|") - if len(fields) != 3 { - continue - } - model := Model{ - Name: strings.TrimSpace(fields[0]), - LastModified: strings.TrimSpace(fields[1]), - Size: strings.TrimSpace(fields[2]), - } - models = append(models, model) - } - } - return models, nil + var models []Model + lines := strings.Split(output, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "+") || strings.HasPrefix(line, "| Model Name") || line == "" { + continue + } + if strings.HasPrefix(line, "|") { + line = strings.Trim(line, "|") + fields := strings.Split(line, "|") + if len(fields) != 3 { + continue + } + model := Model{ + Name: strings.TrimSpace(fields[0]), + LastModified: strings.TrimSpace(fields[1]), + Size: strings.TrimSpace(fields[2]), + } + models = append(models, model) + } + } + return models, nil } // parseDataList parses the output of the "ilab data list" command into a slice of Data. func (srv *ILabServer) parseDataList(output string) ([]Data, error) { - var dataList []Data - lines := strings.Split(output, "\n") - for _, line := range lines { - line = strings.TrimSpace(line) - if strings.HasPrefix(line, "+") || strings.HasPrefix(line, "| Dataset") || line == "" { - continue - } - if strings.HasPrefix(line, "|") { - line = strings.Trim(line, "|") - fields := strings.Split(line, "|") - if len(fields) != 3 { - continue - } - data := Data{ - Dataset: strings.TrimSpace(fields[0]), - CreatedAt: strings.TrimSpace(fields[1]), - FileSize: strings.TrimSpace(fields[2]), - } - dataList = append(dataList, data) - } - } - return dataList, nil + var dataList []Data + lines := strings.Split(output, "\n") + for _, line := range lines { + line = strings.TrimSpace(line) + if strings.HasPrefix(line, "+") || strings.HasPrefix(line, "| Dataset") || line == "" { + continue + } + if strings.HasPrefix(line, "|") { + line = strings.Trim(line, "|") + fields := strings.Split(line, "|") + if len(fields) != 3 { + continue + } + data := Data{ + Dataset: strings.TrimSpace(fields[0]), + CreatedAt: strings.TrimSpace(fields[1]), + FileSize: strings.TrimSpace(fields[2]), + } + dataList = append(dataList, data) + } + } + return dataList, nil } diff --git a/api-server/utils.go b/api-server/utils.go index c9a7ae02..6d6323b9 100644 --- a/api-server/utils.go +++ b/api-server/utils.go @@ -1,140 +1,140 @@ package main import ( - "fmt" - "io" - "io/ioutil" - "os" - "path/filepath" - "strings" + "fmt" + "io" + "io/ioutil" + "os" + "path/filepath" + "strings" ) // findLatestFileWithPrefix scans `dir` for all files whose name starts with `prefix`, // and returns the path of the latest modified file. func findLatestFileWithPrefix(dir, prefix string) (string, error) { - files, err := ioutil.ReadDir(dir) - if err != nil { - return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) - } - - var latestFile os.FileInfo - for _, f := range files { - if strings.HasPrefix(f.Name(), prefix) && strings.HasSuffix(f.Name(), ".jsonl") { - if latestFile == nil || f.ModTime().After(latestFile.ModTime()) { - latestFile = f - } - } - } - if latestFile == nil { - return "", fmt.Errorf("no file found matching prefix '%s' in '%s'", prefix, dir) - } - return filepath.Join(dir, latestFile.Name()), nil + files, err := ioutil.ReadDir(dir) + if err != nil { + return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) + } + + var latestFile os.FileInfo + for _, f := range files { + if strings.HasPrefix(f.Name(), prefix) && strings.HasSuffix(f.Name(), ".jsonl") { + if latestFile == nil || f.ModTime().After(latestFile.ModTime()) { + latestFile = f + } + } + } + if latestFile == nil { + return "", fmt.Errorf("no file found matching prefix '%s' in '%s'", prefix, dir) + } + return filepath.Join(dir, latestFile.Name()), nil } // overwriteCopy removes `destPath` if it exists, then copies srcPath -> destPath. func overwriteCopy(srcPath, destPath string) error { - // If the destination file already exists, remove it - if _, err := os.Stat(destPath); err == nil { - if err := os.Remove(destPath); err != nil { - return fmt.Errorf("could not remove existing file '%s': %v", destPath, err) - } - } - - // Open the source - in, err := os.Open(srcPath) - if err != nil { - return fmt.Errorf("could not open source file '%s': %v", srcPath, err) - } - defer in.Close() - - // Create the destination - out, err := os.Create(destPath) - if err != nil { - return fmt.Errorf("could not create dest file '%s': %v", destPath, err) - } - defer out.Close() - - // Copy contents - if _, err := io.Copy(out, in); err != nil { - return fmt.Errorf("failed to copy '%s' to '%s': %v", srcPath, destPath, err) - } - - return nil + // If the destination file already exists, remove it + if _, err := os.Stat(destPath); err == nil { + if err := os.Remove(destPath); err != nil { + return fmt.Errorf("could not remove existing file '%s': %v", destPath, err) + } + } + + // Open the source + in, err := os.Open(srcPath) + if err != nil { + return fmt.Errorf("could not open source file '%s': %v", srcPath, err) + } + defer in.Close() + + // Create the destination + out, err := os.Create(destPath) + if err != nil { + return fmt.Errorf("could not create dest file '%s': %v", destPath, err) + } + defer out.Close() + + // Copy contents + if _, err := io.Copy(out, in); err != nil { + return fmt.Errorf("failed to copy '%s' to '%s': %v", srcPath, destPath, err) + } + + return nil } // getBaseCacheDir returns the base cache directory path: ~/.cache/instructlab/ func getBaseCacheDir() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get user home directory: %v", err) - } - return filepath.Join(homeDir, ".cache", "instructlab"), nil + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %v", err) + } + return filepath.Join(homeDir, ".cache", "instructlab"), nil } // getFullModelPath converts a user-supplied model name into a fully qualified path: // // ~/.cache/instructlab/models/ func getFullModelPath(modelName string) (string, error) { - baseCacheDir, err := getBaseCacheDir() - if err != nil { - return "", err - } - // If user-supplied name already starts with "models/", don't prepend again - if strings.HasPrefix(modelName, "models/") { - return filepath.Join(baseCacheDir, modelName), nil - } - return filepath.Join(baseCacheDir, "models", modelName), nil + baseCacheDir, err := getBaseCacheDir() + if err != nil { + return "", err + } + // If user-supplied name already starts with "models/", don't prepend again + if strings.HasPrefix(modelName, "models/") { + return filepath.Join(baseCacheDir, modelName), nil + } + return filepath.Join(baseCacheDir, "models", modelName), nil } // findLatestDirWithPrefix finds the most recently modified directory within 'dir' that starts with 'prefix'. func (srv *ILabServer) findLatestDirWithPrefix(dir, prefix string) (string, error) { - entries, err := ioutil.ReadDir(dir) - if err != nil { - return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) - } - - var latestDir os.FileInfo - for _, entry := range entries { - if !entry.IsDir() { - continue - } - if strings.HasPrefix(entry.Name(), prefix) { - if latestDir == nil || entry.ModTime().After(latestDir.ModTime()) { - latestDir = entry - } - } - } - - if latestDir == nil { - return "", fmt.Errorf("no directory found in '%s' with prefix '%s'", dir, prefix) - } - - latestPath := filepath.Join(dir, latestDir.Name()) - return latestPath, nil + entries, err := ioutil.ReadDir(dir) + if err != nil { + return "", fmt.Errorf("failed to read directory '%s': %v", dir, err) + } + + var latestDir os.FileInfo + for _, entry := range entries { + if !entry.IsDir() { + continue + } + if strings.HasPrefix(entry.Name(), prefix) { + if latestDir == nil || entry.ModTime().After(latestDir.ModTime()) { + latestDir = entry + } + } + } + + if latestDir == nil { + return "", fmt.Errorf("no directory found in '%s' with prefix '%s'", dir, prefix) + } + + latestPath := filepath.Join(dir, latestDir.Name()) + return latestPath, nil } // getLatestDatasetFile returns the path to the latest dataset file named "knowledge_train_msgs_*.jsonl". func (srv *ILabServer) getLatestDatasetFile() (string, error) { - homeDir, err := os.UserHomeDir() - if err != nil { - return "", fmt.Errorf("failed to get user home directory: %v", err) - } - datasetDir := filepath.Join(homeDir, ".local", "share", "instructlab", "datasets") - files, err := ioutil.ReadDir(datasetDir) - if err != nil { - return "", fmt.Errorf("failed to read dataset directory: %v", err) - } - - var latestFile os.FileInfo - for _, file := range files { - if strings.HasPrefix(file.Name(), "knowledge_train_msgs_") && strings.HasSuffix(file.Name(), ".jsonl") { - if latestFile == nil || file.ModTime().After(latestFile.ModTime()) { - latestFile = file - } - } - } - if latestFile == nil { - return "", fmt.Errorf("no dataset file found with the prefix 'knowledge_train_msgs_'") - } - return filepath.Join(datasetDir, latestFile.Name()), nil + homeDir, err := os.UserHomeDir() + if err != nil { + return "", fmt.Errorf("failed to get user home directory: %v", err) + } + datasetDir := filepath.Join(homeDir, ".local", "share", "instructlab", "datasets") + files, err := ioutil.ReadDir(datasetDir) + if err != nil { + return "", fmt.Errorf("failed to read dataset directory: %v", err) + } + + var latestFile os.FileInfo + for _, file := range files { + if strings.HasPrefix(file.Name(), "knowledge_train_msgs_") && strings.HasSuffix(file.Name(), ".jsonl") { + if latestFile == nil || file.ModTime().After(latestFile.ModTime()) { + latestFile = file + } + } + } + if latestFile == nil { + return "", fmt.Errorf("no dataset file found with the prefix 'knowledge_train_msgs_'") + } + return filepath.Join(datasetDir, latestFile.Name()), nil } diff --git a/api-server/zap.go b/api-server/zap.go index 94b22132..538d8c94 100644 --- a/api-server/zap.go +++ b/api-server/zap.go @@ -29,6 +29,6 @@ func (srv *ILabServer) initLogger(debug bool) { srv.log = logger.Sugar() if debug { - srv.log.Debug("Debug logging is ENABLED.") + srv.log.Debug("Debug logging is enabled.") } }