support high dimension array

This commit is contained in:
2023-03-18 15:52:08 +08:00
parent 2a110bb2a6
commit 6e9288c17e

View File

@ -1,5 +1,7 @@
import numpy as np
import cartopy.crs as ccrs
class model_info_2d(object):
"""
用于创建模式网格, 并包含了相关信息, 提供了方便坐标与经纬度相互转换的工具
@ -34,6 +36,8 @@ class model_info_2d(object):
2022-09-28 16:42:12 Sola v2 加入了转化传入对象为numpy数组的功能
2022-09-28 18:28:38 Sola v2 修正了计算网格id时, 未输出ix, iy的bug
2023-03-14 10:02:41 Sola v3 增加输出边界网格的功能(调整get_grid, 使其支持边界宽度及边缘网格id)
2023-03-18 15:17:40 Sola v4 删除扩展边界的选项
2023-03-18 15:18:04 Sola v4 修正输入高维数组时, 计算报错的问题
测试记录:
2022-09-28 16:28:10 Sola v2 新的简化网格生成方法测试完成, 结果与旧版一致
2022-09-28 18:27:59 Sola v2 测试了使用proj_LC投影的相关方法, 网格与WRF一致
@ -125,12 +129,13 @@ class model_info_2d(object):
2022-09-28 15:46:50 Sola 简化原本的网格计算, 使用转置的方式代替判断返回数组长度
2022-09-28 16:40:27 Sola 增加将输入数组转化为numpy数组的功能, 防止传入列表
2022-10-19 18:52:25 Sola 修正了除错距离的bug
2023-03-18 15:39:06 Sola 在计算前, 先将数组展开到1维, 返回时折叠
注意事项:
当前存在一个bug, 输入的投影必须是cartopy的投影, 否则无法计算经纬度,
但是是否有必要在自己写的proj中加入该功能? 需要考虑
"""
original_x_array = np.array(original_x_array)
original_y_array = np.array(original_y_array)
original_x_array, original_y_array, shape = flat_array(\
np.array(original_x_array), np.array(original_y_array))
if hasattr(self.projection, 'grid_ids_float'): # 如果投影有相应方法
# 判断是否是经纬度坐标, 不是则转化为经纬度坐标
if original_proj != ccrs.PlateCarree():
@ -149,6 +154,7 @@ class model_info_2d(object):
# 将m转化为网格坐标
ix_array = ((ix_array - self.lowerleft_projxy[0])/ self.dx).T
iy_array = ((iy_array - self.lowerleft_projxy[1])/ self.dy).T
ix_array, iy_array = fold_array(ix_array, iy_array, shape)
return ix_array, iy_array
def grid_ids(self, original_x_array, original_y_array,
@ -186,8 +192,9 @@ class model_info_2d(object):
2022-09-28 16:07:40 Sola 增加判断proj是否有计算网格的功能
2022-09-28 16:08:38 Sola 简化原本的网格计算, 使用转置的方式代替判断返回数组长度
2022-09-28 16:40:27 Sola 增加将输入数组转化为numpy数组的功能, 防止传入列表
2023-03-18 15:39:06 Sola 在计算前, 先将数组展开到1维, 返回时折叠
"""
ix_array, iy_array = np.array(ix_array), np.array(iy_array)
ix_array, iy_array, shape = flat_array(np.array(ix_array), np.array(iy_array))
if hasattr(self.projection, 'grid_lonlats'):
lon_array, lat_array = self.projection.grid_lonlats(ix_array, iy_array)
else:
@ -196,18 +203,19 @@ class model_info_2d(object):
lon_array, lat_array, _ = ccrs.PlateCarree().transform_points(
self.projection, x_array, y_array).T
lon_array, lat_array = lon_array.T, lat_array.T
lon_array, lat_array = fold_array(lon_array, lat_array, shape)
return lon_array, lat_array
def get_grid(self, bdy_width=0, type=None):
def get_grid(self, type=None):
"""
范围模式所有网格的经纬度坐标
2023-03-14 10:05:43 Sola 更新边界宽度的功能及边缘网格的功能
获取的边缘网格从左下角开始顺时针排序(左优先)
2023-03-14 10:30:23 Sola 经过测试, 代码可以正常运行
2023-03-18 15:40:20 Sola 删除边界宽度的功能(没有用了)
"""
# 获取网格信息, 下标从0开始
ys, xs = np.meshgrid(range(-bdy_width, self.ny + bdy_width),
range(-bdy_width, self.nx + bdy_width), indexing='ij')
ys, xs = np.meshgrid(range(self.ny), range(self.nx), indexing='ij')
if type is None:
xlon, xlat = self.grid_lonlats(xs, ys) # 从网格信息获取经纬度信息
elif type.lower() in ["corner", "c"]: # 四角的网格
@ -227,3 +235,38 @@ class model_info_2d(object):
xlon = np.array([x[0] for x in result])
xlat = np.array([x[1] for x in result])
return xlon, xlat
def flat_array(
x : np.ndarray,
y : np.ndarray
) -> tuple:
"""
用于将数组展开, 并检查数组性质是否一致
更新记录:
2023-03-18 15:25:30 Sola 编写源代码
"""
shape = x.shape
if not shape == y.shape:
print(f"[WARNING] dimension mismatch, {x.shape}, {y.shape}")
x, y = x.reshape(-1), y.reshape(-1)
return x, y, shape
def fold_array(
x : np.ndarray,
y : np.ndarray,
shape : tuple
) -> tuple:
"""
用于将展开的数组折叠回去
更新记录:
2023-03-18 15:26:42 Sola 编写源代码
"""
x, y = x.reshape(shape), y.reshape(shape)
return x, y