Skip to content
This repository has been archived by the owner on Jun 12, 2024. It is now read-only.

Commit

Permalink
Add greetings and input labels for AIs
Browse files Browse the repository at this point in the history
  • Loading branch information
janpawellek committed Sep 13, 2023
1 parent 489acf6 commit 1e72274
Show file tree
Hide file tree
Showing 16 changed files with 135 additions and 26 deletions.
12 changes: 11 additions & 1 deletion backaind/aifile.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,28 @@ def add_ai(aifile_path):
aifile = read_aifile_from_path(aifile_path)
name = aifile["name"]
input_keys = list(get_input_keys(aifile))
input_labels = aifile.get("input_labels")
chain = aifile["chain"]
greeting = aifile.get("greeting")

existing_ai = db.session.query(Ai).filter_by(name=name).first()

if existing_ai is None:
new_ai = Ai(name=name, input_keys=input_keys, chain=chain)
new_ai = Ai(
name=name,
input_keys=input_keys,
input_labels=input_labels,
chain=chain,
greeting=greeting,
)
db.session.add(new_ai)
db.session.commit()
click.echo(f"Added {name}. Say hello!")
else:
existing_ai.input_keys = input_keys
existing_ai.input_labels = input_labels
existing_ai.chain = chain
existing_ai.greeting = greeting
existing_ai.name = name
db.session.commit()
click.echo(f"Updated {name}. Say hello!")
Expand Down
2 changes: 2 additions & 0 deletions backaind/ainteraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ def index():
"id": aifile.id,
"name": aifile.name,
"input_keys": aifile.input_keys,
"input_labels": aifile.input_labels,
"greeting": aifile.greeting,
}
)
knowledge_entries = db.session.query(Knowledge).all()
Expand Down
39 changes: 30 additions & 9 deletions backaind/api/ai.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,16 @@ def validate(ai_json):
400,
)
)
if "input_labels" in ai_json and not isinstance(ai_json["input_labels"], dict):
abort(
make_response(
jsonify(
error='The property "input_labels" has to be an object'
+ " assigning input keys to labels."
),
400,
)
)
if not "chain" in ai_json:
abort(make_response(jsonify(error='The property "chain" is required.'), 400))
if not isinstance(ai_json["chain"], dict):
Expand All @@ -38,6 +48,12 @@ def validate(ai_json):
jsonify(error='The property "chain" has to be a chain object.'), 400
)
)
if "greeting" in ai_json and not isinstance(ai_json["greeting"], str):
abort(
make_response(
jsonify(error='The property "greeting" has to be a string.'), 400
)
)


@bp.route("/", methods=["GET"])
Expand All @@ -64,8 +80,16 @@ def create_ai():

name = request.json["name"]
input_keys = request.json["input_keys"]
input_labels = request.json.get("input_labels")
chain = request.json["chain"]
new_ai = Ai(name=name, input_keys=input_keys, chain=chain)
greeting = request.json.get("greeting")
new_ai = Ai(
name=name,
input_keys=input_keys,
input_labels=input_labels,
chain=chain,
greeting=greeting,
)
db.session.add(new_ai)
db.session.commit()
return (
Expand All @@ -83,22 +107,19 @@ def update_ai(ai_id):

name = request.json["name"]
input_keys = request.json["input_keys"]
input_labels = request.json.get("input_labels")
chain = request.json["chain"]
greeting = request.json.get("greeting")

existing_ai = db.get_or_404(Ai, ai_id)
existing_ai.name = name
existing_ai.input_keys = input_keys
existing_ai.input_labels = input_labels
existing_ai.chain = chain
existing_ai.greeting = greeting
db.session.commit()
reset_global_chain(ai_id)
return jsonify(
{
"id": ai_id,
"name": name,
"input_keys": request.json["input_keys"],
"chain": request.json["chain"],
}
)
return existing_ai.as_dict()


@bp.route("/<int:ai_id>", methods=["DELETE"])
Expand Down
4 changes: 4 additions & 0 deletions backaind/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,19 @@ class Ai(db.Model):
id = db.Column(db.Integer, primary_key=True, autoincrement=True)
name = db.Column(db.String, nullable=False)
input_keys = db.Column(db.JSON, nullable=False)
input_labels = db.Column(db.JSON, nullable=True)
chain = db.Column(db.JSON, nullable=False)
greeting = db.Column(db.String, nullable=True)

def as_dict(self):
"""Return the model as a dictionary"""
return {
"id": self.id,
"name": self.name,
"input_keys": self.input_keys,
"input_labels": self.input_labels,
"chain": self.chain,
"greeting": self.greeting,
}


Expand Down
2 changes: 1 addition & 1 deletion backaind/static/ownai-ainteraction.css

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion backaind/static/ownai-ainteraction.js

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion backaind/static/ownai-workshop.js

Large diffs are not rendered by default.

10 changes: 9 additions & 1 deletion frontaind/components/ainteraction/Ainteraction.vue
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,14 @@
:disabled="selectionDisabled"
@select-knowledge="selectKnowledge"
/>
<MessageHistory :messages="messages" @clear-messages="clearMessages" />
<MessageHistory
:greeting="selectedAi?.greeting"
:messages="messages"
@clear-messages="clearMessages"
/>
<MessageInput
v-if="selectedAi && (!needsKnowledge || selectedKnowledge)"
:label="textInputLabel"
@send-message="sendMessage"
/>
</template>
Expand Down Expand Up @@ -50,6 +55,9 @@ const selectKnowledge = (knowledge: BasicKnowledge) => {
const needsKnowledge = computed(
() => !!selectedAi.value?.input_keys.includes("input_knowledge")
);
const textInputLabel = computed(
() => selectedAi.value?.input_labels?.input_text || "Send a message"
);
const selectionDisabled = ref<boolean>(false);
const messages = ref<Message[]>([]);
Expand Down
8 changes: 7 additions & 1 deletion frontaind/components/ainteraction/MessageHistory.vue
Original file line number Diff line number Diff line change
@@ -1,4 +1,9 @@
<template>
<div v-if="greeting" class="card mb-2 bg-light">
<div class="card-body">
{{ greeting }}
</div>
</div>
<div
v-for="message in messages"
class="card mb-2"
Expand All @@ -20,7 +25,8 @@
<script setup lang="ts">
import type { Message } from "@/types/ainteraction/Message";
const { messages } = defineProps<{
const { greeting, messages } = defineProps<{
greeting?: string;
messages: Message[];
}>();
const emit = defineEmits(["clear-messages"]);
Expand Down
6 changes: 5 additions & 1 deletion frontaind/components/ainteraction/MessageInput.vue
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
<textarea
class="form-control"
v-model="messageInput"
placeholder="Send a message"
:placeholder="label"
required
@keydown="checkForSubmit"
></textarea>
Expand All @@ -30,6 +30,10 @@
<script setup lang="ts">
import { ref } from "vue";
const { label } = defineProps<{
label: string;
}>();
const messageInput = ref("");
const emit = defineEmits(["send-message"]);
Expand Down
28 changes: 19 additions & 9 deletions frontaind/components/workshop/AiWorkshop.vue
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ import { indentationMarkers } from "@replit/codemirror-indentation-markers";
import Workshop from "./Workshop.vue";
import NewAiDropdown from "./ai/NewAiDropdown.vue";
import type { Ai } from "@/types/Ai";
import type { Aifile } from "@/types/Aifile";
import type { IdName } from "@/types/IdName";
import type { Optional } from "@/types/Optional";
Expand All @@ -94,15 +95,18 @@ const aiToJson = (ai: Ai | undefined) => {
if (!ai) {
return "";
}
return JSON.stringify(
{
name: ai.name,
aifileversion: 1,
chain: ai.chain,
},
null,
2
);
const aifile: Aifile = {
name: ai.name,
aifileversion: 1,
chain: ai.chain,
};
if (ai.input_labels) {
aifile["input_labels"] = ai.input_labels;
}
if (ai.greeting) {
aifile["greeting"] = ai.greeting;
}
return JSON.stringify(aifile, null, 2);
};
const jsonToAi = (id: number | undefined, jsonString: string) => {
Expand Down Expand Up @@ -133,6 +137,12 @@ const jsonToAi = (id: number | undefined, jsonString: string) => {
input_keys: Array.from(inputKeys),
chain: json["chain"],
};
if (json["input_labels"]) {
ai.input_labels = json["input_labels"];
}
if (json["greeting"]) {
ai.greeting = json["greeting"];
}
if (id !== undefined) {
ai.id = id;
}
Expand Down
2 changes: 2 additions & 0 deletions frontaind/types/Ai.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,7 @@ import { IdName } from "./IdName";

export interface Ai extends IdName {
input_keys: string[];
input_labels?: { [key: string]: string };
chain: object;
greeting?: string;
}
3 changes: 3 additions & 0 deletions frontaind/types/Aifile.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
import { Ai } from "./Ai";

export type Aifile = Omit<Ai, "id" | "input_keys"> & { aifileversion: number };
5 changes: 4 additions & 1 deletion frontaind/types/ainteraction/BasicAi.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import { Ai } from "@/types/Ai";

export type BasicAi = Pick<Ai, "id" | "name" | "input_keys">;
export type BasicAi = Pick<
Ai,
"id" | "name" | "input_keys" | "input_labels" | "greeting"
>;
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
"""Add greeting and input_labels columns
Revision ID: 16c3aa6e11aa
Revises: 58ae6bdf9d33
Create Date: 2023-09-13 12:08:56.515625
"""
from alembic import op
import sqlalchemy as sa


# revision identifiers, used by Alembic.
revision = "16c3aa6e11aa"
down_revision = "58ae6bdf9d33"
branch_labels = None
depends_on = None


def upgrade():
with op.batch_alter_table("ai", schema=None) as batch_op:
batch_op.add_column(sa.Column("input_labels", sa.JSON(), nullable=True))
batch_op.add_column(sa.Column("greeting", sa.String(), nullable=True))


def downgrade():
with op.batch_alter_table("ai", schema=None) as batch_op:
batch_op.drop_column("greeting")
batch_op.drop_column("input_labels")
8 changes: 8 additions & 0 deletions tests/api/test_ai_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,14 @@ def test_delete_ai(client, auth, app):
{"name": "Test", "input_keys": [], "chain": []},
'The property "chain" has to be a chain object.',
),
(
{"name": "Test", "input_keys": [], "chain": {}, "input_labels": "Test"},
'The property "input_labels" has to be an object assigning input keys to labels.',
),
(
{"name": "Test", "input_keys": [], "chain": {}, "greeting": True},
'The property "greeting" has to be a string.',
),
),
)
def test_validation(client, auth, data, message):
Expand Down

0 comments on commit 1e72274

Please sign in to comment.