diff options
author | Ben Cohen <ben@kensho.com> | 2019-07-28 13:43:42 -0400 |
---|---|---|
committer | Ben Cohen <ben@kensho.com> | 2019-07-28 13:43:42 -0400 |
commit | f0f20b41e7eaef25497a9f456b832ee7cd539841 (patch) | |
tree | 31b038a5dc72ed2c36df1a9f699523b0a88ab7bf /utils.py | |
parent | 23d2f73b5f721fd0987ecd127c7cf37fb249a705 (diff) |
rekey
Diffstat (limited to 'utils.py')
-rw-r--r-- | utils.py | 34 |
1 files changed, 12 insertions, 22 deletions
@@ -38,17 +38,15 @@ with open('data/final_data_small.json') as f: -failed = 0 beer_by_style = defaultdict(list) for beer, data in final_data.items(): try: - a, b = beer.split('-') - real_style = beer_names[b + '-' + a][1][0] - if beer_names[b + '-' + a][3] > 7: + real_style = beer_names[beer][1][0] + if beer_names[beer][3] > 7: beer_by_style[real_style].append(data['embed']) except: - failed += 1 + pass style_centers = {} @@ -78,8 +76,7 @@ small_embeddings = [] small_styles = [] bad = 0 for e in embeddings: - b, a = e[0].split('-') - newkey = a + '-' + b + newkey = e[0] if newkey in beer_names: info= beer_names[newkey] @@ -89,18 +86,14 @@ for e in embeddings: else: bad += 1 - - def get_closest(beer_id): one_embed = final_data[beer_id]['embed'] for thing in sorted(small_embeddings, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:25]: - b, a = thing[0].split('-') -# print(b, a) - if a + '-' + b in beer_names: - print(beer_names[a + '-' + b]) + if thing[0] in beer_names: + print(beer_names[thing[0]]) print('=' * 50) def get_closest_to_point(one_embed, style_limit=[]): @@ -116,15 +109,12 @@ def get_closest_to_point(one_embed, style_limit=[]): print(len(possible_beers)) for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]: - b, a = thing[0].split('-') + # print(b, a) - if a + '-' + b in beer_names: - bn = beer_names[a + '-' + b] + if thing[0] in beer_names: + bn = beer_names[thing[0]] print(bn[0], '---', bn[1][0]) - print('=' * 50) - - -# In[16]: + print('=' * 50)# In[16]: def translate_to_attr(embedding, to_attr): @@ -183,7 +173,7 @@ def translate_to_style(embedding, style): beer_id = '4-59' # allagash white # beer_id = '64-33832' # palo salto marron -info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]] +info = beer_names[beer_id] print('Real ABV: {}'.format(info[2])) one_embed = final_data[beer_id]['embed'] @@ -198,7 +188,7 @@ translate_to_style(one_embed, 'Belgian Quadrupel (Quad)') beer_id = '4-59' # allagash white # beer_id = '64-33832' # palo salto marron -info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]] +info = beer_names[beer_id] print('Real ABV: {}'.format(info[2])) one_embed = final_data[beer_id]['embed'] |