首頁>Club>
11
回覆列表
  • 1 # Calvin0007

    1,實現線性分類

    import numpy as np

    import matplotlib.pyplot as plt

    from sklearn.datasets.samples_generator import make_blobs

    from sklearn.svm import SVC

    #隨機生成點,n_samples:樣本點個數;centers:樣本點分為幾類;random_state:每次隨機生成一致;cluster_std:每類樣本點間的離散程度,值越大離散程度越大。

    X,y = make_blobs(n_samples=50, centers=2, random_state=0, cluster_std=0.60)

    #畫出所有樣本點

    plt.scatter(X[:,0],X[:,1],c=y,cmap="summer")

    #使用線性分類SVC擬合

    #svc函式還可以包括以下引數(具體例子見文章最後):

    #1,C(C越大意味著分類越嚴格不能有錯誤;當C趨近於很小的時意味著可以有更大的錯誤容忍)

    #2,kernel(kernel必須是[‘linear’, ‘poly’, ‘rbf’, ‘sigmoid’, ‘precomputed’]中的一個,預設為’rbf’)

    #3,gamma(gamma越大模型越複雜,會導致過擬合,對線性核函式無影響)

    model = SVC(kernel="linear")

    model.fit(X,y)

    plot_svc_decision_function(model)

    這裡用到繪製邊界線及圈出支援向量的函式plot_svc_decision_function()

    def plot_svc_decision_function(model, ax=None, plot_support=True):

    #Plot the decision function for a 2D SVC

    if ax is None:

    ax = plt.gca()

    #找出圖片x軸y軸的邊界

    xlim = ax.get_xlim()

    ylim = ax.get_ylim()

    # create grid to evaluate model

    x = np.linspace(xlim[0], xlim[1], 30)

    y = np.linspace(ylim[0], ylim[1], 30)

    Y, X = np.meshgrid(y, x)

    #形成圖片上所有座標點(900,2),900個二維點

    xy = np.vstack([X.ravel(), Y.ravel()]).T

    #計算每點到邊界的距離(30,30)

    P = model.decision_function(xy).reshape(X.shape)

    #繪製等高線(距離邊界線為0的實線,以及距離邊界為1的過支援向量的虛線)

    ax.contour(X, Y, P, colors="k",levels=[-1, 0, 1], alpha=0.5,linestyles=["--", "-", "--"])

    # 圈出支援向量

    if plot_support:

    #model.support_vectors_函式可打印出所有支援向量座標

    ax.scatter(model.support_vectors_[:, 0],model.support_vectors_[:, 1],s=200,c="",edgecolors="k")

    ax.set_xlim(xlim)

    ax.set_ylim(ylim)

    繪製效果圖如下:

    2,實現非線性分類–引入核函式有時候線性核函式不能很好的劃分邊界比如:

    from sklearn.datasets.samples_generator import make_circles

    X,y = make_circles(100, factor=.1, noise=.1)

    plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap="summer")

    clf = SVC(kernel="linear").fit(X, y)

    plot_svc_decision_function(clf, plot_support=False)

    分類結果如下:

    此時,需加入徑向基函式rbf(高斯)

    X,y = make_circles(100, factor=.1, noise=.1)

    plt.scatter(X[:, 0], X[:, 1], c=y, s=50, cmap="summer")

    clf = SVC(kernel="rbf", C=1E6)

    clf.fit(X,y)

    plot_svc_decision_function(clf)

    分類結果如下:

    希望您滿意,能幫助到您~~

  • 中秋節和大豐收的關聯?
  • 兒子在他媽家喝酒,醉酒後作鬧,造成的損失誰負責?