-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
103 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
from featuretools.primitives.base import TransformPrimitive | ||
from featuretools.variable_types import NaturalLanguage, Numeric | ||
|
||
|
||
class Elmo(TransformPrimitive): | ||
"""Transforms a sentence or short paragraph using deep | ||
contextualized langauge representations. Usese the following | ||
pre-trained model [tfhub model](https://tfhub.dev/google/elmo/2) | ||
Args: | ||
None | ||
Examples: | ||
>>> Elmo = Elmo() | ||
>>> words = ["I like to eat pizza", | ||
... "The roller coaster was built in 1885.", | ||
... "When will humans go to mars?"] | ||
>>> output = Elmo(words) | ||
>>> len(output) | ||
1024 | ||
>>> len(output[0]) | ||
3 | ||
>>> values = output[:3, 0] | ||
>>> [round(x, 4) for x in values] | ||
[-0.3457, -0.4546, 0.2538] | ||
""" | ||
name = "elmo" | ||
input_types = [NaturalLanguage] | ||
return_type = Numeric | ||
|
||
def __init__(self): | ||
self.handle = "https://tfhub.dev/google/elmo/2" | ||
self.number_output_features = 1024 | ||
self.n = 1024 | ||
|
||
def install(self): | ||
import tensorflow as tf | ||
import tensorflow_hub as hub | ||
|
||
with tf.compat.v1.Session(): | ||
tf.compat.v1.global_variables_initializer().run() | ||
self.embed = hub.Module(self.handle) | ||
|
||
def get_function(self): | ||
self.install() | ||
|
||
def elmo(col): | ||
import tensorflow as tf | ||
|
||
with tf.compat.v1.Session() as session: | ||
session.run([tf.compat.v1.global_variables_initializer(), | ||
tf.compat.v1.tables_initializer()]) | ||
embeddings = session.run(self.embed(col.tolist())) | ||
return embeddings.transpose() | ||
return elmo |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import numpy as np | ||
import pandas as pd | ||
|
||
from ..elmo import Elmo | ||
from .test_utils import PrimitiveT, find_applicable_primitives, valid_dfs | ||
|
||
|
||
class TestElmo(PrimitiveT): | ||
primitive = Elmo | ||
|
||
def test_regular(self): | ||
primitive = self.primitive().get_function() | ||
words = pd.Series([ | ||
"", | ||
"I like to eat pizza", | ||
"The roller coaster was built in 1885.", | ||
"When will humans go to mars?", | ||
"Mitochondria is the powerhouse of the cell", | ||
]) | ||
new_results = primitive(words) | ||
assert round(sum(new_results[:, 0]), 3) == 8.744 | ||
assert round(sum(new_results[:, 1]), 3) == 2.774 | ||
assert round(sum(new_results[:, 2]), 3) == -8.852 | ||
assert round(sum(new_results[:, 3]), 3) == -2.433 | ||
assert round(sum(new_results[:, 4]), 3) == -13.623 | ||
answer0 = new_results[:5, 0].round(5) | ||
answer1 = new_results[:5, 1].round(5) | ||
answer2 = new_results[:5, 2].round(5) | ||
answer3 = new_results[:5, 3].round(5) | ||
answer4 = new_results[:5, 4].round(5) | ||
correct0 = np.array([-0.19886, -0.30473, 0.28911, -0.1545, -0.08023]).round(5) | ||
correct1 = np.array([-0.3457, -0.45462, 0.25379, 0.02318, 0.33729]).round(5) | ||
correct2 = np.array([-0.07041, -0.09047, -0.26982, 0.03027, -0.5508]).round(5) | ||
correct3 = np.array([0.25077, -0.01977, -0.10443, -0.26923, 0.453]).round(5) | ||
correct4 = np.array([-0.21744, 0.71178, -0.04611, -0.02631, -0.44973]).round(5) | ||
np.testing.assert_allclose(correct0, answer0, rtol=1e-04) | ||
np.testing.assert_allclose(correct1, answer1, rtol=1e-04) | ||
np.testing.assert_allclose(correct2, answer2, rtol=1e-04) | ||
np.testing.assert_allclose(correct3, answer3, rtol=1e-04) | ||
np.testing.assert_allclose(correct4, answer4, rtol=1e-04) | ||
|
||
def test_with_featuretools(self, es): | ||
transform, aggregation = find_applicable_primitives(self.primitive) | ||
primitive_instance = self.primitive() | ||
transform.append(primitive_instance) | ||
valid_dfs(es, aggregation, transform, self.primitive.name.upper(), | ||
multi_output=True, max_features=100, instance_ids=[0]) |