001/*-
002 *******************************************************************************
003 * Copyright (c) 2011, 2016 Diamond Light Source Ltd.
004 * All rights reserved. This program and the accompanying materials
005 * are made available under the terms of the Eclipse Public License v1.0
006 * which accompanies this distribution, and is available at
007 * http://www.eclipse.org/legal/epl-v10.html
008 *
009 * Contributors:
010 *    Peter Chang - initial API and implementation and/or initial documentation
011 *******************************************************************************/
012
013package org.eclipse.january.dataset;
014
015import java.util.Arrays;
016import java.util.List;
017
018/**
019 * Class to run over a single dataset with NumPy broadcasting to promote shapes
020 * which have lower rank and outputs to a second dataset
021 */
022public class SingleInputBroadcastIterator extends IndexIterator {
023        private int[] maxShape;
024        private int[] aShape;
025        private final Dataset aDataset;
026        private final Dataset oDataset;
027        private int[] aStride;
028        private int[] oStride;
029
030        final private int endrank;
031
032        /**
033         * position in dataset
034         */
035        private final int[] pos;
036        private final int[] aDelta;
037        private final int[] oDelta; // this being non-null means output is different from inputs
038        private final int aStep, oStep;
039        private int aMax;
040        private int aStart, oStart;
041        private final boolean outputA;
042
043        /**
044         * Index in array
045         */
046        public int aIndex, oIndex;
047
048        /**
049         * Current value in array
050         */
051        public double aDouble;
052
053        /**
054         * Current value in array
055         */
056        public long aLong;
057
058        private boolean asDouble = true;
059
060        /**
061         * @param a
062         * @param o (can be null for new dataset, or a)
063         */
064        public SingleInputBroadcastIterator(Dataset a, Dataset o) {
065                this(a, o, false);
066        }
067
068        /**
069         * @param a
070         * @param o (can be null for new dataset, or a)
071         * @param createIfNull (by default, can create float or complex datasets)
072         */
073        public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull) {
074                this(a, o, createIfNull, false, true);
075        }
076
077        /**
078         * @param a
079         * @param o (can be null for new dataset, or a)
080         * @param createIfNull
081         * @param allowInteger if true, can create integer datasets
082         * @param allowComplex if true, can create complex datasets
083         */
084        @SuppressWarnings("deprecation")
085        public SingleInputBroadcastIterator(Dataset a, Dataset o, boolean createIfNull, boolean allowInteger, boolean allowComplex) {
086                List<int[]> fullShapes = BroadcastUtils.broadcastShapes(a.getShapeRef(), o == null ? null : o.getShapeRef());
087
088                BroadcastUtils.checkItemSize(a, o);
089
090                maxShape = fullShapes.remove(0);
091
092                oStride = null;
093                if (o != null) {
094                        if (!Arrays.equals(maxShape, o.getShapeRef())) {
095                                throw new IllegalArgumentException("Output does not match broadcasted shape");
096                        }
097                        o.setDirty();
098                }
099
100                aShape = fullShapes.remove(0);
101
102                int rank = maxShape.length;
103                endrank = rank - 1;
104
105                aDataset = a.reshape(aShape);
106                aStride = BroadcastUtils.createBroadcastStrides(aDataset, maxShape);
107                outputA = o == a;
108                if (outputA) {
109                        oStride = aStride;
110                        oDelta = null;
111                        oStep = 0;
112                        oDataset = aDataset;
113                } else if (o != null) {
114                        oStride = BroadcastUtils.createBroadcastStrides(o, maxShape);
115                        oDelta = new int[rank];
116                        oStep = o.getElementsPerItem();
117                        oDataset = o;
118                } else if (createIfNull) {
119                        int is = aDataset.getElementsPerItem();
120                        int dt = aDataset.getDType();
121                        if (aDataset.isComplex() && !allowComplex) {
122                                is = 1;
123                                dt = DTypeUtils.getBestFloatDType(dt);
124                        } else if (!aDataset.hasFloatingPointElements() && !allowInteger) {
125                                dt = DTypeUtils.getBestFloatDType(dt);
126                        }
127                        oDataset = DatasetFactory.zeros(is, maxShape, dt);
128                        oStride = BroadcastUtils.createBroadcastStrides(oDataset, maxShape);
129                        oDelta = new int[rank];
130                        oStep = oDataset.getElementsPerItem();
131                } else {
132                        oDelta = null;
133                        oStep = 0;
134                        oDataset = o;
135                }
136
137                pos = new int[rank];
138                aDelta = new int[rank];
139                aStep = aDataset.getElementsPerItem();
140                for (int j = endrank; j >= 0; j--) {
141                        aDelta[j] = aStride[j] * aShape[j];
142                        if (oDelta != null) {
143                                oDelta[j] = oStride[j] * maxShape[j];
144                        }
145                }
146                aStart = aDataset.getOffset();
147                aMax = endrank < 0 ? aStep + aStart: Integer.MIN_VALUE;
148                oStart = oDelta == null ? 0 : oDataset.getOffset();
149                asDouble = aDataset.hasFloatingPointElements();
150                reset();
151        }
152
153        /**
154         * @return true if output from iterator is double
155         */
156        public boolean isOutputDouble() {
157                return asDouble;
158        }
159
160        /**
161         * Set to output doubles
162         * @param asDouble
163         */
164        public void setOutputDouble(boolean asDouble) {
165                if (this.asDouble != asDouble) {
166                        this.asDouble = asDouble;
167                        storeCurrentValues();
168                }
169        }
170
171        @Override
172        public int[] getShape() {
173                return maxShape;
174        }
175
176        @Override
177        public boolean hasNext() {
178                int j = endrank;
179                int oldA = aIndex;
180                for (; j >= 0; j--) {
181                        pos[j]++;
182                        aIndex += aStride[j];
183                        if (oDelta != null) {
184                                oIndex += oStride[j];
185                        }
186                        if (pos[j] >= maxShape[j]) {
187                                pos[j] = 0;
188                                aIndex -= aDelta[j]; // reset these dimensions
189                                if (oDelta != null) {
190                                        oIndex -= oDelta[j];
191                                }
192                        } else {
193                                break;
194                        }
195                }
196                if (j == -1) {
197                        if (endrank >= 0) {
198                                return false;
199                        }
200                        aIndex += aStep;
201                        if (oDelta != null) {
202                                oIndex += oStep;
203                        }
204                }
205                if (outputA) {
206                        oIndex = aIndex;
207                }
208
209                if (aIndex == aMax) {
210                        return false; // used for zero-rank datasets
211                }
212
213                if (oldA != aIndex) {
214                        if (asDouble) {
215                                aDouble = aDataset.getElementDoubleAbs(aIndex);
216                        } else {
217                                aLong = aDataset.getElementLongAbs(aIndex);
218                        }
219                }
220
221                return true;
222        }
223
224        /**
225         * @return output dataset (can be null)
226         */
227        public Dataset getOutput() {
228                return oDataset;
229        }
230
231        @Override
232        public int[] getPos() {
233                return pos;
234        }
235
236        @Override
237        public void reset() {
238                for (int i = 0; i <= endrank; i++) {
239                        pos[i] = 0;
240                }
241
242                if (endrank >= 0) {
243                        pos[endrank] = -1;
244                        aIndex = aStart - aStride[endrank];
245                        oIndex = oStart - (oStride == null ? 0 : oStride[endrank]);
246                } else {
247                        aIndex = -aStep;
248                        oIndex = -oStep;
249                }
250
251                // for zero-ranked datasets
252                if (aIndex == 0) {
253                        storeCurrentValues();
254                }
255        }
256
257        private void storeCurrentValues() {
258                if (aIndex >= 0) {
259                        if (asDouble) {
260                                aDouble = aDataset.getElementDoubleAbs(aIndex);
261                        } else {
262                                aLong = aDataset.getElementLongAbs(aIndex);
263                        }
264                }
265        }
266}