Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release 0.60.0 #1373

Merged
merged 27 commits into from
Feb 7, 2025
Merged
Show file tree
Hide file tree
Changes from 26 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
56608ff
Fix `patch` behavior for trusses using python DX (#1337)
nnarayen Jan 27, 2025
ff95f73
Bump truss version to 0.9.60rc001 (#1345)
nnarayen Jan 27, 2025
934bcbc
adds resources.node_count to truss spec (#1344)
rcano-baseten Jan 27, 2025
bd655ef
bump truss-transfer to 0.0.1 (#1342)
michaelfeil Jan 29, 2025
ad2fcd6
Add trace ID logging to failed chainlet RPC log - related BT-13465 (#…
marius-baseten Jan 29, 2025
ddbbb3c
Add H200 accelerator type (#1351)
nnarayen Jan 29, 2025
7812b1d
Bug bash improvements to Python DX, better error messages (#1346)
nnarayen Jan 29, 2025
f26589c
Better example code for models using chains framework (#1347)
nnarayen Jan 29, 2025
105eb1d
Update BEI Image (#1354)
michaelfeil Jan 31, 2025
070de58
raising ValueError for py38 or less for truss supported version (#1350)
sumerjoshi Feb 1, 2025
2b96fcb
Update pyproject.toml (#1356)
michaelfeil Feb 2, 2025
53db85f
update: max num tokens for bei (#1357)
michaelfeil Feb 3, 2025
9eafca8
add bei specfic user migrations in trt-llm config (#1361)
michaelfeil Feb 3, 2025
0f15e3c
add max gib 16GB (#1362)
michaelfeil Feb 3, 2025
b5ed43b
Use PurePosixPath. (#1358)
squidarth Feb 4, 2025
0dec2a1
rename: b10cache (#1365)
michaelfeil Feb 4, 2025
da548ae
[TaT] Pre-requisites for flexible builds. (#1352)
marius-baseten Feb 4, 2025
14f0f55
Clean Chains Stack Traces and consolidate logging config. Fixes BT-13…
marius-baseten Feb 4, 2025
b8df346
add truss trtllm config (#1366)
michaelfeil Feb 4, 2025
f7be3e0
Add `--python-dx` flag to `truss init` (#1367)
nnarayen Feb 5, 2025
27dced4
Multinode Cleanup (#1360)
rcano-baseten Feb 6, 2025
aee0bd8
Cleanup `init` and `init_directory` usage (#1368)
nnarayen Feb 6, 2025
87c83d9
Introduce truss server passthrough for OpenAI methods (#1364)
nnarayen Feb 6, 2025
911dddf
Fix integration test expectations (#1370)
nnarayen Feb 6, 2025
d5db001
Fix streaming for OpenAI clients (#1371)
nnarayen Feb 7, 2025
4db445b
Bump version to 0.60.0
basetenbot Feb 7, 2025
ef2e43b
Fix integration test for span cleanup (#1374)
nnarayen Feb 7, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
808 changes: 418 additions & 390 deletions poetry.lock

Large diffs are not rendered by default.

5 changes: 3 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "truss"
version = "0.9.59"
version = "0.60.0"
description = "A seamless bridge from model development to model delivery"
license = "MIT"
readme = "README.md"
Expand All @@ -24,6 +24,7 @@ requires-poetry = ">=2.0"

[tool.poetry.scripts]
truss = "truss.cli.cli:truss_cli"
truss-docker-build-setup = "truss.contexts.docker_build_setup:docker_build_setup"

[tool.poetry.urls]
"Homepage" = "https://truss.baseten.co"
Expand Down Expand Up @@ -165,7 +166,7 @@ numpy = ">=1.23.5"
opentelemetry-api = ">=1.25.0"
opentelemetry-exporter-otlp = ">=1.25.0"
opentelemetry-sdk = ">=1.25.0"
truss_transfer="0.0.1rc4"
truss_transfer="0.0.1"
uvicorn = ">=0.24.0"
uvloop = ">=0.17.0"

Expand Down
10 changes: 5 additions & 5 deletions smoketests/test_chains.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from truss.remote.baseten.utils import status as status_utils

from truss_chains import definitions
from truss_chains.remote_chainlet import stub
from truss_chains.remote_chainlet import stub, utils

backend_env_domain = "staging.baseten.co"
BASETEN_API_KEY = os.environ["BASETEN_API_KEY_STAGING"]
Expand Down Expand Up @@ -153,7 +153,7 @@ def test_itest_chain_publish(prepare) -> None:
# Test regular (JSON) invocation.
chain_stub = make_stub(url, definitions.RPCOptions(timeout_sec=10))
trace_parent = generate_traceparent()
with stub.trace_parent_raw(trace_parent):
with utils.trace_parent_raw(trace_parent):
result = chain_stub.predict_sync({"length": 30, "num_partitions": 3})

expected = [
Expand All @@ -169,7 +169,7 @@ def test_itest_chain_publish(prepare) -> None:
invocation_times_sec = []
for i in range(10):
t0 = time.perf_counter()
with stub.trace_parent_raw(trace_parent):
with utils.trace_parent_raw(trace_parent):
chain_stub.predict_sync({"length": 30, "num_partitions": 3})
invocation_times_sec.append(time.perf_counter() - t0)

Expand All @@ -182,7 +182,7 @@ def test_itest_chain_publish(prepare) -> None:
url, definitions.RPCOptions(timeout_sec=10, use_binary=True)
)
trace_parent = generate_traceparent()
with stub.trace_parent_raw(trace_parent):
with utils.trace_parent_raw(trace_parent):
result = chain_stub_binary.predict_sync({"length": 30, "num_partitions": 3})

expected = [
Expand All @@ -198,7 +198,7 @@ def test_itest_chain_publish(prepare) -> None:
invocation_times_sec = []
for i in range(10):
t0 = time.perf_counter()
with stub.trace_parent_raw(trace_parent):
with utils.trace_parent_raw(trace_parent):
chain_stub_binary.predict_sync({"length": 30, "num_partitions": 3})
invocation_times_sec.append(time.perf_counter() - t0)

Expand Down
7 changes: 7 additions & 0 deletions truss-chains/tests/import/model_without_inheritance.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
class ClassWithoutModelInheritance:
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count
19 changes: 19 additions & 0 deletions truss-chains/tests/import/standalone_with_multiple_entrypoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import truss_chains as chains


class FirstModel(chains.ModelBase):
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count


class SecondModel(chains.ModelBase):
def __init__(self):
self._call_count = 0

async def predict(self, call_count_increment: int) -> int:
self._call_count += call_count_increment
return self._call_count
6 changes: 6 additions & 0 deletions truss-chains/tests/itest_chain/itest_chain.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
import logging
import math

from user_package import shared_chainlet
from user_package.nested_package import io_types

import truss_chains as chains

logger = logging.getLogger(__name__)

IMAGE_BASETEN = chains.DockerImage(
base_image=chains.BasetenImage.PY310,
pip_requirements_file=chains.make_abs_path_here("requirements.txt"),
Expand Down Expand Up @@ -103,6 +106,7 @@ def __init__(
text_to_num: TextToNum = chains.depends(TextToNum),
context=chains.depends_context(),
) -> None:
logging.info("User log root during load.")
self._context = context
self._data_generator = data_generator
self._data_splitter = splitter
Expand All @@ -117,6 +121,8 @@ async def run_remote(
),
simple_default_arg: list[str] = ["a", "b"],
) -> tuple[int, str, int, shared_chainlet.SplitTextOutput, list[str]]:
logging.info("User log root.")
logger.info("User log module.")
data = self._data_generator.run_remote(length)
text_parts, number, items = await self._data_splitter.run_remote(
io_types.SplitTextInput(
Expand Down
36 changes: 25 additions & 11 deletions truss-chains/tests/test_e2e.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@
def test_chain():
with ensure_kill_all():
chain_root = TEST_ROOT / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "ItestChain"
) as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
Expand All @@ -37,7 +39,7 @@ def test_chain():
response = requests.post(
url,
json={"length": 30, "num_partitions": 3},
headers={"traceparent": "TEST TEST TEST"},
headers={"traceparent": "TRACE_ID"},
)
print(response.content)
assert response.status_code == 200
Expand Down Expand Up @@ -70,7 +72,10 @@ def test_chain():

# Test with errors.
response = requests.post(
url, json={"length": 300, "num_partitions": 3}, stream=True
url,
json={"length": 300, "num_partitions": 3},
stream=True,
headers={"traceparent": "TRACE_ID"},
)
print(response)
assert response.status_code == 500
Expand All @@ -86,12 +91,12 @@ def test_chain():
File \".*?/itest_chain\.py\", line \d+, in _accumulate_parts
value \+= self\._text_to_num\.run_remote\(part\)
ValueError: \(showing chained remote errors, root error at the bottom\)
├─ Error in dependency Chainlet `TextToNum` \(HTTP status 500\):
├─ Error calling dependency Chainlet `TextToNum`, HTTP status=500, trace ID=`TRACE_ID`.
│ Chainlet-Traceback \(most recent call last\):
│ File \".*?/itest_chain\.py\", line \d+, in run_remote
│ generated_text = self\._replicator\.run_remote\(data\)
│ ValueError: \(showing chained remote errors, root error at the bottom\)
│ ├─ Error in dependency Chainlet `TextReplicator` \(HTTP status 500\):
│ ├─ Error calling dependency Chainlet `TextReplicator`, HTTP status=500, trace ID=`TRACE_ID`.
│ │ Chainlet-Traceback \(most recent call last\):
│ │ File \".*?/itest_chain\.py\", line \d+, in run_remote
│ │ validate_data\(data\)
Expand All @@ -106,7 +111,9 @@ def test_chain():
@pytest.mark.asyncio
async def test_chain_local():
chain_root = TEST_ROOT / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "ItestChain"
) as entrypoint:
with public_api.run_local():
with pytest.raises(ValueError):
# First time `SplitTextFailOnce` raises an error and
Expand Down Expand Up @@ -140,7 +147,9 @@ def test_streaming_chain():
with ensure_kill_all():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "Consumer"
) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand Down Expand Up @@ -176,7 +185,7 @@ def test_streaming_chain():
async def test_streaming_chain_local():
examples_root = Path(__file__).parent.parent.resolve() / "examples"
chain_root = examples_root / "streaming" / "streaming_chain.py"
with framework.import_target(chain_root, "Consumer") as entrypoint:
with framework.ChainletImporter.import_target(chain_root, "Consumer") as entrypoint:
with public_api.run_local():
result = await entrypoint().run_remote(cause_error=False)
print(result)
Expand All @@ -198,7 +207,7 @@ def test_numpy_chain(mode):
target = "HostBinary"
with ensure_kill_all():
chain_root = TEST_ROOT / "numpy_and_binary" / "chain.py"
with framework.import_target(chain_root, target) as entrypoint:
with framework.ChainletImporter.import_target(chain_root, target) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand All @@ -213,11 +222,14 @@ def test_numpy_chain(mode):
print(response.json())


@pytest.mark.integration
@pytest.mark.asyncio
async def test_timeout():
with ensure_kill_all():
chain_root = TEST_ROOT / "timeout" / "timeout_chain.py"
with framework.import_target(chain_root, "TimeoutChain") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "TimeoutChain"
) as entrypoint:
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
Expand Down Expand Up @@ -284,7 +296,9 @@ def test_traditional_truss():
def test_custom_health_checks_chain():
with ensure_kill_all():
chain_root = TEST_ROOT / "custom_health_checks" / "custom_health_checks.py"
with framework.import_target(chain_root, "CustomHealthChecks") as entrypoint:
with framework.ChainletImporter.import_target(
chain_root, "CustomHealthChecks"
) as entrypoint:
service = deployment_client.push(
entrypoint,
options=definitions.PushOptionsLocalDocker(
Expand Down
18 changes: 18 additions & 0 deletions truss-chains/tests/test_framework.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
import contextlib
import logging
import pathlib
import re
from typing import AsyncIterator, Iterator, List

Expand All @@ -12,6 +13,7 @@

utils.setup_dev_logging(logging.DEBUG)

TEST_ROOT = pathlib.Path(__file__).parent.resolve()

# Assert that naive chainlet initialization is detected and prevented. #################

Expand Down Expand Up @@ -668,3 +670,19 @@ def is_healthy(self) -> str: # type: ignore[misc]

async def run_remote(self) -> str:
return ""


def test_import_model_requires_entrypoint():
model_src = TEST_ROOT / "import" / "model_without_inheritance.py"
match = r"No Model class in `.+` inherits from"
with pytest.raises(ValueError, match=match), _raise_errors():
with framework.ModelImporter.import_target(model_src):
pass


def test_import_model_requires_single_entrypoint():
model_src = TEST_ROOT / "import" / "standalone_with_multiple_entrypoints.py"
match = r"Multiple Model classes in `.+` inherit from"
with pytest.raises(ValueError, match=match), _raise_errors():
with framework.ModelImporter.import_target(model_src):
pass
3 changes: 1 addition & 2 deletions truss-chains/truss_chains/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,8 @@
RemoteErrorDetail,
RPCOptions,
)
from truss_chains.framework import ChainletBase, ModelBase
from truss_chains.public_api import (
ChainletBase,
ModelBase,
depends,
depends_context,
mark_entrypoint,
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def migrate_fields(cls, values):
class ComputeSpec(pydantic.BaseModel):
"""Parsed and validated compute. See ``Compute`` for more information."""

# TODO[rcano] add node count
cpu_count: int = 1
predict_concurrency: int = 1
memory: str = "2Gi"
Expand Down
8 changes: 3 additions & 5 deletions truss-chains/truss_chains/deployment/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,7 @@ def _gen_predict_src(chainlet_descriptor: definitions.ChainletAPIDescriptor) ->
f"request: starlette.requests.Request) -> {output_type_name}:"
)
# Add error handling context manager:
parts.append(
_indent("with stub.trace_parent(request), utils.exception_to_http_error():")
)
parts.append(_indent("with utils.predict_context(request):"))
# Invoke Chainlet.
if (
chainlet_descriptor.endpoint.is_async
Expand Down Expand Up @@ -733,7 +731,7 @@ def gen_truss_model_from_source(
# TODO(nikhil): Improve detection of directory structure, since right now
# we assume a flat structure
root_dir = model_src.absolute().parent
with framework.import_target(model_src) as entrypoint_cls:
with framework.ModelImporter.import_target(model_src) as entrypoint_cls:
descriptor = framework.get_descriptor(entrypoint_cls)
return gen_truss_model(
model_root=root_dir,
Expand Down Expand Up @@ -773,7 +771,7 @@ def gen_truss_chainlet(
gen_root = pathlib.Path(tempfile.gettempdir())
chainlet_dir = _make_chainlet_dir(chain_name, chainlet_descriptor, gen_root)
logging.info(
f"Code generation for Chainlet `{chainlet_descriptor.name}` "
f"Code generation for {chainlet_descriptor.chainlet_cls.entity_type} `{chainlet_descriptor.name}` "
f"in `{chainlet_dir}`."
)
_write_truss_config_yaml(
Expand Down
Loading
Loading