Skip to content

Commit

Permalink
Add Elmo primitive (#64)
Browse files Browse the repository at this point in the history
  • Loading branch information
davesque authored Aug 31, 2021
1 parent 832c13f commit 139ecc4
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 0 deletions.
1 change: 1 addition & 0 deletions nlp_primitives/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pkg_resources

from .diversity_score import DiversityScore
from .elmo import Elmo
from .lsa import LSA
from .mean_characters_per_word import MeanCharactersPerWord
from .part_of_speech_count import PartOfSpeechCount
Expand Down
55 changes: 55 additions & 0 deletions nlp_primitives/elmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
from featuretools.primitives.base import TransformPrimitive
from featuretools.variable_types import NaturalLanguage, Numeric


class Elmo(TransformPrimitive):
"""Transforms a sentence or short paragraph using deep
contextualized langauge representations. Usese the following
pre-trained model [tfhub model](https://tfhub.dev/google/elmo/2)
Args:
None
Examples:
>>> Elmo = Elmo()
>>> words = ["I like to eat pizza",
... "The roller coaster was built in 1885.",
... "When will humans go to mars?"]
>>> output = Elmo(words)
>>> len(output)
1024
>>> len(output[0])
3
>>> values = output[:3, 0]
>>> [round(x, 4) for x in values]
[-0.3457, -0.4546, 0.2538]
"""
name = "elmo"
input_types = [NaturalLanguage]
return_type = Numeric

def __init__(self):
self.handle = "https://tfhub.dev/google/elmo/2"
self.number_output_features = 1024
self.n = 1024

def install(self):
import tensorflow as tf
import tensorflow_hub as hub

with tf.compat.v1.Session():
tf.compat.v1.global_variables_initializer().run()
self.embed = hub.Module(self.handle)

def get_function(self):
self.install()

def elmo(col):
import tensorflow as tf

with tf.compat.v1.Session() as session:
session.run([tf.compat.v1.global_variables_initializer(),
tf.compat.v1.tables_initializer()])
embeddings = session.run(self.embed(col.tolist()))
return embeddings.transpose()
return elmo
47 changes: 47 additions & 0 deletions nlp_primitives/tests/test_elmo.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import numpy as np
import pandas as pd

from ..elmo import Elmo
from .test_utils import PrimitiveT, find_applicable_primitives, valid_dfs


class TestElmo(PrimitiveT):
primitive = Elmo

def test_regular(self):
primitive = self.primitive().get_function()
words = pd.Series([
"",
"I like to eat pizza",
"The roller coaster was built in 1885.",
"When will humans go to mars?",
"Mitochondria is the powerhouse of the cell",
])
new_results = primitive(words)
assert round(sum(new_results[:, 0]), 3) == 8.744
assert round(sum(new_results[:, 1]), 3) == 2.774
assert round(sum(new_results[:, 2]), 3) == -8.852
assert round(sum(new_results[:, 3]), 3) == -2.433
assert round(sum(new_results[:, 4]), 3) == -13.623
answer0 = new_results[:5, 0].round(5)
answer1 = new_results[:5, 1].round(5)
answer2 = new_results[:5, 2].round(5)
answer3 = new_results[:5, 3].round(5)
answer4 = new_results[:5, 4].round(5)
correct0 = np.array([-0.19886, -0.30473, 0.28911, -0.1545, -0.08023]).round(5)
correct1 = np.array([-0.3457, -0.45462, 0.25379, 0.02318, 0.33729]).round(5)
correct2 = np.array([-0.07041, -0.09047, -0.26982, 0.03027, -0.5508]).round(5)
correct3 = np.array([0.25077, -0.01977, -0.10443, -0.26923, 0.453]).round(5)
correct4 = np.array([-0.21744, 0.71178, -0.04611, -0.02631, -0.44973]).round(5)
np.testing.assert_allclose(correct0, answer0, rtol=1e-04)
np.testing.assert_allclose(correct1, answer1, rtol=1e-04)
np.testing.assert_allclose(correct2, answer2, rtol=1e-04)
np.testing.assert_allclose(correct3, answer3, rtol=1e-04)
np.testing.assert_allclose(correct4, answer4, rtol=1e-04)

def test_with_featuretools(self, es):
transform, aggregation = find_applicable_primitives(self.primitive)
primitive_instance = self.primitive()
transform.append(primitive_instance)
valid_dfs(es, aggregation, transform, self.primitive.name.upper(),
multi_output=True, max_features=100, instance_ids=[0])

0 comments on commit 139ecc4

Please sign in to comment.