From 55a3d9bb2cbc6e71f4cc61058e744b065c6cd948 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Tue, 31 Mar 2020 11:56:56 +1100 Subject: [PATCH] Fix loading both model and serializer at once from stream + re-add checks Signed-off-by: Alex Black --- .../deeplearning4j/util/ModelSerializer.java | 37 +++++++++++++------ 1 file changed, 26 insertions(+), 11 deletions(-) diff --git a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java index 997303977..30dbe7b4b 100644 --- a/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java +++ b/deeplearning4j/deeplearning4j-nn/src/main/java/org/deeplearning4j/util/ModelSerializer.java @@ -16,6 +16,7 @@ package org.deeplearning4j.util; +import org.apache.commons.io.input.CloseShieldInputStream; import org.nd4j.shade.guava.io.Files; import lombok.NonNull; import lombok.extern.slf4j.Slf4j; @@ -235,6 +236,11 @@ public class ModelSerializer { */ public static MultiLayerNetwork restoreMultiLayerNetwork(@NonNull InputStream is, boolean loadUpdater) throws IOException { + return restoreMultiLayerNetworkHelper(is, loadUpdater).getFirst(); + } + + private static Pair> restoreMultiLayerNetworkHelper(@NonNull InputStream is, boolean loadUpdater) + throws IOException { checkInputStream(is); Map zipFile = loadZipData(is); @@ -343,7 +349,7 @@ public class ModelSerializer { if (gotUpdaterState && updaterState != null) { network.getUpdater().setStateViewArray(network, updaterState, false); } - return network; + return new Pair<>(network, zipFile); } else throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); @@ -399,9 +405,11 @@ public class ModelSerializer { public static Pair restoreMultiLayerNetworkAndNormalizer( @NonNull InputStream is, boolean loadUpdater) throws IOException { checkInputStream(is); + is = new CloseShieldInputStream(is); - MultiLayerNetwork net = restoreMultiLayerNetwork(is, loadUpdater); - Normalizer norm = restoreNormalizerFromInputStream(is); + Pair> p = restoreMultiLayerNetworkHelper(is, loadUpdater); + MultiLayerNetwork net = p.getFirst(); + Normalizer norm = restoreNormalizerFromMap(p.getSecond()); return new Pair<>(net, norm); } @@ -453,6 +461,11 @@ public class ModelSerializer { */ public static ComputationGraph restoreComputationGraph(@NonNull InputStream is, boolean loadUpdater) throws IOException { + return restoreComputationGraphHelper(is, loadUpdater).getFirst(); + } + + private static Pair> restoreComputationGraphHelper(@NonNull InputStream is, boolean loadUpdater) + throws IOException { checkInputStream(is); Map files = loadZipData(is); @@ -564,7 +577,7 @@ public class ModelSerializer { if (gotUpdaterState && updaterState != null) { cg.getUpdater().setStateViewArray(updaterState); } - return cg; + return new Pair<>(cg, files); } else throw new IllegalStateException("Model wasnt found within file: gotConfig: [" + gotConfig + "], gotCoefficients: [" + gotCoefficients + "], gotUpdater: [" + gotUpdaterState + "]"); @@ -604,9 +617,11 @@ public class ModelSerializer { public static Pair restoreComputationGraphAndNormalizer( @NonNull InputStream is, boolean loadUpdater) throws IOException { checkInputStream(is); - - ComputationGraph net = restoreComputationGraph(is, loadUpdater); - Normalizer norm = restoreNormalizerFromInputStream(is); + + + Pair> p = restoreComputationGraphHelper(is, loadUpdater); + ComputationGraph net = p.getFirst(); + Normalizer norm = restoreNormalizerFromMap(p.getSecond()); return new Pair<>(net, norm); } @@ -900,8 +915,11 @@ public class ModelSerializer { */ public static T restoreNormalizerFromInputStream(InputStream is) throws IOException { checkInputStream(is); - Map files = loadZipData(is); + return restoreNormalizerFromMap(files); + } + + private static T restoreNormalizerFromMap(Map files) throws IOException { byte[] norm = files.get(NORMALIZER_BIN); // checking for file existence @@ -937,8 +955,6 @@ public class ModelSerializer { private static void checkInputStream(InputStream inputStream) throws IOException { - - /* //available method can return 0 in some cases: https://github.com/deeplearning4j/deeplearning4j/issues/4887 int available; try{ @@ -953,7 +969,6 @@ public class ModelSerializer { throw new IOException("Cannot read from stream: stream may have been closed or is attempting to be read from" + "multiple times?"); } - */ } private static Map loadZipData(InputStream is) throws IOException {