TensorFlow.js机器学习预测鸢尾花种类
一、加载IRIS数据集
创建index.html入口文件,跳转到script主文件。
<script src="script.js"></script>
在script.js文件夹中利用预先准备好的脚本生成鸢尾花数据集,包括训练集和验证集,并打印查看。
import {getIrisData, IRIS_CLASSES} from "./data.js"; window.onload = () => { // 加载数据 const [xTrain, yTrain, xTest, yTest] = getIrisData(0.2); // 打印查看数据集 xTrain.print(); yTrain.print(); xTest.print(); yTest.print(); // 打印鸢尾花种类类别 console.log(IRIS_CLASSES); }
getIrisData(0.2):获取数据集的时候,将20%的数据当成测试集,剩下的80%当成训练集。
xTrain:训练集的特征值。
yTrain:训练集的目标值。
xTest:验证集的特征值。
yTest:验证集的目标值。
可以在控制台查看到结果:
其中特征矩阵里面的四个值分别表示:花萼的长度、花萼的宽度、花瓣的长度、花瓣的宽度。
目标值矩阵采用one-hot编码形式。
二、定义模型结构
初始化一个神经网络模型,为神经网络模型添加两层,配置模型的损失函数、激活函数、优化器、添加准确度度量。
// 定义网络模型 const model = tf.sequential(); // 添加隐藏层 model.add(tf.layers.dense({ units: 10, inputShape: [xTrain.shape[1]], activation: 'relu' })); // 添加输出层 model.add(tf.layers.dense({ units: 3, activation: 'softmax' })); // 配置模型 model.compile({ loss: "categoricalCrossentropy", optimizer: tf.train.adam(0.1), metrics: ['accuracy'] });
三、训练模型并可视化
训练结果需要等待,所以采用异步方式训练。
await model.fit(xTrain, yTrain,{ epochs: 100, batchSize: 32, validationData: [xTest, yTest], callbacks: tfvis.show.fitCallbacks( {name: '训练效果'}, ['loss', 'val_loss', 'acc', 'val_acc'], {callbacks: ['onEpochEnd']} ) });
训练结果:
四、预测
编写前端界面输入待预测数据,使用训练好的模型进行预测,将输出的Tensor转成普通数据并显示。
在index.html中编写form表单,用来输入预测数据。
<form action="" onsubmit="predict(this); return false"> 花萼长度:<input type="text" name="a"><br> 花萼宽度:<input type="text" name="b"><br> 花瓣长度:<input type="text" name="c"><br> 花瓣宽度:<input type="text" name="d"><br> <button type="submit">预测</button> </form>
输入数据的顺序不能错,因为我们训练数据的顺序就是花萼长度、花萼宽度、花瓣长度、花瓣宽度。
在Script.js中编写predict预测函数。
window.predict = (form) => { // 将表单获取的到数据转成Tensor const input = tf.tensor([[ form.a.value * 1, form.b.value * 1, form.c.value * 1, form.d.value * 1, ]]); // 预测 const pred = model.predict(input); alert(`预测结果:${IRIS_CLASSES[pred.argMax(1).dataSync(0)]}`) }
预测结果:gif动图有点模糊,可以自己动手试试看哦。
到此这篇关于TensorFlow.js机器学习预测鸢尾花种类的文章就介绍到这了,更多相关TensorFlow.js预测鸢尾花内容请搜索脚本之家以前的文章或继续浏览下面的相关文章希望大家以后多多支持脚本之家!
相关文章
使用js判断数组中是否包含某一元素(类似于php中的in_array())
这篇文章主要是对使用js判断数组中是否包含某一元素(类似于php中的in_array())需要的朋友可以过来参考下,希望对大家有所帮助2013-12-12深入理解JavaScript系列(1) 编写高质量JavaScript代码的基本要点
才华横溢的Stoyan Stefanov,在他写的由O’Reilly初版的新书《JavaScript Patterns》(JavaScript模式)中,我想要是为我们的读者贡献其摘要,那会是件很美妙的事情2012-01-01
最新评论