diff --git a/.travis.yml b/.travis.yml
index 5c71a1e..4673191 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,9 +1,8 @@
language: java
jdk:
- - oraclejdk8
+ - openjdk11
os:
- linux
- - osx
cache: bundler
# Setting install to 'true' to prevent Travis CI from installing depedencies via:
# "mvn install -DskipTests=true -Dmaven.javadoc.skip=true -B -V" which fails due to missing the GPG secret key.
diff --git a/CHANGELOG.md b/CHANGELOG.md
new file mode 100644
index 0000000..4be0e6c
--- /dev/null
+++ b/CHANGELOG.md
@@ -0,0 +1,7 @@
+# Changelog
+
+## Version 0.9.1
+
+* Use [FastText v0.9.1](https://github.com/facebookresearch/fastText/releases/tag/v0.9.1)
+* Minor bug fixes
+
\ No newline at end of file
diff --git a/README.md b/README.md
index d32185a..3eb99f9 100644
--- a/README.md
+++ b/README.md
@@ -1,10 +1,11 @@
-[![Build Status](https://travis-ci.org/vinhkhuc/JFastText.svg?branch=master)](https://travis-ci.org/vinhkhuc/JFastText)
+[![Build Status](https://travis-ci.org/carschno/JFastText.svg?branch=master)](https://travis-ci.org/carschno/JFastText)
Table of Contents
=================
* [Introduction](#introduction)
* [Maven Dependency](#maven-dependency)
+ * [Windows and Mac OSX](#windows-and-mac-os-x)
* [Building](#building)
* [Quick Application - Language Identification](#quick-application-\--language-identification)
* [Detailed Examples](#detailed-examples)
@@ -12,6 +13,7 @@ Table of Contents
* [FastText's Command Line](#fasttexts-command-line)
* [License](#license)
* [References](#references)
+ * [Changelog](CHANGELOG.md)
## Introduction
@@ -28,23 +30,39 @@ JFastText is ideal for building fast text classifiers in Java.
## Maven Dependency
```xml
- com.github.vinhkhuc
+ io.github.carschno
jfasttext
- 0.4
+ 0.9.1
```
-The Jar package on Maven Central is bundled with precompiled fastText library for Windows, Linux and
-MacOSX 64bit.
+The Jar package on Maven Central is bundled with precompiled fastText library for ~~Windows,~~ Linux ~~and
+MacOSX~~ 64bit.
+
+### Windows and Mac OS X
+
+Currently, the Maven dependency only contains binaries for Linux (64 bit), _not_ for Windows or Mac OS X.
+In order to use JFastText for Windows or Mac OS X (or any other system), you need to build it yourself (see [below](#building)).
## Building
-C++ compiler (g++ on Mac/Linux or cl.exe on Windows) is required to compile fastText's code.
+C++ compiler (g++ on Mac/Linux or `cl.exe` on Windows) is required to compile fastText's code.
```bash
-git clone --recursive https://github.com/vinhkhuc/JFastText
+git clone --recursive https://github.com/carschno/JFastText
cd JFastText
+git submodule init
+git submodule update
mvn package
```
+### Building on Windows
+
+The (automatic) build seems to fail on some Windows systems/C++ compilers.
+See [this issue](https://github.com/carschno/JFastText/issues/5#issuecomment-546485377):
+
+> I used MS's developer tools, not the full-blown Visual Studio. If I run `cl` directly, the compilation fails with the same error.
+>
+> I was able to build on Windows by changing the call to `cl.exe` and running it outside the Maven build. I changed one parameter in the call to `cl`: I use `/MT` (whereas Maven uses `/MD`). Bundling the generated DLLs works fine.
+
## Quick Application - Language Identification
JFastText can use FastText's pretrained models directly. Language identification models can be downloaded [here](https://fasttext.cc/docs/en/language-identification.html).
In this quick example, we will use the [quantized model](https://s3-us-west-1.amazonaws.com/fasttext-vectors/supervised_models/lid.176.ftz)
diff --git a/examples/api/pom.xml b/examples/api/pom.xml
index c2002ae..0ca4a9d 100644
--- a/examples/api/pom.xml
+++ b/examples/api/pom.xml
@@ -7,12 +7,15 @@
com.github.vinhkhuc
java_sandbox
0.1
-
+
+ 1.8
+ 1.8
+
- com.github.vinhkhuc
+ io.github.carschno
jfasttext
- 0.4
+ 0.5.0
\ No newline at end of file
diff --git a/pom.xml b/pom.xml
index 3899d17..01b5822 100644
--- a/pom.xml
+++ b/pom.xml
@@ -4,9 +4,9 @@
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
4.0.0
- com.github.vinhkhuc
+ io.github.carschno
jfasttext
- 0.4
+ 0.9.1
Java interface for fastText
JFastText is a Java interface for fastText, a library for efficient learning of
@@ -22,6 +22,15 @@
+
+ carschno
+ Carsten Schnober
+ http://github.com/carschno
+
+ developer
+ maintainer
+
+
vinhkhuc
Vinh Khuc
@@ -34,13 +43,14 @@
- scm:git:https://github.com/vinhkhuc/JFastText.git
- scm:git:git@github.com:vinhkhuc/JFastText.git
- https://github.com/vinhkhuc/JFastText
+ scm:git:https://github.com/carschno/JFastText.git
+ scm:git:git@github.com:carschno/JFastText.git
+ https://github.com/carschno/JFastText
UTF-8
+ 1.5.1
@@ -65,21 +75,42 @@
-Xdoclint:none
-
-
-
-
-
+
+ release
+
+
+
org.sonatype.plugins
nexus-staging-maven-plugin
1.6.3
true
- ossrh
- https://oss.sonatype.org/
- true
+ ossrh
+ https://oss.sonatype.org/
+ false
-
+
+
+ org.apache.maven.plugins
+ maven-gpg-plugin
+ 1.5
+
+
+ sign-artifacts
+ verify
+
+ sign
+
+
+
+
+
+
+
+
+
+
+
org.apache.maven.plugins
maven-compiler-plugin
@@ -110,6 +141,9 @@
org.apache.maven.plugins
maven-javadoc-plugin
2.9.1
+
+
+
attach-javadocs
@@ -119,20 +153,6 @@
-
- org.apache.maven.plugins
- maven-gpg-plugin
- 1.5
-
-
- sign-artifacts
- verify
-
- sign
-
-
-
-
org.apache.maven.plugins
maven-enforcer-plugin
@@ -155,7 +175,7 @@
org.bytedeco
javacpp
- 1.3.1
+ ${javacpp.version}
run-javacpp-parser
@@ -234,8 +254,8 @@
org.bytedeco
javacpp
- 1.3.1
+ ${javacpp.version}
-
\ No newline at end of file
+
diff --git a/src/main/cpp/fastText b/src/main/cpp/fastText
index 9c9b9b2..cda295f 160000
--- a/src/main/cpp/fastText
+++ b/src/main/cpp/fastText
@@ -1 +1 @@
-Subproject commit 9c9b9b2eca4fdc0182182cb69234fffbb847d820
+Subproject commit cda295f1b5851df0a26a6ac2ab04230fb864a89d
diff --git a/src/main/cpp/fasttext_wrapper.cc b/src/main/cpp/fasttext_wrapper.cc
index ec5c90e..b2362af 100644
--- a/src/main/cpp/fasttext_wrapper.cc
+++ b/src/main/cpp/fasttext_wrapper.cc
@@ -74,13 +74,20 @@ namespace FastTextWrapper {
const std::string& text, int32_t k) {
std::vector> predictions;
std::istringstream in(text);
- fastText.predict(in, k, predictions);
+ fastText.predictLine(in, predictions, k, 0.0);
return predictions;
}
std::vector FastTextApi::getVector(const std::string& word) {
Vector vec(privateMembers->args_->dim);
- fastText.getVector(vec, word);
+ fastText.getWordVector(vec, word);
+ return std::vector(vec.data(), vec.data() + vec.size());
+ }
+
+ std::vector FastTextApi::getSentenceVector(const std::string& sentence) {
+ Vector vec(privateMembers->args_->dim);
+ std::istringstream in(sentence);
+ fastText.getSentenceVector(in, vec);
return std::vector(vec.data(), vec.data() + vec.size());
}
diff --git a/src/main/cpp/fasttext_wrapper.h b/src/main/cpp/fasttext_wrapper.h
index 10562ef..68758ec 100644
--- a/src/main/cpp/fasttext_wrapper.h
+++ b/src/main/cpp/fasttext_wrapper.h
@@ -29,6 +29,7 @@ namespace FastTextWrapper {
std::vector predict(const std::string&, int32_t);
std::vector> predictProba(const std::string&, int32_t);
std::vector getVector(const std::string&);
+ std::vector getSentenceVector(const std::string&);
std::vector getWords();
std::vector getLabels();
std::string getWord(int32_t);
diff --git a/src/main/cpp/fasttext_wrapper_javacpp.h b/src/main/cpp/fasttext_wrapper_javacpp.h
index 3991d2e..3666f44 100644
--- a/src/main/cpp/fasttext_wrapper_javacpp.h
+++ b/src/main/cpp/fasttext_wrapper_javacpp.h
@@ -1,12 +1,15 @@
// Added since VS 14.0 complains about missing std::iota
#include
#include "fastText/src/args.cc"
+#include "fastText/src/densematrix.cc"
#include "fastText/src/dictionary.cc"
#include "fastText/src/fasttext.cc"
+#include "fastText/src/loss.cc"
#include "fastText/src/matrix.cc"
+#include "fastText/src/meter.cc"
#include "fastText/src/model.cc"
#include "fastText/src/productquantizer.cc"
-#include "fastText/src/qmatrix.cc"
+#include "fastText/src/quantmatrix.cc"
#include "fastText/src/vector.cc"
#include "fastText/src/utils.cc"
diff --git a/src/main/java/com/github/jfasttext/FastTextWrapper.java b/src/main/java/com/github/jfasttext/FastTextWrapper.java
index c583d03..5fe8295 100644
--- a/src/main/java/com/github/jfasttext/FastTextWrapper.java
+++ b/src/main/java/com/github/jfasttext/FastTextWrapper.java
@@ -1,4 +1,4 @@
-// Targeted by JavaCPP version 1.3.1: DO NOT EDIT THIS FILE
+// Targeted by JavaCPP version 1.5.1: DO NOT EDIT THIS FILE
package com.github.jfasttext;
@@ -13,7 +13,9 @@ public class FastTextWrapper extends com.github.jfasttext.config.FastTextWrapper
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public StringVector(Pointer p) { super(p); }
+ public StringVector(BytePointer value) { this(1); put(0, value); }
public StringVector(BytePointer ... array) { this(array.length); put(array); }
+ public StringVector(String value) { this(1); put(0, value); }
public StringVector(String ... array) { this(array.length); put(array); }
public StringVector() { allocate(); }
public StringVector(long n) { allocate(n); }
@@ -21,13 +23,54 @@ public class FastTextWrapper extends com.github.jfasttext.config.FastTextWrapper
private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef StringVector put(@ByRef StringVector x);
+ public boolean empty() { return size() == 0; }
public native long size();
+ public void clear() { resize(0); }
public native void resize(@Cast("size_t") long n);
- @Index public native @StdString BytePointer get(@Cast("size_t") long i);
+ @Index(function = "at") public native @StdString BytePointer get(@Cast("size_t") long i);
public native StringVector put(@Cast("size_t") long i, BytePointer value);
- @ValueSetter @Index public native StringVector put(@Cast("size_t") long i, @StdString String value);
+ @ValueSetter @Index(function = "at") public native StringVector put(@Cast("size_t") long i, @StdString String value);
+
+ public native @ByVal Iterator insert(@ByVal Iterator pos, @StdString BytePointer value);
+ public native @ByVal Iterator erase(@ByVal Iterator pos);
+ public native @ByVal Iterator begin();
+ public native @ByVal Iterator end();
+ @NoOffset @Name("iterator") public static class Iterator extends Pointer {
+ public Iterator(Pointer p) { super(p); }
+ public Iterator() { }
+
+ public native @Name("operator++") @ByRef Iterator increment();
+ public native @Name("operator==") boolean equals(@ByRef Iterator it);
+ public native @Name("operator*") @StdString BytePointer get();
+ }
+
+ public BytePointer[] get() {
+ BytePointer[] array = new BytePointer[size() < Integer.MAX_VALUE ? (int)size() : Integer.MAX_VALUE];
+ for (int i = 0; i < array.length; i++) {
+ array[i] = get(i);
+ }
+ return array;
+ }
+ @Override public String toString() {
+ return java.util.Arrays.toString(get());
+ }
+ public BytePointer pop_back() {
+ long size = size();
+ BytePointer value = get(size - 1);
+ resize(size - 1);
+ return value;
+ }
+ public StringVector push_back(BytePointer value) {
+ long size = size();
+ resize(size + 1);
+ return put(size, value);
+ }
+ public StringVector put(BytePointer value) {
+ if (size() != 1) { resize(1); }
+ return put(0, value);
+ }
public StringVector put(BytePointer ... array) {
if (size() != array.length) { resize(array.length); }
for (int i = 0; i < array.length; i++) {
@@ -36,6 +79,15 @@ public StringVector put(BytePointer ... array) {
return this;
}
+ public StringVector push_back(String value) {
+ long size = size();
+ resize(size + 1);
+ return put(size, value);
+ }
+ public StringVector put(String value) {
+ if (size() != 1) { resize(1); }
+ return put(0, value);
+ }
public StringVector put(String ... array) {
if (size() != array.length) { resize(array.length); }
for (int i = 0; i < array.length; i++) {
@@ -49,6 +101,7 @@ public StringVector put(String ... array) {
static { Loader.load(); }
/** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */
public RealVector(Pointer p) { super(p); }
+ public RealVector(float value) { this(1); put(0, value); }
public RealVector(float ... array) { this(array.length); put(array); }
public RealVector() { allocate(); }
public RealVector(long n) { allocate(n); }
@@ -56,12 +109,53 @@ public StringVector put(String ... array) {
private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef RealVector put(@ByRef RealVector x);
+ public boolean empty() { return size() == 0; }
public native long size();
+ public void clear() { resize(0); }
public native void resize(@Cast("size_t") long n);
- @Index public native @ByRef float get(@Cast("size_t") long i);
+ @Index(function = "at") public native @ByRef float get(@Cast("size_t") long i);
public native RealVector put(@Cast("size_t") long i, float value);
+ public native @ByVal Iterator insert(@ByVal Iterator pos, @ByRef float value);
+ public native @ByVal Iterator erase(@ByVal Iterator pos);
+ public native @ByVal Iterator begin();
+ public native @ByVal Iterator end();
+ @NoOffset @Name("iterator") public static class Iterator extends Pointer {
+ public Iterator(Pointer p) { super(p); }
+ public Iterator() { }
+
+ public native @Name("operator++") @ByRef Iterator increment();
+ public native @Name("operator==") boolean equals(@ByRef Iterator it);
+ public native @Name("operator*") @ByRef @Const float get();
+ }
+
+ public float[] get() {
+ float[] array = new float[size() < Integer.MAX_VALUE ? (int)size() : Integer.MAX_VALUE];
+ for (int i = 0; i < array.length; i++) {
+ array[i] = get(i);
+ }
+ return array;
+ }
+ @Override public String toString() {
+ return java.util.Arrays.toString(get());
+ }
+
+ public float pop_back() {
+ long size = size();
+ float value = get(size - 1);
+ resize(size - 1);
+ return value;
+ }
+ public RealVector push_back(float value) {
+ long size = size();
+ resize(size + 1);
+ return put(size, value);
+ }
+ public RealVector put(float value) {
+ if (size() != 1) { resize(1); }
+ return put(0, value);
+ }
public RealVector put(float ... array) {
if (size() != array.length) { resize(array.length); }
for (int i = 0; i < array.length; i++) {
@@ -83,12 +177,14 @@ public RealVector put(float ... array) {
private native void allocate(@Cast("size_t") long n);
public native @Name("operator=") @ByRef FloatStringPairVector put(@ByRef FloatStringPairVector x);
+ public boolean empty() { return size() == 0; }
public native long size();
+ public void clear() { resize(0); }
public native void resize(@Cast("size_t") long n);
- @Index public native @ByRef float first(@Cast("size_t") long i); public native FloatStringPairVector first(@Cast("size_t") long i, float first);
- @Index public native @StdString BytePointer second(@Cast("size_t") long i); public native FloatStringPairVector second(@Cast("size_t") long i, BytePointer second);
- @MemberSetter @Index public native FloatStringPairVector second(@Cast("size_t") long i, @StdString String second);
+ @Index(function = "at") public native @ByRef float first(@Cast("size_t") long i); public native FloatStringPairVector first(@Cast("size_t") long i, float first);
+ @Index(function = "at") public native @StdString BytePointer second(@Cast("size_t") long i); public native FloatStringPairVector second(@Cast("size_t") long i, BytePointer second);
+ @MemberSetter @Index(function = "at") public native FloatStringPairVector second(@Cast("size_t") long i, @StdString String second);
public FloatStringPairVector put(float[] firstValue, BytePointer[] secondValue) {
for (int i = 0; i < firstValue.length && i < secondValue.length; i++) {
@@ -201,6 +297,8 @@ public DoubleIntPair put(float firstValue, int secondValue) {
public native @ByVal FloatStringPairVector predictProba(@StdString String arg0, int arg1);
public native @ByVal RealVector getVector(@StdString BytePointer arg0);
public native @ByVal RealVector getVector(@StdString String arg0);
+ public native @ByVal RealVector getSentenceVector(@StdString BytePointer arg0);
+ public native @ByVal RealVector getSentenceVector(@StdString String arg0);
public native @ByVal StringVector getWords();
public native @ByVal StringVector getLabels();
public native @StdString BytePointer getWord(int arg0);
@@ -235,12 +333,15 @@ public DoubleIntPair put(float firstValue, int secondValue) {
// Added since VS 14.0 complains about missing std::iota
// #include
// #include "fastText/src/args.cc"
+// #include "fastText/src/densematrix.cc"
// #include "fastText/src/dictionary.cc"
// #include "fastText/src/fasttext.cc"
+// #include "fastText/src/loss.cc"
// #include "fastText/src/matrix.cc"
+// #include "fastText/src/meter.cc"
// #include "fastText/src/model.cc"
// #include "fastText/src/productquantizer.cc"
-// #include "fastText/src/qmatrix.cc"
+// #include "fastText/src/quantmatrix.cc"
// #include "fastText/src/vector.cc"
// #include "fastText/src/utils.cc"
diff --git a/src/main/java/com/github/jfasttext/JFastText.java b/src/main/java/com/github/jfasttext/JFastText.java
index 1f2d333..f728f78 100644
--- a/src/main/java/com/github/jfasttext/JFastText.java
+++ b/src/main/java/com/github/jfasttext/JFastText.java
@@ -3,15 +3,40 @@
import org.bytedeco.javacpp.PointerPointer;
import java.io.File;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.net.MalformedURLException;
+import java.net.URI;
+import java.net.URL;
+import java.nio.file.CopyOption;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.nio.file.StandardCopyOption;
import java.util.ArrayList;
import java.util.List;
public class JFastText {
- private FastTextWrapper.FastTextApi fta;
+ private FastTextWrapper.FastTextApi fta = new FastTextWrapper.FastTextApi();;
public JFastText() {
- fta = new FastTextWrapper.FastTextApi();
+ }
+
+ public JFastText(final String modelFile) {
+ loadModel(modelFile);
+ }
+
+ public JFastText(final URI modelUri) throws IOException {
+ loadModel(modelUri);
+ }
+
+ public JFastText(final URL modelUrl) throws IOException {
+ loadModel(modelUrl);
+ }
+
+ public JFastText(final InputStream modelStream) throws IOException {
+ loadModel(modelStream);
}
public void runCmd(String[] args) {
@@ -22,9 +47,9 @@ public void runCmd(String[] args) {
fta.runCmd(cArgs.length, new PointerPointer(cArgs));
}
- public void loadModel(String modelFile) {
+ public void loadModel(String modelFile) throws IllegalArgumentException {
if (!new File(modelFile).exists()) {
- throw new IllegalArgumentException("Model file doesn't exist!");
+ throw new IllegalArgumentException("Model file " + modelFile + " doesn't exist!");
}
if (!fta.checkModel(modelFile)) {
throw new IllegalArgumentException(
@@ -33,6 +58,39 @@ public void loadModel(String modelFile) {
fta.loadModel(modelFile);
}
+ /**
+ * Loads model from location specified by URI, by copying its content into local file & then loading it.
+ *
+ * @param modelUri location of the model
+ */
+ public void loadModel(URI modelUri) throws IOException {
+ loadModel(modelUri.toURL());
+ }
+
+ /**
+ * Loads model from location specified by URL, by copying its content into local file & then loading it.
+ *
+ * @param modelUrl location of the model
+ */
+ public void loadModel(URL modelUrl) throws IOException {
+ loadModel(modelUrl.openStream());
+ }
+
+ /**
+ * Loads model given InputStream, by copying its content into local file & then loading it.
+ *
+ * @param modelStream stream for model
+ */
+ public void loadModel(InputStream modelStream) throws IOException {
+ Path tmpFile = Files.createTempFile("jft-", ".model");
+ try {
+ Files.copy(modelStream, tmpFile, StandardCopyOption.REPLACE_EXISTING);
+ loadModel(tmpFile.toString());
+ } finally {
+ Files.deleteIfExists(tmpFile);
+ }
+ }
+
public void unloadModel() {
fta.unloadModel();
}
@@ -84,15 +142,39 @@ public List predictProba(String text, int k) {
return probaPredictions;
}
+ @Deprecated
public List getVector(String word) {
- FastTextWrapper.RealVector rv = fta.getVector(word);
+ float[] vector = getArrayVector(word);
List wordVec = new ArrayList<>();
- for (int i = 0; i < rv.size(); i++) {
- wordVec.add(rv.get(i));
+ for (float f : vector) {
+ wordVec.add(f);
}
return wordVec;
}
+ public float[] getArrayVector(String word) {
+ FastTextWrapper.RealVector rv = fta.getVector(word);
+ return rv.get();
+ }
+
+ @Deprecated
+ public List getSentenceVector(String sentence) {
+ float[] vector = getArraySentenceVector(sentence);
+ List sentenceVec = new ArrayList<>();
+ for (float f : vector) {
+ sentenceVec.add(f);
+ }
+ return sentenceVec;
+ }
+
+ public float[] getArraySentenceVector(String sentence) {
+ if (!sentence.endsWith("\n")) {
+ sentence += "\n";
+ }
+ FastTextWrapper.RealVector rv = fta.getSentenceVector(sentence);
+ return rv.get();
+ }
+
public int getNWords() {
return fta.getNWords();
}
diff --git a/src/test/java/com/github/jfasttext/JFastTextTest.java b/src/test/java/com/github/jfasttext/JFastTextTest.java
index b6f1521..1fd7ffc 100644
--- a/src/test/java/com/github/jfasttext/JFastTextTest.java
+++ b/src/test/java/com/github/jfasttext/JFastTextTest.java
@@ -4,8 +4,17 @@
import org.junit.Test;
import org.junit.runners.MethodSorters;
+import java.io.File;
+import java.io.FileInputStream;
+import java.io.InputStream;
+import java.net.URL;
+import java.util.Arrays;
import java.util.List;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+
@FixMethodOrder(MethodSorters.NAME_ASCENDING)
public class JFastTextTest {
@@ -16,7 +25,9 @@ public void test01TrainSupervisedCmd() {
jft.runCmd(new String[] {
"supervised",
"-input", "src/test/resources/data/labeled_data.txt",
- "-output", "src/test/resources/models/supervised.model"
+ "-output", "src/test/resources/models/supervised.model",
+ "-wordNgrams", "3",
+ "-bucket", "100"
});
}
@@ -56,9 +67,18 @@ public void test04Predict() throws Exception {
}
@Test
- public void test05PredictProba() throws Exception {
+ public void test04getArrayVector() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
+ String text = "I like soccer";
+ float[] predictedArray = jft.getArrayVector(text);
+ float[] expected = new float[100];
+ assertArrayEquals("", predictedArray, expected, 0.1f);
+ }
+
+ @Test
+ public void test05PredictProba() throws Exception {
+ JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin");
String text = "What is the most popular sport in the US ?";
JFastText.ProbLabel predictedProbLabel = jft.predictProba(text);
System.out.printf("\nText: '%s', label: '%s', probability: %f\n",
@@ -67,8 +87,7 @@ public void test05PredictProba() throws Exception {
@Test
public void test06MultiPredictProba() throws Exception {
- JFastText jft = new JFastText();
- jft.loadModel("src/test/resources/models/supervised.model.bin");
+ JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin");
String text = "Do you like soccer ?";
System.out.printf("Text: '%s'\n", text);
for (JFastText.ProbLabel predictedProbLabel: jft.predictProba(text, 2)) {
@@ -79,21 +98,51 @@ public void test06MultiPredictProba() throws Exception {
@Test
public void test07GetVector() throws Exception {
+ try (InputStream is = new FileInputStream("src/test/resources/models/supervised.model.bin")) {
+ JFastText jft = new JFastText(is);
+ String word = "soccer";
+ List vec = jft.getVector(word);
+ System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec);
+ }
+ }
+
+ @Test
+ public void test07GetArrayVector() throws Exception {
+ try (InputStream is = new FileInputStream("src/test/resources/models/supervised.model.bin")) {
+ JFastText jft = new JFastText(is);
+ String word = "soccer";
+ float[] vec = jft.getArrayVector(word);
+ System.out.printf("\nWord embedding vector of '%s': %s\n", word, Arrays.toString(vec));
+ }
+ }
+
+ @Test
+ public void test08GetSentenceVector() throws Exception {
+ JFastText jft = new JFastText();
+ jft.loadModel("src/test/resources/models/supervised.model.bin");
+ String word = "soccers";
+ List vec = jft.getSentenceVector(word);
+ int expectedSize = 100;
+ assertEquals(expectedSize, vec.size());
+ }
+
+ @Test
+ public void test08GetArraySentenceVector() throws Exception {
JFastText jft = new JFastText();
jft.loadModel("src/test/resources/models/supervised.model.bin");
- String word = "soccer";
- List vec = jft.getVector(word);
- System.out.printf("\nWord embedding vector of '%s': %s\n", word, vec);
+ String word = "soccers";
+ float[] vec = jft.getArraySentenceVector(word);
+ int expectedSize = 100;
+ assertEquals(expectedSize, vec.length);
}
/**
* Test retrieving model's information: words, labels, learning rate, etc.
*/
@Test
- public void test08ModelInfo() throws Exception {
+ public void test09ModelInfo() throws Exception {
System.out.printf("\nSupervised model information:\n");
- JFastText jft = new JFastText();
- jft.loadModel("src/test/resources/models/supervised.model.bin");
+ JFastText jft = new JFastText("src/test/resources/models/supervised.model.bin");
System.out.printf("\tnumber of words = %d\n", jft.getNWords());
System.out.printf("\twords = %s\n", jft.getWords());
System.out.printf("\tlearning rate = %g\n", jft.getLr());
@@ -113,11 +162,23 @@ public void test08ModelInfo() throws Exception {
* allocated by native function calls).
*/
@Test
- public void test09ModelUnloading() throws Exception {
+ public void test10ModelUnloading() throws Exception {
JFastText jft = new JFastText();
System.out.println("\nLoading model ...");
jft.loadModel("src/test/resources/models/supervised.model.bin");
System.out.println("Unloading model ...");
jft.unloadModel();
}
+
+ /**
+ * Loads model from specified URL (resource, web, etc.)
+ */
+ @Test
+ public void test10ModelFromURL() throws Exception {
+ String modelFile = "src/test/resources/models/supervised.model.bin";
+ URL modelUrl = new File(modelFile).toURI().toURL();
+ assertNotNull(String.format("Failed to locate model '%s'", modelFile), modelUrl);
+ JFastText jft = new JFastText(modelUrl);
+ System.out.printf("\tnumber of words = %d\n", jft.getNWords());
+ }
}