# TensorFlow训练单特征和多特征的线性回归

2017-04-17

TensorFlow训练单特征和多特征的线性回归：线性回归是很常见的一种回归，线性回归可以用来预测或者分类，主要解决线性问题。

TensorFlow训练单特征和多特征的线性回归：线性回归是很常见的一种回归，线性回归可以用来预测或者分类，主要解决线性问题。

## 单特征线性回归

```X = tf.placeholder(tf.float32, [None, 1])
w = tf.Variable(tf.zeros([1, 1]))
b = tf.Variable(tf.zeros([1]))
y = tf.matmul(X, w) + b
Y = tf.placeholder(tf.float32, [None, 1])```

`cost = tf.reduce_mean(tf.square(Y-y))`

`train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cost)`

```import tensorflow as tf

X = tf.placeholder(tf.float32, [None, 1])
w = tf.Variable(tf.zeros([1, 1]))
b = tf.Variable(tf.zeros([1]))
y = tf.matmul(X, w) + b
Y = tf.placeholder(tf.float32, [None, 1])

# 成本函数 sum(sqr(y_-y))/n
cost = tf.reduce_mean(tf.square(Y-y))

# 用梯度下降训练

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

x_train = [[1],[2],[3],[4],[5],[6],[7],[8],[9],[10]]
y_train = [[10],[11.5],[12],[13],[14.5],[15.5],[16.8],[17.3],[18],[18.7]]

for i in range(10000):
sess.run(train_step, feed_dict={X: x_train, Y: y_train})
print("w:%f" % sess.run(w))
print("b:%f" % sess.run(b))```

## 多特征线性回归

y为m行1列矩阵，x为m行n列矩阵，w为n行1列矩阵。TensorFlow中用如下来表示模型。

```X = tf.placeholder(tf.float32, [None, n])
w = tf.Variable(tf.zeros([n, 1]))
b = tf.Variable(tf.zeros([1]))
y = tf.matmul(X, w) + b
Y = tf.placeholder(tf.float32, [None, 1])```

`cost = tf.reduce_mean(tf.square(Y-y))`

`train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cost)`

```import tensorflow as tf

X = tf.placeholder(tf.float32, [None, 2])
w = tf.Variable(tf.zeros([2, 1]))
b = tf.Variable(tf.zeros([1]))
y = tf.matmul(X, w) + b
Y = tf.placeholder(tf.float32, [None, 1])

# 成本函数 sum(sqr(y_-y))/n
cost = tf.reduce_mean(tf.square(Y-y))

# 用梯度下降训练

init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)

x_train = [[1, 2], [2, 1], [2, 3], [3, 5], [1, 3], [4, 2], [7, 3], [4, 5], [11, 3], [8, 7]]
y_train = [[7], [8], [10], [14], [8], [13], [20], [16], [28], [26]]

for i in range(10000):
sess.run(train_step, feed_dict={X: x_train, Y: y_train})
print("w0:%f" % sess.run(w[0]))
print("w1:%f" % sess.run(w[1]))
print("b:%f" % sess.run(b))```