aboutsummaryrefslogtreecommitdiff
path: root/utils.py
diff options
context:
space:
mode:
authorBen Cohen <ben@kensho.com>2019-07-28 13:43:42 -0400
committerBen Cohen <ben@kensho.com>2019-07-28 13:43:42 -0400
commitf0f20b41e7eaef25497a9f456b832ee7cd539841 (patch)
tree31b038a5dc72ed2c36df1a9f699523b0a88ab7bf /utils.py
parent23d2f73b5f721fd0987ecd127c7cf37fb249a705 (diff)
rekey
Diffstat (limited to 'utils.py')
-rw-r--r--utils.py34
1 files changed, 12 insertions, 22 deletions
diff --git a/utils.py b/utils.py
index 73decb8..7257c8a 100644
--- a/utils.py
+++ b/utils.py
@@ -38,17 +38,15 @@ with open('data/final_data_small.json') as f:
-failed = 0
beer_by_style = defaultdict(list)
for beer, data in final_data.items():
try:
- a, b = beer.split('-')
- real_style = beer_names[b + '-' + a][1][0]
- if beer_names[b + '-' + a][3] > 7:
+ real_style = beer_names[beer][1][0]
+ if beer_names[beer][3] > 7:
beer_by_style[real_style].append(data['embed'])
except:
- failed += 1
+ pass
style_centers = {}
@@ -78,8 +76,7 @@ small_embeddings = []
small_styles = []
bad = 0
for e in embeddings:
- b, a = e[0].split('-')
- newkey = a + '-' + b
+ newkey = e[0]
if newkey in beer_names:
info= beer_names[newkey]
@@ -89,18 +86,14 @@ for e in embeddings:
else:
bad += 1
-
-
def get_closest(beer_id):
one_embed = final_data[beer_id]['embed']
for thing in sorted(small_embeddings, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:25]:
- b, a = thing[0].split('-')
-# print(b, a)
- if a + '-' + b in beer_names:
- print(beer_names[a + '-' + b])
+ if thing[0] in beer_names:
+ print(beer_names[thing[0]])
print('=' * 50)
def get_closest_to_point(one_embed, style_limit=[]):
@@ -116,15 +109,12 @@ def get_closest_to_point(one_embed, style_limit=[]):
print(len(possible_beers))
for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]:
- b, a = thing[0].split('-')
+
# print(b, a)
- if a + '-' + b in beer_names:
- bn = beer_names[a + '-' + b]
+ if thing[0] in beer_names:
+ bn = beer_names[thing[0]]
print(bn[0], '---', bn[1][0])
- print('=' * 50)
-
-
-# In[16]:
+ print('=' * 50)# In[16]:
def translate_to_attr(embedding, to_attr):
@@ -183,7 +173,7 @@ def translate_to_style(embedding, style):
beer_id = '4-59' # allagash white
# beer_id = '64-33832' # palo salto marron
-info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]]
+info = beer_names[beer_id]
print('Real ABV: {}'.format(info[2]))
one_embed = final_data[beer_id]['embed']
@@ -198,7 +188,7 @@ translate_to_style(one_embed, 'Belgian Quadrupel (Quad)')
beer_id = '4-59' # allagash white
# beer_id = '64-33832' # palo salto marron
-info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]]
+info = beer_names[beer_id]
print('Real ABV: {}'.format(info[2]))
one_embed = final_data[beer_id]['embed']