4. Model Tuning & Threshold Optimization

import pandas as pd
import numpy as np
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score
from sklearn.utils.class_weight import compute_sample_weight

Load data and split

df = pd.read_csv("engineered_heart_data.csv")

X = df.drop("target", axis=1)
y = df["target"]

X_train, X_val, y_train, y_val = train_test_split(X, y, stratify=y, test_size=0.2, random_state=42)
sample_weights = compute_sample_weight(class_weight="balanced", y=y_train)

Grid search over learning rates and thresholds

import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

learning_rates = [0.01, 0.05, 0.1, 0.2]
thresholds = np.arange(0.3, 0.71, 0.05)
f1_matrix = np.zeros((len(learning_rates), len(thresholds)))

for i, lr in enumerate(learning_rates):
    model = xgb.XGBClassifier(
        use_label_encoder=False,
        eval_metric="logloss",
        objective="binary:logistic",
        learning_rate=lr,
        n_estimators=50,
        random_state=42
    )
    model.fit(X_train, y_train, sample_weight=sample_weights)
    probs = model.predict_proba(X_val)[:, 1]
    for j, t in enumerate(thresholds):
        preds = (probs >= t).astype(int)
        f1_matrix[i, j] = f1_score(y_val, preds)

3D plot of F1 scores

T, LR = np.meshgrid(thresholds, learning_rates)
fig = plt.figure(figsize=(10, 6))
ax = fig.add_subplot(111, projection='3d')
surf = ax.plot_surface(LR, T, f1_matrix, cmap='viridis', edgecolor='k')
ax.set_xlabel("Learning Rate")
ax.set_ylabel("Threshold")
ax.set_zlabel("F1 Score")
ax.set_title("F1 Score by Learning Rate and Threshold")
plt.tight_layout()
plt.show()

Report best configuration

best_idx = np.unravel_index(np.argmax(f1_matrix, axis=None), f1_matrix.shape)
best_lr = learning_rates[best_idx[0]]
best_threshold = thresholds[best_idx[1]]
best_f1 = f1_matrix[best_idx]

print("🔥 Best Configuration Found:")
print(f"Learning Rate: {best_lr}")
print(f"Threshold    : {best_threshold}")
print(f"F1 Score     : {best_f1:.4f}")