Skip to content

Commit

Permalink
更新辣
Browse files Browse the repository at this point in the history
  • Loading branch information
BlackTea-c committed Nov 29, 2023
1 parent b4a8ca9 commit 765cc16
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 139 deletions.
64 changes: 0 additions & 64 deletions eg1.py

This file was deleted.

54 changes: 0 additions & 54 deletions eg_dual_form.py

This file was deleted.

53 changes: 32 additions & 21 deletions 支持向量机/主程序代码.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,34 +2,34 @@


class SVM:
def __init__(self, max_iter=100, kernel='linear'): #设置最大迭代量,若超过max_iter仍不满足KKT条件则跳出; kernel即为核函数。
def __init__(self, max_iter, kernel='linear'): #设置最大迭代量,若超过max_iter仍不满足KKT条件则跳出; kernel即为核函数。
self.max_iter = max_iter
self._kernel = kernel

def init_args(self, features, labels):
self.m, self.n = features.shape
def init_args(self, features, labels): #features特征空间 labels变迁
self.m, self.n = features.shape #m行 n列;n为特征数 m为样本量
self.X = features
self.Y = labels
self.b = 0.0

# 将Ei保存在一个列表里
self.alpha = np.ones(self.m)
self.alpha = np.ones(self.m) #alpha_i 阿尔法初始化
self.E = [self._E(i) for i in range(self.m)]
# 松弛变量
self.C = 1.0

def _KKT(self, i):
def _KKT(self, i): #KKT条件的检验,这里没有在精度范围内进行检验呐
y_g = self._g(i) * self.Y[i]
if self.alpha[i] == 0:
return y_g >= 1
elif 0 < self.alpha[i] < self.C:
return y_g == 1
return y_g == 1 #满足式子则return True 反之False
else:
return y_g <= 1

# g(x)预测值,输入xi(X[i])
def _g(self, i):
r = self.b
r = self.b #g(x)=sum(a_y*y_j*K(xi,xj)) + b
for j in range(self.m):
r += self.alpha[j] * self.Y[j] * self.kernel(self.X[i], self.X[j])
return r
Expand All @@ -38,7 +38,7 @@ def _g(self, i):
def kernel(self, x1, x2):
if self._kernel == 'linear':
return sum([x1[k] * x2[k] for k in range(self.n)])
elif self._kernel == 'poly':
elif self._kernel == 'poly': #多项式和核函数 取p=2
return (sum([x1[k] * x2[k] for k in range(self.n)]) + 1)**2

return 0
Expand All @@ -47,41 +47,41 @@ def kernel(self, x1, x2):
def _E(self, i):
return self._g(i) - self.Y[i]

def _init_alpha(self):
def _init_alpha(self):#alpha的选择
# 外层循环首先遍历所有满足0<a<C的样本点,检验是否满足KKT
index_list = [i for i in range(self.m) if 0 < self.alpha[i] < self.C]
# 否则遍历整个训练集
non_satisfy_list = [i for i in range(self.m) if i not in index_list]
non_satisfy_list = [i for i in range(self.m) if i not in index_list] #alpha=0 or C
index_list.extend(non_satisfy_list)

for i in index_list:
if self._KKT(i):
continue
continue#满足KKT条件的a 则终止下列运行返回for循环重新判断

E1 = self.E[i]
E1 = self.E[i] #E1确定
# 如果E2是+,选择最小的;如果E2是负的,选择最大的
if E1 >= 0:
j = min(range(self.m), key=lambda x: self.E[x])
else:
j = max(range(self.m), key=lambda x: self.E[x])
return i, j

def _compare(self, _alpha, L, H):
def _compare(self, _alpha, L, H): #
if _alpha > H:
return H
elif _alpha < L:
return L
else:
return _alpha

def fit(self, features, labels):
def fit(self, features, labels): #Train
self.init_args(features, labels)

for t in range(self.max_iter):
# train
i1, i2 = self._init_alpha()
i1, i2 = self._init_alpha() #得到选择的两个alpha

# 边界
# 边界 L,H的确认:
if self.Y[i1] == self.Y[i2]:
L = max(0, self.alpha[i1] + self.alpha[i2] - self.C)
H = min(self.C, self.alpha[i1] + self.alpha[i2])
Expand All @@ -91,12 +91,13 @@ def fit(self, features, labels):

E1 = self.E[i1]
E2 = self.E[i2]
# eta=K11+K22-2K12
# eta=K11+K22-2K12 (书上的n)
eta = self.kernel(self.X[i1], self.X[i1]) + self.kernel(
self.X[i2],
self.X[i2]) - 2 * self.kernel(self.X[i1], self.X[i2])
if eta <= 0:
# print('eta <= 0')
print(eta)
if eta <= 0: #为什么这里会小于0????为什么小于0就不计算了?
#print(eta)
continue

alpha2_new_unc = self.alpha[i2] + self.Y[i2] * (
Expand Down Expand Up @@ -150,5 +151,15 @@ def score(self, X_test, y_test):
def _weight(self):
# linear model
yx = self.Y.reshape(-1, 1) * self.X
self.w = np.dot(yx.T, self.alpha)
return self.w
self.w = np.dot(yx.T, self.alpha) #np.dot矩阵乘法
return self.w



svm=SVM(max_iter=30)

X_train=np.array([[1,2],[0,2],[1,1],[3,4],[2,3],[1,4],[-1,-1]])
y_train=np.array([-1,-1,-1,1,1,1,-1])
svm.fit(X_train, y_train)

print(svm.score(X_train, y_train))
4 changes: 4 additions & 0 deletions 草稿.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@



print(1>2)

0 comments on commit 765cc16

Please sign in to comment.