Coverage for src/efts_io/_ncdf_stf2.py: 8.43%

206 statements  

« prev     ^ index     » next       coverage.py v7.6.1, created at 2025-07-24 10:14 +1000

1"""Low level functions to write an xarray dataarray to disk in the sft conventions. 

2 

3These are functions ported from a collection of utilities initially in https://bitbucket.csiro.au/projects/SF/repos/python_functions/browse/swift_utility/swift_io.py 

4""" 

5 

6import os # noqa: I001 

7from enum import Enum 

8from typing import Optional 

9 

10import numpy as np 

11import pandas as pd 

12import xarray as xr 

13 

14from efts_io.conventions import ( 

15 DAT_TYPE_ATTR_KEY, 

16 DAT_TYPE_DESCRIPTION_ATTR_KEY, 

17 LAT_VARNAME, 

18 LOCATION_TYPE_ATTR_KEY, 

19 LON_VARNAME, 

20 REALISATION_DIMNAME, 

21 STATION_ID_DIMNAME, 

22 STF_2_0_URL, 

23 TYPE_ATTR_KEY, 

24 TYPE_DESCRIPTION_ATTR_KEY, 

25 TYPES_CONVERTIBLE_TO_TIMESTAMP, 

26 AttributesErrorLevel, 

27 check_optional_variable_attributes, 

28 convert_to_datetime64_utc, 

29) 

30 

31from netCDF4 import Dataset 

32 

33class StfVariable(Enum): 

34 STREAMFLOW = 1 

35 POTENTIAL_EVAPOTRANSPIRATION = 2 

36 RAINFALL = 3 

37 SNOW_WATER_EQUIVALENT = 4 

38 MINIMUM_TEMPERATURE = 5 

39 MAXIMUM_TEMPERATURE = 6 

40 

41 

42class StfDataType(Enum): 

43 DERIVED_FROM_OBSERVATIONS = 1 

44 FORECAST = 2 

45 OBSERVED = 3 

46 SIMULATED = 4 

47 

48 

49def _create_cf_time_axis(data: xr.DataArray, timestep_str: str) -> tuple[np.ndarray, str, str]: 

50 """Create a CF-compliant time axis for the given xarray DataArray. 

51 

52 Args: 

53 data (xr.DataArray): The input data array. 

54 timestep_str (str): The time step string (e.g., "days"). 

55 

56 Returns: 

57 tuple[np.ndarray, str, str]: A tuple containing the encoded time axis, 

58 the units string, and the calendar string. 

59 """ 

60 from xarray.coding import times # noqa: I001 

61 from efts_io.conventions import TIME_DIMNAME 

62 

63 tt = data[TIME_DIMNAME].values 

64 if len(tt) == 0: 

65 raise ValueError("Cannot create CF time axis from empty data array.") 

66 origin = tt[0] 

67 # will be strict in the first instance, relax or expand later on as needed 

68 if not any(isinstance(origin, t) for t in TYPES_CONVERTIBLE_TO_TIMESTAMP): 

69 raise TypeError( 

70 f"Expected data[TIME_DIMNAME] to be of a type convertible to pd.Timestamp, got {type(origin)} instead.", 

71 ) 

72 origin = convert_to_datetime64_utc(origin) 

73 dtimes = [convert_to_datetime64_utc(x) for x in tt] 

74 # NOTE: this is not quite what is suggested by the STF convention in the example string. 

75 # The below is closer to the the 8601 specifications, however we use space not 'T' for date/time separator 

76 # https://docs.digi.com/resources/documentation/digidocs/90001488-13/reference/r_iso_8601_date_format.htm 

77 iso_8601_origin = pd.Timestamp(origin).tz_localize("UTC") 

78 formatted_string = iso_8601_origin.strftime("%Y-%m-%d %H:%M:%S") 

79 timezone_offset = iso_8601_origin.strftime("%z") 

80 formatted_timezone_offset = f"{timezone_offset[:3]}:{timezone_offset[3:]}" 

81 formatted_string_with_tz = f"{formatted_string}{formatted_timezone_offset}" 

82 

83 axis, units, calendar = times.encode_cf_datetime( 

84 dates=dtimes, #: 'T_DuckArray', 

85 units=f"{timestep_str} since {formatted_string_with_tz}", #: 'str | None' = None, 

86 calendar=None, #: 'str | None' = None, 

87 dtype=None, #: 'np.dtype | None' = None, 

88 ) # -> 'tuple[T_DuckArray, str, str]' 

89 # override times.encode_cf_datetime, which is varying 

90 # depending on the imput unit string and may not have the time zone, or a T separator. 

91 units = f"{timestep_str} since {formatted_string_with_tz}" 

92 return axis, units, calendar 

93 

94 

95def write_nc_stf2( 

96 out_nc_file: str, 

97 dataset: xr.Dataset, 

98 data: xr.DataArray, 

99 var_type: StfVariable = StfVariable.STREAMFLOW, 

100 data_type: StfDataType = StfDataType.OBSERVED, 

101 stf_nc_vers: int = 2, 

102 ens: bool = False, # noqa: FBT001, FBT002 

103 timestep: str = "days", 

104 data_qual: Optional[xr.DataArray] = None, 

105 overwrite: bool = True, # noqa: FBT001, FBT002 

106 # loc_info: Optional[Dict[str, Any]] = None, 

107) -> None: 

108 from efts_io.conventions import ( # noqa: I001 

109 X_VARNAME, 

110 Y_VARNAME, 

111 AREA_VARNAME, 

112 ELEVATION_VARNAME, 

113 AXIS_ATTR_KEY, 

114 CATCHMENT_ATTR_KEY, 

115 COMMENT_ATTR_KEY, 

116 ENS_MEMBER_DIMNAME, 

117 HISTORY_ATTR_KEY, 

118 INSTITUTION_ATTR_KEY, 

119 LEAD_TIME_DIMNAME, 

120 LONG_NAME_ATTR_KEY, 

121 SOURCE_ATTR_KEY, 

122 STANDARD_NAME_ATTR_KEY, 

123 STATION_DIMNAME, 

124 STATION_ID_VARNAME, 

125 STATION_NAME_VARNAME, 

126 STF_CONVENTION_VERSION_ATTR_KEY, 

127 STR_LEN_DIMNAME, 

128 TIME_DIMNAME, 

129 TIME_STANDARD_ATTR_KEY, 

130 TITLE_ATTR_KEY, 

131 UNITS_ATTR_KEY, 

132 has_required_xarray_dimensions, 

133 has_required_global_attributes, 

134 mandatory_xarray_dimensions, 

135 mandatory_global_attributes, 

136 has_required_variables_xr, 

137 mandatory_varnames_xr, 

138 has_variable, 

139 ) 

140 

141 if not has_required_xarray_dimensions(data): 

142 raise ValueError( 

143 f"DataArray must have the following dimensions: {mandatory_xarray_dimensions}", 

144 ) 

145 

146 if not has_required_global_attributes(dataset): 

147 raise ValueError( 

148 f"DataArray must have the following global attributes: {mandatory_global_attributes}", 

149 ) 

150 

151 if not has_required_variables_xr(dataset): 

152 raise ValueError( 

153 f"DataArray must have the following variables: {mandatory_varnames_xr}", 

154 ) 

155 

156 # Check that optional variables, if present, have the minimum attributes present. 

157 def _check_optional_var_attr(dataset:xr.Dataset, var_id:str) -> None: 

158 if has_variable(dataset, var_id): 

159 xrvar = dataset[var_id] 

160 check_optional_variable_attributes(xrvar, AttributesErrorLevel.ERROR) 

161 

162 for var_id in (AREA_VARNAME, X_VARNAME, Y_VARNAME, ELEVATION_VARNAME): 

163 _check_optional_var_attr(dataset, var_id) 

164 

165 intdata_type = "i4" 

166 

167 var_type = var_type.value 

168 data_type = data_type.value 

169 

170 n_stations = len(data[STATION_ID_DIMNAME]) 

171 

172 station = np.arange(1, n_stations + 1) 

173 

174 # Retrieve arrays from expected variables in the input xarray dataarray `data` 

175 station_id = dataset[STATION_ID_VARNAME].values 

176 station_name = dataset[STATION_NAME_VARNAME].values 

177 sub_x_centroid = dataset[LON_VARNAME].values 

178 sub_y_centroid = dataset[LAT_VARNAME].values 

179 

180 # NOTE: the original code had an "other_station_id" option, apparently storing some 

181 # identifiers from the Bureau of meteorology. For the time being, disable, 

182 # but initiate a discussion. See issue #9. 

183 # other_station_id = data["other_station_id"].values 

184 

185 if timestep in ["weeks", "w", "wk", "week"]: 

186 timestep_str = "weeks" 

187 elif timestep in ["days", "d", "ds", "day"]: 

188 timestep_str = "days" 

189 elif timestep in ["hours", "h", "hr", "hour"]: 

190 timestep_str = "hours" 

191 elif timestep in ["minutes", "m", "min", "minute"]: 

192 timestep_str = "minutes" 

193 elif timestep in ["seconds", "s", "sec", "second"]: 

194 timestep_str = "seconds" 

195 else: 

196 raise ValueError(f"Unsupported or unrecognised time step unit: {timestep}") 

197 

198 # Check if file exists 

199 if os.path.exists(out_nc_file): 

200 if not overwrite: 

201 raise FileExistsError( 

202 f"Warning: The file '{out_nc_file}' exists, so either set overwrite=True to overwrite or give new filename.", 

203 ) 

204 os.remove(out_nc_file) 

205 # print(f"Warning: The file '{out_nc_file}' has been overwritten.") 

206 

207 # Create netcdf file 

208 ncfile = Dataset(out_nc_file, "w", format="NETCDF4") 

209 # Global Attributes 

210 # ncfile.description = "CCLIR forecasts" 

211 ncfile.title = dataset.attrs.get(TITLE_ATTR_KEY, "") # = nc_title 

212 ncfile.institution = dataset.attrs.get(INSTITUTION_ATTR_KEY, "") # = inst 

213 ncfile.source = dataset.attrs.get(SOURCE_ATTR_KEY, "") # = source 

214 ncfile.catchment = dataset.attrs.get(CATCHMENT_ATTR_KEY, "") # = catchment 

215 ncfile.STF_convention_version = dataset.attrs.get(STF_CONVENTION_VERSION_ATTR_KEY, "") # = stf_nc_vers 

216 ncfile.STF_nc_spec = STF_2_0_URL # we do not transfer the spec version, this code determines it. 

217 ncfile.comment = dataset.attrs.get(COMMENT_ATTR_KEY, "") # = comment 

218 ncfile.history = dataset.attrs.get(HISTORY_ATTR_KEY, "") 

219 # = "Created " + datetime.now(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S") 

220 

221 # station 

222 # -------------------- 

223 ncfile.createDimension(STATION_DIMNAME, n_stations) 

224 station_var = ncfile.createVariable(STATION_DIMNAME, intdata_type, (STATION_DIMNAME,), fill_value=-9999) 

225 station_var[:] = station 

226 

227 # station_id 

228 station_id_var = ncfile.createVariable(STATION_ID_VARNAME, intdata_type, (STATION_DIMNAME,), fill_value=-9999) 

229 station_id_var.setncattr(LONG_NAME_ATTR_KEY, "station or node identification code") 

230 station_id_var[:] = station_id 

231 

232 # station_name 

233 ncfile.createDimension(STR_LEN_DIMNAME, 30) 

234 station_name_var = ncfile.createVariable(STATION_NAME_VARNAME, "c", (STATION_DIMNAME, STR_LEN_DIMNAME)) 

235 station_name_var.setncattr(LONG_NAME_ATTR_KEY, "station or node name") 

236 for s_i, stn_name in enumerate(station_name): 

237 char_stn_name = [" "] * 30 # 30 char length 

238 stn_name_30 = stn_name[:30] 

239 char_stn_name[: len(stn_name_30)] = stn_name_30 

240 station_name_var[s_i, :] = char_stn_name 

241 

242 # additional station id e.g. BoM 

243 # other_station_id_var = ncfile.createVariable("other_station_id", "c", (STATION_DIMNAME, STR_LEN_DIMNAME)) 

244 # other_station_id_var.setncattr(LONG_NAME_ATTR_KEY, "other station id e.g. BoM") 

245 # for s_i, stn_name in enumerate(other_station_id): 

246 # char_stn_name = [" "] * 30 # 30 char length 

247 # stn_name_30 = stn_name[:30] 

248 # char_stn_name[: len(stn_name_30)] = stn_name_30 

249 # other_station_id_var[s_i, :] = char_stn_name 

250 # coordinates, area 

251 # -------------------- 

252 lat_var = ncfile.createVariable(LAT_VARNAME, "f", (STATION_DIMNAME,), fill_value=-9999) 

253 lat_var.setncattr(LONG_NAME_ATTR_KEY, "latitude") 

254 lat_var.setncattr(UNITS_ATTR_KEY, "degrees_north") 

255 lat_var.setncattr(AXIS_ATTR_KEY, "y") 

256 lat_var[:] = sub_y_centroid 

257 

258 lon_var = ncfile.createVariable(LON_VARNAME, "f", (STATION_DIMNAME,), fill_value=-9999) 

259 lon_var.setncattr(LONG_NAME_ATTR_KEY, "longitude") 

260 lon_var.setncattr(UNITS_ATTR_KEY, "degrees_east") 

261 lon_var.setncattr(AXIS_ATTR_KEY, "x") 

262 lon_var[:] = sub_x_centroid 

263 

264 def add_optional_variables(data:xr.DataArray, ncfile:Dataset, var_id:str) -> None: 

265 if has_variable(data, var_id): 

266 ncvar_type = "f" 

267 xrvar = data[var_id] 

268 opt_nc_var = ncfile.createVariable(var_id, ncvar_type, (STATION_DIMNAME,), fill_value=-9999) 

269 opt_nc_var[:] = xrvar.values 

270 for x in (STANDARD_NAME_ATTR_KEY, LONG_NAME_ATTR_KEY, UNITS_ATTR_KEY): 

271 opt_nc_var.setncattr(x, xrvar.attrs[x]) 

272 

273 for var_id in (AREA_VARNAME, X_VARNAME, Y_VARNAME, ELEVATION_VARNAME): 

274 add_optional_variables(dataset, ncfile, var_id) 

275 

276 # lead time 

277 # ------------ 

278 ncfile.createDimension(LEAD_TIME_DIMNAME, len(data[LEAD_TIME_DIMNAME])) 

279 lt_var = ncfile.createVariable(LEAD_TIME_DIMNAME, intdata_type, (LEAD_TIME_DIMNAME,), fill_value=-9999) 

280 lt_var.setncattr(STANDARD_NAME_ATTR_KEY, "lead time") 

281 lt_var.setncattr(LONG_NAME_ATTR_KEY, "forecast lead time") 

282 lt_var.setncattr(UNITS_ATTR_KEY, "days since time") 

283 lt_var.setncattr(AXIS_ATTR_KEY, "v") 

284 lt_var[:] = data[LEAD_TIME_DIMNAME].values 

285 

286 # ensemble members 

287 # ------------------ 

288 ncfile.createDimension(ENS_MEMBER_DIMNAME, len(data[REALISATION_DIMNAME])) 

289 ens_mem_var = ncfile.createVariable(ENS_MEMBER_DIMNAME, intdata_type, (ENS_MEMBER_DIMNAME,), fill_value=-9999) 

290 ens_mem_var.setncattr(STANDARD_NAME_ATTR_KEY, ENS_MEMBER_DIMNAME) 

291 ens_mem_var.setncattr(LONG_NAME_ATTR_KEY, "ensemble member") 

292 ens_mem_var.setncattr(UNITS_ATTR_KEY, "member id") 

293 ens_mem_var.setncattr(AXIS_ATTR_KEY, "u") 

294 ens_mem_var[:] = np.arange(1, len(data[REALISATION_DIMNAME]) + 1) 

295 

296 # time 

297 # ------ 

298 ncfile.createDimension(TIME_DIMNAME, len(data[TIME_DIMNAME])) 

299 time_var = ncfile.createVariable(TIME_DIMNAME, intdata_type, (TIME_DIMNAME,), fill_value=-9999) 

300 time_var.setncattr(STANDARD_NAME_ATTR_KEY, TIME_DIMNAME) 

301 time_var.setncattr(LONG_NAME_ATTR_KEY, TIME_DIMNAME) 

302 time_var.setncattr(TIME_STANDARD_ATTR_KEY, "UTC+00:00") 

303 time_var.setncattr(AXIS_ATTR_KEY, "t") 

304 

305 # time_units_str = "days since {} 00:00:00".format(data.attrs["fcast_date"]) 

306 axis_values, time_units_str, _ = _create_cf_time_axis(data, timestep_str) 

307 time_var.setncattr(UNITS_ATTR_KEY, time_units_str) 

308 time_var[:] = axis_values 

309 

310 # Borrowing from create_empty_stfnc.m 

311 # Name Arrays 

312 v_type = ["q", "pet", "rain", "swe", "tmin", "tmax", "tave"] 

313 v_type_long = [ 

314 "streamflow", 

315 "potential evapotranspiration", 

316 "rainfall", 

317 "snow water equivalent", 

318 "minimum temperature", 

319 "maximum temperature", 

320 "average temperature", 

321 ] 

322 v_units = ["m3/s", "mm", "mm", "mm", "K", "K", "K"] 

323 v_ttype = [3, 2, 2, 2, 5, 5, 5] 

324 v_ttype_name = [ 

325 "averaged over the preceding interval", 

326 "accumulated over the preceding interval", 

327 "accumulated over the preceding interval", 

328 "point value recorded in the preceding interval", 

329 "point value recorded in the preceding interval", 

330 "averaged over the preceding interval", 

331 ] 

332 

333 d_type = [None] * 4 

334 d_type_long = [None] * 4 

335 d_type[0] = "der" 

336 d_type_long[0] = "derived (from observations)" 

337 

338 if int(stf_nc_vers) == 1: 

339 d_type[1] = "fcast" 

340 d_type_long[1] = "forecast" 

341 elif int(stf_nc_vers) == 2: # noqa: PLR2004 

342 d_type[1] = "fct" 

343 d_type_long[1] = "forecast" 

344 else: 

345 raise ValueError("Version not recognised: Currently only version 1.X or 2.X are supported") 

346 

347 d_type[2] = "obs" 

348 d_type_long[2] = "observed" 

349 d_type[3] = "sim" 

350 d_type_long[3] = "simulated" 

351 

352 # change var_type and data_type to python based index starting from 0 

353 var_type = var_type - 1 

354 data_type = data_type - 1 

355 # print(f"data_type: {data_type}') 

356 # Create prescribed variable names 

357 if int(stf_nc_vers) == 1: 

358 var_name_s = f"{v_type[var_type]}_{d_type[data_type]}" 

359 var_name_l = f"{d_type_long[data_type]} {v_type_long[var_type]}" 

360 if ens: 

361 var_name_s = f"{var_name_s}_ens" 

362 var_name_l = f"{var_name_l} ensemble" 

363 else: 

364 var_name_attr = d_type[data_type] 

365 dat_type_description = d_type_long[data_type] 

366 if data_type in [0, 2]: 

367 # print("Obs") 

368 var_name_s = f"{v_type[var_type]}_obs" 

369 var_name_l = f"observed {v_type_long[var_type]}" 

370 else: 

371 # print("Sim") 

372 var_name_s = f"{v_type[var_type]}_sim" 

373 var_name_l = f"simulated {v_type_long[var_type]}" 

374 

375 qsim_var = ncfile.createVariable( 

376 var_name_s, 

377 "f", 

378 (TIME_DIMNAME, ENS_MEMBER_DIMNAME, STATION_DIMNAME, LEAD_TIME_DIMNAME), 

379 fill_value=-9999, 

380 ) 

381 qsim_var.setncattr(STANDARD_NAME_ATTR_KEY, var_name_s) 

382 qsim_var.setncattr(LONG_NAME_ATTR_KEY, var_name_l) 

383 qsim_var.setncattr(UNITS_ATTR_KEY, v_units[var_type]) 

384 

385 qsim_var.setncattr(TYPE_ATTR_KEY, v_ttype[var_type]) 

386 qsim_var.setncattr(TYPE_DESCRIPTION_ATTR_KEY, v_ttype_name[var_type]) 

387 if int(stf_nc_vers) == 2: # noqa: PLR2004 

388 qsim_var.setncattr(DAT_TYPE_ATTR_KEY, var_name_attr) 

389 qsim_var.setncattr(DAT_TYPE_DESCRIPTION_ATTR_KEY, dat_type_description) 

390 qsim_var.setncattr(LOCATION_TYPE_ATTR_KEY, "Point") 

391 else: 

392 qsim_var.setncattr(LOCATION_TYPE_ATTR_KEY, "Point") 

393 

394 # WARNING: I do not like the look of the following; is it bug prone? 

395 qsim_var[:, :, :, :] = data.values[:] 

396 

397 # Specify the quality variable 

398 if data_qual is not None: 

399 qu_var_name_s = f"{var_name_s}_qual" 

400 if int(stf_nc_vers) == 1: 

401 if data_type == 2: # noqa: PLR2004 

402 qsim_qual_var = ncfile.createVariable( 

403 qu_var_name_s, "f", (TIME_DIMNAME, STATION_DIMNAME, LEAD_TIME_DIMNAME), fill_value=-1, 

404 ) 

405 qsim_qual_var[:, :, :] = data_qual.values[:] 

406 else: 

407 qsim_qual_var = ncfile.createVariable( 

408 qu_var_name_s, "f", (TIME_DIMNAME, STATION_DIMNAME), fill_value=-1, 

409 ) 

410 qsim_qual_var[:, :] = data_qual.values[:] 

411 else: 

412 qsim_qual_var = ncfile.createVariable( 

413 qu_var_name_s, 

414 "f", 

415 (TIME_DIMNAME, ENS_MEMBER_DIMNAME, STATION_DIMNAME, LEAD_TIME_DIMNAME), 

416 fill_value=-1, 

417 ) 

418 qsim_qual_var[:, :, :, :] = data_qual.values[:] 

419 

420 qu_var_name_l = f"{var_name_l} data quality" 

421 

422 qsim_qual_var.setncattr(STANDARD_NAME_ATTR_KEY, qu_var_name_s) 

423 qsim_qual_var.setncattr(LONG_NAME_ATTR_KEY, qu_var_name_l) 

424 quality_code = data_qual.attrs.get("quality_code", "Quality codes") 

425 

426 qsim_qual_var.setncattr(UNITS_ATTR_KEY, quality_code) 

427 # Write data 

428 

429 # close file 

430 ncfile.close() 

431