Skip to content

Commit aae4c56

Browse files
committed
Minor changes
1 parent bfb499f commit aae4c56

File tree

5 files changed

+119
-67
lines changed

5 files changed

+119
-67
lines changed

eclipse/io.github.mzattera.v4j/src/main/java/io/github/mzattera/v4j/text/alphabet/SlotAlphabet.java

Lines changed: 0 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -179,15 +179,6 @@ public char[] getRegularChars() {
179179
p = new ArrayList<String>();
180180
p.add("s");
181181
p.add("d");
182-
// p.add("K");
183-
// p.add("k");
184-
// p.add("T");
185-
// p.add("t");
186-
// p.add("P");
187-
// p.add("p");
188-
// p.add("F");
189-
// p.add("f");
190-
// p.add("x");
191182
SLOTS.add(p);
192183

193184
//// 8 //////
@@ -198,15 +189,9 @@ public char[] getRegularChars() {
198189

199190
//// 9 //////
200191
p = new ArrayList<String>();
201-
202192
p.add("i");
203193
p.add("J");
204194
p.add("U");
205-
206-
// p.add("V");
207-
// p.add("W");
208-
// p.add("Z");
209-
210195
SLOTS.add(p);
211196

212197
//// 10 //////
@@ -221,7 +206,6 @@ public char[] getRegularChars() {
221206
//// 11 //////
222207
p = new ArrayList<String>();
223208
p.add("y");
224-
// p.add("g");
225209
SLOTS.add(p);
226210
}
227211

@@ -578,15 +562,6 @@ private static TermDecomposition internalDecompose(String term) {
578562
pushRight(result.slots1, "d", 0, 7);
579563
pushRight(result.slots1, "s", 0, 7);
580564

581-
// pushRight(result.slots1, "t", 3, 7);
582-
// pushRight(result.slots1, "k", 3, 7);
583-
// pushRight(result.slots1, "p", 3, 7);
584-
// pushRight(result.slots1, "f", 3, 7);
585-
// pushRight(result.slots1, "T", 5, 7);
586-
// pushRight(result.slots1, "K", 5, 7);
587-
// pushRight(result.slots1, "P", 5, 7);
588-
// pushRight(result.slots1, "F", 5, 7);
589-
590565
result.part2 = s;
591566
if (s == null) {
592567
result.classification = TermClassification.REGULAR;

eclipse/io.github.mzattera.v4j/src/main/java/io/github/mzattera/v4j/text/ivtff/LineFilter.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ public class LineFilter implements ElementFilter<IvtffLine> {
2222
private final String transcriber;
2323

2424
/**
25-
* Pre-made filter to return "paragraph" text; that is terxt contained in "P0"
25+
* Pre-made filter to return "paragraph" text; that is text contained in "P0"
2626
* or "P1" loci.
2727
*/
2828
public static final ElementFilter<IvtffLine> PARAGRAPH_TEXT_FILTER = new ElementFilter<IvtffLine>() {

eclipse/io.github.mzattera.v4j/src/main/java/io/github/mzattera/v4j/util/FileUtil.java

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,17 +30,31 @@ private FileUtil() {
3030
* Write a string into an ASCII file.
3131
*/
3232
public static void write(String txt, String fileName) throws IOException {
33-
write(txt, fileName, "ASCII");
33+
write(txt, new File(fileName), "ASCII");
34+
}
35+
36+
/**
37+
* Write a string into an ASCII file.
38+
*/
39+
public static void write(String txt, File file) throws IOException {
40+
write(txt, file, "ASCII");
3441
}
3542

3643
/**
3744
* Write a string into a file with given encoding.
3845
*/
3946
public static void write(String txt, String fileName, String encoding) throws IOException {
47+
write(txt, new File(fileName), encoding);
48+
}
49+
50+
/**
51+
* Write a string into a file with given encoding.
52+
*/
53+
public static void write(String txt, File file, String encoding) throws IOException {
4054

4155
BufferedWriter out = null;
4256
try {
43-
out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fileName), encoding));
57+
out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), encoding));
4458
out.write(txt);
4559
out.flush();
4660
} finally {
@@ -56,17 +70,31 @@ public static void write(String txt, String fileName, String encoding) throws IO
5670
* Write a list of strings into a file. Uses UTF-8 encoding.
5771
*/
5872
public static void write(List<String> txt, String fileName) throws IOException {
59-
write(txt, fileName, "UTF-8");
73+
write(txt, new File(fileName), "UTF-8");
74+
}
75+
76+
/**
77+
* Write a list of strings into a file. Uses UTF-8 encoding.
78+
*/
79+
public static void write(List<String> txt, File file) throws IOException {
80+
write(txt, file, "UTF-8");
6081
}
6182

6283
/**
6384
* Write a list of strings into a file with given encoding.
6485
*/
6586
public static void write(List<String> txt, String fileName, String encoding) throws IOException {
87+
write(txt, new File(fileName), encoding);
88+
}
89+
90+
/**
91+
* Write a list of strings into a file with given encoding.
92+
*/
93+
public static void write(List<String> txt, File file, String encoding) throws IOException {
6694

6795
BufferedWriter out = null;
6896
try {
69-
out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(fileName), encoding));
97+
out = new BufferedWriter(new OutputStreamWriter(new FileOutputStream(file), encoding));
7098
for (String s : txt) {
7199
out.write(s);
72100
out.newLine();

eclipse/io.github.mzattera.v4j/src/main/java/io/github/mzattera/v4j/util/KerasUtil.java

Lines changed: 58 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@
2121
*/
2222
public final class KerasUtil {
2323

24+
private final static Random rnd = new Random();
25+
2426
private KerasUtil() {
2527
}
2628

@@ -32,7 +34,11 @@ private KerasUtil() {
3234
* index of the element with highest probability accordingly to softmax
3335
* - greedy sampling).
3436
*/
35-
public static long greedy(INDArray a) {
37+
public static int greedy(INDArray a) {
38+
if (a == null || a.rank() != 2 || a.shape()[1] == 0) {
39+
throw new IllegalArgumentException("Input must be a non-null one-dimensional INDArray.");
40+
}
41+
3642
long idx = 0;
3743
double max = 0.0;
3844
for (long i = 0; i < a.shape()[1]; ++i) {
@@ -42,18 +48,16 @@ public static long greedy(INDArray a) {
4248
}
4349
}
4450

45-
return idx;
51+
return (int) idx;
4652
}
4753

48-
private final static Random rnd = new Random(42);
49-
5054
/**
5155
* @param a an array of shape [0,N], reflecting the output of a Keras softmax
5256
* layer.
5357
*
5458
* @return A random index, distributed accordingly to a.
5559
*/
56-
public static long random(INDArray a) {
60+
public static int random(INDArray a) {
5761
return random(a, rnd);
5862
}
5963

@@ -64,33 +68,21 @@ public static long random(INDArray a) {
6468
*
6569
* @return A random index, distributed accordingly to a.
6670
*/
67-
public static long random(INDArray a, Random r) {
68-
69-
// TODO should be using long indexes, not int
70-
double[] p = new double[(int) a.shape()[1]];
71-
if (p.length < 1)
72-
throw new IllegalArgumentException();
71+
public static int random(INDArray a, Random r) {
7372

74-
// Cumulative probabilities
75-
p[0] = a.getDouble(0, 0);
76-
for (int i = 1; i < p.length; ++i) {
77-
p[i] = p[i - 1] + a.getDouble(0, i);
73+
if (a == null || a.rank() != 2 || a.shape()[1] == 0) {
74+
throw new IllegalArgumentException("Input must be a non-null one-dimensional INDArray.");
7875
}
7976

80-
double e = r.nextDouble();
81-
for (int i = 0; i < p.length; ++i)
82-
if (p[i] >= e)
83-
return i;
84-
85-
return p.length - 1; // safeguard for rounding
77+
return random(a.toDoubleVector(), r);
8678
}
8779

8880
/**
8981
* @param a an array reflecting the output of a Keras softmax layer.
9082
*
9183
* @return A random index, distributed accordingly to a.
9284
*/
93-
public static long random(double[] a) {
85+
public static int random(double[] a) {
9486
return random(a, rnd);
9587
}
9688

@@ -101,20 +93,16 @@ public static long random(double[] a) {
10193
*
10294
* @return A random index, distributed accordingly to a.
10395
*/
104-
public static long random(double[] a, Random r) {
105-
106-
// TODO should be using long indexes, not int
107-
double[] p = new double[a.length];
108-
if (p.length < 1)
109-
throw new IllegalArgumentException();
96+
public static int random(double[] a, Random r) {
11097

11198
// Cumulative probabilities
99+
double[] p = new double[a.length];
112100
p[0] = a[0];
113101
for (int i = 1; i < p.length; ++i) {
114102
p[i] = p[i - 1] + a[i];
115103
}
116104

117-
double e = r.nextDouble();
105+
double e = rnd.nextDouble();
118106
for (int i = 0; i < p.length; ++i)
119107
if (p[i] >= e)
120108
return i;
@@ -156,6 +144,10 @@ public static long topK(INDArray a, int k) {
156144
*/
157145
public static long topK(INDArray a, int k, Random r) {
158146

147+
if (a == null || a.rank() != 2 || a.shape()[1] == 0) {
148+
throw new IllegalArgumentException("Input must be a non-null one-dimensional INDArray.");
149+
}
150+
159151
// put a in a sorted list
160152
Map<Long, Double> d = new HashMap<>();
161153
for (long i = 0; i < a.shape()[1]; ++i) {
@@ -174,7 +166,7 @@ public static long topK(INDArray a, int k, Random r) {
174166

175167
// normalise and random sample
176168
// TODO this should workj with long
177-
int idx = (int) random(normalize(a2));
169+
int idx = random(normalize(a2));
178170

179171
// re-translate into an index on the original array
180172
return list.get(idx).getKey();
@@ -200,6 +192,10 @@ public static long nucleus(INDArray a, double p) {
200192
*/
201193
public static long nucleus(INDArray a, double p, Random r) {
202194

195+
if (a == null || a.rank() != 2 || a.shape()[1] == 0) {
196+
throw new IllegalArgumentException("Input must be a non-null one-dimensional INDArray.");
197+
}
198+
203199
// put a in a sorted list
204200
Map<Long, Double> d = new HashMap<>();
205201
for (long i = 0; i < a.shape()[1]; ++i) {
@@ -217,24 +213,51 @@ public static long nucleus(INDArray a, double p, Random r) {
217213
if (s >= p)
218214
break;
219215
}
220-
double[] a2 = new double[aSize+1];
216+
double[] a2 = new double[aSize + 1];
221217
for (int i = 0; i <= aSize; ++i) {
222218
a2[i] = list.get(i).getValue();
223219
}
224220

225221
// normalise and random sample
226222
// TODO this should workj with long
227-
int idx = (int) random(normalize(a2));
223+
int idx = random(normalize(a2));
228224

229225
// re-translate into an index on the original array
230226
return list.get(idx).getKey();
231227
}
232228

233229
/**
234-
* @param args
230+
*
231+
* @param a an array of shape [0,N], reflecting the output of a Keras softmax
232+
* layer.
233+
*
234+
* @return An index in a using random sampling with givne temperature t.
235+
*/
236+
public static long temperature(INDArray a, double t) {
237+
238+
return temperature(a, t, rnd);
239+
}
240+
241+
/**
242+
*
243+
* @param a an array of shape [0,N], reflecting the output of a Keras softmax
244+
* layer.
245+
*
246+
* @return An index in a using random sampling with givne temperature t.
235247
*/
236-
public static void main(String[] args) {
237-
// TODO Auto-generated method stub
248+
public static long temperature(INDArray a, double t, Random r) {
238249

250+
if (a == null || a.rank() != 2 || a.shape()[1] == 0) {
251+
throw new IllegalArgumentException("Input must be a non-null one-dimensional INDArray.");
252+
}
253+
254+
// scale probabilities using temperature
255+
double[] d = new double[(int) a.shape()[1]];
256+
for (int i = 0; i < d.length; ++i) {
257+
d[i] = Math.exp(a.getDouble(0, i) / t);
258+
}
259+
260+
// normalise and random sample
261+
return random(normalize(d));
239262
}
240263
}

eclipse/io.github.mzattera.v4j/src/main/java/io/github/mzattera/v4j/util/StringUtil.java

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,9 +22,9 @@
2222
*
2323
*/
2424
public final class StringUtil {
25-
25+
2626
private final static Random RND = new Random();
27-
27+
2828
private StringUtil() {
2929
}
3030

@@ -211,4 +211,30 @@ public int compare(String o1, String o2) {
211211

212212
return result.toString();
213213
}
214+
215+
/**
216+
* @return Levenshtein distance between two strings.
217+
*/
218+
public static int levenshtein(String s1, String s2) {
219+
int[][] distance = new int[s1.length() + 1][s2.length() + 1];
220+
221+
for (int i = 0; i <= s1.length(); i++) {
222+
distance[i][0] = i;
223+
}
224+
for (int j = 1; j <= s2.length(); j++) {
225+
distance[0][j] = j;
226+
}
227+
228+
for (int i = 1; i <= s1.length(); i++) {
229+
for (int j = 1; j <= s2.length(); j++) {
230+
int cost = (s1.charAt(i - 1) == s2.charAt(j - 1)) ? 0 : 1;
231+
232+
distance[i][j] = Math.min(Math.min(distance[i - 1][j] + 1, // deletion
233+
distance[i][j - 1] + 1), // insertion
234+
distance[i - 1][j - 1] + cost); // substitution
235+
}
236+
}
237+
238+
return distance[s1.length()][s2.length()];
239+
}
214240
}

0 commit comments

Comments
 (0)