Files
test/zx_test.py
2025-06-22 01:43:06 +08:00

215 lines
8.0 KiB
Python

"""
主要是一些方便画图的函数
"""
import numpy as np
import matplotlib.lines as lines
import matplotlib.pyplot as plt
import matplotlib as mpl
from matplotlib import font_manager
import glob, os
from matplotlib.colors import LinearSegmentedColormap, rgb_to_hsv, hsv_to_rgb, to_rgb
def set_tformat():
import datetime
import matplotlib.dates as mdates
import matplotlib.units as munits
converter = mdates.ConciseDateConverter()
munits.registry[np.datetime64] = converter
munits.registry[datetime.date] = converter
munits.registry[datetime.datetime] = converter
base_dir = os.environ.get("zhouxu", os.environ.get("HOME"))
font_dir = f"{base_dir}/Su-A-important/font/"
if os.path.isdir(font_dir):
for font in glob.glob(font_dir+"*"):
if os.path.isfile(font):
font_manager.fontManager.addfont(font)
def set_font(font="Source Han Sans SC"):
"""
font: Source Han Sans SC, SimHei, Sarasa Mono T SC, Inconsolata, ex.
a = font_manager.FontProperties(fname=f"{base_dir}/Su-A-important/font/sarasa-monoT-sc-regular.ttf")
a.get_name()
"""
plt.rcParams["font.sans-serif"]=[font]
plt.rcParams["axes.unicode_minus"]=False
def get_font_name(font_file):
"""
font_file: path of font file like "./sarasa-monoT-sc-regular.ttf"
"""
font = font_manager.FontProperties(fname=font_file)
font_name = font.get_name()
return font_name
def set_font_file(font_file):
"""
font_file: path of font file like "./sarasa-monoT-sc-regular.ttf"
"""
font_name = get_font_name(font_file)
font_manager.fontManager.addfont(font_file)
set_font(font_name)
def set_font_size(size=9):
plt.rcParams["font.size"] = size
def point_transform(tar_prj, src_prj, x, y):
"""将一个点从一个投影转换到另一个投影"""
return tar_prj.inverted().transform(src_prj.transform([x, y]))
def point_fig2axes(ax, fig, x, y):
"""将Figure坐标下的点转换到Axes坐标"""
return point_transform(ax.transAxes, fig.transFigure, x, y)
def point_fig2data(ax, fig, x, y):
"""将Figure坐标下的点转换到Data坐标"""
return point_transform(ax.transData, fig.transFigure, x, y)
def point_axes2fig(fig, ax, x, y):
"""将Axes坐标下的点转换到Figure坐标"""
return point_transform(fig.transFigure, ax.transAxes, x, y)
def point_axes2data(tar_ax, src_ax, x, y):
"""将Axes坐标下的点转换到Data坐标"""
return point_transform(tar_ax.transData, src_ax.transAxes, x, y)
def point_data2axes(tar_ax, src_ax, x, y):
"""将Data坐标下的点转换到Axes坐标"""
return point_transform(tar_ax.transAxes, src_ax.transData, x, y)
def point_data2fig(fig, ax, x, y):
"""将Data坐标下的点转换到Figure坐标"""
return point_transform(fig.transFigure, ax.transData, x, y)
def connect_axs(fig, ax1, ax2, color='k', lw=1):
"""用于将ax1左边ax2的数据范围在ax1中绘制成子框, 并且绘制两个框的连线"""
# 绘制ax1子框到ax2左上角的连线
verts = [
point_data2fig(fig, ax1, *point_axes2data(ax2, ax2, 1, 1)),
point_axes2fig(fig, ax2, 0, 1),
]
line = lines.Line2D([x[0] for x in verts], [x[1] for x in verts], color=color, lw=lw)
fig.add_artist(line)
# 绘制ax1子框到ax2左下角的连线
verts = [
point_data2fig(fig, ax1, *point_axes2data(ax2, ax2, 1, 0)),
point_axes2fig(fig, ax2, 0, 0),
]
line = lines.Line2D([x[0] for x in verts], [x[1] for x in verts], color=color, lw=lw)
fig.add_artist(line)
point_list = []
# 绘制ax1中的边框
for x, y in [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]:
x_temp, y_temp = point_axes2data(ax2, ax2, x, y)
x_temp, y_temp = point_data2fig(fig, ax1, x_temp, y_temp)
point_list.append((x_temp, y_temp))
ax1.plot([x[0] for x in point_list], [x[1] for x in point_list], transform=fig.transFigure, color=color, lw=lw)
def add_right_cax(ax, pad=0.01, width=0.01):
'''
在一个ax右边追加与之等高的cax.
pad是cax与ax的间距,width是cax的宽度.
'''
axpos = ax.get_position()
caxpos = mpl.transforms.Bbox.from_extents(
axpos.x1 + pad,
axpos.y0,
axpos.x1 + pad + width,
axpos.y1
)
cax = ax.figure.add_axes(caxpos)
return cax
def get_boundary(file, pad=0.01, h_pad=None, w_pad=None):
"""
get parameters for clipping white space in figure
Parameters:
file: str or mpl.figure.Figure or np.ndarray
Something like figure file path or url, or object of figure.
pad: int or float, optional
Int in px, float in ratio, additional white space from boundary.
h_pad: int or float, optional
Like pad, for horizontal, default equal to pad.
v_pad: int or float, optional
Like pad, for vertical, default equal to pad.
Returns:
(dx, dy, xs, ys)
Example: figure[ys:ys+dy, xs:xs+dx]
Update:
2024-08-13 01:20:10 Sola Source code
2024-08-13 09:24:01 Sola fix bug add fig.canvas.draw()
"""
if isinstance(file, str):
data = np.array(plt.imread(file)).astype(float)
elif isinstance(file, mpl.figure.Figure):
file.canvas.draw()
data = np.array(file.canvas.buffer_rgba()).astype(float)
elif isinstance(file, np.ndarray):
data= file
ny, nx = data.shape[0:2]
h_pad = pad if h_pad is None else h_pad
h_pad = int(h_pad*ny) if isinstance(h_pad, float) else h_pad
w_pad = pad if w_pad is None else w_pad
w_pad = int(w_pad*nx) if isinstance(w_pad, float) else w_pad
if np.max(data) > 1:
data /= 255
if data.ndim == 3:
data = np.mean(data, axis=-1)
data_y, data_x = np.mean(data, axis=1) < 1, np.mean(data, axis=0) < 1
xs, xe = np.argmax(data_x), nx - np.argmax(data_x[::-1])
ys, ye = np.argmax(data_y), ny - np.argmax(data_y[::-1])
dx, dy = xe - xs, ye - ys
return dx + 2*w_pad, dy + 2*h_pad, xs - w_pad, ys - h_pad
def interp_color_linear(color1, color2, n=16):
color1_rgb, color2_rgb = np.array(color1), np.array(color2)
result = interp_linear(np.array(color1_rgb), np.array(color2_rgb), n)
return result
def interp_linear(value1, value2, n=16):
return np.linspace(1, 0, n)[:, None] @ value1[None, :] + np.linspace(0, 1, n)[:, None] @ value2[None, :]
def interp_color_hsv(color1_rgb, color2_rgb, direction="nearest", n=16):
"""
furthest
if direction == "nearest":
if abs(h1 - h2) < 0.5:
interp
else:
min(h1, h2) += 1
interp % 1
else direction == "furthest":
if abs(h1 - h2) < 0.5:
min(h1, h2) += 1
interp % 1
else:
interp
if abs(h1 - h2) > 0.5 => interp => furthest
"""
color1_hsv, color2_hsv = rgb_to_hsv(color1_rgb), rgb_to_hsv(color2_rgb)
h_distance = abs(color1_hsv[0] - color2_hsv[0])
if ((direction == "nearest") and (h_distance < 0.5)) or ((direction == "furthest") and (h_distance > 0.5)):
result = interp_linear(color1_hsv, color2_hsv, n)
elif ((direction == "nearest") and (h_distance > 0.5)) or ((direction == "furthest") and (h_distance < 0.5)):
if color1_hsv[0] > color2_hsv[0]:
color2_hsv[0] += 1
else:
color1_hsv[0] += 1
result = interp_linear(color1_hsv, color2_hsv, n)
result[result > 1] -= 1
return hsv_to_rgb(result)
hgt_1 = LinearSegmentedColormap.from_list("hgt #1", np.concatenate(
[
interp_color_hsv(to_rgb("#aff0e9"), to_rgb("#ffffb2"), direction="nearest", n=37),
interp_color_hsv(to_rgb("#ffffb2"), to_rgb("#008040"), direction="nearest", n=37),
interp_color_linear(to_rgb("#008040"), to_rgb("#f7c823"), n=37),
interp_color_hsv(to_rgb("#f7c823"), to_rgb("#800000"), direction="nearest", n=37),
interp_color_linear(to_rgb("#800000"), to_rgb("#7a3f12"), n=37),
interp_color_linear(to_rgb("#7a3f12"), to_rgb("#b9b9b9"), n=37),
interp_color_linear(to_rgb("#b9b9b9"), to_rgb("#fdfdfe"), n=37),
],
axis=0
))