From 23d2f73b5f721fd0987ecd127c7cf37fb249a705 Mon Sep 17 00:00:00 2001 From: Ben Cohen Date: Sun, 28 Jul 2019 13:29:20 -0400 Subject: add utils and data --- utils.py | 218 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 218 insertions(+) create mode 100644 utils.py (limited to 'utils.py') diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..73decb8 --- /dev/null +++ b/utils.py @@ -0,0 +1,218 @@ +import re +import funcy +import csv +import numpy as np +import datetime +import json +from collections import defaultdict +import glob +from scipy.spatial.distance import cosine + + +with open('data/styles.json') as f: + styles = json.load(f) + + + +with open('data/beer_info.json', 'r') as f: + beer_names = json.load(f) + + +attr_to_styles = defaultdict(list) +style_to_attrs = defaultdict(list) +with open('data/sparser_beer_data.csv') as f: + reader = csv.DictReader(f) + for line in reader: + style = line.pop('Style') + for k, v in line.items(): + if v == '1': + attr_to_styles[k].append(style) + style_to_attrs[style].append(k) + + +# In[7]: + +with open('data/final_data_small.json') as f: + final_data = json.load(f) + + + + +failed = 0 +beer_by_style = defaultdict(list) +for beer, data in final_data.items(): + try: + a, b = beer.split('-') + real_style = beer_names[b + '-' + a][1][0] + if beer_names[b + '-' + a][3] > 7: + beer_by_style[real_style].append(data['embed']) + + except: + failed += 1 + + +style_centers = {} +for style, datas in beer_by_style.items(): + style_centers[style.strip()] = np.mean(datas, axis=0) + + + +style_name_to_num = {y[0]: x for x,y in styles.items()} + + + + +attr_centers = {} +for attr, rel_styles in attr_to_styles.items(): + centers = [style_centers[x] for x in rel_styles] + attr_mean = np.mean(centers, axis=0) + attr_centers[attr] = attr_mean + + + +embeddings = [(x[0], x[1]['embed']) for x in final_data.items()] + + + +small_embeddings = [] +small_styles = [] +bad = 0 +for e in embeddings: + b, a = e[0].split('-') + newkey = a + '-' + b + if newkey in beer_names: + info= beer_names[newkey] + + if info[3] > 25: + small_embeddings.append(e) + small_styles.append(info[1][0]) + else: + bad += 1 + + + +def get_closest(beer_id): + one_embed = final_data[beer_id]['embed'] + + + for thing in sorted(small_embeddings, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:25]: + + b, a = thing[0].split('-') +# print(b, a) + if a + '-' + b in beer_names: + print(beer_names[a + '-' + b]) + print('=' * 50) + +def get_closest_to_point(one_embed, style_limit=[]): +# one_embed = final_data[beer_id]['embed'] + + if style_limit: + possible_beers = [] + for style, e in zip(small_styles, small_embeddings): + if style == style_limit: + possible_beers.append(e) + else: + possible_beers = small_embeddings + print(len(possible_beers)) + for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]: + + b, a = thing[0].split('-') +# print(b, a) + if a + '-' + b in beer_names: + bn = beer_names[a + '-' + b] + print(bn[0], '---', bn[1][0]) + print('=' * 50) + + +# In[16]: + + +def translate_to_attr(embedding, to_attr): + if to_attr.startswith('-'): + back = True + to_attr = to_attr[1:] + + else: + back = False + + relevant_styles = attr_to_styles[to_attr] + small_rel_centers = {x: y for x, y in style_centers.items() if x in relevant_styles} + + sorted_centers = sorted(small_rel_centers, key=lambda x: cosine(small_rel_centers[x], embedding)) + closest_center = sorted_centers[0] + closest_center_vector = style_centers[closest_center] + + vector_between = closest_center_vector - embedding + + print('Moving Towards/From: {}'.format(closest_center)) + print('=' *10) + print('\n') + dist_between = np.linalg.norm(closest_center_vector-embedding) + for x in [4,2, 1]: + print('Moving {}%'.format(1/float(x) * 100)) + if back: + new_point = embedding - (vector_between/x/5) + else: + new_point = embedding + (vector_between/x) + get_closest_to_point(new_point) + print('-' * 25) + +def translate_to_style(embedding, style): + + closest_center_vector = style_centers[style] + + vector_between = closest_center_vector - embedding +# print(vector_between) + print('Moving Towards/From: {}'.format(style)) + print('=' *10) + print('\n') + dist_between = np.linalg.norm(closest_center_vector-embedding) + for x in [4,2, 1]: + print('Moving {}%'.format(1/float(x) * 100)) + + new_point = embedding + (vector_between/x) + get_closest_to_point(new_point, style_limit=style) + print('-' * 25) + + +# In[17]: + + +# beer_id = '388-1703' # cantillio gueueze +# beer_id = '140-276' # SN pale ale +beer_id = '4-59' # allagash white +# beer_id = '64-33832' # palo salto marron + +info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]] +print('Real ABV: {}'.format(info[2])) + +one_embed = final_data[beer_id]['embed'] +translate_to_style(one_embed, 'Belgian Quadrupel (Quad)') + + +# In[22]: + + +# beer_id = '388-1703' # cantillio gueueze +# beer_id = '140-276' # SN pale ale +beer_id = '4-59' # allagash white +# beer_id = '64-33832' # palo salto marron + +info = beer_names[beer_id.split('-')[1] + '-' + beer_id.split('-')[0]] +print('Real ABV: {}'.format(info[2])) + +one_embed = final_data[beer_id]['embed'] +translate_to_attr(one_embed, 'Hoppy') + + +# In[ ]: + + + + + +# In[ ]: + + + + -- cgit v1.2.3