Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,183 +8,172 @@ import io
|
|
| 8 |
import pickle
|
| 9 |
import random
|
| 10 |
|
| 11 |
-
def
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
)
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
#
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
#
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
#
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
# if (mod1 == "Image"):
|
| 47 |
-
# query = id_to_image_emb_dict[id]
|
| 48 |
-
# elif(mod1 == "DNA"):
|
| 49 |
-
# query = id_to_dna_emb_dict[id]
|
| 50 |
-
# query = query.astype(np.float32)
|
| 51 |
-
# D, I = index.search(query, num_neighbors)
|
| 52 |
-
|
| 53 |
-
# id_list = []
|
| 54 |
-
# for indx in I[0]:
|
| 55 |
-
# id = indx_to_id_dict[indx]
|
| 56 |
-
# id_list.append(id)
|
| 57 |
|
| 58 |
-
#
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
| 71 |
-
#
|
| 72 |
-
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
|
| 84 |
-
|
| 85 |
-
|
| 86 |
-
|
| 87 |
-
|
| 88 |
-
|
| 89 |
-
|
| 90 |
-
|
| 91 |
-
|
| 92 |
-
|
| 93 |
-
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
| 97 |
-
|
| 98 |
-
#
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
| 112 |
-
#
|
| 113 |
-
|
| 114 |
-
|
| 115 |
-
|
| 116 |
-
#
|
| 117 |
-
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
#
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
| 147 |
-
|
| 148 |
-
|
| 149 |
-
|
| 150 |
-
|
| 151 |
-
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
|
|
|
| 8 |
import pickle
|
| 9 |
import random
|
| 10 |
|
| 11 |
+
def get_image(image1, image2, dataset_image_mask, processid_to_index, idx):
|
| 12 |
+
if (idx < 162834):
|
| 13 |
+
image_enc_padded = image1[idx].astype(np.uint8)
|
| 14 |
+
elif(idx >= 162834):
|
| 15 |
+
image_enc_padded = image2[idx-162834].astype(np.uint8)
|
| 16 |
+
enc_length = dataset_image_mask[idx]
|
| 17 |
+
image_enc = image_enc_padded[:enc_length]
|
| 18 |
+
image = Image.open(io.BytesIO(image_enc))
|
| 19 |
+
return image
|
| 20 |
+
|
| 21 |
+
def searchEmbeddings(id, mod1, mod2):
|
| 22 |
+
# variable and index initialization
|
| 23 |
+
original_indx = processid_to_index[id]
|
| 24 |
+
dim = 768
|
| 25 |
+
num_neighbors = 10
|
| 26 |
+
|
| 27 |
+
# get index
|
| 28 |
+
index = faiss.IndexFlatIP(dim)
|
| 29 |
+
if (mod2 == "Image"):
|
| 30 |
+
index = faiss.read_index("image_index.index")
|
| 31 |
+
elif (mod2 == "DNA"):
|
| 32 |
+
index = faiss.read_index("dna_index.index")
|
| 33 |
+
|
| 34 |
+
# search index
|
| 35 |
+
if (mod1 == "Image"):
|
| 36 |
+
query = id_to_image_emb_dict[id]
|
| 37 |
+
elif(mod1 == "DNA"):
|
| 38 |
+
query = id_to_dna_emb_dict[id]
|
| 39 |
+
query = query.astype(np.float32)
|
| 40 |
+
D, I = index.search(query, num_neighbors)
|
| 41 |
+
|
| 42 |
+
id_list = []
|
| 43 |
+
for indx in I[0]:
|
| 44 |
+
id = indx_to_id_dict[indx]
|
| 45 |
+
id_list.append(id)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 46 |
|
| 47 |
+
# get images
|
| 48 |
+
image0 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, original_indx)
|
| 49 |
+
image1 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][0])
|
| 50 |
+
image2 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][1])
|
| 51 |
+
image3 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][2])
|
| 52 |
+
image4 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][3])
|
| 53 |
+
image5 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][4])
|
| 54 |
+
image6 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][5])
|
| 55 |
+
image7 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][6])
|
| 56 |
+
image8 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][7])
|
| 57 |
+
image9 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][8])
|
| 58 |
+
image10 = get_image(dataset_image1, dataset_image2, dataset_image_mask, processid_to_index, I[0][9])
|
| 59 |
+
|
| 60 |
+
# get taxonomic information
|
| 61 |
+
s0 = getTax(original_indx)
|
| 62 |
+
s1 = getTax(I[0][0])
|
| 63 |
+
s2 = getTax(I[0][1])
|
| 64 |
+
s3 = getTax(I[0][2])
|
| 65 |
+
s4 = getTax(I[0][3])
|
| 66 |
+
s5 = getTax(I[0][4])
|
| 67 |
+
s6 = getTax(I[0][5])
|
| 68 |
+
s7 = getTax(I[0][6])
|
| 69 |
+
s8 = getTax(I[0][7])
|
| 70 |
+
s9 = getTax(I[0][8])
|
| 71 |
+
s10 = getTax(I[0][9])
|
| 72 |
|
| 73 |
+
return id_list, image0, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10, s0, s1, s2, s3, s4, s5, s6, s7, s8, s9, s10
|
| 74 |
+
|
| 75 |
+
def getRandID():
|
| 76 |
+
indx = random.randrange(0, 325667)
|
| 77 |
+
return indx_to_id_dict[indx], indx
|
| 78 |
+
|
| 79 |
+
def getTax(indx):
|
| 80 |
+
s = species[indx]
|
| 81 |
+
g = genus[indx]
|
| 82 |
+
f = family[indx]
|
| 83 |
+
str = "Species: " + s + "\nGenus: " + g + "\nFamily: " + f
|
| 84 |
+
return str
|
| 85 |
+
|
| 86 |
+
with gr.Blocks(title="Bioscan-Clip") as demo:
|
| 87 |
+
# open general files
|
| 88 |
+
with open("dataset_image1.pickle", "rb") as f:
|
| 89 |
+
dataset_image1 = pickle.load(f)
|
| 90 |
+
with open("dataset_image2.pickle", "rb") as f:
|
| 91 |
+
dataset_image2 = pickle.load(f)
|
| 92 |
+
with open("dataset_processid_list.pickle", "rb") as f:
|
| 93 |
+
dataset_processid_list = pickle.load(f)
|
| 94 |
+
with open("dataset_image_mask.pickle", "rb") as f:
|
| 95 |
+
dataset_image_mask = pickle.load(f)
|
| 96 |
+
with open("processid_to_index.pickle", "rb") as f:
|
| 97 |
+
processid_to_index = pickle.load(f)
|
| 98 |
+
with open("indx_to_id_dict.pickle", "rb") as f:
|
| 99 |
+
indx_to_id_dict = pickle.load(f)
|
| 100 |
+
|
| 101 |
+
# open image files
|
| 102 |
+
with open("id_to_image_emb_dict.pickle", "rb") as f:
|
| 103 |
+
id_to_image_emb_dict = pickle.load(f)
|
| 104 |
+
|
| 105 |
+
# open dna files
|
| 106 |
+
with open("id_to_dna_emb_dict.pickle", "rb") as f:
|
| 107 |
+
id_to_dna_emb_dict = pickle.load(f)
|
| 108 |
+
|
| 109 |
+
# open taxonomy files
|
| 110 |
+
with open("family.pickle", "rb") as f:
|
| 111 |
+
family = [item.decode("utf-8") for item in pickle.load(f)]
|
| 112 |
+
with open("genus.pickle", "rb") as f:
|
| 113 |
+
genus= [item.decode("utf-8") for item in pickle.load(f)]
|
| 114 |
+
with open("species.pickle", "rb") as f:
|
| 115 |
+
species = [item.decode("utf-8") for item in pickle.load(f)]
|
| 116 |
+
with gr.Column():
|
| 117 |
+
process_id = gr.Textbox(label="ID:", info="Enter a sample ID to search for")
|
| 118 |
+
process_id_list = gr.Textbox(label="Closest 10 matches:" )
|
| 119 |
+
mod1 = gr.Radio(choices=["DNA", "Image"], label="Search From:")
|
| 120 |
+
mod2 = gr.Radio(choices=["DNA", "Image"], label="Search To:")
|
| 121 |
+
search_btn = gr.Button("Search")
|
| 122 |
+
|
| 123 |
+
with gr.Row():
|
| 124 |
+
with gr.Column():
|
| 125 |
+
image0 = gr.Image(label="Original", height=550)
|
| 126 |
+
tax0 = gr.Textbox(label="Taxonomy")
|
| 127 |
+
with gr.Column():
|
| 128 |
+
rand_id = gr.Textbox(label="Random ID:")
|
| 129 |
+
rand_id_indx = gr.Textbox(label="Index:")
|
| 130 |
+
id_btn = gr.Button("Get Random ID")
|
| 131 |
+
|
| 132 |
+
with gr.Row():
|
| 133 |
+
with gr.Column():
|
| 134 |
+
image1 = gr.Image(label=1)
|
| 135 |
+
tax1 = gr.Textbox(label="Taxonomy")
|
| 136 |
+
with gr.Column():
|
| 137 |
+
image2 = gr.Image(label=2)
|
| 138 |
+
tax2 = gr.Textbox(label="Taxonomy")
|
| 139 |
+
with gr.Column():
|
| 140 |
+
image3 = gr.Image(label=3)
|
| 141 |
+
tax3 = gr.Textbox(label="Taxonomy")
|
| 142 |
+
|
| 143 |
+
with gr.Row():
|
| 144 |
+
with gr.Column():
|
| 145 |
+
image4 = gr.Image(label=4)
|
| 146 |
+
tax4 = gr.Textbox(label="Taxonomy")
|
| 147 |
+
with gr.Column():
|
| 148 |
+
image5 = gr.Image(label=5)
|
| 149 |
+
tax5 = gr.Textbox(label="Taxonomy")
|
| 150 |
+
with gr.Column():
|
| 151 |
+
image6 = gr.Image(label=6)
|
| 152 |
+
tax6 = gr.Textbox(label="Taxonomy")
|
| 153 |
+
|
| 154 |
+
with gr.Row():
|
| 155 |
+
with gr.Column():
|
| 156 |
+
image7 = gr.Image(label=7)
|
| 157 |
+
tax7 = gr.Textbox(label="Taxonomy")
|
| 158 |
+
with gr.Column():
|
| 159 |
+
image8 = gr.Image(label=8)
|
| 160 |
+
tax8 = gr.Textbox(label="Taxonomy")
|
| 161 |
+
with gr.Column():
|
| 162 |
+
image9 = gr.Image(label=9)
|
| 163 |
+
tax9 = gr.Textbox(label="Taxonomy")
|
| 164 |
+
with gr.Column():
|
| 165 |
+
image10 = gr.Image(label=10)
|
| 166 |
+
tax10 = gr.Textbox(label="Taxonomy")
|
| 167 |
+
|
| 168 |
+
id_btn.click(fn=getRandID, inputs=[], outputs=[rand_id, rand_id_indx])
|
| 169 |
+
search_btn.click(fn=searchEmbeddings, inputs=[process_id, mod1, mod2],
|
| 170 |
+
outputs=[process_id_list, image0, image1, image2, image3, image4, image5, image6, image7, image8, image9, image10,
|
| 171 |
+
tax0, tax1, tax2, tax3, tax4, tax5, tax6, tax7, tax8, tax9, tax10])
|
| 172 |
+
examples = gr.Examples(
|
| 173 |
+
examples=[["ABOTH966-22", "DNA", "DNA"],
|
| 174 |
+
["CRTOB8472-22", "DNA", "Image"],
|
| 175 |
+
["PLOAD050-20", "Image", "DNA"],
|
| 176 |
+
["HELAC26711-21", "Image", "Image"]],
|
| 177 |
+
inputs=[process_id, mod1, mod2],)
|
| 178 |
+
|
| 179 |
+
demo.launch()
|