From 4f97215bd50b12438863f3f5ed9a7ac19fa0a55d Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Fri, 10 Jan 2025 17:04:30 +0100 Subject: [PATCH 1/7] Create the pipeline sensitive context when creating a pipeline instance, not on every processing We used to create the pipeline context during pipeline processing which means we cannot reuse the same pipeline for output that spans several data buffers. --- src/codegate/pipeline/base.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index 59c73981..aa63dad2 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -277,6 +277,13 @@ def __init__( self.secret_manager = secret_manager self.is_fim = is_fim self.context = PipelineContext() + + # we create the sesitive context here so that it is not shared between individual requests + # TODO: could we get away with just generating the session ID for an instance? + self.context.sensitive = PipelineSensitiveData( + manager=self.secret_manager, + session_id=str(uuid.uuid4()), + ) self.context.metadata["is_fim"] = is_fim async def process_request( @@ -290,17 +297,14 @@ async def process_request( is_copilot: bool = False, ) -> PipelineResult: """Process a request through all pipeline steps""" - self.context.sensitive = PipelineSensitiveData( - manager=self.secret_manager, - session_id=str(uuid.uuid4()), - api_key=api_key, - model=model, - provider=provider, - api_base=api_base, - ) self.context.metadata["extra_headers"] = extra_headers current_request = request + self.context.sensitive.api_key = api_key + self.context.sensitive.model = model + self.context.sensitive.provider = provider + self.context.sensitive.api_base = api_base + # For Copilot provider=openai. Use a flag to not clash with other places that may use that. provider_db = "copilot" if is_copilot else provider From 8dbff6372ea5e1afeb5b20f4a32b797d1f75b80f Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Fri, 10 Jan 2025 17:05:17 +0100 Subject: [PATCH 2/7] Create pipeline instance when creating the SequentialPipelineProcessor not for every process A pipeline instance is what binds the pipeline steps with the context. Create the instance sooner, not when processing the request. --- src/codegate/pipeline/base.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/codegate/pipeline/base.py b/src/codegate/pipeline/base.py index aa63dad2..eb92d8a0 100644 --- a/src/codegate/pipeline/base.py +++ b/src/codegate/pipeline/base.py @@ -340,8 +340,9 @@ def __init__( self.pipeline_steps = pipeline_steps self.secret_manager = secret_manager self.is_fim = is_fim + self.instance = self._create_instance() - def create_instance(self) -> InputPipelineInstance: + def _create_instance(self) -> InputPipelineInstance: """Create a new pipeline instance for processing a request""" return InputPipelineInstance(self.pipeline_steps, self.secret_manager, self.is_fim) @@ -356,7 +357,6 @@ async def process_request( is_copilot: bool = False, ) -> PipelineResult: """Create a new pipeline instance and process the request""" - instance = self.create_instance() - return await instance.process_request( + return await self.instance.process_request( request, provider, model, api_key, api_base, extra_headers, is_copilot ) From 606da49acfdb3ec632f785ad4223978c39cf9c14 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Fri, 10 Jan 2025 17:09:35 +0100 Subject: [PATCH 3/7] Create the pipelines only once in the copilot provider Since the copilot provider class instance is created once per connection, let's create the pipelines when establishing the connection and reuse them. --- src/codegate/providers/copilot/pipeline.py | 22 ++++++++++++++++------ src/codegate/providers/copilot/provider.py | 19 ++++++++++++++++--- 2 files changed, 32 insertions(+), 9 deletions(-) diff --git a/src/codegate/providers/copilot/pipeline.py b/src/codegate/providers/copilot/pipeline.py index 5268aeaa..d1ef13da 100644 --- a/src/codegate/providers/copilot/pipeline.py +++ b/src/codegate/providers/copilot/pipeline.py @@ -24,6 +24,7 @@ class CopilotPipeline(ABC): def __init__(self, pipeline_factory: PipelineFactory): self.pipeline_factory = pipeline_factory + self.instance = self._create_pipeline() self.normalizer = self._create_normalizer() self.provider_name = "openai" @@ -33,7 +34,7 @@ def _create_normalizer(self): pass @abstractmethod - def create_pipeline(self) -> SequentialPipelineProcessor: + def _create_pipeline(self) -> SequentialPipelineProcessor: """Each strategy defines which pipeline to create""" pass @@ -84,7 +85,11 @@ def _create_shortcut_response(result: PipelineResult, model: str) -> bytes: body = response.model_dump_json(exclude_none=True, exclude_unset=True).encode() return body - async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, PipelineContext]: + async def process_body( + self, + headers: list[str], + body: bytes, + ) -> Tuple[bytes, PipelineContext | None]: """Common processing logic for all strategies""" try: normalized_body = self.normalizer.normalize(body) @@ -97,8 +102,7 @@ async def process_body(self, headers: list[str], body: bytes) -> Tuple[bytes, Pi except ValueError: continue - pipeline = self.create_pipeline() - result = await pipeline.process_request( + result = await self.instance.process_request( request=normalized_body, provider=self.provider_name, model=normalized_body.get("model", "gpt-4o-mini"), @@ -168,10 +172,13 @@ class CopilotFimPipeline(CopilotPipeline): format and the FIM pipeline used by all providers. """ + def __init__(self, pipeline_factory: PipelineFactory): + super().__init__(pipeline_factory) + def _create_normalizer(self): return CopilotFimNormalizer() - def create_pipeline(self) -> SequentialPipelineProcessor: + def _create_pipeline(self) -> SequentialPipelineProcessor: return self.pipeline_factory.create_fim_pipeline() @@ -181,8 +188,11 @@ class CopilotChatPipeline(CopilotPipeline): format and the FIM pipeline used by all providers. """ + def __init__(self, pipeline_factory: PipelineFactory): + super().__init__(pipeline_factory) + def _create_normalizer(self): return CopilotChatNormalizer() - def create_pipeline(self) -> SequentialPipelineProcessor: + def _create_pipeline(self) -> SequentialPipelineProcessor: return self.pipeline_factory.create_input_pipeline() diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index f8bfeb1e..cb9b2694 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -150,8 +150,16 @@ def __init__(self, loop: asyncio.AbstractEventLoop): self.cert_manager = TLSCertDomainManager(self.ca) self._closing = False self.pipeline_factory = PipelineFactory(SecretsManager()) + self.input_pipeline: Optional[CopilotPipeline] = None + self.fim_pipeline: Optional[CopilotPipeline] = None + # the context as provided by the pipeline self.context_tracking: Optional[PipelineContext] = None + def _ensure_pipelines(self): + if not self.input_pipeline or not self.fim_pipeline: + self.input_pipeline = CopilotChatPipeline(self.pipeline_factory) + self.fim_pipeline = CopilotFimPipeline(self.pipeline_factory) + def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]: if method != "POST": logger.debug("Not a POST request, no pipeline selected") @@ -161,10 +169,10 @@ def _select_pipeline(self, method: str, path: str) -> Optional[CopilotPipeline]: if path == route.path: if route.pipeline_type == PipelineType.FIM: logger.debug("Selected FIM pipeline") - return CopilotFimPipeline(self.pipeline_factory) + return self.fim_pipeline elif route.pipeline_type == PipelineType.CHAT: logger.debug("Selected CHAT pipeline") - return CopilotChatPipeline(self.pipeline_factory) + return self.input_pipeline logger.debug("No pipeline selected") return None @@ -181,7 +189,6 @@ async def _body_through_pipeline( # if we didn't select any strategy that would change the request # let's just pass through the body as-is return body, None - logger.debug(f"Processing body through pipeline: {len(body)} bytes") return await strategy.process_body(headers, body) async def _request_to_target(self, headers: list[str], body: bytes): @@ -288,6 +295,9 @@ async def _forward_data_through_pipeline(self, data: bytes) -> Union[HttpRequest http_request.headers, http_request.body, ) + # TODO: it's weird that we're overwriting the context. + # Should we set the context once? Maybe when + # creating the pipeline instance? self.context_tracking = context if context and context.shortcut_response: @@ -442,6 +452,7 @@ def data_received(self, data: bytes) -> None: if not self.headers_parsed: self.headers_parsed = self.parse_headers() if self.headers_parsed: + self._ensure_pipelines() if self.request.method == "CONNECT": self.handle_connect() self.buffer.clear() @@ -756,10 +767,12 @@ def connection_made(self, transport: asyncio.Transport) -> None: def _ensure_output_processor(self) -> None: if self.proxy.context_tracking is None: + logger.debug("No context tracking, no need to process pipeline") # No context tracking, no need to process pipeline return if self.sse_processor is not None: + logger.debug("Already initialized, no need to reinitialize") # Already initialized, no need to reinitialize return From 5b621a59f874c855d1cb4aa90f4a1dd5b090d812 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Fri, 10 Jan 2025 17:05:40 +0100 Subject: [PATCH 4/7] Don't flush the buffer and the sensitive data if we're reusing the output pipeline Since we can reuse a single pipeline for multiple request-reply round trips, we shouldn't flush the buffer and destroy the context. --- src/codegate/pipeline/output.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/codegate/pipeline/output.py b/src/codegate/pipeline/output.py index 89c31c59..f5bb716a 100644 --- a/src/codegate/pipeline/output.py +++ b/src/codegate/pipeline/output.py @@ -162,6 +162,10 @@ async def process_stream( logger.error(f"Error processing stream: {e}") raise e finally: + # Don't flush the buffer if we assume we'll call the pipeline again + if cleanup_sensitive is False: + return + # Process any remaining content in buffer when stream ends if self._context.buffer: final_content = "".join(self._context.buffer) From 3e4790d7d26b12955ebd85ce91e7e974095ae6c8 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Sun, 12 Jan 2025 23:00:29 +0100 Subject: [PATCH 5/7] Handle multiple requests in one data_received call We've seen instances where the request (typically a FIM one) contained more than one request. Let's dispatch them one by one individually. Also let's not pass around self.buffer into the tasks but a parameter. --- src/codegate/providers/copilot/provider.py | 74 ++++++++++++++-------- 1 file changed, 46 insertions(+), 28 deletions(-) diff --git a/src/codegate/providers/copilot/provider.py b/src/codegate/providers/copilot/provider.py index cb9b2694..6a3adf67 100644 --- a/src/codegate/providers/copilot/provider.py +++ b/src/codegate/providers/copilot/provider.py @@ -229,15 +229,15 @@ def connection_made(self, transport: asyncio.Transport) -> None: self.peername = transport.get_extra_info("peername") logger.debug(f"Client connected from {self.peername}") - def get_headers_dict(self) -> Dict[str, str]: + def get_headers_dict(self, complete_request) -> Dict[str, str]: """Convert raw headers to dictionary format""" headers_dict = {} try: - if b"\r\n\r\n" not in self.buffer: + if b"\r\n\r\n" not in complete_request: return {} - headers_end = self.buffer.index(b"\r\n\r\n") - headers = self.buffer[:headers_end].split(b"\r\n")[1:] + headers_end = complete_request.index(b"\r\n\r\n") + headers = complete_request[:headers_end].split(b"\r\n")[1:] for header in headers: try: @@ -449,32 +449,50 @@ def data_received(self, data: bytes) -> None: self.buffer.extend(data) - if not self.headers_parsed: - self.headers_parsed = self.parse_headers() - if self.headers_parsed: - self._ensure_pipelines() - if self.request.method == "CONNECT": - self.handle_connect() - self.buffer.clear() - else: - # Only process the request once we have the complete body - asyncio.create_task(self.handle_http_request()) - else: - if self._has_complete_body(): - # Process the complete request through the pipeline - complete_request = bytes(self.buffer) - # logger.debug(f"Complete request: {complete_request}") - self.buffer.clear() - asyncio.create_task(self._forward_data_to_target(complete_request)) + while self.buffer: # Process as many complete requests as we have + if not self.headers_parsed: + self.headers_parsed = self.parse_headers() + if self.headers_parsed: + self._ensure_pipelines() + if self.request.method == "CONNECT": + if self._has_complete_body(): + self.handle_connect() + self.buffer.clear() # CONNECT requests are handled differently + break # CONNECT handling complete + elif self._has_complete_body(): + # Find where this request ends + headers_end = self.buffer.index(b"\r\n\r\n") + headers = self.buffer[:headers_end].split(b"\r\n")[1:] + content_length = 0 + for header in headers: + if header.lower().startswith(b"content-length:"): + content_length = int(header.split(b":", 1)[1]) + break + + request_end = headers_end + 4 + content_length + complete_request = self.buffer[:request_end] + + self.buffer = self.buffer[request_end:] # Keep remaining data + + self.headers_parsed = False # Reset for next request + + asyncio.create_task(self.handle_http_request(complete_request)) + break # Either processing request or need more data + else: + if self._has_complete_body(): + complete_request = bytes(self.buffer) + self.buffer.clear() # Clear buffer for next request + asyncio.create_task(self._forward_data_to_target(complete_request)) + break # Either processing request or need more data except Exception as e: logger.error(f"Error processing received data: {e}") self.send_error_response(502, str(e).encode()) - async def handle_http_request(self) -> None: + async def handle_http_request(self, complete_request: bytes) -> None: """Handle standard HTTP request""" try: - target_url = await self._get_target_url() + target_url = await self._get_target_url(complete_request) except Exception as e: logger.error(f"Error getting target URL: {e}") self.send_error_response(404, b"Not Found") @@ -518,9 +536,9 @@ async def handle_http_request(self) -> None: new_headers.append(f"Host: {self.target_host}") if self.target_transport: - if self.buffer: - body_start = self.buffer.index(b"\r\n\r\n") + 4 - body = self.buffer[body_start:] + if complete_request: + body_start = complete_request.index(b"\r\n\r\n") + 4 + body = complete_request[body_start:] await self._request_to_target(new_headers, body) else: # just skip it @@ -532,9 +550,9 @@ async def handle_http_request(self) -> None: logger.error(f"Error preparing or sending request to target: {e}") self.send_error_response(502, b"Bad Gateway") - async def _get_target_url(self) -> Optional[str]: + async def _get_target_url(self, complete_request) -> Optional[str]: """Determine target URL based on request path and headers""" - headers_dict = self.get_headers_dict() + headers_dict = self.get_headers_dict(complete_request) auth_header = headers_dict.get("authorization", "") if auth_header: From 30e56e3cdd734f64ace374172d79e5391282a2d9 Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 13 Jan 2025 11:49:54 +0100 Subject: [PATCH 6/7] Initialize pipelines per instance in the base providers, too --- src/codegate/providers/anthropic/provider.py | 14 +++------- src/codegate/providers/base.py | 22 ++++++---------- src/codegate/providers/llamacpp/provider.py | 14 +++------- src/codegate/providers/ollama/provider.py | 13 +++------- src/codegate/providers/openai/provider.py | 14 +++------- src/codegate/providers/vllm/provider.py | 14 +++------- src/codegate/server.py | 27 ++++---------------- tests/test_provider.py | 6 ++--- 8 files changed, 30 insertions(+), 94 deletions(-) diff --git a/src/codegate/providers/anthropic/provider.py b/src/codegate/providers/anthropic/provider.py index 78c5df72..2daa5a8d 100644 --- a/src/codegate/providers/anthropic/provider.py +++ b/src/codegate/providers/anthropic/provider.py @@ -1,11 +1,9 @@ import json -from typing import Optional import structlog from fastapi import Header, HTTPException, Request -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.anthropic.adapter import AnthropicInputNormalizer, AnthropicOutputNormalizer from codegate.providers.anthropic.completion_handler import AnthropicCompletion from codegate.providers.base import BaseProvider @@ -15,20 +13,14 @@ class AnthropicProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): completion_handler = AnthropicCompletion(stream_generator=anthropic_stream_generator) super().__init__( AnthropicInputNormalizer(), AnthropicOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, - fim_output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/providers/base.py b/src/codegate/providers/base.py index b529ad3b..dc45616e 100644 --- a/src/codegate/providers/base.py +++ b/src/codegate/providers/base.py @@ -10,9 +10,9 @@ from codegate.pipeline.base import ( PipelineContext, PipelineResult, - SequentialPipelineProcessor, ) -from codegate.pipeline.output import OutputPipelineInstance, OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory +from codegate.pipeline.output import OutputPipelineInstance from codegate.providers.completion.base import BaseCompletionHandler from codegate.providers.formatting.input_pipeline import PipelineResponseFormatter from codegate.providers.normalizer.base import ModelInputNormalizer, ModelOutputNormalizer @@ -34,19 +34,13 @@ def __init__( input_normalizer: ModelInputNormalizer, output_normalizer: ModelOutputNormalizer, completion_handler: BaseCompletionHandler, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): self.router = APIRouter() self._completion_handler = completion_handler self._input_normalizer = input_normalizer self._output_normalizer = output_normalizer - self._pipeline_processor = pipeline_processor - self._fim_pipelin_processor = fim_pipeline_processor - self._output_pipeline_processor = output_pipeline_processor - self._fim_output_pipeline_processor = fim_output_pipeline_processor + self._pipeline_factory = pipeline_factory self._db_recorder = DbRecorder() self._pipeline_response_formatter = PipelineResponseFormatter( output_normalizer, self._db_recorder @@ -73,10 +67,10 @@ async def _run_output_stream_pipeline( # Decide which pipeline processor to use out_pipeline_processor = None if is_fim_request: - out_pipeline_processor = self._fim_output_pipeline_processor + out_pipeline_processor = self._pipeline_factory.create_fim_output_pipeline() logger.info("FIM pipeline selected for output.") else: - out_pipeline_processor = self._output_pipeline_processor + out_pipeline_processor = self._pipeline_factory.create_output_pipeline() logger.info("Chat completion pipeline selected for output.") if out_pipeline_processor is None: logger.info("No output pipeline processor found, passing through") @@ -117,11 +111,11 @@ async def _run_input_pipeline( ) -> PipelineResult: # Decide which pipeline processor to use if is_fim_request: - pipeline_processor = self._fim_pipelin_processor + pipeline_processor = self._pipeline_factory.create_fim_pipeline() logger.info("FIM pipeline selected for execution.") normalized_request = self._fim_normalizer.normalize(normalized_request) else: - pipeline_processor = self._pipeline_processor + pipeline_processor = self._pipeline_factory.create_input_pipeline() logger.info("Chat completion pipeline selected for execution.") if pipeline_processor is None: return PipelineResult(request=normalized_request) diff --git a/src/codegate/providers/llamacpp/provider.py b/src/codegate/providers/llamacpp/provider.py index 37dc64d1..7f90619e 100644 --- a/src/codegate/providers/llamacpp/provider.py +++ b/src/codegate/providers/llamacpp/provider.py @@ -1,11 +1,9 @@ import json -from typing import Optional import structlog from fastapi import HTTPException, Request -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider from codegate.providers.llamacpp.completion_handler import LlamaCppCompletionHandler from codegate.providers.llamacpp.normalizer import LLamaCppInputNormalizer, LLamaCppOutputNormalizer @@ -14,20 +12,14 @@ class LlamaCppProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): completion_handler = LlamaCppCompletionHandler() super().__init__( LLamaCppInputNormalizer(), LLamaCppOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, - fim_output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index f8e901d4..37c06a0d 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -1,13 +1,11 @@ import json -from typing import Optional import httpx import structlog from fastapi import HTTPException, Request from codegate.config import Config -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider from codegate.providers.ollama.adapter import OllamaInputNormalizer, OllamaOutputNormalizer from codegate.providers.ollama.completion_handler import OllamaShim @@ -16,10 +14,7 @@ class OllamaProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): config = Config.get_config() if config is None: @@ -32,9 +27,7 @@ def __init__( OllamaInputNormalizer(), OllamaOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/providers/openai/provider.py b/src/codegate/providers/openai/provider.py index 75c201da..53aa7db8 100644 --- a/src/codegate/providers/openai/provider.py +++ b/src/codegate/providers/openai/provider.py @@ -1,11 +1,9 @@ import json -from typing import Optional import structlog from fastapi import Header, HTTPException, Request -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.openai.adapter import OpenAIInputNormalizer, OpenAIOutputNormalizer @@ -14,20 +12,14 @@ class OpenAIProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): completion_handler = LiteLLmShim(stream_generator=sse_stream_generator) super().__init__( OpenAIInputNormalizer(), OpenAIOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, - fim_output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/providers/vllm/provider.py b/src/codegate/providers/vllm/provider.py index 6d7b9bca..f39ed8d6 100644 --- a/src/codegate/providers/vllm/provider.py +++ b/src/codegate/providers/vllm/provider.py @@ -1,5 +1,4 @@ import json -from typing import Optional import httpx import structlog @@ -7,8 +6,7 @@ from litellm import atext_completion from codegate.config import Config -from codegate.pipeline.base import SequentialPipelineProcessor -from codegate.pipeline.output import OutputPipelineProcessor +from codegate.pipeline.factory import PipelineFactory from codegate.providers.base import BaseProvider from codegate.providers.litellmshim import LiteLLmShim, sse_stream_generator from codegate.providers.vllm.adapter import VLLMInputNormalizer, VLLMOutputNormalizer @@ -17,10 +15,7 @@ class VLLMProvider(BaseProvider): def __init__( self, - pipeline_processor: Optional[SequentialPipelineProcessor] = None, - fim_pipeline_processor: Optional[SequentialPipelineProcessor] = None, - output_pipeline_processor: Optional[OutputPipelineProcessor] = None, - fim_output_pipeline_processor: Optional[OutputPipelineProcessor] = None, + pipeline_factory: PipelineFactory, ): completion_handler = LiteLLmShim( stream_generator=sse_stream_generator, fim_completion_func=atext_completion @@ -29,10 +24,7 @@ def __init__( VLLMInputNormalizer(), VLLMOutputNormalizer(), completion_handler, - pipeline_processor, - fim_pipeline_processor, - output_pipeline_processor, - fim_output_pipeline_processor, + pipeline_factory, ) @property diff --git a/src/codegate/server.py b/src/codegate/server.py index 57206712..9ea9e569 100644 --- a/src/codegate/server.py +++ b/src/codegate/server.py @@ -51,47 +51,30 @@ def init_app(pipeline_factory: PipelineFactory) -> FastAPI: # Register all known providers registry.add_provider( "openai", - OpenAIProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), - ), + OpenAIProvider(pipeline_factory), ) registry.add_provider( "anthropic", AnthropicProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), + pipeline_factory, ), ) registry.add_provider( "llamacpp", LlamaCppProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), + pipeline_factory, ), ) registry.add_provider( "vllm", VLLMProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), + pipeline_factory, ), ) registry.add_provider( "ollama", OllamaProvider( - pipeline_processor=pipeline_factory.create_input_pipeline(), - fim_pipeline_processor=pipeline_factory.create_fim_pipeline(), - output_pipeline_processor=pipeline_factory.create_output_pipeline(), - fim_output_pipeline_processor=pipeline_factory.create_fim_output_pipeline(), + pipeline_factory, ), ) diff --git a/tests/test_provider.py b/tests/test_provider.py index f2c4011f..95361c97 100644 --- a/tests/test_provider.py +++ b/tests/test_provider.py @@ -11,14 +11,12 @@ def __init__(self): mocked_input_normalizer = MagicMock() mocked_output_normalizer = MagicMock() mocked_completion_handler = MagicMock() - mocked_pipepeline = MagicMock() - mocked_fim_pipeline = MagicMock() + mocked_factory = MagicMock() super().__init__( mocked_input_normalizer, mocked_output_normalizer, mocked_completion_handler, - mocked_pipepeline, - mocked_fim_pipeline, + mocked_factory, ) def _setup_routes(self) -> None: From cb9540abb3673acc9d1dcc96638f4c9c79bee27d Mon Sep 17 00:00:00 2001 From: Jakub Hrozek Date: Mon, 13 Jan 2025 13:05:55 +0100 Subject: [PATCH 7/7] Fix SecretsManager For some reason we coded up the SecretsManager so that it held only one secret per session. Let's store a dict instead. --- src/codegate/pipeline/secrets/manager.py | 21 ++++++++++------- src/codegate/pipeline/secrets/secrets.py | 3 +++ .../providers/ollama/completion_handler.py | 1 + src/codegate/providers/ollama/provider.py | 1 + tests/pipeline/secrets/test_manager.py | 23 ++++++++++--------- 5 files changed, 30 insertions(+), 19 deletions(-) diff --git a/src/codegate/pipeline/secrets/manager.py b/src/codegate/pipeline/secrets/manager.py index a7b32319..bef07c75 100644 --- a/src/codegate/pipeline/secrets/manager.py +++ b/src/codegate/pipeline/secrets/manager.py @@ -21,7 +21,7 @@ class SecretsManager: def __init__(self): self.crypto = CodeGateCrypto() - self._session_store: dict[str, SecretEntry] = {} + self._session_store: dict[str, dict[str, SecretEntry]] = {} self._encrypted_to_session: dict[str, str] = {} # Reverse lookup index def store_secret(self, value: str, service: str, secret_type: str, session_id: str) -> str: @@ -41,12 +41,14 @@ def store_secret(self, value: str, service: str, secret_type: str, session_id: s encrypted_value = self.crypto.encrypt_token(value, session_id) # Store mappings - self._session_store[session_id] = SecretEntry( + session_secrets = self._session_store.get(session_id, {}) + session_secrets[encrypted_value] = SecretEntry( original=value, encrypted=encrypted_value, service=service, secret_type=secret_type, ) + self._session_store[session_id] = session_secrets self._encrypted_to_session[encrypted_value] = session_id logger.debug("Stored secret", service=service, type=secret_type, encrypted=encrypted_value) @@ -58,7 +60,9 @@ def get_original_value(self, encrypted_value: str, session_id: str) -> Optional[ try: stored_session_id = self._encrypted_to_session.get(encrypted_value) if stored_session_id == session_id: - return self._session_store[session_id].original + session_secrets = self._session_store[session_id].get(encrypted_value) + if session_secrets: + return session_secrets.original except Exception as e: logger.error("Error retrieving secret", error=str(e)) return None @@ -71,9 +75,10 @@ def cleanup(self): """Securely wipe sensitive data""" try: # Convert and wipe original values - for entry in self._session_store.values(): - original_bytes = bytearray(entry.original.encode()) - self.crypto.wipe_bytearray(original_bytes) + for secrets in self._session_store.values(): + for entry in secrets.values(): + original_bytes = bytearray(entry.original.encode()) + self.crypto.wipe_bytearray(original_bytes) # Clear the dictionaries self._session_store.clear() @@ -92,9 +97,9 @@ def cleanup_session(self, session_id: str): """ try: # Get the secret entry for the session - entry = self._session_store.get(session_id) + secrets = self._session_store.get(session_id, {}) - if entry: + for entry in secrets.values(): # Securely wipe the original value original_bytes = bytearray(entry.original.encode()) self.crypto.wipe_bytearray(original_bytes) diff --git a/src/codegate/pipeline/secrets/secrets.py b/src/codegate/pipeline/secrets/secrets.py index 606db5bb..43ec17a8 100644 --- a/src/codegate/pipeline/secrets/secrets.py +++ b/src/codegate/pipeline/secrets/secrets.py @@ -362,6 +362,7 @@ async def process_chunk( if match: # Found a complete marker, process it encrypted_value = match.group(1) + print("----> encrypted_value: ", encrypted_value) original_value = input_context.sensitive.manager.get_original_value( encrypted_value, input_context.sensitive.session_id, @@ -370,6 +371,8 @@ async def process_chunk( if original_value is None: # If value not found, leave as is original_value = match.group(0) # Keep the REDACTED marker + else: + print("----> original_value: ", original_value) # Post an alert with the redacted content input_context.add_alert(self.name, trigger_string=encrypted_value) diff --git a/src/codegate/providers/ollama/completion_handler.py b/src/codegate/providers/ollama/completion_handler.py index 49fbc103..f569d988 100644 --- a/src/codegate/providers/ollama/completion_handler.py +++ b/src/codegate/providers/ollama/completion_handler.py @@ -16,6 +16,7 @@ async def ollama_stream_generator( """OpenAI-style SSE format""" try: async for chunk in stream: + print(chunk) try: yield f"{chunk.model_dump_json()}\n\n" except Exception as e: diff --git a/src/codegate/providers/ollama/provider.py b/src/codegate/providers/ollama/provider.py index 37c06a0d..8307f7e0 100644 --- a/src/codegate/providers/ollama/provider.py +++ b/src/codegate/providers/ollama/provider.py @@ -38,6 +38,7 @@ def _setup_routes(self): """ Sets up Ollama API routes. """ + @self.router.get(f"/{self.provider_route_name}/api/tags") async def get_tags(request: Request): """ diff --git a/tests/pipeline/secrets/test_manager.py b/tests/pipeline/secrets/test_manager.py index 5cb06ade..177e8f3f 100644 --- a/tests/pipeline/secrets/test_manager.py +++ b/tests/pipeline/secrets/test_manager.py @@ -1,6 +1,6 @@ import pytest -from codegate.pipeline.secrets.manager import SecretEntry, SecretsManager +from codegate.pipeline.secrets.manager import SecretsManager class TestSecretsManager: @@ -21,11 +21,8 @@ def test_store_secret(self): # Verify the secret was stored stored = self.manager.get_by_session_id(self.test_session) - assert isinstance(stored, SecretEntry) - assert stored.original == self.test_value - assert stored.encrypted == encrypted - assert stored.service == self.test_service - assert stored.secret_type == self.test_type + assert isinstance(stored, dict) + assert stored[encrypted].original == self.test_value # Verify encrypted value can be retrieved retrieved = self.manager.get_original_value(encrypted, self.test_session) @@ -86,10 +83,15 @@ def test_multiple_secrets_same_session(self): encrypted1 = self.manager.store_secret("secret1", "service1", "type1", self.test_session) encrypted2 = self.manager.store_secret("secret2", "service2", "type2", self.test_session) - # Latest secret should be retrievable + # Latest secret should be retrievable in the session stored = self.manager.get_by_session_id(self.test_session) - assert stored.original == "secret2" - assert stored.encrypted == encrypted2 + assert isinstance(stored, dict) + assert stored[encrypted1].original == "secret1" + assert stored[encrypted2].original == "secret2" + + # Both secrets should be retrievable directly + assert self.manager.get_original_value(encrypted1, self.test_session) == "secret1" + assert self.manager.get_original_value(encrypted2, self.test_session) == "secret2" # Both encrypted values should map to the session assert self.manager._encrypted_to_session[encrypted1] == self.test_session @@ -119,7 +121,7 @@ def test_secure_cleanup(self): # Get reference to stored data before cleanup stored = self.manager.get_by_session_id(self.test_session) - original_value = stored.original + assert len(stored) == 1 # Perform cleanup self.manager.cleanup() @@ -127,7 +129,6 @@ def test_secure_cleanup(self): # Verify the original string was overwritten, not just removed # This test is a bit tricky since Python strings are immutable, # but we can at least verify the data is no longer accessible - assert original_value not in str(self.manager._session_store) assert self.test_value not in str(self.manager._session_store) def test_session_isolation(self):