import matplotlib.pyplot as plt
import numpy as np
from mpl_toolkits.mplot3d import Axes3D
layers = [
{"name": "SiO₂", "color": "#8A2BE2", "height": 0.2},
{"name": "h-BN", "color": "#87CEEB", "height": 0.15},
{"name": "InSe", "color": "#0000FF", "height": 0.1},
{"name": "Graphene", "color": "#808080", "height": 0.05},
{"name": "h-BN", "color": "#87CEEB", "height": 0.15},
{"name": "Au", "color": "#FFD700", "height": 0.1}
]
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
x_pos = 0
y_pos = 0
z_base = 0
width = 1
depth = 1
for i, layer in enumerate(layers):
ax.bar3d(x_pos, y_pos, z_base,
width, depth, layer["height"],
color=layer["color"],
edgecolor='black',
alpha=0.8)
ax.text(x_pos+width/2, y_pos+depth/2, z_base + layer["height"]/2,
layer["name"],
color='black',
ha='center',
va='center',
fontsize=10)
z_base += layer["height"]
ax.set_xticks([])
ax.set_yticks([])
ax.set_zticks([])
ax.set_xlim(0, 1)
ax.set_ylim(0, 1)
ax.set_zlim(0, z_base)
ax.view_init(elev=20, azim=-45)
plt.title("Multilayer Structure Diagram")
plt.tight_layout()
plt.show()