Skip to content

Latest commit

 

History

History
61 lines (52 loc) · 2.53 KB

File metadata and controls

61 lines (52 loc) · 2.53 KB

Stroke Prediction Using Neural Networks

Overview

This project implements various neural network models to predict strokes using the Stroke Prediction Dataset from Kaggle. The goal is to optimize classification performance while addressing challenges like imbalanced datasets and high false-positive rates in medical predictions.

Dataset

  • Source: Stroke Prediction Dataset on Kaggle
  • Description:
    • Instances: Approximately 5,000 samples.
    • Features: Demographic and health-related attributes, including:
      • age, hypertension, heart_disease, avg_glucose_level, bmi, etc.
    • Target: stroke (binary classification: 1 for stroke, 0 for no stroke).

Key Models

Model 1: Baseline Neural Network

  • Architecture:
    • Input layer, one dense hidden layer (sigmoid activation), and output layer (softmax activation).
  • Performance:
    • Training Time: ~0.89 seconds.
    • Accuracy: ~95.14%.
    • F1 Score: ~0.93.
    • Precision: ~90.46%.
    • Recall: ~95.11%.

Model 2: Optimized Neural Network

  • Enhancements:
    • Added L2 regularization and callback functions.
    • Improved training time and performance metrics.
  • Performance:
    • Training Time: ~4.2 seconds.
    • Accuracy: ~95%.
    • Best balance between training speed and accuracy.

Model 6: LSTM-Based Neural Network

  • Performance:
    • Train, Validation, Test Accuracy: ~85%.
    • ROC AUC: 0.78.
    • Precision for Minority Class: Low, indicating false positives are prevalent.

Results and Insights

  1. General Findings:
    • High training and validation accuracy (~95%) across most models.
    • Significant challenges in addressing false positives and minority class imbalances.
  2. Clinical Implications:
    • Balancing false positives and negatives is crucial in stroke prediction.
    • False positives may lead to unnecessary tests, but missing true positives (strokes) can be life-threatening.

Visualizations

Key plots include:

  1. Loss vs. Epoch: Training and validation loss curves.
  2. Confusion Matrix: Illustrating true vs. predicted labels.
  3. ROC Curve: Model's ability to distinguish between classes.

Files

  • analysis/csc578_group5.ipynb: Jupyter notebook with data preprocessing, model training, and evaluation.
  • data/: Placeholder for dataset information or preprocessing scripts.

How to Use

  1. Clone the repository:
    git clone https://github.com/tejas-1911/Stroke-Prediction-Using-Neural-Networks.git