Skip to content

Commit

Permalink
Merge pull request #325 from nerdalert/prechat-context-bug
Browse files Browse the repository at this point in the history
Resolve precheck context passing
  • Loading branch information
mergify[bot] authored May 2, 2024
2 parents 5ccd2b0 + 5937201 commit e196b41
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 9 deletions.
2 changes: 1 addition & 1 deletion worker/Containerfile.servebase
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ RUN go build -o instructlab-bot-worker main.go && \

FROM fedora:latest as base

RUN dnf install -y python openssh git python3-pip make automake gcc gcc-c++ python3-devel && \
RUN dnf install -y python openssh git python3-pip make automake gcc gcc-c++ python3-devel procps && \
mkdir ~/.ssh && ssh-keyscan github.com > ~/.ssh/known_hosts && \
python -m ensurepip && \
dnf install -y gcc cmake gcc-c++ && \
Expand Down
26 changes: 18 additions & 8 deletions worker/cmd/generate.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ type Worker struct {
tlsClientKeyPath string
tlsServerCaCertPath string
maxSeed int
cmdRun string
}

func NewJobProcessor(ctx context.Context, pool *redis.Pool, svc *s3.Client, logger *zap.SugaredLogger, job, precheckEndpoint, sdgEndpoint, tlsClientCertPath, tlsClientKeyPath, tlsServerCaCertPath string, maxSeed int) *Worker {
Expand Down Expand Up @@ -322,28 +323,33 @@ func (w *Worker) runPrecheck(lab, outputDir, modelName string) error {
continue
}

chatArgs := []string{"chat", "--quick-question", question}
context, hasContext := example["context"].(string)
// Slicing args breaks ilab chat for context, use Sprintf to control spacing
if hasContext {
chatArgs = append(chatArgs, "--context", context)
// Append the context to the question with a specific format
question = fmt.Sprintf("%s Answer this based on the following context: %s.", question, context)
}
commandStr := fmt.Sprintf("chat --quick-question %s", question)
if TlsInsecure {
chatArgs = append(chatArgs, "--tls-insecure")
commandStr += " --tls-insecure"
}
if PreCheckEndpointURL != localEndpoint && modelName != "unknown" {
chatArgs = append(chatArgs, "--endpoint-url", PreCheckEndpointURL, "--model", modelName)
commandStr += fmt.Sprintf(" --endpoint-url %s --model %s", PreCheckEndpointURL, modelName)
}

cmd := exec.Command(lab, chatArgs...)
cmdArgs := strings.Fields(commandStr)
cmd := exec.Command(lab, cmdArgs...)
w.cmdRun = cmd.String()
w.logger.Infof("Running the precheck command: %s", cmd.String())

cmd.Dir = workDir
cmd.Env = os.Environ()
cmd.Stderr = os.Stderr
var out bytes.Buffer
var errOut bytes.Buffer
cmd.Stdout = &out
cmd.Stderr = &errOut
err = cmd.Run()
if err != nil {
w.logger.Error(err)
w.logger.Errorf("Precheck command failed with error: %v; stderr: %s", err, errOut.String())
continue
}

Expand Down Expand Up @@ -757,6 +763,10 @@ func (w *Worker) postJobResults(URL, jobType string) {
w.logger.Errorf("Could not set s3_url in redis: %v", err)
}

if _, err := conn.Do("SET", fmt.Sprintf("jobs:%s:cmd", w.job), w.cmdRun); err != nil {
w.logger.Errorf("Could not set cmd in redis: %v", err)
}

modelName := w.determineModelName(jobType)

if _, err := conn.Do("SET", fmt.Sprintf("jobs:%s:model_name", w.job), modelName); err != nil {
Expand Down

0 comments on commit e196b41

Please sign in to comment.