Initial commit
This commit is contained in:
4
Tensorflow/concours_foetus/README.md
Normal file
4
Tensorflow/concours_foetus/README.md
Normal file
@@ -0,0 +1,4 @@
|
||||
# Concours HC18: Mesurer un foetus à partir d'image d'échographie
|
||||
|
||||
La vidéo du tutoriel se trouve à l'adresse suivante:
|
||||
https://www.youtube.com/watch?v=BCOLI8CTF00
|
||||
104
Tensorflow/concours_foetus/common.py
Normal file
104
Tensorflow/concours_foetus/common.py
Normal file
@@ -0,0 +1,104 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers, models
|
||||
import csv
|
||||
import random
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import csv
|
||||
import random
|
||||
import config
|
||||
|
||||
def rotateImage(image, angle):
|
||||
image_center=tuple(np.array(image.shape[1::-1])/2)
|
||||
rot_mat=cv2.getRotationMatrix2D(image_center, angle, 1.0)
|
||||
result=cv2.warpAffine(image, rot_mat, image.shape[1::-1], flags=cv2.INTER_LINEAR)
|
||||
return result
|
||||
|
||||
def complete_dataset(image, image_ellipse, tab_images, tab_labels):
|
||||
contours, hierarchy=cv2.findContours(image_ellipse[:, :, 0], cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE)
|
||||
if len(contours) is not 2:
|
||||
return 1
|
||||
else:
|
||||
if len(contours[0])<6 or len(contours[1])<6:
|
||||
return 1
|
||||
(x1, y1), (ma1, MA1), a1=cv2.fitEllipse(contours[0])
|
||||
(x2, y2), (ma2, MA2), a2=cv2.fitEllipse(contours[1])
|
||||
x=(x1+x2)/2
|
||||
y=(y1+y2)/2
|
||||
ma=(ma1+ma2)/2
|
||||
MA=(MA1+MA2)/2
|
||||
a=(a1+a2)/2
|
||||
tab_images.append(image[:, :, 0])
|
||||
tab_labels.append([x/config.norm, y/config.norm, MA/config.norm, ma/config.norm, a/180])
|
||||
return 0
|
||||
|
||||
def prepare_data(fichier):
|
||||
with open(fichier, newline='') as csvfile:
|
||||
lignes=csv.reader(csvfile, delimiter=',')
|
||||
next(lignes)
|
||||
tab_images=[]
|
||||
tab_labels=[]
|
||||
nbr=0
|
||||
for ligne in lignes:
|
||||
image_orig=cv2.imread(config.dir_images+ligne[0])
|
||||
if image_orig is None:
|
||||
print("Fichier absent", config.dir_images+ligne[0])
|
||||
continue
|
||||
|
||||
f_ellipse=ligne[0].split('.')[0]+"_Annotation.png"
|
||||
image_ellipse_orig=cv2.imread(config.dir_images+f_ellipse)
|
||||
if image_ellipse_orig is None:
|
||||
print("Fichier absent", config.dir_images+f_ellipse)
|
||||
continue
|
||||
|
||||
for angle in range(0, 360, 30):
|
||||
|
||||
if np.random.randint(2)==0:
|
||||
h, w, c=image_orig.shape
|
||||
H=int(h*1.4)
|
||||
W=int(w*1.4)
|
||||
h_shift=np.random.randint(H-h)
|
||||
w_shift=np.random.randint(W-w)
|
||||
|
||||
i=np.zeros(shape=(H, W, c), dtype=np.uint8)
|
||||
i[h_shift:h_shift+h, w_shift:w_shift+w, :]=image_orig
|
||||
image_orig2=i
|
||||
|
||||
i=np.zeros(shape=(H, W, c), dtype=np.uint8)
|
||||
i[h_shift:h_shift+h, w_shift:w_shift+w, :]=image_ellipse_orig
|
||||
image_ellipse_orig2=i
|
||||
else:
|
||||
image_orig2=image_orig
|
||||
image_ellipse_orig2=image_ellipse_orig
|
||||
|
||||
image=cv2.resize(image_orig2, (config.largeur, config.hauteur), interpolation=cv2.INTER_AREA)
|
||||
image_ellipse=cv2.resize(image_ellipse_orig2, (config.largeur, config.hauteur), interpolation=cv2.INTER_AREA)
|
||||
img_r=rotateImage(image, angle)
|
||||
|
||||
#if np.random.randint(3)==0:
|
||||
# kernel_blur=np.random.randint(2)*2+1
|
||||
# img_r=cv2.GaussianBlur(img_r, (kernel_blur, kernel_blur), 0)
|
||||
|
||||
bruit=np.random.randn(config.hauteur, config.largeur, 3)*random.randint(1, 50)
|
||||
img_r=np.clip(img_r+bruit, 0, 255).astype(np.uint8)
|
||||
|
||||
img_ellipse=rotateImage(image_ellipse, angle)
|
||||
nbr+=complete_dataset(img_r, img_ellipse, tab_images, tab_labels)
|
||||
|
||||
img_f=cv2.flip(img_r, 0)
|
||||
img_ellipse_f=cv2.flip(img_ellipse, 0)
|
||||
nbr+=complete_dataset(img_f, img_ellipse_f, tab_images, tab_labels)
|
||||
|
||||
img_f=cv2.flip(img_r, 1)
|
||||
img_ellipse_f=cv2.flip(img_ellipse, 1)
|
||||
nbr+=complete_dataset(img_f, img_ellipse_f, tab_images, tab_labels)
|
||||
|
||||
img_f=cv2.flip(img_r, -1)
|
||||
img_ellipse_f=cv2.flip(img_ellipse, -1)
|
||||
nbr+=complete_dataset(img_f, img_ellipse_f, tab_images, tab_labels)
|
||||
|
||||
print("Image(s) rejetée(s):", nbr)
|
||||
print("Nombre d'images:", len(tab_images))
|
||||
return tab_images, tab_labels
|
||||
|
||||
14
Tensorflow/concours_foetus/config.py
Normal file
14
Tensorflow/concours_foetus/config.py
Normal file
@@ -0,0 +1,14 @@
|
||||
largeur=200
|
||||
hauteur=135
|
||||
|
||||
#largeur=220
|
||||
#hauteur=148
|
||||
|
||||
|
||||
norm=max(largeur, hauteur)
|
||||
|
||||
batch_size=64
|
||||
input_model=8
|
||||
|
||||
dir_images="training_set/"
|
||||
dir_images_test="test_set/"
|
||||
17
Tensorflow/concours_foetus/genere_csv.py
Normal file
17
Tensorflow/concours_foetus/genere_csv.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import numpy as np
|
||||
import os
|
||||
import cv2
|
||||
import glob
|
||||
import config
|
||||
|
||||
for image in glob.glob(config.dir_images+'*_HC.png'):
|
||||
image_ellipse=image.split('.')[0]+"_Annotation.png"
|
||||
img=cv2.imread(image_ellipse)
|
||||
img=cv2.resize(img, (config.largeur, config.hauteur))
|
||||
print(img.shape)
|
||||
h, w, c=img.shape
|
||||
img=img[:, :, 0]
|
||||
contours, hierarchy=cv2.findContours(img, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
for cont in contours:
|
||||
(x, y),(ma, MA) ,angle = cv2.fitEllipse(cont)
|
||||
print("{}:{:f}:{:f}:{:f}:{:f}:{:f}".format(image, x/w, y/h, ma/w, MA/h, angle/180))
|
||||
29
Tensorflow/concours_foetus/images.py
Normal file
29
Tensorflow/concours_foetus/images.py
Normal file
@@ -0,0 +1,29 @@
|
||||
import tensorflow as tf
|
||||
import sys
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
import common
|
||||
import config
|
||||
import model
|
||||
|
||||
images, labels=common.prepare_data('training_set.csv')
|
||||
images=np.array(images, dtype=np.float32)/255
|
||||
labels=np.array(labels, dtype=np.float32)
|
||||
index=np.random.permutation(len(images))
|
||||
images=images[index].reshape(-1, config.hauteur, config.largeur, 1)
|
||||
labels=labels[index]
|
||||
|
||||
print("Nombre d'image:", len(images))
|
||||
|
||||
for i in range(len(images)):
|
||||
x, y, grand_axe, petit_axe, angle=labels[i]
|
||||
print("Label:", labels[i], angle*180)
|
||||
img_couleur=np.tile(images[i], (1, 1, 3))
|
||||
cv2.ellipse(img_couleur, (int(x*config.norm), int(y*config.norm)), (int(petit_axe*config.norm/2), int(grand_axe*config.norm/2)), angle*180, 0., 360., (0, 0, 255), 2)
|
||||
cv2.imshow("Image", img_couleur)
|
||||
|
||||
key=cv2.waitKey()&0xFF
|
||||
if key==ord('q'):
|
||||
quit()
|
||||
|
||||
87
Tensorflow/concours_foetus/inference.py
Normal file
87
Tensorflow/concours_foetus/inference.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import tensorflow as tf
|
||||
import sys
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import common
|
||||
import config
|
||||
import model
|
||||
import csv
|
||||
|
||||
model=model.model(config.input_model)
|
||||
|
||||
rouge=(0, 0, 255)
|
||||
vert=(0, 255, 0)
|
||||
|
||||
if True:
|
||||
dir=config.dir_images
|
||||
fichier="training_set.csv"
|
||||
test=False
|
||||
else:
|
||||
dir=config.dir_images_test
|
||||
fichier="test_set.csv"
|
||||
test=True
|
||||
|
||||
checkpoint=tf.train.Checkpoint(model=model)
|
||||
checkpoint.restore(tf.train.latest_checkpoint("./training/"))
|
||||
|
||||
with open(fichier, newline='') as csvfile:
|
||||
lignes=csv.reader(csvfile, delimiter=',')
|
||||
for ligne in lignes:
|
||||
print("LIGNE:", ligne)
|
||||
print("XXX", dir+ligne[0])
|
||||
img_originale=cv2.imread(dir+ligne[0])
|
||||
if img_originale is None:
|
||||
continue
|
||||
print("WWW", ligne[0], dir+ligne[0], img_originale.shape)
|
||||
H, W, C=img_originale.shape
|
||||
mm_pixel=float(ligne[1])
|
||||
img=cv2.resize(img_originale, (config.largeur, config.hauteur))
|
||||
img2=img.copy()
|
||||
img=np.array(img, dtype=np.float32)/255
|
||||
img=np.expand_dims(img[:, :, 0], axis=-1)
|
||||
predictions=model(np.array([img]))
|
||||
x, y, grand_axe, petit_axe, angle=predictions[0]
|
||||
cv2.ellipse(img2, (x*config.norm, y*config.norm), (petit_axe*config.norm/2, grand_axe*config.norm/2), angle*180, 0., 360., rouge, 2)
|
||||
print("Prediction", np.array(predictions[0]))
|
||||
|
||||
if test is False:
|
||||
f_ellipse=ligne[0].split('.')[0]+"_Annotation.png"
|
||||
image_ellipse=cv2.imread(dir+f_ellipse)
|
||||
if image_ellipse is None:
|
||||
print("Fichier absent", dir+f_ellipse)
|
||||
continue
|
||||
img_ellipse_f_=image_ellipse[:, :, 0]
|
||||
contours, hierarchy=cv2.findContours(img_ellipse_f_, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
||||
(x_, y_), (ma_, MA_), a_=cv2.fitEllipse(contours[0])
|
||||
cv2.ellipse(img_originale, (int(x_), int(y_)), (int(ma_/2), int(MA_/2)), a_, 0., 360., vert, 3)
|
||||
|
||||
cv2.ellipse(img_originale, (x*W, y*W), (petit_axe*W/2, grand_axe*W/2), angle*180, 0., 360., rouge, 2)
|
||||
|
||||
x=float(x*W*mm_pixel)
|
||||
y=float(y*W*mm_pixel)
|
||||
axis_x=float(grand_axe*W*mm_pixel/2)
|
||||
axis_y=float(petit_axe*W*mm_pixel/2)
|
||||
|
||||
r=180.
|
||||
r_2=r/2
|
||||
if angle>=0.5:
|
||||
angle=angle*r-r_2
|
||||
else:
|
||||
angle=angle*r+r_2
|
||||
|
||||
HC=np.pi*np.sqrt(2*(axis_x**2+axis_y**2))
|
||||
|
||||
cv2.putText(img_originale, "HC: {:5.2f}mm".format(HC), (20, 30), cv2.FONT_HERSHEY_DUPLEX, 1, (0, 0, 255), 2)
|
||||
print("{},{:f},{:f},{:f},{:f},{:f} HC: {:f}mm".format(ligne[0], x, y, axis_x, axis_y, angle, HC))
|
||||
if len(ligne)==3:
|
||||
cv2.putText(img_originale, "HC: {:5.2f}mm".format(float(ligne[2])), (20, 60), cv2.FONT_HERSHEY_DUPLEX, 1, (0, 255, 0), 2)
|
||||
print("HC: {}mm prediction: {:f}mm".format(ligne[2], HC))
|
||||
|
||||
cv2.imshow("Image originale", img_originale)
|
||||
cv2.imshow("Inference", img2)
|
||||
|
||||
key=cv2.waitKey()&0xFF
|
||||
if key==ord('q'):
|
||||
quit()
|
||||
92
Tensorflow/concours_foetus/model.py
Normal file
92
Tensorflow/concours_foetus/model.py
Normal file
@@ -0,0 +1,92 @@
|
||||
import tensorflow as tf
|
||||
from tensorflow.keras import layers, models
|
||||
import config
|
||||
|
||||
def block_resnet(input, filters, kernel_size, reduce, dropout=0.):
|
||||
result=layers.Conv2D(filters, kernel_size, strides=1, padding='SAME', activation='relu')(input)
|
||||
if dropout is not 0.:
|
||||
result=layers.Dropout(dropout)(result)
|
||||
if reduce is True:
|
||||
result=layers.Conv2D(filters, kernel_size, strides=2, padding='SAME')(result)
|
||||
else:
|
||||
result=layers.Conv2D(filters, kernel_size, strides=1, padding='SAME')(result)
|
||||
|
||||
if input.shape[-1]==filters:
|
||||
if reduce is True:
|
||||
shortcut=layers.Conv2D(filters, 1, strides=2, padding='SAME')(input)
|
||||
else:
|
||||
shortcut=input
|
||||
else:
|
||||
if reduce is True:
|
||||
shortcut=layers.Conv2D(filters, 1, strides=2, padding='SAME')(input)
|
||||
else:
|
||||
shortcut=layers.Conv2D(filters, 1, strides=1, padding='SAME')(input)
|
||||
result=layers.add([result, shortcut])
|
||||
if dropout is not 0.:
|
||||
result=layers.Dropout(dropout)(result)
|
||||
result=layers.Activation('relu')(result)
|
||||
result=layers.BatchNormalization()(result)
|
||||
return result
|
||||
|
||||
def model(nbr):
|
||||
entree=layers.Input(shape=(config.largeur, config.hauteur, 1), dtype='float32')
|
||||
|
||||
result=block_resnet(entree, 2*nbr, 3, False, 0.3)
|
||||
result=block_resnet(result, 2*nbr, 3, False, 0.3)
|
||||
result=block_resnet(result, 2*nbr, 3, False, 0.3)
|
||||
result=block_resnet(result, 2*nbr, 3, True, 0.3)
|
||||
|
||||
result=block_resnet(result, 4*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 4*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 4*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 4*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 4*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 4*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 4*nbr, 3, True, 0.4)
|
||||
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, False, 0.4)
|
||||
result=block_resnet(result, 8*nbr, 3, True, 0.4)
|
||||
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
result=block_resnet(result, 16*nbr, 3, False, 0.5)
|
||||
|
||||
result=layers.AveragePooling2D()(result)
|
||||
result=layers.Flatten()(result)
|
||||
sortie=layers.Dense(5, activation='sigmoid')(result)
|
||||
|
||||
model=models.Model(inputs=entree, outputs=sortie)
|
||||
return model
|
||||
|
||||
45
Tensorflow/concours_foetus/result.py
Normal file
45
Tensorflow/concours_foetus/result.py
Normal file
@@ -0,0 +1,45 @@
|
||||
import tensorflow as tf
|
||||
import sys
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
import math
|
||||
import common
|
||||
import config
|
||||
import model
|
||||
import csv
|
||||
|
||||
model=model.model(config.input_model)
|
||||
|
||||
checkpoint=tf.train.Checkpoint(model=model)
|
||||
checkpoint.restore(tf.train.latest_checkpoint("./training/"))
|
||||
|
||||
print("filename,center_x_mm,center_y_mm,semi_axes_a_mm,semi_axes_b_mm,angle_rad")
|
||||
with open("test_set.csv", newline='') as csvfile:
|
||||
lignes=csv.reader(csvfile, delimiter=',')
|
||||
for ligne in lignes:
|
||||
img=cv2.imread(config.dir_images_test+ligne[0])
|
||||
if img is None:
|
||||
continue
|
||||
mm_pixel=float(ligne[1])
|
||||
H, W, C=img.shape
|
||||
img=cv2.resize(img, (config.largeur, config.hauteur))
|
||||
img=np.array(img, dtype=np.float32)/255
|
||||
img=np.expand_dims(img[:, :, 0], axis=-1)
|
||||
predictions=model(np.array([img]))
|
||||
x, y, grand_axe, petit_axe, angle=predictions[0]
|
||||
|
||||
x=float(x*W*mm_pixel)
|
||||
y=float(y*W*mm_pixel)
|
||||
axis_x=float(grand_axe*W*mm_pixel/2)
|
||||
axis_y=float(petit_axe*W*mm_pixel/2)
|
||||
|
||||
r=np.pi
|
||||
r_2=r/2
|
||||
if angle>=0.5:
|
||||
angle=angle*r-r_2
|
||||
else:
|
||||
angle=angle*r+r_2
|
||||
|
||||
print("{},{:f},{:f},{:f},{:f},{:f}".format(ligne[0], x, y, axis_x, axis_y, angle))
|
||||
|
||||
75
Tensorflow/concours_foetus/train.py
Normal file
75
Tensorflow/concours_foetus/train.py
Normal file
@@ -0,0 +1,75 @@
|
||||
import tensorflow as tf
|
||||
import sys
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
import common
|
||||
import config
|
||||
import model
|
||||
|
||||
images, labels=common.prepare_data('training_set.csv')
|
||||
images=np.array(images, dtype=np.float32)/255
|
||||
labels=np.array(labels, dtype=np.float32)
|
||||
index=np.random.permutation(len(images))
|
||||
images=images[index].reshape(-1, config.hauteur, config.largeur, 1)
|
||||
labels=labels[index]
|
||||
|
||||
print("Nbr images:", len(images))
|
||||
|
||||
train_ds=tf.data.Dataset.from_tensor_slices((images, labels)).batch(config.batch_size)
|
||||
|
||||
del images
|
||||
del labels
|
||||
|
||||
def my_loss(labels, preds):
|
||||
lambda_xy=5
|
||||
lambda_Aa=5
|
||||
lambda_angle=1
|
||||
|
||||
preds_xy=preds[:, 0:2]
|
||||
preds_Aa=preds[:, 2:4]
|
||||
preds_angle=preds[:, 4]
|
||||
|
||||
labels_xy=labels[:, 0:2]
|
||||
labels_Aa=labels[:, 2:4]
|
||||
labels_angle=labels[:, 4]
|
||||
|
||||
loss_xy=tf.reduce_sum(tf.math.square(preds_xy-labels_xy), axis=-1)
|
||||
loss_Aa=tf.reduce_sum(tf.math.square(preds_Aa-labels_Aa), axis=-1)
|
||||
loss_angle=tf.math.square(preds_angle-labels_angle)
|
||||
|
||||
loss=lambda_xy*loss_xy+lambda_Aa*loss_Aa+lambda_angle*loss_angle
|
||||
return loss
|
||||
|
||||
model=model.model(config.input_model)
|
||||
|
||||
@tf.function
|
||||
def train_step(images, labels):
|
||||
with tf.GradientTape() as tape:
|
||||
predictions=model(images)
|
||||
loss=my_loss(labels, predictions)
|
||||
gradients=tape.gradient(loss, model.trainable_variables)
|
||||
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
|
||||
train_loss(loss)
|
||||
|
||||
def train(train_ds, nbr_entrainement):
|
||||
for entrainement in range(nbr_entrainement):
|
||||
start=time.time()
|
||||
for images, labels in train_ds:
|
||||
train_step(images, labels)
|
||||
message='Entrainement {:04d}: loss: {:6.4f}, temps: {:7.4f}'
|
||||
print(message.format(entrainement+1,
|
||||
train_loss.result(),
|
||||
time.time()-start))
|
||||
if not entrainement%10:
|
||||
checkpoint.save(file_prefix="./training/")
|
||||
|
||||
optimizer=tf.keras.optimizers.Adam(learning_rate=1E-4)
|
||||
checkpoint=tf.train.Checkpoint(model=model)
|
||||
train_loss=tf.keras.metrics.Mean()
|
||||
|
||||
checkpoint=tf.train.Checkpoint(model=model)
|
||||
checkpoint.restore(tf.train.latest_checkpoint("./training/"))
|
||||
|
||||
train(train_ds, 60)
|
||||
checkpoint.save(file_prefix="./training/")
|
||||
Reference in New Issue
Block a user