From f0f20b41e7eaef25497a9f456b832ee7cd539841 Mon Sep 17 00:00:00 2001 From: Ben Cohen Date: Sun, 28 Jul 2019 13:43:42 -0400 Subject: rekey --- utils.py | 34 ++++++++++++---------------------- 1 file changed, 12 insertions(+), 22 deletions(-) (limited to 'utils.py') diff --git a/utils.py b/utils.py index 73decb8..7257c8a 100644 --- a/utils.py +++ b/utils.py @@ -38,17 +38,15 @@ with open('data/final_data_small.json') as f: -failed = 0 beer_by_style = defaultdict(list) for beer, data in final_data.items(): try: - a, b = beer.split('-') - real_style = beer_names[b + '-' + a][1][0] - if beer_names[b + '-' + a][3] > 7: + real_style = beer_names[beer][1][0] + if beer_names[beer][3] > 7: beer_by_style[real_style].append(data['embed']) except: - failed += 1 + pass style_centers = {} @@ -78,8 +76,7 @@ small_embeddings = [] small_styles = [] bad = 0 for e in embeddings: - b, a = e[0].split('-') - newkey = a + '-' + b + newkey = e[0] if newkey in beer_names: info= beer_names[newkey] @@ -89,18 +86,14 @@ for e in embeddings: else: bad += 1 - - def get_closest(beer_id): one_embed = final_data[beer_id]['embed'] for thing in sorted(small_embeddings, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:25]: - b, a = thing[0].split('-') -# print(b, a) - if a + '-' + b in beer_names: - print(beer_names[a + '-' + b]) + if thing[0] in beer_names: + print(beer_names[thing[0]]) print('=' * 50) def get_closest_to_point(one_embed, style_limit=[]): @@ -116,15 +109,12 @@ def get_closest_to_point(one_embed, style_limit=[]): print(len(possible_beers)) for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]: - b, a = thing[0].split('-') + # print(b, a) - if a + '-' + b in beer_names: - bn = beer_names[a + '-' + b] + if thing[0] in beer_names: + bn = beer_names[thing[0]] print(bn[0], '---', bn[1][0]) - print('=' * 50) - - -# In[16]: + print('=' * 50)# In[16]: def translate_to_attr(embedding, to_attr): @@ -183,7 +173,7 @@ def translate_to_style(embedding, style): beer_id = '4-59' # allagash white # beer_id = '64-33832' # palo salto marron -info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]] +info = beer_names[beer_id] print('Real ABV: {}'.format(info[2])) one_embed = final_data[beer_id]['embed'] @@ -198,7 +188,7 @@ translate_to_style(one_embed, 'Belgian Quadrupel (Quad)') beer_id = '4-59' # allagash white # beer_id = '64-33832' # palo salto marron -info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]] +info = beer_names[beer_id] print('Real ABV: {}'.format(info[2])) one_embed = final_data[beer_id]['embed'] -- cgit v1.2.3