import matplotlib.pyplot as plt
import numpy as np
models = ['G(L=2)', 'G(L=4)', 'G(L=8)', 'Ours']
train_time = [150, 130, 102, 109]
val_time = [17, 15, 12, 13]
test_time = [1.12, 0.96, 0.82, 0.81]
x = np.arange(len(models))
width = 0.25
fig, ax = plt.subplots(figsize=(8, 5))
bars1 = ax.bar(x - width, train_time, width, label='Train Time')
bars2 = ax.bar(x, val_time, width, label='Val Time')
bars3 = ax.bar(x + width, test_time, width, label='Test Time')
ax.set_ylabel('Time (seconds)')
ax.set_title('Training, Validation, and Testing Time per Iteration')
ax.set_xticks(x)
ax.set_xticklabels(models)
ax.legend()
ax.grid(axis='y', linestyle='--', alpha=0.6)
def add_labels(bars):
for bar in bars:
height = bar.get_height()
ax.annotate(f'{height:.2f}',
xy=(bar.get_x() + bar.get_width() / 2, height),
xytext=(0, 3),
textcoords="offset points",
ha='center', va='bottom')
add_labels(bars1)
add_labels(bars2)
add_labels(bars3)
plt.tight_layout()
plt.show()