diff --git a/.travis.yml b/.travis.yml index 5c71a1e..4673191 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,9 +1,8 @@ language: java jdk: - - oraclejdk8 + - openjdk11 os: - linux - - osx cache: bundler # Setting install to 'true' to prevent Travis CI from installing depedencies via: # "mvn install -DskipTests=true -Dmaven.javadoc.skip=true -B -V" which fails due to missing the GPG secret key. diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..4be0e6c --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,7 @@ +# Changelog + +## Version 0.9.1 + +* Use [FastText v0.9.1](https://github.com/facebookresearch/fastText/releases/tag/v0.9.1) +* Minor bug fixes + \ No newline at end of file diff --git a/README.md b/README.md index d32185a..3eb99f9 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,11 @@ -[![Build Status](https://travis-ci.org/vinhkhuc/JFastText.svg?branch=master)](https://travis-ci.org/vinhkhuc/JFastText) +[![Build Status](https://travis-ci.org/carschno/JFastText.svg?branch=master)](https://travis-ci.org/carschno/JFastText) Table of Contents ================= * [Introduction](#introduction) * [Maven Dependency](#maven-dependency) + * [Windows and Mac OSX](#windows-and-mac-os-x) * [Building](#building) * [Quick Application - Language Identification](#quick-application-\--language-identification) * [Detailed Examples](#detailed-examples) @@ -12,6 +13,7 @@ Table of Contents * [FastText's Command Line](#fasttexts-command-line) * [License](#license) * [References](#references) + * [Changelog](CHANGELOG.md) ## Introduction @@ -28,23 +30,39 @@ JFastText is ideal for building fast text classifiers in Java. ## Maven Dependency ```xml - com.github.vinhkhuc + io.github.carschno jfasttext - 0.4 + 0.9.1 ``` -The Jar package on Maven Central is bundled with precompiled fastText library for Windows, Linux and -MacOSX 64bit. +The Jar package on Maven Central is bundled with precompiled fastText library for ~~Windows,~~ Linux ~~and +MacOSX~~ 64bit. + +### Windows and Mac OS X + +Currently, the Maven dependency only contains binaries for Linux (64 bit), _not_ for Windows or Mac OS X. +In order to use JFastText for Windows or Mac OS X (or any other system), you need to build it yourself (see [below](#building)). ## Building -C++ compiler (g++ on Mac/Linux or cl.exe on Windows) is required to compile fastText's code. +C++ compiler (g++ on Mac/Linux or `cl.exe` on Windows) is required to compile fastText's code. ```bash -git clone --recursive https://github.com/vinhkhuc/JFastText +git clone --recursive https://github.com/carschno/JFastText cd JFastText +git submodule init +git submodule update mvn package ``` +### Building on Windows + +The (automatic) build seems to fail on some Windows systems/C++ compilers. +See [this issue](https://github.com/carschno/JFastText/issues/5#issuecomment-546485377): + +> I used MS's developer tools, not the full-blown Visual Studio. If I run `cl` directly, the compilation fails with the same error. +> +> I was able to build on Windows by changing the call to `cl.exe` and running it outside the Maven build. I changed one parameter in the call to `cl`: I use `/MT` (whereas Maven uses `/MD`). Bundling the generated DLLs works fine. + ## Quick Application - Language Identification JFastText can use FastText's pretrained models directly. Language identification models can be downloaded [here](https://fasttext.cc/docs/en/language-identification.html). In this quick example, we will use the [quantized model](https://s3-us-west-1.amazonaws.com/fasttext-vectors/supervised_models/lid.176.ftz) diff --git a/examples/api/pom.xml b/examples/api/pom.xml index c2002ae..0ca4a9d 100644 --- a/examples/api/pom.xml +++ b/examples/api/pom.xml @@ -7,12 +7,15 @@ com.github.vinhkhuc java_sandbox 0.1 - + + 1.8 + 1.8 + - com.github.vinhkhuc + io.github.carschno jfasttext - 0.4 + 0.5.0 \ No newline at end of file diff --git a/pom.xml b/pom.xml index 3899d17..01b5822 100644 --- a/pom.xml +++ b/pom.xml @@ -4,9 +4,9 @@ xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> 4.0.0 - com.github.vinhkhuc + io.github.carschno jfasttext - 0.4 + 0.9.1 Java interface for fastText JFastText is a Java interface for fastText, a library for efficient learning of @@ -22,6 +22,15 @@ + + carschno + Carsten Schnober + http://github.com/carschno + + developer + maintainer + + vinhkhuc Vinh Khuc @@ -34,13 +43,14 @@ - scm:git:https://github.com/vinhkhuc/JFastText.git - scm:git:git@github.com:vinhkhuc/JFastText.git - https://github.com/vinhkhuc/JFastText + scm:git:https://github.com/carschno/JFastText.git + scm:git:git@github.com:carschno/JFastText.git + https://github.com/carschno/JFastText UTF-8 + 1.5.1 @@ -65,21 +75,42 @@ -Xdoclint:none - - - - - + + release + + + org.sonatype.plugins nexus-staging-maven-plugin 1.6.3 true - ossrh - https://oss.sonatype.org/ - true + ossrh + https://oss.sonatype.org/ + false - + + + org.apache.maven.plugins + maven-gpg-plugin + 1.5 + + + sign-artifacts + verify + + sign + + + + + + + + + + + org.apache.maven.plugins maven-compiler-plugin @@ -110,6 +141,9 @@ org.apache.maven.plugins maven-javadoc-plugin 2.9.1 + + 8 + attach-javadocs @@ -119,20 +153,6 @@ - - org.apache.maven.plugins - maven-gpg-plugin - 1.5 - - - sign-artifacts - verify - - sign - - - - org.apache.maven.plugins maven-enforcer-plugin @@ -155,7 +175,7 @@ org.bytedeco javacpp - 1.3.1 + ${javacpp.version} run-javacpp-parser @@ -234,8 +254,8 @@ org.bytedeco javacpp - 1.3.1 + ${javacpp.version} - \ No newline at end of file + diff --git a/src/main/cpp/fastText b/src/main/cpp/fastText index 9c9b9b2..cda295f 160000 --- a/src/main/cpp/fastText +++ b/src/main/cpp/fastText @@ -1 +1 @@ -Subproject commit 9c9b9b2eca4fdc0182182cb69234fffbb847d820 +Subproject commit cda295f1b5851df0a26a6ac2ab04230fb864a89d diff --git a/src/main/cpp/fasttext_wrapper.cc b/src/main/cpp/fasttext_wrapper.cc index ec5c90e..b2362af 100644 --- a/src/main/cpp/fasttext_wrapper.cc +++ b/src/main/cpp/fasttext_wrapper.cc @@ -74,13 +74,20 @@ namespace FastTextWrapper { const std::string& text, int32_t k) { std::vector> predictions; std::istringstream in(text); - fastText.predict(in, k, predictions); + fastText.predictLine(in, predictions, k, 0.0); return predictions; } std::vector FastTextApi::getVector(const std::string& word) { Vector vec(privateMembers->args_->dim); - fastText.getVector(vec, word); + fastText.getWordVector(vec, word); + return std::vector(vec.data(), vec.data() + vec.size()); + } + + std::vector FastTextApi::getSentenceVector(const std::string& sentence) { + Vector vec(privateMembers->args_->dim); + std::istringstream in(sentence); + fastText.getSentenceVector(in, vec); return std::vector(vec.data(), vec.data() + vec.size()); } diff --git a/src/main/cpp/fasttext_wrapper.h b/src/main/cpp/fasttext_wrapper.h index 10562ef..68758ec 100644 --- a/src/main/cpp/fasttext_wrapper.h +++ b/src/main/cpp/fasttext_wrapper.h @@ -29,6 +29,7 @@ namespace FastTextWrapper { std::vector predict(const std::string&, int32_t); std::vector> predictProba(const std::string&, int32_t); std::vector getVector(const std::string&); + std::vector getSentenceVector(const std::string&); std::vector getWords(); std::vector getLabels(); std::string getWord(int32_t); diff --git a/src/main/cpp/fasttext_wrapper_javacpp.h b/src/main/cpp/fasttext_wrapper_javacpp.h index 3991d2e..3666f44 100644 --- a/src/main/cpp/fasttext_wrapper_javacpp.h +++ b/src/main/cpp/fasttext_wrapper_javacpp.h @@ -1,12 +1,15 @@ // Added since VS 14.0 complains about missing std::iota #include #include "fastText/src/args.cc" +#include "fastText/src/densematrix.cc" #include "fastText/src/dictionary.cc" #include "fastText/src/fasttext.cc" +#include "fastText/src/loss.cc" #include "fastText/src/matrix.cc" +#include "fastText/src/meter.cc" #include "fastText/src/model.cc" #include "fastText/src/productquantizer.cc" -#include "fastText/src/qmatrix.cc" +#include "fastText/src/quantmatrix.cc" #include "fastText/src/vector.cc" #include "fastText/src/utils.cc" diff --git a/src/main/java/com/github/jfasttext/FastTextWrapper.java b/src/main/java/com/github/jfasttext/FastTextWrapper.java index c583d03..5fe8295 100644 --- a/src/main/java/com/github/jfasttext/FastTextWrapper.java +++ b/src/main/java/com/github/jfasttext/FastTextWrapper.java @@ -1,4 +1,4 @@ -// Targeted by JavaCPP version 1.3.1: DO NOT EDIT THIS FILE +// Targeted by JavaCPP version 1.5.1: DO NOT EDIT THIS FILE package com.github.jfasttext; @@ -13,7 +13,9 @@ public class FastTextWrapper extends com.github.jfasttext.config.FastTextWrapper static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public StringVector(Pointer p) { super(p); } + public StringVector(BytePointer value) { this(1); put(0, value); } public StringVector(BytePointer ... array) { this(array.length); put(array); } + public StringVector(String value) { this(1); put(0, value); } public StringVector(String ... array) { this(array.length); put(array); } public StringVector() { allocate(); } public StringVector(long n) { allocate(n); } @@ -21,13 +23,54 @@ public class FastTextWrapper extends com.github.jfasttext.config.FastTextWrapper private native void allocate(@Cast("size_t") long n); public native @Name("operator=") @ByRef StringVector put(@ByRef StringVector x); + public boolean empty() { return size() == 0; } public native long size(); + public void clear() { resize(0); } public native void resize(@Cast("size_t") long n); - @Index public native @StdString BytePointer get(@Cast("size_t") long i); + @Index(function = "at") public native @StdString BytePointer get(@Cast("size_t") long i); public native StringVector put(@Cast("size_t") long i, BytePointer value); - @ValueSetter @Index public native StringVector put(@Cast("size_t") long i, @StdString String value); + @ValueSetter @Index(function = "at") public native StringVector put(@Cast("size_t") long i, @StdString String value); + + public native @ByVal Iterator insert(@ByVal Iterator pos, @StdString BytePointer value); + public native @ByVal Iterator erase(@ByVal Iterator pos); + public native @ByVal Iterator begin(); + public native @ByVal Iterator end(); + @NoOffset @Name("iterator") public static class Iterator extends Pointer { + public Iterator(Pointer p) { super(p); } + public Iterator() { } + + public native @Name("operator++") @ByRef Iterator increment(); + public native @Name("operator==") boolean equals(@ByRef Iterator it); + public native @Name("operator*") @StdString BytePointer get(); + } + + public BytePointer[] get() { + BytePointer[] array = new BytePointer[size() < Integer.MAX_VALUE ? (int)size() : Integer.MAX_VALUE]; + for (int i = 0; i < array.length; i++) { + array[i] = get(i); + } + return array; + } + @Override public String toString() { + return java.util.Arrays.toString(get()); + } + public BytePointer pop_back() { + long size = size(); + BytePointer value = get(size - 1); + resize(size - 1); + return value; + } + public StringVector push_back(BytePointer value) { + long size = size(); + resize(size + 1); + return put(size, value); + } + public StringVector put(BytePointer value) { + if (size() != 1) { resize(1); } + return put(0, value); + } public StringVector put(BytePointer ... array) { if (size() != array.length) { resize(array.length); } for (int i = 0; i < array.length; i++) { @@ -36,6 +79,15 @@ public StringVector put(BytePointer ... array) { return this; } + public StringVector push_back(String value) { + long size = size(); + resize(size + 1); + return put(size, value); + } + public StringVector put(String value) { + if (size() != 1) { resize(1); } + return put(0, value); + } public StringVector put(String ... array) { if (size() != array.length) { resize(array.length); } for (int i = 0; i < array.length; i++) { @@ -49,6 +101,7 @@ public StringVector put(String ... array) { static { Loader.load(); } /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ public RealVector(Pointer p) { super(p); } + public RealVector(float value) { this(1); put(0, value); } public RealVector(float ... array) { this(array.length); put(array); } public RealVector() { allocate(); } public RealVector(long n) { allocate(n); } @@ -56,12 +109,53 @@ public StringVector put(String ... array) { private native void allocate(@Cast("size_t") long n); public native @Name("operator=") @ByRef RealVector put(@ByRef RealVector x); + public boolean empty() { return size() == 0; } public native long size(); + public void clear() { resize(0); } public native void resize(@Cast("size_t") long n); - @Index public native @ByRef float get(@Cast("size_t") long i); + @Index(function = "at") public native @ByRef float get(@Cast("size_t") long i); public native RealVector put(@Cast("size_t") long i, float value); + public native @ByVal Iterator insert(@ByVal Iterator pos, @ByRef float value); + public native @ByVal Iterator erase(@ByVal Iterator pos); + public native @ByVal Iterator begin(); + public native @ByVal Iterator end(); + @NoOffset @Name("iterator") public static class Iterator extends Pointer { + public Iterator(Pointer p) { super(p); } + public Iterator() { } + + public native @Name("operator++") @ByRef Iterator increment(); + public native @Name("operator==") boolean equals(@ByRef Iterator it); + public native @Name("operator*") @ByRef @Const float get(); + } + + public float[] get() { + float[] array = new float[size() < Integer.MAX_VALUE ? (int)size() : Integer.MAX_VALUE]; + for (int i = 0; i < array.length; i++) { + array[i] = get(i); + } + return array; + } + @Override public String toString() { + return java.util.Arrays.toString(get()); + } + + public float pop_back() { + long size = size(); + float value = get(size - 1); + resize(size - 1); + return value; + } + public RealVector push_back(float value) { + long size = size(); + resize(size + 1); + return put(size, value); + } + public RealVector put(float value) { + if (size() != 1) { resize(1); } + return put(0, value); + } public RealVector put(float ... array) { if (size() != array.length) { resize(array.length); } for (int i = 0; i < array.length; i++) { @@ -83,12 +177,14 @@ public RealVector put(float ... array) { private native void allocate(@Cast("size_t") long n); public native @Name("operator=") @ByRef FloatStringPairVector put(@ByRef FloatStringPairVector x); + public boolean empty() { return size() == 0; } public native long size(); + public void clear() { resize(0); } public native void resize(@Cast("size_t") long n); - @Index public native @ByRef float first(@Cast("size_t") long i); public native FloatStringPairVector first(@Cast("size_t") long i, float first); - @Index public native @StdString BytePointer second(@Cast("size_t") long i); public native FloatStringPairVector second(@Cast("size_t") long i, BytePointer second); - @MemberSetter @Index public native FloatStringPairVector second(@Cast("size_t") long i, @StdString String second); + @Index(function = "at") public native @ByRef float first(@Cast("size_t") long i); public native FloatStringPairVector first(@Cast("size_t") long i, float first); + @Index(function = "at") public native @StdString BytePointer second(@Cast("size_t") long i); public native FloatStringPairVector second(@Cast("size_t") long i, BytePointer second); + @MemberSetter @Index(function = "at") public native FloatStringPairVector second(@Cast("size_t") long i, @StdString String second); public FloatStringPairVector put(float[] firstValue, BytePointer[] secondValue) { for (int i = 0; i < firstValue.length && i < secondValue.length; i++) { @@ -201,6 +297,8 @@ public DoubleIntPair put(float firstValue, int secondValue) { public native @ByVal FloatStringPairVector predictProba(@StdString String arg0, int arg1); public native @ByVal RealVector getVector(@StdString BytePointer arg0); public native @ByVal RealVector getVector(@StdString String arg0); + public native @ByVal RealVector getSentenceVector(@StdString BytePointer arg0); + public native @ByVal RealVector getSentenceVector(@StdString String arg0); public native @ByVal StringVector getWords(); public native @ByVal StringVector getLabels(); public native @StdString BytePointer getWord(int arg0); @@ -235,12 +333,15 @@ public DoubleIntPair put(float firstValue, int secondValue) { // Added since VS 14.0 complains about missing std::iota // #include // #include "fastText/src/args.cc" +// #include "fastText/src/densematrix.cc" // #include "fastText/src/dictionary.cc" // #include "fastText/src/fasttext.cc" +// #include "fastText/src/loss.cc" // #include "fastText/src/matrix.cc" +// #include "fastText/src/meter.cc" // #include "fastText/src/model.cc" // #include "fastText/src/productquantizer.cc" -// #include "fastText/src/qmatrix.cc" +// #include "fastText/src/quantmatrix.cc" // #include "fastText/src/vector.cc" // #include "fastText/src/utils.cc" diff --git a/src/main/java/com/github/jfasttext/JFastText.java b/src/main/java/com/github/jfasttext/JFastText.java index 1f2d333..f728f78 100644 --- a/src/main/java/com/github/jfasttext/JFastText.java +++ b/src/main/java/com/github/jfasttext/JFastText.java @@ -3,15 +3,40 @@ import org.bytedeco.javacpp.PointerPointer; import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.MalformedURLException; +import java.net.URI; +import java.net.URL; +import java.nio.file.CopyOption; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardCopyOption; import java.util.ArrayList; import java.util.List; public class JFastText { - private FastTextWrapper.FastTextApi fta; + private FastTextWrapper.FastTextApi fta = new FastTextWrapper.FastTextApi();; public JFastText() { - fta = new FastTextWrapper.FastTextApi(); + } + + public JFastText(final String modelFile) { + loadModel(modelFile); + } + + public JFastText(final URI modelUri) throws IOException { + loadModel(modelUri); + } + + public JFastText(final URL modelUrl) throws IOException { + loadModel(modelUrl); + } + + public JFastText(final InputStream modelStream) throws IOException { + loadModel(modelStream); } public void runCmd(String[] args) { @@ -22,9 +47,9 @@ public void runCmd(String[] args) { fta.runCmd(cArgs.length, new PointerPointer(cArgs)); } - public void loadModel(String modelFile) { + public void loadModel(String modelFile) throws IllegalArgumentException { if (!new File(modelFile).exists()) { - throw new IllegalArgumentException("Model file doesn't exist!"); + throw new IllegalArgumentException("Model file " + modelFile + " doesn't exist!"); } if (!fta.checkModel(modelFile)) { throw new IllegalArgumentException( @@ -33,6 +58,39 @@ public void loadModel(String modelFile) { fta.loadModel(modelFile); } + /** + * Loads model from location specified by URI, by copying its content into local file & then loading it. + * + * @param modelUri location of the model + */ + public void loadModel(URI modelUri) throws IOException { + loadModel(modelUri.toURL()); + } + + /** + * Loads model from location specified by URL, by copying its content into local file & then loading it. + * + * @param modelUrl location of the model + */ + public void loadModel(URL modelUrl) throws IOException { + loadModel(modelUrl.openStream()); + } + + /** + * Loads model given InputStream, by copying its content into local file & then loading it. + * + * @param modelStream stream for model + */ + public void loadModel(InputStream modelStream) throws IOException { + Path tmpFile = Files.createTempFile("jft-", ".model"); + try { + Files.copy(modelStream, tmpFile, StandardCopyOption.REPLACE_EXISTING); + loadModel(tmpFile.toString()); + } finally { + Files.deleteIfExists(tmpFile); + } + } + public void unloadModel() { fta.unloadModel(); } @@ -84,15 +142,39 @@ public List predictProba(String text, int k) { return probaPredictions; } + @Deprecated public List getVector(String word) { - FastTextWrapper.RealVector rv = fta.getVector(word); + float[] vector = getArrayVector(word); List wordVec = new ArrayList<>(); - for (int i = 0; i < rv.size(); i++) { - wordVec.add(rv.get(i)); + for (float f : vector) { + wordVec.add(f); } return wordVec; } + public float[] getArrayVector(String word) { + FastTextWrapper.RealVector rv = fta.getVector(word); + return rv.get(); + } + + @Deprecated + public List getSentenceVector(String sentence) { + float[] vector = getArraySentenceVector(sentence); + List sentenceVec = new ArrayList<>(); + for (float f : vector) { + sentenceVec.add(f); + } + return sentenceVec; + } + + public float[] getArraySentenceVector(String sentence) { + if (!sentence.endsWith("\n")) { + sentence += "\n"; + } + FastTextWrapper.RealVector rv = fta.getSentenceVector(sentence); + return rv.get(); + } + public int getNWords() { return fta.getNWords(); } diff --git a/src/test/java/com/github/jfasttext/JFastTextTest.java b/src/test/java/com/github/jfasttext/JFastTextTest.java index b6f1521..1fd7ffc 100644 --- a/src/test/java/com/github/jfasttext/JFastTextTest.java +++ b/src/test/java/com/github/jfasttext/JFastTextTest.java @@ -4,8 +4,17 @@ import org.junit.Test; import org.junit.runners.MethodSorters; +import java.io.File; +import java.io.FileInputStream; +import java.io.InputStream; +import java.net.URL; +import java.util.Arrays; import java.util.List; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + @FixMethodOrder(MethodSorters.NAME_ASCENDING) public class JFastTextTest { @@ -16,7 +25,9 @@ public void test01TrainSupervisedCmd() { jft.runCmd(new String[] { "supervised", "-input", "src/test/resources/data/labeled_data.txt", - "-output", "src/test/resources/models/supervised.model" + "-output", "src/test/resources/models/supervised.model", + "-wordNgrams", "3", + "-bucket", "100" }); } @@ -56,9 +67,18 @@ public void test04Predict() throws Exception { } @Test - public void test05PredictProba() throws Exception { + public void test04getArrayVector() throws Exception { JFastText jft = new JFastText(); jft.loadModel("src/test/resources/models/supervised.model.bin"); + String text = "I like soccer"; + float[] predictedArray = jft.getArrayVector(text); + float[] expected = new float[100]; + assertArrayEquals("", predictedArray, expected, 0.1f); + } + + @Test + public void test05PredictProba() throws Exception { + JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin"); String text = "What is the most popular sport in the US ?"; JFastText.ProbLabel predictedProbLabel = jft.predictProba(text); System.out.printf("\nText: '%s', label: '%s', probability: %f\n", @@ -67,8 +87,7 @@ public void test05PredictProba() throws Exception { @Test public void test06MultiPredictProba() throws Exception { - JFastText jft = new JFastText(); - jft.loadModel("src/test/resources/models/supervised.model.bin"); + JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin"); String text = "Do you like soccer ?"; System.out.printf("Text: '%s'\n", text); for (JFastText.ProbLabel predictedProbLabel: jft.predictProba(text, 2)) { @@ -79,21 +98,51 @@ public void test06MultiPredictProba() throws Exception { @Test public void test07GetVector() throws Exception { + try (InputStream is = new FileInputStream("src/test/resources/models/supervised.model.bin")) { + JFastText jft = new JFastText(is); + String word = "soccer"; + List vec = jft.getVector(word); + System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec); + } + } + + @Test + public void test07GetArrayVector() throws Exception { + try (InputStream is = new FileInputStream("src/test/resources/models/supervised.model.bin")) { + JFastText jft = new JFastText(is); + String word = "soccer"; + float[] vec = jft.getArrayVector(word); + System.out.printf("\nWord embedding vector of '%s': %s\n", word, Arrays.toString(vec)); + } + } + + @Test + public void test08GetSentenceVector() throws Exception { + JFastText jft = new JFastText(); + jft.loadModel("src/test/resources/models/supervised.model.bin"); + String word = "soccers"; + List vec = jft.getSentenceVector(word); + int expectedSize = 100; + assertEquals(expectedSize, vec.size()); + } + + @Test + public void test08GetArraySentenceVector() throws Exception { JFastText jft = new JFastText(); jft.loadModel("src/test/resources/models/supervised.model.bin"); - String word = "soccer"; - List vec = jft.getVector(word); - System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec); + String word = "soccers"; + float[] vec = jft.getArraySentenceVector(word); + int expectedSize = 100; + assertEquals(expectedSize, vec.length); } /** * Test retrieving model's information: words, labels, learning rate, etc. */ @Test - public void test08ModelInfo() throws Exception { + public void test09ModelInfo() throws Exception { System.out.printf("\nSupervised model information:\n"); - JFastText jft = new JFastText(); - jft.loadModel("src/test/resources/models/supervised.model.bin"); + JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin"); System.out.printf("\tnumber of words = %d\n", jft.getNWords()); System.out.printf("\twords = %s\n", jft.getWords()); System.out.printf("\tlearning rate = %g\n", jft.getLr()); @@ -113,11 +162,23 @@ public void test08ModelInfo() throws Exception { * allocated by native function calls). */ @Test - public void test09ModelUnloading() throws Exception { + public void test10ModelUnloading() throws Exception { JFastText jft = new JFastText(); System.out.println("\nLoading model ..."); jft.loadModel("src/test/resources/models/supervised.model.bin"); System.out.println("Unloading model ..."); jft.unloadModel(); } + + /** + * Loads model from specified URL (resource, web, etc.) + */ + @Test + public void test10ModelFromURL() throws Exception { + String modelFile = "src/test/resources/models/supervised.model.bin"; + URL modelUrl = new File(modelFile).toURI().toURL(); + assertNotNull(String.format("Failed to locate model '%s'", modelFile), modelUrl); + JFastText jft = new JFastText(modelUrl); + System.out.printf("\tnumber of words = %d\n", jft.getNWords()); + } }