mnist-classify/visualise.py
Arija A. 62eefb93e0
Add visualisation.
Signed-off-by: Arija A. <ari@ari.lt>
2025-05-19 21:21:00 +03:00

160 lines
3.8 KiB
Python

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""visualise training statistics"""
import matplotlib.pyplot as plt
import pandas as pd
from matplotlib.animation import FuncAnimation
column_names = [
"epoch",
"phase",
"progress",
"val_acc",
"best_val_acc",
"dropout",
"learning_rate",
"grad_clip",
"l2_lambda",
]
data = pd.read_csv(
"model-traning-stats-98.3p.txt", sep=r"\s+", header=None, names=column_names
)
def get_phase_lt(progress):
if progress < 0.3:
return "1 fazė: Greitas mokymasis"
elif progress < 0.6:
return "2 fazė: Stabilus mokymasis"
else:
return "3 fazė: Reguliarizacija ir pritaikymas"
data["phase_label"] = data["progress"].apply(get_phase_lt)
# Setup figure and axes
fig, axs = plt.subplots(3, 2, figsize=(12, 10))
fig.suptitle("Mokymo statistika per epochas", fontsize=16)
axs = axs.flatten()
titles = [
"Validacijos tikslumas",
"Geriausias validacijos tikslumas",
"Išmetimo reikšmė",
"Mokymosi sparta",
"Gradiento apkarpymas",
"L2 lambda",
]
for ax, title in zip(axs, titles):
ax.set_title(title)
ax.set_xlim(0, data["epoch"].max())
if "tikslumas" in title.lower():
ax.set_ylim(0, 1)
elif title == "Išmetimo reikšmė":
ax.set_ylim(0, 1)
elif title == "Mokymosi sparta":
ax.set_ylim(0, data["learning_rate"].max() * 1.1)
elif title == "Gradiento apkarpymas":
ax.set_ylim(0, data["grad_clip"].max() * 1.1)
elif title == "L2 lambda":
ax.set_ylim(0, data["l2_lambda"].max() * 1.1)
lines = []
value_texts = []
for ax in axs:
(line,) = ax.plot([], [], lw=2)
lines.append(line)
# Text to show current value
txt = ax.text(
0.95,
0.1,
"",
transform=ax.transAxes,
ha="right",
va="bottom",
fontsize=10,
bbox=dict(facecolor="white", alpha=0.7),
)
value_texts.append(txt)
# Move phase text to bottom center
phase_text = fig.text(
0.5,
0.05,
"",
ha="center",
va="bottom",
fontsize=14,
bbox=dict(facecolor="yellow", alpha=0.5),
)
phase_colors = {
"1 fazė: Greitas mokymasis": "tab:blue",
"2 fazė: Stabilus mokymasis": "tab:orange",
"3 fazė: Reguliarizacija ir pritaikymas": "tab:green",
}
def init():
for line, txt in zip(lines, value_texts):
line.set_data([], [])
txt.set_text("")
phase_text.set_text("")
return lines + value_texts + [phase_text]
def update(frame):
current_data = data[data["epoch"] <= frame]
x = current_data["epoch"]
y_values = [
current_data["val_acc"],
current_data["best_val_acc"],
current_data["dropout"],
current_data["learning_rate"],
current_data["grad_clip"],
current_data["l2_lambda"],
]
for line, y, txt in zip(lines, y_values, value_texts):
line.set_data(x, y)
if len(y) > 0:
# Show last value rounded nicely
val = y.values[-1]
txt.set_text(f"{val:.4f}")
else:
txt.set_text("")
last_row = current_data.iloc[-1]
phase = last_row["phase_label"]
progress_pct = last_row["progress"] * 100
phase_text.set_text(
f"Epocha: {int(frame)} | {phase} | Progresas: {progress_pct:.1f}%"
)
color = phase_colors.get(phase, "black")
for line in lines:
line.set_color(color)
return lines + value_texts + [phase_text]
# Increase fps and reduce interval for smoother animation
fps = 30
interval = 1000 / fps # interval in ms
anim = FuncAnimation(
fig,
update,
frames=range(1, data["epoch"].max() + 1),
init_func=init,
blit=True,
interval=interval,
)
# Save animation
anim.save("mokymo_statistika.gif", writer="pillow", fps=fps)
print("Vizualizacija išsaugota į mokymo_statistika.gif")
# plt.show()