Вот моя попытка:
inputs = Input(shape=(config.N_FRAMES_IN_SEQUENCE, config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS))
def cnn_model(inputs):
x = Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu')(inputs)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=128, kernel_size=(3,3), padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
return x
x = TimeDistributed(cnn_model)(inputs)
Это дает следующую ошибку:
AttributeError: 'function' object has no attribute 'built'
Всего 1 ответ
Вы должны использовать слой Lambda
и обернуть свою функцию внутри него:
# cnn_model function the same way as you defined it ...
x = TimeDistributed(Lambda(cnn_model))(inputs)
В качестве альтернативы вы можете определить этот блок как модель, а затем применить к TimeDistributed
слой TimeDistributed
:
def cnn_model():
input_frame = Input(shape=(config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS))
x = Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu')(input_frame)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=32, kernel_size=(3,3), padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=64, kernel_size=(3,3), padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
x = Conv2D(filters=128, kernel_size=(3,3), padding='same', activation='relu')(x)
x = MaxPooling2D(pool_size=(2, 2))(x)
model = Model(input_frame, x)
return model
inputs = Input(shape=(config.N_FRAMES_IN_SEQUENCE, config.IMAGE_H, config.IMAGE_W, config.N_CHANNELS))
x = TimeDistributed(cnn_model())(inputs)