p4ul commited on
Commit
039ae87
·
1 Parent(s): 11d9034

app.py update

Browse files
Files changed (1) hide show
  1. app.py +343 -281
app.py CHANGED
@@ -33,284 +33,346 @@ from plasticorigins.tracking.track_video import track_video
33
  from plasticorigins.tracking.trackers import get_tracker
34
 
35
 
36
- logger = logging.getLogger()
37
- logger.setLevel(logging.DEBUG)
38
- ch = logging.StreamHandler()
39
- ch.setLevel(logging.DEBUG)
40
- formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
41
- ch.setFormatter(formatter)
42
- logger.addHandler(ch)
43
-
44
- class DotDict(dict):
45
- """dot.notation access to dictionary attributes"""
46
- __getattr__ = dict.get
47
- __setattr__ = dict.__setitem__
48
- __delattr__ = dict.__delitem__
49
-
50
- id_categories = {
51
- 0: 'Fragment', #'Sheet / tarp / plastic bag / fragment',
52
- 1: 'Insulating', #'Insulating material',
53
- 2: 'Bottle', #'Bottle-shaped',
54
- 3: 'Can', #'Can-shaped',
55
- 4: 'Drum',
56
- 5: 'Packaging', #'Other packaging',
57
- 6: 'Tire',
58
- 7: 'Fishing net', #'Fishing net / cord',
59
- 8: 'Easily namable',
60
- 9: 'Unclear'
61
- }
62
-
63
-
64
- config_track = DotDict({
65
- "yolo_conf_thrld": 0.35,
66
- "yolo_iou_thrld": 0.5,
67
-
68
- "confidence_threshold": 0.004, # for the tracking part
69
- "detection_threshold": 0.3, # for centernet
70
- "downsampling_factor": 4,
71
- "noise_covariances_path": "data/tracking_parameters",
72
- "output_shape": (960,544),
73
- "size": 768,
74
- "skip_frames": 3, #3
75
- "arch": "mobilenet_v3_small",
76
- "device": "cpu",
77
- "detection_batch_size": 1,
78
- "display": 0,
79
- "kappa": 4, #4
80
- "tau": 3, #4
81
- "max_length": 240,
82
- "downscale_output":2
83
- })
84
-
85
-
86
- logger.info('---Yolo model...')
87
- # Yolo has warning problems, so we set an env variable to remove it
88
- os.environ["VERBOSE"] = "False"
89
- URL_MODEL = "https://github.com/surfriderfoundationeurope/IA_Pau/releases/download/v0.1/yolov5.pt"
90
- FILE_MODEL = "yolov5.pt"
91
- model_path = download_from_url(URL_MODEL, FILE_MODEL, "./models", logger)
92
- model_yolo = load_model(model_path, config_track.device, config_track.yolo_conf_thrld, config_track.yolo_iou_thrld)
93
-
94
-
95
- logger.info('---Centernet model...')
96
- URL_MODEL = "https://partage.imt.fr/index.php/s/sJi22N6gedN6T4q/download"
97
- FILE_MODEL = "mobilenet_v3_pretrained.pth"
98
- model_path = download_from_url(URL_MODEL, FILE_MODEL, "./models", logger)
99
- model = get_mobilenet_v3_small(num_layers=0, heads={'hm': 1}, head_conv=256)
100
- checkpoint = torch.load(model_path, map_location="cpu")
101
- model.load_state_dict(checkpoint['model'], strict=True)
102
-
103
-
104
- URL_DEMO1 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_niv15.mp4"
105
- FILE_DEMO1 = "video_niv15.mp4"
106
- download_from_url(URL_DEMO1, FILE_DEMO1, "./data/", logger)
107
- video1_path = op.join("./data", FILE_DEMO1)
108
- URL_DEMO2 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_midouze15.mp4"
109
- FILE_DEMO2 = "video_midouze15.mp4"
110
- download_from_url(URL_DEMO2, FILE_DEMO2, "./data/", logger)
111
- video2_path = op.join("./data", FILE_DEMO2)
112
- URL_DEMO3 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_antoine15.mp4"
113
- FILE_DEMO3 = "video_antoine15.mp4"
114
- download_from_url(URL_DEMO3, FILE_DEMO3, "./data/", logger)
115
- video3_path = op.join("./data", FILE_DEMO3)
116
- JSON_FILE_PATH = "data/"
117
-
118
-
119
- labels2icons = load_trash_icons("./data/icons/")
120
-
121
- def track(args):
122
- device = torch.device("cpu")
123
-
124
- engine = get_tracker('EKF')
125
-
126
- detector = None
127
- # centernet version
128
- if args.model_type == "yolo":
129
- logger.info("---Using Yolo")
130
- detector = lambda frame: predict_yolo(model_yolo, frame, size=config_track.size, augment=False)
131
- elif args.model_type == "centernet":
132
- logger.info("---Using Centernet")
133
- detector = lambda frame: detect(frame, threshold=args.detection_threshold, model=model)
134
-
135
-
136
- transition_variance = np.load(op.join(args.noise_covariances_path, 'transition_variance.npy'))
137
- observation_variance = np.load(op.join(args.noise_covariances_path, 'observation_variance.npy'))
138
-
139
- logger.info(f'---Processing {args.video_path}')
140
- reader = IterableFrameReader(video_filename=args.video_path,
141
- skip_frames=args.skip_frames,
142
- output_shape=args.output_shape,
143
- progress_bar=True,
144
- preload=False,
145
- max_frame=args.max_length)
146
-
147
-
148
- input_shape = reader.input_shape
149
- output_shape = reader.output_shape
150
- ratio_y = input_shape[0] / (output_shape[0] // args.downsampling_factor)
151
- ratio_x = input_shape[1] / (output_shape[1] // args.downsampling_factor)
152
-
153
- detections = []
154
- logger.info('---Detecting...')
155
- if args.model_type == "yolo":
156
- with warnings.catch_warnings():
157
- warnings.filterwarnings("ignore")
158
-
159
- for frame in reader:
160
- detections.append(detector(frame))
161
- elif args.model_type == "centernet":
162
- detections = get_detections_for_video(reader, detector, batch_size=args.detection_batch_size, device=device)
163
-
164
- logger.info('---Tracking...')
165
- display = None
166
- results = track_video(reader, iter(detections), args, engine, transition_variance, observation_variance, display, is_yolo=args.model_type=="yolo")
167
- reader.video.release()
168
-
169
- # store unfiltered results
170
- datestr = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
171
- output_filename = op.splitext(args.video_path)[0] + "_" + datestr + '_unfiltered.txt'
172
- write_tracking_results_to_file(results, ratio_x=ratio_x, ratio_y=ratio_y, output_filename=output_filename)
173
- logger.info('---Filtering...')
174
-
175
- # read from the file
176
- results = read_tracking_results(output_filename)
177
- filtered_results = filter_tracks(results, config_track.kappa, config_track.tau)
178
- # store filtered results
179
- output_filename = op.splitext(args.video_path)[0] + "_" + datestr + '_filtered.txt'
180
- write_tracking_results_to_file(filtered_results, ratio_x=ratio_x, ratio_y=ratio_y, output_filename=output_filename)
181
-
182
- return filtered_results
183
-
184
-
185
- def run_model(video_path, model_type, seconds, skip, tau, kappa, gps_file):
186
- logger.info('---video filename: '+ video_path)
187
-
188
- # launch the tracking
189
- config_track.video_path = video_path
190
- config_track.model_type = model_type
191
- config_track.skip_frames = int(skip)
192
- config_track.tau = int(tau)
193
- config_track.kappa = int(kappa)
194
- config_track.max_length = int(seconds)*24
195
-
196
- out_folder = create_unique_folder("/tmp/", "output")
197
- output_path = op.join(out_folder, "video.mp4")
198
- filtered_results = track(config_track)
199
-
200
- # postprocess
201
- logger.info('---Postprocessing...')
202
- output_json_path = op.join(out_folder, "output.json")
203
- output_json = postprocess_for_api(filtered_results, id_categories)
204
- with open(output_json_path, 'w') as f_out:
205
- json.dump(output_json, f_out)
206
-
207
- # build video output
208
- logger.info('---Generating new video...')
209
- reader = IterableFrameReader(video_filename=config_track.video_path,
210
- skip_frames=0,
211
- progress_bar=True,
212
- preload=False,
213
- max_frame=config_track.max_length)
214
-
215
- # Get GPS Data
216
- video_duration = reader.total_num_frames / reader.fps
217
- gps_data = get_filled_gps(gps_file, video_duration)
218
-
219
- # Generate new video
220
- generate_video_with_annotations(reader, output_json, output_path,
221
- config_track.skip_frames, config_track.max_length,
222
- config_track.downscale_output, logger, gps_data=gps_data,
223
- labels2icons=labels2icons)
224
- output_label = count_objects(output_json, id_categories)
225
-
226
-
227
- # Get Plastic Map
228
- map_frame = None # default value in case no GPS file
229
- if gps_data is not None:
230
- logger.info('---Creating Plastic Map...')
231
- # Get Trash Prediction
232
- with open(output_json_path) as json_file:
233
- predictions = json.load(json_file)
234
- trash_df = get_df_prediction(predictions, reader.fps)
235
- if len(trash_df) != 0 :
236
- # Get Trash prediction alongside GPS data
237
- trash_gps_df = get_trash_gps_df(trash_df,gps_data)
238
- trash_gps_geo_df = get_trash_gps_geo_df(trash_gps_df)
239
- # Create Map
240
- center_lat = trash_gps_df.iloc[0]['Latitude']
241
- center_long = trash_gps_df.iloc[0]['Longitude']
242
- map_path = get_plastic_map(center_lat,center_long,trash_gps_geo_df,out_folder)
243
- html_content = codecs.open(map_path, 'r')
244
- map_html = html_content.read()
245
- map_frame = f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera;
246
- display-capture; encrypted-media;" sandbox="allow-modals allow-forms
247
- allow-scripts allow-same-origin allow-popups
248
- allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
249
- allowpaymentrequest="" frameborder="0" srcdoc='{map_html}'></iframe>"""
250
-
251
- logger.info('---Surfnet End processing...')
252
-
253
- return output_path, map_frame, output_label, output_json_path
254
-
255
-
256
-
257
- def get_filled_gps(file_obj, video_duration)->list:
258
- """Get a filled GPS point list from Plastic Origin mobile GPS JSON track
259
- Args:
260
- file_obj: a file_obj from gradio input File type
261
- video_duration: in seconds
262
- Returns:
263
- gps_data (list): the GPS filled data as a list
264
- """
265
-
266
- if file_obj is not None:
267
- json_data = parse_json(file_obj)
268
- json_data_list = get_json_gps_list(json_data)
269
- gps_data = fill_gps(json_data_list, video_duration)
270
- return gps_data
271
- else:
272
- return None
273
-
274
-
275
- def get_plastic_map(center_lat,center_long,trash_gps_gdf,out_folder)->str:
276
- """Get the map with plastic trash detection
277
- Args:
278
- center_lat (float): latitude to center map
279
- center_long (float): longitude to center map
280
- trash_gps_gdf (DataFrame): trash & gps geo dataframe
281
- out_folder (str): folder to save html map
282
- Returns:
283
- map_html_path (str): full path to html map
284
- """
285
-
286
- m = folium.Map([center_lat, center_long], zoom_start=16)
287
- locs = zip(trash_gps_gdf.geometry.y,trash_gps_gdf.geometry.x)
288
- labels = list(trash_gps_gdf['label'])
289
- i = 0
290
- for location in locs:
291
- folium.CircleMarker(location=location).add_child(folium.Popup(labels[i])).add_to(m)
292
- i = i + 1
293
- map_html_path = op.join(out_folder,"plasticmap.html")
294
- m.save(map_html_path)
295
- return map_html_path
296
-
297
-
298
-
299
- video_in = gr.inputs.Video(type="mp4", source="upload", label="Video Upload", optional=False)
300
- model_type = gr.inputs.Dropdown(choices=["centernet", "yolo"], type="value", default="yolo", label="model")
301
- skip_slider = gr.inputs.Slider(minimum=0, maximum=15, step=1, default=3, label="skip frames")
302
- tau_slider = gr.inputs.Slider(minimum=1, maximum=7, step=1, default=3, label="tau")
303
- kappa_slider = gr.inputs.Slider(minimum=1, maximum=7, step=1, default=4, label="kappa")
304
- seconds_num = gr.inputs.Number(default=10, label="seconds")
305
- gps_in = gr.inputs.File(type="file", label="GPS Upload", optional=True)
306
-
307
-
308
- gr.Interface(fn=run_model, inputs=[video_in, model_type, seconds_num, skip_slider, tau_slider, kappa_slider,gps_in],
309
- outputs=["playable_video","html","label", "file"],
310
- title="Surfnet demo",
311
- examples=[[video1_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"gavepau.json"],
312
- [video2_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"midouze.json"],
313
- [video3_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"gavepau.json"]],
314
- description="Upload a video, optionnaly a GPS file and you'll get Plastic detection on river.",
315
- theme="huggingface",
316
- allow_screenshot=False, allow_flagging="never").launch(debug=True,enable_queue=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  from plasticorigins.tracking.trackers import get_tracker
34
 
35
 
36
+
37
+
38
+ demo = gr.Blocks()
39
+ title = "Surfnet AI Demo"
40
+ with demo:
41
+ with gr.Tabs():
42
+ with gr.TabItem("The Project"):
43
+ gr.HTML(""" <!DOCTYPE html>
44
+ <html lang="en-us">
45
+ <head>
46
+ <meta charset="utf-8">
47
+ <meta name="viewport" content="width=device-width, initial-scale=1.0, user-scalable=no">
48
+ <title>Surfnet AI Demo</title>
49
+ </head>
50
+
51
+ <body>
52
+ <h1 style="text-align: center;"> <b>Welcome to the Surfnet demo, the algo that tracks Plastic Pollution</b> 🌊</h1>
53
+ <p style="text-align: center;"> We all dream about swimming in clear blue waters and walking
54
+ bare-footed on a beautiful white sand beach. But our dream is
55
+ threatened. Plastics are invading every corner of the earth,
56
+ from remote alpine lakes to the deepest oceanic trench.
57
+ Thankfully, there are many things we can do.🤝
58
+ <b>Plastic Origins</b>, a citizen science project from <a href="https://surfrider.eu" target = "_blank"><b><u>Surfrider Europe</u></b></a>,
59
+ using artificial intelligence to map river plastic pollution, is one of them.
60
+ This demo is here for you to test the AI model we use to detect and count litter items on riverbanks.
61
+ </p>
62
+ <br>
63
+ <p style="text-align: center;">
64
+ Read more on <a href="https://plasticorigins.eu" target = "_blank"> <b><u>www.plasticorigins.eu</u></b></a>
65
+ <br>
66
+ 💻 Join the dev team on <a href="https://github.com/surfriderfoundationeurope/The-Plastic-Origins-Project" target = "_blank"> <b><u>Github</u></b></a>
67
+ <br>
68
+ 🏷️ Help us label images on <a href="https://www.trashroulette.com/#/" target = "_blank"> <b><u>www.trashroulette.com</u></b></a>
69
+ <br>
70
+ <br>
71
+ <p style="text-align: center">
72
+ 📧 contact :
73
+ <br>
74
+ <a href="mailto:[email protected]"> <b><u>[email protected]</u></b></a>
75
+ </p>
76
+ </div>
77
+ </body>
78
+ </html>""")
79
+ with gr.TabItem("Surfnet AI"):
80
+ gr.HTML(""" <!DOCTYPE html>
81
+ <html lang="en">
82
+ <head>
83
+ <title>Left Side Panel</title>
84
+ <meta charset="utf-8">
85
+ </head>
86
+
87
+ <body>
88
+ <p style="text-align: center;">
89
+ <b>Surfnet</b> is an AI model that detects trash on riverbanks.
90
+ We use it to map river plastic pollution and act to reduce the introduction of litter into the environment.
91
+ Developed & Maintain by a bunch of amazing volunteers from the NGO Surfrider Foundation Europe.
92
+ </p>
93
+ </body>
94
+ </html>""" )
95
+
96
+
97
+ logger = logging.getLogger()
98
+ logger.setLevel(logging.DEBUG)
99
+ ch = logging.StreamHandler()
100
+ ch.setLevel(logging.DEBUG)
101
+ formatter = logging.Formatter('%(asctime)s - %(name)s - %(levelname)s - %(message)s')
102
+ ch.setFormatter(formatter)
103
+ logger.addHandler(ch)
104
+
105
+ class DotDict(dict):
106
+ """dot.notation access to dictionary attributes"""
107
+ __getattr__ = dict.get
108
+ __setattr__ = dict.__setitem__
109
+ __delattr__ = dict.__delitem__
110
+
111
+ id_categories = {
112
+ 0: 'Fragment', #'Sheet / tarp / plastic bag / fragment',
113
+ 1: 'Insulating', #'Insulating material',
114
+ 2: 'Bottle', #'Bottle-shaped',
115
+ 3: 'Can', #'Can-shaped',
116
+ 4: 'Drum',
117
+ 5: 'Packaging', #'Other packaging',
118
+ 6: 'Tire',
119
+ 7: 'Fishing net', #'Fishing net / cord',
120
+ 8: 'Easily namable',
121
+ 9: 'Unclear'
122
+ }
123
+
124
+
125
+ config_track = DotDict({
126
+ "yolo_conf_thrld": 0.35,
127
+ "yolo_iou_thrld": 0.5,
128
+
129
+ "confidence_threshold": 0.004, # for the tracking part
130
+ "detection_threshold": 0.3, # for centernet
131
+ "downsampling_factor": 4,
132
+ "noise_covariances_path": "data/tracking_parameters",
133
+ "output_shape": (960,544),
134
+ "size": 768,
135
+ "skip_frames": 3, #3
136
+ "arch": "mobilenet_v3_small",
137
+ "device": "cpu",
138
+ "detection_batch_size": 1,
139
+ "display": 0,
140
+ "kappa": 4, #4
141
+ "tau": 3, #4
142
+ "max_length": 240,
143
+ "downscale_output":2
144
+ })
145
+
146
+
147
+ logger.info('---Yolo model...')
148
+ # Yolo has warning problems, so we set an env variable to remove it
149
+ os.environ["VERBOSE"] = "False"
150
+ URL_MODEL = "https://github.com/surfriderfoundationeurope/IA_Pau/releases/download/v0.1/yolov5.pt"
151
+ FILE_MODEL = "yolov5.pt"
152
+ model_path = download_from_url(URL_MODEL, FILE_MODEL, "./models", logger)
153
+ model_yolo = load_model(model_path, config_track.device, config_track.yolo_conf_thrld, config_track.yolo_iou_thrld)
154
+
155
+
156
+ logger.info('---Centernet model...')
157
+ URL_MODEL = "https://partage.imt.fr/index.php/s/sJi22N6gedN6T4q/download"
158
+ FILE_MODEL = "mobilenet_v3_pretrained.pth"
159
+ model_path = download_from_url(URL_MODEL, FILE_MODEL, "./models", logger)
160
+ model = get_mobilenet_v3_small(num_layers=0, heads={'hm': 1}, head_conv=256)
161
+ checkpoint = torch.load(model_path, map_location="cpu")
162
+ model.load_state_dict(checkpoint['model'], strict=True)
163
+
164
+
165
+ URL_DEMO1 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_niv15.mp4"
166
+ FILE_DEMO1 = "video_niv15.mp4"
167
+ download_from_url(URL_DEMO1, FILE_DEMO1, "./data/", logger)
168
+ video1_path = op.join("./data", FILE_DEMO1)
169
+ URL_DEMO2 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_midouze15.mp4"
170
+ FILE_DEMO2 = "video_midouze15.mp4"
171
+ download_from_url(URL_DEMO2, FILE_DEMO2, "./data/", logger)
172
+ video2_path = op.join("./data", FILE_DEMO2)
173
+ URL_DEMO3 = "https://etlplasticostorageacc.blob.core.windows.net/surfnetbenchmark/video_antoine15.mp4"
174
+ FILE_DEMO3 = "video_antoine15.mp4"
175
+ download_from_url(URL_DEMO3, FILE_DEMO3, "./data/", logger)
176
+ video3_path = op.join("./data", FILE_DEMO3)
177
+ JSON_FILE_PATH = "data/"
178
+
179
+
180
+ labels2icons = load_trash_icons("./data/icons/")
181
+
182
+ def track(args):
183
+ device = torch.device("cpu")
184
+
185
+ engine = get_tracker('EKF')
186
+
187
+ detector = None
188
+ # centernet version
189
+ if args.model_type == "yolo":
190
+ logger.info("---Using Yolo")
191
+ detector = lambda frame: predict_yolo(model_yolo, frame, size=config_track.size, augment=False)
192
+ elif args.model_type == "centernet":
193
+ logger.info("---Using Centernet")
194
+ detector = lambda frame: detect(frame, threshold=args.detection_threshold, model=model)
195
+
196
+
197
+ transition_variance = np.load(op.join(args.noise_covariances_path, 'transition_variance.npy'))
198
+ observation_variance = np.load(op.join(args.noise_covariances_path, 'observation_variance.npy'))
199
+
200
+ logger.info(f'---Processing {args.video_path}')
201
+ reader = IterableFrameReader(video_filename=args.video_path,
202
+ skip_frames=args.skip_frames,
203
+ output_shape=args.output_shape,
204
+ progress_bar=True,
205
+ preload=False,
206
+ max_frame=args.max_length)
207
+
208
+
209
+ input_shape = reader.input_shape
210
+ output_shape = reader.output_shape
211
+ ratio_y = input_shape[0] / (output_shape[0] // args.downsampling_factor)
212
+ ratio_x = input_shape[1] / (output_shape[1] // args.downsampling_factor)
213
+
214
+ detections = []
215
+ logger.info('---Detecting...')
216
+ if args.model_type == "yolo":
217
+ with warnings.catch_warnings():
218
+ warnings.filterwarnings("ignore")
219
+
220
+ for frame in reader:
221
+ detections.append(detector(frame))
222
+ elif args.model_type == "centernet":
223
+ detections = get_detections_for_video(reader, detector, batch_size=args.detection_batch_size, device=device)
224
+
225
+ logger.info('---Tracking...')
226
+ display = None
227
+ results = track_video(reader, iter(detections), args, engine, transition_variance, observation_variance, display, is_yolo=args.model_type=="yolo")
228
+ reader.video.release()
229
+
230
+ # store unfiltered results
231
+ datestr = datetime.datetime.now().strftime('%Y%m%d%H%M%S%f')
232
+ output_filename = op.splitext(args.video_path)[0] + "_" + datestr + '_unfiltered.txt'
233
+ write_tracking_results_to_file(results, ratio_x=ratio_x, ratio_y=ratio_y, output_filename=output_filename)
234
+ logger.info('---Filtering...')
235
+
236
+ # read from the file
237
+ results = read_tracking_results(output_filename)
238
+ filtered_results = filter_tracks(results, config_track.kappa, config_track.tau)
239
+ # store filtered results
240
+ output_filename = op.splitext(args.video_path)[0] + "_" + datestr + '_filtered.txt'
241
+ write_tracking_results_to_file(filtered_results, ratio_x=ratio_x, ratio_y=ratio_y, output_filename=output_filename)
242
+
243
+ return filtered_results
244
+
245
+
246
+ def run_model(video_path, model_type, seconds, skip, tau, kappa, gps_file):
247
+ logger.info('---video filename: '+ video_path)
248
+
249
+ # launch the tracking
250
+ config_track.video_path = video_path
251
+ config_track.model_type = model_type
252
+ config_track.skip_frames = int(skip)
253
+ config_track.tau = int(tau)
254
+ config_track.kappa = int(kappa)
255
+ config_track.max_length = int(seconds)*24
256
+
257
+ out_folder = create_unique_folder("/tmp/", "output")
258
+ output_path = op.join(out_folder, "video.mp4")
259
+ filtered_results = track(config_track)
260
+
261
+ # postprocess
262
+ logger.info('---Postprocessing...')
263
+ output_json_path = op.join(out_folder, "output.json")
264
+ output_json = postprocess_for_api(filtered_results, id_categories)
265
+ with open(output_json_path, 'w') as f_out:
266
+ json.dump(output_json, f_out)
267
+
268
+ # build video output
269
+ logger.info('---Generating new video...')
270
+ reader = IterableFrameReader(video_filename=config_track.video_path,
271
+ skip_frames=0,
272
+ progress_bar=True,
273
+ preload=False,
274
+ max_frame=config_track.max_length)
275
+
276
+ # Get GPS Data
277
+ video_duration = reader.total_num_frames / reader.fps
278
+ gps_data = get_filled_gps(gps_file, video_duration)
279
+
280
+ # Generate new video
281
+ generate_video_with_annotations(reader, output_json, output_path,
282
+ config_track.skip_frames, config_track.max_length,
283
+ config_track.downscale_output, logger, gps_data=gps_data,
284
+ labels2icons=labels2icons)
285
+ output_label = count_objects(output_json, id_categories)
286
+
287
+
288
+ # Get Plastic Map
289
+ map_frame = None # default value in case no GPS file
290
+ if gps_data is not None:
291
+ logger.info('---Creating Plastic Map...')
292
+ # Get Trash Prediction
293
+ with open(output_json_path) as json_file:
294
+ predictions = json.load(json_file)
295
+ trash_df = get_df_prediction(predictions, reader.fps)
296
+ if len(trash_df) != 0 :
297
+ # Get Trash prediction alongside GPS data
298
+ trash_gps_df = get_trash_gps_df(trash_df,gps_data)
299
+ trash_gps_geo_df = get_trash_gps_geo_df(trash_gps_df)
300
+ # Create Map
301
+ center_lat = trash_gps_df.iloc[0]['Latitude']
302
+ center_long = trash_gps_df.iloc[0]['Longitude']
303
+ map_path = get_plastic_map(center_lat,center_long,trash_gps_geo_df,out_folder)
304
+ html_content = codecs.open(map_path, 'r')
305
+ map_html = html_content.read()
306
+ map_frame = f"""<iframe style="width: 100%; height: 480px" name="result" allow="midi; geolocation; microphone; camera;
307
+ display-capture; encrypted-media;" sandbox="allow-modals allow-forms
308
+ allow-scripts allow-same-origin allow-popups
309
+ allow-top-navigation-by-user-activation allow-downloads" allowfullscreen=""
310
+ allowpaymentrequest="" frameborder="0" srcdoc='{map_html}'></iframe>"""
311
+
312
+ logger.info('---Surfnet End processing...')
313
+
314
+ return output_path, map_frame, output_label, output_json_path
315
+
316
+
317
+
318
+ def get_filled_gps(file_obj, video_duration)->list:
319
+ """Get a filled GPS point list from Plastic Origin mobile GPS JSON track
320
+ Args:
321
+ file_obj: a file_obj from gradio input File type
322
+ video_duration: in seconds
323
+ Returns:
324
+ gps_data (list): the GPS filled data as a list
325
+ """
326
+
327
+ if file_obj is not None:
328
+ json_data = parse_json(file_obj)
329
+ json_data_list = get_json_gps_list(json_data)
330
+ gps_data = fill_gps(json_data_list, video_duration)
331
+ return gps_data
332
+ else:
333
+ return None
334
+
335
+
336
+ def get_plastic_map(center_lat,center_long,trash_gps_gdf,out_folder)->str:
337
+ """Get the map with plastic trash detection
338
+ Args:
339
+ center_lat (float): latitude to center map
340
+ center_long (float): longitude to center map
341
+ trash_gps_gdf (DataFrame): trash & gps geo dataframe
342
+ out_folder (str): folder to save html map
343
+ Returns:
344
+ map_html_path (str): full path to html map
345
+ """
346
+
347
+ m = folium.Map([center_lat, center_long], zoom_start=16)
348
+ locs = zip(trash_gps_gdf.geometry.y,trash_gps_gdf.geometry.x)
349
+ labels = list(trash_gps_gdf['label'])
350
+ i = 0
351
+ for location in locs:
352
+ folium.CircleMarker(location=location).add_child(folium.Popup(labels[i])).add_to(m)
353
+ i = i + 1
354
+ map_html_path = op.join(out_folder,"plasticmap.html")
355
+ m.save(map_html_path)
356
+ return map_html_path
357
+
358
+
359
+
360
+ video_in = gr.inputs.Video(type="mp4", source="upload", label="Video Upload", optional=False)
361
+ model_type = gr.inputs.Dropdown(choices=["centernet", "yolo"], type="value", default="yolo", label="model")
362
+ skip_slider = gr.inputs.Slider(minimum=0, maximum=15, step=1, default=3, label="skip frames")
363
+ tau_slider = gr.inputs.Slider(minimum=1, maximum=7, step=1, default=3, label="tau")
364
+ kappa_slider = gr.inputs.Slider(minimum=1, maximum=7, step=1, default=4, label="kappa")
365
+ seconds_num = gr.inputs.Number(default=10, label="seconds")
366
+ gps_in = gr.inputs.File(type="file", label="GPS Upload", optional=True)
367
+
368
+
369
+ gr.Interface(fn=run_model, inputs=[video_in, model_type, seconds_num, skip_slider, tau_slider, kappa_slider,gps_in],
370
+ outputs=["playable_video","html","label", "file"],
371
+ examples=[[video1_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"gavepau.json"],
372
+ [video2_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"midouze.json"],
373
+ [video3_path, "yolo", 10, 3, 3, 4, JSON_FILE_PATH+"gavepau.json"]],
374
+ description="Upload a video, optionnaly a GPS file and you'll get Plastic detection on river.",
375
+ theme="huggingface",
376
+ allow_screenshot=False, allow_flagging="never")
377
+
378
+ demo.launch()