我想创建一些数据的 3D 绘图。首先,我下载数据:
# Import the libraries
import numpy as np
import pandas as pd
import pandas_datareader as pdr
import datetime
import matplotlib.pyplot as plt
# Import the data
start = datetime.date(2020, 1, 1)
end = datetime.date.today()
assets = {'DGS3MO': '3m', 'DGS6MO': '6m', 'DGS1': '1y', 'DGS2': '2y', 'DGS3': '3y', 'DGS5': '5y', 'DGS7': '7y', 'DGS10': '10y', 'DGS20': '20y', 'DGS30': '30y'}
df_yields = pdr.DataReader(list(assets.keys()), 'fred', start, end)
df_yields.rename(columns=assets, inplace=True)
df_yields = df_yields.dropna(how='any')
然后,我创建 3D 绘图:
# Create 3D plot
dates = np.array([x.toordinal() for x in df_yields.index])
maturities = np.array([3/12, 6/12, 1, 2, 3, 5, 7, 10, 20, 30])
X, Y = np.meshgrid(dates, maturities)
Z = df_yields.values.T
fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.set_box_aspect(aspect=(0.9, 0.5, 0.4))
surf = ax.plot_surface(X, Y, Z, rstride=10, cmap='Blues', edgecolor='black', linewidth=0.5)
ax.set_xlabel('Date')
ax.set_ylabel('Maturity (Y)')
ax.set_zlabel('Yield (%)')
years = np.arange(start.year, end.year + 1)
selected_years = np.linspace(start.year, end.year, 5, dtype=int)
ax.set_xticks([datetime.date(year, 1, 1).toordinal() for year in selected_years])
ax.set_xticklabels(selected_years)
ax.xaxis.labelpad = 5
ax.zaxis.labelpad = -1
plt.subplots_adjust(left=0.05, right=100, top=0.95, bottom=0.05)
这是输出:
有两件事我想解决:
我已经尝试了很长时间来解决这两个问题,但我没有取得任何进展。有什么建议吗?
设置绘图轴的限制肯定会让它变得更紧。您还可以使用
plt.tight_layout()
,它会自动将边距调整得更紧。
因此,将这些合并到您的代码中时:
ax.set_ylim(0,max(maturities))
ax.set_xlim(min(dates), max(dates))
ax.set_zlim(0, max(df_yields.max()))
与提到的
plt.tight_layout()
一起,边距将明显变小,至少在侧面和底部:
如果您想要更紧,我建议使用
plotly
:
import numpy as np
import pandas_datareader as pdr
import datetime
import plotly.graph_objects as go
# import the data
start = datetime.date(2020, 1, 1)
end = datetime.date.today()
assets = {'DGS3MO': '3m', 'DGS6MO': '6m', 'DGS1': '1y', 'DGS2': '2y', 'DGS3': '3y', 'DGS5': '5y', 'DGS7': '7y', 'DGS10': '10y', 'DGS20': '20y', 'DGS30': '30y'}
df_yields = pdr.DataReader(list(assets.keys()), 'fred', start, end)
df_yields.rename(columns=assets, inplace=True)
df_yields = df_yields.dropna(how='any')
dates = np.array([x.toordinal() for x in df_yields.index])
maturities = np.array([3/12, 6/12, 1, 2, 3, 5, 7, 10, 20, 30])
X, Y = np.meshgrid(dates, maturities)
Z = df_yields.values.T
# plot
fig = go.Figure(data=[go.Surface(z=Z, x=dates, y=maturities, colorscale='Blues')])
fig.update_layout(
scene=dict(
xaxis_title='Date',
yaxis_title='Maturity (Years)',
zaxis_title='Yield (%)',
xaxis=dict(tickvals=[datetime.date(year, 1, 1).toordinal() for year in np.linspace(start.year, end.year, 5, dtype=int)],
ticktext=[str(year) for year in np.linspace(start.year, end.year, 5, dtype=int)]),
),
margin=dict(l=0, r=0, t=40, b=0), # Tight margins
title='3D Curve'
)
fig.show()