当前位置: 首页 > news >正文

好坏质检二分类MLP 实战

任务

1、基于 data-mlp-01.csv 数据,建立 mlp 模型,计算其在测试数据上的准确率,可视化模型预测结果;
2、进行数据分离:test_size=0.33,random_state=10
3、模型结构:一层隐藏层,有 20 个神经元

参考资料

多层感知器MLP(原理)

多层感知器MLP实现非线性分类(原理)

视频:

36.40 实战准备_哔哩哔哩_bilibili

37.41 实战(一)_哔哩哔哩_bilibili

数据准备

数据集名称:data-mlp-01.csv

点我转到百度网盘获取数据集 提取码: 8497 

1、原数据可视化

加载数据

#load data
import pandas as pd
import numpy as np
data = pd.read_csv('data-mlp-01.csv')
data.head()

定义 X,y

#define the X and y
X = data.drop(['y'], axis = 1)
y = data.loc[:,'y']
X.head()

可视化

#visualize the data
%matplotlib inline 
from matplotlib import pyplot as plt
fig1 = plt.figure(figsize = (5,5))
passed = plt.scatter(X.loc[:,'x1'][y==1], X.loc[:,'x2'][y==1])
failed = plt.scatter(X.loc[:,'x1'][y==0], X.loc[:,'x2'][y==0])
plt.legend((passed, failed),('passed','failed'))
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('raw data')
plt.show()
#蓝色是 y==1 的结果, 橙色的是 y == 0  的结果

2、数据分离

#split the data
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X,y, test_size = 0.33, random_state = 10)
print(X_train.shape, X_test.shape, X.shape)#(275, 2) (136, 2) (411, 2)

3、建立 MLP 模型、训练、预测、可视化

创建模型

# set up the model
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Activationmlp = Sequential()#20:隐藏层神经元个数,input_dim =2 为输入有  x1, x2 两个维度,activation 为激活函数
mlp.add(Dense(units = 20, input_dim = 2, activation = 'sigmoid')) #增加一层输出层,输出层是 1 个神经元,激活函数也是 sigmoid
mlp.add(Dense(units =1, activation = 'sigmoid'))
mlp.summary()

模型配置

#compile the model(模型的配置)
mlp.compile(optimizer = 'adam', loss = 'binary_crossentropy')
# optimizer 为优化方法 ,loss这里用的是 二分类的损失函数

模型训练

#模型的训练
# train the model
mlp.fit(X_train, y_train, epochs = 3000) # epochs 迭代次数

进行预测、计算准确率

训练集

# make prediction and calculate the accuracy
y_train_predict = mlp.predict_classes(X_train)
from sklearn.metrics import accuracy_score
#训练数据集的准确率
accuracy_train = accuracy_score(y_train, y_train_predict)
print(accuracy_train)#0.96

 测试集

# 接下来看测试数据集的准确率
y_test_predict = mlp.predict_classes(X_test)
#测试数据集的准确率
accuracy_test = accuracy_score(y_test, y_test_predict)
print(accuracy_test)#0.9632352941176471
#可视化模型的预测结果
#首先看一下输出数据的结构,不合适时进行转换
print(type(y_train_predict)) # <class 'numpy.ndarray'>,不能直接进行索引
y_train_predict_converted = pd.Series(i[0] for i in y_train_predict)
print(y_train_predict_converted)

 创建新的数据、进行预测

#创建数据点集
xx,yy = np.meshgrid(np.arange(0,1,0.01), np.arange(0,1,0.01)) # 生成数据,0-1之间,间隔为 0.01
x_range = np.c_[xx.ravel(), yy.ravel()]
#接下来直接进行预测
y_range_predict = mlp.predict_classes(x_range)
print(type(y_range_predict))#<class 'numpy.ndarray'>
#format the output
y_range_predict_converted = pd.Series(i[0] for i in y_range_predict)
print(type(y_range_predict_converted))#<class 'pandas.core.series.Series'>

结果可视化

#最后一步,画图,把原来数据点的图拿过来,再加上新的数据点
fig2 = plt.figure(figsize = (5,5))#下面是新创建的数据点(预测数据)
passed_predict = plt.scatter(x_range[:,0][y_range_predict_converted==1],x_range[:,1][y_range_predict_converted==1])
failed_predict = plt.scatter(x_range[:,0][y_range_predict_converted==0],x_range[:,1][y_range_predict_converted==0])
#下面是原来的数据点
passed = plt.scatter(X.loc[:,'x1'][y==1], X.loc[:,'x2'][y==1])
failed = plt.scatter(X.loc[:,'x1'][y==0], X.loc[:,'x2'][y==0])plt.legend((passed, failed, passed_predict, failed_predict),('passed','failed','passed_predict','failed_predict'))
plt.xlabel('x1')
plt.ylabel('x2')
plt.title('prediction result')
plt.show()
#蓝色是 y==1 的结果, 橙色的是 y == 0  的结果

4、好坏质检二分类 mlp 实战 总结

1、通过 mlp 模型,在不增加特征项的情况下,实现了非线性二分类任务;
2、掌握了 mlp 模型的建立、配置与训练方法,并实现基于新数据的预测;
3、熟悉了 mlp 分类的预测数据格式,并实现格式转换;
4、核心算法参考链接:https://keras-cn.readthedocs.io/en/latest/#30skeras

http://www.lqws.cn/news/130591.html

相关文章:

  • 数字人技术的核心:AI与动作捕捉的双引擎驱动(210)
  • 网络安全中网络诈骗的攻防博弈
  • Flutter快速上手,入门教程
  • OPenCV CUDA模块图像处理-----对图像执行 均值漂移滤波(Mean Shift Filtering)函数meanShiftFiltering()
  • 架构设计技巧——架构设计模板
  • 区块链技术发展现状与应用前景分析
  • Windows系统工具:WinToolsPlus 之 SQL Server Suspect/质疑/置疑/可疑/单用户等 修复
  • intense-rp-api开源程序是一个具有直观可视化界面的 API,可以将 DeepSeek 非正式地集成到 SillyTavern 中
  • C#学习第27天:时间和日期的处理
  • 【Linux】编译器gcc/g++及其库的详细介绍
  • 《高等数学》(同济大学·第7版)第一章第七节无穷小的比较
  • C++11 defaulted和deleted函数从入门到精通
  • JavaScript 二维数组初始化:为什么 fill([]) 是个大坑?
  • 《波段操盘实战技法》速读笔记
  • 《射频识别(RFID)原理与应用》期末复习 RFID第二章 RFID基础与前端(知识点总结+习题巩固)
  • 【Code】Python金融基础
  • el-input限制输入数字,输入中文后数字校验失效
  • Spark实战能力测评模拟题精析【模拟考】
  • 实时数据湖架构设计:从批处理到流处理的企业数据战略升级
  • HarmonyOS 实战:给笔记应用加防截图水印
  • 【HarmonyOS 5】生活与服务开发实践详解以及服务卡片案例
  • function as a service的极简方案:通过jupyterhub和gradio搭建FAAS平台(一)
  • 如何在 React 中监听 div 的滚动事件
  • 从Node.js到React/Vue3:流式输出技术的全栈实现指南
  • (2025)Windows修改JupyterNotebook的字体,使用JetBrains Mono
  • 前端工具库lodash与lodash-es区别详解
  • Elasticsearch中的刷新(Refresh)和刷新间隔介绍
  • Comparable和Comparator
  • 腾讯位置商业授权AOI边界查询开发指南
  • 【PmHub面试篇】PmHub 整合 TransmittableThreadLocal(TTL)缓存用户数据面试专题解析