Skip to content

Commit

Permalink
update plotting
Browse files Browse the repository at this point in the history
  • Loading branch information
zahrakolagar committed Oct 10, 2024
1 parent c15f711 commit dead1b0
Showing 1 changed file with 51 additions and 29 deletions.
80 changes: 51 additions & 29 deletions plotting_predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,20 @@
import seaborn as sns
import pandas as pd
from config import PRED_DIR, PLOT_DIR
from dataclasses import dataclass
from typing import List, Dict

@dataclass
class AccuracyData:
file_name: str
step: str
model_size: str
temperature: int
accuracy: float

def extract_accuracies_from_file(file_path):

def extract_accuracies_from_file(file_path: str) -> List[AccuracyData]:
"""Extract accuracy data from a single file."""
accuracies = []
file_name = os.path.basename(file_path)

Expand All @@ -21,51 +32,56 @@ def extract_accuracies_from_file(file_path):
temperature, accuracy = match
step = re.search(r'step(\d+)', file_name).group(1) # Extract step from filename

# Adjusting for model size extraction based on "small" or "middle"
if "small" in file_name:
model_size = "0.5"
elif "middle" in file_name:
model_size = "1.5"
else:
model_size = "unknown" # Default if no pattern matches

accuracies.append({
'file_name': file_name,
'step': step,
'model_size': model_size,
'temperature': int(temperature),
'accuracy': float(accuracy)
})
# Determine model size from the filename
model_size = determine_model_size(file_name)

# Append extracted data as an AccuracyData object
accuracies.append(AccuracyData(
file_name=file_name,
step=step,
model_size=model_size,
temperature=int(temperature),
accuracy=float(accuracy)
))

return accuracies


def determine_model_size(file_name: str) -> str:
"""Determine the model size based on the filename."""
if "small" in file_name:
return "0.5"
elif "middle" in file_name:
return "1.5"
else:
return "unknown" # Default if no pattern matches

def process_files_in_directory(directory_path):

def process_files_in_directory(directory_path: str) -> List[AccuracyData]:
"""Process all text files in a given directory and return a list of AccuracyData."""
all_accuracies = []

for file_name in os.listdir(directory_path):
if file_name.endswith('.txt'):
file_path = os.path.join(directory_path, file_name)
accuracies = extract_accuracies_from_file(file_path)
all_accuracies.extend(accuracies)
print(all_accuracies)

return all_accuracies


def plot_accuracies(accuracies, plot_dir):
df = pd.DataFrame(accuracies)
def plot_accuracies(accuracies: List[AccuracyData], plot_dir: str) -> None:
"""Plot accuracies and save plots to the specified directory."""
df = pd.DataFrame([data.__dict__ for data in accuracies])

# Convert step and temperature to categorical for better plot
# Convert step and temperature to categorical for better plotting
df['temperature'] = df['temperature'].astype(str)

# Loop over unique steps and model sizes
# Loop over unique steps and model sizes to create plots
for step in df['step'].unique():
step_data = df[df['step'] == step]

plt.figure(figsize=(10, 6))

# Create barplot for each step, showing model size in different colors
sns.barplot(x='temperature', y='accuracy', hue='model_size', data=step_data, palette='Set2')

plt.title(f'Accuracy for Different Temperatures - Step {step}')
Expand All @@ -81,10 +97,16 @@ def plot_accuracies(accuracies, plot_dir):
plt.close()


def main():
"""Main function to process accuracy files and generate plots."""
directory_path = PRED_DIR
plots_directory = PLOT_DIR

accuracies_data = process_files_in_directory(directory_path)

# Plot and save the results in the plot directory
plot_accuracies(accuracies_data, plots_directory)

directory_path = PRED_DIR
plots_directory = PLOT_DIR
accuracies_data = process_files_in_directory(directory_path)

# Plot and save the results in the plot directory
plot_accuracies(accuracies_data, plots_directory)
if __name__ == "__main__":
main()

0 comments on commit dead1b0

Please sign in to comment.