Skip to content
Snippets Groups Projects
Commit fca282f9 authored by sheng's avatar sheng
Browse files

Update file diabetes.py

parent 34e235fd
Branches main
No related tags found
No related merge requests found
......@@ -105,7 +105,6 @@ class DiabetesRiskApp(QMainWindow):
# Data preprocessing including outlier handling and normalization
def data_preprocess(self):
if self.data is not None:
# Select target variable
target_variable = self.comboOutcome.currentText()
if target_variable:
# Separate target variable and features
......@@ -231,7 +230,6 @@ class DiabetesRiskApp(QMainWindow):
if y.dtype != 'int' and y.dtype != 'object':
y = (y > y.mean()).astype(int)
# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
param_grid = {
......@@ -242,7 +240,6 @@ class DiabetesRiskApp(QMainWindow):
lr_model = GridSearchCV(LogisticRegression(random_state=42), param_grid, cv=3, scoring='roc_auc')
lr_model.fit(X_train, y_train)
# Evaluate model performance on test set
best_model = lr_model.best_estimator_
y_pred = best_model.predict(X_test)
y_pred_proba = best_model.predict_proba(X_test)[:, 1]
......@@ -263,7 +260,6 @@ class DiabetesRiskApp(QMainWindow):
f"Best Parameters: {lr_model.best_params_}\nAccuracy: {accuracy:.2f}\nAUC: {roc_auc:.2f}")
def display_model_results(self, accuracy, fpr, tpr, roc_auc):
# Show model performance metrics and ROC curve
plt.figure()
plt.plot(fpr, tpr, label='AUC = %0.2f' % roc_auc)
plt.plot([0, 1], [0, 1], 'r--')
......@@ -287,17 +283,14 @@ class DiabetesRiskApp(QMainWindow):
X = self.data.drop(columns=[outcome_col])
y = self.data[outcome_col]
# 连续目标变量二值化
if y.dtype != 'int' and y.dtype != 'object':
y = (y > y.mean()).astype(int)
X = pd.get_dummies(X, drop_first=True)
X = X.astype(float)
# Split into training and testing sets
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Define list of models
models = [
('Logistic Regression', LogisticRegression(max_iter=1000)),
('Random Forest', RandomForestClassifier(n_estimators=100, random_state=42)),
......@@ -312,13 +305,11 @@ class DiabetesRiskApp(QMainWindow):
model_results = []
plt.figure(figsize=(10, 8))
# Model comparison analysis
for name, model in models:
model.fit(X_train, y_train)
y_pred = model.predict(X_test)
y_proba = model.predict_proba(X_test)[:, 1] if hasattr(model, "predict_proba") else None
accuracy = accuracy_score(y_test, y_pred)
precision = precision_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
......@@ -344,17 +335,14 @@ class DiabetesRiskApp(QMainWindow):
self.display_model_comparison(model_results)
def display_model_comparison(self, model_results):
# Create dialog
dialog = QDialog(self)
dialog.setWindowTitle("Model Comparison Results")
layout = QVBoxLayout(dialog)
# Create table
table = QTableWidget()
table.setColumnCount(6)
table.setHorizontalHeaderLabels(["Model Name", "Accuracy", "Precision", "Recall", "F1 Score", "AUC"])
# Adjust column width
header = table.horizontalHeader()
header.setSectionResizeMode(0, QHeaderView.Stretch)
header.setSectionResizeMode(1, QHeaderView.ResizeToContents)
......@@ -363,7 +351,6 @@ class DiabetesRiskApp(QMainWindow):
header.setSectionResizeMode(4, QHeaderView.ResizeToContents)
header.setSectionResizeMode(5, QHeaderView.ResizeToContents)
# Fill table data
table.setRowCount(len(model_results))
for row, (name, accuracy, precision, recall, f1, roc_auc) in enumerate(model_results):
table.setItem(row, 0, QTableWidgetItem(name))
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment