36 #ifndef VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
37 #define VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
39 #ifndef NPY_NO_DEPRECATED_API
40 # define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
44 #include "array_vector.hxx"
45 #include "python_utility.hxx"
46 #include "axistags.hxx"
53 python_ptr getArrayTypeObject()
55 python_ptr arraytype((PyObject*)&PyArray_Type);
56 python_ptr vigra(PyImport_ImportModule(
"vigra"));
59 return pythonGetAttr(vigra,
"standardArrayType", arraytype);
63 std::string defaultOrder(std::string defaultValue =
"C")
65 python_ptr arraytype = getArrayTypeObject();
66 return pythonGetAttr(arraytype,
"defaultOrder", defaultValue);
70 python_ptr defaultAxistags(
int ndim, std::string order =
"")
73 order = defaultOrder();
74 python_ptr arraytype = getArrayTypeObject();
75 python_ptr func(pythonFromData(
"defaultAxistags"));
76 python_ptr d(pythonFromData(ndim));
77 python_ptr o(pythonFromData(order));
78 python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), o.get(), NULL),
79 python_ptr::keep_count);
87 python_ptr emptyAxistags(
int ndim)
89 python_ptr arraytype = getArrayTypeObject();
90 python_ptr func(pythonFromData(
"_empty_axistags"));
91 python_ptr d(pythonFromData(ndim));
92 python_ptr axistags(PyObject_CallMethodObjArgs(arraytype, func.get(), d.get(), NULL),
93 python_ptr::keep_count);
102 getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
103 python_ptr
object,
const char * name,
104 AxisInfo::AxisType type,
bool ignoreErrors)
106 python_ptr func(pythonFromData(name));
107 python_ptr t(pythonFromData((
long)type));
108 python_ptr permutation(PyObject_CallMethodObjArgs(
object, func.get(), t.get(), NULL),
109 python_ptr::keep_count);
110 if(!permutation && ignoreErrors)
115 pythonToCppException(permutation);
117 if(!PySequence_Check(permutation))
121 std::string message = std::string(name) +
"() did not return a sequence.";
122 PyErr_SetString(PyExc_ValueError, message.c_str());
123 pythonToCppException(
false);
126 ArrayVector<npy_intp> res(PySequence_Length(permutation));
127 for(
int k=0; k<(int)res.size(); ++k)
129 python_ptr i(PySequence_GetItem(permutation, k), python_ptr::keep_count);
130 #if PY_MAJOR_VERSION < 3
133 if (!PyLong_Check(i))
138 std::string message = std::string(name) +
"() did not return a sequence of int.";
139 PyErr_SetString(PyExc_ValueError, message.c_str());
140 pythonToCppException(
false);
142 #if PY_MAJOR_VERSION < 3
143 res[k] = PyInt_AsLong(i);
145 res[k] = PyLong_AsLong(i);
153 getAxisPermutationImpl(ArrayVector<npy_intp> & permute,
154 python_ptr
object,
const char * name,
bool ignoreErrors)
156 getAxisPermutationImpl(permute,
object, name, AxisInfo::AllAxes, ignoreErrors);
175 typedef PyObject * pointer;
179 PyAxisTags(python_ptr tags = python_ptr(),
bool createCopy =
false)
184 if(!PySequence_Check(tags))
186 PyErr_SetString(PyExc_TypeError,
187 "PyAxisTags(tags): tags argument must have type 'AxisTags'.");
188 pythonToCppException(
false);
190 else if(PySequence_Length(tags) == 0)
197 python_ptr func(pythonFromData(
"__copy__"));
198 axistags = python_ptr(PyObject_CallMethodObjArgs(tags, func.get(), NULL),
199 python_ptr::keep_count);
207 PyAxisTags(PyAxisTags
const & other,
bool createCopy =
false)
213 python_ptr func(pythonFromData(
"__copy__"));
214 axistags = python_ptr(PyObject_CallMethodObjArgs(other.axistags, func.get(), NULL),
215 python_ptr::keep_count);
219 axistags = other.axistags;
223 PyAxisTags(
int ndim, std::string
const & order =
"")
226 axistags = detail::defaultAxistags(ndim, order);
228 axistags = detail::emptyAxistags(ndim);
234 ? PySequence_Length(axistags)
238 long channelIndex(
long defaultVal)
const
240 return pythonGetAttr(axistags,
"channelIndex", defaultVal);
243 long channelIndex()
const
245 return channelIndex(size());
248 bool hasChannelAxis()
const
250 return channelIndex() != size();
253 long innerNonchannelIndex(
long defaultVal)
const
255 return pythonGetAttr(axistags,
"innerNonchannelIndex", defaultVal);
258 long innerNonchannelIndex()
const
260 return innerNonchannelIndex(size());
263 void setChannelDescription(std::string
const & description)
267 python_ptr d(pythonFromData(description));
268 python_ptr func(pythonFromData(
"setChannelDescription"));
269 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), d.get(), NULL),
270 python_ptr::keep_count);
271 pythonToCppException(res);
274 double resolution(
long index)
278 python_ptr func(pythonFromData(
"resolution"));
279 python_ptr i(pythonFromData(index));
280 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), NULL),
281 python_ptr::keep_count);
282 pythonToCppException(res);
283 if(!PyFloat_Check(res))
285 PyErr_SetString(PyExc_TypeError,
"AxisTags.resolution() did not return float.");
286 pythonToCppException(
false);
288 return PyFloat_AsDouble(res);
291 void setResolution(
long index,
double resolution)
295 python_ptr func(pythonFromData(
"setResolution"));
296 python_ptr i(pythonFromData(index));
297 python_ptr r(PyFloat_FromDouble(resolution), python_ptr::keep_count);
298 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), r.get(), NULL),
299 python_ptr::keep_count);
300 pythonToCppException(res);
303 void scaleResolution(
long index,
double factor)
307 python_ptr func(pythonFromData(
"scaleResolution"));
308 python_ptr i(pythonFromData(index));
309 python_ptr f(PyFloat_FromDouble(factor), python_ptr::keep_count);
310 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), f.get(), NULL),
311 python_ptr::keep_count);
312 pythonToCppException(res);
315 void toFrequencyDomain(
long index,
int size,
int sign = 1)
319 python_ptr func(
sign == 1
320 ? pythonFromData(
"toFrequencyDomain")
321 : pythonFromData(
"fromFrequencyDomain"));
322 python_ptr i(pythonFromData(index));
323 python_ptr s(pythonFromData(size));
324 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), i.get(), s.get(), NULL),
325 python_ptr::keep_count);
326 pythonToCppException(res);
329 void fromFrequencyDomain(
long index,
int size)
331 toFrequencyDomain(index, size, -1);
334 ArrayVector<npy_intp>
335 permutationToNormalOrder(
bool ignoreErrors =
false)
const
337 ArrayVector<npy_intp> permute;
338 detail::getAxisPermutationImpl(permute, axistags,
"permutationToNormalOrder", ignoreErrors);
342 ArrayVector<npy_intp>
343 permutationToNormalOrder(AxisInfo::AxisType types,
bool ignoreErrors =
false)
const
345 ArrayVector<npy_intp> permute;
346 detail::getAxisPermutationImpl(permute, axistags,
347 "permutationToNormalOrder", types, ignoreErrors);
351 ArrayVector<npy_intp>
352 permutationFromNormalOrder(
bool ignoreErrors =
false)
const
354 ArrayVector<npy_intp> permute;
355 detail::getAxisPermutationImpl(permute, axistags,
356 "permutationFromNormalOrder", ignoreErrors);
360 ArrayVector<npy_intp>
361 permutationFromNormalOrder(AxisInfo::AxisType types,
bool ignoreErrors =
false)
const
363 ArrayVector<npy_intp> permute;
364 detail::getAxisPermutationImpl(permute, axistags,
365 "permutationFromNormalOrder", types, ignoreErrors);
369 void dropChannelAxis()
373 python_ptr func(pythonFromData(
"dropChannelAxis"));
374 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL),
375 python_ptr::keep_count);
376 pythonToCppException(res);
379 void insertChannelAxis()
383 python_ptr func(pythonFromData(
"insertChannelAxis"));
384 python_ptr res(PyObject_CallMethodObjArgs(axistags, func.get(), NULL),
385 python_ptr::keep_count);
386 pythonToCppException(res);
391 return axistags.get();
394 bool operator!()
const
409 enum ChannelAxis { first, last, none };
411 ArrayVector<npy_intp> shape, original_shape;
413 ChannelAxis channelAxis;
414 std::string channelDescription;
422 template <
class U,
int N>
423 TaggedShape(TinyVector<U, N>
const & sh, PyAxisTags tags)
424 : shape(sh.begin(), sh.end()),
425 original_shape(sh.begin(), sh.end()),
431 TaggedShape(ArrayVector<T>
const & sh, PyAxisTags tags)
432 : shape(sh.begin(), sh.end()),
433 original_shape(sh.begin(), sh.end()),
438 template <
class U,
int N>
439 explicit TaggedShape(TinyVector<U, N>
const & sh)
440 : shape(sh.begin(), sh.end()),
441 original_shape(sh.begin(), sh.end()),
446 explicit TaggedShape(ArrayVector<T>
const & sh)
447 : shape(sh.begin(), sh.end()),
448 original_shape(sh.begin(), sh.end()),
452 template <
class U,
int N>
453 TaggedShape & resize(TinyVector<U, N>
const & sh)
455 int start = channelAxis == first
458 stop = channelAxis == last
462 vigra_precondition(N == stop - start || size() == 0,
463 "TaggedShape.resize(): size mismatch.");
468 for(
int k=0; k<N; ++k)
469 shape[k+start] = sh[k];
476 return resize(TinyVector<MultiArrayIndex, 1>(v1));
481 return resize(TinyVector<MultiArrayIndex, 2>(v1, v2));
486 return resize(TinyVector<MultiArrayIndex, 3>(v1, v2, v3));
492 return resize(TinyVector<MultiArrayIndex, 4>(v1, v2, v3, v4));
495 npy_intp & operator[](
int i)
500 npy_intp operator[](
int i)
const
505 unsigned int size()
const
512 int start = channelAxis == first
515 stop = channelAxis == last
518 for(
int k=start; k<stop; ++k)
531 int start = channelAxis == first
534 stop = channelAxis == last
537 for(
int k=start; k<stop; ++k)
543 void rotateToNormalOrder()
545 if(axistags && channelAxis == last)
547 int ndim = (int)size();
549 npy_intp channelCount = shape[ndim-1];
550 for(
int k=ndim-1; k>0; --k)
551 shape[k] = shape[k-1];
552 shape[0] = channelCount;
554 channelCount = original_shape[ndim-1];
555 for(
int k=ndim-1; k>0; --k)
556 original_shape[k] = original_shape[k-1];
557 original_shape[0] = channelCount;
563 TaggedShape & setChannelDescription(std::string
const & description)
567 channelDescription = description;
571 TaggedShape & setChannelIndexLast()
579 template <
class U,
int N>
580 TaggedShape & transposeShape(TinyVector<U, N>
const & p)
584 int ntags = axistags.size();
585 ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
587 int tstart = (axistags.channelIndex(ntags) < ntags)
590 int sstart = (channelAxis == first)
593 int ndim = ntags - tstart;
595 vigra_precondition(N == ndim,
596 "TaggedShape.transposeShape(): size mismatch.");
598 PyAxisTags newAxistags(axistags.axistags);
599 for(
int k=0; k<ndim; ++k)
601 original_shape[k+sstart] = shape[p[k]+sstart];
602 newAxistags.setResolution(permute[k+tstart], axistags.resolution(permute[p[k]+tstart]));
604 axistags = newAxistags;
608 for(
int k=0; k<N; ++k)
610 original_shape[k] = shape[p[k]];
613 shape = original_shape;
618 TaggedShape & toFrequencyDomain(
int sign = 1)
622 int ntags = axistags.size();
624 ArrayVector<npy_intp> permute = axistags.permutationToNormalOrder();
626 int tstart = (axistags.channelIndex(ntags) < ntags)
629 int sstart = (channelAxis == first)
632 int send = (channelAxis == last)
635 int size = send - sstart;
637 for(
int k=0; k<size; ++k)
639 axistags.toFrequencyDomain(permute[k+tstart], shape[k+sstart],
sign);
645 bool hasChannelAxis()
const
647 return channelAxis !=none;
650 TaggedShape & fromFrequencyDomain()
652 return toFrequencyDomain(-1);
655 bool compatible(TaggedShape
const & other)
const
657 if(channelCount() != other.channelCount())
660 int start = channelAxis == first
663 stop = channelAxis == last
666 int ostart = other.channelAxis == first
669 ostop = other.channelAxis == last
670 ? (int)other.size()-1
673 int len = stop - start;
674 if(len != ostop - ostart)
677 for(
int k=0; k<len; ++k)
678 if(shape[k+start] != other.shape[k+ostart])
683 TaggedShape & setChannelCount(
int count)
694 shape.erase(shape.begin());
695 original_shape.erase(original_shape.begin());
702 shape[size()-1] = count;
707 original_shape.pop_back();
714 shape.push_back(count);
715 original_shape.push_back(count);
723 int channelCount()
const
730 return shape[size()-1];
738 void scaleAxisResolution(TaggedShape & tagged_shape)
740 if(tagged_shape.size() != tagged_shape.original_shape.size())
743 int ntags = tagged_shape.axistags.size();
745 ArrayVector<npy_intp> permute = tagged_shape.axistags.permutationToNormalOrder();
747 int tstart = (tagged_shape.axistags.channelIndex(ntags) < ntags)
750 int sstart = (tagged_shape.channelAxis == TaggedShape::first)
753 int size = (int)tagged_shape.size() - sstart;
755 for(
int k=0; k<size; ++k)
758 if(tagged_shape.shape[sk] == tagged_shape.original_shape[sk])
760 double factor = (tagged_shape.original_shape[sk] - 1.0) / (tagged_shape.shape[sk] - 1.0);
761 tagged_shape.axistags.scaleResolution(permute[k+tstart], factor);
766 void unifyTaggedShapeSize(TaggedShape & tagged_shape)
768 PyAxisTags axistags = tagged_shape.axistags;
769 ArrayVector<npy_intp> & shape = tagged_shape.shape;
771 int ndim = (int)shape.size();
772 int ntags = axistags.size();
774 long channelIndex = axistags.channelIndex();
776 if(tagged_shape.channelAxis == TaggedShape::none)
779 if(channelIndex == ntags)
783 vigra_precondition(ndim == ntags,
784 "constructArray(): size mismatch between shape and axistags.");
794 axistags.dropChannelAxis();
798 vigra_precondition(ndim == ntags,
799 "constructArray(): size mismatch between shape and axistags.");
806 if(channelIndex == ntags)
810 vigra_precondition(ndim == ntags+1,
811 "constructArray(): size mismatch between shape and axistags.");
817 shape.erase(shape.begin());
824 axistags.insertChannelAxis();
831 vigra_precondition(ndim == ntags,
832 "constructArray(): size mismatch between shape and axistags.");
838 ArrayVector<npy_intp> finalizeTaggedShape(TaggedShape & tagged_shape)
840 if(tagged_shape.axistags)
842 tagged_shape.rotateToNormalOrder();
846 scaleAxisResolution(tagged_shape);
850 unifyTaggedShapeSize(tagged_shape);
852 if(tagged_shape.channelDescription !=
"")
853 tagged_shape.axistags.setChannelDescription(tagged_shape.channelDescription);
855 return tagged_shape.shape;
860 #endif // VIGRA_NUMPY_ARRAY_TAGGEDSHAPE_HXX
std::ptrdiff_t MultiArrayIndex
Definition: multi_fwd.hxx:60
FFTWComplex< R > & operator-=(FFTWComplex< R > &a, const FFTWComplex< R > &b)
subtract-assignment
Definition: fftw3.hxx:867
FFTWComplex< R > & operator+=(FFTWComplex< R > &a, const FFTWComplex< R > &b)
add-assignment
Definition: fftw3.hxx:859
FFTWComplex< R > & operator*=(FFTWComplex< R > &a, const FFTWComplex< R > &b)
multiply-assignment
Definition: fftw3.hxx:875
T sign(T t)
The sign function.
Definition: mathutil.hxx:591