Skip to content

Commit

Permalink
Add patch to handle gdal update to num_threads option (#424)
Browse files Browse the repository at this point in the history
* Add patch to handle gdal update to num_threads option

* Indent line as per pep8 suggestion

* Update gdal requirement to reflect earliest compatible version

---------

Co-authored-by: Simran S Sangha <ssangha@trappist.jpl.nasa.gov>
  • Loading branch information
sssangha and Simran S Sangha authored Jun 25, 2024
1 parent b0f042a commit 0d2e438
Show file tree
Hide file tree
Showing 7 changed files with 233 additions and 210 deletions.
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ dependencies:
- asf_search
- dask
- dem_stitcher>=2.5.0
- gdal>=3.4.1
- gdal>=3.7.0
- h5py
- joblib
- matplotlib
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
asf_search
dask
dem_stitcher>=2.5.0
gdal>=3.4.1
gdal>=3.7.0
h5py
joblib
matplotlib
Expand Down
180 changes: 92 additions & 88 deletions tools/ARIAtools/extractProduct.py
Original file line number Diff line number Diff line change
Expand Up @@ -479,49 +479,52 @@ def merged_productbbox(
OG_bounds = list(
ARIAtools.util.shp.open_shp(bbox_file).bounds)
gdal_warp_kwargs = {
'format': 'MEM', 'multithread': True, 'dstSRS': f'EPSG:{lyr_proj}',
'options': [f'NUM_THREADS={num_threads}']}
'format': 'MEM', 'multithread': True, 'dstSRS': f'EPSG:{lyr_proj}'}
vrt = osgeo.gdal.BuildVRT('', product_dict[0]['unwrappedPhase'][0])
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
ds = osgeo.gdal.Warp('', vrt, options=warp_options)
arrres = [abs(ds.GetGeoTransform()[1]),
abs(ds.GetGeoTransform()[-1])]
ds = None
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
ds = osgeo.gdal.Warp('', vrt, options=warp_options)
arrres = [abs(ds.GetGeoTransform()[1]),
abs(ds.GetGeoTransform()[-1])]
ds = None

# warp again with fixed transform and bounds
gdal_warp_kwargs['outputBounds'] = OG_bounds
gdal_warp_kwargs['xRes'] = arrres[0]
gdal_warp_kwargs['yRes'] = arrres[1]
gdal_warp_kwargs['targetAlignedPixels'] = True
vrt = osgeo.gdal.BuildVRT('', product_dict[0]['unwrappedPhase'][0])
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
ds = osgeo.gdal.Warp('', vrt, options=warp_options)

# Get shape of full res layers
arrshape = [ds.RasterYSize, ds.RasterXSize]
ds_gt = ds.GetGeoTransform()
new_bounds = [ds_gt[0], ds_gt[3] + (ds_gt[-1] * arrshape[0]),
ds_gt[0] + (ds_gt[1] * arrshape[1]), ds_gt[3]]

if OG_bounds != new_bounds:
# Use shapely to make list
user_bbox = shapely.geometry.Polygon(np.column_stack((
np.array([new_bounds[0], new_bounds[2], new_bounds[2],
new_bounds[0], new_bounds[0]]),
np.array([new_bounds[1], new_bounds[1], new_bounds[3],
new_bounds[3], new_bounds[1]]))))

# Save polygon in shapefile
bbox_file = os.path.join(os.path.dirname(workdir), 'user_bbox.json')
ARIAtools.util.shp.save_shp(
bbox_file, user_bbox, lyr_proj)
total_bbox = ARIAtools.util.shp.open_shp(prods_TOTbbox)
user_bbox = user_bbox.intersection(total_bbox)
ARIAtools.util.shp.save_shp(
prods_TOTbbox, user_bbox, lyr_proj)
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
ds = osgeo.gdal.Warp('', vrt, options=warp_options)

# Get shape of full res layers
arrshape = [ds.RasterYSize, ds.RasterXSize]
ds_gt = ds.GetGeoTransform()
new_bounds = [ds_gt[0], ds_gt[3] + (ds_gt[-1] * arrshape[0]),
ds_gt[0] + (ds_gt[1] * arrshape[1]), ds_gt[3]]

if OG_bounds != new_bounds:
# Use shapely to make list
user_bbox = shapely.geometry.Polygon(np.column_stack((
np.array([new_bounds[0], new_bounds[2], new_bounds[2],
new_bounds[0], new_bounds[0]]),
np.array([new_bounds[1], new_bounds[1], new_bounds[3],
new_bounds[3], new_bounds[1]]))))

# Save polygon in shapefile
bbox_file = os.path.join(
os.path.dirname(workdir), 'user_bbox.json')
ARIAtools.util.shp.save_shp(
bbox_file, user_bbox, lyr_proj)
total_bbox = ARIAtools.util.shp.open_shp(prods_TOTbbox)
user_bbox = user_bbox.intersection(total_bbox)
ARIAtools.util.shp.save_shp(
prods_TOTbbox, user_bbox, lyr_proj)

# Get projection of full res layers
proj = ds.GetProjection()
ds = None
# Get projection of full res layers
proj = ds.GetProjection()
ds = None

return (metadata_dict, product_dict, bbox_file, prods_TOTbbox,
prods_TOTbbox_metadatalyr, arrres, proj, is_nisar_file)
Expand Down Expand Up @@ -915,8 +918,7 @@ def export_product_worker(
gdal_warp_kwargs = {
'format': outputFormat, 'cutlineDSName': prods_TOTbbox,
'outputBounds': bounds, 'xRes': arrres[0], 'yRes': arrres[1],
'targetAlignedPixels': True, 'multithread': True, 'dstSRS': proj,
'options': [f'NUM_THREADS={num_threads}']}
'targetAlignedPixels': True, 'multithread': True, 'dstSRS': proj}

mask = None if maskfile is None else osgeo.gdal.Open(maskfile)
dem = None if demfile is None else osgeo.gdal.Open(demfile)
Expand Down Expand Up @@ -964,25 +966,26 @@ def export_product_worker(
# Extract/crop full res layers, except for "unw" and "conn_comp"
# which requires advanced stitching
elif layer != 'unwrappedPhase' and layer != 'connectedComponents':
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
if outputFormat == 'VRT':
# building the virtual vrt
osgeo.gdal.BuildVRT(outname + "_uncropped" + '.vrt', product)

# building the cropped vrt
osgeo.gdal.Warp(
outname + '.vrt', outname + '_uncropped.vrt',
options=warp_options)
else:
# building the VRT
osgeo.gdal.BuildVRT(outname + '.vrt', product)
osgeo.gdal.Warp(
outname, outname + '.vrt', options=warp_options)
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
if outputFormat == 'VRT':
# building the virtual vrt
osgeo.gdal.BuildVRT(outname + "_uncropped" + '.vrt', product)

# building the cropped vrt
osgeo.gdal.Warp(
outname + '.vrt', outname + '_uncropped.vrt',
options=warp_options)
else:
# building the VRT
osgeo.gdal.BuildVRT(outname + '.vrt', product)
osgeo.gdal.Warp(
outname, outname + '.vrt', options=warp_options)

# Update VRT
osgeo.gdal.Translate(
outname + '.vrt', outname,
options=osgeo.gdal.TranslateOptions(format="VRT"))
# Update VRT
osgeo.gdal.Translate(
outname + '.vrt', outname,
options=osgeo.gdal.TranslateOptions(format="VRT"))

# Extract/crop phs and conn_comp layers
else:
Expand Down Expand Up @@ -1422,9 +1425,9 @@ def finalize_metadata(outname, bbox_bounds, arrres, dem_bounds, prods_TOTbbox,

# load layered metadata array
tmp_name = outname + '.vrt'
warp_options = osgeo.gdal.WarpOptions(
format="MEM", options=['NUM_THREADS=%s' % (num_threads)])
data_array = osgeo.gdal.Warp('', tmp_name, options=warp_options)
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
warp_options = osgeo.gdal.WarpOptions(format="MEM")
data_array = osgeo.gdal.Warp('', tmp_name, options=warp_options)

# get minimum version
version_check = []
Expand Down Expand Up @@ -1513,23 +1516,24 @@ def finalize_metadata(outname, bbox_bounds, arrres, dem_bounds, prods_TOTbbox,
# it must be cut to conform with these bounds.
# Crop to track extents
data_array_nodata = data_array.GetRasterBand(1).GetNoDataValue()
gdal_warp_kwargs = {
'format': outputFormat, 'cutlineDSName': prods_TOTbbox,
'outputBounds': dem_bounds, 'dstNodata': data_array_nodata,
'xRes': dem_arrres[0], 'yRes': dem_arrres[1],
'targetAlignedPixels': True, 'multithread': True,
'options': [f'NUM_THREADS={num_threads}']}
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
osgeo.gdal.Warp(tmp_name + '_temp', tmp_name, options=warp_options)
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
gdal_warp_kwargs = {
'format': outputFormat, 'cutlineDSName': prods_TOTbbox,
'outputBounds': dem_bounds, 'dstNodata': data_array_nodata,
'xRes': dem_arrres[0], 'yRes': dem_arrres[1],
'targetAlignedPixels': True, 'multithread': True}
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
osgeo.gdal.Warp(tmp_name + '_temp', tmp_name, options=warp_options)

# Adjust shape
gdal_warp_kwargs = {
'format': outputFormat, 'cutlineDSName': prods_TOTbbox,
'outputBounds': bbox_bounds, 'dstNodata': data_array_nodata,
'xRes': arrres[0], 'yRes': arrres[1], 'targetAlignedPixels': True,
'multithread': True, 'options': [f'NUM_THREADS={num_threads}']}
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
osgeo.gdal.Warp(outname, tmp_name + '_temp', options=warp_options)
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
gdal_warp_kwargs = {
'format': outputFormat, 'cutlineDSName': prods_TOTbbox,
'outputBounds': bbox_bounds, 'dstNodata': data_array_nodata,
'xRes': arrres[0], 'yRes': arrres[1], 'targetAlignedPixels': True,
'multithread': True}
warp_options = osgeo.gdal.WarpOptions(**gdal_warp_kwargs)
osgeo.gdal.Warp(outname, tmp_name + '_temp', options=warp_options)

# remove temp files
for i in glob.glob(outname + '*_temp*'):
Expand Down Expand Up @@ -1847,21 +1851,21 @@ def gacos_correction(full_product_dict, gacos_products, bbox_file,
product_dict[2][i][0])

# Open corresponding tropo products and pass the difference
gdal_warp_kwargs = {'format': outputFormat,
'cutlineDSName': prods_TOTbbox,
'outputBounds': bounds,
'xRes': arrres[0],
'yRes': arrres[1],
'targetAlignedPixels': True,
'multithread': True,
'options': [f'NUM_THREADS={num_threads}']}
tropo_reference = osgeo.gdal.Warp(
'', tropo_reference, options=osgeo.gdal.WarpOptions(
**gdal_warp_kwargs)).ReadAsArray()
tropo_secondary = osgeo.gdal.Warp(
'', tropo_secondary, options=osgeo.gdal.WarpOptions(
**gdal_warp_kwargs)).ReadAsArray()
tropo_product = np.subtract(tropo_secondary, tropo_reference)
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
gdal_warp_kwargs = {'format': outputFormat,
'cutlineDSName': prods_TOTbbox,
'outputBounds': bounds,
'xRes': arrres[0],
'yRes': arrres[1],
'targetAlignedPixels': True,
'multithread': True}
tropo_reference = osgeo.gdal.Warp(
'', tropo_reference, options=osgeo.gdal.WarpOptions(
**gdal_warp_kwargs)).ReadAsArray()
tropo_secondary = osgeo.gdal.Warp(
'', tropo_secondary, options=osgeo.gdal.WarpOptions(
**gdal_warp_kwargs)).ReadAsArray()
tropo_product = np.subtract(tropo_secondary, tropo_reference)

# Convert troposphere from m to rad
scale = float(metadata_dict[1][i][0]) / (4 * np.pi)
Expand Down
34 changes: 18 additions & 16 deletions tools/ARIAtools/util/dem.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,15 @@ def prep_dem(demfilename, bbox_file, prods_TOTbbox, prods_TOTbbox_metadatalyr,
ds_aria = osgeo.gdal.Open(aria_dem)

else:
gdal_warp_kwargs = {
'format': outputFormat, 'cutlineDSName': prods_TOTbbox,
'outputBounds': bounds, 'outputType': osgeo.gdal.GDT_Int16,
'xRes': arrres[0], 'yRes': arrres[1], 'targetAlignedPixels': True,
'multithread': True, 'options': [f'NUM_THREADS={num_threads}']}
osgeo.gdal.Warp(
aria_dem, demfilename,
options=osgeo.gdal.WarpOptions(**gdal_warp_kwargs))
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
gdal_warp_kwargs = {
'format': outputFormat, 'cutlineDSName': prods_TOTbbox,
'outputBounds': bounds, 'outputType': osgeo.gdal.GDT_Int16,
'xRes': arrres[0], 'yRes': arrres[1],
'targetAlignedPixels': True, 'multithread': True}
osgeo.gdal.Warp(
aria_dem, demfilename,
options=osgeo.gdal.WarpOptions(**gdal_warp_kwargs))

update_file = osgeo.gdal.Open(aria_dem, osgeo.gdal.GA_Update)
update_file.SetProjection(proj)
Expand All @@ -99,14 +100,15 @@ def prep_dem(demfilename, bbox_file, prods_TOTbbox, prods_TOTbbox_metadatalyr,
bounds = list(
ARIAtools.util.shp.open_shp(prods_TOTbbox_metadatalyr).bounds)

gdal_warp_kwargs = {
'format': outputFormat, 'outputBounds': bounds, 'xRes': arrres[0],
'yRes': arrres[1], 'targetAlignedPixels': True, 'multithread': True,
'options': ['NUM_THREADS=%s' % (num_threads) + ' -overwrite']}
demfile_expanded = aria_dem.replace('.dem', '_expanded.dem')
ds_aria_expanded = osgeo.gdal.Warp(
demfile_expanded, aria_dem,
options=osgeo.gdal.WarpOptions(**gdal_warp_kwargs))
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
gdal_warp_kwargs = {
'format': outputFormat, 'outputBounds': bounds, 'xRes': arrres[0],
'yRes': arrres[1], 'targetAlignedPixels': True,
'multithread': True, 'options': ['-overwrite']}
demfile_expanded = aria_dem.replace('.dem', '_expanded.dem')
ds_aria_expanded = osgeo.gdal.Warp(
demfile_expanded, aria_dem,
options=osgeo.gdal.WarpOptions(**gdal_warp_kwargs))

# Delete temporary dem-stitcher directory
if os.path.exists(f'{dem_name}_tiles'):
Expand Down
44 changes: 24 additions & 20 deletions tools/ARIAtools/util/mask.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,11 +81,12 @@ def prep_mask(

# get output parameters from temp file
crs = pyproj.CRS.from_wkt(proj)
osgeo.gdal.Warp(
ref_file, product_dict[0], format=outputFormat,
outputBounds=bounds, xRes=arrres[0], yRes=arrres[1],
targetAlignedPixels=True, multithread=True,
options=['NUM_THREADS=%s' % (num_threads)])
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
osgeo.gdal.Warp(
ref_file, product_dict[0], format=outputFormat,
outputBounds=bounds, xRes=arrres[0], yRes=arrres[1],
targetAlignedPixels=True, multithread=True)

with rasterio.open(ref_file) as src:
reference_gt = src.transform
resize_col = src.width
Expand All @@ -108,11 +109,12 @@ def prep_mask(
resampling=resampling_mode)

# save cropped mask with precise spacing
osgeo.gdal.Warp(
maskfilename, uncropped_maskfilename, format=outputFormat,
outputBounds=bounds, outputType=osgeo.gdal.GDT_Byte,
xRes=arrres[0], yRes=arrres[1], targetAlignedPixels=True,
multithread=True, options=['NUM_THREADS=%s' % (num_threads)])
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
osgeo.gdal.Warp(
maskfilename, uncropped_maskfilename, format=outputFormat,
outputBounds=bounds, outputType=osgeo.gdal.GDT_Byte,
xRes=arrres[0], yRes=arrres[1], targetAlignedPixels=True,
multithread=True)

update_file = osgeo.gdal.Open(maskfilename, osgeo.gdal.GA_Update)
update_file.SetProjection(proj)
Expand Down Expand Up @@ -157,11 +159,12 @@ def prep_mask(
assert ds is not None, f'Could not open user mask: {user_mask}'

# crop the user mask and write
osgeo.gdal.Warp(
f'{local_mask}', ds, format=outputFormat,
cutlineDSName=prods_TOTbbox, outputBounds=bounds, xRes=arrres[0],
yRes=arrres[1], targetAlignedPixels=True, multithread=True,
options=[f'NUM_THREADS={num_threads}'])
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
osgeo.gdal.Warp(
f'{local_mask}', ds, format=outputFormat,
cutlineDSName=prods_TOTbbox, outputBounds=bounds,
xRes=arrres[0], yRes=arrres[1], targetAlignedPixels=True,
multithread=True)

# set projection of the local mask
mask_file = osgeo.gdal.Open(local_mask, osgeo.gdal.GA_Update)
Expand All @@ -186,11 +189,12 @@ def prep_mask(
mask_file.GetRasterBand(1).WriteArray(mask_arr * amp_file)

# crop/expand mask to DEM size?
osgeo.gdal.Warp(
maskfilename, maskfilename, format=outputFormat,
cutlineDSName=prods_TOTbbox, outputBounds=bounds, xRes=arrres[0],
yRes=arrres[1], targetAlignedPixels=True, multithread=True,
options=[f'NUM_THREADS={num_threads} -overwrite'])
with osgeo.gdal.config_options({"GDAL_NUM_THREADS": num_threads}):
osgeo.gdal.Warp(
maskfilename, maskfilename, format=outputFormat,
cutlineDSName=prods_TOTbbox, outputBounds=bounds, xRes=arrres[0],
yRes=arrres[1], targetAlignedPixels=True, multithread=True,
options=['-overwrite'])

mask = osgeo.gdal.Open(maskfilename, osgeo.gdal.GA_Update)
mask.SetProjection(proj)
Expand Down
Loading

0 comments on commit 0d2e438

Please sign in to comment.