""" 主要是一些方便画图的函数 """ import numpy as np import matplotlib.lines as lines import matplotlib.pyplot as plt import matplotlib as mpl from matplotlib import font_manager import glob, os from matplotlib.colors import LinearSegmentedColormap, rgb_to_hsv, hsv_to_rgb, to_rgb def set_tformat(): import datetime import matplotlib.dates as mdates import matplotlib.units as munits converter = mdates.ConciseDateConverter() munits.registry[np.datetime64] = converter munits.registry[datetime.date] = converter munits.registry[datetime.datetime] = converter base_dir = os.environ.get("zhouxu", os.environ.get("HOME")) font_dir = f"{base_dir}/Su-A-important/font/" if os.path.isdir(font_dir): for font in glob.glob(font_dir+"*"): if os.path.isfile(font): font_manager.fontManager.addfont(font) def set_font(font="Source Han Sans SC"): """ font: Source Han Sans SC, SimHei, Sarasa Mono T SC, Inconsolata, ex. a = font_manager.FontProperties(fname=f"{base_dir}/Su-A-important/font/sarasa-monoT-sc-regular.ttf") a.get_name() """ plt.rcParams["font.sans-serif"]=[font] plt.rcParams["axes.unicode_minus"]=False def get_font_name(font_file): """ font_file: path of font file like "./sarasa-monoT-sc-regular.ttf" """ font = font_manager.FontProperties(fname=font_file) font_name = font.get_name() return font_name def set_font_file(font_file): """ font_file: path of font file like "./sarasa-monoT-sc-regular.ttf" """ font_name = get_font_name(font_file) font_manager.fontManager.addfont(font_file) set_font(font_name) def set_font_size(size=9): plt.rcParams["font.size"] = size def point_transform(tar_prj, src_prj, x, y): """将一个点从一个投影转换到另一个投影""" return tar_prj.inverted().transform(src_prj.transform([x, y])) def point_fig2axes(ax, fig, x, y): """将Figure坐标下的点转换到Axes坐标""" return point_transform(ax.transAxes, fig.transFigure, x, y) def point_fig2data(ax, fig, x, y): """将Figure坐标下的点转换到Data坐标""" return point_transform(ax.transData, fig.transFigure, x, y) def point_axes2fig(fig, ax, x, y): """将Axes坐标下的点转换到Figure坐标""" return point_transform(fig.transFigure, ax.transAxes, x, y) def point_axes2data(tar_ax, src_ax, x, y): """将Axes坐标下的点转换到Data坐标""" return point_transform(tar_ax.transData, src_ax.transAxes, x, y) def point_data2axes(tar_ax, src_ax, x, y): """将Data坐标下的点转换到Axes坐标""" return point_transform(tar_ax.transAxes, src_ax.transData, x, y) def point_data2fig(fig, ax, x, y): """将Data坐标下的点转换到Figure坐标""" return point_transform(fig.transFigure, ax.transData, x, y) def connect_axs(fig, ax1, ax2, color='k', lw=1): """用于将ax1左边ax2的数据范围在ax1中绘制成子框, 并且绘制两个框的连线""" # 绘制ax1子框到ax2左上角的连线 verts = [ point_data2fig(fig, ax1, *point_axes2data(ax2, ax2, 1, 1)), point_axes2fig(fig, ax2, 0, 1), ] line = lines.Line2D([x[0] for x in verts], [x[1] for x in verts], color=color, lw=lw) fig.add_artist(line) # 绘制ax1子框到ax2左下角的连线 verts = [ point_data2fig(fig, ax1, *point_axes2data(ax2, ax2, 1, 0)), point_axes2fig(fig, ax2, 0, 0), ] line = lines.Line2D([x[0] for x in verts], [x[1] for x in verts], color=color, lw=lw) fig.add_artist(line) point_list = [] # 绘制ax1中的边框 for x, y in [[0, 0], [0, 1], [1, 1], [1, 0], [0, 0]]: x_temp, y_temp = point_axes2data(ax2, ax2, x, y) x_temp, y_temp = point_data2fig(fig, ax1, x_temp, y_temp) point_list.append((x_temp, y_temp)) ax1.plot([x[0] for x in point_list], [x[1] for x in point_list], transform=fig.transFigure, color=color, lw=lw) def add_right_cax(ax, pad=0.01, width=0.01): ''' 在一个ax右边追加与之等高的cax. pad是cax与ax的间距,width是cax的宽度. ''' axpos = ax.get_position() caxpos = mpl.transforms.Bbox.from_extents( axpos.x1 + pad, axpos.y0, axpos.x1 + pad + width, axpos.y1 ) cax = ax.figure.add_axes(caxpos) return cax def get_boundary(file, pad=0.01, h_pad=None, w_pad=None): """ get parameters for clipping white space in figure Parameters: file: str or mpl.figure.Figure or np.ndarray Something like figure file path or url, or object of figure. pad: int or float, optional Int in px, float in ratio, additional white space from boundary. h_pad: int or float, optional Like pad, for horizontal, default equal to pad. v_pad: int or float, optional Like pad, for vertical, default equal to pad. Returns: (dx, dy, xs, ys) Example: figure[ys:ys+dy, xs:xs+dx] Update: 2024-08-13 01:20:10 Sola Source code 2024-08-13 09:24:01 Sola fix bug add fig.canvas.draw() """ if isinstance(file, str): data = np.array(plt.imread(file)).astype(float) elif isinstance(file, mpl.figure.Figure): file.canvas.draw() data = np.array(file.canvas.buffer_rgba()).astype(float) elif isinstance(file, np.ndarray): data= file ny, nx = data.shape[0:2] h_pad = pad if h_pad is None else h_pad h_pad = int(h_pad*ny) if isinstance(h_pad, float) else h_pad w_pad = pad if w_pad is None else w_pad w_pad = int(w_pad*nx) if isinstance(w_pad, float) else w_pad if np.max(data) > 1: data /= 255 if data.ndim == 3: data = np.mean(data, axis=-1) data_y, data_x = np.mean(data, axis=1) < 1, np.mean(data, axis=0) < 1 xs, xe = np.argmax(data_x), nx - np.argmax(data_x[::-1]) ys, ye = np.argmax(data_y), ny - np.argmax(data_y[::-1]) dx, dy = xe - xs, ye - ys return dx + 2*w_pad, dy + 2*h_pad, xs - w_pad, ys - h_pad def interp_color_linear(color1, color2, n=16): color1_rgb, color2_rgb = np.array(color1), np.array(color2) result = interp_linear(np.array(color1_rgb), np.array(color2_rgb), n) return result def interp_linear(value1, value2, n=16): return np.linspace(1, 0, n)[:, None] @ value1[None, :] + np.linspace(0, 1, n)[:, None] @ value2[None, :] def interp_color_hsv(color1_rgb, color2_rgb, direction="nearest", n=16): """ furthest if direction == "nearest": if abs(h1 - h2) < 0.5: interp else: min(h1, h2) += 1 interp % 1 else direction == "furthest": if abs(h1 - h2) < 0.5: min(h1, h2) += 1 interp % 1 else: interp if abs(h1 - h2) > 0.5 => interp => furthest """ color1_hsv, color2_hsv = rgb_to_hsv(color1_rgb), rgb_to_hsv(color2_rgb) h_distance = abs(color1_hsv[0] - color2_hsv[0]) if ((direction == "nearest") and (h_distance < 0.5)) or ((direction == "furthest") and (h_distance > 0.5)): result = interp_linear(color1_hsv, color2_hsv, n) elif ((direction == "nearest") and (h_distance > 0.5)) or ((direction == "furthest") and (h_distance < 0.5)): if color1_hsv[0] > color2_hsv[0]: color2_hsv[0] += 1 else: color1_hsv[0] += 1 result = interp_linear(color1_hsv, color2_hsv, n) result[result > 1] -= 1 return hsv_to_rgb(result) hgt_1 = LinearSegmentedColormap.from_list("hgt #1", np.concatenate( [ interp_color_hsv(to_rgb("#aff0e9"), to_rgb("#ffffb2"), direction="nearest", n=37), interp_color_hsv(to_rgb("#ffffb2"), to_rgb("#008040"), direction="nearest", n=37), interp_color_linear(to_rgb("#008040"), to_rgb("#f7c823"), n=37), interp_color_hsv(to_rgb("#f7c823"), to_rgb("#800000"), direction="nearest", n=37), interp_color_linear(to_rgb("#800000"), to_rgb("#7a3f12"), n=37), interp_color_linear(to_rgb("#7a3f12"), to_rgb("#b9b9b9"), n=37), interp_color_linear(to_rgb("#b9b9b9"), to_rgb("#fdfdfe"), n=37), ], axis=0 ))