Skip to content

Commit

Permalink
Updated the code to speed up the process of making summary plots.
Browse files Browse the repository at this point in the history
  • Loading branch information
azadeh-gh committed May 1, 2024
1 parent 7af7fe0 commit dbdcf79
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 26 deletions.
57 changes: 36 additions & 21 deletions ush/SpatialTemporalStatsTool/SpatialTemporalStats.py
Original file line number Diff line number Diff line change
Expand Up @@ -355,43 +355,50 @@ def make_summary_plots(
ds = xarray.open_dataset(studied_cycle_files[index[0]])
unique_channels = np.unique(ds["Channel_Index"].data).tolist()
print('Total Number of Channels ', len(unique_channels))
print('Channels ', unique_channels)

Allchannels_data={}
for this_channel in unique_channels:
this_channel_values = np.empty(shape=(0,))
for this_cycle_obs_file in studied_cycle_files:
ds = xarray.open_dataset(this_cycle_obs_file)
Combined_bool = ds["Channel_Index"].data == this_channel

if QC_filter:
QC_bool = ds["QC_Flag"].data == 0
Combined_bool = Combined_bool * QC_bool

this_cycle_var_values = ds[var_name].data[Combined_bool]
this_channel_values = np.append(
this_channel_values, this_cycle_var_values
Allchannels_data[this_channel] = np.empty(shape=(0,))
for this_cycle_obs_file in studied_cycle_files:
ds = xarray.open_dataset(this_cycle_obs_file)
if QC_filter:
QC_bool = ds["QC_Flag"].data == 0

for this_channel in unique_channels:
channel_bool = ds["Channel_Index"].data == this_channel

this_cycle_channel_var_values = ds[var_name].data[channel_bool*QC_bool]
Allchannels_data[this_channel] = np.append(
Allchannels_data[this_channel], this_cycle_channel_var_values
)

for this_channel in unique_channels:
this_channel_values=Allchannels_data[this_channel]
squared_values = [x**2 for x in this_channel_values]
mean_of_squares = sum(squared_values) / len(squared_values)
rms_value=mean_of_squares ** 0.5
Summary_results.append(
[
this_channel,
np.size(this_channel_values),
np.std(this_channel_values),
np.mean(this_channel_values),
rms_value
]
)


Summary_resultsDF = pd.DataFrame(
Summary_results, columns=["channel", "count", "std", "mean"]
Summary_results, columns=["channel", "count", "std", "mean", "rms"]
)
# Plotting
plt.figure(figsize=(10, 6))
plt.scatter(Summary_resultsDF["channel"], Summary_resultsDF["count"], s=50)
plt.xlabel("Channel")
plt.ylabel("Count")
plt.title("%s %s" % ((self.sensor, var_name)))
plt.xticks(Summary_resultsDF["channel"])
plt.xticks(rotation=45)
#plt.xticks(Summary_resultsDF["channel"])
#plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()
plt.savefig(
Expand All @@ -402,26 +409,34 @@ def make_summary_plots(
plt.close()

# Plotting scatter plot for mean and std
plt.figure(figsize=(10, 6))
plt.figure(figsize=(15, 6))
plt.scatter(
Summary_resultsDF["channel"],
Summary_resultsDF["mean"],
s=50,
c="red",
c="green",
label="Mean",
)
plt.scatter(
Summary_resultsDF["channel"],
Summary_resultsDF["std"],
s=50,
c="green",
c="red",
label="Std",
)
plt.scatter(
Summary_resultsDF["channel"],
Summary_resultsDF["rms"],
s=50,
label="Rms",
facecolors='none',
edgecolors='blue'
)
plt.xlabel("Channel")
plt.ylabel("Statistics")
plt.title("%s %s" % ((self.sensor, var_name)))
plt.xticks(Summary_resultsDF["channel"])
plt.xticks(rotation=45)
#plt.xticks(Summary_resultsDF["channel"])
#plt.xticks(rotation=45)
plt.grid(True)
plt.tight_layout()
plt.legend()
Expand Down
16 changes: 11 additions & 5 deletions ush/SpatialTemporalStatsTool/user_Analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,24 @@

# Set input and output paths
input_path = "/PATH/TO/Input/Files"
output_path = r'./Results'

#output_path = r'./Results'


# Set sensor name
sensor = "atms_n20"
sensor = "iasi_metop-c"

# Set variable name and channel number
var_name = "Obs_Minus_Forecast_adjusted"
channel_no = 1

# Set start and end dates
start_date, end_date = '2023-03-01', '2023-03-10'
start_date, end_date = '2024-01-01', '2024-01-31'

# Set region
# 1: global, 2: polar region, 3: mid-latitudes region,
# 4: tropics region, 5:southern mid-latitudes region, 6: southern polar region
region = 3
region = 1

# Initialize SpatialTemporalStats object
my_tool = SpatialTemporalStats()
Expand All @@ -28,6 +29,7 @@

# Generate grid
my_tool.generate_grid(resolution) # Call generate_grid method)
print("grid created!")

# Set QC filter
QC_filter = True # should be always False or true
Expand All @@ -52,16 +54,20 @@
QC_filter,
)

print("read obs values!")
# Can save the results in a gpkg file
# o_minus_f_gdf.to_file("filename.gpkg", driver='GPKG')

# Plot observations
print("creating plots...")
my_tool.plot_obs(o_minus_f_gdf, var_name, region, resolution, output_path)
print("Time/Area stats plots created!")

# Make summary plots
print("Creating summary plots...")
summary_results = my_tool.make_summary_plots(
input_path, sensor, var_name, start_date, end_date, QC_filter, output_path
)
print("Summary plots created!")

# Print summary results
print(summary_results)

0 comments on commit dbdcf79

Please sign in to comment.