forked from ari/mnist-classify
160 lines
3.8 KiB
Python
160 lines
3.8 KiB
Python
#!/usr/bin/env python3
|
|
# -*- coding: utf-8 -*-
|
|
"""visualise training statistics"""
|
|
|
|
import matplotlib.pyplot as plt
|
|
import pandas as pd
|
|
from matplotlib.animation import FuncAnimation
|
|
|
|
column_names = [
|
|
"epoch",
|
|
"phase",
|
|
"progress",
|
|
"val_acc",
|
|
"best_val_acc",
|
|
"dropout",
|
|
"learning_rate",
|
|
"grad_clip",
|
|
"l2_lambda",
|
|
]
|
|
data = pd.read_csv(
|
|
"model-traning-stats-98.3p.txt", sep=r"\s+", header=None, names=column_names
|
|
)
|
|
|
|
|
|
def get_phase_lt(progress):
|
|
if progress < 0.3:
|
|
return "1 fazė: Greitas mokymasis"
|
|
elif progress < 0.6:
|
|
return "2 fazė: Stabilus mokymasis"
|
|
else:
|
|
return "3 fazė: Reguliarizacija ir pritaikymas"
|
|
|
|
|
|
data["phase_label"] = data["progress"].apply(get_phase_lt)
|
|
|
|
# Setup figure and axes
|
|
fig, axs = plt.subplots(3, 2, figsize=(12, 10))
|
|
fig.suptitle("Mokymo statistika per epochas", fontsize=16)
|
|
axs = axs.flatten()
|
|
|
|
titles = [
|
|
"Validacijos tikslumas",
|
|
"Geriausias validacijos tikslumas",
|
|
"Išmetimo reikšmė",
|
|
"Mokymosi sparta",
|
|
"Gradiento apkarpymas",
|
|
"L2 lambda",
|
|
]
|
|
|
|
for ax, title in zip(axs, titles):
|
|
ax.set_title(title)
|
|
ax.set_xlim(0, data["epoch"].max())
|
|
if "tikslumas" in title.lower():
|
|
ax.set_ylim(0, 1)
|
|
elif title == "Išmetimo reikšmė":
|
|
ax.set_ylim(0, 1)
|
|
elif title == "Mokymosi sparta":
|
|
ax.set_ylim(0, data["learning_rate"].max() * 1.1)
|
|
elif title == "Gradiento apkarpymas":
|
|
ax.set_ylim(0, data["grad_clip"].max() * 1.1)
|
|
elif title == "L2 lambda":
|
|
ax.set_ylim(0, data["l2_lambda"].max() * 1.1)
|
|
|
|
lines = []
|
|
value_texts = []
|
|
for ax in axs:
|
|
(line,) = ax.plot([], [], lw=2)
|
|
lines.append(line)
|
|
# Text to show current value
|
|
txt = ax.text(
|
|
0.95,
|
|
0.1,
|
|
"",
|
|
transform=ax.transAxes,
|
|
ha="right",
|
|
va="bottom",
|
|
fontsize=10,
|
|
bbox=dict(facecolor="white", alpha=0.7),
|
|
)
|
|
value_texts.append(txt)
|
|
|
|
# Move phase text to bottom center
|
|
phase_text = fig.text(
|
|
0.5,
|
|
0.05,
|
|
"",
|
|
ha="center",
|
|
va="bottom",
|
|
fontsize=14,
|
|
bbox=dict(facecolor="yellow", alpha=0.5),
|
|
)
|
|
|
|
phase_colors = {
|
|
"1 fazė: Greitas mokymasis": "tab:blue",
|
|
"2 fazė: Stabilus mokymasis": "tab:orange",
|
|
"3 fazė: Reguliarizacija ir pritaikymas": "tab:green",
|
|
}
|
|
|
|
|
|
def init():
|
|
for line, txt in zip(lines, value_texts):
|
|
line.set_data([], [])
|
|
txt.set_text("")
|
|
phase_text.set_text("")
|
|
return lines + value_texts + [phase_text]
|
|
|
|
|
|
def update(frame):
|
|
current_data = data[data["epoch"] <= frame]
|
|
x = current_data["epoch"]
|
|
|
|
y_values = [
|
|
current_data["val_acc"],
|
|
current_data["best_val_acc"],
|
|
current_data["dropout"],
|
|
current_data["learning_rate"],
|
|
current_data["grad_clip"],
|
|
current_data["l2_lambda"],
|
|
]
|
|
|
|
for line, y, txt in zip(lines, y_values, value_texts):
|
|
line.set_data(x, y)
|
|
if len(y) > 0:
|
|
# Show last value rounded nicely
|
|
val = y.values[-1]
|
|
txt.set_text(f"{val:.4f}")
|
|
else:
|
|
txt.set_text("")
|
|
|
|
last_row = current_data.iloc[-1]
|
|
phase = last_row["phase_label"]
|
|
progress_pct = last_row["progress"] * 100
|
|
phase_text.set_text(
|
|
f"Epocha: {int(frame)} | {phase} | Progresas: {progress_pct:.1f}%"
|
|
)
|
|
|
|
color = phase_colors.get(phase, "black")
|
|
for line in lines:
|
|
line.set_color(color)
|
|
|
|
return lines + value_texts + [phase_text]
|
|
|
|
|
|
# Increase fps and reduce interval for smoother animation
|
|
fps = 30
|
|
interval = 1000 / fps # interval in ms
|
|
|
|
anim = FuncAnimation(
|
|
fig,
|
|
update,
|
|
frames=range(1, data["epoch"].max() + 1),
|
|
init_func=init,
|
|
blit=True,
|
|
interval=interval,
|
|
)
|
|
|
|
# Save animation
|
|
anim.save("mokymo_statistika.gif", writer="pillow", fps=fps)
|
|
print("Vizualizacija išsaugota į mokymo_statistika.gif")
|
|
# plt.show()
|