Skip to content

Commit

Permalink
refactor: Simplify weak and editor model configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
Finndersen committed Jan 10, 2025
1 parent 6f82661 commit 340aa34
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 49 deletions.
5 changes: 3 additions & 2 deletions aider/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def check_streamlit_install(io):

def write_streamlit_credentials():
from streamlit.file_util import get_streamlit_file_path

# See https://github.com/Aider-AI/aider/issues/772

credential_path = Path(get_streamlit_file_path()) / "credentials.toml"
Expand Down Expand Up @@ -753,8 +754,8 @@ def get_io(pretty):

main_model = models.Model(
args.model,
weak_model=args.weak_model,
editor_model=args.editor_model,
weak_model_name=args.weak_model,
editor_model_name=args.editor_model,
editor_edit_format=args.editor_edit_format,
)

Expand Down
86 changes: 41 additions & 45 deletions aider/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -744,20 +744,18 @@ def get_model_info(self, model):


class Model(ModelSettings):
def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format=None):
def __init__(
self, model_name, weak_model_name=None, editor_model_name=None, editor_edit_format=None
):
# Map any alias to its canonical name
model = MODEL_ALIASES.get(model, model)

self.name = model
model_name = MODEL_ALIASES.get(model_name, model_name)

self.max_chat_history_tokens = 1024
self.weak_model = None
self.editor_model = None
self.name = model_name

# Find the extra settings
self.extra_model_settings = get_model_settings("aider/extra_params")

self.info = self.get_model_info(model)
self.info = self.get_model_info(model_name)

# Are all needed keys/params available?
res = self.validate_environment()
Expand All @@ -769,33 +767,28 @@ def __init__(self, model, weak_model=None, editor_model=None, editor_edit_format
# with minimum 1k and maximum 8k
self.max_chat_history_tokens = min(max(max_input_tokens / 16, 1024), 8192)

self.configure_model_settings(model)
if weak_model is False:
self.weak_model_name = None
else:
self.get_weak_model(weak_model)
self.configure_model_settings(model_name)

if editor_model is False:
self.editor_model_name = None
else:
self.get_editor_model(editor_model, editor_edit_format)
self._set_weak_model(weak_model_name)

def get_model_info(self, model):
return model_info_manager.get_model_info(model)
self._set_editor_model_and_format(editor_model_name, editor_edit_format)

def get_model_info(self, model_name):
return model_info_manager.get_model_info(model_name)

def _copy_fields(self, source):
"""Helper to copy fields from a ModelSettings instance to self"""
for field in fields(ModelSettings):
val = getattr(source, field.name)
setattr(self, field.name, val)

def configure_model_settings(self, model):
def configure_model_settings(self, model_name):
# Look for exact model match
if ms := get_model_settings(model):
if ms := get_model_settings(model_name):
self._copy_fields(ms)
else:
# If no exact match, try generic settings
self.apply_generic_model_settings(model.lower())
self.apply_generic_model_settings(model_name.lower())

# Apply override settings last if they exist
if self.extra_model_settings and self.extra_model_settings.extra_params:
Expand Down Expand Up @@ -869,43 +862,46 @@ def apply_generic_model_settings(self, model):
def __str__(self):
return self.name

def get_weak_model(self, provided_weak_model_name):
def _set_weak_model(self, provided_weak_model_name):
if provided_weak_model_name is False:
self.weak_model = None
self.weak_model_name = None
return

# If weak_model_name is provided, override the model settings
if provided_weak_model_name:
self.weak_model_name = provided_weak_model_name
self.weak_model_name = provided_weak_model_name or self.weak_model_name

if (not self.weak_model_name) or (self.weak_model_name == self.name):
if (self.weak_model_name is None) or (self.weak_model_name == self.name):
self.weak_model = self
return

self.weak_model = Model(
self.weak_model_name,
weak_model=False,
)
return self.weak_model
else:
self.weak_model = Model(
self.weak_model_name,
weak_model_name=False,
)

def commit_message_models(self):
return [self.weak_model, self]

def get_editor_model(self, provided_editor_model_name, editor_edit_format):
def _set_editor_model_and_format(self, provided_editor_model_name, provided_editor_edit_format):
if provided_editor_model_name is False:
self.editor_model = None
self.editor_model_name = None
return

# If editor_model_name is provided, override the model settings
if provided_editor_model_name:
self.editor_model_name = provided_editor_model_name
if editor_edit_format:
self.editor_edit_format = editor_edit_format
self.editor_model_name = provided_editor_model_name or self.editor_model_name

if not self.editor_model_name or self.editor_model_name == self.name:
if (self.editor_model_name is None) or (self.editor_model_name == self.name):
self.editor_model = self
else:
self.editor_model = Model(
self.editor_model_name,
editor_model=False,
provided_editor_model_name,
editor_model_name=False,
)

if not self.editor_edit_format:
self.editor_edit_format = self.editor_model.edit_format

return self.editor_model
self.editor_edit_format = (
provided_editor_edit_format or self.editor_edit_format or self.editor_model.edit_format
)

def tokenizer(self, text):
return litellm.encode(model=self.name, text=text)
Expand Down
4 changes: 2 additions & 2 deletions benchmark/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -753,8 +753,8 @@ def run_test_real(

main_model = models.Model(
model_name,
weak_model=weak_model_name,
editor_model=editor_model,
weak_model_name=weak_model_name,
editor_model_name=editor_model,
editor_edit_format=editor_edit_format,
)

Expand Down
72 changes: 72 additions & 0 deletions tests/basic/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,78 @@ def test_get_model_settings_invalid(self):
self.assertIsNone(get_model_settings("openai/invalid-model"))
self.assertIsNone(get_model_settings("not-a-provider/gpt-4"))

def test_weak_model_settings(self):
# When weak_model is None and no weak model configured, use self
model = Model("openai/gpt-undefined", weak_model_name=None)
self.assertIs(model.weak_model, model)

# When weak_model is None and model has weak model configured
model = Model("openai/gpt-4o", weak_model_name=None)
self.assertEqual(model.weak_model.name, model.weak_model_name)

# Test when weak_model is False
model = Model("openai/gpt-4o", weak_model_name=False)
self.assertIsNone(model.weak_model)

# Test when weak_model_name matches model name
model = Model("openai/gpt-undefined", weak_model_name="openai/gpt-undefined")
self.assertIs(model.weak_model, model)

# Test when weak_model_name is different, and none configured
model = Model("openai/gpt-undefined", weak_model_name="gpt-3.5-turbo")
self.assertNotEqual(model.weak_model, model)
self.assertEqual(model.weak_model.name, "openai/gpt-3.5-turbo")

# Test when weak_model_name is different, and other configured
model = Model("openai/gpt-4o", weak_model_name="gpt-3.5-turbo")
self.assertNotEqual(model.weak_model, model)
self.assertEqual(model.weak_model.name, "openai/gpt-3.5-turbo")

def test_editor_model_settings(self):
# Test when model has no editor model configured, use self
model = Model("openai/gpt-undefined", editor_model_name=None)
self.assertIs(model.editor_model, model)

# Test when model has editor model configured
model = Model("anthropic/claude-3-5-sonnet-20240620")
self.assertEqual(model.editor_model.name, "anthropic/claude-3-5-sonnet-20240620")
self.assertIs(model.editor_model, model)

# Test when editor_model is False
model = Model("anthropic/claude-3-5-sonnet-20240620", editor_model_name=False)
self.assertIsNone(model.editor_model)

# Test when editor_model_name matches model name
model = Model("openai/gpt-4o", editor_model_name="openai/gpt-4o")
self.assertIs(model.editor_model, model)

def test_editor_edit_format(self):
# Test when editor_edit_format is provided, override the model settings
model = Model("openai/gpt-4o", editor_edit_format="whole")
self.assertEqual(model.editor_edit_format, "whole")

# Test when editor_edit_format is not provided, use the model settings
model = Model("openai/gpt-4o")
self.assertEqual(model.editor_edit_format, "editor-diff")

# Test when editor_model_name is provided, use the model settings
model = Model("openai/gpt-4o", editor_model_name="openai/gpt-4o")
self.assertEqual(model.editor_edit_format, "editor-diff")

# When editor_model_name and editor_edit_format is provided, overrides the model settings
model = Model(
"openai/gpt-4o", editor_model_name="openai/gpt-4o", editor_edit_format="whole"
)
self.assertEqual(model.editor_edit_format, "whole")

# When model editor_edit_format is not specified, use the editor_model settings
model = Model("openai/gpt-4-turbo", editor_model_name="openai/gpt-4o")
self.assertEqual(model.editor_edit_format, "diff")

# When editor_model_name=False, ignore provided editor_edit_format
model = Model("openai/gpt-4-turbo", editor_model_name=False, editor_edit_format="whole")
self.assertIsNone(model.editor_edit_format)


if __name__ == "__main__":
unittest.main()

0 comments on commit 340aa34

Please sign in to comment.