I noticed a lot of cyclic calls in the metric implementation provided leading to slow computation. I wrote a much faster implementation tested to be 16x faster.
However, both true_intervals and predicted_intervals here have to be numpy arrays.
# fast metric computation
import numpy as np
import pandas as pd
def bb_iou_array(boxes, new_box):
# bb interesection over union
tA = np.maximum(boxes[:, 0], new_box[0])
tB = np.minimum(boxes[:, 1], new_box[1])
interWidth = np.maximum(tB - tA, 0)
# compute the width of both the prediction and ground-truth rectangles
boxAWidth = (boxes[:, 1] - boxes[:, 0])
boxBWidth = (new_box[1] - new_box[0])
union = (boxAWidth + boxBWidth - interWidth)
iou = interWidth / union
idx_best = interWidth.argmax()
interval=boxes[idx_best]
true_start = new_box[0]
predicted_start = interval[0]
diff_start=abs(true_start - predicted_start)
tau = (5 if predicted_start>true_start else 10)
exp_diff_start = np.exp(-diff_start/tau)
result = dict(
intersection=interWidth[idx_best],
union=union[idx_best],
iou=iou[idx_best],
interval=interval,
true=new_box,
diff_start=diff_start,
exp_diff_start=exp_diff_start,
)
return result
def get_score(true_intervals, predicted_intervals):
assignations = pd.DataFrame.from_dict(
map(lambda true_interval: bb_iou_array(predicted_intervals, true_interval), true_intervals)
).query("intersection!=0").reset_index(drop=True)
total_intersection = assignations.intersection.sum()
total_union = assignations.union.sum()
IoU = total_intersection/total_union if ((total_union>0) & (total_intersection>0)) else 0
NegExpDiffStarts = 0
if(len(assignations)>0):
NegExpDiffStarts = assignations.exp_diff_start.mean()
Recall=float(len(assignations)/len(true_intervals)) if (len(true_intervals)>0) else 0
Precision=float(len(assignations)/len(predicted_intervals)) if (len(predicted_intervals)>0) else 0
FBeta=float((1+0.5*0.5)*(Precision*Recall)/(0.5*0.5*Precision+Recall)) if((Recall>0) | (Recall>0)) else 0
final_score = (IoU + 2*NegExpDiffStarts + 3*FBeta)/6
# display(assignations)
# display(final_score)
return final_score