Skip to content

Commit

Permalink
Share state with jupyter lab
Browse files Browse the repository at this point in the history
  • Loading branch information
jakubno committed May 14, 2024
1 parent 911028d commit 2ae7297
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 53 deletions.
122 changes: 102 additions & 20 deletions python/e2b_code_interpreter/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import json
import logging
import threading
import uuid

import requests

from concurrent.futures import Future
Expand Down Expand Up @@ -103,6 +106,47 @@ def exec_cell(
f"Received result from kernel {kernel_id}, session_id: {session_id}, result: {result}"
)

nb = self._sandbox.filesystem.read(f"/home/user/default.ipynb", timeout=timeout)
nb_parsed = json.loads(nb)
cell = {
"cell_type": "code",
"metadata": {},
"source": code,
}

outputs = []
if result.logs.stdout:
outputs.append(
{"output_type": "stream", "name": "stdout", "text": result.logs.stdout}
)
if result.logs.stderr:
outputs.append(
{"output_type": "stream", "name": "stderr", "text": result.logs.stderr}
)

if result.results:
outputs = [
{
"output_type": (
"execute_result" if r.is_main_result else "display_data"
),
"data": r.raw,
"metadata": {},
}
for r in result.results
]

cell["execution_count"] = result.execution_count
cell["outputs"] = outputs

if nb_parsed["cells"] and not nb_parsed["cells"][-1]["source"]:
nb_parsed["cells"][-1] = cell
else:
nb_parsed["cells"].append(cell)

self._sandbox.filesystem.write(
f"/home/user/default.ipynb", json.dumps(nb_parsed), timeout=timeout
)
return result

@property
Expand All @@ -120,8 +164,9 @@ def default_kernel_id(self) -> str:

def create_kernel(
self,
cwd: str = "/home/user",
kernel_name: Optional[str] = None,
name: str,
path: str = "/home/user",
kernel_name: str = "python3",
timeout: Optional[float] = TIMEOUT,
) -> str:
"""
Expand All @@ -132,32 +177,60 @@ def create_kernel(
Once the kernel is created, this method establishes a WebSocket connection to the new kernel for
real-time communication.
:param cwd: Sets the current working directory for the kernel. Defaults to "/home/user".
:param kernel_name:
Specifies which kernel should be used, useful if you have multiple kernel types.
If not provided, the default kernel will be used.
:param name: Name of the kernel
:param path: Sets the current working directory for the kernel. Defaults to "/home/user".
:param kernel_name: Specifies which kernel should be used, useful if you have multiple kernel types.
:param timeout: Timeout for the kernel creation request.
:return: Kernel id of the created kernel
"""
data = {"cwd": cwd}
if kernel_name:
data["kernel_name"] = kernel_name

x = {
"metadata": {
"signature": "hex-digest",
"kernel_info": {"name": kernel_name},
"language_info": {
"name": kernel_name,
"version": "3.10.14",
}, # TODO: get version
},
"nbformat": 4,
"nbformat_minor": 0,
"cells": [],
}

self._sandbox.filesystem.write(
f"/home/user/default.ipynb", json.dumps(x), timeout=timeout
)
self._sandbox.process.start("chmod 777 /home/user/default.ipynb")

data = {
"name": name,
"kernel": {"name": kernel_name},
"notebook": {"name": "default.ipynb"},
"path": path,
"type": "notebook",
}

logger.debug(f"Creating kernel with data: {data}")

response = requests.post(
f"{self._sandbox.get_protocol()}://{self._sandbox.get_hostname(8888)}/api/kernels",
f"{self._sandbox.get_protocol()}://{self._sandbox.get_hostname(8888)}/api/sessions",
json=data,
timeout=timeout,
)
if not response.ok:
raise KernelException(f"Failed to create kernel: {response.text}")

kernel_id = response.json()["id"]
logger.debug(f"Created kernel {kernel_id}")
response_data = response.json()
kernel_id = response_data["kernel"]["id"]
session_id = response_data["id"]

logger.debug(f"Created kernel {kernel_id}, session {session_id}")

threading.Thread(
target=self._connect_to_kernel_ws, args=(kernel_id, timeout)
target=self._connect_to_kernel_ws, args=(kernel_id, session_id, timeout)
).start()

return kernel_id

def restart_kernel(
Expand Down Expand Up @@ -243,22 +316,30 @@ def close(self):
ws.result().close()

def _connect_to_kernel_ws(
self, kernel_id: str, timeout: Optional[float] = TIMEOUT
self,
kernel_id: str,
session_id: Optional[str],
timeout: Optional[float] = TIMEOUT,
) -> JupyterKernelWebSocket:
"""
Establishes a WebSocket connection to a specified Jupyter kernel.
:param kernel_id: Kernel id
:param session_id: Session id
:param timeout: The timeout for the kernel connection request.
:return: Websocket connection
"""
if not session_id:
session_id = uuid.uuid4()

logger.debug(f"Connecting to kernel's ({kernel_id}) websocket")
future = Future()
self._connected_kernels[kernel_id] = future

ws = JupyterKernelWebSocket(
url=f"{self._sandbox.get_protocol('ws')}://{self._sandbox.get_hostname(8888)}/api/kernels/{kernel_id}/channels",
session_id=session_id,
)
ws.connect(timeout=timeout)
logger.debug(f"Connected to kernel's ({kernel_id}) websocket.")
Expand All @@ -276,15 +357,16 @@ def _start_connecting_to_default_kernel(
logger.debug("Starting to connect to the default kernel")

def setup_default_kernel():
kernel_id = self._sandbox.filesystem.read(
"/root/.jupyter/kernel_id", timeout=timeout
session_info = self._sandbox.filesystem.read(
"/root/.jupyter/.session_info", timeout=timeout
)
if kernel_id is None and not self._sandbox.is_open:
if session_info is None and not self._sandbox.is_open:
return

kernel_id = kernel_id.strip()
data = json.loads(session_info)
kernel_id = data["kernel"]["id"]
session_id = data["id"]
logger.debug(f"Default kernel id: {kernel_id}")
self._connect_to_kernel_ws(kernel_id, timeout=timeout)
self._connect_to_kernel_ws(kernel_id, session_id, timeout=timeout)
self._kernel_id_set.set_result(kernel_id)

threading.Thread(target=setup_default_kernel).start()
9 changes: 5 additions & 4 deletions python/e2b_code_interpreter/messaging.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ class JupyterKernelWebSocket(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

url: str
session_id: str

_cells: Dict[str, CellExecution] = {}
_waiting_for_replies: Dict[str, DeferredFuture] = PrivateAttr(default_factory=dict)
Expand Down Expand Up @@ -103,14 +104,13 @@ def connect(self, timeout: float = TIMEOUT):

logger.debug("WebSocket started")

@staticmethod
def _get_execute_request(msg_id: str, code: str) -> str:
def _get_execute_request(self, msg_id: str, code: str) -> str:
return json.dumps(
{
"header": {
"msg_id": msg_id,
"username": "e2b",
"session": str(uuid.uuid4()),
"session": self.session_id,
"msg_type": "execute_request",
"version": "5.3",
},
Expand All @@ -119,7 +119,7 @@ def _get_execute_request(msg_id: str, code: str) -> str:
"content": {
"code": code,
"silent": False,
"store_history": False,
"store_history": True,
"user_expressions": {},
"allow_stdin": False,
},
Expand Down Expand Up @@ -239,6 +239,7 @@ def _receive_message(self, data: dict):

elif data["msg_type"] == "execute_input":
logger.debug(f"Input accepted for {parent_msg_ig}")
cell.partial_result.execution_count = data["content"]["execution_count"]
cell.input_accepted = True
else:
logger.warning(f"[UNHANDLED MESSAGE TYPE]: {data['msg_type']}")
Expand Down
4 changes: 3 additions & 1 deletion python/e2b_code_interpreter/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,8 @@ class Config:
"Logs printed to stdout and stderr during execution."
error: Optional[Error] = None
"Error object if an error occurred, None otherwise."
execution_count: Optional[int] = None
"Counter of the number of times the cell has been executed."

@property
def text(self) -> Optional[str]:
Expand Down Expand Up @@ -249,7 +251,7 @@ def serialize_results(results: List[Result]) -> List[Dict[str, str]]:
serialized = []
for result in results:
serialized_dict = {key: result[key] for key in result.formats()}
serialized_dict['text'] = result.text
serialized_dict["text"] = result.text
serialized.append(serialized_dict)
return serialized

Expand Down
27 changes: 15 additions & 12 deletions python/example.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,13 @@
load_dotenv()


code = """
import matplotlib.pyplot as plt
code = """class Welcome:
def _repr_html_(self):
return "<h1>Welcome to the code interpreter!</h1>"
Welcome()
"""

graph = """import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 20, 100)
Expand All @@ -14,17 +19,15 @@
plt.plot(x, y)
plt.show()
x = np.linspace(0, 10, 100)
plt.plot(x, y)
plt.show()
import pandas
pandas.DataFrame({"a": [1, 2, 3]})
"""

with CodeInterpreter() as sandbox:
result = sandbox.notebook.exec_cell(code)

print(result.result.formats())
print(len(result.display_data))
with CodeInterpreter(template="code-interpreter-stateful-lab") as sandbox:
print(f"https://{sandbox.get_hostname(8888)}/doc/tree/RTC:default.ipynb")
sandbox.notebook.exec_cell(code)
sandbox.notebook.exec_cell(graph)
while True:
r = sandbox.notebook.exec_cell(input("code: "))
if r.results:
print(r.results[0].text)
2 changes: 1 addition & 1 deletion template/e2b.Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -17,4 +17,4 @@ RUN mkdir -p /home/user/.ipython/profile_default
COPY ipython_kernel_config.py /home/user/.ipython/profile_default/

COPY ./start-up.sh /home/user/.jupyter/
RUN chmod +x /home/user/.jupyter/start-up.sh
RUN chmod +x /home/user/.jupyter/start-up.sh
10 changes: 5 additions & 5 deletions template/e2b.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# This is a config for E2B sandbox template.
# You can use 'template_id' (1tsfj5yvigmgc5gmgqz2) or 'template_name (code-interpreter-stateful) from this config to spawn a sandbox:
# You can use 'template_id' (w9nfeev5mm8a64f8tqw2) or 'template_name (code-interpreter-stateful-lab) from this config to spawn a sandbox:

# Python SDK
# from e2b import Sandbox
# sandbox = Sandbox(template='code-interpreter-stateful')
# sandbox = Sandbox(template='code-interpreter-stateful-lab')

# JS SDK
# import { Sandbox } from 'e2b'
# const sandbox = await Sandbox.create({ template: 'code-interpreter-stateful' })
# const sandbox = await Sandbox.create({ template: 'code-interpreter-stateful-lab' })

memory_mb = 1_024
start_cmd = "/home/user/.jupyter/start-up.sh"
dockerfile = "e2b.Dockerfile"
template_name = "code-interpreter-stateful"
template_id = "1tsfj5yvigmgc5gmgqz2"
template_name = "code-interpreter-stateful-lab"
template_id = "w9nfeev5mm8a64f8tqw2"
7 changes: 4 additions & 3 deletions template/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
# Jupyter server requirements
jupyter-server==2.13.0
ipykernel==6.29.3
ipython==8.22.2
jupyter-collaboration==3.0.0a2
jupyterlab==4.2.0

# Other packages
aiohttp==3.9.3
Expand All @@ -18,10 +19,10 @@ opencv-python==4.9.0.80
openpyxl==3.1.2
pandas==1.5.3
plotly==5.19.0
pytest==8.1.0
pytest==8.2.0
python-docx==1.1.0
pytz==2024.1
requests==2.26.0
requests==2.31.0
scikit-image==0.22.0
scikit-learn==1.4.1.post1
scipy==1.12.0
Expand Down
Loading

0 comments on commit 2ae7297

Please sign in to comment.