Matplotlib - 图中 3D 子图的大小

问题描述 投票:0回答:1

我想创建一些数据的 3D 绘图。首先,我下载数据:

# Import the libraries

import numpy as np
import pandas as pd
import pandas_datareader as pdr
import datetime
import matplotlib.pyplot as plt

# Import the data

start = datetime.date(2020, 1, 1)
end = datetime.date.today()

assets = {'DGS3MO': '3m', 'DGS6MO': '6m', 'DGS1': '1y', 'DGS2': '2y', 'DGS3': '3y', 'DGS5': '5y', 'DGS7': '7y', 'DGS10': '10y', 'DGS20': '20y', 'DGS30': '30y'}
df_yields = pdr.DataReader(list(assets.keys()), 'fred', start, end)
df_yields.rename(columns=assets, inplace=True)
df_yields = df_yields.dropna(how='any')

然后,我创建 3D 绘图:

# Create 3D plot

dates = np.array([x.toordinal() for x in df_yields.index])
maturities = np.array([3/12, 6/12, 1, 2, 3, 5, 7, 10, 20, 30])

X, Y = np.meshgrid(dates, maturities)
Z = df_yields.values.T

fig = plt.figure(figsize=(10, 10))
ax = fig.add_subplot(111, projection='3d')
ax.set_box_aspect(aspect=(0.9, 0.5, 0.4))

surf = ax.plot_surface(X, Y, Z, rstride=10, cmap='Blues', edgecolor='black', linewidth=0.5)

ax.set_xlabel('Date')
ax.set_ylabel('Maturity (Y)')
ax.set_zlabel('Yield (%)')

years = np.arange(start.year, end.year + 1)
selected_years = np.linspace(start.year, end.year, 5, dtype=int)
ax.set_xticks([datetime.date(year, 1, 1).toordinal() for year in selected_years])
ax.set_xticklabels(selected_years)
ax.xaxis.labelpad = 5
ax.zaxis.labelpad = -1

plt.subplots_adjust(left=0.05, right=100, top=0.95, bottom=0.05)

这是输出:

enter image description here

有两件事我想解决:

  1. 绘图右侧尺寸被图形边缘切割的事实
  2. 减少顶部和底部的宽边距(理想情况下,使左侧和右侧的边距也尽可能紧)。

我已经尝试了很长时间来解决这两个问题,但我没有取得任何进展。有什么建议吗?

matplotlib 3d surface
1个回答
0
投票

设置绘图轴的限制肯定会让它变得更紧。您还可以使用

plt.tight_layout()
,它会自动将边距调整得更紧。

因此,将这些合并到您的代码中时:

ax.set_ylim(0,max(maturities))
ax.set_xlim(min(dates), max(dates))
ax.set_zlim(0, max(df_yields.max()))

与提到的

plt.tight_layout()
一起,边距将明显变小,至少在侧面和底部: enter image description here

如果您想要更紧,我建议使用

plotly

import numpy as np
import pandas_datareader as pdr
import datetime
import plotly.graph_objects as go


# import the data
start = datetime.date(2020, 1, 1)
end = datetime.date.today()

assets = {'DGS3MO': '3m', 'DGS6MO': '6m', 'DGS1': '1y', 'DGS2': '2y', 'DGS3': '3y', 'DGS5': '5y', 'DGS7': '7y', 'DGS10': '10y', 'DGS20': '20y', 'DGS30': '30y'}
df_yields = pdr.DataReader(list(assets.keys()), 'fred', start, end)
df_yields.rename(columns=assets, inplace=True)
df_yields = df_yields.dropna(how='any')


dates = np.array([x.toordinal() for x in df_yields.index])
maturities = np.array([3/12, 6/12, 1, 2, 3, 5, 7, 10, 20, 30])

X, Y = np.meshgrid(dates, maturities)
Z = df_yields.values.T

# plot
fig = go.Figure(data=[go.Surface(z=Z, x=dates, y=maturities, colorscale='Blues')])

fig.update_layout(
    scene=dict(
        xaxis_title='Date',
        yaxis_title='Maturity (Years)',
        zaxis_title='Yield (%)',
        xaxis=dict(tickvals=[datetime.date(year, 1, 1).toordinal() for year in np.linspace(start.year, end.year, 5, dtype=int)],
                   ticktext=[str(year) for year in np.linspace(start.year, end.year, 5, dtype=int)]),
    ),
    margin=dict(l=0, r=0, t=40, b=0),  # Tight margins
    title='3D Curve'
)

fig.show()

结果是: enter image description here

© www.soinside.com 2019 - 2024. All rights reserved.