Skip to content
Snippets Groups Projects

Hyperparameter tuning

Merged brianjr3 requested to merge hyperparameter-tuning into main
35 files
+ 271
193
Compare changes
  • Side-by-side
  • Inline
Files
35
+ 116
44
@@ -16,47 +16,72 @@ batch_size_seer = 1024
### Metabric ###
CPH_metabric = EasyDict({
'epochs': 200
'epochs': 100
})
### SUPPORT ###
CPH_support = EasyDict({
'epochs': 200
'epochs': 100
})
### SEER ###
CPH_seer = EasyDict({
'epochs': 200,
'epochs': 100,
})
########################## DeepHit ##########################
# alpha: governs the contribution of likelihoo and rank loss
# sigma: scales loss function
### Metabric ###
DeepHit_metabric = EasyDict({
'batch_size': batch_size_metabric,
'learning_rate': 0.01,
'epochs': 100,
'hidden_size': 32,
'dropout': 0.1
# network
'hidden_layers_size': [16],
'dropout': 0.1,
# loss
'alpha': 0.2,
'sigma': 0.1,
# optimizer
'learning_rate': 0.01
})
### SUPPORT ###
DeepHit_support = EasyDict({
'batch_size': batch_size_support,
'learning_rate': 0.01,
'epochs': 100,
'hidden_size': 32,
'dropout': 0.1
# network
'hidden_layers_size': [16, 16],
'dropout': 0.1,
# loss
'alpha': 0.2,
'sigma': 0.1,
# optimizer
'learning_rate': 0.01
})
### SEER ###
DeepHit_seer = EasyDict({
'batch_size': batch_size_seer,
'learning_rate': 0.01,
'epochs': 100,
'hidden_size_indiv': 32,
'hidden_size_shared': 64,
'dropout': 0.1
# network
'hidden_layers_size': [32, 32],
'dropout': 0.1,
# loss
'alpha': 0.2,
'sigma': 0.1,
# optimizer
'learning_rate': 0.001
})
@@ -65,32 +90,49 @@ DeepHit_seer = EasyDict({
### Metabric ###
DeepSurv_metabric = EasyDict({
'batch_size': batch_size_metabric,
'learning_rate': 0.01,
'epochs': 100,
'hidden_size': 32,
'dropout': 0.1
# Network
'hidden_layers_size': [32, 32],
'dropout': 0.1,
# Adam
'learning_rate': 1e-2,
'weight_decay': 0,
})
### SUPPORT ###
DeepSurv_support = EasyDict({
'batch_size': batch_size_support,
'learning_rate': 0.01,
'epochs': 100,
'hidden_size': 32,
'dropout': 0.1
# Network
'hidden_layers_size': [32, 32],
'dropout': 0.1,
# Adam
'learning_rate': 1e-2,
'weight_decay': 0,
})
### SEER ###
DeepSurv_seer = EasyDict({
'batch_size': batch_size_support,
'learning_rate': 0.01,
'batch_size': batch_size_seer,
'epochs': 100,
'hidden_size': 32,
# Network
'hidden_layers_size': [64, 64],
'dropout': 0.1,
# Adam
'learning_rate': 1e-2,
'weight_decay': 0.1,
})
########################## Deep Survival Machines ##########################
# hidden_size (list): number of layers and number of nodex
# k: number of underlying parametric distributions
### Metabric ###
DSM_metabric = EasyDict({
@@ -134,46 +176,70 @@ DSM_seer = EasyDict({
### Metabric ###
PCHazard_metabric = EasyDict({
'batch_size': batch_size_metabric,
'learning_rate': 0.01,
'epochs': 100,
'hidden_size': 32,
'dropout': 0.1
# Network
'hidden_layers_size': [64, 64],
'dropout': 0.1,
# AdamWR
'learning_rate': 1e-2,
'decoupled_weight_decay': 0.8,
'cycle_multiplier': 2
})
### SUPPORT ###
PCHazard_support = EasyDict({
'batch_size': batch_size_support,
'learning_rate': 0.01,
'epochs': 100,
'hidden_size': 32,
'dropout': 0.1
# Network
'hidden_layers_size': [32, 32],
'dropout': 0.1,
# AdamWR
'learning_rate': 1e-2,
'decoupled_weight_decay': 0.8,
'cycle_multiplier': 2
})
### SEER ###
PCHazard_seer = EasyDict({
'batch_size': batch_size_seer,
'learning_rate': 0.01,
'epochs': 100,
'hidden_size': 32,
# Network
'hidden_layers_size': [32, 32, 32, 32],
'dropout': 0.1,
# AdamWR
'learning_rate': 1e-3,
'decoupled_weight_decay': 0.8,
'cycle_multiplier': 2
})
########################## Random Survival Forests ##########################
# epoch: in this context, refers to number of trees to generate in the forest.
# max_depth: maximum depth of the tree.
### Metabric ###
RSF_metabric = EasyDict({
'epochs': 100
'epochs': 200,
'max_depth': 4
})
### SUPPORT ###
RSF_support = EasyDict({
'epochs': 100
'epochs': 200,
'max_depth': 4
})
### SEER ###
RSF_seer = EasyDict({
'epochs': 100,
'epochs': 200,
'max_depth': 4
})
########################## SurvTRACE #########################
@@ -181,7 +247,7 @@ RSF_seer = EasyDict({
# and not on different variants.
### METABRIC ###
SurvTRACE_metabric = EasyDict(
survtrace_metabric = EasyDict(
{
'num_durations': 5, # num of discrete intervals for prediction, e.g., num_dur = 5 means the whole period is discretized to be 5 intervals
'seed': 1234,
@@ -210,15 +276,17 @@ SurvTRACE_metabric = EasyDict(
'pruned_heads': {}, # no use
# hyperparameters
'batch_size': 64,
'batch_size': batch_size_metabric,
'weight_decay': 1e-4,
'learning_rate': 1e-3,
'epochs': 100
'epochs': 100,
'gamma1': 1,
'gamma2': 1
}
)
### SUPPORT ###
SurvTRACE_support = EasyDict(
survtrace_support = EasyDict(
{
'num_durations': 5, # num of discrete intervals for prediction, e.g., num_dur = 5 means the whole period is discretized to be 5 intervals
'seed': 1234,
@@ -247,16 +315,18 @@ SurvTRACE_support = EasyDict(
'pruned_heads': {}, # no use
# hyperparameters
'batch_size': 128,
'batch_size': batch_size_support,
'weight_decay': 0,
'learning_rate': 1e-3,
'epochs': 100
'epochs': 100,
'gamma1': 1,
'gamma2': 1
}
)
### SUPPORT ###
SurvTRACE_seer = EasyDict(
### SEER ###
survtrace_seer = EasyDict(
{
'num_durations': 5, # num of discrete intervals for prediction, e.g., num_dur = 5 means the whole period is discretized to be 5 intervals
'seed': 1234,
@@ -274,7 +344,7 @@ SurvTRACE_seer = EasyDict(
'num_event': 1, # only set when using SurvTraceMulti for competing risks
'hidden_act': 'gelu',
'attention_probs_dropout_prob': 0.1,
'early_stop_patience': 5,
'early_stop_patience': 10,
'initializer_range': 0.02,
'layer_norm_eps': 1e-12,
'max_position_embeddings': 512, # # no use
@@ -286,9 +356,11 @@ SurvTRACE_seer = EasyDict(
'val_batch_size': 10000,
# hyperparameters
'batch_size': 1024,
'batch_size': batch_size_seer,
'weight_decay': 0,
'learning_rate': 1e-4,
'epochs': 100
'epochs': 100,
'gamma1': 1,
'gamma2': 1
}
)
\ No newline at end of file
Loading