104 for idx
in range(len(cols)):
105 for idx2
in range(idx + 1, len(cols)):
106 unique1 = np.unique(data[idx])
107 unique2 = np.unique(data[idx2])
108 if len(unique2) < len(unique1):
115 fig, axs = plt.subplots(nrows=len(unique2), ncols=len(unique1), sharex=
'all', sharey=
'all', figsize=figsize)
116 fig.suptitle(f
'↓ {cols[idx2]} → {cols[idx]}')
117 for uni1_idx, uni1
in enumerate(unique1):
118 for uni2_idx, uni2
in enumerate(unique2):
119 acc_sum = list(np.zeros(uni_accuracies.shape))
120 for ind, acc
in enumerate(uni_accuracies):
121 count_acc = [
True for dat
in data.transpose()
122 if dat[idx] == uni1
and dat[idx2] == uni2
and dat[-1] == acc]
123 acc_sum[ind] = len(count_acc)
124 avg_acc = uni_accuracies - uni_accuracies[0]
125 avg_acc = avg_acc/uni_accuracies[-1]
126 avg_acc = list(map(mul, acc_sum, avg_acc))
127 avg_acc = sum(avg_acc)/sum(acc_sum)
128 acc_sum = np.cumsum(acc_sum)
129 hue = (-(1/3)+avg_acc/3)+1
130 hsv_colors = (hue, 1, 1)
131 ax = axs[uni2_idx][uni1_idx]
132 ax.plot(uni_accuracies, acc_sum, c=mpc.hsv_to_rgb(hsv_colors))
135 ax.set_title(f
'{uni1:.4f}')
137 ax.set_ylabel(f
'{uni2:.4f}')
138 os.makedirs(path, exist_ok=
True)
139 plt.savefig(f
'{path}/{cols[idx2]}_{cols[idx]}.png', bbox_inches=
'tight')