Skip to content
Snippets Groups Projects
Commit b3c218c4 authored by akshayv4's avatar akshayv4
Browse files

ML models and corresponding plots for predicting BRAN Energy from HAD1-3, RPD,...

ML models and corresponding plots for predicting BRAN Energy from HAD1-3, RPD, and EM module energies.
parent 54a3c7cf
No related branches found
No related tags found
No related merge requests found
Showing
with 164 additions and 0 deletions
from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, HistGradientBoostingRegressor
from sklearn.model_selection import train_test_split, cross_val_score
from sklearn.tree import export_graphviz, plot_tree
from sklearn.inspection import permutation_importance
import graphviz
import numpy as np
import matplotlib.pyplot as plot
def main():
zdc_sideA_withRPD = np.load("zdc_ZdcModuleTruthEMNonEM_fullNumpy_sideA.npy")
zdc_sideC_withRPD = np.load("zdc_ZdcModuleTruthEMNonEM_fullNumpy_sideC.npy")
zdc_sideA_noRPD = np.load("zdc_ZdcModuleTruthEMNonEM_fullNumpy_sideA.npy")
zdc_sideC_noRPD = np.load("zdc_ZdcModuleTruthEMNonEM_fullNumpy_sideC.npy")
zdc_sideA_noRPD = np.delete(zdc_sideA_noRPD, 4, 1)
zdc_sideC_noRPD = np.delete(zdc_sideC_noRPD, 4, 1)
# print(len(zdc_sideA_noRPD))
# rf_cross_val(zdc_sideA_withRPD[:10000, :5], zdc_sideA_withRPD[:10000, 5], 4)
# rf_cross_val(zdc_sideC_withRPD[:10000, :5], zdc_sideC_withRPD[:10000, 5], 4)
# rf_cross_val(zdc_sideA_noRPD[:10000, :4], zdc_sideA_noRPD[:10000, 4], 4)
# rf_cross_val(zdc_sideC_noRPD[:10000, :4], zdc_sideC_noRPD[:10000, 4], 4)
# xg_cross_val(zdc_sideA_withRPD[:, :5], zdc_sideA_withRPD[:, 5], 4)
# xg_cross_val(zdc_sideC_withRPD[:, :5], zdc_sideC_withRPD[:, 5], 4)
# xg_cross_val(zdc_sideA_noRPD[:, :4], zdc_sideA_noRPD[:, 4], 4)
# xg_cross_val(zdc_sideC_noRPD[:, :4], zdc_sideC_noRPD[:, 4], 4)
hist_tree_depth = -1
hist_training_ratio = 0.8
vis_tree_depth = 4
vis_training_ratio = 0.05
xg_train_and_test(zdc_sideA_withRPD[:, :5], zdc_sideA_withRPD[:, 5], hist_tree_depth, hist_training_ratio, "Side A (WITH RPD)")
# xg_visualize(zdc_sideA_withRPD[:, :5], zdc_sideA_withRPD[:, 5], vis_tree_depth, vis_training_ratio, "Side A (WITH RPD)")
xg_train_and_test(zdc_sideC_withRPD[:, :5], zdc_sideC_withRPD[:, 5], hist_tree_depth, hist_training_ratio, "Side C (WITH RPD)")
# xg_visualize(zdc_sideC_withRPD[:, :5], zdc_sideC_withRPD[:, 5], vis_tree_depth, vis_training_ratio, "Side C (WITH RPD)")
xg_train_and_test(zdc_sideA_noRPD[:, :4], zdc_sideA_noRPD[:, 4], hist_tree_depth, hist_training_ratio, "Side A (NO RPD)")
# xg_visualize(zdc_sideA_noRPD[:, :5], zdc_sideA_noRPD[:, 5], vis_tree_depth, vis_training_ratio, "Side A (NO RPD)")
xg_train_and_test(zdc_sideC_noRPD[:, :4], zdc_sideC_noRPD[:, 4], hist_tree_depth, hist_training_ratio, "Side C (NO RPD)")
# xg_visualize(zdc_sideC_noRPD[:, :5], zdc_sideC_noRPD[:, 5], vis_tree_depth, vis_training_ratio, "Side C (NO RPD)")
#############################################
# Creates a HistGradientBoostingRegressor
# Any allowed decision tree heights
# 80/20 Train-Test Split
# Actual model used in predicting BRAN
#############################################
def xg_train_and_test(X, y, tree_depth, training_ratio, name):
print(name)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = training_ratio)
model = 0
if tree_depth < 1:
model = HistGradientBoostingRegressor()
else:
model = HistGradientBoostingRegressor(max_depth = tree_depth)
model.fit(X_train, y_train)
# Evaluate Model Against Test Set
y_pred = model.predict(X_test)
y_test_mean = sum(v for v in y_test) / len(y_test)
######## Test R^2 value = 1 - SSE/SST
accuracy = 1 - sum( (y_test[i] - y_pred[i])**2 for i in range(len(y_test)) ) / sum( (y_test[i] - y_test_mean)**2 for i in range(len(y_test)) )
print("Test accuracy:", accuracy)
# Cross-validation
scores = cross_val_score(model, X, y, cv=5)
print("Cross-validation scores:", scores)
print("Mean cross-validation score:", scores.mean())
# Feature Importances
result = permutation_importance(model, X_test, y_test, scoring='neg_mean_squared_error', n_repeats=30)
feature_importances = result.importances_mean
feature_importances = np.array(feature_importances) / sum(feature_importances)
x_labs_WITHRPD = ["EM", "HAD1", "HAD2", "HAD3", "RPD"]
x_labs_NORPD = ["EM", "HAD1", "HAD2", "HAD3"]
x_labs = x_labs_NORPD if len(feature_importances) == 4 else x_labs_WITHRPD
plot.bar(x_labs, feature_importances, color='blue')
plot.title(f"{name} Importance Metrics of Modules in Predicting BRAN Energy", fontsize=8)
plot.xlabel("Modules")
plot.ylabel("Importance")
plot.savefig(f"{name}ImportanceBars.png")
plot.clf()
print()
########################################################################
# Creates a GradientBoostingRegressor
# Any allowed decision tree heights BUT 4-6 range is ideal for speed
# 0.05-0.2 Training Ratio is ideal for speed
# Primarily for visualizing individual decision trees
# since HistGradientBoostingRegressor's trees can't be viewed
########################################################################
def xg_visualize(X, y, tree_depth, training_ratio, name):
print(name)
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = training_ratio)
model = 0
if tree_depth < 1:
model = GradientBoostingRegressor()
else:
model = GradientBoostingRegressor(max_depth = tree_depth)
model.fit(X_train, y_train)
# Evaluate Model Against Test Set
y_pred = model.predict(X_test)
y_test_mean = sum(v for v in y_test) / len(y_test)
######## Test R^2 value = 1 - SSE/SST
accuracy = 1 - sum( (y_test[i] - y_pred[i])**2 for i in range(len(y_test)) ) / sum( (y_test[i] - y_test_mean)**2 for i in range(len(y_test)) )
print("Test accuracy:", accuracy)
print()
estimators = model.estimators_
feat_names = ["EM", "HAD1", "HAD2", "HAD3", "RPD"]
for i in range(3):
visualize_tree(estimators, feat_names, i, name)
def visualize_tree(estimators, feat_names, idx, name):
estimator = estimators[idx][0]
export_graphviz(estimator, out_file=f"{name}_tree{idx}.dot", feature_names=feat_names, filled=True, rounded=True, special_characters=True)
with open(f"{name}_tree{idx}.dot") as f:
dot_graph = f.read()
graph = graphviz.Source(dot_graph)
graph.format = "png"
graph.render(f"{name}_tree{idx}")
###############################################################################
def rf_cross_val(X, y, tree_depth):
# X_train, X_test, y_train, y_test = train_test_split(X, y, train_size = 0.8)
regr = RandomForestRegressor(max_depth = tree_depth, random_state = 2)
regr.fit(X, y)
print(cross_val_score(regr, X, y, cv = 5))
def xg_cross_val(X, y, tree_depth):
regr = HistGradientBoostingRegressor(max_depth = tree_depth)
regr.fit(X, y)
print(cross_val_score(regr, X, y, cv = 5))
###############################################################################
if __name__ == "__main__":
main()
\ No newline at end of file
ml/ml_diagnostic_plots/A_no_RPD.png

42 KiB

ml/ml_diagnostic_plots/A_w_RPD.png

43 KiB

ml/ml_diagnostic_plots/C_no_RPD.png

46.9 KiB

ml/ml_diagnostic_plots/C_w_RPD.png

45.8 KiB

ml/ml_diagnostic_plots/Energy_Dists/EnergyFreqPlots_RPD_A_w_RPD (Truncated).png

48.8 KiB

ml/ml_diagnostic_plots/Energy_Dists/EnergyFreqPlots_RPD_A_w_RPD.png

38.6 KiB

ml/ml_diagnostic_plots/Energy_Dists/EnergyFreqPlots_RPD_C_w_RPD (Truncated).png

52 KiB

ml/ml_diagnostic_plots/Energy_Dists/EnergyFreqPlots_RPD_C_w_RPD.png

38.8 KiB

ml/ml_diagnostic_plots/Energy_Dists/EnergyFreqPlots__A_w_RPD_true_BRAN (Truncated).png

50.6 KiB

ml/ml_diagnostic_plots/Energy_Dists/EnergyFreqPlots__A_w_RPD_true_BRAN.png

39.8 KiB

ml/ml_diagnostic_plots/Energy_Dists/EnergyFreqPlots__C_w_RPD_true_BRAN (Truncated).png

53.1 KiB

ml/ml_diagnostic_plots/Energy_Dists/EnergyFreqPlots__C_w_RPD_true_BRAN.png

39.8 KiB

ml/ml_diagnostic_plots/hist_A_no_RPD.png

27.2 KiB

ml/ml_diagnostic_plots/hist_A_w_RPD.png

25.3 KiB

ml/ml_diagnostic_plots/hist_C_no_RPD.png

26.9 KiB

ml/ml_diagnostic_plots/hist_C_w_RPD.png

25.4 KiB

ml/ml_diagnostic_plots/log_A_no_RPD.png

47.5 KiB

ml/ml_diagnostic_plots/log_A_w_RPD.png

46.9 KiB

ml/ml_diagnostic_plots/log_C_no_RPD.png

49 KiB

0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment