From 340aa342070809938d7ff5de9328d9e650da6cf7 Mon Sep 17 00:00:00 2001 From: Finn Andersen Date: Fri, 10 Jan 2025 14:55:09 +0000 Subject: [PATCH] refactor: Simplify weak and editor model configuration --- aider/main.py | 5 ++- aider/models.py | 86 ++++++++++++++++++-------------------- benchmark/benchmark.py | 4 +- tests/basic/test_models.py | 72 +++++++++++++++++++++++++++++++ 4 files changed, 118 insertions(+), 49 deletions(-) diff --git a/aider/main.py b/aider/main.py index 4b440434715..0f97fe18082 100644 --- a/aider/main.py +++ b/aider/main.py @@ -217,6 +217,7 @@ def check_streamlit_install(io): def write_streamlit_credentials(): from streamlit.file_util import get_streamlit_file_path + # See https://github.com/Aider-AI/aider/issues/772 credential_path = Path(get_streamlit_file_path()) / "credentials.toml" @@ -753,8 +754,8 @@ def get_io(pretty): main_model = models.Model( args.model, - weak_model=args.weak_model, - editor_model=args.editor_model, + weak_model_name=args.weak_model, + editor_model_name=args.editor_model, editor_edit_format=args.editor_edit_format, ) diff --git a/aider/models.py b/aider/models.py index 74e5e5cc7b9..888efe60433 100644 --- a/aider/models.py +++ b/aider/models.py @@ -744,20 +744,18 @@ def get_model_info(self, model): class Model(ModelSettings): - def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format=None): + def __init__( + self, model_name, weak_model_name=None, editor_model_name=None, editor_edit_format=None + ): # Map any alias to its canonical name - model = MODEL_ALIASES.get(model, model) - - self.name = model + model_name = MODEL_ALIASES.get(model_name, model_name) - self.max_chat_history_tokens = 1024 - self.weak_model = None - self.editor_model = None + self.name = model_name # Find the extra settings self.extra_model_settings = get_model_settings("aider/extra_params") - self.info = self.get_model_info(model) + self.info = self.get_model_info(model_name) # Are all needed keys/params available? res = self.validate_environment() @@ -769,19 +767,14 @@ def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format # with minimum 1k and maximum 8k self.max_chat_history_tokens = min(max(max_input_tokens / 16, 1024), 8192) - self.configure_model_settings(model) - if weak_model is False: - self.weak_model_name = None - else: - self.get_weak_model(weak_model) + self.configure_model_settings(model_name) - if editor_model is False: - self.editor_model_name = None - else: - self.get_editor_model(editor_model, editor_edit_format) + self._set_weak_model(weak_model_name) - def get_model_info(self, model): - return model_info_manager.get_model_info(model) + self._set_editor_model_and_format(editor_model_name, editor_edit_format) + + def get_model_info(self, model_name): + return model_info_manager.get_model_info(model_name) def _copy_fields(self, source): """Helper to copy fields from a ModelSettings instance to self""" @@ -789,13 +782,13 @@ def _copy_fields(self, source): val = getattr(source, field.name) setattr(self, field.name, val) - def configure_model_settings(self, model): + def configure_model_settings(self, model_name): # Look for exact model match - if ms := get_model_settings(model): + if ms := get_model_settings(model_name): self._copy_fields(ms) else: # If no exact match, try generic settings - self.apply_generic_model_settings(model.lower()) + self.apply_generic_model_settings(model_name.lower()) # Apply override settings last if they exist if self.extra_model_settings and self.extra_model_settings.extra_params: @@ -869,43 +862,46 @@ def apply_generic_model_settings(self, model): def __str__(self): return self.name - def get_weak_model(self, provided_weak_model_name): + def _set_weak_model(self, provided_weak_model_name): + if provided_weak_model_name is False: + self.weak_model = None + self.weak_model_name = None + return + # If weak_model_name is provided, override the model settings - if provided_weak_model_name: - self.weak_model_name = provided_weak_model_name + self.weak_model_name = provided_weak_model_name or self.weak_model_name - if (not self.weak_model_name) or (self.weak_model_name == self.name): + if (self.weak_model_name is None) or (self.weak_model_name == self.name): self.weak_model = self - return - - self.weak_model = Model( - self.weak_model_name, - weak_model=False, - ) - return self.weak_model + else: + self.weak_model = Model( + self.weak_model_name, + weak_model_name=False, + ) def commit_message_models(self): return [self.weak_model, self] - def get_editor_model(self, provided_editor_model_name, editor_edit_format): + def _set_editor_model_and_format(self, provided_editor_model_name, provided_editor_edit_format): + if provided_editor_model_name is False: + self.editor_model = None + self.editor_model_name = None + return + # If editor_model_name is provided, override the model settings - if provided_editor_model_name: - self.editor_model_name = provided_editor_model_name - if editor_edit_format: - self.editor_edit_format = editor_edit_format + self.editor_model_name = provided_editor_model_name or self.editor_model_name - if not self.editor_model_name or self.editor_model_name == self.name: + if (self.editor_model_name is None) or (self.editor_model_name == self.name): self.editor_model = self else: self.editor_model = Model( - self.editor_model_name, - editor_model=False, + provided_editor_model_name, + editor_model_name=False, ) - if not self.editor_edit_format: - self.editor_edit_format = self.editor_model.edit_format - - return self.editor_model + self.editor_edit_format = ( + provided_editor_edit_format or self.editor_edit_format or self.editor_model.edit_format + ) def tokenizer(self, text): return litellm.encode(model=self.name, text=text) diff --git a/benchmark/benchmark.py b/benchmark/benchmark.py index 51050b13a2b..89eea062c0e 100755 --- a/benchmark/benchmark.py +++ b/benchmark/benchmark.py @@ -753,8 +753,8 @@ def run_test_real( main_model = models.Model( model_name, - weak_model=weak_model_name, - editor_model=editor_model, + weak_model_name=weak_model_name, + editor_model_name=editor_model, editor_edit_format=editor_edit_format, ) diff --git a/tests/basic/test_models.py b/tests/basic/test_models.py index a367370aefa..12751303143 100644 --- a/tests/basic/test_models.py +++ b/tests/basic/test_models.py @@ -206,6 +206,78 @@ def test_get_model_settings_invalid(self): self.assertIsNone(get_model_settings("openai/invalid-model")) self.assertIsNone(get_model_settings("not-a-provider/gpt-4")) + def test_weak_model_settings(self): + # When weak_model is None and no weak model configured, use self + model = Model("openai/gpt-undefined", weak_model_name=None) + self.assertIs(model.weak_model, model) + + # When weak_model is None and model has weak model configured + model = Model("openai/gpt-4o", weak_model_name=None) + self.assertEqual(model.weak_model.name, model.weak_model_name) + + # Test when weak_model is False + model = Model("openai/gpt-4o", weak_model_name=False) + self.assertIsNone(model.weak_model) + + # Test when weak_model_name matches model name + model = Model("openai/gpt-undefined", weak_model_name="openai/gpt-undefined") + self.assertIs(model.weak_model, model) + + # Test when weak_model_name is different, and none configured + model = Model("openai/gpt-undefined", weak_model_name="gpt-3.5-turbo") + self.assertNotEqual(model.weak_model, model) + self.assertEqual(model.weak_model.name, "openai/gpt-3.5-turbo") + + # Test when weak_model_name is different, and other configured + model = Model("openai/gpt-4o", weak_model_name="gpt-3.5-turbo") + self.assertNotEqual(model.weak_model, model) + self.assertEqual(model.weak_model.name, "openai/gpt-3.5-turbo") + + def test_editor_model_settings(self): + # Test when model has no editor model configured, use self + model = Model("openai/gpt-undefined", editor_model_name=None) + self.assertIs(model.editor_model, model) + + # Test when model has editor model configured + model = Model("anthropic/claude-3-5-sonnet-20240620") + self.assertEqual(model.editor_model.name, "anthropic/claude-3-5-sonnet-20240620") + self.assertIs(model.editor_model, model) + + # Test when editor_model is False + model = Model("anthropic/claude-3-5-sonnet-20240620", editor_model_name=False) + self.assertIsNone(model.editor_model) + + # Test when editor_model_name matches model name + model = Model("openai/gpt-4o", editor_model_name="openai/gpt-4o") + self.assertIs(model.editor_model, model) + + def test_editor_edit_format(self): + # Test when editor_edit_format is provided, override the model settings + model = Model("openai/gpt-4o", editor_edit_format="whole") + self.assertEqual(model.editor_edit_format, "whole") + + # Test when editor_edit_format is not provided, use the model settings + model = Model("openai/gpt-4o") + self.assertEqual(model.editor_edit_format, "editor-diff") + + # Test when editor_model_name is provided, use the model settings + model = Model("openai/gpt-4o", editor_model_name="openai/gpt-4o") + self.assertEqual(model.editor_edit_format, "editor-diff") + + # When editor_model_name and editor_edit_format is provided, overrides the model settings + model = Model( + "openai/gpt-4o", editor_model_name="openai/gpt-4o", editor_edit_format="whole" + ) + self.assertEqual(model.editor_edit_format, "whole") + + # When model editor_edit_format is not specified, use the editor_model settings + model = Model("openai/gpt-4-turbo", editor_model_name="openai/gpt-4o") + self.assertEqual(model.editor_edit_format, "diff") + + # When editor_model_name=False, ignore provided editor_edit_format + model = Model("openai/gpt-4-turbo", editor_model_name=False, editor_edit_format="whole") + self.assertIsNone(model.editor_edit_format) + if __name__ == "__main__": unittest.main()