diff options
| author | Ben Cohen <ben@kensho.com> | 2019-07-28 13:51:38 -0400 | 
|---|---|---|
| committer | Ben Cohen <ben@kensho.com> | 2019-07-28 13:51:38 -0400 | 
| commit | e92085f124cb5f16edfc4c19dac1e6823d4838b7 (patch) | |
| tree | 62a90ab516447efc930ae130008efe572f96ccd8 /utils.py | |
| parent | 52b6c48da443a6fba35dec68348499640944790d (diff) | |
returning
Diffstat (limited to 'utils.py')
| -rw-r--r-- | utils.py | 48 | 
1 files changed, 28 insertions, 20 deletions
@@ -107,20 +107,24 @@ def get_closest_to_point(one_embed, style_limit=[]):      else:          possible_beers = small_embeddings      print(len(possible_beers)) + +    to_return = []      for thing in sorted(possible_beers, key = lambda x: cosine(one_embed, x[1]), reverse=False)[:5]:  #         print(b, a)          if thing[0] in beer_names:              bn = beer_names[thing[0]] -            print(bn[0], '---', bn[1][0]) -            print('=' * 50)# In[16]: + +            to_return.append((thing[0], bn[0], bn[1][0])) +    return to_return -def translate_to_attr(embedding, to_attr): -    if to_attr.startswith('-'): + +def translate_to_attr(embedding, to_attr, amt): +    if amt < 0:          back = True -        to_attr = to_attr[1:] +        amt = abs(amt)      else:          back = False @@ -138,14 +142,19 @@ def translate_to_attr(embedding, to_attr):      print('=' *10)      print('\n')      dist_between = np.linalg.norm(closest_center_vector-embedding) -    for x in [4,2, 1]: -        print('Moving {}%'.format(1/float(x) * 100)) -        if back: -            new_point = embedding - (vector_between/x/5) -        else: -            new_point = embedding + (vector_between/x) -        get_closest_to_point(new_point) -        print('-' * 25) +    amt_dict = { +        1: 4, +        2: 2, +        3: 1 +    } +    x = amt_dict[amt] +    print('Moving {}%'.format(1/float(x) * 100)) +    if back: +        new_point = embedding - (vector_between/x) +    else: +        new_point = embedding + (vector_between/x) +    return get_closest_to_point(new_point) +  def translate_to_style(embedding, style): @@ -157,12 +166,11 @@ def translate_to_style(embedding, style):      print('=' *10)      print('\n')      dist_between = np.linalg.norm(closest_center_vector-embedding) -    for x in [4,2, 1]: -        print('Moving {}%'.format(1/float(x) * 100)) -        new_point = embedding + (vector_between/x) -        get_closest_to_point(new_point, style_limit=style) -        print('-' * 25) +    TRANSLATION_CONSTANT = 4  # 4 is slight +     +    new_point = embedding + (vector_between/TRANSLATION_CONSTANT) +    return get_closest_to_point(new_point, style_limit=style)  # In[17]: @@ -177,7 +185,7 @@ info = beer_names[beer_id]  print('Real ABV: {}'.format(info[2]))  one_embed = final_data[beer_id]['embed'] -translate_to_style(one_embed, 'Belgian Quadrupel (Quad)') +print(translate_to_style(one_embed, 'Belgian Quadrupel (Quad)'))  # In[22]: @@ -192,7 +200,7 @@ info = beer_names[beer_id]  print('Real ABV: {}'.format(info[2]))  one_embed = final_data[beer_id]['embed'] -translate_to_attr(one_embed, 'Hoppy') +print(translate_to_attr(one_embed, 'Hoppy', 1))  # In[ ]:  | 
