Coverage for src/efts_io/_ncdf_stf2.py: 8.43%
206 statements
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-24 10:14 +1000
« prev ^ index » next coverage.py v7.6.1, created at 2025-07-24 10:14 +1000
1"""Low level functions to write an xarray dataarray to disk in the sft conventions.
3These are functions ported from a collection of utilities initially in https://bitbucket.csiro.au/projects/SF/repos/python_functions/browse/swift_utility/swift_io.py
4"""
6import os # noqa: I001
7from enum import Enum
8from typing import Optional
10import numpy as np
11import pandas as pd
12import xarray as xr
14from efts_io.conventions import (
15 DAT_TYPE_ATTR_KEY,
16 DAT_TYPE_DESCRIPTION_ATTR_KEY,
17 LAT_VARNAME,
18 LOCATION_TYPE_ATTR_KEY,
19 LON_VARNAME,
20 REALISATION_DIMNAME,
21 STATION_ID_DIMNAME,
22 STF_2_0_URL,
23 TYPE_ATTR_KEY,
24 TYPE_DESCRIPTION_ATTR_KEY,
25 TYPES_CONVERTIBLE_TO_TIMESTAMP,
26 AttributesErrorLevel,
27 check_optional_variable_attributes,
28 convert_to_datetime64_utc,
29)
31from netCDF4 import Dataset
33class StfVariable(Enum):
34 STREAMFLOW = 1
35 POTENTIAL_EVAPOTRANSPIRATION = 2
36 RAINFALL = 3
37 SNOW_WATER_EQUIVALENT = 4
38 MINIMUM_TEMPERATURE = 5
39 MAXIMUM_TEMPERATURE = 6
42class StfDataType(Enum):
43 DERIVED_FROM_OBSERVATIONS = 1
44 FORECAST = 2
45 OBSERVED = 3
46 SIMULATED = 4
49def _create_cf_time_axis(data: xr.DataArray, timestep_str: str) -> tuple[np.ndarray, str, str]:
50 """Create a CF-compliant time axis for the given xarray DataArray.
52 Args:
53 data (xr.DataArray): The input data array.
54 timestep_str (str): The time step string (e.g., "days").
56 Returns:
57 tuple[np.ndarray, str, str]: A tuple containing the encoded time axis,
58 the units string, and the calendar string.
59 """
60 from xarray.coding import times # noqa: I001
61 from efts_io.conventions import TIME_DIMNAME
63 tt = data[TIME_DIMNAME].values
64 if len(tt) == 0:
65 raise ValueError("Cannot create CF time axis from empty data array.")
66 origin = tt[0]
67 # will be strict in the first instance, relax or expand later on as needed
68 if not any(isinstance(origin, t) for t in TYPES_CONVERTIBLE_TO_TIMESTAMP):
69 raise TypeError(
70 f"Expected data[TIME_DIMNAME] to be of a type convertible to pd.Timestamp, got {type(origin)} instead.",
71 )
72 origin = convert_to_datetime64_utc(origin)
73 dtimes = [convert_to_datetime64_utc(x) for x in tt]
74 # NOTE: this is not quite what is suggested by the STF convention in the example string.
75 # The below is closer to the the 8601 specifications, however we use space not 'T' for date/time separator
76 # https://docs.digi.com/resources/documentation/digidocs/90001488-13/reference/r_iso_8601_date_format.htm
77 iso_8601_origin = pd.Timestamp(origin).tz_localize("UTC")
78 formatted_string = iso_8601_origin.strftime("%Y-%m-%d %H:%M:%S")
79 timezone_offset = iso_8601_origin.strftime("%z")
80 formatted_timezone_offset = f"{timezone_offset[:3]}:{timezone_offset[3:]}"
81 formatted_string_with_tz = f"{formatted_string}{formatted_timezone_offset}"
83 axis, units, calendar = times.encode_cf_datetime(
84 dates=dtimes, #: 'T_DuckArray',
85 units=f"{timestep_str} since {formatted_string_with_tz}", #: 'str | None' = None,
86 calendar=None, #: 'str | None' = None,
87 dtype=None, #: 'np.dtype | None' = None,
88 ) # -> 'tuple[T_DuckArray, str, str]'
89 # override times.encode_cf_datetime, which is varying
90 # depending on the imput unit string and may not have the time zone, or a T separator.
91 units = f"{timestep_str} since {formatted_string_with_tz}"
92 return axis, units, calendar
95def write_nc_stf2(
96 out_nc_file: str,
97 dataset: xr.Dataset,
98 data: xr.DataArray,
99 var_type: StfVariable = StfVariable.STREAMFLOW,
100 data_type: StfDataType = StfDataType.OBSERVED,
101 stf_nc_vers: int = 2,
102 ens: bool = False, # noqa: FBT001, FBT002
103 timestep: str = "days",
104 data_qual: Optional[xr.DataArray] = None,
105 overwrite: bool = True, # noqa: FBT001, FBT002
106 # loc_info: Optional[Dict[str, Any]] = None,
107) -> None:
108 from efts_io.conventions import ( # noqa: I001
109 X_VARNAME,
110 Y_VARNAME,
111 AREA_VARNAME,
112 ELEVATION_VARNAME,
113 AXIS_ATTR_KEY,
114 CATCHMENT_ATTR_KEY,
115 COMMENT_ATTR_KEY,
116 ENS_MEMBER_DIMNAME,
117 HISTORY_ATTR_KEY,
118 INSTITUTION_ATTR_KEY,
119 LEAD_TIME_DIMNAME,
120 LONG_NAME_ATTR_KEY,
121 SOURCE_ATTR_KEY,
122 STANDARD_NAME_ATTR_KEY,
123 STATION_DIMNAME,
124 STATION_ID_VARNAME,
125 STATION_NAME_VARNAME,
126 STF_CONVENTION_VERSION_ATTR_KEY,
127 STR_LEN_DIMNAME,
128 TIME_DIMNAME,
129 TIME_STANDARD_ATTR_KEY,
130 TITLE_ATTR_KEY,
131 UNITS_ATTR_KEY,
132 has_required_xarray_dimensions,
133 has_required_global_attributes,
134 mandatory_xarray_dimensions,
135 mandatory_global_attributes,
136 has_required_variables_xr,
137 mandatory_varnames_xr,
138 has_variable,
139 )
141 if not has_required_xarray_dimensions(data):
142 raise ValueError(
143 f"DataArray must have the following dimensions: {mandatory_xarray_dimensions}",
144 )
146 if not has_required_global_attributes(dataset):
147 raise ValueError(
148 f"DataArray must have the following global attributes: {mandatory_global_attributes}",
149 )
151 if not has_required_variables_xr(dataset):
152 raise ValueError(
153 f"DataArray must have the following variables: {mandatory_varnames_xr}",
154 )
156 # Check that optional variables, if present, have the minimum attributes present.
157 def _check_optional_var_attr(dataset:xr.Dataset, var_id:str) -> None:
158 if has_variable(dataset, var_id):
159 xrvar = dataset[var_id]
160 check_optional_variable_attributes(xrvar, AttributesErrorLevel.ERROR)
162 for var_id in (AREA_VARNAME, X_VARNAME, Y_VARNAME, ELEVATION_VARNAME):
163 _check_optional_var_attr(dataset, var_id)
165 intdata_type = "i4"
167 var_type = var_type.value
168 data_type = data_type.value
170 n_stations = len(data[STATION_ID_DIMNAME])
172 station = np.arange(1, n_stations + 1)
174 # Retrieve arrays from expected variables in the input xarray dataarray `data`
175 station_id = dataset[STATION_ID_VARNAME].values
176 station_name = dataset[STATION_NAME_VARNAME].values
177 sub_x_centroid = dataset[LON_VARNAME].values
178 sub_y_centroid = dataset[LAT_VARNAME].values
180 # NOTE: the original code had an "other_station_id" option, apparently storing some
181 # identifiers from the Bureau of meteorology. For the time being, disable,
182 # but initiate a discussion. See issue #9.
183 # other_station_id = data["other_station_id"].values
185 if timestep in ["weeks", "w", "wk", "week"]:
186 timestep_str = "weeks"
187 elif timestep in ["days", "d", "ds", "day"]:
188 timestep_str = "days"
189 elif timestep in ["hours", "h", "hr", "hour"]:
190 timestep_str = "hours"
191 elif timestep in ["minutes", "m", "min", "minute"]:
192 timestep_str = "minutes"
193 elif timestep in ["seconds", "s", "sec", "second"]:
194 timestep_str = "seconds"
195 else:
196 raise ValueError(f"Unsupported or unrecognised time step unit: {timestep}")
198 # Check if file exists
199 if os.path.exists(out_nc_file):
200 if not overwrite:
201 raise FileExistsError(
202 f"Warning: The file '{out_nc_file}' exists, so either set overwrite=True to overwrite or give new filename.",
203 )
204 os.remove(out_nc_file)
205 # print(f"Warning: The file '{out_nc_file}' has been overwritten.")
207 # Create netcdf file
208 ncfile = Dataset(out_nc_file, "w", format="NETCDF4")
209 # Global Attributes
210 # ncfile.description = "CCLIR forecasts"
211 ncfile.title = dataset.attrs.get(TITLE_ATTR_KEY, "") # = nc_title
212 ncfile.institution = dataset.attrs.get(INSTITUTION_ATTR_KEY, "") # = inst
213 ncfile.source = dataset.attrs.get(SOURCE_ATTR_KEY, "") # = source
214 ncfile.catchment = dataset.attrs.get(CATCHMENT_ATTR_KEY, "") # = catchment
215 ncfile.STF_convention_version = dataset.attrs.get(STF_CONVENTION_VERSION_ATTR_KEY, "") # = stf_nc_vers
216 ncfile.STF_nc_spec = STF_2_0_URL # we do not transfer the spec version, this code determines it.
217 ncfile.comment = dataset.attrs.get(COMMENT_ATTR_KEY, "") # = comment
218 ncfile.history = dataset.attrs.get(HISTORY_ATTR_KEY, "")
219 # = "Created " + datetime.now(tz=timezone.utc).strftime("%Y-%m-%d %H:%M:%S")
221 # station
222 # --------------------
223 ncfile.createDimension(STATION_DIMNAME, n_stations)
224 station_var = ncfile.createVariable(STATION_DIMNAME, intdata_type, (STATION_DIMNAME,), fill_value=-9999)
225 station_var[:] = station
227 # station_id
228 station_id_var = ncfile.createVariable(STATION_ID_VARNAME, intdata_type, (STATION_DIMNAME,), fill_value=-9999)
229 station_id_var.setncattr(LONG_NAME_ATTR_KEY, "station or node identification code")
230 station_id_var[:] = station_id
232 # station_name
233 ncfile.createDimension(STR_LEN_DIMNAME, 30)
234 station_name_var = ncfile.createVariable(STATION_NAME_VARNAME, "c", (STATION_DIMNAME, STR_LEN_DIMNAME))
235 station_name_var.setncattr(LONG_NAME_ATTR_KEY, "station or node name")
236 for s_i, stn_name in enumerate(station_name):
237 char_stn_name = [" "] * 30 # 30 char length
238 stn_name_30 = stn_name[:30]
239 char_stn_name[: len(stn_name_30)] = stn_name_30
240 station_name_var[s_i, :] = char_stn_name
242 # additional station id e.g. BoM
243 # other_station_id_var = ncfile.createVariable("other_station_id", "c", (STATION_DIMNAME, STR_LEN_DIMNAME))
244 # other_station_id_var.setncattr(LONG_NAME_ATTR_KEY, "other station id e.g. BoM")
245 # for s_i, stn_name in enumerate(other_station_id):
246 # char_stn_name = [" "] * 30 # 30 char length
247 # stn_name_30 = stn_name[:30]
248 # char_stn_name[: len(stn_name_30)] = stn_name_30
249 # other_station_id_var[s_i, :] = char_stn_name
250 # coordinates, area
251 # --------------------
252 lat_var = ncfile.createVariable(LAT_VARNAME, "f", (STATION_DIMNAME,), fill_value=-9999)
253 lat_var.setncattr(LONG_NAME_ATTR_KEY, "latitude")
254 lat_var.setncattr(UNITS_ATTR_KEY, "degrees_north")
255 lat_var.setncattr(AXIS_ATTR_KEY, "y")
256 lat_var[:] = sub_y_centroid
258 lon_var = ncfile.createVariable(LON_VARNAME, "f", (STATION_DIMNAME,), fill_value=-9999)
259 lon_var.setncattr(LONG_NAME_ATTR_KEY, "longitude")
260 lon_var.setncattr(UNITS_ATTR_KEY, "degrees_east")
261 lon_var.setncattr(AXIS_ATTR_KEY, "x")
262 lon_var[:] = sub_x_centroid
264 def add_optional_variables(data:xr.DataArray, ncfile:Dataset, var_id:str) -> None:
265 if has_variable(data, var_id):
266 ncvar_type = "f"
267 xrvar = data[var_id]
268 opt_nc_var = ncfile.createVariable(var_id, ncvar_type, (STATION_DIMNAME,), fill_value=-9999)
269 opt_nc_var[:] = xrvar.values
270 for x in (STANDARD_NAME_ATTR_KEY, LONG_NAME_ATTR_KEY, UNITS_ATTR_KEY):
271 opt_nc_var.setncattr(x, xrvar.attrs[x])
273 for var_id in (AREA_VARNAME, X_VARNAME, Y_VARNAME, ELEVATION_VARNAME):
274 add_optional_variables(dataset, ncfile, var_id)
276 # lead time
277 # ------------
278 ncfile.createDimension(LEAD_TIME_DIMNAME, len(data[LEAD_TIME_DIMNAME]))
279 lt_var = ncfile.createVariable(LEAD_TIME_DIMNAME, intdata_type, (LEAD_TIME_DIMNAME,), fill_value=-9999)
280 lt_var.setncattr(STANDARD_NAME_ATTR_KEY, "lead time")
281 lt_var.setncattr(LONG_NAME_ATTR_KEY, "forecast lead time")
282 lt_var.setncattr(UNITS_ATTR_KEY, "days since time")
283 lt_var.setncattr(AXIS_ATTR_KEY, "v")
284 lt_var[:] = data[LEAD_TIME_DIMNAME].values
286 # ensemble members
287 # ------------------
288 ncfile.createDimension(ENS_MEMBER_DIMNAME, len(data[REALISATION_DIMNAME]))
289 ens_mem_var = ncfile.createVariable(ENS_MEMBER_DIMNAME, intdata_type, (ENS_MEMBER_DIMNAME,), fill_value=-9999)
290 ens_mem_var.setncattr(STANDARD_NAME_ATTR_KEY, ENS_MEMBER_DIMNAME)
291 ens_mem_var.setncattr(LONG_NAME_ATTR_KEY, "ensemble member")
292 ens_mem_var.setncattr(UNITS_ATTR_KEY, "member id")
293 ens_mem_var.setncattr(AXIS_ATTR_KEY, "u")
294 ens_mem_var[:] = np.arange(1, len(data[REALISATION_DIMNAME]) + 1)
296 # time
297 # ------
298 ncfile.createDimension(TIME_DIMNAME, len(data[TIME_DIMNAME]))
299 time_var = ncfile.createVariable(TIME_DIMNAME, intdata_type, (TIME_DIMNAME,), fill_value=-9999)
300 time_var.setncattr(STANDARD_NAME_ATTR_KEY, TIME_DIMNAME)
301 time_var.setncattr(LONG_NAME_ATTR_KEY, TIME_DIMNAME)
302 time_var.setncattr(TIME_STANDARD_ATTR_KEY, "UTC+00:00")
303 time_var.setncattr(AXIS_ATTR_KEY, "t")
305 # time_units_str = "days since {} 00:00:00".format(data.attrs["fcast_date"])
306 axis_values, time_units_str, _ = _create_cf_time_axis(data, timestep_str)
307 time_var.setncattr(UNITS_ATTR_KEY, time_units_str)
308 time_var[:] = axis_values
310 # Borrowing from create_empty_stfnc.m
311 # Name Arrays
312 v_type = ["q", "pet", "rain", "swe", "tmin", "tmax", "tave"]
313 v_type_long = [
314 "streamflow",
315 "potential evapotranspiration",
316 "rainfall",
317 "snow water equivalent",
318 "minimum temperature",
319 "maximum temperature",
320 "average temperature",
321 ]
322 v_units = ["m3/s", "mm", "mm", "mm", "K", "K", "K"]
323 v_ttype = [3, 2, 2, 2, 5, 5, 5]
324 v_ttype_name = [
325 "averaged over the preceding interval",
326 "accumulated over the preceding interval",
327 "accumulated over the preceding interval",
328 "point value recorded in the preceding interval",
329 "point value recorded in the preceding interval",
330 "averaged over the preceding interval",
331 ]
333 d_type = [None] * 4
334 d_type_long = [None] * 4
335 d_type[0] = "der"
336 d_type_long[0] = "derived (from observations)"
338 if int(stf_nc_vers) == 1:
339 d_type[1] = "fcast"
340 d_type_long[1] = "forecast"
341 elif int(stf_nc_vers) == 2: # noqa: PLR2004
342 d_type[1] = "fct"
343 d_type_long[1] = "forecast"
344 else:
345 raise ValueError("Version not recognised: Currently only version 1.X or 2.X are supported")
347 d_type[2] = "obs"
348 d_type_long[2] = "observed"
349 d_type[3] = "sim"
350 d_type_long[3] = "simulated"
352 # change var_type and data_type to python based index starting from 0
353 var_type = var_type - 1
354 data_type = data_type - 1
355 # print(f"data_type: {data_type}')
356 # Create prescribed variable names
357 if int(stf_nc_vers) == 1:
358 var_name_s = f"{v_type[var_type]}_{d_type[data_type]}"
359 var_name_l = f"{d_type_long[data_type]} {v_type_long[var_type]}"
360 if ens:
361 var_name_s = f"{var_name_s}_ens"
362 var_name_l = f"{var_name_l} ensemble"
363 else:
364 var_name_attr = d_type[data_type]
365 dat_type_description = d_type_long[data_type]
366 if data_type in [0, 2]:
367 # print("Obs")
368 var_name_s = f"{v_type[var_type]}_obs"
369 var_name_l = f"observed {v_type_long[var_type]}"
370 else:
371 # print("Sim")
372 var_name_s = f"{v_type[var_type]}_sim"
373 var_name_l = f"simulated {v_type_long[var_type]}"
375 qsim_var = ncfile.createVariable(
376 var_name_s,
377 "f",
378 (TIME_DIMNAME, ENS_MEMBER_DIMNAME, STATION_DIMNAME, LEAD_TIME_DIMNAME),
379 fill_value=-9999,
380 )
381 qsim_var.setncattr(STANDARD_NAME_ATTR_KEY, var_name_s)
382 qsim_var.setncattr(LONG_NAME_ATTR_KEY, var_name_l)
383 qsim_var.setncattr(UNITS_ATTR_KEY, v_units[var_type])
385 qsim_var.setncattr(TYPE_ATTR_KEY, v_ttype[var_type])
386 qsim_var.setncattr(TYPE_DESCRIPTION_ATTR_KEY, v_ttype_name[var_type])
387 if int(stf_nc_vers) == 2: # noqa: PLR2004
388 qsim_var.setncattr(DAT_TYPE_ATTR_KEY, var_name_attr)
389 qsim_var.setncattr(DAT_TYPE_DESCRIPTION_ATTR_KEY, dat_type_description)
390 qsim_var.setncattr(LOCATION_TYPE_ATTR_KEY, "Point")
391 else:
392 qsim_var.setncattr(LOCATION_TYPE_ATTR_KEY, "Point")
394 # WARNING: I do not like the look of the following; is it bug prone?
395 qsim_var[:, :, :, :] = data.values[:]
397 # Specify the quality variable
398 if data_qual is not None:
399 qu_var_name_s = f"{var_name_s}_qual"
400 if int(stf_nc_vers) == 1:
401 if data_type == 2: # noqa: PLR2004
402 qsim_qual_var = ncfile.createVariable(
403 qu_var_name_s, "f", (TIME_DIMNAME, STATION_DIMNAME, LEAD_TIME_DIMNAME), fill_value=-1,
404 )
405 qsim_qual_var[:, :, :] = data_qual.values[:]
406 else:
407 qsim_qual_var = ncfile.createVariable(
408 qu_var_name_s, "f", (TIME_DIMNAME, STATION_DIMNAME), fill_value=-1,
409 )
410 qsim_qual_var[:, :] = data_qual.values[:]
411 else:
412 qsim_qual_var = ncfile.createVariable(
413 qu_var_name_s,
414 "f",
415 (TIME_DIMNAME, ENS_MEMBER_DIMNAME, STATION_DIMNAME, LEAD_TIME_DIMNAME),
416 fill_value=-1,
417 )
418 qsim_qual_var[:, :, :, :] = data_qual.values[:]
420 qu_var_name_l = f"{var_name_l} data quality"
422 qsim_qual_var.setncattr(STANDARD_NAME_ATTR_KEY, qu_var_name_s)
423 qsim_qual_var.setncattr(LONG_NAME_ATTR_KEY, qu_var_name_l)
424 quality_code = data_qual.attrs.get("quality_code", "Quality codes")
426 qsim_qual_var.setncattr(UNITS_ATTR_KEY, quality_code)
427 # Write data
429 # close file
430 ncfile.close()