@@ -1055,6 +1055,225 @@ def identity(in_ar, out_ar, *args, **kwargs):
1055
1055
assert vrt_ds .GetRasterBand (1 ).DataType == dtype
1056
1056
1057
1057
1058
+ ###############################################################################
1059
+ # Test arbitrary expression pixel functions
1060
+
1061
+
1062
+ def vrt_expression_xml (tmpdir , expression , dialect , sources ):
1063
+
1064
+ drv = gdal .GetDriverByName ("GTiff" )
1065
+
1066
+ nx = 1
1067
+ ny = 1
1068
+
1069
+ expression = expression .replace ("<" , "<" ).replace (">" , ">" )
1070
+
1071
+ xml = f"""<VRTDataset rasterXSize="{ nx } " rasterYSize="{ ny } ">
1072
+ <VRTRasterBand dataType="Float64" band="1" subClass="VRTDerivedRasterBand">
1073
+ <PixelFunctionType>expression</PixelFunctionType>
1074
+ <PixelFunctionArguments expression="{ expression } " dialect="{ dialect } " />"""
1075
+
1076
+ for i , source in enumerate (sources ):
1077
+ if type (source ) is tuple :
1078
+ source_name , source_value = source
1079
+ else :
1080
+ source_name = ""
1081
+ source_value = source
1082
+
1083
+ src_fname = tmpdir / f"source_{ i } .tif"
1084
+
1085
+ with drv .Create (src_fname , 1 , 1 , 1 , gdal .GDT_Float64 ) as ds :
1086
+ ds .GetRasterBand (1 ).Fill (source_value )
1087
+
1088
+ xml += f"""<SimpleSource name="{ source_name } ">
1089
+ <SourceFilename relativeToVRT="0">{ src_fname } </SourceFilename>
1090
+ <SourceBand>1</SourceBand>
1091
+ </SimpleSource>"""
1092
+
1093
+ xml += "</VRTRasterBand></VRTDataset>"
1094
+
1095
+ return xml
1096
+
1097
+
1098
+ @pytest .mark .parametrize (
1099
+ "expression,sources,result,dialects" ,
1100
+ [
1101
+ pytest .param ("A" , [("A" , 77 )], 77 , None , id = "identity" ),
1102
+ pytest .param (
1103
+ "(NIR-R)/(NIR+R)" ,
1104
+ [("NIR" , 77 ), ("R" , 63 )],
1105
+ (77 - 63 ) / (77 + 63 ),
1106
+ None ,
1107
+ id = "simple expression" ,
1108
+ ),
1109
+ pytest .param (
1110
+ "if (A > B) 1.5*C ; else A" ,
1111
+ [("A" , 77 ), ("B" , 63 ), ("C" , 18 )],
1112
+ 27 ,
1113
+ ["exprtk" ],
1114
+ id = "exprtk conditional (explicit)" ,
1115
+ ),
1116
+ pytest .param (
1117
+ "(A > B) ? 1.5*C : A" ,
1118
+ [("A" , 77 ), ("B" , 63 ), ("C" , 18 )],
1119
+ 27 ,
1120
+ ["muparser" ],
1121
+ id = "muparser conditional (explicit)" ,
1122
+ ),
1123
+ pytest .param (
1124
+ "(A > B)*(1.5*C) + (A <= B)*(A)" ,
1125
+ [("A" , 77 ), ("B" , 63 ), ("C" , 18 )],
1126
+ 27 ,
1127
+ None ,
1128
+ id = "conditional (implicit)" ,
1129
+ ),
1130
+ pytest .param (
1131
+ "B2 * PopDensity" ,
1132
+ [("PopDensity" , 3 ), ("" , 7 )],
1133
+ 21 ,
1134
+ None ,
1135
+ id = "implicit source name" ,
1136
+ ),
1137
+ pytest .param (
1138
+ "B1 / sum(BANDS)" ,
1139
+ [("" , 3 ), ("" , 5 ), ("" , 31 )],
1140
+ 3 / (3 + 5 + 31 ),
1141
+ None ,
1142
+ id = "use of BANDS variable" ,
1143
+ ),
1144
+ pytest .param (
1145
+ "B1 / sum(B2, B3) " ,
1146
+ [("" , 3 ), ("" , 5 ), ("" , 31 )],
1147
+ 3 / (5 + 31 ),
1148
+ None ,
1149
+ id = "aggregate specified inputs" ,
1150
+ ),
1151
+ pytest .param (
1152
+ "var q[2] := {B2, B3}; B1 * q" ,
1153
+ [("" , 3 ), ("" , 5 ), ("" , 31 )],
1154
+ 15 , # First value in returned vector. This behavior doesn't seem desirable
1155
+ # but I haven't figured out how to detect a vector return.
1156
+ ["exprtk" ],
1157
+ id = "return vector" ,
1158
+ ),
1159
+ pytest .param (
1160
+ "B1 + B2 + B3" ,
1161
+ (5 , 9 , float ("nan" )),
1162
+ float ("nan" ),
1163
+ None ,
1164
+ id = "nan propagated via arithmetic" ,
1165
+ ),
1166
+ pytest .param (
1167
+ "if (B3) B1 ; else B2" ,
1168
+ (5 , 9 , float ("nan" )),
1169
+ 5 ,
1170
+ ["exprtk" ],
1171
+ id = "exprtk nan = truth in conditional?" ,
1172
+ ),
1173
+ pytest .param (
1174
+ "B3 ? B1 : B2" ,
1175
+ (5 , 9 , float ("nan" )),
1176
+ 5 ,
1177
+ ["muparser" ],
1178
+ id = "muparser nan = truth in conditional?" ,
1179
+ ),
1180
+ pytest .param (
1181
+ "if (B3 > 0) B1 ; else B2" ,
1182
+ (5 , 9 , float ("nan" )),
1183
+ 9 ,
1184
+ ["exprtk" ],
1185
+ id = "exprtk nan comparison is false in conditional" ,
1186
+ ),
1187
+ pytest .param (
1188
+ "(B3 > 0) ? B1 : B2" ,
1189
+ (5 , 9 , float ("nan" )),
1190
+ 9 ,
1191
+ ["muparser" ],
1192
+ id = "muparser nan comparison is false in conditional" ,
1193
+ ),
1194
+ pytest .param (
1195
+ "if (B1 > 5) B1" ,
1196
+ (1 ,),
1197
+ float ("nan" ),
1198
+ ["exprtk" ],
1199
+ id = "expression returns nodata" ,
1200
+ ),
1201
+ ],
1202
+ )
1203
+ @pytest .mark .parametrize ("dialect" , ("exprtk" , "muparser" ))
1204
+ def test_vrt_pixelfn_expression (
1205
+ tmp_vsimem , expression , sources , result , dialect , dialects
1206
+ ):
1207
+ pytest .importorskip ("numpy" )
1208
+
1209
+ if not gdaltest .gdal_has_vrt_expression_dialect (dialect ):
1210
+ pytest .skip (f"Expression dialect { dialect } is not available" )
1211
+
1212
+ if dialects and dialect not in dialects :
1213
+ pytest .skip (f"Expression not supported for dialect { dialect } " )
1214
+
1215
+ xml = vrt_expression_xml (tmp_vsimem , expression , dialect , sources )
1216
+
1217
+ with gdal .Open (xml ) as ds :
1218
+ assert pytest .approx (ds .ReadAsArray ()[0 ][0 ], nan_ok = True ) == result
1219
+
1220
+
1221
+ @pytest .mark .parametrize (
1222
+ "expression,sources,dialect,exception" ,
1223
+ [
1224
+ pytest .param (
1225
+ "A*B + C" ,
1226
+ [("A" , 77 ), ("B" , 63 )],
1227
+ "exprtk" ,
1228
+ "Undefined symbol" ,
1229
+ id = "exprtk undefined variable" ,
1230
+ ),
1231
+ pytest .param (
1232
+ "A*B + C" ,
1233
+ [("A" , 77 ), ("B" , 63 )],
1234
+ "muparser" ,
1235
+ "Unexpected token" ,
1236
+ id = "muparser undefined variable" ,
1237
+ ),
1238
+ pytest .param (
1239
+ "(" .join (["asin" , "sin" , "acos" , "cos" ] * 100 ) + "(X" + 100 * 4 * ")" ,
1240
+ [("X" , 0.5 )],
1241
+ "exprtk" ,
1242
+ "exceeds maximum allowed stack depth" ,
1243
+ id = "expression is too complex" ,
1244
+ ),
1245
+ pytest .param (
1246
+ " " .join (["sin(x) + cos(x)" ] * 10000 ),
1247
+ [("x" , 0.5 )],
1248
+ "exprtk" ,
1249
+ "exceeds maximum of 100000 set by GDAL_EXPRTK_MAX_EXPRESSION_LENGTH" ,
1250
+ id = "expression is too long" ,
1251
+ ),
1252
+ ],
1253
+ )
1254
+ def test_vrt_pixelfn_expression_invalid (
1255
+ tmp_vsimem , expression , sources , dialect , exception
1256
+ ):
1257
+ pytest .importorskip ("numpy" )
1258
+
1259
+ if not gdaltest .gdal_has_vrt_expression_dialect (dialect ):
1260
+ pytest .skip (f"Expression dialect { dialect } is not available" )
1261
+
1262
+ messages = []
1263
+
1264
+ def handle (ecls , ecode , emsg ):
1265
+ messages .append (emsg )
1266
+
1267
+ xml = vrt_expression_xml (tmp_vsimem , expression , dialect , sources )
1268
+
1269
+ with gdaltest .error_handler (handle ):
1270
+ ds = gdal .Open (xml )
1271
+ if ds :
1272
+ assert ds .ReadAsArray () is None
1273
+
1274
+ assert exception in "" .join (messages )
1275
+
1276
+
1058
1277
###############################################################################
1059
1278
# Cleanup.
1060
1279
0 commit comments