-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathphind34_di-unet.py
49 lines (39 loc) · 1.75 KB
/
phind34_di-unet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import tensorflow as tf
from tensorflow.keras.layers import Conv2D, MaxPooling2D, BatchNormalization, ReLU, Concatenate, Conv2DTranspose
def encoder_block(input_tensor, filters, kernel_size=3, strides=1, padding='same'):
x = Conv2D(filters, kernel_size, strides=strides, padding=padding)(input_tensor)
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2D(filters, kernel_size, strides=strides, padding=padding)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
return x
def decoder_block(input_tensor, filters, kernel_size=3, strides=1, padding='same'):
x = Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)(input_tensor)
x = BatchNormalization()(x)
x = ReLU()(x)
x = Conv2DTranspose(filters, kernel_size, strides=strides, padding=padding)(x)
x = BatchNormalization()(x)
x = ReLU()(x)
return x
def build_model(input_shape=(128, 128, 23)):
inputs = tf.keras.Input(shape=input_shape)
# Ensure correct slicing of the input tensor
# Assuming the input tensor has 23 channels, split it into two branches
encoder1_input = inputs[:, :, :, :23] # First half of the channels
encoder2_input = inputs[:, :, :, 23:] # Second half of the channels
# Encoder pathways
encoder1 = encoder_block(encoder1_input, 64)
encoder2 = encoder_block(encoder2_input, 64)
# Bottleneck
bottleneck = Concatenate()([encoder1, encoder2])
# Decoder pathways
decoder1 = decoder_block(bottleneck, 64)
decoder2 = decoder_block(bottleneck, 64)
# Output
output = Concatenate()([decoder1, decoder2])
output = Conv2D(1, 1, activation='sigmoid')(output)
model = tf.keras.Model(inputs=inputs, outputs=output)
return model
model = build_model()
model.summary()