Skip to content

Commit

Permalink
Support Gemini 1.5 via Google AI studio playground (#9)
Browse files Browse the repository at this point in the history
* start work on AI studio proxy

* add a test

* add chat

* add gemini 1.5 doc

* lint

* add function calling support for ai studio

* update viz to support gemini 1.5 over AI studio
  • Loading branch information
extremeheat authored Feb 28, 2024
1 parent b34f389 commit 3cad49f
Show file tree
Hide file tree
Showing 14 changed files with 525 additions and 72 deletions.
30 changes: 30 additions & 0 deletions examples/gemini1.5.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
/* eslint-disable no-unused-vars */
// Gemini 1.5 Pro is *not* currently publicly available.
// This is a demo that will only work if you have access to 1.5 Pro in the Google AI Studio playground *and*
// a special user script (like an extension) that you can run to allow langxlang to use the browser as an API.

const { GoogleAIStudioCompletionService, ChatSession } = require('langxlang')

async function testCompletion () {
// Use port 8095 to host the websocket server
const service = new GoogleAIStudioCompletionService(8095)
await service.ready
const response = await service.requestCompletion('gemini-1.5-pro', '', 'Why is the sky blue?')
console.log('Result', response.text)
}

// With ChatSessions
async function testChatSession () {
const service = new GoogleAIStudioCompletionService(8095)
await service.ready
const session = new ChatSession(service, 'gemini-1.5-pro', '')
const message = await session.sendMessage('Hello! Why is the sky blue?')
console.log('Done', message.length, 'bytes', 'now asking a followup')
// ask related question about the response
const followup = await session.sendMessage('Is this the case everywhere on Earth, what about the poles?')
console.log('Done', followup.text.length, 'bytes')
}

// In order to run this example, you need to have the Google AI Studio user script client running
// that will connect to the WebSocket server running the specified port (8095 in this example)
// The client code is a user script that you can run in the Google AI Studio playground.
3 changes: 2 additions & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
"@google/generative-ai": "^0.2.1",
"acorn": "^8.11.3",
"debug": "^4.3.4",
"openai": "^4.28.0"
"openai": "^4.28.0",
"ws": "^8.16.0"
}
}
118 changes: 56 additions & 62 deletions src/ChatSession.js
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
const { cleanMessage } = require('./util')
const { convertFunctionsToOpenAI, convertFunctionsToGemini } = require('./functions')
const { convertFunctionsToOpenAI, convertFunctionsToGemini, convertFunctionsToGoogleAIStudio } = require('./functions')
const { getModelInfo } = require('./util')
const debug = require('debug')('lxl')

Expand All @@ -23,8 +23,13 @@ class ChatSession {

async _loadFunctions (functions) {
const modelInfo = getModelInfo(this.model)
this.modelAuthor = modelInfo.author
this.modelFamily = modelInfo.family
if (modelInfo.family === 'openai') {
if (modelInfo.author === 'googleaistudio') {
const { result, metadata } = await convertFunctionsToGoogleAIStudio(functions)
this.functionsPayload = result
this.functionsMeta = metadata
} else if (modelInfo.family === 'openai') {
const { result, metadata } = await convertFunctionsToOpenAI(functions)
this.functionsPayload = result
this.functionsMeta = metadata
Expand All @@ -37,39 +42,45 @@ class ChatSession {
debug('Loaded function metadata: ' + JSON.stringify(this.functionsMeta))
}

async _callFunctionWithArgs (functionName, payload) {
const fnMeta = this.functionsMeta[functionName]
const fn = this.functions[functionName]
// payload is an object of { argName: argValue } ... since order is not guaranteed we need to handle it here
const args = []
for (const param in payload) {
const value = payload[param]
const index = fnMeta.argNames.indexOf(param)
args[index] = value
}
// Set default values if they're not provided
for (let i = 0; i < fnMeta.args.length; i++) {
const meta = fnMeta.args[i]
if (!args[i]) {
if (meta.default) {
args[i] = meta.default
}
}
}
const result = await fn.apply(null, args.map(e => e))
return result
}

// This calls a function and adds the reponse to the context so the model can be called again
async _callFunction (functionName, payload, metadata) {
if (this.modelFamily === 'openai') {
if (this.modelAuthor === 'googleaistudio') {
let content
if (metadata.text) content = metadata.text + '\n'
content = content.trim()
const arStr = Object.keys(payload).length ? JSON.stringify(payload) : ''
content += `\n<FUNCTION_CALL>${functionName}(${arStr})</FUNCTION_CALL>`
this.messages.push({ role: 'assistant', content })
const result = await this._callFunctionWithArgs(functionName, payload)
this.messages.push({ role: 'function', name: functionName, content: JSON.stringify(result) })
} else if (this.modelFamily === 'openai') {
// https://openai.com/blog/function-calling-and-other-api-updates
this.messages.push({ role: 'assistant', function_call: { name: functionName, arguments: JSON.stringify(payload) } })

const fnMeta = this.functionsMeta[functionName]

if (this.functionsPayload.length === 0) {
const fn = this.functions[functionName]
const result = await fn()
this.messages.push({ role: 'function', name: functionName, content: result })
} else {
const fn = this.functions[functionName]
// payload is an object of { argName: argValue } ... since order is not guaranteed we need to handle it here
const args = []
for (const param in payload) {
const value = payload[param]
const index = fnMeta.argNames.indexOf(param)
args[index] = value
}
// Set default values if they're not provided
for (let i = 0; i < fnMeta.args.length; i++) {
const meta = fnMeta.args[i]
if (args[i] === undefined) {
if (meta.default !== undefined) {
args[i] = meta.default
}
}
}
const result = await fn.apply(null, args.map(e => e))
this.messages.push({ role: 'function', name: functionName, content: JSON.stringify(result) })
}
const result = await this._callFunctionWithArgs(functionName, payload)
this.messages.push({ role: 'function', name: functionName, content: JSON.stringify(result) })
} else if (this.modelFamily === 'gemini') {
/*
{
Expand Down Expand Up @@ -99,35 +110,10 @@ class ChatSession {
]
}
*/
const fnMeta = this.functionsMeta[functionName]
this.messages.push({ role: 'model', parts: [{ functionCall: { name: functionName, args: payload } }] })

// if there's 1 function, we can just call it directly
if (this.functionsPayload.length === 0) {
const fn = this.functions[functionName]
const result = await fn()
this.messages.push({ role: 'function', parts: [{ functionResponse: { name: functionName, response: { name: functionName, content: result } } }] })
} else {
const fn = this.functions[functionName]
// payload is an object of { argName: argValue } ... since order is not guaranteed we need to handle it here
const args = []
for (const param in payload) {
const value = payload[param]
const index = fnMeta.argNames.indexOf(param)
args[index] = value
}
// Set default values if they're not provided
for (let i = 0; i < fnMeta.args.length; i++) {
const meta = fnMeta.args[i]
if (!args[i]) {
if (meta.default) {
args[i] = meta.default
}
}
}
const result = await fn.apply(null, args.map(e => e))
this.messages.push({ role: 'function', parts: [{ functionResponse: { name: functionName, response: { name: functionName, content: result } } }] })
}
this.messages.push({ role: 'model', parts: [{ functionCall: { name: functionName, args: payload } }] })
const result = await this._callFunctionWithArgs(functionName, payload)
this.messages.push({ role: 'function', parts: [{ functionResponse: { name: functionName, response: { name: functionName, content: result } } }] })
}
}

Expand All @@ -136,7 +122,7 @@ class ChatSession {
}

async _submitRequest (chunkCb) {
// console.log('Sending to', this.model, this.messages)
debug('Sending to', this.model, this.messages)
const response = await this.service.requestStreamingChat(this.model, {
maxTokens: this.maxTokens,
messages: this.messages,
Expand All @@ -145,11 +131,19 @@ class ChatSession {
}, chunkCb)
debug('Streaming response', JSON.stringify(response))
if (response.type === 'function') {
this._calledFunctionsForRound.push(response.fnName)
this._calledFunctionsForRound.push(response.fnCalls)
if (Array.isArray(response.fnCalls) && !response.fnCalls.length) {
throw new Error('No function calls returned, but type is function')
}
// we need to call the function with the payload and then send the result back to the model
for (const index in response.fnCalls) {
const call = response.fnCalls[index]
await this._callFunction(call.name, JSON.parse(call.args))
await this._callFunction(call.name, call.args ? JSON.parse(call.args) : {}, response)
}
// Google AI Studio: We can only send one req/second... TODO: throttle this internally
if (this.modelAuthor === 'googleaistudio') {
// throttle a bit to avoid rate limiting :(
await new Promise((resolve) => setTimeout(resolve, 500))
}
return this._submitRequest(chunkCb)
} else if (response.type === 'text') {
Expand Down
32 changes: 32 additions & 0 deletions src/GoogleAIStudioCompletionService.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
const studio = require('./googleAIStudio')

const supportedModels = ['gemini-1.0-pro', 'gemini-1.5-pro']

class GoogleAIStudioCompletionService {
constructor (serverPort) {
this.serverPort = serverPort
this.ready = studio.runServer(serverPort)
}

stop () {
studio.stopServer()
}

async requestCompletion (model, system, user, chunkCb) {
if (!supportedModels.includes(model)) {
throw new Error(`Model ${model} is not supported`)
}
const result = await studio.generateCompletion(model, system + '\n' + user, chunkCb)
return { text: result.text }
}

async requestStreamingChat (model, { messages, maxTokens, functions }, chunkCb) {
if (!supportedModels.includes(model)) {
throw new Error(`Model ${model} is not supported`)
}
const result = await studio.requestChatCompletion(model, messages, chunkCb, { maxTokens, functions })
return { ...result, completeMessage: result.text }
}
}

module.exports = GoogleAIStudioCompletionService
13 changes: 12 additions & 1 deletion src/functions.js
Original file line number Diff line number Diff line change
Expand Up @@ -196,4 +196,15 @@ async function convertFunctionsToGemini (functions) {
return { result, metadata: openai.metadata }
}

module.exports = { Arg, Desc, processFunctions, convertArgToOpenAI, convertFunctionsToOpenAI, convertFunctionsToGemini }
async function convertFunctionsToGoogleAIStudio (functions) {
const openai = await convertFunctionsToOpenAI(functions)
const result = openai.result.map((e) => {
return [e.function.name, {
description: e.function.description,
parameters: e.function.parameters
}]
})
return { result: Object.fromEntries(result), metadata: openai.metadata }
}

module.exports = { Arg, Desc, processFunctions, convertArgToOpenAI, convertFunctionsToOpenAI, convertFunctionsToGemini, convertFunctionsToGoogleAIStudio }
Loading

0 comments on commit 3cad49f

Please sign in to comment.