如何構建一個 CNN 模型,以從圖像中對幼苗的種類進行分類?
new_train = np.asarray(new_train)
# CLEANED IMAGES
for i in range(8):
plt.subplot(2,4,i+1)
plt.imshow(new_train[i])


將標簽轉換為數字
標簽是字符串,這些很難處理。因此,我們將這些標簽轉換為二元分類。
分類可以由 12 個數字組成的數組表示,這些數字將遵循以下條件:
如果未檢測到物種,則為 0。
1 如果檢測到該物種。
示例:如果檢測到 Blackgrass,則數組將為 = [1,0,0,0,0,0,0,0,0,0,0,0]
labels = preprocessing.LabelEncoder()
labels.fit(traininglabels[0])
print('Classes'+str(labels.classes_))
encodedlabels = labels.transform(traininglabels[0])
clearalllabels = np_utils.to_categorical(encodedlabels)
classes = clearalllabels.shape[1]
print(str(classes))
traininglabels[0].value_counts().plot(kind='pie')


定義我們的模型并拆分數據集
在這一步中,我們將拆分訓練數據集進行驗證。我們正在使用 scikit-learn 中的 train_test_split() 函數。這里我們拆分數據集,保持 test_size=0.1。這意味著總數據的 10% 用作測試數據,其余 90% 用作訓練數據。檢查以下代碼以拆分數據集。new_train = new_train/255
x_train,x_test,y_train,y_test = train_test_split(new_train,clearalllabels,test_size=0.1,random_state=seed,stratify=clearalllabels)
防止過擬合
過擬合是機器學習中的一個問題,我們的模型在訓練數據上表現非常好,但在測試數據上表現不佳。在深度神經網絡過度擬合的深度學習中,過度擬合的問題很嚴重。過度擬合的問題嚴重影響了我們的最終結果。為了擺脫它,我們需要減少它。在這個問題中,我們使用 ImageDataGenerator() 函數隨機改變圖像的特征并提供數據的隨機性。、為了避免過擬合,我們需要一個函數。此函數隨機改變圖像特性。檢查以下代碼以了解如何減少過度擬合generator = ImageDataGenerator(rotation_range = 180,zoom_range = 0.1,width_shift_range = 0.1,height_shift_range = 0.1,horizontal_flip = True,vertical_flip = True)
generator.fit(x_train)
定義卷積神經網絡
我們的數據集由圖像組成,因此我們不能使用線性回歸、邏輯回歸、決策樹等機器學習算法。我們需要一個用于圖像的深度神經網絡。在這個問題中,我們將使用卷積神經網絡。該神經網絡將圖像作為輸入,并將提供最終輸出作為物種值。我們隨機使用了 4 個卷積層和 3 個全連接層。此外,我們使用了多個函數,如 Sequential()、Conv2D()、Batch Normalization、Max Pooling、Dropout 和 Flatting。
我們使用卷積神經網絡進行訓練。
該模型有 4 個卷積層。
該模型有 3 個全連接層。
np.random.seed(seed)
model = Sequential()
model.add(Conv2D(filters=64, kernel_size=(5, 5), input_shape=(scale, scale, 3), activation='relu'))
model.add(BatchNormalization(axis=3))
model.add(Conv2D(filters=64, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(BatchNormalization(axis=3))
model.add(Dropout(0.1))
model.add(Conv2D(filters=128, kernel_size=(5, 5), activation='relu'))
model.add(BatchNormalization(axis=3))
model.add(Conv2D(filters=128, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(BatchNormalization(axis=3))
model.add(Dropout(0.1))
model.add(Conv2D(filters=256, kernel_size=(5, 5), activation='relu'))
model.add(BatchNormalization(axis=3))
model.add(Conv2D(filters=256, kernel_size=(5, 5), activation='relu'))
model.add(MaxPooling2D((2, 2)))
model.add(BatchNormalization(axis=3))
model.add(Dropout(0.1))
model.add(Flatten())
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(256, activation='relu'))
model.add(BatchNormalization())
model.add(Dropout(0.5))
model.add(Dense(classes, activation='softmax'))
model.compile(loss='categorical_crossentropy', optimizer='adam', metrics=['accuracy'])
model.summary()
請輸入評論內容...
請輸入評論/評論長度6~500個字
最新活動更多
- 1 AI狂歡遇上油價破百,全球股市還能漲多久? | 產聯看全球
- 2 OpenAI深夜王炸!ChatGPT Images 2.0實測:中文穩、細節炸,設計師慌了
- 3 6000億美元估值錨定:字節跳動的“去單一化”突圍與估值重構
- 4 Tesla AI5芯片最新進展總結
- 5 連夜測了一波DeepSeek-V4,我發現它可能只剩“審美”這個短板了
- 6 熱點丨AI“瑜亮之爭”:既生OpenClaw,何生Hermes?
- 7 AI界的殺豬盤:9秒刪庫跑路,全員被封號,還繼續扣錢!
- 8 2026,人形機器人只贏了面子
- 9 DeepSeek降價90%:價格屠夫不是身份,是戰略
- 10 AI Infra產業鏈卡在哪里了?


分享













