Skip to content

Commit

Permalink
server&plugin: simplify API
Browse files Browse the repository at this point in the history
  • Loading branch information
bzz committed Oct 20, 2023
1 parent d34e9d9 commit d058065
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 137 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ Internal hackathon version for IJ 2023.2

### Added
- Initial scaffold from [IntelliJ Platform Plugin Template](https://github.com/JetBrains/intellij-platform-plugin-template)
- server.py & ghost/inline code completion provider
- server.py, plugin for 2023.2 with ghost/inline code completion provider and a tool window


## [Unreleased]
Expand Down
10 changes: 6 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
[![Version](https://img.shields.io/jetbrains/plugin/v/PLUGIN_ID.svg)](https://plugins.jetbrains.com/plugin/PLUGIN_ID)
[![Downloads](https://img.shields.io/jetbrains/plugin/d/PLUGIN_ID.svg)](https://plugins.jetbrains.com/plugin/PLUGIN_ID)

## Run model inference

To run OSS LLM use:
* a simple local [inference server](./server).



## Plugin development

Expand All @@ -23,10 +29,6 @@
Download the [latest release](https://github.com/mloncode/intellij/releases/latest) and install it manually using
<kbd>Settings/Preferences</kbd> > <kbd>Plugins</kbd> > <kbd>⚙️</kbd> > <kbd>Install plugin from disk...</kbd>

## Running LLM Inference

To run OSS LLM using a simple inference server follow instructions in [./server].


<!-- Plugin description -->
Use Open Source LLMs for code completion in IntelliJ IDEs.
Expand Down
15 changes: 6 additions & 9 deletions server/README.md
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
# Python server
Usage:

`python server.py [-p PORT] [-a HOSTNAME| address]`

Default address is `localhost`, default port is `8000`
Usage:

Call to API using curl:
`python server.py -p 8080 --model`

`curl -X POST [ADDRESS]:[PORT] --data "your_text_here"`

In response json in form `{"text": "another_text_here"}` will be returned.
Call the API using curl:

Example:
```sh
curl -X POST localhost:8080/generate --data "{ 'prompt': 'def ping_with_back_off():\n ' }"
```

`python server.py -p 8000 -a localhost`
89 changes: 0 additions & 89 deletions server/models/huggingface_models.py

This file was deleted.

93 changes: 67 additions & 26 deletions server/server.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,107 @@
from functools import partial
from http.server import HTTPServer, BaseHTTPRequestHandler
import argparse
import json
from models.huggingface_models import codet5_base_model, starcoder_model
import re

models = [{"model_name": "StarCoder"}, {"model_name": "codeT5-base"}]
maper = {
"StarCoder": starcoder_model,
"codeT5-base": codet5_base_model
}
import torch
import transformers
from transformers import BitsAndBytesConfig, AutoModelForCausalLM, AutoTokenizer


models = {}

class RequestHandler(BaseHTTPRequestHandler):
def __init__(self, model_pipeline, *args, **kwargs):
self.pipeline = model_pipeline
super().__init__(*args, **kwargs)

def _set_headers(self):
self.send_response(200)
self.send_header("Content-type", "application/json")
self.end_headers()

def do_POST(self):
def do_POST(self): #TODO: add a route /generate
self._set_headers()
content_len = int(self.headers.get("Content-Length", 0))
post_body = self.rfile.read(content_len)
json_data = json.loads(post_body)
text_received = json_data["prompt"]
model = maper[json_data["model"]]
processed_texts = model(text_received, json_data["max_new_tokens"])
start_index = len(text_received) if json_data["model"] == "StarCoder" else 0
json_bytes = json.dumps(
{"results" : list(map(lambda x: {"text": x[start_index:]}, processed_texts))}
).encode("utf-8")
print(json_bytes)
self.wfile.write(json_bytes)
text = json_data["prompt"]
max_new_tokens = json_data.get("max_new_tokens", 200)

outputs = self.pipeline(
text,
do_sample=True,
temperature=0.1,
top_p=0.95,
num_return_sequences=1,
max_new_tokens=max_new_tokens,
# max_length=200,
)

# model = models[json_data["model"]]
# processed_texts = model(text, json_data["max_new_tokens"])
results = {"results" : [o["generated_text"][len(text):] for o in outputs]}
self.wfile.write(json.dumps(results).encode("utf-8"))

def do_GET(self):
self._set_headers()
models_json = json.dumps({"models": models})
models_json = json.dumps({"models": list(models.keys())})
self.wfile.write(models_json.encode("utf-8"))


def run(port, addr):
server_address = (addr, port)
httpd = HTTPServer(server_address, RequestHandler)
print(f"Starting httpd server on {addr}:{port}")
def run(args):
# load a model
kwargs = {}
if (re.search("codellama", args.model) or re.search("starcoder(base)?", args.model)) and not torch.backends.mps.is_available():
nf4_config = BitsAndBytesConfig(
load_in_4bit=True,
bnb_4bit_quant_type="nf4",
bnb_4bit_use_double_quant=True,
bnb_4bit_compute_dtype=torch.bfloat16
)
kwargs["quantization_config"] = nf4_config

print(f"Loading {args.model}")
model = AutoModelForCausalLM.from_pretrained(args.model, **kwargs)
tokenizer = AutoTokenizer.from_pretrained(args.model)
global models
models[args.model] = model

pipeline = transformers.pipeline(
"text-generation",
model=model,
torch_dtype=torch.bfloat16,
device="mps" if torch.backends.mps.is_available() else "auto",
tokenizer=tokenizer,
eos_token_id=tokenizer.eos_token_id
)
# print(f"loaded to device {model.hf_device_map}")

server_address = (args.host, args.port)
httpd = HTTPServer(server_address, partial(RequestHandler, pipeline))
print(f"Starting http server on {args.host}:{args.port}")
httpd.serve_forever()


if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Run a python server.")
parser = argparse.ArgumentParser(description="Run a python OSS LLM inference server")
parser.add_argument(
"-p",
dest="port",
type=int,
help="Specify the port, default is 8000",
help="Specify the port",
default=8000,
required=False,
)
parser.add_argument(
"-a",
dest="address",
dest="host",
type=str,
help="Specify the address, default is localhost",
help="Specify the address",
default="localhost",
required=False,
)
parser.add_argument("--model", type=str, default="bigcode/starcoderbase-1b", help="HF checkpoint to use")
args = parser.parse_args()
run(port=args.port, addr=args.address)
run(args)
2 changes: 1 addition & 1 deletion src/main/kotlin/org/intellij/ml/llm/server/OSSLLMModels.kt
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ object OSSLLMModels {
private val logger = thisLogger()

val currentModelIndex = AtomicInteger(0)
val modelsList: List<String> = OSSLLMServer.getAvailableModels().models.map { it.modelName }
val modelsList: List<String> = OSSLLMServer.getAvailableModels().models

fun currentModel() : String? {
return modelsList.getOrNull(currentModelIndex.get())
Expand Down
13 changes: 7 additions & 6 deletions src/main/kotlin/org/intellij/ml/llm/server/OSSLLMServer.kt
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,9 @@ import java.net.http.HttpResponse
object OSSLLMServer {
private const val MAX_TOKEN_LENGTH = 10

// @Serializable data class LLMModel (@SerializedName("model_name") val modelName: String)
@Serializable data class LLMModelsResponse(val models: List<String>)
@Serializable data class LLMResponse (val results: List<String>)
@Serializable data class LLMModelsResponse(val models: List<LLMModel>)
@Serializable data class LLMModel (@SerializedName("model_name") val modelName: String)
@Serializable data class LLMRequest(
val model: String,
val prompt: String,
Expand Down Expand Up @@ -54,13 +54,14 @@ object OSSLLMServer {

fun getAvailableModels(host: String = "localhost", port: Int = 8000): LLMModelsResponse {
val url = URI.create("http://$host:$port")
val request = HttpRequest.newBuilder()
val getRequest = HttpRequest.newBuilder()
.uri(url)
.GET()
.timeout(java.time.Duration.ofSeconds(10))
.build()

return serverQuery(url.toString(), request, LLMModelsResponse::class.java)
val models = serverQuery(url.toString(), getRequest, LLMModelsResponse::class.java)
return models
}

fun getSuggestions(context: String, host: String = "localhost", port: Int = 8000): LLMResponse {
Expand All @@ -71,13 +72,13 @@ object OSSLLMServer {
)
val json = Gson().toJson(queryParams)
val url = URI.create("http://$host:$port")
val request = HttpRequest.newBuilder()
val postRequest = HttpRequest.newBuilder()
.uri(url)
.POST(HttpRequest.BodyPublishers.ofString(json))
.header("Content-type", "application/json")
.timeout(java.time.Duration.ofSeconds(10))
.build()

return serverQuery(url.toString(), request, LLMResponse::class.java)
return serverQuery(url.toString(), postRequest, LLMResponse::class.java)
}
}
2 changes: 1 addition & 1 deletion src/main/resources/META-INF/plugin.xml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
<resource-bundle>messages.MyBundle</resource-bundle>

<extensions defaultExtensionNs="com.intellij">
<toolWindow factoryClass="org.intellij.ml.llm.toolWindow.LLMToolWindowFactory" id="LLM"/>
<toolWindow factoryClass="org.intellij.ml.llm.toolWindow.LLMToolWindowFactory" id="OSS LLMs"/>
<inline.completion.provider implementation="org.intellij.ml.llm.OSSLLMCompletionProvider"></inline.completion.provider>
</extensions>

Expand Down

0 comments on commit d058065

Please sign in to comment.