Spaces:
Runtime error
Runtime error
| from pathlib import Path | |
| import pandas as pd | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import sys | |
| import pdb | |
| threshold = 0.3 | |
| if __name__ == "__main__": | |
| wer_csv = sys.argv[1] | |
| df = pd.read_csv(wer_csv) | |
| fig, ax = plt.subplots(nrows=2, ncols=1, figsize=(25, 15)) | |
| # Hist for distribution | |
| ax[0].set_xlabel("Word Error Rate") | |
| ax[0].set_ylabel("Counts") | |
| ax[0].set_xlim(left=0.0, right=df['wer'].max()) | |
| ax[0].hist(df['wer'], bins=50) | |
| ax[0].axvline(x=threshold, color="r") | |
| # plt.savefig("hist.png") | |
| # Line curve for each sentences | |
| colors = ['green' if x < threshold else 'red' for x in df['wer']] | |
| new_ids = [str(x).split('.')[0] for x in df['id']] | |
| ax[1].set_xlabel("IDs") | |
| ax[1].set_ylabel("Word Error Rate") | |
| ax[1].scatter(new_ids, df['wer'], c=colors, marker='o') | |
| ax[1].vlines(new_ids, ymin=0, ymax=df['wer'], colors='grey', linestyle='dotted', label='Vertical Lines') | |
| ax[1].axhline(y=threshold, xmin=0, xmax=len(new_ids), color='r') | |
| # ax[0].axhline(y=threshold, color="black") | |
| # for i, v in enumerate(df['wer']): | |
| # plt.text(str(df['id'][i]).split('.')[0], -2, str(df['id'][i]), ha='center', fontsize=3) | |
| ax[1].set_xticklabels(new_ids, rotation=90, fontsize=10) | |
| ax[1].tick_params(axis='x', width=20) | |
| # ax[1].set_xlim(10, len(df['id']) + 10) | |
| plt.tight_layout() | |
| pdb.set_trace() | |
| # fig.savefig("%s/%s.png"%(Path(sys.argv[1]).parent, sys.argv[1].split('/')[-1]), format='png') | |
| fig.savefig("%s.png"%(sys.argv[1]), format='png') | |