-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathmodel.js
90 lines (74 loc) · 2.37 KB
/
model.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import { GoogleGenerativeAI } from 'https://cdn.jsdelivr.net/npm/@google/generative-ai/+esm'
import { FilesetResolver, LlmInference } from 'https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai'
import gemini_api_key from './key.js'
const connectRemoteModel = async (options = {}) => {
const genAI = new GoogleGenerativeAI(gemini_api_key)
const defaultOption = { model: 'gemini-1.5-flash' }
const model = genAI.getGenerativeModel({
...defaultOption,
...options
})
const generateResponse = async (content, display) => {
try {
const result = await model.generateContentStream([content])
for await (const chunk of result.stream) {
display(chunk.text())
}
display('', true)
} catch (error) {
throw error.message
}
}
return { generateResponse }
}
const connectLocalModel = async (options = {}) => {
const genAI = await FilesetResolver.forGenAiTasks('https://cdn.jsdelivr.net/npm/@mediapipe/tasks-genai/wasm')
const defaultOption = {
baseOptions: { modelAssetPath: '/models/gemma-1.1-2b-it-gpu-int4.bin' },
maxTokens: 5120,
randomSeed: 1,
topK: 1,
temperature: 1.0
}
try {
const model = await LlmInference.createFromOptions(genAI, {
...defaultOption,
...options
})
const generateResponse = (content, display) => model.generateResponse(content, display)
return { generateResponse }
} catch (error) {
throw error.message
}
}
const autoResizeOutput = (output) => {
if (output.textContent) {
output.style.display = 'block'
output.style.height = 'auto'
output.style.height = output.scrollHeight + 'px'
} else {
output.style.display = 'none'
}
}
const displayResult = (output, button, callback) => (partialResults, complete) => {
output.style.color = 'blue'
output.textContent += partialResults.replaceAll('*', '')
if (complete) {
button.disabled = false
button.innerText = button.title
if (typeof complete === 'function') {
complete(output.textContent)
output.style.color = 'red'
} else if (typeof callback === 'function') {
callback(output.textContent)
}
}
autoResizeOutput(output)
}
const clearResult = (output, button, message) => {
button.disabled = true
button.innerText = message
output.textContent = ''
autoResizeOutput(output)
}
export { connectRemoteModel, connectLocalModel, displayResult, clearResult }