TensorFlow实现简单线性回归

 更新时间:2022年03月30日 14:58:36   作者:kylinxjd  
这篇文章主要为大家详细介绍了TensorFlow实现简单线性回归,文中示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下

本文实例为大家分享了TensorFlow实现简单线性回归的具体代码,供大家参考,具体内容如下

简单的一元线性回归

一元线性回归公式:

其中x是特征:[x1,x2,x3,…,xn,]T
w是权重,b是偏置值

代码实现

导入必须的包

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

# 屏蔽warning以下的日志信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

产生模拟数据

def generate_data():
    x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
    y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
    return x, y

x是100行1列的数据,tf.matmul是矩阵相乘,所以权值设置成二维的。
设置的w是1.3, b是1

实现回归

def myregression():
    """
    自实现线性回归
    :return:
    """
    x, y = generate_data()
    #     建立模型  y = x * w + b
    # w 1x1的二维数据
    w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
    b = tf.Variable(0.0, name='bias_b')

    y_predict = tf.matmul(x, a) + b

    # 建立损失函数
    loss = tf.reduce_mean(tf.square(y_predict - y))
    
    # 训练
    train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss=loss)

    # 初始化全局变量
    init_op = tf.global_variables_initializer()

  
    with tf.Session() as sess:
        sess.run(init_op)
        print('初始的权重:%f偏置值:%f' % (a.eval(), b.eval()))
    
        # 训练优化
        for i in range(1, 100):
            sess.run(train_op)
            print('第%d次优化的权重:%f偏置值:%f' % (i, a.eval(), b.eval()))
        # 显示回归效果
        show_img(x.eval(), y.eval(), y_predict.eval())

使用matplotlib查看回归效果

def show_img(x, y, y_pre):
    plt.scatter(x, y)
    plt.plot(x, y_pre)
    plt.show()

完整代码

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


def generate_data():
    x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
    y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
    return x, y


def myregression():
    """
    自实现线性回归
    :return:
    """
    x, y = generate_data()
    # 建立模型  y = x * w + b
    w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
    b = tf.Variable(0.0, name='bias_b')

    y_predict = tf.matmul(x, w) + b

    # 建立损失函数
    loss = tf.reduce_mean(tf.square(y_predict - y))
    # 训练
    train_op = tf.train.GradientDescentOptimizer(0.0001).minimize(loss=loss)

    init_op = tf.global_variables_initializer()

    with tf.Session() as sess:
        sess.run(init_op)
        print('初始的权重:%f偏置值:%f' % (w.eval(), b.eval()))
        # 训练优化
        for i in range(1, 35000):
            sess.run(train_op)
            print('第%d次优化的权重:%f偏置值:%f' % (i, w.eval(), b.eval()))
        show_img(x.eval(), y.eval(), y_predict.eval())


def show_img(x, y, y_pre):
    plt.scatter(x, y)
    plt.plot(x, y_pre)
    plt.show()


if __name__ == '__main__':
    myregression()

看看训练的结果(因为数据是随机产生的,每次的训练结果都会不同,可适当调节梯度下降的学习率和训练步数)

35000次的训练结果

以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持脚本之家。

相关文章

  • 详解python读取matlab数据(.mat文件)

    详解python读取matlab数据(.mat文件)

    本文主要介绍了python读取matlab数据,文中通过示例代码介绍的非常详细,具有一定的参考价值,感兴趣的小伙伴们可以参考一下
    2021-12-12
  • Python中通过property设置类属性的访问

    Python中通过property设置类属性的访问

    为了达到类似C++类的封装性能,可以使用property来设置Python类属性的访问权限,本文就介绍一下Python中通过property设置类属性的访问,感兴趣的可以了解一下,感兴趣的可以了解一下
    2023-09-09
  • 基于Python实现图片浏览器的应用程序

    基于Python实现图片浏览器的应用程序

    图像浏览器应用程序是一种非常常见和实用的工具,这篇文章就来为大家介绍一下如何使用Python编程语言和wxPython库创建一个简单的图像浏览器应用程序,感兴趣的可以了解下
    2023-10-10
  • 解决python写入带有中文的字符到文件错误的问题

    解决python写入带有中文的字符到文件错误的问题

    今天小编就为大家分享一篇解决python写入带有中文的字符到文件错误的问题,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2019-01-01
  • python selenium浏览器复用技术的使用

    python selenium浏览器复用技术的使用

    本文主要介绍了python selenium浏览器复用技术的使用,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友们下面随着小编来一起学习学习吧
    2023-02-02
  • Python中的迭代器与生成器高级用法解析

    Python中的迭代器与生成器高级用法解析

    这篇文章主要介绍了Python中的迭代器与生成器高级用法解析,生成器在Python中是迭代器的一种,这里我们会讲到生成表达式、链式生成器等深层次内容,需要的朋友可以参考下
    2016-06-06
  • 基于python解线性矩阵方程(numpy中的matrix类)

    基于python解线性矩阵方程(numpy中的matrix类)

    这篇文章主要介绍了基于python解线性矩阵方程(numpy中的matrix类),文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可以参考下
    2019-10-10
  • Python多进程之进程同步及通信详解

    Python多进程之进程同步及通信详解

    这篇文章主要为大家介绍了Python多进程之进程同步及通信,具有一定的参考价值,感兴趣的小伙伴们可以参考一下,希望能够给你带来帮助
    2021-11-11
  • 安装python依赖包psycopg2来调用postgresql的操作

    安装python依赖包psycopg2来调用postgresql的操作

    这篇文章主要介绍了安装python依赖包psycopg2来调用postgresql的操作,具有很好的参考价值,希望对大家有所帮助。一起跟随小编过来看看吧
    2021-01-01
  • python爬虫入门教程之糗百图片爬虫代码分享

    python爬虫入门教程之糗百图片爬虫代码分享

    这篇文章主要介绍了python爬虫入门教程之糗百图片爬虫代码分享,本文以抓取糗事百科内涵图为需求写了一个爬虫,,需要的朋友可以参考下
    2014-09-09

最新评论