sharktide commited on
Commit
8547fd7
·
verified ·
1 Parent(s): c37a6ea

Add custom_objects.py for tensorflowtools compatibility

Browse files
Files changed (1) hide show
  1. custom_objects.py +57 -0
custom_objects.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tensorflow as tf
2
+ from tensorflow.keras.saving import register_keras_serializable
3
+ from tensorflow.keras import layers, models, backend as K
4
+ import numpy as np
5
+
6
+ @register_keras_serializable()
7
+ def cold_temp_penalty(inputs):
8
+ temp = inputs[:, 0]
9
+ penalty = tf.where(
10
+ temp > 295.0,
11
+ 1.0,
12
+ tf.where(
13
+ temp < 290.0,
14
+ 0.0,
15
+ (temp - 290.0) / 5.0
16
+ )
17
+ )
18
+ return penalty[:, None]
19
+
20
+ @register_keras_serializable()
21
+ def fire_risk_booster(inputs):
22
+ temp = inputs[:, 0]
23
+ humidity = inputs[:, 1]
24
+ wind = inputs[:, 2]
25
+ veg = inputs[:, 3]
26
+
27
+ # Boost ranges
28
+ temp_boost = tf.sigmoid((temp - 305.0) * 1.2)
29
+ humidity_boost = tf.sigmoid((20.0 - humidity) * 0.5)
30
+ wind_boost = tf.sigmoid((wind - 15.0) * 0.8)
31
+ veg_boost = tf.sigmoid((veg - 70.0) * 0.5)
32
+
33
+ # Combine and scale
34
+ combined = temp_boost * humidity_boost * wind_boost * veg_boost
35
+ boost = 1.0 + 0.3 * combined # Up to 30% increase in fire score
36
+ return boost[:, None]
37
+
38
+ @register_keras_serializable()
39
+ def fire_suppression_mask(inputs):
40
+ temp = inputs[:, 0]
41
+ humidity = inputs[:, 1]
42
+ wind = inputs[:, 2]
43
+
44
+ # Suppress if warm but humid and still
45
+ temp_flag = tf.sigmoid((temp - 293.0) * 1.2)
46
+ humid_flag = tf.sigmoid((humidity - 50.0) * 0.4)
47
+ wind_flag = 1 - tf.sigmoid((wind - 5.0) * 0.8)
48
+
49
+ suppression = temp_flag * humid_flag * wind_flag
50
+ penalty = 1.0 - 0.3 * suppression # Max 30% suppression
51
+ return penalty[:, None]
52
+
53
+ CUSTOM_OBJECTS = {
54
+ "cold_temp_penalty": cold_temp_penalty,
55
+ "fire_risk_booster": fire_risk_booster,
56
+ "fire_suppression_mask": fire_suppression_mask
57
+ }