alessandrobenedetti commented on code in PR #1435:
URL: https://github.com/apache/solr/pull/1435#discussion_r1154177094
##########
solr/core/src/java/org/apache/solr/schema/DenseVectorField.java:
##########
@@ -205,53 +257,144 @@ public IndexableField createField(SchemaField field,
Object parsedVector) {
* org.apache.solr.handler.loader.CSVLoader} produces an ArrayList of String
- {@link
* org.apache.solr.handler.loader.JsonLoader} produces an ArrayList of
Double - {@link
* org.apache.solr.handler.loader.JavabinLoader} produces an ArrayList of
Float
- *
- * @param inputValue - An {@link ArrayList} containing the elements of the
vector
- * @return the vector parsed
*/
- float[] parseVector(Object inputValue) {
- if (!(inputValue instanceof List)) {
- throw new SolrException(
- SolrException.ErrorCode.BAD_REQUEST,
- "incorrect vector format."
- + " The expected format is an array :'[f1,f2..f3]' where each
element f is a float");
+ public VectorBuilder getVectorBuilder(Object inputValue) {
+ switch (vectorEncoding) {
+ case FLOAT32:
+ return new VectorBuilder.Float32VectorBuilder(dimension, inputValue);
+ case BYTE:
+ return new VectorBuilder.ByteVectorBuilder(dimension, inputValue);
+ default:
+ throw new SolrException(
+ SolrException.ErrorCode.SERVER_ERROR,
+ "Unexpected state. Vector Encoding: " + vectorEncoding);
}
- List<?> inputVector = (List<?>) inputValue;
- if (inputVector.size() != dimension) {
- throw new SolrException(
- SolrException.ErrorCode.BAD_REQUEST,
- "incorrect vector dimension."
- + " The vector value has size "
- + inputVector.size()
- + " while it is expected a vector with size "
- + dimension);
+ }
+
+ abstract static class VectorBuilder {
+
+ protected int dimension;
+ protected Object inputValue;
+
+ public float[] getFloatVector() {
+ throw new RuntimeException("Not implemented");
}
- float[] vector = new float[dimension];
- if (inputVector.get(0) instanceof CharSequence) {
- for (int i = 0; i < dimension; i++) {
- try {
- vector[i] = Float.parseFloat(inputVector.get(i).toString());
- } catch (NumberFormatException e) {
- throw new SolrException(
- SolrException.ErrorCode.BAD_REQUEST,
- "incorrect vector element: '"
- + inputVector.get(i)
- + "'. The expected format is:'[f1,f2..f3]' where each
element f is a float");
+ protected BytesRef getByteVector() {
+ throw new RuntimeException("Not implemented");
+ }
+
+ protected void parseVector() {
+ if (!(inputValue instanceof List)) {
+ throw new SolrException(
+ SolrException.ErrorCode.BAD_REQUEST, "incorrect vector format. " +
errorMessage());
+ }
+ List<?> inputVector = (List<?>) inputValue;
+ if (inputVector.size() != dimension) {
+ throw new SolrException(
+ SolrException.ErrorCode.BAD_REQUEST,
+ "incorrect vector dimension."
+ + " The vector value has size "
+ + inputVector.size()
+ + " while it is expected a vector with size "
+ + dimension);
+ }
+
+ if (inputVector.get(0) instanceof CharSequence) {
+ for (int i = 0; i < dimension; i++) {
+ try {
+ addStringElement(inputVector.get(i).toString());
+ } catch (NumberFormatException e) {
+ throw new SolrException(
+ SolrException.ErrorCode.BAD_REQUEST,
+ "incorrect vector element: '" + inputVector.get(i) + "'. " +
errorMessage());
+ }
}
+ } else if (inputVector.get(0) instanceof Number) {
+ for (int i = 0; i < dimension; i++) {
+ addNumberElement((Number) inputVector.get(i));
+ }
+ } else {
+ throw new SolrException(
+ SolrException.ErrorCode.BAD_REQUEST, "incorrect vector format. " +
errorMessage());
}
- } else if (inputVector.get(0) instanceof Number) {
- for (int i = 0; i < dimension; i++) {
- vector[i] = ((Number) inputVector.get(i)).floatValue();
+ }
+
+ protected abstract void addNumberElement(Number element);
+
+ protected abstract void addStringElement(String element);
+
+ protected abstract String errorMessage();
+
+ static class ByteVectorBuilder extends VectorBuilder {
+ private BytesRefBuilder byteRefBuilder;
+ private BytesRef byteVector;
+
+ public ByteVectorBuilder(int dimension, Object inputValue) {
+ this.dimension = dimension;
+ this.inputValue = inputValue;
+ }
+
+ @Override
+ public BytesRef getByteVector() {
+ if (byteVector == null) {
+ byteRefBuilder = new BytesRefBuilder();
+ parseVector();
+ byteVector = byteRefBuilder.toBytesRef();
+ }
+ return byteVector;
+ }
+
+ @Override
+ protected void addNumberElement(Number element) {
+ byteRefBuilder.append(element.byteValue());
+ }
+
+ @Override
+ protected void addStringElement(String element) {
+ byteRefBuilder.append(Byte.parseByte(element));
+ }
+
+ @Override
+ protected String errorMessage() {
+ return "The expected format is:'[fb,b2..b3]' where each element b is a
byte (-128 to 127)";
Review Comment:
[fb,b2..b3] -> [b1,b2..b3] maybe?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]