Skip to content

Commit

Permalink
Add evaluate_passkey_retrieval method
Browse files Browse the repository at this point in the history
  • Loading branch information
jshuadvd committed Jul 5, 2024
1 parent 815c717 commit e7da903
Showing 1 changed file with 7 additions and 0 deletions.
7 changes: 7 additions & 0 deletions evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,3 +75,10 @@ def passkey_retrieval_test(model, tokenizer, max_length, num_trials=10):
accuracies[length] = correct_retrievals / num_trials

return accuracies


def evaluate_passkey_retrieval(model, tokenizer, max_length):
accuracies = passkey_retrieval_test(model, tokenizer, max_length)
for length, accuracy in accuracies.items():
print(f"Passkey retrieval accuracy at {length} tokens: {accuracy:.2f}")
wandb.log({f"passkey_retrieval_{length}": accuracy})

0 comments on commit e7da903

Please sign in to comment.