Skip to content

Instantly share code, notes, and snippets.

@shepai
Created January 8, 2022 14:48
Show Gist options
  • Save shepai/b045a4bc3e4177a2f46a89ef90921622 to your computer and use it in GitHub Desktop.
Save shepai/b045a4bc3e4177a2f46a89ef90921622 to your computer and use it in GitHub Desktop.
ConvNN
class model:
def __init__(self,outcomes=10):
self.model = Sequential()
self.model.add(Conv2D(28, (3, 3), activation='relu', input_shape=(28, 28, 1)))
self.model.add(MaxPooling2D((2, 2)))
self.model.add(Conv2D(64, (3, 3), activation='relu'))
self.model.add(MaxPooling2D((2, 2)))
self.model.add(Conv2D(64, (3, 3), activation='relu'))
self.model.add(Flatten())
self.model.add(Dense(outcomes, activation='sigmoid')) #sigmoid is good for binary
self.model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])
def train(self,x,y,epochs=30,batch=32):
history=self.model.fit(x,y,batch_size=batch,epochs=epochs,validation_split=0.1)
self.history = history.history #gather training log
def test(self,X,y):
assert len(X)==len(y), "Error, the arrays do not match length"
predictions = self.model.predict(X)
count=0
for i in range(len(predictions)):
pred=np.argmax(predictions[i])
if y[i]==pred:
count+=1
return count/len(predictions)
def save(self,name):
self.model.save(''+name)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment