-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathShpTif2ImageMask_singleband.py
322 lines (250 loc) · 10.6 KB
/
ShpTif2ImageMask_singleband.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
# 输入:所有的tif路径和shp文件的路径
# 输出:png格式的image和mask
# -*- coding: utf-8 -*-
import gdal
import osgeo
import os
import shutil
import shapefile
from osgeo import osr
import numpy as np
# import arcpy
from PIL import Image, ImageDraw
import sys
Image.MAX_IMAGE_PIXELS = None
# arcpy.env.workspace = r"D:\xzr\process\data.gbd" # arcgis地理数据库目录
def lonlat2geo(dataset, lon, lat):
'''
将经纬度坐标转为投影坐标(具体的投影坐标系由给定数据确定)
:param dataset: GDAL地理数据
:param lon: 地理坐标lon经度
:param lat: 地理坐标lat纬度
:return: 经纬度坐标(lon, lat)对应的投影坐标
'''
prosrs, geosrs = getSRSPair(dataset)
ct = osr.CoordinateTransformation(geosrs, prosrs)
coords = ct.TransformPoint(lon, lat)
return coords[:2]
def getSRSPair(dataset):
'''
获得给定数据的投影参考系和地理参考系
:param dataset: GDAL地理数据
:return: 投影参考系和地理参考系
'''
prosrs = osr.SpatialReference()
prosrs.ImportFromWkt(dataset.GetProjection())
geosrs = prosrs.CloneGeogCS()
return prosrs, geosrs
def geo2imagexy(dataset, x, y):
'''
根据GDAL的六 参数模型将给定的投影或地理坐标转为影像图上坐标(行列号)
:param dataset: GDAL地理数据
:param x: 投影或地理坐标x
:param y: 投影或地理坐标y
:return: 影坐标或地理坐标(x, y)对应的影像图上行列号(row, col)
'''
trans = dataset.GetGeoTransform()
a = np.array([[trans[1], trans[2]], [trans[4], trans[5]]])
b = np.array([x - trans[0], y - trans[3]])
return np.linalg.solve(a, b) # 使用numpy的linalg.solve进行二元一次方程的求解
def lonlat2imagexy(dataset, x, y):
'''
影像行列转经纬度:
:通过经纬度转平面坐标
:平面坐标转影像行列
'''
coords = lonlat2geo(dataset, x, y)
coords2 = geo2imagexy(dataset, coords[0], coords[1])
return (int(round(abs(coords2[0]))), int(round(abs(coords2[1]))))
# 获取影像边界
def raster_boarder(im_geotrans, im_width, im_height):
raster_x_min = im_geotrans[0]
raster_x_max = raster_x_min + im_width * im_geotrans[1]
raster_y_max = im_geotrans[3]
raster_y_min = raster_y_max + im_height * im_geotrans[5]
# raster_x_min+=1024*im_geotrans[1]
# raster_x_max-=1024*im_geotrans[1]
# raster_y_max+=1024*im_geotrans[5]
# raster_y_min-=1024*im_geotrans[5]
return raster_x_min, raster_x_max, raster_y_min, raster_y_max
# 迭代获取路径下所有tif影像
def getFileName(file_dir):
file_path_list = []
for root, dirs, files in os.walk(file_dir):
for file in files:
if (file[-3:] == "tif"):
file_path_list.append((root + '\\' + file))
return file_path_list
# 看一个polygon的xy是否与影像完全没有交集
def isOutOfRaster(x_list, y_list, raster_x_min, raster_x_max, raster_y_min, raster_y_max):
x_max = max(x_list)
y_max = max(y_list)
x_min = min(x_list)
y_min = min(y_list)
if (x_max < raster_x_min):
return True
if (y_max < raster_y_min):
return True
if (y_min > raster_y_max):
return True
if (x_min > raster_x_max):
return True
return False
# 需要改的参数
#########################################################################################################
date_string = 'bengta-ludian-dem' # 文件的独特命名
shp_path = r"I:\ludian_slope\good\bengta_valid.shp" # shp文件的路径, shapefile不支持中文路径
cell = 200
# tif的位置,两种方式
# tif_folder = r"H:\四川泥石流补充样本\2018" # 存放tif的总文件夹
# tif_files = getFileName(tif_folder)
# tif_files.extend(getFileName(r"H:\xzr\wanli\DOM\拼接\truecolor"))
tif_files=[r"I:\ludian_slope\good\slope_good_bengta.tif"]
# 生成文件总文件夹
all_dir = r"I:\ludian_slope\good\bengta\\"
# 颜色
# 滑坡
# color = (255, 255, 255)
# #崩塌
color=(0, 255, 0)
# #泥石流
# color=(0, 0, 255)
# 如果影像和shp都是投影坐标系
use_proj_coord = False
#########################################################################################################
print_shp=False
out_dir = all_dir + "tif" # 裁剪后图像保存路径
image_folder = all_dir + "image"
mask_folder = all_dir + "mask"
if not os.path.exists(out_dir):
os.mkdir(out_dir)
if not os.path.exists(image_folder):
os.mkdir(image_folder)
if not os.path.exists(mask_folder):
os.mkdir(mask_folder)
num = 0
print('共有影像:',len(tif_files))
for tif_file in tif_files:
# if('2019' not in tif_file):
# continue
print(tif_file)
dataset = gdal.Open(tif_file)
im_width = dataset.RasterXSize # 栅格矩阵的列数
im_height = dataset.RasterYSize # 栅格矩阵的行数
im_geotrans = dataset.GetGeoTransform() # 仿射矩阵
im_proj = dataset.GetProjection() # 地图投影信息
raster_x_min, raster_x_max, raster_y_min, raster_y_max = raster_boarder(im_geotrans, im_width, im_height)
print(raster_x_min, raster_x_max, raster_y_min, raster_y_max)
in_band1 = dataset.GetRasterBand(1)
# in_band2 = dataset.GetRasterBand(2)
# in_band3 = dataset.GetRasterBand(3)
sf = shapefile.Reader(shp_path) # 读取shp文件
shapes = sf.shapes()
# raster = np.zeros((im_height,im_width), dtype=np.int)
mask_all = Image.new('RGB', (im_width, im_height))
drawed_mask = 0
for i in range(len(shapes)):
if(not print_shp):
print_shp=True
print('共有shp:',len(shapes))
# print (str(i) + '/' + str(len(shapes)))
shp = shapes[i] # 获取shp文件中的每一个形状
point = shp.points # 获取每一个最小外接矩形的四个点
x_list = [ii[0] for ii in point]
y_list = [ii[1] for ii in point]
if (isOutOfRaster(x_list, y_list, raster_x_min, raster_x_max, raster_y_min, raster_y_max)):
continue
drawed_mask += 1
vertice = []
for j in range(len(x_list) - 1):
if (use_proj_coord):
coords = geo2imagexy(dataset, x_list[j], y_list[j])
coords = (int(round(abs(coords[0]))), int(round(abs(coords[1]))))
else:
coords = lonlat2imagexy(dataset, x_list[j], y_list[j])
vertice.append(coords)
draw = ImageDraw.Draw(mask_all)
# 滑坡
draw.polygon(vertice, fill=color)
print('影像共包含灾害数:', drawed_mask)
# mask_all.save(r"H:\xzr\duxiang_buffer\huapo_self\huapo_shp\2kmclip.png")
#
# sys.exit()
for i in range(len(shapes)):
# if(i>5):
# break
shp = shapes[i] # 获取shp文件中的每一个形状
point = shp.points # 获取每一个最小外接矩形的四个点
x_list = [ii[0] for ii in point]
y_list = [ii[1] for ii in point]
if (isOutOfRaster(x_list, y_list, raster_x_min, raster_x_max, raster_y_min, raster_y_max)):
continue
x_min = min(x_list)
y_min = min(y_list)
x_max = max(x_list)
y_max = max(y_list)
x_cen = (x_min + x_max) / 2
y_cen = (y_max + y_min) / 2
if (not use_proj_coord):
coords = lonlat2imagexy(dataset, x_cen, y_cen)
else:
coords = geo2imagexy(dataset, x_cen, y_cen)
coords = (int(round(abs(coords[0]))), int(round(abs(coords[1]))))
offset_x = coords[0] - cell / 2
offset_y = coords[1] - cell / 2
if (offset_x < 0 or offset_y < 0 or offset_x + cell > im_width or offset_y + cell > im_height):
continue
out_band1 = in_band1.ReadAsArray(offset_x, offset_y, cell, cell)
# out_band2 = in_band2.ReadAsArray(offset_x, offset_y, cell, cell)
# out_band3 = in_band3.ReadAsArray(offset_x, offset_y, cell, cell)
if (np.where(out_band1 == 0)[0].shape[0] > 1024 ** 2 / 2):
continue
print('灾害:', str(i) + '/' + str(len(shapes)))
num += 1
mask = mask_all.crop((offset_x, offset_y, offset_x + cell, offset_y + cell))
mask.save(mask_folder + '\\' + date_string + '-' + str(num) + '.png')
# 获取Tif的驱动,为创建切出来的图文件做准备
gtif_driver = gdal.GetDriverByName("GTiff")
# 创建切出来的要存的文件(3代表3个不都按,最后一个参数为数据类型,跟原文件一致)
out_ds = gtif_driver.Create(out_dir + '\\' + date_string + '-' + str(num) + '.tif', cell, cell, 1,
in_band1.DataType)
# print("create new tif file succeed")
# 获取原图的原点坐标信息
ori_transform = dataset.GetGeoTransform()
# if ori_transform:
# print (ori_transform)
# print("Origin = ({}, {})".format(ori_transform[0], ori_transform[3]))
# print("Pixel Size = ({}, {})".format(ori_transform[1], ori_transform[5]))
# 读取原图仿射变换参数值
top_left_x = ori_transform[0] # 左上角x坐标
w_e_pixel_resolution = ori_transform[1] # 东西方向像素分辨率
top_left_y = ori_transform[3] # 左上角y坐标
n_s_pixel_resolution = ori_transform[5] # 南北方向像素分辨率
# 根据反射变换参数计算新图的原点坐标
top_left_x = top_left_x + offset_x * w_e_pixel_resolution
top_left_y = top_left_y + offset_y * n_s_pixel_resolution
# 将计算后的值组装为一个元组,以方便设置
dst_transform = (top_left_x, ori_transform[1], ori_transform[2], top_left_y, ori_transform[4], ori_transform[5])
# 设置裁剪出来图的原点坐标
out_ds.SetGeoTransform(dst_transform)
# 设置SRS属性(投影信息)
out_ds.SetProjection(dataset.GetProjection())
# 写入目标文件
out_ds.GetRasterBand(1).WriteArray(out_band1)
# out_ds.GetRasterBand(2).WriteArray(out_band2)
# out_ds.GetRasterBand(3).WriteArray(out_band3)
# 将缓存写入磁盘
out_ds.FlushCache()
print("FlushCache succeed")
# 计算统计值
# for i in range(1, 3):
# out_ds.GetRasterBand(i).ComputeStatistics(False)
# print("ComputeStatistics succeed")
# del out_ds
img = Image.open(out_dir + '\\' + date_string + '-' + str(num) + '.tif')
img=np.asarray(img).astype(np.int)
img = Image.fromarray(img)
img.save(image_folder + '\\' + date_string + '-' + str(num) + '.png')
# print("End!")
print('image保存至:', image_folder)
print('mask保存至:', mask_folder)