import numpy as np
import matplotlib.pyplot as plt
amino_acids = ['M', 'L', 'I', 'V', 'N', 'Y', 'L', 'G', 'D', 'L', 'W', 'Q', 'V', 'T', 'Q', 'T', 'T', 'N', 'S', 'L', 'E', 'K', 'R', 'E', 'Q', 'F', 'R', 'G', 'L', 'Y', 'L', 'N', 'K', 'I', 'S', 'E', 'Y', 'S', 'G', 'C', 'I', 'T', 'K', 'S', 'L', 'D', 'E', 'K', 'L', 'L', 'G', 'P', 'I', 'L', 'C', 'E', 'S', 'F', 'F', 'V', 'N', 'G', 'L', 'A', 'Q', 'L', 'Y', 'D', 'K', 'K', 'Q', 'S', 'Q', 'H', 'Q', 'V', 'A', 'E', 'A', 'K', 'R', 'V', 'M', 'Y', 'I', 'V', 'A', 'K', 'E', 'F', 'N', 'A', 'D', 'Y', 'D']
attention_scores = [0.66119576, 0.8003918, 0.7251454, 0.72951543, 0.65185624, 0.70199203, 0.80630046, 0.6887356, 0.6589288, 0.7341603, 0.7876512, 0.7288329, 0.69848996, 0.71069247, 0.67025244, 0.73884994, 0.5722088, 0.8040074, 0.5754876, 0.71131456, 0.6967427, 0.7382109, 0.81139785, 0.8124091, 0.68958503, 0.7973248, 0.81046563, 0.6905091, 0.7415714, 0.70773214, 0.7040749, 0.66361755, 0.75876045, 0.7386744, 0.7828561, 0.7754259, 0.58837545, 0.7422827, 0.8162603, 0.7287266, 0.735986, 0.7265348, 0.7996047, 0.6885884, 0.7867183, 0.7414569, 0.5857099, 0.70292705, 0.76017207, 0.73468393, 0.73548526, 0.7142082, 0.691998, 0.7387566, 0.70642775, 0.7064969, 0.71976453, 0.71234685, 0.6858974, 0.632145, 0.6200939, 0.80994266, 0.70530456, 0.7354963, 0.8044978, 0.80209994, 0.7402193, 0.6048866, 0.73206097, 0.6948871, 0.7328906, 0.8089224, 0.79542226, 0.701279, 0.7086161, 0.80316234, 0.737327, 0.6821824, 0.7155741, 0.7198679, 0.81061196, 0.61442053, 0.8037533, 0.63040495, 0.7192761, 0.76812285, 0.7024652, 0.70812845, 0.80918, 0.64278257, 0.7862615, 0.81110525, 0.77845424, 0.7100564, 0.78630733, 0.8025117]
num_rows = len(attention_scores) // 30
data = np.array(attention_scores[:num_rows * 30]).reshape(num_rows, 30)
plt.imshow(data, cmap='hot', vmin=0, vmax=1)
for i in range(num_rows):
for j in range(30):
plt.text(j, i, f'{data[i, j]:.2f}', ha='center', va='center', color='black')
# Step 6: Customize the plot
plt.xlabel('Amino Acids')
plt.ylabel('Rows')
plt.title('Attention Visualization')
plt.xticks(range(30), amino_acids[:30])
plt.yticks(range(num_rows), ['Row {}'.format(row+1) for row in range(num_rows)])
cbar = plt.colorbar()
cbar.set_label('Color Scale')
plt.show()
我尝试使用此脚本来生成可视化。
但它产生了一个情节:
,
而我想要这样的东西:
当每一行都是单独的子图时,创建图会更容易。 Seaborn 的热图会根据单元格颜色自动创建黑色或白色注释。
这是一些示例代码。请注意,示例数据似乎缺少
amino_acids
中的一项。
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
amino_acids = ['M', 'L', 'I', 'V', 'N', 'Y', 'L', 'G', 'D', 'L', 'W', 'Q', 'V', 'T', 'Q', 'T', 'T', 'N', 'S', 'L', 'E', 'K', 'R', 'E', 'Q', 'F', 'R', 'G', 'L', 'Y', 'L', 'N', 'K', 'I', 'S', 'E', 'Y', 'S', 'G', 'C', 'I', 'T', 'K', 'S', 'L', 'D', 'E', 'K', 'L', 'L', 'G', 'P', 'I', 'L', 'C', 'E', 'S', 'F', 'F', 'V', 'N', 'G', 'L', 'A', 'Q', 'L', 'Y', 'D', 'K', 'K', 'Q', 'S', 'Q', 'H', 'Q', 'V', 'A', 'E', 'A', 'K', 'R', 'V', 'M', 'Y', 'I', 'V', 'A', 'K', 'E', 'F', 'N', 'A', 'D', 'Y', 'D']
attention_scores = [0.66119576, 0.8003918, 0.7251454, 0.72951543, 0.65185624, 0.70199203, 0.80630046, 0.6887356, 0.6589288, 0.7341603, 0.7876512, 0.7288329, 0.69848996, 0.71069247, 0.67025244, 0.73884994, 0.5722088, 0.8040074, 0.5754876, 0.71131456, 0.6967427, 0.7382109, 0.81139785, 0.8124091, 0.68958503, 0.7973248, 0.81046563, 0.6905091, 0.7415714, 0.70773214, 0.7040749, 0.66361755, 0.75876045, 0.7386744, 0.7828561, 0.7754259, 0.58837545, 0.7422827, 0.8162603, 0.7287266, 0.735986, 0.7265348, 0.7996047, 0.6885884, 0.7867183, 0.7414569, 0.5857099, 0.70292705, 0.76017207, 0.73468393, 0.73548526, 0.7142082, 0.691998, 0.7387566, 0.70642775, 0.7064969, 0.71976453, 0.71234685, 0.6858974, 0.632145, 0.6200939, 0.80994266, 0.70530456, 0.7354963, 0.8044978, 0.80209994, 0.7402193, 0.6048866, 0.73206097, 0.6948871, 0.7328906, 0.8089224, 0.79542226, 0.701279, 0.7086161, 0.80316234, 0.737327, 0.6821824, 0.7155741, 0.7198679, 0.81061196, 0.61442053, 0.8037533, 0.63040495, 0.7192761, 0.76812285, 0.7024652, 0.70812845, 0.80918, 0.64278257, 0.7862615, 0.81110525, 0.77845424, 0.7100564, 0.78630733, 0.8025117]
row_len = 30
num_rows = (len(attention_scores) + row_len - 1) // row_len
fig, axs = plt.subplots(nrows=num_rows, figsize=(10, num_rows))
norm = plt.Normalize(vmin=0, vmax=1)
for i, ax in enumerate(axs):
data = attention_scores[i * row_len:(i + 1) * row_len]
data_len = len(data)
data += (row_len - data_len) * [np.nan] # pad to full length
sns.heatmap(data=[data],
cmap='turbo', norm=norm,
xticklabels=amino_acids[i * row_len:(i + 1) * row_len],
yticklabels=[f'Row\n{i * row_len}-{i * row_len + data_len - 1}'],
square=True, annot=True, annot_kws={'fontsize': 6},
cbar=False, ax=ax)
ax.tick_params(length=0, rotation=0)
plt.colorbar(axs[0].collections[0], ax=axs[:], label='Color Scale')
axs[-1].set_xlabel('Amino Acids') # label for the lowest x-axis
plt.suptitle('Attention Visualization') # overall title
plt.show()