sklearn机器学习 Python代码通用模板

news/2025/2/24 10:41:51

以下是一个使用 `scikit-learn`(sklearn)进行机器学习的通用 Python 代码模板。这个模板涵盖了数据加载、预处理、模型训练、评估和预测的基本流程,适用于常见的机器学习任务。

 

```python

# 导入必要的库

import numpy as np

import pandas as pd

from sklearn.model_selection import train_test_split

from sklearn.preprocessing import StandardScaler

from sklearn.metrics import accuracy_score, classification_report, confusion_matrix

from sklearn.ensemble import RandomForestClassifier # 以随机森林为例,可根据任务替换模型

 

# 1. 加载数据

# 假设数据是一个 CSV 文件

data = pd.read_csv('your_dataset.csv')

 

# 2. 数据预处理

# 分离特征和目标变量

X = data.drop('target_column', axis=1) # 替换 'target_column' 为目标列名

y = data['target_column']

 

# 将数据集分为训练集和测试集

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

 

# 特征标准化(根据需求选择)

scaler = StandardScaler()

X_train = scaler.fit_transform(X_train)

X_test = scaler.transform(X_test)

 

# 3. 选择并训练模型

model = RandomForestClassifier(random_state=42) # 以随机森林为例,可替换为其他模型

model.fit(X_train, y_train)

 

# 4. 模型评估

# 在测试集上进行预测

y_pred = model.predict(X_test)

 

# 计算准确率

accuracy = accuracy_score(y_test, y_pred)

print(f'模型准确率: {accuracy:.2f}')

 

# 打印分类报告

print("分类报告:")

print(classification_report(y_test, y_pred))

 

# 打印混淆矩阵

print("混淆矩阵:")

print(confusion_matrix(y_test, y_pred))

 

# 5. 模型保存(可选)

import joblib

joblib.dump(model, 'model.pkl') # 保存模型到文件

 

# 6. 加载模型并进行预测(可选)

loaded_model = joblib.load('model.pkl')

new_predictions = loaded_model.predict(X_test) # 对新数据进行预测

```

 

### 关键步骤说明:

1. **数据加载**:从文件(如 CSV)中加载数据。

2. **数据预处理**:

   - 分离特征(`X`)和目标变量(`y`)。

   - 将数据集分为训练集和测试集。

   - 对特征进行标准化或归一化(可选)。

3. **模型训练**:选择模型(如随机森林、逻辑回归等)并训练。

4. **模型评估**:使用测试集评估模型性能,输出准确率、分类报告和混淆矩阵。

5. **模型保存与加载**:将训练好的模型保存到文件,便于后续使用。

 

### 注意事项:

- 根据任务类型(分类、回归、聚类等)选择合适的模型和评估指标。

- 如果数据量较大,可以使用交叉验证(`cross_val_score`)或网格搜索(`GridSearchCV`)优化模型。

- 对于非数值型数据,需要进行编码(如 `OneHotEncoder` 或 `LabelEncoder`)。

在机器学习中,模型选择和调参是提升性能的关键步骤。Python 的 `scikit-learn` 提供了丰富的工具来实现这些任务。以下是一个完整的模型选择和调参的流程,包括交叉验证、网格搜索和随机搜索。

### 1. 导入必要的库
```python
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, GridSearchCV, RandomizedSearchCV, cross_val_score
from sklearn.preprocessing import StandardScaler
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import accuracy_score, classification_report
from scipy.stats import randint
```

### 2. 加载和预处理数据
```python
# 加载数据
data = pd.read_csv('your_dataset.csv')

# 分离特征和目标变量
X = data.drop('target_column', axis=1)  # 替换 'target_column' 为目标列名
y = data['target_column']

# 数据集划分
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

# 特征标准化
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)
```

### 3. 模型选择
通过交叉验证评估多个模型的性能,选择最佳模型。
```python
# 示例:比较随机森林和支持向量机
from sklearn.svm import SVC

models = {
    'RandomForest': RandomForestClassifier(random_state=42),
    'SVM': SVC(random_state=42)
}

# 交叉验证评估
for name, model in models.items():
    scores = cross_val_score(model, X_train, y_train, cv=5, scoring='accuracy')
    print(f'{name} 的平均准确率: {np.mean(scores):.2f}')
```

### 4. 超参数调优
#### 4.1 网格搜索(Grid Search)
网格搜索会遍历所有给定的参数组合,找到最优参数。
```python
# 定义参数网格
param_grid = {
    'n_estimators': [50, 100, 200],
    'max_depth': [None, 10, 20, 30],
    'min_samples_split': [2, 5, 10]
}

# 初始化模型
model = RandomForestClassifier(random_state=42)

# 网格搜索
grid_search = GridSearchCV(estimator=model, param_grid=param_grid, cv=5, scoring='accuracy')
grid_search.fit(X_train, y_train)

# 输出最佳参数和得分
print(f'最佳参数: {grid_search.best_params_}')
print(f'最佳交叉验证得分: {grid_search.best_score_:.2f}')

# 使用最佳模型进行预测
best_model = grid_search.best_estimator_
y_pred = best_model.predict(X_test)
print(f'测试集准确率: {accuracy_score(y_test, y_pred):.2f}')
```

#### 4.2 随机搜索(Random Search)
随机搜索从参数分布中随机采样,适合参数空间较大的情况。
```python
# 定义参数分布
param_dist = {
    'n_estimators': randint(50, 200),
    'max_depth': [None, 10, 20, 30],
    'min_samples_split': randint(2, 11)
}

# 随机搜索
random_search = RandomizedSearchCV(estimator=model, param_distributions=param_dist, n_iter=10, cv=5, scoring='accuracy', random_state=42)
random_search.fit(X_train, y_train)

# 输出最佳参数和得分
print(f'最佳参数: {random_search.best_params_}')
print(f'最佳交叉验证得分: {random_search.best_score_:.2f}')

# 使用最佳模型进行预测
best_model = random_search.best_estimator_
y_pred = best_model.predict(X_test)
print(f'测试集准确率: {accuracy_score(y_test, y_pred):.2f}')
```

### 5. 模型评估
使用测试集评估最终模型的性能。
```python
# 打印分类报告
print("分类报告:")
print(classification_report(y_test, y_pred))

# 打印混淆矩阵
from sklearn.metrics import confusion_matrix
print("混淆矩阵:")
print(confusion_matrix(y_test, y_pred))
```

### 6. 保存模型
将训练好的模型保存到文件,便于后续使用。
```python
import joblib
joblib.dump(best_model, 'best_model.pkl')
```

### 总结
- **模型选择**:通过交叉验证比较多个模型的性能。
- **调参方法**:
  - 网格搜索(`GridSearchCV`):适合小规模参数空间。
  - 随机搜索(`RandomizedSearchCV`):适合大规模参数空间。
- **模型评估**:使用测试集评估模型性能,输出分类报告和混淆矩阵。
- **模型保存**:将最佳模型保存到文件。

通过以上步骤,可以系统地选择和优化机器学习模型。


http://www.niftyadmin.cn/n/5864230.html

相关文章

Ubuntu编译jetlinks-ui-vue

安装node 需要18.14.0以上 LINUX安装node/nodejs-CSDN博客 启动jetlinks-community Ubuntu安装geteck/jetlinks实战:源码启动-CSDN博客 下载 git clone https://gitee.com/jetlinks/jetlinks-ui-vue 准备 yarn 编译 # npm run build yarn build启动 yarn dev 测试

WebXR教学 01 基础介绍

什么是WebXR? 定义 XR VR AR Web上使用XR技术的API WebXR 是一组用于在 Web 浏览器中实现虚拟现实(VR)和增强现实(AR)应用的技术标准。它由 W3C 的 Immersive Web 工作组开发,旨在提供跨设备的沉浸式体验…

【算法系列】荷兰国旗问题:三指针法原地排序

一、题目(leetcode75 颜色分类 --三分数组) 二、思路 算法核心:三指针分治策略 该问题被称为“荷兰国旗问题”(Dutch National Flag Problem),由计算机科学家Edsger Dijkstra提出。其核心思想是通过三个指针将数组划分为三个区…

Android 技术栈

这里有必要学一下。 Android 串口通信-CSDN博客

LangChain大模型应用开发:构建Agent智能体

介绍 大家好,博主又来给大家分享知识了。今天要给大家分享的内容是使用LangChain进行大模型应用开发中的构建Agent智能体。 在LangChain中,Agent智能体是一种能够根据输入的任务或问题,动态地决定使用哪些工具(如搜索引擎、数据库查询等)来…

深度学习(3)-TensorFlow入门(梯度带)

TensorFlow看起来很像NumPy。但是NumPy无法做到的是,检索任意可微表达式相对于其输入的梯度。你只需要创建一个GradientTape作用域,对一个或多个输入张量做一些计算,然后就可以检索计算结果相对于输入的梯度,如代码清单3-10所示。…

【新手初学】SQL注入之二次注入、中转注入、伪静态注入

二次注入 一、概念 二次注入可以理解为,攻击者构造的恶意数据存储在数据库后,恶意数据被读取并进入到SQL查询语句所导致的注入。 二、原理 防御者可能在用户输入恶意数据时对其中的特殊字符进行了转义处理,但在恶意数据插入到数据库时被处…

代码审计入门学习之sql注入

路由规则 入口文件&#xff1a;index.php <?php // ---------------------------------------------------------------------- // | wuzhicms [ 五指互联网站内容管理系统 ] // | Copyright (c) 2014-2015 http://www.wuzhicms.com All rights reserved. // | Licensed …